test_nn.py 77.9 KB
Newer Older
1
import io
2
3
4
5
6
import pickle
from copy import deepcopy

import backend as F

7
import dgl
8
import dgl.function as fn
9
10
import dgl.nn.pytorch as nn
import networkx as nx
11
import pytest
12
import scipy as sp
LuckyLiuM's avatar
LuckyLiuM committed
13
import torch
14
import torch as th
15
16
from dgl import shortest_dist
from torch.nn.utils.rnn import pad_sequence
17
18
19
20
from torch.optim import Adam, SparseAdam
from torch.utils.data import DataLoader
from utils import parametrize_idtype
from utils.graph_cases import (
21
22
23
24
25
    get_cases,
    random_bipartite,
    random_dglgraph,
    random_graph,
)
26

27
28
tmp_buffer = io.BytesIO()

29

30
31
32
33
34
def _AXWb(A, X, W, b):
    X = th.matmul(X, W)
    Y = th.matmul(A, X.view(X.shape[0], -1)).view_as(X)
    return Y + b

35
36

@pytest.mark.parametrize("out_dim", [1, 2])
37
def test_graph_conv0(out_dim):
38
    g = dgl.DGLGraph(nx.path_graph(3)).to(F.ctx())
39
    ctx = F.ctx()
40
    adj = g.adj_external(transpose=True, ctx=ctx)
41

42
    conv = nn.GraphConv(5, out_dim, norm="none", bias=True)
43
    conv = conv.to(ctx)
44
    print(conv)
45
46
47
48

    # test pickle
    th.save(conv, tmp_buffer)

49
    # test#1: basic
50
    h0 = F.ones((3, 5))
51
    h1 = conv(g, h0)
52
53
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
54
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
55
    # test#2: more-dim
56
    h0 = F.ones((3, 5, 5))
57
    h1 = conv(g, h0)
58
59
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
60
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
61

62
    conv = nn.GraphConv(5, out_dim)
63
    conv = conv.to(ctx)
64
    # test#3: basic
65
    h0 = F.ones((3, 5))
66
    h1 = conv(g, h0)
67
68
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
69
    # test#4: basic
70
    h0 = F.ones((3, 5, 5))
71
    h1 = conv(g, h0)
72
73
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
74

75
    conv = nn.GraphConv(5, out_dim)
76
    conv = conv.to(ctx)
77
    # test#3: basic
78
    h0 = F.ones((3, 5))
79
    h1 = conv(g, h0)
80
81
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
82
    # test#4: basic
83
    h0 = F.ones((3, 5, 5))
84
    h1 = conv(g, h0)
85
86
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
87
88
89
90
91

    # test rest_parameters
    old_weight = deepcopy(conv.weight.data)
    conv.reset_parameters()
    new_weight = conv.weight.data
92
    assert not F.allclose(old_weight, new_weight)
93

94

nv-dlasalle's avatar
nv-dlasalle committed
95
@parametrize_idtype
96
97
98
99
100
101
102
@pytest.mark.parametrize(
    "g", get_cases(["homo", "bipartite"], exclude=["zero-degree", "dglgraph"])
)
@pytest.mark.parametrize("norm", ["none", "both", "right", "left"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("out_dim", [1, 2])
103
def test_graph_conv(idtype, g, norm, weight, bias, out_dim):
104
105
    # Test one tensor input
    g = g.astype(idtype).to(F.ctx())
106
107
108
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(
        F.ctx()
    )
109
    ext_w = F.randn((5, out_dim)).to(F.ctx())
110
111
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
112
113
    h = F.randn((nsrc, 5)).to(F.ctx())
    if weight:
114
        h_out = conv(g, h)
115
    else:
116
        h_out = conv(g, h, weight=ext_w)
117
    assert h_out.shape == (ndst, out_dim)
118

119

nv-dlasalle's avatar
nv-dlasalle committed
120
@parametrize_idtype
121
122
123
124
125
126
127
128
@pytest.mark.parametrize(
    "g",
    get_cases(["has_scalar_e_feature"], exclude=["zero-degree", "dglgraph"]),
)
@pytest.mark.parametrize("norm", ["none", "both", "right"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("out_dim", [1, 2])
129
def test_graph_conv_e_weight(idtype, g, norm, weight, bias, out_dim):
130
    g = g.astype(idtype).to(F.ctx())
131
132
133
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(
        F.ctx()
    )
134
    ext_w = F.randn((5, out_dim)).to(F.ctx())
135
136
137
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
138
    e_w = g.edata["scalar_w"]
139
140
141
142
    if weight:
        h_out = conv(g, h, edge_weight=e_w)
    else:
        h_out = conv(g, h, weight=ext_w, edge_weight=e_w)
143
    assert h_out.shape == (ndst, out_dim)
144

145

nv-dlasalle's avatar
nv-dlasalle committed
146
@parametrize_idtype
147
148
149
150
151
152
153
154
@pytest.mark.parametrize(
    "g",
    get_cases(["has_scalar_e_feature"], exclude=["zero-degree", "dglgraph"]),
)
@pytest.mark.parametrize("norm", ["none", "both", "right"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("out_dim", [1, 2])
155
def test_graph_conv_e_weight_norm(idtype, g, norm, weight, bias, out_dim):
156
    g = g.astype(idtype).to(F.ctx())
157
158
159
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(
        F.ctx()
    )
160
161
162
163

    # test pickle
    th.save(conv, tmp_buffer)

164
    ext_w = F.randn((5, out_dim)).to(F.ctx())
165
166
167
168
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
    edgenorm = nn.EdgeWeightNorm(norm=norm)
169
    norm_weight = edgenorm(g, g.edata["scalar_w"])
170
171
172
173
    if weight:
        h_out = conv(g, h, edge_weight=norm_weight)
    else:
        h_out = conv(g, h, weight=ext_w, edge_weight=norm_weight)
174
    assert h_out.shape == (ndst, out_dim)
175

176

nv-dlasalle's avatar
nv-dlasalle committed
177
@parametrize_idtype
178
179
180
181
182
183
184
@pytest.mark.parametrize(
    "g", get_cases(["bipartite"], exclude=["zero-degree", "dglgraph"])
)
@pytest.mark.parametrize("norm", ["none", "both", "right"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("out_dim", [1, 2])
185
def test_graph_conv_bi(idtype, g, norm, weight, bias, out_dim):
186
187
    # Test a pair of tensor inputs
    g = g.astype(idtype).to(F.ctx())
188
189
190
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(
        F.ctx()
    )
Mufei Li's avatar
Mufei Li committed
191

192
193
194
    # test pickle
    th.save(conv, tmp_buffer)

195
    ext_w = F.randn((5, out_dim)).to(F.ctx())
196
197
198
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
199
    h_dst = F.randn((ndst, out_dim)).to(F.ctx())
200
201
202
203
    if weight:
        h_out = conv(g, (h, h_dst))
    else:
        h_out = conv(g, (h, h_dst), weight=ext_w)
204
    assert h_out.shape == (ndst, out_dim)
205

206

207
208
209
210
211
212
213
214
215
216
217
218
def _S2AXWb(A, N, X, W, b):
    X1 = X * N
    X1 = th.matmul(A, X1.view(X1.shape[0], -1))
    X1 = X1 * N
    X2 = X1 * N
    X2 = th.matmul(A, X2.view(X2.shape[0], -1))
    X2 = X2 * N
    X = th.cat([X, X1, X2], dim=-1)
    Y = th.matmul(X, W.rot90())

    return Y + b

219
220

@pytest.mark.parametrize("out_dim", [1, 2])
221
def test_tagconv(out_dim):
222
    g = dgl.DGLGraph(nx.path_graph(3))
223
    g = g.to(F.ctx())
224
    ctx = F.ctx()
225
    adj = g.adj_external(transpose=True, ctx=ctx)
226
227
    norm = th.pow(g.in_degrees().float(), -0.5)

228
    conv = nn.TAGConv(5, out_dim, bias=True)
229
    conv = conv.to(ctx)
230
    print(conv)
Mufei Li's avatar
Mufei Li committed
231

232
233
    # test pickle
    th.save(conv, tmp_buffer)
234
235
236

    # test#1: basic
    h0 = F.ones((3, 5))
237
    h1 = conv(g, h0)
238
239
240
241
242
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    shp = norm.shape + (1,) * (h0.dim() - 1)
    norm = th.reshape(norm, shp).to(ctx)

243
244
245
    assert F.allclose(
        h1, _S2AXWb(adj, norm, h0, conv.lin.weight, conv.lin.bias)
    )
246

247
    conv = nn.TAGConv(5, out_dim)
248
    conv = conv.to(ctx)
249

250
251
    # test#2: basic
    h0 = F.ones((3, 5))
252
    h1 = conv(g, h0)
253
    assert h1.shape[-1] == out_dim
254

255
    # test reset_parameters
256
257
258
259
260
    old_weight = deepcopy(conv.lin.weight.data)
    conv.reset_parameters()
    new_weight = conv.lin.weight.data
    assert not F.allclose(old_weight, new_weight)

261

262
def test_set2set():
263
    ctx = F.ctx()
264
    g = dgl.DGLGraph(nx.path_graph(10))
265
    g = g.to(F.ctx())
266

267
    s2s = nn.Set2Set(5, 3, 3)  # hidden size 5, 3 iters, 3 layers
268
    s2s = s2s.to(ctx)
269
270
271
    print(s2s)

    # test#1: basic
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
272
    h0 = F.randn((g.num_nodes(), 5))
273
    h1 = s2s(g, h0)
274
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
275
276

    # test#2: batched graph
277
278
    g1 = dgl.DGLGraph(nx.path_graph(11)).to(F.ctx())
    g2 = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
279
    bg = dgl.batch([g, g1, g2])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
280
    h0 = F.randn((bg.num_nodes(), 5))
281
    h1 = s2s(bg, h0)
282
283
    assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2

284

285
def test_glob_att_pool():
286
    ctx = F.ctx()
287
    g = dgl.DGLGraph(nx.path_graph(10))
288
    g = g.to(F.ctx())
289
290

    gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))
291
    gap = gap.to(ctx)
292
293
    print(gap)

294
295
296
    # test pickle
    th.save(gap, tmp_buffer)

297
    # test#1: basic
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
298
    h0 = F.randn((g.num_nodes(), 5))
299
    h1 = gap(g, h0)
300
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
301
302
303

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
304
    h0 = F.randn((bg.num_nodes(), 5))
305
    h1 = gap(bg, h0)
306
307
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2

308

309
def test_simple_pool():
310
    ctx = F.ctx()
311
    g = dgl.DGLGraph(nx.path_graph(15))
312
    g = g.to(F.ctx())
313
314
315
316

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
317
    sort_pool = nn.SortPooling(10)  # k = 10
318
319
320
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
321
    h0 = F.randn((g.num_nodes(), 5))
322
323
324
325
    sum_pool = sum_pool.to(ctx)
    avg_pool = avg_pool.to(ctx)
    max_pool = max_pool.to(ctx)
    sort_pool = sort_pool.to(ctx)
326
    h1 = sum_pool(g, h0)
327
    assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
328
    h1 = avg_pool(g, h0)
329
    assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
330
    h1 = max_pool(g, h0)
331
    assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
332
    h1 = sort_pool(g, h0)
333
    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.dim() == 2
334
335

    # test#2: batched graph
336
    g_ = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
337
    bg = dgl.batch([g, g_, g, g_, g])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
338
    h0 = F.randn((bg.num_nodes(), 5))
339
    h1 = sum_pool(bg, h0)
340
341
342
343
344
345
346
347
348
349
    truth = th.stack(
        [
            F.sum(h0[:15], 0),
            F.sum(h0[15:20], 0),
            F.sum(h0[20:35], 0),
            F.sum(h0[35:40], 0),
            F.sum(h0[40:55], 0),
        ],
        0,
    )
350
    assert F.allclose(h1, truth)
351

352
    h1 = avg_pool(bg, h0)
353
354
355
356
357
358
359
360
361
362
    truth = th.stack(
        [
            F.mean(h0[:15], 0),
            F.mean(h0[15:20], 0),
            F.mean(h0[20:35], 0),
            F.mean(h0[35:40], 0),
            F.mean(h0[40:55], 0),
        ],
        0,
    )
363
    assert F.allclose(h1, truth)
364

365
    h1 = max_pool(bg, h0)
366
367
368
369
370
371
372
373
374
375
    truth = th.stack(
        [
            F.max(h0[:15], 0),
            F.max(h0[15:20], 0),
            F.max(h0[20:35], 0),
            F.max(h0[35:40], 0),
            F.max(h0[40:55], 0),
        ],
        0,
    )
376
    assert F.allclose(h1, truth)
377

378
    h1 = sort_pool(bg, h0)
379
380
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2

381

382
def test_set_trans():
383
    ctx = F.ctx()
384
385
    g = dgl.DGLGraph(nx.path_graph(15))

386
387
    st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, "sab")
    st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, "isab", 3)
388
    st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)
389
390
391
    st_enc_0 = st_enc_0.to(ctx)
    st_enc_1 = st_enc_1.to(ctx)
    st_dec = st_dec.to(ctx)
392
393
394
    print(st_enc_0, st_enc_1, st_dec)

    # test#1: basic
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
395
    h0 = F.randn((g.num_nodes(), 50))
396
    h1 = st_enc_0(g, h0)
397
    assert h1.shape == h0.shape
398
    h1 = st_enc_1(g, h0)
399
    assert h1.shape == h0.shape
400
    h2 = st_dec(g, h1)
401
    assert h2.shape[0] == 1 and h2.shape[1] == 200 and h2.dim() == 2
402
403
404
405
406

    # test#2: batched graph
    g1 = dgl.DGLGraph(nx.path_graph(5))
    g2 = dgl.DGLGraph(nx.path_graph(10))
    bg = dgl.batch([g, g1, g2])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
407
    h0 = F.randn((bg.num_nodes(), 50))
408
    h1 = st_enc_0(bg, h0)
409
    assert h1.shape == h0.shape
410
    h1 = st_enc_1(bg, h0)
411
412
    assert h1.shape == h0.shape

413
    h2 = st_dec(bg, h1)
414
415
    assert h2.shape[0] == 3 and h2.shape[1] == 200 and h2.dim() == 2

416

nv-dlasalle's avatar
nv-dlasalle committed
417
@parametrize_idtype
418
@pytest.mark.parametrize("O", [1, 8, 32])
419
def test_rgcn(idtype, O):
Minjie Wang's avatar
Minjie Wang committed
420
421
    ctx = F.ctx()
    etype = []
422
423
    g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.1))
    g = g.astype(idtype).to(F.ctx())
Minjie Wang's avatar
Minjie Wang committed
424
425
    # 5 etypes
    R = 5
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
426
    for i in range(g.num_edges()):
Minjie Wang's avatar
Minjie Wang committed
427
428
429
430
431
432
        etype.append(i % 5)
    B = 2
    I = 10

    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
433
    norm = th.rand((g.num_edges(), 1)).to(ctx)
434
    sorted_r, idx = th.sort(r)
435
436
437
438
439
    sorted_g = dgl.reorder_graph(
        g,
        edge_permute_algo="custom",
        permute_config={"edges_perm": idx.to(idtype)},
    )
440
    sorted_norm = norm[idx]
Minjie Wang's avatar
Minjie Wang committed
441

442
443
    rgc = nn.RelGraphConv(I, O, R).to(ctx)
    th.save(rgc, tmp_buffer)  # test pickle
Minjie Wang's avatar
Minjie Wang committed
444
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
445
    th.save(rgc_basis, tmp_buffer)  # test pickle
446
447
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
448
        th.save(rgc_bdd, tmp_buffer)  # test pickle
449

450
451
452
453
454
    # basic usage
    h_new = rgc(g, h, r)
    assert h_new.shape == (100, O)
    h_new_basis = rgc_basis(g, h, r)
    assert h_new_basis.shape == (100, O)
455
    if O % B == 0:
456
457
458
459
460
461
462
463
464
465
466
        h_new_bdd = rgc_bdd(g, h, r)
        assert h_new_bdd.shape == (100, O)

    # sorted input
    h_new_sorted = rgc(sorted_g, h, sorted_r, presorted=True)
    assert th.allclose(h_new, h_new_sorted, atol=1e-4, rtol=1e-4)
    h_new_basis_sorted = rgc_basis(sorted_g, h, sorted_r, presorted=True)
    assert th.allclose(h_new_basis, h_new_basis_sorted, atol=1e-4, rtol=1e-4)
    if O % B == 0:
        h_new_bdd_sorted = rgc_bdd(sorted_g, h, sorted_r, presorted=True)
        assert th.allclose(h_new_bdd, h_new_bdd_sorted, atol=1e-4, rtol=1e-4)
467

468
469
470
    # norm input
    h_new = rgc(g, h, r, norm)
    assert h_new.shape == (100, O)
471
    h_new = rgc_basis(g, h, r, norm)
472
    assert h_new.shape == (100, O)
473
474
    if O % B == 0:
        h_new = rgc_bdd(g, h, r, norm)
475
        assert h_new.shape == (100, O)
476

477

478
@parametrize_idtype
479
@pytest.mark.parametrize("O", [1, 10, 40])
480
481
482
483
484
485
486
def test_rgcn_default_nbasis(idtype, O):
    ctx = F.ctx()
    etype = []
    g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.1))
    g = g.astype(idtype).to(F.ctx())
    # 5 etypes
    R = 5
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
487
    for i in range(g.num_edges()):
488
489
490
491
492
        etype.append(i % 5)
    I = 10

    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
493
    norm = th.rand((g.num_edges(), 1)).to(ctx)
494
    sorted_r, idx = th.sort(r)
495
496
497
498
499
    sorted_g = dgl.reorder_graph(
        g,
        edge_permute_algo="custom",
        permute_config={"edges_perm": idx.to(idtype)},
    )
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
    sorted_norm = norm[idx]

    rgc = nn.RelGraphConv(I, O, R).to(ctx)
    th.save(rgc, tmp_buffer)  # test pickle
    rgc_basis = nn.RelGraphConv(I, O, R, "basis").to(ctx)
    th.save(rgc_basis, tmp_buffer)  # test pickle
    if O % R == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd").to(ctx)
        th.save(rgc_bdd, tmp_buffer)  # test pickle

    # basic usage
    h_new = rgc(g, h, r)
    assert h_new.shape == (100, O)
    h_new_basis = rgc_basis(g, h, r)
    assert h_new_basis.shape == (100, O)
    if O % R == 0:
        h_new_bdd = rgc_bdd(g, h, r)
        assert h_new_bdd.shape == (100, O)

    # sorted input
    h_new_sorted = rgc(sorted_g, h, sorted_r, presorted=True)
    assert th.allclose(h_new, h_new_sorted, atol=1e-4, rtol=1e-4)
    h_new_basis_sorted = rgc_basis(sorted_g, h, sorted_r, presorted=True)
    assert th.allclose(h_new_basis, h_new_basis_sorted, atol=1e-4, rtol=1e-4)
    if O % R == 0:
        h_new_bdd_sorted = rgc_bdd(sorted_g, h, sorted_r, presorted=True)
        assert th.allclose(h_new_bdd, h_new_bdd_sorted, atol=1e-4, rtol=1e-4)

    # norm input
    h_new = rgc(g, h, r, norm)
    assert h_new.shape == (100, O)
    h_new = rgc_basis(g, h, r, norm)
    assert h_new.shape == (100, O)
    if O % R == 0:
        h_new = rgc_bdd(g, h, r, norm)
        assert h_new.shape == (100, O)
536

537

nv-dlasalle's avatar
nv-dlasalle committed
538
@parametrize_idtype
539
540
541
542
543
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
544
def test_gat_conv(g, idtype, out_dim, num_heads):
545
    ctx = F.ctx()
546
    g = g.astype(idtype).to(ctx)
547
    gat = nn.GATConv(5, out_dim, num_heads)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
548
    feat = F.randn((g.number_of_src_nodes(), 5))
549
    gat = gat.to(ctx)
550
    h = gat(g, feat)
551
552
553
554

    # test pickle
    th.save(gat, tmp_buffer)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
555
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
556
    _, a = gat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
557
    assert a.shape == (g.num_edges(), num_heads, 1)
558

559
560
561
562
563
    # test residual connection
    gat = nn.GATConv(5, out_dim, num_heads, residual=True)
    gat = gat.to(ctx)
    h = gat(g, feat)

564

nv-dlasalle's avatar
nv-dlasalle committed
565
@parametrize_idtype
566
567
568
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
569
def test_gat_conv_bi(g, idtype, out_dim, num_heads):
570
    ctx = F.ctx()
571
    g = g.astype(idtype).to(ctx)
572
    gat = nn.GATConv(5, out_dim, num_heads)
573
574
575
576
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
577
578
    gat = gat.to(ctx)
    h = gat(g, feat)
579
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
580
    _, a = gat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
581
    assert a.shape == (g.num_edges(), num_heads, 1)
582

583

584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
@parametrize_idtype
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
def test_gat_conv_edge_weight(g, idtype, out_dim, num_heads):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    gat = nn.GATConv(5, out_dim, num_heads)
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
    gat = gat.to(ctx)
    ew = F.randn((g.num_edges(),))
    h = gat(g, feat, edge_weight=ew)
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
    _, a = gat(g, feat, get_attention=True)
    assert a.shape[0] == ew.shape[0]
    assert a.shape == (g.num_edges(), num_heads, 1)


nv-dlasalle's avatar
nv-dlasalle committed
605
@parametrize_idtype
606
607
608
609
610
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
Shaked Brody's avatar
Shaked Brody committed
611
612
613
614
615
616
617
618
619
620
621
622
623
def test_gatv2_conv(g, idtype, out_dim, num_heads):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    gat = nn.GATv2Conv(5, out_dim, num_heads)
    feat = F.randn((g.number_of_src_nodes(), 5))
    gat = gat.to(ctx)
    h = gat(g, feat)

    # test pickle
    th.save(gat, tmp_buffer)

    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
    _, a = gat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
624
    assert a.shape == (g.num_edges(), num_heads, 1)
Shaked Brody's avatar
Shaked Brody committed
625
626
627
628
629
630

    # test residual connection
    gat = nn.GATConv(5, out_dim, num_heads, residual=True)
    gat = gat.to(ctx)
    h = gat(g, feat)

631

nv-dlasalle's avatar
nv-dlasalle committed
632
@parametrize_idtype
633
634
635
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
Shaked Brody's avatar
Shaked Brody committed
636
637
638
639
def test_gatv2_conv_bi(g, idtype, out_dim, num_heads):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    gat = nn.GATv2Conv(5, out_dim, num_heads)
640
641
642
643
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
Shaked Brody's avatar
Shaked Brody committed
644
645
646
647
    gat = gat.to(ctx)
    h = gat(g, feat)
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
    _, a = gat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
648
    assert a.shape == (g.num_edges(), num_heads, 1)
Shaked Brody's avatar
Shaked Brody committed
649

650

nv-dlasalle's avatar
nv-dlasalle committed
651
@parametrize_idtype
652
653
654
655
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_node_feats", [1, 5])
@pytest.mark.parametrize("out_edge_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
656
def test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads):
Mufei Li's avatar
Mufei Li committed
657
    ctx = F.ctx()
658
    g = g.astype(idtype).to(ctx)
659
660
661
662
663
664
665
    egat = nn.EGATConv(
        in_node_feats=10,
        in_edge_feats=5,
        out_node_feats=out_node_feats,
        out_edge_feats=out_edge_feats,
        num_heads=num_heads,
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
666
667
    nfeat = F.randn((g.num_nodes(), 10))
    efeat = F.randn((g.num_edges(), 5))
668
669
    egat = egat.to(ctx)
    h, f = egat(g, nfeat, efeat)
670

671
    th.save(egat, tmp_buffer)
672

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
673
674
    assert h.shape == (g.num_nodes(), num_heads, out_node_feats)
    assert f.shape == (g.num_edges(), num_heads, out_edge_feats)
675
    _, _, attn = egat(g, nfeat, efeat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
676
    assert attn.shape == (g.num_edges(), num_heads, 1)
677

678

679
@parametrize_idtype
680
681
682
683
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_node_feats", [1, 5])
@pytest.mark.parametrize("out_edge_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
684
685
def test_egat_conv_bi(g, idtype, out_node_feats, out_edge_feats, num_heads):
    ctx = F.ctx()
686
    g = g.astype(idtype).to(ctx)
687
688
689
690
691
692
693
694
695
696
697
    egat = nn.EGATConv(
        in_node_feats=(10, 15),
        in_edge_feats=7,
        out_node_feats=out_node_feats,
        out_edge_feats=out_edge_feats,
        num_heads=num_heads,
    )
    nfeat = (
        F.randn((g.number_of_src_nodes(), 10)),
        F.randn((g.number_of_dst_nodes(), 15)),
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
698
    efeat = F.randn((g.num_edges(), 7))
699
700
    egat = egat.to(ctx)
    h, f = egat(g, nfeat, efeat)
701

Mufei Li's avatar
Mufei Li committed
702
    th.save(egat, tmp_buffer)
703

704
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_node_feats)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
705
    assert f.shape == (g.num_edges(), num_heads, out_edge_feats)
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
    _, _, attn = egat(g, nfeat, efeat, get_attention=True)
    assert attn.shape == (g.num_edges(), num_heads, 1)


@parametrize_idtype
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_node_feats", [1, 5])
@pytest.mark.parametrize("out_edge_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
def test_egat_conv_edge_weight(
    g, idtype, out_node_feats, out_edge_feats, num_heads
):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    egat = nn.EGATConv(
        in_node_feats=10,
        in_edge_feats=5,
        out_node_feats=out_node_feats,
        out_edge_feats=out_edge_feats,
        num_heads=num_heads,
    )
    egat = egat.to(ctx)
    nfeat = F.randn((g.num_nodes(), 10))
    efeat = F.randn((g.num_edges(), 5))
    ew = F.randn((g.num_edges(),))

    h, f, attn = egat(g, nfeat, efeat, edge_weight=ew, get_attention=True)

    assert h.shape == (g.num_nodes(), num_heads, out_node_feats)
    assert f.shape == (g.num_edges(), num_heads, out_edge_feats)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
736
    assert attn.shape == (g.num_edges(), num_heads, 1)
schmidt-ju's avatar
schmidt-ju committed
737
738
739
740
741
742
743
744


@parametrize_idtype
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
def test_edgegat_conv(g, idtype, out_feats, num_heads):
    ctx = F.ctx()
745
    g = g.astype(idtype).to(ctx)
schmidt-ju's avatar
schmidt-ju committed
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
    edgegat = nn.EdgeGATConv(
        in_feats=10, edge_feats=5, out_feats=out_feats, num_heads=num_heads
    )
    nfeat = F.randn((g.number_of_nodes(), 10))
    efeat = F.randn((g.number_of_edges(), 5))
    edgegat = edgegat.to(ctx)
    h = edgegat(g, nfeat, efeat)

    th.save(edgegat, tmp_buffer)

    assert h.shape == (g.number_of_nodes(), num_heads, out_feats)
    _, attn = edgegat(g, nfeat, efeat, True)
    assert attn.shape == (g.number_of_edges(), num_heads, 1)


@parametrize_idtype
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
def test_edgegat_conv_bi(g, idtype, out_feats, num_heads):
    ctx = F.ctx()
767
    g = g.astype(idtype).to(ctx)
schmidt-ju's avatar
schmidt-ju committed
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
    edgegat = nn.EdgeGATConv(
        in_feats=(10, 15),
        edge_feats=7,
        out_feats=out_feats,
        num_heads=num_heads,
    )
    nfeat = (
        F.randn((g.number_of_src_nodes(), 10)),
        F.randn((g.number_of_dst_nodes(), 15)),
    )
    efeat = F.randn((g.number_of_edges(), 7))
    edgegat = edgegat.to(ctx)
    h = edgegat(g, nfeat, efeat)

    th.save(edgegat, tmp_buffer)

    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_feats)
    _, attn = edgegat(g, nfeat, efeat, True)
    assert attn.shape == (g.number_of_edges(), num_heads, 1)
787

788

nv-dlasalle's avatar
nv-dlasalle committed
789
@parametrize_idtype
790
791
@pytest.mark.parametrize("g", get_cases(["homo", "block-bipartite"]))
@pytest.mark.parametrize("aggre_type", ["mean", "pool", "gcn", "lstm"])
792
793
def test_sage_conv(idtype, g, aggre_type):
    g = g.astype(idtype).to(F.ctx())
794
    sage = nn.SAGEConv(5, 10, aggre_type)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
795
    feat = F.randn((g.number_of_src_nodes(), 5))
796
    sage = sage.to(F.ctx())
797
798
    # test pickle
    th.save(sage, tmp_buffer)
799
800
801
    h = sage(g, feat)
    assert h.shape[-1] == 10

802

nv-dlasalle's avatar
nv-dlasalle committed
803
@parametrize_idtype
804
805
806
@pytest.mark.parametrize("g", get_cases(["bipartite"]))
@pytest.mark.parametrize("aggre_type", ["mean", "pool", "gcn", "lstm"])
@pytest.mark.parametrize("out_dim", [1, 2])
807
def test_sage_conv_bi(idtype, g, aggre_type, out_dim):
808
    g = g.astype(idtype).to(F.ctx())
809
    dst_dim = 5 if aggre_type != "gcn" else 10
810
    sage = nn.SAGEConv((10, dst_dim), out_dim, aggre_type)
811
812
813
814
    feat = (
        F.randn((g.number_of_src_nodes(), 10)),
        F.randn((g.number_of_dst_nodes(), dst_dim)),
    )
815
    sage = sage.to(F.ctx())
816
    h = sage(g, feat)
817
    assert h.shape[-1] == out_dim
818
    assert h.shape[0] == g.number_of_dst_nodes()
819

820

nv-dlasalle's avatar
nv-dlasalle committed
821
@parametrize_idtype
822
@pytest.mark.parametrize("out_dim", [1, 2])
823
def test_sage_conv2(idtype, out_dim):
824
    # TODO: add test for blocks
Mufei Li's avatar
Mufei Li committed
825
    # Test the case for graphs without edges
826
    g = dgl.heterograph({("_U", "_E", "_V"): ([], [])}, {"_U": 5, "_V": 3})
827
828
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
829
    sage = nn.SAGEConv((3, 3), out_dim, "gcn")
Mufei Li's avatar
Mufei Li committed
830
831
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    sage = sage.to(ctx)
832
    h = sage(g, (F.copy_to(feat[0], F.ctx()), F.copy_to(feat[1], F.ctx())))
833
    assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
834
    assert h.shape[0] == 3
835
    for aggre_type in ["mean", "pool", "lstm"]:
836
        sage = nn.SAGEConv((3, 1), out_dim, aggre_type)
Mufei Li's avatar
Mufei Li committed
837
838
839
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        sage = sage.to(ctx)
        h = sage(g, feat)
840
        assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
841
842
        assert h.shape[0] == 3

843

nv-dlasalle's avatar
nv-dlasalle committed
844
@parametrize_idtype
845
846
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
847
def test_sgc_conv(g, idtype, out_dim):
848
    ctx = F.ctx()
849
    g = g.astype(idtype).to(ctx)
850
    # not cached
851
    sgc = nn.SGConv(5, out_dim, 3)
852
853
854
855

    # test pickle
    th.save(sgc, tmp_buffer)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
856
    feat = F.randn((g.num_nodes(), 5))
857
    sgc = sgc.to(ctx)
858

859
    h = sgc(g, feat)
860
    assert h.shape[-1] == out_dim
861
862

    # cached
863
    sgc = nn.SGConv(5, out_dim, 3, True)
864
    sgc = sgc.to(ctx)
865
866
    h_0 = sgc(g, feat)
    h_1 = sgc(g, feat + 1)
867
    assert F.allclose(h_0, h_1)
868
    assert h_0.shape[-1] == out_dim
869

870

nv-dlasalle's avatar
nv-dlasalle committed
871
@parametrize_idtype
872
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
873
def test_appnp_conv(g, idtype):
874
    ctx = F.ctx()
875
    g = g.astype(idtype).to(ctx)
876
    appnp = nn.APPNPConv(10, 0.1)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
877
    feat = F.randn((g.num_nodes(), 5))
878
    appnp = appnp.to(ctx)
Mufei Li's avatar
Mufei Li committed
879

880
881
    # test pickle
    th.save(appnp, tmp_buffer)
882

883
    h = appnp(g, feat)
884
885
    assert h.shape[-1] == 5

886

nv-dlasalle's avatar
nv-dlasalle committed
887
@parametrize_idtype
888
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
889
890
891
892
def test_appnp_conv_e_weight(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    appnp = nn.APPNPConv(10, 0.1)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
893
    feat = F.randn((g.num_nodes(), 5))
894
    eweight = F.ones((g.num_edges(),))
895
896
897
898
899
    appnp = appnp.to(ctx)

    h = appnp(g, feat, edge_weight=eweight)
    assert h.shape[-1] == 5

900

nv-dlasalle's avatar
nv-dlasalle committed
901
@parametrize_idtype
902
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
903
904
@pytest.mark.parametrize("bias", [True, False])
def test_gcn2conv_e_weight(g, idtype, bias):
905
906
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
907
908
909
    gcn2conv = nn.GCN2Conv(
        5, layer=2, alpha=0.5, bias=bias, project_initial_features=True
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
910
    feat = F.randn((g.num_nodes(), 5))
911
    eweight = F.ones((g.num_edges(),))
912
913
914
915
916
917
    gcn2conv = gcn2conv.to(ctx)
    res = feat
    h = gcn2conv(g, res, feat, edge_weight=eweight)
    assert h.shape[-1] == 5


nv-dlasalle's avatar
nv-dlasalle committed
918
@parametrize_idtype
919
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
920
921
922
923
def test_sgconv_e_weight(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    sgconv = nn.SGConv(5, 5, 3)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
924
    feat = F.randn((g.num_nodes(), 5))
925
    eweight = F.ones((g.num_edges(),))
926
927
928
929
    sgconv = sgconv.to(ctx)
    h = sgconv(g, feat, edge_weight=eweight)
    assert h.shape[-1] == 5

930

nv-dlasalle's avatar
nv-dlasalle committed
931
@parametrize_idtype
932
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
933
934
935
936
937
def test_tagconv_e_weight(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    conv = nn.TAGConv(5, 5, bias=True)
    conv = conv.to(ctx)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
938
    feat = F.randn((g.num_nodes(), 5))
939
    eweight = F.ones((g.num_edges(),))
940
941
942
943
    conv = conv.to(ctx)
    h = conv(g, feat, edge_weight=eweight)
    assert h.shape[-1] == 5

944

nv-dlasalle's avatar
nv-dlasalle committed
945
@parametrize_idtype
946
947
948
949
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("aggregator_type", ["mean", "max", "sum"])
950
951
def test_gin_conv(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
952
    ctx = F.ctx()
953
    gin = nn.GINConv(th.nn.Linear(5, 12), aggregator_type)
VoVAllen's avatar
VoVAllen committed
954
    th.save(gin, tmp_buffer)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
955
    feat = F.randn((g.number_of_src_nodes(), 5))
956
957
    gin = gin.to(ctx)
    h = gin(g, feat)
958
959

    # test pickle
VoVAllen's avatar
VoVAllen committed
960
    th.save(gin, tmp_buffer)
Mufei Li's avatar
Mufei Li committed
961

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
962
    assert h.shape == (g.number_of_dst_nodes(), 12)
963

Mufei Li's avatar
Mufei Li committed
964
965
966
967
    gin = nn.GINConv(None, aggregator_type)
    th.save(gin, tmp_buffer)
    gin = gin.to(ctx)
    h = gin(g, feat)
968

969

nv-dlasalle's avatar
nv-dlasalle committed
970
@parametrize_idtype
971
@pytest.mark.parametrize("g", get_cases(["homo", "block-bipartite"]))
Mufei Li's avatar
Mufei Li committed
972
973
974
def test_gine_conv(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
975
    gine = nn.GINEConv(th.nn.Linear(5, 12))
Mufei Li's avatar
Mufei Li committed
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
    th.save(gine, tmp_buffer)
    nfeat = F.randn((g.number_of_src_nodes(), 5))
    efeat = F.randn((g.num_edges(), 5))
    gine = gine.to(ctx)
    h = gine(g, nfeat, efeat)

    # test pickle
    th.save(gine, tmp_buffer)
    assert h.shape == (g.number_of_dst_nodes(), 12)

    gine = nn.GINEConv(None)
    th.save(gine, tmp_buffer)
    gine = gine.to(ctx)
    h = gine(g, nfeat, efeat)

991

nv-dlasalle's avatar
nv-dlasalle committed
992
@parametrize_idtype
993
994
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("aggregator_type", ["mean", "max", "sum"])
995
996
997
def test_gin_conv_bi(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
998
999
1000
1001
    gin = nn.GINConv(th.nn.Linear(5, 12), aggregator_type)
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
1002
1003
1004
    )
    gin = gin.to(ctx)
    h = gin(g, feat)
1005
    assert h.shape == (g.number_of_dst_nodes(), 12)
1006

1007

nv-dlasalle's avatar
nv-dlasalle committed
1008
@parametrize_idtype
1009
1010
1011
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
1012
1013
def test_agnn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
1014
1015
    ctx = F.ctx()
    agnn = nn.AGNNConv(1)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1016
    feat = F.randn((g.number_of_src_nodes(), 5))
1017
    agnn = agnn.to(ctx)
1018
    h = agnn(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1019
    assert h.shape == (g.number_of_dst_nodes(), 5)
1020

1021

nv-dlasalle's avatar
nv-dlasalle committed
1022
@parametrize_idtype
1023
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
1024
1025
1026
def test_agnn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1027
    agnn = nn.AGNNConv(1)
1028
1029
1030
1031
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
1032
1033
    agnn = agnn.to(ctx)
    h = agnn(g, feat)
1034
    assert h.shape == (g.number_of_dst_nodes(), 5)
1035

1036

nv-dlasalle's avatar
nv-dlasalle committed
1037
@parametrize_idtype
1038
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
1039
def test_gated_graph_conv(g, idtype):
1040
    ctx = F.ctx()
1041
    g = g.astype(idtype).to(ctx)
1042
    ggconv = nn.GatedGraphConv(5, 10, 5, 3)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1043
1044
    etypes = th.arange(g.num_edges()) % 3
    feat = F.randn((g.num_nodes(), 5))
1045
1046
    ggconv = ggconv.to(ctx)
    etypes = etypes.to(ctx)
1047

1048
    h = ggconv(g, feat, etypes)
1049
1050
1051
    # current we only do shape check
    assert h.shape[-1] == 10

1052

nv-dlasalle's avatar
nv-dlasalle committed
1053
@parametrize_idtype
1054
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
1055
1056
1057
1058
def test_gated_graph_conv_one_etype(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    ggconv = nn.GatedGraphConv(5, 10, 5, 1)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1059
1060
    etypes = th.zeros(g.num_edges())
    feat = F.randn((g.num_nodes(), 5))
1061
1062
1063
1064
1065
1066
1067
1068
1069
    ggconv = ggconv.to(ctx)
    etypes = etypes.to(ctx)

    h = ggconv(g, feat, etypes)
    h2 = ggconv(g, feat)
    # current we only do shape check
    assert F.allclose(h, h2)
    assert h.shape[-1] == 10

1070

nv-dlasalle's avatar
nv-dlasalle committed
1071
@parametrize_idtype
1072
1073
1074
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
1075
1076
def test_nn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
1077
1078
    ctx = F.ctx()
    edge_func = th.nn.Linear(4, 5 * 10)
1079
    nnconv = nn.NNConv(5, 10, edge_func, "mean")
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1080
    feat = F.randn((g.number_of_src_nodes(), 5))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1081
    efeat = F.randn((g.num_edges(), 4))
1082
1083
1084
1085
1086
    nnconv = nnconv.to(ctx)
    h = nnconv(g, feat, efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

1087

nv-dlasalle's avatar
nv-dlasalle committed
1088
@parametrize_idtype
1089
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
1090
1091
1092
def test_nn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1093
    edge_func = th.nn.Linear(4, 5 * 10)
1094
    nnconv = nn.NNConv((5, 2), 10, edge_func, "mean")
1095
1096
    feat = F.randn((g.number_of_src_nodes(), 5))
    feat_dst = F.randn((g.number_of_dst_nodes(), 2))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1097
    efeat = F.randn((g.num_edges(), 4))
1098
1099
1100
1101
1102
    nnconv = nnconv.to(ctx)
    h = nnconv(g, (feat, feat_dst), efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

1103

nv-dlasalle's avatar
nv-dlasalle committed
1104
@parametrize_idtype
1105
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
1106
1107
def test_gmm_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
1108
    ctx = F.ctx()
1109
    gmmconv = nn.GMMConv(5, 10, 3, 4, "mean")
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1110
1111
    feat = F.randn((g.num_nodes(), 5))
    pseudo = F.randn((g.num_edges(), 3))
1112
    gmmconv = gmmconv.to(ctx)
1113
    h = gmmconv(g, feat, pseudo)
1114
1115
1116
    # currently we only do shape check
    assert h.shape[-1] == 10

1117

nv-dlasalle's avatar
nv-dlasalle committed
1118
@parametrize_idtype
1119
1120
1121
@pytest.mark.parametrize(
    "g", get_cases(["bipartite", "block-bipartite"], exclude=["zero-degree"])
)
1122
1123
1124
def test_gmm_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1125
    gmmconv = nn.GMMConv((5, 2), 10, 3, 4, "mean")
1126
1127
    feat = F.randn((g.number_of_src_nodes(), 5))
    feat_dst = F.randn((g.number_of_dst_nodes(), 2))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1128
    pseudo = F.randn((g.num_edges(), 3))
1129
1130
1131
1132
1133
    gmmconv = gmmconv.to(ctx)
    h = gmmconv(g, (feat, feat_dst), pseudo)
    # currently we only do shape check
    assert h.shape[-1] == 10

1134

nv-dlasalle's avatar
nv-dlasalle committed
1135
@parametrize_idtype
1136
1137
1138
1139
1140
@pytest.mark.parametrize("norm_type", ["both", "right", "none"])
@pytest.mark.parametrize(
    "g", get_cases(["homo", "bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 2])
1141
def test_dense_graph_conv(norm_type, g, idtype, out_dim):
1142
    g = g.astype(idtype).to(F.ctx())
1143
    ctx = F.ctx()
1144
    # TODO(minjie): enable the following option after #1385
1145
    adj = g.adj_external(transpose=True, ctx=ctx).to_dense()
1146
1147
    conv = nn.GraphConv(5, out_dim, norm=norm_type, bias=True)
    dense_conv = nn.DenseGraphConv(5, out_dim, norm=norm_type, bias=True)
1148
1149
    dense_conv.weight.data = conv.weight.data
    dense_conv.bias.data = conv.bias.data
1150
    feat = F.randn((g.number_of_src_nodes(), 5))
1151
1152
    conv = conv.to(ctx)
    dense_conv = dense_conv.to(ctx)
1153
1154
    out_conv = conv(g, feat)
    out_dense_conv = dense_conv(adj, feat)
1155
1156
    assert F.allclose(out_conv, out_dense_conv)

1157

nv-dlasalle's avatar
nv-dlasalle committed
1158
@parametrize_idtype
1159
1160
@pytest.mark.parametrize("g", get_cases(["homo", "bipartite"]))
@pytest.mark.parametrize("out_dim", [1, 2])
1161
def test_dense_sage_conv(g, idtype, out_dim):
1162
    g = g.astype(idtype).to(F.ctx())
1163
    ctx = F.ctx()
1164
    adj = g.adj_external(transpose=True, ctx=ctx).to_dense()
1165
    sage = nn.SAGEConv(5, out_dim, "gcn")
1166
    dense_sage = nn.DenseSAGEConv(5, out_dim)
1167
    dense_sage.fc.weight.data = sage.fc_neigh.weight.data
1168
    dense_sage.fc.bias.data = sage.bias.data
1169
1170
1171
    if len(g.ntypes) == 2:
        feat = (
            F.randn((g.number_of_src_nodes(), 5)),
1172
            F.randn((g.number_of_dst_nodes(), 5)),
1173
1174
        )
    else:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1175
        feat = F.randn((g.num_nodes(), 5))
1176
1177
    sage = sage.to(ctx)
    dense_sage = dense_sage.to(ctx)
1178
1179
    out_sage = sage(g, feat)
    out_dense_sage = dense_sage(adj, feat)
1180
1181
    assert F.allclose(out_sage, out_dense_sage), g

1182

nv-dlasalle's avatar
nv-dlasalle committed
1183
@parametrize_idtype
1184
1185
1186
1187
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 2])
1188
def test_edge_conv(g, idtype, out_dim):
1189
    g = g.astype(idtype).to(F.ctx())
1190
    ctx = F.ctx()
1191
    edge_conv = nn.EdgeConv(5, out_dim).to(ctx)
1192
    print(edge_conv)
1193
1194
1195

    # test pickle
    th.save(edge_conv, tmp_buffer)
Mufei Li's avatar
Mufei Li committed
1196

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1197
    h0 = F.randn((g.number_of_src_nodes(), 5))
1198
    h1 = edge_conv(g, h0)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1199
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)
1200

1201

nv-dlasalle's avatar
nv-dlasalle committed
1202
@parametrize_idtype
1203
1204
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
1205
def test_edge_conv_bi(g, idtype, out_dim):
1206
1207
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1208
    edge_conv = nn.EdgeConv(5, out_dim).to(ctx)
1209
    print(edge_conv)
1210
    h0 = F.randn((g.number_of_src_nodes(), 5))
1211
1212
    x0 = F.randn((g.number_of_dst_nodes(), 5))
    h1 = edge_conv(g, (h0, x0))
1213
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)
Mufei Li's avatar
Mufei Li committed
1214

1215

nv-dlasalle's avatar
nv-dlasalle committed
1216
@parametrize_idtype
1217
1218
1219
1220
1221
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
1222
def test_dotgat_conv(g, idtype, out_dim, num_heads):
1223
1224
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1225
    dotgat = nn.DotGatConv(5, out_dim, num_heads)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1226
    feat = F.randn((g.number_of_src_nodes(), 5))
1227
    dotgat = dotgat.to(ctx)
Mufei Li's avatar
Mufei Li committed
1228

1229
1230
    # test pickle
    th.save(dotgat, tmp_buffer)
Mufei Li's avatar
Mufei Li committed
1231

1232
    h = dotgat(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1233
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
1234
    _, a = dotgat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1235
    assert a.shape == (g.num_edges(), num_heads, 1)
1236

1237

nv-dlasalle's avatar
nv-dlasalle committed
1238
@parametrize_idtype
1239
1240
1241
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
1242
def test_dotgat_conv_bi(g, idtype, out_dim, num_heads):
1243
1244
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1245
    dotgat = nn.DotGatConv((5, 5), out_dim, num_heads)
1246
1247
1248
1249
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
1250
1251
    dotgat = dotgat.to(ctx)
    h = dotgat(g, feat)
1252
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
1253
    _, a = dotgat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1254
    assert a.shape == (g.num_edges(), num_heads, 1)
1255

1256
1257

@pytest.mark.parametrize("out_dim", [1, 2])
1258
def test_dense_cheb_conv(out_dim):
1259
1260
1261
    for k in range(1, 4):
        ctx = F.ctx()
        g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
1262
        g = g.to(F.ctx())
1263
        adj = g.adj_external(transpose=True, ctx=ctx).to_dense()
1264
1265
        cheb = nn.ChebConv(5, out_dim, k, None)
        dense_cheb = nn.DenseChebConv(5, out_dim, k)
1266
        # for i in range(len(cheb.fc)):
Axel Nilsson's avatar
Axel Nilsson committed
1267
        #    dense_cheb.W.data[i] = cheb.fc[i].weight.data.t()
1268
1269
1270
        dense_cheb.W.data = cheb.linear.weight.data.transpose(-1, -2).view(
            k, 5, out_dim
        )
Axel Nilsson's avatar
Axel Nilsson committed
1271
1272
        if cheb.linear.bias is not None:
            dense_cheb.bias.data = cheb.linear.bias.data
1273
        feat = F.randn((100, 5))
1274
1275
        cheb = cheb.to(ctx)
        dense_cheb = dense_cheb.to(ctx)
1276
1277
        out_cheb = cheb(g, feat, [2.0])
        out_dense_cheb = dense_cheb(adj, feat, 2.0)
Axel Nilsson's avatar
Axel Nilsson committed
1278
        print(k, out_cheb, out_dense_cheb)
1279
1280
        assert F.allclose(out_cheb, out_dense_cheb)

1281

1282
1283
def test_sequential():
    ctx = F.ctx()
1284

1285
1286
1287
1288
1289
1290
1291
    # Test single graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat, e_feat):
            graph = graph.local_var()
1292
1293
1294
1295
1296
            graph.ndata["h"] = n_feat
            graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
            n_feat += graph.ndata["h"]
            graph.apply_edges(fn.u_add_v("h", "h", "e"))
            e_feat += graph.edata["e"]
1297
1298
1299
1300
1301
            return n_feat, e_feat

    g = dgl.DGLGraph()
    g.add_nodes(3)
    g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
1302
    g = g.to(F.ctx())
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    n_feat = F.randn((3, 4))
    e_feat = F.randn((9, 4))
    net = net.to(ctx)
    n_feat, e_feat = net(g, n_feat, e_feat)
    assert n_feat.shape == (3, 4)
    assert e_feat.shape == (9, 4)

    # Test multiple graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat):
            graph = graph.local_var()
1318
1319
1320
            graph.ndata["h"] = n_feat
            graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
            n_feat += graph.ndata["h"]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1321
            return n_feat.view(graph.num_nodes() // 2, 2, -1).sum(1)
1322

1323
1324
1325
    g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05)).to(F.ctx())
    g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2)).to(F.ctx())
    g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8)).to(F.ctx())
1326
1327
1328
1329
1330
1331
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    net = net.to(ctx)
    n_feat = F.randn((32, 4))
    n_feat = net([g1, g2, g3], n_feat)
    assert n_feat.shape == (4, 4)

1332

nv-dlasalle's avatar
nv-dlasalle committed
1333
@parametrize_idtype
1334
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
1335
1336
def test_atomic_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
1337
1338
1339
1340
1341
1342
    aconv = nn.AtomicConv(
        interaction_cutoffs=F.tensor([12.0, 12.0]),
        rbf_kernel_means=F.tensor([0.0, 2.0]),
        rbf_kernel_scaling=F.tensor([4.0, 4.0]),
        features_to_use=F.tensor([6.0, 8.0]),
    )
1343
1344
1345
1346
1347

    ctx = F.ctx()
    if F.gpu_ctx():
        aconv = aconv.to(ctx)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1348
1349
    feat = F.randn((g.num_nodes(), 1))
    dist = F.randn((g.num_edges(), 1))
1350
1351

    h = aconv(g, feat, dist)
1352

1353
1354
1355
    # current we only do shape check
    assert h.shape[-1] == 4

1356

nv-dlasalle's avatar
nv-dlasalle committed
1357
@parametrize_idtype
1358
1359
1360
1361
@pytest.mark.parametrize(
    "g", get_cases(["homo", "bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 3])
1362
def test_cf_conv(g, idtype, out_dim):
1363
    g = g.astype(idtype).to(F.ctx())
1364
1365
1366
    cfconv = nn.CFConv(
        node_in_feats=2, edge_in_feats=3, hidden_feats=2, out_feats=out_dim
    )
1367
1368
1369
1370
1371

    ctx = F.ctx()
    if F.gpu_ctx():
        cfconv = cfconv.to(ctx)

1372
    src_feats = F.randn((g.number_of_src_nodes(), 2))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1373
    edge_feats = F.randn((g.num_edges(), 3))
1374
1375
1376
1377
1378
1379
1380
    h = cfconv(g, src_feats, edge_feats)
    # current we only do shape check
    assert h.shape[-1] == out_dim

    # case for bipartite graphs
    dst_feats = F.randn((g.number_of_dst_nodes(), 3))
    h = cfconv(g, (src_feats, dst_feats), edge_feats)
1381
    # current we only do shape check
1382
    assert h.shape[-1] == out_dim
1383

1384

1385
1386
1387
1388
1389
1390
def myagg(alist, dsttype):
    rst = alist[0]
    for i in range(1, len(alist)):
        rst = rst + (i + 1) * alist[i]
    return rst

1391

nv-dlasalle's avatar
nv-dlasalle committed
1392
@parametrize_idtype
1393
1394
@pytest.mark.parametrize("agg", ["sum", "max", "min", "mean", "stack", myagg])
@pytest.mark.parametrize("canonical_keys", [False, True])
1395
def test_hetero_conv(agg, idtype, canonical_keys):
1396
1397
1398
1399
1400
1401
1402
1403
1404
    g = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 0, 2, 1], [1, 2, 1, 3]),
            ("user", "plays", "game"): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]),
            ("store", "sells", "game"): ([0, 0, 1, 1], [0, 3, 1, 2]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
1405
    if not canonical_keys:
1406
1407
1408
1409
1410
1411
1412
1413
        conv = nn.HeteroGraphConv(
            {
                "follows": nn.GraphConv(2, 3, allow_zero_in_degree=True),
                "plays": nn.GraphConv(2, 4, allow_zero_in_degree=True),
                "sells": nn.GraphConv(3, 4, allow_zero_in_degree=True),
            },
            agg,
        )
1414
    else:
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
        conv = nn.HeteroGraphConv(
            {
                ("user", "follows", "user"): nn.GraphConv(
                    2, 3, allow_zero_in_degree=True
                ),
                ("user", "plays", "game"): nn.GraphConv(
                    2, 4, allow_zero_in_degree=True
                ),
                ("store", "sells", "game"): nn.GraphConv(
                    3, 4, allow_zero_in_degree=True
                ),
            },
            agg,
        )
1429

1430
    conv = conv.to(F.ctx())
1431
1432
1433
1434

    # test pickle
    th.save(conv, tmp_buffer)

1435
1436
1437
1438
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))

1439
1440
1441
1442
1443
    h = conv(g, {"user": uf, "game": gf, "store": sf})
    assert set(h.keys()) == {"user", "game"}
    if agg != "stack":
        assert h["user"].shape == (4, 3)
        assert h["game"].shape == (4, 4)
1444
    else:
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
        assert h["user"].shape == (4, 1, 3)
        assert h["game"].shape == (4, 2, 4)

    block = dgl.to_block(
        g.to(F.cpu()), {"user": [0, 1, 2, 3], "game": [0, 1, 2, 3], "store": []}
    ).to(F.ctx())
    h = conv(
        block,
        (
            {"user": uf, "game": gf, "store": sf},
            {"user": uf, "game": gf, "store": sf[0:0]},
        ),
    )
    assert set(h.keys()) == {"user", "game"}
    if agg != "stack":
        assert h["user"].shape == (4, 3)
        assert h["game"].shape == (4, 4)
1462
    else:
1463
1464
1465
1466
1467
1468
1469
1470
        assert h["user"].shape == (4, 1, 3)
        assert h["game"].shape == (4, 2, 4)

    h = conv(block, {"user": uf, "game": gf, "store": sf})
    assert set(h.keys()) == {"user", "game"}
    if agg != "stack":
        assert h["user"].shape == (4, 3)
        assert h["game"].shape == (4, 4)
1471
    else:
1472
1473
        assert h["user"].shape == (4, 1, 3)
        assert h["game"].shape == (4, 2, 4)
1474
1475
1476
1477
1478
1479
1480
1481
1482

    # test with mod args
    class MyMod(th.nn.Module):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.carg2 = 0
            self.s1 = s1
            self.s2 = s2
1483

1484
1485
1486
1487
1488
1489
        def forward(self, g, h, arg1=None, *, arg2=None):
            if arg1 is not None:
                self.carg1 += 1
            if arg2 is not None:
                self.carg2 += 1
            return th.zeros((g.number_of_dst_nodes(), self.s2))
1490

1491
1492
1493
    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
1494
1495
1496
    conv = nn.HeteroGraphConv(
        {"follows": mod1, "plays": mod2, "sells": mod3}, agg
    )
1497
    conv = conv.to(F.ctx())
1498
1499
1500
1501
1502
1503
1504
1505
    mod_args = {"follows": (1,), "plays": (1,)}
    mod_kwargs = {"sells": {"arg2": "abc"}}
    h = conv(
        g,
        {"user": uf, "game": gf, "store": sf},
        mod_args=mod_args,
        mod_kwargs=mod_kwargs,
    )
1506
1507
1508
1509
1510
1511
1512
    assert mod1.carg1 == 1
    assert mod1.carg2 == 0
    assert mod2.carg1 == 1
    assert mod2.carg2 == 0
    assert mod3.carg1 == 0
    assert mod3.carg2 == 1

1513
    # conv on graph without any edges
1514
    for etype in g.etypes:
1515
        g = dgl.remove_edges(g, g.edges(form="eid", etype=etype), etype=etype)
1516
    assert g.num_edges() == 0
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
    h = conv(g, {"user": uf, "game": gf, "store": sf})
    assert set(h.keys()) == {"user", "game"}

    block = dgl.to_block(
        g.to(F.cpu()), {"user": [0, 1, 2, 3], "game": [0, 1, 2, 3], "store": []}
    ).to(F.ctx())
    h = conv(
        block,
        (
            {"user": uf, "game": gf, "store": sf},
            {"user": uf, "game": gf, "store": sf[0:0]},
        ),
    )
    assert set(h.keys()) == {"user", "game"}
1531
1532


1533
@pytest.mark.parametrize("out_dim", [1, 2, 100])
1534
1535
def test_hetero_linear(out_dim):
    in_feats = {
1536
1537
        "user": F.randn((2, 1)),
        ("user", "follows", "user"): F.randn((3, 2)),
1538
1539
    }

1540
1541
1542
    layer = nn.HeteroLinear(
        {"user": 1, ("user", "follows", "user"): 2}, out_dim
    )
1543
1544
    layer = layer.to(F.ctx())
    out_feats = layer(in_feats)
1545
1546
1547
    assert out_feats["user"].shape == (2, out_dim)
    assert out_feats[("user", "follows", "user")].shape == (3, out_dim)

1548

1549
@pytest.mark.parametrize("out_dim", [1, 2, 100])
1550
def test_hetero_embedding(out_dim):
1551
1552
1553
    layer = nn.HeteroEmbedding(
        {"user": 2, ("user", "follows", "user"): 3}, out_dim
    )
1554
1555
1556
    layer = layer.to(F.ctx())

    embeds = layer.weight
1557
1558
    assert embeds["user"].shape == (2, out_dim)
    assert embeds[("user", "follows", "user")].shape == (3, out_dim)
1559

YJ-Zhao's avatar
YJ-Zhao committed
1560
1561
    layer.reset_parameters()
    embeds = layer.weight
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
    assert embeds["user"].shape == (2, out_dim)
    assert embeds[("user", "follows", "user")].shape == (3, out_dim)

    embeds = layer(
        {
            "user": F.tensor([0], dtype=F.int64),
            ("user", "follows", "user"): F.tensor([0, 2], dtype=F.int64),
        }
    )
    assert embeds["user"].shape == (1, out_dim)
    assert embeds[("user", "follows", "user")].shape == (2, out_dim)
YJ-Zhao's avatar
YJ-Zhao committed
1573

1574

nv-dlasalle's avatar
nv-dlasalle committed
1575
@parametrize_idtype
1576
1577
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
Mufei Li's avatar
Mufei Li committed
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
def test_gnnexplainer(g, idtype, out_dim):
    g = g.astype(idtype).to(F.ctx())
    feat = F.randn((g.num_nodes(), 5))

    class Model(th.nn.Module):
        def __init__(self, in_feats, out_feats, graph=False):
            super(Model, self).__init__()
            self.linear = th.nn.Linear(in_feats, out_feats)
            if graph:
                self.pool = nn.AvgPooling()
            else:
                self.pool = None

        def forward(self, graph, feat, eweight=None):
            with graph.local_scope():
                feat = self.linear(feat)
1594
                graph.ndata["h"] = feat
Mufei Li's avatar
Mufei Li committed
1595
                if eweight is None:
1596
                    graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
Mufei Li's avatar
Mufei Li committed
1597
                else:
1598
1599
1600
1601
                    graph.edata["w"] = eweight
                    graph.update_all(
                        fn.u_mul_e("h", "w", "m"), fn.sum("m", "h")
                    )
Mufei Li's avatar
Mufei Li committed
1602
1603

                if self.pool:
1604
                    return self.pool(graph, graph.ndata["h"])
Mufei Li's avatar
Mufei Li committed
1605
                else:
1606
                    return graph.ndata["h"]
Mufei Li's avatar
Mufei Li committed
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619

    # Explain node prediction
    model = Model(5, out_dim)
    model = model.to(F.ctx())
    explainer = nn.GNNExplainer(model, num_hops=1)
    new_center, sg, feat_mask, edge_mask = explainer.explain_node(0, g, feat)

    # Explain graph prediction
    model = Model(5, out_dim, graph=True)
    model = model.to(F.ctx())
    explainer = nn.GNNExplainer(model, num_hops=1)
    feat_mask, edge_mask = explainer.explain_graph(g, feat)

1620
1621
1622
1623
1624

@pytest.mark.parametrize("g", get_cases(["hetero"], exclude=["zero-degree"]))
@pytest.mark.parametrize("idtype", [F.int64])
@pytest.mark.parametrize("input_dim", [5])
@pytest.mark.parametrize("output_dim", [1, 2])
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
def test_heterognnexplainer(g, idtype, input_dim, output_dim):
    g = g.astype(idtype).to(F.ctx())
    device = g.device

    # add self-loop and reverse edges
    transform1 = dgl.transforms.AddSelfLoop(new_etypes=True)
    g = transform1(g)
    transform2 = dgl.transforms.AddReverse(copy_edata=True)
    g = transform2(g)

1635
1636
1637
1638
    feat = {
        ntype: th.zeros((g.num_nodes(ntype), input_dim), device=device)
        for ntype in g.ntypes
    }
1639
1640
1641
1642

    class Model(th.nn.Module):
        def __init__(self, in_dim, num_classes, canonical_etypes, graph=False):
            super(Model, self).__init__()
1643
1644
1645
1646
1647
1648
1649
            self.graph = graph
            self.etype_weights = th.nn.ModuleDict(
                {
                    "_".join(c_etype): th.nn.Linear(in_dim, num_classes)
                    for c_etype in canonical_etypes
                }
            )
1650
1651
1652
1653
1654
1655

        def forward(self, graph, feat, eweight=None):
            with graph.local_scope():
                c_etype_func_dict = {}
                for c_etype in graph.canonical_etypes:
                    src_type, etype, dst_type = c_etype
1656
1657
                    wh = self.etype_weights["_".join(c_etype)](feat[src_type])
                    graph.nodes[src_type].data[f"h_{c_etype}"] = wh
1658
                    if eweight is None:
1659
1660
1661
1662
                        c_etype_func_dict[c_etype] = (
                            fn.copy_u(f"h_{c_etype}", "m"),
                            fn.mean("m", "h"),
                        )
1663
                    else:
1664
                        graph.edges[c_etype].data["w"] = eweight[c_etype]
1665
                        c_etype_func_dict[c_etype] = (
1666
1667
1668
1669
                            fn.u_mul_e(f"h_{c_etype}", "w", "m"),
                            fn.mean("m", "h"),
                        )
                graph.multi_update_all(c_etype_func_dict, "sum")
1670
1671
1672
1673
                if self.graph:
                    hg = 0
                    for ntype in graph.ntypes:
                        if graph.num_nodes(ntype):
1674
                            hg = hg + dgl.mean_nodes(graph, "h", ntype=ntype)
1675
1676
1677

                    return hg
                else:
1678
                    return graph.ndata["h"]
1679
1680
1681
1682
1683
1684

    # Explain node prediction
    model = Model(input_dim, output_dim, g.canonical_etypes)
    model = model.to(F.ctx())
    ntype = g.ntypes[0]
    explainer = nn.explain.HeteroGNNExplainer(model, num_hops=1)
1685
1686
1687
    new_center, sg, feat_mask, edge_mask = explainer.explain_node(
        ntype, 0, g, feat
    )
1688
1689
1690
1691
1692
1693
1694
1695

    # Explain graph prediction
    model = Model(input_dim, output_dim, g.canonical_etypes, graph=True)
    model = model.to(F.ctx())
    explainer = nn.explain.HeteroGNNExplainer(model, num_hops=1)
    feat_mask, edge_mask = explainer.explain_graph(g, feat)


1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
@parametrize_idtype
@pytest.mark.parametrize(
    "g",
    get_cases(
        ["homo"],
        exclude=[
            "zero-degree",
            "homo-zero-degree",
            "has_feature",
            "has_scalar_e_feature",
            "row_sorted",
            "col_sorted",
            "batched",
        ],
    ),
)
@pytest.mark.parametrize("n_classes", [2])
def test_subgraphx(g, idtype, n_classes):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    feat = F.randn((g.num_nodes(), 5))

    class Model(th.nn.Module):
        def __init__(self, in_dim, n_classes):
            super().__init__()
            self.conv = nn.GraphConv(in_dim, n_classes)
            self.pool = nn.AvgPooling()

        def forward(self, g, h):
            h = th.nn.functional.relu(self.conv(g, h))
            return self.pool(g, h)

    model = Model(feat.shape[1], n_classes)
    model = model.to(ctx)
    explainer = nn.SubgraphX(
        model, num_hops=1, shapley_steps=20, num_rollouts=5, coef=2.0
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
    )
    explainer.explain_graph(g, feat, target_class=0)


@pytest.mark.parametrize("g", get_cases(["hetero"], exclude=["zero-degree"]))
@pytest.mark.parametrize("idtype", [F.int64])
@pytest.mark.parametrize("input_dim", [5])
@pytest.mark.parametrize("n_classes", [2])
def test_heterosubgraphx(g, idtype, input_dim, n_classes):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    device = g.device

    # add self-loop and reverse edges
    transform1 = dgl.transforms.AddSelfLoop(new_etypes=True)
    g = transform1(g)
    transform2 = dgl.transforms.AddReverse(copy_edata=True)
    g = transform2(g)

    feat = {
        ntype: th.zeros((g.num_nodes(ntype), input_dim), device=device)
        for ntype in g.ntypes
    }

    class Model(th.nn.Module):
        def __init__(self, in_dim, n_classes, canonical_etypes):
            super(Model, self).__init__()
            self.etype_weights = th.nn.ModuleDict(
                {
                    "_".join(c_etype): th.nn.Linear(in_dim, n_classes)
                    for c_etype in canonical_etypes
                }
            )

        def forward(self, graph, feat):
            with graph.local_scope():
                c_etype_func_dict = {}
                for c_etype in graph.canonical_etypes:
                    src_type, etype, dst_type = c_etype
                    wh = self.etype_weights["_".join(c_etype)](feat[src_type])
                    graph.nodes[src_type].data[f"h_{c_etype}"] = wh
                    c_etype_func_dict[c_etype] = (
                        fn.copy_u(f"h_{c_etype}", "m"),
                        fn.mean("m", "h"),
                    )
                graph.multi_update_all(c_etype_func_dict, "sum")
                hg = 0
                for ntype in graph.ntypes:
                    if graph.num_nodes(ntype):
                        hg = hg + dgl.mean_nodes(graph, "h", ntype=ntype)

                return hg

    model = Model(input_dim, n_classes, g.canonical_etypes)
    model = model.to(ctx)
    explainer = nn.HeteroSubgraphX(
        model, num_hops=1, shapley_steps=20, num_rollouts=5, coef=2.0
1789
1790
1791
1792
    )
    explainer.explain_graph(g, feat, target_class=0)


1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
@parametrize_idtype
@pytest.mark.parametrize(
    "g",
    get_cases(
        ["homo"],
        exclude=[
            "zero-degree",
            "homo-zero-degree",
            "has_feature",
            "has_scalar_e_feature",
            "row_sorted",
            "col_sorted",
        ],
    ),
)
@pytest.mark.parametrize("n_classes", [2])
def test_pgexplainer(g, idtype, n_classes):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    feat = F.randn((g.num_nodes(), 5))
    g.ndata["attr"] = feat

    # add reverse edges
    transform = dgl.transforms.AddReverse(copy_edata=True)
    g = transform(g)

    class Model(th.nn.Module):
        def __init__(self, in_feats, out_feats):
            super(Model, self).__init__()
            self.conv = nn.GraphConv(in_feats, out_feats)
            self.fc = th.nn.Linear(out_feats, out_feats)
            th.nn.init.xavier_uniform_(self.fc.weight)

        def forward(self, g, h, embed=False, edge_weight=None):
            h = self.conv(g, h, edge_weight=edge_weight)
1828
1829
1830
1831
1832

            if embed:
                return h

            with g.local_scope():
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
                g.ndata["h"] = h
                hg = dgl.mean_nodes(g, "h")
                return self.fc(hg)

    model = Model(feat.shape[1], n_classes)
    model = model.to(ctx)

    explainer = nn.PGExplainer(model, n_classes)
    explainer.train_step(g, g.ndata["attr"], 5.0)

    probs, edge_weight = explainer.explain_graph(g, feat)


1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
@pytest.mark.parametrize("g", get_cases(["hetero"]))
@pytest.mark.parametrize("idtype", [F.int64])
@pytest.mark.parametrize("input_dim", [5])
@pytest.mark.parametrize("n_classes", [2])
def test_heteropgexplainer(g, idtype, input_dim, n_classes):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    feat = {
        ntype: F.randn((g.num_nodes(ntype), input_dim)) for ntype in g.ntypes
    }

    # add self-loop and reverse edges
    transform1 = dgl.transforms.AddSelfLoop(new_etypes=True)
    g = transform1(g)
    transform2 = dgl.transforms.AddReverse(copy_edata=True)
    g = transform2(g)

    class Model(th.nn.Module):
        def __init__(self, in_feats, embed_dim, out_feats, canonical_etypes):
            super(Model, self).__init__()
            self.conv = nn.HeteroGraphConv(
                {
                    c_etype: nn.GraphConv(in_feats, embed_dim)
                    for c_etype in canonical_etypes
                }
            )
            self.fc = th.nn.Linear(embed_dim, out_feats)

        def forward(self, g, h, embed=False, edge_weight=None):
            if edge_weight is not None:
                mod_kwargs = {
                    etype: {"edge_weight": mask}
                    for etype, mask in edge_weight.items()
                }
                h = self.conv(g, h, mod_kwargs=mod_kwargs)
            else:
                h = self.conv(g, h)

            if embed:
                return h

            with g.local_scope():
                g.ndata["h"] = h
                hg = 0
                for ntype in g.ntypes:
                    hg = hg + dgl.mean_nodes(g, "h", ntype=ntype)
                return self.fc(hg)

    embed_dim = input_dim
    model = Model(input_dim, embed_dim, n_classes, g.canonical_etypes)
    model = model.to(ctx)

    explainer = nn.HeteroPGExplainer(model, embed_dim)
    explainer.train_step(g, feat, 5.0)

    probs, edge_weight = explainer.explain_graph(g, feat)


Mufei Li's avatar
Mufei Li committed
1904
1905
1906
1907
1908
1909
def test_jumping_knowledge():
    ctx = F.ctx()
    num_layers = 2
    num_nodes = 3
    num_feats = 4

1910
1911
1912
    feat_list = [
        th.randn((num_nodes, num_feats)).to(ctx) for _ in range(num_layers)
    ]
Mufei Li's avatar
Mufei Li committed
1913

1914
    model = nn.JumpingKnowledge("cat").to(ctx)
Mufei Li's avatar
Mufei Li committed
1915
1916
1917
    model.reset_parameters()
    assert model(feat_list).shape == (num_nodes, num_layers * num_feats)

1918
    model = nn.JumpingKnowledge("max").to(ctx)
Mufei Li's avatar
Mufei Li committed
1919
1920
1921
    model.reset_parameters()
    assert model(feat_list).shape == (num_nodes, num_feats)

1922
    model = nn.JumpingKnowledge("lstm", num_feats, num_layers).to(ctx)
Mufei Li's avatar
Mufei Li committed
1923
1924
1925
    model.reset_parameters()
    assert model(feat_list).shape == (num_nodes, num_feats)

1926
1927

@pytest.mark.parametrize("op", ["dot", "cos", "ele", "cat"])
Mufei Li's avatar
Mufei Li committed
1928
1929
1930
1931
1932
1933
1934
1935
1936
def test_edge_predictor(op):
    ctx = F.ctx()
    num_pairs = 3
    in_feats = 4
    out_feats = 5
    h_src = th.randn((num_pairs, in_feats)).to(ctx)
    h_dst = th.randn((num_pairs, in_feats)).to(ctx)

    pred = nn.EdgePredictor(op)
1937
    if op in ["dot", "cos"]:
Mufei Li's avatar
Mufei Li committed
1938
        assert pred(h_src, h_dst).shape == (num_pairs, 1)
1939
    elif op == "ele":
Mufei Li's avatar
Mufei Li committed
1940
1941
1942
1943
1944
1945
        assert pred(h_src, h_dst).shape == (num_pairs, in_feats)
    else:
        assert pred(h_src, h_dst).shape == (num_pairs, 2 * in_feats)
    pred = nn.EdgePredictor(op, in_feats, out_feats, bias=True).to(ctx)
    assert pred(h_src, h_dst).shape == (num_pairs, out_feats)

Mufei Li's avatar
Mufei Li committed
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960

def test_ke_score_funcs():
    ctx = F.ctx()
    num_edges = 30
    num_rels = 3
    nfeats = 4

    h_src = th.randn((num_edges, nfeats)).to(ctx)
    h_dst = th.randn((num_edges, nfeats)).to(ctx)
    rels = th.randint(low=0, high=num_rels, size=(num_edges,)).to(ctx)

    score_func = nn.TransE(num_rels=num_rels, feats=nfeats).to(ctx)
    score_func.reset_parameters()
    score_func(h_src, h_dst, rels).shape == (num_edges)

1961
1962
1963
    score_func = nn.TransR(
        num_rels=num_rels, rfeats=nfeats - 1, nfeats=nfeats
    ).to(ctx)
Mufei Li's avatar
Mufei Li committed
1964
1965
1966
1967
    score_func.reset_parameters()
    score_func(h_src, h_dst, rels).shape == (num_edges)


1968
def test_twirls():
1969
    g = dgl.graph(([0, 1, 2, 3, 2, 5], [1, 2, 3, 4, 0, 3]))
1970
    feat = th.ones(6, 10)
1971
1972
1973
1974
    conv = nn.TWIRLSConv(10, 2, 128, prop_step=64)
    res = conv(g, feat)
    assert res.size() == (6, 2)

1975

1976
1977
1978
1979
@pytest.mark.parametrize("feat_size", [4, 32])
@pytest.mark.parametrize(
    "regularizer,num_bases", [(None, None), ("basis", 4), ("bdd", 4)]
)
1980
1981
1982
def test_typed_linear(feat_size, regularizer, num_bases):
    dev = F.ctx()
    num_types = 5
1983
1984
1985
1986
1987
1988
1989
    lin = nn.TypedLinear(
        feat_size,
        feat_size * 2,
        5,
        regularizer=regularizer,
        num_bases=num_bases,
    ).to(dev)
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
    print(lin)
    x = th.randn(100, feat_size).to(dev)
    x_type = th.randint(0, 5, (100,)).to(dev)
    x_type_sorted, idx = th.sort(x_type)
    _, rev_idx = th.sort(idx)
    x_sorted = x[idx]

    # test unsorted
    y = lin(x, x_type)
    assert y.shape == (100, feat_size * 2)
    # test sorted
    y_sorted = lin(x_sorted, x_type_sorted, sorted_by_type=True)
    assert y_sorted.shape == (100, feat_size * 2)

    assert th.allclose(y, y_sorted[rev_idx], atol=1e-4, rtol=1e-4)
2005

2006

nv-dlasalle's avatar
nv-dlasalle committed
2007
@parametrize_idtype
2008
2009
@pytest.mark.parametrize("in_size", [4])
@pytest.mark.parametrize("num_heads", [1])
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
def test_hgt(idtype, in_size, num_heads):
    dev = F.ctx()
    num_etypes = 5
    num_ntypes = 2
    head_size = in_size // num_heads

    g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.01))
    g = g.astype(idtype).to(dev)
    etype = th.tensor([i % num_etypes for i in range(g.num_edges())]).to(dev)
    ntype = th.tensor([i % num_ntypes for i in range(g.num_nodes())]).to(dev)
    x = th.randn(g.num_nodes(), in_size).to(dev)
2021

2022
2023
2024
    m = nn.HGTConv(in_size, head_size, num_heads, num_ntypes, num_etypes).to(
        dev
    )
2025
2026
2027
2028
2029
2030
2031

    y = m(g, x, ntype, etype)
    assert y.shape == (g.num_nodes(), head_size * num_heads)
    # presorted
    sorted_ntype, idx_nt = th.sort(ntype)
    sorted_etype, idx_et = th.sort(etype)
    _, rev_idx = th.sort(idx_nt)
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
    g.ndata["t"] = ntype
    g.ndata["x"] = x
    g.edata["t"] = etype
    sorted_g = dgl.reorder_graph(
        g,
        node_permute_algo="custom",
        edge_permute_algo="custom",
        permute_config={
            "nodes_perm": idx_nt.to(idtype),
            "edges_perm": idx_et.to(idtype),
        },
    )
    print(sorted_g.ndata["t"])
    print(sorted_g.edata["t"])
    sorted_x = sorted_g.ndata["x"]
    sorted_y = m(
        sorted_g, sorted_x, sorted_ntype, sorted_etype, presorted=False
    )
2050
    assert sorted_y.shape == (g.num_nodes(), head_size * num_heads)
dddg617's avatar
dddg617 committed
2051
    # mini-batch
2052
    train_idx = th.randperm(100, dtype=idtype)[:10]
dddg617's avatar
dddg617 committed
2053
    sampler = dgl.dataloading.NeighborSampler([-1])
2054
2055
2056
    train_loader = dgl.dataloading.DataLoader(
        g, train_idx.to(dev), sampler, batch_size=8, device=dev, shuffle=True
    )
dddg617's avatar
dddg617 committed
2057
2058
2059
2060
2061
2062
2063
2064
    (input_nodes, output_nodes, block) = next(iter(train_loader))
    block = block[0]
    x = x[input_nodes.to(th.long)]
    ntype = ntype[input_nodes.to(th.long)]
    edge = block.edata[dgl.EID]
    etype = etype[edge.to(th.long)]
    y = m(block, x, ntype, etype)
    assert y.shape == (block.number_of_dst_nodes(), head_size * num_heads)
2065
    # TODO(minjie): enable the following check
2066
2067
    # assert th.allclose(y, sorted_y[rev_idx], atol=1e-4, rtol=1e-4)

2068

2069
2070
@pytest.mark.parametrize("self_loop", [True, False])
@pytest.mark.parametrize("get_distances", [True, False])
2071
def test_radius_graph(self_loop, get_distances):
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
    pos = th.tensor(
        [
            [0.1, 0.3, 0.4],
            [0.5, 0.2, 0.1],
            [0.7, 0.9, 0.5],
            [0.3, 0.2, 0.5],
            [0.2, 0.8, 0.2],
            [0.9, 0.2, 0.1],
            [0.7, 0.4, 0.4],
            [0.2, 0.1, 0.6],
            [0.5, 0.3, 0.5],
            [0.4, 0.2, 0.6],
        ]
    )
2086
2087
2088
2089
2090
2091
2092
2093
2094

    rg = nn.RadiusGraph(0.3, self_loop=self_loop)

    if get_distances:
        g, dists = rg(pos, get_distances=get_distances)
    else:
        g = rg(pos)

    if self_loop:
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
        src_target = th.tensor(
            [
                0,
                0,
                1,
                2,
                3,
                3,
                3,
                3,
                3,
                4,
                5,
                6,
                6,
                7,
                7,
                7,
                8,
                8,
                8,
                8,
                9,
                9,
                9,
                9,
            ]
        )
        dst_target = th.tensor(
            [
                0,
                3,
                1,
                2,
                0,
                3,
                7,
                8,
                9,
                4,
                5,
                6,
                8,
                3,
                7,
                9,
                3,
                6,
                8,
                9,
                3,
                7,
                8,
                9,
            ]
        )
2151
2152

        if get_distances:
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
            dists_target = th.tensor(
                [
                    [0.0000],
                    [0.2449],
                    [0.0000],
                    [0.0000],
                    [0.2449],
                    [0.0000],
                    [0.1732],
                    [0.2236],
                    [0.1414],
                    [0.0000],
                    [0.0000],
                    [0.0000],
                    [0.2449],
                    [0.1732],
                    [0.0000],
                    [0.2236],
                    [0.2236],
                    [0.2449],
                    [0.0000],
                    [0.1732],
                    [0.1414],
                    [0.2236],
                    [0.1732],
                    [0.0000],
                ]
            )
2181
2182
2183
2184
2185
    else:
        src_target = th.tensor([0, 3, 3, 3, 3, 6, 7, 7, 8, 8, 8, 9, 9, 9])
        dst_target = th.tensor([3, 0, 7, 8, 9, 8, 3, 9, 3, 6, 9, 3, 7, 8])

        if get_distances:
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
            dists_target = th.tensor(
                [
                    [0.2449],
                    [0.2449],
                    [0.1732],
                    [0.2236],
                    [0.1414],
                    [0.2449],
                    [0.1732],
                    [0.2236],
                    [0.2236],
                    [0.2449],
                    [0.1732],
                    [0.1414],
                    [0.2236],
                    [0.1732],
                ]
            )
2204
2205
2206
2207
2208
2209
2210
2211
2212

    src, dst = g.edges()

    assert th.equal(src, src_target)
    assert th.equal(dst, dst_target)

    if get_distances:
        assert th.allclose(dists, dists_target, rtol=1e-03)

2213

nv-dlasalle's avatar
nv-dlasalle committed
2214
@parametrize_idtype
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
def test_group_rev_res(idtype):
    dev = F.ctx()

    num_nodes = 5
    num_edges = 20
    feats = 32
    groups = 2
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    h = th.randn(num_nodes, feats).to(dev)
    conv = nn.GraphConv(feats // groups, feats // groups)
    model = nn.GroupRevRes(conv, groups).to(dev)
2226
2227
    result = model(g, h)
    result.sum().backward()
rudongyu's avatar
rudongyu committed
2228

2229
2230
2231
2232
2233

@pytest.mark.parametrize("in_size", [16, 32])
@pytest.mark.parametrize("hidden_size", [16, 32])
@pytest.mark.parametrize("out_size", [16, 32])
@pytest.mark.parametrize("edge_feat_size", [16, 10, 0])
rudongyu's avatar
rudongyu committed
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
def test_egnn_conv(in_size, hidden_size, out_size, edge_feat_size):
    dev = F.ctx()
    num_nodes = 5
    num_edges = 20
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    h = th.randn(num_nodes, in_size).to(dev)
    x = th.randn(num_nodes, 3).to(dev)
    e = th.randn(num_edges, edge_feat_size).to(dev)
    model = nn.EGNNConv(in_size, hidden_size, out_size, edge_feat_size).to(dev)
    model(g, h, x, e)

2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274

@pytest.mark.parametrize("in_size", [16, 32])
@pytest.mark.parametrize("out_size", [16, 32])
@pytest.mark.parametrize(
    "aggregators",
    [
        ["mean", "max", "sum"],
        ["min", "std", "var"],
        ["moment3", "moment4", "moment5"],
    ],
)
@pytest.mark.parametrize(
    "scalers", [["identity"], ["amplification", "attenuation"]]
)
@pytest.mark.parametrize("delta", [2.5, 7.4])
@pytest.mark.parametrize("dropout", [0.0, 0.1])
@pytest.mark.parametrize("num_towers", [1, 4])
@pytest.mark.parametrize("edge_feat_size", [16, 0])
@pytest.mark.parametrize("residual", [True, False])
def test_pna_conv(
    in_size,
    out_size,
    aggregators,
    scalers,
    delta,
    dropout,
    num_towers,
    edge_feat_size,
    residual,
):
rudongyu's avatar
rudongyu committed
2275
2276
2277
2278
2279
2280
    dev = F.ctx()
    num_nodes = 5
    num_edges = 20
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    h = th.randn(num_nodes, in_size).to(dev)
    e = th.randn(num_edges, edge_feat_size).to(dev)
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
    model = nn.PNAConv(
        in_size,
        out_size,
        aggregators,
        scalers,
        delta,
        dropout,
        num_towers,
        edge_feat_size,
        residual,
    ).to(dev)
rudongyu's avatar
rudongyu committed
2292
    model(g, h, edge_feat=e)
2293

2294
2295
2296
2297
2298
2299
2300

@pytest.mark.parametrize("k", [3, 5])
@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0])
@pytest.mark.parametrize("norm_type", ["sym", "row"])
@pytest.mark.parametrize("clamp", [True, False])
@pytest.mark.parametrize("normalize", [True, False])
@pytest.mark.parametrize("reset", [True, False])
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
def test_label_prop(k, alpha, norm_type, clamp, normalize, reset):
    dev = F.ctx()
    num_nodes = 5
    num_edges = 20
    num_classes = 4
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    labels = th.tensor([0, 2, 1, 3, 0]).long().to(dev)
    ml_labels = th.rand(num_nodes, num_classes).to(dev) > 0.7
    mask = th.tensor([0, 1, 1, 1, 0]).bool().to(dev)
    model = nn.LabelPropagation(k, alpha, norm_type, clamp, normalize, reset)
    model(g, labels, mask)
    # multi-label case
    model(g, ml_labels, mask)

2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326

@pytest.mark.parametrize("in_size", [16])
@pytest.mark.parametrize("out_size", [16, 32])
@pytest.mark.parametrize(
    "aggregators", [["mean", "max", "dir2-av"], ["min", "std", "dir1-dx"]]
)
@pytest.mark.parametrize("scalers", [["amplification", "attenuation"]])
@pytest.mark.parametrize("delta", [2.5])
@pytest.mark.parametrize("edge_feat_size", [16, 0])
def test_dgn_conv(
    in_size, out_size, aggregators, scalers, delta, edge_feat_size
):
2327
2328
2329
2330
2331
2332
    dev = F.ctx()
    num_nodes = 5
    num_edges = 20
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    h = th.randn(num_nodes, in_size).to(dev)
    e = th.randn(num_edges, edge_feat_size).to(dev)
2333
    transform = dgl.LaplacianPE(k=3, feat_name="eig")
2334
    g = transform(g)
2335
2336
2337
2338
2339
2340
2341
2342
2343
    eig = g.ndata["eig"]
    model = nn.DGNConv(
        in_size,
        out_size,
        aggregators,
        scalers,
        delta,
        edge_feat_size=edge_feat_size,
    ).to(dev)
2344
2345
    model(g, h, edge_feat=e, eig_vec=eig)

2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
    aggregators_non_eig = [
        aggr for aggr in aggregators if not aggr.startswith("dir")
    ]
    model = nn.DGNConv(
        in_size,
        out_size,
        aggregators_non_eig,
        scalers,
        delta,
        edge_feat_size=edge_feat_size,
    ).to(dev)
2357
    model(g, h, edge_feat=e)
LuckyLiuM's avatar
LuckyLiuM committed
2358

2359

LuckyLiuM's avatar
LuckyLiuM committed
2360
2361
2362
def test_DeepWalk():
    dev = F.ctx()
    g = dgl.graph(([0, 1, 2, 1, 2, 0], [1, 2, 0, 0, 1, 2]))
2363
2364
2365
    model = nn.DeepWalk(
        g, emb_dim=8, walk_length=2, window_size=1, fast_neg=True, sparse=True
    )
LuckyLiuM's avatar
LuckyLiuM committed
2366
    model = model.to(dev)
2367
2368
2369
    dataloader = DataLoader(
        torch.arange(g.num_nodes()), batch_size=16, collate_fn=model.sample
    )
LuckyLiuM's avatar
LuckyLiuM committed
2370
2371
2372
2373
2374
2375
    optim = SparseAdam(model.parameters(), lr=0.01)
    walk = next(iter(dataloader)).to(dev)
    loss = model(walk)
    loss.backward()
    optim.step()

2376
2377
2378
    model = nn.DeepWalk(
        g, emb_dim=8, walk_length=2, window_size=1, fast_neg=False, sparse=False
    )
LuckyLiuM's avatar
LuckyLiuM committed
2379
    model = model.to(dev)
2380
2381
2382
    dataloader = DataLoader(
        torch.arange(g.num_nodes()), batch_size=16, collate_fn=model.sample
    )
LuckyLiuM's avatar
LuckyLiuM committed
2383
2384
2385
2386
2387
    optim = Adam(model.parameters(), lr=0.01)
    walk = next(iter(dataloader)).to(dev)
    loss = model(walk)
    loss.backward()
    optim.step()
LuckyLiuM's avatar
LuckyLiuM committed
2388

2389
2390
2391
2392

@pytest.mark.parametrize("max_degree", [2, 6])
@pytest.mark.parametrize("embedding_dim", [8, 16])
@pytest.mark.parametrize("direction", ["in", "out", "both"])
2393
def test_degree_encoder(max_degree, embedding_dim, direction):
2394
    g1 = dgl.graph(
2395
2396
2397
2398
2399
        (
            th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),
            th.tensor([1, 2, 3, 0, 3, 0, 0, 1]),
        )
    )
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
    g2 = dgl.graph(
        (
            th.tensor([0, 1]),
            th.tensor([1, 0]),
        )
    )
    in_degree = pad_sequence(
        [g1.in_degrees(), g2.in_degrees()], batch_first=True
    )
    out_degree = pad_sequence(
        [g1.out_degrees(), g2.out_degrees()], batch_first=True
    )
2412
    model = nn.DegreeEncoder(max_degree, embedding_dim, direction=direction)
2413
2414
2415
2416
2417
2418
2419
    if direction == "in":
        de_g = model(in_degree)
    elif direction == "out":
        de_g = model(out_degree)
    elif direction == "both":
        de_g = model(th.stack((in_degree, out_degree)))
    assert de_g.shape == (2, 4, embedding_dim)
2420

2421

LuckyLiuM's avatar
LuckyLiuM committed
2422
2423
2424
@parametrize_idtype
def test_MetaPath2Vec(idtype):
    dev = F.ctx()
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
    g = dgl.heterograph(
        {
            ("user", "uc", "company"): ([0, 0, 2, 1, 3], [1, 2, 1, 3, 0]),
            ("company", "cp", "product"): (
                [0, 0, 0, 1, 2, 3],
                [0, 2, 3, 0, 2, 1],
            ),
            ("company", "cu", "user"): ([1, 2, 1, 3, 0], [0, 0, 2, 1, 3]),
            ("product", "pc", "company"): (
                [0, 2, 3, 0, 2, 1],
                [0, 0, 0, 1, 2, 3],
            ),
        },
        idtype=idtype,
        device=dev,
    )
    model = nn.MetaPath2Vec(g, ["uc", "cu"], window_size=1)
LuckyLiuM's avatar
LuckyLiuM committed
2442
2443
2444
    model = model.to(dev)
    embeds = model.node_embed.weight
    assert embeds.shape[0] == g.num_nodes()
2445

2446
2447
2448
2449
2450
2451
2452

@pytest.mark.parametrize("num_layer", [1, 4])
@pytest.mark.parametrize("k", [3, 5])
@pytest.mark.parametrize("lpe_dim", [4, 16])
@pytest.mark.parametrize("n_head", [1, 4])
@pytest.mark.parametrize("batch_norm", [True, False])
@pytest.mark.parametrize("num_post_layer", [0, 1, 2])
2453
def test_LapPosEncoder(
2454
2455
    num_layer, k, lpe_dim, n_head, batch_norm, num_post_layer
):
2456
2457
2458
2459
2460
2461
    ctx = F.ctx()
    num_nodes = 4

    EigVals = th.randn((num_nodes, k)).to(ctx)
    EigVecs = th.randn((num_nodes, k)).to(ctx)

2462
    model = nn.LapPosEncoder(
2463
2464
        "Transformer", num_layer, k, lpe_dim, n_head, batch_norm, num_post_layer
    ).to(ctx)
2465
2466
    assert model(EigVals, EigVecs).shape == (num_nodes, lpe_dim)

2467
    model = nn.LapPosEncoder(
2468
2469
2470
2471
2472
2473
2474
        "DeepSet",
        num_layer,
        k,
        lpe_dim,
        batch_norm=batch_norm,
        num_post_layer=num_post_layer,
    ).to(ctx)
2475
    assert model(EigVals, EigVecs).shape == (num_nodes, lpe_dim)
2476

2477
2478
2479
2480
2481
2482

@pytest.mark.parametrize("feat_size", [128, 512])
@pytest.mark.parametrize("num_heads", [8, 16])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("attn_bias_type", ["add", "mul"])
@pytest.mark.parametrize("attn_drop", [0.1, 0.5])
2483
def test_BiasedMHA(feat_size, num_heads, bias, attn_bias_type, attn_drop):
2484
2485
2486
2487
    ndata = th.rand(16, 100, feat_size)
    attn_bias = th.rand(16, 100, 100, num_heads)
    attn_mask = th.rand(16, 100, 100) < 0.5

2488
    net = nn.BiasedMHA(feat_size, num_heads, bias, attn_bias_type, attn_drop)
2489
2490
2491
    out = net(ndata, attn_bias, attn_mask)

    assert out.shape == (16, 100, feat_size)
2492

2493
2494
2495

@pytest.mark.parametrize("attn_bias_type", ["add", "mul"])
@pytest.mark.parametrize("norm_first", [True, False])
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
def test_GraphormerLayer(attn_bias_type, norm_first):
    batch_size = 16
    num_nodes = 100
    feat_size = 512
    num_heads = 8
    nfeat = th.rand(batch_size, num_nodes, feat_size)
    attn_bias = th.rand(batch_size, num_nodes, num_nodes, num_heads)
    attn_mask = th.rand(batch_size, num_nodes, num_nodes) < 0.5

    net = nn.GraphormerLayer(
        feat_size=feat_size,
        hidden_size=2048,
        num_heads=num_heads,
        attn_bias_type=attn_bias_type,
        norm_first=norm_first,
        dropout=0.1,
2512
        attn_dropout=0.1,
2513
        activation=th.nn.ReLU(),
2514
2515
2516
2517
2518
    )
    out = net(nfeat, attn_bias, attn_mask)

    assert out.shape == (batch_size, num_nodes, feat_size)

2519

2520
@pytest.mark.parametrize("max_len", [1, 2])
2521
2522
@pytest.mark.parametrize("feat_dim", [16])
@pytest.mark.parametrize("num_heads", [1, 8])
2523
2524
def test_PathEncoder(max_len, feat_dim, num_heads):
    dev = F.ctx()
2525
    g = dgl.graph(
2526
2527
2528
2529
2530
        (
            th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),
            th.tensor([1, 2, 3, 0, 3, 0, 0, 1]),
        )
    ).to(dev)
2531
2532
2533
2534
    edge_feat = th.rand(g.num_edges(), feat_dim).to(dev)
    edge_feat = th.cat((edge_feat, th.zeros(1, 16).to(dev)), dim=0)
    dist, path = shortest_dist(g, root=None, return_paths=True)
    path_data = edge_feat[path[:, :, :max_len]]
2535
    model = nn.PathEncoder(max_len, feat_dim, num_heads=num_heads).to(dev)
2536
2537
    bias = model(dist.unsqueeze(0), path_data.unsqueeze(0))
    assert bias.shape == (1, 4, 4, num_heads)
2538

2539
2540
2541
2542

@pytest.mark.parametrize("max_dist", [1, 4])
@pytest.mark.parametrize("num_kernels", [8, 16])
@pytest.mark.parametrize("num_heads", [1, 8])
2543
2544
def test_SpatialEncoder(max_dist, num_kernels, num_heads):
    dev = F.ctx()
2545
2546
2547
2548
2549
2550
2551
2552
2553
    g1 = dgl.graph(
        (
            th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),
            th.tensor([1, 2, 3, 0, 3, 0, 0, 1]),
        )
    ).to(dev)
    g2 = dgl.graph(
        (th.tensor([0, 1, 2, 3, 2, 5]), th.tensor([1, 2, 3, 4, 0, 3]))
    ).to(dev)
2554
2555
2556
2557
    bg = dgl.batch([g1, g2])
    ndata = th.rand(bg.num_nodes(), 3).to(dev)
    num_nodes = bg.num_nodes()
    node_type = th.randint(0, 512, (num_nodes,)).to(dev)
2558
2559
2560
    dist = -th.ones((2, 6, 6), dtype=th.long).to(dev)
    dist[0, :4, :4] = shortest_dist(g1, root=None, return_paths=False)
    dist[1, :6, :6] = shortest_dist(g2, root=None, return_paths=False)
2561
2562
2563
2564
2565
    model_1 = nn.SpatialEncoder(max_dist, num_heads=num_heads).to(dev)
    model_2 = nn.SpatialEncoder3d(num_kernels, num_heads=num_heads).to(dev)
    model_3 = nn.SpatialEncoder3d(
        num_kernels, num_heads=num_heads, max_node_type=512
    ).to(dev)
2566
    encoding = model_1(dist)
2567
2568
2569
2570
2571
    encoding3d_1 = model_2(bg, ndata)
    encoding3d_2 = model_3(bg, ndata, node_type)
    assert encoding.shape == (2, 6, 6, num_heads)
    assert encoding3d_1.shape == (2, 6, 6, num_heads)
    assert encoding3d_2.shape == (2, 6, 6, num_heads)