test_nn.py 39.6 KB
Newer Older
1
import io
2
3
4
5
import torch as th
import networkx as nx
import dgl
import dgl.nn.pytorch as nn
6
import dgl.function as fn
7
import backend as F
8
import pytest
9
10
from test_utils.graph_cases import get_cases, random_graph, random_bipartite, random_dglgraph
from test_utils import parametrize_dtype
11
from copy import deepcopy
12
import pickle
13

14
15
import scipy as sp

16
17
tmp_buffer = io.BytesIO()

18
19
20
21
22
def _AXWb(A, X, W, b):
    X = th.matmul(X, W)
    Y = th.matmul(A, X.view(X.shape[0], -1)).view_as(X)
    return Y + b

23
24
@pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv0(out_dim):
25
    g = dgl.DGLGraph(nx.path_graph(3)).to(F.ctx())
26
    ctx = F.ctx()
27
    adj = g.adjacency_matrix(transpose=True, ctx=ctx)
28

29
    conv = nn.GraphConv(5, out_dim, norm='none', bias=True)
30
    conv = conv.to(ctx)
31
    print(conv)
32
33
34
35
36

    # test pickle
    th.save(conv, tmp_buffer)


37
    # test#1: basic
38
    h0 = F.ones((3, 5))
39
    h1 = conv(g, h0)
40
41
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
42
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
43
    # test#2: more-dim
44
    h0 = F.ones((3, 5, 5))
45
    h1 = conv(g, h0)
46
47
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
48
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
49

50
    conv = nn.GraphConv(5, out_dim)
51
    conv = conv.to(ctx)
52
    # test#3: basic
53
    h0 = F.ones((3, 5))
54
    h1 = conv(g, h0)
55
56
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
57
    # test#4: basic
58
    h0 = F.ones((3, 5, 5))
59
    h1 = conv(g, h0)
60
61
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
62

63
    conv = nn.GraphConv(5, out_dim)
64
    conv = conv.to(ctx)
65
    # test#3: basic
66
    h0 = F.ones((3, 5))
67
    h1 = conv(g, h0)
68
69
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
70
    # test#4: basic
71
    h0 = F.ones((3, 5, 5))
72
    h1 = conv(g, h0)
73
74
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
75
76
77
78
79

    # test rest_parameters
    old_weight = deepcopy(conv.weight.data)
    conv.reset_parameters()
    new_weight = conv.weight.data
80
    assert not F.allclose(old_weight, new_weight)
81

82
83
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite'], exclude=['zero-degree', 'dglgraph']))
84
@pytest.mark.parametrize('norm', ['none', 'both', 'right', 'left'])
85
86
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
87
88
@pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv(idtype, g, norm, weight, bias, out_dim):
89
90
    # Test one tensor input
    g = g.astype(idtype).to(F.ctx())
91
92
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(F.ctx())
    ext_w = F.randn((5, out_dim)).to(F.ctx())
93
94
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
95
96
    h = F.randn((nsrc, 5)).to(F.ctx())
    if weight:
97
        h_out = conv(g, h)
98
    else:
99
        h_out = conv(g, h, weight=ext_w)
100
    assert h_out.shape == (ndst, out_dim)
101

102
103
104
105
106
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['has_scalar_e_feature'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
107
108
@pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv_e_weight(idtype, g, norm, weight, bias, out_dim):
109
    g = g.astype(idtype).to(F.ctx())
110
111
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(F.ctx())
    ext_w = F.randn((5, out_dim)).to(F.ctx())
112
113
114
115
116
117
118
119
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
    e_w = g.edata['scalar_w']
    if weight:
        h_out = conv(g, h, edge_weight=e_w)
    else:
        h_out = conv(g, h, weight=ext_w, edge_weight=e_w)
120
    assert h_out.shape == (ndst, out_dim)
121
122
123
124
125
126

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['has_scalar_e_feature'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
127
128
@pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv_e_weight_norm(idtype, g, norm, weight, bias, out_dim):
129
    g = g.astype(idtype).to(F.ctx())
130
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(F.ctx())
131
132
133
134

    # test pickle
    th.save(conv, tmp_buffer)

135
    ext_w = F.randn((5, out_dim)).to(F.ctx())
136
137
138
139
140
141
142
143
144
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
    edgenorm = nn.EdgeWeightNorm(norm=norm)
    norm_weight = edgenorm(g, g.edata['scalar_w'])
    if weight:
        h_out = conv(g, h, edge_weight=norm_weight)
    else:
        h_out = conv(g, h, weight=ext_w, edge_weight=norm_weight)
145
    assert h_out.shape == (ndst, out_dim)
146

147
148
149
150
151
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
152
153
@pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv_bi(idtype, g, norm, weight, bias, out_dim):
154
155
    # Test a pair of tensor inputs
    g = g.astype(idtype).to(F.ctx())
156
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(F.ctx())
157
158
159
160
    
    # test pickle
    th.save(conv, tmp_buffer)

161
    ext_w = F.randn((5, out_dim)).to(F.ctx())
162
163
164
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
165
    h_dst = F.randn((ndst, out_dim)).to(F.ctx())
166
167
168
169
    if weight:
        h_out = conv(g, (h, h_dst))
    else:
        h_out = conv(g, (h, h_dst), weight=ext_w)
170
    assert h_out.shape == (ndst, out_dim)
171

172
173
174
175
176
177
178
179
180
181
182
183
def _S2AXWb(A, N, X, W, b):
    X1 = X * N
    X1 = th.matmul(A, X1.view(X1.shape[0], -1))
    X1 = X1 * N
    X2 = X1 * N
    X2 = th.matmul(A, X2.view(X2.shape[0], -1))
    X2 = X2 * N
    X = th.cat([X, X1, X2], dim=-1)
    Y = th.matmul(X, W.rot90())

    return Y + b

184
185
@pytest.mark.parametrize('out_dim', [1, 2])
def test_tagconv(out_dim):
186
    g = dgl.DGLGraph(nx.path_graph(3))
187
    g = g.to(F.ctx())
188
    ctx = F.ctx()
189
    adj = g.adjacency_matrix(transpose=True, ctx=ctx)
190
191
    norm = th.pow(g.in_degrees().float(), -0.5)

192
    conv = nn.TAGConv(5, out_dim, bias=True)
193
    conv = conv.to(ctx)
194
    print(conv)
195
196
197
    
    # test pickle
    th.save(conv, tmp_buffer)
198
199
200

    # test#1: basic
    h0 = F.ones((3, 5))
201
    h1 = conv(g, h0)
202
203
204
205
206
207
208
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    shp = norm.shape + (1,) * (h0.dim() - 1)
    norm = th.reshape(norm, shp).to(ctx)

    assert F.allclose(h1, _S2AXWb(adj, norm, h0, conv.lin.weight, conv.lin.bias))

209
    conv = nn.TAGConv(5, out_dim)
210
    conv = conv.to(ctx)
211

212
213
    # test#2: basic
    h0 = F.ones((3, 5))
214
    h1 = conv(g, h0)
215
    assert h1.shape[-1] == out_dim
216

217
    # test reset_parameters
218
219
220
221
222
    old_weight = deepcopy(conv.lin.weight.data)
    conv.reset_parameters()
    new_weight = conv.lin.weight.data
    assert not F.allclose(old_weight, new_weight)

223
def test_set2set():
224
    ctx = F.ctx()
225
    g = dgl.DGLGraph(nx.path_graph(10))
226
    g = g.to(F.ctx())
227
228

    s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
229
    s2s = s2s.to(ctx)
230
231
232
    print(s2s)

    # test#1: basic
233
    h0 = F.randn((g.number_of_nodes(), 5))
234
    h1 = s2s(g, h0)
235
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
236
237

    # test#2: batched graph
238
239
    g1 = dgl.DGLGraph(nx.path_graph(11)).to(F.ctx())
    g2 = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
240
    bg = dgl.batch([g, g1, g2])
241
    h0 = F.randn((bg.number_of_nodes(), 5))
242
    h1 = s2s(bg, h0)
243
244
245
    assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2

def test_glob_att_pool():
246
    ctx = F.ctx()
247
    g = dgl.DGLGraph(nx.path_graph(10))
248
    g = g.to(F.ctx())
249
250

    gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))
251
    gap = gap.to(ctx)
252
253
    print(gap)

254
255
256
    # test pickle
    th.save(gap, tmp_buffer)

257
    # test#1: basic
258
    h0 = F.randn((g.number_of_nodes(), 5))
259
    h1 = gap(g, h0)
260
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
261
262
263

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
264
    h0 = F.randn((bg.number_of_nodes(), 5))
265
    h1 = gap(bg, h0)
266
267
268
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2

def test_simple_pool():
269
    ctx = F.ctx()
270
    g = dgl.DGLGraph(nx.path_graph(15))
271
    g = g.to(F.ctx())
272
273
274
275
276
277
278
279

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
    sort_pool = nn.SortPooling(10) # k = 10
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
280
    h0 = F.randn((g.number_of_nodes(), 5))
281
282
283
284
    sum_pool = sum_pool.to(ctx)
    avg_pool = avg_pool.to(ctx)
    max_pool = max_pool.to(ctx)
    sort_pool = sort_pool.to(ctx)
285
    h1 = sum_pool(g, h0)
286
    assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
287
    h1 = avg_pool(g, h0)
288
    assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
289
    h1 = max_pool(g, h0)
290
    assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
291
    h1 = sort_pool(g, h0)
292
    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.dim() == 2
293
294

    # test#2: batched graph
295
    g_ = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
296
    bg = dgl.batch([g, g_, g, g_, g])
297
    h0 = F.randn((bg.number_of_nodes(), 5))
298
    h1 = sum_pool(bg, h0)
299
300
301
302
303
304
    truth = th.stack([F.sum(h0[:15], 0),
                      F.sum(h0[15:20], 0),
                      F.sum(h0[20:35], 0),
                      F.sum(h0[35:40], 0),
                      F.sum(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
305

306
    h1 = avg_pool(bg, h0)
307
308
309
310
311
312
    truth = th.stack([F.mean(h0[:15], 0),
                      F.mean(h0[15:20], 0),
                      F.mean(h0[20:35], 0),
                      F.mean(h0[35:40], 0),
                      F.mean(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
313

314
    h1 = max_pool(bg, h0)
315
316
317
318
319
320
    truth = th.stack([F.max(h0[:15], 0),
                      F.max(h0[15:20], 0),
                      F.max(h0[20:35], 0),
                      F.max(h0[35:40], 0),
                      F.max(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
321

322
    h1 = sort_pool(bg, h0)
323
324
325
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2

def test_set_trans():
326
    ctx = F.ctx()
327
328
329
330
331
    g = dgl.DGLGraph(nx.path_graph(15))

    st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'sab')
    st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'isab', 3)
    st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)
332
333
334
    st_enc_0 = st_enc_0.to(ctx)
    st_enc_1 = st_enc_1.to(ctx)
    st_dec = st_dec.to(ctx)
335
336
337
    print(st_enc_0, st_enc_1, st_dec)

    # test#1: basic
338
    h0 = F.randn((g.number_of_nodes(), 50))
339
    h1 = st_enc_0(g, h0)
340
    assert h1.shape == h0.shape
341
    h1 = st_enc_1(g, h0)
342
    assert h1.shape == h0.shape
343
    h2 = st_dec(g, h1)
344
    assert h2.shape[0] == 1 and h2.shape[1] == 200 and h2.dim() == 2
345
346
347
348
349

    # test#2: batched graph
    g1 = dgl.DGLGraph(nx.path_graph(5))
    g2 = dgl.DGLGraph(nx.path_graph(10))
    bg = dgl.batch([g, g1, g2])
350
    h0 = F.randn((bg.number_of_nodes(), 50))
351
    h1 = st_enc_0(bg, h0)
352
    assert h1.shape == h0.shape
353
    h1 = st_enc_1(bg, h0)
354
355
    assert h1.shape == h0.shape

356
    h2 = st_dec(bg, h1)
357
358
    assert h2.shape[0] == 3 and h2.shape[1] == 200 and h2.dim() == 2

359
360
@pytest.mark.parametrize('O', [1, 2, 8])
def test_rgcn(O):
Minjie Wang's avatar
Minjie Wang committed
361
362
363
    ctx = F.ctx()
    etype = []
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
364
    g = g.to(F.ctx())
Minjie Wang's avatar
Minjie Wang committed
365
366
367
368
369
370
371
372
    # 5 etypes
    R = 5
    for i in range(g.number_of_edges()):
        etype.append(i % 5)
    B = 2
    I = 10

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
373
374
375
376

    # test pickle
    th.save(rgc_basis, tmp_buffer)

377
378
379
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
380
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
Minjie Wang's avatar
Minjie Wang committed
381
382
383
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r)
384
    h_new_low = rgc_basis_low(g, h, r)
Minjie Wang's avatar
Minjie Wang committed
385
    assert list(h_new.shape) == [100, O]
386
387
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
388

389
390
391
392
393
394
395
396
397
398
399
400
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
        rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
        rgc_bdd_low.weight = rgc_bdd.weight
        rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
        h = th.randn((100, I)).to(ctx)
        r = th.tensor(etype).to(ctx)
        h_new = rgc_bdd(g, h, r)
        h_new_low = rgc_bdd_low(g, h, r)
        assert list(h_new.shape) == [100, O]
        assert list(h_new_low.shape) == [100, O]
        assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
401
402

    # with norm
xiang song(charlie.song)'s avatar
xiang song(charlie.song) committed
403
    norm = th.rand((g.number_of_edges(), 1)).to(ctx)
Minjie Wang's avatar
Minjie Wang committed
404
405

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
406
407
408
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
409
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
Minjie Wang's avatar
Minjie Wang committed
410
411
412
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r, norm)
413
    h_new_low = rgc_basis_low(g, h, r, norm)
Minjie Wang's avatar
Minjie Wang committed
414
    assert list(h_new.shape) == [100, O]
415
416
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
417

418
419
420
421
422
423
424
425
426
427
428
429
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
        rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
        rgc_bdd_low.weight = rgc_bdd.weight
        rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
        h = th.randn((100, I)).to(ctx)
        r = th.tensor(etype).to(ctx)
        h_new = rgc_bdd(g, h, r, norm)
        h_new_low = rgc_bdd_low(g, h, r, norm)
        assert list(h_new.shape) == [100, O]
        assert list(h_new_low.shape) == [100, O]
        assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
430
431
432

    # id input
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
433
434
435
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
436
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
Minjie Wang's avatar
Minjie Wang committed
437
438
439
    h = th.randint(0, I, (100,)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r)
440
    h_new_low = rgc_basis_low(g, h, r)
Minjie Wang's avatar
Minjie Wang committed
441
    assert list(h_new.shape) == [100, O]
442
443
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
444

445

446
447
@pytest.mark.parametrize('O', [1, 2, 8])
def test_rgcn_sorted(O):
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
    ctx = F.ctx()
    etype = []
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    g = g.to(F.ctx())
    # 5 etypes
    R = 5
    etype = [200, 200, 200, 200, 200]
    B = 2
    I = 10

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
    h = th.randn((100, I)).to(ctx)
    r = etype
    h_new = rgc_basis(g, h, r)
    h_new_low = rgc_basis_low(g, h, r)
    assert list(h_new.shape) == [100, O]
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)

471
472
473
474
475
476
477
478
479
480
481
482
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
        rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
        rgc_bdd_low.weight = rgc_bdd.weight
        rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
        h = th.randn((100, I)).to(ctx)
        r = etype
        h_new = rgc_bdd(g, h, r)
        h_new_low = rgc_bdd_low(g, h, r)
        assert list(h_new.shape) == [100, O]
        assert list(h_new_low.shape) == [100, O]
        assert F.allclose(h_new, h_new_low)
483
484

    # with norm
xiang song(charlie.song)'s avatar
xiang song(charlie.song) committed
485
    norm = th.rand((g.number_of_edges(), 1)).to(ctx)
486
487
488
489
490
491
492
493
494
495
496
497
498
499

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
    h = th.randn((100, I)).to(ctx)
    r = etype
    h_new = rgc_basis(g, h, r, norm)
    h_new_low = rgc_basis_low(g, h, r, norm)
    assert list(h_new.shape) == [100, O]
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)

500
501
502
503
504
505
506
507
508
509
510
511
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
        rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
        rgc_bdd_low.weight = rgc_bdd.weight
        rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
        h = th.randn((100, I)).to(ctx)
        r = etype
        h_new = rgc_bdd(g, h, r, norm)
        h_new_low = rgc_bdd_low(g, h, r, norm)
        assert list(h_new.shape) == [100, O]
        assert list(h_new_low.shape) == [100, O]
        assert F.allclose(h_new, h_new_low)
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527

    # id input
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
    h = th.randint(0, I, (100,)).to(ctx)
    r = etype
    h_new = rgc_basis(g, h, r)
    h_new_low = rgc_basis_low(g, h, r)
    assert list(h_new.shape) == [100, O]
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)


528
@parametrize_dtype
529
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
530
@pytest.mark.parametrize('out_dim', [1, 5])
531
532
@pytest.mark.parametrize('num_heads', [1, 4])
def test_gat_conv(g, idtype, out_dim, num_heads):
533
    g = g.astype(idtype).to(F.ctx())
534
    ctx = F.ctx()
535
    gat = nn.GATConv(5, out_dim, num_heads)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
536
    feat = F.randn((g.number_of_src_nodes(), 5))
537
    gat = gat.to(ctx)
538
    h = gat(g, feat)
539
540
541
542

    # test pickle
    th.save(gat, tmp_buffer)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
543
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
544
    _, a = gat(g, feat, get_attention=True)
545
    assert a.shape == (g.number_of_edges(), num_heads, 1)
546

547
548
549
550
551
    # test residual connection
    gat = nn.GATConv(5, out_dim, num_heads, residual=True)
    gat = gat.to(ctx)
    h = gat(g, feat)

552
@parametrize_dtype
553
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
554
555
556
@pytest.mark.parametrize('out_dim', [1, 2])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_gat_conv_bi(g, idtype, out_dim, num_heads):
557
558
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
559
    gat = nn.GATConv(5, out_dim, num_heads)
560
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
561
562
    gat = gat.to(ctx)
    h = gat(g, feat)
563
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
564
    _, a = gat(g, feat, get_attention=True)
565
    assert a.shape == (g.number_of_edges(), num_heads, 1)
566

567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_node_feats', [1, 5])
@pytest.mark.parametrize('out_edge_feats', [1, 5])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx() 
    egat = nn.EGATConv(in_node_feats=10,
                       in_edge_feats=5,
                       out_node_feats=out_node_feats,
                       out_edge_feats=out_edge_feats,
                       num_heads=num_heads)
    nfeat = F.randn((g.number_of_nodes(), 10))
    efeat = F.randn((g.number_of_edges(), 5))
    
    egat = egat.to(ctx)
    h, f = egat(g, nfeat, efeat)
    h, f, attn = egat(g, nfeat, efeat, True)

    th.save(egat, tmp_buffer)    

589
@parametrize_dtype
590
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
591
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm'])
592
593
def test_sage_conv(idtype, g, aggre_type):
    g = g.astype(idtype).to(F.ctx())
594
    sage = nn.SAGEConv(5, 10, aggre_type)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
595
    feat = F.randn((g.number_of_src_nodes(), 5))
596
    sage = sage.to(F.ctx())
597
598
    # test pickle
    th.save(sage, tmp_buffer)
599
600
601
    h = sage(g, feat)
    assert h.shape[-1] == 10

602
@parametrize_dtype
603
@pytest.mark.parametrize('g', get_cases(['bipartite']))
604
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm'])
605
606
@pytest.mark.parametrize('out_dim', [1, 2])
def test_sage_conv_bi(idtype, g, aggre_type, out_dim):
607
    g = g.astype(idtype).to(F.ctx())
608
    dst_dim = 5 if aggre_type != 'gcn' else 10
609
    sage = nn.SAGEConv((10, dst_dim), out_dim, aggre_type)
610
611
    feat = (F.randn((g.number_of_src_nodes(), 10)), F.randn((g.number_of_dst_nodes(), dst_dim)))
    sage = sage.to(F.ctx())
612
    h = sage(g, feat)
613
    assert h.shape[-1] == out_dim
614
    assert h.shape[0] == g.number_of_dst_nodes()
615

616
@parametrize_dtype
617
618
@pytest.mark.parametrize('out_dim', [1, 2])
def test_sage_conv2(idtype, out_dim):
619
    # TODO: add test for blocks
Mufei Li's avatar
Mufei Li committed
620
    # Test the case for graphs without edges
621
    g = dgl.heterograph({('_U', '_E', '_V'): ([], [])}, {'_U': 5, '_V': 3})
622
623
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
624
    sage = nn.SAGEConv((3, 3), out_dim, 'gcn')
Mufei Li's avatar
Mufei Li committed
625
626
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    sage = sage.to(ctx)
627
    h = sage(g, (F.copy_to(feat[0], F.ctx()), F.copy_to(feat[1], F.ctx())))
628
    assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
629
630
    assert h.shape[0] == 3
    for aggre_type in ['mean', 'pool', 'lstm']:
631
        sage = nn.SAGEConv((3, 1), out_dim, aggre_type)
Mufei Li's avatar
Mufei Li committed
632
633
634
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        sage = sage.to(ctx)
        h = sage(g, feat)
635
        assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
636
637
        assert h.shape[0] == 3

638
639
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
640
641
@pytest.mark.parametrize('out_dim', [1, 2])
def test_sgc_conv(g, idtype, out_dim):
642
    ctx = F.ctx()
643
    g = g.astype(idtype).to(ctx)
644
    # not cached
645
    sgc = nn.SGConv(5, out_dim, 3)
646
647
648
649

    # test pickle
    th.save(sgc, tmp_buffer)

650
    feat = F.randn((g.number_of_nodes(), 5))
651
    sgc = sgc.to(ctx)
652

653
    h = sgc(g, feat)
654
    assert h.shape[-1] == out_dim
655
656

    # cached
657
    sgc = nn.SGConv(5, out_dim, 3, True)
658
    sgc = sgc.to(ctx)
659
660
    h_0 = sgc(g, feat)
    h_1 = sgc(g, feat + 1)
661
    assert F.allclose(h_0, h_1)
662
    assert h_0.shape[-1] == out_dim
663

664
665
666
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_appnp_conv(g, idtype):
667
    ctx = F.ctx()
668
    g = g.astype(idtype).to(ctx)
669
    appnp = nn.APPNPConv(10, 0.1)
670
    feat = F.randn((g.number_of_nodes(), 5))
671
    appnp = appnp.to(ctx)
672
673
674
    
    # test pickle
    th.save(appnp, tmp_buffer)
675

676
    h = appnp(g, feat)
677
678
    assert h.shape[-1] == 5

679
@parametrize_dtype
680
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
681
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
682
683
def test_gin_conv(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
684
685
686
687
688
    ctx = F.ctx()
    gin = nn.GINConv(
        th.nn.Linear(5, 12),
        aggregator_type
    )
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
689
    feat = F.randn((g.number_of_src_nodes(), 5))
690
691
    gin = gin.to(ctx)
    h = gin(g, feat)
692
693
694
695

    # test pickle
    th.save(h, tmp_buffer)
    
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
696
    assert h.shape == (g.number_of_dst_nodes(), 12)
697

698
@parametrize_dtype
699
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
700
701
702
703
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
def test_gin_conv_bi(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
704
705
706
707
    gin = nn.GINConv(
        th.nn.Linear(5, 12),
        aggregator_type
    )
708
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
709
710
    gin = gin.to(ctx)
    h = gin(g, feat)
711
    assert h.shape == (g.number_of_dst_nodes(), 12)
712

713
@parametrize_dtype
714
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
715
716
def test_agnn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
717
718
    ctx = F.ctx()
    agnn = nn.AGNNConv(1)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
719
    feat = F.randn((g.number_of_src_nodes(), 5))
720
    agnn = agnn.to(ctx)
721
    h = agnn(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
722
    assert h.shape == (g.number_of_dst_nodes(), 5)
723

724
@parametrize_dtype
725
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
726
727
728
def test_agnn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
729
    agnn = nn.AGNNConv(1)
730
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
731
732
    agnn = agnn.to(ctx)
    h = agnn(g, feat)
733
    assert h.shape == (g.number_of_dst_nodes(), 5)
734

735
736
737
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_gated_graph_conv(g, idtype):
738
    ctx = F.ctx()
739
    g = g.astype(idtype).to(ctx)
740
741
    ggconv = nn.GatedGraphConv(5, 10, 5, 3)
    etypes = th.arange(g.number_of_edges()) % 3
742
    feat = F.randn((g.number_of_nodes(), 5))
743
744
    ggconv = ggconv.to(ctx)
    etypes = etypes.to(ctx)
745

746
    h = ggconv(g, feat, etypes)
747
748
749
    # current we only do shape check
    assert h.shape[-1] == 10

750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_gated_graph_conv_one_etype(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    ggconv = nn.GatedGraphConv(5, 10, 5, 1)
    etypes = th.zeros(g.number_of_edges())
    feat = F.randn((g.number_of_nodes(), 5))
    ggconv = ggconv.to(ctx)
    etypes = etypes.to(ctx)

    h = ggconv(g, feat, etypes)
    h2 = ggconv(g, feat)
    # current we only do shape check
    assert F.allclose(h, h2)
    assert h.shape[-1] == 10

767
@parametrize_dtype
768
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
769
770
def test_nn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
771
772
773
    ctx = F.ctx()
    edge_func = th.nn.Linear(4, 5 * 10)
    nnconv = nn.NNConv(5, 10, edge_func, 'mean')
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
774
    feat = F.randn((g.number_of_src_nodes(), 5))
775
776
777
778
779
780
    efeat = F.randn((g.number_of_edges(), 4))
    nnconv = nnconv.to(ctx)
    h = nnconv(g, feat, efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

781
@parametrize_dtype
782
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
783
784
785
def test_nn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
786
787
    edge_func = th.nn.Linear(4, 5 * 10)
    nnconv = nn.NNConv((5, 2), 10, edge_func, 'mean')
788
789
    feat = F.randn((g.number_of_src_nodes(), 5))
    feat_dst = F.randn((g.number_of_dst_nodes(), 2))
790
791
792
793
794
795
    efeat = F.randn((g.number_of_edges(), 4))
    nnconv = nnconv.to(ctx)
    h = nnconv(g, (feat, feat_dst), efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

796
@parametrize_dtype
797
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
798
799
def test_gmm_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
800
801
    ctx = F.ctx()
    gmmconv = nn.GMMConv(5, 10, 3, 4, 'mean')
802
    feat = F.randn((g.number_of_nodes(), 5))
803
    pseudo = F.randn((g.number_of_edges(), 3))
804
    gmmconv = gmmconv.to(ctx)
805
    h = gmmconv(g, feat, pseudo)
806
807
808
    # currently we only do shape check
    assert h.shape[-1] == 10

809
@parametrize_dtype
810
@pytest.mark.parametrize('g', get_cases(['bipartite', 'block-bipartite'], exclude=['zero-degree']))
811
812
813
def test_gmm_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
814
    gmmconv = nn.GMMConv((5, 2), 10, 3, 4, 'mean')
815
816
    feat = F.randn((g.number_of_src_nodes(), 5))
    feat_dst = F.randn((g.number_of_dst_nodes(), 2))
817
818
819
820
821
822
    pseudo = F.randn((g.number_of_edges(), 3))
    gmmconv = gmmconv.to(ctx)
    h = gmmconv(g, (feat, feat_dst), pseudo)
    # currently we only do shape check
    assert h.shape[-1] == 10

823
@parametrize_dtype
824
@pytest.mark.parametrize('norm_type', ['both', 'right', 'none'])
825
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite'], exclude=['zero-degree']))
826
827
@pytest.mark.parametrize('out_dim', [1, 2])
def test_dense_graph_conv(norm_type, g, idtype, out_dim):
828
    g = g.astype(idtype).to(F.ctx())
829
    ctx = F.ctx()
830
    # TODO(minjie): enable the following option after #1385
831
    adj = g.adjacency_matrix(transpose=True, ctx=ctx).to_dense()
832
833
    conv = nn.GraphConv(5, out_dim, norm=norm_type, bias=True)
    dense_conv = nn.DenseGraphConv(5, out_dim, norm=norm_type, bias=True)
834
835
    dense_conv.weight.data = conv.weight.data
    dense_conv.bias.data = conv.bias.data
836
    feat = F.randn((g.number_of_src_nodes(), 5))
837
838
    conv = conv.to(ctx)
    dense_conv = dense_conv.to(ctx)
839
840
    out_conv = conv(g, feat)
    out_dense_conv = dense_conv(adj, feat)
841
842
    assert F.allclose(out_conv, out_dense_conv)

843
@parametrize_dtype
844
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite']))
845
846
@pytest.mark.parametrize('out_dim', [1, 2])
def test_dense_sage_conv(g, idtype, out_dim):
847
    g = g.astype(idtype).to(F.ctx())
848
    ctx = F.ctx()
849
    adj = g.adjacency_matrix(transpose=True, ctx=ctx).to_dense()
850
851
    sage = nn.SAGEConv(5, out_dim, 'gcn')
    dense_sage = nn.DenseSAGEConv(5, out_dim)
852
    dense_sage.fc.weight.data = sage.fc_neigh.weight.data
853
    dense_sage.fc.bias.data = sage.bias.data
854
855
856
857
858
859
860
    if len(g.ntypes) == 2:
        feat = (
            F.randn((g.number_of_src_nodes(), 5)),
            F.randn((g.number_of_dst_nodes(), 5))
        )
    else:
        feat = F.randn((g.number_of_nodes(), 5))
861
862
    sage = sage.to(ctx)
    dense_sage = dense_sage.to(ctx)
863
864
    out_sage = sage(g, feat)
    out_dense_sage = dense_sage(adj, feat)
865
866
    assert F.allclose(out_sage, out_dense_sage), g

867
@parametrize_dtype
868
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
869
870
@pytest.mark.parametrize('out_dim', [1, 2])
def test_edge_conv(g, idtype, out_dim):
871
    g = g.astype(idtype).to(F.ctx())
872
    ctx = F.ctx()
873
    edge_conv = nn.EdgeConv(5, out_dim).to(ctx)
874
    print(edge_conv)
875
876
877
878

    # test pickle
    th.save(edge_conv, tmp_buffer)
    
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
879
    h0 = F.randn((g.number_of_src_nodes(), 5))
880
    h1 = edge_conv(g, h0)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
881
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)
882

883
@parametrize_dtype
884
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
885
886
@pytest.mark.parametrize('out_dim', [1, 2])
def test_edge_conv_bi(g, idtype, out_dim):
887
888
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
889
    edge_conv = nn.EdgeConv(5, out_dim).to(ctx)
890
    print(edge_conv)
891
    h0 = F.randn((g.number_of_src_nodes(), 5))
892
893
    x0 = F.randn((g.number_of_dst_nodes(), 5))
    h1 = edge_conv(g, (h0, x0))
894
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)
895
896
897
    
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
898
899
900
@pytest.mark.parametrize('out_dim', [1, 2])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_dotgat_conv(g, idtype, out_dim, num_heads):
901
902
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
903
    dotgat = nn.DotGatConv(5, out_dim, num_heads)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
904
    feat = F.randn((g.number_of_src_nodes(), 5))
905
    dotgat = dotgat.to(ctx)
906
907
908
909
    
    # test pickle
    th.save(dotgat, tmp_buffer)
    
910
    h = dotgat(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
911
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
912
    _, a = dotgat(g, feat, get_attention=True)
913
    assert a.shape == (g.number_of_edges(), num_heads, 1)
914
915
916

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
917
918
919
@pytest.mark.parametrize('out_dim', [1, 2])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_dotgat_conv_bi(g, idtype, out_dim, num_heads):
920
921
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
922
    dotgat = nn.DotGatConv((5, 5), out_dim, num_heads)
923
924
925
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
    dotgat = dotgat.to(ctx)
    h = dotgat(g, feat)
926
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
927
    _, a = dotgat(g, feat, get_attention=True)
928
    assert a.shape == (g.number_of_edges(), num_heads, 1)
929

930
931
@pytest.mark.parametrize('out_dim', [1, 2])
def test_dense_cheb_conv(out_dim):
932
933
934
    for k in range(1, 4):
        ctx = F.ctx()
        g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
935
        g = g.to(F.ctx())
936
        adj = g.adjacency_matrix(transpose=True, ctx=ctx).to_dense()
937
938
        cheb = nn.ChebConv(5, out_dim, k, None)
        dense_cheb = nn.DenseChebConv(5, out_dim, k)
Axel Nilsson's avatar
Axel Nilsson committed
939
940
        #for i in range(len(cheb.fc)):
        #    dense_cheb.W.data[i] = cheb.fc[i].weight.data.t()
941
        dense_cheb.W.data = cheb.linear.weight.data.transpose(-1, -2).view(k, 5, out_dim)
Axel Nilsson's avatar
Axel Nilsson committed
942
943
        if cheb.linear.bias is not None:
            dense_cheb.bias.data = cheb.linear.bias.data
944
        feat = F.randn((100, 5))
945
946
        cheb = cheb.to(ctx)
        dense_cheb = dense_cheb.to(ctx)
947
948
        out_cheb = cheb(g, feat, [2.0])
        out_dense_cheb = dense_cheb(adj, feat, 2.0)
Axel Nilsson's avatar
Axel Nilsson committed
949
        print(k, out_cheb, out_dense_cheb)
950
951
        assert F.allclose(out_cheb, out_dense_cheb)

952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
def test_sequential():
    ctx = F.ctx()
    # Test single graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat, e_feat):
            graph = graph.local_var()
            graph.ndata['h'] = n_feat
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            n_feat += graph.ndata['h']
            graph.apply_edges(fn.u_add_v('h', 'h', 'e'))
            e_feat += graph.edata['e']
            return n_feat, e_feat

    g = dgl.DGLGraph()
    g.add_nodes(3)
    g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
971
    g = g.to(F.ctx())
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    n_feat = F.randn((3, 4))
    e_feat = F.randn((9, 4))
    net = net.to(ctx)
    n_feat, e_feat = net(g, n_feat, e_feat)
    assert n_feat.shape == (3, 4)
    assert e_feat.shape == (9, 4)

    # Test multiple graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat):
            graph = graph.local_var()
            graph.ndata['h'] = n_feat
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            n_feat += graph.ndata['h']
            return n_feat.view(graph.number_of_nodes() // 2, 2, -1).sum(1)

992
993
994
    g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05)).to(F.ctx())
    g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2)).to(F.ctx())
    g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8)).to(F.ctx())
995
996
997
998
999
1000
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    net = net.to(ctx)
    n_feat = F.randn((32, 4))
    n_feat = net([g1, g2, g3], n_feat)
    assert n_feat.shape == (4, 4)

1001
1002
1003
1004
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_atomic_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
1005
1006
1007
1008
1009
1010
1011
1012
1013
    aconv = nn.AtomicConv(interaction_cutoffs=F.tensor([12.0, 12.0]),
                          rbf_kernel_means=F.tensor([0.0, 2.0]),
                          rbf_kernel_scaling=F.tensor([4.0, 4.0]),
                          features_to_use=F.tensor([6.0, 8.0]))

    ctx = F.ctx()
    if F.gpu_ctx():
        aconv = aconv.to(ctx)

1014
    feat = F.randn((g.number_of_nodes(), 1))
1015
1016
1017
    dist = F.randn((g.number_of_edges(), 1))

    h = aconv(g, feat, dist)
1018

1019
1020
1021
    # current we only do shape check
    assert h.shape[-1] == 4

1022
@parametrize_dtype
1023
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite'], exclude=['zero-degree']))
1024
1025
@pytest.mark.parametrize('out_dim', [1, 3])
def test_cf_conv(g, idtype, out_dim):
1026
    g = g.astype(idtype).to(F.ctx())
1027
1028
1029
    cfconv = nn.CFConv(node_in_feats=2,
                       edge_in_feats=3,
                       hidden_feats=2,
1030
                       out_feats=out_dim)
1031
1032
1033
1034
1035

    ctx = F.ctx()
    if F.gpu_ctx():
        cfconv = cfconv.to(ctx)

1036
    src_feats = F.randn((g.number_of_src_nodes(), 2))
1037
    edge_feats = F.randn((g.number_of_edges(), 3))
1038
1039
1040
1041
1042
1043
1044
    h = cfconv(g, src_feats, edge_feats)
    # current we only do shape check
    assert h.shape[-1] == out_dim

    # case for bipartite graphs
    dst_feats = F.randn((g.number_of_dst_nodes(), 3))
    h = cfconv(g, (src_feats, dst_feats), edge_feats)
1045
    # current we only do shape check
1046
    assert h.shape[-1] == out_dim
1047

1048
1049
1050
1051
1052
1053
def myagg(alist, dsttype):
    rst = alist[0]
    for i in range(1, len(alist)):
        rst = rst + (i + 1) * alist[i]
    return rst

1054
@parametrize_dtype
1055
@pytest.mark.parametrize('agg', ['sum', 'max', 'min', 'mean', 'stack', myagg])
1056
def test_hetero_conv(agg, idtype):
1057
    g = dgl.heterograph({
1058
1059
1060
        ('user', 'follows', 'user'): ([0, 0, 2, 1], [1, 2, 1, 3]),
        ('user', 'plays', 'game'): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]),
        ('store', 'sells', 'game'): ([0, 0, 1, 1], [0, 3, 1, 2])},
1061
        idtype=idtype, device=F.ctx())
1062
    conv = nn.HeteroGraphConv({
1063
1064
1065
        'follows': nn.GraphConv(2, 3, allow_zero_in_degree=True),
        'plays': nn.GraphConv(2, 4, allow_zero_in_degree=True),
        'sells': nn.GraphConv(3, 4, allow_zero_in_degree=True)},
1066
        agg)
1067
    conv = conv.to(F.ctx())
1068
1069
1070
1071

    # test pickle
    th.save(conv, tmp_buffer)

1072
1073
1074
1075
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))

1076
    h = conv(g, {'user': uf, 'game': gf, 'store': sf})
1077
1078
1079
1080
1081
1082
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
1083
        assert h['game'].shape == (4, 2, 4)
1084

1085
1086
    block = dgl.to_block(g.to(F.cpu()), {'user': [0, 1, 2, 3], 'game': [0, 1, 2, 3], 'store': []}).to(F.ctx())
    h = conv(block, ({'user': uf, 'game': gf, 'store': sf}, {'user': uf, 'game': gf, 'store': sf[0:0]}))
1087
1088
1089
1090
1091
1092
1093
1094
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 2, 4)

1095
    h = conv(block, {'user': uf, 'game': gf, 'store': sf})
1096
1097
1098
1099
1100
1101
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
1102
        assert h['game'].shape == (4, 2, 4)
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125

    # test with mod args
    class MyMod(th.nn.Module):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.carg2 = 0
            self.s1 = s1
            self.s2 = s2
        def forward(self, g, h, arg1=None, *, arg2=None):
            if arg1 is not None:
                self.carg1 += 1
            if arg2 is not None:
                self.carg2 += 1
            return th.zeros((g.number_of_dst_nodes(), self.s2))
    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
    conv = nn.HeteroGraphConv({
        'follows': mod1,
        'plays': mod2,
        'sells': mod3},
        agg)
1126
    conv = conv.to(F.ctx())
1127
1128
    mod_args = {'follows' : (1,), 'plays' : (1,)}
    mod_kwargs = {'sells' : {'arg2' : 'abc'}}
1129
    h = conv(g, {'user' : uf, 'game': gf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs)
1130
1131
1132
1133
1134
1135
1136
    assert mod1.carg1 == 1
    assert mod1.carg2 == 0
    assert mod2.carg1 == 1
    assert mod2.carg2 == 0
    assert mod3.carg1 == 0
    assert mod3.carg2 == 1

1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
    #conv on graph without any edges
    for etype in g.etypes:
        g = dgl.remove_edges(g, g.edges(form='eid', etype=etype), etype=etype)
    assert g.num_edges() == 0
    h = conv(g, {'user': uf, 'game': gf, 'store': sf})
    assert set(h.keys()) == {'user', 'game'}

    block = dgl.to_block(g.to(F.cpu()), {'user': [0, 1, 2, 3], 'game': [
                         0, 1, 2, 3], 'store': []}).to(F.ctx())
    h = conv(block, ({'user': uf, 'game': gf, 'store': sf},
             {'user': uf, 'game': gf, 'store': sf[0:0]}))
    assert set(h.keys()) == {'user', 'game'}

1150
1151
if __name__ == '__main__':
    test_graph_conv()
1152
1153
    test_graph_conv_e_weight()
    test_graph_conv_e_weight_norm()
1154
1155
1156
1157
    test_set2set()
    test_glob_att_pool()
    test_simple_pool()
    test_set_trans()
Minjie Wang's avatar
Minjie Wang committed
1158
    test_rgcn()
xiang song(charlie.song)'s avatar
xiang song(charlie.song) committed
1159
    test_rgcn_sorted()
1160
1161
    test_tagconv()
    test_gat_conv()
1162
    test_egat_conv()
1163
1164
1165
1166
1167
1168
    test_sage_conv()
    test_sgc_conv()
    test_appnp_conv()
    test_gin_conv()
    test_agnn_conv()
    test_gated_graph_conv()
1169
    test_gated_graph_conv_one_etype()
1170
1171
    test_nn_conv()
    test_gmm_conv()
1172
    test_dotgat_conv()
1173
1174
1175
    test_dense_graph_conv()
    test_dense_sage_conv()
    test_dense_cheb_conv()
1176
    test_sequential()
1177
    test_atomic_conv()
1178
    test_cf_conv()
1179
    test_hetero_conv()