test_nn.py 69.2 KB
Newer Older
1
import io
2
3
4
5
6
import pickle
from copy import deepcopy

import backend as F

7
import dgl
8
import dgl.function as fn
9
10
import dgl.nn.pytorch as nn
import networkx as nx
11
import pytest
12
import scipy as sp
LuckyLiuM's avatar
LuckyLiuM committed
13
import torch
14
import torch as th
nv-dlasalle's avatar
nv-dlasalle committed
15
from test_utils import parametrize_idtype
16
17
18
19
20
21
22
from test_utils.graph_cases import (
    get_cases,
    random_bipartite,
    random_dglgraph,
    random_graph,
)
from torch.optim import Adam, SparseAdam
LuckyLiuM's avatar
LuckyLiuM committed
23
from torch.utils.data import DataLoader
24

25
26
tmp_buffer = io.BytesIO()

27

28
29
30
31
32
def _AXWb(A, X, W, b):
    X = th.matmul(X, W)
    Y = th.matmul(A, X.view(X.shape[0], -1)).view_as(X)
    return Y + b

33
34

@pytest.mark.parametrize("out_dim", [1, 2])
35
def test_graph_conv0(out_dim):
36
    g = dgl.DGLGraph(nx.path_graph(3)).to(F.ctx())
37
    ctx = F.ctx()
38
    adj = g.adjacency_matrix(transpose=True, ctx=ctx)
39

40
    conv = nn.GraphConv(5, out_dim, norm="none", bias=True)
41
    conv = conv.to(ctx)
42
    print(conv)
43
44
45
46

    # test pickle
    th.save(conv, tmp_buffer)

47
    # test#1: basic
48
    h0 = F.ones((3, 5))
49
    h1 = conv(g, h0)
50
51
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
52
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
53
    # test#2: more-dim
54
    h0 = F.ones((3, 5, 5))
55
    h1 = conv(g, h0)
56
57
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
58
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
59

60
    conv = nn.GraphConv(5, out_dim)
61
    conv = conv.to(ctx)
62
    # test#3: basic
63
    h0 = F.ones((3, 5))
64
    h1 = conv(g, h0)
65
66
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
67
    # test#4: basic
68
    h0 = F.ones((3, 5, 5))
69
    h1 = conv(g, h0)
70
71
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
72

73
    conv = nn.GraphConv(5, out_dim)
74
    conv = conv.to(ctx)
75
    # test#3: basic
76
    h0 = F.ones((3, 5))
77
    h1 = conv(g, h0)
78
79
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
80
    # test#4: basic
81
    h0 = F.ones((3, 5, 5))
82
    h1 = conv(g, h0)
83
84
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
85
86
87
88
89

    # test rest_parameters
    old_weight = deepcopy(conv.weight.data)
    conv.reset_parameters()
    new_weight = conv.weight.data
90
    assert not F.allclose(old_weight, new_weight)
91

92

nv-dlasalle's avatar
nv-dlasalle committed
93
@parametrize_idtype
94
95
96
97
98
99
100
@pytest.mark.parametrize(
    "g", get_cases(["homo", "bipartite"], exclude=["zero-degree", "dglgraph"])
)
@pytest.mark.parametrize("norm", ["none", "both", "right", "left"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("out_dim", [1, 2])
101
def test_graph_conv(idtype, g, norm, weight, bias, out_dim):
102
103
    # Test one tensor input
    g = g.astype(idtype).to(F.ctx())
104
105
106
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(
        F.ctx()
    )
107
    ext_w = F.randn((5, out_dim)).to(F.ctx())
108
109
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
110
111
    h = F.randn((nsrc, 5)).to(F.ctx())
    if weight:
112
        h_out = conv(g, h)
113
    else:
114
        h_out = conv(g, h, weight=ext_w)
115
    assert h_out.shape == (ndst, out_dim)
116

117

nv-dlasalle's avatar
nv-dlasalle committed
118
@parametrize_idtype
119
120
121
122
123
124
125
126
@pytest.mark.parametrize(
    "g",
    get_cases(["has_scalar_e_feature"], exclude=["zero-degree", "dglgraph"]),
)
@pytest.mark.parametrize("norm", ["none", "both", "right"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("out_dim", [1, 2])
127
def test_graph_conv_e_weight(idtype, g, norm, weight, bias, out_dim):
128
    g = g.astype(idtype).to(F.ctx())
129
130
131
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(
        F.ctx()
    )
132
    ext_w = F.randn((5, out_dim)).to(F.ctx())
133
134
135
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
136
    e_w = g.edata["scalar_w"]
137
138
139
140
    if weight:
        h_out = conv(g, h, edge_weight=e_w)
    else:
        h_out = conv(g, h, weight=ext_w, edge_weight=e_w)
141
    assert h_out.shape == (ndst, out_dim)
142

143

nv-dlasalle's avatar
nv-dlasalle committed
144
@parametrize_idtype
145
146
147
148
149
150
151
152
@pytest.mark.parametrize(
    "g",
    get_cases(["has_scalar_e_feature"], exclude=["zero-degree", "dglgraph"]),
)
@pytest.mark.parametrize("norm", ["none", "both", "right"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("out_dim", [1, 2])
153
def test_graph_conv_e_weight_norm(idtype, g, norm, weight, bias, out_dim):
154
    g = g.astype(idtype).to(F.ctx())
155
156
157
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(
        F.ctx()
    )
158
159
160
161

    # test pickle
    th.save(conv, tmp_buffer)

162
    ext_w = F.randn((5, out_dim)).to(F.ctx())
163
164
165
166
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
    edgenorm = nn.EdgeWeightNorm(norm=norm)
167
    norm_weight = edgenorm(g, g.edata["scalar_w"])
168
169
170
171
    if weight:
        h_out = conv(g, h, edge_weight=norm_weight)
    else:
        h_out = conv(g, h, weight=ext_w, edge_weight=norm_weight)
172
    assert h_out.shape == (ndst, out_dim)
173

174

nv-dlasalle's avatar
nv-dlasalle committed
175
@parametrize_idtype
176
177
178
179
180
181
182
@pytest.mark.parametrize(
    "g", get_cases(["bipartite"], exclude=["zero-degree", "dglgraph"])
)
@pytest.mark.parametrize("norm", ["none", "both", "right"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("out_dim", [1, 2])
183
def test_graph_conv_bi(idtype, g, norm, weight, bias, out_dim):
184
185
    # Test a pair of tensor inputs
    g = g.astype(idtype).to(F.ctx())
186
187
188
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(
        F.ctx()
    )
Mufei Li's avatar
Mufei Li committed
189

190
191
192
    # test pickle
    th.save(conv, tmp_buffer)

193
    ext_w = F.randn((5, out_dim)).to(F.ctx())
194
195
196
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
197
    h_dst = F.randn((ndst, out_dim)).to(F.ctx())
198
199
200
201
    if weight:
        h_out = conv(g, (h, h_dst))
    else:
        h_out = conv(g, (h, h_dst), weight=ext_w)
202
    assert h_out.shape == (ndst, out_dim)
203

204

205
206
207
208
209
210
211
212
213
214
215
216
def _S2AXWb(A, N, X, W, b):
    X1 = X * N
    X1 = th.matmul(A, X1.view(X1.shape[0], -1))
    X1 = X1 * N
    X2 = X1 * N
    X2 = th.matmul(A, X2.view(X2.shape[0], -1))
    X2 = X2 * N
    X = th.cat([X, X1, X2], dim=-1)
    Y = th.matmul(X, W.rot90())

    return Y + b

217
218

@pytest.mark.parametrize("out_dim", [1, 2])
219
def test_tagconv(out_dim):
220
    g = dgl.DGLGraph(nx.path_graph(3))
221
    g = g.to(F.ctx())
222
    ctx = F.ctx()
223
    adj = g.adjacency_matrix(transpose=True, ctx=ctx)
224
225
    norm = th.pow(g.in_degrees().float(), -0.5)

226
    conv = nn.TAGConv(5, out_dim, bias=True)
227
    conv = conv.to(ctx)
228
    print(conv)
Mufei Li's avatar
Mufei Li committed
229

230
231
    # test pickle
    th.save(conv, tmp_buffer)
232
233
234

    # test#1: basic
    h0 = F.ones((3, 5))
235
    h1 = conv(g, h0)
236
237
238
239
240
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    shp = norm.shape + (1,) * (h0.dim() - 1)
    norm = th.reshape(norm, shp).to(ctx)

241
242
243
    assert F.allclose(
        h1, _S2AXWb(adj, norm, h0, conv.lin.weight, conv.lin.bias)
    )
244

245
    conv = nn.TAGConv(5, out_dim)
246
    conv = conv.to(ctx)
247

248
249
    # test#2: basic
    h0 = F.ones((3, 5))
250
    h1 = conv(g, h0)
251
    assert h1.shape[-1] == out_dim
252

253
    # test reset_parameters
254
255
256
257
258
    old_weight = deepcopy(conv.lin.weight.data)
    conv.reset_parameters()
    new_weight = conv.lin.weight.data
    assert not F.allclose(old_weight, new_weight)

259

260
def test_set2set():
261
    ctx = F.ctx()
262
    g = dgl.DGLGraph(nx.path_graph(10))
263
    g = g.to(F.ctx())
264

265
    s2s = nn.Set2Set(5, 3, 3)  # hidden size 5, 3 iters, 3 layers
266
    s2s = s2s.to(ctx)
267
268
269
    print(s2s)

    # test#1: basic
270
    h0 = F.randn((g.number_of_nodes(), 5))
271
    h1 = s2s(g, h0)
272
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
273
274

    # test#2: batched graph
275
276
    g1 = dgl.DGLGraph(nx.path_graph(11)).to(F.ctx())
    g2 = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
277
    bg = dgl.batch([g, g1, g2])
278
    h0 = F.randn((bg.number_of_nodes(), 5))
279
    h1 = s2s(bg, h0)
280
281
    assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2

282

283
def test_glob_att_pool():
284
    ctx = F.ctx()
285
    g = dgl.DGLGraph(nx.path_graph(10))
286
    g = g.to(F.ctx())
287
288

    gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))
289
    gap = gap.to(ctx)
290
291
    print(gap)

292
293
294
    # test pickle
    th.save(gap, tmp_buffer)

295
    # test#1: basic
296
    h0 = F.randn((g.number_of_nodes(), 5))
297
    h1 = gap(g, h0)
298
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
299
300
301

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
302
    h0 = F.randn((bg.number_of_nodes(), 5))
303
    h1 = gap(bg, h0)
304
305
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2

306

307
def test_simple_pool():
308
    ctx = F.ctx()
309
    g = dgl.DGLGraph(nx.path_graph(15))
310
    g = g.to(F.ctx())
311
312
313
314

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
315
    sort_pool = nn.SortPooling(10)  # k = 10
316
317
318
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
319
    h0 = F.randn((g.number_of_nodes(), 5))
320
321
322
323
    sum_pool = sum_pool.to(ctx)
    avg_pool = avg_pool.to(ctx)
    max_pool = max_pool.to(ctx)
    sort_pool = sort_pool.to(ctx)
324
    h1 = sum_pool(g, h0)
325
    assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
326
    h1 = avg_pool(g, h0)
327
    assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
328
    h1 = max_pool(g, h0)
329
    assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
330
    h1 = sort_pool(g, h0)
331
    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.dim() == 2
332
333

    # test#2: batched graph
334
    g_ = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
335
    bg = dgl.batch([g, g_, g, g_, g])
336
    h0 = F.randn((bg.number_of_nodes(), 5))
337
    h1 = sum_pool(bg, h0)
338
339
340
341
342
343
344
345
346
347
    truth = th.stack(
        [
            F.sum(h0[:15], 0),
            F.sum(h0[15:20], 0),
            F.sum(h0[20:35], 0),
            F.sum(h0[35:40], 0),
            F.sum(h0[40:55], 0),
        ],
        0,
    )
348
    assert F.allclose(h1, truth)
349

350
    h1 = avg_pool(bg, h0)
351
352
353
354
355
356
357
358
359
360
    truth = th.stack(
        [
            F.mean(h0[:15], 0),
            F.mean(h0[15:20], 0),
            F.mean(h0[20:35], 0),
            F.mean(h0[35:40], 0),
            F.mean(h0[40:55], 0),
        ],
        0,
    )
361
    assert F.allclose(h1, truth)
362

363
    h1 = max_pool(bg, h0)
364
365
366
367
368
369
370
371
372
373
    truth = th.stack(
        [
            F.max(h0[:15], 0),
            F.max(h0[15:20], 0),
            F.max(h0[20:35], 0),
            F.max(h0[35:40], 0),
            F.max(h0[40:55], 0),
        ],
        0,
    )
374
    assert F.allclose(h1, truth)
375

376
    h1 = sort_pool(bg, h0)
377
378
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2

379

380
def test_set_trans():
381
    ctx = F.ctx()
382
383
    g = dgl.DGLGraph(nx.path_graph(15))

384
385
    st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, "sab")
    st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, "isab", 3)
386
    st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)
387
388
389
    st_enc_0 = st_enc_0.to(ctx)
    st_enc_1 = st_enc_1.to(ctx)
    st_dec = st_dec.to(ctx)
390
391
392
    print(st_enc_0, st_enc_1, st_dec)

    # test#1: basic
393
    h0 = F.randn((g.number_of_nodes(), 50))
394
    h1 = st_enc_0(g, h0)
395
    assert h1.shape == h0.shape
396
    h1 = st_enc_1(g, h0)
397
    assert h1.shape == h0.shape
398
    h2 = st_dec(g, h1)
399
    assert h2.shape[0] == 1 and h2.shape[1] == 200 and h2.dim() == 2
400
401
402
403
404

    # test#2: batched graph
    g1 = dgl.DGLGraph(nx.path_graph(5))
    g2 = dgl.DGLGraph(nx.path_graph(10))
    bg = dgl.batch([g, g1, g2])
405
    h0 = F.randn((bg.number_of_nodes(), 50))
406
    h1 = st_enc_0(bg, h0)
407
    assert h1.shape == h0.shape
408
    h1 = st_enc_1(bg, h0)
409
410
    assert h1.shape == h0.shape

411
    h2 = st_dec(bg, h1)
412
413
    assert h2.shape[0] == 3 and h2.shape[1] == 200 and h2.dim() == 2

414

nv-dlasalle's avatar
nv-dlasalle committed
415
@parametrize_idtype
416
@pytest.mark.parametrize("O", [1, 8, 32])
417
def test_rgcn(idtype, O):
Minjie Wang's avatar
Minjie Wang committed
418
419
    ctx = F.ctx()
    etype = []
420
421
    g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.1))
    g = g.astype(idtype).to(F.ctx())
Minjie Wang's avatar
Minjie Wang committed
422
423
424
425
426
427
428
429
430
    # 5 etypes
    R = 5
    for i in range(g.number_of_edges()):
        etype.append(i % 5)
    B = 2
    I = 10

    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
xiang song(charlie.song)'s avatar
xiang song(charlie.song) committed
431
    norm = th.rand((g.number_of_edges(), 1)).to(ctx)
432
    sorted_r, idx = th.sort(r)
433
434
435
436
437
    sorted_g = dgl.reorder_graph(
        g,
        edge_permute_algo="custom",
        permute_config={"edges_perm": idx.to(idtype)},
    )
438
    sorted_norm = norm[idx]
Minjie Wang's avatar
Minjie Wang committed
439

440
441
    rgc = nn.RelGraphConv(I, O, R).to(ctx)
    th.save(rgc, tmp_buffer)  # test pickle
Minjie Wang's avatar
Minjie Wang committed
442
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
443
    th.save(rgc_basis, tmp_buffer)  # test pickle
444
445
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
446
        th.save(rgc_bdd, tmp_buffer)  # test pickle
447

448
449
450
451
452
    # basic usage
    h_new = rgc(g, h, r)
    assert h_new.shape == (100, O)
    h_new_basis = rgc_basis(g, h, r)
    assert h_new_basis.shape == (100, O)
453
    if O % B == 0:
454
455
456
457
458
459
460
461
462
463
464
        h_new_bdd = rgc_bdd(g, h, r)
        assert h_new_bdd.shape == (100, O)

    # sorted input
    h_new_sorted = rgc(sorted_g, h, sorted_r, presorted=True)
    assert th.allclose(h_new, h_new_sorted, atol=1e-4, rtol=1e-4)
    h_new_basis_sorted = rgc_basis(sorted_g, h, sorted_r, presorted=True)
    assert th.allclose(h_new_basis, h_new_basis_sorted, atol=1e-4, rtol=1e-4)
    if O % B == 0:
        h_new_bdd_sorted = rgc_bdd(sorted_g, h, sorted_r, presorted=True)
        assert th.allclose(h_new_bdd, h_new_bdd_sorted, atol=1e-4, rtol=1e-4)
465

466
467
468
    # norm input
    h_new = rgc(g, h, r, norm)
    assert h_new.shape == (100, O)
469
    h_new = rgc_basis(g, h, r, norm)
470
    assert h_new.shape == (100, O)
471
472
    if O % B == 0:
        h_new = rgc_bdd(g, h, r, norm)
473
        assert h_new.shape == (100, O)
474

475

476
@parametrize_idtype
477
@pytest.mark.parametrize("O", [1, 10, 40])
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
def test_rgcn_default_nbasis(idtype, O):
    ctx = F.ctx()
    etype = []
    g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.1))
    g = g.astype(idtype).to(F.ctx())
    # 5 etypes
    R = 5
    for i in range(g.number_of_edges()):
        etype.append(i % 5)
    I = 10

    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    norm = th.rand((g.number_of_edges(), 1)).to(ctx)
    sorted_r, idx = th.sort(r)
493
494
495
496
497
    sorted_g = dgl.reorder_graph(
        g,
        edge_permute_algo="custom",
        permute_config={"edges_perm": idx.to(idtype)},
    )
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
    sorted_norm = norm[idx]

    rgc = nn.RelGraphConv(I, O, R).to(ctx)
    th.save(rgc, tmp_buffer)  # test pickle
    rgc_basis = nn.RelGraphConv(I, O, R, "basis").to(ctx)
    th.save(rgc_basis, tmp_buffer)  # test pickle
    if O % R == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd").to(ctx)
        th.save(rgc_bdd, tmp_buffer)  # test pickle

    # basic usage
    h_new = rgc(g, h, r)
    assert h_new.shape == (100, O)
    h_new_basis = rgc_basis(g, h, r)
    assert h_new_basis.shape == (100, O)
    if O % R == 0:
        h_new_bdd = rgc_bdd(g, h, r)
        assert h_new_bdd.shape == (100, O)

    # sorted input
    h_new_sorted = rgc(sorted_g, h, sorted_r, presorted=True)
    assert th.allclose(h_new, h_new_sorted, atol=1e-4, rtol=1e-4)
    h_new_basis_sorted = rgc_basis(sorted_g, h, sorted_r, presorted=True)
    assert th.allclose(h_new_basis, h_new_basis_sorted, atol=1e-4, rtol=1e-4)
    if O % R == 0:
        h_new_bdd_sorted = rgc_bdd(sorted_g, h, sorted_r, presorted=True)
        assert th.allclose(h_new_bdd, h_new_bdd_sorted, atol=1e-4, rtol=1e-4)

    # norm input
    h_new = rgc(g, h, r, norm)
    assert h_new.shape == (100, O)
    h_new = rgc_basis(g, h, r, norm)
    assert h_new.shape == (100, O)
    if O % R == 0:
        h_new = rgc_bdd(g, h, r, norm)
        assert h_new.shape == (100, O)
534

535

nv-dlasalle's avatar
nv-dlasalle committed
536
@parametrize_idtype
537
538
539
540
541
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
542
def test_gat_conv(g, idtype, out_dim, num_heads):
543
    g = g.astype(idtype).to(F.ctx())
544
    ctx = F.ctx()
545
    gat = nn.GATConv(5, out_dim, num_heads)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
546
    feat = F.randn((g.number_of_src_nodes(), 5))
547
    gat = gat.to(ctx)
548
    h = gat(g, feat)
549
550
551
552

    # test pickle
    th.save(gat, tmp_buffer)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
553
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
554
    _, a = gat(g, feat, get_attention=True)
555
    assert a.shape == (g.number_of_edges(), num_heads, 1)
556

557
558
559
560
561
    # test residual connection
    gat = nn.GATConv(5, out_dim, num_heads, residual=True)
    gat = gat.to(ctx)
    h = gat(g, feat)

562

nv-dlasalle's avatar
nv-dlasalle committed
563
@parametrize_idtype
564
565
566
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
567
def test_gat_conv_bi(g, idtype, out_dim, num_heads):
568
569
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
570
    gat = nn.GATConv(5, out_dim, num_heads)
571
572
573
574
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
575
576
    gat = gat.to(ctx)
    h = gat(g, feat)
577
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
578
    _, a = gat(g, feat, get_attention=True)
579
    assert a.shape == (g.number_of_edges(), num_heads, 1)
580

581

nv-dlasalle's avatar
nv-dlasalle committed
582
@parametrize_idtype
583
584
585
586
587
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
Shaked Brody's avatar
Shaked Brody committed
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
def test_gatv2_conv(g, idtype, out_dim, num_heads):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    gat = nn.GATv2Conv(5, out_dim, num_heads)
    feat = F.randn((g.number_of_src_nodes(), 5))
    gat = gat.to(ctx)
    h = gat(g, feat)

    # test pickle
    th.save(gat, tmp_buffer)

    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
    _, a = gat(g, feat, get_attention=True)
    assert a.shape == (g.number_of_edges(), num_heads, 1)

    # test residual connection
    gat = nn.GATConv(5, out_dim, num_heads, residual=True)
    gat = gat.to(ctx)
    h = gat(g, feat)

608

nv-dlasalle's avatar
nv-dlasalle committed
609
@parametrize_idtype
610
611
612
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
Shaked Brody's avatar
Shaked Brody committed
613
614
615
616
def test_gatv2_conv_bi(g, idtype, out_dim, num_heads):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    gat = nn.GATv2Conv(5, out_dim, num_heads)
617
618
619
620
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
Shaked Brody's avatar
Shaked Brody committed
621
622
623
624
625
626
    gat = gat.to(ctx)
    h = gat(g, feat)
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
    _, a = gat(g, feat, get_attention=True)
    assert a.shape == (g.number_of_edges(), num_heads, 1)

627

nv-dlasalle's avatar
nv-dlasalle committed
628
@parametrize_idtype
629
630
631
632
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_node_feats", [1, 5])
@pytest.mark.parametrize("out_edge_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
633
634
def test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads):
    g = g.astype(idtype).to(F.ctx())
Mufei Li's avatar
Mufei Li committed
635
    ctx = F.ctx()
636
637
638
639
640
641
642
    egat = nn.EGATConv(
        in_node_feats=10,
        in_edge_feats=5,
        out_node_feats=out_node_feats,
        out_edge_feats=out_edge_feats,
        num_heads=num_heads,
    )
643
644
645
646
    nfeat = F.randn((g.number_of_nodes(), 10))
    efeat = F.randn((g.number_of_edges(), 5))
    egat = egat.to(ctx)
    h, f = egat(g, nfeat, efeat)
647

648
    th.save(egat, tmp_buffer)
649

650
651
652
653
    assert h.shape == (g.number_of_nodes(), num_heads, out_node_feats)
    assert f.shape == (g.number_of_edges(), num_heads, out_edge_feats)
    _, _, attn = egat(g, nfeat, efeat, True)
    assert attn.shape == (g.number_of_edges(), num_heads, 1)
654

655

656
@parametrize_idtype
657
658
659
660
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_node_feats", [1, 5])
@pytest.mark.parametrize("out_edge_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
661
662
663
def test_egat_conv_bi(g, idtype, out_node_feats, out_edge_feats, num_heads):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
664
665
666
667
668
669
670
671
672
673
674
    egat = nn.EGATConv(
        in_node_feats=(10, 15),
        in_edge_feats=7,
        out_node_feats=out_node_feats,
        out_edge_feats=out_edge_feats,
        num_heads=num_heads,
    )
    nfeat = (
        F.randn((g.number_of_src_nodes(), 10)),
        F.randn((g.number_of_dst_nodes(), 15)),
    )
675
676
677
    efeat = F.randn((g.number_of_edges(), 7))
    egat = egat.to(ctx)
    h, f = egat(g, nfeat, efeat)
678

Mufei Li's avatar
Mufei Li committed
679
    th.save(egat, tmp_buffer)
680

681
682
683
684
685
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_node_feats)
    assert f.shape == (g.number_of_edges(), num_heads, out_edge_feats)
    _, _, attn = egat(g, nfeat, efeat, True)
    assert attn.shape == (g.number_of_edges(), num_heads, 1)

686

nv-dlasalle's avatar
nv-dlasalle committed
687
@parametrize_idtype
688
689
@pytest.mark.parametrize("g", get_cases(["homo", "block-bipartite"]))
@pytest.mark.parametrize("aggre_type", ["mean", "pool", "gcn", "lstm"])
690
691
def test_sage_conv(idtype, g, aggre_type):
    g = g.astype(idtype).to(F.ctx())
692
    sage = nn.SAGEConv(5, 10, aggre_type)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
693
    feat = F.randn((g.number_of_src_nodes(), 5))
694
    sage = sage.to(F.ctx())
695
696
    # test pickle
    th.save(sage, tmp_buffer)
697
698
699
    h = sage(g, feat)
    assert h.shape[-1] == 10

700

nv-dlasalle's avatar
nv-dlasalle committed
701
@parametrize_idtype
702
703
704
@pytest.mark.parametrize("g", get_cases(["bipartite"]))
@pytest.mark.parametrize("aggre_type", ["mean", "pool", "gcn", "lstm"])
@pytest.mark.parametrize("out_dim", [1, 2])
705
def test_sage_conv_bi(idtype, g, aggre_type, out_dim):
706
    g = g.astype(idtype).to(F.ctx())
707
    dst_dim = 5 if aggre_type != "gcn" else 10
708
    sage = nn.SAGEConv((10, dst_dim), out_dim, aggre_type)
709
710
711
712
    feat = (
        F.randn((g.number_of_src_nodes(), 10)),
        F.randn((g.number_of_dst_nodes(), dst_dim)),
    )
713
    sage = sage.to(F.ctx())
714
    h = sage(g, feat)
715
    assert h.shape[-1] == out_dim
716
    assert h.shape[0] == g.number_of_dst_nodes()
717

718

nv-dlasalle's avatar
nv-dlasalle committed
719
@parametrize_idtype
720
@pytest.mark.parametrize("out_dim", [1, 2])
721
def test_sage_conv2(idtype, out_dim):
722
    # TODO: add test for blocks
Mufei Li's avatar
Mufei Li committed
723
    # Test the case for graphs without edges
724
    g = dgl.heterograph({("_U", "_E", "_V"): ([], [])}, {"_U": 5, "_V": 3})
725
726
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
727
    sage = nn.SAGEConv((3, 3), out_dim, "gcn")
Mufei Li's avatar
Mufei Li committed
728
729
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    sage = sage.to(ctx)
730
    h = sage(g, (F.copy_to(feat[0], F.ctx()), F.copy_to(feat[1], F.ctx())))
731
    assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
732
    assert h.shape[0] == 3
733
    for aggre_type in ["mean", "pool", "lstm"]:
734
        sage = nn.SAGEConv((3, 1), out_dim, aggre_type)
Mufei Li's avatar
Mufei Li committed
735
736
737
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        sage = sage.to(ctx)
        h = sage(g, feat)
738
        assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
739
740
        assert h.shape[0] == 3

741

nv-dlasalle's avatar
nv-dlasalle committed
742
@parametrize_idtype
743
744
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
745
def test_sgc_conv(g, idtype, out_dim):
746
    ctx = F.ctx()
747
    g = g.astype(idtype).to(ctx)
748
    # not cached
749
    sgc = nn.SGConv(5, out_dim, 3)
750
751
752
753

    # test pickle
    th.save(sgc, tmp_buffer)

754
    feat = F.randn((g.number_of_nodes(), 5))
755
    sgc = sgc.to(ctx)
756

757
    h = sgc(g, feat)
758
    assert h.shape[-1] == out_dim
759
760

    # cached
761
    sgc = nn.SGConv(5, out_dim, 3, True)
762
    sgc = sgc.to(ctx)
763
764
    h_0 = sgc(g, feat)
    h_1 = sgc(g, feat + 1)
765
    assert F.allclose(h_0, h_1)
766
    assert h_0.shape[-1] == out_dim
767

768

nv-dlasalle's avatar
nv-dlasalle committed
769
@parametrize_idtype
770
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
771
def test_appnp_conv(g, idtype):
772
    ctx = F.ctx()
773
    g = g.astype(idtype).to(ctx)
774
    appnp = nn.APPNPConv(10, 0.1)
775
    feat = F.randn((g.number_of_nodes(), 5))
776
    appnp = appnp.to(ctx)
Mufei Li's avatar
Mufei Li committed
777

778
779
    # test pickle
    th.save(appnp, tmp_buffer)
780

781
    h = appnp(g, feat)
782
783
    assert h.shape[-1] == 5

784

nv-dlasalle's avatar
nv-dlasalle committed
785
@parametrize_idtype
786
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
787
788
789
790
791
def test_appnp_conv_e_weight(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    appnp = nn.APPNPConv(10, 0.1)
    feat = F.randn((g.number_of_nodes(), 5))
792
    eweight = F.ones((g.num_edges(),))
793
794
795
796
797
    appnp = appnp.to(ctx)

    h = appnp(g, feat, edge_weight=eweight)
    assert h.shape[-1] == 5

798

nv-dlasalle's avatar
nv-dlasalle committed
799
@parametrize_idtype
800
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
801
802
@pytest.mark.parametrize("bias", [True, False])
def test_gcn2conv_e_weight(g, idtype, bias):
803
804
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
805
806
807
    gcn2conv = nn.GCN2Conv(
        5, layer=2, alpha=0.5, bias=bias, project_initial_features=True
    )
808
    feat = F.randn((g.number_of_nodes(), 5))
809
    eweight = F.ones((g.num_edges(),))
810
811
812
813
814
815
    gcn2conv = gcn2conv.to(ctx)
    res = feat
    h = gcn2conv(g, res, feat, edge_weight=eweight)
    assert h.shape[-1] == 5


nv-dlasalle's avatar
nv-dlasalle committed
816
@parametrize_idtype
817
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
818
819
820
821
822
def test_sgconv_e_weight(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    sgconv = nn.SGConv(5, 5, 3)
    feat = F.randn((g.number_of_nodes(), 5))
823
    eweight = F.ones((g.num_edges(),))
824
825
826
827
    sgconv = sgconv.to(ctx)
    h = sgconv(g, feat, edge_weight=eweight)
    assert h.shape[-1] == 5

828

nv-dlasalle's avatar
nv-dlasalle committed
829
@parametrize_idtype
830
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
831
832
833
834
835
836
def test_tagconv_e_weight(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    conv = nn.TAGConv(5, 5, bias=True)
    conv = conv.to(ctx)
    feat = F.randn((g.number_of_nodes(), 5))
837
    eweight = F.ones((g.num_edges(),))
838
839
840
841
    conv = conv.to(ctx)
    h = conv(g, feat, edge_weight=eweight)
    assert h.shape[-1] == 5

842

nv-dlasalle's avatar
nv-dlasalle committed
843
@parametrize_idtype
844
845
846
847
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("aggregator_type", ["mean", "max", "sum"])
848
849
def test_gin_conv(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
850
    ctx = F.ctx()
851
    gin = nn.GINConv(th.nn.Linear(5, 12), aggregator_type)
VoVAllen's avatar
VoVAllen committed
852
    th.save(gin, tmp_buffer)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
853
    feat = F.randn((g.number_of_src_nodes(), 5))
854
855
    gin = gin.to(ctx)
    h = gin(g, feat)
856
857

    # test pickle
VoVAllen's avatar
VoVAllen committed
858
    th.save(gin, tmp_buffer)
Mufei Li's avatar
Mufei Li committed
859

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
860
    assert h.shape == (g.number_of_dst_nodes(), 12)
861

Mufei Li's avatar
Mufei Li committed
862
863
864
865
    gin = nn.GINConv(None, aggregator_type)
    th.save(gin, tmp_buffer)
    gin = gin.to(ctx)
    h = gin(g, feat)
866

867

nv-dlasalle's avatar
nv-dlasalle committed
868
@parametrize_idtype
869
@pytest.mark.parametrize("g", get_cases(["homo", "block-bipartite"]))
Mufei Li's avatar
Mufei Li committed
870
871
872
def test_gine_conv(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
873
    gine = nn.GINEConv(th.nn.Linear(5, 12))
Mufei Li's avatar
Mufei Li committed
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
    th.save(gine, tmp_buffer)
    nfeat = F.randn((g.number_of_src_nodes(), 5))
    efeat = F.randn((g.num_edges(), 5))
    gine = gine.to(ctx)
    h = gine(g, nfeat, efeat)

    # test pickle
    th.save(gine, tmp_buffer)
    assert h.shape == (g.number_of_dst_nodes(), 12)

    gine = nn.GINEConv(None)
    th.save(gine, tmp_buffer)
    gine = gine.to(ctx)
    h = gine(g, nfeat, efeat)

889

nv-dlasalle's avatar
nv-dlasalle committed
890
@parametrize_idtype
891
892
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("aggregator_type", ["mean", "max", "sum"])
893
894
895
def test_gin_conv_bi(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
896
897
898
899
    gin = nn.GINConv(th.nn.Linear(5, 12), aggregator_type)
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
900
901
902
    )
    gin = gin.to(ctx)
    h = gin(g, feat)
903
    assert h.shape == (g.number_of_dst_nodes(), 12)
904

905

nv-dlasalle's avatar
nv-dlasalle committed
906
@parametrize_idtype
907
908
909
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
910
911
def test_agnn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
912
913
    ctx = F.ctx()
    agnn = nn.AGNNConv(1)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
914
    feat = F.randn((g.number_of_src_nodes(), 5))
915
    agnn = agnn.to(ctx)
916
    h = agnn(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
917
    assert h.shape == (g.number_of_dst_nodes(), 5)
918

919

nv-dlasalle's avatar
nv-dlasalle committed
920
@parametrize_idtype
921
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
922
923
924
def test_agnn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
925
    agnn = nn.AGNNConv(1)
926
927
928
929
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
930
931
    agnn = agnn.to(ctx)
    h = agnn(g, feat)
932
    assert h.shape == (g.number_of_dst_nodes(), 5)
933

934

nv-dlasalle's avatar
nv-dlasalle committed
935
@parametrize_idtype
936
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
937
def test_gated_graph_conv(g, idtype):
938
    ctx = F.ctx()
939
    g = g.astype(idtype).to(ctx)
940
941
    ggconv = nn.GatedGraphConv(5, 10, 5, 3)
    etypes = th.arange(g.number_of_edges()) % 3
942
    feat = F.randn((g.number_of_nodes(), 5))
943
944
    ggconv = ggconv.to(ctx)
    etypes = etypes.to(ctx)
945

946
    h = ggconv(g, feat, etypes)
947
948
949
    # current we only do shape check
    assert h.shape[-1] == 10

950

nv-dlasalle's avatar
nv-dlasalle committed
951
@parametrize_idtype
952
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
def test_gated_graph_conv_one_etype(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    ggconv = nn.GatedGraphConv(5, 10, 5, 1)
    etypes = th.zeros(g.number_of_edges())
    feat = F.randn((g.number_of_nodes(), 5))
    ggconv = ggconv.to(ctx)
    etypes = etypes.to(ctx)

    h = ggconv(g, feat, etypes)
    h2 = ggconv(g, feat)
    # current we only do shape check
    assert F.allclose(h, h2)
    assert h.shape[-1] == 10

968

nv-dlasalle's avatar
nv-dlasalle committed
969
@parametrize_idtype
970
971
972
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
973
974
def test_nn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
975
976
    ctx = F.ctx()
    edge_func = th.nn.Linear(4, 5 * 10)
977
    nnconv = nn.NNConv(5, 10, edge_func, "mean")
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
978
    feat = F.randn((g.number_of_src_nodes(), 5))
979
980
981
982
983
984
    efeat = F.randn((g.number_of_edges(), 4))
    nnconv = nnconv.to(ctx)
    h = nnconv(g, feat, efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

985

nv-dlasalle's avatar
nv-dlasalle committed
986
@parametrize_idtype
987
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
988
989
990
def test_nn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
991
    edge_func = th.nn.Linear(4, 5 * 10)
992
    nnconv = nn.NNConv((5, 2), 10, edge_func, "mean")
993
994
    feat = F.randn((g.number_of_src_nodes(), 5))
    feat_dst = F.randn((g.number_of_dst_nodes(), 2))
995
996
997
998
999
1000
    efeat = F.randn((g.number_of_edges(), 4))
    nnconv = nnconv.to(ctx)
    h = nnconv(g, (feat, feat_dst), efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

1001

nv-dlasalle's avatar
nv-dlasalle committed
1002
@parametrize_idtype
1003
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
1004
1005
def test_gmm_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
1006
    ctx = F.ctx()
1007
    gmmconv = nn.GMMConv(5, 10, 3, 4, "mean")
1008
    feat = F.randn((g.number_of_nodes(), 5))
1009
    pseudo = F.randn((g.number_of_edges(), 3))
1010
    gmmconv = gmmconv.to(ctx)
1011
    h = gmmconv(g, feat, pseudo)
1012
1013
1014
    # currently we only do shape check
    assert h.shape[-1] == 10

1015

nv-dlasalle's avatar
nv-dlasalle committed
1016
@parametrize_idtype
1017
1018
1019
@pytest.mark.parametrize(
    "g", get_cases(["bipartite", "block-bipartite"], exclude=["zero-degree"])
)
1020
1021
1022
def test_gmm_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1023
    gmmconv = nn.GMMConv((5, 2), 10, 3, 4, "mean")
1024
1025
    feat = F.randn((g.number_of_src_nodes(), 5))
    feat_dst = F.randn((g.number_of_dst_nodes(), 2))
1026
1027
1028
1029
1030
1031
    pseudo = F.randn((g.number_of_edges(), 3))
    gmmconv = gmmconv.to(ctx)
    h = gmmconv(g, (feat, feat_dst), pseudo)
    # currently we only do shape check
    assert h.shape[-1] == 10

1032

nv-dlasalle's avatar
nv-dlasalle committed
1033
@parametrize_idtype
1034
1035
1036
1037
1038
@pytest.mark.parametrize("norm_type", ["both", "right", "none"])
@pytest.mark.parametrize(
    "g", get_cases(["homo", "bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 2])
1039
def test_dense_graph_conv(norm_type, g, idtype, out_dim):
1040
    g = g.astype(idtype).to(F.ctx())
1041
    ctx = F.ctx()
1042
    # TODO(minjie): enable the following option after #1385
1043
    adj = g.adjacency_matrix(transpose=True, ctx=ctx).to_dense()
1044
1045
    conv = nn.GraphConv(5, out_dim, norm=norm_type, bias=True)
    dense_conv = nn.DenseGraphConv(5, out_dim, norm=norm_type, bias=True)
1046
1047
    dense_conv.weight.data = conv.weight.data
    dense_conv.bias.data = conv.bias.data
1048
    feat = F.randn((g.number_of_src_nodes(), 5))
1049
1050
    conv = conv.to(ctx)
    dense_conv = dense_conv.to(ctx)
1051
1052
    out_conv = conv(g, feat)
    out_dense_conv = dense_conv(adj, feat)
1053
1054
    assert F.allclose(out_conv, out_dense_conv)

1055

nv-dlasalle's avatar
nv-dlasalle committed
1056
@parametrize_idtype
1057
1058
@pytest.mark.parametrize("g", get_cases(["homo", "bipartite"]))
@pytest.mark.parametrize("out_dim", [1, 2])
1059
def test_dense_sage_conv(g, idtype, out_dim):
1060
    g = g.astype(idtype).to(F.ctx())
1061
    ctx = F.ctx()
1062
    adj = g.adjacency_matrix(transpose=True, ctx=ctx).to_dense()
1063
    sage = nn.SAGEConv(5, out_dim, "gcn")
1064
    dense_sage = nn.DenseSAGEConv(5, out_dim)
1065
    dense_sage.fc.weight.data = sage.fc_neigh.weight.data
1066
    dense_sage.fc.bias.data = sage.bias.data
1067
1068
1069
    if len(g.ntypes) == 2:
        feat = (
            F.randn((g.number_of_src_nodes(), 5)),
1070
            F.randn((g.number_of_dst_nodes(), 5)),
1071
1072
1073
        )
    else:
        feat = F.randn((g.number_of_nodes(), 5))
1074
1075
    sage = sage.to(ctx)
    dense_sage = dense_sage.to(ctx)
1076
1077
    out_sage = sage(g, feat)
    out_dense_sage = dense_sage(adj, feat)
1078
1079
    assert F.allclose(out_sage, out_dense_sage), g

1080

nv-dlasalle's avatar
nv-dlasalle committed
1081
@parametrize_idtype
1082
1083
1084
1085
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 2])
1086
def test_edge_conv(g, idtype, out_dim):
1087
    g = g.astype(idtype).to(F.ctx())
1088
    ctx = F.ctx()
1089
    edge_conv = nn.EdgeConv(5, out_dim).to(ctx)
1090
    print(edge_conv)
1091
1092
1093

    # test pickle
    th.save(edge_conv, tmp_buffer)
Mufei Li's avatar
Mufei Li committed
1094

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1095
    h0 = F.randn((g.number_of_src_nodes(), 5))
1096
    h1 = edge_conv(g, h0)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1097
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)
1098

1099

nv-dlasalle's avatar
nv-dlasalle committed
1100
@parametrize_idtype
1101
1102
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
1103
def test_edge_conv_bi(g, idtype, out_dim):
1104
1105
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1106
    edge_conv = nn.EdgeConv(5, out_dim).to(ctx)
1107
    print(edge_conv)
1108
    h0 = F.randn((g.number_of_src_nodes(), 5))
1109
1110
    x0 = F.randn((g.number_of_dst_nodes(), 5))
    h1 = edge_conv(g, (h0, x0))
1111
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)
Mufei Li's avatar
Mufei Li committed
1112

1113

nv-dlasalle's avatar
nv-dlasalle committed
1114
@parametrize_idtype
1115
1116
1117
1118
1119
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
1120
def test_dotgat_conv(g, idtype, out_dim, num_heads):
1121
1122
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1123
    dotgat = nn.DotGatConv(5, out_dim, num_heads)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1124
    feat = F.randn((g.number_of_src_nodes(), 5))
1125
    dotgat = dotgat.to(ctx)
Mufei Li's avatar
Mufei Li committed
1126

1127
1128
    # test pickle
    th.save(dotgat, tmp_buffer)
Mufei Li's avatar
Mufei Li committed
1129

1130
    h = dotgat(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1131
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
1132
    _, a = dotgat(g, feat, get_attention=True)
1133
    assert a.shape == (g.number_of_edges(), num_heads, 1)
1134

1135

nv-dlasalle's avatar
nv-dlasalle committed
1136
@parametrize_idtype
1137
1138
1139
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
1140
def test_dotgat_conv_bi(g, idtype, out_dim, num_heads):
1141
1142
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1143
    dotgat = nn.DotGatConv((5, 5), out_dim, num_heads)
1144
1145
1146
1147
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
1148
1149
    dotgat = dotgat.to(ctx)
    h = dotgat(g, feat)
1150
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
1151
    _, a = dotgat(g, feat, get_attention=True)
1152
    assert a.shape == (g.number_of_edges(), num_heads, 1)
1153

1154
1155

@pytest.mark.parametrize("out_dim", [1, 2])
1156
def test_dense_cheb_conv(out_dim):
1157
1158
1159
    for k in range(1, 4):
        ctx = F.ctx()
        g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
1160
        g = g.to(F.ctx())
1161
        adj = g.adjacency_matrix(transpose=True, ctx=ctx).to_dense()
1162
1163
        cheb = nn.ChebConv(5, out_dim, k, None)
        dense_cheb = nn.DenseChebConv(5, out_dim, k)
1164
        # for i in range(len(cheb.fc)):
Axel Nilsson's avatar
Axel Nilsson committed
1165
        #    dense_cheb.W.data[i] = cheb.fc[i].weight.data.t()
1166
1167
1168
        dense_cheb.W.data = cheb.linear.weight.data.transpose(-1, -2).view(
            k, 5, out_dim
        )
Axel Nilsson's avatar
Axel Nilsson committed
1169
1170
        if cheb.linear.bias is not None:
            dense_cheb.bias.data = cheb.linear.bias.data
1171
        feat = F.randn((100, 5))
1172
1173
        cheb = cheb.to(ctx)
        dense_cheb = dense_cheb.to(ctx)
1174
1175
        out_cheb = cheb(g, feat, [2.0])
        out_dense_cheb = dense_cheb(adj, feat, 2.0)
Axel Nilsson's avatar
Axel Nilsson committed
1176
        print(k, out_cheb, out_dense_cheb)
1177
1178
        assert F.allclose(out_cheb, out_dense_cheb)

1179

1180
1181
def test_sequential():
    ctx = F.ctx()
1182

1183
1184
1185
1186
1187
1188
1189
    # Test single graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat, e_feat):
            graph = graph.local_var()
1190
1191
1192
1193
1194
            graph.ndata["h"] = n_feat
            graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
            n_feat += graph.ndata["h"]
            graph.apply_edges(fn.u_add_v("h", "h", "e"))
            e_feat += graph.edata["e"]
1195
1196
1197
1198
1199
            return n_feat, e_feat

    g = dgl.DGLGraph()
    g.add_nodes(3)
    g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
1200
    g = g.to(F.ctx())
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    n_feat = F.randn((3, 4))
    e_feat = F.randn((9, 4))
    net = net.to(ctx)
    n_feat, e_feat = net(g, n_feat, e_feat)
    assert n_feat.shape == (3, 4)
    assert e_feat.shape == (9, 4)

    # Test multiple graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat):
            graph = graph.local_var()
1216
1217
1218
            graph.ndata["h"] = n_feat
            graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
            n_feat += graph.ndata["h"]
1219
1220
            return n_feat.view(graph.number_of_nodes() // 2, 2, -1).sum(1)

1221
1222
1223
    g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05)).to(F.ctx())
    g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2)).to(F.ctx())
    g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8)).to(F.ctx())
1224
1225
1226
1227
1228
1229
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    net = net.to(ctx)
    n_feat = F.randn((32, 4))
    n_feat = net([g1, g2, g3], n_feat)
    assert n_feat.shape == (4, 4)

1230

nv-dlasalle's avatar
nv-dlasalle committed
1231
@parametrize_idtype
1232
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
1233
1234
def test_atomic_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
1235
1236
1237
1238
1239
1240
    aconv = nn.AtomicConv(
        interaction_cutoffs=F.tensor([12.0, 12.0]),
        rbf_kernel_means=F.tensor([0.0, 2.0]),
        rbf_kernel_scaling=F.tensor([4.0, 4.0]),
        features_to_use=F.tensor([6.0, 8.0]),
    )
1241
1242
1243
1244
1245

    ctx = F.ctx()
    if F.gpu_ctx():
        aconv = aconv.to(ctx)

1246
    feat = F.randn((g.number_of_nodes(), 1))
1247
1248
1249
    dist = F.randn((g.number_of_edges(), 1))

    h = aconv(g, feat, dist)
1250

1251
1252
1253
    # current we only do shape check
    assert h.shape[-1] == 4

1254

nv-dlasalle's avatar
nv-dlasalle committed
1255
@parametrize_idtype
1256
1257
1258
1259
@pytest.mark.parametrize(
    "g", get_cases(["homo", "bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 3])
1260
def test_cf_conv(g, idtype, out_dim):
1261
    g = g.astype(idtype).to(F.ctx())
1262
1263
1264
    cfconv = nn.CFConv(
        node_in_feats=2, edge_in_feats=3, hidden_feats=2, out_feats=out_dim
    )
1265
1266
1267
1268
1269

    ctx = F.ctx()
    if F.gpu_ctx():
        cfconv = cfconv.to(ctx)

1270
    src_feats = F.randn((g.number_of_src_nodes(), 2))
1271
    edge_feats = F.randn((g.number_of_edges(), 3))
1272
1273
1274
1275
1276
1277
1278
    h = cfconv(g, src_feats, edge_feats)
    # current we only do shape check
    assert h.shape[-1] == out_dim

    # case for bipartite graphs
    dst_feats = F.randn((g.number_of_dst_nodes(), 3))
    h = cfconv(g, (src_feats, dst_feats), edge_feats)
1279
    # current we only do shape check
1280
    assert h.shape[-1] == out_dim
1281

1282

1283
1284
1285
1286
1287
1288
def myagg(alist, dsttype):
    rst = alist[0]
    for i in range(1, len(alist)):
        rst = rst + (i + 1) * alist[i]
    return rst

1289

nv-dlasalle's avatar
nv-dlasalle committed
1290
@parametrize_idtype
1291
1292
@pytest.mark.parametrize("agg", ["sum", "max", "min", "mean", "stack", myagg])
@pytest.mark.parametrize("canonical_keys", [False, True])
1293
def test_hetero_conv(agg, idtype, canonical_keys):
1294
1295
1296
1297
1298
1299
1300
1301
1302
    g = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 0, 2, 1], [1, 2, 1, 3]),
            ("user", "plays", "game"): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]),
            ("store", "sells", "game"): ([0, 0, 1, 1], [0, 3, 1, 2]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
1303
    if not canonical_keys:
1304
1305
1306
1307
1308
1309
1310
1311
        conv = nn.HeteroGraphConv(
            {
                "follows": nn.GraphConv(2, 3, allow_zero_in_degree=True),
                "plays": nn.GraphConv(2, 4, allow_zero_in_degree=True),
                "sells": nn.GraphConv(3, 4, allow_zero_in_degree=True),
            },
            agg,
        )
1312
    else:
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
        conv = nn.HeteroGraphConv(
            {
                ("user", "follows", "user"): nn.GraphConv(
                    2, 3, allow_zero_in_degree=True
                ),
                ("user", "plays", "game"): nn.GraphConv(
                    2, 4, allow_zero_in_degree=True
                ),
                ("store", "sells", "game"): nn.GraphConv(
                    3, 4, allow_zero_in_degree=True
                ),
            },
            agg,
        )
1327

1328
    conv = conv.to(F.ctx())
1329
1330
1331
1332

    # test pickle
    th.save(conv, tmp_buffer)

1333
1334
1335
1336
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))

1337
1338
1339
1340
1341
    h = conv(g, {"user": uf, "game": gf, "store": sf})
    assert set(h.keys()) == {"user", "game"}
    if agg != "stack":
        assert h["user"].shape == (4, 3)
        assert h["game"].shape == (4, 4)
1342
    else:
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
        assert h["user"].shape == (4, 1, 3)
        assert h["game"].shape == (4, 2, 4)

    block = dgl.to_block(
        g.to(F.cpu()), {"user": [0, 1, 2, 3], "game": [0, 1, 2, 3], "store": []}
    ).to(F.ctx())
    h = conv(
        block,
        (
            {"user": uf, "game": gf, "store": sf},
            {"user": uf, "game": gf, "store": sf[0:0]},
        ),
    )
    assert set(h.keys()) == {"user", "game"}
    if agg != "stack":
        assert h["user"].shape == (4, 3)
        assert h["game"].shape == (4, 4)
1360
    else:
1361
1362
1363
1364
1365
1366
1367
1368
        assert h["user"].shape == (4, 1, 3)
        assert h["game"].shape == (4, 2, 4)

    h = conv(block, {"user": uf, "game": gf, "store": sf})
    assert set(h.keys()) == {"user", "game"}
    if agg != "stack":
        assert h["user"].shape == (4, 3)
        assert h["game"].shape == (4, 4)
1369
    else:
1370
1371
        assert h["user"].shape == (4, 1, 3)
        assert h["game"].shape == (4, 2, 4)
1372
1373
1374
1375
1376
1377
1378
1379
1380

    # test with mod args
    class MyMod(th.nn.Module):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.carg2 = 0
            self.s1 = s1
            self.s2 = s2
1381

1382
1383
1384
1385
1386
1387
        def forward(self, g, h, arg1=None, *, arg2=None):
            if arg1 is not None:
                self.carg1 += 1
            if arg2 is not None:
                self.carg2 += 1
            return th.zeros((g.number_of_dst_nodes(), self.s2))
1388

1389
1390
1391
    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
1392
1393
1394
    conv = nn.HeteroGraphConv(
        {"follows": mod1, "plays": mod2, "sells": mod3}, agg
    )
1395
    conv = conv.to(F.ctx())
1396
1397
1398
1399
1400
1401
1402
1403
    mod_args = {"follows": (1,), "plays": (1,)}
    mod_kwargs = {"sells": {"arg2": "abc"}}
    h = conv(
        g,
        {"user": uf, "game": gf, "store": sf},
        mod_args=mod_args,
        mod_kwargs=mod_kwargs,
    )
1404
1405
1406
1407
1408
1409
1410
    assert mod1.carg1 == 1
    assert mod1.carg2 == 0
    assert mod2.carg1 == 1
    assert mod2.carg2 == 0
    assert mod3.carg1 == 0
    assert mod3.carg2 == 1

1411
    # conv on graph without any edges
1412
    for etype in g.etypes:
1413
        g = dgl.remove_edges(g, g.edges(form="eid", etype=etype), etype=etype)
1414
    assert g.num_edges() == 0
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
    h = conv(g, {"user": uf, "game": gf, "store": sf})
    assert set(h.keys()) == {"user", "game"}

    block = dgl.to_block(
        g.to(F.cpu()), {"user": [0, 1, 2, 3], "game": [0, 1, 2, 3], "store": []}
    ).to(F.ctx())
    h = conv(
        block,
        (
            {"user": uf, "game": gf, "store": sf},
            {"user": uf, "game": gf, "store": sf[0:0]},
        ),
    )
    assert set(h.keys()) == {"user", "game"}
1429
1430


1431
@pytest.mark.parametrize("out_dim", [1, 2, 100])
1432
1433
def test_hetero_linear(out_dim):
    in_feats = {
1434
1435
        "user": F.randn((2, 1)),
        ("user", "follows", "user"): F.randn((3, 2)),
1436
1437
    }

1438
1439
1440
    layer = nn.HeteroLinear(
        {"user": 1, ("user", "follows", "user"): 2}, out_dim
    )
1441
1442
    layer = layer.to(F.ctx())
    out_feats = layer(in_feats)
1443
1444
1445
    assert out_feats["user"].shape == (2, out_dim)
    assert out_feats[("user", "follows", "user")].shape == (3, out_dim)

1446

1447
@pytest.mark.parametrize("out_dim", [1, 2, 100])
1448
def test_hetero_embedding(out_dim):
1449
1450
1451
    layer = nn.HeteroEmbedding(
        {"user": 2, ("user", "follows", "user"): 3}, out_dim
    )
1452
1453
1454
    layer = layer.to(F.ctx())

    embeds = layer.weight
1455
1456
    assert embeds["user"].shape == (2, out_dim)
    assert embeds[("user", "follows", "user")].shape == (3, out_dim)
1457

YJ-Zhao's avatar
YJ-Zhao committed
1458
1459
    layer.reset_parameters()
    embeds = layer.weight
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
    assert embeds["user"].shape == (2, out_dim)
    assert embeds[("user", "follows", "user")].shape == (3, out_dim)

    embeds = layer(
        {
            "user": F.tensor([0], dtype=F.int64),
            ("user", "follows", "user"): F.tensor([0, 2], dtype=F.int64),
        }
    )
    assert embeds["user"].shape == (1, out_dim)
    assert embeds[("user", "follows", "user")].shape == (2, out_dim)
YJ-Zhao's avatar
YJ-Zhao committed
1471

1472

nv-dlasalle's avatar
nv-dlasalle committed
1473
@parametrize_idtype
1474
1475
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
Mufei Li's avatar
Mufei Li committed
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
def test_gnnexplainer(g, idtype, out_dim):
    g = g.astype(idtype).to(F.ctx())
    feat = F.randn((g.num_nodes(), 5))

    class Model(th.nn.Module):
        def __init__(self, in_feats, out_feats, graph=False):
            super(Model, self).__init__()
            self.linear = th.nn.Linear(in_feats, out_feats)
            if graph:
                self.pool = nn.AvgPooling()
            else:
                self.pool = None

        def forward(self, graph, feat, eweight=None):
            with graph.local_scope():
                feat = self.linear(feat)
1492
                graph.ndata["h"] = feat
Mufei Li's avatar
Mufei Li committed
1493
                if eweight is None:
1494
                    graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
Mufei Li's avatar
Mufei Li committed
1495
                else:
1496
1497
1498
1499
                    graph.edata["w"] = eweight
                    graph.update_all(
                        fn.u_mul_e("h", "w", "m"), fn.sum("m", "h")
                    )
Mufei Li's avatar
Mufei Li committed
1500
1501

                if self.pool:
1502
                    return self.pool(graph, graph.ndata["h"])
Mufei Li's avatar
Mufei Li committed
1503
                else:
1504
                    return graph.ndata["h"]
Mufei Li's avatar
Mufei Li committed
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517

    # Explain node prediction
    model = Model(5, out_dim)
    model = model.to(F.ctx())
    explainer = nn.GNNExplainer(model, num_hops=1)
    new_center, sg, feat_mask, edge_mask = explainer.explain_node(0, g, feat)

    # Explain graph prediction
    model = Model(5, out_dim, graph=True)
    model = model.to(F.ctx())
    explainer = nn.GNNExplainer(model, num_hops=1)
    feat_mask, edge_mask = explainer.explain_graph(g, feat)

1518
1519
1520
1521
1522

@pytest.mark.parametrize("g", get_cases(["hetero"], exclude=["zero-degree"]))
@pytest.mark.parametrize("idtype", [F.int64])
@pytest.mark.parametrize("input_dim", [5])
@pytest.mark.parametrize("output_dim", [1, 2])
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
def test_heterognnexplainer(g, idtype, input_dim, output_dim):
    g = g.astype(idtype).to(F.ctx())
    device = g.device

    # add self-loop and reverse edges
    transform1 = dgl.transforms.AddSelfLoop(new_etypes=True)
    g = transform1(g)
    transform2 = dgl.transforms.AddReverse(copy_edata=True)
    g = transform2(g)

1533
1534
1535
1536
    feat = {
        ntype: th.zeros((g.num_nodes(ntype), input_dim), device=device)
        for ntype in g.ntypes
    }
1537
1538
1539
1540

    class Model(th.nn.Module):
        def __init__(self, in_dim, num_classes, canonical_etypes, graph=False):
            super(Model, self).__init__()
1541
1542
1543
1544
1545
1546
1547
            self.graph = graph
            self.etype_weights = th.nn.ModuleDict(
                {
                    "_".join(c_etype): th.nn.Linear(in_dim, num_classes)
                    for c_etype in canonical_etypes
                }
            )
1548
1549
1550
1551
1552
1553

        def forward(self, graph, feat, eweight=None):
            with graph.local_scope():
                c_etype_func_dict = {}
                for c_etype in graph.canonical_etypes:
                    src_type, etype, dst_type = c_etype
1554
1555
                    wh = self.etype_weights["_".join(c_etype)](feat[src_type])
                    graph.nodes[src_type].data[f"h_{c_etype}"] = wh
1556
                    if eweight is None:
1557
1558
1559
1560
                        c_etype_func_dict[c_etype] = (
                            fn.copy_u(f"h_{c_etype}", "m"),
                            fn.mean("m", "h"),
                        )
1561
                    else:
1562
                        graph.edges[c_etype].data["w"] = eweight[c_etype]
1563
                        c_etype_func_dict[c_etype] = (
1564
1565
1566
1567
                            fn.u_mul_e(f"h_{c_etype}", "w", "m"),
                            fn.mean("m", "h"),
                        )
                graph.multi_update_all(c_etype_func_dict, "sum")
1568
1569
1570
1571
                if self.graph:
                    hg = 0
                    for ntype in graph.ntypes:
                        if graph.num_nodes(ntype):
1572
                            hg = hg + dgl.mean_nodes(graph, "h", ntype=ntype)
1573
1574
1575

                    return hg
                else:
1576
                    return graph.ndata["h"]
1577
1578
1579
1580
1581
1582

    # Explain node prediction
    model = Model(input_dim, output_dim, g.canonical_etypes)
    model = model.to(F.ctx())
    ntype = g.ntypes[0]
    explainer = nn.explain.HeteroGNNExplainer(model, num_hops=1)
1583
1584
1585
    new_center, sg, feat_mask, edge_mask = explainer.explain_node(
        ntype, 0, g, feat
    )
1586
1587
1588
1589
1590
1591
1592
1593

    # Explain graph prediction
    model = Model(input_dim, output_dim, g.canonical_etypes, graph=True)
    model = model.to(F.ctx())
    explainer = nn.explain.HeteroGNNExplainer(model, num_hops=1)
    feat_mask, edge_mask = explainer.explain_graph(g, feat)


1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
@parametrize_idtype
@pytest.mark.parametrize(
    "g",
    get_cases(
        ["homo"],
        exclude=[
            "zero-degree",
            "homo-zero-degree",
            "has_feature",
            "has_scalar_e_feature",
            "row_sorted",
            "col_sorted",
            "batched",
        ],
    ),
)
@pytest.mark.parametrize("n_classes", [2])
def test_subgraphx(g, idtype, n_classes):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    feat = F.randn((g.num_nodes(), 5))

    class Model(th.nn.Module):
        def __init__(self, in_dim, n_classes):
            super().__init__()
            self.conv = nn.GraphConv(in_dim, n_classes)
            self.pool = nn.AvgPooling()

        def forward(self, g, h):
            h = th.nn.functional.relu(self.conv(g, h))
            return self.pool(g, h)

    model = Model(feat.shape[1], n_classes)
    model = model.to(ctx)
    explainer = nn.SubgraphX(
        model, num_hops=1, shapley_steps=20, num_rollouts=5, coef=2.0
    )
    explainer.explain_graph(g, feat, target_class=0)


Mufei Li's avatar
Mufei Li committed
1634
1635
1636
1637
1638
1639
def test_jumping_knowledge():
    ctx = F.ctx()
    num_layers = 2
    num_nodes = 3
    num_feats = 4

1640
1641
1642
    feat_list = [
        th.randn((num_nodes, num_feats)).to(ctx) for _ in range(num_layers)
    ]
Mufei Li's avatar
Mufei Li committed
1643

1644
    model = nn.JumpingKnowledge("cat").to(ctx)
Mufei Li's avatar
Mufei Li committed
1645
1646
1647
    model.reset_parameters()
    assert model(feat_list).shape == (num_nodes, num_layers * num_feats)

1648
    model = nn.JumpingKnowledge("max").to(ctx)
Mufei Li's avatar
Mufei Li committed
1649
1650
1651
    model.reset_parameters()
    assert model(feat_list).shape == (num_nodes, num_feats)

1652
    model = nn.JumpingKnowledge("lstm", num_feats, num_layers).to(ctx)
Mufei Li's avatar
Mufei Li committed
1653
1654
1655
    model.reset_parameters()
    assert model(feat_list).shape == (num_nodes, num_feats)

1656
1657

@pytest.mark.parametrize("op", ["dot", "cos", "ele", "cat"])
Mufei Li's avatar
Mufei Li committed
1658
1659
1660
1661
1662
1663
1664
1665
1666
def test_edge_predictor(op):
    ctx = F.ctx()
    num_pairs = 3
    in_feats = 4
    out_feats = 5
    h_src = th.randn((num_pairs, in_feats)).to(ctx)
    h_dst = th.randn((num_pairs, in_feats)).to(ctx)

    pred = nn.EdgePredictor(op)
1667
    if op in ["dot", "cos"]:
Mufei Li's avatar
Mufei Li committed
1668
        assert pred(h_src, h_dst).shape == (num_pairs, 1)
1669
    elif op == "ele":
Mufei Li's avatar
Mufei Li committed
1670
1671
1672
1673
1674
1675
        assert pred(h_src, h_dst).shape == (num_pairs, in_feats)
    else:
        assert pred(h_src, h_dst).shape == (num_pairs, 2 * in_feats)
    pred = nn.EdgePredictor(op, in_feats, out_feats, bias=True).to(ctx)
    assert pred(h_src, h_dst).shape == (num_pairs, out_feats)

Mufei Li's avatar
Mufei Li committed
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690

def test_ke_score_funcs():
    ctx = F.ctx()
    num_edges = 30
    num_rels = 3
    nfeats = 4

    h_src = th.randn((num_edges, nfeats)).to(ctx)
    h_dst = th.randn((num_edges, nfeats)).to(ctx)
    rels = th.randint(low=0, high=num_rels, size=(num_edges,)).to(ctx)

    score_func = nn.TransE(num_rels=num_rels, feats=nfeats).to(ctx)
    score_func.reset_parameters()
    score_func(h_src, h_dst, rels).shape == (num_edges)

1691
1692
1693
    score_func = nn.TransR(
        num_rels=num_rels, rfeats=nfeats - 1, nfeats=nfeats
    ).to(ctx)
Mufei Li's avatar
Mufei Li committed
1694
1695
1696
1697
    score_func.reset_parameters()
    score_func(h_src, h_dst, rels).shape == (num_edges)


1698
def test_twirls():
1699
    g = dgl.graph(([0, 1, 2, 3, 2, 5], [1, 2, 3, 4, 0, 3]))
1700
    feat = th.ones(6, 10)
1701
1702
1703
1704
    conv = nn.TWIRLSConv(10, 2, 128, prop_step=64)
    res = conv(g, feat)
    assert res.size() == (6, 2)

1705

1706
1707
1708
1709
@pytest.mark.parametrize("feat_size", [4, 32])
@pytest.mark.parametrize(
    "regularizer,num_bases", [(None, None), ("basis", 4), ("bdd", 4)]
)
1710
1711
1712
def test_typed_linear(feat_size, regularizer, num_bases):
    dev = F.ctx()
    num_types = 5
1713
1714
1715
1716
1717
1718
1719
    lin = nn.TypedLinear(
        feat_size,
        feat_size * 2,
        5,
        regularizer=regularizer,
        num_bases=num_bases,
    ).to(dev)
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
    print(lin)
    x = th.randn(100, feat_size).to(dev)
    x_type = th.randint(0, 5, (100,)).to(dev)
    x_type_sorted, idx = th.sort(x_type)
    _, rev_idx = th.sort(idx)
    x_sorted = x[idx]

    # test unsorted
    y = lin(x, x_type)
    assert y.shape == (100, feat_size * 2)
    # test sorted
    y_sorted = lin(x_sorted, x_type_sorted, sorted_by_type=True)
    assert y_sorted.shape == (100, feat_size * 2)

    assert th.allclose(y, y_sorted[rev_idx], atol=1e-4, rtol=1e-4)
1735

1736

nv-dlasalle's avatar
nv-dlasalle committed
1737
@parametrize_idtype
1738
1739
@pytest.mark.parametrize("in_size", [4])
@pytest.mark.parametrize("num_heads", [1])
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
def test_hgt(idtype, in_size, num_heads):
    dev = F.ctx()
    num_etypes = 5
    num_ntypes = 2
    head_size = in_size // num_heads

    g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.01))
    g = g.astype(idtype).to(dev)
    etype = th.tensor([i % num_etypes for i in range(g.num_edges())]).to(dev)
    ntype = th.tensor([i % num_ntypes for i in range(g.num_nodes())]).to(dev)
    x = th.randn(g.num_nodes(), in_size).to(dev)
1751

1752
1753
1754
    m = nn.HGTConv(in_size, head_size, num_heads, num_ntypes, num_etypes).to(
        dev
    )
1755
1756
1757
1758
1759
1760
1761

    y = m(g, x, ntype, etype)
    assert y.shape == (g.num_nodes(), head_size * num_heads)
    # presorted
    sorted_ntype, idx_nt = th.sort(ntype)
    sorted_etype, idx_et = th.sort(etype)
    _, rev_idx = th.sort(idx_nt)
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
    g.ndata["t"] = ntype
    g.ndata["x"] = x
    g.edata["t"] = etype
    sorted_g = dgl.reorder_graph(
        g,
        node_permute_algo="custom",
        edge_permute_algo="custom",
        permute_config={
            "nodes_perm": idx_nt.to(idtype),
            "edges_perm": idx_et.to(idtype),
        },
    )
    print(sorted_g.ndata["t"])
    print(sorted_g.edata["t"])
    sorted_x = sorted_g.ndata["x"]
    sorted_y = m(
        sorted_g, sorted_x, sorted_ntype, sorted_etype, presorted=False
    )
1780
    assert sorted_y.shape == (g.num_nodes(), head_size * num_heads)
dddg617's avatar
dddg617 committed
1781
    # mini-batch
1782
    train_idx = th.randperm(100, dtype=idtype)[:10]
dddg617's avatar
dddg617 committed
1783
    sampler = dgl.dataloading.NeighborSampler([-1])
1784
1785
1786
    train_loader = dgl.dataloading.DataLoader(
        g, train_idx.to(dev), sampler, batch_size=8, device=dev, shuffle=True
    )
dddg617's avatar
dddg617 committed
1787
1788
1789
1790
1791
1792
1793
1794
    (input_nodes, output_nodes, block) = next(iter(train_loader))
    block = block[0]
    x = x[input_nodes.to(th.long)]
    ntype = ntype[input_nodes.to(th.long)]
    edge = block.edata[dgl.EID]
    etype = etype[edge.to(th.long)]
    y = m(block, x, ntype, etype)
    assert y.shape == (block.number_of_dst_nodes(), head_size * num_heads)
1795
    # TODO(minjie): enable the following check
1796
1797
    # assert th.allclose(y, sorted_y[rev_idx], atol=1e-4, rtol=1e-4)

1798

1799
1800
@pytest.mark.parametrize("self_loop", [True, False])
@pytest.mark.parametrize("get_distances", [True, False])
1801
def test_radius_graph(self_loop, get_distances):
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
    pos = th.tensor(
        [
            [0.1, 0.3, 0.4],
            [0.5, 0.2, 0.1],
            [0.7, 0.9, 0.5],
            [0.3, 0.2, 0.5],
            [0.2, 0.8, 0.2],
            [0.9, 0.2, 0.1],
            [0.7, 0.4, 0.4],
            [0.2, 0.1, 0.6],
            [0.5, 0.3, 0.5],
            [0.4, 0.2, 0.6],
        ]
    )
1816
1817
1818
1819
1820
1821
1822
1823
1824

    rg = nn.RadiusGraph(0.3, self_loop=self_loop)

    if get_distances:
        g, dists = rg(pos, get_distances=get_distances)
    else:
        g = rg(pos)

    if self_loop:
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
        src_target = th.tensor(
            [
                0,
                0,
                1,
                2,
                3,
                3,
                3,
                3,
                3,
                4,
                5,
                6,
                6,
                7,
                7,
                7,
                8,
                8,
                8,
                8,
                9,
                9,
                9,
                9,
            ]
        )
        dst_target = th.tensor(
            [
                0,
                3,
                1,
                2,
                0,
                3,
                7,
                8,
                9,
                4,
                5,
                6,
                8,
                3,
                7,
                9,
                3,
                6,
                8,
                9,
                3,
                7,
                8,
                9,
            ]
        )
1881
1882

        if get_distances:
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
            dists_target = th.tensor(
                [
                    [0.0000],
                    [0.2449],
                    [0.0000],
                    [0.0000],
                    [0.2449],
                    [0.0000],
                    [0.1732],
                    [0.2236],
                    [0.1414],
                    [0.0000],
                    [0.0000],
                    [0.0000],
                    [0.2449],
                    [0.1732],
                    [0.0000],
                    [0.2236],
                    [0.2236],
                    [0.2449],
                    [0.0000],
                    [0.1732],
                    [0.1414],
                    [0.2236],
                    [0.1732],
                    [0.0000],
                ]
            )
1911
1912
1913
1914
1915
    else:
        src_target = th.tensor([0, 3, 3, 3, 3, 6, 7, 7, 8, 8, 8, 9, 9, 9])
        dst_target = th.tensor([3, 0, 7, 8, 9, 8, 3, 9, 3, 6, 9, 3, 7, 8])

        if get_distances:
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
            dists_target = th.tensor(
                [
                    [0.2449],
                    [0.2449],
                    [0.1732],
                    [0.2236],
                    [0.1414],
                    [0.2449],
                    [0.1732],
                    [0.2236],
                    [0.2236],
                    [0.2449],
                    [0.1732],
                    [0.1414],
                    [0.2236],
                    [0.1732],
                ]
            )
1934
1935
1936
1937
1938
1939
1940
1941
1942

    src, dst = g.edges()

    assert th.equal(src, src_target)
    assert th.equal(dst, dst_target)

    if get_distances:
        assert th.allclose(dists, dists_target, rtol=1e-03)

1943

nv-dlasalle's avatar
nv-dlasalle committed
1944
@parametrize_idtype
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
def test_group_rev_res(idtype):
    dev = F.ctx()

    num_nodes = 5
    num_edges = 20
    feats = 32
    groups = 2
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    h = th.randn(num_nodes, feats).to(dev)
    conv = nn.GraphConv(feats // groups, feats // groups)
    model = nn.GroupRevRes(conv, groups).to(dev)
1956
1957
    result = model(g, h)
    result.sum().backward()
rudongyu's avatar
rudongyu committed
1958

1959
1960
1961
1962
1963

@pytest.mark.parametrize("in_size", [16, 32])
@pytest.mark.parametrize("hidden_size", [16, 32])
@pytest.mark.parametrize("out_size", [16, 32])
@pytest.mark.parametrize("edge_feat_size", [16, 10, 0])
rudongyu's avatar
rudongyu committed
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
def test_egnn_conv(in_size, hidden_size, out_size, edge_feat_size):
    dev = F.ctx()
    num_nodes = 5
    num_edges = 20
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    h = th.randn(num_nodes, in_size).to(dev)
    x = th.randn(num_nodes, 3).to(dev)
    e = th.randn(num_edges, edge_feat_size).to(dev)
    model = nn.EGNNConv(in_size, hidden_size, out_size, edge_feat_size).to(dev)
    model(g, h, x, e)

1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004

@pytest.mark.parametrize("in_size", [16, 32])
@pytest.mark.parametrize("out_size", [16, 32])
@pytest.mark.parametrize(
    "aggregators",
    [
        ["mean", "max", "sum"],
        ["min", "std", "var"],
        ["moment3", "moment4", "moment5"],
    ],
)
@pytest.mark.parametrize(
    "scalers", [["identity"], ["amplification", "attenuation"]]
)
@pytest.mark.parametrize("delta", [2.5, 7.4])
@pytest.mark.parametrize("dropout", [0.0, 0.1])
@pytest.mark.parametrize("num_towers", [1, 4])
@pytest.mark.parametrize("edge_feat_size", [16, 0])
@pytest.mark.parametrize("residual", [True, False])
def test_pna_conv(
    in_size,
    out_size,
    aggregators,
    scalers,
    delta,
    dropout,
    num_towers,
    edge_feat_size,
    residual,
):
rudongyu's avatar
rudongyu committed
2005
2006
2007
2008
2009
2010
    dev = F.ctx()
    num_nodes = 5
    num_edges = 20
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    h = th.randn(num_nodes, in_size).to(dev)
    e = th.randn(num_edges, edge_feat_size).to(dev)
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
    model = nn.PNAConv(
        in_size,
        out_size,
        aggregators,
        scalers,
        delta,
        dropout,
        num_towers,
        edge_feat_size,
        residual,
    ).to(dev)
rudongyu's avatar
rudongyu committed
2022
    model(g, h, edge_feat=e)
2023

2024
2025
2026
2027
2028
2029
2030

@pytest.mark.parametrize("k", [3, 5])
@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0])
@pytest.mark.parametrize("norm_type", ["sym", "row"])
@pytest.mark.parametrize("clamp", [True, False])
@pytest.mark.parametrize("normalize", [True, False])
@pytest.mark.parametrize("reset", [True, False])
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
def test_label_prop(k, alpha, norm_type, clamp, normalize, reset):
    dev = F.ctx()
    num_nodes = 5
    num_edges = 20
    num_classes = 4
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    labels = th.tensor([0, 2, 1, 3, 0]).long().to(dev)
    ml_labels = th.rand(num_nodes, num_classes).to(dev) > 0.7
    mask = th.tensor([0, 1, 1, 1, 0]).bool().to(dev)
    model = nn.LabelPropagation(k, alpha, norm_type, clamp, normalize, reset)
    model(g, labels, mask)
    # multi-label case
    model(g, ml_labels, mask)

2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056

@pytest.mark.parametrize("in_size", [16])
@pytest.mark.parametrize("out_size", [16, 32])
@pytest.mark.parametrize(
    "aggregators", [["mean", "max", "dir2-av"], ["min", "std", "dir1-dx"]]
)
@pytest.mark.parametrize("scalers", [["amplification", "attenuation"]])
@pytest.mark.parametrize("delta", [2.5])
@pytest.mark.parametrize("edge_feat_size", [16, 0])
def test_dgn_conv(
    in_size, out_size, aggregators, scalers, delta, edge_feat_size
):
2057
2058
2059
2060
2061
2062
    dev = F.ctx()
    num_nodes = 5
    num_edges = 20
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    h = th.randn(num_nodes, in_size).to(dev)
    e = th.randn(num_edges, edge_feat_size).to(dev)
2063
    transform = dgl.LaplacianPE(k=3, feat_name="eig")
2064
    g = transform(g)
2065
2066
2067
2068
2069
2070
2071
2072
2073
    eig = g.ndata["eig"]
    model = nn.DGNConv(
        in_size,
        out_size,
        aggregators,
        scalers,
        delta,
        edge_feat_size=edge_feat_size,
    ).to(dev)
2074
2075
    model(g, h, edge_feat=e, eig_vec=eig)

2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
    aggregators_non_eig = [
        aggr for aggr in aggregators if not aggr.startswith("dir")
    ]
    model = nn.DGNConv(
        in_size,
        out_size,
        aggregators_non_eig,
        scalers,
        delta,
        edge_feat_size=edge_feat_size,
    ).to(dev)
2087
    model(g, h, edge_feat=e)
LuckyLiuM's avatar
LuckyLiuM committed
2088

2089

LuckyLiuM's avatar
LuckyLiuM committed
2090
2091
2092
def test_DeepWalk():
    dev = F.ctx()
    g = dgl.graph(([0, 1, 2, 1, 2, 0], [1, 2, 0, 0, 1, 2]))
2093
2094
2095
    model = nn.DeepWalk(
        g, emb_dim=8, walk_length=2, window_size=1, fast_neg=True, sparse=True
    )
LuckyLiuM's avatar
LuckyLiuM committed
2096
    model = model.to(dev)
2097
2098
2099
    dataloader = DataLoader(
        torch.arange(g.num_nodes()), batch_size=16, collate_fn=model.sample
    )
LuckyLiuM's avatar
LuckyLiuM committed
2100
2101
2102
2103
2104
2105
    optim = SparseAdam(model.parameters(), lr=0.01)
    walk = next(iter(dataloader)).to(dev)
    loss = model(walk)
    loss.backward()
    optim.step()

2106
2107
2108
    model = nn.DeepWalk(
        g, emb_dim=8, walk_length=2, window_size=1, fast_neg=False, sparse=False
    )
LuckyLiuM's avatar
LuckyLiuM committed
2109
    model = model.to(dev)
2110
2111
2112
    dataloader = DataLoader(
        torch.arange(g.num_nodes()), batch_size=16, collate_fn=model.sample
    )
LuckyLiuM's avatar
LuckyLiuM committed
2113
2114
2115
2116
2117
    optim = Adam(model.parameters(), lr=0.01)
    walk = next(iter(dataloader)).to(dev)
    loss = model(walk)
    loss.backward()
    optim.step()
LuckyLiuM's avatar
LuckyLiuM committed
2118

2119
2120
2121
2122

@pytest.mark.parametrize("max_degree", [2, 6])
@pytest.mark.parametrize("embedding_dim", [8, 16])
@pytest.mark.parametrize("direction", ["in", "out", "both"])
2123
def test_degree_encoder(max_degree, embedding_dim, direction):
2124
2125
2126
2127
2128
2129
    g = dgl.graph(
        (
            th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),
            th.tensor([1, 2, 3, 0, 3, 0, 0, 1]),
        )
    )
2130
    # test heterograph
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
    hg = dgl.heterograph(
        {
            ("drug", "interacts", "drug"): (
                th.tensor([0, 1]),
                th.tensor([1, 2]),
            ),
            ("drug", "interacts", "gene"): (
                th.tensor([0, 1]),
                th.tensor([2, 3]),
            ),
            ("drug", "treats", "disease"): (th.tensor([1]), th.tensor([2])),
        }
    )
2144
2145
2146
2147
2148
2149
    model = nn.DegreeEncoder(max_degree, embedding_dim, direction=direction)
    de_g = model(g)
    de_hg = model(hg)
    assert de_g.shape == (4, embedding_dim)
    assert de_hg.shape == (10, embedding_dim)

2150

LuckyLiuM's avatar
LuckyLiuM committed
2151
2152
2153
@parametrize_idtype
def test_MetaPath2Vec(idtype):
    dev = F.ctx()
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
    g = dgl.heterograph(
        {
            ("user", "uc", "company"): ([0, 0, 2, 1, 3], [1, 2, 1, 3, 0]),
            ("company", "cp", "product"): (
                [0, 0, 0, 1, 2, 3],
                [0, 2, 3, 0, 2, 1],
            ),
            ("company", "cu", "user"): ([1, 2, 1, 3, 0], [0, 0, 2, 1, 3]),
            ("product", "pc", "company"): (
                [0, 2, 3, 0, 2, 1],
                [0, 0, 0, 1, 2, 3],
            ),
        },
        idtype=idtype,
        device=dev,
    )
    model = nn.MetaPath2Vec(g, ["uc", "cu"], window_size=1)
LuckyLiuM's avatar
LuckyLiuM committed
2171
2172
2173
    model = model.to(dev)
    embeds = model.node_embed.weight
    assert embeds.shape[0] == g.num_nodes()
2174

2175
2176
2177
2178
2179
2180
2181
2182
2183
2184

@pytest.mark.parametrize("num_layer", [1, 4])
@pytest.mark.parametrize("k", [3, 5])
@pytest.mark.parametrize("lpe_dim", [4, 16])
@pytest.mark.parametrize("n_head", [1, 4])
@pytest.mark.parametrize("batch_norm", [True, False])
@pytest.mark.parametrize("num_post_layer", [0, 1, 2])
def test_LaplacianPosEnc(
    num_layer, k, lpe_dim, n_head, batch_norm, num_post_layer
):
2185
2186
2187
2188
2189
2190
    ctx = F.ctx()
    num_nodes = 4

    EigVals = th.randn((num_nodes, k)).to(ctx)
    EigVecs = th.randn((num_nodes, k)).to(ctx)

2191
2192
2193
    model = nn.LaplacianPosEnc(
        "Transformer", num_layer, k, lpe_dim, n_head, batch_norm, num_post_layer
    ).to(ctx)
2194
2195
    assert model(EigVals, EigVecs).shape == (num_nodes, lpe_dim)

2196
2197
2198
2199
2200
2201
2202
2203
    model = nn.LaplacianPosEnc(
        "DeepSet",
        num_layer,
        k,
        lpe_dim,
        batch_norm=batch_norm,
        num_post_layer=num_post_layer,
    ).to(ctx)
2204
    assert model(EigVals, EigVecs).shape == (num_nodes, lpe_dim)
2205

2206
2207
2208
2209
2210
2211
2212
2213
2214

@pytest.mark.parametrize("feat_size", [128, 512])
@pytest.mark.parametrize("num_heads", [8, 16])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("attn_bias_type", ["add", "mul"])
@pytest.mark.parametrize("attn_drop", [0.1, 0.5])
def test_BiasedMultiheadAttention(
    feat_size, num_heads, bias, attn_bias_type, attn_drop
):
2215
2216
2217
2218
    ndata = th.rand(16, 100, feat_size)
    attn_bias = th.rand(16, 100, 100, num_heads)
    attn_mask = th.rand(16, 100, 100) < 0.5

2219
2220
2221
    net = nn.BiasedMultiheadAttention(
        feat_size, num_heads, bias, attn_bias_type, attn_drop
    )
2222
2223
2224
    out = net(ndata, attn_bias, attn_mask)

    assert out.shape == (16, 100, feat_size)
2225

2226
2227
2228

@pytest.mark.parametrize("attn_bias_type", ["add", "mul"])
@pytest.mark.parametrize("norm_first", [True, False])
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
def test_GraphormerLayer(attn_bias_type, norm_first):
    batch_size = 16
    num_nodes = 100
    feat_size = 512
    num_heads = 8
    nfeat = th.rand(batch_size, num_nodes, feat_size)
    attn_bias = th.rand(batch_size, num_nodes, num_nodes, num_heads)
    attn_mask = th.rand(batch_size, num_nodes, num_nodes) < 0.5

    net = nn.GraphormerLayer(
        feat_size=feat_size,
        hidden_size=2048,
        num_heads=num_heads,
        attn_bias_type=attn_bias_type,
        norm_first=norm_first,
        dropout=0.1,
2245
        activation=th.nn.ReLU(),
2246
2247
2248
2249
2250
    )
    out = net(nfeat, attn_bias, attn_mask)

    assert out.shape == (batch_size, num_nodes, feat_size)

2251
2252
2253
2254

@pytest.mark.parametrize("max_len", [1, 4])
@pytest.mark.parametrize("feat_dim", [16])
@pytest.mark.parametrize("num_heads", [1, 8])
2255
2256
def test_PathEncoder(max_len, feat_dim, num_heads):
    dev = F.ctx()
2257
2258
2259
2260
2261
2262
2263
2264
2265
    g1 = dgl.graph(
        (
            th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),
            th.tensor([1, 2, 3, 0, 3, 0, 0, 1]),
        )
    ).to(dev)
    g2 = dgl.graph(
        (th.tensor([0, 1, 2, 3, 2, 5]), th.tensor([1, 2, 3, 4, 0, 3]))
    ).to(dev)
2266
2267
2268
2269
2270
    bg = dgl.batch([g1, g2])
    edge_feat = th.rand(bg.num_edges(), feat_dim).to(dev)
    model = nn.PathEncoder(max_len, feat_dim, num_heads=num_heads).to(dev)
    bias = model(bg, edge_feat)
    assert bias.shape == (2, 6, 6, num_heads)
2271

2272
2273
2274
2275

@pytest.mark.parametrize("max_dist", [1, 4])
@pytest.mark.parametrize("num_kernels", [8, 16])
@pytest.mark.parametrize("num_heads", [1, 8])
2276
2277
def test_SpatialEncoder(max_dist, num_kernels, num_heads):
    dev = F.ctx()
2278
2279
2280
2281
2282
2283
2284
2285
2286
    g1 = dgl.graph(
        (
            th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),
            th.tensor([1, 2, 3, 0, 3, 0, 0, 1]),
        )
    ).to(dev)
    g2 = dgl.graph(
        (th.tensor([0, 1, 2, 3, 2, 5]), th.tensor([1, 2, 3, 4, 0, 3]))
    ).to(dev)
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
    bg = dgl.batch([g1, g2])
    ndata = th.rand(bg.num_nodes(), 3).to(dev)
    num_nodes = bg.num_nodes()
    node_type = th.randint(0, 512, (num_nodes,)).to(dev)
    model_1 = nn.SpatialEncoder(max_dist, num_heads=num_heads).to(dev)
    model_2 = nn.SpatialEncoder3d(num_kernels, num_heads=num_heads).to(dev)
    model_3 = nn.SpatialEncoder3d(
        num_kernels, num_heads=num_heads, max_node_type=512
    ).to(dev)
    encoding = model_1(bg)
    encoding3d_1 = model_2(bg, ndata)
    encoding3d_2 = model_3(bg, ndata, node_type)
    assert encoding.shape == (2, 6, 6, num_heads)
    assert encoding3d_1.shape == (2, 6, 6, num_heads)
    assert encoding3d_2.shape == (2, 6, 6, num_heads)