test_nn.py 50.6 KB
Newer Older
1
import io
2
3
4
5
import torch as th
import networkx as nx
import dgl
import dgl.nn.pytorch as nn
6
import dgl.function as fn
7
import backend as F
8
import pytest
9
10
from test_utils.graph_cases import get_cases, random_graph, random_bipartite, random_dglgraph
from test_utils import parametrize_dtype
11
from copy import deepcopy
12
import pickle
13

14
15
import scipy as sp

16
17
tmp_buffer = io.BytesIO()

18
19
20
21
22
def _AXWb(A, X, W, b):
    X = th.matmul(X, W)
    Y = th.matmul(A, X.view(X.shape[0], -1)).view_as(X)
    return Y + b

23
24
@pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv0(out_dim):
25
    g = dgl.DGLGraph(nx.path_graph(3)).to(F.ctx())
26
    ctx = F.ctx()
27
    adj = g.adjacency_matrix(transpose=True, ctx=ctx)
28

29
    conv = nn.GraphConv(5, out_dim, norm='none', bias=True)
30
    conv = conv.to(ctx)
31
    print(conv)
32
33
34
35
36

    # test pickle
    th.save(conv, tmp_buffer)


37
    # test#1: basic
38
    h0 = F.ones((3, 5))
39
    h1 = conv(g, h0)
40
41
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
42
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
43
    # test#2: more-dim
44
    h0 = F.ones((3, 5, 5))
45
    h1 = conv(g, h0)
46
47
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
48
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
49

50
    conv = nn.GraphConv(5, out_dim)
51
    conv = conv.to(ctx)
52
    # test#3: basic
53
    h0 = F.ones((3, 5))
54
    h1 = conv(g, h0)
55
56
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
57
    # test#4: basic
58
    h0 = F.ones((3, 5, 5))
59
    h1 = conv(g, h0)
60
61
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
62

63
    conv = nn.GraphConv(5, out_dim)
64
    conv = conv.to(ctx)
65
    # test#3: basic
66
    h0 = F.ones((3, 5))
67
    h1 = conv(g, h0)
68
69
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
70
    # test#4: basic
71
    h0 = F.ones((3, 5, 5))
72
    h1 = conv(g, h0)
73
74
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
75
76
77
78
79

    # test rest_parameters
    old_weight = deepcopy(conv.weight.data)
    conv.reset_parameters()
    new_weight = conv.weight.data
80
    assert not F.allclose(old_weight, new_weight)
81

82
83
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite'], exclude=['zero-degree', 'dglgraph']))
84
@pytest.mark.parametrize('norm', ['none', 'both', 'right', 'left'])
85
86
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
87
88
@pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv(idtype, g, norm, weight, bias, out_dim):
89
90
    # Test one tensor input
    g = g.astype(idtype).to(F.ctx())
91
92
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(F.ctx())
    ext_w = F.randn((5, out_dim)).to(F.ctx())
93
94
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
95
96
    h = F.randn((nsrc, 5)).to(F.ctx())
    if weight:
97
        h_out = conv(g, h)
98
    else:
99
        h_out = conv(g, h, weight=ext_w)
100
    assert h_out.shape == (ndst, out_dim)
101

102
103
104
105
106
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['has_scalar_e_feature'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
107
108
@pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv_e_weight(idtype, g, norm, weight, bias, out_dim):
109
    g = g.astype(idtype).to(F.ctx())
110
111
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(F.ctx())
    ext_w = F.randn((5, out_dim)).to(F.ctx())
112
113
114
115
116
117
118
119
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
    e_w = g.edata['scalar_w']
    if weight:
        h_out = conv(g, h, edge_weight=e_w)
    else:
        h_out = conv(g, h, weight=ext_w, edge_weight=e_w)
120
    assert h_out.shape == (ndst, out_dim)
121
122
123
124
125
126

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['has_scalar_e_feature'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
127
128
@pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv_e_weight_norm(idtype, g, norm, weight, bias, out_dim):
129
    g = g.astype(idtype).to(F.ctx())
130
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(F.ctx())
131
132
133
134

    # test pickle
    th.save(conv, tmp_buffer)

135
    ext_w = F.randn((5, out_dim)).to(F.ctx())
136
137
138
139
140
141
142
143
144
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
    edgenorm = nn.EdgeWeightNorm(norm=norm)
    norm_weight = edgenorm(g, g.edata['scalar_w'])
    if weight:
        h_out = conv(g, h, edge_weight=norm_weight)
    else:
        h_out = conv(g, h, weight=ext_w, edge_weight=norm_weight)
145
    assert h_out.shape == (ndst, out_dim)
146

147
148
149
150
151
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
152
153
@pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv_bi(idtype, g, norm, weight, bias, out_dim):
154
155
    # Test a pair of tensor inputs
    g = g.astype(idtype).to(F.ctx())
156
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(F.ctx())
Mufei Li's avatar
Mufei Li committed
157

158
159
160
    # test pickle
    th.save(conv, tmp_buffer)

161
    ext_w = F.randn((5, out_dim)).to(F.ctx())
162
163
164
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
165
    h_dst = F.randn((ndst, out_dim)).to(F.ctx())
166
167
168
169
    if weight:
        h_out = conv(g, (h, h_dst))
    else:
        h_out = conv(g, (h, h_dst), weight=ext_w)
170
    assert h_out.shape == (ndst, out_dim)
171

172
173
174
175
176
177
178
179
180
181
182
183
def _S2AXWb(A, N, X, W, b):
    X1 = X * N
    X1 = th.matmul(A, X1.view(X1.shape[0], -1))
    X1 = X1 * N
    X2 = X1 * N
    X2 = th.matmul(A, X2.view(X2.shape[0], -1))
    X2 = X2 * N
    X = th.cat([X, X1, X2], dim=-1)
    Y = th.matmul(X, W.rot90())

    return Y + b

184
185
@pytest.mark.parametrize('out_dim', [1, 2])
def test_tagconv(out_dim):
186
    g = dgl.DGLGraph(nx.path_graph(3))
187
    g = g.to(F.ctx())
188
    ctx = F.ctx()
189
    adj = g.adjacency_matrix(transpose=True, ctx=ctx)
190
191
    norm = th.pow(g.in_degrees().float(), -0.5)

192
    conv = nn.TAGConv(5, out_dim, bias=True)
193
    conv = conv.to(ctx)
194
    print(conv)
Mufei Li's avatar
Mufei Li committed
195

196
197
    # test pickle
    th.save(conv, tmp_buffer)
198
199
200

    # test#1: basic
    h0 = F.ones((3, 5))
201
    h1 = conv(g, h0)
202
203
204
205
206
207
208
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    shp = norm.shape + (1,) * (h0.dim() - 1)
    norm = th.reshape(norm, shp).to(ctx)

    assert F.allclose(h1, _S2AXWb(adj, norm, h0, conv.lin.weight, conv.lin.bias))

209
    conv = nn.TAGConv(5, out_dim)
210
    conv = conv.to(ctx)
211

212
213
    # test#2: basic
    h0 = F.ones((3, 5))
214
    h1 = conv(g, h0)
215
    assert h1.shape[-1] == out_dim
216

217
    # test reset_parameters
218
219
220
221
222
    old_weight = deepcopy(conv.lin.weight.data)
    conv.reset_parameters()
    new_weight = conv.lin.weight.data
    assert not F.allclose(old_weight, new_weight)

223
def test_set2set():
224
    ctx = F.ctx()
225
    g = dgl.DGLGraph(nx.path_graph(10))
226
    g = g.to(F.ctx())
227
228

    s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
229
    s2s = s2s.to(ctx)
230
231
232
    print(s2s)

    # test#1: basic
233
    h0 = F.randn((g.number_of_nodes(), 5))
234
    h1 = s2s(g, h0)
235
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
236
237

    # test#2: batched graph
238
239
    g1 = dgl.DGLGraph(nx.path_graph(11)).to(F.ctx())
    g2 = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
240
    bg = dgl.batch([g, g1, g2])
241
    h0 = F.randn((bg.number_of_nodes(), 5))
242
    h1 = s2s(bg, h0)
243
244
245
    assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2

def test_glob_att_pool():
246
    ctx = F.ctx()
247
    g = dgl.DGLGraph(nx.path_graph(10))
248
    g = g.to(F.ctx())
249
250

    gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))
251
    gap = gap.to(ctx)
252
253
    print(gap)

254
255
256
    # test pickle
    th.save(gap, tmp_buffer)

257
    # test#1: basic
258
    h0 = F.randn((g.number_of_nodes(), 5))
259
    h1 = gap(g, h0)
260
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
261
262
263

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
264
    h0 = F.randn((bg.number_of_nodes(), 5))
265
    h1 = gap(bg, h0)
266
267
268
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2

def test_simple_pool():
269
    ctx = F.ctx()
270
    g = dgl.DGLGraph(nx.path_graph(15))
271
    g = g.to(F.ctx())
272
273
274
275
276
277
278
279

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
    sort_pool = nn.SortPooling(10) # k = 10
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
280
    h0 = F.randn((g.number_of_nodes(), 5))
281
282
283
284
    sum_pool = sum_pool.to(ctx)
    avg_pool = avg_pool.to(ctx)
    max_pool = max_pool.to(ctx)
    sort_pool = sort_pool.to(ctx)
285
    h1 = sum_pool(g, h0)
286
    assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
287
    h1 = avg_pool(g, h0)
288
    assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
289
    h1 = max_pool(g, h0)
290
    assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
291
    h1 = sort_pool(g, h0)
292
    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.dim() == 2
293
294

    # test#2: batched graph
295
    g_ = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
296
    bg = dgl.batch([g, g_, g, g_, g])
297
    h0 = F.randn((bg.number_of_nodes(), 5))
298
    h1 = sum_pool(bg, h0)
299
300
301
302
303
304
    truth = th.stack([F.sum(h0[:15], 0),
                      F.sum(h0[15:20], 0),
                      F.sum(h0[20:35], 0),
                      F.sum(h0[35:40], 0),
                      F.sum(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
305

306
    h1 = avg_pool(bg, h0)
307
308
309
310
311
312
    truth = th.stack([F.mean(h0[:15], 0),
                      F.mean(h0[15:20], 0),
                      F.mean(h0[20:35], 0),
                      F.mean(h0[35:40], 0),
                      F.mean(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
313

314
    h1 = max_pool(bg, h0)
315
316
317
318
319
320
    truth = th.stack([F.max(h0[:15], 0),
                      F.max(h0[15:20], 0),
                      F.max(h0[20:35], 0),
                      F.max(h0[35:40], 0),
                      F.max(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
321

322
    h1 = sort_pool(bg, h0)
323
324
325
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2

def test_set_trans():
326
    ctx = F.ctx()
327
328
329
330
331
    g = dgl.DGLGraph(nx.path_graph(15))

    st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'sab')
    st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'isab', 3)
    st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)
332
333
334
    st_enc_0 = st_enc_0.to(ctx)
    st_enc_1 = st_enc_1.to(ctx)
    st_dec = st_dec.to(ctx)
335
336
337
    print(st_enc_0, st_enc_1, st_dec)

    # test#1: basic
338
    h0 = F.randn((g.number_of_nodes(), 50))
339
    h1 = st_enc_0(g, h0)
340
    assert h1.shape == h0.shape
341
    h1 = st_enc_1(g, h0)
342
    assert h1.shape == h0.shape
343
    h2 = st_dec(g, h1)
344
    assert h2.shape[0] == 1 and h2.shape[1] == 200 and h2.dim() == 2
345
346
347
348
349

    # test#2: batched graph
    g1 = dgl.DGLGraph(nx.path_graph(5))
    g2 = dgl.DGLGraph(nx.path_graph(10))
    bg = dgl.batch([g, g1, g2])
350
    h0 = F.randn((bg.number_of_nodes(), 50))
351
    h1 = st_enc_0(bg, h0)
352
    assert h1.shape == h0.shape
353
    h1 = st_enc_1(bg, h0)
354
355
    assert h1.shape == h0.shape

356
    h2 = st_dec(bg, h1)
357
358
    assert h2.shape[0] == 3 and h2.shape[1] == 200 and h2.dim() == 2

359
360
361
@parametrize_dtype
@pytest.mark.parametrize('O', [1, 8, 32])
def test_rgcn(idtype, O):
Minjie Wang's avatar
Minjie Wang committed
362
363
    ctx = F.ctx()
    etype = []
364
365
    g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.1))
    g = g.astype(idtype).to(F.ctx())
Minjie Wang's avatar
Minjie Wang committed
366
367
368
369
370
371
372
373
374
    # 5 etypes
    R = 5
    for i in range(g.number_of_edges()):
        etype.append(i % 5)
    B = 2
    I = 10

    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
xiang song(charlie.song)'s avatar
xiang song(charlie.song) committed
375
    norm = th.rand((g.number_of_edges(), 1)).to(ctx)
376
377
378
    sorted_r, idx = th.sort(r)
    sorted_g = dgl.reorder_graph(g, edge_permute_algo='custom', permute_config={'edges_perm' : idx.to(idtype)})
    sorted_norm = norm[idx]
Minjie Wang's avatar
Minjie Wang committed
379

380
381
    rgc = nn.RelGraphConv(I, O, R).to(ctx)
    th.save(rgc, tmp_buffer)  # test pickle
Minjie Wang's avatar
Minjie Wang committed
382
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
383
    th.save(rgc_basis, tmp_buffer)  # test pickle
384
385
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
386
        th.save(rgc_bdd, tmp_buffer)  # test pickle
387

388
389
390
391
392
    # basic usage
    h_new = rgc(g, h, r)
    assert h_new.shape == (100, O)
    h_new_basis = rgc_basis(g, h, r)
    assert h_new_basis.shape == (100, O)
393
    if O % B == 0:
394
395
396
397
398
399
400
401
402
403
404
        h_new_bdd = rgc_bdd(g, h, r)
        assert h_new_bdd.shape == (100, O)

    # sorted input
    h_new_sorted = rgc(sorted_g, h, sorted_r, presorted=True)
    assert th.allclose(h_new, h_new_sorted, atol=1e-4, rtol=1e-4)
    h_new_basis_sorted = rgc_basis(sorted_g, h, sorted_r, presorted=True)
    assert th.allclose(h_new_basis, h_new_basis_sorted, atol=1e-4, rtol=1e-4)
    if O % B == 0:
        h_new_bdd_sorted = rgc_bdd(sorted_g, h, sorted_r, presorted=True)
        assert th.allclose(h_new_bdd, h_new_bdd_sorted, atol=1e-4, rtol=1e-4)
405

406
407
408
    # norm input
    h_new = rgc(g, h, r, norm)
    assert h_new.shape == (100, O)
409
    h_new = rgc_basis(g, h, r, norm)
410
    assert h_new.shape == (100, O)
411
412
    if O % B == 0:
        h_new = rgc_bdd(g, h, r, norm)
413
        assert h_new.shape == (100, O)
414
415


416
@parametrize_dtype
417
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
418
@pytest.mark.parametrize('out_dim', [1, 5])
419
420
@pytest.mark.parametrize('num_heads', [1, 4])
def test_gat_conv(g, idtype, out_dim, num_heads):
421
    g = g.astype(idtype).to(F.ctx())
422
    ctx = F.ctx()
423
    gat = nn.GATConv(5, out_dim, num_heads)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
424
    feat = F.randn((g.number_of_src_nodes(), 5))
425
    gat = gat.to(ctx)
426
    h = gat(g, feat)
427
428
429
430

    # test pickle
    th.save(gat, tmp_buffer)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
431
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
432
    _, a = gat(g, feat, get_attention=True)
433
    assert a.shape == (g.number_of_edges(), num_heads, 1)
434

435
436
437
438
439
    # test residual connection
    gat = nn.GATConv(5, out_dim, num_heads, residual=True)
    gat = gat.to(ctx)
    h = gat(g, feat)

440
@parametrize_dtype
441
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
442
443
444
@pytest.mark.parametrize('out_dim', [1, 2])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_gat_conv_bi(g, idtype, out_dim, num_heads):
445
446
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
447
    gat = nn.GATConv(5, out_dim, num_heads)
448
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
449
450
    gat = gat.to(ctx)
    h = gat(g, feat)
451
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
452
    _, a = gat(g, feat, get_attention=True)
453
    assert a.shape == (g.number_of_edges(), num_heads, 1)
454

Shaked Brody's avatar
Shaked Brody committed
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_dim', [1, 5])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_gatv2_conv(g, idtype, out_dim, num_heads):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    gat = nn.GATv2Conv(5, out_dim, num_heads)
    feat = F.randn((g.number_of_src_nodes(), 5))
    gat = gat.to(ctx)
    h = gat(g, feat)

    # test pickle
    th.save(gat, tmp_buffer)

    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
    _, a = gat(g, feat, get_attention=True)
    assert a.shape == (g.number_of_edges(), num_heads, 1)

    # test residual connection
    gat = nn.GATConv(5, out_dim, num_heads, residual=True)
    gat = gat.to(ctx)
    h = gat(g, feat)

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_dim', [1, 2])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_gatv2_conv_bi(g, idtype, out_dim, num_heads):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    gat = nn.GATv2Conv(5, out_dim, num_heads)
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
    gat = gat.to(ctx)
    h = gat(g, feat)
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
    _, a = gat(g, feat, get_attention=True)
    assert a.shape == (g.number_of_edges(), num_heads, 1)

494
495
496
497
498
499
500
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_node_feats', [1, 5])
@pytest.mark.parametrize('out_edge_feats', [1, 5])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads):
    g = g.astype(idtype).to(F.ctx())
Mufei Li's avatar
Mufei Li committed
501
    ctx = F.ctx()
502
503
504
505
506
507
508
    egat = nn.EGATConv(in_node_feats=10,
                       in_edge_feats=5,
                       out_node_feats=out_node_feats,
                       out_edge_feats=out_edge_feats,
                       num_heads=num_heads)
    nfeat = F.randn((g.number_of_nodes(), 10))
    efeat = F.randn((g.number_of_edges(), 5))
Mufei Li's avatar
Mufei Li committed
509

510
511
512
513
    egat = egat.to(ctx)
    h, f = egat(g, nfeat, efeat)
    h, f, attn = egat(g, nfeat, efeat, True)

Mufei Li's avatar
Mufei Li committed
514
    th.save(egat, tmp_buffer)
515

516
@parametrize_dtype
517
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
518
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm'])
519
520
def test_sage_conv(idtype, g, aggre_type):
    g = g.astype(idtype).to(F.ctx())
521
    sage = nn.SAGEConv(5, 10, aggre_type)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
522
    feat = F.randn((g.number_of_src_nodes(), 5))
523
    sage = sage.to(F.ctx())
524
525
    # test pickle
    th.save(sage, tmp_buffer)
526
527
528
    h = sage(g, feat)
    assert h.shape[-1] == 10

529
@parametrize_dtype
530
@pytest.mark.parametrize('g', get_cases(['bipartite']))
531
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm'])
532
533
@pytest.mark.parametrize('out_dim', [1, 2])
def test_sage_conv_bi(idtype, g, aggre_type, out_dim):
534
    g = g.astype(idtype).to(F.ctx())
535
    dst_dim = 5 if aggre_type != 'gcn' else 10
536
    sage = nn.SAGEConv((10, dst_dim), out_dim, aggre_type)
537
538
    feat = (F.randn((g.number_of_src_nodes(), 10)), F.randn((g.number_of_dst_nodes(), dst_dim)))
    sage = sage.to(F.ctx())
539
    h = sage(g, feat)
540
    assert h.shape[-1] == out_dim
541
    assert h.shape[0] == g.number_of_dst_nodes()
542

543
@parametrize_dtype
544
545
@pytest.mark.parametrize('out_dim', [1, 2])
def test_sage_conv2(idtype, out_dim):
546
    # TODO: add test for blocks
Mufei Li's avatar
Mufei Li committed
547
    # Test the case for graphs without edges
548
    g = dgl.heterograph({('_U', '_E', '_V'): ([], [])}, {'_U': 5, '_V': 3})
549
550
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
551
    sage = nn.SAGEConv((3, 3), out_dim, 'gcn')
Mufei Li's avatar
Mufei Li committed
552
553
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    sage = sage.to(ctx)
554
    h = sage(g, (F.copy_to(feat[0], F.ctx()), F.copy_to(feat[1], F.ctx())))
555
    assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
556
557
    assert h.shape[0] == 3
    for aggre_type in ['mean', 'pool', 'lstm']:
558
        sage = nn.SAGEConv((3, 1), out_dim, aggre_type)
Mufei Li's avatar
Mufei Li committed
559
560
561
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        sage = sage.to(ctx)
        h = sage(g, feat)
562
        assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
563
564
        assert h.shape[0] == 3

565
566
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
567
568
@pytest.mark.parametrize('out_dim', [1, 2])
def test_sgc_conv(g, idtype, out_dim):
569
    ctx = F.ctx()
570
    g = g.astype(idtype).to(ctx)
571
    # not cached
572
    sgc = nn.SGConv(5, out_dim, 3)
573
574
575
576

    # test pickle
    th.save(sgc, tmp_buffer)

577
    feat = F.randn((g.number_of_nodes(), 5))
578
    sgc = sgc.to(ctx)
579

580
    h = sgc(g, feat)
581
    assert h.shape[-1] == out_dim
582
583

    # cached
584
    sgc = nn.SGConv(5, out_dim, 3, True)
585
    sgc = sgc.to(ctx)
586
587
    h_0 = sgc(g, feat)
    h_1 = sgc(g, feat + 1)
588
    assert F.allclose(h_0, h_1)
589
    assert h_0.shape[-1] == out_dim
590

591
592
593
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_appnp_conv(g, idtype):
594
    ctx = F.ctx()
595
    g = g.astype(idtype).to(ctx)
596
    appnp = nn.APPNPConv(10, 0.1)
597
    feat = F.randn((g.number_of_nodes(), 5))
598
    appnp = appnp.to(ctx)
Mufei Li's avatar
Mufei Li committed
599

600
601
    # test pickle
    th.save(appnp, tmp_buffer)
602

603
    h = appnp(g, feat)
604
605
    assert h.shape[-1] == 5

606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_appnp_conv_e_weight(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    appnp = nn.APPNPConv(10, 0.1)
    feat = F.randn((g.number_of_nodes(), 5))
    eweight = F.ones((g.num_edges(), ))
    appnp = appnp.to(ctx)

    h = appnp(g, feat, edge_weight=eweight)
    assert h.shape[-1] == 5

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_gcn2conv_e_weight(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    gcn2conv = nn.GCN2Conv(5, layer=2, alpha=0.5,
                           project_initial_features=True)
    feat = F.randn((g.number_of_nodes(), 5))
    eweight = F.ones((g.num_edges(), ))
    gcn2conv = gcn2conv.to(ctx)
    res = feat
    h = gcn2conv(g, res, feat, edge_weight=eweight)
    assert h.shape[-1] == 5


@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_sgconv_e_weight(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    sgconv = nn.SGConv(5, 5, 3)
    feat = F.randn((g.number_of_nodes(), 5))
    eweight = F.ones((g.num_edges(), ))
    sgconv = sgconv.to(ctx)
    h = sgconv(g, feat, edge_weight=eweight)
    assert h.shape[-1] == 5

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_tagconv_e_weight(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    conv = nn.TAGConv(5, 5, bias=True)
    conv = conv.to(ctx)
    feat = F.randn((g.number_of_nodes(), 5))
    eweight = F.ones((g.num_edges(), ))
    conv = conv.to(ctx)
    h = conv(g, feat, edge_weight=eweight)
    assert h.shape[-1] == 5

660
@parametrize_dtype
661
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
662
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
663
664
def test_gin_conv(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
665
666
667
668
669
    ctx = F.ctx()
    gin = nn.GINConv(
        th.nn.Linear(5, 12),
        aggregator_type
    )
VoVAllen's avatar
VoVAllen committed
670
    th.save(gin, tmp_buffer)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
671
    feat = F.randn((g.number_of_src_nodes(), 5))
672
673
    gin = gin.to(ctx)
    h = gin(g, feat)
674
675

    # test pickle
VoVAllen's avatar
VoVAllen committed
676
    th.save(gin, tmp_buffer)
Mufei Li's avatar
Mufei Li committed
677

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
678
    assert h.shape == (g.number_of_dst_nodes(), 12)
679

Mufei Li's avatar
Mufei Li committed
680
681
682
683
    gin = nn.GINConv(None, aggregator_type)
    th.save(gin, tmp_buffer)
    gin = gin.to(ctx)
    h = gin(g, feat)
684

Mufei Li's avatar
Mufei Li committed
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
def test_gine_conv(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    gine = nn.GINEConv(
        th.nn.Linear(5, 12)
    )
    th.save(gine, tmp_buffer)
    nfeat = F.randn((g.number_of_src_nodes(), 5))
    efeat = F.randn((g.num_edges(), 5))
    gine = gine.to(ctx)
    h = gine(g, nfeat, efeat)

    # test pickle
    th.save(gine, tmp_buffer)
    assert h.shape == (g.number_of_dst_nodes(), 12)

    gine = nn.GINEConv(None)
    th.save(gine, tmp_buffer)
    gine = gine.to(ctx)
    h = gine(g, nfeat, efeat)

708
@parametrize_dtype
709
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
710
711
712
713
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
def test_gin_conv_bi(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
714
715
716
717
    gin = nn.GINConv(
        th.nn.Linear(5, 12),
        aggregator_type
    )
718
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
719
720
    gin = gin.to(ctx)
    h = gin(g, feat)
721
    assert h.shape == (g.number_of_dst_nodes(), 12)
722

723
@parametrize_dtype
724
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
725
726
def test_agnn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
727
728
    ctx = F.ctx()
    agnn = nn.AGNNConv(1)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
729
    feat = F.randn((g.number_of_src_nodes(), 5))
730
    agnn = agnn.to(ctx)
731
    h = agnn(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
732
    assert h.shape == (g.number_of_dst_nodes(), 5)
733

734
@parametrize_dtype
735
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
736
737
738
def test_agnn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
739
    agnn = nn.AGNNConv(1)
740
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
741
742
    agnn = agnn.to(ctx)
    h = agnn(g, feat)
743
    assert h.shape == (g.number_of_dst_nodes(), 5)
744

745
746
747
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_gated_graph_conv(g, idtype):
748
    ctx = F.ctx()
749
    g = g.astype(idtype).to(ctx)
750
751
    ggconv = nn.GatedGraphConv(5, 10, 5, 3)
    etypes = th.arange(g.number_of_edges()) % 3
752
    feat = F.randn((g.number_of_nodes(), 5))
753
754
    ggconv = ggconv.to(ctx)
    etypes = etypes.to(ctx)
755

756
    h = ggconv(g, feat, etypes)
757
758
759
    # current we only do shape check
    assert h.shape[-1] == 10

760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_gated_graph_conv_one_etype(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    ggconv = nn.GatedGraphConv(5, 10, 5, 1)
    etypes = th.zeros(g.number_of_edges())
    feat = F.randn((g.number_of_nodes(), 5))
    ggconv = ggconv.to(ctx)
    etypes = etypes.to(ctx)

    h = ggconv(g, feat, etypes)
    h2 = ggconv(g, feat)
    # current we only do shape check
    assert F.allclose(h, h2)
    assert h.shape[-1] == 10

777
@parametrize_dtype
778
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
779
780
def test_nn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
781
782
783
    ctx = F.ctx()
    edge_func = th.nn.Linear(4, 5 * 10)
    nnconv = nn.NNConv(5, 10, edge_func, 'mean')
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
784
    feat = F.randn((g.number_of_src_nodes(), 5))
785
786
787
788
789
790
    efeat = F.randn((g.number_of_edges(), 4))
    nnconv = nnconv.to(ctx)
    h = nnconv(g, feat, efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

791
@parametrize_dtype
792
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
793
794
795
def test_nn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
796
797
    edge_func = th.nn.Linear(4, 5 * 10)
    nnconv = nn.NNConv((5, 2), 10, edge_func, 'mean')
798
799
    feat = F.randn((g.number_of_src_nodes(), 5))
    feat_dst = F.randn((g.number_of_dst_nodes(), 2))
800
801
802
803
804
805
    efeat = F.randn((g.number_of_edges(), 4))
    nnconv = nnconv.to(ctx)
    h = nnconv(g, (feat, feat_dst), efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

806
@parametrize_dtype
807
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
808
809
def test_gmm_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
810
811
    ctx = F.ctx()
    gmmconv = nn.GMMConv(5, 10, 3, 4, 'mean')
812
    feat = F.randn((g.number_of_nodes(), 5))
813
    pseudo = F.randn((g.number_of_edges(), 3))
814
    gmmconv = gmmconv.to(ctx)
815
    h = gmmconv(g, feat, pseudo)
816
817
818
    # currently we only do shape check
    assert h.shape[-1] == 10

819
@parametrize_dtype
820
@pytest.mark.parametrize('g', get_cases(['bipartite', 'block-bipartite'], exclude=['zero-degree']))
821
822
823
def test_gmm_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
824
    gmmconv = nn.GMMConv((5, 2), 10, 3, 4, 'mean')
825
826
    feat = F.randn((g.number_of_src_nodes(), 5))
    feat_dst = F.randn((g.number_of_dst_nodes(), 2))
827
828
829
830
831
832
    pseudo = F.randn((g.number_of_edges(), 3))
    gmmconv = gmmconv.to(ctx)
    h = gmmconv(g, (feat, feat_dst), pseudo)
    # currently we only do shape check
    assert h.shape[-1] == 10

833
@parametrize_dtype
834
@pytest.mark.parametrize('norm_type', ['both', 'right', 'none'])
835
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite'], exclude=['zero-degree']))
836
837
@pytest.mark.parametrize('out_dim', [1, 2])
def test_dense_graph_conv(norm_type, g, idtype, out_dim):
838
    g = g.astype(idtype).to(F.ctx())
839
    ctx = F.ctx()
840
    # TODO(minjie): enable the following option after #1385
841
    adj = g.adjacency_matrix(transpose=True, ctx=ctx).to_dense()
842
843
    conv = nn.GraphConv(5, out_dim, norm=norm_type, bias=True)
    dense_conv = nn.DenseGraphConv(5, out_dim, norm=norm_type, bias=True)
844
845
    dense_conv.weight.data = conv.weight.data
    dense_conv.bias.data = conv.bias.data
846
    feat = F.randn((g.number_of_src_nodes(), 5))
847
848
    conv = conv.to(ctx)
    dense_conv = dense_conv.to(ctx)
849
850
    out_conv = conv(g, feat)
    out_dense_conv = dense_conv(adj, feat)
851
852
    assert F.allclose(out_conv, out_dense_conv)

853
@parametrize_dtype
854
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite']))
855
856
@pytest.mark.parametrize('out_dim', [1, 2])
def test_dense_sage_conv(g, idtype, out_dim):
857
    g = g.astype(idtype).to(F.ctx())
858
    ctx = F.ctx()
859
    adj = g.adjacency_matrix(transpose=True, ctx=ctx).to_dense()
860
861
    sage = nn.SAGEConv(5, out_dim, 'gcn')
    dense_sage = nn.DenseSAGEConv(5, out_dim)
862
    dense_sage.fc.weight.data = sage.fc_neigh.weight.data
863
    dense_sage.fc.bias.data = sage.bias.data
864
865
866
867
868
869
870
    if len(g.ntypes) == 2:
        feat = (
            F.randn((g.number_of_src_nodes(), 5)),
            F.randn((g.number_of_dst_nodes(), 5))
        )
    else:
        feat = F.randn((g.number_of_nodes(), 5))
871
872
    sage = sage.to(ctx)
    dense_sage = dense_sage.to(ctx)
873
874
    out_sage = sage(g, feat)
    out_dense_sage = dense_sage(adj, feat)
875
876
    assert F.allclose(out_sage, out_dense_sage), g

877
@parametrize_dtype
878
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
879
880
@pytest.mark.parametrize('out_dim', [1, 2])
def test_edge_conv(g, idtype, out_dim):
881
    g = g.astype(idtype).to(F.ctx())
882
    ctx = F.ctx()
883
    edge_conv = nn.EdgeConv(5, out_dim).to(ctx)
884
    print(edge_conv)
885
886
887

    # test pickle
    th.save(edge_conv, tmp_buffer)
Mufei Li's avatar
Mufei Li committed
888

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
889
    h0 = F.randn((g.number_of_src_nodes(), 5))
890
    h1 = edge_conv(g, h0)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
891
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)
892

893
@parametrize_dtype
894
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
895
896
@pytest.mark.parametrize('out_dim', [1, 2])
def test_edge_conv_bi(g, idtype, out_dim):
897
898
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
899
    edge_conv = nn.EdgeConv(5, out_dim).to(ctx)
900
    print(edge_conv)
901
    h0 = F.randn((g.number_of_src_nodes(), 5))
902
903
    x0 = F.randn((g.number_of_dst_nodes(), 5))
    h1 = edge_conv(g, (h0, x0))
904
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)
Mufei Li's avatar
Mufei Li committed
905

906
907
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
908
909
910
@pytest.mark.parametrize('out_dim', [1, 2])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_dotgat_conv(g, idtype, out_dim, num_heads):
911
912
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
913
    dotgat = nn.DotGatConv(5, out_dim, num_heads)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
914
    feat = F.randn((g.number_of_src_nodes(), 5))
915
    dotgat = dotgat.to(ctx)
Mufei Li's avatar
Mufei Li committed
916

917
918
    # test pickle
    th.save(dotgat, tmp_buffer)
Mufei Li's avatar
Mufei Li committed
919

920
    h = dotgat(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
921
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
922
    _, a = dotgat(g, feat, get_attention=True)
923
    assert a.shape == (g.number_of_edges(), num_heads, 1)
924
925
926

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
927
928
929
@pytest.mark.parametrize('out_dim', [1, 2])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_dotgat_conv_bi(g, idtype, out_dim, num_heads):
930
931
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
932
    dotgat = nn.DotGatConv((5, 5), out_dim, num_heads)
933
934
935
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
    dotgat = dotgat.to(ctx)
    h = dotgat(g, feat)
936
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
937
    _, a = dotgat(g, feat, get_attention=True)
938
    assert a.shape == (g.number_of_edges(), num_heads, 1)
939

940
941
@pytest.mark.parametrize('out_dim', [1, 2])
def test_dense_cheb_conv(out_dim):
942
943
944
    for k in range(1, 4):
        ctx = F.ctx()
        g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
945
        g = g.to(F.ctx())
946
        adj = g.adjacency_matrix(transpose=True, ctx=ctx).to_dense()
947
948
        cheb = nn.ChebConv(5, out_dim, k, None)
        dense_cheb = nn.DenseChebConv(5, out_dim, k)
Axel Nilsson's avatar
Axel Nilsson committed
949
950
        #for i in range(len(cheb.fc)):
        #    dense_cheb.W.data[i] = cheb.fc[i].weight.data.t()
951
        dense_cheb.W.data = cheb.linear.weight.data.transpose(-1, -2).view(k, 5, out_dim)
Axel Nilsson's avatar
Axel Nilsson committed
952
953
        if cheb.linear.bias is not None:
            dense_cheb.bias.data = cheb.linear.bias.data
954
        feat = F.randn((100, 5))
955
956
        cheb = cheb.to(ctx)
        dense_cheb = dense_cheb.to(ctx)
957
958
        out_cheb = cheb(g, feat, [2.0])
        out_dense_cheb = dense_cheb(adj, feat, 2.0)
Axel Nilsson's avatar
Axel Nilsson committed
959
        print(k, out_cheb, out_dense_cheb)
960
961
        assert F.allclose(out_cheb, out_dense_cheb)

962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
def test_sequential():
    ctx = F.ctx()
    # Test single graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat, e_feat):
            graph = graph.local_var()
            graph.ndata['h'] = n_feat
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            n_feat += graph.ndata['h']
            graph.apply_edges(fn.u_add_v('h', 'h', 'e'))
            e_feat += graph.edata['e']
            return n_feat, e_feat

    g = dgl.DGLGraph()
    g.add_nodes(3)
    g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
981
    g = g.to(F.ctx())
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    n_feat = F.randn((3, 4))
    e_feat = F.randn((9, 4))
    net = net.to(ctx)
    n_feat, e_feat = net(g, n_feat, e_feat)
    assert n_feat.shape == (3, 4)
    assert e_feat.shape == (9, 4)

    # Test multiple graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat):
            graph = graph.local_var()
            graph.ndata['h'] = n_feat
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            n_feat += graph.ndata['h']
            return n_feat.view(graph.number_of_nodes() // 2, 2, -1).sum(1)

1002
1003
1004
    g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05)).to(F.ctx())
    g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2)).to(F.ctx())
    g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8)).to(F.ctx())
1005
1006
1007
1008
1009
1010
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    net = net.to(ctx)
    n_feat = F.randn((32, 4))
    n_feat = net([g1, g2, g3], n_feat)
    assert n_feat.shape == (4, 4)

1011
1012
1013
1014
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_atomic_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
1015
1016
1017
1018
1019
1020
1021
1022
1023
    aconv = nn.AtomicConv(interaction_cutoffs=F.tensor([12.0, 12.0]),
                          rbf_kernel_means=F.tensor([0.0, 2.0]),
                          rbf_kernel_scaling=F.tensor([4.0, 4.0]),
                          features_to_use=F.tensor([6.0, 8.0]))

    ctx = F.ctx()
    if F.gpu_ctx():
        aconv = aconv.to(ctx)

1024
    feat = F.randn((g.number_of_nodes(), 1))
1025
1026
1027
    dist = F.randn((g.number_of_edges(), 1))

    h = aconv(g, feat, dist)
1028

1029
1030
1031
    # current we only do shape check
    assert h.shape[-1] == 4

1032
@parametrize_dtype
1033
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite'], exclude=['zero-degree']))
1034
1035
@pytest.mark.parametrize('out_dim', [1, 3])
def test_cf_conv(g, idtype, out_dim):
1036
    g = g.astype(idtype).to(F.ctx())
1037
1038
1039
    cfconv = nn.CFConv(node_in_feats=2,
                       edge_in_feats=3,
                       hidden_feats=2,
1040
                       out_feats=out_dim)
1041
1042
1043
1044
1045

    ctx = F.ctx()
    if F.gpu_ctx():
        cfconv = cfconv.to(ctx)

1046
    src_feats = F.randn((g.number_of_src_nodes(), 2))
1047
    edge_feats = F.randn((g.number_of_edges(), 3))
1048
1049
1050
1051
1052
1053
1054
    h = cfconv(g, src_feats, edge_feats)
    # current we only do shape check
    assert h.shape[-1] == out_dim

    # case for bipartite graphs
    dst_feats = F.randn((g.number_of_dst_nodes(), 3))
    h = cfconv(g, (src_feats, dst_feats), edge_feats)
1055
    # current we only do shape check
1056
    assert h.shape[-1] == out_dim
1057

1058
1059
1060
1061
1062
1063
def myagg(alist, dsttype):
    rst = alist[0]
    for i in range(1, len(alist)):
        rst = rst + (i + 1) * alist[i]
    return rst

1064
@parametrize_dtype
1065
@pytest.mark.parametrize('agg', ['sum', 'max', 'min', 'mean', 'stack', myagg])
1066
def test_hetero_conv(agg, idtype):
1067
    g = dgl.heterograph({
1068
1069
1070
        ('user', 'follows', 'user'): ([0, 0, 2, 1], [1, 2, 1, 3]),
        ('user', 'plays', 'game'): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]),
        ('store', 'sells', 'game'): ([0, 0, 1, 1], [0, 3, 1, 2])},
1071
        idtype=idtype, device=F.ctx())
1072
    conv = nn.HeteroGraphConv({
1073
1074
1075
        'follows': nn.GraphConv(2, 3, allow_zero_in_degree=True),
        'plays': nn.GraphConv(2, 4, allow_zero_in_degree=True),
        'sells': nn.GraphConv(3, 4, allow_zero_in_degree=True)},
1076
        agg)
1077
    conv = conv.to(F.ctx())
1078
1079
1080
1081

    # test pickle
    th.save(conv, tmp_buffer)

1082
1083
1084
1085
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))

1086
    h = conv(g, {'user': uf, 'game': gf, 'store': sf})
1087
1088
1089
1090
1091
1092
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
1093
        assert h['game'].shape == (4, 2, 4)
1094

1095
1096
    block = dgl.to_block(g.to(F.cpu()), {'user': [0, 1, 2, 3], 'game': [0, 1, 2, 3], 'store': []}).to(F.ctx())
    h = conv(block, ({'user': uf, 'game': gf, 'store': sf}, {'user': uf, 'game': gf, 'store': sf[0:0]}))
1097
1098
1099
1100
1101
1102
1103
1104
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 2, 4)

1105
    h = conv(block, {'user': uf, 'game': gf, 'store': sf})
1106
1107
1108
1109
1110
1111
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
1112
        assert h['game'].shape == (4, 2, 4)
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135

    # test with mod args
    class MyMod(th.nn.Module):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.carg2 = 0
            self.s1 = s1
            self.s2 = s2
        def forward(self, g, h, arg1=None, *, arg2=None):
            if arg1 is not None:
                self.carg1 += 1
            if arg2 is not None:
                self.carg2 += 1
            return th.zeros((g.number_of_dst_nodes(), self.s2))
    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
    conv = nn.HeteroGraphConv({
        'follows': mod1,
        'plays': mod2,
        'sells': mod3},
        agg)
1136
    conv = conv.to(F.ctx())
1137
1138
    mod_args = {'follows' : (1,), 'plays' : (1,)}
    mod_kwargs = {'sells' : {'arg2' : 'abc'}}
1139
    h = conv(g, {'user' : uf, 'game': gf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs)
1140
1141
1142
1143
1144
1145
1146
    assert mod1.carg1 == 1
    assert mod1.carg2 == 0
    assert mod2.carg1 == 1
    assert mod2.carg2 == 0
    assert mod3.carg1 == 0
    assert mod3.carg2 == 1

1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
    #conv on graph without any edges
    for etype in g.etypes:
        g = dgl.remove_edges(g, g.edges(form='eid', etype=etype), etype=etype)
    assert g.num_edges() == 0
    h = conv(g, {'user': uf, 'game': gf, 'store': sf})
    assert set(h.keys()) == {'user', 'game'}

    block = dgl.to_block(g.to(F.cpu()), {'user': [0, 1, 2, 3], 'game': [
                         0, 1, 2, 3], 'store': []}).to(F.ctx())
    h = conv(block, ({'user': uf, 'game': gf, 'store': sf},
             {'user': uf, 'game': gf, 'store': sf[0:0]}))
    assert set(h.keys()) == {'user', 'game'}

1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
@pytest.mark.parametrize('out_dim', [1, 2, 100])
def test_hetero_linear(out_dim):
    in_feats = {
        'user': F.randn((2, 1)),
        ('user', 'follows', 'user'): F.randn((3, 2))
    }

    layer = nn.HeteroLinear({'user': 1, ('user', 'follows', 'user'): 2}, out_dim)
    layer = layer.to(F.ctx())
    out_feats = layer(in_feats)
    assert out_feats['user'].shape == (2, out_dim)
    assert out_feats[('user', 'follows', 'user')].shape == (3, out_dim)

@pytest.mark.parametrize('out_dim', [1, 2, 100])
def test_hetero_embedding(out_dim):
    layer = nn.HeteroEmbedding({'user': 2, ('user', 'follows', 'user'): 3}, out_dim)
    layer = layer.to(F.ctx())

    embeds = layer.weight
    assert embeds['user'].shape == (2, out_dim)
    assert embeds[('user', 'follows', 'user')].shape == (3, out_dim)

    embeds = layer({
        'user': F.tensor([0], dtype=F.int64),
        ('user', 'follows', 'user'): F.tensor([0, 2], dtype=F.int64)
    })
    assert embeds['user'].shape == (1, out_dim)
    assert embeds[('user', 'follows', 'user')].shape == (2, out_dim)

Mufei Li's avatar
Mufei Li committed
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_dim', [1, 2])
def test_gnnexplainer(g, idtype, out_dim):
    g = g.astype(idtype).to(F.ctx())
    feat = F.randn((g.num_nodes(), 5))

    class Model(th.nn.Module):
        def __init__(self, in_feats, out_feats, graph=False):
            super(Model, self).__init__()
            self.linear = th.nn.Linear(in_feats, out_feats)
            if graph:
                self.pool = nn.AvgPooling()
            else:
                self.pool = None

        def forward(self, graph, feat, eweight=None):
            with graph.local_scope():
                feat = self.linear(feat)
                graph.ndata['h'] = feat
                if eweight is None:
                    graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
                else:
                    graph.edata['w'] = eweight
                    graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h'))

                if self.pool:
                    return self.pool(graph, graph.ndata['h'])
                else:
                    return graph.ndata['h']

    # Explain node prediction
    model = Model(5, out_dim)
    model = model.to(F.ctx())
    explainer = nn.GNNExplainer(model, num_hops=1)
    new_center, sg, feat_mask, edge_mask = explainer.explain_node(0, g, feat)

    # Explain graph prediction
    model = Model(5, out_dim, graph=True)
    model = model.to(F.ctx())
    explainer = nn.GNNExplainer(model, num_hops=1)
    feat_mask, edge_mask = explainer.explain_graph(g, feat)

Mufei Li's avatar
Mufei Li committed
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
def test_jumping_knowledge():
    ctx = F.ctx()
    num_layers = 2
    num_nodes = 3
    num_feats = 4

    feat_list = [th.randn((num_nodes, num_feats)).to(ctx) for _ in range(num_layers)]

    model = nn.JumpingKnowledge('cat').to(ctx)
    model.reset_parameters()
    assert model(feat_list).shape == (num_nodes, num_layers * num_feats)

    model = nn.JumpingKnowledge('max').to(ctx)
    model.reset_parameters()
    assert model(feat_list).shape == (num_nodes, num_feats)

    model = nn.JumpingKnowledge('lstm', num_feats, num_layers).to(ctx)
    model.reset_parameters()
    assert model(feat_list).shape == (num_nodes, num_feats)

Mufei Li's avatar
Mufei Li committed
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
@pytest.mark.parametrize('op', ['dot', 'cos', 'ele', 'cat'])
def test_edge_predictor(op):
    ctx = F.ctx()
    num_pairs = 3
    in_feats = 4
    out_feats = 5
    h_src = th.randn((num_pairs, in_feats)).to(ctx)
    h_dst = th.randn((num_pairs, in_feats)).to(ctx)

    pred = nn.EdgePredictor(op)
    if op in ['dot', 'cos']:
        assert pred(h_src, h_dst).shape == (num_pairs, 1)
    elif op == 'ele':
        assert pred(h_src, h_dst).shape == (num_pairs, in_feats)
    else:
        assert pred(h_src, h_dst).shape == (num_pairs, 2 * in_feats)
    pred = nn.EdgePredictor(op, in_feats, out_feats, bias=True).to(ctx)
    assert pred(h_src, h_dst).shape == (num_pairs, out_feats)

Mufei Li's avatar
Mufei Li committed
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290

def test_ke_score_funcs():
    ctx = F.ctx()
    num_edges = 30
    num_rels = 3
    nfeats = 4

    h_src = th.randn((num_edges, nfeats)).to(ctx)
    h_dst = th.randn((num_edges, nfeats)).to(ctx)
    rels = th.randint(low=0, high=num_rels, size=(num_edges,)).to(ctx)

    score_func = nn.TransE(num_rels=num_rels, feats=nfeats).to(ctx)
    score_func.reset_parameters()
    score_func(h_src, h_dst, rels).shape == (num_edges)

    score_func = nn.TransR(num_rels=num_rels, rfeats=nfeats - 1, nfeats=nfeats).to(ctx)
    score_func.reset_parameters()
    score_func(h_src, h_dst, rels).shape == (num_edges)


1291
def test_twirls():
1292
1293
1294
1295
1296
    g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
    feat = th.ones(6, 10)
    conv = nn.TWIRLSConv(10, 2, 128, prop_step = 64)
    res = conv(g , feat)
    assert ( res.size() == (6,2) )
1297

1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
@pytest.mark.parametrize('feat_size', [4, 32])
@pytest.mark.parametrize('regularizer,num_bases', [(None, None), ('basis', 4), ('bdd', 4)])
def test_typed_linear(feat_size, regularizer, num_bases):
    dev = F.ctx()
    num_types = 5
    lin = nn.TypedLinear(feat_size, feat_size * 2, 5, regularizer=regularizer, num_bases=num_bases).to(dev)
    print(lin)
    x = th.randn(100, feat_size).to(dev)
    x_type = th.randint(0, 5, (100,)).to(dev)
    x_type_sorted, idx = th.sort(x_type)
    _, rev_idx = th.sort(idx)
    x_sorted = x[idx]

    # test unsorted
    y = lin(x, x_type)
    assert y.shape == (100, feat_size * 2)
    # test sorted
    y_sorted = lin(x_sorted, x_type_sorted, sorted_by_type=True)
    assert y_sorted.shape == (100, feat_size * 2)

    assert th.allclose(y, y_sorted[rev_idx], atol=1e-4, rtol=1e-4)
1319

1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
@parametrize_dtype
@pytest.mark.parametrize('in_size', [4])
@pytest.mark.parametrize('num_heads', [1])
def test_hgt(idtype, in_size, num_heads):
    dev = F.ctx()
    num_etypes = 5
    num_ntypes = 2
    head_size = in_size // num_heads

    g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.01))
    g = g.astype(idtype).to(dev)
    etype = th.tensor([i % num_etypes for i in range(g.num_edges())]).to(dev)
    ntype = th.tensor([i % num_ntypes for i in range(g.num_nodes())]).to(dev)
    x = th.randn(g.num_nodes(), in_size).to(dev)
1334

1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
    m = nn.HGTConv(in_size, head_size, num_heads, num_ntypes, num_etypes).to(dev)

    y = m(g, x, ntype, etype)
    assert y.shape == (g.num_nodes(), head_size * num_heads)
    # presorted
    sorted_ntype, idx_nt = th.sort(ntype)
    sorted_etype, idx_et = th.sort(etype)
    _, rev_idx = th.sort(idx_nt)
    g.ndata['t'] = ntype
    g.ndata['x'] = x
    g.edata['t'] = etype
    sorted_g = dgl.reorder_graph(g, node_permute_algo='custom', edge_permute_algo='custom',
                                 permute_config={'nodes_perm' : idx_nt.to(idtype), 'edges_perm' : idx_et.to(idtype)})
    print(sorted_g.ndata['t'])
    print(sorted_g.edata['t'])
    sorted_x = sorted_g.ndata['x']
    sorted_y = m(sorted_g, sorted_x, sorted_ntype, sorted_etype, presorted=False)
    assert sorted_y.shape == (g.num_nodes(), head_size * num_heads)
    # TODO(minjie): enable the following check
    #assert th.allclose(y, sorted_y[rev_idx], atol=1e-4, rtol=1e-4)
1355

1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
@pytest.mark.parametrize('self_loop', [True, False])
@pytest.mark.parametrize('get_distances', [True, False])
def test_radius_graph(self_loop, get_distances):
    pos = th.tensor([[0.1, 0.3, 0.4],
                     [0.5, 0.2, 0.1],
                     [0.7, 0.9, 0.5],
                     [0.3, 0.2, 0.5],
                     [0.2, 0.8, 0.2],
                     [0.9, 0.2, 0.1],
                     [0.7, 0.4, 0.4],
                     [0.2, 0.1, 0.6],
                     [0.5, 0.3, 0.5],
                     [0.4, 0.2, 0.6]])

    rg = nn.RadiusGraph(0.3, self_loop=self_loop)

    if get_distances:
        g, dists = rg(pos, get_distances=get_distances)
    else:
        g = rg(pos)

    if self_loop:
        src_target = th.tensor([0, 0, 1, 2, 3, 3, 3, 3, 3, 4, 5, 6, 6, 7, 7, 7,
                                8, 8, 8, 8, 9, 9, 9, 9])
        dst_target = th.tensor([0, 3, 1, 2, 0, 3, 7, 8, 9, 4, 5, 6, 8, 3, 7, 9,
                                3, 6, 8, 9, 3, 7, 8, 9])

        if get_distances:
            dists_target = th.tensor([[0.0000],
                                      [0.2449],
                                      [0.0000],
                                      [0.0000],
                                      [0.2449],
                                      [0.0000],
                                      [0.1732],
                                      [0.2236],
                                      [0.1414],
                                      [0.0000],
                                      [0.0000],
                                      [0.0000],
                                      [0.2449],
                                      [0.1732],
                                      [0.0000],
                                      [0.2236],
                                      [0.2236],
                                      [0.2449],
                                      [0.0000],
                                      [0.1732],
                                      [0.1414],
                                      [0.2236],
                                      [0.1732],
                                      [0.0000]])
    else:
        src_target = th.tensor([0, 3, 3, 3, 3, 6, 7, 7, 8, 8, 8, 9, 9, 9])
        dst_target = th.tensor([3, 0, 7, 8, 9, 8, 3, 9, 3, 6, 9, 3, 7, 8])

        if get_distances:
            dists_target = th.tensor([[0.2449],
                                      [0.2449],
                                      [0.1732],
                                      [0.2236],
                                      [0.1414],
                                      [0.2449],
                                      [0.1732],
                                      [0.2236],
                                      [0.2236],
                                      [0.2449],
                                      [0.1732],
                                      [0.1414],
                                      [0.2236],
                                      [0.1732]])

    src, dst = g.edges()

    assert th.equal(src, src_target)
    assert th.equal(dst, dst_target)

    if get_distances:
        assert th.allclose(dists, dists_target, rtol=1e-03)

1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
@parametrize_dtype
def test_group_rev_res(idtype):
    dev = F.ctx()

    num_nodes = 5
    num_edges = 20
    feats = 32
    groups = 2
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    h = th.randn(num_nodes, feats).to(dev)
    conv = nn.GraphConv(feats // groups, feats // groups)
    model = nn.GroupRevRes(conv, groups).to(dev)
    model(g, h)
rudongyu's avatar
rudongyu committed
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466

@pytest.mark.parametrize('in_size', [16, 32])
@pytest.mark.parametrize('hidden_size', [16, 32])
@pytest.mark.parametrize('out_size', [16, 32])
@pytest.mark.parametrize('edge_feat_size', [16, 10, 0])
def test_egnn_conv(in_size, hidden_size, out_size, edge_feat_size):
    dev = F.ctx()
    num_nodes = 5
    num_edges = 20
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    h = th.randn(num_nodes, in_size).to(dev)
    x = th.randn(num_nodes, 3).to(dev)
    e = th.randn(num_edges, edge_feat_size).to(dev)
    model = nn.EGNNConv(in_size, hidden_size, out_size, edge_feat_size).to(dev)
    model(g, h, x, e)

@pytest.mark.parametrize('in_size', [16, 32])
@pytest.mark.parametrize('out_size', [16, 32])
Mufei Li's avatar
Mufei Li committed
1467
@pytest.mark.parametrize('aggregators',
rudongyu's avatar
rudongyu committed
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
    [['mean', 'max', 'sum'], ['min', 'std', 'var'], ['moment3', 'moment4', 'moment5']])
@pytest.mark.parametrize('scalers', [['identity'], ['amplification', 'attenuation']])
@pytest.mark.parametrize('delta', [2.5, 7.4])
@pytest.mark.parametrize('dropout', [0., 0.1])
@pytest.mark.parametrize('num_towers', [1, 4])
@pytest.mark.parametrize('edge_feat_size', [16, 0])
@pytest.mark.parametrize('residual', [True, False])
def test_pna_conv(in_size, out_size, aggregators, scalers, delta,
    dropout, num_towers, edge_feat_size, residual):
    dev = F.ctx()
    num_nodes = 5
    num_edges = 20
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    h = th.randn(num_nodes, in_size).to(dev)
    e = th.randn(num_edges, edge_feat_size).to(dev)
    model = nn.PNAConv(in_size, out_size, aggregators, scalers, delta, dropout,
        num_towers, edge_feat_size, residual).to(dev)
    model(g, h, edge_feat=e)