"tests/python/pytorch/optim/test_optim.py" did not exist on "3c387988d7addc4a6b92785c12b64566d164bb55"
test_nn.py 78.1 KB
Newer Older
1
import io
2
import pickle
3
import random
4
5
6
7
from copy import deepcopy

import backend as F

8
import dgl
9
import dgl.function as fn
10
11
import dgl.nn.pytorch as nn
import networkx as nx
12
import numpy as np  # For setting seed for scipy
13
import pytest
14
import scipy as sp
LuckyLiuM's avatar
LuckyLiuM committed
15
import torch
16
import torch as th
17
18
from dgl import shortest_dist
from torch.nn.utils.rnn import pad_sequence
19
20
21
22
from torch.optim import Adam, SparseAdam
from torch.utils.data import DataLoader
from utils import parametrize_idtype
from utils.graph_cases import (
23
24
25
26
27
    get_cases,
    random_bipartite,
    random_dglgraph,
    random_graph,
)
28

29
30
31
32
33
34
35
# Set seeds to make tests fully reproducible.
SEED = 12345  # random.randint(1, 99999)
random.seed(SEED)  # For networkx
np.random.seed(SEED)  # For scipy
dgl.seed(SEED)
F.seed(SEED)

36
37
tmp_buffer = io.BytesIO()

38

39
40
41
42
43
def _AXWb(A, X, W, b):
    X = th.matmul(X, W)
    Y = th.matmul(A, X.view(X.shape[0], -1)).view_as(X)
    return Y + b

44
45

@pytest.mark.parametrize("out_dim", [1, 2])
46
def test_graph_conv0(out_dim):
47
    g = dgl.DGLGraph(nx.path_graph(3)).to(F.ctx())
48
    ctx = F.ctx()
49
    adj = g.adj_external(transpose=True, ctx=ctx)
50

51
    conv = nn.GraphConv(5, out_dim, norm="none", bias=True)
52
    conv = conv.to(ctx)
53
    print(conv)
54
55
56
57

    # test pickle
    th.save(conv, tmp_buffer)

58
    # test#1: basic
59
    h0 = F.ones((3, 5))
60
    h1 = conv(g, h0)
61
62
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
63
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
64
    # test#2: more-dim
65
    h0 = F.ones((3, 5, 5))
66
    h1 = conv(g, h0)
67
68
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
69
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
70

71
    conv = nn.GraphConv(5, out_dim)
72
    conv = conv.to(ctx)
73
    # test#3: basic
74
    h0 = F.ones((3, 5))
75
    h1 = conv(g, h0)
76
77
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
78
    # test#4: basic
79
    h0 = F.ones((3, 5, 5))
80
    h1 = conv(g, h0)
81
82
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
83

84
    conv = nn.GraphConv(5, out_dim)
85
    conv = conv.to(ctx)
86
    # test#3: basic
87
    h0 = F.ones((3, 5))
88
    h1 = conv(g, h0)
89
90
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
91
    # test#4: basic
92
    h0 = F.ones((3, 5, 5))
93
    h1 = conv(g, h0)
94
95
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
96
97
98
99
100

    # test rest_parameters
    old_weight = deepcopy(conv.weight.data)
    conv.reset_parameters()
    new_weight = conv.weight.data
101
    assert not F.allclose(old_weight, new_weight)
102

103

nv-dlasalle's avatar
nv-dlasalle committed
104
@parametrize_idtype
105
106
107
108
109
110
111
@pytest.mark.parametrize(
    "g", get_cases(["homo", "bipartite"], exclude=["zero-degree", "dglgraph"])
)
@pytest.mark.parametrize("norm", ["none", "both", "right", "left"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("out_dim", [1, 2])
112
def test_graph_conv(idtype, g, norm, weight, bias, out_dim):
113
114
    # Test one tensor input
    g = g.astype(idtype).to(F.ctx())
115
116
117
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(
        F.ctx()
    )
118
    ext_w = F.randn((5, out_dim)).to(F.ctx())
119
120
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
121
122
    h = F.randn((nsrc, 5)).to(F.ctx())
    if weight:
123
        h_out = conv(g, h)
124
    else:
125
        h_out = conv(g, h, weight=ext_w)
126
    assert h_out.shape == (ndst, out_dim)
127

128

nv-dlasalle's avatar
nv-dlasalle committed
129
@parametrize_idtype
130
131
132
133
134
135
136
137
@pytest.mark.parametrize(
    "g",
    get_cases(["has_scalar_e_feature"], exclude=["zero-degree", "dglgraph"]),
)
@pytest.mark.parametrize("norm", ["none", "both", "right"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("out_dim", [1, 2])
138
def test_graph_conv_e_weight(idtype, g, norm, weight, bias, out_dim):
139
    g = g.astype(idtype).to(F.ctx())
140
141
142
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(
        F.ctx()
    )
143
    ext_w = F.randn((5, out_dim)).to(F.ctx())
144
145
146
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
147
    e_w = g.edata["scalar_w"]
148
149
150
151
    if weight:
        h_out = conv(g, h, edge_weight=e_w)
    else:
        h_out = conv(g, h, weight=ext_w, edge_weight=e_w)
152
    assert h_out.shape == (ndst, out_dim)
153

154

nv-dlasalle's avatar
nv-dlasalle committed
155
@parametrize_idtype
156
157
158
159
160
161
162
163
@pytest.mark.parametrize(
    "g",
    get_cases(["has_scalar_e_feature"], exclude=["zero-degree", "dglgraph"]),
)
@pytest.mark.parametrize("norm", ["none", "both", "right"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("out_dim", [1, 2])
164
def test_graph_conv_e_weight_norm(idtype, g, norm, weight, bias, out_dim):
165
    g = g.astype(idtype).to(F.ctx())
166
167
168
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(
        F.ctx()
    )
169
170
171
172

    # test pickle
    th.save(conv, tmp_buffer)

173
    ext_w = F.randn((5, out_dim)).to(F.ctx())
174
175
176
177
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
    edgenorm = nn.EdgeWeightNorm(norm=norm)
178
    norm_weight = edgenorm(g, g.edata["scalar_w"])
179
180
181
182
    if weight:
        h_out = conv(g, h, edge_weight=norm_weight)
    else:
        h_out = conv(g, h, weight=ext_w, edge_weight=norm_weight)
183
    assert h_out.shape == (ndst, out_dim)
184

185

nv-dlasalle's avatar
nv-dlasalle committed
186
@parametrize_idtype
187
188
189
190
191
192
193
@pytest.mark.parametrize(
    "g", get_cases(["bipartite"], exclude=["zero-degree", "dglgraph"])
)
@pytest.mark.parametrize("norm", ["none", "both", "right"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("out_dim", [1, 2])
194
def test_graph_conv_bi(idtype, g, norm, weight, bias, out_dim):
195
196
    # Test a pair of tensor inputs
    g = g.astype(idtype).to(F.ctx())
197
198
199
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(
        F.ctx()
    )
Mufei Li's avatar
Mufei Li committed
200

201
202
203
    # test pickle
    th.save(conv, tmp_buffer)

204
    ext_w = F.randn((5, out_dim)).to(F.ctx())
205
206
207
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
208
    h_dst = F.randn((ndst, out_dim)).to(F.ctx())
209
210
211
212
    if weight:
        h_out = conv(g, (h, h_dst))
    else:
        h_out = conv(g, (h, h_dst), weight=ext_w)
213
    assert h_out.shape == (ndst, out_dim)
214

215

216
217
218
219
220
221
222
223
224
225
226
227
def _S2AXWb(A, N, X, W, b):
    X1 = X * N
    X1 = th.matmul(A, X1.view(X1.shape[0], -1))
    X1 = X1 * N
    X2 = X1 * N
    X2 = th.matmul(A, X2.view(X2.shape[0], -1))
    X2 = X2 * N
    X = th.cat([X, X1, X2], dim=-1)
    Y = th.matmul(X, W.rot90())

    return Y + b

228
229

@pytest.mark.parametrize("out_dim", [1, 2])
230
def test_tagconv(out_dim):
231
    g = dgl.DGLGraph(nx.path_graph(3))
232
    g = g.to(F.ctx())
233
    ctx = F.ctx()
234
    adj = g.adj_external(transpose=True, ctx=ctx)
235
236
    norm = th.pow(g.in_degrees().float(), -0.5)

237
    conv = nn.TAGConv(5, out_dim, bias=True)
238
    conv = conv.to(ctx)
239
    print(conv)
Mufei Li's avatar
Mufei Li committed
240

241
242
    # test pickle
    th.save(conv, tmp_buffer)
243
244
245

    # test#1: basic
    h0 = F.ones((3, 5))
246
    h1 = conv(g, h0)
247
248
249
250
251
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    shp = norm.shape + (1,) * (h0.dim() - 1)
    norm = th.reshape(norm, shp).to(ctx)

252
253
254
    assert F.allclose(
        h1, _S2AXWb(adj, norm, h0, conv.lin.weight, conv.lin.bias)
    )
255

256
    conv = nn.TAGConv(5, out_dim)
257
    conv = conv.to(ctx)
258

259
260
    # test#2: basic
    h0 = F.ones((3, 5))
261
    h1 = conv(g, h0)
262
    assert h1.shape[-1] == out_dim
263

264
    # test reset_parameters
265
266
267
268
269
    old_weight = deepcopy(conv.lin.weight.data)
    conv.reset_parameters()
    new_weight = conv.lin.weight.data
    assert not F.allclose(old_weight, new_weight)

270

271
def test_set2set():
272
    ctx = F.ctx()
273
    g = dgl.DGLGraph(nx.path_graph(10))
274
    g = g.to(F.ctx())
275

276
    s2s = nn.Set2Set(5, 3, 3)  # hidden size 5, 3 iters, 3 layers
277
    s2s = s2s.to(ctx)
278
279
280
    print(s2s)

    # test#1: basic
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
281
    h0 = F.randn((g.num_nodes(), 5))
282
    h1 = s2s(g, h0)
283
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
284
285

    # test#2: batched graph
286
287
    g1 = dgl.DGLGraph(nx.path_graph(11)).to(F.ctx())
    g2 = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
288
    bg = dgl.batch([g, g1, g2])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
289
    h0 = F.randn((bg.num_nodes(), 5))
290
    h1 = s2s(bg, h0)
291
292
    assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2

293

294
def test_glob_att_pool():
295
    ctx = F.ctx()
296
    g = dgl.DGLGraph(nx.path_graph(10))
297
    g = g.to(F.ctx())
298
299

    gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))
300
    gap = gap.to(ctx)
301
302
    print(gap)

303
304
305
    # test pickle
    th.save(gap, tmp_buffer)

306
    # test#1: basic
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
307
    h0 = F.randn((g.num_nodes(), 5))
308
    h1 = gap(g, h0)
309
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
310
311
312

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
313
    h0 = F.randn((bg.num_nodes(), 5))
314
    h1 = gap(bg, h0)
315
316
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2

317

318
def test_simple_pool():
319
    ctx = F.ctx()
320
    g = dgl.DGLGraph(nx.path_graph(15))
321
    g = g.to(F.ctx())
322
323
324
325

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
326
    sort_pool = nn.SortPooling(10)  # k = 10
327
328
329
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
330
    h0 = F.randn((g.num_nodes(), 5))
331
332
333
334
    sum_pool = sum_pool.to(ctx)
    avg_pool = avg_pool.to(ctx)
    max_pool = max_pool.to(ctx)
    sort_pool = sort_pool.to(ctx)
335
    h1 = sum_pool(g, h0)
336
    assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
337
    h1 = avg_pool(g, h0)
338
    assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
339
    h1 = max_pool(g, h0)
340
    assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
341
    h1 = sort_pool(g, h0)
342
    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.dim() == 2
343
344

    # test#2: batched graph
345
    g_ = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
346
    bg = dgl.batch([g, g_, g, g_, g])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
347
    h0 = F.randn((bg.num_nodes(), 5))
348
    h1 = sum_pool(bg, h0)
349
350
351
352
353
354
355
356
357
358
    truth = th.stack(
        [
            F.sum(h0[:15], 0),
            F.sum(h0[15:20], 0),
            F.sum(h0[20:35], 0),
            F.sum(h0[35:40], 0),
            F.sum(h0[40:55], 0),
        ],
        0,
    )
359
    assert F.allclose(h1, truth)
360

361
    h1 = avg_pool(bg, h0)
362
363
364
365
366
367
368
369
370
371
    truth = th.stack(
        [
            F.mean(h0[:15], 0),
            F.mean(h0[15:20], 0),
            F.mean(h0[20:35], 0),
            F.mean(h0[35:40], 0),
            F.mean(h0[40:55], 0),
        ],
        0,
    )
372
    assert F.allclose(h1, truth)
373

374
    h1 = max_pool(bg, h0)
375
376
377
378
379
380
381
382
383
384
    truth = th.stack(
        [
            F.max(h0[:15], 0),
            F.max(h0[15:20], 0),
            F.max(h0[20:35], 0),
            F.max(h0[35:40], 0),
            F.max(h0[40:55], 0),
        ],
        0,
    )
385
    assert F.allclose(h1, truth)
386

387
    h1 = sort_pool(bg, h0)
388
389
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2

390

391
def test_set_trans():
392
    ctx = F.ctx()
393
394
    g = dgl.DGLGraph(nx.path_graph(15))

395
396
    st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, "sab")
    st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, "isab", 3)
397
    st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)
398
399
400
    st_enc_0 = st_enc_0.to(ctx)
    st_enc_1 = st_enc_1.to(ctx)
    st_dec = st_dec.to(ctx)
401
402
403
    print(st_enc_0, st_enc_1, st_dec)

    # test#1: basic
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
404
    h0 = F.randn((g.num_nodes(), 50))
405
    h1 = st_enc_0(g, h0)
406
    assert h1.shape == h0.shape
407
    h1 = st_enc_1(g, h0)
408
    assert h1.shape == h0.shape
409
    h2 = st_dec(g, h1)
410
    assert h2.shape[0] == 1 and h2.shape[1] == 200 and h2.dim() == 2
411
412
413
414
415

    # test#2: batched graph
    g1 = dgl.DGLGraph(nx.path_graph(5))
    g2 = dgl.DGLGraph(nx.path_graph(10))
    bg = dgl.batch([g, g1, g2])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
416
    h0 = F.randn((bg.num_nodes(), 50))
417
    h1 = st_enc_0(bg, h0)
418
    assert h1.shape == h0.shape
419
    h1 = st_enc_1(bg, h0)
420
421
    assert h1.shape == h0.shape

422
    h2 = st_dec(bg, h1)
423
424
    assert h2.shape[0] == 3 and h2.shape[1] == 200 and h2.dim() == 2

425

nv-dlasalle's avatar
nv-dlasalle committed
426
@parametrize_idtype
427
@pytest.mark.parametrize("O", [1, 8, 32])
428
def test_rgcn(idtype, O):
Minjie Wang's avatar
Minjie Wang committed
429
430
    ctx = F.ctx()
    etype = []
431
432
    g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.1))
    g = g.astype(idtype).to(F.ctx())
Minjie Wang's avatar
Minjie Wang committed
433
434
    # 5 etypes
    R = 5
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
435
    for i in range(g.num_edges()):
Minjie Wang's avatar
Minjie Wang committed
436
437
438
439
440
441
        etype.append(i % 5)
    B = 2
    I = 10

    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
442
    norm = th.rand((g.num_edges(), 1)).to(ctx)
443
    sorted_r, idx = th.sort(r)
444
445
446
447
448
    sorted_g = dgl.reorder_graph(
        g,
        edge_permute_algo="custom",
        permute_config={"edges_perm": idx.to(idtype)},
    )
449
    sorted_norm = norm[idx]
Minjie Wang's avatar
Minjie Wang committed
450

451
452
    rgc = nn.RelGraphConv(I, O, R).to(ctx)
    th.save(rgc, tmp_buffer)  # test pickle
Minjie Wang's avatar
Minjie Wang committed
453
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
454
    th.save(rgc_basis, tmp_buffer)  # test pickle
455
456
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
457
        th.save(rgc_bdd, tmp_buffer)  # test pickle
458

459
460
461
462
463
    # basic usage
    h_new = rgc(g, h, r)
    assert h_new.shape == (100, O)
    h_new_basis = rgc_basis(g, h, r)
    assert h_new_basis.shape == (100, O)
464
    if O % B == 0:
465
466
467
468
469
470
471
472
473
474
475
        h_new_bdd = rgc_bdd(g, h, r)
        assert h_new_bdd.shape == (100, O)

    # sorted input
    h_new_sorted = rgc(sorted_g, h, sorted_r, presorted=True)
    assert th.allclose(h_new, h_new_sorted, atol=1e-4, rtol=1e-4)
    h_new_basis_sorted = rgc_basis(sorted_g, h, sorted_r, presorted=True)
    assert th.allclose(h_new_basis, h_new_basis_sorted, atol=1e-4, rtol=1e-4)
    if O % B == 0:
        h_new_bdd_sorted = rgc_bdd(sorted_g, h, sorted_r, presorted=True)
        assert th.allclose(h_new_bdd, h_new_bdd_sorted, atol=1e-4, rtol=1e-4)
476

477
478
479
    # norm input
    h_new = rgc(g, h, r, norm)
    assert h_new.shape == (100, O)
480
    h_new = rgc_basis(g, h, r, norm)
481
    assert h_new.shape == (100, O)
482
483
    if O % B == 0:
        h_new = rgc_bdd(g, h, r, norm)
484
        assert h_new.shape == (100, O)
485

486

487
@parametrize_idtype
488
@pytest.mark.parametrize("O", [1, 10, 40])
489
490
491
492
493
494
495
def test_rgcn_default_nbasis(idtype, O):
    ctx = F.ctx()
    etype = []
    g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.1))
    g = g.astype(idtype).to(F.ctx())
    # 5 etypes
    R = 5
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
496
    for i in range(g.num_edges()):
497
498
499
500
501
        etype.append(i % 5)
    I = 10

    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
502
    norm = th.rand((g.num_edges(), 1)).to(ctx)
503
    sorted_r, idx = th.sort(r)
504
505
506
507
508
    sorted_g = dgl.reorder_graph(
        g,
        edge_permute_algo="custom",
        permute_config={"edges_perm": idx.to(idtype)},
    )
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
    sorted_norm = norm[idx]

    rgc = nn.RelGraphConv(I, O, R).to(ctx)
    th.save(rgc, tmp_buffer)  # test pickle
    rgc_basis = nn.RelGraphConv(I, O, R, "basis").to(ctx)
    th.save(rgc_basis, tmp_buffer)  # test pickle
    if O % R == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd").to(ctx)
        th.save(rgc_bdd, tmp_buffer)  # test pickle

    # basic usage
    h_new = rgc(g, h, r)
    assert h_new.shape == (100, O)
    h_new_basis = rgc_basis(g, h, r)
    assert h_new_basis.shape == (100, O)
    if O % R == 0:
        h_new_bdd = rgc_bdd(g, h, r)
        assert h_new_bdd.shape == (100, O)

    # sorted input
    h_new_sorted = rgc(sorted_g, h, sorted_r, presorted=True)
    assert th.allclose(h_new, h_new_sorted, atol=1e-4, rtol=1e-4)
    h_new_basis_sorted = rgc_basis(sorted_g, h, sorted_r, presorted=True)
    assert th.allclose(h_new_basis, h_new_basis_sorted, atol=1e-4, rtol=1e-4)
    if O % R == 0:
        h_new_bdd_sorted = rgc_bdd(sorted_g, h, sorted_r, presorted=True)
        assert th.allclose(h_new_bdd, h_new_bdd_sorted, atol=1e-4, rtol=1e-4)

    # norm input
    h_new = rgc(g, h, r, norm)
    assert h_new.shape == (100, O)
    h_new = rgc_basis(g, h, r, norm)
    assert h_new.shape == (100, O)
    if O % R == 0:
        h_new = rgc_bdd(g, h, r, norm)
        assert h_new.shape == (100, O)
545

546

nv-dlasalle's avatar
nv-dlasalle committed
547
@parametrize_idtype
548
549
550
551
552
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
553
def test_gat_conv(g, idtype, out_dim, num_heads):
554
    ctx = F.ctx()
555
    g = g.astype(idtype).to(ctx)
556
    gat = nn.GATConv(5, out_dim, num_heads)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
557
    feat = F.randn((g.number_of_src_nodes(), 5))
558
    gat = gat.to(ctx)
559
    h = gat(g, feat)
560
561
562
563

    # test pickle
    th.save(gat, tmp_buffer)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
564
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
565
    _, a = gat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
566
    assert a.shape == (g.num_edges(), num_heads, 1)
567

568
569
570
571
572
    # test residual connection
    gat = nn.GATConv(5, out_dim, num_heads, residual=True)
    gat = gat.to(ctx)
    h = gat(g, feat)

573

nv-dlasalle's avatar
nv-dlasalle committed
574
@parametrize_idtype
575
576
577
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
578
def test_gat_conv_bi(g, idtype, out_dim, num_heads):
579
    ctx = F.ctx()
580
    g = g.astype(idtype).to(ctx)
581
    gat = nn.GATConv(5, out_dim, num_heads)
582
583
584
585
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
586
587
    gat = gat.to(ctx)
    h = gat(g, feat)
588
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
589
    _, a = gat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
590
    assert a.shape == (g.num_edges(), num_heads, 1)
591

592

593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
@parametrize_idtype
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
def test_gat_conv_edge_weight(g, idtype, out_dim, num_heads):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    gat = nn.GATConv(5, out_dim, num_heads)
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
    gat = gat.to(ctx)
    ew = F.randn((g.num_edges(),))
    h = gat(g, feat, edge_weight=ew)
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
    _, a = gat(g, feat, get_attention=True)
    assert a.shape[0] == ew.shape[0]
    assert a.shape == (g.num_edges(), num_heads, 1)


nv-dlasalle's avatar
nv-dlasalle committed
614
@parametrize_idtype
615
616
617
618
619
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
Shaked Brody's avatar
Shaked Brody committed
620
621
622
623
624
625
626
627
628
629
630
631
632
def test_gatv2_conv(g, idtype, out_dim, num_heads):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    gat = nn.GATv2Conv(5, out_dim, num_heads)
    feat = F.randn((g.number_of_src_nodes(), 5))
    gat = gat.to(ctx)
    h = gat(g, feat)

    # test pickle
    th.save(gat, tmp_buffer)

    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
    _, a = gat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
633
    assert a.shape == (g.num_edges(), num_heads, 1)
Shaked Brody's avatar
Shaked Brody committed
634
635
636
637
638
639

    # test residual connection
    gat = nn.GATConv(5, out_dim, num_heads, residual=True)
    gat = gat.to(ctx)
    h = gat(g, feat)

640

nv-dlasalle's avatar
nv-dlasalle committed
641
@parametrize_idtype
642
643
644
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
Shaked Brody's avatar
Shaked Brody committed
645
646
647
648
def test_gatv2_conv_bi(g, idtype, out_dim, num_heads):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    gat = nn.GATv2Conv(5, out_dim, num_heads)
649
650
651
652
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
Shaked Brody's avatar
Shaked Brody committed
653
654
655
656
    gat = gat.to(ctx)
    h = gat(g, feat)
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
    _, a = gat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
657
    assert a.shape == (g.num_edges(), num_heads, 1)
Shaked Brody's avatar
Shaked Brody committed
658

659

nv-dlasalle's avatar
nv-dlasalle committed
660
@parametrize_idtype
661
662
663
664
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_node_feats", [1, 5])
@pytest.mark.parametrize("out_edge_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
665
def test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads):
Mufei Li's avatar
Mufei Li committed
666
    ctx = F.ctx()
667
    g = g.astype(idtype).to(ctx)
668
669
670
671
672
673
674
    egat = nn.EGATConv(
        in_node_feats=10,
        in_edge_feats=5,
        out_node_feats=out_node_feats,
        out_edge_feats=out_edge_feats,
        num_heads=num_heads,
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
675
676
    nfeat = F.randn((g.num_nodes(), 10))
    efeat = F.randn((g.num_edges(), 5))
677
678
    egat = egat.to(ctx)
    h, f = egat(g, nfeat, efeat)
679

680
    th.save(egat, tmp_buffer)
681

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
682
683
    assert h.shape == (g.num_nodes(), num_heads, out_node_feats)
    assert f.shape == (g.num_edges(), num_heads, out_edge_feats)
684
    _, _, attn = egat(g, nfeat, efeat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
685
    assert attn.shape == (g.num_edges(), num_heads, 1)
686

687

688
@parametrize_idtype
689
690
691
692
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_node_feats", [1, 5])
@pytest.mark.parametrize("out_edge_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
693
694
def test_egat_conv_bi(g, idtype, out_node_feats, out_edge_feats, num_heads):
    ctx = F.ctx()
695
    g = g.astype(idtype).to(ctx)
696
697
698
699
700
701
702
703
704
705
706
    egat = nn.EGATConv(
        in_node_feats=(10, 15),
        in_edge_feats=7,
        out_node_feats=out_node_feats,
        out_edge_feats=out_edge_feats,
        num_heads=num_heads,
    )
    nfeat = (
        F.randn((g.number_of_src_nodes(), 10)),
        F.randn((g.number_of_dst_nodes(), 15)),
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
707
    efeat = F.randn((g.num_edges(), 7))
708
709
    egat = egat.to(ctx)
    h, f = egat(g, nfeat, efeat)
710

Mufei Li's avatar
Mufei Li committed
711
    th.save(egat, tmp_buffer)
712

713
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_node_feats)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
714
    assert f.shape == (g.num_edges(), num_heads, out_edge_feats)
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
    _, _, attn = egat(g, nfeat, efeat, get_attention=True)
    assert attn.shape == (g.num_edges(), num_heads, 1)


@parametrize_idtype
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_node_feats", [1, 5])
@pytest.mark.parametrize("out_edge_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
def test_egat_conv_edge_weight(
    g, idtype, out_node_feats, out_edge_feats, num_heads
):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    egat = nn.EGATConv(
        in_node_feats=10,
        in_edge_feats=5,
        out_node_feats=out_node_feats,
        out_edge_feats=out_edge_feats,
        num_heads=num_heads,
    )
    egat = egat.to(ctx)
    nfeat = F.randn((g.num_nodes(), 10))
    efeat = F.randn((g.num_edges(), 5))
    ew = F.randn((g.num_edges(),))

    h, f, attn = egat(g, nfeat, efeat, edge_weight=ew, get_attention=True)

    assert h.shape == (g.num_nodes(), num_heads, out_node_feats)
    assert f.shape == (g.num_edges(), num_heads, out_edge_feats)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
745
    assert attn.shape == (g.num_edges(), num_heads, 1)
schmidt-ju's avatar
schmidt-ju committed
746
747
748
749
750
751
752
753


@parametrize_idtype
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
def test_edgegat_conv(g, idtype, out_feats, num_heads):
    ctx = F.ctx()
754
    g = g.astype(idtype).to(ctx)
schmidt-ju's avatar
schmidt-ju committed
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
    edgegat = nn.EdgeGATConv(
        in_feats=10, edge_feats=5, out_feats=out_feats, num_heads=num_heads
    )
    nfeat = F.randn((g.number_of_nodes(), 10))
    efeat = F.randn((g.number_of_edges(), 5))
    edgegat = edgegat.to(ctx)
    h = edgegat(g, nfeat, efeat)

    th.save(edgegat, tmp_buffer)

    assert h.shape == (g.number_of_nodes(), num_heads, out_feats)
    _, attn = edgegat(g, nfeat, efeat, True)
    assert attn.shape == (g.number_of_edges(), num_heads, 1)


@parametrize_idtype
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
def test_edgegat_conv_bi(g, idtype, out_feats, num_heads):
    ctx = F.ctx()
776
    g = g.astype(idtype).to(ctx)
schmidt-ju's avatar
schmidt-ju committed
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
    edgegat = nn.EdgeGATConv(
        in_feats=(10, 15),
        edge_feats=7,
        out_feats=out_feats,
        num_heads=num_heads,
    )
    nfeat = (
        F.randn((g.number_of_src_nodes(), 10)),
        F.randn((g.number_of_dst_nodes(), 15)),
    )
    efeat = F.randn((g.number_of_edges(), 7))
    edgegat = edgegat.to(ctx)
    h = edgegat(g, nfeat, efeat)

    th.save(edgegat, tmp_buffer)

    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_feats)
    _, attn = edgegat(g, nfeat, efeat, True)
    assert attn.shape == (g.number_of_edges(), num_heads, 1)
796

797

nv-dlasalle's avatar
nv-dlasalle committed
798
@parametrize_idtype
799
800
@pytest.mark.parametrize("g", get_cases(["homo", "block-bipartite"]))
@pytest.mark.parametrize("aggre_type", ["mean", "pool", "gcn", "lstm"])
801
802
def test_sage_conv(idtype, g, aggre_type):
    g = g.astype(idtype).to(F.ctx())
803
    sage = nn.SAGEConv(5, 10, aggre_type)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
804
    feat = F.randn((g.number_of_src_nodes(), 5))
805
    sage = sage.to(F.ctx())
806
807
    # test pickle
    th.save(sage, tmp_buffer)
808
809
810
    h = sage(g, feat)
    assert h.shape[-1] == 10

811

nv-dlasalle's avatar
nv-dlasalle committed
812
@parametrize_idtype
813
814
815
@pytest.mark.parametrize("g", get_cases(["bipartite"]))
@pytest.mark.parametrize("aggre_type", ["mean", "pool", "gcn", "lstm"])
@pytest.mark.parametrize("out_dim", [1, 2])
816
def test_sage_conv_bi(idtype, g, aggre_type, out_dim):
817
    g = g.astype(idtype).to(F.ctx())
818
    dst_dim = 5 if aggre_type != "gcn" else 10
819
    sage = nn.SAGEConv((10, dst_dim), out_dim, aggre_type)
820
821
822
823
    feat = (
        F.randn((g.number_of_src_nodes(), 10)),
        F.randn((g.number_of_dst_nodes(), dst_dim)),
    )
824
    sage = sage.to(F.ctx())
825
    h = sage(g, feat)
826
    assert h.shape[-1] == out_dim
827
    assert h.shape[0] == g.number_of_dst_nodes()
828

829

nv-dlasalle's avatar
nv-dlasalle committed
830
@parametrize_idtype
831
@pytest.mark.parametrize("out_dim", [1, 2])
832
def test_sage_conv2(idtype, out_dim):
833
    # TODO: add test for blocks
Mufei Li's avatar
Mufei Li committed
834
    # Test the case for graphs without edges
835
    g = dgl.heterograph({("_U", "_E", "_V"): ([], [])}, {"_U": 5, "_V": 3})
836
837
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
838
    sage = nn.SAGEConv((3, 3), out_dim, "gcn")
Mufei Li's avatar
Mufei Li committed
839
840
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    sage = sage.to(ctx)
841
    h = sage(g, (F.copy_to(feat[0], F.ctx()), F.copy_to(feat[1], F.ctx())))
842
    assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
843
    assert h.shape[0] == 3
844
    for aggre_type in ["mean", "pool", "lstm"]:
845
        sage = nn.SAGEConv((3, 1), out_dim, aggre_type)
Mufei Li's avatar
Mufei Li committed
846
847
848
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        sage = sage.to(ctx)
        h = sage(g, feat)
849
        assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
850
851
        assert h.shape[0] == 3

852

nv-dlasalle's avatar
nv-dlasalle committed
853
@parametrize_idtype
854
855
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
856
def test_sgc_conv(g, idtype, out_dim):
857
    ctx = F.ctx()
858
    g = g.astype(idtype).to(ctx)
859
    # not cached
860
    sgc = nn.SGConv(5, out_dim, 3)
861
862
863
864

    # test pickle
    th.save(sgc, tmp_buffer)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
865
    feat = F.randn((g.num_nodes(), 5))
866
    sgc = sgc.to(ctx)
867

868
    h = sgc(g, feat)
869
    assert h.shape[-1] == out_dim
870
871

    # cached
872
    sgc = nn.SGConv(5, out_dim, 3, True)
873
    sgc = sgc.to(ctx)
874
875
    h_0 = sgc(g, feat)
    h_1 = sgc(g, feat + 1)
876
    assert F.allclose(h_0, h_1)
877
    assert h_0.shape[-1] == out_dim
878

879

nv-dlasalle's avatar
nv-dlasalle committed
880
@parametrize_idtype
881
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
882
def test_appnp_conv(g, idtype):
883
    ctx = F.ctx()
884
    g = g.astype(idtype).to(ctx)
885
    appnp = nn.APPNPConv(10, 0.1)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
886
    feat = F.randn((g.num_nodes(), 5))
887
    appnp = appnp.to(ctx)
Mufei Li's avatar
Mufei Li committed
888

889
890
    # test pickle
    th.save(appnp, tmp_buffer)
891

892
    h = appnp(g, feat)
893
894
    assert h.shape[-1] == 5

895

nv-dlasalle's avatar
nv-dlasalle committed
896
@parametrize_idtype
897
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
898
899
900
901
def test_appnp_conv_e_weight(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    appnp = nn.APPNPConv(10, 0.1)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
902
    feat = F.randn((g.num_nodes(), 5))
903
    eweight = F.ones((g.num_edges(),))
904
905
906
907
908
    appnp = appnp.to(ctx)

    h = appnp(g, feat, edge_weight=eweight)
    assert h.shape[-1] == 5

909

nv-dlasalle's avatar
nv-dlasalle committed
910
@parametrize_idtype
911
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
912
913
@pytest.mark.parametrize("bias", [True, False])
def test_gcn2conv_e_weight(g, idtype, bias):
914
915
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
916
917
918
    gcn2conv = nn.GCN2Conv(
        5, layer=2, alpha=0.5, bias=bias, project_initial_features=True
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
919
    feat = F.randn((g.num_nodes(), 5))
920
    eweight = F.ones((g.num_edges(),))
921
922
923
924
925
926
    gcn2conv = gcn2conv.to(ctx)
    res = feat
    h = gcn2conv(g, res, feat, edge_weight=eweight)
    assert h.shape[-1] == 5


nv-dlasalle's avatar
nv-dlasalle committed
927
@parametrize_idtype
928
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
929
930
931
932
def test_sgconv_e_weight(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    sgconv = nn.SGConv(5, 5, 3)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
933
    feat = F.randn((g.num_nodes(), 5))
934
    eweight = F.ones((g.num_edges(),))
935
936
937
938
    sgconv = sgconv.to(ctx)
    h = sgconv(g, feat, edge_weight=eweight)
    assert h.shape[-1] == 5

939

nv-dlasalle's avatar
nv-dlasalle committed
940
@parametrize_idtype
941
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
942
943
944
945
946
def test_tagconv_e_weight(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    conv = nn.TAGConv(5, 5, bias=True)
    conv = conv.to(ctx)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
947
    feat = F.randn((g.num_nodes(), 5))
948
    eweight = F.ones((g.num_edges(),))
949
950
951
952
    conv = conv.to(ctx)
    h = conv(g, feat, edge_weight=eweight)
    assert h.shape[-1] == 5

953

nv-dlasalle's avatar
nv-dlasalle committed
954
@parametrize_idtype
955
956
957
958
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("aggregator_type", ["mean", "max", "sum"])
959
960
def test_gin_conv(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
961
    ctx = F.ctx()
962
    gin = nn.GINConv(th.nn.Linear(5, 12), aggregator_type)
VoVAllen's avatar
VoVAllen committed
963
    th.save(gin, tmp_buffer)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
964
    feat = F.randn((g.number_of_src_nodes(), 5))
965
966
    gin = gin.to(ctx)
    h = gin(g, feat)
967
968

    # test pickle
VoVAllen's avatar
VoVAllen committed
969
    th.save(gin, tmp_buffer)
Mufei Li's avatar
Mufei Li committed
970

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
971
    assert h.shape == (g.number_of_dst_nodes(), 12)
972

Mufei Li's avatar
Mufei Li committed
973
974
975
976
    gin = nn.GINConv(None, aggregator_type)
    th.save(gin, tmp_buffer)
    gin = gin.to(ctx)
    h = gin(g, feat)
977

978

nv-dlasalle's avatar
nv-dlasalle committed
979
@parametrize_idtype
980
@pytest.mark.parametrize("g", get_cases(["homo", "block-bipartite"]))
Mufei Li's avatar
Mufei Li committed
981
982
983
def test_gine_conv(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
984
    gine = nn.GINEConv(th.nn.Linear(5, 12))
Mufei Li's avatar
Mufei Li committed
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
    th.save(gine, tmp_buffer)
    nfeat = F.randn((g.number_of_src_nodes(), 5))
    efeat = F.randn((g.num_edges(), 5))
    gine = gine.to(ctx)
    h = gine(g, nfeat, efeat)

    # test pickle
    th.save(gine, tmp_buffer)
    assert h.shape == (g.number_of_dst_nodes(), 12)

    gine = nn.GINEConv(None)
    th.save(gine, tmp_buffer)
    gine = gine.to(ctx)
    h = gine(g, nfeat, efeat)

1000

nv-dlasalle's avatar
nv-dlasalle committed
1001
@parametrize_idtype
1002
1003
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("aggregator_type", ["mean", "max", "sum"])
1004
1005
1006
def test_gin_conv_bi(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1007
1008
1009
1010
    gin = nn.GINConv(th.nn.Linear(5, 12), aggregator_type)
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
1011
1012
1013
    )
    gin = gin.to(ctx)
    h = gin(g, feat)
1014
    assert h.shape == (g.number_of_dst_nodes(), 12)
1015

1016

nv-dlasalle's avatar
nv-dlasalle committed
1017
@parametrize_idtype
1018
1019
1020
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
1021
1022
def test_agnn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
1023
1024
    ctx = F.ctx()
    agnn = nn.AGNNConv(1)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1025
    feat = F.randn((g.number_of_src_nodes(), 5))
1026
    agnn = agnn.to(ctx)
1027
    h = agnn(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1028
    assert h.shape == (g.number_of_dst_nodes(), 5)
1029

1030

nv-dlasalle's avatar
nv-dlasalle committed
1031
@parametrize_idtype
1032
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
1033
1034
1035
def test_agnn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1036
    agnn = nn.AGNNConv(1)
1037
1038
1039
1040
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
1041
1042
    agnn = agnn.to(ctx)
    h = agnn(g, feat)
1043
    assert h.shape == (g.number_of_dst_nodes(), 5)
1044

1045

nv-dlasalle's avatar
nv-dlasalle committed
1046
@parametrize_idtype
1047
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
1048
def test_gated_graph_conv(g, idtype):
1049
    ctx = F.ctx()
1050
    g = g.astype(idtype).to(ctx)
1051
    ggconv = nn.GatedGraphConv(5, 10, 5, 3)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1052
1053
    etypes = th.arange(g.num_edges()) % 3
    feat = F.randn((g.num_nodes(), 5))
1054
1055
    ggconv = ggconv.to(ctx)
    etypes = etypes.to(ctx)
1056

1057
    h = ggconv(g, feat, etypes)
1058
1059
1060
    # current we only do shape check
    assert h.shape[-1] == 10

1061

nv-dlasalle's avatar
nv-dlasalle committed
1062
@parametrize_idtype
1063
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
1064
1065
1066
1067
def test_gated_graph_conv_one_etype(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    ggconv = nn.GatedGraphConv(5, 10, 5, 1)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1068
1069
    etypes = th.zeros(g.num_edges())
    feat = F.randn((g.num_nodes(), 5))
1070
1071
1072
1073
1074
1075
1076
1077
1078
    ggconv = ggconv.to(ctx)
    etypes = etypes.to(ctx)

    h = ggconv(g, feat, etypes)
    h2 = ggconv(g, feat)
    # current we only do shape check
    assert F.allclose(h, h2)
    assert h.shape[-1] == 10

1079

nv-dlasalle's avatar
nv-dlasalle committed
1080
@parametrize_idtype
1081
1082
1083
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
1084
1085
def test_nn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
1086
1087
    ctx = F.ctx()
    edge_func = th.nn.Linear(4, 5 * 10)
1088
    nnconv = nn.NNConv(5, 10, edge_func, "mean")
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1089
    feat = F.randn((g.number_of_src_nodes(), 5))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1090
    efeat = F.randn((g.num_edges(), 4))
1091
1092
1093
1094
1095
    nnconv = nnconv.to(ctx)
    h = nnconv(g, feat, efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

1096

nv-dlasalle's avatar
nv-dlasalle committed
1097
@parametrize_idtype
1098
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
1099
1100
1101
def test_nn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1102
    edge_func = th.nn.Linear(4, 5 * 10)
1103
    nnconv = nn.NNConv((5, 2), 10, edge_func, "mean")
1104
1105
    feat = F.randn((g.number_of_src_nodes(), 5))
    feat_dst = F.randn((g.number_of_dst_nodes(), 2))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1106
    efeat = F.randn((g.num_edges(), 4))
1107
1108
1109
1110
1111
    nnconv = nnconv.to(ctx)
    h = nnconv(g, (feat, feat_dst), efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

1112

nv-dlasalle's avatar
nv-dlasalle committed
1113
@parametrize_idtype
1114
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
1115
1116
def test_gmm_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
1117
    ctx = F.ctx()
1118
    gmmconv = nn.GMMConv(5, 10, 3, 4, "mean")
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1119
1120
    feat = F.randn((g.num_nodes(), 5))
    pseudo = F.randn((g.num_edges(), 3))
1121
    gmmconv = gmmconv.to(ctx)
1122
    h = gmmconv(g, feat, pseudo)
1123
1124
1125
    # currently we only do shape check
    assert h.shape[-1] == 10

1126

nv-dlasalle's avatar
nv-dlasalle committed
1127
@parametrize_idtype
1128
1129
1130
@pytest.mark.parametrize(
    "g", get_cases(["bipartite", "block-bipartite"], exclude=["zero-degree"])
)
1131
1132
1133
def test_gmm_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1134
    gmmconv = nn.GMMConv((5, 2), 10, 3, 4, "mean")
1135
1136
    feat = F.randn((g.number_of_src_nodes(), 5))
    feat_dst = F.randn((g.number_of_dst_nodes(), 2))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1137
    pseudo = F.randn((g.num_edges(), 3))
1138
1139
1140
1141
1142
    gmmconv = gmmconv.to(ctx)
    h = gmmconv(g, (feat, feat_dst), pseudo)
    # currently we only do shape check
    assert h.shape[-1] == 10

1143

nv-dlasalle's avatar
nv-dlasalle committed
1144
@parametrize_idtype
1145
1146
1147
1148
1149
@pytest.mark.parametrize("norm_type", ["both", "right", "none"])
@pytest.mark.parametrize(
    "g", get_cases(["homo", "bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 2])
1150
def test_dense_graph_conv(norm_type, g, idtype, out_dim):
1151
    g = g.astype(idtype).to(F.ctx())
1152
    ctx = F.ctx()
1153
    # TODO(minjie): enable the following option after #1385
1154
    adj = g.adj_external(transpose=True, ctx=ctx).to_dense()
1155
1156
    conv = nn.GraphConv(5, out_dim, norm=norm_type, bias=True)
    dense_conv = nn.DenseGraphConv(5, out_dim, norm=norm_type, bias=True)
1157
1158
    dense_conv.weight.data = conv.weight.data
    dense_conv.bias.data = conv.bias.data
1159
    feat = F.randn((g.number_of_src_nodes(), 5))
1160
1161
    conv = conv.to(ctx)
    dense_conv = dense_conv.to(ctx)
1162
1163
    out_conv = conv(g, feat)
    out_dense_conv = dense_conv(adj, feat)
1164
1165
    assert F.allclose(out_conv, out_dense_conv)

1166

nv-dlasalle's avatar
nv-dlasalle committed
1167
@parametrize_idtype
1168
1169
@pytest.mark.parametrize("g", get_cases(["homo", "bipartite"]))
@pytest.mark.parametrize("out_dim", [1, 2])
1170
def test_dense_sage_conv(g, idtype, out_dim):
1171
    g = g.astype(idtype).to(F.ctx())
1172
    ctx = F.ctx()
1173
    adj = g.adj_external(transpose=True, ctx=ctx).to_dense()
1174
    sage = nn.SAGEConv(5, out_dim, "gcn")
1175
    dense_sage = nn.DenseSAGEConv(5, out_dim)
1176
    dense_sage.fc.weight.data = sage.fc_neigh.weight.data
1177
    dense_sage.fc.bias.data = sage.bias.data
1178
1179
1180
    if len(g.ntypes) == 2:
        feat = (
            F.randn((g.number_of_src_nodes(), 5)),
1181
            F.randn((g.number_of_dst_nodes(), 5)),
1182
1183
        )
    else:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1184
        feat = F.randn((g.num_nodes(), 5))
1185
1186
    sage = sage.to(ctx)
    dense_sage = dense_sage.to(ctx)
1187
1188
    out_sage = sage(g, feat)
    out_dense_sage = dense_sage(adj, feat)
1189
1190
    assert F.allclose(out_sage, out_dense_sage), g

1191

nv-dlasalle's avatar
nv-dlasalle committed
1192
@parametrize_idtype
1193
1194
1195
1196
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 2])
1197
def test_edge_conv(g, idtype, out_dim):
1198
    g = g.astype(idtype).to(F.ctx())
1199
    ctx = F.ctx()
1200
    edge_conv = nn.EdgeConv(5, out_dim).to(ctx)
1201
    print(edge_conv)
1202
1203
1204

    # test pickle
    th.save(edge_conv, tmp_buffer)
Mufei Li's avatar
Mufei Li committed
1205

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1206
    h0 = F.randn((g.number_of_src_nodes(), 5))
1207
    h1 = edge_conv(g, h0)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1208
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)
1209

1210

nv-dlasalle's avatar
nv-dlasalle committed
1211
@parametrize_idtype
1212
1213
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
1214
def test_edge_conv_bi(g, idtype, out_dim):
1215
1216
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1217
    edge_conv = nn.EdgeConv(5, out_dim).to(ctx)
1218
    print(edge_conv)
1219
    h0 = F.randn((g.number_of_src_nodes(), 5))
1220
1221
    x0 = F.randn((g.number_of_dst_nodes(), 5))
    h1 = edge_conv(g, (h0, x0))
1222
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)
Mufei Li's avatar
Mufei Li committed
1223

1224

nv-dlasalle's avatar
nv-dlasalle committed
1225
@parametrize_idtype
1226
1227
1228
1229
1230
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
1231
def test_dotgat_conv(g, idtype, out_dim, num_heads):
1232
1233
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1234
    dotgat = nn.DotGatConv(5, out_dim, num_heads)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1235
    feat = F.randn((g.number_of_src_nodes(), 5))
1236
    dotgat = dotgat.to(ctx)
Mufei Li's avatar
Mufei Li committed
1237

1238
1239
    # test pickle
    th.save(dotgat, tmp_buffer)
Mufei Li's avatar
Mufei Li committed
1240

1241
    h = dotgat(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1242
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
1243
    _, a = dotgat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1244
    assert a.shape == (g.num_edges(), num_heads, 1)
1245

1246

nv-dlasalle's avatar
nv-dlasalle committed
1247
@parametrize_idtype
1248
1249
1250
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
1251
def test_dotgat_conv_bi(g, idtype, out_dim, num_heads):
1252
1253
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1254
    dotgat = nn.DotGatConv((5, 5), out_dim, num_heads)
1255
1256
1257
1258
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
1259
1260
    dotgat = dotgat.to(ctx)
    h = dotgat(g, feat)
1261
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
1262
    _, a = dotgat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1263
    assert a.shape == (g.num_edges(), num_heads, 1)
1264

1265
1266

@pytest.mark.parametrize("out_dim", [1, 2])
1267
def test_dense_cheb_conv(out_dim):
1268
1269
1270
    for k in range(1, 4):
        ctx = F.ctx()
        g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
1271
        g = g.to(F.ctx())
1272
        adj = g.adj_external(transpose=True, ctx=ctx).to_dense()
1273
1274
        cheb = nn.ChebConv(5, out_dim, k, None)
        dense_cheb = nn.DenseChebConv(5, out_dim, k)
1275
        # for i in range(len(cheb.fc)):
Axel Nilsson's avatar
Axel Nilsson committed
1276
        #    dense_cheb.W.data[i] = cheb.fc[i].weight.data.t()
1277
1278
1279
        dense_cheb.W.data = cheb.linear.weight.data.transpose(-1, -2).view(
            k, 5, out_dim
        )
Axel Nilsson's avatar
Axel Nilsson committed
1280
1281
        if cheb.linear.bias is not None:
            dense_cheb.bias.data = cheb.linear.bias.data
1282
        feat = F.randn((100, 5))
1283
1284
        cheb = cheb.to(ctx)
        dense_cheb = dense_cheb.to(ctx)
1285
1286
        out_cheb = cheb(g, feat, [2.0])
        out_dense_cheb = dense_cheb(adj, feat, 2.0)
Axel Nilsson's avatar
Axel Nilsson committed
1287
        print(k, out_cheb, out_dense_cheb)
1288
1289
        assert F.allclose(out_cheb, out_dense_cheb)

1290

1291
1292
def test_sequential():
    ctx = F.ctx()
1293

1294
1295
1296
1297
1298
1299
1300
    # Test single graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat, e_feat):
            graph = graph.local_var()
1301
1302
1303
1304
1305
            graph.ndata["h"] = n_feat
            graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
            n_feat += graph.ndata["h"]
            graph.apply_edges(fn.u_add_v("h", "h", "e"))
            e_feat += graph.edata["e"]
1306
1307
1308
1309
1310
            return n_feat, e_feat

    g = dgl.DGLGraph()
    g.add_nodes(3)
    g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
1311
    g = g.to(F.ctx())
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    n_feat = F.randn((3, 4))
    e_feat = F.randn((9, 4))
    net = net.to(ctx)
    n_feat, e_feat = net(g, n_feat, e_feat)
    assert n_feat.shape == (3, 4)
    assert e_feat.shape == (9, 4)

    # Test multiple graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat):
            graph = graph.local_var()
1327
1328
1329
            graph.ndata["h"] = n_feat
            graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
            n_feat += graph.ndata["h"]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1330
            return n_feat.view(graph.num_nodes() // 2, 2, -1).sum(1)
1331

1332
1333
1334
    g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05)).to(F.ctx())
    g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2)).to(F.ctx())
    g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8)).to(F.ctx())
1335
1336
1337
1338
1339
1340
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    net = net.to(ctx)
    n_feat = F.randn((32, 4))
    n_feat = net([g1, g2, g3], n_feat)
    assert n_feat.shape == (4, 4)

1341

nv-dlasalle's avatar
nv-dlasalle committed
1342
@parametrize_idtype
1343
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
1344
1345
def test_atomic_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
1346
1347
1348
1349
1350
1351
    aconv = nn.AtomicConv(
        interaction_cutoffs=F.tensor([12.0, 12.0]),
        rbf_kernel_means=F.tensor([0.0, 2.0]),
        rbf_kernel_scaling=F.tensor([4.0, 4.0]),
        features_to_use=F.tensor([6.0, 8.0]),
    )
1352
1353
1354
1355
1356

    ctx = F.ctx()
    if F.gpu_ctx():
        aconv = aconv.to(ctx)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1357
1358
    feat = F.randn((g.num_nodes(), 1))
    dist = F.randn((g.num_edges(), 1))
1359
1360

    h = aconv(g, feat, dist)
1361

1362
1363
1364
    # current we only do shape check
    assert h.shape[-1] == 4

1365

nv-dlasalle's avatar
nv-dlasalle committed
1366
@parametrize_idtype
1367
1368
1369
1370
@pytest.mark.parametrize(
    "g", get_cases(["homo", "bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 3])
1371
def test_cf_conv(g, idtype, out_dim):
1372
    g = g.astype(idtype).to(F.ctx())
1373
1374
1375
    cfconv = nn.CFConv(
        node_in_feats=2, edge_in_feats=3, hidden_feats=2, out_feats=out_dim
    )
1376
1377
1378
1379
1380

    ctx = F.ctx()
    if F.gpu_ctx():
        cfconv = cfconv.to(ctx)

1381
    src_feats = F.randn((g.number_of_src_nodes(), 2))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1382
    edge_feats = F.randn((g.num_edges(), 3))
1383
1384
1385
1386
1387
1388
1389
    h = cfconv(g, src_feats, edge_feats)
    # current we only do shape check
    assert h.shape[-1] == out_dim

    # case for bipartite graphs
    dst_feats = F.randn((g.number_of_dst_nodes(), 3))
    h = cfconv(g, (src_feats, dst_feats), edge_feats)
1390
    # current we only do shape check
1391
    assert h.shape[-1] == out_dim
1392

1393

1394
1395
1396
1397
1398
1399
def myagg(alist, dsttype):
    rst = alist[0]
    for i in range(1, len(alist)):
        rst = rst + (i + 1) * alist[i]
    return rst

1400

nv-dlasalle's avatar
nv-dlasalle committed
1401
@parametrize_idtype
1402
1403
@pytest.mark.parametrize("agg", ["sum", "max", "min", "mean", "stack", myagg])
@pytest.mark.parametrize("canonical_keys", [False, True])
1404
def test_hetero_conv(agg, idtype, canonical_keys):
1405
1406
1407
1408
1409
1410
1411
1412
1413
    g = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 0, 2, 1], [1, 2, 1, 3]),
            ("user", "plays", "game"): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]),
            ("store", "sells", "game"): ([0, 0, 1, 1], [0, 3, 1, 2]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
1414
    if not canonical_keys:
1415
1416
1417
1418
1419
1420
1421
1422
        conv = nn.HeteroGraphConv(
            {
                "follows": nn.GraphConv(2, 3, allow_zero_in_degree=True),
                "plays": nn.GraphConv(2, 4, allow_zero_in_degree=True),
                "sells": nn.GraphConv(3, 4, allow_zero_in_degree=True),
            },
            agg,
        )
1423
    else:
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
        conv = nn.HeteroGraphConv(
            {
                ("user", "follows", "user"): nn.GraphConv(
                    2, 3, allow_zero_in_degree=True
                ),
                ("user", "plays", "game"): nn.GraphConv(
                    2, 4, allow_zero_in_degree=True
                ),
                ("store", "sells", "game"): nn.GraphConv(
                    3, 4, allow_zero_in_degree=True
                ),
            },
            agg,
        )
1438

1439
    conv = conv.to(F.ctx())
1440
1441
1442
1443

    # test pickle
    th.save(conv, tmp_buffer)

1444
1445
1446
1447
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))

1448
1449
1450
1451
1452
    h = conv(g, {"user": uf, "game": gf, "store": sf})
    assert set(h.keys()) == {"user", "game"}
    if agg != "stack":
        assert h["user"].shape == (4, 3)
        assert h["game"].shape == (4, 4)
1453
    else:
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
        assert h["user"].shape == (4, 1, 3)
        assert h["game"].shape == (4, 2, 4)

    block = dgl.to_block(
        g.to(F.cpu()), {"user": [0, 1, 2, 3], "game": [0, 1, 2, 3], "store": []}
    ).to(F.ctx())
    h = conv(
        block,
        (
            {"user": uf, "game": gf, "store": sf},
            {"user": uf, "game": gf, "store": sf[0:0]},
        ),
    )
    assert set(h.keys()) == {"user", "game"}
    if agg != "stack":
        assert h["user"].shape == (4, 3)
        assert h["game"].shape == (4, 4)
1471
    else:
1472
1473
1474
1475
1476
1477
1478
1479
        assert h["user"].shape == (4, 1, 3)
        assert h["game"].shape == (4, 2, 4)

    h = conv(block, {"user": uf, "game": gf, "store": sf})
    assert set(h.keys()) == {"user", "game"}
    if agg != "stack":
        assert h["user"].shape == (4, 3)
        assert h["game"].shape == (4, 4)
1480
    else:
1481
1482
        assert h["user"].shape == (4, 1, 3)
        assert h["game"].shape == (4, 2, 4)
1483
1484
1485
1486
1487
1488
1489
1490
1491

    # test with mod args
    class MyMod(th.nn.Module):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.carg2 = 0
            self.s1 = s1
            self.s2 = s2
1492

1493
1494
1495
1496
1497
1498
        def forward(self, g, h, arg1=None, *, arg2=None):
            if arg1 is not None:
                self.carg1 += 1
            if arg2 is not None:
                self.carg2 += 1
            return th.zeros((g.number_of_dst_nodes(), self.s2))
1499

1500
1501
1502
    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
1503
1504
1505
    conv = nn.HeteroGraphConv(
        {"follows": mod1, "plays": mod2, "sells": mod3}, agg
    )
1506
    conv = conv.to(F.ctx())
1507
1508
1509
1510
1511
1512
1513
1514
    mod_args = {"follows": (1,), "plays": (1,)}
    mod_kwargs = {"sells": {"arg2": "abc"}}
    h = conv(
        g,
        {"user": uf, "game": gf, "store": sf},
        mod_args=mod_args,
        mod_kwargs=mod_kwargs,
    )
1515
1516
1517
1518
1519
1520
1521
    assert mod1.carg1 == 1
    assert mod1.carg2 == 0
    assert mod2.carg1 == 1
    assert mod2.carg2 == 0
    assert mod3.carg1 == 0
    assert mod3.carg2 == 1

1522
    # conv on graph without any edges
1523
    for etype in g.etypes:
1524
        g = dgl.remove_edges(g, g.edges(form="eid", etype=etype), etype=etype)
1525
    assert g.num_edges() == 0
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
    h = conv(g, {"user": uf, "game": gf, "store": sf})
    assert set(h.keys()) == {"user", "game"}

    block = dgl.to_block(
        g.to(F.cpu()), {"user": [0, 1, 2, 3], "game": [0, 1, 2, 3], "store": []}
    ).to(F.ctx())
    h = conv(
        block,
        (
            {"user": uf, "game": gf, "store": sf},
            {"user": uf, "game": gf, "store": sf[0:0]},
        ),
    )
    assert set(h.keys()) == {"user", "game"}
1540
1541


1542
@pytest.mark.parametrize("out_dim", [1, 2, 100])
1543
1544
def test_hetero_linear(out_dim):
    in_feats = {
1545
1546
        "user": F.randn((2, 1)),
        ("user", "follows", "user"): F.randn((3, 2)),
1547
1548
    }

1549
1550
1551
    layer = nn.HeteroLinear(
        {"user": 1, ("user", "follows", "user"): 2}, out_dim
    )
1552
1553
    layer = layer.to(F.ctx())
    out_feats = layer(in_feats)
1554
1555
1556
    assert out_feats["user"].shape == (2, out_dim)
    assert out_feats[("user", "follows", "user")].shape == (3, out_dim)

1557

1558
@pytest.mark.parametrize("out_dim", [1, 2, 100])
1559
def test_hetero_embedding(out_dim):
1560
1561
1562
    layer = nn.HeteroEmbedding(
        {"user": 2, ("user", "follows", "user"): 3}, out_dim
    )
1563
1564
1565
    layer = layer.to(F.ctx())

    embeds = layer.weight
1566
1567
    assert embeds["user"].shape == (2, out_dim)
    assert embeds[("user", "follows", "user")].shape == (3, out_dim)
1568

YJ-Zhao's avatar
YJ-Zhao committed
1569
1570
    layer.reset_parameters()
    embeds = layer.weight
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
    assert embeds["user"].shape == (2, out_dim)
    assert embeds[("user", "follows", "user")].shape == (3, out_dim)

    embeds = layer(
        {
            "user": F.tensor([0], dtype=F.int64),
            ("user", "follows", "user"): F.tensor([0, 2], dtype=F.int64),
        }
    )
    assert embeds["user"].shape == (1, out_dim)
    assert embeds[("user", "follows", "user")].shape == (2, out_dim)
YJ-Zhao's avatar
YJ-Zhao committed
1582

1583

nv-dlasalle's avatar
nv-dlasalle committed
1584
@parametrize_idtype
1585
1586
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
Mufei Li's avatar
Mufei Li committed
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
def test_gnnexplainer(g, idtype, out_dim):
    g = g.astype(idtype).to(F.ctx())
    feat = F.randn((g.num_nodes(), 5))

    class Model(th.nn.Module):
        def __init__(self, in_feats, out_feats, graph=False):
            super(Model, self).__init__()
            self.linear = th.nn.Linear(in_feats, out_feats)
            if graph:
                self.pool = nn.AvgPooling()
            else:
                self.pool = None

        def forward(self, graph, feat, eweight=None):
            with graph.local_scope():
                feat = self.linear(feat)
1603
                graph.ndata["h"] = feat
Mufei Li's avatar
Mufei Li committed
1604
                if eweight is None:
1605
                    graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
Mufei Li's avatar
Mufei Li committed
1606
                else:
1607
1608
1609
1610
                    graph.edata["w"] = eweight
                    graph.update_all(
                        fn.u_mul_e("h", "w", "m"), fn.sum("m", "h")
                    )
Mufei Li's avatar
Mufei Li committed
1611
1612

                if self.pool:
1613
                    return self.pool(graph, graph.ndata["h"])
Mufei Li's avatar
Mufei Li committed
1614
                else:
1615
                    return graph.ndata["h"]
Mufei Li's avatar
Mufei Li committed
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628

    # Explain node prediction
    model = Model(5, out_dim)
    model = model.to(F.ctx())
    explainer = nn.GNNExplainer(model, num_hops=1)
    new_center, sg, feat_mask, edge_mask = explainer.explain_node(0, g, feat)

    # Explain graph prediction
    model = Model(5, out_dim, graph=True)
    model = model.to(F.ctx())
    explainer = nn.GNNExplainer(model, num_hops=1)
    feat_mask, edge_mask = explainer.explain_graph(g, feat)

1629
1630
1631
1632
1633

@pytest.mark.parametrize("g", get_cases(["hetero"], exclude=["zero-degree"]))
@pytest.mark.parametrize("idtype", [F.int64])
@pytest.mark.parametrize("input_dim", [5])
@pytest.mark.parametrize("output_dim", [1, 2])
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
def test_heterognnexplainer(g, idtype, input_dim, output_dim):
    g = g.astype(idtype).to(F.ctx())
    device = g.device

    # add self-loop and reverse edges
    transform1 = dgl.transforms.AddSelfLoop(new_etypes=True)
    g = transform1(g)
    transform2 = dgl.transforms.AddReverse(copy_edata=True)
    g = transform2(g)

1644
1645
1646
1647
    feat = {
        ntype: th.zeros((g.num_nodes(ntype), input_dim), device=device)
        for ntype in g.ntypes
    }
1648
1649
1650
1651

    class Model(th.nn.Module):
        def __init__(self, in_dim, num_classes, canonical_etypes, graph=False):
            super(Model, self).__init__()
1652
1653
1654
1655
1656
1657
1658
            self.graph = graph
            self.etype_weights = th.nn.ModuleDict(
                {
                    "_".join(c_etype): th.nn.Linear(in_dim, num_classes)
                    for c_etype in canonical_etypes
                }
            )
1659
1660
1661
1662
1663
1664

        def forward(self, graph, feat, eweight=None):
            with graph.local_scope():
                c_etype_func_dict = {}
                for c_etype in graph.canonical_etypes:
                    src_type, etype, dst_type = c_etype
1665
1666
                    wh = self.etype_weights["_".join(c_etype)](feat[src_type])
                    graph.nodes[src_type].data[f"h_{c_etype}"] = wh
1667
                    if eweight is None:
1668
1669
1670
1671
                        c_etype_func_dict[c_etype] = (
                            fn.copy_u(f"h_{c_etype}", "m"),
                            fn.mean("m", "h"),
                        )
1672
                    else:
1673
                        graph.edges[c_etype].data["w"] = eweight[c_etype]
1674
                        c_etype_func_dict[c_etype] = (
1675
1676
1677
1678
                            fn.u_mul_e(f"h_{c_etype}", "w", "m"),
                            fn.mean("m", "h"),
                        )
                graph.multi_update_all(c_etype_func_dict, "sum")
1679
1680
1681
1682
                if self.graph:
                    hg = 0
                    for ntype in graph.ntypes:
                        if graph.num_nodes(ntype):
1683
                            hg = hg + dgl.mean_nodes(graph, "h", ntype=ntype)
1684
1685
1686

                    return hg
                else:
1687
                    return graph.ndata["h"]
1688
1689
1690
1691
1692
1693

    # Explain node prediction
    model = Model(input_dim, output_dim, g.canonical_etypes)
    model = model.to(F.ctx())
    ntype = g.ntypes[0]
    explainer = nn.explain.HeteroGNNExplainer(model, num_hops=1)
1694
1695
1696
    new_center, sg, feat_mask, edge_mask = explainer.explain_node(
        ntype, 0, g, feat
    )
1697
1698
1699
1700
1701
1702
1703
1704

    # Explain graph prediction
    model = Model(input_dim, output_dim, g.canonical_etypes, graph=True)
    model = model.to(F.ctx())
    explainer = nn.explain.HeteroGNNExplainer(model, num_hops=1)
    feat_mask, edge_mask = explainer.explain_graph(g, feat)


1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
@parametrize_idtype
@pytest.mark.parametrize(
    "g",
    get_cases(
        ["homo"],
        exclude=[
            "zero-degree",
            "homo-zero-degree",
            "has_feature",
            "has_scalar_e_feature",
            "row_sorted",
            "col_sorted",
            "batched",
        ],
    ),
)
@pytest.mark.parametrize("n_classes", [2])
def test_subgraphx(g, idtype, n_classes):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    feat = F.randn((g.num_nodes(), 5))

    class Model(th.nn.Module):
        def __init__(self, in_dim, n_classes):
            super().__init__()
            self.conv = nn.GraphConv(in_dim, n_classes)
            self.pool = nn.AvgPooling()

        def forward(self, g, h):
            h = th.nn.functional.relu(self.conv(g, h))
            return self.pool(g, h)

    model = Model(feat.shape[1], n_classes)
    model = model.to(ctx)
    explainer = nn.SubgraphX(
        model, num_hops=1, shapley_steps=20, num_rollouts=5, coef=2.0
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
    )
    explainer.explain_graph(g, feat, target_class=0)


@pytest.mark.parametrize("g", get_cases(["hetero"], exclude=["zero-degree"]))
@pytest.mark.parametrize("idtype", [F.int64])
@pytest.mark.parametrize("input_dim", [5])
@pytest.mark.parametrize("n_classes", [2])
def test_heterosubgraphx(g, idtype, input_dim, n_classes):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    device = g.device

    # add self-loop and reverse edges
    transform1 = dgl.transforms.AddSelfLoop(new_etypes=True)
    g = transform1(g)
    transform2 = dgl.transforms.AddReverse(copy_edata=True)
    g = transform2(g)

    feat = {
        ntype: th.zeros((g.num_nodes(ntype), input_dim), device=device)
        for ntype in g.ntypes
    }

    class Model(th.nn.Module):
        def __init__(self, in_dim, n_classes, canonical_etypes):
            super(Model, self).__init__()
            self.etype_weights = th.nn.ModuleDict(
                {
                    "_".join(c_etype): th.nn.Linear(in_dim, n_classes)
                    for c_etype in canonical_etypes
                }
            )

        def forward(self, graph, feat):
            with graph.local_scope():
                c_etype_func_dict = {}
                for c_etype in graph.canonical_etypes:
                    src_type, etype, dst_type = c_etype
                    wh = self.etype_weights["_".join(c_etype)](feat[src_type])
                    graph.nodes[src_type].data[f"h_{c_etype}"] = wh
                    c_etype_func_dict[c_etype] = (
                        fn.copy_u(f"h_{c_etype}", "m"),
                        fn.mean("m", "h"),
                    )
                graph.multi_update_all(c_etype_func_dict, "sum")
                hg = 0
                for ntype in graph.ntypes:
                    if graph.num_nodes(ntype):
                        hg = hg + dgl.mean_nodes(graph, "h", ntype=ntype)

                return hg

    model = Model(input_dim, n_classes, g.canonical_etypes)
    model = model.to(ctx)
    explainer = nn.HeteroSubgraphX(
        model, num_hops=1, shapley_steps=20, num_rollouts=5, coef=2.0
1798
1799
1800
1801
    )
    explainer.explain_graph(g, feat, target_class=0)


1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
@parametrize_idtype
@pytest.mark.parametrize(
    "g",
    get_cases(
        ["homo"],
        exclude=[
            "zero-degree",
            "homo-zero-degree",
            "has_feature",
            "has_scalar_e_feature",
            "row_sorted",
            "col_sorted",
        ],
    ),
)
@pytest.mark.parametrize("n_classes", [2])
def test_pgexplainer(g, idtype, n_classes):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    feat = F.randn((g.num_nodes(), 5))
    g.ndata["attr"] = feat

    # add reverse edges
    transform = dgl.transforms.AddReverse(copy_edata=True)
    g = transform(g)

    class Model(th.nn.Module):
        def __init__(self, in_feats, out_feats):
            super(Model, self).__init__()
            self.conv = nn.GraphConv(in_feats, out_feats)
            self.fc = th.nn.Linear(out_feats, out_feats)
            th.nn.init.xavier_uniform_(self.fc.weight)

        def forward(self, g, h, embed=False, edge_weight=None):
            h = self.conv(g, h, edge_weight=edge_weight)
1837
1838
1839
1840
1841

            if embed:
                return h

            with g.local_scope():
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
                g.ndata["h"] = h
                hg = dgl.mean_nodes(g, "h")
                return self.fc(hg)

    model = Model(feat.shape[1], n_classes)
    model = model.to(ctx)

    explainer = nn.PGExplainer(model, n_classes)
    explainer.train_step(g, g.ndata["attr"], 5.0)

    probs, edge_weight = explainer.explain_graph(g, feat)


1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
@pytest.mark.parametrize("g", get_cases(["hetero"]))
@pytest.mark.parametrize("idtype", [F.int64])
@pytest.mark.parametrize("input_dim", [5])
@pytest.mark.parametrize("n_classes", [2])
def test_heteropgexplainer(g, idtype, input_dim, n_classes):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    feat = {
        ntype: F.randn((g.num_nodes(ntype), input_dim)) for ntype in g.ntypes
    }

    # add self-loop and reverse edges
    transform1 = dgl.transforms.AddSelfLoop(new_etypes=True)
    g = transform1(g)
    transform2 = dgl.transforms.AddReverse(copy_edata=True)
    g = transform2(g)

    class Model(th.nn.Module):
        def __init__(self, in_feats, embed_dim, out_feats, canonical_etypes):
            super(Model, self).__init__()
            self.conv = nn.HeteroGraphConv(
                {
                    c_etype: nn.GraphConv(in_feats, embed_dim)
                    for c_etype in canonical_etypes
                }
            )
            self.fc = th.nn.Linear(embed_dim, out_feats)

        def forward(self, g, h, embed=False, edge_weight=None):
            if edge_weight is not None:
                mod_kwargs = {
                    etype: {"edge_weight": mask}
                    for etype, mask in edge_weight.items()
                }
                h = self.conv(g, h, mod_kwargs=mod_kwargs)
            else:
                h = self.conv(g, h)

            if embed:
                return h

            with g.local_scope():
                g.ndata["h"] = h
                hg = 0
                for ntype in g.ntypes:
                    hg = hg + dgl.mean_nodes(g, "h", ntype=ntype)
                return self.fc(hg)

    embed_dim = input_dim
    model = Model(input_dim, embed_dim, n_classes, g.canonical_etypes)
    model = model.to(ctx)

    explainer = nn.HeteroPGExplainer(model, embed_dim)
    explainer.train_step(g, feat, 5.0)

    probs, edge_weight = explainer.explain_graph(g, feat)


Mufei Li's avatar
Mufei Li committed
1913
1914
1915
1916
1917
1918
def test_jumping_knowledge():
    ctx = F.ctx()
    num_layers = 2
    num_nodes = 3
    num_feats = 4

1919
1920
1921
    feat_list = [
        th.randn((num_nodes, num_feats)).to(ctx) for _ in range(num_layers)
    ]
Mufei Li's avatar
Mufei Li committed
1922

1923
    model = nn.JumpingKnowledge("cat").to(ctx)
Mufei Li's avatar
Mufei Li committed
1924
1925
1926
    model.reset_parameters()
    assert model(feat_list).shape == (num_nodes, num_layers * num_feats)

1927
    model = nn.JumpingKnowledge("max").to(ctx)
Mufei Li's avatar
Mufei Li committed
1928
1929
1930
    model.reset_parameters()
    assert model(feat_list).shape == (num_nodes, num_feats)

1931
    model = nn.JumpingKnowledge("lstm", num_feats, num_layers).to(ctx)
Mufei Li's avatar
Mufei Li committed
1932
1933
1934
    model.reset_parameters()
    assert model(feat_list).shape == (num_nodes, num_feats)

1935
1936

@pytest.mark.parametrize("op", ["dot", "cos", "ele", "cat"])
Mufei Li's avatar
Mufei Li committed
1937
1938
1939
1940
1941
1942
1943
1944
1945
def test_edge_predictor(op):
    ctx = F.ctx()
    num_pairs = 3
    in_feats = 4
    out_feats = 5
    h_src = th.randn((num_pairs, in_feats)).to(ctx)
    h_dst = th.randn((num_pairs, in_feats)).to(ctx)

    pred = nn.EdgePredictor(op)
1946
    if op in ["dot", "cos"]:
Mufei Li's avatar
Mufei Li committed
1947
        assert pred(h_src, h_dst).shape == (num_pairs, 1)
1948
    elif op == "ele":
Mufei Li's avatar
Mufei Li committed
1949
1950
1951
1952
1953
1954
        assert pred(h_src, h_dst).shape == (num_pairs, in_feats)
    else:
        assert pred(h_src, h_dst).shape == (num_pairs, 2 * in_feats)
    pred = nn.EdgePredictor(op, in_feats, out_feats, bias=True).to(ctx)
    assert pred(h_src, h_dst).shape == (num_pairs, out_feats)

Mufei Li's avatar
Mufei Li committed
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969

def test_ke_score_funcs():
    ctx = F.ctx()
    num_edges = 30
    num_rels = 3
    nfeats = 4

    h_src = th.randn((num_edges, nfeats)).to(ctx)
    h_dst = th.randn((num_edges, nfeats)).to(ctx)
    rels = th.randint(low=0, high=num_rels, size=(num_edges,)).to(ctx)

    score_func = nn.TransE(num_rels=num_rels, feats=nfeats).to(ctx)
    score_func.reset_parameters()
    score_func(h_src, h_dst, rels).shape == (num_edges)

1970
1971
1972
    score_func = nn.TransR(
        num_rels=num_rels, rfeats=nfeats - 1, nfeats=nfeats
    ).to(ctx)
Mufei Li's avatar
Mufei Li committed
1973
1974
1975
1976
    score_func.reset_parameters()
    score_func(h_src, h_dst, rels).shape == (num_edges)


1977
def test_twirls():
1978
    g = dgl.graph(([0, 1, 2, 3, 2, 5], [1, 2, 3, 4, 0, 3]))
1979
    feat = th.ones(6, 10)
1980
1981
1982
1983
    conv = nn.TWIRLSConv(10, 2, 128, prop_step=64)
    res = conv(g, feat)
    assert res.size() == (6, 2)

1984

1985
1986
1987
1988
@pytest.mark.parametrize("feat_size", [4, 32])
@pytest.mark.parametrize(
    "regularizer,num_bases", [(None, None), ("basis", 4), ("bdd", 4)]
)
1989
1990
1991
def test_typed_linear(feat_size, regularizer, num_bases):
    dev = F.ctx()
    num_types = 5
1992
1993
1994
1995
1996
1997
1998
    lin = nn.TypedLinear(
        feat_size,
        feat_size * 2,
        5,
        regularizer=regularizer,
        num_bases=num_bases,
    ).to(dev)
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
    print(lin)
    x = th.randn(100, feat_size).to(dev)
    x_type = th.randint(0, 5, (100,)).to(dev)
    x_type_sorted, idx = th.sort(x_type)
    _, rev_idx = th.sort(idx)
    x_sorted = x[idx]

    # test unsorted
    y = lin(x, x_type)
    assert y.shape == (100, feat_size * 2)
    # test sorted
    y_sorted = lin(x_sorted, x_type_sorted, sorted_by_type=True)
    assert y_sorted.shape == (100, feat_size * 2)

    assert th.allclose(y, y_sorted[rev_idx], atol=1e-4, rtol=1e-4)
2014

2015

nv-dlasalle's avatar
nv-dlasalle committed
2016
@parametrize_idtype
2017
2018
@pytest.mark.parametrize("in_size", [4])
@pytest.mark.parametrize("num_heads", [1])
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
def test_hgt(idtype, in_size, num_heads):
    dev = F.ctx()
    num_etypes = 5
    num_ntypes = 2
    head_size = in_size // num_heads

    g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.01))
    g = g.astype(idtype).to(dev)
    etype = th.tensor([i % num_etypes for i in range(g.num_edges())]).to(dev)
    ntype = th.tensor([i % num_ntypes for i in range(g.num_nodes())]).to(dev)
    x = th.randn(g.num_nodes(), in_size).to(dev)
2030

2031
2032
2033
    m = nn.HGTConv(in_size, head_size, num_heads, num_ntypes, num_etypes).to(
        dev
    )
2034
2035
2036
2037
2038
2039
2040

    y = m(g, x, ntype, etype)
    assert y.shape == (g.num_nodes(), head_size * num_heads)
    # presorted
    sorted_ntype, idx_nt = th.sort(ntype)
    sorted_etype, idx_et = th.sort(etype)
    _, rev_idx = th.sort(idx_nt)
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
    g.ndata["t"] = ntype
    g.ndata["x"] = x
    g.edata["t"] = etype
    sorted_g = dgl.reorder_graph(
        g,
        node_permute_algo="custom",
        edge_permute_algo="custom",
        permute_config={
            "nodes_perm": idx_nt.to(idtype),
            "edges_perm": idx_et.to(idtype),
        },
    )
    print(sorted_g.ndata["t"])
    print(sorted_g.edata["t"])
    sorted_x = sorted_g.ndata["x"]
    sorted_y = m(
        sorted_g, sorted_x, sorted_ntype, sorted_etype, presorted=False
    )
2059
    assert sorted_y.shape == (g.num_nodes(), head_size * num_heads)
dddg617's avatar
dddg617 committed
2060
    # mini-batch
2061
    train_idx = th.randperm(100, dtype=idtype)[:10]
dddg617's avatar
dddg617 committed
2062
    sampler = dgl.dataloading.NeighborSampler([-1])
2063
2064
2065
    train_loader = dgl.dataloading.DataLoader(
        g, train_idx.to(dev), sampler, batch_size=8, device=dev, shuffle=True
    )
dddg617's avatar
dddg617 committed
2066
2067
2068
2069
2070
2071
2072
2073
    (input_nodes, output_nodes, block) = next(iter(train_loader))
    block = block[0]
    x = x[input_nodes.to(th.long)]
    ntype = ntype[input_nodes.to(th.long)]
    edge = block.edata[dgl.EID]
    etype = etype[edge.to(th.long)]
    y = m(block, x, ntype, etype)
    assert y.shape == (block.number_of_dst_nodes(), head_size * num_heads)
2074
    # TODO(minjie): enable the following check
2075
2076
    # assert th.allclose(y, sorted_y[rev_idx], atol=1e-4, rtol=1e-4)

2077

2078
2079
@pytest.mark.parametrize("self_loop", [True, False])
@pytest.mark.parametrize("get_distances", [True, False])
2080
def test_radius_graph(self_loop, get_distances):
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
    pos = th.tensor(
        [
            [0.1, 0.3, 0.4],
            [0.5, 0.2, 0.1],
            [0.7, 0.9, 0.5],
            [0.3, 0.2, 0.5],
            [0.2, 0.8, 0.2],
            [0.9, 0.2, 0.1],
            [0.7, 0.4, 0.4],
            [0.2, 0.1, 0.6],
            [0.5, 0.3, 0.5],
            [0.4, 0.2, 0.6],
        ]
    )
2095
2096
2097
2098
2099
2100
2101
2102
2103

    rg = nn.RadiusGraph(0.3, self_loop=self_loop)

    if get_distances:
        g, dists = rg(pos, get_distances=get_distances)
    else:
        g = rg(pos)

    if self_loop:
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
        src_target = th.tensor(
            [
                0,
                0,
                1,
                2,
                3,
                3,
                3,
                3,
                3,
                4,
                5,
                6,
                6,
                7,
                7,
                7,
                8,
                8,
                8,
                8,
                9,
                9,
                9,
                9,
            ]
        )
        dst_target = th.tensor(
            [
                0,
                3,
                1,
                2,
                0,
                3,
                7,
                8,
                9,
                4,
                5,
                6,
                8,
                3,
                7,
                9,
                3,
                6,
                8,
                9,
                3,
                7,
                8,
                9,
            ]
        )
2160
2161

        if get_distances:
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
            dists_target = th.tensor(
                [
                    [0.0000],
                    [0.2449],
                    [0.0000],
                    [0.0000],
                    [0.2449],
                    [0.0000],
                    [0.1732],
                    [0.2236],
                    [0.1414],
                    [0.0000],
                    [0.0000],
                    [0.0000],
                    [0.2449],
                    [0.1732],
                    [0.0000],
                    [0.2236],
                    [0.2236],
                    [0.2449],
                    [0.0000],
                    [0.1732],
                    [0.1414],
                    [0.2236],
                    [0.1732],
                    [0.0000],
                ]
            )
2190
2191
2192
2193
2194
    else:
        src_target = th.tensor([0, 3, 3, 3, 3, 6, 7, 7, 8, 8, 8, 9, 9, 9])
        dst_target = th.tensor([3, 0, 7, 8, 9, 8, 3, 9, 3, 6, 9, 3, 7, 8])

        if get_distances:
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
            dists_target = th.tensor(
                [
                    [0.2449],
                    [0.2449],
                    [0.1732],
                    [0.2236],
                    [0.1414],
                    [0.2449],
                    [0.1732],
                    [0.2236],
                    [0.2236],
                    [0.2449],
                    [0.1732],
                    [0.1414],
                    [0.2236],
                    [0.1732],
                ]
            )
2213
2214
2215
2216
2217
2218
2219
2220
2221

    src, dst = g.edges()

    assert th.equal(src, src_target)
    assert th.equal(dst, dst_target)

    if get_distances:
        assert th.allclose(dists, dists_target, rtol=1e-03)

2222

nv-dlasalle's avatar
nv-dlasalle committed
2223
@parametrize_idtype
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
def test_group_rev_res(idtype):
    dev = F.ctx()

    num_nodes = 5
    num_edges = 20
    feats = 32
    groups = 2
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    h = th.randn(num_nodes, feats).to(dev)
    conv = nn.GraphConv(feats // groups, feats // groups)
    model = nn.GroupRevRes(conv, groups).to(dev)
2235
2236
    result = model(g, h)
    result.sum().backward()
rudongyu's avatar
rudongyu committed
2237

2238
2239
2240
2241
2242

@pytest.mark.parametrize("in_size", [16, 32])
@pytest.mark.parametrize("hidden_size", [16, 32])
@pytest.mark.parametrize("out_size", [16, 32])
@pytest.mark.parametrize("edge_feat_size", [16, 10, 0])
rudongyu's avatar
rudongyu committed
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
def test_egnn_conv(in_size, hidden_size, out_size, edge_feat_size):
    dev = F.ctx()
    num_nodes = 5
    num_edges = 20
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    h = th.randn(num_nodes, in_size).to(dev)
    x = th.randn(num_nodes, 3).to(dev)
    e = th.randn(num_edges, edge_feat_size).to(dev)
    model = nn.EGNNConv(in_size, hidden_size, out_size, edge_feat_size).to(dev)
    model(g, h, x, e)

2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283

@pytest.mark.parametrize("in_size", [16, 32])
@pytest.mark.parametrize("out_size", [16, 32])
@pytest.mark.parametrize(
    "aggregators",
    [
        ["mean", "max", "sum"],
        ["min", "std", "var"],
        ["moment3", "moment4", "moment5"],
    ],
)
@pytest.mark.parametrize(
    "scalers", [["identity"], ["amplification", "attenuation"]]
)
@pytest.mark.parametrize("delta", [2.5, 7.4])
@pytest.mark.parametrize("dropout", [0.0, 0.1])
@pytest.mark.parametrize("num_towers", [1, 4])
@pytest.mark.parametrize("edge_feat_size", [16, 0])
@pytest.mark.parametrize("residual", [True, False])
def test_pna_conv(
    in_size,
    out_size,
    aggregators,
    scalers,
    delta,
    dropout,
    num_towers,
    edge_feat_size,
    residual,
):
rudongyu's avatar
rudongyu committed
2284
2285
2286
2287
2288
2289
    dev = F.ctx()
    num_nodes = 5
    num_edges = 20
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    h = th.randn(num_nodes, in_size).to(dev)
    e = th.randn(num_edges, edge_feat_size).to(dev)
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
    model = nn.PNAConv(
        in_size,
        out_size,
        aggregators,
        scalers,
        delta,
        dropout,
        num_towers,
        edge_feat_size,
        residual,
    ).to(dev)
rudongyu's avatar
rudongyu committed
2301
    model(g, h, edge_feat=e)
2302

2303
2304
2305
2306
2307
2308
2309

@pytest.mark.parametrize("k", [3, 5])
@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0])
@pytest.mark.parametrize("norm_type", ["sym", "row"])
@pytest.mark.parametrize("clamp", [True, False])
@pytest.mark.parametrize("normalize", [True, False])
@pytest.mark.parametrize("reset", [True, False])
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
def test_label_prop(k, alpha, norm_type, clamp, normalize, reset):
    dev = F.ctx()
    num_nodes = 5
    num_edges = 20
    num_classes = 4
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    labels = th.tensor([0, 2, 1, 3, 0]).long().to(dev)
    ml_labels = th.rand(num_nodes, num_classes).to(dev) > 0.7
    mask = th.tensor([0, 1, 1, 1, 0]).bool().to(dev)
    model = nn.LabelPropagation(k, alpha, norm_type, clamp, normalize, reset)
    model(g, labels, mask)
    # multi-label case
    model(g, ml_labels, mask)

2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335

@pytest.mark.parametrize("in_size", [16])
@pytest.mark.parametrize("out_size", [16, 32])
@pytest.mark.parametrize(
    "aggregators", [["mean", "max", "dir2-av"], ["min", "std", "dir1-dx"]]
)
@pytest.mark.parametrize("scalers", [["amplification", "attenuation"]])
@pytest.mark.parametrize("delta", [2.5])
@pytest.mark.parametrize("edge_feat_size", [16, 0])
def test_dgn_conv(
    in_size, out_size, aggregators, scalers, delta, edge_feat_size
):
2336
2337
2338
2339
2340
2341
    dev = F.ctx()
    num_nodes = 5
    num_edges = 20
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    h = th.randn(num_nodes, in_size).to(dev)
    e = th.randn(num_edges, edge_feat_size).to(dev)
2342
    transform = dgl.LaplacianPE(k=3, feat_name="eig")
2343
    g = transform(g)
2344
2345
2346
2347
2348
2349
2350
2351
2352
    eig = g.ndata["eig"]
    model = nn.DGNConv(
        in_size,
        out_size,
        aggregators,
        scalers,
        delta,
        edge_feat_size=edge_feat_size,
    ).to(dev)
2353
2354
    model(g, h, edge_feat=e, eig_vec=eig)

2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
    aggregators_non_eig = [
        aggr for aggr in aggregators if not aggr.startswith("dir")
    ]
    model = nn.DGNConv(
        in_size,
        out_size,
        aggregators_non_eig,
        scalers,
        delta,
        edge_feat_size=edge_feat_size,
    ).to(dev)
2366
    model(g, h, edge_feat=e)
LuckyLiuM's avatar
LuckyLiuM committed
2367

2368

LuckyLiuM's avatar
LuckyLiuM committed
2369
2370
2371
def test_DeepWalk():
    dev = F.ctx()
    g = dgl.graph(([0, 1, 2, 1, 2, 0], [1, 2, 0, 0, 1, 2]))
2372
2373
2374
    model = nn.DeepWalk(
        g, emb_dim=8, walk_length=2, window_size=1, fast_neg=True, sparse=True
    )
LuckyLiuM's avatar
LuckyLiuM committed
2375
    model = model.to(dev)
2376
2377
2378
    dataloader = DataLoader(
        torch.arange(g.num_nodes()), batch_size=16, collate_fn=model.sample
    )
LuckyLiuM's avatar
LuckyLiuM committed
2379
2380
2381
2382
2383
2384
    optim = SparseAdam(model.parameters(), lr=0.01)
    walk = next(iter(dataloader)).to(dev)
    loss = model(walk)
    loss.backward()
    optim.step()

2385
2386
2387
    model = nn.DeepWalk(
        g, emb_dim=8, walk_length=2, window_size=1, fast_neg=False, sparse=False
    )
LuckyLiuM's avatar
LuckyLiuM committed
2388
    model = model.to(dev)
2389
2390
2391
    dataloader = DataLoader(
        torch.arange(g.num_nodes()), batch_size=16, collate_fn=model.sample
    )
LuckyLiuM's avatar
LuckyLiuM committed
2392
2393
2394
2395
2396
    optim = Adam(model.parameters(), lr=0.01)
    walk = next(iter(dataloader)).to(dev)
    loss = model(walk)
    loss.backward()
    optim.step()
LuckyLiuM's avatar
LuckyLiuM committed
2397

2398
2399
2400
2401

@pytest.mark.parametrize("max_degree", [2, 6])
@pytest.mark.parametrize("embedding_dim", [8, 16])
@pytest.mark.parametrize("direction", ["in", "out", "both"])
2402
def test_degree_encoder(max_degree, embedding_dim, direction):
2403
    g1 = dgl.graph(
2404
2405
2406
2407
2408
        (
            th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),
            th.tensor([1, 2, 3, 0, 3, 0, 0, 1]),
        )
    )
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
    g2 = dgl.graph(
        (
            th.tensor([0, 1]),
            th.tensor([1, 0]),
        )
    )
    in_degree = pad_sequence(
        [g1.in_degrees(), g2.in_degrees()], batch_first=True
    )
    out_degree = pad_sequence(
        [g1.out_degrees(), g2.out_degrees()], batch_first=True
    )
2421
    model = nn.DegreeEncoder(max_degree, embedding_dim, direction=direction)
2422
2423
2424
2425
2426
2427
2428
    if direction == "in":
        de_g = model(in_degree)
    elif direction == "out":
        de_g = model(out_degree)
    elif direction == "both":
        de_g = model(th.stack((in_degree, out_degree)))
    assert de_g.shape == (2, 4, embedding_dim)
2429

2430

LuckyLiuM's avatar
LuckyLiuM committed
2431
2432
2433
@parametrize_idtype
def test_MetaPath2Vec(idtype):
    dev = F.ctx()
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
    g = dgl.heterograph(
        {
            ("user", "uc", "company"): ([0, 0, 2, 1, 3], [1, 2, 1, 3, 0]),
            ("company", "cp", "product"): (
                [0, 0, 0, 1, 2, 3],
                [0, 2, 3, 0, 2, 1],
            ),
            ("company", "cu", "user"): ([1, 2, 1, 3, 0], [0, 0, 2, 1, 3]),
            ("product", "pc", "company"): (
                [0, 2, 3, 0, 2, 1],
                [0, 0, 0, 1, 2, 3],
            ),
        },
        idtype=idtype,
        device=dev,
    )
    model = nn.MetaPath2Vec(g, ["uc", "cu"], window_size=1)
LuckyLiuM's avatar
LuckyLiuM committed
2451
2452
2453
    model = model.to(dev)
    embeds = model.node_embed.weight
    assert embeds.shape[0] == g.num_nodes()
2454

2455
2456
2457
2458
2459
2460
2461

@pytest.mark.parametrize("num_layer", [1, 4])
@pytest.mark.parametrize("k", [3, 5])
@pytest.mark.parametrize("lpe_dim", [4, 16])
@pytest.mark.parametrize("n_head", [1, 4])
@pytest.mark.parametrize("batch_norm", [True, False])
@pytest.mark.parametrize("num_post_layer", [0, 1, 2])
2462
def test_LapPosEncoder(
2463
2464
    num_layer, k, lpe_dim, n_head, batch_norm, num_post_layer
):
2465
2466
2467
2468
2469
2470
    ctx = F.ctx()
    num_nodes = 4

    EigVals = th.randn((num_nodes, k)).to(ctx)
    EigVecs = th.randn((num_nodes, k)).to(ctx)

2471
    model = nn.LapPosEncoder(
2472
2473
        "Transformer", num_layer, k, lpe_dim, n_head, batch_norm, num_post_layer
    ).to(ctx)
2474
2475
    assert model(EigVals, EigVecs).shape == (num_nodes, lpe_dim)

2476
    model = nn.LapPosEncoder(
2477
2478
2479
2480
2481
2482
2483
        "DeepSet",
        num_layer,
        k,
        lpe_dim,
        batch_norm=batch_norm,
        num_post_layer=num_post_layer,
    ).to(ctx)
2484
    assert model(EigVals, EigVecs).shape == (num_nodes, lpe_dim)
2485

2486
2487
2488
2489
2490
2491

@pytest.mark.parametrize("feat_size", [128, 512])
@pytest.mark.parametrize("num_heads", [8, 16])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("attn_bias_type", ["add", "mul"])
@pytest.mark.parametrize("attn_drop", [0.1, 0.5])
2492
def test_BiasedMHA(feat_size, num_heads, bias, attn_bias_type, attn_drop):
2493
2494
2495
2496
    ndata = th.rand(16, 100, feat_size)
    attn_bias = th.rand(16, 100, 100, num_heads)
    attn_mask = th.rand(16, 100, 100) < 0.5

2497
    net = nn.BiasedMHA(feat_size, num_heads, bias, attn_bias_type, attn_drop)
2498
2499
2500
    out = net(ndata, attn_bias, attn_mask)

    assert out.shape == (16, 100, feat_size)
2501

2502
2503
2504

@pytest.mark.parametrize("attn_bias_type", ["add", "mul"])
@pytest.mark.parametrize("norm_first", [True, False])
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
def test_GraphormerLayer(attn_bias_type, norm_first):
    batch_size = 16
    num_nodes = 100
    feat_size = 512
    num_heads = 8
    nfeat = th.rand(batch_size, num_nodes, feat_size)
    attn_bias = th.rand(batch_size, num_nodes, num_nodes, num_heads)
    attn_mask = th.rand(batch_size, num_nodes, num_nodes) < 0.5

    net = nn.GraphormerLayer(
        feat_size=feat_size,
        hidden_size=2048,
        num_heads=num_heads,
        attn_bias_type=attn_bias_type,
        norm_first=norm_first,
        dropout=0.1,
2521
        attn_dropout=0.1,
2522
        activation=th.nn.ReLU(),
2523
2524
2525
2526
2527
    )
    out = net(nfeat, attn_bias, attn_mask)

    assert out.shape == (batch_size, num_nodes, feat_size)

2528

2529
@pytest.mark.parametrize("max_len", [1, 2])
2530
2531
@pytest.mark.parametrize("feat_dim", [16])
@pytest.mark.parametrize("num_heads", [1, 8])
2532
2533
def test_PathEncoder(max_len, feat_dim, num_heads):
    dev = F.ctx()
2534
    g = dgl.graph(
2535
2536
2537
2538
2539
        (
            th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),
            th.tensor([1, 2, 3, 0, 3, 0, 0, 1]),
        )
    ).to(dev)
2540
2541
2542
2543
    edge_feat = th.rand(g.num_edges(), feat_dim).to(dev)
    edge_feat = th.cat((edge_feat, th.zeros(1, 16).to(dev)), dim=0)
    dist, path = shortest_dist(g, root=None, return_paths=True)
    path_data = edge_feat[path[:, :, :max_len]]
2544
    model = nn.PathEncoder(max_len, feat_dim, num_heads=num_heads).to(dev)
2545
2546
    bias = model(dist.unsqueeze(0), path_data.unsqueeze(0))
    assert bias.shape == (1, 4, 4, num_heads)
2547

2548
2549
2550
2551

@pytest.mark.parametrize("max_dist", [1, 4])
@pytest.mark.parametrize("num_kernels", [8, 16])
@pytest.mark.parametrize("num_heads", [1, 8])
2552
2553
def test_SpatialEncoder(max_dist, num_kernels, num_heads):
    dev = F.ctx()
2554
2555
2556
2557
2558
2559
2560
2561
2562
    g1 = dgl.graph(
        (
            th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),
            th.tensor([1, 2, 3, 0, 3, 0, 0, 1]),
        )
    ).to(dev)
    g2 = dgl.graph(
        (th.tensor([0, 1, 2, 3, 2, 5]), th.tensor([1, 2, 3, 4, 0, 3]))
    ).to(dev)
2563
2564
2565
2566
    bg = dgl.batch([g1, g2])
    ndata = th.rand(bg.num_nodes(), 3).to(dev)
    num_nodes = bg.num_nodes()
    node_type = th.randint(0, 512, (num_nodes,)).to(dev)
2567
2568
2569
    dist = -th.ones((2, 6, 6), dtype=th.long).to(dev)
    dist[0, :4, :4] = shortest_dist(g1, root=None, return_paths=False)
    dist[1, :6, :6] = shortest_dist(g2, root=None, return_paths=False)
2570
2571
2572
2573
2574
    model_1 = nn.SpatialEncoder(max_dist, num_heads=num_heads).to(dev)
    model_2 = nn.SpatialEncoder3d(num_kernels, num_heads=num_heads).to(dev)
    model_3 = nn.SpatialEncoder3d(
        num_kernels, num_heads=num_heads, max_node_type=512
    ).to(dev)
2575
    encoding = model_1(dist)
2576
2577
2578
2579
2580
    encoding3d_1 = model_2(bg, ndata)
    encoding3d_2 = model_3(bg, ndata, node_type)
    assert encoding.shape == (2, 6, 6, num_heads)
    assert encoding3d_1.shape == (2, 6, 6, num_heads)
    assert encoding3d_2.shape == (2, 6, 6, num_heads)