test_nn.py 29.5 KB
Newer Older
1
2
3
4
import torch as th
import networkx as nx
import dgl
import dgl.nn.pytorch as nn
5
import dgl.function as fn
6
import backend as F
7
import pytest
8
9
from test_utils.graph_cases import get_cases, random_graph, random_bipartite, random_dglgraph
from test_utils import parametrize_dtype
10
11
from copy import deepcopy

12
13
14
import numpy as np
import scipy as sp

15
16
17
18
19
def _AXWb(A, X, W, b):
    X = th.matmul(X, W)
    Y = th.matmul(A, X.view(X.shape[0], -1)).view_as(X)
    return Y + b

20
21
def test_graph_conv0():
    g = dgl.DGLGraph(nx.path_graph(3)).to(F.ctx())
22
    ctx = F.ctx()
23
    adj = g.adjacency_matrix(transpose=False, ctx=ctx)
24

25
    conv = nn.GraphConv(5, 2, norm='none', bias=True)
26
    conv = conv.to(ctx)
27
28
    print(conv)
    # test#1: basic
29
    h0 = F.ones((3, 5))
30
    h1 = conv(g, h0)
31
32
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
33
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
34
    # test#2: more-dim
35
    h0 = F.ones((3, 5, 5))
36
    h1 = conv(g, h0)
37
38
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
39
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
40
41

    conv = nn.GraphConv(5, 2)
42
    conv = conv.to(ctx)
43
    # test#3: basic
44
    h0 = F.ones((3, 5))
45
    h1 = conv(g, h0)
46
47
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
48
    # test#4: basic
49
    h0 = F.ones((3, 5, 5))
50
    h1 = conv(g, h0)
51
52
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
53
54

    conv = nn.GraphConv(5, 2)
55
    conv = conv.to(ctx)
56
    # test#3: basic
57
    h0 = F.ones((3, 5))
58
    h1 = conv(g, h0)
59
60
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
61
    # test#4: basic
62
    h0 = F.ones((3, 5, 5))
63
    h1 = conv(g, h0)
64
65
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
66
67
68
69
70

    # test rest_parameters
    old_weight = deepcopy(conv.weight.data)
    conv.reset_parameters()
    new_weight = conv.weight.data
71
    assert not F.allclose(old_weight, new_weight)
72

73
74
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite'], exclude=['zero-degree', 'dglgraph']))
75
76
77
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
78
79
80
def test_graph_conv(idtype, g, norm, weight, bias):
    # Test one tensor input
    g = g.astype(idtype).to(F.ctx())
81
82
    conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias).to(F.ctx())
    ext_w = F.randn((5, 2)).to(F.ctx())
83
84
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
85
86
    h = F.randn((nsrc, 5)).to(F.ctx())
    if weight:
87
        h_out = conv(g, h)
88
    else:
89
90
91
        h_out = conv(g, h, weight=ext_w)
    assert h_out.shape == (ndst, 2)

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
def test_graph_conv_bi(idtype, g, norm, weight, bias):
    # Test a pair of tensor inputs
    g = g.astype(idtype).to(F.ctx())
    conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias).to(F.ctx())
    ext_w = F.randn((5, 2)).to(F.ctx())
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
    h_dst = F.randn((ndst, 2)).to(F.ctx())
    if weight:
        h_out = conv(g, (h, h_dst))
    else:
        h_out = conv(g, (h, h_dst), weight=ext_w)
    assert h_out.shape == (ndst, 2)
111

112
113
114
115
116
117
118
119
120
121
122
123
def _S2AXWb(A, N, X, W, b):
    X1 = X * N
    X1 = th.matmul(A, X1.view(X1.shape[0], -1))
    X1 = X1 * N
    X2 = X1 * N
    X2 = th.matmul(A, X2.view(X2.shape[0], -1))
    X2 = X2 * N
    X = th.cat([X, X1, X2], dim=-1)
    Y = th.matmul(X, W.rot90())

    return Y + b

124
def test_tagconv():
125
    g = dgl.DGLGraph(nx.path_graph(3))
126
    g = g.to(F.ctx())
127
    ctx = F.ctx()
128
    adj = g.adjacency_matrix(transpose=False, ctx=ctx)
129
130
    norm = th.pow(g.in_degrees().float(), -0.5)

131
    conv = nn.TAGConv(5, 2, bias=True)
132
    conv = conv.to(ctx)
133
134
135
136
    print(conv)

    # test#1: basic
    h0 = F.ones((3, 5))
137
    h1 = conv(g, h0)
138
139
140
141
142
143
144
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    shp = norm.shape + (1,) * (h0.dim() - 1)
    norm = th.reshape(norm, shp).to(ctx)

    assert F.allclose(h1, _S2AXWb(adj, norm, h0, conv.lin.weight, conv.lin.bias))

145
    conv = nn.TAGConv(5, 2)
146
    conv = conv.to(ctx)
147

148
149
    # test#2: basic
    h0 = F.ones((3, 5))
150
    h1 = conv(g, h0)
151
    assert h1.shape[-1] == 2
152

153
    # test reset_parameters
154
155
156
157
158
    old_weight = deepcopy(conv.lin.weight.data)
    conv.reset_parameters()
    new_weight = conv.lin.weight.data
    assert not F.allclose(old_weight, new_weight)

159
def test_set2set():
160
    ctx = F.ctx()
161
    g = dgl.DGLGraph(nx.path_graph(10))
162
    g = g.to(F.ctx())
163
164

    s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
165
    s2s = s2s.to(ctx)
166
167
168
    print(s2s)

    # test#1: basic
169
    h0 = F.randn((g.number_of_nodes(), 5))
170
    h1 = s2s(g, h0)
171
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
172
173

    # test#2: batched graph
174
175
    g1 = dgl.DGLGraph(nx.path_graph(11)).to(F.ctx())
    g2 = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
176
    bg = dgl.batch([g, g1, g2])
177
    h0 = F.randn((bg.number_of_nodes(), 5))
178
    h1 = s2s(bg, h0)
179
180
181
    assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2

def test_glob_att_pool():
182
    ctx = F.ctx()
183
    g = dgl.DGLGraph(nx.path_graph(10))
184
    g = g.to(F.ctx())
185
186

    gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))
187
    gap = gap.to(ctx)
188
189
190
    print(gap)

    # test#1: basic
191
    h0 = F.randn((g.number_of_nodes(), 5))
192
    h1 = gap(g, h0)
193
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
194
195
196

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
197
    h0 = F.randn((bg.number_of_nodes(), 5))
198
    h1 = gap(bg, h0)
199
200
201
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2

def test_simple_pool():
202
    ctx = F.ctx()
203
    g = dgl.DGLGraph(nx.path_graph(15))
204
    g = g.to(F.ctx())
205
206
207
208
209
210
211
212

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
    sort_pool = nn.SortPooling(10) # k = 10
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
213
    h0 = F.randn((g.number_of_nodes(), 5))
214
215
216
217
    sum_pool = sum_pool.to(ctx)
    avg_pool = avg_pool.to(ctx)
    max_pool = max_pool.to(ctx)
    sort_pool = sort_pool.to(ctx)
218
    h1 = sum_pool(g, h0)
219
    assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
220
    h1 = avg_pool(g, h0)
221
    assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
222
    h1 = max_pool(g, h0)
223
    assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
224
    h1 = sort_pool(g, h0)
225
    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.dim() == 2
226
227

    # test#2: batched graph
228
    g_ = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
229
    bg = dgl.batch([g, g_, g, g_, g])
230
    h0 = F.randn((bg.number_of_nodes(), 5))
231
    h1 = sum_pool(bg, h0)
232
233
234
235
236
237
    truth = th.stack([F.sum(h0[:15], 0),
                      F.sum(h0[15:20], 0),
                      F.sum(h0[20:35], 0),
                      F.sum(h0[35:40], 0),
                      F.sum(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
238

239
    h1 = avg_pool(bg, h0)
240
241
242
243
244
245
    truth = th.stack([F.mean(h0[:15], 0),
                      F.mean(h0[15:20], 0),
                      F.mean(h0[20:35], 0),
                      F.mean(h0[35:40], 0),
                      F.mean(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
246

247
    h1 = max_pool(bg, h0)
248
249
250
251
252
253
    truth = th.stack([F.max(h0[:15], 0),
                      F.max(h0[15:20], 0),
                      F.max(h0[20:35], 0),
                      F.max(h0[35:40], 0),
                      F.max(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
254

255
    h1 = sort_pool(bg, h0)
256
257
258
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2

def test_set_trans():
259
    ctx = F.ctx()
260
261
262
263
264
    g = dgl.DGLGraph(nx.path_graph(15))

    st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'sab')
    st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'isab', 3)
    st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)
265
266
267
    st_enc_0 = st_enc_0.to(ctx)
    st_enc_1 = st_enc_1.to(ctx)
    st_dec = st_dec.to(ctx)
268
269
270
    print(st_enc_0, st_enc_1, st_dec)

    # test#1: basic
271
    h0 = F.randn((g.number_of_nodes(), 50))
272
    h1 = st_enc_0(g, h0)
273
    assert h1.shape == h0.shape
274
    h1 = st_enc_1(g, h0)
275
    assert h1.shape == h0.shape
276
    h2 = st_dec(g, h1)
277
    assert h2.shape[0] == 1 and h2.shape[1] == 200 and h2.dim() == 2
278
279
280
281
282

    # test#2: batched graph
    g1 = dgl.DGLGraph(nx.path_graph(5))
    g2 = dgl.DGLGraph(nx.path_graph(10))
    bg = dgl.batch([g, g1, g2])
283
    h0 = F.randn((bg.number_of_nodes(), 50))
284
    h1 = st_enc_0(bg, h0)
285
    assert h1.shape == h0.shape
286
    h1 = st_enc_1(bg, h0)
287
288
    assert h1.shape == h0.shape

289
    h2 = st_dec(bg, h1)
290
291
    assert h2.shape[0] == 3 and h2.shape[1] == 200 and h2.dim() == 2

Minjie Wang's avatar
Minjie Wang committed
292
293
294
295
def test_rgcn():
    ctx = F.ctx()
    etype = []
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
296
    g = g.to(F.ctx())
Minjie Wang's avatar
Minjie Wang committed
297
298
299
300
301
302
303
304
305
    # 5 etypes
    R = 5
    for i in range(g.number_of_edges()):
        etype.append(i % 5)
    B = 2
    I = 10
    O = 8

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
306
307
308
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
309
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
Minjie Wang's avatar
Minjie Wang committed
310
311
312
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r)
313
    h_new_low = rgc_basis_low(g, h, r)
Minjie Wang's avatar
Minjie Wang committed
314
    assert list(h_new.shape) == [100, O]
315
316
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
317
318

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
319
320
    rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
    rgc_bdd_low.weight = rgc_bdd.weight
321
    rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
Minjie Wang's avatar
Minjie Wang committed
322
323
324
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_bdd(g, h, r)
325
    h_new_low = rgc_bdd_low(g, h, r)
Minjie Wang's avatar
Minjie Wang committed
326
    assert list(h_new.shape) == [100, O]
327
328
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
329
330
331
332
333

    # with norm
    norm = th.zeros((g.number_of_edges(), 1)).to(ctx)

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
334
335
336
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
337
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
Minjie Wang's avatar
Minjie Wang committed
338
339
340
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r, norm)
341
    h_new_low = rgc_basis_low(g, h, r, norm)
Minjie Wang's avatar
Minjie Wang committed
342
    assert list(h_new.shape) == [100, O]
343
344
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
345
346

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
347
348
    rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
    rgc_bdd_low.weight = rgc_bdd.weight
349
    rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
Minjie Wang's avatar
Minjie Wang committed
350
351
352
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_bdd(g, h, r, norm)
353
    h_new_low = rgc_bdd_low(g, h, r, norm)
Minjie Wang's avatar
Minjie Wang committed
354
    assert list(h_new.shape) == [100, O]
355
356
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
357
358
359

    # id input
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
360
361
362
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
363
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
Minjie Wang's avatar
Minjie Wang committed
364
365
366
    h = th.randint(0, I, (100,)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r)
367
    h_new_low = rgc_basis_low(g, h, r)
Minjie Wang's avatar
Minjie Wang committed
368
    assert list(h_new.shape) == [100, O]
369
370
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
371

372
@parametrize_dtype
373
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
374
375
def test_gat_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
376
377
    ctx = F.ctx()
    gat = nn.GATConv(5, 2, 4)
378
    feat = F.randn((g.number_of_nodes(), 5))
379
    gat = gat.to(ctx)
380
    h = gat(g, feat)
381
    assert h.shape == (g.number_of_nodes(), 4, 2)
382
383
    _, a = gat(g, feat, get_attention=True)
    assert a.shape == (g.number_of_edges(), 4, 1)
384

385
@parametrize_dtype
386
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
387
388
389
def test_gat_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
390
391
    gat = nn.GATConv(5, 2, 4)
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
392
393
    gat = gat.to(ctx)
    h = gat(g, feat)
394
    assert h.shape == (g.number_of_dst_nodes(), 4, 2)
395
396
    _, a = gat(g, feat, get_attention=True)
    assert a.shape == (g.number_of_edges(), 4, 1)
397

398
@parametrize_dtype
399
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
400
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm'])
401
402
def test_sage_conv(idtype, g, aggre_type):
    g = g.astype(idtype).to(F.ctx())
403
    sage = nn.SAGEConv(5, 10, aggre_type)
404
405
    feat = F.randn((g.number_of_nodes(), 5))
    sage = sage.to(F.ctx())
406
407
408
    h = sage(g, feat)
    assert h.shape[-1] == 10

409
@parametrize_dtype
410
@pytest.mark.parametrize('g', get_cases(['bipartite']))
411
412
413
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm'])
def test_sage_conv_bi(idtype, g, aggre_type):
    g = g.astype(idtype).to(F.ctx())
414
415
    dst_dim = 5 if aggre_type != 'gcn' else 10
    sage = nn.SAGEConv((10, dst_dim), 2, aggre_type)
416
417
    feat = (F.randn((g.number_of_src_nodes(), 10)), F.randn((g.number_of_dst_nodes(), dst_dim)))
    sage = sage.to(F.ctx())
418
419
    h = sage(g, feat)
    assert h.shape[-1] == 2
420
    assert h.shape[0] == g.number_of_dst_nodes()
421

422
423
424
@parametrize_dtype
def test_sage_conv2(idtype):
    # TODO: add test for blocks
Mufei Li's avatar
Mufei Li committed
425
    # Test the case for graphs without edges
426
    g = dgl.heterograph({('_U', '_E', '_V'): ([], [])}, {'_U': 5, '_V': 3})
427
428
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
429
    sage = nn.SAGEConv((3, 3), 2, 'gcn')
Mufei Li's avatar
Mufei Li committed
430
431
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    sage = sage.to(ctx)
432
    h = sage(g, (F.copy_to(feat[0], F.ctx()), F.copy_to(feat[1], F.ctx())))
Mufei Li's avatar
Mufei Li committed
433
434
435
    assert h.shape[-1] == 2
    assert h.shape[0] == 3
    for aggre_type in ['mean', 'pool', 'lstm']:
436
        sage = nn.SAGEConv((3, 1), 2, aggre_type)
Mufei Li's avatar
Mufei Li committed
437
438
439
440
441
442
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        sage = sage.to(ctx)
        h = sage(g, feat)
        assert h.shape[-1] == 2
        assert h.shape[0] == 3

443
444
445
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_sgc_conv(g, idtype):
446
    ctx = F.ctx()
447
    g = g.astype(idtype).to(ctx)
448
449
    # not cached
    sgc = nn.SGConv(5, 10, 3)
450
    feat = F.randn((g.number_of_nodes(), 5))
451
    sgc = sgc.to(ctx)
452

453
    h = sgc(g, feat)
454
455
456
457
    assert h.shape[-1] == 10

    # cached
    sgc = nn.SGConv(5, 10, 3, True)
458
    sgc = sgc.to(ctx)
459
460
    h_0 = sgc(g, feat)
    h_1 = sgc(g, feat + 1)
461
462
463
    assert F.allclose(h_0, h_1)
    assert h_0.shape[-1] == 10

464
465
466
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_appnp_conv(g, idtype):
467
    ctx = F.ctx()
468
    g = g.astype(idtype).to(ctx)
469
    appnp = nn.APPNPConv(10, 0.1)
470
    feat = F.randn((g.number_of_nodes(), 5))
471
    appnp = appnp.to(ctx)
472

473
    h = appnp(g, feat)
474
475
    assert h.shape[-1] == 5

476
@parametrize_dtype
477
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
478
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
479
480
def test_gin_conv(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
481
482
483
484
485
    ctx = F.ctx()
    gin = nn.GINConv(
        th.nn.Linear(5, 12),
        aggregator_type
    )
486
    feat = F.randn((g.number_of_nodes(), 5))
487
488
    gin = gin.to(ctx)
    h = gin(g, feat)
489
    assert h.shape == (g.number_of_nodes(), 12)
490

491
@parametrize_dtype
492
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
493
494
495
496
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
def test_gin_conv_bi(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
497
498
499
500
    gin = nn.GINConv(
        th.nn.Linear(5, 12),
        aggregator_type
    )
501
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
502
503
    gin = gin.to(ctx)
    h = gin(g, feat)
504
    assert h.shape == (g.number_of_dst_nodes(), 12)
505

506
@parametrize_dtype
507
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
508
509
def test_agnn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
510
511
    ctx = F.ctx()
    agnn = nn.AGNNConv(1)
512
    feat = F.randn((g.number_of_nodes(), 5))
513
    agnn = agnn.to(ctx)
514
    h = agnn(g, feat)
515
    assert h.shape == (g.number_of_nodes(), 5)
516

517
@parametrize_dtype
518
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
519
520
521
def test_agnn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
522
    agnn = nn.AGNNConv(1)
523
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
524
525
    agnn = agnn.to(ctx)
    h = agnn(g, feat)
526
    assert h.shape == (g.number_of_dst_nodes(), 5)
527

528
529
530
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_gated_graph_conv(g, idtype):
531
    ctx = F.ctx()
532
    g = g.astype(idtype).to(ctx)
533
534
    ggconv = nn.GatedGraphConv(5, 10, 5, 3)
    etypes = th.arange(g.number_of_edges()) % 3
535
    feat = F.randn((g.number_of_nodes(), 5))
536
537
    ggconv = ggconv.to(ctx)
    etypes = etypes.to(ctx)
538

539
    h = ggconv(g, feat, etypes)
540
541
542
    # current we only do shape check
    assert h.shape[-1] == 10

543
@parametrize_dtype
544
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
545
546
def test_nn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
547
548
549
    ctx = F.ctx()
    edge_func = th.nn.Linear(4, 5 * 10)
    nnconv = nn.NNConv(5, 10, edge_func, 'mean')
550
    feat = F.randn((g.number_of_nodes(), 5))
551
552
553
554
555
556
    efeat = F.randn((g.number_of_edges(), 4))
    nnconv = nnconv.to(ctx)
    h = nnconv(g, feat, efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

557
@parametrize_dtype
558
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
559
560
561
def test_nn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
562
563
    edge_func = th.nn.Linear(4, 5 * 10)
    nnconv = nn.NNConv((5, 2), 10, edge_func, 'mean')
564
565
    feat = F.randn((g.number_of_src_nodes(), 5))
    feat_dst = F.randn((g.number_of_dst_nodes(), 2))
566
567
568
569
570
571
    efeat = F.randn((g.number_of_edges(), 4))
    nnconv = nnconv.to(ctx)
    h = nnconv(g, (feat, feat_dst), efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

572
@parametrize_dtype
573
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
574
575
def test_gmm_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
576
577
    ctx = F.ctx()
    gmmconv = nn.GMMConv(5, 10, 3, 4, 'mean')
578
    feat = F.randn((g.number_of_nodes(), 5))
579
    pseudo = F.randn((g.number_of_edges(), 3))
580
    gmmconv = gmmconv.to(ctx)
581
    h = gmmconv(g, feat, pseudo)
582
583
584
    # currently we only do shape check
    assert h.shape[-1] == 10

585
@parametrize_dtype
586
@pytest.mark.parametrize('g', get_cases(['bipartite', 'block-bipartite'], exclude=['zero-degree']))
587
588
589
def test_gmm_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
590
    gmmconv = nn.GMMConv((5, 2), 10, 3, 4, 'mean')
591
592
    feat = F.randn((g.number_of_src_nodes(), 5))
    feat_dst = F.randn((g.number_of_dst_nodes(), 2))
593
594
595
596
597
598
    pseudo = F.randn((g.number_of_edges(), 3))
    gmmconv = gmmconv.to(ctx)
    h = gmmconv(g, (feat, feat_dst), pseudo)
    # currently we only do shape check
    assert h.shape[-1] == 10

599
@parametrize_dtype
600
@pytest.mark.parametrize('norm_type', ['both', 'right', 'none'])
601
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite'], exclude=['zero-degree']))
602
603
def test_dense_graph_conv(norm_type, g, idtype):
    g = g.astype(idtype).to(F.ctx())
604
    ctx = F.ctx()
605
    # TODO(minjie): enable the following option after #1385
606
    adj = g.adjacency_matrix(transpose=False, ctx=ctx).to_dense()
607
608
    conv = nn.GraphConv(5, 2, norm=norm_type, bias=True)
    dense_conv = nn.DenseGraphConv(5, 2, norm=norm_type, bias=True)
609
610
    dense_conv.weight.data = conv.weight.data
    dense_conv.bias.data = conv.bias.data
611
    feat = F.randn((g.number_of_src_nodes(), 5))
612
613
    conv = conv.to(ctx)
    dense_conv = dense_conv.to(ctx)
614
615
    out_conv = conv(g, feat)
    out_dense_conv = dense_conv(adj, feat)
616
617
    assert F.allclose(out_conv, out_dense_conv)

618
@parametrize_dtype
619
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite']))
620
621
def test_dense_sage_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
622
    ctx = F.ctx()
623
    adj = g.adjacency_matrix(transpose=False, ctx=ctx).to_dense()
624
    sage = nn.SAGEConv(5, 2, 'gcn')
625
626
627
    dense_sage = nn.DenseSAGEConv(5, 2)
    dense_sage.fc.weight.data = sage.fc_neigh.weight.data
    dense_sage.fc.bias.data = sage.fc_neigh.bias.data
628
629
630
631
632
633
634
    if len(g.ntypes) == 2:
        feat = (
            F.randn((g.number_of_src_nodes(), 5)),
            F.randn((g.number_of_dst_nodes(), 5))
        )
    else:
        feat = F.randn((g.number_of_nodes(), 5))
635
636
    sage = sage.to(ctx)
    dense_sage = dense_sage.to(ctx)
637
638
    out_sage = sage(g, feat)
    out_dense_sage = dense_sage(adj, feat)
639
640
    assert F.allclose(out_sage, out_dense_sage), g

641
@parametrize_dtype
642
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
643
644
def test_edge_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
645
646
647
    ctx = F.ctx()
    edge_conv = nn.EdgeConv(5, 2).to(ctx)
    print(edge_conv)
648
649
650
    h0 = F.randn((g.number_of_nodes(), 5))
    h1 = edge_conv(g, h0)
    assert h1.shape == (g.number_of_nodes(), 2)
651

652
@parametrize_dtype
653
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
654
655
656
657
658
def test_edge_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    edge_conv = nn.EdgeConv(5, 2).to(ctx)
    print(edge_conv)
659
    h0 = F.randn((g.number_of_src_nodes(), 5))
660
661
    x0 = F.randn((g.number_of_dst_nodes(), 5))
    h1 = edge_conv(g, (h0, x0))
662
    assert h1.shape == (g.number_of_dst_nodes(), 2)
663
664
665
666
667

def test_dense_cheb_conv():
    for k in range(1, 4):
        ctx = F.ctx()
        g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
668
        g = g.to(F.ctx())
669
        adj = g.adjacency_matrix(transpose=False, ctx=ctx).to_dense()
Axel Nilsson's avatar
Axel Nilsson committed
670
        cheb = nn.ChebConv(5, 2, k, None)
671
        dense_cheb = nn.DenseChebConv(5, 2, k)
Axel Nilsson's avatar
Axel Nilsson committed
672
673
674
675
676
        #for i in range(len(cheb.fc)):
        #    dense_cheb.W.data[i] = cheb.fc[i].weight.data.t()
        dense_cheb.W.data = cheb.linear.weight.data.transpose(-1, -2).view(k, 5, 2)
        if cheb.linear.bias is not None:
            dense_cheb.bias.data = cheb.linear.bias.data
677
        feat = F.randn((100, 5))
678
679
        cheb = cheb.to(ctx)
        dense_cheb = dense_cheb.to(ctx)
680
681
        out_cheb = cheb(g, feat, [2.0])
        out_dense_cheb = dense_cheb(adj, feat, 2.0)
Axel Nilsson's avatar
Axel Nilsson committed
682
        print(k, out_cheb, out_dense_cheb)
683
684
        assert F.allclose(out_cheb, out_dense_cheb)

685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
def test_sequential():
    ctx = F.ctx()
    # Test single graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat, e_feat):
            graph = graph.local_var()
            graph.ndata['h'] = n_feat
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            n_feat += graph.ndata['h']
            graph.apply_edges(fn.u_add_v('h', 'h', 'e'))
            e_feat += graph.edata['e']
            return n_feat, e_feat

    g = dgl.DGLGraph()
    g.add_nodes(3)
    g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
704
    g = g.to(F.ctx())
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    n_feat = F.randn((3, 4))
    e_feat = F.randn((9, 4))
    net = net.to(ctx)
    n_feat, e_feat = net(g, n_feat, e_feat)
    assert n_feat.shape == (3, 4)
    assert e_feat.shape == (9, 4)

    # Test multiple graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat):
            graph = graph.local_var()
            graph.ndata['h'] = n_feat
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            n_feat += graph.ndata['h']
            return n_feat.view(graph.number_of_nodes() // 2, 2, -1).sum(1)

725
726
727
    g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05)).to(F.ctx())
    g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2)).to(F.ctx())
    g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8)).to(F.ctx())
728
729
730
731
732
733
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    net = net.to(ctx)
    n_feat = F.randn((32, 4))
    n_feat = net([g1, g2, g3], n_feat)
    assert n_feat.shape == (4, 4)

734
735
736
737
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_atomic_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
738
739
740
741
742
743
744
745
746
    aconv = nn.AtomicConv(interaction_cutoffs=F.tensor([12.0, 12.0]),
                          rbf_kernel_means=F.tensor([0.0, 2.0]),
                          rbf_kernel_scaling=F.tensor([4.0, 4.0]),
                          features_to_use=F.tensor([6.0, 8.0]))

    ctx = F.ctx()
    if F.gpu_ctx():
        aconv = aconv.to(ctx)

747
    feat = F.randn((g.number_of_nodes(), 1))
748
749
750
751
752
753
    dist = F.randn((g.number_of_edges(), 1))

    h = aconv(g, feat, dist)
    # current we only do shape check
    assert h.shape[-1] == 4

754
755
756
757
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_cf_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
758
759
760
761
762
763
764
765
766
    cfconv = nn.CFConv(node_in_feats=2,
                       edge_in_feats=3,
                       hidden_feats=2,
                       out_feats=3)

    ctx = F.ctx()
    if F.gpu_ctx():
        cfconv = cfconv.to(ctx)

767
    node_feats = F.randn((g.number_of_nodes(), 2))
768
769
770
    edge_feats = F.randn((g.number_of_edges(), 3))
    h = cfconv(g, node_feats, edge_feats)
    # current we only do shape check
771
    assert h.shape[-1] == 3
772

773
774
775
776
777
778
def myagg(alist, dsttype):
    rst = alist[0]
    for i in range(1, len(alist)):
        rst = rst + (i + 1) * alist[i]
    return rst

779
@parametrize_dtype
780
@pytest.mark.parametrize('agg', ['sum', 'max', 'min', 'mean', 'stack', myagg])
781
def test_hetero_conv(agg, idtype):
782
    g = dgl.heterograph({
783
784
785
        ('user', 'follows', 'user'): ([0, 0, 2, 1], [1, 2, 1, 3]),
        ('user', 'plays', 'game'): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]),
        ('store', 'sells', 'game'): ([0, 0, 1, 1], [0, 3, 1, 2])},
786
        idtype=idtype, device=F.ctx())
787
    conv = nn.HeteroGraphConv({
788
789
790
        'follows': nn.GraphConv(2, 3, allow_zero_in_degree=True),
        'plays': nn.GraphConv(2, 4, allow_zero_in_degree=True),
        'sells': nn.GraphConv(3, 4, allow_zero_in_degree=True)},
791
        agg)
792
    conv = conv.to(F.ctx())
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))

    h = conv(g, {'user': uf})
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 1, 4)

    h = conv(g, {'user': uf, 'store': sf})
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 2, 4)

    h = conv(g, {'store': sf})
    assert set(h.keys()) == {'game'}
    if agg != 'stack':
        assert h['game'].shape == (4, 4)
    else:
        assert h['game'].shape == (4, 1, 4)

    # test with pair input
    conv = nn.HeteroGraphConv({
824
825
826
        'follows': nn.SAGEConv(2, 3, 'mean'),
        'plays': nn.SAGEConv((2, 4), 4, 'mean'),
        'sells': nn.SAGEConv(3, 4, 'mean')},
827
        agg)
828
    conv = conv.to(F.ctx())
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868

    h = conv(g, ({'user': uf}, {'user' : uf, 'game' : gf}))
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 1, 4)

    # pair input requires both src and dst type features to be provided
    h = conv(g, ({'user': uf}, {'game' : gf}))
    assert set(h.keys()) == {'game'}
    if agg != 'stack':
        assert h['game'].shape == (4, 4)
    else:
        assert h['game'].shape == (4, 1, 4)

    # test with mod args
    class MyMod(th.nn.Module):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.carg2 = 0
            self.s1 = s1
            self.s2 = s2
        def forward(self, g, h, arg1=None, *, arg2=None):
            if arg1 is not None:
                self.carg1 += 1
            if arg2 is not None:
                self.carg2 += 1
            return th.zeros((g.number_of_dst_nodes(), self.s2))
    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
    conv = nn.HeteroGraphConv({
        'follows': mod1,
        'plays': mod2,
        'sells': mod3},
        agg)
869
    conv = conv.to(F.ctx())
870
871
872
873
874
875
876
877
878
879
    mod_args = {'follows' : (1,), 'plays' : (1,)}
    mod_kwargs = {'sells' : {'arg2' : 'abc'}}
    h = conv(g, {'user' : uf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs)
    assert mod1.carg1 == 1
    assert mod1.carg2 == 0
    assert mod2.carg1 == 1
    assert mod2.carg2 == 0
    assert mod3.carg1 == 0
    assert mod3.carg2 == 1

880
881
if __name__ == '__main__':
    test_graph_conv()
882
883
884
885
    test_set2set()
    test_glob_att_pool()
    test_simple_pool()
    test_set_trans()
Minjie Wang's avatar
Minjie Wang committed
886
    test_rgcn()
887
888
889
890
891
892
893
894
895
896
897
898
899
    test_tagconv()
    test_gat_conv()
    test_sage_conv()
    test_sgc_conv()
    test_appnp_conv()
    test_gin_conv()
    test_agnn_conv()
    test_gated_graph_conv()
    test_nn_conv()
    test_gmm_conv()
    test_dense_graph_conv()
    test_dense_sage_conv()
    test_dense_cheb_conv()
900
    test_sequential()
901
    test_atomic_conv()
902
    test_cf_conv()