test_nn.py 31.6 KB
Newer Older
1
2
3
4
import torch as th
import networkx as nx
import dgl
import dgl.nn.pytorch as nn
5
import dgl.function as fn
6
import backend as F
7
import pytest
8
9
from test_utils.graph_cases import get_cases, random_graph, random_bipartite, random_dglgraph, \
    random_block
10
11
from copy import deepcopy

12
13
14
import numpy as np
import scipy as sp

15
16
17
18
19
20
21
def _AXWb(A, X, W, b):
    X = th.matmul(X, W)
    Y = th.matmul(A, X.view(X.shape[0], -1)).view_as(X)
    return Y + b

def test_graph_conv():
    g = dgl.DGLGraph(nx.path_graph(3))
22
23
    ctx = F.ctx()
    adj = g.adjacency_matrix(ctx=ctx)
24

25
    conv = nn.GraphConv(5, 2, norm='none', bias=True)
26
    conv = conv.to(ctx)
27
28
    print(conv)
    # test#1: basic
29
    h0 = F.ones((3, 5))
30
    h1 = conv(g, h0)
31
32
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
33
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
34
    # test#2: more-dim
35
    h0 = F.ones((3, 5, 5))
36
    h1 = conv(g, h0)
37
38
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
39
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
40
41

    conv = nn.GraphConv(5, 2)
42
    conv = conv.to(ctx)
43
    # test#3: basic
44
    h0 = F.ones((3, 5))
45
    h1 = conv(g, h0)
46
47
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
48
    # test#4: basic
49
    h0 = F.ones((3, 5, 5))
50
    h1 = conv(g, h0)
51
52
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
53
54

    conv = nn.GraphConv(5, 2)
55
    conv = conv.to(ctx)
56
    # test#3: basic
57
    h0 = F.ones((3, 5))
58
    h1 = conv(g, h0)
59
60
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
61
    # test#4: basic
62
    h0 = F.ones((3, 5, 5))
63
    h1 = conv(g, h0)
64
65
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
66
67
68
69
70

    # test rest_parameters
    old_weight = deepcopy(conv.weight.data)
    conv.reset_parameters()
    new_weight = conv.weight.data
71
    assert not F.allclose(old_weight, new_weight)
72

73
@pytest.mark.parametrize('g', get_cases(['path', 'bipartite', 'small', 'block'], exclude=['zero-degree']))
74
75
76
77
78
79
80
81
82
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
def test_graph_conv2(g, norm, weight, bias):
    conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias).to(F.ctx())
    ext_w = F.randn((5, 2)).to(F.ctx())
    nsrc = g.number_of_nodes() if isinstance(g, dgl.DGLGraph) else g.number_of_src_nodes()
    ndst = g.number_of_nodes() if isinstance(g, dgl.DGLGraph) else g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
83
    h_dst = F.randn((ndst, 2)).to(F.ctx())
84
    if weight:
85
        h_out = conv(g, h)
86
    else:
87
88
89
90
91
92
93
94
95
96
97
        h_out = conv(g, h, weight=ext_w)
    assert h_out.shape == (ndst, 2)

    if not isinstance(g, dgl.DGLGraph) and len(g.ntypes) == 2:
        # bipartite, should also accept pair of tensors
        if weight:
            h_out2 = conv(g, (h, h_dst))
        else:
            h_out2 = conv(g, (h, h_dst), weight=ext_w)
        assert h_out2.shape == (ndst, 2)
        assert F.array_equal(h_out, h_out2)
98

99
100
101
102
103
104
105
106
107
108
109
110
def _S2AXWb(A, N, X, W, b):
    X1 = X * N
    X1 = th.matmul(A, X1.view(X1.shape[0], -1))
    X1 = X1 * N
    X2 = X1 * N
    X2 = th.matmul(A, X2.view(X2.shape[0], -1))
    X2 = X2 * N
    X = th.cat([X, X1, X2], dim=-1)
    Y = th.matmul(X, W.rot90())

    return Y + b

111
def test_tagconv():
112
113
114
115
116
    g = dgl.DGLGraph(nx.path_graph(3))
    ctx = F.ctx()
    adj = g.adjacency_matrix(ctx=ctx)
    norm = th.pow(g.in_degrees().float(), -0.5)

117
    conv = nn.TAGConv(5, 2, bias=True)
118
    conv = conv.to(ctx)
119
120
121
122
    print(conv)

    # test#1: basic
    h0 = F.ones((3, 5))
123
    h1 = conv(g, h0)
124
125
126
127
128
129
130
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    shp = norm.shape + (1,) * (h0.dim() - 1)
    norm = th.reshape(norm, shp).to(ctx)

    assert F.allclose(h1, _S2AXWb(adj, norm, h0, conv.lin.weight, conv.lin.bias))

131
    conv = nn.TAGConv(5, 2)
132
    conv = conv.to(ctx)
133

134
135
    # test#2: basic
    h0 = F.ones((3, 5))
136
    h1 = conv(g, h0)
137
    assert h1.shape[-1] == 2
138

139
    # test reset_parameters
140
141
142
143
144
    old_weight = deepcopy(conv.lin.weight.data)
    conv.reset_parameters()
    new_weight = conv.lin.weight.data
    assert not F.allclose(old_weight, new_weight)

145
def test_set2set():
146
    ctx = F.ctx()
147
148
149
    g = dgl.DGLGraph(nx.path_graph(10))

    s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
150
    s2s = s2s.to(ctx)
151
152
153
    print(s2s)

    # test#1: basic
154
    h0 = F.randn((g.number_of_nodes(), 5))
155
    h1 = s2s(g, h0)
156
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
157
158
159
160
161

    # test#2: batched graph
    g1 = dgl.DGLGraph(nx.path_graph(11))
    g2 = dgl.DGLGraph(nx.path_graph(5))
    bg = dgl.batch([g, g1, g2])
162
    h0 = F.randn((bg.number_of_nodes(), 5))
163
    h1 = s2s(bg, h0)
164
165
166
    assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2

def test_glob_att_pool():
167
    ctx = F.ctx()
168
169
170
    g = dgl.DGLGraph(nx.path_graph(10))

    gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))
171
    gap = gap.to(ctx)
172
173
174
    print(gap)

    # test#1: basic
175
    h0 = F.randn((g.number_of_nodes(), 5))
176
    h1 = gap(g, h0)
177
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
178
179
180

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
181
    h0 = F.randn((bg.number_of_nodes(), 5))
182
    h1 = gap(bg, h0)
183
184
185
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2

def test_simple_pool():
186
    ctx = F.ctx()
187
188
189
190
191
192
193
194
195
    g = dgl.DGLGraph(nx.path_graph(15))

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
    sort_pool = nn.SortPooling(10) # k = 10
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
196
    h0 = F.randn((g.number_of_nodes(), 5))
197
198
199
200
    sum_pool = sum_pool.to(ctx)
    avg_pool = avg_pool.to(ctx)
    max_pool = max_pool.to(ctx)
    sort_pool = sort_pool.to(ctx)
201
    h1 = sum_pool(g, h0)
202
    assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
203
    h1 = avg_pool(g, h0)
204
    assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
205
    h1 = max_pool(g, h0)
206
    assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
207
    h1 = sort_pool(g, h0)
208
    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.dim() == 2
209
210
211
212

    # test#2: batched graph
    g_ = dgl.DGLGraph(nx.path_graph(5))
    bg = dgl.batch([g, g_, g, g_, g])
213
    h0 = F.randn((bg.number_of_nodes(), 5))
214
    h1 = sum_pool(bg, h0)
215
216
217
218
219
220
    truth = th.stack([F.sum(h0[:15], 0),
                      F.sum(h0[15:20], 0),
                      F.sum(h0[20:35], 0),
                      F.sum(h0[35:40], 0),
                      F.sum(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
221

222
    h1 = avg_pool(bg, h0)
223
224
225
226
227
228
    truth = th.stack([F.mean(h0[:15], 0),
                      F.mean(h0[15:20], 0),
                      F.mean(h0[20:35], 0),
                      F.mean(h0[35:40], 0),
                      F.mean(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
229

230
    h1 = max_pool(bg, h0)
231
232
233
234
235
236
    truth = th.stack([F.max(h0[:15], 0),
                      F.max(h0[15:20], 0),
                      F.max(h0[20:35], 0),
                      F.max(h0[35:40], 0),
                      F.max(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
237

238
    h1 = sort_pool(bg, h0)
239
240
241
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2

def test_set_trans():
242
    ctx = F.ctx()
243
244
245
246
247
    g = dgl.DGLGraph(nx.path_graph(15))

    st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'sab')
    st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'isab', 3)
    st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)
248
249
250
    st_enc_0 = st_enc_0.to(ctx)
    st_enc_1 = st_enc_1.to(ctx)
    st_dec = st_dec.to(ctx)
251
252
253
    print(st_enc_0, st_enc_1, st_dec)

    # test#1: basic
254
    h0 = F.randn((g.number_of_nodes(), 50))
255
    h1 = st_enc_0(g, h0)
256
    assert h1.shape == h0.shape
257
    h1 = st_enc_1(g, h0)
258
    assert h1.shape == h0.shape
259
    h2 = st_dec(g, h1)
260
    assert h2.shape[0] == 1 and h2.shape[1] == 200 and h2.dim() == 2
261
262
263
264
265

    # test#2: batched graph
    g1 = dgl.DGLGraph(nx.path_graph(5))
    g2 = dgl.DGLGraph(nx.path_graph(10))
    bg = dgl.batch([g, g1, g2])
266
    h0 = F.randn((bg.number_of_nodes(), 50))
267
    h1 = st_enc_0(bg, h0)
268
    assert h1.shape == h0.shape
269
    h1 = st_enc_1(bg, h0)
270
271
    assert h1.shape == h0.shape

272
    h2 = st_dec(bg, h1)
273
274
    assert h2.shape[0] == 3 and h2.shape[1] == 200 and h2.dim() == 2

275
276
277
278
279
280
def uniform_attention(g, shape):
    a = th.ones(shape)
    target_shape = (g.number_of_edges(),) + (1,) * (len(shape) - 1)
    return a / g.in_degrees(g.edges()[1]).view(target_shape).float()

def test_edge_softmax():
281
    # Basic
282
    g = dgl.graph(nx.path_graph(3))
283
    edata = F.ones((g.number_of_edges(), 1))
284
    a = nn.edge_softmax(g, edata)
285
286
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
287
    assert F.allclose(a, uniform_attention(g, a.shape))
288

289
    # Test higher dimension case
290
    edata = F.ones((g.number_of_edges(), 3, 1))
291
    a = nn.edge_softmax(g, edata)
292
293
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
294
    assert F.allclose(a, uniform_attention(g, a.shape))
295

296
    # Test both forward and backward with PyTorch built-in softmax.
297
    g = dgl.rand_graph(30, 900)
298

299
    score = F.randn((900, 1))
300
    score.requires_grad_()
301
302
    grad = F.randn((900, 1))
    y = F.softmax(score.view(30, 30), dim=0).view(-1, 1)
303
304
305
306
    y.backward(grad)
    grad_score = score.grad
    score.grad.zero_()
    y_dgl = nn.edge_softmax(g, score)
307
308
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
309
    # check forward
310
    assert F.allclose(y_dgl, y)
311
312
    y_dgl.backward(grad)
    # checkout gradient
313
    assert F.allclose(score.grad, grad_score)
314
315
    print(score.grad[:10], grad_score[:10])
    
316
    """
317
    # Test 2
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
    def generate_rand_graph(n, m=None, ctor=dgl.DGLGraph):
        if m is None:
            m = n
        arr = (sp.sparse.random(m, n, density=0.1, format='coo') != 0).astype(np.int64)
        return ctor(arr, readonly=True)

    for g in [generate_rand_graph(50),
              generate_rand_graph(50, ctor=dgl.graph),
              generate_rand_graph(100, 50, ctor=dgl.bipartite)]:
        a1 = F.randn((g.number_of_edges(), 1)).requires_grad_()
        a2 = a1.clone().detach().requires_grad_()
        g.edata['s'] = a1
        g.group_apply_edges('dst', lambda edges: {'ss':F.softmax(edges.data['s'], 1)})
        g.edata['ss'].sum().backward()
        
        builtin_sm = nn.edge_softmax(g, a2)
        builtin_sm.sum().backward()
        print(a1.grad - a2.grad)
        assert len(g.srcdata) == 0
        assert len(g.dstdata) == 0
        assert len(g.edata) == 2
        assert F.allclose(a1.grad, a2.grad, rtol=1e-4, atol=1e-4) # Follow tolerance in unittest backend
340
    """
341
342

def test_partial_edge_softmax():
343
    g = dgl.rand_graph(30, 900)
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365

    score = F.randn((300, 1))
    score.requires_grad_()
    grad = F.randn((300, 1))
    import numpy as np
    eids = np.random.choice(900, 300, replace=False).astype('int64')
    eids = F.zerocopy_from_numpy(eids)
    # compute partial edge softmax
    y_1 = nn.edge_softmax(g, score, eids)
    y_1.backward(grad)
    grad_1 = score.grad
    score.grad.zero_()
    # compute edge softmax on edge subgraph
    subg = g.edge_subgraph(eids)
    y_2 = nn.edge_softmax(subg, score)
    y_2.backward(grad)
    grad_2 = score.grad
    score.grad.zero_()

    assert F.allclose(y_1, y_2)
    assert F.allclose(grad_1, grad_2)

Minjie Wang's avatar
Minjie Wang committed
366
367
368
369
370
371
372
373
374
375
376
377
378
def test_rgcn():
    ctx = F.ctx()
    etype = []
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    # 5 etypes
    R = 5
    for i in range(g.number_of_edges()):
        etype.append(i % 5)
    B = 2
    I = 10
    O = 8

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
379
380
381
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
Minjie Wang's avatar
Minjie Wang committed
382
383
384
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r)
385
    h_new_low = rgc_basis_low(g, h, r)
Minjie Wang's avatar
Minjie Wang committed
386
    assert list(h_new.shape) == [100, O]
387
388
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
389
390

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
391
392
    rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
    rgc_bdd_low.weight = rgc_bdd.weight
Minjie Wang's avatar
Minjie Wang committed
393
394
395
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_bdd(g, h, r)
396
    h_new_low = rgc_bdd_low(g, h, r)
Minjie Wang's avatar
Minjie Wang committed
397
    assert list(h_new.shape) == [100, O]
398
399
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
400
401
402
403
404

    # with norm
    norm = th.zeros((g.number_of_edges(), 1)).to(ctx)

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
405
406
407
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
Minjie Wang's avatar
Minjie Wang committed
408
409
410
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r, norm)
411
    h_new_low = rgc_basis_low(g, h, r, norm)
Minjie Wang's avatar
Minjie Wang committed
412
    assert list(h_new.shape) == [100, O]
413
414
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
415
416

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
417
418
    rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
    rgc_bdd_low.weight = rgc_bdd.weight
Minjie Wang's avatar
Minjie Wang committed
419
420
421
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_bdd(g, h, r, norm)
422
    h_new_low = rgc_bdd_low(g, h, r, norm)
Minjie Wang's avatar
Minjie Wang committed
423
    assert list(h_new.shape) == [100, O]
424
425
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
426
427
428

    # id input
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
429
430
431
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
Minjie Wang's avatar
Minjie Wang committed
432
433
434
    h = th.randint(0, I, (100,)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r)
435
    h_new_low = rgc_basis_low(g, h, r)
Minjie Wang's avatar
Minjie Wang committed
436
    assert list(h_new.shape) == [100, O]
437
438
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
439

440
441
def test_gat_conv():
    ctx = F.ctx()
442
    g = dgl.rand_graph(100, 1000)
443
444
    gat = nn.GATConv(5, 2, 4)
    feat = F.randn((100, 5))
445
    gat = gat.to(ctx)
446
    h = gat(g, feat)
447
    assert h.shape == (100, 4, 2)
448

449
450
451
452
453
454
455
    g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1))
    gat = nn.GATConv((5, 10), 2, 4)
    feat = (F.randn((100, 5)), F.randn((200, 10)))
    gat = gat.to(ctx)
    h = gat(g, feat)
    assert h.shape == (200, 4, 2)

456
457
458
459
460
461
462
463
464
    g = dgl.graph(sp.sparse.random(100, 100, density=0.001))
    seed_nodes = th.unique(g.edges()[1])
    block = dgl.to_block(g, seed_nodes)
    gat = nn.GATConv(5, 2, 4)
    feat = F.randn((block.number_of_src_nodes(), 5))
    gat = gat.to(ctx)
    h = gat(block, feat)
    assert h.shape == (block.number_of_dst_nodes(), 4, 2)

465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm'])
def test_sage_conv(aggre_type):
    ctx = F.ctx()
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    sage = nn.SAGEConv(5, 10, aggre_type)
    feat = F.randn((100, 5))
    sage = sage.to(ctx)
    h = sage(g, feat)
    assert h.shape[-1] == 10

    g = dgl.graph(sp.sparse.random(100, 100, density=0.1))
    sage = nn.SAGEConv(5, 10, aggre_type)
    feat = F.randn((100, 5))
    sage = sage.to(ctx)
    h = sage(g, feat)
    assert h.shape[-1] == 10

    g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1))
    dst_dim = 5 if aggre_type != 'gcn' else 10
    sage = nn.SAGEConv((10, dst_dim), 2, aggre_type)
    feat = (F.randn((100, 10)), F.randn((200, dst_dim)))
    sage = sage.to(ctx)
    h = sage(g, feat)
    assert h.shape[-1] == 2
    assert h.shape[0] == 200
490

491
492
493
494
495
496
497
498
499
500
    g = dgl.graph(sp.sparse.random(100, 100, density=0.001))
    seed_nodes = th.unique(g.edges()[1])
    block = dgl.to_block(g, seed_nodes)
    sage = nn.SAGEConv(5, 10, aggre_type)
    feat = F.randn((block.number_of_src_nodes(), 5))
    sage = sage.to(ctx)
    h = sage(block, feat)
    assert h.shape[0] == block.number_of_dst_nodes()
    assert h.shape[-1] == 10

Mufei Li's avatar
Mufei Li committed
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
    # Test the case for graphs without edges
    g = dgl.bipartite([], num_nodes=(5, 3))
    sage = nn.SAGEConv((3, 3), 2, 'gcn')
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    sage = sage.to(ctx)
    h = sage(g, feat)
    assert h.shape[-1] == 2
    assert h.shape[0] == 3
    for aggre_type in ['mean', 'pool', 'lstm']:
        sage = nn.SAGEConv((3, 1), 2, aggre_type)
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        sage = sage.to(ctx)
        h = sage(g, feat)
        assert h.shape[-1] == 2
        assert h.shape[0] == 3

517
518
519
520
521
522
def test_sgc_conv():
    ctx = F.ctx()
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    # not cached
    sgc = nn.SGConv(5, 10, 3)
    feat = F.randn((100, 5))
523
    sgc = sgc.to(ctx)
524

525
    h = sgc(g, feat)
526
527
528
529
    assert h.shape[-1] == 10

    # cached
    sgc = nn.SGConv(5, 10, 3, True)
530
    sgc = sgc.to(ctx)
531
532
    h_0 = sgc(g, feat)
    h_1 = sgc(g, feat + 1)
533
534
535
536
537
538
539
540
    assert F.allclose(h_0, h_1)
    assert h_0.shape[-1] == 10

def test_appnp_conv():
    ctx = F.ctx()
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    appnp = nn.APPNPConv(10, 0.1)
    feat = F.randn((100, 5))
541
    appnp = appnp.to(ctx)
542

543
    h = appnp(g, feat)
544
545
    assert h.shape[-1] == 5

546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
def test_gin_conv(aggregator_type):
    ctx = F.ctx()
    g = dgl.graph(sp.sparse.random(100, 100, density=0.1))
    gin = nn.GINConv(
        th.nn.Linear(5, 12),
        aggregator_type
    )
    feat = F.randn((100, 5))
    gin = gin.to(ctx)
    h = gin(g, feat)
    assert h.shape == (100, 12)

    g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1))
    gin = nn.GINConv(
        th.nn.Linear(5, 12),
        aggregator_type
    )
    feat = (F.randn((100, 5)), F.randn((200, 5)))
    gin = gin.to(ctx)
    h = gin(g, feat)
    assert h.shape == (200, 12)
568

569
570
571
572
573
574
575
576
577
    g = dgl.graph(sp.sparse.random(100, 100, density=0.001))
    seed_nodes = th.unique(g.edges()[1])
    block = dgl.to_block(g, seed_nodes)
    gin = nn.GINConv(th.nn.Linear(5, 12), aggregator_type)
    feat = F.randn((block.number_of_src_nodes(), 5))
    gin = gin.to(ctx)
    h = gin(block, feat)
    assert h.shape == (block.number_of_dst_nodes(), 12)

578
579
def test_agnn_conv():
    ctx = F.ctx()
580
    g = dgl.graph(sp.sparse.random(100, 100, density=0.1))
581
582
    agnn = nn.AGNNConv(1)
    feat = F.randn((100, 5))
583
    agnn = agnn.to(ctx)
584
    h = agnn(g, feat)
585
586
587
588
589
590
591
592
    assert h.shape == (100, 5)

    g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1))
    agnn = nn.AGNNConv(1)
    feat = (F.randn((100, 5)), F.randn((200, 5)))
    agnn = agnn.to(ctx)
    h = agnn(g, feat)
    assert h.shape == (200, 5)
593

594
595
596
597
598
599
600
601
602
    g = dgl.graph(sp.sparse.random(100, 100, density=0.001))
    seed_nodes = th.unique(g.edges()[1])
    block = dgl.to_block(g, seed_nodes)
    agnn = nn.AGNNConv(1)
    feat = F.randn((block.number_of_src_nodes(), 5))
    agnn = agnn.to(ctx)
    h = agnn(block, feat)
    assert h.shape == (block.number_of_dst_nodes(), 5)

603
604
605
606
607
608
def test_gated_graph_conv():
    ctx = F.ctx()
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    ggconv = nn.GatedGraphConv(5, 10, 5, 3)
    etypes = th.arange(g.number_of_edges()) % 3
    feat = F.randn((100, 5))
609
610
    ggconv = ggconv.to(ctx)
    etypes = etypes.to(ctx)
611

612
    h = ggconv(g, feat, etypes)
613
614
615
616
617
618
619
620
621
622
    # current we only do shape check
    assert h.shape[-1] == 10

def test_nn_conv():
    ctx = F.ctx()
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    edge_func = th.nn.Linear(4, 5 * 10)
    nnconv = nn.NNConv(5, 10, edge_func, 'mean')
    feat = F.randn((100, 5))
    efeat = F.randn((g.number_of_edges(), 4))
623
    nnconv = nnconv.to(ctx)
624
    h = nnconv(g, feat, efeat)
625
626
627
    # currently we only do shape check
    assert h.shape[-1] == 10

628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
    g = dgl.graph(sp.sparse.random(100, 100, density=0.1))
    edge_func = th.nn.Linear(4, 5 * 10)
    nnconv = nn.NNConv(5, 10, edge_func, 'mean')
    feat = F.randn((100, 5))
    efeat = F.randn((g.number_of_edges(), 4))
    nnconv = nnconv.to(ctx)
    h = nnconv(g, feat, efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

    g = dgl.bipartite(sp.sparse.random(50, 100, density=0.1))
    edge_func = th.nn.Linear(4, 5 * 10)
    nnconv = nn.NNConv((5, 2), 10, edge_func, 'mean')
    feat = F.randn((50, 5))
    feat_dst = F.randn((100, 2))
    efeat = F.randn((g.number_of_edges(), 4))
    nnconv = nnconv.to(ctx)
    h = nnconv(g, (feat, feat_dst), efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

649
650
651
652
653
654
655
656
657
658
659
660
    g = dgl.graph(sp.sparse.random(100, 100, density=0.001))
    seed_nodes = th.unique(g.edges()[1])
    block = dgl.to_block(g, seed_nodes)
    edge_func = th.nn.Linear(4, 5 * 10)
    nnconv = nn.NNConv(5, 10, edge_func, 'mean')
    feat = F.randn((block.number_of_src_nodes(), 5))
    efeat = F.randn((block.number_of_edges(), 4))
    nnconv = nnconv.to(ctx)
    h = nnconv(block, feat, efeat)
    assert h.shape[0] == block.number_of_dst_nodes()
    assert h.shape[-1] == 10

661
662
663
664
665
666
def test_gmm_conv():
    ctx = F.ctx()
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    gmmconv = nn.GMMConv(5, 10, 3, 4, 'mean')
    feat = F.randn((100, 5))
    pseudo = F.randn((g.number_of_edges(), 3))
667
    gmmconv = gmmconv.to(ctx)
668
    h = gmmconv(g, feat, pseudo)
669
670
671
    # currently we only do shape check
    assert h.shape[-1] == 10

672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
    g = dgl.graph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    gmmconv = nn.GMMConv(5, 10, 3, 4, 'mean')
    feat = F.randn((100, 5))
    pseudo = F.randn((g.number_of_edges(), 3))
    gmmconv = gmmconv.to(ctx)
    h = gmmconv(g, feat, pseudo)
    # currently we only do shape check
    assert h.shape[-1] == 10

    g = dgl.bipartite(sp.sparse.random(100, 50, density=0.1), readonly=True)
    gmmconv = nn.GMMConv((5, 2), 10, 3, 4, 'mean')
    feat = F.randn((100, 5))
    feat_dst = F.randn((50, 2))
    pseudo = F.randn((g.number_of_edges(), 3))
    gmmconv = gmmconv.to(ctx)
    h = gmmconv(g, (feat, feat_dst), pseudo)
    # currently we only do shape check
    assert h.shape[-1] == 10

691
692
693
694
695
696
697
698
699
700
701
    g = dgl.graph(sp.sparse.random(100, 100, density=0.001))
    seed_nodes = th.unique(g.edges()[1])
    block = dgl.to_block(g, seed_nodes)
    gmmconv = nn.GMMConv(5, 10, 3, 4, 'mean')
    feat = F.randn((block.number_of_src_nodes(), 5))
    pseudo = F.randn((block.number_of_edges(), 3))
    gmmconv = gmmconv.to(ctx)
    h = gmmconv(block, feat, pseudo)
    assert h.shape[0] == block.number_of_dst_nodes()
    assert h.shape[-1] == 10

702
703
704
@pytest.mark.parametrize('norm_type', ['both', 'right', 'none'])
@pytest.mark.parametrize('g', [random_graph(100), random_bipartite(100, 200)])
def test_dense_graph_conv(norm_type, g):
705
    ctx = F.ctx()
706
    # TODO(minjie): enable the following option after #1385
707
    adj = g.adjacency_matrix(ctx=ctx).to_dense()
708
709
    conv = nn.GraphConv(5, 2, norm=norm_type, bias=True)
    dense_conv = nn.DenseGraphConv(5, 2, norm=norm_type, bias=True)
710
711
    dense_conv.weight.data = conv.weight.data
    dense_conv.bias.data = conv.bias.data
712
    feat = F.randn((g.number_of_src_nodes(), 5))
713
714
    conv = conv.to(ctx)
    dense_conv = dense_conv.to(ctx)
715
716
    out_conv = conv(g, feat)
    out_dense_conv = dense_conv(adj, feat)
717
718
    assert F.allclose(out_conv, out_dense_conv)

719
720
@pytest.mark.parametrize('g', [random_graph(100), random_bipartite(100, 200)])
def test_dense_sage_conv(g):
721
722
    ctx = F.ctx()
    adj = g.adjacency_matrix(ctx=ctx).to_dense()
723
    sage = nn.SAGEConv(5, 2, 'gcn')
724
725
726
    dense_sage = nn.DenseSAGEConv(5, 2)
    dense_sage.fc.weight.data = sage.fc_neigh.weight.data
    dense_sage.fc.bias.data = sage.fc_neigh.bias.data
727
728
729
730
731
732
733
    if len(g.ntypes) == 2:
        feat = (
            F.randn((g.number_of_src_nodes(), 5)),
            F.randn((g.number_of_dst_nodes(), 5))
        )
    else:
        feat = F.randn((g.number_of_nodes(), 5))
734
735
    sage = sage.to(ctx)
    dense_sage = dense_sage.to(ctx)
736
737
    out_sage = sage(g, feat)
    out_dense_sage = dense_sage(adj, feat)
738
739
    assert F.allclose(out_sage, out_dense_sage), g

740
@pytest.mark.parametrize('g', [random_dglgraph(20), random_graph(20), random_bipartite(20, 10), random_block(20)])
741
742
743
744
745
746
747
748
def test_edge_conv(g):
    ctx = F.ctx()

    edge_conv = nn.EdgeConv(5, 2).to(ctx)
    print(edge_conv)

    # test #1: basic
    h0 = F.randn((g.number_of_src_nodes(), 5))
749
    if not g.is_homograph() and not g.is_block:
750
751
752
753
754
        # bipartite
        h1 = edge_conv(g, (h0, h0[:10]))
    else:
        h1 = edge_conv(g, h0)
    assert h1.shape == (g.number_of_dst_nodes(), 2)
755
756
757
758
759
760

def test_dense_cheb_conv():
    for k in range(1, 4):
        ctx = F.ctx()
        g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
        adj = g.adjacency_matrix(ctx=ctx).to_dense()
Axel Nilsson's avatar
Axel Nilsson committed
761
        cheb = nn.ChebConv(5, 2, k, None)
762
        dense_cheb = nn.DenseChebConv(5, 2, k)
Axel Nilsson's avatar
Axel Nilsson committed
763
764
765
766
767
        #for i in range(len(cheb.fc)):
        #    dense_cheb.W.data[i] = cheb.fc[i].weight.data.t()
        dense_cheb.W.data = cheb.linear.weight.data.transpose(-1, -2).view(k, 5, 2)
        if cheb.linear.bias is not None:
            dense_cheb.bias.data = cheb.linear.bias.data
768
        feat = F.randn((100, 5))
769
770
        cheb = cheb.to(ctx)
        dense_cheb = dense_cheb.to(ctx)
771
772
        out_cheb = cheb(g, feat, [2.0])
        out_dense_cheb = dense_cheb(adj, feat, 2.0)
Axel Nilsson's avatar
Axel Nilsson committed
773
        print(k, out_cheb, out_dense_cheb)
774
775
        assert F.allclose(out_cheb, out_dense_cheb)

776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
def test_sequential():
    ctx = F.ctx()
    # Test single graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat, e_feat):
            graph = graph.local_var()
            graph.ndata['h'] = n_feat
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            n_feat += graph.ndata['h']
            graph.apply_edges(fn.u_add_v('h', 'h', 'e'))
            e_feat += graph.edata['e']
            return n_feat, e_feat

    g = dgl.DGLGraph()
    g.add_nodes(3)
    g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    n_feat = F.randn((3, 4))
    e_feat = F.randn((9, 4))
    net = net.to(ctx)
    n_feat, e_feat = net(g, n_feat, e_feat)
    assert n_feat.shape == (3, 4)
    assert e_feat.shape == (9, 4)

    # Test multiple graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat):
            graph = graph.local_var()
            graph.ndata['h'] = n_feat
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            n_feat += graph.ndata['h']
            return n_feat.view(graph.number_of_nodes() // 2, 2, -1).sum(1)

    g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05))
    g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2))
    g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8))
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    net = net.to(ctx)
    n_feat = F.randn((32, 4))
    n_feat = net([g1, g2, g3], n_feat)
    assert n_feat.shape == (4, 4)

824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
def test_atomic_conv():
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    aconv = nn.AtomicConv(interaction_cutoffs=F.tensor([12.0, 12.0]),
                          rbf_kernel_means=F.tensor([0.0, 2.0]),
                          rbf_kernel_scaling=F.tensor([4.0, 4.0]),
                          features_to_use=F.tensor([6.0, 8.0]))

    ctx = F.ctx()
    if F.gpu_ctx():
        aconv = aconv.to(ctx)

    feat = F.randn((100, 1))
    dist = F.randn((g.number_of_edges(), 1))

    h = aconv(g, feat, dist)
    # current we only do shape check
    assert h.shape[-1] == 4

def test_cf_conv():
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    cfconv = nn.CFConv(node_in_feats=2,
                       edge_in_feats=3,
                       hidden_feats=2,
                       out_feats=3)

    ctx = F.ctx()
    if F.gpu_ctx():
        cfconv = cfconv.to(ctx)

    node_feats = F.randn((100, 2))
    edge_feats = F.randn((g.number_of_edges(), 3))
    h = cfconv(g, node_feats, edge_feats)
    # current we only do shape check
    assert h.shape[-1] == 3    

859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
def myagg(alist, dsttype):
    rst = alist[0]
    for i in range(1, len(alist)):
        rst = rst + (i + 1) * alist[i]
    return rst

@pytest.mark.parametrize('agg', ['sum', 'max', 'min', 'mean', 'stack', myagg])
def test_hetero_conv(agg):
    g = dgl.heterograph({
        ('user', 'follows', 'user'): [(0, 1), (0, 2), (2, 1), (1, 3)],
        ('user', 'plays', 'game'): [(0, 0), (0, 2), (0, 3), (1, 0), (2, 2)],
        ('store', 'sells', 'game'): [(0, 0), (0, 3), (1, 1), (1, 2)]})
    conv = nn.HeteroGraphConv({
        'follows': nn.GraphConv(2, 3),
        'plays': nn.GraphConv(2, 4),
        'sells': nn.GraphConv(3, 4)},
        agg)
    if F.gpu_ctx():
        conv = conv.to(F.ctx())
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))
    uf_dst = F.randn((4, 3))
    gf_dst = F.randn((4, 4))

    h = conv(g, {'user': uf})
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 1, 4)

    h = conv(g, {'user': uf, 'store': sf})
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 2, 4)

    h = conv(g, {'store': sf})
    assert set(h.keys()) == {'game'}
    if agg != 'stack':
        assert h['game'].shape == (4, 4)
    else:
        assert h['game'].shape == (4, 1, 4)

    # test with pair input
    conv = nn.HeteroGraphConv({
        'follows': nn.SAGEConv(2, 3, 'mean'),
        'plays': nn.SAGEConv((2, 4), 4, 'mean'),
        'sells': nn.SAGEConv(3, 4, 'mean')},
        agg)
    if F.gpu_ctx():
        conv = conv.to(F.ctx())

    h = conv(g, ({'user': uf}, {'user' : uf, 'game' : gf}))
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 1, 4)

    # pair input requires both src and dst type features to be provided
    h = conv(g, ({'user': uf}, {'game' : gf}))
    assert set(h.keys()) == {'game'}
    if agg != 'stack':
        assert h['game'].shape == (4, 4)
    else:
        assert h['game'].shape == (4, 1, 4)

    # test with mod args
    class MyMod(th.nn.Module):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.carg2 = 0
            self.s1 = s1
            self.s2 = s2
        def forward(self, g, h, arg1=None, *, arg2=None):
            if arg1 is not None:
                self.carg1 += 1
            if arg2 is not None:
                self.carg2 += 1
            return th.zeros((g.number_of_dst_nodes(), self.s2))
    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
    conv = nn.HeteroGraphConv({
        'follows': mod1,
        'plays': mod2,
        'sells': mod3},
        agg)
    if F.gpu_ctx():
        conv = conv.to(F.ctx())
    mod_args = {'follows' : (1,), 'plays' : (1,)}
    mod_kwargs = {'sells' : {'arg2' : 'abc'}}
    h = conv(g, {'user' : uf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs)
    assert mod1.carg1 == 1
    assert mod1.carg2 == 0
    assert mod2.carg1 == 1
    assert mod2.carg2 == 0
    assert mod3.carg1 == 0
    assert mod3.carg2 == 1

969
970
971
if __name__ == '__main__':
    test_graph_conv()
    test_edge_softmax()
972
    test_partial_edge_softmax()
973
974
975
976
    test_set2set()
    test_glob_att_pool()
    test_simple_pool()
    test_set_trans()
Minjie Wang's avatar
Minjie Wang committed
977
    test_rgcn()
978
979
980
981
982
983
984
985
986
987
988
989
990
    test_tagconv()
    test_gat_conv()
    test_sage_conv()
    test_sgc_conv()
    test_appnp_conv()
    test_gin_conv()
    test_agnn_conv()
    test_gated_graph_conv()
    test_nn_conv()
    test_gmm_conv()
    test_dense_graph_conv()
    test_dense_sage_conv()
    test_dense_cheb_conv()
991
    test_sequential()
992
993
    test_atomic_conv()
    test_cf_conv()