"vscode:/vscode.git/clone" did not exist on "f4cd8040732f348b7c55e432ae772b6ea70520db"
test_nn.py 47.6 KB
Newer Older
1
import io
2
3
4
5
import torch as th
import networkx as nx
import dgl
import dgl.nn.pytorch as nn
6
import dgl.function as fn
7
import backend as F
8
import pytest
9
10
from test_utils.graph_cases import get_cases, random_graph, random_bipartite, random_dglgraph
from test_utils import parametrize_dtype
11
from copy import deepcopy
12
import pickle
13

14
15
import scipy as sp

16
17
tmp_buffer = io.BytesIO()

18
19
20
21
22
def _AXWb(A, X, W, b):
    X = th.matmul(X, W)
    Y = th.matmul(A, X.view(X.shape[0], -1)).view_as(X)
    return Y + b

23
24
@pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv0(out_dim):
25
    g = dgl.DGLGraph(nx.path_graph(3)).to(F.ctx())
26
    ctx = F.ctx()
27
    adj = g.adjacency_matrix(transpose=True, ctx=ctx)
28

29
    conv = nn.GraphConv(5, out_dim, norm='none', bias=True)
30
    conv = conv.to(ctx)
31
    print(conv)
32
33
34
35
36

    # test pickle
    th.save(conv, tmp_buffer)


37
    # test#1: basic
38
    h0 = F.ones((3, 5))
39
    h1 = conv(g, h0)
40
41
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
42
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
43
    # test#2: more-dim
44
    h0 = F.ones((3, 5, 5))
45
    h1 = conv(g, h0)
46
47
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
48
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
49

50
    conv = nn.GraphConv(5, out_dim)
51
    conv = conv.to(ctx)
52
    # test#3: basic
53
    h0 = F.ones((3, 5))
54
    h1 = conv(g, h0)
55
56
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
57
    # test#4: basic
58
    h0 = F.ones((3, 5, 5))
59
    h1 = conv(g, h0)
60
61
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
62

63
    conv = nn.GraphConv(5, out_dim)
64
    conv = conv.to(ctx)
65
    # test#3: basic
66
    h0 = F.ones((3, 5))
67
    h1 = conv(g, h0)
68
69
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
70
    # test#4: basic
71
    h0 = F.ones((3, 5, 5))
72
    h1 = conv(g, h0)
73
74
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
75
76
77
78
79

    # test rest_parameters
    old_weight = deepcopy(conv.weight.data)
    conv.reset_parameters()
    new_weight = conv.weight.data
80
    assert not F.allclose(old_weight, new_weight)
81

82
83
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite'], exclude=['zero-degree', 'dglgraph']))
84
@pytest.mark.parametrize('norm', ['none', 'both', 'right', 'left'])
85
86
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
87
88
@pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv(idtype, g, norm, weight, bias, out_dim):
89
90
    # Test one tensor input
    g = g.astype(idtype).to(F.ctx())
91
92
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(F.ctx())
    ext_w = F.randn((5, out_dim)).to(F.ctx())
93
94
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
95
96
    h = F.randn((nsrc, 5)).to(F.ctx())
    if weight:
97
        h_out = conv(g, h)
98
    else:
99
        h_out = conv(g, h, weight=ext_w)
100
    assert h_out.shape == (ndst, out_dim)
101

102
103
104
105
106
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['has_scalar_e_feature'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
107
108
@pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv_e_weight(idtype, g, norm, weight, bias, out_dim):
109
    g = g.astype(idtype).to(F.ctx())
110
111
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(F.ctx())
    ext_w = F.randn((5, out_dim)).to(F.ctx())
112
113
114
115
116
117
118
119
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
    e_w = g.edata['scalar_w']
    if weight:
        h_out = conv(g, h, edge_weight=e_w)
    else:
        h_out = conv(g, h, weight=ext_w, edge_weight=e_w)
120
    assert h_out.shape == (ndst, out_dim)
121
122
123
124
125
126

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['has_scalar_e_feature'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
127
128
@pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv_e_weight_norm(idtype, g, norm, weight, bias, out_dim):
129
    g = g.astype(idtype).to(F.ctx())
130
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(F.ctx())
131
132
133
134

    # test pickle
    th.save(conv, tmp_buffer)

135
    ext_w = F.randn((5, out_dim)).to(F.ctx())
136
137
138
139
140
141
142
143
144
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
    edgenorm = nn.EdgeWeightNorm(norm=norm)
    norm_weight = edgenorm(g, g.edata['scalar_w'])
    if weight:
        h_out = conv(g, h, edge_weight=norm_weight)
    else:
        h_out = conv(g, h, weight=ext_w, edge_weight=norm_weight)
145
    assert h_out.shape == (ndst, out_dim)
146

147
148
149
150
151
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
152
153
@pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv_bi(idtype, g, norm, weight, bias, out_dim):
154
155
    # Test a pair of tensor inputs
    g = g.astype(idtype).to(F.ctx())
156
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(F.ctx())
Mufei Li's avatar
Mufei Li committed
157

158
159
160
    # test pickle
    th.save(conv, tmp_buffer)

161
    ext_w = F.randn((5, out_dim)).to(F.ctx())
162
163
164
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
165
    h_dst = F.randn((ndst, out_dim)).to(F.ctx())
166
167
168
169
    if weight:
        h_out = conv(g, (h, h_dst))
    else:
        h_out = conv(g, (h, h_dst), weight=ext_w)
170
    assert h_out.shape == (ndst, out_dim)
171

172
173
174
175
176
177
178
179
180
181
182
183
def _S2AXWb(A, N, X, W, b):
    X1 = X * N
    X1 = th.matmul(A, X1.view(X1.shape[0], -1))
    X1 = X1 * N
    X2 = X1 * N
    X2 = th.matmul(A, X2.view(X2.shape[0], -1))
    X2 = X2 * N
    X = th.cat([X, X1, X2], dim=-1)
    Y = th.matmul(X, W.rot90())

    return Y + b

184
185
@pytest.mark.parametrize('out_dim', [1, 2])
def test_tagconv(out_dim):
186
    g = dgl.DGLGraph(nx.path_graph(3))
187
    g = g.to(F.ctx())
188
    ctx = F.ctx()
189
    adj = g.adjacency_matrix(transpose=True, ctx=ctx)
190
191
    norm = th.pow(g.in_degrees().float(), -0.5)

192
    conv = nn.TAGConv(5, out_dim, bias=True)
193
    conv = conv.to(ctx)
194
    print(conv)
Mufei Li's avatar
Mufei Li committed
195

196
197
    # test pickle
    th.save(conv, tmp_buffer)
198
199
200

    # test#1: basic
    h0 = F.ones((3, 5))
201
    h1 = conv(g, h0)
202
203
204
205
206
207
208
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    shp = norm.shape + (1,) * (h0.dim() - 1)
    norm = th.reshape(norm, shp).to(ctx)

    assert F.allclose(h1, _S2AXWb(adj, norm, h0, conv.lin.weight, conv.lin.bias))

209
    conv = nn.TAGConv(5, out_dim)
210
    conv = conv.to(ctx)
211

212
213
    # test#2: basic
    h0 = F.ones((3, 5))
214
    h1 = conv(g, h0)
215
    assert h1.shape[-1] == out_dim
216

217
    # test reset_parameters
218
219
220
221
222
    old_weight = deepcopy(conv.lin.weight.data)
    conv.reset_parameters()
    new_weight = conv.lin.weight.data
    assert not F.allclose(old_weight, new_weight)

223
def test_set2set():
224
    ctx = F.ctx()
225
    g = dgl.DGLGraph(nx.path_graph(10))
226
    g = g.to(F.ctx())
227
228

    s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
229
    s2s = s2s.to(ctx)
230
231
232
    print(s2s)

    # test#1: basic
233
    h0 = F.randn((g.number_of_nodes(), 5))
234
    h1 = s2s(g, h0)
235
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
236
237

    # test#2: batched graph
238
239
    g1 = dgl.DGLGraph(nx.path_graph(11)).to(F.ctx())
    g2 = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
240
    bg = dgl.batch([g, g1, g2])
241
    h0 = F.randn((bg.number_of_nodes(), 5))
242
    h1 = s2s(bg, h0)
243
244
245
    assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2

def test_glob_att_pool():
246
    ctx = F.ctx()
247
    g = dgl.DGLGraph(nx.path_graph(10))
248
    g = g.to(F.ctx())
249
250

    gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))
251
    gap = gap.to(ctx)
252
253
    print(gap)

254
255
256
    # test pickle
    th.save(gap, tmp_buffer)

257
    # test#1: basic
258
    h0 = F.randn((g.number_of_nodes(), 5))
259
    h1 = gap(g, h0)
260
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
261
262
263

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
264
    h0 = F.randn((bg.number_of_nodes(), 5))
265
    h1 = gap(bg, h0)
266
267
268
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2

def test_simple_pool():
269
    ctx = F.ctx()
270
    g = dgl.DGLGraph(nx.path_graph(15))
271
    g = g.to(F.ctx())
272
273
274
275
276
277
278
279

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
    sort_pool = nn.SortPooling(10) # k = 10
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
280
    h0 = F.randn((g.number_of_nodes(), 5))
281
282
283
284
    sum_pool = sum_pool.to(ctx)
    avg_pool = avg_pool.to(ctx)
    max_pool = max_pool.to(ctx)
    sort_pool = sort_pool.to(ctx)
285
    h1 = sum_pool(g, h0)
286
    assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
287
    h1 = avg_pool(g, h0)
288
    assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
289
    h1 = max_pool(g, h0)
290
    assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
291
    h1 = sort_pool(g, h0)
292
    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.dim() == 2
293
294

    # test#2: batched graph
295
    g_ = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
296
    bg = dgl.batch([g, g_, g, g_, g])
297
    h0 = F.randn((bg.number_of_nodes(), 5))
298
    h1 = sum_pool(bg, h0)
299
300
301
302
303
304
    truth = th.stack([F.sum(h0[:15], 0),
                      F.sum(h0[15:20], 0),
                      F.sum(h0[20:35], 0),
                      F.sum(h0[35:40], 0),
                      F.sum(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
305

306
    h1 = avg_pool(bg, h0)
307
308
309
310
311
312
    truth = th.stack([F.mean(h0[:15], 0),
                      F.mean(h0[15:20], 0),
                      F.mean(h0[20:35], 0),
                      F.mean(h0[35:40], 0),
                      F.mean(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
313

314
    h1 = max_pool(bg, h0)
315
316
317
318
319
320
    truth = th.stack([F.max(h0[:15], 0),
                      F.max(h0[15:20], 0),
                      F.max(h0[20:35], 0),
                      F.max(h0[35:40], 0),
                      F.max(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
321

322
    h1 = sort_pool(bg, h0)
323
324
325
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2

def test_set_trans():
326
    ctx = F.ctx()
327
328
329
330
331
    g = dgl.DGLGraph(nx.path_graph(15))

    st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'sab')
    st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'isab', 3)
    st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)
332
333
334
    st_enc_0 = st_enc_0.to(ctx)
    st_enc_1 = st_enc_1.to(ctx)
    st_dec = st_dec.to(ctx)
335
336
337
    print(st_enc_0, st_enc_1, st_dec)

    # test#1: basic
338
    h0 = F.randn((g.number_of_nodes(), 50))
339
    h1 = st_enc_0(g, h0)
340
    assert h1.shape == h0.shape
341
    h1 = st_enc_1(g, h0)
342
    assert h1.shape == h0.shape
343
    h2 = st_dec(g, h1)
344
    assert h2.shape[0] == 1 and h2.shape[1] == 200 and h2.dim() == 2
345
346
347
348
349

    # test#2: batched graph
    g1 = dgl.DGLGraph(nx.path_graph(5))
    g2 = dgl.DGLGraph(nx.path_graph(10))
    bg = dgl.batch([g, g1, g2])
350
    h0 = F.randn((bg.number_of_nodes(), 50))
351
    h1 = st_enc_0(bg, h0)
352
    assert h1.shape == h0.shape
353
    h1 = st_enc_1(bg, h0)
354
355
    assert h1.shape == h0.shape

356
    h2 = st_dec(bg, h1)
357
358
    assert h2.shape[0] == 3 and h2.shape[1] == 200 and h2.dim() == 2

359
360
@pytest.mark.parametrize('O', [1, 2, 8])
def test_rgcn(O):
Minjie Wang's avatar
Minjie Wang committed
361
362
363
    ctx = F.ctx()
    etype = []
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
364
    g = g.to(F.ctx())
Minjie Wang's avatar
Minjie Wang committed
365
366
367
368
369
370
371
372
    # 5 etypes
    R = 5
    for i in range(g.number_of_edges()):
        etype.append(i % 5)
    B = 2
    I = 10

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
373
374
375
376

    # test pickle
    th.save(rgc_basis, tmp_buffer)

377
378
379
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
380
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
Minjie Wang's avatar
Minjie Wang committed
381
382
383
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r)
384
    h_new_low = rgc_basis_low(g, h, r)
Minjie Wang's avatar
Minjie Wang committed
385
    assert list(h_new.shape) == [100, O]
386
387
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
388

389
390
391
392
393
394
395
396
397
398
399
400
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
        rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
        rgc_bdd_low.weight = rgc_bdd.weight
        rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
        h = th.randn((100, I)).to(ctx)
        r = th.tensor(etype).to(ctx)
        h_new = rgc_bdd(g, h, r)
        h_new_low = rgc_bdd_low(g, h, r)
        assert list(h_new.shape) == [100, O]
        assert list(h_new_low.shape) == [100, O]
        assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
401
402

    # with norm
xiang song(charlie.song)'s avatar
xiang song(charlie.song) committed
403
    norm = th.rand((g.number_of_edges(), 1)).to(ctx)
Minjie Wang's avatar
Minjie Wang committed
404
405

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
406
407
408
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
409
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
Minjie Wang's avatar
Minjie Wang committed
410
411
412
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r, norm)
413
    h_new_low = rgc_basis_low(g, h, r, norm)
Minjie Wang's avatar
Minjie Wang committed
414
    assert list(h_new.shape) == [100, O]
415
416
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
417

418
419
420
421
422
423
424
425
426
427
428
429
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
        rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
        rgc_bdd_low.weight = rgc_bdd.weight
        rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
        h = th.randn((100, I)).to(ctx)
        r = th.tensor(etype).to(ctx)
        h_new = rgc_bdd(g, h, r, norm)
        h_new_low = rgc_bdd_low(g, h, r, norm)
        assert list(h_new.shape) == [100, O]
        assert list(h_new_low.shape) == [100, O]
        assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
430
431
432

    # id input
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
433
434
435
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
436
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
Minjie Wang's avatar
Minjie Wang committed
437
438
439
    h = th.randint(0, I, (100,)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r)
440
    h_new_low = rgc_basis_low(g, h, r)
Minjie Wang's avatar
Minjie Wang committed
441
    assert list(h_new.shape) == [100, O]
442
443
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
444

445

446
447
@pytest.mark.parametrize('O', [1, 2, 8])
def test_rgcn_sorted(O):
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
    ctx = F.ctx()
    etype = []
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    g = g.to(F.ctx())
    # 5 etypes
    R = 5
    etype = [200, 200, 200, 200, 200]
    B = 2
    I = 10

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
    h = th.randn((100, I)).to(ctx)
    r = etype
    h_new = rgc_basis(g, h, r)
    h_new_low = rgc_basis_low(g, h, r)
    assert list(h_new.shape) == [100, O]
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)

471
472
473
474
475
476
477
478
479
480
481
482
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
        rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
        rgc_bdd_low.weight = rgc_bdd.weight
        rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
        h = th.randn((100, I)).to(ctx)
        r = etype
        h_new = rgc_bdd(g, h, r)
        h_new_low = rgc_bdd_low(g, h, r)
        assert list(h_new.shape) == [100, O]
        assert list(h_new_low.shape) == [100, O]
        assert F.allclose(h_new, h_new_low)
483
484

    # with norm
xiang song(charlie.song)'s avatar
xiang song(charlie.song) committed
485
    norm = th.rand((g.number_of_edges(), 1)).to(ctx)
486
487
488
489
490
491
492
493
494
495
496
497
498
499

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
    h = th.randn((100, I)).to(ctx)
    r = etype
    h_new = rgc_basis(g, h, r, norm)
    h_new_low = rgc_basis_low(g, h, r, norm)
    assert list(h_new.shape) == [100, O]
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)

500
501
502
503
504
505
506
507
508
509
510
511
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
        rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
        rgc_bdd_low.weight = rgc_bdd.weight
        rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
        h = th.randn((100, I)).to(ctx)
        r = etype
        h_new = rgc_bdd(g, h, r, norm)
        h_new_low = rgc_bdd_low(g, h, r, norm)
        assert list(h_new.shape) == [100, O]
        assert list(h_new_low.shape) == [100, O]
        assert F.allclose(h_new, h_new_low)
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527

    # id input
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
    h = th.randint(0, I, (100,)).to(ctx)
    r = etype
    h_new = rgc_basis(g, h, r)
    h_new_low = rgc_basis_low(g, h, r)
    assert list(h_new.shape) == [100, O]
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)


528
@parametrize_dtype
529
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
530
@pytest.mark.parametrize('out_dim', [1, 5])
531
532
@pytest.mark.parametrize('num_heads', [1, 4])
def test_gat_conv(g, idtype, out_dim, num_heads):
533
    g = g.astype(idtype).to(F.ctx())
534
    ctx = F.ctx()
535
    gat = nn.GATConv(5, out_dim, num_heads)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
536
    feat = F.randn((g.number_of_src_nodes(), 5))
537
    gat = gat.to(ctx)
538
    h = gat(g, feat)
539
540
541
542

    # test pickle
    th.save(gat, tmp_buffer)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
543
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
544
    _, a = gat(g, feat, get_attention=True)
545
    assert a.shape == (g.number_of_edges(), num_heads, 1)
546

547
548
549
550
551
    # test residual connection
    gat = nn.GATConv(5, out_dim, num_heads, residual=True)
    gat = gat.to(ctx)
    h = gat(g, feat)

552
@parametrize_dtype
553
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
554
555
556
@pytest.mark.parametrize('out_dim', [1, 2])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_gat_conv_bi(g, idtype, out_dim, num_heads):
557
558
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
559
    gat = nn.GATConv(5, out_dim, num_heads)
560
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
561
562
    gat = gat.to(ctx)
    h = gat(g, feat)
563
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
564
    _, a = gat(g, feat, get_attention=True)
565
    assert a.shape == (g.number_of_edges(), num_heads, 1)
566

Shaked Brody's avatar
Shaked Brody committed
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_dim', [1, 5])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_gatv2_conv(g, idtype, out_dim, num_heads):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    gat = nn.GATv2Conv(5, out_dim, num_heads)
    feat = F.randn((g.number_of_src_nodes(), 5))
    gat = gat.to(ctx)
    h = gat(g, feat)

    # test pickle
    th.save(gat, tmp_buffer)

    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
    _, a = gat(g, feat, get_attention=True)
    assert a.shape == (g.number_of_edges(), num_heads, 1)

    # test residual connection
    gat = nn.GATConv(5, out_dim, num_heads, residual=True)
    gat = gat.to(ctx)
    h = gat(g, feat)

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_dim', [1, 2])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_gatv2_conv_bi(g, idtype, out_dim, num_heads):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    gat = nn.GATv2Conv(5, out_dim, num_heads)
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
    gat = gat.to(ctx)
    h = gat(g, feat)
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
    _, a = gat(g, feat, get_attention=True)
    assert a.shape == (g.number_of_edges(), num_heads, 1)

606
607
608
609
610
611
612
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_node_feats', [1, 5])
@pytest.mark.parametrize('out_edge_feats', [1, 5])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads):
    g = g.astype(idtype).to(F.ctx())
Mufei Li's avatar
Mufei Li committed
613
    ctx = F.ctx()
614
615
616
617
618
619
620
    egat = nn.EGATConv(in_node_feats=10,
                       in_edge_feats=5,
                       out_node_feats=out_node_feats,
                       out_edge_feats=out_edge_feats,
                       num_heads=num_heads)
    nfeat = F.randn((g.number_of_nodes(), 10))
    efeat = F.randn((g.number_of_edges(), 5))
Mufei Li's avatar
Mufei Li committed
621

622
623
624
625
    egat = egat.to(ctx)
    h, f = egat(g, nfeat, efeat)
    h, f, attn = egat(g, nfeat, efeat, True)

Mufei Li's avatar
Mufei Li committed
626
    th.save(egat, tmp_buffer)
627

628
@parametrize_dtype
629
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
630
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm'])
631
632
def test_sage_conv(idtype, g, aggre_type):
    g = g.astype(idtype).to(F.ctx())
633
    sage = nn.SAGEConv(5, 10, aggre_type)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
634
    feat = F.randn((g.number_of_src_nodes(), 5))
635
    sage = sage.to(F.ctx())
636
637
    # test pickle
    th.save(sage, tmp_buffer)
638
639
640
    h = sage(g, feat)
    assert h.shape[-1] == 10

641
@parametrize_dtype
642
@pytest.mark.parametrize('g', get_cases(['bipartite']))
643
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm'])
644
645
@pytest.mark.parametrize('out_dim', [1, 2])
def test_sage_conv_bi(idtype, g, aggre_type, out_dim):
646
    g = g.astype(idtype).to(F.ctx())
647
    dst_dim = 5 if aggre_type != 'gcn' else 10
648
    sage = nn.SAGEConv((10, dst_dim), out_dim, aggre_type)
649
650
    feat = (F.randn((g.number_of_src_nodes(), 10)), F.randn((g.number_of_dst_nodes(), dst_dim)))
    sage = sage.to(F.ctx())
651
    h = sage(g, feat)
652
    assert h.shape[-1] == out_dim
653
    assert h.shape[0] == g.number_of_dst_nodes()
654

655
@parametrize_dtype
656
657
@pytest.mark.parametrize('out_dim', [1, 2])
def test_sage_conv2(idtype, out_dim):
658
    # TODO: add test for blocks
Mufei Li's avatar
Mufei Li committed
659
    # Test the case for graphs without edges
660
    g = dgl.heterograph({('_U', '_E', '_V'): ([], [])}, {'_U': 5, '_V': 3})
661
662
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
663
    sage = nn.SAGEConv((3, 3), out_dim, 'gcn')
Mufei Li's avatar
Mufei Li committed
664
665
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    sage = sage.to(ctx)
666
    h = sage(g, (F.copy_to(feat[0], F.ctx()), F.copy_to(feat[1], F.ctx())))
667
    assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
668
669
    assert h.shape[0] == 3
    for aggre_type in ['mean', 'pool', 'lstm']:
670
        sage = nn.SAGEConv((3, 1), out_dim, aggre_type)
Mufei Li's avatar
Mufei Li committed
671
672
673
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        sage = sage.to(ctx)
        h = sage(g, feat)
674
        assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
675
676
        assert h.shape[0] == 3

677
678
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
679
680
@pytest.mark.parametrize('out_dim', [1, 2])
def test_sgc_conv(g, idtype, out_dim):
681
    ctx = F.ctx()
682
    g = g.astype(idtype).to(ctx)
683
    # not cached
684
    sgc = nn.SGConv(5, out_dim, 3)
685
686
687
688

    # test pickle
    th.save(sgc, tmp_buffer)

689
    feat = F.randn((g.number_of_nodes(), 5))
690
    sgc = sgc.to(ctx)
691

692
    h = sgc(g, feat)
693
    assert h.shape[-1] == out_dim
694
695

    # cached
696
    sgc = nn.SGConv(5, out_dim, 3, True)
697
    sgc = sgc.to(ctx)
698
699
    h_0 = sgc(g, feat)
    h_1 = sgc(g, feat + 1)
700
    assert F.allclose(h_0, h_1)
701
    assert h_0.shape[-1] == out_dim
702

703
704
705
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_appnp_conv(g, idtype):
706
    ctx = F.ctx()
707
    g = g.astype(idtype).to(ctx)
708
    appnp = nn.APPNPConv(10, 0.1)
709
    feat = F.randn((g.number_of_nodes(), 5))
710
    appnp = appnp.to(ctx)
Mufei Li's avatar
Mufei Li committed
711

712
713
    # test pickle
    th.save(appnp, tmp_buffer)
714

715
    h = appnp(g, feat)
716
717
    assert h.shape[-1] == 5

718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_appnp_conv_e_weight(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    appnp = nn.APPNPConv(10, 0.1)
    feat = F.randn((g.number_of_nodes(), 5))
    eweight = F.ones((g.num_edges(), ))
    appnp = appnp.to(ctx)

    h = appnp(g, feat, edge_weight=eweight)
    assert h.shape[-1] == 5

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_gcn2conv_e_weight(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    gcn2conv = nn.GCN2Conv(5, layer=2, alpha=0.5,
                           project_initial_features=True)
    feat = F.randn((g.number_of_nodes(), 5))
    eweight = F.ones((g.num_edges(), ))
    gcn2conv = gcn2conv.to(ctx)
    res = feat
    h = gcn2conv(g, res, feat, edge_weight=eweight)
    assert h.shape[-1] == 5


@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_sgconv_e_weight(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    sgconv = nn.SGConv(5, 5, 3)
    feat = F.randn((g.number_of_nodes(), 5))
    eweight = F.ones((g.num_edges(), ))
    sgconv = sgconv.to(ctx)
    h = sgconv(g, feat, edge_weight=eweight)
    assert h.shape[-1] == 5

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_tagconv_e_weight(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    conv = nn.TAGConv(5, 5, bias=True)
    conv = conv.to(ctx)
    feat = F.randn((g.number_of_nodes(), 5))
    eweight = F.ones((g.num_edges(), ))
    conv = conv.to(ctx)
    h = conv(g, feat, edge_weight=eweight)
    assert h.shape[-1] == 5

772
@parametrize_dtype
773
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
774
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
775
776
def test_gin_conv(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
777
778
779
780
781
    ctx = F.ctx()
    gin = nn.GINConv(
        th.nn.Linear(5, 12),
        aggregator_type
    )
VoVAllen's avatar
VoVAllen committed
782
    th.save(gin, tmp_buffer)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
783
    feat = F.randn((g.number_of_src_nodes(), 5))
784
785
    gin = gin.to(ctx)
    h = gin(g, feat)
786
787

    # test pickle
VoVAllen's avatar
VoVAllen committed
788
    th.save(gin, tmp_buffer)
Mufei Li's avatar
Mufei Li committed
789

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
790
    assert h.shape == (g.number_of_dst_nodes(), 12)
791

Mufei Li's avatar
Mufei Li committed
792
793
794
795
    gin = nn.GINConv(None, aggregator_type)
    th.save(gin, tmp_buffer)
    gin = gin.to(ctx)
    h = gin(g, feat)
796

797
@parametrize_dtype
798
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
799
800
801
802
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
def test_gin_conv_bi(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
803
804
805
806
    gin = nn.GINConv(
        th.nn.Linear(5, 12),
        aggregator_type
    )
807
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
808
809
    gin = gin.to(ctx)
    h = gin(g, feat)
810
    assert h.shape == (g.number_of_dst_nodes(), 12)
811

812
@parametrize_dtype
813
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
814
815
def test_agnn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
816
817
    ctx = F.ctx()
    agnn = nn.AGNNConv(1)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
818
    feat = F.randn((g.number_of_src_nodes(), 5))
819
    agnn = agnn.to(ctx)
820
    h = agnn(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
821
    assert h.shape == (g.number_of_dst_nodes(), 5)
822

823
@parametrize_dtype
824
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
825
826
827
def test_agnn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
828
    agnn = nn.AGNNConv(1)
829
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
830
831
    agnn = agnn.to(ctx)
    h = agnn(g, feat)
832
    assert h.shape == (g.number_of_dst_nodes(), 5)
833

834
835
836
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_gated_graph_conv(g, idtype):
837
    ctx = F.ctx()
838
    g = g.astype(idtype).to(ctx)
839
840
    ggconv = nn.GatedGraphConv(5, 10, 5, 3)
    etypes = th.arange(g.number_of_edges()) % 3
841
    feat = F.randn((g.number_of_nodes(), 5))
842
843
    ggconv = ggconv.to(ctx)
    etypes = etypes.to(ctx)
844

845
    h = ggconv(g, feat, etypes)
846
847
848
    # current we only do shape check
    assert h.shape[-1] == 10

849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_gated_graph_conv_one_etype(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    ggconv = nn.GatedGraphConv(5, 10, 5, 1)
    etypes = th.zeros(g.number_of_edges())
    feat = F.randn((g.number_of_nodes(), 5))
    ggconv = ggconv.to(ctx)
    etypes = etypes.to(ctx)

    h = ggconv(g, feat, etypes)
    h2 = ggconv(g, feat)
    # current we only do shape check
    assert F.allclose(h, h2)
    assert h.shape[-1] == 10

866
@parametrize_dtype
867
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
868
869
def test_nn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
870
871
872
    ctx = F.ctx()
    edge_func = th.nn.Linear(4, 5 * 10)
    nnconv = nn.NNConv(5, 10, edge_func, 'mean')
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
873
    feat = F.randn((g.number_of_src_nodes(), 5))
874
875
876
877
878
879
    efeat = F.randn((g.number_of_edges(), 4))
    nnconv = nnconv.to(ctx)
    h = nnconv(g, feat, efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

880
@parametrize_dtype
881
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
882
883
884
def test_nn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
885
886
    edge_func = th.nn.Linear(4, 5 * 10)
    nnconv = nn.NNConv((5, 2), 10, edge_func, 'mean')
887
888
    feat = F.randn((g.number_of_src_nodes(), 5))
    feat_dst = F.randn((g.number_of_dst_nodes(), 2))
889
890
891
892
893
894
    efeat = F.randn((g.number_of_edges(), 4))
    nnconv = nnconv.to(ctx)
    h = nnconv(g, (feat, feat_dst), efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

895
@parametrize_dtype
896
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
897
898
def test_gmm_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
899
900
    ctx = F.ctx()
    gmmconv = nn.GMMConv(5, 10, 3, 4, 'mean')
901
    feat = F.randn((g.number_of_nodes(), 5))
902
    pseudo = F.randn((g.number_of_edges(), 3))
903
    gmmconv = gmmconv.to(ctx)
904
    h = gmmconv(g, feat, pseudo)
905
906
907
    # currently we only do shape check
    assert h.shape[-1] == 10

908
@parametrize_dtype
909
@pytest.mark.parametrize('g', get_cases(['bipartite', 'block-bipartite'], exclude=['zero-degree']))
910
911
912
def test_gmm_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
913
    gmmconv = nn.GMMConv((5, 2), 10, 3, 4, 'mean')
914
915
    feat = F.randn((g.number_of_src_nodes(), 5))
    feat_dst = F.randn((g.number_of_dst_nodes(), 2))
916
917
918
919
920
921
    pseudo = F.randn((g.number_of_edges(), 3))
    gmmconv = gmmconv.to(ctx)
    h = gmmconv(g, (feat, feat_dst), pseudo)
    # currently we only do shape check
    assert h.shape[-1] == 10

922
@parametrize_dtype
923
@pytest.mark.parametrize('norm_type', ['both', 'right', 'none'])
924
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite'], exclude=['zero-degree']))
925
926
@pytest.mark.parametrize('out_dim', [1, 2])
def test_dense_graph_conv(norm_type, g, idtype, out_dim):
927
    g = g.astype(idtype).to(F.ctx())
928
    ctx = F.ctx()
929
    # TODO(minjie): enable the following option after #1385
930
    adj = g.adjacency_matrix(transpose=True, ctx=ctx).to_dense()
931
932
    conv = nn.GraphConv(5, out_dim, norm=norm_type, bias=True)
    dense_conv = nn.DenseGraphConv(5, out_dim, norm=norm_type, bias=True)
933
934
    dense_conv.weight.data = conv.weight.data
    dense_conv.bias.data = conv.bias.data
935
    feat = F.randn((g.number_of_src_nodes(), 5))
936
937
    conv = conv.to(ctx)
    dense_conv = dense_conv.to(ctx)
938
939
    out_conv = conv(g, feat)
    out_dense_conv = dense_conv(adj, feat)
940
941
    assert F.allclose(out_conv, out_dense_conv)

942
@parametrize_dtype
943
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite']))
944
945
@pytest.mark.parametrize('out_dim', [1, 2])
def test_dense_sage_conv(g, idtype, out_dim):
946
    g = g.astype(idtype).to(F.ctx())
947
    ctx = F.ctx()
948
    adj = g.adjacency_matrix(transpose=True, ctx=ctx).to_dense()
949
950
    sage = nn.SAGEConv(5, out_dim, 'gcn')
    dense_sage = nn.DenseSAGEConv(5, out_dim)
951
    dense_sage.fc.weight.data = sage.fc_neigh.weight.data
952
    dense_sage.fc.bias.data = sage.bias.data
953
954
955
956
957
958
959
    if len(g.ntypes) == 2:
        feat = (
            F.randn((g.number_of_src_nodes(), 5)),
            F.randn((g.number_of_dst_nodes(), 5))
        )
    else:
        feat = F.randn((g.number_of_nodes(), 5))
960
961
    sage = sage.to(ctx)
    dense_sage = dense_sage.to(ctx)
962
963
    out_sage = sage(g, feat)
    out_dense_sage = dense_sage(adj, feat)
964
965
    assert F.allclose(out_sage, out_dense_sage), g

966
@parametrize_dtype
967
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
968
969
@pytest.mark.parametrize('out_dim', [1, 2])
def test_edge_conv(g, idtype, out_dim):
970
    g = g.astype(idtype).to(F.ctx())
971
    ctx = F.ctx()
972
    edge_conv = nn.EdgeConv(5, out_dim).to(ctx)
973
    print(edge_conv)
974
975
976

    # test pickle
    th.save(edge_conv, tmp_buffer)
Mufei Li's avatar
Mufei Li committed
977

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
978
    h0 = F.randn((g.number_of_src_nodes(), 5))
979
    h1 = edge_conv(g, h0)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
980
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)
981

982
@parametrize_dtype
983
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
984
985
@pytest.mark.parametrize('out_dim', [1, 2])
def test_edge_conv_bi(g, idtype, out_dim):
986
987
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
988
    edge_conv = nn.EdgeConv(5, out_dim).to(ctx)
989
    print(edge_conv)
990
    h0 = F.randn((g.number_of_src_nodes(), 5))
991
992
    x0 = F.randn((g.number_of_dst_nodes(), 5))
    h1 = edge_conv(g, (h0, x0))
993
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)
Mufei Li's avatar
Mufei Li committed
994

995
996
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
997
998
999
@pytest.mark.parametrize('out_dim', [1, 2])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_dotgat_conv(g, idtype, out_dim, num_heads):
1000
1001
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1002
    dotgat = nn.DotGatConv(5, out_dim, num_heads)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1003
    feat = F.randn((g.number_of_src_nodes(), 5))
1004
    dotgat = dotgat.to(ctx)
Mufei Li's avatar
Mufei Li committed
1005

1006
1007
    # test pickle
    th.save(dotgat, tmp_buffer)
Mufei Li's avatar
Mufei Li committed
1008

1009
    h = dotgat(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1010
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
1011
    _, a = dotgat(g, feat, get_attention=True)
1012
    assert a.shape == (g.number_of_edges(), num_heads, 1)
1013
1014
1015

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
1016
1017
1018
@pytest.mark.parametrize('out_dim', [1, 2])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_dotgat_conv_bi(g, idtype, out_dim, num_heads):
1019
1020
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1021
    dotgat = nn.DotGatConv((5, 5), out_dim, num_heads)
1022
1023
1024
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
    dotgat = dotgat.to(ctx)
    h = dotgat(g, feat)
1025
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
1026
    _, a = dotgat(g, feat, get_attention=True)
1027
    assert a.shape == (g.number_of_edges(), num_heads, 1)
1028

1029
1030
@pytest.mark.parametrize('out_dim', [1, 2])
def test_dense_cheb_conv(out_dim):
1031
1032
1033
    for k in range(1, 4):
        ctx = F.ctx()
        g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
1034
        g = g.to(F.ctx())
1035
        adj = g.adjacency_matrix(transpose=True, ctx=ctx).to_dense()
1036
1037
        cheb = nn.ChebConv(5, out_dim, k, None)
        dense_cheb = nn.DenseChebConv(5, out_dim, k)
Axel Nilsson's avatar
Axel Nilsson committed
1038
1039
        #for i in range(len(cheb.fc)):
        #    dense_cheb.W.data[i] = cheb.fc[i].weight.data.t()
1040
        dense_cheb.W.data = cheb.linear.weight.data.transpose(-1, -2).view(k, 5, out_dim)
Axel Nilsson's avatar
Axel Nilsson committed
1041
1042
        if cheb.linear.bias is not None:
            dense_cheb.bias.data = cheb.linear.bias.data
1043
        feat = F.randn((100, 5))
1044
1045
        cheb = cheb.to(ctx)
        dense_cheb = dense_cheb.to(ctx)
1046
1047
        out_cheb = cheb(g, feat, [2.0])
        out_dense_cheb = dense_cheb(adj, feat, 2.0)
Axel Nilsson's avatar
Axel Nilsson committed
1048
        print(k, out_cheb, out_dense_cheb)
1049
1050
        assert F.allclose(out_cheb, out_dense_cheb)

1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
def test_sequential():
    ctx = F.ctx()
    # Test single graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat, e_feat):
            graph = graph.local_var()
            graph.ndata['h'] = n_feat
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            n_feat += graph.ndata['h']
            graph.apply_edges(fn.u_add_v('h', 'h', 'e'))
            e_feat += graph.edata['e']
            return n_feat, e_feat

    g = dgl.DGLGraph()
    g.add_nodes(3)
    g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
1070
    g = g.to(F.ctx())
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    n_feat = F.randn((3, 4))
    e_feat = F.randn((9, 4))
    net = net.to(ctx)
    n_feat, e_feat = net(g, n_feat, e_feat)
    assert n_feat.shape == (3, 4)
    assert e_feat.shape == (9, 4)

    # Test multiple graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat):
            graph = graph.local_var()
            graph.ndata['h'] = n_feat
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            n_feat += graph.ndata['h']
            return n_feat.view(graph.number_of_nodes() // 2, 2, -1).sum(1)

1091
1092
1093
    g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05)).to(F.ctx())
    g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2)).to(F.ctx())
    g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8)).to(F.ctx())
1094
1095
1096
1097
1098
1099
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    net = net.to(ctx)
    n_feat = F.randn((32, 4))
    n_feat = net([g1, g2, g3], n_feat)
    assert n_feat.shape == (4, 4)

1100
1101
1102
1103
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_atomic_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
1104
1105
1106
1107
1108
1109
1110
1111
1112
    aconv = nn.AtomicConv(interaction_cutoffs=F.tensor([12.0, 12.0]),
                          rbf_kernel_means=F.tensor([0.0, 2.0]),
                          rbf_kernel_scaling=F.tensor([4.0, 4.0]),
                          features_to_use=F.tensor([6.0, 8.0]))

    ctx = F.ctx()
    if F.gpu_ctx():
        aconv = aconv.to(ctx)

1113
    feat = F.randn((g.number_of_nodes(), 1))
1114
1115
1116
    dist = F.randn((g.number_of_edges(), 1))

    h = aconv(g, feat, dist)
1117

1118
1119
1120
    # current we only do shape check
    assert h.shape[-1] == 4

1121
@parametrize_dtype
1122
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite'], exclude=['zero-degree']))
1123
1124
@pytest.mark.parametrize('out_dim', [1, 3])
def test_cf_conv(g, idtype, out_dim):
1125
    g = g.astype(idtype).to(F.ctx())
1126
1127
1128
    cfconv = nn.CFConv(node_in_feats=2,
                       edge_in_feats=3,
                       hidden_feats=2,
1129
                       out_feats=out_dim)
1130
1131
1132
1133
1134

    ctx = F.ctx()
    if F.gpu_ctx():
        cfconv = cfconv.to(ctx)

1135
    src_feats = F.randn((g.number_of_src_nodes(), 2))
1136
    edge_feats = F.randn((g.number_of_edges(), 3))
1137
1138
1139
1140
1141
1142
1143
    h = cfconv(g, src_feats, edge_feats)
    # current we only do shape check
    assert h.shape[-1] == out_dim

    # case for bipartite graphs
    dst_feats = F.randn((g.number_of_dst_nodes(), 3))
    h = cfconv(g, (src_feats, dst_feats), edge_feats)
1144
    # current we only do shape check
1145
    assert h.shape[-1] == out_dim
1146

1147
1148
1149
1150
1151
1152
def myagg(alist, dsttype):
    rst = alist[0]
    for i in range(1, len(alist)):
        rst = rst + (i + 1) * alist[i]
    return rst

1153
@parametrize_dtype
1154
@pytest.mark.parametrize('agg', ['sum', 'max', 'min', 'mean', 'stack', myagg])
1155
def test_hetero_conv(agg, idtype):
1156
    g = dgl.heterograph({
1157
1158
1159
        ('user', 'follows', 'user'): ([0, 0, 2, 1], [1, 2, 1, 3]),
        ('user', 'plays', 'game'): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]),
        ('store', 'sells', 'game'): ([0, 0, 1, 1], [0, 3, 1, 2])},
1160
        idtype=idtype, device=F.ctx())
1161
    conv = nn.HeteroGraphConv({
1162
1163
1164
        'follows': nn.GraphConv(2, 3, allow_zero_in_degree=True),
        'plays': nn.GraphConv(2, 4, allow_zero_in_degree=True),
        'sells': nn.GraphConv(3, 4, allow_zero_in_degree=True)},
1165
        agg)
1166
    conv = conv.to(F.ctx())
1167
1168
1169
1170

    # test pickle
    th.save(conv, tmp_buffer)

1171
1172
1173
1174
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))

1175
    h = conv(g, {'user': uf, 'game': gf, 'store': sf})
1176
1177
1178
1179
1180
1181
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
1182
        assert h['game'].shape == (4, 2, 4)
1183

1184
1185
    block = dgl.to_block(g.to(F.cpu()), {'user': [0, 1, 2, 3], 'game': [0, 1, 2, 3], 'store': []}).to(F.ctx())
    h = conv(block, ({'user': uf, 'game': gf, 'store': sf}, {'user': uf, 'game': gf, 'store': sf[0:0]}))
1186
1187
1188
1189
1190
1191
1192
1193
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 2, 4)

1194
    h = conv(block, {'user': uf, 'game': gf, 'store': sf})
1195
1196
1197
1198
1199
1200
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
1201
        assert h['game'].shape == (4, 2, 4)
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224

    # test with mod args
    class MyMod(th.nn.Module):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.carg2 = 0
            self.s1 = s1
            self.s2 = s2
        def forward(self, g, h, arg1=None, *, arg2=None):
            if arg1 is not None:
                self.carg1 += 1
            if arg2 is not None:
                self.carg2 += 1
            return th.zeros((g.number_of_dst_nodes(), self.s2))
    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
    conv = nn.HeteroGraphConv({
        'follows': mod1,
        'plays': mod2,
        'sells': mod3},
        agg)
1225
    conv = conv.to(F.ctx())
1226
1227
    mod_args = {'follows' : (1,), 'plays' : (1,)}
    mod_kwargs = {'sells' : {'arg2' : 'abc'}}
1228
    h = conv(g, {'user' : uf, 'game': gf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs)
1229
1230
1231
1232
1233
1234
1235
    assert mod1.carg1 == 1
    assert mod1.carg2 == 0
    assert mod2.carg1 == 1
    assert mod2.carg2 == 0
    assert mod3.carg1 == 0
    assert mod3.carg2 == 1

1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
    #conv on graph without any edges
    for etype in g.etypes:
        g = dgl.remove_edges(g, g.edges(form='eid', etype=etype), etype=etype)
    assert g.num_edges() == 0
    h = conv(g, {'user': uf, 'game': gf, 'store': sf})
    assert set(h.keys()) == {'user', 'game'}

    block = dgl.to_block(g.to(F.cpu()), {'user': [0, 1, 2, 3], 'game': [
                         0, 1, 2, 3], 'store': []}).to(F.ctx())
    h = conv(block, ({'user': uf, 'game': gf, 'store': sf},
             {'user': uf, 'game': gf, 'store': sf[0:0]}))
    assert set(h.keys()) == {'user', 'game'}

1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
@pytest.mark.parametrize('out_dim', [1, 2, 100])
def test_hetero_linear(out_dim):
    in_feats = {
        'user': F.randn((2, 1)),
        ('user', 'follows', 'user'): F.randn((3, 2))
    }

    layer = nn.HeteroLinear({'user': 1, ('user', 'follows', 'user'): 2}, out_dim)
    layer = layer.to(F.ctx())
    out_feats = layer(in_feats)
    assert out_feats['user'].shape == (2, out_dim)
    assert out_feats[('user', 'follows', 'user')].shape == (3, out_dim)

@pytest.mark.parametrize('out_dim', [1, 2, 100])
def test_hetero_embedding(out_dim):
    layer = nn.HeteroEmbedding({'user': 2, ('user', 'follows', 'user'): 3}, out_dim)
    layer = layer.to(F.ctx())

    embeds = layer.weight
    assert embeds['user'].shape == (2, out_dim)
    assert embeds[('user', 'follows', 'user')].shape == (3, out_dim)

    embeds = layer({
        'user': F.tensor([0], dtype=F.int64),
        ('user', 'follows', 'user'): F.tensor([0, 2], dtype=F.int64)
    })
    assert embeds['user'].shape == (1, out_dim)
    assert embeds[('user', 'follows', 'user')].shape == (2, out_dim)

Mufei Li's avatar
Mufei Li committed
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_dim', [1, 2])
def test_gnnexplainer(g, idtype, out_dim):
    g = g.astype(idtype).to(F.ctx())
    feat = F.randn((g.num_nodes(), 5))

    class Model(th.nn.Module):
        def __init__(self, in_feats, out_feats, graph=False):
            super(Model, self).__init__()
            self.linear = th.nn.Linear(in_feats, out_feats)
            if graph:
                self.pool = nn.AvgPooling()
            else:
                self.pool = None

        def forward(self, graph, feat, eweight=None):
            with graph.local_scope():
                feat = self.linear(feat)
                graph.ndata['h'] = feat
                if eweight is None:
                    graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
                else:
                    graph.edata['w'] = eweight
                    graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h'))

                if self.pool:
                    return self.pool(graph, graph.ndata['h'])
                else:
                    return graph.ndata['h']

    # Explain node prediction
    model = Model(5, out_dim)
    model = model.to(F.ctx())
    explainer = nn.GNNExplainer(model, num_hops=1)
    new_center, sg, feat_mask, edge_mask = explainer.explain_node(0, g, feat)

    # Explain graph prediction
    model = Model(5, out_dim, graph=True)
    model = model.to(F.ctx())
    explainer = nn.GNNExplainer(model, num_hops=1)
    feat_mask, edge_mask = explainer.explain_graph(g, feat)

Mufei Li's avatar
Mufei Li committed
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
def test_jumping_knowledge():
    ctx = F.ctx()
    num_layers = 2
    num_nodes = 3
    num_feats = 4

    feat_list = [th.randn((num_nodes, num_feats)).to(ctx) for _ in range(num_layers)]

    model = nn.JumpingKnowledge('cat').to(ctx)
    model.reset_parameters()
    assert model(feat_list).shape == (num_nodes, num_layers * num_feats)

    model = nn.JumpingKnowledge('max').to(ctx)
    model.reset_parameters()
    assert model(feat_list).shape == (num_nodes, num_feats)

    model = nn.JumpingKnowledge('lstm', num_feats, num_layers).to(ctx)
    model.reset_parameters()
    assert model(feat_list).shape == (num_nodes, num_feats)

Mufei Li's avatar
Mufei Li committed
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
@pytest.mark.parametrize('op', ['dot', 'cos', 'ele', 'cat'])
def test_edge_predictor(op):
    ctx = F.ctx()
    num_pairs = 3
    in_feats = 4
    out_feats = 5
    h_src = th.randn((num_pairs, in_feats)).to(ctx)
    h_dst = th.randn((num_pairs, in_feats)).to(ctx)

    pred = nn.EdgePredictor(op)
    if op in ['dot', 'cos']:
        assert pred(h_src, h_dst).shape == (num_pairs, 1)
    elif op == 'ele':
        assert pred(h_src, h_dst).shape == (num_pairs, in_feats)
    else:
        assert pred(h_src, h_dst).shape == (num_pairs, 2 * in_feats)
    pred = nn.EdgePredictor(op, in_feats, out_feats, bias=True).to(ctx)
    assert pred(h_src, h_dst).shape == (num_pairs, out_feats)

Mufei Li's avatar
Mufei Li committed
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379

def test_ke_score_funcs():
    ctx = F.ctx()
    num_edges = 30
    num_rels = 3
    nfeats = 4

    h_src = th.randn((num_edges, nfeats)).to(ctx)
    h_dst = th.randn((num_edges, nfeats)).to(ctx)
    rels = th.randint(low=0, high=num_rels, size=(num_edges,)).to(ctx)

    score_func = nn.TransE(num_rels=num_rels, feats=nfeats).to(ctx)
    score_func.reset_parameters()
    score_func(h_src, h_dst, rels).shape == (num_edges)

    score_func = nn.TransR(num_rels=num_rels, rfeats=nfeats - 1, nfeats=nfeats).to(ctx)
    score_func.reset_parameters()
    score_func(h_src, h_dst, rels).shape == (num_edges)


1380
def test_twirls():
1381
1382
1383
1384
1385
    g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
    feat = th.ones(6, 10)
    conv = nn.TWIRLSConv(10, 2, 128, prop_step = 64)
    res = conv(g , feat)
    assert ( res.size() == (6,2) )
1386

1387
1388


1389
1390
if __name__ == '__main__':
    test_graph_conv()
1391
1392
    test_graph_conv_e_weight()
    test_graph_conv_e_weight_norm()
1393
1394
1395
1396
    test_set2set()
    test_glob_att_pool()
    test_simple_pool()
    test_set_trans()
Minjie Wang's avatar
Minjie Wang committed
1397
    test_rgcn()
xiang song(charlie.song)'s avatar
xiang song(charlie.song) committed
1398
    test_rgcn_sorted()
1399
1400
    test_tagconv()
    test_gat_conv()
Shaked Brody's avatar
Shaked Brody committed
1401
    test_gatv2_conv()
1402
    test_egat_conv()
1403
1404
1405
1406
1407
1408
    test_sage_conv()
    test_sgc_conv()
    test_appnp_conv()
    test_gin_conv()
    test_agnn_conv()
    test_gated_graph_conv()
1409
    test_gated_graph_conv_one_etype()
1410
1411
    test_nn_conv()
    test_gmm_conv()
1412
    test_dotgat_conv()
1413
1414
1415
    test_dense_graph_conv()
    test_dense_sage_conv()
    test_dense_cheb_conv()
1416
    test_sequential()
1417
    test_atomic_conv()
1418
    test_cf_conv()
1419
    test_hetero_conv()
Mufei Li's avatar
Mufei Li committed
1420
    test_twirls()