test_nn.py 29.3 KB
Newer Older
1
2
3
4
import torch as th
import networkx as nx
import dgl
import dgl.nn.pytorch as nn
5
import dgl.function as fn
6
import backend as F
7
import pytest
8
9
from test_utils.graph_cases import get_cases, random_graph, random_bipartite, random_dglgraph
from test_utils import parametrize_dtype
10
11
from copy import deepcopy

12
13
14
import numpy as np
import scipy as sp

15
16
17
18
19
def _AXWb(A, X, W, b):
    X = th.matmul(X, W)
    Y = th.matmul(A, X.view(X.shape[0], -1)).view_as(X)
    return Y + b

20
21
def test_graph_conv0():
    g = dgl.DGLGraph(nx.path_graph(3)).to(F.ctx())
22
23
    ctx = F.ctx()
    adj = g.adjacency_matrix(ctx=ctx)
24

25
    conv = nn.GraphConv(5, 2, norm='none', bias=True)
26
    conv = conv.to(ctx)
27
28
    print(conv)
    # test#1: basic
29
    h0 = F.ones((3, 5))
30
    h1 = conv(g, h0)
31
32
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
33
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
34
    # test#2: more-dim
35
    h0 = F.ones((3, 5, 5))
36
    h1 = conv(g, h0)
37
38
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
39
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
40
41

    conv = nn.GraphConv(5, 2)
42
    conv = conv.to(ctx)
43
    # test#3: basic
44
    h0 = F.ones((3, 5))
45
    h1 = conv(g, h0)
46
47
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
48
    # test#4: basic
49
    h0 = F.ones((3, 5, 5))
50
    h1 = conv(g, h0)
51
52
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
53
54

    conv = nn.GraphConv(5, 2)
55
    conv = conv.to(ctx)
56
    # test#3: basic
57
    h0 = F.ones((3, 5))
58
    h1 = conv(g, h0)
59
60
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
61
    # test#4: basic
62
    h0 = F.ones((3, 5, 5))
63
    h1 = conv(g, h0)
64
65
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
66
67
68
69
70

    # test rest_parameters
    old_weight = deepcopy(conv.weight.data)
    conv.reset_parameters()
    new_weight = conv.weight.data
71
    assert not F.allclose(old_weight, new_weight)
72

73
74
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite'], exclude=['zero-degree', 'dglgraph']))
75
76
77
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
78
79
80
def test_graph_conv(idtype, g, norm, weight, bias):
    # Test one tensor input
    g = g.astype(idtype).to(F.ctx())
81
82
    conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias).to(F.ctx())
    ext_w = F.randn((5, 2)).to(F.ctx())
83
84
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
85
86
    h = F.randn((nsrc, 5)).to(F.ctx())
    if weight:
87
        h_out = conv(g, h)
88
    else:
89
90
91
        h_out = conv(g, h, weight=ext_w)
    assert h_out.shape == (ndst, 2)

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
def test_graph_conv_bi(idtype, g, norm, weight, bias):
    # Test a pair of tensor inputs
    g = g.astype(idtype).to(F.ctx())
    conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias).to(F.ctx())
    ext_w = F.randn((5, 2)).to(F.ctx())
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
    h_dst = F.randn((ndst, 2)).to(F.ctx())
    if weight:
        h_out = conv(g, (h, h_dst))
    else:
        h_out = conv(g, (h, h_dst), weight=ext_w)
    assert h_out.shape == (ndst, 2)
111

112
113
114
115
116
117
118
119
120
121
122
123
def _S2AXWb(A, N, X, W, b):
    X1 = X * N
    X1 = th.matmul(A, X1.view(X1.shape[0], -1))
    X1 = X1 * N
    X2 = X1 * N
    X2 = th.matmul(A, X2.view(X2.shape[0], -1))
    X2 = X2 * N
    X = th.cat([X, X1, X2], dim=-1)
    Y = th.matmul(X, W.rot90())

    return Y + b

124
def test_tagconv():
125
    g = dgl.DGLGraph(nx.path_graph(3))
126
    g = g.to(F.ctx())
127
128
129
130
    ctx = F.ctx()
    adj = g.adjacency_matrix(ctx=ctx)
    norm = th.pow(g.in_degrees().float(), -0.5)

131
    conv = nn.TAGConv(5, 2, bias=True)
132
    conv = conv.to(ctx)
133
134
135
136
    print(conv)

    # test#1: basic
    h0 = F.ones((3, 5))
137
    h1 = conv(g, h0)
138
139
140
141
142
143
144
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    shp = norm.shape + (1,) * (h0.dim() - 1)
    norm = th.reshape(norm, shp).to(ctx)

    assert F.allclose(h1, _S2AXWb(adj, norm, h0, conv.lin.weight, conv.lin.bias))

145
    conv = nn.TAGConv(5, 2)
146
    conv = conv.to(ctx)
147

148
149
    # test#2: basic
    h0 = F.ones((3, 5))
150
    h1 = conv(g, h0)
151
    assert h1.shape[-1] == 2
152

153
    # test reset_parameters
154
155
156
157
158
    old_weight = deepcopy(conv.lin.weight.data)
    conv.reset_parameters()
    new_weight = conv.lin.weight.data
    assert not F.allclose(old_weight, new_weight)

159
def test_set2set():
160
    ctx = F.ctx()
161
    g = dgl.DGLGraph(nx.path_graph(10))
162
    g = g.to(F.ctx())
163
164

    s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
165
    s2s = s2s.to(ctx)
166
167
168
    print(s2s)

    # test#1: basic
169
    h0 = F.randn((g.number_of_nodes(), 5))
170
    h1 = s2s(g, h0)
171
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
172
173

    # test#2: batched graph
174
175
    g1 = dgl.DGLGraph(nx.path_graph(11)).to(F.ctx())
    g2 = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
176
    bg = dgl.batch([g, g1, g2])
177
    h0 = F.randn((bg.number_of_nodes(), 5))
178
    h1 = s2s(bg, h0)
179
180
181
    assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2

def test_glob_att_pool():
182
    ctx = F.ctx()
183
    g = dgl.DGLGraph(nx.path_graph(10))
184
    g = g.to(F.ctx())
185
186

    gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))
187
    gap = gap.to(ctx)
188
189
190
    print(gap)

    # test#1: basic
191
    h0 = F.randn((g.number_of_nodes(), 5))
192
    h1 = gap(g, h0)
193
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
194
195
196

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
197
    h0 = F.randn((bg.number_of_nodes(), 5))
198
    h1 = gap(bg, h0)
199
200
201
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2

def test_simple_pool():
202
    ctx = F.ctx()
203
    g = dgl.DGLGraph(nx.path_graph(15))
204
    g = g.to(F.ctx())
205
206
207
208
209
210
211
212

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
    sort_pool = nn.SortPooling(10) # k = 10
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
213
    h0 = F.randn((g.number_of_nodes(), 5))
214
215
216
217
    sum_pool = sum_pool.to(ctx)
    avg_pool = avg_pool.to(ctx)
    max_pool = max_pool.to(ctx)
    sort_pool = sort_pool.to(ctx)
218
    h1 = sum_pool(g, h0)
219
    assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
220
    h1 = avg_pool(g, h0)
221
    assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
222
    h1 = max_pool(g, h0)
223
    assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
224
    h1 = sort_pool(g, h0)
225
    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.dim() == 2
226
227

    # test#2: batched graph
228
    g_ = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
229
    bg = dgl.batch([g, g_, g, g_, g])
230
    h0 = F.randn((bg.number_of_nodes(), 5))
231
    h1 = sum_pool(bg, h0)
232
233
234
235
236
237
    truth = th.stack([F.sum(h0[:15], 0),
                      F.sum(h0[15:20], 0),
                      F.sum(h0[20:35], 0),
                      F.sum(h0[35:40], 0),
                      F.sum(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
238

239
    h1 = avg_pool(bg, h0)
240
241
242
243
244
245
    truth = th.stack([F.mean(h0[:15], 0),
                      F.mean(h0[15:20], 0),
                      F.mean(h0[20:35], 0),
                      F.mean(h0[35:40], 0),
                      F.mean(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
246

247
    h1 = max_pool(bg, h0)
248
249
250
251
252
253
    truth = th.stack([F.max(h0[:15], 0),
                      F.max(h0[15:20], 0),
                      F.max(h0[20:35], 0),
                      F.max(h0[35:40], 0),
                      F.max(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
254

255
    h1 = sort_pool(bg, h0)
256
257
258
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2

def test_set_trans():
259
    ctx = F.ctx()
260
261
262
263
264
    g = dgl.DGLGraph(nx.path_graph(15))

    st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'sab')
    st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'isab', 3)
    st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)
265
266
267
    st_enc_0 = st_enc_0.to(ctx)
    st_enc_1 = st_enc_1.to(ctx)
    st_dec = st_dec.to(ctx)
268
269
270
    print(st_enc_0, st_enc_1, st_dec)

    # test#1: basic
271
    h0 = F.randn((g.number_of_nodes(), 50))
272
    h1 = st_enc_0(g, h0)
273
    assert h1.shape == h0.shape
274
    h1 = st_enc_1(g, h0)
275
    assert h1.shape == h0.shape
276
    h2 = st_dec(g, h1)
277
    assert h2.shape[0] == 1 and h2.shape[1] == 200 and h2.dim() == 2
278
279
280
281
282

    # test#2: batched graph
    g1 = dgl.DGLGraph(nx.path_graph(5))
    g2 = dgl.DGLGraph(nx.path_graph(10))
    bg = dgl.batch([g, g1, g2])
283
    h0 = F.randn((bg.number_of_nodes(), 50))
284
    h1 = st_enc_0(bg, h0)
285
    assert h1.shape == h0.shape
286
    h1 = st_enc_1(bg, h0)
287
288
    assert h1.shape == h0.shape

289
    h2 = st_dec(bg, h1)
290
291
    assert h2.shape[0] == 3 and h2.shape[1] == 200 and h2.dim() == 2

Minjie Wang's avatar
Minjie Wang committed
292
293
294
295
def test_rgcn():
    ctx = F.ctx()
    etype = []
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
296
    g = g.to(F.ctx())
Minjie Wang's avatar
Minjie Wang committed
297
298
299
300
301
302
303
304
305
    # 5 etypes
    R = 5
    for i in range(g.number_of_edges()):
        etype.append(i % 5)
    B = 2
    I = 10
    O = 8

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
306
307
308
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
309
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
Minjie Wang's avatar
Minjie Wang committed
310
311
312
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r)
313
    h_new_low = rgc_basis_low(g, h, r)
Minjie Wang's avatar
Minjie Wang committed
314
    assert list(h_new.shape) == [100, O]
315
316
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
317
318

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
319
320
    rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
    rgc_bdd_low.weight = rgc_bdd.weight
321
    rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
Minjie Wang's avatar
Minjie Wang committed
322
323
324
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_bdd(g, h, r)
325
    h_new_low = rgc_bdd_low(g, h, r)
Minjie Wang's avatar
Minjie Wang committed
326
    assert list(h_new.shape) == [100, O]
327
328
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
329
330
331
332
333

    # with norm
    norm = th.zeros((g.number_of_edges(), 1)).to(ctx)

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
334
335
336
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
337
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
Minjie Wang's avatar
Minjie Wang committed
338
339
340
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r, norm)
341
    h_new_low = rgc_basis_low(g, h, r, norm)
Minjie Wang's avatar
Minjie Wang committed
342
    assert list(h_new.shape) == [100, O]
343
344
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
345
346

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
347
348
    rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
    rgc_bdd_low.weight = rgc_bdd.weight
349
    rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
Minjie Wang's avatar
Minjie Wang committed
350
351
352
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_bdd(g, h, r, norm)
353
    h_new_low = rgc_bdd_low(g, h, r, norm)
Minjie Wang's avatar
Minjie Wang committed
354
    assert list(h_new.shape) == [100, O]
355
356
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
357
358
359

    # id input
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
360
361
362
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
363
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
Minjie Wang's avatar
Minjie Wang committed
364
365
366
    h = th.randint(0, I, (100,)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r)
367
    h_new_low = rgc_basis_low(g, h, r)
Minjie Wang's avatar
Minjie Wang committed
368
    assert list(h_new.shape) == [100, O]
369
370
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
371

372
@parametrize_dtype
373
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
374
375
def test_gat_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
376
377
    ctx = F.ctx()
    gat = nn.GATConv(5, 2, 4)
378
    feat = F.randn((g.number_of_nodes(), 5))
379
    gat = gat.to(ctx)
380
    h = gat(g, feat)
381
    assert h.shape == (g.number_of_nodes(), 4, 2)
382

383
@parametrize_dtype
384
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
385
386
387
def test_gat_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
388
389
    gat = nn.GATConv(5, 2, 4)
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
390
391
    gat = gat.to(ctx)
    h = gat(g, feat)
392
    assert h.shape == (g.number_of_dst_nodes(), 4, 2)
393

394
@parametrize_dtype
395
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
396
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm'])
397
398
def test_sage_conv(idtype, g, aggre_type):
    g = g.astype(idtype).to(F.ctx())
399
    sage = nn.SAGEConv(5, 10, aggre_type)
400
401
    feat = F.randn((g.number_of_nodes(), 5))
    sage = sage.to(F.ctx())
402
403
404
    h = sage(g, feat)
    assert h.shape[-1] == 10

405
@parametrize_dtype
406
@pytest.mark.parametrize('g', get_cases(['bipartite']))
407
408
409
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm'])
def test_sage_conv_bi(idtype, g, aggre_type):
    g = g.astype(idtype).to(F.ctx())
410
411
    dst_dim = 5 if aggre_type != 'gcn' else 10
    sage = nn.SAGEConv((10, dst_dim), 2, aggre_type)
412
413
    feat = (F.randn((g.number_of_src_nodes(), 10)), F.randn((g.number_of_dst_nodes(), dst_dim)))
    sage = sage.to(F.ctx())
414
415
    h = sage(g, feat)
    assert h.shape[-1] == 2
416
    assert h.shape[0] == g.number_of_dst_nodes()
417

418
419
420
@parametrize_dtype
def test_sage_conv2(idtype):
    # TODO: add test for blocks
Mufei Li's avatar
Mufei Li committed
421
    # Test the case for graphs without edges
422
    g = dgl.heterograph({('_U', '_E', '_V'): ([], [])}, {'_U': 5, '_V': 3})
423
424
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
425
    sage = nn.SAGEConv((3, 3), 2, 'gcn')
Mufei Li's avatar
Mufei Li committed
426
427
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    sage = sage.to(ctx)
428
    h = sage(g, (F.copy_to(feat[0], F.ctx()), F.copy_to(feat[1], F.ctx())))
Mufei Li's avatar
Mufei Li committed
429
430
431
    assert h.shape[-1] == 2
    assert h.shape[0] == 3
    for aggre_type in ['mean', 'pool', 'lstm']:
432
        sage = nn.SAGEConv((3, 1), 2, aggre_type)
Mufei Li's avatar
Mufei Li committed
433
434
435
436
437
438
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        sage = sage.to(ctx)
        h = sage(g, feat)
        assert h.shape[-1] == 2
        assert h.shape[0] == 3

439
440
441
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_sgc_conv(g, idtype):
442
    ctx = F.ctx()
443
    g = g.astype(idtype).to(ctx)
444
445
    # not cached
    sgc = nn.SGConv(5, 10, 3)
446
    feat = F.randn((g.number_of_nodes(), 5))
447
    sgc = sgc.to(ctx)
448

449
    h = sgc(g, feat)
450
451
452
453
    assert h.shape[-1] == 10

    # cached
    sgc = nn.SGConv(5, 10, 3, True)
454
    sgc = sgc.to(ctx)
455
456
    h_0 = sgc(g, feat)
    h_1 = sgc(g, feat + 1)
457
458
459
    assert F.allclose(h_0, h_1)
    assert h_0.shape[-1] == 10

460
461
462
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_appnp_conv(g, idtype):
463
    ctx = F.ctx()
464
    g = g.astype(idtype).to(ctx)
465
    appnp = nn.APPNPConv(10, 0.1)
466
    feat = F.randn((g.number_of_nodes(), 5))
467
    appnp = appnp.to(ctx)
468

469
    h = appnp(g, feat)
470
471
    assert h.shape[-1] == 5

472
@parametrize_dtype
473
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
474
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
475
476
def test_gin_conv(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
477
478
479
480
481
    ctx = F.ctx()
    gin = nn.GINConv(
        th.nn.Linear(5, 12),
        aggregator_type
    )
482
    feat = F.randn((g.number_of_nodes(), 5))
483
484
    gin = gin.to(ctx)
    h = gin(g, feat)
485
    assert h.shape == (g.number_of_nodes(), 12)
486

487
@parametrize_dtype
488
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
489
490
491
492
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
def test_gin_conv_bi(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
493
494
495
496
    gin = nn.GINConv(
        th.nn.Linear(5, 12),
        aggregator_type
    )
497
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
498
499
    gin = gin.to(ctx)
    h = gin(g, feat)
500
    assert h.shape == (g.number_of_dst_nodes(), 12)
501

502
@parametrize_dtype
503
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
504
505
def test_agnn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
506
507
    ctx = F.ctx()
    agnn = nn.AGNNConv(1)
508
    feat = F.randn((g.number_of_nodes(), 5))
509
    agnn = agnn.to(ctx)
510
    h = agnn(g, feat)
511
    assert h.shape == (g.number_of_nodes(), 5)
512

513
@parametrize_dtype
514
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
515
516
517
def test_agnn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
518
    agnn = nn.AGNNConv(1)
519
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
520
521
    agnn = agnn.to(ctx)
    h = agnn(g, feat)
522
    assert h.shape == (g.number_of_dst_nodes(), 5)
523

524
525
526
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_gated_graph_conv(g, idtype):
527
    ctx = F.ctx()
528
    g = g.astype(idtype).to(ctx)
529
530
    ggconv = nn.GatedGraphConv(5, 10, 5, 3)
    etypes = th.arange(g.number_of_edges()) % 3
531
    feat = F.randn((g.number_of_nodes(), 5))
532
533
    ggconv = ggconv.to(ctx)
    etypes = etypes.to(ctx)
534

535
    h = ggconv(g, feat, etypes)
536
537
538
    # current we only do shape check
    assert h.shape[-1] == 10

539
@parametrize_dtype
540
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
541
542
def test_nn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
543
544
545
    ctx = F.ctx()
    edge_func = th.nn.Linear(4, 5 * 10)
    nnconv = nn.NNConv(5, 10, edge_func, 'mean')
546
    feat = F.randn((g.number_of_nodes(), 5))
547
548
549
550
551
552
    efeat = F.randn((g.number_of_edges(), 4))
    nnconv = nnconv.to(ctx)
    h = nnconv(g, feat, efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

553
@parametrize_dtype
554
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
555
556
557
def test_nn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
558
559
    edge_func = th.nn.Linear(4, 5 * 10)
    nnconv = nn.NNConv((5, 2), 10, edge_func, 'mean')
560
561
    feat = F.randn((g.number_of_src_nodes(), 5))
    feat_dst = F.randn((g.number_of_dst_nodes(), 2))
562
563
564
565
566
567
    efeat = F.randn((g.number_of_edges(), 4))
    nnconv = nnconv.to(ctx)
    h = nnconv(g, (feat, feat_dst), efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

568
@parametrize_dtype
569
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
570
571
def test_gmm_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
572
573
    ctx = F.ctx()
    gmmconv = nn.GMMConv(5, 10, 3, 4, 'mean')
574
    feat = F.randn((g.number_of_nodes(), 5))
575
    pseudo = F.randn((g.number_of_edges(), 3))
576
    gmmconv = gmmconv.to(ctx)
577
    h = gmmconv(g, feat, pseudo)
578
579
580
    # currently we only do shape check
    assert h.shape[-1] == 10

581
@parametrize_dtype
582
@pytest.mark.parametrize('g', get_cases(['bipartite', 'block-bipartite'], exclude=['zero-degree']))
583
584
585
def test_gmm_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
586
    gmmconv = nn.GMMConv((5, 2), 10, 3, 4, 'mean')
587
588
    feat = F.randn((g.number_of_src_nodes(), 5))
    feat_dst = F.randn((g.number_of_dst_nodes(), 2))
589
590
591
592
593
594
    pseudo = F.randn((g.number_of_edges(), 3))
    gmmconv = gmmconv.to(ctx)
    h = gmmconv(g, (feat, feat_dst), pseudo)
    # currently we only do shape check
    assert h.shape[-1] == 10

595
@parametrize_dtype
596
@pytest.mark.parametrize('norm_type', ['both', 'right', 'none'])
597
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite'], exclude=['zero-degree']))
598
599
def test_dense_graph_conv(norm_type, g, idtype):
    g = g.astype(idtype).to(F.ctx())
600
    ctx = F.ctx()
601
    # TODO(minjie): enable the following option after #1385
602
    adj = g.adjacency_matrix(ctx=ctx).to_dense()
603
604
    conv = nn.GraphConv(5, 2, norm=norm_type, bias=True)
    dense_conv = nn.DenseGraphConv(5, 2, norm=norm_type, bias=True)
605
606
    dense_conv.weight.data = conv.weight.data
    dense_conv.bias.data = conv.bias.data
607
    feat = F.randn((g.number_of_src_nodes(), 5))
608
609
    conv = conv.to(ctx)
    dense_conv = dense_conv.to(ctx)
610
611
    out_conv = conv(g, feat)
    out_dense_conv = dense_conv(adj, feat)
612
613
    assert F.allclose(out_conv, out_dense_conv)

614
@parametrize_dtype
615
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite']))
616
617
def test_dense_sage_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
618
619
    ctx = F.ctx()
    adj = g.adjacency_matrix(ctx=ctx).to_dense()
620
    sage = nn.SAGEConv(5, 2, 'gcn')
621
622
623
    dense_sage = nn.DenseSAGEConv(5, 2)
    dense_sage.fc.weight.data = sage.fc_neigh.weight.data
    dense_sage.fc.bias.data = sage.fc_neigh.bias.data
624
625
626
627
628
629
630
    if len(g.ntypes) == 2:
        feat = (
            F.randn((g.number_of_src_nodes(), 5)),
            F.randn((g.number_of_dst_nodes(), 5))
        )
    else:
        feat = F.randn((g.number_of_nodes(), 5))
631
632
    sage = sage.to(ctx)
    dense_sage = dense_sage.to(ctx)
633
634
    out_sage = sage(g, feat)
    out_dense_sage = dense_sage(adj, feat)
635
636
    assert F.allclose(out_sage, out_dense_sage), g

637
@parametrize_dtype
638
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
639
640
def test_edge_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
641
642
643
    ctx = F.ctx()
    edge_conv = nn.EdgeConv(5, 2).to(ctx)
    print(edge_conv)
644
645
646
    h0 = F.randn((g.number_of_nodes(), 5))
    h1 = edge_conv(g, h0)
    assert h1.shape == (g.number_of_nodes(), 2)
647

648
@parametrize_dtype
649
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
650
651
652
653
654
def test_edge_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    edge_conv = nn.EdgeConv(5, 2).to(ctx)
    print(edge_conv)
655
    h0 = F.randn((g.number_of_src_nodes(), 5))
656
657
    x0 = F.randn((g.number_of_dst_nodes(), 5))
    h1 = edge_conv(g, (h0, x0))
658
    assert h1.shape == (g.number_of_dst_nodes(), 2)
659
660
661
662
663

def test_dense_cheb_conv():
    for k in range(1, 4):
        ctx = F.ctx()
        g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
664
        g = g.to(F.ctx())
665
        adj = g.adjacency_matrix(ctx=ctx).to_dense()
Axel Nilsson's avatar
Axel Nilsson committed
666
        cheb = nn.ChebConv(5, 2, k, None)
667
        dense_cheb = nn.DenseChebConv(5, 2, k)
Axel Nilsson's avatar
Axel Nilsson committed
668
669
670
671
672
        #for i in range(len(cheb.fc)):
        #    dense_cheb.W.data[i] = cheb.fc[i].weight.data.t()
        dense_cheb.W.data = cheb.linear.weight.data.transpose(-1, -2).view(k, 5, 2)
        if cheb.linear.bias is not None:
            dense_cheb.bias.data = cheb.linear.bias.data
673
        feat = F.randn((100, 5))
674
675
        cheb = cheb.to(ctx)
        dense_cheb = dense_cheb.to(ctx)
676
677
        out_cheb = cheb(g, feat, [2.0])
        out_dense_cheb = dense_cheb(adj, feat, 2.0)
Axel Nilsson's avatar
Axel Nilsson committed
678
        print(k, out_cheb, out_dense_cheb)
679
680
        assert F.allclose(out_cheb, out_dense_cheb)

681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
def test_sequential():
    ctx = F.ctx()
    # Test single graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat, e_feat):
            graph = graph.local_var()
            graph.ndata['h'] = n_feat
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            n_feat += graph.ndata['h']
            graph.apply_edges(fn.u_add_v('h', 'h', 'e'))
            e_feat += graph.edata['e']
            return n_feat, e_feat

    g = dgl.DGLGraph()
    g.add_nodes(3)
    g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
700
    g = g.to(F.ctx())
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    n_feat = F.randn((3, 4))
    e_feat = F.randn((9, 4))
    net = net.to(ctx)
    n_feat, e_feat = net(g, n_feat, e_feat)
    assert n_feat.shape == (3, 4)
    assert e_feat.shape == (9, 4)

    # Test multiple graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat):
            graph = graph.local_var()
            graph.ndata['h'] = n_feat
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            n_feat += graph.ndata['h']
            return n_feat.view(graph.number_of_nodes() // 2, 2, -1).sum(1)

721
722
723
    g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05)).to(F.ctx())
    g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2)).to(F.ctx())
    g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8)).to(F.ctx())
724
725
726
727
728
729
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    net = net.to(ctx)
    n_feat = F.randn((32, 4))
    n_feat = net([g1, g2, g3], n_feat)
    assert n_feat.shape == (4, 4)

730
731
732
733
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_atomic_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
734
735
736
737
738
739
740
741
742
    aconv = nn.AtomicConv(interaction_cutoffs=F.tensor([12.0, 12.0]),
                          rbf_kernel_means=F.tensor([0.0, 2.0]),
                          rbf_kernel_scaling=F.tensor([4.0, 4.0]),
                          features_to_use=F.tensor([6.0, 8.0]))

    ctx = F.ctx()
    if F.gpu_ctx():
        aconv = aconv.to(ctx)

743
    feat = F.randn((g.number_of_nodes(), 1))
744
745
746
747
748
749
    dist = F.randn((g.number_of_edges(), 1))

    h = aconv(g, feat, dist)
    # current we only do shape check
    assert h.shape[-1] == 4

750
751
752
753
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_cf_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
754
755
756
757
758
759
760
761
762
    cfconv = nn.CFConv(node_in_feats=2,
                       edge_in_feats=3,
                       hidden_feats=2,
                       out_feats=3)

    ctx = F.ctx()
    if F.gpu_ctx():
        cfconv = cfconv.to(ctx)

763
    node_feats = F.randn((g.number_of_nodes(), 2))
764
765
766
    edge_feats = F.randn((g.number_of_edges(), 3))
    h = cfconv(g, node_feats, edge_feats)
    # current we only do shape check
767
    assert h.shape[-1] == 3
768

769
770
771
772
773
774
def myagg(alist, dsttype):
    rst = alist[0]
    for i in range(1, len(alist)):
        rst = rst + (i + 1) * alist[i]
    return rst

775
@parametrize_dtype
776
@pytest.mark.parametrize('agg', ['sum', 'max', 'min', 'mean', 'stack', myagg])
777
def test_hetero_conv(agg, idtype):
778
    g = dgl.heterograph({
779
780
781
        ('user', 'follows', 'user'): ([0, 0, 2, 1], [1, 2, 1, 3]),
        ('user', 'plays', 'game'): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]),
        ('store', 'sells', 'game'): ([0, 0, 1, 1], [0, 3, 1, 2])},
782
        idtype=idtype, device=F.ctx())
783
    conv = nn.HeteroGraphConv({
784
785
786
        'follows': nn.GraphConv(2, 3, allow_zero_in_degree=True),
        'plays': nn.GraphConv(2, 4, allow_zero_in_degree=True),
        'sells': nn.GraphConv(3, 4, allow_zero_in_degree=True)},
787
        agg)
788
    conv = conv.to(F.ctx())
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))

    h = conv(g, {'user': uf})
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 1, 4)

    h = conv(g, {'user': uf, 'store': sf})
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 2, 4)

    h = conv(g, {'store': sf})
    assert set(h.keys()) == {'game'}
    if agg != 'stack':
        assert h['game'].shape == (4, 4)
    else:
        assert h['game'].shape == (4, 1, 4)

    # test with pair input
    conv = nn.HeteroGraphConv({
820
821
822
        'follows': nn.SAGEConv(2, 3, 'mean'),
        'plays': nn.SAGEConv((2, 4), 4, 'mean'),
        'sells': nn.SAGEConv(3, 4, 'mean')},
823
        agg)
824
    conv = conv.to(F.ctx())
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864

    h = conv(g, ({'user': uf}, {'user' : uf, 'game' : gf}))
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 1, 4)

    # pair input requires both src and dst type features to be provided
    h = conv(g, ({'user': uf}, {'game' : gf}))
    assert set(h.keys()) == {'game'}
    if agg != 'stack':
        assert h['game'].shape == (4, 4)
    else:
        assert h['game'].shape == (4, 1, 4)

    # test with mod args
    class MyMod(th.nn.Module):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.carg2 = 0
            self.s1 = s1
            self.s2 = s2
        def forward(self, g, h, arg1=None, *, arg2=None):
            if arg1 is not None:
                self.carg1 += 1
            if arg2 is not None:
                self.carg2 += 1
            return th.zeros((g.number_of_dst_nodes(), self.s2))
    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
    conv = nn.HeteroGraphConv({
        'follows': mod1,
        'plays': mod2,
        'sells': mod3},
        agg)
865
    conv = conv.to(F.ctx())
866
867
868
869
870
871
872
873
874
875
    mod_args = {'follows' : (1,), 'plays' : (1,)}
    mod_kwargs = {'sells' : {'arg2' : 'abc'}}
    h = conv(g, {'user' : uf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs)
    assert mod1.carg1 == 1
    assert mod1.carg2 == 0
    assert mod2.carg1 == 1
    assert mod2.carg2 == 0
    assert mod3.carg1 == 0
    assert mod3.carg2 == 1

876
877
if __name__ == '__main__':
    test_graph_conv()
878
879
880
881
    test_set2set()
    test_glob_att_pool()
    test_simple_pool()
    test_set_trans()
Minjie Wang's avatar
Minjie Wang committed
882
    test_rgcn()
883
884
885
886
887
888
889
890
891
892
893
894
895
    test_tagconv()
    test_gat_conv()
    test_sage_conv()
    test_sgc_conv()
    test_appnp_conv()
    test_gin_conv()
    test_agnn_conv()
    test_gated_graph_conv()
    test_nn_conv()
    test_gmm_conv()
    test_dense_graph_conv()
    test_dense_sage_conv()
    test_dense_cheb_conv()
896
    test_sequential()
897
    test_atomic_conv()
898
    test_cf_conv()