test_nn.py 81.3 KB
Newer Older
1
import io
2
import pickle
3
import random
4
import re
5
6
7
8
from copy import deepcopy

import backend as F

9
import dgl
10
import dgl.function as fn
11
12
import dgl.nn.pytorch as nn
import networkx as nx
13
import numpy as np  # For setting seed for scipy
14
import pytest
15
import scipy as sp
LuckyLiuM's avatar
LuckyLiuM committed
16
import torch
17
import torch as th
18
19
from dgl import shortest_dist
from torch.nn.utils.rnn import pad_sequence
20
21
22
23
from torch.optim import Adam, SparseAdam
from torch.utils.data import DataLoader
from utils import parametrize_idtype
from utils.graph_cases import (
24
25
26
27
28
    get_cases,
    random_bipartite,
    random_dglgraph,
    random_graph,
)
29

30
31
32
33
34
35
36
# Set seeds to make tests fully reproducible.
SEED = 12345  # random.randint(1, 99999)
random.seed(SEED)  # For networkx
np.random.seed(SEED)  # For scipy
dgl.seed(SEED)
F.seed(SEED)

37
38
tmp_buffer = io.BytesIO()

39

40
41
42
43
44
def _AXWb(A, X, W, b):
    X = th.matmul(X, W)
    Y = th.matmul(A, X.view(X.shape[0], -1)).view_as(X)
    return Y + b

45

46
47
48
49
50
def graph_with_nodes(num_nodes, ctx=None):
    g = dgl.from_networkx(nx.path_graph(num_nodes))
    return g.to(ctx) if ctx else g


51
@pytest.mark.parametrize("out_dim", [1, 2])
52
def test_graph_conv0(out_dim):
53
    ctx = F.ctx()
54
    g = graph_with_nodes(3, ctx)
55
    adj = g.adj_external(transpose=True, ctx=ctx)
56

57
    conv = nn.GraphConv(5, out_dim, norm="none", bias=True)
58
    conv = conv.to(ctx)
59
    print(conv)
60
61
62
63

    # test pickle
    th.save(conv, tmp_buffer)

64
    # test#1: basic
65
    h0 = F.ones((3, 5))
66
    h1 = conv(g, h0)
67
68
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
69
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
70
    # test#2: more-dim
71
    h0 = F.ones((3, 5, 5))
72
    h1 = conv(g, h0)
73
74
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
75
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
76

77
    conv = nn.GraphConv(5, out_dim)
78
    conv = conv.to(ctx)
79
    # test#3: basic
80
    h0 = F.ones((3, 5))
81
    h1 = conv(g, h0)
82
83
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
84
    # test#4: basic
85
    h0 = F.ones((3, 5, 5))
86
    h1 = conv(g, h0)
87
88
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
89

90
    conv = nn.GraphConv(5, out_dim)
91
    conv = conv.to(ctx)
92
    # test#3: basic
93
    h0 = F.ones((3, 5))
94
    h1 = conv(g, h0)
95
96
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
97
    # test#4: basic
98
    h0 = F.ones((3, 5, 5))
99
    h1 = conv(g, h0)
100
101
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
102
103
104
105
106

    # test rest_parameters
    old_weight = deepcopy(conv.weight.data)
    conv.reset_parameters()
    new_weight = conv.weight.data
107
    assert not F.allclose(old_weight, new_weight)
108

109

nv-dlasalle's avatar
nv-dlasalle committed
110
@parametrize_idtype
111
112
113
114
115
116
117
@pytest.mark.parametrize(
    "g", get_cases(["homo", "bipartite"], exclude=["zero-degree", "dglgraph"])
)
@pytest.mark.parametrize("norm", ["none", "both", "right", "left"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("out_dim", [1, 2])
118
def test_graph_conv(idtype, g, norm, weight, bias, out_dim):
119
120
    # Test one tensor input
    g = g.astype(idtype).to(F.ctx())
121
122
123
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(
        F.ctx()
    )
124
    ext_w = F.randn((5, out_dim)).to(F.ctx())
125
126
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
127
128
    h = F.randn((nsrc, 5)).to(F.ctx())
    if weight:
129
        h_out = conv(g, h)
130
    else:
131
        h_out = conv(g, h, weight=ext_w)
132
    assert h_out.shape == (ndst, out_dim)
133

134

nv-dlasalle's avatar
nv-dlasalle committed
135
@parametrize_idtype
136
137
138
139
140
141
142
143
@pytest.mark.parametrize(
    "g",
    get_cases(["has_scalar_e_feature"], exclude=["zero-degree", "dglgraph"]),
)
@pytest.mark.parametrize("norm", ["none", "both", "right"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("out_dim", [1, 2])
144
def test_graph_conv_e_weight(idtype, g, norm, weight, bias, out_dim):
145
    g = g.astype(idtype).to(F.ctx())
146
147
148
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(
        F.ctx()
    )
149
    ext_w = F.randn((5, out_dim)).to(F.ctx())
150
151
152
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
153
    e_w = g.edata["scalar_w"]
154
155
156
157
    if weight:
        h_out = conv(g, h, edge_weight=e_w)
    else:
        h_out = conv(g, h, weight=ext_w, edge_weight=e_w)
158
    assert h_out.shape == (ndst, out_dim)
159

160

nv-dlasalle's avatar
nv-dlasalle committed
161
@parametrize_idtype
162
163
164
165
166
167
168
169
@pytest.mark.parametrize(
    "g",
    get_cases(["has_scalar_e_feature"], exclude=["zero-degree", "dglgraph"]),
)
@pytest.mark.parametrize("norm", ["none", "both", "right"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("out_dim", [1, 2])
170
def test_graph_conv_e_weight_norm(idtype, g, norm, weight, bias, out_dim):
171
    g = g.astype(idtype).to(F.ctx())
172
173
174
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(
        F.ctx()
    )
175
176
177
178

    # test pickle
    th.save(conv, tmp_buffer)

179
    ext_w = F.randn((5, out_dim)).to(F.ctx())
180
181
182
183
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
    edgenorm = nn.EdgeWeightNorm(norm=norm)
184
    norm_weight = edgenorm(g, g.edata["scalar_w"])
185
186
187
188
    if weight:
        h_out = conv(g, h, edge_weight=norm_weight)
    else:
        h_out = conv(g, h, weight=ext_w, edge_weight=norm_weight)
189
    assert h_out.shape == (ndst, out_dim)
190

191

nv-dlasalle's avatar
nv-dlasalle committed
192
@parametrize_idtype
193
194
195
196
197
198
199
@pytest.mark.parametrize(
    "g", get_cases(["bipartite"], exclude=["zero-degree", "dglgraph"])
)
@pytest.mark.parametrize("norm", ["none", "both", "right"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("out_dim", [1, 2])
200
def test_graph_conv_bi(idtype, g, norm, weight, bias, out_dim):
201
202
    # Test a pair of tensor inputs
    g = g.astype(idtype).to(F.ctx())
203
204
205
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(
        F.ctx()
    )
Mufei Li's avatar
Mufei Li committed
206

207
208
209
    # test pickle
    th.save(conv, tmp_buffer)

210
    ext_w = F.randn((5, out_dim)).to(F.ctx())
211
212
213
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
214
    h_dst = F.randn((ndst, out_dim)).to(F.ctx())
215
216
217
218
    if weight:
        h_out = conv(g, (h, h_dst))
    else:
        h_out = conv(g, (h, h_dst), weight=ext_w)
219
    assert h_out.shape == (ndst, out_dim)
220

221

222
223
224
225
226
227
228
229
230
231
232
233
def _S2AXWb(A, N, X, W, b):
    X1 = X * N
    X1 = th.matmul(A, X1.view(X1.shape[0], -1))
    X1 = X1 * N
    X2 = X1 * N
    X2 = th.matmul(A, X2.view(X2.shape[0], -1))
    X2 = X2 * N
    X = th.cat([X, X1, X2], dim=-1)
    Y = th.matmul(X, W.rot90())

    return Y + b

234
235

@pytest.mark.parametrize("out_dim", [1, 2])
236
def test_tagconv(out_dim):
237
    ctx = F.ctx()
238
    g = graph_with_nodes(3, ctx)
239
    adj = g.adj_external(transpose=True, ctx=ctx)
240
241
    norm = th.pow(g.in_degrees().float(), -0.5)

242
    conv = nn.TAGConv(5, out_dim, bias=True)
243
    conv = conv.to(ctx)
244
    print(conv)
Mufei Li's avatar
Mufei Li committed
245

246
247
    # test pickle
    th.save(conv, tmp_buffer)
248
249
250

    # test#1: basic
    h0 = F.ones((3, 5))
251
    h1 = conv(g, h0)
252
253
254
255
256
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    shp = norm.shape + (1,) * (h0.dim() - 1)
    norm = th.reshape(norm, shp).to(ctx)

257
258
259
    assert F.allclose(
        h1, _S2AXWb(adj, norm, h0, conv.lin.weight, conv.lin.bias)
    )
260

261
    conv = nn.TAGConv(5, out_dim)
262
    conv = conv.to(ctx)
263

264
265
    # test#2: basic
    h0 = F.ones((3, 5))
266
    h1 = conv(g, h0)
267
    assert h1.shape[-1] == out_dim
268

269
    # test reset_parameters
270
271
272
273
274
    old_weight = deepcopy(conv.lin.weight.data)
    conv.reset_parameters()
    new_weight = conv.lin.weight.data
    assert not F.allclose(old_weight, new_weight)

275

276
def test_set2set():
277
    ctx = F.ctx()
278
    g = graph_with_nodes(10, ctx)
279

280
    s2s = nn.Set2Set(5, 3, 3)  # hidden size 5, 3 iters, 3 layers
281
    s2s = s2s.to(ctx)
282
283
284
    print(s2s)

    # test#1: basic
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
285
    h0 = F.randn((g.num_nodes(), 5))
286
    h1 = s2s(g, h0)
287
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
288
289

    # test#2: batched graph
290
291
    g1 = graph_with_nodes(11, ctx)
    g2 = graph_with_nodes(5, ctx)
292
    bg = dgl.batch([g, g1, g2])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
293
    h0 = F.randn((bg.num_nodes(), 5))
294
    h1 = s2s(bg, h0)
295
296
    assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2

297

298
def test_glob_att_pool():
299
    ctx = F.ctx()
300
    g = graph_with_nodes(10, ctx)
301
302

    gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))
303
    gap = gap.to(ctx)
304
305
    print(gap)

306
307
308
    # test pickle
    th.save(gap, tmp_buffer)

309
    # test#1: basic
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
310
    h0 = F.randn((g.num_nodes(), 5))
311
    h1 = gap(g, h0)
312
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
313
314
315

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
316
    h0 = F.randn((bg.num_nodes(), 5))
317
    h1 = gap(bg, h0)
318
319
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2

320

321
def test_simple_pool():
322
    ctx = F.ctx()
323
    g = graph_with_nodes(15, ctx)
324
325
326
327

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
328
    sort_pool = nn.SortPooling(10)  # k = 10
329
330
331
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
332
    h0 = F.randn((g.num_nodes(), 5))
333
334
335
336
    sum_pool = sum_pool.to(ctx)
    avg_pool = avg_pool.to(ctx)
    max_pool = max_pool.to(ctx)
    sort_pool = sort_pool.to(ctx)
337
    h1 = sum_pool(g, h0)
338
    assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
339
    h1 = avg_pool(g, h0)
340
    assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
341
    h1 = max_pool(g, h0)
342
    assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
343
    h1 = sort_pool(g, h0)
344
    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.dim() == 2
345
346

    # test#2: batched graph
347
    g_ = graph_with_nodes(5, ctx)
348
    bg = dgl.batch([g, g_, g, g_, g])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
349
    h0 = F.randn((bg.num_nodes(), 5))
350
    h1 = sum_pool(bg, h0)
351
352
353
354
355
356
357
358
359
360
    truth = th.stack(
        [
            F.sum(h0[:15], 0),
            F.sum(h0[15:20], 0),
            F.sum(h0[20:35], 0),
            F.sum(h0[35:40], 0),
            F.sum(h0[40:55], 0),
        ],
        0,
    )
361
    assert F.allclose(h1, truth)
362

363
    h1 = avg_pool(bg, h0)
364
365
366
367
368
369
370
371
372
373
    truth = th.stack(
        [
            F.mean(h0[:15], 0),
            F.mean(h0[15:20], 0),
            F.mean(h0[20:35], 0),
            F.mean(h0[35:40], 0),
            F.mean(h0[40:55], 0),
        ],
        0,
    )
374
    assert F.allclose(h1, truth)
375

376
    h1 = max_pool(bg, h0)
377
378
379
380
381
382
383
384
385
386
    truth = th.stack(
        [
            F.max(h0[:15], 0),
            F.max(h0[15:20], 0),
            F.max(h0[20:35], 0),
            F.max(h0[35:40], 0),
            F.max(h0[40:55], 0),
        ],
        0,
    )
387
    assert F.allclose(h1, truth)
388

389
    h1 = sort_pool(bg, h0)
390
391
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2

392

393
def test_set_trans():
394
    ctx = F.ctx()
395
    g = graph_with_nodes(15)
396

397
398
    st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, "sab")
    st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, "isab", 3)
399
    st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)
400
401
402
    st_enc_0 = st_enc_0.to(ctx)
    st_enc_1 = st_enc_1.to(ctx)
    st_dec = st_dec.to(ctx)
403
404
405
    print(st_enc_0, st_enc_1, st_dec)

    # test#1: basic
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
406
    h0 = F.randn((g.num_nodes(), 50))
407
    h1 = st_enc_0(g, h0)
408
    assert h1.shape == h0.shape
409
    h1 = st_enc_1(g, h0)
410
    assert h1.shape == h0.shape
411
    h2 = st_dec(g, h1)
412
    assert h2.shape[0] == 1 and h2.shape[1] == 200 and h2.dim() == 2
413
414

    # test#2: batched graph
415
416
    g1 = graph_with_nodes(5)
    g2 = graph_with_nodes(10)
417
    bg = dgl.batch([g, g1, g2])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
418
    h0 = F.randn((bg.num_nodes(), 50))
419
    h1 = st_enc_0(bg, h0)
420
    assert h1.shape == h0.shape
421
    h1 = st_enc_1(bg, h0)
422
423
    assert h1.shape == h0.shape

424
    h2 = st_dec(bg, h1)
425
426
    assert h2.shape[0] == 3 and h2.shape[1] == 200 and h2.dim() == 2

427

nv-dlasalle's avatar
nv-dlasalle committed
428
@parametrize_idtype
429
@pytest.mark.parametrize("O", [1, 8, 32])
430
def test_rgcn(idtype, O):
Minjie Wang's avatar
Minjie Wang committed
431
432
    ctx = F.ctx()
    etype = []
433
434
    g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.1))
    g = g.astype(idtype).to(F.ctx())
Minjie Wang's avatar
Minjie Wang committed
435
436
    # 5 etypes
    R = 5
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
437
    for i in range(g.num_edges()):
Minjie Wang's avatar
Minjie Wang committed
438
439
440
441
442
443
        etype.append(i % 5)
    B = 2
    I = 10

    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
444
    norm = th.rand((g.num_edges(), 1)).to(ctx)
445
    sorted_r, idx = th.sort(r)
446
447
448
449
450
    sorted_g = dgl.reorder_graph(
        g,
        edge_permute_algo="custom",
        permute_config={"edges_perm": idx.to(idtype)},
    )
451
    sorted_norm = norm[idx]
Minjie Wang's avatar
Minjie Wang committed
452

453
454
    rgc = nn.RelGraphConv(I, O, R).to(ctx)
    th.save(rgc, tmp_buffer)  # test pickle
Minjie Wang's avatar
Minjie Wang committed
455
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
456
    th.save(rgc_basis, tmp_buffer)  # test pickle
457
458
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
459
        th.save(rgc_bdd, tmp_buffer)  # test pickle
460

461
462
463
464
465
    # basic usage
    h_new = rgc(g, h, r)
    assert h_new.shape == (100, O)
    h_new_basis = rgc_basis(g, h, r)
    assert h_new_basis.shape == (100, O)
466
    if O % B == 0:
467
468
469
470
471
472
473
474
475
476
477
        h_new_bdd = rgc_bdd(g, h, r)
        assert h_new_bdd.shape == (100, O)

    # sorted input
    h_new_sorted = rgc(sorted_g, h, sorted_r, presorted=True)
    assert th.allclose(h_new, h_new_sorted, atol=1e-4, rtol=1e-4)
    h_new_basis_sorted = rgc_basis(sorted_g, h, sorted_r, presorted=True)
    assert th.allclose(h_new_basis, h_new_basis_sorted, atol=1e-4, rtol=1e-4)
    if O % B == 0:
        h_new_bdd_sorted = rgc_bdd(sorted_g, h, sorted_r, presorted=True)
        assert th.allclose(h_new_bdd, h_new_bdd_sorted, atol=1e-4, rtol=1e-4)
478

479
480
481
    # norm input
    h_new = rgc(g, h, r, norm)
    assert h_new.shape == (100, O)
482
    h_new = rgc_basis(g, h, r, norm)
483
    assert h_new.shape == (100, O)
484
485
    if O % B == 0:
        h_new = rgc_bdd(g, h, r, norm)
486
        assert h_new.shape == (100, O)
487

488

489
@parametrize_idtype
490
@pytest.mark.parametrize("O", [1, 10, 40])
491
492
493
494
495
496
497
def test_rgcn_default_nbasis(idtype, O):
    ctx = F.ctx()
    etype = []
    g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.1))
    g = g.astype(idtype).to(F.ctx())
    # 5 etypes
    R = 5
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
498
    for i in range(g.num_edges()):
499
500
501
502
503
        etype.append(i % 5)
    I = 10

    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
504
    norm = th.rand((g.num_edges(), 1)).to(ctx)
505
    sorted_r, idx = th.sort(r)
506
507
508
509
510
    sorted_g = dgl.reorder_graph(
        g,
        edge_permute_algo="custom",
        permute_config={"edges_perm": idx.to(idtype)},
    )
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
    sorted_norm = norm[idx]

    rgc = nn.RelGraphConv(I, O, R).to(ctx)
    th.save(rgc, tmp_buffer)  # test pickle
    rgc_basis = nn.RelGraphConv(I, O, R, "basis").to(ctx)
    th.save(rgc_basis, tmp_buffer)  # test pickle
    if O % R == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd").to(ctx)
        th.save(rgc_bdd, tmp_buffer)  # test pickle

    # basic usage
    h_new = rgc(g, h, r)
    assert h_new.shape == (100, O)
    h_new_basis = rgc_basis(g, h, r)
    assert h_new_basis.shape == (100, O)
    if O % R == 0:
        h_new_bdd = rgc_bdd(g, h, r)
        assert h_new_bdd.shape == (100, O)

    # sorted input
    h_new_sorted = rgc(sorted_g, h, sorted_r, presorted=True)
    assert th.allclose(h_new, h_new_sorted, atol=1e-4, rtol=1e-4)
    h_new_basis_sorted = rgc_basis(sorted_g, h, sorted_r, presorted=True)
    assert th.allclose(h_new_basis, h_new_basis_sorted, atol=1e-4, rtol=1e-4)
    if O % R == 0:
        h_new_bdd_sorted = rgc_bdd(sorted_g, h, sorted_r, presorted=True)
        assert th.allclose(h_new_bdd, h_new_bdd_sorted, atol=1e-4, rtol=1e-4)

    # norm input
    h_new = rgc(g, h, r, norm)
    assert h_new.shape == (100, O)
    h_new = rgc_basis(g, h, r, norm)
    assert h_new.shape == (100, O)
    if O % R == 0:
        h_new = rgc_bdd(g, h, r, norm)
        assert h_new.shape == (100, O)
547

548

nv-dlasalle's avatar
nv-dlasalle committed
549
@parametrize_idtype
550
551
552
553
554
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
555
def test_gat_conv(g, idtype, out_dim, num_heads):
556
    ctx = F.ctx()
557
    g = g.astype(idtype).to(ctx)
558
    gat = nn.GATConv(5, out_dim, num_heads)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
559
    feat = F.randn((g.number_of_src_nodes(), 5))
560
    gat = gat.to(ctx)
561
    h = gat(g, feat)
562
563
564
565

    # test pickle
    th.save(gat, tmp_buffer)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
566
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
567
    _, a = gat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
568
    assert a.shape == (g.num_edges(), num_heads, 1)
569

570
571
572
573
574
    # test residual connection
    gat = nn.GATConv(5, out_dim, num_heads, residual=True)
    gat = gat.to(ctx)
    h = gat(g, feat)

575

nv-dlasalle's avatar
nv-dlasalle committed
576
@parametrize_idtype
577
578
579
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
580
def test_gat_conv_bi(g, idtype, out_dim, num_heads):
581
    ctx = F.ctx()
582
    g = g.astype(idtype).to(ctx)
583
    gat = nn.GATConv(5, out_dim, num_heads)
584
585
586
587
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
588
589
    gat = gat.to(ctx)
    h = gat(g, feat)
590
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
591
    _, a = gat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
592
    assert a.shape == (g.num_edges(), num_heads, 1)
593

594

595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
@parametrize_idtype
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
def test_gat_conv_edge_weight(g, idtype, out_dim, num_heads):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    gat = nn.GATConv(5, out_dim, num_heads)
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
    gat = gat.to(ctx)
    ew = F.randn((g.num_edges(),))
    h = gat(g, feat, edge_weight=ew)
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
    _, a = gat(g, feat, get_attention=True)
    assert a.shape[0] == ew.shape[0]
    assert a.shape == (g.num_edges(), num_heads, 1)


nv-dlasalle's avatar
nv-dlasalle committed
616
@parametrize_idtype
617
618
619
620
621
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
Shaked Brody's avatar
Shaked Brody committed
622
623
624
625
626
627
628
629
630
631
632
633
634
def test_gatv2_conv(g, idtype, out_dim, num_heads):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    gat = nn.GATv2Conv(5, out_dim, num_heads)
    feat = F.randn((g.number_of_src_nodes(), 5))
    gat = gat.to(ctx)
    h = gat(g, feat)

    # test pickle
    th.save(gat, tmp_buffer)

    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
    _, a = gat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
635
    assert a.shape == (g.num_edges(), num_heads, 1)
Shaked Brody's avatar
Shaked Brody committed
636
637
638
639
640
641

    # test residual connection
    gat = nn.GATConv(5, out_dim, num_heads, residual=True)
    gat = gat.to(ctx)
    h = gat(g, feat)

642

nv-dlasalle's avatar
nv-dlasalle committed
643
@parametrize_idtype
644
645
646
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
Shaked Brody's avatar
Shaked Brody committed
647
648
649
650
def test_gatv2_conv_bi(g, idtype, out_dim, num_heads):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    gat = nn.GATv2Conv(5, out_dim, num_heads)
651
652
653
654
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
Shaked Brody's avatar
Shaked Brody committed
655
656
657
658
    gat = gat.to(ctx)
    h = gat(g, feat)
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
    _, a = gat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
659
    assert a.shape == (g.num_edges(), num_heads, 1)
Shaked Brody's avatar
Shaked Brody committed
660

661

nv-dlasalle's avatar
nv-dlasalle committed
662
@parametrize_idtype
663
664
665
666
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_node_feats", [1, 5])
@pytest.mark.parametrize("out_edge_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
667
def test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads):
Mufei Li's avatar
Mufei Li committed
668
    ctx = F.ctx()
669
    g = g.astype(idtype).to(ctx)
670
671
672
673
674
675
676
    egat = nn.EGATConv(
        in_node_feats=10,
        in_edge_feats=5,
        out_node_feats=out_node_feats,
        out_edge_feats=out_edge_feats,
        num_heads=num_heads,
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
677
678
    nfeat = F.randn((g.num_nodes(), 10))
    efeat = F.randn((g.num_edges(), 5))
679
680
    egat = egat.to(ctx)
    h, f = egat(g, nfeat, efeat)
681

682
    th.save(egat, tmp_buffer)
683

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
684
685
    assert h.shape == (g.num_nodes(), num_heads, out_node_feats)
    assert f.shape == (g.num_edges(), num_heads, out_edge_feats)
686
    _, _, attn = egat(g, nfeat, efeat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
687
    assert attn.shape == (g.num_edges(), num_heads, 1)
688

689

690
@parametrize_idtype
691
692
693
694
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_node_feats", [1, 5])
@pytest.mark.parametrize("out_edge_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
695
696
def test_egat_conv_bi(g, idtype, out_node_feats, out_edge_feats, num_heads):
    ctx = F.ctx()
697
    g = g.astype(idtype).to(ctx)
698
699
700
701
702
703
704
705
706
707
708
    egat = nn.EGATConv(
        in_node_feats=(10, 15),
        in_edge_feats=7,
        out_node_feats=out_node_feats,
        out_edge_feats=out_edge_feats,
        num_heads=num_heads,
    )
    nfeat = (
        F.randn((g.number_of_src_nodes(), 10)),
        F.randn((g.number_of_dst_nodes(), 15)),
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
709
    efeat = F.randn((g.num_edges(), 7))
710
711
    egat = egat.to(ctx)
    h, f = egat(g, nfeat, efeat)
712

Mufei Li's avatar
Mufei Li committed
713
    th.save(egat, tmp_buffer)
714

715
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_node_feats)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
716
    assert f.shape == (g.num_edges(), num_heads, out_edge_feats)
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
    _, _, attn = egat(g, nfeat, efeat, get_attention=True)
    assert attn.shape == (g.num_edges(), num_heads, 1)


@parametrize_idtype
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_node_feats", [1, 5])
@pytest.mark.parametrize("out_edge_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
def test_egat_conv_edge_weight(
    g, idtype, out_node_feats, out_edge_feats, num_heads
):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    egat = nn.EGATConv(
        in_node_feats=10,
        in_edge_feats=5,
        out_node_feats=out_node_feats,
        out_edge_feats=out_edge_feats,
        num_heads=num_heads,
    )
    egat = egat.to(ctx)
    nfeat = F.randn((g.num_nodes(), 10))
    efeat = F.randn((g.num_edges(), 5))
    ew = F.randn((g.num_edges(),))

    h, f, attn = egat(g, nfeat, efeat, edge_weight=ew, get_attention=True)

    assert h.shape == (g.num_nodes(), num_heads, out_node_feats)
    assert f.shape == (g.num_edges(), num_heads, out_edge_feats)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
747
    assert attn.shape == (g.num_edges(), num_heads, 1)
schmidt-ju's avatar
schmidt-ju committed
748
749
750
751
752
753
754
755


@parametrize_idtype
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
def test_edgegat_conv(g, idtype, out_feats, num_heads):
    ctx = F.ctx()
756
    g = g.astype(idtype).to(ctx)
schmidt-ju's avatar
schmidt-ju committed
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
    edgegat = nn.EdgeGATConv(
        in_feats=10, edge_feats=5, out_feats=out_feats, num_heads=num_heads
    )
    nfeat = F.randn((g.number_of_nodes(), 10))
    efeat = F.randn((g.number_of_edges(), 5))
    edgegat = edgegat.to(ctx)
    h = edgegat(g, nfeat, efeat)

    th.save(edgegat, tmp_buffer)

    assert h.shape == (g.number_of_nodes(), num_heads, out_feats)
    _, attn = edgegat(g, nfeat, efeat, True)
    assert attn.shape == (g.number_of_edges(), num_heads, 1)


@parametrize_idtype
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
def test_edgegat_conv_bi(g, idtype, out_feats, num_heads):
    ctx = F.ctx()
778
    g = g.astype(idtype).to(ctx)
schmidt-ju's avatar
schmidt-ju committed
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
    edgegat = nn.EdgeGATConv(
        in_feats=(10, 15),
        edge_feats=7,
        out_feats=out_feats,
        num_heads=num_heads,
    )
    nfeat = (
        F.randn((g.number_of_src_nodes(), 10)),
        F.randn((g.number_of_dst_nodes(), 15)),
    )
    efeat = F.randn((g.number_of_edges(), 7))
    edgegat = edgegat.to(ctx)
    h = edgegat(g, nfeat, efeat)

    th.save(edgegat, tmp_buffer)

    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_feats)
    _, attn = edgegat(g, nfeat, efeat, True)
    assert attn.shape == (g.number_of_edges(), num_heads, 1)
798

799

nv-dlasalle's avatar
nv-dlasalle committed
800
@parametrize_idtype
801
802
@pytest.mark.parametrize("g", get_cases(["homo", "block-bipartite"]))
@pytest.mark.parametrize("aggre_type", ["mean", "pool", "gcn", "lstm"])
803
804
def test_sage_conv(idtype, g, aggre_type):
    g = g.astype(idtype).to(F.ctx())
805
    sage = nn.SAGEConv(5, 10, aggre_type)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
806
    feat = F.randn((g.number_of_src_nodes(), 5))
807
    sage = sage.to(F.ctx())
808
809
    # test pickle
    th.save(sage, tmp_buffer)
810
811
812
    h = sage(g, feat)
    assert h.shape[-1] == 10

813

nv-dlasalle's avatar
nv-dlasalle committed
814
@parametrize_idtype
815
816
817
@pytest.mark.parametrize("g", get_cases(["bipartite"]))
@pytest.mark.parametrize("aggre_type", ["mean", "pool", "gcn", "lstm"])
@pytest.mark.parametrize("out_dim", [1, 2])
818
def test_sage_conv_bi(idtype, g, aggre_type, out_dim):
819
    g = g.astype(idtype).to(F.ctx())
820
    dst_dim = 5 if aggre_type != "gcn" else 10
821
    sage = nn.SAGEConv((10, dst_dim), out_dim, aggre_type)
822
823
824
825
    feat = (
        F.randn((g.number_of_src_nodes(), 10)),
        F.randn((g.number_of_dst_nodes(), dst_dim)),
    )
826
    sage = sage.to(F.ctx())
827
    h = sage(g, feat)
828
    assert h.shape[-1] == out_dim
829
    assert h.shape[0] == g.number_of_dst_nodes()
830

831

nv-dlasalle's avatar
nv-dlasalle committed
832
@parametrize_idtype
833
@pytest.mark.parametrize("out_dim", [1, 2])
834
def test_sage_conv2(idtype, out_dim):
835
    # TODO: add test for blocks
Mufei Li's avatar
Mufei Li committed
836
    # Test the case for graphs without edges
837
    g = dgl.heterograph({("_U", "_E", "_V"): ([], [])}, {"_U": 5, "_V": 3})
838
839
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
840
    sage = nn.SAGEConv((3, 3), out_dim, "gcn")
Mufei Li's avatar
Mufei Li committed
841
842
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    sage = sage.to(ctx)
843
    h = sage(g, (F.copy_to(feat[0], F.ctx()), F.copy_to(feat[1], F.ctx())))
844
    assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
845
    assert h.shape[0] == 3
846
    for aggre_type in ["mean", "pool", "lstm"]:
847
        sage = nn.SAGEConv((3, 1), out_dim, aggre_type)
Mufei Li's avatar
Mufei Li committed
848
849
850
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        sage = sage.to(ctx)
        h = sage(g, feat)
851
        assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
852
853
        assert h.shape[0] == 3

854

nv-dlasalle's avatar
nv-dlasalle committed
855
@parametrize_idtype
856
857
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
858
def test_sgc_conv(g, idtype, out_dim):
859
    ctx = F.ctx()
860
    g = g.astype(idtype).to(ctx)
861
    # not cached
862
    sgc = nn.SGConv(5, out_dim, 3)
863
864
865
866

    # test pickle
    th.save(sgc, tmp_buffer)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
867
    feat = F.randn((g.num_nodes(), 5))
868
    sgc = sgc.to(ctx)
869

870
    h = sgc(g, feat)
871
    assert h.shape[-1] == out_dim
872
873

    # cached
874
    sgc = nn.SGConv(5, out_dim, 3, True)
875
    sgc = sgc.to(ctx)
876
877
    h_0 = sgc(g, feat)
    h_1 = sgc(g, feat + 1)
878
    assert F.allclose(h_0, h_1)
879
    assert h_0.shape[-1] == out_dim
880

881

nv-dlasalle's avatar
nv-dlasalle committed
882
@parametrize_idtype
883
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
884
def test_appnp_conv(g, idtype):
885
    ctx = F.ctx()
886
    g = g.astype(idtype).to(ctx)
887
    appnp = nn.APPNPConv(10, 0.1)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
888
    feat = F.randn((g.num_nodes(), 5))
889
    appnp = appnp.to(ctx)
Mufei Li's avatar
Mufei Li committed
890

891
892
    # test pickle
    th.save(appnp, tmp_buffer)
893

894
    h = appnp(g, feat)
895
896
    assert h.shape[-1] == 5

897

nv-dlasalle's avatar
nv-dlasalle committed
898
@parametrize_idtype
899
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
900
901
902
903
def test_appnp_conv_e_weight(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    appnp = nn.APPNPConv(10, 0.1)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
904
    feat = F.randn((g.num_nodes(), 5))
905
    eweight = F.ones((g.num_edges(),))
906
907
908
909
910
    appnp = appnp.to(ctx)

    h = appnp(g, feat, edge_weight=eweight)
    assert h.shape[-1] == 5

911

nv-dlasalle's avatar
nv-dlasalle committed
912
@parametrize_idtype
913
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
914
915
@pytest.mark.parametrize("bias", [True, False])
def test_gcn2conv_e_weight(g, idtype, bias):
916
917
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
918
919
920
    gcn2conv = nn.GCN2Conv(
        5, layer=2, alpha=0.5, bias=bias, project_initial_features=True
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
921
    feat = F.randn((g.num_nodes(), 5))
922
    eweight = F.ones((g.num_edges(),))
923
924
925
926
    gcn2conv = gcn2conv.to(ctx)
    res = feat
    h = gcn2conv(g, res, feat, edge_weight=eweight)
    assert h.shape[-1] == 5
927
928
929
    assert re.match(
        re.compile(".*GCN2Conv.*in=.*, alpha=.*, beta=.*"), str(gcn2conv)
    )
930
931


nv-dlasalle's avatar
nv-dlasalle committed
932
@parametrize_idtype
933
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
934
935
936
937
def test_sgconv_e_weight(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    sgconv = nn.SGConv(5, 5, 3)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
938
    feat = F.randn((g.num_nodes(), 5))
939
    eweight = F.ones((g.num_edges(),))
940
941
942
943
    sgconv = sgconv.to(ctx)
    h = sgconv(g, feat, edge_weight=eweight)
    assert h.shape[-1] == 5

944

nv-dlasalle's avatar
nv-dlasalle committed
945
@parametrize_idtype
946
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
947
948
949
950
951
def test_tagconv_e_weight(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    conv = nn.TAGConv(5, 5, bias=True)
    conv = conv.to(ctx)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
952
    feat = F.randn((g.num_nodes(), 5))
953
    eweight = F.ones((g.num_edges(),))
954
955
956
957
    conv = conv.to(ctx)
    h = conv(g, feat, edge_weight=eweight)
    assert h.shape[-1] == 5

958

nv-dlasalle's avatar
nv-dlasalle committed
959
@parametrize_idtype
960
961
962
963
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("aggregator_type", ["mean", "max", "sum"])
964
965
def test_gin_conv(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
966
    ctx = F.ctx()
967
    gin = nn.GINConv(th.nn.Linear(5, 12), aggregator_type)
VoVAllen's avatar
VoVAllen committed
968
    th.save(gin, tmp_buffer)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
969
    feat = F.randn((g.number_of_src_nodes(), 5))
970
971
    gin = gin.to(ctx)
    h = gin(g, feat)
972
973

    # test pickle
VoVAllen's avatar
VoVAllen committed
974
    th.save(gin, tmp_buffer)
Mufei Li's avatar
Mufei Li committed
975

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
976
    assert h.shape == (g.number_of_dst_nodes(), 12)
977

Mufei Li's avatar
Mufei Li committed
978
979
980
981
    gin = nn.GINConv(None, aggregator_type)
    th.save(gin, tmp_buffer)
    gin = gin.to(ctx)
    h = gin(g, feat)
982

983

nv-dlasalle's avatar
nv-dlasalle committed
984
@parametrize_idtype
985
@pytest.mark.parametrize("g", get_cases(["homo", "block-bipartite"]))
Mufei Li's avatar
Mufei Li committed
986
987
988
def test_gine_conv(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
989
    gine = nn.GINEConv(th.nn.Linear(5, 12))
Mufei Li's avatar
Mufei Li committed
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
    th.save(gine, tmp_buffer)
    nfeat = F.randn((g.number_of_src_nodes(), 5))
    efeat = F.randn((g.num_edges(), 5))
    gine = gine.to(ctx)
    h = gine(g, nfeat, efeat)

    # test pickle
    th.save(gine, tmp_buffer)
    assert h.shape == (g.number_of_dst_nodes(), 12)

    gine = nn.GINEConv(None)
    th.save(gine, tmp_buffer)
    gine = gine.to(ctx)
    h = gine(g, nfeat, efeat)

1005

nv-dlasalle's avatar
nv-dlasalle committed
1006
@parametrize_idtype
1007
1008
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("aggregator_type", ["mean", "max", "sum"])
1009
1010
1011
def test_gin_conv_bi(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1012
1013
1014
1015
    gin = nn.GINConv(th.nn.Linear(5, 12), aggregator_type)
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
1016
1017
1018
    )
    gin = gin.to(ctx)
    h = gin(g, feat)
1019
    assert h.shape == (g.number_of_dst_nodes(), 12)
1020

1021

nv-dlasalle's avatar
nv-dlasalle committed
1022
@parametrize_idtype
1023
1024
1025
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
1026
1027
def test_agnn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
1028
1029
    ctx = F.ctx()
    agnn = nn.AGNNConv(1)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1030
    feat = F.randn((g.number_of_src_nodes(), 5))
1031
    agnn = agnn.to(ctx)
1032
    h = agnn(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1033
    assert h.shape == (g.number_of_dst_nodes(), 5)
1034

1035

nv-dlasalle's avatar
nv-dlasalle committed
1036
@parametrize_idtype
1037
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
1038
1039
1040
def test_agnn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1041
    agnn = nn.AGNNConv(1)
1042
1043
1044
1045
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
1046
1047
    agnn = agnn.to(ctx)
    h = agnn(g, feat)
1048
    assert h.shape == (g.number_of_dst_nodes(), 5)
1049

1050

nv-dlasalle's avatar
nv-dlasalle committed
1051
@parametrize_idtype
1052
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
1053
def test_gated_graph_conv(g, idtype):
1054
    ctx = F.ctx()
1055
    g = g.astype(idtype).to(ctx)
1056
    ggconv = nn.GatedGraphConv(5, 10, 5, 3)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1057
1058
    etypes = th.arange(g.num_edges()) % 3
    feat = F.randn((g.num_nodes(), 5))
1059
1060
    ggconv = ggconv.to(ctx)
    etypes = etypes.to(ctx)
1061

1062
    h = ggconv(g, feat, etypes)
1063
1064
1065
    # current we only do shape check
    assert h.shape[-1] == 10

1066

nv-dlasalle's avatar
nv-dlasalle committed
1067
@parametrize_idtype
1068
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
1069
1070
1071
1072
def test_gated_graph_conv_one_etype(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    ggconv = nn.GatedGraphConv(5, 10, 5, 1)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1073
1074
    etypes = th.zeros(g.num_edges())
    feat = F.randn((g.num_nodes(), 5))
1075
1076
1077
1078
1079
1080
1081
1082
1083
    ggconv = ggconv.to(ctx)
    etypes = etypes.to(ctx)

    h = ggconv(g, feat, etypes)
    h2 = ggconv(g, feat)
    # current we only do shape check
    assert F.allclose(h, h2)
    assert h.shape[-1] == 10

1084

nv-dlasalle's avatar
nv-dlasalle committed
1085
@parametrize_idtype
1086
1087
1088
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
1089
1090
def test_nn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
1091
1092
    ctx = F.ctx()
    edge_func = th.nn.Linear(4, 5 * 10)
1093
    nnconv = nn.NNConv(5, 10, edge_func, "mean")
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1094
    feat = F.randn((g.number_of_src_nodes(), 5))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1095
    efeat = F.randn((g.num_edges(), 4))
1096
1097
1098
1099
1100
    nnconv = nnconv.to(ctx)
    h = nnconv(g, feat, efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

1101

nv-dlasalle's avatar
nv-dlasalle committed
1102
@parametrize_idtype
1103
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
1104
1105
1106
def test_nn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1107
    edge_func = th.nn.Linear(4, 5 * 10)
1108
    nnconv = nn.NNConv((5, 2), 10, edge_func, "mean")
1109
1110
    feat = F.randn((g.number_of_src_nodes(), 5))
    feat_dst = F.randn((g.number_of_dst_nodes(), 2))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1111
    efeat = F.randn((g.num_edges(), 4))
1112
1113
1114
1115
1116
    nnconv = nnconv.to(ctx)
    h = nnconv(g, (feat, feat_dst), efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

1117

nv-dlasalle's avatar
nv-dlasalle committed
1118
@parametrize_idtype
1119
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
1120
1121
def test_gmm_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
1122
    ctx = F.ctx()
1123
    gmmconv = nn.GMMConv(5, 10, 3, 4, "mean")
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1124
1125
    feat = F.randn((g.num_nodes(), 5))
    pseudo = F.randn((g.num_edges(), 3))
1126
    gmmconv = gmmconv.to(ctx)
1127
    h = gmmconv(g, feat, pseudo)
1128
1129
1130
    # currently we only do shape check
    assert h.shape[-1] == 10

1131

nv-dlasalle's avatar
nv-dlasalle committed
1132
@parametrize_idtype
1133
1134
1135
@pytest.mark.parametrize(
    "g", get_cases(["bipartite", "block-bipartite"], exclude=["zero-degree"])
)
1136
1137
1138
def test_gmm_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1139
    gmmconv = nn.GMMConv((5, 2), 10, 3, 4, "mean")
1140
1141
    feat = F.randn((g.number_of_src_nodes(), 5))
    feat_dst = F.randn((g.number_of_dst_nodes(), 2))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1142
    pseudo = F.randn((g.num_edges(), 3))
1143
1144
1145
1146
1147
    gmmconv = gmmconv.to(ctx)
    h = gmmconv(g, (feat, feat_dst), pseudo)
    # currently we only do shape check
    assert h.shape[-1] == 10

1148

nv-dlasalle's avatar
nv-dlasalle committed
1149
@parametrize_idtype
1150
1151
1152
1153
1154
@pytest.mark.parametrize("norm_type", ["both", "right", "none"])
@pytest.mark.parametrize(
    "g", get_cases(["homo", "bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 2])
1155
def test_dense_graph_conv(norm_type, g, idtype, out_dim):
1156
    g = g.astype(idtype).to(F.ctx())
1157
    ctx = F.ctx()
1158
    # TODO(minjie): enable the following option after #1385
1159
    adj = g.adj_external(transpose=True, ctx=ctx).to_dense()
1160
1161
    conv = nn.GraphConv(5, out_dim, norm=norm_type, bias=True)
    dense_conv = nn.DenseGraphConv(5, out_dim, norm=norm_type, bias=True)
1162
1163
    dense_conv.weight.data = conv.weight.data
    dense_conv.bias.data = conv.bias.data
1164
    feat = F.randn((g.number_of_src_nodes(), 5))
1165
1166
    conv = conv.to(ctx)
    dense_conv = dense_conv.to(ctx)
1167
1168
    out_conv = conv(g, feat)
    out_dense_conv = dense_conv(adj, feat)
1169
1170
    assert F.allclose(out_conv, out_dense_conv)

1171

nv-dlasalle's avatar
nv-dlasalle committed
1172
@parametrize_idtype
1173
1174
@pytest.mark.parametrize("g", get_cases(["homo", "bipartite"]))
@pytest.mark.parametrize("out_dim", [1, 2])
1175
def test_dense_sage_conv(g, idtype, out_dim):
1176
    g = g.astype(idtype).to(F.ctx())
1177
    ctx = F.ctx()
1178
    adj = g.adj_external(transpose=True, ctx=ctx).to_dense()
1179
    sage = nn.SAGEConv(5, out_dim, "gcn")
1180
    dense_sage = nn.DenseSAGEConv(5, out_dim)
1181
    dense_sage.fc.weight.data = sage.fc_neigh.weight.data
1182
    dense_sage.fc.bias.data = sage.bias.data
1183
1184
1185
    if len(g.ntypes) == 2:
        feat = (
            F.randn((g.number_of_src_nodes(), 5)),
1186
            F.randn((g.number_of_dst_nodes(), 5)),
1187
1188
        )
    else:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1189
        feat = F.randn((g.num_nodes(), 5))
1190
1191
    sage = sage.to(ctx)
    dense_sage = dense_sage.to(ctx)
1192
1193
    out_sage = sage(g, feat)
    out_dense_sage = dense_sage(adj, feat)
1194
1195
    assert F.allclose(out_sage, out_dense_sage), g

1196

nv-dlasalle's avatar
nv-dlasalle committed
1197
@parametrize_idtype
1198
1199
1200
1201
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 2])
1202
def test_edge_conv(g, idtype, out_dim):
1203
    g = g.astype(idtype).to(F.ctx())
1204
    ctx = F.ctx()
1205
    edge_conv = nn.EdgeConv(5, out_dim).to(ctx)
1206
    print(edge_conv)
1207
1208
1209

    # test pickle
    th.save(edge_conv, tmp_buffer)
Mufei Li's avatar
Mufei Li committed
1210

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1211
    h0 = F.randn((g.number_of_src_nodes(), 5))
1212
    h1 = edge_conv(g, h0)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1213
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)
1214

1215

nv-dlasalle's avatar
nv-dlasalle committed
1216
@parametrize_idtype
1217
1218
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
1219
def test_edge_conv_bi(g, idtype, out_dim):
1220
1221
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1222
    edge_conv = nn.EdgeConv(5, out_dim).to(ctx)
1223
    print(edge_conv)
1224
    h0 = F.randn((g.number_of_src_nodes(), 5))
1225
1226
    x0 = F.randn((g.number_of_dst_nodes(), 5))
    h1 = edge_conv(g, (h0, x0))
1227
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)
Mufei Li's avatar
Mufei Li committed
1228

1229

nv-dlasalle's avatar
nv-dlasalle committed
1230
@parametrize_idtype
1231
1232
1233
1234
1235
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
1236
def test_dotgat_conv(g, idtype, out_dim, num_heads):
1237
1238
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1239
    dotgat = nn.DotGatConv(5, out_dim, num_heads)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1240
    feat = F.randn((g.number_of_src_nodes(), 5))
1241
    dotgat = dotgat.to(ctx)
Mufei Li's avatar
Mufei Li committed
1242

1243
1244
    # test pickle
    th.save(dotgat, tmp_buffer)
Mufei Li's avatar
Mufei Li committed
1245

1246
    h = dotgat(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1247
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
1248
    _, a = dotgat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1249
    assert a.shape == (g.num_edges(), num_heads, 1)
1250

1251

nv-dlasalle's avatar
nv-dlasalle committed
1252
@parametrize_idtype
1253
1254
1255
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
1256
def test_dotgat_conv_bi(g, idtype, out_dim, num_heads):
1257
1258
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
1259
    dotgat = nn.DotGatConv((5, 5), out_dim, num_heads)
1260
1261
1262
1263
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
1264
1265
    dotgat = dotgat.to(ctx)
    h = dotgat(g, feat)
1266
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
1267
    _, a = dotgat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1268
    assert a.shape == (g.num_edges(), num_heads, 1)
1269

1270
1271

@pytest.mark.parametrize("out_dim", [1, 2])
1272
def test_dense_cheb_conv(out_dim):
1273
1274
    for k in range(1, 4):
        ctx = F.ctx()
1275
        g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.1))
1276
        g = g.to(F.ctx())
1277
        adj = g.adj_external(transpose=True, ctx=ctx).to_dense()
1278
1279
        cheb = nn.ChebConv(5, out_dim, k, None)
        dense_cheb = nn.DenseChebConv(5, out_dim, k)
1280
        # for i in range(len(cheb.fc)):
Axel Nilsson's avatar
Axel Nilsson committed
1281
        #    dense_cheb.W.data[i] = cheb.fc[i].weight.data.t()
1282
1283
1284
        dense_cheb.W.data = cheb.linear.weight.data.transpose(-1, -2).view(
            k, 5, out_dim
        )
Axel Nilsson's avatar
Axel Nilsson committed
1285
1286
        if cheb.linear.bias is not None:
            dense_cheb.bias.data = cheb.linear.bias.data
1287
        feat = F.randn((100, 5))
1288
1289
        cheb = cheb.to(ctx)
        dense_cheb = dense_cheb.to(ctx)
1290
1291
        out_cheb = cheb(g, feat, [2.0])
        out_dense_cheb = dense_cheb(adj, feat, 2.0)
Axel Nilsson's avatar
Axel Nilsson committed
1292
        print(k, out_cheb, out_dense_cheb)
1293
1294
        assert F.allclose(out_cheb, out_dense_cheb)

1295

1296
1297
def test_sequential():
    ctx = F.ctx()
1298

1299
1300
1301
1302
1303
1304
1305
    # Test single graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat, e_feat):
            graph = graph.local_var()
1306
1307
1308
1309
1310
            graph.ndata["h"] = n_feat
            graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
            n_feat += graph.ndata["h"]
            graph.apply_edges(fn.u_add_v("h", "h", "e"))
            e_feat += graph.edata["e"]
1311
1312
            return n_feat, e_feat

1313
    g = dgl.graph([])
1314
1315
    g.add_nodes(3)
    g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
1316
    g = g.to(F.ctx())
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    n_feat = F.randn((3, 4))
    e_feat = F.randn((9, 4))
    net = net.to(ctx)
    n_feat, e_feat = net(g, n_feat, e_feat)
    assert n_feat.shape == (3, 4)
    assert e_feat.shape == (9, 4)

    # Test multiple graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat):
            graph = graph.local_var()
1332
1333
1334
            graph.ndata["h"] = n_feat
            graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
            n_feat += graph.ndata["h"]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1335
            return n_feat.view(graph.num_nodes() // 2, 2, -1).sum(1)
1336

1337
1338
1339
    g1 = dgl.from_networkx(nx.erdos_renyi_graph(32, 0.05)).to(ctx)
    g2 = dgl.from_networkx(nx.erdos_renyi_graph(16, 0.2)).to(ctx)
    g3 = dgl.from_networkx(nx.erdos_renyi_graph(8, 0.8)).to(ctx)
1340
1341
1342
1343
1344
1345
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    net = net.to(ctx)
    n_feat = F.randn((32, 4))
    n_feat = net([g1, g2, g3], n_feat)
    assert n_feat.shape == (4, 4)

1346

nv-dlasalle's avatar
nv-dlasalle committed
1347
@parametrize_idtype
1348
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
1349
1350
def test_atomic_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
1351
1352
1353
1354
1355
1356
    aconv = nn.AtomicConv(
        interaction_cutoffs=F.tensor([12.0, 12.0]),
        rbf_kernel_means=F.tensor([0.0, 2.0]),
        rbf_kernel_scaling=F.tensor([4.0, 4.0]),
        features_to_use=F.tensor([6.0, 8.0]),
    )
1357
1358
1359
1360
1361

    ctx = F.ctx()
    if F.gpu_ctx():
        aconv = aconv.to(ctx)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1362
1363
    feat = F.randn((g.num_nodes(), 1))
    dist = F.randn((g.num_edges(), 1))
1364
1365

    h = aconv(g, feat, dist)
1366

1367
1368
1369
    # current we only do shape check
    assert h.shape[-1] == 4

1370

nv-dlasalle's avatar
nv-dlasalle committed
1371
@parametrize_idtype
1372
1373
1374
1375
@pytest.mark.parametrize(
    "g", get_cases(["homo", "bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 3])
1376
def test_cf_conv(g, idtype, out_dim):
1377
    g = g.astype(idtype).to(F.ctx())
1378
1379
1380
    cfconv = nn.CFConv(
        node_in_feats=2, edge_in_feats=3, hidden_feats=2, out_feats=out_dim
    )
1381
1382
1383
1384
1385

    ctx = F.ctx()
    if F.gpu_ctx():
        cfconv = cfconv.to(ctx)

1386
    src_feats = F.randn((g.number_of_src_nodes(), 2))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1387
    edge_feats = F.randn((g.num_edges(), 3))
1388
1389
1390
1391
1392
1393
1394
    h = cfconv(g, src_feats, edge_feats)
    # current we only do shape check
    assert h.shape[-1] == out_dim

    # case for bipartite graphs
    dst_feats = F.randn((g.number_of_dst_nodes(), 3))
    h = cfconv(g, (src_feats, dst_feats), edge_feats)
1395
    # current we only do shape check
1396
    assert h.shape[-1] == out_dim
1397

1398

1399
1400
1401
1402
1403
1404
def myagg(alist, dsttype):
    rst = alist[0]
    for i in range(1, len(alist)):
        rst = rst + (i + 1) * alist[i]
    return rst

1405

nv-dlasalle's avatar
nv-dlasalle committed
1406
@parametrize_idtype
1407
1408
@pytest.mark.parametrize("agg", ["sum", "max", "min", "mean", "stack", myagg])
@pytest.mark.parametrize("canonical_keys", [False, True])
1409
def test_hetero_conv(agg, idtype, canonical_keys):
1410
1411
1412
1413
1414
1415
1416
1417
1418
    g = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 0, 2, 1], [1, 2, 1, 3]),
            ("user", "plays", "game"): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]),
            ("store", "sells", "game"): ([0, 0, 1, 1], [0, 3, 1, 2]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
1419
    if not canonical_keys:
1420
1421
1422
1423
1424
1425
1426
1427
        conv = nn.HeteroGraphConv(
            {
                "follows": nn.GraphConv(2, 3, allow_zero_in_degree=True),
                "plays": nn.GraphConv(2, 4, allow_zero_in_degree=True),
                "sells": nn.GraphConv(3, 4, allow_zero_in_degree=True),
            },
            agg,
        )
1428
    else:
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
        conv = nn.HeteroGraphConv(
            {
                ("user", "follows", "user"): nn.GraphConv(
                    2, 3, allow_zero_in_degree=True
                ),
                ("user", "plays", "game"): nn.GraphConv(
                    2, 4, allow_zero_in_degree=True
                ),
                ("store", "sells", "game"): nn.GraphConv(
                    3, 4, allow_zero_in_degree=True
                ),
            },
            agg,
        )
1443

1444
    conv = conv.to(F.ctx())
1445
1446
1447
1448

    # test pickle
    th.save(conv, tmp_buffer)

1449
1450
1451
1452
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))

1453
1454
1455
1456
1457
    h = conv(g, {"user": uf, "game": gf, "store": sf})
    assert set(h.keys()) == {"user", "game"}
    if agg != "stack":
        assert h["user"].shape == (4, 3)
        assert h["game"].shape == (4, 4)
1458
    else:
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
        assert h["user"].shape == (4, 1, 3)
        assert h["game"].shape == (4, 2, 4)

    block = dgl.to_block(
        g.to(F.cpu()), {"user": [0, 1, 2, 3], "game": [0, 1, 2, 3], "store": []}
    ).to(F.ctx())
    h = conv(
        block,
        (
            {"user": uf, "game": gf, "store": sf},
            {"user": uf, "game": gf, "store": sf[0:0]},
        ),
    )
    assert set(h.keys()) == {"user", "game"}
    if agg != "stack":
        assert h["user"].shape == (4, 3)
        assert h["game"].shape == (4, 4)
1476
    else:
1477
1478
1479
1480
1481
1482
1483
1484
        assert h["user"].shape == (4, 1, 3)
        assert h["game"].shape == (4, 2, 4)

    h = conv(block, {"user": uf, "game": gf, "store": sf})
    assert set(h.keys()) == {"user", "game"}
    if agg != "stack":
        assert h["user"].shape == (4, 3)
        assert h["game"].shape == (4, 4)
1485
    else:
1486
1487
        assert h["user"].shape == (4, 1, 3)
        assert h["game"].shape == (4, 2, 4)
1488
1489
1490
1491
1492
1493
1494
1495
1496

    # test with mod args
    class MyMod(th.nn.Module):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.carg2 = 0
            self.s1 = s1
            self.s2 = s2
1497

1498
1499
1500
1501
1502
1503
        def forward(self, g, h, arg1=None, *, arg2=None):
            if arg1 is not None:
                self.carg1 += 1
            if arg2 is not None:
                self.carg2 += 1
            return th.zeros((g.number_of_dst_nodes(), self.s2))
1504

1505
1506
1507
    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
1508
1509
1510
    conv = nn.HeteroGraphConv(
        {"follows": mod1, "plays": mod2, "sells": mod3}, agg
    )
1511
    conv = conv.to(F.ctx())
1512
1513
1514
1515
1516
1517
1518
1519
    mod_args = {"follows": (1,), "plays": (1,)}
    mod_kwargs = {"sells": {"arg2": "abc"}}
    h = conv(
        g,
        {"user": uf, "game": gf, "store": sf},
        mod_args=mod_args,
        mod_kwargs=mod_kwargs,
    )
1520
1521
1522
1523
1524
1525
1526
    assert mod1.carg1 == 1
    assert mod1.carg2 == 0
    assert mod2.carg1 == 1
    assert mod2.carg2 == 0
    assert mod3.carg1 == 0
    assert mod3.carg2 == 1

1527
    # conv on graph without any edges
1528
    for etype in g.etypes:
1529
        g = dgl.remove_edges(g, g.edges(form="eid", etype=etype), etype=etype)
1530
    assert g.num_edges() == 0
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
    h = conv(g, {"user": uf, "game": gf, "store": sf})
    assert set(h.keys()) == {"user", "game"}

    block = dgl.to_block(
        g.to(F.cpu()), {"user": [0, 1, 2, 3], "game": [0, 1, 2, 3], "store": []}
    ).to(F.ctx())
    h = conv(
        block,
        (
            {"user": uf, "game": gf, "store": sf},
            {"user": uf, "game": gf, "store": sf[0:0]},
        ),
    )
    assert set(h.keys()) == {"user", "game"}
1545
1546


1547
@pytest.mark.parametrize("out_dim", [1, 2, 100])
1548
1549
def test_hetero_linear(out_dim):
    in_feats = {
1550
1551
        "user": F.randn((2, 1)),
        ("user", "follows", "user"): F.randn((3, 2)),
1552
1553
    }

1554
1555
1556
    layer = nn.HeteroLinear(
        {"user": 1, ("user", "follows", "user"): 2}, out_dim
    )
1557
1558
    layer = layer.to(F.ctx())
    out_feats = layer(in_feats)
1559
1560
1561
    assert out_feats["user"].shape == (2, out_dim)
    assert out_feats[("user", "follows", "user")].shape == (3, out_dim)

1562

1563
@pytest.mark.parametrize("out_dim", [1, 2, 100])
1564
def test_hetero_embedding(out_dim):
1565
1566
1567
    layer = nn.HeteroEmbedding(
        {"user": 2, ("user", "follows", "user"): 3}, out_dim
    )
1568
1569
1570
    layer = layer.to(F.ctx())

    embeds = layer.weight
1571
1572
    assert embeds["user"].shape == (2, out_dim)
    assert embeds[("user", "follows", "user")].shape == (3, out_dim)
1573

YJ-Zhao's avatar
YJ-Zhao committed
1574
1575
    layer.reset_parameters()
    embeds = layer.weight
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
    assert embeds["user"].shape == (2, out_dim)
    assert embeds[("user", "follows", "user")].shape == (3, out_dim)

    embeds = layer(
        {
            "user": F.tensor([0], dtype=F.int64),
            ("user", "follows", "user"): F.tensor([0, 2], dtype=F.int64),
        }
    )
    assert embeds["user"].shape == (1, out_dim)
    assert embeds[("user", "follows", "user")].shape == (2, out_dim)
YJ-Zhao's avatar
YJ-Zhao committed
1587

1588

nv-dlasalle's avatar
nv-dlasalle committed
1589
@parametrize_idtype
1590
1591
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
Mufei Li's avatar
Mufei Li committed
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
def test_gnnexplainer(g, idtype, out_dim):
    g = g.astype(idtype).to(F.ctx())
    feat = F.randn((g.num_nodes(), 5))

    class Model(th.nn.Module):
        def __init__(self, in_feats, out_feats, graph=False):
            super(Model, self).__init__()
            self.linear = th.nn.Linear(in_feats, out_feats)
            if graph:
                self.pool = nn.AvgPooling()
            else:
                self.pool = None

        def forward(self, graph, feat, eweight=None):
            with graph.local_scope():
                feat = self.linear(feat)
1608
                graph.ndata["h"] = feat
Mufei Li's avatar
Mufei Li committed
1609
                if eweight is None:
1610
                    graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
Mufei Li's avatar
Mufei Li committed
1611
                else:
1612
1613
1614
1615
                    graph.edata["w"] = eweight
                    graph.update_all(
                        fn.u_mul_e("h", "w", "m"), fn.sum("m", "h")
                    )
Mufei Li's avatar
Mufei Li committed
1616
1617

                if self.pool:
1618
                    return self.pool(graph, graph.ndata["h"])
Mufei Li's avatar
Mufei Li committed
1619
                else:
1620
                    return graph.ndata["h"]
Mufei Li's avatar
Mufei Li committed
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633

    # Explain node prediction
    model = Model(5, out_dim)
    model = model.to(F.ctx())
    explainer = nn.GNNExplainer(model, num_hops=1)
    new_center, sg, feat_mask, edge_mask = explainer.explain_node(0, g, feat)

    # Explain graph prediction
    model = Model(5, out_dim, graph=True)
    model = model.to(F.ctx())
    explainer = nn.GNNExplainer(model, num_hops=1)
    feat_mask, edge_mask = explainer.explain_graph(g, feat)

1634
1635
1636
1637
1638

@pytest.mark.parametrize("g", get_cases(["hetero"], exclude=["zero-degree"]))
@pytest.mark.parametrize("idtype", [F.int64])
@pytest.mark.parametrize("input_dim", [5])
@pytest.mark.parametrize("output_dim", [1, 2])
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
def test_heterognnexplainer(g, idtype, input_dim, output_dim):
    g = g.astype(idtype).to(F.ctx())
    device = g.device

    # add self-loop and reverse edges
    transform1 = dgl.transforms.AddSelfLoop(new_etypes=True)
    g = transform1(g)
    transform2 = dgl.transforms.AddReverse(copy_edata=True)
    g = transform2(g)

1649
1650
1651
1652
    feat = {
        ntype: th.zeros((g.num_nodes(ntype), input_dim), device=device)
        for ntype in g.ntypes
    }
1653
1654
1655
1656

    class Model(th.nn.Module):
        def __init__(self, in_dim, num_classes, canonical_etypes, graph=False):
            super(Model, self).__init__()
1657
1658
1659
1660
1661
1662
1663
            self.graph = graph
            self.etype_weights = th.nn.ModuleDict(
                {
                    "_".join(c_etype): th.nn.Linear(in_dim, num_classes)
                    for c_etype in canonical_etypes
                }
            )
1664
1665
1666
1667
1668
1669

        def forward(self, graph, feat, eweight=None):
            with graph.local_scope():
                c_etype_func_dict = {}
                for c_etype in graph.canonical_etypes:
                    src_type, etype, dst_type = c_etype
1670
1671
                    wh = self.etype_weights["_".join(c_etype)](feat[src_type])
                    graph.nodes[src_type].data[f"h_{c_etype}"] = wh
1672
                    if eweight is None:
1673
1674
1675
1676
                        c_etype_func_dict[c_etype] = (
                            fn.copy_u(f"h_{c_etype}", "m"),
                            fn.mean("m", "h"),
                        )
1677
                    else:
1678
                        graph.edges[c_etype].data["w"] = eweight[c_etype]
1679
                        c_etype_func_dict[c_etype] = (
1680
1681
1682
1683
                            fn.u_mul_e(f"h_{c_etype}", "w", "m"),
                            fn.mean("m", "h"),
                        )
                graph.multi_update_all(c_etype_func_dict, "sum")
1684
1685
1686
1687
                if self.graph:
                    hg = 0
                    for ntype in graph.ntypes:
                        if graph.num_nodes(ntype):
1688
                            hg = hg + dgl.mean_nodes(graph, "h", ntype=ntype)
1689
1690
1691

                    return hg
                else:
1692
                    return graph.ndata["h"]
1693
1694
1695
1696
1697
1698

    # Explain node prediction
    model = Model(input_dim, output_dim, g.canonical_etypes)
    model = model.to(F.ctx())
    ntype = g.ntypes[0]
    explainer = nn.explain.HeteroGNNExplainer(model, num_hops=1)
1699
1700
1701
    new_center, sg, feat_mask, edge_mask = explainer.explain_node(
        ntype, 0, g, feat
    )
1702
1703
1704
1705
1706
1707
1708
1709

    # Explain graph prediction
    model = Model(input_dim, output_dim, g.canonical_etypes, graph=True)
    model = model.to(F.ctx())
    explainer = nn.explain.HeteroGNNExplainer(model, num_hops=1)
    feat_mask, edge_mask = explainer.explain_graph(g, feat)


1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
@parametrize_idtype
@pytest.mark.parametrize(
    "g",
    get_cases(
        ["homo"],
        exclude=[
            "zero-degree",
            "homo-zero-degree",
            "has_feature",
            "has_scalar_e_feature",
            "row_sorted",
            "col_sorted",
            "batched",
        ],
    ),
)
@pytest.mark.parametrize("n_classes", [2])
def test_subgraphx(g, idtype, n_classes):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    feat = F.randn((g.num_nodes(), 5))

    class Model(th.nn.Module):
        def __init__(self, in_dim, n_classes):
            super().__init__()
            self.conv = nn.GraphConv(in_dim, n_classes)
            self.pool = nn.AvgPooling()

        def forward(self, g, h):
            h = th.nn.functional.relu(self.conv(g, h))
            return self.pool(g, h)

    model = Model(feat.shape[1], n_classes)
    model = model.to(ctx)
    explainer = nn.SubgraphX(
        model, num_hops=1, shapley_steps=20, num_rollouts=5, coef=2.0
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
    )
    explainer.explain_graph(g, feat, target_class=0)


@pytest.mark.parametrize("g", get_cases(["hetero"], exclude=["zero-degree"]))
@pytest.mark.parametrize("idtype", [F.int64])
@pytest.mark.parametrize("input_dim", [5])
@pytest.mark.parametrize("n_classes", [2])
def test_heterosubgraphx(g, idtype, input_dim, n_classes):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    device = g.device

    # add self-loop and reverse edges
    transform1 = dgl.transforms.AddSelfLoop(new_etypes=True)
    g = transform1(g)
    transform2 = dgl.transforms.AddReverse(copy_edata=True)
    g = transform2(g)

    feat = {
        ntype: th.zeros((g.num_nodes(ntype), input_dim), device=device)
        for ntype in g.ntypes
    }

    class Model(th.nn.Module):
        def __init__(self, in_dim, n_classes, canonical_etypes):
            super(Model, self).__init__()
            self.etype_weights = th.nn.ModuleDict(
                {
                    "_".join(c_etype): th.nn.Linear(in_dim, n_classes)
                    for c_etype in canonical_etypes
                }
            )

        def forward(self, graph, feat):
            with graph.local_scope():
                c_etype_func_dict = {}
                for c_etype in graph.canonical_etypes:
                    src_type, etype, dst_type = c_etype
                    wh = self.etype_weights["_".join(c_etype)](feat[src_type])
                    graph.nodes[src_type].data[f"h_{c_etype}"] = wh
                    c_etype_func_dict[c_etype] = (
                        fn.copy_u(f"h_{c_etype}", "m"),
                        fn.mean("m", "h"),
                    )
                graph.multi_update_all(c_etype_func_dict, "sum")
                hg = 0
                for ntype in graph.ntypes:
                    if graph.num_nodes(ntype):
                        hg = hg + dgl.mean_nodes(graph, "h", ntype=ntype)

                return hg

    model = Model(input_dim, n_classes, g.canonical_etypes)
    model = model.to(ctx)
    explainer = nn.HeteroSubgraphX(
        model, num_hops=1, shapley_steps=20, num_rollouts=5, coef=2.0
1803
1804
1805
1806
    )
    explainer.explain_graph(g, feat, target_class=0)


1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
@parametrize_idtype
@pytest.mark.parametrize(
    "g",
    get_cases(
        ["homo"],
        exclude=[
            "zero-degree",
            "homo-zero-degree",
            "has_feature",
            "has_scalar_e_feature",
            "row_sorted",
            "col_sorted",
        ],
    ),
)
@pytest.mark.parametrize("n_classes", [2])
def test_pgexplainer(g, idtype, n_classes):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    feat = F.randn((g.num_nodes(), 5))
    g.ndata["attr"] = feat

    # add reverse edges
    transform = dgl.transforms.AddReverse(copy_edata=True)
    g = transform(g)

    class Model(th.nn.Module):
1834
        def __init__(self, in_feats, out_feats, graph=False):
1835
            super(Model, self).__init__()
1836
            self.graph = graph
1837
1838
1839
1840
1841
1842
            self.conv = nn.GraphConv(in_feats, out_feats)
            self.fc = th.nn.Linear(out_feats, out_feats)
            th.nn.init.xavier_uniform_(self.fc.weight)

        def forward(self, g, h, embed=False, edge_weight=None):
            h = self.conv(g, h, edge_weight=edge_weight)
1843

1844
            if not self.graph or embed:
1845
1846
1847
                return h

            with g.local_scope():
1848
1849
1850
1851
                g.ndata["h"] = h
                hg = dgl.mean_nodes(g, "h")
                return self.fc(hg)

1852
1853
    # graph explainer
    model = Model(feat.shape[1], n_classes, graph=True)
1854
1855
1856
1857
1858
1859
    model = model.to(ctx)
    explainer = nn.PGExplainer(model, n_classes)
    explainer.train_step(g, g.ndata["attr"], 5.0)

    probs, edge_weight = explainer.explain_graph(g, feat)

1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
    # node explainer
    model = Model(feat.shape[1], n_classes, graph=False)
    model = model.to(ctx)
    explainer = nn.PGExplainer(
        model, n_classes, num_hops=1, explain_graph=False
    )
    explainer.train_step_node(0, g, g.ndata["attr"], 5.0)
    explainer.train_step_node([0, 1], g, g.ndata["attr"], 5.0)
    explainer.train_step_node(th.tensor(0), g, g.ndata["attr"], 5.0)
    explainer.train_step_node(th.tensor([0, 1]), g, g.ndata["attr"], 5.0)

    probs, edge_weight, bg, inverse_indices = explainer.explain_node(0, g, feat)
    probs, edge_weight, bg, inverse_indices = explainer.explain_node(
        [0, 1], g, feat
    )
    probs, edge_weight, bg, inverse_indices = explainer.explain_node(
        th.tensor(0), g, feat
    )
    probs, edge_weight, bg, inverse_indices = explainer.explain_node(
        th.tensor([0, 1]), g, feat
    )

1882

1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
@pytest.mark.parametrize("g", get_cases(["hetero"]))
@pytest.mark.parametrize("idtype", [F.int64])
@pytest.mark.parametrize("input_dim", [5])
@pytest.mark.parametrize("n_classes", [2])
def test_heteropgexplainer(g, idtype, input_dim, n_classes):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    feat = {
        ntype: F.randn((g.num_nodes(ntype), input_dim)) for ntype in g.ntypes
    }

    # add self-loop and reverse edges
    transform1 = dgl.transforms.AddSelfLoop(new_etypes=True)
    g = transform1(g)
    transform2 = dgl.transforms.AddReverse(copy_edata=True)
    g = transform2(g)

    class Model(th.nn.Module):
1901
1902
1903
        def __init__(
            self, in_feats, embed_dim, out_feats, canonical_etypes, graph=True
        ):
1904
            super(Model, self).__init__()
1905
            self.graph = graph
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
            self.conv = nn.HeteroGraphConv(
                {
                    c_etype: nn.GraphConv(in_feats, embed_dim)
                    for c_etype in canonical_etypes
                }
            )
            self.fc = th.nn.Linear(embed_dim, out_feats)

        def forward(self, g, h, embed=False, edge_weight=None):
            if edge_weight is not None:
                mod_kwargs = {
                    etype: {"edge_weight": mask}
                    for etype, mask in edge_weight.items()
                }
                h = self.conv(g, h, mod_kwargs=mod_kwargs)
            else:
                h = self.conv(g, h)

1924
            if not self.graph or embed:
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
                return h

            with g.local_scope():
                g.ndata["h"] = h
                hg = 0
                for ntype in g.ntypes:
                    hg = hg + dgl.mean_nodes(g, "h", ntype=ntype)
                return self.fc(hg)

    embed_dim = input_dim
1935
1936

    # graph explainer
1937
1938
1939
    model = Model(
        input_dim, embed_dim, n_classes, g.canonical_etypes, graph=True
    )
1940
1941
1942
1943
1944
1945
    model = model.to(ctx)
    explainer = nn.HeteroPGExplainer(model, embed_dim)
    explainer.train_step(g, feat, 5.0)

    probs, edge_weight = explainer.explain_graph(g, feat)

1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
    # node explainer
    model = Model(
        input_dim, embed_dim, n_classes, g.canonical_etypes, graph=False
    )
    model = model.to(ctx)
    explainer = nn.HeteroPGExplainer(
        model, embed_dim, num_hops=1, explain_graph=False
    )
    explainer.train_step_node({g.ntypes[0]: [0]}, g, feat, 5.0)
    explainer.train_step_node({g.ntypes[0]: th.tensor([0, 1])}, g, feat, 5.0)

    probs, edge_weight, bg, inverse_indices = explainer.explain_node(
        {g.ntypes[0]: [0]}, g, feat
    )
    probs, edge_weight, bg, inverse_indices = explainer.explain_node(
        {g.ntypes[0]: th.tensor([0, 1])}, g, feat
    )

1964

Mufei Li's avatar
Mufei Li committed
1965
1966
1967
1968
1969
1970
def test_jumping_knowledge():
    ctx = F.ctx()
    num_layers = 2
    num_nodes = 3
    num_feats = 4

1971
1972
1973
    feat_list = [
        th.randn((num_nodes, num_feats)).to(ctx) for _ in range(num_layers)
    ]
Mufei Li's avatar
Mufei Li committed
1974

1975
    model = nn.JumpingKnowledge("cat").to(ctx)
Mufei Li's avatar
Mufei Li committed
1976
1977
1978
    model.reset_parameters()
    assert model(feat_list).shape == (num_nodes, num_layers * num_feats)

1979
    model = nn.JumpingKnowledge("max").to(ctx)
Mufei Li's avatar
Mufei Li committed
1980
1981
1982
    model.reset_parameters()
    assert model(feat_list).shape == (num_nodes, num_feats)

1983
    model = nn.JumpingKnowledge("lstm", num_feats, num_layers).to(ctx)
Mufei Li's avatar
Mufei Li committed
1984
1985
1986
    model.reset_parameters()
    assert model(feat_list).shape == (num_nodes, num_feats)

1987
1988

@pytest.mark.parametrize("op", ["dot", "cos", "ele", "cat"])
Mufei Li's avatar
Mufei Li committed
1989
1990
1991
1992
1993
1994
1995
1996
1997
def test_edge_predictor(op):
    ctx = F.ctx()
    num_pairs = 3
    in_feats = 4
    out_feats = 5
    h_src = th.randn((num_pairs, in_feats)).to(ctx)
    h_dst = th.randn((num_pairs, in_feats)).to(ctx)

    pred = nn.EdgePredictor(op)
1998
    if op in ["dot", "cos"]:
Mufei Li's avatar
Mufei Li committed
1999
        assert pred(h_src, h_dst).shape == (num_pairs, 1)
2000
    elif op == "ele":
Mufei Li's avatar
Mufei Li committed
2001
2002
2003
2004
2005
2006
        assert pred(h_src, h_dst).shape == (num_pairs, in_feats)
    else:
        assert pred(h_src, h_dst).shape == (num_pairs, 2 * in_feats)
    pred = nn.EdgePredictor(op, in_feats, out_feats, bias=True).to(ctx)
    assert pred(h_src, h_dst).shape == (num_pairs, out_feats)

Mufei Li's avatar
Mufei Li committed
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021

def test_ke_score_funcs():
    ctx = F.ctx()
    num_edges = 30
    num_rels = 3
    nfeats = 4

    h_src = th.randn((num_edges, nfeats)).to(ctx)
    h_dst = th.randn((num_edges, nfeats)).to(ctx)
    rels = th.randint(low=0, high=num_rels, size=(num_edges,)).to(ctx)

    score_func = nn.TransE(num_rels=num_rels, feats=nfeats).to(ctx)
    score_func.reset_parameters()
    score_func(h_src, h_dst, rels).shape == (num_edges)

2022
2023
2024
    score_func = nn.TransR(
        num_rels=num_rels, rfeats=nfeats - 1, nfeats=nfeats
    ).to(ctx)
Mufei Li's avatar
Mufei Li committed
2025
2026
2027
2028
    score_func.reset_parameters()
    score_func(h_src, h_dst, rels).shape == (num_edges)


2029
def test_twirls():
2030
    g = dgl.graph(([0, 1, 2, 3, 2, 5], [1, 2, 3, 4, 0, 3]))
2031
    feat = th.ones(6, 10)
2032
2033
2034
2035
    conv = nn.TWIRLSConv(10, 2, 128, prop_step=64)
    res = conv(g, feat)
    assert res.size() == (6, 2)

2036

2037
2038
2039
2040
@pytest.mark.parametrize("feat_size", [4, 32])
@pytest.mark.parametrize(
    "regularizer,num_bases", [(None, None), ("basis", 4), ("bdd", 4)]
)
2041
2042
2043
def test_typed_linear(feat_size, regularizer, num_bases):
    dev = F.ctx()
    num_types = 5
2044
2045
2046
2047
2048
2049
2050
    lin = nn.TypedLinear(
        feat_size,
        feat_size * 2,
        5,
        regularizer=regularizer,
        num_bases=num_bases,
    ).to(dev)
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
    print(lin)
    x = th.randn(100, feat_size).to(dev)
    x_type = th.randint(0, 5, (100,)).to(dev)
    x_type_sorted, idx = th.sort(x_type)
    _, rev_idx = th.sort(idx)
    x_sorted = x[idx]

    # test unsorted
    y = lin(x, x_type)
    assert y.shape == (100, feat_size * 2)
    # test sorted
    y_sorted = lin(x_sorted, x_type_sorted, sorted_by_type=True)
    assert y_sorted.shape == (100, feat_size * 2)

    assert th.allclose(y, y_sorted[rev_idx], atol=1e-4, rtol=1e-4)
2066

2067

nv-dlasalle's avatar
nv-dlasalle committed
2068
@parametrize_idtype
2069
2070
@pytest.mark.parametrize("in_size", [4])
@pytest.mark.parametrize("num_heads", [1])
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
def test_hgt(idtype, in_size, num_heads):
    dev = F.ctx()
    num_etypes = 5
    num_ntypes = 2
    head_size = in_size // num_heads

    g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.01))
    g = g.astype(idtype).to(dev)
    etype = th.tensor([i % num_etypes for i in range(g.num_edges())]).to(dev)
    ntype = th.tensor([i % num_ntypes for i in range(g.num_nodes())]).to(dev)
    x = th.randn(g.num_nodes(), in_size).to(dev)
2082

2083
2084
2085
    m = nn.HGTConv(in_size, head_size, num_heads, num_ntypes, num_etypes).to(
        dev
    )
2086
2087
2088
2089
2090
2091
2092

    y = m(g, x, ntype, etype)
    assert y.shape == (g.num_nodes(), head_size * num_heads)
    # presorted
    sorted_ntype, idx_nt = th.sort(ntype)
    sorted_etype, idx_et = th.sort(etype)
    _, rev_idx = th.sort(idx_nt)
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
    g.ndata["t"] = ntype
    g.ndata["x"] = x
    g.edata["t"] = etype
    sorted_g = dgl.reorder_graph(
        g,
        node_permute_algo="custom",
        edge_permute_algo="custom",
        permute_config={
            "nodes_perm": idx_nt.to(idtype),
            "edges_perm": idx_et.to(idtype),
        },
    )
    print(sorted_g.ndata["t"])
    print(sorted_g.edata["t"])
    sorted_x = sorted_g.ndata["x"]
    sorted_y = m(
        sorted_g, sorted_x, sorted_ntype, sorted_etype, presorted=False
    )
2111
    assert sorted_y.shape == (g.num_nodes(), head_size * num_heads)
dddg617's avatar
dddg617 committed
2112
    # mini-batch
2113
    train_idx = th.randperm(100, dtype=idtype)[:10]
dddg617's avatar
dddg617 committed
2114
    sampler = dgl.dataloading.NeighborSampler([-1])
2115
2116
2117
    train_loader = dgl.dataloading.DataLoader(
        g, train_idx.to(dev), sampler, batch_size=8, device=dev, shuffle=True
    )
dddg617's avatar
dddg617 committed
2118
2119
2120
2121
2122
2123
2124
2125
    (input_nodes, output_nodes, block) = next(iter(train_loader))
    block = block[0]
    x = x[input_nodes.to(th.long)]
    ntype = ntype[input_nodes.to(th.long)]
    edge = block.edata[dgl.EID]
    etype = etype[edge.to(th.long)]
    y = m(block, x, ntype, etype)
    assert y.shape == (block.number_of_dst_nodes(), head_size * num_heads)
2126
    # TODO(minjie): enable the following check
2127
2128
    # assert th.allclose(y, sorted_y[rev_idx], atol=1e-4, rtol=1e-4)

2129

2130
2131
@pytest.mark.parametrize("self_loop", [True, False])
@pytest.mark.parametrize("get_distances", [True, False])
2132
def test_radius_graph(self_loop, get_distances):
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
    pos = th.tensor(
        [
            [0.1, 0.3, 0.4],
            [0.5, 0.2, 0.1],
            [0.7, 0.9, 0.5],
            [0.3, 0.2, 0.5],
            [0.2, 0.8, 0.2],
            [0.9, 0.2, 0.1],
            [0.7, 0.4, 0.4],
            [0.2, 0.1, 0.6],
            [0.5, 0.3, 0.5],
            [0.4, 0.2, 0.6],
        ]
    )
2147
2148
2149
2150
2151
2152
2153
2154
2155

    rg = nn.RadiusGraph(0.3, self_loop=self_loop)

    if get_distances:
        g, dists = rg(pos, get_distances=get_distances)
    else:
        g = rg(pos)

    if self_loop:
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
        src_target = th.tensor(
            [
                0,
                0,
                1,
                2,
                3,
                3,
                3,
                3,
                3,
                4,
                5,
                6,
                6,
                7,
                7,
                7,
                8,
                8,
                8,
                8,
                9,
                9,
                9,
                9,
            ]
        )
        dst_target = th.tensor(
            [
                0,
                3,
                1,
                2,
                0,
                3,
                7,
                8,
                9,
                4,
                5,
                6,
                8,
                3,
                7,
                9,
                3,
                6,
                8,
                9,
                3,
                7,
                8,
                9,
            ]
        )
2212
2213

        if get_distances:
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
            dists_target = th.tensor(
                [
                    [0.0000],
                    [0.2449],
                    [0.0000],
                    [0.0000],
                    [0.2449],
                    [0.0000],
                    [0.1732],
                    [0.2236],
                    [0.1414],
                    [0.0000],
                    [0.0000],
                    [0.0000],
                    [0.2449],
                    [0.1732],
                    [0.0000],
                    [0.2236],
                    [0.2236],
                    [0.2449],
                    [0.0000],
                    [0.1732],
                    [0.1414],
                    [0.2236],
                    [0.1732],
                    [0.0000],
                ]
            )
2242
2243
2244
2245
2246
    else:
        src_target = th.tensor([0, 3, 3, 3, 3, 6, 7, 7, 8, 8, 8, 9, 9, 9])
        dst_target = th.tensor([3, 0, 7, 8, 9, 8, 3, 9, 3, 6, 9, 3, 7, 8])

        if get_distances:
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
            dists_target = th.tensor(
                [
                    [0.2449],
                    [0.2449],
                    [0.1732],
                    [0.2236],
                    [0.1414],
                    [0.2449],
                    [0.1732],
                    [0.2236],
                    [0.2236],
                    [0.2449],
                    [0.1732],
                    [0.1414],
                    [0.2236],
                    [0.1732],
                ]
            )
2265
2266
2267
2268
2269
2270
2271
2272
2273

    src, dst = g.edges()

    assert th.equal(src, src_target)
    assert th.equal(dst, dst_target)

    if get_distances:
        assert th.allclose(dists, dists_target, rtol=1e-03)

2274

nv-dlasalle's avatar
nv-dlasalle committed
2275
@parametrize_idtype
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
def test_group_rev_res(idtype):
    dev = F.ctx()

    num_nodes = 5
    num_edges = 20
    feats = 32
    groups = 2
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    h = th.randn(num_nodes, feats).to(dev)
    conv = nn.GraphConv(feats // groups, feats // groups)
    model = nn.GroupRevRes(conv, groups).to(dev)
2287
2288
    result = model(g, h)
    result.sum().backward()
rudongyu's avatar
rudongyu committed
2289

2290
2291
2292
2293
2294

@pytest.mark.parametrize("in_size", [16, 32])
@pytest.mark.parametrize("hidden_size", [16, 32])
@pytest.mark.parametrize("out_size", [16, 32])
@pytest.mark.parametrize("edge_feat_size", [16, 10, 0])
rudongyu's avatar
rudongyu committed
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
def test_egnn_conv(in_size, hidden_size, out_size, edge_feat_size):
    dev = F.ctx()
    num_nodes = 5
    num_edges = 20
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    h = th.randn(num_nodes, in_size).to(dev)
    x = th.randn(num_nodes, 3).to(dev)
    e = th.randn(num_edges, edge_feat_size).to(dev)
    model = nn.EGNNConv(in_size, hidden_size, out_size, edge_feat_size).to(dev)
    model(g, h, x, e)

2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335

@pytest.mark.parametrize("in_size", [16, 32])
@pytest.mark.parametrize("out_size", [16, 32])
@pytest.mark.parametrize(
    "aggregators",
    [
        ["mean", "max", "sum"],
        ["min", "std", "var"],
        ["moment3", "moment4", "moment5"],
    ],
)
@pytest.mark.parametrize(
    "scalers", [["identity"], ["amplification", "attenuation"]]
)
@pytest.mark.parametrize("delta", [2.5, 7.4])
@pytest.mark.parametrize("dropout", [0.0, 0.1])
@pytest.mark.parametrize("num_towers", [1, 4])
@pytest.mark.parametrize("edge_feat_size", [16, 0])
@pytest.mark.parametrize("residual", [True, False])
def test_pna_conv(
    in_size,
    out_size,
    aggregators,
    scalers,
    delta,
    dropout,
    num_towers,
    edge_feat_size,
    residual,
):
rudongyu's avatar
rudongyu committed
2336
2337
2338
2339
2340
2341
    dev = F.ctx()
    num_nodes = 5
    num_edges = 20
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    h = th.randn(num_nodes, in_size).to(dev)
    e = th.randn(num_edges, edge_feat_size).to(dev)
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
    model = nn.PNAConv(
        in_size,
        out_size,
        aggregators,
        scalers,
        delta,
        dropout,
        num_towers,
        edge_feat_size,
        residual,
    ).to(dev)
rudongyu's avatar
rudongyu committed
2353
    model(g, h, edge_feat=e)
2354

2355
2356
2357
2358
2359
2360
2361

@pytest.mark.parametrize("k", [3, 5])
@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0])
@pytest.mark.parametrize("norm_type", ["sym", "row"])
@pytest.mark.parametrize("clamp", [True, False])
@pytest.mark.parametrize("normalize", [True, False])
@pytest.mark.parametrize("reset", [True, False])
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
def test_label_prop(k, alpha, norm_type, clamp, normalize, reset):
    dev = F.ctx()
    num_nodes = 5
    num_edges = 20
    num_classes = 4
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    labels = th.tensor([0, 2, 1, 3, 0]).long().to(dev)
    ml_labels = th.rand(num_nodes, num_classes).to(dev) > 0.7
    mask = th.tensor([0, 1, 1, 1, 0]).bool().to(dev)
    model = nn.LabelPropagation(k, alpha, norm_type, clamp, normalize, reset)
    model(g, labels, mask)
    # multi-label case
    model(g, ml_labels, mask)

2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387

@pytest.mark.parametrize("in_size", [16])
@pytest.mark.parametrize("out_size", [16, 32])
@pytest.mark.parametrize(
    "aggregators", [["mean", "max", "dir2-av"], ["min", "std", "dir1-dx"]]
)
@pytest.mark.parametrize("scalers", [["amplification", "attenuation"]])
@pytest.mark.parametrize("delta", [2.5])
@pytest.mark.parametrize("edge_feat_size", [16, 0])
def test_dgn_conv(
    in_size, out_size, aggregators, scalers, delta, edge_feat_size
):
2388
2389
2390
2391
2392
2393
    dev = F.ctx()
    num_nodes = 5
    num_edges = 20
    g = dgl.rand_graph(num_nodes, num_edges).to(dev)
    h = th.randn(num_nodes, in_size).to(dev)
    e = th.randn(num_edges, edge_feat_size).to(dev)
2394
    transform = dgl.LapPE(k=3, feat_name="eig")
2395
    g = transform(g)
2396
2397
2398
2399
2400
2401
2402
2403
2404
    eig = g.ndata["eig"]
    model = nn.DGNConv(
        in_size,
        out_size,
        aggregators,
        scalers,
        delta,
        edge_feat_size=edge_feat_size,
    ).to(dev)
2405
2406
    model(g, h, edge_feat=e, eig_vec=eig)

2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
    aggregators_non_eig = [
        aggr for aggr in aggregators if not aggr.startswith("dir")
    ]
    model = nn.DGNConv(
        in_size,
        out_size,
        aggregators_non_eig,
        scalers,
        delta,
        edge_feat_size=edge_feat_size,
    ).to(dev)
2418
    model(g, h, edge_feat=e)
LuckyLiuM's avatar
LuckyLiuM committed
2419

2420

LuckyLiuM's avatar
LuckyLiuM committed
2421
2422
2423
def test_DeepWalk():
    dev = F.ctx()
    g = dgl.graph(([0, 1, 2, 1, 2, 0], [1, 2, 0, 0, 1, 2]))
2424
2425
2426
    model = nn.DeepWalk(
        g, emb_dim=8, walk_length=2, window_size=1, fast_neg=True, sparse=True
    )
LuckyLiuM's avatar
LuckyLiuM committed
2427
    model = model.to(dev)
2428
2429
2430
    dataloader = DataLoader(
        torch.arange(g.num_nodes()), batch_size=16, collate_fn=model.sample
    )
LuckyLiuM's avatar
LuckyLiuM committed
2431
2432
2433
2434
2435
2436
    optim = SparseAdam(model.parameters(), lr=0.01)
    walk = next(iter(dataloader)).to(dev)
    loss = model(walk)
    loss.backward()
    optim.step()

2437
2438
2439
    model = nn.DeepWalk(
        g, emb_dim=8, walk_length=2, window_size=1, fast_neg=False, sparse=False
    )
LuckyLiuM's avatar
LuckyLiuM committed
2440
    model = model.to(dev)
2441
2442
2443
    dataloader = DataLoader(
        torch.arange(g.num_nodes()), batch_size=16, collate_fn=model.sample
    )
LuckyLiuM's avatar
LuckyLiuM committed
2444
2445
2446
2447
2448
    optim = Adam(model.parameters(), lr=0.01)
    walk = next(iter(dataloader)).to(dev)
    loss = model(walk)
    loss.backward()
    optim.step()
LuckyLiuM's avatar
LuckyLiuM committed
2449

2450
2451
2452
2453

@pytest.mark.parametrize("max_degree", [2, 6])
@pytest.mark.parametrize("embedding_dim", [8, 16])
@pytest.mark.parametrize("direction", ["in", "out", "both"])
2454
def test_degree_encoder(max_degree, embedding_dim, direction):
2455
    g1 = dgl.graph(
2456
2457
2458
2459
2460
        (
            th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),
            th.tensor([1, 2, 3, 0, 3, 0, 0, 1]),
        )
    )
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
    g2 = dgl.graph(
        (
            th.tensor([0, 1]),
            th.tensor([1, 0]),
        )
    )
    in_degree = pad_sequence(
        [g1.in_degrees(), g2.in_degrees()], batch_first=True
    )
    out_degree = pad_sequence(
        [g1.out_degrees(), g2.out_degrees()], batch_first=True
    )
2473
    model = nn.DegreeEncoder(max_degree, embedding_dim, direction=direction)
2474
2475
2476
2477
2478
2479
2480
    if direction == "in":
        de_g = model(in_degree)
    elif direction == "out":
        de_g = model(out_degree)
    elif direction == "both":
        de_g = model(th.stack((in_degree, out_degree)))
    assert de_g.shape == (2, 4, embedding_dim)
2481

2482

LuckyLiuM's avatar
LuckyLiuM committed
2483
2484
2485
@parametrize_idtype
def test_MetaPath2Vec(idtype):
    dev = F.ctx()
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
    g = dgl.heterograph(
        {
            ("user", "uc", "company"): ([0, 0, 2, 1, 3], [1, 2, 1, 3, 0]),
            ("company", "cp", "product"): (
                [0, 0, 0, 1, 2, 3],
                [0, 2, 3, 0, 2, 1],
            ),
            ("company", "cu", "user"): ([1, 2, 1, 3, 0], [0, 0, 2, 1, 3]),
            ("product", "pc", "company"): (
                [0, 2, 3, 0, 2, 1],
                [0, 0, 0, 1, 2, 3],
            ),
        },
        idtype=idtype,
        device=dev,
    )
    model = nn.MetaPath2Vec(g, ["uc", "cu"], window_size=1)
LuckyLiuM's avatar
LuckyLiuM committed
2503
2504
2505
    model = model.to(dev)
    embeds = model.node_embed.weight
    assert embeds.shape[0] == g.num_nodes()
2506

2507
2508
2509
2510

@pytest.mark.parametrize("num_layer", [1, 4])
@pytest.mark.parametrize("k", [3, 5])
@pytest.mark.parametrize("lpe_dim", [4, 16])
2511
@pytest.mark.parametrize("n_head", [2, 4])
2512
2513
@pytest.mark.parametrize("batch_norm", [True, False])
@pytest.mark.parametrize("num_post_layer", [0, 1, 2])
2514
def test_LapPosEncoder(
2515
2516
    num_layer, k, lpe_dim, n_head, batch_norm, num_post_layer
):
2517
2518
2519
2520
2521
2522
    ctx = F.ctx()
    num_nodes = 4

    EigVals = th.randn((num_nodes, k)).to(ctx)
    EigVecs = th.randn((num_nodes, k)).to(ctx)

2523
    model = nn.LapPosEncoder(
2524
2525
        "Transformer", num_layer, k, lpe_dim, n_head, batch_norm, num_post_layer
    ).to(ctx)
2526
2527
    assert model(EigVals, EigVecs).shape == (num_nodes, lpe_dim)

2528
    model = nn.LapPosEncoder(
2529
2530
2531
2532
2533
2534
2535
        "DeepSet",
        num_layer,
        k,
        lpe_dim,
        batch_norm=batch_norm,
        num_post_layer=num_post_layer,
    ).to(ctx)
2536
    assert model(EigVals, EigVecs).shape == (num_nodes, lpe_dim)
2537

2538
2539
2540
2541
2542
2543

@pytest.mark.parametrize("feat_size", [128, 512])
@pytest.mark.parametrize("num_heads", [8, 16])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("attn_bias_type", ["add", "mul"])
@pytest.mark.parametrize("attn_drop", [0.1, 0.5])
2544
def test_BiasedMHA(feat_size, num_heads, bias, attn_bias_type, attn_drop):
2545
2546
2547
2548
    ndata = th.rand(16, 100, feat_size)
    attn_bias = th.rand(16, 100, 100, num_heads)
    attn_mask = th.rand(16, 100, 100) < 0.5

2549
    net = nn.BiasedMHA(feat_size, num_heads, bias, attn_bias_type, attn_drop)
2550
2551
2552
    out = net(ndata, attn_bias, attn_mask)

    assert out.shape == (16, 100, feat_size)
2553

2554

Zhiteng Li's avatar
Zhiteng Li committed
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
2579
2580
2581
2582
2583
2584
2585
@pytest.mark.parametrize("edge_update", [True, False])
def test_EGTLayer(edge_update):
    batch_size = 16
    num_nodes = 100
    feat_size, edge_feat_size = 128, 32
    nfeat = th.rand(batch_size, num_nodes, feat_size)
    efeat = th.rand(batch_size, num_nodes, num_nodes, edge_feat_size)
    mask = (th.rand(batch_size, num_nodes, num_nodes) < 0.5) * -1e9

    net = nn.EGTLayer(
        feat_size=feat_size,
        edge_feat_size=edge_feat_size,
        num_heads=8,
        num_virtual_nodes=4,
        edge_update=edge_update,
    )

    if edge_update:
        out_nfeat, out_efeat = net(nfeat, efeat, mask)
        assert out_nfeat.shape == (batch_size, num_nodes, feat_size)
        assert out_efeat.shape == (
            batch_size,
            num_nodes,
            num_nodes,
            edge_feat_size,
        )
    else:
        out_nfeat = net(nfeat, efeat, mask)
        assert out_nfeat.shape == (batch_size, num_nodes, feat_size)


2586
2587
@pytest.mark.parametrize("attn_bias_type", ["add", "mul"])
@pytest.mark.parametrize("norm_first", [True, False])
2588
2589
2590
2591
2592
2593
2594
2595
2596
2597
2598
2599
2600
2601
2602
2603
def test_GraphormerLayer(attn_bias_type, norm_first):
    batch_size = 16
    num_nodes = 100
    feat_size = 512
    num_heads = 8
    nfeat = th.rand(batch_size, num_nodes, feat_size)
    attn_bias = th.rand(batch_size, num_nodes, num_nodes, num_heads)
    attn_mask = th.rand(batch_size, num_nodes, num_nodes) < 0.5

    net = nn.GraphormerLayer(
        feat_size=feat_size,
        hidden_size=2048,
        num_heads=num_heads,
        attn_bias_type=attn_bias_type,
        norm_first=norm_first,
        dropout=0.1,
2604
        attn_dropout=0.1,
2605
        activation=th.nn.ReLU(),
2606
2607
2608
2609
2610
    )
    out = net(nfeat, attn_bias, attn_mask)

    assert out.shape == (batch_size, num_nodes, feat_size)

2611

2612
@pytest.mark.parametrize("max_len", [1, 2])
2613
2614
@pytest.mark.parametrize("feat_dim", [16])
@pytest.mark.parametrize("num_heads", [1, 8])
2615
2616
def test_PathEncoder(max_len, feat_dim, num_heads):
    dev = F.ctx()
2617
    g = dgl.graph(
2618
2619
2620
2621
2622
        (
            th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),
            th.tensor([1, 2, 3, 0, 3, 0, 0, 1]),
        )
    ).to(dev)
2623
2624
2625
2626
    edge_feat = th.rand(g.num_edges(), feat_dim).to(dev)
    edge_feat = th.cat((edge_feat, th.zeros(1, 16).to(dev)), dim=0)
    dist, path = shortest_dist(g, root=None, return_paths=True)
    path_data = edge_feat[path[:, :, :max_len]]
2627
    model = nn.PathEncoder(max_len, feat_dim, num_heads=num_heads).to(dev)
2628
2629
    bias = model(dist.unsqueeze(0), path_data.unsqueeze(0))
    assert bias.shape == (1, 4, 4, num_heads)
2630

2631
2632

@pytest.mark.parametrize("max_dist", [1, 4])
2633
@pytest.mark.parametrize("num_kernels", [4, 16])
2634
@pytest.mark.parametrize("num_heads", [1, 8])
2635
2636
def test_SpatialEncoder(max_dist, num_kernels, num_heads):
    dev = F.ctx()
2637
2638
2639
2640
2641
2642
2643
2644
2645
2646
2647
    # single graph encoding 3d
    num_nodes = 4
    coord = th.rand(1, num_nodes, 3).to(dev)
    node_type = th.tensor([[1, 0, 2, 1]]).to(dev)
    spatial_encoder = nn.SpatialEncoder3d(
        num_kernels=num_kernels, num_heads=num_heads, max_node_type=3
    ).to(dev)
    out = spatial_encoder(coord, node_type=node_type)
    assert out.shape == (1, num_nodes, num_nodes, num_heads)

    # encoding on a batch of graphs
2648
2649
2650
2651
2652
2653
2654
2655
2656
    g1 = dgl.graph(
        (
            th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),
            th.tensor([1, 2, 3, 0, 3, 0, 0, 1]),
        )
    ).to(dev)
    g2 = dgl.graph(
        (th.tensor([0, 1, 2, 3, 2, 5]), th.tensor([1, 2, 3, 4, 0, 3]))
    ).to(dev)
2657
2658
2659
    bsz, max_num_nodes = 2, 6
    # 2d encoding
    dist = -th.ones((bsz, max_num_nodes, max_num_nodes), dtype=th.long).to(dev)
2660
2661
    dist[0, :4, :4] = shortest_dist(g1, root=None, return_paths=False)
    dist[1, :6, :6] = shortest_dist(g2, root=None, return_paths=False)
2662
    model_1 = nn.SpatialEncoder(max_dist, num_heads=num_heads).to(dev)
2663
2664
2665
2666
2667
2668
2669
2670
2671
2672
2673
2674
    encoding = model_1(dist)
    assert encoding.shape == (bsz, max_num_nodes, max_num_nodes, num_heads)
    # 3d encoding
    coord = th.rand(bsz, max_num_nodes, 3).to(dev)
    node_type = th.randint(
        0,
        512,
        (
            bsz,
            max_num_nodes,
        ),
    ).to(dev)
2675
2676
2677
2678
    model_2 = nn.SpatialEncoder3d(num_kernels, num_heads=num_heads).to(dev)
    model_3 = nn.SpatialEncoder3d(
        num_kernels, num_heads=num_heads, max_node_type=512
    ).to(dev)
2679
2680
2681
2682
    encoding3d_1 = model_2(coord)
    encoding3d_2 = model_3(coord, node_type)
    assert encoding3d_1.shape == (bsz, max_num_nodes, max_num_nodes, num_heads)
    assert encoding3d_2.shape == (bsz, max_num_nodes, max_num_nodes, num_heads)