test_nn.py 29.3 KB
Newer Older
1
2
3
4
import torch as th
import networkx as nx
import dgl
import dgl.nn.pytorch as nn
5
import dgl.function as fn
6
import backend as F
7
import pytest
8
9
from test_utils.graph_cases import get_cases, random_graph, random_bipartite, random_dglgraph
from test_utils import parametrize_dtype
10
11
from copy import deepcopy

12
13
14
import numpy as np
import scipy as sp

15
16
17
18
19
def _AXWb(A, X, W, b):
    X = th.matmul(X, W)
    Y = th.matmul(A, X.view(X.shape[0], -1)).view_as(X)
    return Y + b

20
21
def test_graph_conv0():
    g = dgl.DGLGraph(nx.path_graph(3)).to(F.ctx())
22
23
    ctx = F.ctx()
    adj = g.adjacency_matrix(ctx=ctx)
24

25
    conv = nn.GraphConv(5, 2, norm='none', bias=True)
26
    conv = conv.to(ctx)
27
28
    print(conv)
    # test#1: basic
29
    h0 = F.ones((3, 5))
30
    h1 = conv(g, h0)
31
32
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
33
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
34
    # test#2: more-dim
35
    h0 = F.ones((3, 5, 5))
36
    h1 = conv(g, h0)
37
38
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
39
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
40
41

    conv = nn.GraphConv(5, 2)
42
    conv = conv.to(ctx)
43
    # test#3: basic
44
    h0 = F.ones((3, 5))
45
    h1 = conv(g, h0)
46
47
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
48
    # test#4: basic
49
    h0 = F.ones((3, 5, 5))
50
    h1 = conv(g, h0)
51
52
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
53
54

    conv = nn.GraphConv(5, 2)
55
    conv = conv.to(ctx)
56
    # test#3: basic
57
    h0 = F.ones((3, 5))
58
    h1 = conv(g, h0)
59
60
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
61
    # test#4: basic
62
    h0 = F.ones((3, 5, 5))
63
    h1 = conv(g, h0)
64
65
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
66
67
68
69
70

    # test rest_parameters
    old_weight = deepcopy(conv.weight.data)
    conv.reset_parameters()
    new_weight = conv.weight.data
71
    assert not F.allclose(old_weight, new_weight)
72

73
74
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite'], exclude=['zero-degree', 'dglgraph']))
75
76
77
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
78
79
80
def test_graph_conv(idtype, g, norm, weight, bias):
    # Test one tensor input
    g = g.astype(idtype).to(F.ctx())
81
82
    conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias).to(F.ctx())
    ext_w = F.randn((5, 2)).to(F.ctx())
83
84
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
85
86
    h = F.randn((nsrc, 5)).to(F.ctx())
    if weight:
87
        h_out = conv(g, h)
88
    else:
89
90
91
        h_out = conv(g, h, weight=ext_w)
    assert h_out.shape == (ndst, 2)

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
def test_graph_conv_bi(idtype, g, norm, weight, bias):
    # Test a pair of tensor inputs
    g = g.astype(idtype).to(F.ctx())
    conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias).to(F.ctx())
    ext_w = F.randn((5, 2)).to(F.ctx())
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
    h_dst = F.randn((ndst, 2)).to(F.ctx())
    if weight:
        h_out = conv(g, (h, h_dst))
    else:
        h_out = conv(g, (h, h_dst), weight=ext_w)
    assert h_out.shape == (ndst, 2)
111

112
113
114
115
116
117
118
119
120
121
122
123
def _S2AXWb(A, N, X, W, b):
    X1 = X * N
    X1 = th.matmul(A, X1.view(X1.shape[0], -1))
    X1 = X1 * N
    X2 = X1 * N
    X2 = th.matmul(A, X2.view(X2.shape[0], -1))
    X2 = X2 * N
    X = th.cat([X, X1, X2], dim=-1)
    Y = th.matmul(X, W.rot90())

    return Y + b

124
def test_tagconv():
125
    g = dgl.DGLGraph(nx.path_graph(3))
126
    g = g.to(F.ctx())
127
128
129
130
    ctx = F.ctx()
    adj = g.adjacency_matrix(ctx=ctx)
    norm = th.pow(g.in_degrees().float(), -0.5)

131
    conv = nn.TAGConv(5, 2, bias=True)
132
    conv = conv.to(ctx)
133
134
135
136
    print(conv)

    # test#1: basic
    h0 = F.ones((3, 5))
137
    h1 = conv(g, h0)
138
139
140
141
142
143
144
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    shp = norm.shape + (1,) * (h0.dim() - 1)
    norm = th.reshape(norm, shp).to(ctx)

    assert F.allclose(h1, _S2AXWb(adj, norm, h0, conv.lin.weight, conv.lin.bias))

145
    conv = nn.TAGConv(5, 2)
146
    conv = conv.to(ctx)
147

148
149
    # test#2: basic
    h0 = F.ones((3, 5))
150
    h1 = conv(g, h0)
151
    assert h1.shape[-1] == 2
152

153
    # test reset_parameters
154
155
156
157
158
    old_weight = deepcopy(conv.lin.weight.data)
    conv.reset_parameters()
    new_weight = conv.lin.weight.data
    assert not F.allclose(old_weight, new_weight)

159
def test_set2set():
160
    ctx = F.ctx()
161
    g = dgl.DGLGraph(nx.path_graph(10))
162
    g = g.to(F.ctx())
163
164

    s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
165
    s2s = s2s.to(ctx)
166
167
168
    print(s2s)

    # test#1: basic
169
    h0 = F.randn((g.number_of_nodes(), 5))
170
    h1 = s2s(g, h0)
171
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
172
173

    # test#2: batched graph
174
175
    g1 = dgl.DGLGraph(nx.path_graph(11)).to(F.ctx())
    g2 = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
176
    bg = dgl.batch([g, g1, g2])
177
    h0 = F.randn((bg.number_of_nodes(), 5))
178
    h1 = s2s(bg, h0)
179
180
181
    assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2

def test_glob_att_pool():
182
    ctx = F.ctx()
183
    g = dgl.DGLGraph(nx.path_graph(10))
184
    g = g.to(F.ctx())
185
186

    gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))
187
    gap = gap.to(ctx)
188
189
190
    print(gap)

    # test#1: basic
191
    h0 = F.randn((g.number_of_nodes(), 5))
192
    h1 = gap(g, h0)
193
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
194
195
196

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
197
    h0 = F.randn((bg.number_of_nodes(), 5))
198
    h1 = gap(bg, h0)
199
200
201
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2

def test_simple_pool():
202
    ctx = F.ctx()
203
    g = dgl.DGLGraph(nx.path_graph(15))
204
    g = g.to(F.ctx())
205
206
207
208
209
210
211
212

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
    sort_pool = nn.SortPooling(10) # k = 10
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
213
    h0 = F.randn((g.number_of_nodes(), 5))
214
215
216
217
    sum_pool = sum_pool.to(ctx)
    avg_pool = avg_pool.to(ctx)
    max_pool = max_pool.to(ctx)
    sort_pool = sort_pool.to(ctx)
218
    h1 = sum_pool(g, h0)
219
    assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
220
    h1 = avg_pool(g, h0)
221
    assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
222
    h1 = max_pool(g, h0)
223
    assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
224
    h1 = sort_pool(g, h0)
225
    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.dim() == 2
226
227

    # test#2: batched graph
228
    g_ = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
229
    bg = dgl.batch([g, g_, g, g_, g])
230
    h0 = F.randn((bg.number_of_nodes(), 5))
231
    h1 = sum_pool(bg, h0)
232
233
234
235
236
237
    truth = th.stack([F.sum(h0[:15], 0),
                      F.sum(h0[15:20], 0),
                      F.sum(h0[20:35], 0),
                      F.sum(h0[35:40], 0),
                      F.sum(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
238

239
    h1 = avg_pool(bg, h0)
240
241
242
243
244
245
    truth = th.stack([F.mean(h0[:15], 0),
                      F.mean(h0[15:20], 0),
                      F.mean(h0[20:35], 0),
                      F.mean(h0[35:40], 0),
                      F.mean(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
246

247
    h1 = max_pool(bg, h0)
248
249
250
251
252
253
    truth = th.stack([F.max(h0[:15], 0),
                      F.max(h0[15:20], 0),
                      F.max(h0[20:35], 0),
                      F.max(h0[35:40], 0),
                      F.max(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
254

255
    h1 = sort_pool(bg, h0)
256
257
258
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2

def test_set_trans():
259
    ctx = F.ctx()
260
261
262
263
264
    g = dgl.DGLGraph(nx.path_graph(15))

    st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'sab')
    st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'isab', 3)
    st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)
265
266
267
    st_enc_0 = st_enc_0.to(ctx)
    st_enc_1 = st_enc_1.to(ctx)
    st_dec = st_dec.to(ctx)
268
269
270
    print(st_enc_0, st_enc_1, st_dec)

    # test#1: basic
271
    h0 = F.randn((g.number_of_nodes(), 50))
272
    h1 = st_enc_0(g, h0)
273
    assert h1.shape == h0.shape
274
    h1 = st_enc_1(g, h0)
275
    assert h1.shape == h0.shape
276
    h2 = st_dec(g, h1)
277
    assert h2.shape[0] == 1 and h2.shape[1] == 200 and h2.dim() == 2
278
279
280
281
282

    # test#2: batched graph
    g1 = dgl.DGLGraph(nx.path_graph(5))
    g2 = dgl.DGLGraph(nx.path_graph(10))
    bg = dgl.batch([g, g1, g2])
283
    h0 = F.randn((bg.number_of_nodes(), 50))
284
    h1 = st_enc_0(bg, h0)
285
    assert h1.shape == h0.shape
286
    h1 = st_enc_1(bg, h0)
287
288
    assert h1.shape == h0.shape

289
    h2 = st_dec(bg, h1)
290
291
    assert h2.shape[0] == 3 and h2.shape[1] == 200 and h2.dim() == 2

Minjie Wang's avatar
Minjie Wang committed
292
293
294
295
def test_rgcn():
    ctx = F.ctx()
    etype = []
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
296
    g = g.to(F.ctx())
Minjie Wang's avatar
Minjie Wang committed
297
298
299
300
301
302
303
304
305
    # 5 etypes
    R = 5
    for i in range(g.number_of_edges()):
        etype.append(i % 5)
    B = 2
    I = 10
    O = 8

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
306
307
308
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
309
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
Minjie Wang's avatar
Minjie Wang committed
310
311
312
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r)
313
    h_new_low = rgc_basis_low(g, h, r)
Minjie Wang's avatar
Minjie Wang committed
314
    assert list(h_new.shape) == [100, O]
315
316
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
317
318

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
319
320
    rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
    rgc_bdd_low.weight = rgc_bdd.weight
321
    rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
Minjie Wang's avatar
Minjie Wang committed
322
323
324
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_bdd(g, h, r)
325
    h_new_low = rgc_bdd_low(g, h, r)
Minjie Wang's avatar
Minjie Wang committed
326
    assert list(h_new.shape) == [100, O]
327
328
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
329
330
331
332
333

    # with norm
    norm = th.zeros((g.number_of_edges(), 1)).to(ctx)

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
334
335
336
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
337
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
Minjie Wang's avatar
Minjie Wang committed
338
339
340
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r, norm)
341
    h_new_low = rgc_basis_low(g, h, r, norm)
Minjie Wang's avatar
Minjie Wang committed
342
    assert list(h_new.shape) == [100, O]
343
344
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
345
346

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
347
348
    rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
    rgc_bdd_low.weight = rgc_bdd.weight
349
    rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
Minjie Wang's avatar
Minjie Wang committed
350
351
352
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_bdd(g, h, r, norm)
353
    h_new_low = rgc_bdd_low(g, h, r, norm)
Minjie Wang's avatar
Minjie Wang committed
354
    assert list(h_new.shape) == [100, O]
355
356
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
357
358
359

    # id input
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
360
361
362
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
363
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
Minjie Wang's avatar
Minjie Wang committed
364
365
366
    h = th.randint(0, I, (100,)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r)
367
    h_new_low = rgc_basis_low(g, h, r)
Minjie Wang's avatar
Minjie Wang committed
368
    assert list(h_new.shape) == [100, O]
369
370
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
371

372
@parametrize_dtype
373
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
374
375
def test_gat_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
376
377
    ctx = F.ctx()
    gat = nn.GATConv(5, 2, 4)
378
    feat = F.randn((g.number_of_nodes(), 5))
379
    gat = gat.to(ctx)
380
    h = gat(g, feat)
381
    assert h.shape == (g.number_of_nodes(), 4, 2)
382

383
@parametrize_dtype
384
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
385
386
387
def test_gat_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
388
389
    gat = nn.GATConv(5, 2, 4)
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
390
391
    gat = gat.to(ctx)
    h = gat(g, feat)
392
    assert h.shape == (g.number_of_dst_nodes(), 4, 2)
393

394
@parametrize_dtype
395
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
396
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm'])
397
398
def test_sage_conv(idtype, g, aggre_type):
    g = g.astype(idtype).to(F.ctx())
399
    sage = nn.SAGEConv(5, 10, aggre_type)
400
401
    feat = F.randn((g.number_of_nodes(), 5))
    sage = sage.to(F.ctx())
402
403
404
    h = sage(g, feat)
    assert h.shape[-1] == 10

405
@parametrize_dtype
406
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
407
408
409
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm'])
def test_sage_conv_bi(idtype, g, aggre_type):
    g = g.astype(idtype).to(F.ctx())
410
411
    dst_dim = 5 if aggre_type != 'gcn' else 10
    sage = nn.SAGEConv((10, dst_dim), 2, aggre_type)
412
413
    feat = (F.randn((g.number_of_src_nodes(), 10)), F.randn((g.number_of_dst_nodes(), dst_dim)))
    sage = sage.to(F.ctx())
414
415
    h = sage(g, feat)
    assert h.shape[-1] == 2
416
    assert h.shape[0] == g.number_of_dst_nodes()
417

418
419
420
@parametrize_dtype
def test_sage_conv2(idtype):
    # TODO: add test for blocks
Mufei Li's avatar
Mufei Li committed
421
422
    # Test the case for graphs without edges
    g = dgl.bipartite([], num_nodes=(5, 3))
423
424
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
425
    sage = nn.SAGEConv((3, 3), 2, 'gcn', allow_zero_in_degree=True)
Mufei Li's avatar
Mufei Li committed
426
427
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    sage = sage.to(ctx)
428
    h = sage(g, (F.copy_to(feat[0], F.ctx()), F.copy_to(feat[1], F.ctx())))
Mufei Li's avatar
Mufei Li committed
429
430
431
    assert h.shape[-1] == 2
    assert h.shape[0] == 3
    for aggre_type in ['mean', 'pool', 'lstm']:
432
        sage = nn.SAGEConv((3, 1), 2, aggre_type, allow_zero_in_degree=True)
Mufei Li's avatar
Mufei Li committed
433
434
435
436
437
438
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        sage = sage.to(ctx)
        h = sage(g, feat)
        assert h.shape[-1] == 2
        assert h.shape[0] == 3

439
440
441
def test_sgc_conv():
    ctx = F.ctx()
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
442
    g = g.to(F.ctx())
443
444
445
    # not cached
    sgc = nn.SGConv(5, 10, 3)
    feat = F.randn((100, 5))
446
    sgc = sgc.to(ctx)
447

448
    h = sgc(g, feat)
449
450
451
452
    assert h.shape[-1] == 10

    # cached
    sgc = nn.SGConv(5, 10, 3, True)
453
    sgc = sgc.to(ctx)
454
455
    h_0 = sgc(g, feat)
    h_1 = sgc(g, feat + 1)
456
457
458
459
460
461
    assert F.allclose(h_0, h_1)
    assert h_0.shape[-1] == 10

def test_appnp_conv():
    ctx = F.ctx()
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
462
    g = g.to(F.ctx())
463
464
    appnp = nn.APPNPConv(10, 0.1)
    feat = F.randn((100, 5))
465
    appnp = appnp.to(ctx)
466

467
    h = appnp(g, feat)
468
469
    assert h.shape[-1] == 5

470
@parametrize_dtype
471
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
472
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
473
474
def test_gin_conv(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
475
476
477
478
479
    ctx = F.ctx()
    gin = nn.GINConv(
        th.nn.Linear(5, 12),
        aggregator_type
    )
480
    feat = F.randn((g.number_of_nodes(), 5))
481
482
    gin = gin.to(ctx)
    h = gin(g, feat)
483
    assert h.shape == (g.number_of_nodes(), 12)
484

485
@parametrize_dtype
486
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
487
488
489
490
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
def test_gin_conv_bi(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
491
492
493
494
    gin = nn.GINConv(
        th.nn.Linear(5, 12),
        aggregator_type
    )
495
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
496
497
    gin = gin.to(ctx)
    h = gin(g, feat)
498
    assert h.shape == (g.number_of_dst_nodes(), 12)
499

500
@parametrize_dtype
501
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
502
503
def test_agnn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
504
505
    ctx = F.ctx()
    agnn = nn.AGNNConv(1)
506
    feat = F.randn((g.number_of_nodes(), 5))
507
    agnn = agnn.to(ctx)
508
    h = agnn(g, feat)
509
    assert h.shape == (g.number_of_nodes(), 5)
510

511
@parametrize_dtype
512
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
513
514
515
def test_agnn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
516
    agnn = nn.AGNNConv(1)
517
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
518
519
    agnn = agnn.to(ctx)
    h = agnn(g, feat)
520
    assert h.shape == (g.number_of_dst_nodes(), 5)
521

522
523
524
def test_gated_graph_conv():
    ctx = F.ctx()
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
525
    g = g.to(F.ctx())
526
527
528
    ggconv = nn.GatedGraphConv(5, 10, 5, 3)
    etypes = th.arange(g.number_of_edges()) % 3
    feat = F.randn((100, 5))
529
530
    ggconv = ggconv.to(ctx)
    etypes = etypes.to(ctx)
531

532
    h = ggconv(g, feat, etypes)
533
534
535
    # current we only do shape check
    assert h.shape[-1] == 10

536
@parametrize_dtype
537
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
538
539
def test_nn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
540
541
542
    ctx = F.ctx()
    edge_func = th.nn.Linear(4, 5 * 10)
    nnconv = nn.NNConv(5, 10, edge_func, 'mean')
543
    feat = F.randn((g.number_of_nodes(), 5))
544
545
546
547
548
549
    efeat = F.randn((g.number_of_edges(), 4))
    nnconv = nnconv.to(ctx)
    h = nnconv(g, feat, efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

550
@parametrize_dtype
551
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
552
553
554
555
def test_nn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    #g = dgl.bipartite(sp.sparse.random(50, 100, density=0.1))
556
557
    edge_func = th.nn.Linear(4, 5 * 10)
    nnconv = nn.NNConv((5, 2), 10, edge_func, 'mean')
558
559
    feat = F.randn((g.number_of_src_nodes(), 5))
    feat_dst = F.randn((g.number_of_dst_nodes(), 2))
560
561
562
563
564
565
    efeat = F.randn((g.number_of_edges(), 4))
    nnconv = nnconv.to(ctx)
    h = nnconv(g, (feat, feat_dst), efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

566
@parametrize_dtype
567
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
568
569
def test_gmm_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
570
571
    ctx = F.ctx()
    gmmconv = nn.GMMConv(5, 10, 3, 4, 'mean')
572
    feat = F.randn((g.number_of_nodes(), 5))
573
    pseudo = F.randn((g.number_of_edges(), 3))
574
    gmmconv = gmmconv.to(ctx)
575
    h = gmmconv(g, feat, pseudo)
576
577
578
    # currently we only do shape check
    assert h.shape[-1] == 10

579
@parametrize_dtype
580
@pytest.mark.parametrize('g', get_cases(['bipartite', 'block-bipartite'], exclude=['zero-degree']))
581
582
583
def test_gmm_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
584
    gmmconv = nn.GMMConv((5, 2), 10, 3, 4, 'mean')
585
586
    feat = F.randn((g.number_of_src_nodes(), 5))
    feat_dst = F.randn((g.number_of_dst_nodes(), 2))
587
588
589
590
591
592
    pseudo = F.randn((g.number_of_edges(), 3))
    gmmconv = gmmconv.to(ctx)
    h = gmmconv(g, (feat, feat_dst), pseudo)
    # currently we only do shape check
    assert h.shape[-1] == 10

593
@parametrize_dtype
594
@pytest.mark.parametrize('norm_type', ['both', 'right', 'none'])
595
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite'], exclude=['zero-degree']))
596
597
def test_dense_graph_conv(norm_type, g, idtype):
    g = g.astype(idtype).to(F.ctx())
598
    ctx = F.ctx()
599
    # TODO(minjie): enable the following option after #1385
600
    adj = g.adjacency_matrix(ctx=ctx).to_dense()
601
602
    conv = nn.GraphConv(5, 2, norm=norm_type, bias=True)
    dense_conv = nn.DenseGraphConv(5, 2, norm=norm_type, bias=True)
603
604
    dense_conv.weight.data = conv.weight.data
    dense_conv.bias.data = conv.bias.data
605
    feat = F.randn((g.number_of_src_nodes(), 5))
606
607
    conv = conv.to(ctx)
    dense_conv = dense_conv.to(ctx)
608
609
    out_conv = conv(g, feat)
    out_dense_conv = dense_conv(adj, feat)
610
611
    assert F.allclose(out_conv, out_dense_conv)

612
@parametrize_dtype
613
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite'], exclude=['zero-degree']))
614
615
def test_dense_sage_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
616
617
    ctx = F.ctx()
    adj = g.adjacency_matrix(ctx=ctx).to_dense()
618
    sage = nn.SAGEConv(5, 2, 'gcn')
619
620
621
    dense_sage = nn.DenseSAGEConv(5, 2)
    dense_sage.fc.weight.data = sage.fc_neigh.weight.data
    dense_sage.fc.bias.data = sage.fc_neigh.bias.data
622
623
624
625
626
627
628
    if len(g.ntypes) == 2:
        feat = (
            F.randn((g.number_of_src_nodes(), 5)),
            F.randn((g.number_of_dst_nodes(), 5))
        )
    else:
        feat = F.randn((g.number_of_nodes(), 5))
629
630
    sage = sage.to(ctx)
    dense_sage = dense_sage.to(ctx)
631
632
    out_sage = sage(g, feat)
    out_dense_sage = dense_sage(adj, feat)
633
634
    assert F.allclose(out_sage, out_dense_sage), g

635
@parametrize_dtype
636
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
637
638
def test_edge_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
639
640
641
    ctx = F.ctx()
    edge_conv = nn.EdgeConv(5, 2).to(ctx)
    print(edge_conv)
642
643
644
    h0 = F.randn((g.number_of_nodes(), 5))
    h1 = edge_conv(g, h0)
    assert h1.shape == (g.number_of_nodes(), 2)
645

646
@parametrize_dtype
647
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
648
649
650
651
652
def test_edge_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    edge_conv = nn.EdgeConv(5, 2).to(ctx)
    print(edge_conv)
653
    h0 = F.randn((g.number_of_src_nodes(), 5))
654
655
    x0 = F.randn((g.number_of_dst_nodes(), 5))
    h1 = edge_conv(g, (h0, x0))
656
    assert h1.shape == (g.number_of_dst_nodes(), 2)
657
658
659
660
661

def test_dense_cheb_conv():
    for k in range(1, 4):
        ctx = F.ctx()
        g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
662
        g = g.to(F.ctx())
663
        adj = g.adjacency_matrix(ctx=ctx).to_dense()
Axel Nilsson's avatar
Axel Nilsson committed
664
        cheb = nn.ChebConv(5, 2, k, None)
665
        dense_cheb = nn.DenseChebConv(5, 2, k)
Axel Nilsson's avatar
Axel Nilsson committed
666
667
668
669
670
        #for i in range(len(cheb.fc)):
        #    dense_cheb.W.data[i] = cheb.fc[i].weight.data.t()
        dense_cheb.W.data = cheb.linear.weight.data.transpose(-1, -2).view(k, 5, 2)
        if cheb.linear.bias is not None:
            dense_cheb.bias.data = cheb.linear.bias.data
671
        feat = F.randn((100, 5))
672
673
        cheb = cheb.to(ctx)
        dense_cheb = dense_cheb.to(ctx)
674
675
        out_cheb = cheb(g, feat, [2.0])
        out_dense_cheb = dense_cheb(adj, feat, 2.0)
Axel Nilsson's avatar
Axel Nilsson committed
676
        print(k, out_cheb, out_dense_cheb)
677
678
        assert F.allclose(out_cheb, out_dense_cheb)

679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
def test_sequential():
    ctx = F.ctx()
    # Test single graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat, e_feat):
            graph = graph.local_var()
            graph.ndata['h'] = n_feat
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            n_feat += graph.ndata['h']
            graph.apply_edges(fn.u_add_v('h', 'h', 'e'))
            e_feat += graph.edata['e']
            return n_feat, e_feat

    g = dgl.DGLGraph()
    g.add_nodes(3)
    g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
698
    g = g.to(F.ctx())
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    n_feat = F.randn((3, 4))
    e_feat = F.randn((9, 4))
    net = net.to(ctx)
    n_feat, e_feat = net(g, n_feat, e_feat)
    assert n_feat.shape == (3, 4)
    assert e_feat.shape == (9, 4)

    # Test multiple graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat):
            graph = graph.local_var()
            graph.ndata['h'] = n_feat
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            n_feat += graph.ndata['h']
            return n_feat.view(graph.number_of_nodes() // 2, 2, -1).sum(1)

719
720
721
    g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05)).to(F.ctx())
    g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2)).to(F.ctx())
    g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8)).to(F.ctx())
722
723
724
725
726
727
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    net = net.to(ctx)
    n_feat = F.randn((32, 4))
    n_feat = net([g1, g2, g3], n_feat)
    assert n_feat.shape == (4, 4)

728
def test_atomic_conv():
729
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True).to(F.ctx())
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
    aconv = nn.AtomicConv(interaction_cutoffs=F.tensor([12.0, 12.0]),
                          rbf_kernel_means=F.tensor([0.0, 2.0]),
                          rbf_kernel_scaling=F.tensor([4.0, 4.0]),
                          features_to_use=F.tensor([6.0, 8.0]))

    ctx = F.ctx()
    if F.gpu_ctx():
        aconv = aconv.to(ctx)

    feat = F.randn((100, 1))
    dist = F.randn((g.number_of_edges(), 1))

    h = aconv(g, feat, dist)
    # current we only do shape check
    assert h.shape[-1] == 4

def test_cf_conv():
747
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True).to(F.ctx())
748
749
750
751
752
753
754
755
756
757
758
759
760
    cfconv = nn.CFConv(node_in_feats=2,
                       edge_in_feats=3,
                       hidden_feats=2,
                       out_feats=3)

    ctx = F.ctx()
    if F.gpu_ctx():
        cfconv = cfconv.to(ctx)

    node_feats = F.randn((100, 2))
    edge_feats = F.randn((g.number_of_edges(), 3))
    h = cfconv(g, node_feats, edge_feats)
    # current we only do shape check
761
    assert h.shape[-1] == 3
762

763
764
765
766
767
768
def myagg(alist, dsttype):
    rst = alist[0]
    for i in range(1, len(alist)):
        rst = rst + (i + 1) * alist[i]
    return rst

769
@parametrize_dtype
770
@pytest.mark.parametrize('agg', ['sum', 'max', 'min', 'mean', 'stack', myagg])
771
def test_hetero_conv(agg, idtype):
772
773
774
    g = dgl.heterograph({
        ('user', 'follows', 'user'): [(0, 1), (0, 2), (2, 1), (1, 3)],
        ('user', 'plays', 'game'): [(0, 0), (0, 2), (0, 3), (1, 0), (2, 2)],
775
776
        ('store', 'sells', 'game'): [(0, 0), (0, 3), (1, 1), (1, 2)]},
        idtype=idtype, device=F.ctx())
777
    conv = nn.HeteroGraphConv({
778
779
780
        'follows': nn.GraphConv(2, 3, allow_zero_in_degree=True),
        'plays': nn.GraphConv(2, 4, allow_zero_in_degree=True),
        'sells': nn.GraphConv(3, 4, allow_zero_in_degree=True)},
781
        agg)
782
    conv = conv.to(F.ctx())
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))
    uf_dst = F.randn((4, 3))
    gf_dst = F.randn((4, 4))

    h = conv(g, {'user': uf})
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 1, 4)

    h = conv(g, {'user': uf, 'store': sf})
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 2, 4)

    h = conv(g, {'store': sf})
    assert set(h.keys()) == {'game'}
    if agg != 'stack':
        assert h['game'].shape == (4, 4)
    else:
        assert h['game'].shape == (4, 1, 4)

    # test with pair input
    conv = nn.HeteroGraphConv({
816
817
818
        'follows': nn.SAGEConv(2, 3, 'mean', allow_zero_in_degree=True),
        'plays': nn.SAGEConv((2, 4), 4, 'mean', allow_zero_in_degree=True),
        'sells': nn.SAGEConv(3, 4, 'mean', allow_zero_in_degree=True)},
819
        agg)
820
    conv = conv.to(F.ctx())
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860

    h = conv(g, ({'user': uf}, {'user' : uf, 'game' : gf}))
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 1, 4)

    # pair input requires both src and dst type features to be provided
    h = conv(g, ({'user': uf}, {'game' : gf}))
    assert set(h.keys()) == {'game'}
    if agg != 'stack':
        assert h['game'].shape == (4, 4)
    else:
        assert h['game'].shape == (4, 1, 4)

    # test with mod args
    class MyMod(th.nn.Module):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.carg2 = 0
            self.s1 = s1
            self.s2 = s2
        def forward(self, g, h, arg1=None, *, arg2=None):
            if arg1 is not None:
                self.carg1 += 1
            if arg2 is not None:
                self.carg2 += 1
            return th.zeros((g.number_of_dst_nodes(), self.s2))
    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
    conv = nn.HeteroGraphConv({
        'follows': mod1,
        'plays': mod2,
        'sells': mod3},
        agg)
861
    conv = conv.to(F.ctx())
862
863
864
865
866
867
868
869
870
871
    mod_args = {'follows' : (1,), 'plays' : (1,)}
    mod_kwargs = {'sells' : {'arg2' : 'abc'}}
    h = conv(g, {'user' : uf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs)
    assert mod1.carg1 == 1
    assert mod1.carg2 == 0
    assert mod2.carg1 == 1
    assert mod2.carg2 == 0
    assert mod3.carg1 == 0
    assert mod3.carg2 == 1

872
873
if __name__ == '__main__':
    test_graph_conv()
874
875
876
877
    test_set2set()
    test_glob_att_pool()
    test_simple_pool()
    test_set_trans()
Minjie Wang's avatar
Minjie Wang committed
878
    test_rgcn()
879
880
881
882
883
884
885
886
887
888
889
890
891
    test_tagconv()
    test_gat_conv()
    test_sage_conv()
    test_sgc_conv()
    test_appnp_conv()
    test_gin_conv()
    test_agnn_conv()
    test_gated_graph_conv()
    test_nn_conv()
    test_gmm_conv()
    test_dense_graph_conv()
    test_dense_sage_conv()
    test_dense_cheb_conv()
892
    test_sequential()
893
    test_atomic_conv()
894
    test_cf_conv()