test_nn.py 34.1 KB
Newer Older
1
2
3
4
import torch as th
import networkx as nx
import dgl
import dgl.nn.pytorch as nn
5
import dgl.function as fn
6
import backend as F
7
import pytest
8
9
from test_utils.graph_cases import get_cases, random_graph, random_bipartite, random_dglgraph
from test_utils import parametrize_dtype
10
11
from copy import deepcopy

12
13
import scipy as sp

14
15
16
17
18
def _AXWb(A, X, W, b):
    X = th.matmul(X, W)
    Y = th.matmul(A, X.view(X.shape[0], -1)).view_as(X)
    return Y + b

19
20
def test_graph_conv0():
    g = dgl.DGLGraph(nx.path_graph(3)).to(F.ctx())
21
    ctx = F.ctx()
22
    adj = g.adjacency_matrix(transpose=False, ctx=ctx)
23

24
    conv = nn.GraphConv(5, 2, norm='none', bias=True)
25
    conv = conv.to(ctx)
26
27
    print(conv)
    # test#1: basic
28
    h0 = F.ones((3, 5))
29
    h1 = conv(g, h0)
30
31
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
32
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
33
    # test#2: more-dim
34
    h0 = F.ones((3, 5, 5))
35
    h1 = conv(g, h0)
36
37
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
38
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
39
40

    conv = nn.GraphConv(5, 2)
41
    conv = conv.to(ctx)
42
    # test#3: basic
43
    h0 = F.ones((3, 5))
44
    h1 = conv(g, h0)
45
46
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
47
    # test#4: basic
48
    h0 = F.ones((3, 5, 5))
49
    h1 = conv(g, h0)
50
51
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
52
53

    conv = nn.GraphConv(5, 2)
54
    conv = conv.to(ctx)
55
    # test#3: basic
56
    h0 = F.ones((3, 5))
57
    h1 = conv(g, h0)
58
59
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
60
    # test#4: basic
61
    h0 = F.ones((3, 5, 5))
62
    h1 = conv(g, h0)
63
64
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
65
66
67
68
69

    # test rest_parameters
    old_weight = deepcopy(conv.weight.data)
    conv.reset_parameters()
    new_weight = conv.weight.data
70
    assert not F.allclose(old_weight, new_weight)
71

72
73
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite'], exclude=['zero-degree', 'dglgraph']))
74
75
76
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
77
78
79
def test_graph_conv(idtype, g, norm, weight, bias):
    # Test one tensor input
    g = g.astype(idtype).to(F.ctx())
80
81
    conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias).to(F.ctx())
    ext_w = F.randn((5, 2)).to(F.ctx())
82
83
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
84
85
    h = F.randn((nsrc, 5)).to(F.ctx())
    if weight:
86
        h_out = conv(g, h)
87
    else:
88
89
90
        h_out = conv(g, h, weight=ext_w)
    assert h_out.shape == (ndst, 2)

91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['has_scalar_e_feature'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
def test_graph_conv_e_weight(idtype, g, norm, weight, bias):
    g = g.astype(idtype).to(F.ctx())
    conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias).to(F.ctx())
    ext_w = F.randn((5, 2)).to(F.ctx())
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
    e_w = g.edata['scalar_w']
    if weight:
        h_out = conv(g, h, edge_weight=e_w)
    else:
        h_out = conv(g, h, weight=ext_w, edge_weight=e_w)
    assert h_out.shape == (ndst, 2)

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['has_scalar_e_feature'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
def test_graph_conv_e_weight_norm(idtype, g, norm, weight, bias):
    g = g.astype(idtype).to(F.ctx())
    conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias).to(F.ctx())
    ext_w = F.randn((5, 2)).to(F.ctx())
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
    edgenorm = nn.EdgeWeightNorm(norm=norm)
    norm_weight = edgenorm(g, g.edata['scalar_w'])
    if weight:
        h_out = conv(g, h, edge_weight=norm_weight)
    else:
        h_out = conv(g, h, weight=ext_w, edge_weight=norm_weight)
    assert h_out.shape == (ndst, 2)

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
def test_graph_conv_bi(idtype, g, norm, weight, bias):
    # Test a pair of tensor inputs
    g = g.astype(idtype).to(F.ctx())
    conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias).to(F.ctx())
    ext_w = F.randn((5, 2)).to(F.ctx())
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
    h_dst = F.randn((ndst, 2)).to(F.ctx())
    if weight:
        h_out = conv(g, (h, h_dst))
    else:
        h_out = conv(g, (h, h_dst), weight=ext_w)
    assert h_out.shape == (ndst, 2)
149

150
151
152
153
154
155
156
157
158
159
160
161
def _S2AXWb(A, N, X, W, b):
    X1 = X * N
    X1 = th.matmul(A, X1.view(X1.shape[0], -1))
    X1 = X1 * N
    X2 = X1 * N
    X2 = th.matmul(A, X2.view(X2.shape[0], -1))
    X2 = X2 * N
    X = th.cat([X, X1, X2], dim=-1)
    Y = th.matmul(X, W.rot90())

    return Y + b

162
def test_tagconv():
163
    g = dgl.DGLGraph(nx.path_graph(3))
164
    g = g.to(F.ctx())
165
    ctx = F.ctx()
166
    adj = g.adjacency_matrix(transpose=False, ctx=ctx)
167
168
    norm = th.pow(g.in_degrees().float(), -0.5)

169
    conv = nn.TAGConv(5, 2, bias=True)
170
    conv = conv.to(ctx)
171
172
173
174
    print(conv)

    # test#1: basic
    h0 = F.ones((3, 5))
175
    h1 = conv(g, h0)
176
177
178
179
180
181
182
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    shp = norm.shape + (1,) * (h0.dim() - 1)
    norm = th.reshape(norm, shp).to(ctx)

    assert F.allclose(h1, _S2AXWb(adj, norm, h0, conv.lin.weight, conv.lin.bias))

183
    conv = nn.TAGConv(5, 2)
184
    conv = conv.to(ctx)
185

186
187
    # test#2: basic
    h0 = F.ones((3, 5))
188
    h1 = conv(g, h0)
189
    assert h1.shape[-1] == 2
190

191
    # test reset_parameters
192
193
194
195
196
    old_weight = deepcopy(conv.lin.weight.data)
    conv.reset_parameters()
    new_weight = conv.lin.weight.data
    assert not F.allclose(old_weight, new_weight)

197
def test_set2set():
198
    ctx = F.ctx()
199
    g = dgl.DGLGraph(nx.path_graph(10))
200
    g = g.to(F.ctx())
201
202

    s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
203
    s2s = s2s.to(ctx)
204
205
206
    print(s2s)

    # test#1: basic
207
    h0 = F.randn((g.number_of_nodes(), 5))
208
    h1 = s2s(g, h0)
209
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
210
211

    # test#2: batched graph
212
213
    g1 = dgl.DGLGraph(nx.path_graph(11)).to(F.ctx())
    g2 = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
214
    bg = dgl.batch([g, g1, g2])
215
    h0 = F.randn((bg.number_of_nodes(), 5))
216
    h1 = s2s(bg, h0)
217
218
219
    assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2

def test_glob_att_pool():
220
    ctx = F.ctx()
221
    g = dgl.DGLGraph(nx.path_graph(10))
222
    g = g.to(F.ctx())
223
224

    gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))
225
    gap = gap.to(ctx)
226
227
228
    print(gap)

    # test#1: basic
229
    h0 = F.randn((g.number_of_nodes(), 5))
230
    h1 = gap(g, h0)
231
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
232
233
234

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
235
    h0 = F.randn((bg.number_of_nodes(), 5))
236
    h1 = gap(bg, h0)
237
238
239
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2

def test_simple_pool():
240
    ctx = F.ctx()
241
    g = dgl.DGLGraph(nx.path_graph(15))
242
    g = g.to(F.ctx())
243
244
245
246
247
248
249
250

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
    sort_pool = nn.SortPooling(10) # k = 10
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
251
    h0 = F.randn((g.number_of_nodes(), 5))
252
253
254
255
    sum_pool = sum_pool.to(ctx)
    avg_pool = avg_pool.to(ctx)
    max_pool = max_pool.to(ctx)
    sort_pool = sort_pool.to(ctx)
256
    h1 = sum_pool(g, h0)
257
    assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
258
    h1 = avg_pool(g, h0)
259
    assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
260
    h1 = max_pool(g, h0)
261
    assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
262
    h1 = sort_pool(g, h0)
263
    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.dim() == 2
264
265

    # test#2: batched graph
266
    g_ = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
267
    bg = dgl.batch([g, g_, g, g_, g])
268
    h0 = F.randn((bg.number_of_nodes(), 5))
269
    h1 = sum_pool(bg, h0)
270
271
272
273
274
275
    truth = th.stack([F.sum(h0[:15], 0),
                      F.sum(h0[15:20], 0),
                      F.sum(h0[20:35], 0),
                      F.sum(h0[35:40], 0),
                      F.sum(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
276

277
    h1 = avg_pool(bg, h0)
278
279
280
281
282
283
    truth = th.stack([F.mean(h0[:15], 0),
                      F.mean(h0[15:20], 0),
                      F.mean(h0[20:35], 0),
                      F.mean(h0[35:40], 0),
                      F.mean(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
284

285
    h1 = max_pool(bg, h0)
286
287
288
289
290
291
    truth = th.stack([F.max(h0[:15], 0),
                      F.max(h0[15:20], 0),
                      F.max(h0[20:35], 0),
                      F.max(h0[35:40], 0),
                      F.max(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
292

293
    h1 = sort_pool(bg, h0)
294
295
296
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2

def test_set_trans():
297
    ctx = F.ctx()
298
299
300
301
302
    g = dgl.DGLGraph(nx.path_graph(15))

    st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'sab')
    st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'isab', 3)
    st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)
303
304
305
    st_enc_0 = st_enc_0.to(ctx)
    st_enc_1 = st_enc_1.to(ctx)
    st_dec = st_dec.to(ctx)
306
307
308
    print(st_enc_0, st_enc_1, st_dec)

    # test#1: basic
309
    h0 = F.randn((g.number_of_nodes(), 50))
310
    h1 = st_enc_0(g, h0)
311
    assert h1.shape == h0.shape
312
    h1 = st_enc_1(g, h0)
313
    assert h1.shape == h0.shape
314
    h2 = st_dec(g, h1)
315
    assert h2.shape[0] == 1 and h2.shape[1] == 200 and h2.dim() == 2
316
317
318
319
320

    # test#2: batched graph
    g1 = dgl.DGLGraph(nx.path_graph(5))
    g2 = dgl.DGLGraph(nx.path_graph(10))
    bg = dgl.batch([g, g1, g2])
321
    h0 = F.randn((bg.number_of_nodes(), 50))
322
    h1 = st_enc_0(bg, h0)
323
    assert h1.shape == h0.shape
324
    h1 = st_enc_1(bg, h0)
325
326
    assert h1.shape == h0.shape

327
    h2 = st_dec(bg, h1)
328
329
    assert h2.shape[0] == 3 and h2.shape[1] == 200 and h2.dim() == 2

Minjie Wang's avatar
Minjie Wang committed
330
331
332
333
def test_rgcn():
    ctx = F.ctx()
    etype = []
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
334
    g = g.to(F.ctx())
Minjie Wang's avatar
Minjie Wang committed
335
336
337
338
339
340
341
342
343
    # 5 etypes
    R = 5
    for i in range(g.number_of_edges()):
        etype.append(i % 5)
    B = 2
    I = 10
    O = 8

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
344
345
346
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
347
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
Minjie Wang's avatar
Minjie Wang committed
348
349
350
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r)
351
    h_new_low = rgc_basis_low(g, h, r)
Minjie Wang's avatar
Minjie Wang committed
352
    assert list(h_new.shape) == [100, O]
353
354
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
355
356

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
357
358
    rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
    rgc_bdd_low.weight = rgc_bdd.weight
359
    rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
Minjie Wang's avatar
Minjie Wang committed
360
361
362
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_bdd(g, h, r)
363
    h_new_low = rgc_bdd_low(g, h, r)
Minjie Wang's avatar
Minjie Wang committed
364
    assert list(h_new.shape) == [100, O]
365
366
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
367
368

    # with norm
xiang song(charlie.song)'s avatar
xiang song(charlie.song) committed
369
    norm = th.rand((g.number_of_edges(), 1)).to(ctx)
Minjie Wang's avatar
Minjie Wang committed
370
371

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
372
373
374
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
375
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
Minjie Wang's avatar
Minjie Wang committed
376
377
378
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r, norm)
379
    h_new_low = rgc_basis_low(g, h, r, norm)
Minjie Wang's avatar
Minjie Wang committed
380
    assert list(h_new.shape) == [100, O]
381
382
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
383
384

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
385
386
    rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
    rgc_bdd_low.weight = rgc_bdd.weight
387
    rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
Minjie Wang's avatar
Minjie Wang committed
388
389
390
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_bdd(g, h, r, norm)
391
    h_new_low = rgc_bdd_low(g, h, r, norm)
Minjie Wang's avatar
Minjie Wang committed
392
    assert list(h_new.shape) == [100, O]
393
394
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
395
396
397

    # id input
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
398
399
400
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
401
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
Minjie Wang's avatar
Minjie Wang committed
402
403
404
    h = th.randint(0, I, (100,)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r)
405
    h_new_low = rgc_basis_low(g, h, r)
Minjie Wang's avatar
Minjie Wang committed
406
    assert list(h_new.shape) == [100, O]
407
408
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
409

410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448

def test_rgcn_sorted():
    ctx = F.ctx()
    etype = []
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    g = g.to(F.ctx())
    # 5 etypes
    R = 5
    etype = [200, 200, 200, 200, 200]
    B = 2
    I = 10
    O = 8

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
    h = th.randn((100, I)).to(ctx)
    r = etype
    h_new = rgc_basis(g, h, r)
    h_new_low = rgc_basis_low(g, h, r)
    assert list(h_new.shape) == [100, O]
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
    rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
    rgc_bdd_low.weight = rgc_bdd.weight
    rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
    h = th.randn((100, I)).to(ctx)
    r = etype
    h_new = rgc_bdd(g, h, r)
    h_new_low = rgc_bdd_low(g, h, r)
    assert list(h_new.shape) == [100, O]
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)

    # with norm
xiang song(charlie.song)'s avatar
xiang song(charlie.song) committed
449
    norm = th.rand((g.number_of_edges(), 1)).to(ctx)
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
    h = th.randn((100, I)).to(ctx)
    r = etype
    h_new = rgc_basis(g, h, r, norm)
    h_new_low = rgc_basis_low(g, h, r, norm)
    assert list(h_new.shape) == [100, O]
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
    rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
    rgc_bdd_low.weight = rgc_bdd.weight
    rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
    h = th.randn((100, I)).to(ctx)
    r = etype
    h_new = rgc_bdd(g, h, r, norm)
    h_new_low = rgc_bdd_low(g, h, r, norm)
    assert list(h_new.shape) == [100, O]
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)

    # id input
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
    h = th.randint(0, I, (100,)).to(ctx)
    r = etype
    h_new = rgc_basis(g, h, r)
    h_new_low = rgc_basis_low(g, h, r)
    assert list(h_new.shape) == [100, O]
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)


491
@parametrize_dtype
492
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
493
494
def test_gat_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
495
496
    ctx = F.ctx()
    gat = nn.GATConv(5, 2, 4)
497
    feat = F.randn((g.number_of_nodes(), 5))
498
    gat = gat.to(ctx)
499
    h = gat(g, feat)
500
    assert h.shape == (g.number_of_nodes(), 4, 2)
501
502
    _, a = gat(g, feat, get_attention=True)
    assert a.shape == (g.number_of_edges(), 4, 1)
503

504
@parametrize_dtype
505
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
506
507
508
def test_gat_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
509
510
    gat = nn.GATConv(5, 2, 4)
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
511
512
    gat = gat.to(ctx)
    h = gat(g, feat)
513
    assert h.shape == (g.number_of_dst_nodes(), 4, 2)
514
515
    _, a = gat(g, feat, get_attention=True)
    assert a.shape == (g.number_of_edges(), 4, 1)
516

517
@parametrize_dtype
518
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
519
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm'])
520
521
def test_sage_conv(idtype, g, aggre_type):
    g = g.astype(idtype).to(F.ctx())
522
    sage = nn.SAGEConv(5, 10, aggre_type)
523
524
    feat = F.randn((g.number_of_nodes(), 5))
    sage = sage.to(F.ctx())
525
526
527
    h = sage(g, feat)
    assert h.shape[-1] == 10

528
@parametrize_dtype
529
@pytest.mark.parametrize('g', get_cases(['bipartite']))
530
531
532
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm'])
def test_sage_conv_bi(idtype, g, aggre_type):
    g = g.astype(idtype).to(F.ctx())
533
534
    dst_dim = 5 if aggre_type != 'gcn' else 10
    sage = nn.SAGEConv((10, dst_dim), 2, aggre_type)
535
536
    feat = (F.randn((g.number_of_src_nodes(), 10)), F.randn((g.number_of_dst_nodes(), dst_dim)))
    sage = sage.to(F.ctx())
537
538
    h = sage(g, feat)
    assert h.shape[-1] == 2
539
    assert h.shape[0] == g.number_of_dst_nodes()
540

541
542
543
@parametrize_dtype
def test_sage_conv2(idtype):
    # TODO: add test for blocks
Mufei Li's avatar
Mufei Li committed
544
    # Test the case for graphs without edges
545
    g = dgl.heterograph({('_U', '_E', '_V'): ([], [])}, {'_U': 5, '_V': 3})
546
547
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
548
    sage = nn.SAGEConv((3, 3), 2, 'gcn')
Mufei Li's avatar
Mufei Li committed
549
550
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    sage = sage.to(ctx)
551
    h = sage(g, (F.copy_to(feat[0], F.ctx()), F.copy_to(feat[1], F.ctx())))
Mufei Li's avatar
Mufei Li committed
552
553
554
    assert h.shape[-1] == 2
    assert h.shape[0] == 3
    for aggre_type in ['mean', 'pool', 'lstm']:
555
        sage = nn.SAGEConv((3, 1), 2, aggre_type)
Mufei Li's avatar
Mufei Li committed
556
557
558
559
560
561
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        sage = sage.to(ctx)
        h = sage(g, feat)
        assert h.shape[-1] == 2
        assert h.shape[0] == 3

562
563
564
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_sgc_conv(g, idtype):
565
    ctx = F.ctx()
566
    g = g.astype(idtype).to(ctx)
567
568
    # not cached
    sgc = nn.SGConv(5, 10, 3)
569
    feat = F.randn((g.number_of_nodes(), 5))
570
    sgc = sgc.to(ctx)
571

572
    h = sgc(g, feat)
573
574
575
576
    assert h.shape[-1] == 10

    # cached
    sgc = nn.SGConv(5, 10, 3, True)
577
    sgc = sgc.to(ctx)
578
579
    h_0 = sgc(g, feat)
    h_1 = sgc(g, feat + 1)
580
581
582
    assert F.allclose(h_0, h_1)
    assert h_0.shape[-1] == 10

583
584
585
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_appnp_conv(g, idtype):
586
    ctx = F.ctx()
587
    g = g.astype(idtype).to(ctx)
588
    appnp = nn.APPNPConv(10, 0.1)
589
    feat = F.randn((g.number_of_nodes(), 5))
590
    appnp = appnp.to(ctx)
591

592
    h = appnp(g, feat)
593
594
    assert h.shape[-1] == 5

595
@parametrize_dtype
596
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
597
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
598
599
def test_gin_conv(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
600
601
602
603
604
    ctx = F.ctx()
    gin = nn.GINConv(
        th.nn.Linear(5, 12),
        aggregator_type
    )
605
    feat = F.randn((g.number_of_nodes(), 5))
606
607
    gin = gin.to(ctx)
    h = gin(g, feat)
608
    assert h.shape == (g.number_of_nodes(), 12)
609

610
@parametrize_dtype
611
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
612
613
614
615
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
def test_gin_conv_bi(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
616
617
618
619
    gin = nn.GINConv(
        th.nn.Linear(5, 12),
        aggregator_type
    )
620
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
621
622
    gin = gin.to(ctx)
    h = gin(g, feat)
623
    assert h.shape == (g.number_of_dst_nodes(), 12)
624

625
@parametrize_dtype
626
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
627
628
def test_agnn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
629
630
    ctx = F.ctx()
    agnn = nn.AGNNConv(1)
631
    feat = F.randn((g.number_of_nodes(), 5))
632
    agnn = agnn.to(ctx)
633
    h = agnn(g, feat)
634
    assert h.shape == (g.number_of_nodes(), 5)
635

636
@parametrize_dtype
637
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
638
639
640
def test_agnn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
641
    agnn = nn.AGNNConv(1)
642
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
643
644
    agnn = agnn.to(ctx)
    h = agnn(g, feat)
645
    assert h.shape == (g.number_of_dst_nodes(), 5)
646

647
648
649
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_gated_graph_conv(g, idtype):
650
    ctx = F.ctx()
651
    g = g.astype(idtype).to(ctx)
652
653
    ggconv = nn.GatedGraphConv(5, 10, 5, 3)
    etypes = th.arange(g.number_of_edges()) % 3
654
    feat = F.randn((g.number_of_nodes(), 5))
655
656
    ggconv = ggconv.to(ctx)
    etypes = etypes.to(ctx)
657

658
    h = ggconv(g, feat, etypes)
659
660
661
    # current we only do shape check
    assert h.shape[-1] == 10

662
@parametrize_dtype
663
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
664
665
def test_nn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
666
667
668
    ctx = F.ctx()
    edge_func = th.nn.Linear(4, 5 * 10)
    nnconv = nn.NNConv(5, 10, edge_func, 'mean')
669
    feat = F.randn((g.number_of_nodes(), 5))
670
671
672
673
674
675
    efeat = F.randn((g.number_of_edges(), 4))
    nnconv = nnconv.to(ctx)
    h = nnconv(g, feat, efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

676
@parametrize_dtype
677
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
678
679
680
def test_nn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
681
682
    edge_func = th.nn.Linear(4, 5 * 10)
    nnconv = nn.NNConv((5, 2), 10, edge_func, 'mean')
683
684
    feat = F.randn((g.number_of_src_nodes(), 5))
    feat_dst = F.randn((g.number_of_dst_nodes(), 2))
685
686
687
688
689
690
    efeat = F.randn((g.number_of_edges(), 4))
    nnconv = nnconv.to(ctx)
    h = nnconv(g, (feat, feat_dst), efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

691
@parametrize_dtype
692
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
693
694
def test_gmm_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
695
696
    ctx = F.ctx()
    gmmconv = nn.GMMConv(5, 10, 3, 4, 'mean')
697
    feat = F.randn((g.number_of_nodes(), 5))
698
    pseudo = F.randn((g.number_of_edges(), 3))
699
    gmmconv = gmmconv.to(ctx)
700
    h = gmmconv(g, feat, pseudo)
701
702
703
    # currently we only do shape check
    assert h.shape[-1] == 10

704
@parametrize_dtype
705
@pytest.mark.parametrize('g', get_cases(['bipartite', 'block-bipartite'], exclude=['zero-degree']))
706
707
708
def test_gmm_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
709
    gmmconv = nn.GMMConv((5, 2), 10, 3, 4, 'mean')
710
711
    feat = F.randn((g.number_of_src_nodes(), 5))
    feat_dst = F.randn((g.number_of_dst_nodes(), 2))
712
713
714
715
716
717
    pseudo = F.randn((g.number_of_edges(), 3))
    gmmconv = gmmconv.to(ctx)
    h = gmmconv(g, (feat, feat_dst), pseudo)
    # currently we only do shape check
    assert h.shape[-1] == 10

718
@parametrize_dtype
719
@pytest.mark.parametrize('norm_type', ['both', 'right', 'none'])
720
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite'], exclude=['zero-degree']))
721
722
def test_dense_graph_conv(norm_type, g, idtype):
    g = g.astype(idtype).to(F.ctx())
723
    ctx = F.ctx()
724
    # TODO(minjie): enable the following option after #1385
725
    adj = g.adjacency_matrix(transpose=False, ctx=ctx).to_dense()
726
727
    conv = nn.GraphConv(5, 2, norm=norm_type, bias=True)
    dense_conv = nn.DenseGraphConv(5, 2, norm=norm_type, bias=True)
728
729
    dense_conv.weight.data = conv.weight.data
    dense_conv.bias.data = conv.bias.data
730
    feat = F.randn((g.number_of_src_nodes(), 5))
731
732
    conv = conv.to(ctx)
    dense_conv = dense_conv.to(ctx)
733
734
    out_conv = conv(g, feat)
    out_dense_conv = dense_conv(adj, feat)
735
736
    assert F.allclose(out_conv, out_dense_conv)

737
@parametrize_dtype
738
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite']))
739
740
def test_dense_sage_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
741
    ctx = F.ctx()
742
    adj = g.adjacency_matrix(transpose=False, ctx=ctx).to_dense()
743
    sage = nn.SAGEConv(5, 2, 'gcn')
744
745
746
    dense_sage = nn.DenseSAGEConv(5, 2)
    dense_sage.fc.weight.data = sage.fc_neigh.weight.data
    dense_sage.fc.bias.data = sage.fc_neigh.bias.data
747
748
749
750
751
752
753
    if len(g.ntypes) == 2:
        feat = (
            F.randn((g.number_of_src_nodes(), 5)),
            F.randn((g.number_of_dst_nodes(), 5))
        )
    else:
        feat = F.randn((g.number_of_nodes(), 5))
754
755
    sage = sage.to(ctx)
    dense_sage = dense_sage.to(ctx)
756
757
    out_sage = sage(g, feat)
    out_dense_sage = dense_sage(adj, feat)
758
759
    assert F.allclose(out_sage, out_dense_sage), g

760
@parametrize_dtype
761
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
762
763
def test_edge_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
764
765
766
    ctx = F.ctx()
    edge_conv = nn.EdgeConv(5, 2).to(ctx)
    print(edge_conv)
767
768
769
    h0 = F.randn((g.number_of_nodes(), 5))
    h1 = edge_conv(g, h0)
    assert h1.shape == (g.number_of_nodes(), 2)
770

771
@parametrize_dtype
772
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
773
774
775
776
777
def test_edge_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    edge_conv = nn.EdgeConv(5, 2).to(ctx)
    print(edge_conv)
778
    h0 = F.randn((g.number_of_src_nodes(), 5))
779
780
    x0 = F.randn((g.number_of_dst_nodes(), 5))
    h1 = edge_conv(g, (h0, x0))
781
    assert h1.shape == (g.number_of_dst_nodes(), 2)
782
783
784
785
786

def test_dense_cheb_conv():
    for k in range(1, 4):
        ctx = F.ctx()
        g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
787
        g = g.to(F.ctx())
788
        adj = g.adjacency_matrix(transpose=False, ctx=ctx).to_dense()
Axel Nilsson's avatar
Axel Nilsson committed
789
        cheb = nn.ChebConv(5, 2, k, None)
790
        dense_cheb = nn.DenseChebConv(5, 2, k)
Axel Nilsson's avatar
Axel Nilsson committed
791
792
793
794
795
        #for i in range(len(cheb.fc)):
        #    dense_cheb.W.data[i] = cheb.fc[i].weight.data.t()
        dense_cheb.W.data = cheb.linear.weight.data.transpose(-1, -2).view(k, 5, 2)
        if cheb.linear.bias is not None:
            dense_cheb.bias.data = cheb.linear.bias.data
796
        feat = F.randn((100, 5))
797
798
        cheb = cheb.to(ctx)
        dense_cheb = dense_cheb.to(ctx)
799
800
        out_cheb = cheb(g, feat, [2.0])
        out_dense_cheb = dense_cheb(adj, feat, 2.0)
Axel Nilsson's avatar
Axel Nilsson committed
801
        print(k, out_cheb, out_dense_cheb)
802
803
        assert F.allclose(out_cheb, out_dense_cheb)

804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
def test_sequential():
    ctx = F.ctx()
    # Test single graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat, e_feat):
            graph = graph.local_var()
            graph.ndata['h'] = n_feat
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            n_feat += graph.ndata['h']
            graph.apply_edges(fn.u_add_v('h', 'h', 'e'))
            e_feat += graph.edata['e']
            return n_feat, e_feat

    g = dgl.DGLGraph()
    g.add_nodes(3)
    g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
823
    g = g.to(F.ctx())
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    n_feat = F.randn((3, 4))
    e_feat = F.randn((9, 4))
    net = net.to(ctx)
    n_feat, e_feat = net(g, n_feat, e_feat)
    assert n_feat.shape == (3, 4)
    assert e_feat.shape == (9, 4)

    # Test multiple graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat):
            graph = graph.local_var()
            graph.ndata['h'] = n_feat
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            n_feat += graph.ndata['h']
            return n_feat.view(graph.number_of_nodes() // 2, 2, -1).sum(1)

844
845
846
    g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05)).to(F.ctx())
    g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2)).to(F.ctx())
    g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8)).to(F.ctx())
847
848
849
850
851
852
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    net = net.to(ctx)
    n_feat = F.randn((32, 4))
    n_feat = net([g1, g2, g3], n_feat)
    assert n_feat.shape == (4, 4)

853
854
855
856
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_atomic_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
857
858
859
860
861
862
863
864
865
    aconv = nn.AtomicConv(interaction_cutoffs=F.tensor([12.0, 12.0]),
                          rbf_kernel_means=F.tensor([0.0, 2.0]),
                          rbf_kernel_scaling=F.tensor([4.0, 4.0]),
                          features_to_use=F.tensor([6.0, 8.0]))

    ctx = F.ctx()
    if F.gpu_ctx():
        aconv = aconv.to(ctx)

866
    feat = F.randn((g.number_of_nodes(), 1))
867
868
869
870
871
872
    dist = F.randn((g.number_of_edges(), 1))

    h = aconv(g, feat, dist)
    # current we only do shape check
    assert h.shape[-1] == 4

873
874
875
876
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_cf_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
877
878
879
880
881
882
883
884
885
    cfconv = nn.CFConv(node_in_feats=2,
                       edge_in_feats=3,
                       hidden_feats=2,
                       out_feats=3)

    ctx = F.ctx()
    if F.gpu_ctx():
        cfconv = cfconv.to(ctx)

886
    node_feats = F.randn((g.number_of_nodes(), 2))
887
888
889
    edge_feats = F.randn((g.number_of_edges(), 3))
    h = cfconv(g, node_feats, edge_feats)
    # current we only do shape check
890
    assert h.shape[-1] == 3
891

892
893
894
895
896
897
def myagg(alist, dsttype):
    rst = alist[0]
    for i in range(1, len(alist)):
        rst = rst + (i + 1) * alist[i]
    return rst

898
@parametrize_dtype
899
@pytest.mark.parametrize('agg', ['sum', 'max', 'min', 'mean', 'stack', myagg])
900
def test_hetero_conv(agg, idtype):
901
    g = dgl.heterograph({
902
903
904
        ('user', 'follows', 'user'): ([0, 0, 2, 1], [1, 2, 1, 3]),
        ('user', 'plays', 'game'): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]),
        ('store', 'sells', 'game'): ([0, 0, 1, 1], [0, 3, 1, 2])},
905
        idtype=idtype, device=F.ctx())
906
    conv = nn.HeteroGraphConv({
907
908
909
        'follows': nn.GraphConv(2, 3, allow_zero_in_degree=True),
        'plays': nn.GraphConv(2, 4, allow_zero_in_degree=True),
        'sells': nn.GraphConv(3, 4, allow_zero_in_degree=True)},
910
        agg)
911
    conv = conv.to(F.ctx())
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))

    h = conv(g, {'user': uf})
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 1, 4)

    h = conv(g, {'user': uf, 'store': sf})
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 2, 4)

    h = conv(g, {'store': sf})
    assert set(h.keys()) == {'game'}
    if agg != 'stack':
        assert h['game'].shape == (4, 4)
    else:
        assert h['game'].shape == (4, 1, 4)

    # test with pair input
    conv = nn.HeteroGraphConv({
943
944
945
        'follows': nn.SAGEConv(2, 3, 'mean'),
        'plays': nn.SAGEConv((2, 4), 4, 'mean'),
        'sells': nn.SAGEConv(3, 4, 'mean')},
946
        agg)
947
    conv = conv.to(F.ctx())
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987

    h = conv(g, ({'user': uf}, {'user' : uf, 'game' : gf}))
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 1, 4)

    # pair input requires both src and dst type features to be provided
    h = conv(g, ({'user': uf}, {'game' : gf}))
    assert set(h.keys()) == {'game'}
    if agg != 'stack':
        assert h['game'].shape == (4, 4)
    else:
        assert h['game'].shape == (4, 1, 4)

    # test with mod args
    class MyMod(th.nn.Module):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.carg2 = 0
            self.s1 = s1
            self.s2 = s2
        def forward(self, g, h, arg1=None, *, arg2=None):
            if arg1 is not None:
                self.carg1 += 1
            if arg2 is not None:
                self.carg2 += 1
            return th.zeros((g.number_of_dst_nodes(), self.s2))
    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
    conv = nn.HeteroGraphConv({
        'follows': mod1,
        'plays': mod2,
        'sells': mod3},
        agg)
988
    conv = conv.to(F.ctx())
989
990
991
992
993
994
995
996
997
998
    mod_args = {'follows' : (1,), 'plays' : (1,)}
    mod_kwargs = {'sells' : {'arg2' : 'abc'}}
    h = conv(g, {'user' : uf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs)
    assert mod1.carg1 == 1
    assert mod1.carg2 == 0
    assert mod2.carg1 == 1
    assert mod2.carg2 == 0
    assert mod3.carg1 == 0
    assert mod3.carg2 == 1

999
1000
if __name__ == '__main__':
    test_graph_conv()
1001
1002
    test_graph_conv_e_weight()
    test_graph_conv_e_weight_norm()
1003
1004
1005
1006
    test_set2set()
    test_glob_att_pool()
    test_simple_pool()
    test_set_trans()
Minjie Wang's avatar
Minjie Wang committed
1007
    test_rgcn()
xiang song(charlie.song)'s avatar
xiang song(charlie.song) committed
1008
    test_rgcn_sorted()
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
    test_tagconv()
    test_gat_conv()
    test_sage_conv()
    test_sgc_conv()
    test_appnp_conv()
    test_gin_conv()
    test_agnn_conv()
    test_gated_graph_conv()
    test_nn_conv()
    test_gmm_conv()
    test_dense_graph_conv()
    test_dense_sage_conv()
    test_dense_cheb_conv()
1022
    test_sequential()
1023
    test_atomic_conv()
1024
    test_cf_conv()