test_nn.py 71.2 KB
Newer Older
1
import io
2
3
4
5
6
import pickle
from copy import deepcopy

import backend as F

7
import dgl
8
import dgl.function as fn
9
10
import dgl.nn.pytorch as nn
import networkx as nx
11
import pytest
12
import scipy as sp
LuckyLiuM's avatar
LuckyLiuM committed
13
import torch
14
import torch as th
15
16
17
18
from torch.optim import Adam, SparseAdam
from torch.utils.data import DataLoader
from utils import parametrize_idtype
from utils.graph_cases import (
19
20
21
22
23
    get_cases,
    random_bipartite,
    random_dglgraph,
    random_graph,
)
24

25
26
tmp_buffer = io.BytesIO()

27

28
29
30
31
32
def _AXWb(A, X, W, b):
    X = th.matmul(X, W)
    Y = th.matmul(A, X.view(X.shape[0], -1)).view_as(X)
    return Y + b

33
34

@pytest.mark.parametrize("out_dim", [1, 2])
35
def test_graph_conv0(out_dim):
36
    g = dgl.DGLGraph(nx.path_graph(3)).to(F.ctx())
37
    ctx = F.ctx()
38
    adj = g.adjacency_matrix(transpose=True, ctx=ctx)
39

40
    conv = nn.GraphConv(5, out_dim, norm="none", bias=True)
41
    conv = conv.to(ctx)
42
    print(conv)
43
44
45
46

    # test pickle
    th.save(conv, tmp_buffer)

47
    # test#1: basic
48
    h0 = F.ones((3, 5))
49
    h1 = conv(g, h0)
50
51
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
52
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
53
    # test#2: more-dim
54
    h0 = F.ones((3, 5, 5))
55
    h1 = conv(g, h0)
56
57
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
58
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
59

60
    conv = nn.GraphConv(5, out_dim)
61
    conv = conv.to(ctx)
62
    # test#3: basic
63
    h0 = F.ones((3, 5))
64
    h1 = conv(g, h0)
65
66
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
67
    # test#4: basic
68
    h0 = F.ones((3, 5, 5))
69
    h1 = conv(g, h0)
70
71
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
72

73
    conv = nn.GraphConv(5, out_dim)
74
    conv = conv.to(ctx)
75
    # test#3: basic
76
    h0 = F.ones((3, 5))
77
    h1 = conv(g, h0)
78
79
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
80
    # test#4: basic
81
    h0 = F.ones((3, 5, 5))
82
    h1 = conv(g, h0)
83
84
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
85
86
87
88
89

    # test rest_parameters
    old_weight = deepcopy(conv.weight.data)
    conv.reset_parameters()
    new_weight = conv.weight.data
90
    assert not F.allclose(old_weight, new_weight)
91

92

nv-dlasalle's avatar
nv-dlasalle committed
93
@parametrize_idtype
94
95
96
97
98
99
100
@pytest.mark.parametrize(
    "g", get_cases(["homo", "bipartite"], exclude=["zero-degree", "dglgraph"])
)
@pytest.mark.parametrize("norm", ["none", "both", "right", "left"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("out_dim", [1, 2])
101
def test_graph_conv(idtype, g, norm, weight, bias, out_dim):
102
103
    # Test one tensor input
    g = g.astype(idtype).to(F.ctx())
104
105
106
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(
        F.ctx()
    )
107
    ext_w = F.randn((5, out_dim)).to(F.ctx())
108
109
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
110
111
    h = F.randn((nsrc, 5)).to(F.ctx())
    if weight:
112
        h_out = conv(g, h)
113
    else:
114
        h_out = conv(g, h, weight=ext_w)
115
    assert h_out.shape == (ndst, out_dim)
116

117

nv-dlasalle's avatar
nv-dlasalle committed
118
@parametrize_idtype
119
120
121
122
123
124
125
126
@pytest.mark.parametrize(
    "g",
    get_cases(["has_scalar_e_feature"], exclude=["zero-degree", "dglgraph"]),
)
@pytest.mark.parametrize("norm", ["none", "both", "right"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("out_dim", [1, 2])
127
def test_graph_conv_e_weight(idtype, g, norm, weight, bias, out_dim):
128
    g = g.astype(idtype).to(F.ctx())
129
130
131
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(
        F.ctx()
    )
132
    ext_w = F.randn((5, out_dim)).to(F.ctx())
133
134
135
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
136
    e_w = g.edata["scalar_w"]
137
138
139
140
    if weight:
        h_out = conv(g, h, edge_weight=e_w)
    else:
        h_out = conv(g, h, weight=ext_w, edge_weight=e_w)
141
    assert h_out.shape == (ndst, out_dim)
142

143

nv-dlasalle's avatar
nv-dlasalle committed
144
@parametrize_idtype
145
146
147
148
149
150
151
152
@pytest.mark.parametrize(
    "g",
    get_cases(["has_scalar_e_feature"], exclude=["zero-degree", "dglgraph"]),
)
@pytest.mark.parametrize("norm", ["none", "both", "right"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("out_dim", [1, 2])
153
def test_graph_conv_e_weight_norm(idtype, g, norm, weight, bias, out_dim):
154
    g = g.astype(idtype).to(F.ctx())
155
156
157
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(
        F.ctx()
    )
158
159
160
161

    # test pickle
    th.save(conv, tmp_buffer)

162
    ext_w = F.randn((5, out_dim)).to(F.ctx())
163
164
165
166
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
    edgenorm = nn.EdgeWeightNorm(norm=norm)
167
    norm_weight = edgenorm(g, g.edata["scalar_w"])
168
169
170
171
    if weight:
        h_out = conv(g, h, edge_weight=norm_weight)
    else:
        h_out = conv(g, h, weight=ext_w, edge_weight=norm_weight)
172
    assert h_out.shape == (ndst, out_dim)
173

174

nv-dlasalle's avatar
nv-dlasalle committed
175
@parametrize_idtype
176
177
178
179
180
181
182
@pytest.mark.parametrize(
    "g", get_cases(["bipartite"], exclude=["zero-degree", "dglgraph"])
)
@pytest.mark.parametrize("norm", ["none", "both", "right"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("out_dim", [1, 2])
183
def test_graph_conv_bi(idtype, g, norm, weight, bias, out_dim):
184
185
    # Test a pair of tensor inputs
    g = g.astype(idtype).to(F.ctx())
186
187
188
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(
        F.ctx()
    )
Mufei Li's avatar
Mufei Li committed
189

190
191
192
    # test pickle
    th.save(conv, tmp_buffer)

193
    ext_w = F.randn((5, out_dim)).to(F.ctx())
194
195
196
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
197
    h_dst = F.randn((ndst, out_dim)).to(F.ctx())
198
199
200
201
    if weight:
        h_out = conv(g, (h, h_dst))
    else:
        h_out = conv(g, (h, h_dst), weight=ext_w)
202
    assert h_out.shape == (ndst, out_dim)
203

204

205
206
207
208
209
210
211
212
213
214
215
216
def _S2AXWb(A, N, X, W, b):
    X1 = X * N
    X1 = th.matmul(A, X1.view(X1.shape[0], -1))
    X1 = X1 * N
    X2 = X1 * N
    X2 = th.matmul(A, X2.view(X2.shape[0], -1))
    X2 = X2 * N
    X = th.cat([X, X1, X2], dim=-1)
    Y = th.matmul(X, W.rot90())

    return Y + b

217
218

@pytest.mark.parametrize("out_dim", [1, 2])
219
def test_tagconv(out_dim):
220
    g = dgl.DGLGraph(nx.path_graph(3))
221
    g = g.to(F.ctx())
222
    ctx = F.ctx()
223
    adj = g.adjacency_matrix(transpose=True, ctx=ctx)
224
225
    norm = th.pow(g.in_degrees().float(), -0.5)

226
    conv = nn.TAGConv(5, out_dim, bias=True)
227
    conv = conv.to(ctx)
228
    print(conv)
Mufei Li's avatar
Mufei Li committed
229

230
231
    # test pickle
    th.save(conv, tmp_buffer)
232
233
234

    # test#1: basic
    h0 = F.ones((3, 5))
235
    h1 = conv(g, h0)
236
237
238
239
240
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    shp = norm.shape + (1,) * (h0.dim() - 1)
    norm = th.reshape(norm, shp).to(ctx)

241
242
243
    assert F.allclose(
        h1, _S2AXWb(adj, norm, h0, conv.lin.weight, conv.lin.bias)
    )
244

245
    conv = nn.TAGConv(5, out_dim)
246
    conv = conv.to(ctx)
247

248
249
    # test#2: basic
    h0 = F.ones((3, 5))
250
    h1 = conv(g, h0)
251
    assert h1.shape[-1] == out_dim
252

253
    # test reset_parameters
254
255
256
257
258
    old_weight = deepcopy(conv.lin.weight.data)
    conv.reset_parameters()
    new_weight = conv.lin.weight.data
    assert not F.allclose(old_weight, new_weight)

259

260
def test_set2set():
261
    ctx = F.ctx()
262
    g = dgl.DGLGraph(nx.path_graph(10))
263
    g = g.to(F.ctx())
264

265
    s2s = nn.Set2Set(5, 3, 3)  # hidden size 5, 3 iters, 3 layers
266
    s2s = s2s.to(ctx)
267
268
269
    print(s2s)

    # test#1: basic
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
270
    h0 = F.randn((g.num_nodes(), 5))
271
    h1 = s2s(g, h0)
272
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
273
274

    # test#2: batched graph
275
276
    g1 = dgl.DGLGraph(nx.path_graph(11)).to(F.ctx())
    g2 = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
277
    bg = dgl.batch([g, g1, g2])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
278
    h0 = F.randn((bg.num_nodes(), 5))
279
    h1 = s2s(bg, h0)
280
281
    assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2

282

283
def test_glob_att_pool():
284
    ctx = F.ctx()
285
    g = dgl.DGLGraph(nx.path_graph(10))
286
    g = g.to(F.ctx())
287
288

    gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))
289
    gap = gap.to(ctx)
290
291
    print(gap)

292
293
294
    # test pickle
    th.save(gap, tmp_buffer)

295
    # test#1: basic
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
296
    h0 = F.randn((g.num_nodes(), 5))
297
    h1 = gap(g, h0)
298
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
299
300
301

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
302
    h0 = F.randn((bg.num_nodes(), 5))
303
    h1 = gap(bg, h0)
304
305
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2

306

307
def test_simple_pool():
308
    ctx = F.ctx()
309
    g = dgl.DGLGraph(nx.path_graph(15))
310
    g = g.to(F.ctx())
311
312
313
314

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
315
    sort_pool = nn.SortPooling(10)  # k = 10
316
317
318
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
319
    h0 = F.randn((g.num_nodes(), 5))
320
321
322
323
    sum_pool = sum_pool.to(ctx)
    avg_pool = avg_pool.to(ctx)
    max_pool = max_pool.to(ctx)
    sort_pool = sort_pool.to(ctx)
324
    h1 = sum_pool(g, h0)
325
    assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
326
    h1 = avg_pool(g, h0)
327
    assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
328
    h1 = max_pool(g, h0)
329
    assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
330
    h1 = sort_pool(g, h0)
331
    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.dim() == 2
332
333

    # test#2: batched graph
334
    g_ = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
335
    bg = dgl.batch([g, g_, g, g_, g])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
336
    h0 = F.randn((bg.num_nodes(), 5))
337
    h1 = sum_pool(bg, h0)
338
339
340
341
342
343
344
345
346
347
    truth = th.stack(
        [
            F.sum(h0[:15], 0),
            F.sum(h0[15:20], 0),
            F.sum(h0[20:35], 0),
            F.sum(h0[35:40], 0),
            F.sum(h0[40:55], 0),
        ],
        0,
    )
348
    assert F.allclose(h1, truth)
349

350
    h1 = avg_pool(bg, h0)
351
352
353
354
355
356
357
358
359
360
    truth = th.stack(
        [
            F.mean(h0[:15], 0),
            F.mean(h0[15:20], 0),
            F.mean(h0[20:35], 0),
            F.mean(h0[35:40], 0),
            F.mean(h0[40:55], 0),
        ],
        0,
    )
361
    assert F.allclose(h1, truth)
362

363
    h1 = max_pool(bg, h0)
364
365
366
367
368
369
370
371
372
373
    truth = th.stack(
        [
            F.max(h0[:15], 0),
            F.max(h0[15:20], 0),
            F.max(h0[20:35], 0),
            F.max(h0[35:40], 0),
            F.max(h0[40:55], 0),
        ],
        0,
    )
374
    assert F.allclose(h1, truth)
375

376
    h1 = sort_pool(bg, h0)
377
378
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2

379

380
def test_set_trans():
381
    ctx = F.ctx()
382
383
    g = dgl.DGLGraph(nx.path_graph(15))

384
385
    st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, "sab")
    st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, "isab", 3)
386
    st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)
387
388
389
    st_enc_0 = st_enc_0.to(ctx)
    st_enc_1 = st_enc_1.to(ctx)
    st_dec = st_dec.to(ctx)
390
391
392
    print(st_enc_0, st_enc_1, st_dec)

    # test#1: basic
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
393
    h0 = F.randn((g.num_nodes(), 50))
394
    h1 = st_enc_0(g, h0)
395
    assert h1.shape == h0.shape
396
    h1 = st_enc_1(g, h0)
397
    assert h1.shape == h0.shape
398
    h2 = st_dec(g, h1)
399
    assert h2.shape[0] == 1 and h2.shape[1] == 200 and h2.dim() == 2
400
401
402
403
404

    # test#2: batched graph
    g1 = dgl.DGLGraph(nx.path_graph(5))
    g2 = dgl.DGLGraph(nx.path_graph(10))
    bg = dgl.batch([g, g1, g2])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
405
    h0 = F.randn((bg.num_nodes(), 50))
406
    h1 = st_enc_0(bg, h0)
407
    assert h1.shape == h0.shape
408
    h1 = st_enc_1(bg, h0)
409
410
    assert h1.shape == h0.shape

411
    h2 = st_dec(bg, h1)
412
413
    assert h2.shape[0] == 3 and h2.shape[1] == 200 and h2.dim() == 2

414

nv-dlasalle's avatar
nv-dlasalle committed
415
@parametrize_idtype
416
@pytest.mark.parametrize("O", [1, 8, 32])
417
def test_rgcn(idtype, O):
Minjie Wang's avatar
Minjie Wang committed
418
419
    ctx = F.ctx()
    etype = []
420
421
    g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.1))
    g = g.astype(idtype).to(F.ctx())
Minjie Wang's avatar
Minjie Wang committed
422
423
    # 5 etypes
    R = 5
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
424
    for i in range(g.num_edges()):
Minjie Wang's avatar
Minjie Wang committed
425
426
427
428
429
430
        etype.append(i % 5)
    B = 2
    I = 10

    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
431
    norm = th.rand((g.num_edges(), 1)).to(ctx)
432
    sorted_r, idx = th.sort(r)
433
434
435
436
437
    sorted_g = dgl.reorder_graph(
        g,
        edge_permute_algo="custom",
        permute_config={"edges_perm": idx.to(idtype)},
    )
438
    sorted_norm = norm[idx]
Minjie Wang's avatar
Minjie Wang committed
439

440
441
    rgc = nn.RelGraphConv(I, O, R).to(ctx)
    th.save(rgc, tmp_buffer)  # test pickle
Minjie Wang's avatar
Minjie Wang committed
442
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
443
    th.save(rgc_basis, tmp_buffer)  # test pickle
444
445
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
446
        th.save(rgc_bdd, tmp_buffer)  # test pickle
447

448
449
450
451
452
    # basic usage
    h_new = rgc(g, h, r)
    assert h_new.shape == (100, O)
    h_new_basis = rgc_basis(g, h, r)
    assert h_new_basis.shape == (100, O)
453
    if O % B == 0:
454
455
456
457
458
459
460
461
462
463
464
        h_new_bdd = rgc_bdd(g, h, r)
        assert h_new_bdd.shape == (100, O)

    # sorted input
    h_new_sorted = rgc(sorted_g, h, sorted_r, presorted=True)
    assert th.allclose(h_new, h_new_sorted, atol=1e-4, rtol=1e-4)
    h_new_basis_sorted = rgc_basis(sorted_g, h, sorted_r, presorted=True)
    assert th.allclose(h_new_basis, h_new_basis_sorted, atol=1e-4, rtol=1e-4)
    if O % B == 0:
        h_new_bdd_sorted = rgc_bdd(sorted_g, h, sorted_r, presorted=True)
        assert th.allclose(h_new_bdd, h_new_bdd_sorted, atol=1e-4, rtol=1e-4)
465

466
467
468
    # norm input
    h_new = rgc(g, h, r, norm)
    assert h_new.shape == (100, O)
469
    h_new = rgc_basis(g, h, r, norm)
470
    assert h_new.shape == (100, O)
471
472
    if O % B == 0:
        h_new = rgc_bdd(g, h, r, norm)
473
        assert h_new.shape == (100, O)
474

475

476
@parametrize_idtype
477
@pytest.mark.parametrize("O", [1, 10, 40])
478
479
480
481
482
483
484
def test_rgcn_default_nbasis(idtype, O):
    ctx = F.ctx()
    etype = []
    g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.1))
    g = g.astype(idtype).to(F.ctx())
    # 5 etypes
    R = 5
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
485
    for i in range(g.num_edges()):
486
487
488
489
490
        etype.append(i % 5)
    I = 10

    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
491
    norm = th.rand((g.num_edges(), 1)).to(ctx)
492
    sorted_r, idx = th.sort(r)
493
494
495
496
497
    sorted_g = dgl.reorder_graph(
        g,
        edge_permute_algo="custom",
        permute_config={"edges_perm": idx.to(idtype)},
    )
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
    sorted_norm = norm[idx]

    rgc = nn.RelGraphConv(I, O, R).to(ctx)
    th.save(rgc, tmp_buffer)  # test pickle
    rgc_basis = nn.RelGraphConv(I, O, R, "basis").to(ctx)
    th.save(rgc_basis, tmp_buffer)  # test pickle
    if O % R == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd").to(ctx)
        th.save(rgc_bdd, tmp_buffer)  # test pickle

    # basic usage
    h_new = rgc(g, h, r)
    assert h_new.shape == (100, O)
    h_new_basis = rgc_basis(g, h, r)
    assert h_new_basis.shape == (100, O)
    if O % R == 0:
        h_new_bdd = rgc_bdd(g, h, r)
        assert h_new_bdd.shape == (100, O)

    # sorted input
    h_new_sorted = rgc(sorted_g, h, sorted_r, presorted=True)
    assert th.allclose(h_new, h_new_sorted, atol=1e-4, rtol=1e-4)
    h_new_basis_sorted = rgc_basis(sorted_g, h, sorted_r, presorted=True)
    assert th.allclose(h_new_basis, h_new_basis_sorted, atol=1e-4, rtol=1e-4)
    if O % R == 0:
        h_new_bdd_sorted = rgc_bdd(sorted_g, h, sorted_r, presorted=True)
        assert th.allclose(h_new_bdd, h_new_bdd_sorted, atol=1e-4, rtol=1e-4)

    # norm input
    h_new = rgc(g, h, r, norm)
    assert h_new.shape == (100, O)
    h_new = rgc_basis(g, h, r, norm)
    assert h_new.shape == (100, O)
    if O % R == 0:
        h_new = rgc_bdd(g, h, r, norm)
        assert h_new.shape == (100, O)
534

535

nv-dlasalle's avatar
nv-dlasalle committed
536
@parametrize_idtype
537
538
539
540
541
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
542
def test_gat_conv(g, idtype, out_dim, num_heads):
543
    ctx = F.ctx()
544
    g = g.astype(idtype).to(ctx)
545
    gat = nn.GATConv(5, out_dim, num_heads)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
546
    feat = F.randn((g.number_of_src_nodes(), 5))
547
    gat = gat.to(ctx)
548
    h = gat(g, feat)
549
550
551
552

    # test pickle
    th.save(gat, tmp_buffer)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
553
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
554
    _, a = gat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
555
    assert a.shape == (g.num_edges(), num_heads, 1)
556

557
558
559
560
561
    # test residual connection
    gat = nn.GATConv(5, out_dim, num_heads, residual=True)
    gat = gat.to(ctx)
    h = gat(g, feat)

562

nv-dlasalle's avatar
nv-dlasalle committed
563
@parametrize_idtype
564
565
566
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
567
def test_gat_conv_bi(g, idtype, out_dim, num_heads):
568
    ctx = F.ctx()
569
    g = g.astype(idtype).to(ctx)
570
    gat = nn.GATConv(5, out_dim, num_heads)
571
572
573
574
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
575
576
    gat = gat.to(ctx)
    h = gat(g, feat)
577
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
578
    _, a = gat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
579
    assert a.shape == (g.num_edges(), num_heads, 1)
580

581

582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
@parametrize_idtype
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
def test_gat_conv_edge_weight(g, idtype, out_dim, num_heads):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    gat = nn.GATConv(5, out_dim, num_heads)
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
    gat = gat.to(ctx)
    ew = F.randn((g.num_edges(),))
    h = gat(g, feat, edge_weight=ew)
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
    _, a = gat(g, feat, get_attention=True)
    assert a.shape[0] == ew.shape[0]
    assert a.shape == (g.num_edges(), num_heads, 1)


nv-dlasalle's avatar
nv-dlasalle committed
603
@parametrize_idtype
604
605
606
607
608
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
Shaked Brody's avatar
Shaked Brody committed
609
610
611
612
613
614
615
616
617
618
619
620
621
def test_gatv2_conv(g, idtype, out_dim, num_heads):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    gat = nn.GATv2Conv(5, out_dim, num_heads)
    feat = F.randn((g.number_of_src_nodes(), 5))
    gat = gat.to(ctx)
    h = gat(g, feat)

    # test pickle
    th.save(gat, tmp_buffer)

    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
    _, a = gat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
622
    assert a.shape == (g.num_edges(), num_heads, 1)
Shaked Brody's avatar
Shaked Brody committed
623
624
625
626
627
628

    # test residual connection
    gat = nn.GATConv(5, out_dim, num_heads, residual=True)
    gat = gat.to(ctx)
    h = gat(g, feat)

629

nv-dlasalle's avatar
nv-dlasalle committed
630
@parametrize_idtype
631
632
633
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
Shaked Brody's avatar
Shaked Brody committed
634
635
636
637
def test_gatv2_conv_bi(g, idtype, out_dim, num_heads):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    gat = nn.GATv2Conv(5, out_dim, num_heads)
638
639
640
641
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
Shaked Brody's avatar
Shaked Brody committed
642
643
644
645
    gat = gat.to(ctx)
    h = gat(g, feat)
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
    _, a = gat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
646
    assert a.shape == (g.num_edges(), num_heads, 1)
Shaked Brody's avatar
Shaked Brody committed
647

648

nv-dlasalle's avatar
nv-dlasalle committed
649
@parametrize_idtype
650
651
652
653
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_node_feats", [1, 5])
@pytest.mark.parametrize("out_edge_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
654
655
def test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads):
    g = g.astype(idtype).to(F.ctx())
Mufei Li's avatar
Mufei Li committed
656
    ctx = F.ctx()
657
658
659
660
661
662
663
    egat = nn.EGATConv(
        in_node_feats=10,
        in_edge_feats=5,
        out_node_feats=out_node_feats,
        out_edge_feats=out_edge_feats,
        num_heads=num_heads,
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
664
665
    nfeat = F.randn((g.num_nodes(), 10))
    efeat = F.randn((g.num_edges(), 5))
666
667
    egat = egat.to(ctx)
    h, f = egat(g, nfeat, efeat)
668

669
    th.save(egat, tmp_buffer)
670

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
671
672
    assert h.shape == (g.num_nodes(), num_heads, out_node_feats)
    assert f.shape == (g.num_edges(), num_heads, out_edge_feats)
673
    _, _, attn = egat(g, nfeat, efeat, True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
674
    assert attn.shape == (g.num_edges(), num_heads, 1)
675

676

677
@parametrize_idtype
678
679
680
681
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_node_feats", [1, 5])
@pytest.mark.parametrize("out_edge_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
682
683
684
def test_egat_conv_bi(g, idtype, out_node_feats, out_edge_feats, num_heads):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
685
686
687
688
689
690
691
692
693
694
695
    egat = nn.EGATConv(
        in_node_feats=(10, 15),
        in_edge_feats=7,
        out_node_feats=out_node_feats,
        out_edge_feats=out_edge_feats,
        num_heads=num_heads,
    )
    nfeat = (
        F.randn((g.number_of_src_nodes(), 10)),
        F.randn((g.number_of_dst_nodes(), 15)),
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
696
    efeat = F.randn((g.num_edges(), 7))
697
698
    egat = egat.to(ctx)
    h, f = egat(g, nfeat, efeat)
699

Mufei Li's avatar
Mufei Li committed
700
    th.save(egat, tmp_buffer)
701

702
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_node_feats)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
703
    assert f.shape == (g.num_edges(), num_heads, out_edge_feats)
704
    _, _, attn = egat(g, nfeat, efeat, True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
705
    assert attn.shape == (g.num_edges(), num_heads, 1)
schmidt-ju's avatar
schmidt-ju committed
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755


@parametrize_idtype
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
def test_edgegat_conv(g, idtype, out_feats, num_heads):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    edgegat = nn.EdgeGATConv(
        in_feats=10, edge_feats=5, out_feats=out_feats, num_heads=num_heads
    )
    nfeat = F.randn((g.number_of_nodes(), 10))
    efeat = F.randn((g.number_of_edges(), 5))
    edgegat = edgegat.to(ctx)
    h = edgegat(g, nfeat, efeat)

    th.save(edgegat, tmp_buffer)

    assert h.shape == (g.number_of_nodes(), num_heads, out_feats)
    _, attn = edgegat(g, nfeat, efeat, True)
    assert attn.shape == (g.number_of_edges(), num_heads, 1)


@parametrize_idtype
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
def test_edgegat_conv_bi(g, idtype, out_feats, num_heads):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    edgegat = nn.EdgeGATConv(
        in_feats=(10, 15),
        edge_feats=7,
        out_feats=out_feats,
        num_heads=num_heads,
    )
    nfeat = (
        F.randn((g.number_of_src_nodes(), 10)),
        F.randn((g.number_of_dst_nodes(), 15)),
    )
    efeat = F.randn((g.number_of_edges(), 7))
    edgegat = edgegat.to(ctx)
    h = edgegat(g, nfeat, efeat)

    th.save(edgegat, tmp_buffer)

    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_feats)
    _, attn = edgegat(g, nfeat, efeat, True)
    assert attn.shape == (g.number_of_edges(), num_heads, 1)
756

757

nv-dlasalle's avatar
nv-dlasalle committed
758
@parametrize_idtype
759
760
@pytest.mark.parametrize("g", get_cases(["homo", "block-bipartite"]))
@pytest.mark.parametrize("aggre_type", ["mean", "pool", "gcn", "lstm"])
761
762
def test_sage_conv(idtype, g, aggre_type):
    g = g.astype(idtype).to(F.ctx())
763
    sage = nn.SAGEConv(5, 10, aggre_type)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
764
    feat = F.randn((g.number_of_src_nodes(), 5))
765
    sage = sage.to(F.ctx())
766
767
    # test pickle
    th.save(sage, tmp_buffer)
768
769
770
    h = sage(g, feat)
    assert h.shape[-1] == 10

771

nv-dlasalle's avatar
nv-dlasalle committed
772
@parametrize_idtype
773
774
775
@pytest.mark.parametrize("g", get_cases(["bipartite"]))
@pytest.mark.parametrize("aggre_type", ["mean", "pool", "gcn", "lstm"])
@pytest.mark.parametrize("out_dim", [1, 2])
776
def test_sage_conv_bi(idtype, g, aggre_type, out_dim):
777
    g = g.astype(idtype).to(F.ctx())
778
    dst_dim = 5 if aggre_type != "gcn" else 10
779
    sage = nn.SAGEConv((10, dst_dim), out_dim, aggre_type)
780
781
782
783
    feat = (
        F.randn((g.number_of_src_nodes(), 10)),
        F.randn((g.number_of_dst_nodes(), dst_dim)),
    )
784
    sage = sage.to(F.ctx())
785
    h = sage(g, feat)
786
    assert h.shape[-1] == out_dim
787
    assert h.shape[0] == g.number_of_dst_nodes()
788

789

nv-dlasalle's avatar
nv-dlasalle committed
790
@parametrize_idtype
791
@pytest.mark.parametrize("out_dim", [1, 2])
792
def test_sage_conv2(idtype, out_dim):
793
    # TODO: add test for blocks
Mufei Li's avatar
Mufei Li committed
794
    # Test the case for graphs without edges
795
    g = dgl.heterograph({("_U", "_E", "_V"): ([], [])}, {"_U": 5, "_V": 3})
796
797
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
798
    sage = nn.SAGEConv((3, 3), out_dim, "gcn")
Mufei Li's avatar
Mufei Li committed
799
800
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    sage = sage.to(ctx)
801
    h = sage(g, (F.copy_to(feat[0], F.ctx()), F.copy_to(feat[1], F.ctx())))
802
    assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
803
    assert h.shape[0] == 3
804
    for aggre_type in ["mean", "pool", "lstm"]:
805
        sage = nn.SAGEConv((3, 1), out_dim, aggre_type)
Mufei Li's avatar
Mufei Li committed
806
807
808
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        sage = sage.to(ctx)
        h = sage(g, feat)
809
        assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
810
811
        assert h.shape[0] == 3

812

nv-dlasalle's avatar
nv-dlasalle committed
813
@parametrize_idtype
814
815
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
816
def test_sgc_conv(g, idtype, out_dim):
817
    ctx = F.ctx()
818
    g = g.astype(idtype).to(ctx)
819
    # not cached
820
    sgc = nn.SGConv(5, out_dim, 3)
821
822
823
824

    # test pickle
    th.save(sgc, tmp_buffer)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
825
    feat = F.randn((g.num_nodes(), 5))
826
    sgc = sgc.to(ctx)
827

828
    h = sgc(g, feat)
829
    assert h.shape[-1] == out_dim
830
831

    # cached
832
    sgc = nn.SGConv(5, out_dim, 3, True)
833
    sgc = sgc.to(ctx)
834
835
    h_0 = sgc(g, feat)
    h_1 = sgc(g, feat + 1)
836
    assert F.allclose(h_0, h_1)
837
    assert h_0.shape[-1] == out_dim
838

839

nv-dlasalle's avatar
nv-dlasalle committed
840
@parametrize_idtype
841
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
842
def test_appnp_conv(g, idtype):
843
    ctx = F.ctx()
844
    g = g.astype(idtype).to(ctx)
845
    appnp = nn.APPNPConv(10, 0.1)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
846
    feat = F.randn((g.num_nodes(), 5))
847
    appnp = appnp.to(ctx)
Mufei Li's avatar
Mufei Li committed
848

849
850
    # test pickle
    th.save(appnp, tmp_buffer)
851

852
    h = appnp(g, feat)
853
854
    assert h.shape[-1] == 5

855

nv-dlasalle's avatar
nv-dlasalle committed
856
@parametrize_idtype
857
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
858
859
860
861
def test_appnp_conv_e_weight(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    appnp = nn.APPNPConv(10, 0.1)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
862
    feat = F.randn((g.num_nodes(), 5))
863
    eweight = F.ones((g.num_edges(),))
864
865
866
867
868
    appnp = appnp.to(ctx)

    h = appnp(g, feat, edge_weight=eweight)
    assert h.shape[-1] == 5

869

nv-dlasalle's avatar
nv-dlasalle committed
870
@parametrize_idtype
871
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
872
873
@pytest.mark.parametrize("bias", [True, False])
def test_gcn2conv_e_weight(g, idtype, bias):
874
875
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
876
877
878
    gcn2conv = nn.GCN2Conv(
        5, layer=2, alpha=0.5, bias=bias, project_initial_features=True
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
879
    feat = F.randn((g.num_nodes(), 5))
880
    eweight = F.ones((g.num_edges(),))
881
882
883
884
885
886
    gcn2conv = gcn2conv.to(ctx)
    res = feat
    h = gcn2conv(g, res, feat, edge_weight=eweight)
    assert h.shape[-1] == 5


nv-dlasalle's avatar
nv-dlasalle committed
887
@parametrize_idtype
888
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
889
890
891
892
def test_sgconv_e_weight(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    sgconv = nn.SGConv(5, 5, 3)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
893
    feat = F.randn((g.num_nodes(), 5))
894
    eweight = F.ones((g.num_edges(),))
895
896
897
898
    sgconv = sgconv.to(ctx)
    h = sgconv(g, feat, edge_weight=eweight)
    assert h.shape[-1] == 5

899

nv-dlasalle's avatar
nv-dlasalle committed
900
@parametrize_idtype
901
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
902
903
904
905
906
def test_tagconv_e_weight(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    conv = nn.TAGConv(5, 5, bias=True)
    conv = conv.to(ctx)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
907
    feat = F.randn((g.num_nodes(), 5))
908
    eweight = F.ones((g.num_edges(),))
909
910
911
912
    conv = conv.to(ctx)
    h = conv(g, feat, edge_weight=eweight)
    assert h.shape[-1] == 5

913

nv-dlasalle's avatar
nv-dlasalle committed
914
@parametrize_idtype
915
916
917
918
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("aggregator_type", ["mean", "max", "sum"])
919
920
def test_gin_conv(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
921
    ctx = F.ctx()
922
    gin = nn.GINConv(th.nn.Linear(5, 12), aggregator_type)
VoVAllen's avatar
VoVAllen committed
923
    th.save(gin, tmp_buffer)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
924
    feat = F.randn((g.number_of_src_nodes(), 5))
925
926
    gin = gin.to(ctx)
    h = gin(g, feat)
927
928

    # test pickle
VoVAllen's avatar
VoVAllen committed
929
    th.save(gin, tmp_buffer)
Mufei Li's avatar
Mufei Li committed
930

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
931
    assert h.shape == (g.number_of_dst_nodes(), 12)
932

Mufei Li's avatar
Mufei Li committed
933
934
935
936
    gin = nn.GINConv(None, aggregator_type)
    th.save(gin, tmp_buffer)
    gin = gin.to(ctx)
    h = gin(g, feat)
937

938

nv-dlasalle's avatar
nv-dlasalle committed
939
@parametrize_idtype
940
@pytest.mark.parametrize("g", get_cases(["homo", "block-bipartite"]))
Mufei Li's avatar
Mufei Li committed
941
942
943
def test_gine_conv(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
944
    gine = nn.GINEConv(th.nn.Linear(5, 12))
Mufei Li's avatar
Mufei Li committed
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
    th.save(gine, tmp_buffer)
    nfeat = F.randn((g.number_of_src_nodes(), 5))
    efeat = F.randn((g.num_edges(), 5))
    gine = gine.to(ctx)
    h = gine(g, nfeat, efeat)

    # test pickle
    th.save(gine, tmp_buffer)
    assert h.shape == (g.number_of_dst_nodes(), 12)

    gine = nn.GINEConv(None)
    th.save(gine, tmp_buffer)
    gine = gine.to(ctx)
    h = gine(g, nfeat, efeat)

960

nv-dlasalle's avatar
nv-dlasalle committed
961
@parametrize_idtype
962
963
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("aggregator_type", ["mean", "max", "sum"])
964
965
966
def test_gin_conv_bi(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
967
968
969
970
    gin = nn.GINConv(th.nn.Linear(5, 12), aggregator_type)
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
971
972
973
    )
    gin = gin.to(ctx)
    h = gin(g, feat)
974
    assert h.shape == (g.number_of_dst_nodes(), 12)
975

976

nv-dlasalle's avatar
nv-dlasalle committed
977
@parametrize_idtype
978
979
980
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
981
982
def test_agnn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
983
984
    ctx = F.ctx()
    agnn = nn.AGNNConv(1)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
985
    feat = F.randn((g.number_of_src_nodes(), 5))
986
    agnn = agnn.to(ctx)
987
    h = agnn(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
988
    assert h.shape == (g.number_of_dst_nodes(), 5)
989

990

nv-dlasalle's avatar
nv-dlasalle committed
991
@parametrize_idtype
992
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
993
994
995
def test_agnn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
996
    agnn = nn.AGNNConv(1)
997
998
999
1000
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
1001
1002
    agnn = agnn.to(ctx)
    h = agnn(g, feat)
1003
    assert h.shape == (g.number_of_dst_nodes(), 5)
1004

1005

nv-dlasalle's avatar
nv-dlasalle committed
1006
@parametrize_idtype
1007
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
1008
def test_gated_graph_conv(g, idtype):
1009
    ctx = F.ctx()
1010
    g = g.astype(idtype).to(ctx)
1011
    ggconv = nn.GatedGraphConv(5, 10, 5, 3)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1012
1013
    etypes = th.arange(g.num_edges()) % 3
    feat = F.randn((g.num_nodes(), 5))
1014
1015
    ggconv = ggconv.to(ctx)
    etypes = etypes.to(ctx)
1016

1017
    h = ggconv(g, feat, etypes)
1018
1019
1020
    # current we only do shape check
    assert h.shape[-1] == 10

1021

nv-dlasalle's avatar
nv-dlasalle committed
1022
@parametrize_idtype
1023
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
1024
1025
1026
1027
def test_gated_graph_conv_one_etype(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    ggconv = nn.GatedGraphConv(5, 10, 5, 1)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1028
1029
    etypes = th.zeros(g.num_edges())
    feat = F.randn((g.num_nodes(), 5))
1030
1031
1032
1033
1034
1035
1036
1037
1038
    ggconv = ggconv.to(ctx)
    etypes = etypes.to(ctx)

    h = ggconv(g, feat, etypes)
    h2 = ggconv(g, feat)
    # current we only do shape check
    assert F.allclose(h, h2)
    assert h.shape[-1] == 10

1039

nv-dlasalle's avatar
nv-dlasalle committed
1040
@parametrize_idtype
1041
1042
1043
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
1044
1045
def test_nn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
1046
1047
    ctx = F.ctx()
    edge_func = th.nn.Linear(4, 5 * 10)
1048
    nnconv = nn.NNConv(5, 10, edge_func, "mean")
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1049
    feat = F.randn((g.number_of_src_nodes(), 5))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1050
    efeat = F.randn((g.num_edges(), 4))
1051
1052
1053
1054
1055
    nnconv = nnconv.to(ctx)
    h = nnconv(g, feat, efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

1056

nv-dlasalle's avatar
nv-dlasalle committed
1057
@parametrize_idtype
1058
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
1059
1060
1061
def test_nn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1062
    edge_func = th.nn.Linear(4, 5 * 10)
1063
    nnconv = nn.NNConv((5, 2), 10, edge_func, "mean")
1064
1065
    feat = F.randn((g.number_of_src_nodes(), 5))
    feat_dst = F.randn((g.number_of_dst_nodes(), 2))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1066
    efeat = F.randn((g.num_edges(), 4))
1067
1068
1069
1070
1071
    nnconv = nnconv.to(ctx)
    h = nnconv(g, (feat, feat_dst), efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

1072

nv-dlasalle's avatar
nv-dlasalle committed
1073
@parametrize_idtype
1074
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
1075
1076
def test_gmm_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
1077
    ctx = F.ctx()
1078
    gmmconv = nn.GMMConv(5, 10, 3, 4, "mean")
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1079
1080
    feat = F.randn((g.num_nodes(), 5))
    pseudo = F.randn((g.num_edges(), 3))
1081
    gmmconv = gmmconv.to(ctx)
1082
    h = gmmconv(g, feat, pseudo)
1083
1084
1085
    # currently we only do shape check
    assert h.shape[-1] == 10

1086

nv-dlasalle's avatar
nv-dlasalle committed
1087
@parametrize_idtype
1088
1089
1090
@pytest.mark.parametrize(
    "g", get_cases(["bipartite", "block-bipartite"], exclude=["zero-degree"])
)
1091
1092
1093
def test_gmm_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1094
    gmmconv = nn.GMMConv((5, 2), 10, 3, 4, "mean")
1095
1096
    feat = F.randn((g.number_of_src_nodes(), 5))
    feat_dst = F.randn((g.number_of_dst_nodes(), 2))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1097
    pseudo = F.randn((g.num_edges(), 3))
1098
1099
1100
1101
1102
    gmmconv = gmmconv.to(ctx)
    h = gmmconv(g, (feat, feat_dst), pseudo)
    # currently we only do shape check
    assert h.shape[-1] == 10

1103

nv-dlasalle's avatar
nv-dlasalle committed
1104
@parametrize_idtype
1105
1106
1107
1108
1109
@pytest.mark.parametrize("norm_type", ["both", "right", "none"])
@pytest.mark.parametrize(
    "g", get_cases(["homo", "bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 2])
1110
def test_dense_graph_conv(norm_type, g, idtype, out_dim):
1111
    g = g.astype(idtype).to(F.ctx())
1112
    ctx = F.ctx()
1113
    # TODO(minjie): enable the following option after #1385
1114
    adj = g.adjacency_matrix(transpose=True, ctx=ctx).to_dense()
1115
1116
    conv = nn.GraphConv(5, out_dim, norm=norm_type, bias=True)
    dense_conv = nn.DenseGraphConv(5, out_dim, norm=norm_type, bias=True)
1117
1118
    dense_conv.weight.data = conv.weight.data
    dense_conv.bias.data = conv.bias.data
1119
    feat = F.randn((g.number_of_src_nodes(), 5))
1120
1121
    conv = conv.to(ctx)
    dense_conv = dense_conv.to(ctx)
1122
1123
    out_conv = conv(g, feat)
    out_dense_conv = dense_conv(adj, feat)
1124
1125
    assert F.allclose(out_conv, out_dense_conv)

1126

nv-dlasalle's avatar
nv-dlasalle committed
1127
@parametrize_idtype
1128
1129
@pytest.mark.parametrize("g", get_cases(["homo", "bipartite"]))
@pytest.mark.parametrize("out_dim", [1, 2])
1130
def test_dense_sage_conv(g, idtype, out_dim):
1131
    g = g.astype(idtype).to(F.ctx())
1132
    ctx = F.ctx()
1133
    adj = g.adjacency_matrix(transpose=True, ctx=ctx).to_dense()
1134
    sage = nn.SAGEConv(5, out_dim, "gcn")
1135
    dense_sage = nn.DenseSAGEConv(5, out_dim)
1136
    dense_sage.fc.weight.data = sage.fc_neigh.weight.data
1137
    dense_sage.fc.bias.data = sage.bias.data
1138
1139
1140
    if len(g.ntypes) == 2:
        feat = (
            F.randn((g.number_of_src_nodes(), 5)),
1141
            F.randn((g.number_of_dst_nodes(), 5)),
1142
1143
        )
    else:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1144
        feat = F.randn((g.num_nodes(), 5))
1145
1146
    sage = sage.to(ctx)
    dense_sage = dense_sage.to(ctx)
1147
1148
    out_sage = sage(g, feat)
    out_dense_sage = dense_sage(adj, feat)
1149
1150
    assert F.allclose(out_sage, out_dense_sage), g

1151

nv-dlasalle's avatar
nv-dlasalle committed
1152
@parametrize_idtype
1153
1154
1155
1156
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 2])
1157
def test_edge_conv(g, idtype, out_dim):
1158
    g = g.astype(idtype).to(F.ctx())
1159
    ctx = F.ctx()
1160
    edge_conv = nn.EdgeConv(5, out_dim).to(ctx)
1161
    print(edge_conv)
1162
1163
1164

    # test pickle
    th.save(edge_conv, tmp_buffer)
Mufei Li's avatar
Mufei Li committed
1165

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1166
    h0 = F.randn((g.number_of_src_nodes(), 5))
1167
    h1 = edge_conv(g, h0)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1168
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)
1169

1170

nv-dlasalle's avatar
nv-dlasalle committed
1171
@parametrize_idtype
1172
1173
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
1174
def test_edge_conv_bi(g, idtype, out_dim):
1175
1176
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1177
    edge_conv = nn.EdgeConv(5, out_dim).to(ctx)
1178
    print(edge_conv)
1179
    h0 = F.randn((g.number_of_src_nodes(), 5))
1180
1181
    x0 = F.randn((g.number_of_dst_nodes(), 5))
    h1 = edge_conv(g, (h0, x0))
1182
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)
Mufei Li's avatar
Mufei Li committed
1183

1184

nv-dlasalle's avatar
nv-dlasalle committed
1185
@parametrize_idtype
1186
1187
1188
1189
1190
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
1191
def test_dotgat_conv(g, idtype, out_dim, num_heads):
1192
1193
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1194
    dotgat = nn.DotGatConv(5, out_dim, num_heads)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1195
    feat = F.randn((g.number_of_src_nodes(), 5))
1196
    dotgat = dotgat.to(ctx)
Mufei Li's avatar
Mufei Li committed
1197

1198
1199
    # test pickle
    th.save(dotgat, tmp_buffer)
Mufei Li's avatar
Mufei Li committed
1200

1201
    h = dotgat(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1202
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
1203
    _, a = dotgat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1204
    assert a.shape == (g.num_edges(), num_heads, 1)
1205

1206

nv-dlasalle's avatar
nv-dlasalle committed
1207
@parametrize_idtype
1208
1209
1210
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
1211
def test_dotgat_conv_bi(g, idtype, out_dim, num_heads):
1212
1213
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1214
    dotgat = nn.DotGatConv((5, 5), out_dim, num_heads)
1215
1216
1217
1218
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
1219
1220
    dotgat = dotgat.to(ctx)
    h = dotgat(g, feat)
1221
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
1222
    _, a = dotgat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1223
    assert a.shape == (g.num_edges(), num_heads, 1)
1224

1225
1226

@pytest.mark.parametrize("out_dim", [1, 2])
1227
def test_dense_cheb_conv(out_dim):
1228
1229
1230
    for k in range(1, 4):
        ctx = F.ctx()
        g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
1231
        g = g.to(F.ctx())
1232
        adj = g.adjacency_matrix(transpose=True, ctx=ctx).to_dense()
1233
1234
        cheb = nn.ChebConv(5, out_dim, k, None)
        dense_cheb = nn.DenseChebConv(5, out_dim, k)
1235
        # for i in range(len(cheb.fc)):
Axel Nilsson's avatar
Axel Nilsson committed
1236
        #    dense_cheb.W.data[i] = cheb.fc[i].weight.data.t()
1237
1238
1239
        dense_cheb.W.data = cheb.linear.weight.data.transpose(-1, -2).view(
            k, 5, out_dim
        )
Axel Nilsson's avatar
Axel Nilsson committed
1240
1241
        if cheb.linear.bias is not None:
            dense_cheb.bias.data = cheb.linear.bias.data
1242
        feat = F.randn((100, 5))
1243
1244
        cheb = cheb.to(ctx)
        dense_cheb = dense_cheb.to(ctx)
1245
1246
        out_cheb = cheb(g, feat, [2.0])
        out_dense_cheb = dense_cheb(adj, feat, 2.0)
Axel Nilsson's avatar
Axel Nilsson committed
1247
        print(k, out_cheb, out_dense_cheb)
1248
1249
        assert F.allclose(out_cheb, out_dense_cheb)

1250

1251
1252
def test_sequential():
    ctx = F.ctx()
1253

1254
1255
1256
1257
1258
1259
1260
    # Test single graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat, e_feat):
            graph = graph.local_var()
1261
1262
1263
1264
1265
            graph.ndata["h"] = n_feat
            graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
            n_feat += graph.ndata["h"]
            graph.apply_edges(fn.u_add_v("h", "h", "e"))
            e_feat += graph.edata["e"]
1266
1267
1268
1269
1270
            return n_feat, e_feat

    g = dgl.DGLGraph()
    g.add_nodes(3)
    g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
1271
    g = g.to(F.ctx())
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    n_feat = F.randn((3, 4))
    e_feat = F.randn((9, 4))
    net = net.to(ctx)
    n_feat, e_feat = net(g, n_feat, e_feat)
    assert n_feat.shape == (3, 4)
    assert e_feat.shape == (9, 4)

    # Test multiple graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat):
            graph = graph.local_var()
1287
1288
1289
            graph.ndata["h"] = n_feat
            graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
            n_feat += graph.ndata["h"]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1290
            return n_feat.view(graph.num_nodes() // 2, 2, -1).sum(1)
1291

1292
1293
1294
    g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05)).to(F.ctx())
    g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2)).to(F.ctx())
    g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8)).to(F.ctx())
1295
1296
1297
1298
1299
1300
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    net = net.to(ctx)
    n_feat = F.randn((32, 4))
    n_feat = net([g1, g2, g3], n_feat)
    assert n_feat.shape == (4, 4)

1301

nv-dlasalle's avatar
nv-dlasalle committed
1302
@parametrize_idtype
1303
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
1304
1305
def test_atomic_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
1306
1307
1308
1309
1310
1311
    aconv = nn.AtomicConv(
        interaction_cutoffs=F.tensor([12.0, 12.0]),
        rbf_kernel_means=F.tensor([0.0, 2.0]),
        rbf_kernel_scaling=F.tensor([4.0, 4.0]),
        features_to_use=F.tensor([6.0, 8.0]),
    )
1312
1313
1314
1315
1316

    ctx = F.ctx()
    if F.gpu_ctx():
        aconv = aconv.to(ctx)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1317
1318
    feat = F.randn((g.num_nodes(), 1))
    dist = F.randn((g.num_edges(), 1))
1319
1320

    h = aconv(g, feat, dist)
1321

1322
1323
1324
    # current we only do shape check
    assert h.shape[-1] == 4

1325

nv-dlasalle's avatar
nv-dlasalle committed
1326
@parametrize_idtype
1327
1328
1329
1330
@pytest.mark.parametrize(
    "g", get_cases(["homo", "bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 3])
1331
def test_cf_conv(g, idtype, out_dim):
1332
    g = g.astype(idtype).to(F.ctx())
1333
1334
1335
    cfconv = nn.CFConv(
        node_in_feats=2, edge_in_feats=3, hidden_feats=2, out_feats=out_dim
    )
1336
1337
1338
1339
1340

    ctx = F.ctx()
    if F.gpu_ctx():
        cfconv = cfconv.to(ctx)

1341
    src_feats = F.randn((g.number_of_src_nodes(), 2))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1342
    edge_feats = F.randn((g.num_edges(), 3))
1343
1344
1345
1346
1347
1348
1349
    h = cfconv(g, src_feats, edge_feats)
    # current we only do shape check
    assert h.shape[-1] == out_dim

    # case for bipartite graphs
    dst_feats = F.randn((g.number_of_dst_nodes(), 3))
    h = cfconv(g, (src_feats, dst_feats), edge_feats)
1350
    # current we only do shape check
1351
    assert h.shape[-1] == out_dim
1352

1353

1354
1355
1356
1357
1358
1359
def myagg(alist, dsttype):
    rst = alist[0]
    for i in range(1, len(alist)):
        rst = rst + (i + 1) * alist[i]
    return rst

1360

nv-dlasalle's avatar
nv-dlasalle committed
1361
@parametrize_idtype
1362
1363
@pytest.mark.parametrize("agg", ["sum", "max", "min", "mean", "stack", myagg])
@pytest.mark.parametrize("canonical_keys", [False, True])
1364
def test_hetero_conv(agg, idtype, canonical_keys):
1365
1366
1367
1368
1369
1370
1371
1372
1373
    g = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 0, 2, 1], [1, 2, 1, 3]),
            ("user", "plays", "game"): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]),
            ("store", "sells", "game"): ([0, 0, 1, 1], [0, 3, 1, 2]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
1374
    if not canonical_keys:
1375
1376
1377
1378
1379
1380
1381
1382
        conv = nn.HeteroGraphConv(
            {
                "follows": nn.GraphConv(2, 3, allow_zero_in_degree=True),
                "plays": nn.GraphConv(2, 4, allow_zero_in_degree=True),
                "sells": nn.GraphConv(3, 4, allow_zero_in_degree=True),
            },
            agg,
        )
1383
    else:
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
        conv = nn.HeteroGraphConv(
            {
                ("user", "follows", "user"): nn.GraphConv(
                    2, 3, allow_zero_in_degree=True
                ),
                ("user", "plays", "game"): nn.GraphConv(
                    2, 4, allow_zero_in_degree=True
                ),
                ("store", "sells", "game"): nn.GraphConv(
                    3, 4, allow_zero_in_degree=True
                ),
            },
            agg,
        )
1398

1399
    conv = conv.to(F.ctx())
1400
1401
1402
1403

    # test pickle
    th.save(conv, tmp_buffer)

1404
1405
1406
1407
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))

1408
1409
1410
1411
1412
    h = conv(g, {"user": uf, "game": gf, "store": sf})
    assert set(h.keys()) == {"user", "game"}
    if agg != "stack":
        assert h["user"].shape == (4, 3)
        assert h["game"].shape == (4, 4)
1413
    else:
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
        assert h["user"].shape == (4, 1, 3)
        assert h["game"].shape == (4, 2, 4)

    block = dgl.to_block(
        g.to(F.cpu()), {"user": [0, 1, 2, 3], "game": [0, 1, 2, 3], "store": []}
    ).to(F.ctx())
    h = conv(
        block,
        (
            {"user": uf, "game": gf, "store": sf},
            {"user": uf, "game": gf, "store": sf[0:0]},
        ),
    )
    assert set(h.keys()) == {"user", "game"}
    if agg != "stack":
        assert h["user"].shape == (4, 3)
        assert h["game"].shape == (4, 4)
1431
    else:
1432
1433
1434
1435
1436
1437
1438
1439
        assert h["user"].shape == (4, 1, 3)
        assert h["game"].shape == (4, 2, 4)

    h = conv(block, {"user": uf, "game": gf, "store": sf})
    assert set(h.keys()) == {"user", "game"}
    if agg != "stack":
        assert h["user"].shape == (4, 3)
        assert h["game"].shape == (4, 4)
1440
    else:
1441
1442
        assert h["user"].shape == (4, 1, 3)
        assert h["game"].shape == (4, 2, 4)
1443
1444
1445
1446
1447
1448
1449
1450
1451

    # test with mod args
    class MyMod(th.nn.Module):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.carg2 = 0
            self.s1 = s1
            self.s2 = s2
1452

1453
1454
1455
1456
1457
1458
        def forward(self, g, h, arg1=None, *, arg2=None):
            if arg1 is not None:
                self.carg1 += 1
            if arg2 is not None:
                self.carg2 += 1
            return th.zeros((g.number_of_dst_nodes(), self.s2))
1459

1460
1461
1462
    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
1463
1464
1465
    conv = nn.HeteroGraphConv(
        {"follows": mod1, "plays": mod2, "sells": mod3}, agg
    )
1466
    conv = conv.to(F.ctx())
1467
1468
1469
1470
1471
1472
1473
1474
    mod_args = {"follows": (1,), "plays": (1,)}
    mod_kwargs = {"sells": {"arg2": "abc"}}
    h = conv(
        g,
        {"user": uf, "game": gf, "store": sf},
        mod_args=mod_args,
        mod_kwargs=mod_kwargs,
    )
1475
1476
1477
1478
1479
1480
1481
    assert mod1.carg1 == 1
    assert mod1.carg2 == 0
    assert mod2.carg1 == 1
    assert mod2.carg2 == 0
    assert mod3.carg1 == 0
    assert mod3.carg2 == 1

1482
    # conv on graph without any edges
1483
    for etype in g.etypes:
1484
        g = dgl.remove_edges(g, g.edges(form="eid", etype=etype), etype=etype)
1485
    assert g.num_edges() == 0
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
    h = conv(g, {"user": uf, "game": gf, "store": sf})
    assert set(h.keys()) == {"user", "game"}

    block = dgl.to_block(
        g.to(F.cpu()), {"user": [0, 1, 2, 3], "game": [0, 1, 2, 3], "store": []}
    ).to(F.ctx())
    h = conv(
        block,
        (
            {"user": uf, "game": gf, "store": sf},
            {"user": uf, "game": gf, "store": sf[0:0]},
        ),
    )
    assert set(h.keys()) == {"user", "game"}
1500
1501


1502
@pytest.mark.parametrize("out_dim", [1, 2, 100])
1503
1504
def test_hetero_linear(out_dim):
    in_feats = {
1505
1506
        "user": F.randn((2, 1)),
        ("user", "follows", "user"): F.randn((3, 2)),
1507
1508
    }

1509
1510
1511
    layer = nn.HeteroLinear(
        {"user": 1, ("user", "follows", "user"): 2}, out_dim
    )
1512
1513
    layer = layer.to(F.ctx())
    out_feats = layer(in_feats)
1514
1515
1516
    assert out_feats["user"].shape == (2, out_dim)
    assert out_feats[("user", "follows", "user")].shape == (3, out_dim)

1517

1518
@pytest.mark.parametrize("out_dim", [1, 2, 100])
1519
def test_hetero_embedding(out_dim):
1520
1521
1522
    layer = nn.HeteroEmbedding(
        {"user": 2, ("user", "follows", "user"): 3}, out_dim
    )
1523
1524
1525
    layer = layer.to(F.ctx())

    embeds = layer.weight
1526
1527
    assert embeds["user"].shape == (2, out_dim)
    assert embeds[("user", "follows", "user")].shape == (3, out_dim)
1528

YJ-Zhao's avatar
YJ-Zhao committed
1529
1530
    layer.reset_parameters()
    embeds = layer.weight
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
    assert embeds["user"].shape == (2, out_dim)
    assert embeds[("user", "follows", "user")].shape == (3, out_dim)

    embeds = layer(
        {
            "user": F.tensor([0], dtype=F.int64),
            ("user", "follows", "user"): F.tensor([0, 2], dtype=F.int64),
        }
    )
    assert embeds["user"].shape == (1, out_dim)
    assert embeds[("user", "follows", "user")].shape == (2, out_dim)
YJ-Zhao's avatar
YJ-Zhao committed
1542

1543

nv-dlasalle's avatar
nv-dlasalle committed
1544
@parametrize_idtype
1545
1546
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
Mufei Li's avatar
Mufei Li committed
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
def test_gnnexplainer(g, idtype, out_dim):
    g = g.astype(idtype).to(F.ctx())
    feat = F.randn((g.num_nodes(), 5))

    class Model(th.nn.Module):
        def __init__(self, in_feats, out_feats, graph=False):
            super(Model, self).__init__()
            self.linear = th.nn.Linear(in_feats, out_feats)
            if graph:
                self.pool = nn.AvgPooling()
            else:
                self.pool = None

        def forward(self, graph, feat, eweight=None):
            with graph.local_scope():
                feat = self.linear(feat)
1563
                graph.ndata["h"] = feat
Mufei Li's avatar
Mufei Li committed
1564
                if eweight is None:
1565
                    graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
Mufei Li's avatar
Mufei Li committed
1566
                else:
1567
1568
1569
1570
                    graph.edata["w"] = eweight
                    graph.update_all(
                        fn.u_mul_e("h", "w", "m"), fn.sum("m", "h")
                    )
Mufei Li's avatar
Mufei Li committed
1571
1572

                if self.pool:
1573
                    return self.pool(graph, graph.ndata["h"])
Mufei Li's avatar
Mufei Li committed
1574
                else:
1575
                    return graph.ndata["h"]
Mufei Li's avatar
Mufei Li committed
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588

    # Explain node prediction
    model = Model(5, out_dim)
    model = model.to(F.ctx())
    explainer = nn.GNNExplainer(model, num_hops=1)
    new_center, sg, feat_mask, edge_mask = explainer.explain_node(0, g, feat)

    # Explain graph prediction
    model = Model(5, out_dim, graph=True)
    model = model.to(F.ctx())
    explainer = nn.GNNExplainer(model, num_hops=1)
    feat_mask, edge_mask = explainer.explain_graph(g, feat)

1589
1590
1591
1592
1593

@pytest.mark.parametrize("g", get_cases(["hetero"], exclude=["zero-degree"]))
@pytest.mark.parametrize("idtype", [F.int64])
@pytest.mark.parametrize("input_dim", [5])
@pytest.mark.parametrize("output_dim", [1, 2])
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
def test_heterognnexplainer(g, idtype, input_dim, output_dim):
    g = g.astype(idtype).to(F.ctx())
    device = g.device

    # add self-loop and reverse edges
    transform1 = dgl.transforms.AddSelfLoop(new_etypes=True)
    g = transform1(g)
    transform2 = dgl.transforms.AddReverse(copy_edata=True)
    g = transform2(g)

1604
1605
1606
1607
    feat = {
        ntype: th.zeros((g.num_nodes(ntype), input_dim), device=device)
        for ntype in g.ntypes
    }
1608
1609
1610
1611

    class Model(th.nn.Module):
        def __init__(self, in_dim, num_classes, canonical_etypes, graph=False):
            super(Model, self).__init__()
1612
1613
1614
1615
1616
1617
1618
            self.graph = graph
            self.etype_weights = th.nn.ModuleDict(
                {
                    "_".join(c_etype): th.nn.Linear(in_dim, num_classes)
                    for c_etype in canonical_etypes
                }
            )
1619
1620
1621
1622
1623
1624

        def forward(self, graph, feat, eweight=None):
            with graph.local_scope():
                c_etype_func_dict = {}
                for c_etype in graph.canonical_etypes:
                    src_type, etype, dst_type = c_etype
1625
1626
                    wh = self.etype_weights["_".join(c_etype)](feat[src_type])
                    graph.nodes[src_type].data[f"h_{c_etype}"] = wh
1627
                    if eweight is None:
1628
1629
1630
1631
                        c_etype_func_dict[c_etype] = (
                            fn.copy_u(f"h_{c_etype}", "m"),
                            fn.mean("m", "h"),
                        )
1632
                    else:
1633
                        graph.edges[c_etype].data["w"] = eweight[c_etype]
1634
                        c_etype_func_dict[c_etype] = (
1635
1636
1637
1638
                            fn.u_mul_e(f"h_{c_etype}", "w", "m"),
                            fn.mean("m", "h"),
                        )
                graph.multi_update_all(c_etype_func_dict, "sum")
1639
1640
1641
1642
                if self.graph:
                    hg = 0
                    for ntype in graph.ntypes:
                        if graph.num_nodes(ntype):
1643
                            hg = hg + dgl.mean_nodes(graph, "h", ntype=ntype)
1644
1645
1646

                    return hg
                else:
1647
                    return graph.ndata["h"]
1648
1649
1650
1651
1652
1653

    # Explain node prediction
    model = Model(input_dim, output_dim, g.canonical_etypes)
    model = model.to(F.ctx())
    ntype = g.ntypes[0]
    explainer = nn.explain.HeteroGNNExplainer(model, num_hops=1)
1654
1655
1656
    new_center, sg, feat_mask, edge_mask = explainer.explain_node(
        ntype, 0, g, feat
    )
1657
1658
1659
1660
1661
1662
1663
1664

    # Explain graph prediction
    model = Model(input_dim, output_dim, g.canonical_etypes, graph=True)
    model = model.to(F.ctx())
    explainer = nn.explain.HeteroGNNExplainer(model, num_hops=1)
    feat_mask, edge_mask = explainer.explain_graph(g, feat)


1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
@parametrize_idtype
@pytest.mark.parametrize(
    "g",
    get_cases(
        ["homo"],
        exclude=[
            "zero-degree",
            "homo-zero-degree",
            "has_feature",
            "has_scalar_e_feature",
            "row_sorted",
            "col_sorted",
            "batched",
        ],
    ),
)
@pytest.mark.parametrize("n_classes", [2])
def test_subgraphx(g, idtype, n_classes):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    feat = F.randn((g.num_nodes(), 5))

    class Model(th.nn.Module):
        def __init__(self, in_dim, n_classes):
            super().__init__()
            self.conv = nn.GraphConv(in_dim, n_classes)
            self.pool = nn.AvgPooling()

        def forward(self, g, h):
            h = th.nn.functional.relu(self.conv(g, h))
            return self.pool(g, h)

    model = Model(feat.shape[1], n_classes)
    model = model.to(ctx)
    explainer = nn.SubgraphX(
        model, num_hops=1, shapley_steps=20, num_rollouts=5, coef=2.0
    )
    explainer.explain_graph(g, feat, target_class=0)


Mufei Li's avatar
Mufei Li committed
1705
1706
1707
1708
1709
1710
def test_jumping_knowledge():
    ctx = F.ctx()
    num_layers = 2
    num_nodes = 3
    num_feats = 4

1711
1712
1713
    feat_list = [
        th.randn((num_nodes, num_feats)).to(ctx) for _ in range(num_layers)
    ]
Mufei Li's avatar
Mufei Li committed
1714

1715
    model = nn.JumpingKnowledge("cat").to(ctx)
Mufei Li's avatar
Mufei Li committed
1716
1717
1718
    model.reset_parameters()
    assert model(feat_list).shape == (num_nodes, num_layers * num_feats)

1719
    model = nn.JumpingKnowledge("max").to(ctx)
Mufei Li's avatar
Mufei Li committed
1720
1721
1722
    model.reset_parameters()
    assert model(feat_list).shape == (num_nodes, num_feats)

1723
    model = nn.JumpingKnowledge("lstm", num_feats, num_layers).to(ctx)
Mufei Li's avatar
Mufei Li committed
1724
1725
1726
    model.reset_parameters()
    assert model(feat_list).shape == (num_nodes, num_feats)

1727
1728

@pytest.mark.parametrize("op", ["dot", "cos", "ele", "cat"])
Mufei Li's avatar
Mufei Li committed
1729
1730
1731
1732
1733
1734
1735
1736
1737
def test_edge_predictor(op):
    ctx = F.ctx()
    num_pairs = 3
    in_feats = 4
    out_feats = 5
    h_src = th.randn((num_pairs, in_feats)).to(ctx)
    h_dst = th.randn((num_pairs, in_feats)).to(ctx)

    pred = nn.EdgePredictor(op)
1738
    if op in ["dot", "cos"]:
Mufei Li's avatar
Mufei Li committed
1739
        assert pred(h_src, h_dst).shape == (num_pairs, 1)
1740
    elif op == "ele":
Mufei Li's avatar
Mufei Li committed
1741
1742
1743
1744
1745
1746
        assert pred(h_src, h_dst).shape == (num_pairs, in_feats)
    else:
        assert pred(h_src, h_dst).shape == (num_pairs, 2 * in_feats)
    pred = nn.EdgePredictor(op, in_feats, out_feats, bias=True).to(ctx)
    assert pred(h_src, h_dst).shape == (num_pairs, out_feats)

Mufei Li's avatar
Mufei Li committed
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761

def test_ke_score_funcs():
    ctx = F.ctx()
    num_edges = 30
    num_rels = 3
    nfeats = 4

    h_src = th.randn((num_edges, nfeats)).to(ctx)
    h_dst = th.randn((num_edges, nfeats)).to(ctx)
    rels = th.randint(low=0, high=num_rels, size=(num_edges,)).to(ctx)

    score_func = nn.TransE(num_rels=num_rels, feats=nfeats).to(ctx)
    score_func.reset_parameters()
    score_func(h_src, h_dst, rels).shape == (num_edges)

1762
1763
1764
    score_func = nn.TransR(
        num_rels=num_rels, rfeats=nfeats - 1, nfeats=nfeats
    ).to(ctx)
Mufei Li's avatar
Mufei Li committed
1765
1766
1767
1768
    score_func.reset_parameters()
    score_func(h_src, h_dst, rels).shape == (num_edges)


1769
def test_twirls():
1770
    g = dgl.graph(([0, 1, 2, 3, 2, 5], [1, 2, 3, 4, 0, 3]))
1771
    feat = th.ones(6, 10)
1772
1773
1774
1775
    conv = nn.TWIRLSConv(10, 2, 128, prop_step=64)
    res = conv(g, feat)
    assert res.size() == (6, 2)

1776

1777
1778
1779
1780
@pytest.mark.parametrize("feat_size", [4, 32])
@pytest.mark.parametrize(
    "regularizer,num_bases", [(None, None), ("basis", 4), ("bdd", 4)]
)
1781
1782
1783
def test_typed_linear(feat_size, regularizer, num_bases):
    dev = F.ctx()
    num_types = 5
1784
1785
1786
1787
1788
1789
1790
    lin = nn.TypedLinear(
        feat_size,
        feat_size * 2,
        5,
        regularizer=regularizer,
        num_bases=num_bases,
    ).to(dev)
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
    print(lin)
    x = th.randn(100, feat_size).to(dev)
    x_type = th.randint(0, 5, (100,)).to(dev)
    x_type_sorted, idx = th.sort(x_type)
    _, rev_idx = th.sort(idx)
    x_sorted = x[idx]

    # test unsorted
    y = lin(x, x_type)
    assert y.shape == (100, feat_size * 2)
    # test sorted
    y_sorted = lin(x_sorted, x_type_sorted, sorted_by_type=True)
    assert y_sorted.shape == (100, feat_size * 2)

    assert th.allclose(y, y_sorted[rev_idx], atol=1e-4, rtol=1e-4)
1806

1807

nv-dlasalle's avatar
nv-dlasalle committed
1808
@parametrize_idtype
1809
1810
@pytest.mark.parametrize("in_size", [4])
@pytest.mark.parametrize("num_heads", [1])
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
def test_hgt(idtype, in_size, num_heads):
    dev = F.ctx()
    num_etypes = 5
    num_ntypes = 2
    head_size = in_size // num_heads

    g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.01))
    g = g.astype(idtype).to(dev)
    etype = th.tensor([i % num_etypes for i in range(g.num_edges())]).to(dev)
    ntype = th.tensor([i % num_ntypes for i in range(g.num_nodes())]).to(dev)
    x = th.randn(g.num_nodes(), in_size).to(dev)
1822

1823
1824
1825
    m = nn.HGTConv(in_size, head_size, num_heads, num_ntypes, num_etypes).to(
        dev
    )
1826
1827
1828
1829
1830
1831
1832

    y = m(g, x, ntype, etype)
    assert y.shape == (g.num_nodes(), head_size * num_heads)
    # presorted
    sorted_ntype, idx_nt = th.sort(ntype)
    sorted_etype, idx_et = th.sort(etype)
    _, rev_idx = th.sort(idx_nt)
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
    g.ndata["t"] = ntype
    g.ndata["x"] = x
    g.edata["t"] = etype
    sorted_g = dgl.reorder_graph(
        g,
        node_permute_algo="custom",
        edge_permute_algo="custom",
        permute_config={
            "nodes_perm": idx_nt.to(idtype),
            "edges_perm": idx_et.to(idtype),
        },
    )
    print(sorted_g.ndata["t"])
    print(sorted_g.edata["t"])
    sorted_x = sorted_g.ndata["x"]
    sorted_y = m(
        sorted_g, sorted_x, sorted_ntype, sorted_etype, presorted=False
    )
1851
    assert sorted_y.shape == (g.num_nodes(), head_size * num_heads)
dddg617's avatar
dddg617 committed
1852
    # mini-batch
1853
    train_idx = th.randperm(100, dtype=idtype)[:10]
dddg617's avatar
dddg617 committed
1854
    sampler = dgl.dataloading.NeighborSampler([-1])
1855
1856
1857
    train_loader = dgl.dataloading.DataLoader(
        g, train_idx.to(dev), sampler, batch_size=8, device=dev, shuffle=True
    )
dddg617's avatar
dddg617 committed
1858
1859
1860
1861
1862
1863
1864
1865
    (input_nodes, output_nodes, block) = next(iter(train_loader))
    block = block[0]
    x = x[input_nodes.to(th.long)]
    ntype = ntype[input_nodes.to(th.long)]
    edge = block.edata[dgl.EID]
    etype = etype[edge.to(th.long)]
    y = m(block, x, ntype, etype)
    assert y.shape == (block.number_of_dst_nodes(), head_size * num_heads)
1866
    # TODO(minjie): enable the following check
1867
1868
    # assert th.allclose(y, sorted_y[rev_idx], atol=1e-4, rtol=1e-4)

1869

1870
1871
@pytest.mark.parametrize("self_loop", [True, False])
@pytest.mark.parametrize("get_distances", [True, False])
1872
def test_radius_graph(self_loop, get_distances):
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
    pos = th.tensor(
        [
            [0.1, 0.3, 0.4],
            [0.5, 0.2, 0.1],
            [0.7, 0.9, 0.5],
            [0.3, 0.2, 0.5],
            [0.2, 0.8, 0.2],
            [0.9, 0.2, 0.1],
            [0.7, 0.4, 0.4],
            [0.2, 0.1, 0.6],
            [0.5, 0.3, 0.5],
            [0.4, 0.2, 0.6],
        ]
    )
1887
1888
1889
1890
1891
1892
1893
1894
1895

    rg = nn.RadiusGraph(0.3, self_loop=self_loop)

    if get_distances:
        g, dists = rg(pos, get_distances=get_distances)
    else:
        g = rg(pos)

    if self_loop:
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
        src_target = th.tensor(
            [
                0,
                0,
                1,
                2,
                3,
                3,
                3,
                3,
                3,
                4,
                5,
                6,
                6,
                7,
                7,
                7,
                8,
                8,
                8,
                8,
                9,
                9,
                9,
                9,
            ]
        )
        dst_target = th.tensor(
            [
                0,
                3,
                1,
                2,
                0,
                3,
                7,
                8,
                9,
                4,
                5,
                6,
                8,
                3,
                7,
                9,
                3,
                6,
                8,
                9,
                3,
                7,
                8,
                9,
            ]
        )
1952
1953

        if get_distances:
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
            dists_target = th.tensor(
                [
                    [0.0000],
                    [0.2449],
                    [0.0000],
                    [0.0000],
                    [0.2449],
                    [0.0000],
                    [0.1732],
                    [0.2236],
                    [0.1414],
                    [0.0000],
                    [0.0000],
                    [0.0000],
                    [0.2449],
                    [0.1732],
                    [0.0000],
                    [0.2236],
                    [0.2236],
                    [0.2449],
                    [0.0000],
                    [0.1732],
                    [0.1414],
                    [0.2236],
                    [0.1732],
                    [0.0000],
                ]
            )
1982
1983
1984
1985
1986
    else:
        src_target = th.tensor([0, 3, 3, 3, 3, 6, 7, 7, 8, 8, 8, 9, 9, 9])
        dst_target = th.tensor([3, 0, 7, 8, 9, 8, 3, 9, 3, 6, 9, 3, 7, 8])

        if get_distances:
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
            dists_target = th.tensor(
                [
                    [0.2449],
                    [0.2449],
                    [0.1732],
                    [0.2236],
                    [0.1414],
                    [0.2449],
                    [0.1732],
                    [0.2236],
                    [0.2236],
                    [0.2449],
                    [0.1732],
                    [0.1414],
                    [0.2236],
                    [0.1732],
                ]
            )
2005
2006
2007
2008
2009
2010
2011
2012
2013

    src, dst = g.edges()

    assert th.equal(src, src_target)
    assert th.equal(dst, dst_target)

    if get_distances:
        assert th.allclose(dists, dists_target, rtol=1e-03)

2014

nv-dlasalle's avatar
nv-dlasalle committed
2015
@parametrize_idtype
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
def test_group_rev_res(idtype):
    dev = F.ctx()

    num_nodes = 5
    num_edges = 20
    feats = 32
    groups = 2
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    h = th.randn(num_nodes, feats).to(dev)
    conv = nn.GraphConv(feats // groups, feats // groups)
    model = nn.GroupRevRes(conv, groups).to(dev)
2027
2028
    result = model(g, h)
    result.sum().backward()
rudongyu's avatar
rudongyu committed
2029

2030
2031
2032
2033
2034

@pytest.mark.parametrize("in_size", [16, 32])
@pytest.mark.parametrize("hidden_size", [16, 32])
@pytest.mark.parametrize("out_size", [16, 32])
@pytest.mark.parametrize("edge_feat_size", [16, 10, 0])
rudongyu's avatar
rudongyu committed
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
def test_egnn_conv(in_size, hidden_size, out_size, edge_feat_size):
    dev = F.ctx()
    num_nodes = 5
    num_edges = 20
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    h = th.randn(num_nodes, in_size).to(dev)
    x = th.randn(num_nodes, 3).to(dev)
    e = th.randn(num_edges, edge_feat_size).to(dev)
    model = nn.EGNNConv(in_size, hidden_size, out_size, edge_feat_size).to(dev)
    model(g, h, x, e)

2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075

@pytest.mark.parametrize("in_size", [16, 32])
@pytest.mark.parametrize("out_size", [16, 32])
@pytest.mark.parametrize(
    "aggregators",
    [
        ["mean", "max", "sum"],
        ["min", "std", "var"],
        ["moment3", "moment4", "moment5"],
    ],
)
@pytest.mark.parametrize(
    "scalers", [["identity"], ["amplification", "attenuation"]]
)
@pytest.mark.parametrize("delta", [2.5, 7.4])
@pytest.mark.parametrize("dropout", [0.0, 0.1])
@pytest.mark.parametrize("num_towers", [1, 4])
@pytest.mark.parametrize("edge_feat_size", [16, 0])
@pytest.mark.parametrize("residual", [True, False])
def test_pna_conv(
    in_size,
    out_size,
    aggregators,
    scalers,
    delta,
    dropout,
    num_towers,
    edge_feat_size,
    residual,
):
rudongyu's avatar
rudongyu committed
2076
2077
2078
2079
2080
2081
    dev = F.ctx()
    num_nodes = 5
    num_edges = 20
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    h = th.randn(num_nodes, in_size).to(dev)
    e = th.randn(num_edges, edge_feat_size).to(dev)
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
    model = nn.PNAConv(
        in_size,
        out_size,
        aggregators,
        scalers,
        delta,
        dropout,
        num_towers,
        edge_feat_size,
        residual,
    ).to(dev)
rudongyu's avatar
rudongyu committed
2093
    model(g, h, edge_feat=e)
2094

2095
2096
2097
2098
2099
2100
2101

@pytest.mark.parametrize("k", [3, 5])
@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0])
@pytest.mark.parametrize("norm_type", ["sym", "row"])
@pytest.mark.parametrize("clamp", [True, False])
@pytest.mark.parametrize("normalize", [True, False])
@pytest.mark.parametrize("reset", [True, False])
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
def test_label_prop(k, alpha, norm_type, clamp, normalize, reset):
    dev = F.ctx()
    num_nodes = 5
    num_edges = 20
    num_classes = 4
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    labels = th.tensor([0, 2, 1, 3, 0]).long().to(dev)
    ml_labels = th.rand(num_nodes, num_classes).to(dev) > 0.7
    mask = th.tensor([0, 1, 1, 1, 0]).bool().to(dev)
    model = nn.LabelPropagation(k, alpha, norm_type, clamp, normalize, reset)
    model(g, labels, mask)
    # multi-label case
    model(g, ml_labels, mask)

2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127

@pytest.mark.parametrize("in_size", [16])
@pytest.mark.parametrize("out_size", [16, 32])
@pytest.mark.parametrize(
    "aggregators", [["mean", "max", "dir2-av"], ["min", "std", "dir1-dx"]]
)
@pytest.mark.parametrize("scalers", [["amplification", "attenuation"]])
@pytest.mark.parametrize("delta", [2.5])
@pytest.mark.parametrize("edge_feat_size", [16, 0])
def test_dgn_conv(
    in_size, out_size, aggregators, scalers, delta, edge_feat_size
):
2128
2129
2130
2131
2132
2133
    dev = F.ctx()
    num_nodes = 5
    num_edges = 20
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    h = th.randn(num_nodes, in_size).to(dev)
    e = th.randn(num_edges, edge_feat_size).to(dev)
2134
    transform = dgl.LaplacianPE(k=3, feat_name="eig")
2135
    g = transform(g)
2136
2137
2138
2139
2140
2141
2142
2143
2144
    eig = g.ndata["eig"]
    model = nn.DGNConv(
        in_size,
        out_size,
        aggregators,
        scalers,
        delta,
        edge_feat_size=edge_feat_size,
    ).to(dev)
2145
2146
    model(g, h, edge_feat=e, eig_vec=eig)

2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
    aggregators_non_eig = [
        aggr for aggr in aggregators if not aggr.startswith("dir")
    ]
    model = nn.DGNConv(
        in_size,
        out_size,
        aggregators_non_eig,
        scalers,
        delta,
        edge_feat_size=edge_feat_size,
    ).to(dev)
2158
    model(g, h, edge_feat=e)
LuckyLiuM's avatar
LuckyLiuM committed
2159

2160

LuckyLiuM's avatar
LuckyLiuM committed
2161
2162
2163
def test_DeepWalk():
    dev = F.ctx()
    g = dgl.graph(([0, 1, 2, 1, 2, 0], [1, 2, 0, 0, 1, 2]))
2164
2165
2166
    model = nn.DeepWalk(
        g, emb_dim=8, walk_length=2, window_size=1, fast_neg=True, sparse=True
    )
LuckyLiuM's avatar
LuckyLiuM committed
2167
    model = model.to(dev)
2168
2169
2170
    dataloader = DataLoader(
        torch.arange(g.num_nodes()), batch_size=16, collate_fn=model.sample
    )
LuckyLiuM's avatar
LuckyLiuM committed
2171
2172
2173
2174
2175
2176
    optim = SparseAdam(model.parameters(), lr=0.01)
    walk = next(iter(dataloader)).to(dev)
    loss = model(walk)
    loss.backward()
    optim.step()

2177
2178
2179
    model = nn.DeepWalk(
        g, emb_dim=8, walk_length=2, window_size=1, fast_neg=False, sparse=False
    )
LuckyLiuM's avatar
LuckyLiuM committed
2180
    model = model.to(dev)
2181
2182
2183
    dataloader = DataLoader(
        torch.arange(g.num_nodes()), batch_size=16, collate_fn=model.sample
    )
LuckyLiuM's avatar
LuckyLiuM committed
2184
2185
2186
2187
2188
    optim = Adam(model.parameters(), lr=0.01)
    walk = next(iter(dataloader)).to(dev)
    loss = model(walk)
    loss.backward()
    optim.step()
LuckyLiuM's avatar
LuckyLiuM committed
2189

2190
2191
2192
2193

@pytest.mark.parametrize("max_degree", [2, 6])
@pytest.mark.parametrize("embedding_dim", [8, 16])
@pytest.mark.parametrize("direction", ["in", "out", "both"])
2194
def test_degree_encoder(max_degree, embedding_dim, direction):
2195
2196
2197
2198
2199
2200
    g = dgl.graph(
        (
            th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),
            th.tensor([1, 2, 3, 0, 3, 0, 0, 1]),
        )
    )
2201
    # test heterograph
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
    hg = dgl.heterograph(
        {
            ("drug", "interacts", "drug"): (
                th.tensor([0, 1]),
                th.tensor([1, 2]),
            ),
            ("drug", "interacts", "gene"): (
                th.tensor([0, 1]),
                th.tensor([2, 3]),
            ),
            ("drug", "treats", "disease"): (th.tensor([1]), th.tensor([2])),
        }
    )
2215
2216
2217
2218
2219
2220
    model = nn.DegreeEncoder(max_degree, embedding_dim, direction=direction)
    de_g = model(g)
    de_hg = model(hg)
    assert de_g.shape == (4, embedding_dim)
    assert de_hg.shape == (10, embedding_dim)

2221

LuckyLiuM's avatar
LuckyLiuM committed
2222
2223
2224
@parametrize_idtype
def test_MetaPath2Vec(idtype):
    dev = F.ctx()
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
    g = dgl.heterograph(
        {
            ("user", "uc", "company"): ([0, 0, 2, 1, 3], [1, 2, 1, 3, 0]),
            ("company", "cp", "product"): (
                [0, 0, 0, 1, 2, 3],
                [0, 2, 3, 0, 2, 1],
            ),
            ("company", "cu", "user"): ([1, 2, 1, 3, 0], [0, 0, 2, 1, 3]),
            ("product", "pc", "company"): (
                [0, 2, 3, 0, 2, 1],
                [0, 0, 0, 1, 2, 3],
            ),
        },
        idtype=idtype,
        device=dev,
    )
    model = nn.MetaPath2Vec(g, ["uc", "cu"], window_size=1)
LuckyLiuM's avatar
LuckyLiuM committed
2242
2243
2244
    model = model.to(dev)
    embeds = model.node_embed.weight
    assert embeds.shape[0] == g.num_nodes()
2245

2246
2247
2248
2249
2250
2251
2252
2253
2254
2255

@pytest.mark.parametrize("num_layer", [1, 4])
@pytest.mark.parametrize("k", [3, 5])
@pytest.mark.parametrize("lpe_dim", [4, 16])
@pytest.mark.parametrize("n_head", [1, 4])
@pytest.mark.parametrize("batch_norm", [True, False])
@pytest.mark.parametrize("num_post_layer", [0, 1, 2])
def test_LaplacianPosEnc(
    num_layer, k, lpe_dim, n_head, batch_norm, num_post_layer
):
2256
2257
2258
2259
2260
2261
    ctx = F.ctx()
    num_nodes = 4

    EigVals = th.randn((num_nodes, k)).to(ctx)
    EigVecs = th.randn((num_nodes, k)).to(ctx)

2262
2263
2264
    model = nn.LaplacianPosEnc(
        "Transformer", num_layer, k, lpe_dim, n_head, batch_norm, num_post_layer
    ).to(ctx)
2265
2266
    assert model(EigVals, EigVecs).shape == (num_nodes, lpe_dim)

2267
2268
2269
2270
2271
2272
2273
2274
    model = nn.LaplacianPosEnc(
        "DeepSet",
        num_layer,
        k,
        lpe_dim,
        batch_norm=batch_norm,
        num_post_layer=num_post_layer,
    ).to(ctx)
2275
    assert model(EigVals, EigVecs).shape == (num_nodes, lpe_dim)
2276

2277
2278
2279
2280
2281
2282
2283
2284
2285

@pytest.mark.parametrize("feat_size", [128, 512])
@pytest.mark.parametrize("num_heads", [8, 16])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("attn_bias_type", ["add", "mul"])
@pytest.mark.parametrize("attn_drop", [0.1, 0.5])
def test_BiasedMultiheadAttention(
    feat_size, num_heads, bias, attn_bias_type, attn_drop
):
2286
2287
2288
2289
    ndata = th.rand(16, 100, feat_size)
    attn_bias = th.rand(16, 100, 100, num_heads)
    attn_mask = th.rand(16, 100, 100) < 0.5

2290
2291
2292
    net = nn.BiasedMultiheadAttention(
        feat_size, num_heads, bias, attn_bias_type, attn_drop
    )
2293
2294
2295
    out = net(ndata, attn_bias, attn_mask)

    assert out.shape == (16, 100, feat_size)
2296

2297
2298
2299

@pytest.mark.parametrize("attn_bias_type", ["add", "mul"])
@pytest.mark.parametrize("norm_first", [True, False])
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
def test_GraphormerLayer(attn_bias_type, norm_first):
    batch_size = 16
    num_nodes = 100
    feat_size = 512
    num_heads = 8
    nfeat = th.rand(batch_size, num_nodes, feat_size)
    attn_bias = th.rand(batch_size, num_nodes, num_nodes, num_heads)
    attn_mask = th.rand(batch_size, num_nodes, num_nodes) < 0.5

    net = nn.GraphormerLayer(
        feat_size=feat_size,
        hidden_size=2048,
        num_heads=num_heads,
        attn_bias_type=attn_bias_type,
        norm_first=norm_first,
        dropout=0.1,
2316
        activation=th.nn.ReLU(),
2317
2318
2319
2320
2321
    )
    out = net(nfeat, attn_bias, attn_mask)

    assert out.shape == (batch_size, num_nodes, feat_size)

2322
2323
2324
2325

@pytest.mark.parametrize("max_len", [1, 4])
@pytest.mark.parametrize("feat_dim", [16])
@pytest.mark.parametrize("num_heads", [1, 8])
2326
2327
def test_PathEncoder(max_len, feat_dim, num_heads):
    dev = F.ctx()
2328
2329
2330
2331
2332
2333
2334
2335
2336
    g1 = dgl.graph(
        (
            th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),
            th.tensor([1, 2, 3, 0, 3, 0, 0, 1]),
        )
    ).to(dev)
    g2 = dgl.graph(
        (th.tensor([0, 1, 2, 3, 2, 5]), th.tensor([1, 2, 3, 4, 0, 3]))
    ).to(dev)
2337
2338
2339
2340
2341
    bg = dgl.batch([g1, g2])
    edge_feat = th.rand(bg.num_edges(), feat_dim).to(dev)
    model = nn.PathEncoder(max_len, feat_dim, num_heads=num_heads).to(dev)
    bias = model(bg, edge_feat)
    assert bias.shape == (2, 6, 6, num_heads)
2342

2343
2344
2345
2346

@pytest.mark.parametrize("max_dist", [1, 4])
@pytest.mark.parametrize("num_kernels", [8, 16])
@pytest.mark.parametrize("num_heads", [1, 8])
2347
2348
def test_SpatialEncoder(max_dist, num_kernels, num_heads):
    dev = F.ctx()
2349
2350
2351
2352
2353
2354
2355
2356
2357
    g1 = dgl.graph(
        (
            th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),
            th.tensor([1, 2, 3, 0, 3, 0, 0, 1]),
        )
    ).to(dev)
    g2 = dgl.graph(
        (th.tensor([0, 1, 2, 3, 2, 5]), th.tensor([1, 2, 3, 4, 0, 3]))
    ).to(dev)
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
    bg = dgl.batch([g1, g2])
    ndata = th.rand(bg.num_nodes(), 3).to(dev)
    num_nodes = bg.num_nodes()
    node_type = th.randint(0, 512, (num_nodes,)).to(dev)
    model_1 = nn.SpatialEncoder(max_dist, num_heads=num_heads).to(dev)
    model_2 = nn.SpatialEncoder3d(num_kernels, num_heads=num_heads).to(dev)
    model_3 = nn.SpatialEncoder3d(
        num_kernels, num_heads=num_heads, max_node_type=512
    ).to(dev)
    encoding = model_1(bg)
    encoding3d_1 = model_2(bg, ndata)
    encoding3d_2 = model_3(bg, ndata, node_type)
    assert encoding.shape == (2, 6, 6, num_heads)
    assert encoding3d_1.shape == (2, 6, 6, num_heads)
    assert encoding3d_2.shape == (2, 6, 6, num_heads)