test_nn.py 80.5 KB
Newer Older
1
import io
2
import pickle
3
import random
4
5
6
7
from copy import deepcopy

import backend as F

8
import dgl
9
import dgl.function as fn
10
11
import dgl.nn.pytorch as nn
import networkx as nx
12
import numpy as np  # For setting seed for scipy
13
import pytest
14
import scipy as sp
LuckyLiuM's avatar
LuckyLiuM committed
15
import torch
16
import torch as th
17
18
from dgl import shortest_dist
from torch.nn.utils.rnn import pad_sequence
19
20
21
22
from torch.optim import Adam, SparseAdam
from torch.utils.data import DataLoader
from utils import parametrize_idtype
from utils.graph_cases import (
23
24
25
26
27
    get_cases,
    random_bipartite,
    random_dglgraph,
    random_graph,
)
28

29
30
31
32
33
34
35
# Set seeds to make tests fully reproducible.
SEED = 12345  # random.randint(1, 99999)
random.seed(SEED)  # For networkx
np.random.seed(SEED)  # For scipy
dgl.seed(SEED)
F.seed(SEED)

36
37
tmp_buffer = io.BytesIO()

38

39
40
41
42
43
def _AXWb(A, X, W, b):
    X = th.matmul(X, W)
    Y = th.matmul(A, X.view(X.shape[0], -1)).view_as(X)
    return Y + b

44

45
46
47
48
49
def graph_with_nodes(num_nodes, ctx=None):
    g = dgl.from_networkx(nx.path_graph(num_nodes))
    return g.to(ctx) if ctx else g


50
@pytest.mark.parametrize("out_dim", [1, 2])
51
def test_graph_conv0(out_dim):
52
    ctx = F.ctx()
53
    g = graph_with_nodes(3, ctx)
54
    adj = g.adj_external(transpose=True, ctx=ctx)
55

56
    conv = nn.GraphConv(5, out_dim, norm="none", bias=True)
57
    conv = conv.to(ctx)
58
    print(conv)
59
60
61
62

    # test pickle
    th.save(conv, tmp_buffer)

63
    # test#1: basic
64
    h0 = F.ones((3, 5))
65
    h1 = conv(g, h0)
66
67
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
68
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
69
    # test#2: more-dim
70
    h0 = F.ones((3, 5, 5))
71
    h1 = conv(g, h0)
72
73
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
74
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
75

76
    conv = nn.GraphConv(5, out_dim)
77
    conv = conv.to(ctx)
78
    # test#3: basic
79
    h0 = F.ones((3, 5))
80
    h1 = conv(g, h0)
81
82
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
83
    # test#4: basic
84
    h0 = F.ones((3, 5, 5))
85
    h1 = conv(g, h0)
86
87
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
88

89
    conv = nn.GraphConv(5, out_dim)
90
    conv = conv.to(ctx)
91
    # test#3: basic
92
    h0 = F.ones((3, 5))
93
    h1 = conv(g, h0)
94
95
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
96
    # test#4: basic
97
    h0 = F.ones((3, 5, 5))
98
    h1 = conv(g, h0)
99
100
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
101
102
103
104
105

    # test rest_parameters
    old_weight = deepcopy(conv.weight.data)
    conv.reset_parameters()
    new_weight = conv.weight.data
106
    assert not F.allclose(old_weight, new_weight)
107

108

nv-dlasalle's avatar
nv-dlasalle committed
109
@parametrize_idtype
110
111
112
113
114
115
116
@pytest.mark.parametrize(
    "g", get_cases(["homo", "bipartite"], exclude=["zero-degree", "dglgraph"])
)
@pytest.mark.parametrize("norm", ["none", "both", "right", "left"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("out_dim", [1, 2])
117
def test_graph_conv(idtype, g, norm, weight, bias, out_dim):
118
119
    # Test one tensor input
    g = g.astype(idtype).to(F.ctx())
120
121
122
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(
        F.ctx()
    )
123
    ext_w = F.randn((5, out_dim)).to(F.ctx())
124
125
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
126
127
    h = F.randn((nsrc, 5)).to(F.ctx())
    if weight:
128
        h_out = conv(g, h)
129
    else:
130
        h_out = conv(g, h, weight=ext_w)
131
    assert h_out.shape == (ndst, out_dim)
132

133

nv-dlasalle's avatar
nv-dlasalle committed
134
@parametrize_idtype
135
136
137
138
139
140
141
142
@pytest.mark.parametrize(
    "g",
    get_cases(["has_scalar_e_feature"], exclude=["zero-degree", "dglgraph"]),
)
@pytest.mark.parametrize("norm", ["none", "both", "right"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("out_dim", [1, 2])
143
def test_graph_conv_e_weight(idtype, g, norm, weight, bias, out_dim):
144
    g = g.astype(idtype).to(F.ctx())
145
146
147
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(
        F.ctx()
    )
148
    ext_w = F.randn((5, out_dim)).to(F.ctx())
149
150
151
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
152
    e_w = g.edata["scalar_w"]
153
154
155
156
    if weight:
        h_out = conv(g, h, edge_weight=e_w)
    else:
        h_out = conv(g, h, weight=ext_w, edge_weight=e_w)
157
    assert h_out.shape == (ndst, out_dim)
158

159

nv-dlasalle's avatar
nv-dlasalle committed
160
@parametrize_idtype
161
162
163
164
165
166
167
168
@pytest.mark.parametrize(
    "g",
    get_cases(["has_scalar_e_feature"], exclude=["zero-degree", "dglgraph"]),
)
@pytest.mark.parametrize("norm", ["none", "both", "right"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("out_dim", [1, 2])
169
def test_graph_conv_e_weight_norm(idtype, g, norm, weight, bias, out_dim):
170
    g = g.astype(idtype).to(F.ctx())
171
172
173
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(
        F.ctx()
    )
174
175
176
177

    # test pickle
    th.save(conv, tmp_buffer)

178
    ext_w = F.randn((5, out_dim)).to(F.ctx())
179
180
181
182
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
    edgenorm = nn.EdgeWeightNorm(norm=norm)
183
    norm_weight = edgenorm(g, g.edata["scalar_w"])
184
185
186
187
    if weight:
        h_out = conv(g, h, edge_weight=norm_weight)
    else:
        h_out = conv(g, h, weight=ext_w, edge_weight=norm_weight)
188
    assert h_out.shape == (ndst, out_dim)
189

190

nv-dlasalle's avatar
nv-dlasalle committed
191
@parametrize_idtype
192
193
194
195
196
197
198
@pytest.mark.parametrize(
    "g", get_cases(["bipartite"], exclude=["zero-degree", "dglgraph"])
)
@pytest.mark.parametrize("norm", ["none", "both", "right"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("out_dim", [1, 2])
199
def test_graph_conv_bi(idtype, g, norm, weight, bias, out_dim):
200
201
    # Test a pair of tensor inputs
    g = g.astype(idtype).to(F.ctx())
202
203
204
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(
        F.ctx()
    )
Mufei Li's avatar
Mufei Li committed
205

206
207
208
    # test pickle
    th.save(conv, tmp_buffer)

209
    ext_w = F.randn((5, out_dim)).to(F.ctx())
210
211
212
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
213
    h_dst = F.randn((ndst, out_dim)).to(F.ctx())
214
215
216
217
    if weight:
        h_out = conv(g, (h, h_dst))
    else:
        h_out = conv(g, (h, h_dst), weight=ext_w)
218
    assert h_out.shape == (ndst, out_dim)
219

220

221
222
223
224
225
226
227
228
229
230
231
232
def _S2AXWb(A, N, X, W, b):
    X1 = X * N
    X1 = th.matmul(A, X1.view(X1.shape[0], -1))
    X1 = X1 * N
    X2 = X1 * N
    X2 = th.matmul(A, X2.view(X2.shape[0], -1))
    X2 = X2 * N
    X = th.cat([X, X1, X2], dim=-1)
    Y = th.matmul(X, W.rot90())

    return Y + b

233
234

@pytest.mark.parametrize("out_dim", [1, 2])
235
def test_tagconv(out_dim):
236
    ctx = F.ctx()
237
    g = graph_with_nodes(3, ctx)
238
    adj = g.adj_external(transpose=True, ctx=ctx)
239
240
    norm = th.pow(g.in_degrees().float(), -0.5)

241
    conv = nn.TAGConv(5, out_dim, bias=True)
242
    conv = conv.to(ctx)
243
    print(conv)
Mufei Li's avatar
Mufei Li committed
244

245
246
    # test pickle
    th.save(conv, tmp_buffer)
247
248
249

    # test#1: basic
    h0 = F.ones((3, 5))
250
    h1 = conv(g, h0)
251
252
253
254
255
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    shp = norm.shape + (1,) * (h0.dim() - 1)
    norm = th.reshape(norm, shp).to(ctx)

256
257
258
    assert F.allclose(
        h1, _S2AXWb(adj, norm, h0, conv.lin.weight, conv.lin.bias)
    )
259

260
    conv = nn.TAGConv(5, out_dim)
261
    conv = conv.to(ctx)
262

263
264
    # test#2: basic
    h0 = F.ones((3, 5))
265
    h1 = conv(g, h0)
266
    assert h1.shape[-1] == out_dim
267

268
    # test reset_parameters
269
270
271
272
273
    old_weight = deepcopy(conv.lin.weight.data)
    conv.reset_parameters()
    new_weight = conv.lin.weight.data
    assert not F.allclose(old_weight, new_weight)

274

275
def test_set2set():
276
    ctx = F.ctx()
277
    g = graph_with_nodes(10, ctx)
278

279
    s2s = nn.Set2Set(5, 3, 3)  # hidden size 5, 3 iters, 3 layers
280
    s2s = s2s.to(ctx)
281
282
283
    print(s2s)

    # test#1: basic
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
284
    h0 = F.randn((g.num_nodes(), 5))
285
    h1 = s2s(g, h0)
286
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
287
288

    # test#2: batched graph
289
290
    g1 = graph_with_nodes(11, ctx)
    g2 = graph_with_nodes(5, ctx)
291
    bg = dgl.batch([g, g1, g2])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
292
    h0 = F.randn((bg.num_nodes(), 5))
293
    h1 = s2s(bg, h0)
294
295
    assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2

296

297
def test_glob_att_pool():
298
    ctx = F.ctx()
299
    g = graph_with_nodes(10, ctx)
300
301

    gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))
302
    gap = gap.to(ctx)
303
304
    print(gap)

305
306
307
    # test pickle
    th.save(gap, tmp_buffer)

308
    # test#1: basic
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
309
    h0 = F.randn((g.num_nodes(), 5))
310
    h1 = gap(g, h0)
311
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
312
313
314

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
315
    h0 = F.randn((bg.num_nodes(), 5))
316
    h1 = gap(bg, h0)
317
318
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2

319

320
def test_simple_pool():
321
    ctx = F.ctx()
322
    g = graph_with_nodes(15, ctx)
323
324
325
326

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
327
    sort_pool = nn.SortPooling(10)  # k = 10
328
329
330
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
331
    h0 = F.randn((g.num_nodes(), 5))
332
333
334
335
    sum_pool = sum_pool.to(ctx)
    avg_pool = avg_pool.to(ctx)
    max_pool = max_pool.to(ctx)
    sort_pool = sort_pool.to(ctx)
336
    h1 = sum_pool(g, h0)
337
    assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
338
    h1 = avg_pool(g, h0)
339
    assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
340
    h1 = max_pool(g, h0)
341
    assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
342
    h1 = sort_pool(g, h0)
343
    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.dim() == 2
344
345

    # test#2: batched graph
346
    g_ = graph_with_nodes(5, ctx)
347
    bg = dgl.batch([g, g_, g, g_, g])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
348
    h0 = F.randn((bg.num_nodes(), 5))
349
    h1 = sum_pool(bg, h0)
350
351
352
353
354
355
356
357
358
359
    truth = th.stack(
        [
            F.sum(h0[:15], 0),
            F.sum(h0[15:20], 0),
            F.sum(h0[20:35], 0),
            F.sum(h0[35:40], 0),
            F.sum(h0[40:55], 0),
        ],
        0,
    )
360
    assert F.allclose(h1, truth)
361

362
    h1 = avg_pool(bg, h0)
363
364
365
366
367
368
369
370
371
372
    truth = th.stack(
        [
            F.mean(h0[:15], 0),
            F.mean(h0[15:20], 0),
            F.mean(h0[20:35], 0),
            F.mean(h0[35:40], 0),
            F.mean(h0[40:55], 0),
        ],
        0,
    )
373
    assert F.allclose(h1, truth)
374

375
    h1 = max_pool(bg, h0)
376
377
378
379
380
381
382
383
384
385
    truth = th.stack(
        [
            F.max(h0[:15], 0),
            F.max(h0[15:20], 0),
            F.max(h0[20:35], 0),
            F.max(h0[35:40], 0),
            F.max(h0[40:55], 0),
        ],
        0,
    )
386
    assert F.allclose(h1, truth)
387

388
    h1 = sort_pool(bg, h0)
389
390
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2

391

392
def test_set_trans():
393
    ctx = F.ctx()
394
    g = graph_with_nodes(15)
395

396
397
    st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, "sab")
    st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, "isab", 3)
398
    st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)
399
400
401
    st_enc_0 = st_enc_0.to(ctx)
    st_enc_1 = st_enc_1.to(ctx)
    st_dec = st_dec.to(ctx)
402
403
404
    print(st_enc_0, st_enc_1, st_dec)

    # test#1: basic
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
405
    h0 = F.randn((g.num_nodes(), 50))
406
    h1 = st_enc_0(g, h0)
407
    assert h1.shape == h0.shape
408
    h1 = st_enc_1(g, h0)
409
    assert h1.shape == h0.shape
410
    h2 = st_dec(g, h1)
411
    assert h2.shape[0] == 1 and h2.shape[1] == 200 and h2.dim() == 2
412
413

    # test#2: batched graph
414
415
    g1 = graph_with_nodes(5)
    g2 = graph_with_nodes(10)
416
    bg = dgl.batch([g, g1, g2])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
417
    h0 = F.randn((bg.num_nodes(), 50))
418
    h1 = st_enc_0(bg, h0)
419
    assert h1.shape == h0.shape
420
    h1 = st_enc_1(bg, h0)
421
422
    assert h1.shape == h0.shape

423
    h2 = st_dec(bg, h1)
424
425
    assert h2.shape[0] == 3 and h2.shape[1] == 200 and h2.dim() == 2

426

nv-dlasalle's avatar
nv-dlasalle committed
427
@parametrize_idtype
428
@pytest.mark.parametrize("O", [1, 8, 32])
429
def test_rgcn(idtype, O):
Minjie Wang's avatar
Minjie Wang committed
430
431
    ctx = F.ctx()
    etype = []
432
433
    g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.1))
    g = g.astype(idtype).to(F.ctx())
Minjie Wang's avatar
Minjie Wang committed
434
435
    # 5 etypes
    R = 5
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
436
    for i in range(g.num_edges()):
Minjie Wang's avatar
Minjie Wang committed
437
438
439
440
441
442
        etype.append(i % 5)
    B = 2
    I = 10

    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
443
    norm = th.rand((g.num_edges(), 1)).to(ctx)
444
    sorted_r, idx = th.sort(r)
445
446
447
448
449
    sorted_g = dgl.reorder_graph(
        g,
        edge_permute_algo="custom",
        permute_config={"edges_perm": idx.to(idtype)},
    )
450
    sorted_norm = norm[idx]
Minjie Wang's avatar
Minjie Wang committed
451

452
453
    rgc = nn.RelGraphConv(I, O, R).to(ctx)
    th.save(rgc, tmp_buffer)  # test pickle
Minjie Wang's avatar
Minjie Wang committed
454
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
455
    th.save(rgc_basis, tmp_buffer)  # test pickle
456
457
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
458
        th.save(rgc_bdd, tmp_buffer)  # test pickle
459

460
461
462
463
464
    # basic usage
    h_new = rgc(g, h, r)
    assert h_new.shape == (100, O)
    h_new_basis = rgc_basis(g, h, r)
    assert h_new_basis.shape == (100, O)
465
    if O % B == 0:
466
467
468
469
470
471
472
473
474
475
476
        h_new_bdd = rgc_bdd(g, h, r)
        assert h_new_bdd.shape == (100, O)

    # sorted input
    h_new_sorted = rgc(sorted_g, h, sorted_r, presorted=True)
    assert th.allclose(h_new, h_new_sorted, atol=1e-4, rtol=1e-4)
    h_new_basis_sorted = rgc_basis(sorted_g, h, sorted_r, presorted=True)
    assert th.allclose(h_new_basis, h_new_basis_sorted, atol=1e-4, rtol=1e-4)
    if O % B == 0:
        h_new_bdd_sorted = rgc_bdd(sorted_g, h, sorted_r, presorted=True)
        assert th.allclose(h_new_bdd, h_new_bdd_sorted, atol=1e-4, rtol=1e-4)
477

478
479
480
    # norm input
    h_new = rgc(g, h, r, norm)
    assert h_new.shape == (100, O)
481
    h_new = rgc_basis(g, h, r, norm)
482
    assert h_new.shape == (100, O)
483
484
    if O % B == 0:
        h_new = rgc_bdd(g, h, r, norm)
485
        assert h_new.shape == (100, O)
486

487

488
@parametrize_idtype
489
@pytest.mark.parametrize("O", [1, 10, 40])
490
491
492
493
494
495
496
def test_rgcn_default_nbasis(idtype, O):
    ctx = F.ctx()
    etype = []
    g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.1))
    g = g.astype(idtype).to(F.ctx())
    # 5 etypes
    R = 5
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
497
    for i in range(g.num_edges()):
498
499
500
501
502
        etype.append(i % 5)
    I = 10

    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
503
    norm = th.rand((g.num_edges(), 1)).to(ctx)
504
    sorted_r, idx = th.sort(r)
505
506
507
508
509
    sorted_g = dgl.reorder_graph(
        g,
        edge_permute_algo="custom",
        permute_config={"edges_perm": idx.to(idtype)},
    )
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
    sorted_norm = norm[idx]

    rgc = nn.RelGraphConv(I, O, R).to(ctx)
    th.save(rgc, tmp_buffer)  # test pickle
    rgc_basis = nn.RelGraphConv(I, O, R, "basis").to(ctx)
    th.save(rgc_basis, tmp_buffer)  # test pickle
    if O % R == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd").to(ctx)
        th.save(rgc_bdd, tmp_buffer)  # test pickle

    # basic usage
    h_new = rgc(g, h, r)
    assert h_new.shape == (100, O)
    h_new_basis = rgc_basis(g, h, r)
    assert h_new_basis.shape == (100, O)
    if O % R == 0:
        h_new_bdd = rgc_bdd(g, h, r)
        assert h_new_bdd.shape == (100, O)

    # sorted input
    h_new_sorted = rgc(sorted_g, h, sorted_r, presorted=True)
    assert th.allclose(h_new, h_new_sorted, atol=1e-4, rtol=1e-4)
    h_new_basis_sorted = rgc_basis(sorted_g, h, sorted_r, presorted=True)
    assert th.allclose(h_new_basis, h_new_basis_sorted, atol=1e-4, rtol=1e-4)
    if O % R == 0:
        h_new_bdd_sorted = rgc_bdd(sorted_g, h, sorted_r, presorted=True)
        assert th.allclose(h_new_bdd, h_new_bdd_sorted, atol=1e-4, rtol=1e-4)

    # norm input
    h_new = rgc(g, h, r, norm)
    assert h_new.shape == (100, O)
    h_new = rgc_basis(g, h, r, norm)
    assert h_new.shape == (100, O)
    if O % R == 0:
        h_new = rgc_bdd(g, h, r, norm)
        assert h_new.shape == (100, O)
546

547

nv-dlasalle's avatar
nv-dlasalle committed
548
@parametrize_idtype
549
550
551
552
553
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
554
def test_gat_conv(g, idtype, out_dim, num_heads):
555
    ctx = F.ctx()
556
    g = g.astype(idtype).to(ctx)
557
    gat = nn.GATConv(5, out_dim, num_heads)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
558
    feat = F.randn((g.number_of_src_nodes(), 5))
559
    gat = gat.to(ctx)
560
    h = gat(g, feat)
561
562
563
564

    # test pickle
    th.save(gat, tmp_buffer)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
565
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
566
    _, a = gat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
567
    assert a.shape == (g.num_edges(), num_heads, 1)
568

569
570
571
572
573
    # test residual connection
    gat = nn.GATConv(5, out_dim, num_heads, residual=True)
    gat = gat.to(ctx)
    h = gat(g, feat)

574

nv-dlasalle's avatar
nv-dlasalle committed
575
@parametrize_idtype
576
577
578
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
579
def test_gat_conv_bi(g, idtype, out_dim, num_heads):
580
    ctx = F.ctx()
581
    g = g.astype(idtype).to(ctx)
582
    gat = nn.GATConv(5, out_dim, num_heads)
583
584
585
586
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
587
588
    gat = gat.to(ctx)
    h = gat(g, feat)
589
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
590
    _, a = gat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
591
    assert a.shape == (g.num_edges(), num_heads, 1)
592

593

594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
@parametrize_idtype
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
def test_gat_conv_edge_weight(g, idtype, out_dim, num_heads):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    gat = nn.GATConv(5, out_dim, num_heads)
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
    gat = gat.to(ctx)
    ew = F.randn((g.num_edges(),))
    h = gat(g, feat, edge_weight=ew)
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
    _, a = gat(g, feat, get_attention=True)
    assert a.shape[0] == ew.shape[0]
    assert a.shape == (g.num_edges(), num_heads, 1)


nv-dlasalle's avatar
nv-dlasalle committed
615
@parametrize_idtype
616
617
618
619
620
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
Shaked Brody's avatar
Shaked Brody committed
621
622
623
624
625
626
627
628
629
630
631
632
633
def test_gatv2_conv(g, idtype, out_dim, num_heads):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    gat = nn.GATv2Conv(5, out_dim, num_heads)
    feat = F.randn((g.number_of_src_nodes(), 5))
    gat = gat.to(ctx)
    h = gat(g, feat)

    # test pickle
    th.save(gat, tmp_buffer)

    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
    _, a = gat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
634
    assert a.shape == (g.num_edges(), num_heads, 1)
Shaked Brody's avatar
Shaked Brody committed
635
636
637
638
639
640

    # test residual connection
    gat = nn.GATConv(5, out_dim, num_heads, residual=True)
    gat = gat.to(ctx)
    h = gat(g, feat)

641

nv-dlasalle's avatar
nv-dlasalle committed
642
@parametrize_idtype
643
644
645
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
Shaked Brody's avatar
Shaked Brody committed
646
647
648
649
def test_gatv2_conv_bi(g, idtype, out_dim, num_heads):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    gat = nn.GATv2Conv(5, out_dim, num_heads)
650
651
652
653
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
Shaked Brody's avatar
Shaked Brody committed
654
655
656
657
    gat = gat.to(ctx)
    h = gat(g, feat)
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
    _, a = gat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
658
    assert a.shape == (g.num_edges(), num_heads, 1)
Shaked Brody's avatar
Shaked Brody committed
659

660

nv-dlasalle's avatar
nv-dlasalle committed
661
@parametrize_idtype
662
663
664
665
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_node_feats", [1, 5])
@pytest.mark.parametrize("out_edge_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
666
def test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads):
Mufei Li's avatar
Mufei Li committed
667
    ctx = F.ctx()
668
    g = g.astype(idtype).to(ctx)
669
670
671
672
673
674
675
    egat = nn.EGATConv(
        in_node_feats=10,
        in_edge_feats=5,
        out_node_feats=out_node_feats,
        out_edge_feats=out_edge_feats,
        num_heads=num_heads,
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
676
677
    nfeat = F.randn((g.num_nodes(), 10))
    efeat = F.randn((g.num_edges(), 5))
678
679
    egat = egat.to(ctx)
    h, f = egat(g, nfeat, efeat)
680

681
    th.save(egat, tmp_buffer)
682

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
683
684
    assert h.shape == (g.num_nodes(), num_heads, out_node_feats)
    assert f.shape == (g.num_edges(), num_heads, out_edge_feats)
685
    _, _, attn = egat(g, nfeat, efeat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
686
    assert attn.shape == (g.num_edges(), num_heads, 1)
687

688

689
@parametrize_idtype
690
691
692
693
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_node_feats", [1, 5])
@pytest.mark.parametrize("out_edge_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
694
695
def test_egat_conv_bi(g, idtype, out_node_feats, out_edge_feats, num_heads):
    ctx = F.ctx()
696
    g = g.astype(idtype).to(ctx)
697
698
699
700
701
702
703
704
705
706
707
    egat = nn.EGATConv(
        in_node_feats=(10, 15),
        in_edge_feats=7,
        out_node_feats=out_node_feats,
        out_edge_feats=out_edge_feats,
        num_heads=num_heads,
    )
    nfeat = (
        F.randn((g.number_of_src_nodes(), 10)),
        F.randn((g.number_of_dst_nodes(), 15)),
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
708
    efeat = F.randn((g.num_edges(), 7))
709
710
    egat = egat.to(ctx)
    h, f = egat(g, nfeat, efeat)
711

Mufei Li's avatar
Mufei Li committed
712
    th.save(egat, tmp_buffer)
713

714
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_node_feats)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
715
    assert f.shape == (g.num_edges(), num_heads, out_edge_feats)
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
    _, _, attn = egat(g, nfeat, efeat, get_attention=True)
    assert attn.shape == (g.num_edges(), num_heads, 1)


@parametrize_idtype
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_node_feats", [1, 5])
@pytest.mark.parametrize("out_edge_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
def test_egat_conv_edge_weight(
    g, idtype, out_node_feats, out_edge_feats, num_heads
):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    egat = nn.EGATConv(
        in_node_feats=10,
        in_edge_feats=5,
        out_node_feats=out_node_feats,
        out_edge_feats=out_edge_feats,
        num_heads=num_heads,
    )
    egat = egat.to(ctx)
    nfeat = F.randn((g.num_nodes(), 10))
    efeat = F.randn((g.num_edges(), 5))
    ew = F.randn((g.num_edges(),))

    h, f, attn = egat(g, nfeat, efeat, edge_weight=ew, get_attention=True)

    assert h.shape == (g.num_nodes(), num_heads, out_node_feats)
    assert f.shape == (g.num_edges(), num_heads, out_edge_feats)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
746
    assert attn.shape == (g.num_edges(), num_heads, 1)
schmidt-ju's avatar
schmidt-ju committed
747
748
749
750
751
752
753
754


@parametrize_idtype
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
def test_edgegat_conv(g, idtype, out_feats, num_heads):
    ctx = F.ctx()
755
    g = g.astype(idtype).to(ctx)
schmidt-ju's avatar
schmidt-ju committed
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
    edgegat = nn.EdgeGATConv(
        in_feats=10, edge_feats=5, out_feats=out_feats, num_heads=num_heads
    )
    nfeat = F.randn((g.number_of_nodes(), 10))
    efeat = F.randn((g.number_of_edges(), 5))
    edgegat = edgegat.to(ctx)
    h = edgegat(g, nfeat, efeat)

    th.save(edgegat, tmp_buffer)

    assert h.shape == (g.number_of_nodes(), num_heads, out_feats)
    _, attn = edgegat(g, nfeat, efeat, True)
    assert attn.shape == (g.number_of_edges(), num_heads, 1)


@parametrize_idtype
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
def test_edgegat_conv_bi(g, idtype, out_feats, num_heads):
    ctx = F.ctx()
777
    g = g.astype(idtype).to(ctx)
schmidt-ju's avatar
schmidt-ju committed
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
    edgegat = nn.EdgeGATConv(
        in_feats=(10, 15),
        edge_feats=7,
        out_feats=out_feats,
        num_heads=num_heads,
    )
    nfeat = (
        F.randn((g.number_of_src_nodes(), 10)),
        F.randn((g.number_of_dst_nodes(), 15)),
    )
    efeat = F.randn((g.number_of_edges(), 7))
    edgegat = edgegat.to(ctx)
    h = edgegat(g, nfeat, efeat)

    th.save(edgegat, tmp_buffer)

    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_feats)
    _, attn = edgegat(g, nfeat, efeat, True)
    assert attn.shape == (g.number_of_edges(), num_heads, 1)
797

798

nv-dlasalle's avatar
nv-dlasalle committed
799
@parametrize_idtype
800
801
@pytest.mark.parametrize("g", get_cases(["homo", "block-bipartite"]))
@pytest.mark.parametrize("aggre_type", ["mean", "pool", "gcn", "lstm"])
802
803
def test_sage_conv(idtype, g, aggre_type):
    g = g.astype(idtype).to(F.ctx())
804
    sage = nn.SAGEConv(5, 10, aggre_type)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
805
    feat = F.randn((g.number_of_src_nodes(), 5))
806
    sage = sage.to(F.ctx())
807
808
    # test pickle
    th.save(sage, tmp_buffer)
809
810
811
    h = sage(g, feat)
    assert h.shape[-1] == 10

812

nv-dlasalle's avatar
nv-dlasalle committed
813
@parametrize_idtype
814
815
816
@pytest.mark.parametrize("g", get_cases(["bipartite"]))
@pytest.mark.parametrize("aggre_type", ["mean", "pool", "gcn", "lstm"])
@pytest.mark.parametrize("out_dim", [1, 2])
817
def test_sage_conv_bi(idtype, g, aggre_type, out_dim):
818
    g = g.astype(idtype).to(F.ctx())
819
    dst_dim = 5 if aggre_type != "gcn" else 10
820
    sage = nn.SAGEConv((10, dst_dim), out_dim, aggre_type)
821
822
823
824
    feat = (
        F.randn((g.number_of_src_nodes(), 10)),
        F.randn((g.number_of_dst_nodes(), dst_dim)),
    )
825
    sage = sage.to(F.ctx())
826
    h = sage(g, feat)
827
    assert h.shape[-1] == out_dim
828
    assert h.shape[0] == g.number_of_dst_nodes()
829

830

nv-dlasalle's avatar
nv-dlasalle committed
831
@parametrize_idtype
832
@pytest.mark.parametrize("out_dim", [1, 2])
833
def test_sage_conv2(idtype, out_dim):
834
    # TODO: add test for blocks
Mufei Li's avatar
Mufei Li committed
835
    # Test the case for graphs without edges
836
    g = dgl.heterograph({("_U", "_E", "_V"): ([], [])}, {"_U": 5, "_V": 3})
837
838
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
839
    sage = nn.SAGEConv((3, 3), out_dim, "gcn")
Mufei Li's avatar
Mufei Li committed
840
841
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    sage = sage.to(ctx)
842
    h = sage(g, (F.copy_to(feat[0], F.ctx()), F.copy_to(feat[1], F.ctx())))
843
    assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
844
    assert h.shape[0] == 3
845
    for aggre_type in ["mean", "pool", "lstm"]:
846
        sage = nn.SAGEConv((3, 1), out_dim, aggre_type)
Mufei Li's avatar
Mufei Li committed
847
848
849
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        sage = sage.to(ctx)
        h = sage(g, feat)
850
        assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
851
852
        assert h.shape[0] == 3

853

nv-dlasalle's avatar
nv-dlasalle committed
854
@parametrize_idtype
855
856
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
857
def test_sgc_conv(g, idtype, out_dim):
858
    ctx = F.ctx()
859
    g = g.astype(idtype).to(ctx)
860
    # not cached
861
    sgc = nn.SGConv(5, out_dim, 3)
862
863
864
865

    # test pickle
    th.save(sgc, tmp_buffer)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
866
    feat = F.randn((g.num_nodes(), 5))
867
    sgc = sgc.to(ctx)
868

869
    h = sgc(g, feat)
870
    assert h.shape[-1] == out_dim
871
872

    # cached
873
    sgc = nn.SGConv(5, out_dim, 3, True)
874
    sgc = sgc.to(ctx)
875
876
    h_0 = sgc(g, feat)
    h_1 = sgc(g, feat + 1)
877
    assert F.allclose(h_0, h_1)
878
    assert h_0.shape[-1] == out_dim
879

880

nv-dlasalle's avatar
nv-dlasalle committed
881
@parametrize_idtype
882
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
883
def test_appnp_conv(g, idtype):
884
    ctx = F.ctx()
885
    g = g.astype(idtype).to(ctx)
886
    appnp = nn.APPNPConv(10, 0.1)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
887
    feat = F.randn((g.num_nodes(), 5))
888
    appnp = appnp.to(ctx)
Mufei Li's avatar
Mufei Li committed
889

890
891
    # test pickle
    th.save(appnp, tmp_buffer)
892

893
    h = appnp(g, feat)
894
895
    assert h.shape[-1] == 5

896

nv-dlasalle's avatar
nv-dlasalle committed
897
@parametrize_idtype
898
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
899
900
901
902
def test_appnp_conv_e_weight(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    appnp = nn.APPNPConv(10, 0.1)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
903
    feat = F.randn((g.num_nodes(), 5))
904
    eweight = F.ones((g.num_edges(),))
905
906
907
908
909
    appnp = appnp.to(ctx)

    h = appnp(g, feat, edge_weight=eweight)
    assert h.shape[-1] == 5

910

nv-dlasalle's avatar
nv-dlasalle committed
911
@parametrize_idtype
912
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
913
914
@pytest.mark.parametrize("bias", [True, False])
def test_gcn2conv_e_weight(g, idtype, bias):
915
916
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
917
918
919
    gcn2conv = nn.GCN2Conv(
        5, layer=2, alpha=0.5, bias=bias, project_initial_features=True
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
920
    feat = F.randn((g.num_nodes(), 5))
921
    eweight = F.ones((g.num_edges(),))
922
923
924
925
926
927
    gcn2conv = gcn2conv.to(ctx)
    res = feat
    h = gcn2conv(g, res, feat, edge_weight=eweight)
    assert h.shape[-1] == 5


nv-dlasalle's avatar
nv-dlasalle committed
928
@parametrize_idtype
929
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
930
931
932
933
def test_sgconv_e_weight(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    sgconv = nn.SGConv(5, 5, 3)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
934
    feat = F.randn((g.num_nodes(), 5))
935
    eweight = F.ones((g.num_edges(),))
936
937
938
939
    sgconv = sgconv.to(ctx)
    h = sgconv(g, feat, edge_weight=eweight)
    assert h.shape[-1] == 5

940

nv-dlasalle's avatar
nv-dlasalle committed
941
@parametrize_idtype
942
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
943
944
945
946
947
def test_tagconv_e_weight(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    conv = nn.TAGConv(5, 5, bias=True)
    conv = conv.to(ctx)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
948
    feat = F.randn((g.num_nodes(), 5))
949
    eweight = F.ones((g.num_edges(),))
950
951
952
953
    conv = conv.to(ctx)
    h = conv(g, feat, edge_weight=eweight)
    assert h.shape[-1] == 5

954

nv-dlasalle's avatar
nv-dlasalle committed
955
@parametrize_idtype
956
957
958
959
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("aggregator_type", ["mean", "max", "sum"])
960
961
def test_gin_conv(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
962
    ctx = F.ctx()
963
    gin = nn.GINConv(th.nn.Linear(5, 12), aggregator_type)
VoVAllen's avatar
VoVAllen committed
964
    th.save(gin, tmp_buffer)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
965
    feat = F.randn((g.number_of_src_nodes(), 5))
966
967
    gin = gin.to(ctx)
    h = gin(g, feat)
968
969

    # test pickle
VoVAllen's avatar
VoVAllen committed
970
    th.save(gin, tmp_buffer)
Mufei Li's avatar
Mufei Li committed
971

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
972
    assert h.shape == (g.number_of_dst_nodes(), 12)
973

Mufei Li's avatar
Mufei Li committed
974
975
976
977
    gin = nn.GINConv(None, aggregator_type)
    th.save(gin, tmp_buffer)
    gin = gin.to(ctx)
    h = gin(g, feat)
978

979

nv-dlasalle's avatar
nv-dlasalle committed
980
@parametrize_idtype
981
@pytest.mark.parametrize("g", get_cases(["homo", "block-bipartite"]))
Mufei Li's avatar
Mufei Li committed
982
983
984
def test_gine_conv(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
985
    gine = nn.GINEConv(th.nn.Linear(5, 12))
Mufei Li's avatar
Mufei Li committed
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
    th.save(gine, tmp_buffer)
    nfeat = F.randn((g.number_of_src_nodes(), 5))
    efeat = F.randn((g.num_edges(), 5))
    gine = gine.to(ctx)
    h = gine(g, nfeat, efeat)

    # test pickle
    th.save(gine, tmp_buffer)
    assert h.shape == (g.number_of_dst_nodes(), 12)

    gine = nn.GINEConv(None)
    th.save(gine, tmp_buffer)
    gine = gine.to(ctx)
    h = gine(g, nfeat, efeat)

1001

nv-dlasalle's avatar
nv-dlasalle committed
1002
@parametrize_idtype
1003
1004
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("aggregator_type", ["mean", "max", "sum"])
1005
1006
1007
def test_gin_conv_bi(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1008
1009
1010
1011
    gin = nn.GINConv(th.nn.Linear(5, 12), aggregator_type)
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
1012
1013
1014
    )
    gin = gin.to(ctx)
    h = gin(g, feat)
1015
    assert h.shape == (g.number_of_dst_nodes(), 12)
1016

1017

nv-dlasalle's avatar
nv-dlasalle committed
1018
@parametrize_idtype
1019
1020
1021
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
1022
1023
def test_agnn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
1024
1025
    ctx = F.ctx()
    agnn = nn.AGNNConv(1)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1026
    feat = F.randn((g.number_of_src_nodes(), 5))
1027
    agnn = agnn.to(ctx)
1028
    h = agnn(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1029
    assert h.shape == (g.number_of_dst_nodes(), 5)
1030

1031

nv-dlasalle's avatar
nv-dlasalle committed
1032
@parametrize_idtype
1033
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
1034
1035
1036
def test_agnn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1037
    agnn = nn.AGNNConv(1)
1038
1039
1040
1041
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
1042
1043
    agnn = agnn.to(ctx)
    h = agnn(g, feat)
1044
    assert h.shape == (g.number_of_dst_nodes(), 5)
1045

1046

nv-dlasalle's avatar
nv-dlasalle committed
1047
@parametrize_idtype
1048
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
1049
def test_gated_graph_conv(g, idtype):
1050
    ctx = F.ctx()
1051
    g = g.astype(idtype).to(ctx)
1052
    ggconv = nn.GatedGraphConv(5, 10, 5, 3)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1053
1054
    etypes = th.arange(g.num_edges()) % 3
    feat = F.randn((g.num_nodes(), 5))
1055
1056
    ggconv = ggconv.to(ctx)
    etypes = etypes.to(ctx)
1057

1058
    h = ggconv(g, feat, etypes)
1059
1060
1061
    # current we only do shape check
    assert h.shape[-1] == 10

1062

nv-dlasalle's avatar
nv-dlasalle committed
1063
@parametrize_idtype
1064
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
1065
1066
1067
1068
def test_gated_graph_conv_one_etype(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    ggconv = nn.GatedGraphConv(5, 10, 5, 1)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1069
1070
    etypes = th.zeros(g.num_edges())
    feat = F.randn((g.num_nodes(), 5))
1071
1072
1073
1074
1075
1076
1077
1078
1079
    ggconv = ggconv.to(ctx)
    etypes = etypes.to(ctx)

    h = ggconv(g, feat, etypes)
    h2 = ggconv(g, feat)
    # current we only do shape check
    assert F.allclose(h, h2)
    assert h.shape[-1] == 10

1080

nv-dlasalle's avatar
nv-dlasalle committed
1081
@parametrize_idtype
1082
1083
1084
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
1085
1086
def test_nn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
1087
1088
    ctx = F.ctx()
    edge_func = th.nn.Linear(4, 5 * 10)
1089
    nnconv = nn.NNConv(5, 10, edge_func, "mean")
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1090
    feat = F.randn((g.number_of_src_nodes(), 5))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1091
    efeat = F.randn((g.num_edges(), 4))
1092
1093
1094
1095
1096
    nnconv = nnconv.to(ctx)
    h = nnconv(g, feat, efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

1097

nv-dlasalle's avatar
nv-dlasalle committed
1098
@parametrize_idtype
1099
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
1100
1101
1102
def test_nn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1103
    edge_func = th.nn.Linear(4, 5 * 10)
1104
    nnconv = nn.NNConv((5, 2), 10, edge_func, "mean")
1105
1106
    feat = F.randn((g.number_of_src_nodes(), 5))
    feat_dst = F.randn((g.number_of_dst_nodes(), 2))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1107
    efeat = F.randn((g.num_edges(), 4))
1108
1109
1110
1111
1112
    nnconv = nnconv.to(ctx)
    h = nnconv(g, (feat, feat_dst), efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

1113

nv-dlasalle's avatar
nv-dlasalle committed
1114
@parametrize_idtype
1115
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
1116
1117
def test_gmm_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
1118
    ctx = F.ctx()
1119
    gmmconv = nn.GMMConv(5, 10, 3, 4, "mean")
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1120
1121
    feat = F.randn((g.num_nodes(), 5))
    pseudo = F.randn((g.num_edges(), 3))
1122
    gmmconv = gmmconv.to(ctx)
1123
    h = gmmconv(g, feat, pseudo)
1124
1125
1126
    # currently we only do shape check
    assert h.shape[-1] == 10

1127

nv-dlasalle's avatar
nv-dlasalle committed
1128
@parametrize_idtype
1129
1130
1131
@pytest.mark.parametrize(
    "g", get_cases(["bipartite", "block-bipartite"], exclude=["zero-degree"])
)
1132
1133
1134
def test_gmm_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1135
    gmmconv = nn.GMMConv((5, 2), 10, 3, 4, "mean")
1136
1137
    feat = F.randn((g.number_of_src_nodes(), 5))
    feat_dst = F.randn((g.number_of_dst_nodes(), 2))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1138
    pseudo = F.randn((g.num_edges(), 3))
1139
1140
1141
1142
1143
    gmmconv = gmmconv.to(ctx)
    h = gmmconv(g, (feat, feat_dst), pseudo)
    # currently we only do shape check
    assert h.shape[-1] == 10

1144

nv-dlasalle's avatar
nv-dlasalle committed
1145
@parametrize_idtype
1146
1147
1148
1149
1150
@pytest.mark.parametrize("norm_type", ["both", "right", "none"])
@pytest.mark.parametrize(
    "g", get_cases(["homo", "bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 2])
1151
def test_dense_graph_conv(norm_type, g, idtype, out_dim):
1152
    g = g.astype(idtype).to(F.ctx())
1153
    ctx = F.ctx()
1154
    # TODO(minjie): enable the following option after #1385
1155
    adj = g.adj_external(transpose=True, ctx=ctx).to_dense()
1156
1157
    conv = nn.GraphConv(5, out_dim, norm=norm_type, bias=True)
    dense_conv = nn.DenseGraphConv(5, out_dim, norm=norm_type, bias=True)
1158
1159
    dense_conv.weight.data = conv.weight.data
    dense_conv.bias.data = conv.bias.data
1160
    feat = F.randn((g.number_of_src_nodes(), 5))
1161
1162
    conv = conv.to(ctx)
    dense_conv = dense_conv.to(ctx)
1163
1164
    out_conv = conv(g, feat)
    out_dense_conv = dense_conv(adj, feat)
1165
1166
    assert F.allclose(out_conv, out_dense_conv)

1167

nv-dlasalle's avatar
nv-dlasalle committed
1168
@parametrize_idtype
1169
1170
@pytest.mark.parametrize("g", get_cases(["homo", "bipartite"]))
@pytest.mark.parametrize("out_dim", [1, 2])
1171
def test_dense_sage_conv(g, idtype, out_dim):
1172
    g = g.astype(idtype).to(F.ctx())
1173
    ctx = F.ctx()
1174
    adj = g.adj_external(transpose=True, ctx=ctx).to_dense()
1175
    sage = nn.SAGEConv(5, out_dim, "gcn")
1176
    dense_sage = nn.DenseSAGEConv(5, out_dim)
1177
    dense_sage.fc.weight.data = sage.fc_neigh.weight.data
1178
    dense_sage.fc.bias.data = sage.bias.data
1179
1180
1181
    if len(g.ntypes) == 2:
        feat = (
            F.randn((g.number_of_src_nodes(), 5)),
1182
            F.randn((g.number_of_dst_nodes(), 5)),
1183
1184
        )
    else:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1185
        feat = F.randn((g.num_nodes(), 5))
1186
1187
    sage = sage.to(ctx)
    dense_sage = dense_sage.to(ctx)
1188
1189
    out_sage = sage(g, feat)
    out_dense_sage = dense_sage(adj, feat)
1190
1191
    assert F.allclose(out_sage, out_dense_sage), g

1192

nv-dlasalle's avatar
nv-dlasalle committed
1193
@parametrize_idtype
1194
1195
1196
1197
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 2])
1198
def test_edge_conv(g, idtype, out_dim):
1199
    g = g.astype(idtype).to(F.ctx())
1200
    ctx = F.ctx()
1201
    edge_conv = nn.EdgeConv(5, out_dim).to(ctx)
1202
    print(edge_conv)
1203
1204
1205

    # test pickle
    th.save(edge_conv, tmp_buffer)
Mufei Li's avatar
Mufei Li committed
1206

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1207
    h0 = F.randn((g.number_of_src_nodes(), 5))
1208
    h1 = edge_conv(g, h0)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1209
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)
1210

1211

nv-dlasalle's avatar
nv-dlasalle committed
1212
@parametrize_idtype
1213
1214
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
1215
def test_edge_conv_bi(g, idtype, out_dim):
1216
1217
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1218
    edge_conv = nn.EdgeConv(5, out_dim).to(ctx)
1219
    print(edge_conv)
1220
    h0 = F.randn((g.number_of_src_nodes(), 5))
1221
1222
    x0 = F.randn((g.number_of_dst_nodes(), 5))
    h1 = edge_conv(g, (h0, x0))
1223
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)
Mufei Li's avatar
Mufei Li committed
1224

1225

nv-dlasalle's avatar
nv-dlasalle committed
1226
@parametrize_idtype
1227
1228
1229
1230
1231
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
1232
def test_dotgat_conv(g, idtype, out_dim, num_heads):
1233
1234
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1235
    dotgat = nn.DotGatConv(5, out_dim, num_heads)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1236
    feat = F.randn((g.number_of_src_nodes(), 5))
1237
    dotgat = dotgat.to(ctx)
Mufei Li's avatar
Mufei Li committed
1238

1239
1240
    # test pickle
    th.save(dotgat, tmp_buffer)
Mufei Li's avatar
Mufei Li committed
1241

1242
    h = dotgat(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1243
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
1244
    _, a = dotgat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1245
    assert a.shape == (g.num_edges(), num_heads, 1)
1246

1247

nv-dlasalle's avatar
nv-dlasalle committed
1248
@parametrize_idtype
1249
1250
1251
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
1252
def test_dotgat_conv_bi(g, idtype, out_dim, num_heads):
1253
1254
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1255
    dotgat = nn.DotGatConv((5, 5), out_dim, num_heads)
1256
1257
1258
1259
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
1260
1261
    dotgat = dotgat.to(ctx)
    h = dotgat(g, feat)
1262
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
1263
    _, a = dotgat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1264
    assert a.shape == (g.num_edges(), num_heads, 1)
1265

1266
1267

@pytest.mark.parametrize("out_dim", [1, 2])
1268
def test_dense_cheb_conv(out_dim):
1269
1270
    for k in range(1, 4):
        ctx = F.ctx()
1271
        g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.1))
1272
        g = g.to(F.ctx())
1273
        adj = g.adj_external(transpose=True, ctx=ctx).to_dense()
1274
1275
        cheb = nn.ChebConv(5, out_dim, k, None)
        dense_cheb = nn.DenseChebConv(5, out_dim, k)
1276
        # for i in range(len(cheb.fc)):
Axel Nilsson's avatar
Axel Nilsson committed
1277
        #    dense_cheb.W.data[i] = cheb.fc[i].weight.data.t()
1278
1279
1280
        dense_cheb.W.data = cheb.linear.weight.data.transpose(-1, -2).view(
            k, 5, out_dim
        )
Axel Nilsson's avatar
Axel Nilsson committed
1281
1282
        if cheb.linear.bias is not None:
            dense_cheb.bias.data = cheb.linear.bias.data
1283
        feat = F.randn((100, 5))
1284
1285
        cheb = cheb.to(ctx)
        dense_cheb = dense_cheb.to(ctx)
1286
1287
        out_cheb = cheb(g, feat, [2.0])
        out_dense_cheb = dense_cheb(adj, feat, 2.0)
Axel Nilsson's avatar
Axel Nilsson committed
1288
        print(k, out_cheb, out_dense_cheb)
1289
1290
        assert F.allclose(out_cheb, out_dense_cheb)

1291

1292
1293
def test_sequential():
    ctx = F.ctx()
1294

1295
1296
1297
1298
1299
1300
1301
    # Test single graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat, e_feat):
            graph = graph.local_var()
1302
1303
1304
1305
1306
            graph.ndata["h"] = n_feat
            graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
            n_feat += graph.ndata["h"]
            graph.apply_edges(fn.u_add_v("h", "h", "e"))
            e_feat += graph.edata["e"]
1307
1308
            return n_feat, e_feat

1309
    g = dgl.graph([])
1310
1311
    g.add_nodes(3)
    g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
1312
    g = g.to(F.ctx())
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    n_feat = F.randn((3, 4))
    e_feat = F.randn((9, 4))
    net = net.to(ctx)
    n_feat, e_feat = net(g, n_feat, e_feat)
    assert n_feat.shape == (3, 4)
    assert e_feat.shape == (9, 4)

    # Test multiple graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat):
            graph = graph.local_var()
1328
1329
1330
            graph.ndata["h"] = n_feat
            graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
            n_feat += graph.ndata["h"]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1331
            return n_feat.view(graph.num_nodes() // 2, 2, -1).sum(1)
1332

1333
1334
1335
    g1 = dgl.from_networkx(nx.erdos_renyi_graph(32, 0.05)).to(ctx)
    g2 = dgl.from_networkx(nx.erdos_renyi_graph(16, 0.2)).to(ctx)
    g3 = dgl.from_networkx(nx.erdos_renyi_graph(8, 0.8)).to(ctx)
1336
1337
1338
1339
1340
1341
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    net = net.to(ctx)
    n_feat = F.randn((32, 4))
    n_feat = net([g1, g2, g3], n_feat)
    assert n_feat.shape == (4, 4)

1342

nv-dlasalle's avatar
nv-dlasalle committed
1343
@parametrize_idtype
1344
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
1345
1346
def test_atomic_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
1347
1348
1349
1350
1351
1352
    aconv = nn.AtomicConv(
        interaction_cutoffs=F.tensor([12.0, 12.0]),
        rbf_kernel_means=F.tensor([0.0, 2.0]),
        rbf_kernel_scaling=F.tensor([4.0, 4.0]),
        features_to_use=F.tensor([6.0, 8.0]),
    )
1353
1354
1355
1356
1357

    ctx = F.ctx()
    if F.gpu_ctx():
        aconv = aconv.to(ctx)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1358
1359
    feat = F.randn((g.num_nodes(), 1))
    dist = F.randn((g.num_edges(), 1))
1360
1361

    h = aconv(g, feat, dist)
1362

1363
1364
1365
    # current we only do shape check
    assert h.shape[-1] == 4

1366

nv-dlasalle's avatar
nv-dlasalle committed
1367
@parametrize_idtype
1368
1369
1370
1371
@pytest.mark.parametrize(
    "g", get_cases(["homo", "bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 3])
1372
def test_cf_conv(g, idtype, out_dim):
1373
    g = g.astype(idtype).to(F.ctx())
1374
1375
1376
    cfconv = nn.CFConv(
        node_in_feats=2, edge_in_feats=3, hidden_feats=2, out_feats=out_dim
    )
1377
1378
1379
1380
1381

    ctx = F.ctx()
    if F.gpu_ctx():
        cfconv = cfconv.to(ctx)

1382
    src_feats = F.randn((g.number_of_src_nodes(), 2))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1383
    edge_feats = F.randn((g.num_edges(), 3))
1384
1385
1386
1387
1388
1389
1390
    h = cfconv(g, src_feats, edge_feats)
    # current we only do shape check
    assert h.shape[-1] == out_dim

    # case for bipartite graphs
    dst_feats = F.randn((g.number_of_dst_nodes(), 3))
    h = cfconv(g, (src_feats, dst_feats), edge_feats)
1391
    # current we only do shape check
1392
    assert h.shape[-1] == out_dim
1393

1394

1395
1396
1397
1398
1399
1400
def myagg(alist, dsttype):
    rst = alist[0]
    for i in range(1, len(alist)):
        rst = rst + (i + 1) * alist[i]
    return rst

1401

nv-dlasalle's avatar
nv-dlasalle committed
1402
@parametrize_idtype
1403
1404
@pytest.mark.parametrize("agg", ["sum", "max", "min", "mean", "stack", myagg])
@pytest.mark.parametrize("canonical_keys", [False, True])
1405
def test_hetero_conv(agg, idtype, canonical_keys):
1406
1407
1408
1409
1410
1411
1412
1413
1414
    g = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 0, 2, 1], [1, 2, 1, 3]),
            ("user", "plays", "game"): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]),
            ("store", "sells", "game"): ([0, 0, 1, 1], [0, 3, 1, 2]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
1415
    if not canonical_keys:
1416
1417
1418
1419
1420
1421
1422
1423
        conv = nn.HeteroGraphConv(
            {
                "follows": nn.GraphConv(2, 3, allow_zero_in_degree=True),
                "plays": nn.GraphConv(2, 4, allow_zero_in_degree=True),
                "sells": nn.GraphConv(3, 4, allow_zero_in_degree=True),
            },
            agg,
        )
1424
    else:
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
        conv = nn.HeteroGraphConv(
            {
                ("user", "follows", "user"): nn.GraphConv(
                    2, 3, allow_zero_in_degree=True
                ),
                ("user", "plays", "game"): nn.GraphConv(
                    2, 4, allow_zero_in_degree=True
                ),
                ("store", "sells", "game"): nn.GraphConv(
                    3, 4, allow_zero_in_degree=True
                ),
            },
            agg,
        )
1439

1440
    conv = conv.to(F.ctx())
1441
1442
1443
1444

    # test pickle
    th.save(conv, tmp_buffer)

1445
1446
1447
1448
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))

1449
1450
1451
1452
1453
    h = conv(g, {"user": uf, "game": gf, "store": sf})
    assert set(h.keys()) == {"user", "game"}
    if agg != "stack":
        assert h["user"].shape == (4, 3)
        assert h["game"].shape == (4, 4)
1454
    else:
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
        assert h["user"].shape == (4, 1, 3)
        assert h["game"].shape == (4, 2, 4)

    block = dgl.to_block(
        g.to(F.cpu()), {"user": [0, 1, 2, 3], "game": [0, 1, 2, 3], "store": []}
    ).to(F.ctx())
    h = conv(
        block,
        (
            {"user": uf, "game": gf, "store": sf},
            {"user": uf, "game": gf, "store": sf[0:0]},
        ),
    )
    assert set(h.keys()) == {"user", "game"}
    if agg != "stack":
        assert h["user"].shape == (4, 3)
        assert h["game"].shape == (4, 4)
1472
    else:
1473
1474
1475
1476
1477
1478
1479
1480
        assert h["user"].shape == (4, 1, 3)
        assert h["game"].shape == (4, 2, 4)

    h = conv(block, {"user": uf, "game": gf, "store": sf})
    assert set(h.keys()) == {"user", "game"}
    if agg != "stack":
        assert h["user"].shape == (4, 3)
        assert h["game"].shape == (4, 4)
1481
    else:
1482
1483
        assert h["user"].shape == (4, 1, 3)
        assert h["game"].shape == (4, 2, 4)
1484
1485
1486
1487
1488
1489
1490
1491
1492

    # test with mod args
    class MyMod(th.nn.Module):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.carg2 = 0
            self.s1 = s1
            self.s2 = s2
1493

1494
1495
1496
1497
1498
1499
        def forward(self, g, h, arg1=None, *, arg2=None):
            if arg1 is not None:
                self.carg1 += 1
            if arg2 is not None:
                self.carg2 += 1
            return th.zeros((g.number_of_dst_nodes(), self.s2))
1500

1501
1502
1503
    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
1504
1505
1506
    conv = nn.HeteroGraphConv(
        {"follows": mod1, "plays": mod2, "sells": mod3}, agg
    )
1507
    conv = conv.to(F.ctx())
1508
1509
1510
1511
1512
1513
1514
1515
    mod_args = {"follows": (1,), "plays": (1,)}
    mod_kwargs = {"sells": {"arg2": "abc"}}
    h = conv(
        g,
        {"user": uf, "game": gf, "store": sf},
        mod_args=mod_args,
        mod_kwargs=mod_kwargs,
    )
1516
1517
1518
1519
1520
1521
1522
    assert mod1.carg1 == 1
    assert mod1.carg2 == 0
    assert mod2.carg1 == 1
    assert mod2.carg2 == 0
    assert mod3.carg1 == 0
    assert mod3.carg2 == 1

1523
    # conv on graph without any edges
1524
    for etype in g.etypes:
1525
        g = dgl.remove_edges(g, g.edges(form="eid", etype=etype), etype=etype)
1526
    assert g.num_edges() == 0
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
    h = conv(g, {"user": uf, "game": gf, "store": sf})
    assert set(h.keys()) == {"user", "game"}

    block = dgl.to_block(
        g.to(F.cpu()), {"user": [0, 1, 2, 3], "game": [0, 1, 2, 3], "store": []}
    ).to(F.ctx())
    h = conv(
        block,
        (
            {"user": uf, "game": gf, "store": sf},
            {"user": uf, "game": gf, "store": sf[0:0]},
        ),
    )
    assert set(h.keys()) == {"user", "game"}
1541
1542


1543
@pytest.mark.parametrize("out_dim", [1, 2, 100])
1544
1545
def test_hetero_linear(out_dim):
    in_feats = {
1546
1547
        "user": F.randn((2, 1)),
        ("user", "follows", "user"): F.randn((3, 2)),
1548
1549
    }

1550
1551
1552
    layer = nn.HeteroLinear(
        {"user": 1, ("user", "follows", "user"): 2}, out_dim
    )
1553
1554
    layer = layer.to(F.ctx())
    out_feats = layer(in_feats)
1555
1556
1557
    assert out_feats["user"].shape == (2, out_dim)
    assert out_feats[("user", "follows", "user")].shape == (3, out_dim)

1558

1559
@pytest.mark.parametrize("out_dim", [1, 2, 100])
1560
def test_hetero_embedding(out_dim):
1561
1562
1563
    layer = nn.HeteroEmbedding(
        {"user": 2, ("user", "follows", "user"): 3}, out_dim
    )
1564
1565
1566
    layer = layer.to(F.ctx())

    embeds = layer.weight
1567
1568
    assert embeds["user"].shape == (2, out_dim)
    assert embeds[("user", "follows", "user")].shape == (3, out_dim)
1569

YJ-Zhao's avatar
YJ-Zhao committed
1570
1571
    layer.reset_parameters()
    embeds = layer.weight
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
    assert embeds["user"].shape == (2, out_dim)
    assert embeds[("user", "follows", "user")].shape == (3, out_dim)

    embeds = layer(
        {
            "user": F.tensor([0], dtype=F.int64),
            ("user", "follows", "user"): F.tensor([0, 2], dtype=F.int64),
        }
    )
    assert embeds["user"].shape == (1, out_dim)
    assert embeds[("user", "follows", "user")].shape == (2, out_dim)
YJ-Zhao's avatar
YJ-Zhao committed
1583

1584

nv-dlasalle's avatar
nv-dlasalle committed
1585
@parametrize_idtype
1586
1587
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
Mufei Li's avatar
Mufei Li committed
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
def test_gnnexplainer(g, idtype, out_dim):
    g = g.astype(idtype).to(F.ctx())
    feat = F.randn((g.num_nodes(), 5))

    class Model(th.nn.Module):
        def __init__(self, in_feats, out_feats, graph=False):
            super(Model, self).__init__()
            self.linear = th.nn.Linear(in_feats, out_feats)
            if graph:
                self.pool = nn.AvgPooling()
            else:
                self.pool = None

        def forward(self, graph, feat, eweight=None):
            with graph.local_scope():
                feat = self.linear(feat)
1604
                graph.ndata["h"] = feat
Mufei Li's avatar
Mufei Li committed
1605
                if eweight is None:
1606
                    graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
Mufei Li's avatar
Mufei Li committed
1607
                else:
1608
1609
1610
1611
                    graph.edata["w"] = eweight
                    graph.update_all(
                        fn.u_mul_e("h", "w", "m"), fn.sum("m", "h")
                    )
Mufei Li's avatar
Mufei Li committed
1612
1613

                if self.pool:
1614
                    return self.pool(graph, graph.ndata["h"])
Mufei Li's avatar
Mufei Li committed
1615
                else:
1616
                    return graph.ndata["h"]
Mufei Li's avatar
Mufei Li committed
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629

    # Explain node prediction
    model = Model(5, out_dim)
    model = model.to(F.ctx())
    explainer = nn.GNNExplainer(model, num_hops=1)
    new_center, sg, feat_mask, edge_mask = explainer.explain_node(0, g, feat)

    # Explain graph prediction
    model = Model(5, out_dim, graph=True)
    model = model.to(F.ctx())
    explainer = nn.GNNExplainer(model, num_hops=1)
    feat_mask, edge_mask = explainer.explain_graph(g, feat)

1630
1631
1632
1633
1634

@pytest.mark.parametrize("g", get_cases(["hetero"], exclude=["zero-degree"]))
@pytest.mark.parametrize("idtype", [F.int64])
@pytest.mark.parametrize("input_dim", [5])
@pytest.mark.parametrize("output_dim", [1, 2])
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
def test_heterognnexplainer(g, idtype, input_dim, output_dim):
    g = g.astype(idtype).to(F.ctx())
    device = g.device

    # add self-loop and reverse edges
    transform1 = dgl.transforms.AddSelfLoop(new_etypes=True)
    g = transform1(g)
    transform2 = dgl.transforms.AddReverse(copy_edata=True)
    g = transform2(g)

1645
1646
1647
1648
    feat = {
        ntype: th.zeros((g.num_nodes(ntype), input_dim), device=device)
        for ntype in g.ntypes
    }
1649
1650
1651
1652

    class Model(th.nn.Module):
        def __init__(self, in_dim, num_classes, canonical_etypes, graph=False):
            super(Model, self).__init__()
1653
1654
1655
1656
1657
1658
1659
            self.graph = graph
            self.etype_weights = th.nn.ModuleDict(
                {
                    "_".join(c_etype): th.nn.Linear(in_dim, num_classes)
                    for c_etype in canonical_etypes
                }
            )
1660
1661
1662
1663
1664
1665

        def forward(self, graph, feat, eweight=None):
            with graph.local_scope():
                c_etype_func_dict = {}
                for c_etype in graph.canonical_etypes:
                    src_type, etype, dst_type = c_etype
1666
1667
                    wh = self.etype_weights["_".join(c_etype)](feat[src_type])
                    graph.nodes[src_type].data[f"h_{c_etype}"] = wh
1668
                    if eweight is None:
1669
1670
1671
1672
                        c_etype_func_dict[c_etype] = (
                            fn.copy_u(f"h_{c_etype}", "m"),
                            fn.mean("m", "h"),
                        )
1673
                    else:
1674
                        graph.edges[c_etype].data["w"] = eweight[c_etype]
1675
                        c_etype_func_dict[c_etype] = (
1676
1677
1678
1679
                            fn.u_mul_e(f"h_{c_etype}", "w", "m"),
                            fn.mean("m", "h"),
                        )
                graph.multi_update_all(c_etype_func_dict, "sum")
1680
1681
1682
1683
                if self.graph:
                    hg = 0
                    for ntype in graph.ntypes:
                        if graph.num_nodes(ntype):
1684
                            hg = hg + dgl.mean_nodes(graph, "h", ntype=ntype)
1685
1686
1687

                    return hg
                else:
1688
                    return graph.ndata["h"]
1689
1690
1691
1692
1693
1694

    # Explain node prediction
    model = Model(input_dim, output_dim, g.canonical_etypes)
    model = model.to(F.ctx())
    ntype = g.ntypes[0]
    explainer = nn.explain.HeteroGNNExplainer(model, num_hops=1)
1695
1696
1697
    new_center, sg, feat_mask, edge_mask = explainer.explain_node(
        ntype, 0, g, feat
    )
1698
1699
1700
1701
1702
1703
1704
1705

    # Explain graph prediction
    model = Model(input_dim, output_dim, g.canonical_etypes, graph=True)
    model = model.to(F.ctx())
    explainer = nn.explain.HeteroGNNExplainer(model, num_hops=1)
    feat_mask, edge_mask = explainer.explain_graph(g, feat)


1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
@parametrize_idtype
@pytest.mark.parametrize(
    "g",
    get_cases(
        ["homo"],
        exclude=[
            "zero-degree",
            "homo-zero-degree",
            "has_feature",
            "has_scalar_e_feature",
            "row_sorted",
            "col_sorted",
            "batched",
        ],
    ),
)
@pytest.mark.parametrize("n_classes", [2])
def test_subgraphx(g, idtype, n_classes):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    feat = F.randn((g.num_nodes(), 5))

    class Model(th.nn.Module):
        def __init__(self, in_dim, n_classes):
            super().__init__()
            self.conv = nn.GraphConv(in_dim, n_classes)
            self.pool = nn.AvgPooling()

        def forward(self, g, h):
            h = th.nn.functional.relu(self.conv(g, h))
            return self.pool(g, h)

    model = Model(feat.shape[1], n_classes)
    model = model.to(ctx)
    explainer = nn.SubgraphX(
        model, num_hops=1, shapley_steps=20, num_rollouts=5, coef=2.0
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
    )
    explainer.explain_graph(g, feat, target_class=0)


@pytest.mark.parametrize("g", get_cases(["hetero"], exclude=["zero-degree"]))
@pytest.mark.parametrize("idtype", [F.int64])
@pytest.mark.parametrize("input_dim", [5])
@pytest.mark.parametrize("n_classes", [2])
def test_heterosubgraphx(g, idtype, input_dim, n_classes):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    device = g.device

    # add self-loop and reverse edges
    transform1 = dgl.transforms.AddSelfLoop(new_etypes=True)
    g = transform1(g)
    transform2 = dgl.transforms.AddReverse(copy_edata=True)
    g = transform2(g)

    feat = {
        ntype: th.zeros((g.num_nodes(ntype), input_dim), device=device)
        for ntype in g.ntypes
    }

    class Model(th.nn.Module):
        def __init__(self, in_dim, n_classes, canonical_etypes):
            super(Model, self).__init__()
            self.etype_weights = th.nn.ModuleDict(
                {
                    "_".join(c_etype): th.nn.Linear(in_dim, n_classes)
                    for c_etype in canonical_etypes
                }
            )

        def forward(self, graph, feat):
            with graph.local_scope():
                c_etype_func_dict = {}
                for c_etype in graph.canonical_etypes:
                    src_type, etype, dst_type = c_etype
                    wh = self.etype_weights["_".join(c_etype)](feat[src_type])
                    graph.nodes[src_type].data[f"h_{c_etype}"] = wh
                    c_etype_func_dict[c_etype] = (
                        fn.copy_u(f"h_{c_etype}", "m"),
                        fn.mean("m", "h"),
                    )
                graph.multi_update_all(c_etype_func_dict, "sum")
                hg = 0
                for ntype in graph.ntypes:
                    if graph.num_nodes(ntype):
                        hg = hg + dgl.mean_nodes(graph, "h", ntype=ntype)

                return hg

    model = Model(input_dim, n_classes, g.canonical_etypes)
    model = model.to(ctx)
    explainer = nn.HeteroSubgraphX(
        model, num_hops=1, shapley_steps=20, num_rollouts=5, coef=2.0
1799
1800
1801
1802
    )
    explainer.explain_graph(g, feat, target_class=0)


1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
@parametrize_idtype
@pytest.mark.parametrize(
    "g",
    get_cases(
        ["homo"],
        exclude=[
            "zero-degree",
            "homo-zero-degree",
            "has_feature",
            "has_scalar_e_feature",
            "row_sorted",
            "col_sorted",
        ],
    ),
)
@pytest.mark.parametrize("n_classes", [2])
def test_pgexplainer(g, idtype, n_classes):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    feat = F.randn((g.num_nodes(), 5))
    g.ndata["attr"] = feat

    # add reverse edges
    transform = dgl.transforms.AddReverse(copy_edata=True)
    g = transform(g)

    class Model(th.nn.Module):
1830
        def __init__(self, in_feats, out_feats, graph=False):
1831
            super(Model, self).__init__()
1832
            self.graph = graph
1833
1834
1835
1836
1837
1838
            self.conv = nn.GraphConv(in_feats, out_feats)
            self.fc = th.nn.Linear(out_feats, out_feats)
            th.nn.init.xavier_uniform_(self.fc.weight)

        def forward(self, g, h, embed=False, edge_weight=None):
            h = self.conv(g, h, edge_weight=edge_weight)
1839

1840
            if not self.graph or embed:
1841
1842
1843
                return h

            with g.local_scope():
1844
1845
1846
1847
                g.ndata["h"] = h
                hg = dgl.mean_nodes(g, "h")
                return self.fc(hg)

1848
1849
    # graph explainer
    model = Model(feat.shape[1], n_classes, graph=True)
1850
1851
1852
1853
1854
1855
    model = model.to(ctx)
    explainer = nn.PGExplainer(model, n_classes)
    explainer.train_step(g, g.ndata["attr"], 5.0)

    probs, edge_weight = explainer.explain_graph(g, feat)

1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
    # node explainer
    model = Model(feat.shape[1], n_classes, graph=False)
    model = model.to(ctx)
    explainer = nn.PGExplainer(
        model, n_classes, num_hops=1, explain_graph=False
    )
    explainer.train_step_node(0, g, g.ndata["attr"], 5.0)
    explainer.train_step_node([0, 1], g, g.ndata["attr"], 5.0)
    explainer.train_step_node(th.tensor(0), g, g.ndata["attr"], 5.0)
    explainer.train_step_node(th.tensor([0, 1]), g, g.ndata["attr"], 5.0)

    probs, edge_weight, bg, inverse_indices = explainer.explain_node(0, g, feat)
    probs, edge_weight, bg, inverse_indices = explainer.explain_node(
        [0, 1], g, feat
    )
    probs, edge_weight, bg, inverse_indices = explainer.explain_node(
        th.tensor(0), g, feat
    )
    probs, edge_weight, bg, inverse_indices = explainer.explain_node(
        th.tensor([0, 1]), g, feat
    )

1878

1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
@pytest.mark.parametrize("g", get_cases(["hetero"]))
@pytest.mark.parametrize("idtype", [F.int64])
@pytest.mark.parametrize("input_dim", [5])
@pytest.mark.parametrize("n_classes", [2])
def test_heteropgexplainer(g, idtype, input_dim, n_classes):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    feat = {
        ntype: F.randn((g.num_nodes(ntype), input_dim)) for ntype in g.ntypes
    }

    # add self-loop and reverse edges
    transform1 = dgl.transforms.AddSelfLoop(new_etypes=True)
    g = transform1(g)
    transform2 = dgl.transforms.AddReverse(copy_edata=True)
    g = transform2(g)

    class Model(th.nn.Module):
        def __init__(self, in_feats, embed_dim, out_feats, canonical_etypes):
            super(Model, self).__init__()
            self.conv = nn.HeteroGraphConv(
                {
                    c_etype: nn.GraphConv(in_feats, embed_dim)
                    for c_etype in canonical_etypes
                }
            )
            self.fc = th.nn.Linear(embed_dim, out_feats)

        def forward(self, g, h, embed=False, edge_weight=None):
            if edge_weight is not None:
                mod_kwargs = {
                    etype: {"edge_weight": mask}
                    for etype, mask in edge_weight.items()
                }
                h = self.conv(g, h, mod_kwargs=mod_kwargs)
            else:
                h = self.conv(g, h)

            if embed:
                return h

            with g.local_scope():
                g.ndata["h"] = h
                hg = 0
                for ntype in g.ntypes:
                    hg = hg + dgl.mean_nodes(g, "h", ntype=ntype)
                return self.fc(hg)

    embed_dim = input_dim
1928
1929

    # graph explainer
1930
1931
1932
1933
1934
1935
1936
1937
    model = Model(input_dim, embed_dim, n_classes, g.canonical_etypes)
    model = model.to(ctx)
    explainer = nn.HeteroPGExplainer(model, embed_dim)
    explainer.train_step(g, feat, 5.0)

    probs, edge_weight = explainer.explain_graph(g, feat)


Mufei Li's avatar
Mufei Li committed
1938
1939
1940
1941
1942
1943
def test_jumping_knowledge():
    ctx = F.ctx()
    num_layers = 2
    num_nodes = 3
    num_feats = 4

1944
1945
1946
    feat_list = [
        th.randn((num_nodes, num_feats)).to(ctx) for _ in range(num_layers)
    ]
Mufei Li's avatar
Mufei Li committed
1947

1948
    model = nn.JumpingKnowledge("cat").to(ctx)
Mufei Li's avatar
Mufei Li committed
1949
1950
1951
    model.reset_parameters()
    assert model(feat_list).shape == (num_nodes, num_layers * num_feats)

1952
    model = nn.JumpingKnowledge("max").to(ctx)
Mufei Li's avatar
Mufei Li committed
1953
1954
1955
    model.reset_parameters()
    assert model(feat_list).shape == (num_nodes, num_feats)

1956
    model = nn.JumpingKnowledge("lstm", num_feats, num_layers).to(ctx)
Mufei Li's avatar
Mufei Li committed
1957
1958
1959
    model.reset_parameters()
    assert model(feat_list).shape == (num_nodes, num_feats)

1960
1961

@pytest.mark.parametrize("op", ["dot", "cos", "ele", "cat"])
Mufei Li's avatar
Mufei Li committed
1962
1963
1964
1965
1966
1967
1968
1969
1970
def test_edge_predictor(op):
    ctx = F.ctx()
    num_pairs = 3
    in_feats = 4
    out_feats = 5
    h_src = th.randn((num_pairs, in_feats)).to(ctx)
    h_dst = th.randn((num_pairs, in_feats)).to(ctx)

    pred = nn.EdgePredictor(op)
1971
    if op in ["dot", "cos"]:
Mufei Li's avatar
Mufei Li committed
1972
        assert pred(h_src, h_dst).shape == (num_pairs, 1)
1973
    elif op == "ele":
Mufei Li's avatar
Mufei Li committed
1974
1975
1976
1977
1978
1979
        assert pred(h_src, h_dst).shape == (num_pairs, in_feats)
    else:
        assert pred(h_src, h_dst).shape == (num_pairs, 2 * in_feats)
    pred = nn.EdgePredictor(op, in_feats, out_feats, bias=True).to(ctx)
    assert pred(h_src, h_dst).shape == (num_pairs, out_feats)

Mufei Li's avatar
Mufei Li committed
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994

def test_ke_score_funcs():
    ctx = F.ctx()
    num_edges = 30
    num_rels = 3
    nfeats = 4

    h_src = th.randn((num_edges, nfeats)).to(ctx)
    h_dst = th.randn((num_edges, nfeats)).to(ctx)
    rels = th.randint(low=0, high=num_rels, size=(num_edges,)).to(ctx)

    score_func = nn.TransE(num_rels=num_rels, feats=nfeats).to(ctx)
    score_func.reset_parameters()
    score_func(h_src, h_dst, rels).shape == (num_edges)

1995
1996
1997
    score_func = nn.TransR(
        num_rels=num_rels, rfeats=nfeats - 1, nfeats=nfeats
    ).to(ctx)
Mufei Li's avatar
Mufei Li committed
1998
1999
2000
2001
    score_func.reset_parameters()
    score_func(h_src, h_dst, rels).shape == (num_edges)


2002
def test_twirls():
2003
    g = dgl.graph(([0, 1, 2, 3, 2, 5], [1, 2, 3, 4, 0, 3]))
2004
    feat = th.ones(6, 10)
2005
2006
2007
2008
    conv = nn.TWIRLSConv(10, 2, 128, prop_step=64)
    res = conv(g, feat)
    assert res.size() == (6, 2)

2009

2010
2011
2012
2013
@pytest.mark.parametrize("feat_size", [4, 32])
@pytest.mark.parametrize(
    "regularizer,num_bases", [(None, None), ("basis", 4), ("bdd", 4)]
)
2014
2015
2016
def test_typed_linear(feat_size, regularizer, num_bases):
    dev = F.ctx()
    num_types = 5
2017
2018
2019
2020
2021
2022
2023
    lin = nn.TypedLinear(
        feat_size,
        feat_size * 2,
        5,
        regularizer=regularizer,
        num_bases=num_bases,
    ).to(dev)
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
    print(lin)
    x = th.randn(100, feat_size).to(dev)
    x_type = th.randint(0, 5, (100,)).to(dev)
    x_type_sorted, idx = th.sort(x_type)
    _, rev_idx = th.sort(idx)
    x_sorted = x[idx]

    # test unsorted
    y = lin(x, x_type)
    assert y.shape == (100, feat_size * 2)
    # test sorted
    y_sorted = lin(x_sorted, x_type_sorted, sorted_by_type=True)
    assert y_sorted.shape == (100, feat_size * 2)

    assert th.allclose(y, y_sorted[rev_idx], atol=1e-4, rtol=1e-4)
2039

2040

nv-dlasalle's avatar
nv-dlasalle committed
2041
@parametrize_idtype
2042
2043
@pytest.mark.parametrize("in_size", [4])
@pytest.mark.parametrize("num_heads", [1])
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
def test_hgt(idtype, in_size, num_heads):
    dev = F.ctx()
    num_etypes = 5
    num_ntypes = 2
    head_size = in_size // num_heads

    g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.01))
    g = g.astype(idtype).to(dev)
    etype = th.tensor([i % num_etypes for i in range(g.num_edges())]).to(dev)
    ntype = th.tensor([i % num_ntypes for i in range(g.num_nodes())]).to(dev)
    x = th.randn(g.num_nodes(), in_size).to(dev)
2055

2056
2057
2058
    m = nn.HGTConv(in_size, head_size, num_heads, num_ntypes, num_etypes).to(
        dev
    )
2059
2060
2061
2062
2063
2064
2065

    y = m(g, x, ntype, etype)
    assert y.shape == (g.num_nodes(), head_size * num_heads)
    # presorted
    sorted_ntype, idx_nt = th.sort(ntype)
    sorted_etype, idx_et = th.sort(etype)
    _, rev_idx = th.sort(idx_nt)
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
    g.ndata["t"] = ntype
    g.ndata["x"] = x
    g.edata["t"] = etype
    sorted_g = dgl.reorder_graph(
        g,
        node_permute_algo="custom",
        edge_permute_algo="custom",
        permute_config={
            "nodes_perm": idx_nt.to(idtype),
            "edges_perm": idx_et.to(idtype),
        },
    )
    print(sorted_g.ndata["t"])
    print(sorted_g.edata["t"])
    sorted_x = sorted_g.ndata["x"]
    sorted_y = m(
        sorted_g, sorted_x, sorted_ntype, sorted_etype, presorted=False
    )
2084
    assert sorted_y.shape == (g.num_nodes(), head_size * num_heads)
dddg617's avatar
dddg617 committed
2085
    # mini-batch
2086
    train_idx = th.randperm(100, dtype=idtype)[:10]
dddg617's avatar
dddg617 committed
2087
    sampler = dgl.dataloading.NeighborSampler([-1])
2088
2089
2090
    train_loader = dgl.dataloading.DataLoader(
        g, train_idx.to(dev), sampler, batch_size=8, device=dev, shuffle=True
    )
dddg617's avatar
dddg617 committed
2091
2092
2093
2094
2095
2096
2097
2098
    (input_nodes, output_nodes, block) = next(iter(train_loader))
    block = block[0]
    x = x[input_nodes.to(th.long)]
    ntype = ntype[input_nodes.to(th.long)]
    edge = block.edata[dgl.EID]
    etype = etype[edge.to(th.long)]
    y = m(block, x, ntype, etype)
    assert y.shape == (block.number_of_dst_nodes(), head_size * num_heads)
2099
    # TODO(minjie): enable the following check
2100
2101
    # assert th.allclose(y, sorted_y[rev_idx], atol=1e-4, rtol=1e-4)

2102

2103
2104
@pytest.mark.parametrize("self_loop", [True, False])
@pytest.mark.parametrize("get_distances", [True, False])
2105
def test_radius_graph(self_loop, get_distances):
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
    pos = th.tensor(
        [
            [0.1, 0.3, 0.4],
            [0.5, 0.2, 0.1],
            [0.7, 0.9, 0.5],
            [0.3, 0.2, 0.5],
            [0.2, 0.8, 0.2],
            [0.9, 0.2, 0.1],
            [0.7, 0.4, 0.4],
            [0.2, 0.1, 0.6],
            [0.5, 0.3, 0.5],
            [0.4, 0.2, 0.6],
        ]
    )
2120
2121
2122
2123
2124
2125
2126
2127
2128

    rg = nn.RadiusGraph(0.3, self_loop=self_loop)

    if get_distances:
        g, dists = rg(pos, get_distances=get_distances)
    else:
        g = rg(pos)

    if self_loop:
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
        src_target = th.tensor(
            [
                0,
                0,
                1,
                2,
                3,
                3,
                3,
                3,
                3,
                4,
                5,
                6,
                6,
                7,
                7,
                7,
                8,
                8,
                8,
                8,
                9,
                9,
                9,
                9,
            ]
        )
        dst_target = th.tensor(
            [
                0,
                3,
                1,
                2,
                0,
                3,
                7,
                8,
                9,
                4,
                5,
                6,
                8,
                3,
                7,
                9,
                3,
                6,
                8,
                9,
                3,
                7,
                8,
                9,
            ]
        )
2185
2186

        if get_distances:
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
            dists_target = th.tensor(
                [
                    [0.0000],
                    [0.2449],
                    [0.0000],
                    [0.0000],
                    [0.2449],
                    [0.0000],
                    [0.1732],
                    [0.2236],
                    [0.1414],
                    [0.0000],
                    [0.0000],
                    [0.0000],
                    [0.2449],
                    [0.1732],
                    [0.0000],
                    [0.2236],
                    [0.2236],
                    [0.2449],
                    [0.0000],
                    [0.1732],
                    [0.1414],
                    [0.2236],
                    [0.1732],
                    [0.0000],
                ]
            )
2215
2216
2217
2218
2219
    else:
        src_target = th.tensor([0, 3, 3, 3, 3, 6, 7, 7, 8, 8, 8, 9, 9, 9])
        dst_target = th.tensor([3, 0, 7, 8, 9, 8, 3, 9, 3, 6, 9, 3, 7, 8])

        if get_distances:
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
            dists_target = th.tensor(
                [
                    [0.2449],
                    [0.2449],
                    [0.1732],
                    [0.2236],
                    [0.1414],
                    [0.2449],
                    [0.1732],
                    [0.2236],
                    [0.2236],
                    [0.2449],
                    [0.1732],
                    [0.1414],
                    [0.2236],
                    [0.1732],
                ]
            )
2238
2239
2240
2241
2242
2243
2244
2245
2246

    src, dst = g.edges()

    assert th.equal(src, src_target)
    assert th.equal(dst, dst_target)

    if get_distances:
        assert th.allclose(dists, dists_target, rtol=1e-03)

2247

nv-dlasalle's avatar
nv-dlasalle committed
2248
@parametrize_idtype
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
def test_group_rev_res(idtype):
    dev = F.ctx()

    num_nodes = 5
    num_edges = 20
    feats = 32
    groups = 2
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    h = th.randn(num_nodes, feats).to(dev)
    conv = nn.GraphConv(feats // groups, feats // groups)
    model = nn.GroupRevRes(conv, groups).to(dev)
2260
2261
    result = model(g, h)
    result.sum().backward()
rudongyu's avatar
rudongyu committed
2262

2263
2264
2265
2266
2267

@pytest.mark.parametrize("in_size", [16, 32])
@pytest.mark.parametrize("hidden_size", [16, 32])
@pytest.mark.parametrize("out_size", [16, 32])
@pytest.mark.parametrize("edge_feat_size", [16, 10, 0])
rudongyu's avatar
rudongyu committed
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
def test_egnn_conv(in_size, hidden_size, out_size, edge_feat_size):
    dev = F.ctx()
    num_nodes = 5
    num_edges = 20
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    h = th.randn(num_nodes, in_size).to(dev)
    x = th.randn(num_nodes, 3).to(dev)
    e = th.randn(num_edges, edge_feat_size).to(dev)
    model = nn.EGNNConv(in_size, hidden_size, out_size, edge_feat_size).to(dev)
    model(g, h, x, e)

2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308

@pytest.mark.parametrize("in_size", [16, 32])
@pytest.mark.parametrize("out_size", [16, 32])
@pytest.mark.parametrize(
    "aggregators",
    [
        ["mean", "max", "sum"],
        ["min", "std", "var"],
        ["moment3", "moment4", "moment5"],
    ],
)
@pytest.mark.parametrize(
    "scalers", [["identity"], ["amplification", "attenuation"]]
)
@pytest.mark.parametrize("delta", [2.5, 7.4])
@pytest.mark.parametrize("dropout", [0.0, 0.1])
@pytest.mark.parametrize("num_towers", [1, 4])
@pytest.mark.parametrize("edge_feat_size", [16, 0])
@pytest.mark.parametrize("residual", [True, False])
def test_pna_conv(
    in_size,
    out_size,
    aggregators,
    scalers,
    delta,
    dropout,
    num_towers,
    edge_feat_size,
    residual,
):
rudongyu's avatar
rudongyu committed
2309
2310
2311
2312
2313
2314
    dev = F.ctx()
    num_nodes = 5
    num_edges = 20
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    h = th.randn(num_nodes, in_size).to(dev)
    e = th.randn(num_edges, edge_feat_size).to(dev)
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
    model = nn.PNAConv(
        in_size,
        out_size,
        aggregators,
        scalers,
        delta,
        dropout,
        num_towers,
        edge_feat_size,
        residual,
    ).to(dev)
rudongyu's avatar
rudongyu committed
2326
    model(g, h, edge_feat=e)
2327

2328
2329
2330
2331
2332
2333
2334

@pytest.mark.parametrize("k", [3, 5])
@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0])
@pytest.mark.parametrize("norm_type", ["sym", "row"])
@pytest.mark.parametrize("clamp", [True, False])
@pytest.mark.parametrize("normalize", [True, False])
@pytest.mark.parametrize("reset", [True, False])
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
def test_label_prop(k, alpha, norm_type, clamp, normalize, reset):
    dev = F.ctx()
    num_nodes = 5
    num_edges = 20
    num_classes = 4
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    labels = th.tensor([0, 2, 1, 3, 0]).long().to(dev)
    ml_labels = th.rand(num_nodes, num_classes).to(dev) > 0.7
    mask = th.tensor([0, 1, 1, 1, 0]).bool().to(dev)
    model = nn.LabelPropagation(k, alpha, norm_type, clamp, normalize, reset)
    model(g, labels, mask)
    # multi-label case
    model(g, ml_labels, mask)

2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360

@pytest.mark.parametrize("in_size", [16])
@pytest.mark.parametrize("out_size", [16, 32])
@pytest.mark.parametrize(
    "aggregators", [["mean", "max", "dir2-av"], ["min", "std", "dir1-dx"]]
)
@pytest.mark.parametrize("scalers", [["amplification", "attenuation"]])
@pytest.mark.parametrize("delta", [2.5])
@pytest.mark.parametrize("edge_feat_size", [16, 0])
def test_dgn_conv(
    in_size, out_size, aggregators, scalers, delta, edge_feat_size
):
2361
2362
2363
2364
2365
2366
    dev = F.ctx()
    num_nodes = 5
    num_edges = 20
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    h = th.randn(num_nodes, in_size).to(dev)
    e = th.randn(num_edges, edge_feat_size).to(dev)
2367
    transform = dgl.LapPE(k=3, feat_name="eig")
2368
    g = transform(g)
2369
2370
2371
2372
2373
2374
2375
2376
2377
    eig = g.ndata["eig"]
    model = nn.DGNConv(
        in_size,
        out_size,
        aggregators,
        scalers,
        delta,
        edge_feat_size=edge_feat_size,
    ).to(dev)
2378
2379
    model(g, h, edge_feat=e, eig_vec=eig)

2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
    aggregators_non_eig = [
        aggr for aggr in aggregators if not aggr.startswith("dir")
    ]
    model = nn.DGNConv(
        in_size,
        out_size,
        aggregators_non_eig,
        scalers,
        delta,
        edge_feat_size=edge_feat_size,
    ).to(dev)
2391
    model(g, h, edge_feat=e)
LuckyLiuM's avatar
LuckyLiuM committed
2392

2393

LuckyLiuM's avatar
LuckyLiuM committed
2394
2395
2396
def test_DeepWalk():
    dev = F.ctx()
    g = dgl.graph(([0, 1, 2, 1, 2, 0], [1, 2, 0, 0, 1, 2]))
2397
2398
2399
    model = nn.DeepWalk(
        g, emb_dim=8, walk_length=2, window_size=1, fast_neg=True, sparse=True
    )
LuckyLiuM's avatar
LuckyLiuM committed
2400
    model = model.to(dev)
2401
2402
2403
    dataloader = DataLoader(
        torch.arange(g.num_nodes()), batch_size=16, collate_fn=model.sample
    )
LuckyLiuM's avatar
LuckyLiuM committed
2404
2405
2406
2407
2408
2409
    optim = SparseAdam(model.parameters(), lr=0.01)
    walk = next(iter(dataloader)).to(dev)
    loss = model(walk)
    loss.backward()
    optim.step()

2410
2411
2412
    model = nn.DeepWalk(
        g, emb_dim=8, walk_length=2, window_size=1, fast_neg=False, sparse=False
    )
LuckyLiuM's avatar
LuckyLiuM committed
2413
    model = model.to(dev)
2414
2415
2416
    dataloader = DataLoader(
        torch.arange(g.num_nodes()), batch_size=16, collate_fn=model.sample
    )
LuckyLiuM's avatar
LuckyLiuM committed
2417
2418
2419
2420
2421
    optim = Adam(model.parameters(), lr=0.01)
    walk = next(iter(dataloader)).to(dev)
    loss = model(walk)
    loss.backward()
    optim.step()
LuckyLiuM's avatar
LuckyLiuM committed
2422

2423
2424
2425
2426

@pytest.mark.parametrize("max_degree", [2, 6])
@pytest.mark.parametrize("embedding_dim", [8, 16])
@pytest.mark.parametrize("direction", ["in", "out", "both"])
2427
def test_degree_encoder(max_degree, embedding_dim, direction):
2428
    g1 = dgl.graph(
2429
2430
2431
2432
2433
        (
            th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),
            th.tensor([1, 2, 3, 0, 3, 0, 0, 1]),
        )
    )
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
    g2 = dgl.graph(
        (
            th.tensor([0, 1]),
            th.tensor([1, 0]),
        )
    )
    in_degree = pad_sequence(
        [g1.in_degrees(), g2.in_degrees()], batch_first=True
    )
    out_degree = pad_sequence(
        [g1.out_degrees(), g2.out_degrees()], batch_first=True
    )
2446
    model = nn.DegreeEncoder(max_degree, embedding_dim, direction=direction)
2447
2448
2449
2450
2451
2452
2453
    if direction == "in":
        de_g = model(in_degree)
    elif direction == "out":
        de_g = model(out_degree)
    elif direction == "both":
        de_g = model(th.stack((in_degree, out_degree)))
    assert de_g.shape == (2, 4, embedding_dim)
2454

2455

LuckyLiuM's avatar
LuckyLiuM committed
2456
2457
2458
@parametrize_idtype
def test_MetaPath2Vec(idtype):
    dev = F.ctx()
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
    g = dgl.heterograph(
        {
            ("user", "uc", "company"): ([0, 0, 2, 1, 3], [1, 2, 1, 3, 0]),
            ("company", "cp", "product"): (
                [0, 0, 0, 1, 2, 3],
                [0, 2, 3, 0, 2, 1],
            ),
            ("company", "cu", "user"): ([1, 2, 1, 3, 0], [0, 0, 2, 1, 3]),
            ("product", "pc", "company"): (
                [0, 2, 3, 0, 2, 1],
                [0, 0, 0, 1, 2, 3],
            ),
        },
        idtype=idtype,
        device=dev,
    )
    model = nn.MetaPath2Vec(g, ["uc", "cu"], window_size=1)
LuckyLiuM's avatar
LuckyLiuM committed
2476
2477
2478
    model = model.to(dev)
    embeds = model.node_embed.weight
    assert embeds.shape[0] == g.num_nodes()
2479

2480
2481
2482
2483

@pytest.mark.parametrize("num_layer", [1, 4])
@pytest.mark.parametrize("k", [3, 5])
@pytest.mark.parametrize("lpe_dim", [4, 16])
2484
@pytest.mark.parametrize("n_head", [2, 4])
2485
2486
@pytest.mark.parametrize("batch_norm", [True, False])
@pytest.mark.parametrize("num_post_layer", [0, 1, 2])
2487
def test_LapPosEncoder(
2488
2489
    num_layer, k, lpe_dim, n_head, batch_norm, num_post_layer
):
2490
2491
2492
2493
2494
2495
    ctx = F.ctx()
    num_nodes = 4

    EigVals = th.randn((num_nodes, k)).to(ctx)
    EigVecs = th.randn((num_nodes, k)).to(ctx)

2496
    model = nn.LapPosEncoder(
2497
2498
        "Transformer", num_layer, k, lpe_dim, n_head, batch_norm, num_post_layer
    ).to(ctx)
2499
2500
    assert model(EigVals, EigVecs).shape == (num_nodes, lpe_dim)

2501
    model = nn.LapPosEncoder(
2502
2503
2504
2505
2506
2507
2508
        "DeepSet",
        num_layer,
        k,
        lpe_dim,
        batch_norm=batch_norm,
        num_post_layer=num_post_layer,
    ).to(ctx)
2509
    assert model(EigVals, EigVecs).shape == (num_nodes, lpe_dim)
2510

2511
2512
2513
2514
2515
2516

@pytest.mark.parametrize("feat_size", [128, 512])
@pytest.mark.parametrize("num_heads", [8, 16])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("attn_bias_type", ["add", "mul"])
@pytest.mark.parametrize("attn_drop", [0.1, 0.5])
2517
def test_BiasedMHA(feat_size, num_heads, bias, attn_bias_type, attn_drop):
2518
2519
2520
2521
    ndata = th.rand(16, 100, feat_size)
    attn_bias = th.rand(16, 100, 100, num_heads)
    attn_mask = th.rand(16, 100, 100) < 0.5

2522
    net = nn.BiasedMHA(feat_size, num_heads, bias, attn_bias_type, attn_drop)
2523
2524
2525
    out = net(ndata, attn_bias, attn_mask)

    assert out.shape == (16, 100, feat_size)
2526

2527

Zhiteng Li's avatar
Zhiteng Li committed
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
@pytest.mark.parametrize("edge_update", [True, False])
def test_EGTLayer(edge_update):
    batch_size = 16
    num_nodes = 100
    feat_size, edge_feat_size = 128, 32
    nfeat = th.rand(batch_size, num_nodes, feat_size)
    efeat = th.rand(batch_size, num_nodes, num_nodes, edge_feat_size)
    mask = (th.rand(batch_size, num_nodes, num_nodes) < 0.5) * -1e9

    net = nn.EGTLayer(
        feat_size=feat_size,
        edge_feat_size=edge_feat_size,
        num_heads=8,
        num_virtual_nodes=4,
        edge_update=edge_update,
    )

    if edge_update:
        out_nfeat, out_efeat = net(nfeat, efeat, mask)
        assert out_nfeat.shape == (batch_size, num_nodes, feat_size)
        assert out_efeat.shape == (
            batch_size,
            num_nodes,
            num_nodes,
            edge_feat_size,
        )
    else:
        out_nfeat = net(nfeat, efeat, mask)
        assert out_nfeat.shape == (batch_size, num_nodes, feat_size)


2559
2560
@pytest.mark.parametrize("attn_bias_type", ["add", "mul"])
@pytest.mark.parametrize("norm_first", [True, False])
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
def test_GraphormerLayer(attn_bias_type, norm_first):
    batch_size = 16
    num_nodes = 100
    feat_size = 512
    num_heads = 8
    nfeat = th.rand(batch_size, num_nodes, feat_size)
    attn_bias = th.rand(batch_size, num_nodes, num_nodes, num_heads)
    attn_mask = th.rand(batch_size, num_nodes, num_nodes) < 0.5

    net = nn.GraphormerLayer(
        feat_size=feat_size,
        hidden_size=2048,
        num_heads=num_heads,
        attn_bias_type=attn_bias_type,
        norm_first=norm_first,
        dropout=0.1,
2577
        attn_dropout=0.1,
2578
        activation=th.nn.ReLU(),
2579
2580
2581
2582
2583
    )
    out = net(nfeat, attn_bias, attn_mask)

    assert out.shape == (batch_size, num_nodes, feat_size)

2584

2585
@pytest.mark.parametrize("max_len", [1, 2])
2586
2587
@pytest.mark.parametrize("feat_dim", [16])
@pytest.mark.parametrize("num_heads", [1, 8])
2588
2589
def test_PathEncoder(max_len, feat_dim, num_heads):
    dev = F.ctx()
2590
    g = dgl.graph(
2591
2592
2593
2594
2595
        (
            th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),
            th.tensor([1, 2, 3, 0, 3, 0, 0, 1]),
        )
    ).to(dev)
2596
2597
2598
2599
    edge_feat = th.rand(g.num_edges(), feat_dim).to(dev)
    edge_feat = th.cat((edge_feat, th.zeros(1, 16).to(dev)), dim=0)
    dist, path = shortest_dist(g, root=None, return_paths=True)
    path_data = edge_feat[path[:, :, :max_len]]
2600
    model = nn.PathEncoder(max_len, feat_dim, num_heads=num_heads).to(dev)
2601
2602
    bias = model(dist.unsqueeze(0), path_data.unsqueeze(0))
    assert bias.shape == (1, 4, 4, num_heads)
2603

2604
2605

@pytest.mark.parametrize("max_dist", [1, 4])
2606
@pytest.mark.parametrize("num_kernels", [4, 16])
2607
@pytest.mark.parametrize("num_heads", [1, 8])
2608
2609
def test_SpatialEncoder(max_dist, num_kernels, num_heads):
    dev = F.ctx()
2610
2611
2612
2613
2614
2615
2616
2617
2618
2619
2620
    # single graph encoding 3d
    num_nodes = 4
    coord = th.rand(1, num_nodes, 3).to(dev)
    node_type = th.tensor([[1, 0, 2, 1]]).to(dev)
    spatial_encoder = nn.SpatialEncoder3d(
        num_kernels=num_kernels, num_heads=num_heads, max_node_type=3
    ).to(dev)
    out = spatial_encoder(coord, node_type=node_type)
    assert out.shape == (1, num_nodes, num_nodes, num_heads)

    # encoding on a batch of graphs
2621
2622
2623
2624
2625
2626
2627
2628
2629
    g1 = dgl.graph(
        (
            th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),
            th.tensor([1, 2, 3, 0, 3, 0, 0, 1]),
        )
    ).to(dev)
    g2 = dgl.graph(
        (th.tensor([0, 1, 2, 3, 2, 5]), th.tensor([1, 2, 3, 4, 0, 3]))
    ).to(dev)
2630
2631
2632
    bsz, max_num_nodes = 2, 6
    # 2d encoding
    dist = -th.ones((bsz, max_num_nodes, max_num_nodes), dtype=th.long).to(dev)
2633
2634
    dist[0, :4, :4] = shortest_dist(g1, root=None, return_paths=False)
    dist[1, :6, :6] = shortest_dist(g2, root=None, return_paths=False)
2635
    model_1 = nn.SpatialEncoder(max_dist, num_heads=num_heads).to(dev)
2636
2637
2638
2639
2640
2641
2642
2643
2644
2645
2646
2647
    encoding = model_1(dist)
    assert encoding.shape == (bsz, max_num_nodes, max_num_nodes, num_heads)
    # 3d encoding
    coord = th.rand(bsz, max_num_nodes, 3).to(dev)
    node_type = th.randint(
        0,
        512,
        (
            bsz,
            max_num_nodes,
        ),
    ).to(dev)
2648
2649
2650
2651
    model_2 = nn.SpatialEncoder3d(num_kernels, num_heads=num_heads).to(dev)
    model_3 = nn.SpatialEncoder3d(
        num_kernels, num_heads=num_heads, max_node_type=512
    ).to(dev)
2652
2653
2654
2655
    encoding3d_1 = model_2(coord)
    encoding3d_2 = model_3(coord, node_type)
    assert encoding3d_1.shape == (bsz, max_num_nodes, max_num_nodes, num_heads)
    assert encoding3d_2.shape == (bsz, max_num_nodes, max_num_nodes, num_heads)