test_nn.py 29.3 KB
Newer Older
1
2
3
4
import torch as th
import networkx as nx
import dgl
import dgl.nn.pytorch as nn
5
import dgl.function as fn
6
import backend as F
7
import pytest
8
from test_utils.graph_cases import get_cases, random_graph, random_bipartite, random_dglgraph
9
10
from copy import deepcopy

11
12
13
import numpy as np
import scipy as sp

14
15
16
17
18
19
20
def _AXWb(A, X, W, b):
    X = th.matmul(X, W)
    Y = th.matmul(A, X.view(X.shape[0], -1)).view_as(X)
    return Y + b

def test_graph_conv():
    g = dgl.DGLGraph(nx.path_graph(3))
21
22
    ctx = F.ctx()
    adj = g.adjacency_matrix(ctx=ctx)
23

24
    conv = nn.GraphConv(5, 2, norm='none', bias=True)
25
    conv = conv.to(ctx)
26
27
    print(conv)
    # test#1: basic
28
    h0 = F.ones((3, 5))
29
    h1 = conv(g, h0)
30
31
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
32
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
33
    # test#2: more-dim
34
    h0 = F.ones((3, 5, 5))
35
    h1 = conv(g, h0)
36
37
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
38
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
39
40

    conv = nn.GraphConv(5, 2)
41
    conv = conv.to(ctx)
42
    # test#3: basic
43
    h0 = F.ones((3, 5))
44
    h1 = conv(g, h0)
45
46
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
47
    # test#4: basic
48
    h0 = F.ones((3, 5, 5))
49
    h1 = conv(g, h0)
50
51
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
52
53

    conv = nn.GraphConv(5, 2)
54
    conv = conv.to(ctx)
55
    # test#3: basic
56
    h0 = F.ones((3, 5))
57
    h1 = conv(g, h0)
58
59
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
60
    # test#4: basic
61
    h0 = F.ones((3, 5, 5))
62
    h1 = conv(g, h0)
63
64
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
65
66
67
68
69

    # test rest_parameters
    old_weight = deepcopy(conv.weight.data)
    conv.reset_parameters()
    new_weight = conv.weight.data
70
    assert not F.allclose(old_weight, new_weight)
71

72
73
74
75
76
77
78
79
80
81
@pytest.mark.parametrize('g', get_cases(['path', 'bipartite', 'small'], exclude=['zero-degree']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
def test_graph_conv2(g, norm, weight, bias):
    conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias).to(F.ctx())
    ext_w = F.randn((5, 2)).to(F.ctx())
    nsrc = g.number_of_nodes() if isinstance(g, dgl.DGLGraph) else g.number_of_src_nodes()
    ndst = g.number_of_nodes() if isinstance(g, dgl.DGLGraph) else g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
82
    h_dst = F.randn((ndst, 2)).to(F.ctx())
83
    if weight:
84
        h_out = conv(g, h)
85
    else:
86
87
88
89
90
91
92
93
94
95
96
        h_out = conv(g, h, weight=ext_w)
    assert h_out.shape == (ndst, 2)

    if not isinstance(g, dgl.DGLGraph) and len(g.ntypes) == 2:
        # bipartite, should also accept pair of tensors
        if weight:
            h_out2 = conv(g, (h, h_dst))
        else:
            h_out2 = conv(g, (h, h_dst), weight=ext_w)
        assert h_out2.shape == (ndst, 2)
        assert F.array_equal(h_out, h_out2)
97

98
99
100
101
102
103
104
105
106
107
108
109
def _S2AXWb(A, N, X, W, b):
    X1 = X * N
    X1 = th.matmul(A, X1.view(X1.shape[0], -1))
    X1 = X1 * N
    X2 = X1 * N
    X2 = th.matmul(A, X2.view(X2.shape[0], -1))
    X2 = X2 * N
    X = th.cat([X, X1, X2], dim=-1)
    Y = th.matmul(X, W.rot90())

    return Y + b

110
def test_tagconv():
111
112
113
114
115
    g = dgl.DGLGraph(nx.path_graph(3))
    ctx = F.ctx()
    adj = g.adjacency_matrix(ctx=ctx)
    norm = th.pow(g.in_degrees().float(), -0.5)

116
    conv = nn.TAGConv(5, 2, bias=True)
117
    conv = conv.to(ctx)
118
119
120
121
    print(conv)

    # test#1: basic
    h0 = F.ones((3, 5))
122
    h1 = conv(g, h0)
123
124
125
126
127
128
129
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    shp = norm.shape + (1,) * (h0.dim() - 1)
    norm = th.reshape(norm, shp).to(ctx)

    assert F.allclose(h1, _S2AXWb(adj, norm, h0, conv.lin.weight, conv.lin.bias))

130
    conv = nn.TAGConv(5, 2)
131
    conv = conv.to(ctx)
132

133
134
    # test#2: basic
    h0 = F.ones((3, 5))
135
    h1 = conv(g, h0)
136
    assert h1.shape[-1] == 2
137

138
    # test reset_parameters
139
140
141
142
143
    old_weight = deepcopy(conv.lin.weight.data)
    conv.reset_parameters()
    new_weight = conv.lin.weight.data
    assert not F.allclose(old_weight, new_weight)

144
def test_set2set():
145
    ctx = F.ctx()
146
147
148
    g = dgl.DGLGraph(nx.path_graph(10))

    s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
149
    s2s = s2s.to(ctx)
150
151
152
    print(s2s)

    # test#1: basic
153
    h0 = F.randn((g.number_of_nodes(), 5))
154
    h1 = s2s(g, h0)
155
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
156
157
158
159
160

    # test#2: batched graph
    g1 = dgl.DGLGraph(nx.path_graph(11))
    g2 = dgl.DGLGraph(nx.path_graph(5))
    bg = dgl.batch([g, g1, g2])
161
    h0 = F.randn((bg.number_of_nodes(), 5))
162
    h1 = s2s(bg, h0)
163
164
165
    assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2

def test_glob_att_pool():
166
    ctx = F.ctx()
167
168
169
    g = dgl.DGLGraph(nx.path_graph(10))

    gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))
170
    gap = gap.to(ctx)
171
172
173
    print(gap)

    # test#1: basic
174
    h0 = F.randn((g.number_of_nodes(), 5))
175
    h1 = gap(g, h0)
176
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
177
178
179

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
180
    h0 = F.randn((bg.number_of_nodes(), 5))
181
    h1 = gap(bg, h0)
182
183
184
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2

def test_simple_pool():
185
    ctx = F.ctx()
186
187
188
189
190
191
192
193
194
    g = dgl.DGLGraph(nx.path_graph(15))

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
    sort_pool = nn.SortPooling(10) # k = 10
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
195
    h0 = F.randn((g.number_of_nodes(), 5))
196
197
198
199
    sum_pool = sum_pool.to(ctx)
    avg_pool = avg_pool.to(ctx)
    max_pool = max_pool.to(ctx)
    sort_pool = sort_pool.to(ctx)
200
    h1 = sum_pool(g, h0)
201
    assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
202
    h1 = avg_pool(g, h0)
203
    assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
204
    h1 = max_pool(g, h0)
205
    assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
206
    h1 = sort_pool(g, h0)
207
    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.dim() == 2
208
209
210
211

    # test#2: batched graph
    g_ = dgl.DGLGraph(nx.path_graph(5))
    bg = dgl.batch([g, g_, g, g_, g])
212
    h0 = F.randn((bg.number_of_nodes(), 5))
213
    h1 = sum_pool(bg, h0)
214
215
216
217
218
219
    truth = th.stack([F.sum(h0[:15], 0),
                      F.sum(h0[15:20], 0),
                      F.sum(h0[20:35], 0),
                      F.sum(h0[35:40], 0),
                      F.sum(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
220

221
    h1 = avg_pool(bg, h0)
222
223
224
225
226
227
    truth = th.stack([F.mean(h0[:15], 0),
                      F.mean(h0[15:20], 0),
                      F.mean(h0[20:35], 0),
                      F.mean(h0[35:40], 0),
                      F.mean(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
228

229
    h1 = max_pool(bg, h0)
230
231
232
233
234
235
    truth = th.stack([F.max(h0[:15], 0),
                      F.max(h0[15:20], 0),
                      F.max(h0[20:35], 0),
                      F.max(h0[35:40], 0),
                      F.max(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
236

237
    h1 = sort_pool(bg, h0)
238
239
240
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2

def test_set_trans():
241
    ctx = F.ctx()
242
243
244
245
246
    g = dgl.DGLGraph(nx.path_graph(15))

    st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'sab')
    st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'isab', 3)
    st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)
247
248
249
    st_enc_0 = st_enc_0.to(ctx)
    st_enc_1 = st_enc_1.to(ctx)
    st_dec = st_dec.to(ctx)
250
251
252
    print(st_enc_0, st_enc_1, st_dec)

    # test#1: basic
253
    h0 = F.randn((g.number_of_nodes(), 50))
254
    h1 = st_enc_0(g, h0)
255
    assert h1.shape == h0.shape
256
    h1 = st_enc_1(g, h0)
257
    assert h1.shape == h0.shape
258
    h2 = st_dec(g, h1)
259
    assert h2.shape[0] == 1 and h2.shape[1] == 200 and h2.dim() == 2
260
261
262
263
264

    # test#2: batched graph
    g1 = dgl.DGLGraph(nx.path_graph(5))
    g2 = dgl.DGLGraph(nx.path_graph(10))
    bg = dgl.batch([g, g1, g2])
265
    h0 = F.randn((bg.number_of_nodes(), 50))
266
    h1 = st_enc_0(bg, h0)
267
    assert h1.shape == h0.shape
268
    h1 = st_enc_1(bg, h0)
269
270
    assert h1.shape == h0.shape

271
    h2 = st_dec(bg, h1)
272
273
    assert h2.shape[0] == 3 and h2.shape[1] == 200 and h2.dim() == 2

274
275
276
277
278
279
def uniform_attention(g, shape):
    a = th.ones(shape)
    target_shape = (g.number_of_edges(),) + (1,) * (len(shape) - 1)
    return a / g.in_degrees(g.edges()[1]).view(target_shape).float()

def test_edge_softmax():
280
    # Basic
281
    g = dgl.graph(nx.path_graph(3))
282
    edata = F.ones((g.number_of_edges(), 1))
283
    a = nn.edge_softmax(g, edata)
284
285
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
286
    assert F.allclose(a, uniform_attention(g, a.shape))
287

288
    # Test higher dimension case
289
    edata = F.ones((g.number_of_edges(), 3, 1))
290
    a = nn.edge_softmax(g, edata)
291
292
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
293
    assert F.allclose(a, uniform_attention(g, a.shape))
294

295
    # Test both forward and backward with PyTorch built-in softmax.
296
    g = dgl.rand_graph(30, 900)
297

298
    score = F.randn((900, 1))
299
    score.requires_grad_()
300
301
    grad = F.randn((900, 1))
    y = F.softmax(score.view(30, 30), dim=0).view(-1, 1)
302
303
304
305
    y.backward(grad)
    grad_score = score.grad
    score.grad.zero_()
    y_dgl = nn.edge_softmax(g, score)
306
307
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
308
    # check forward
309
    assert F.allclose(y_dgl, y)
310
311
    y_dgl.backward(grad)
    # checkout gradient
312
    assert F.allclose(score.grad, grad_score)
313
314
    print(score.grad[:10], grad_score[:10])
    
315
    """
316
    # Test 2
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
    def generate_rand_graph(n, m=None, ctor=dgl.DGLGraph):
        if m is None:
            m = n
        arr = (sp.sparse.random(m, n, density=0.1, format='coo') != 0).astype(np.int64)
        return ctor(arr, readonly=True)

    for g in [generate_rand_graph(50),
              generate_rand_graph(50, ctor=dgl.graph),
              generate_rand_graph(100, 50, ctor=dgl.bipartite)]:
        a1 = F.randn((g.number_of_edges(), 1)).requires_grad_()
        a2 = a1.clone().detach().requires_grad_()
        g.edata['s'] = a1
        g.group_apply_edges('dst', lambda edges: {'ss':F.softmax(edges.data['s'], 1)})
        g.edata['ss'].sum().backward()
        
        builtin_sm = nn.edge_softmax(g, a2)
        builtin_sm.sum().backward()
        print(a1.grad - a2.grad)
        assert len(g.srcdata) == 0
        assert len(g.dstdata) == 0
        assert len(g.edata) == 2
        assert F.allclose(a1.grad, a2.grad, rtol=1e-4, atol=1e-4) # Follow tolerance in unittest backend
339
    """
340
341

def test_partial_edge_softmax():
342
    g = dgl.rand_graph(30, 900)
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364

    score = F.randn((300, 1))
    score.requires_grad_()
    grad = F.randn((300, 1))
    import numpy as np
    eids = np.random.choice(900, 300, replace=False).astype('int64')
    eids = F.zerocopy_from_numpy(eids)
    # compute partial edge softmax
    y_1 = nn.edge_softmax(g, score, eids)
    y_1.backward(grad)
    grad_1 = score.grad
    score.grad.zero_()
    # compute edge softmax on edge subgraph
    subg = g.edge_subgraph(eids)
    y_2 = nn.edge_softmax(subg, score)
    y_2.backward(grad)
    grad_2 = score.grad
    score.grad.zero_()

    assert F.allclose(y_1, y_2)
    assert F.allclose(grad_1, grad_2)

Minjie Wang's avatar
Minjie Wang committed
365
366
367
368
369
370
371
372
373
374
375
376
377
def test_rgcn():
    ctx = F.ctx()
    etype = []
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    # 5 etypes
    R = 5
    for i in range(g.number_of_edges()):
        etype.append(i % 5)
    B = 2
    I = 10
    O = 8

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
378
379
380
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
Minjie Wang's avatar
Minjie Wang committed
381
382
383
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r)
384
    h_new_low = rgc_basis_low(g, h, r)
Minjie Wang's avatar
Minjie Wang committed
385
    assert list(h_new.shape) == [100, O]
386
387
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
388
389

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
390
391
    rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
    rgc_bdd_low.weight = rgc_bdd.weight
Minjie Wang's avatar
Minjie Wang committed
392
393
394
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_bdd(g, h, r)
395
    h_new_low = rgc_bdd_low(g, h, r)
Minjie Wang's avatar
Minjie Wang committed
396
    assert list(h_new.shape) == [100, O]
397
398
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
399
400
401
402
403

    # with norm
    norm = th.zeros((g.number_of_edges(), 1)).to(ctx)

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
404
405
406
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
Minjie Wang's avatar
Minjie Wang committed
407
408
409
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r, norm)
410
    h_new_low = rgc_basis_low(g, h, r, norm)
Minjie Wang's avatar
Minjie Wang committed
411
    assert list(h_new.shape) == [100, O]
412
413
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
414
415

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
416
417
    rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
    rgc_bdd_low.weight = rgc_bdd.weight
Minjie Wang's avatar
Minjie Wang committed
418
419
420
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_bdd(g, h, r, norm)
421
    h_new_low = rgc_bdd_low(g, h, r, norm)
Minjie Wang's avatar
Minjie Wang committed
422
    assert list(h_new.shape) == [100, O]
423
424
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
425
426
427

    # id input
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
428
429
430
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
Minjie Wang's avatar
Minjie Wang committed
431
432
433
    h = th.randint(0, I, (100,)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r)
434
    h_new_low = rgc_basis_low(g, h, r)
Minjie Wang's avatar
Minjie Wang committed
435
    assert list(h_new.shape) == [100, O]
436
437
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
438

439
440
def test_gat_conv():
    ctx = F.ctx()
441
    g = dgl.rand_graph(100, 1000)
442
443
    gat = nn.GATConv(5, 2, 4)
    feat = F.randn((100, 5))
444
    gat = gat.to(ctx)
445
    h = gat(g, feat)
446
    assert h.shape == (100, 4, 2)
447

448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
    g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1))
    gat = nn.GATConv((5, 10), 2, 4)
    feat = (F.randn((100, 5)), F.randn((200, 10)))
    gat = gat.to(ctx)
    h = gat(g, feat)
    assert h.shape == (200, 4, 2)

@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm'])
def test_sage_conv(aggre_type):
    ctx = F.ctx()
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    sage = nn.SAGEConv(5, 10, aggre_type)
    feat = F.randn((100, 5))
    sage = sage.to(ctx)
    h = sage(g, feat)
    assert h.shape[-1] == 10

    g = dgl.graph(sp.sparse.random(100, 100, density=0.1))
    sage = nn.SAGEConv(5, 10, aggre_type)
    feat = F.randn((100, 5))
    sage = sage.to(ctx)
    h = sage(g, feat)
    assert h.shape[-1] == 10

    g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1))
    dst_dim = 5 if aggre_type != 'gcn' else 10
    sage = nn.SAGEConv((10, dst_dim), 2, aggre_type)
    feat = (F.randn((100, 10)), F.randn((200, dst_dim)))
    sage = sage.to(ctx)
    h = sage(g, feat)
    assert h.shape[-1] == 2
    assert h.shape[0] == 200
480

Mufei Li's avatar
Mufei Li committed
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
    # Test the case for graphs without edges
    g = dgl.bipartite([], num_nodes=(5, 3))
    sage = nn.SAGEConv((3, 3), 2, 'gcn')
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    sage = sage.to(ctx)
    h = sage(g, feat)
    assert h.shape[-1] == 2
    assert h.shape[0] == 3
    for aggre_type in ['mean', 'pool', 'lstm']:
        sage = nn.SAGEConv((3, 1), 2, aggre_type)
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        sage = sage.to(ctx)
        h = sage(g, feat)
        assert h.shape[-1] == 2
        assert h.shape[0] == 3

497
498
499
500
501
502
def test_sgc_conv():
    ctx = F.ctx()
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    # not cached
    sgc = nn.SGConv(5, 10, 3)
    feat = F.randn((100, 5))
503
    sgc = sgc.to(ctx)
504

505
    h = sgc(g, feat)
506
507
508
509
    assert h.shape[-1] == 10

    # cached
    sgc = nn.SGConv(5, 10, 3, True)
510
    sgc = sgc.to(ctx)
511
512
    h_0 = sgc(g, feat)
    h_1 = sgc(g, feat + 1)
513
514
515
516
517
518
519
520
    assert F.allclose(h_0, h_1)
    assert h_0.shape[-1] == 10

def test_appnp_conv():
    ctx = F.ctx()
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    appnp = nn.APPNPConv(10, 0.1)
    feat = F.randn((100, 5))
521
    appnp = appnp.to(ctx)
522

523
    h = appnp(g, feat)
524
525
    assert h.shape[-1] == 5

526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
def test_gin_conv(aggregator_type):
    ctx = F.ctx()
    g = dgl.graph(sp.sparse.random(100, 100, density=0.1))
    gin = nn.GINConv(
        th.nn.Linear(5, 12),
        aggregator_type
    )
    feat = F.randn((100, 5))
    gin = gin.to(ctx)
    h = gin(g, feat)
    assert h.shape == (100, 12)

    g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1))
    gin = nn.GINConv(
        th.nn.Linear(5, 12),
        aggregator_type
    )
    feat = (F.randn((100, 5)), F.randn((200, 5)))
    gin = gin.to(ctx)
    h = gin(g, feat)
    assert h.shape == (200, 12)
548
549
550

def test_agnn_conv():
    ctx = F.ctx()
551
    g = dgl.graph(sp.sparse.random(100, 100, density=0.1))
552
553
    agnn = nn.AGNNConv(1)
    feat = F.randn((100, 5))
554
    agnn = agnn.to(ctx)
555
    h = agnn(g, feat)
556
557
558
559
560
561
562
563
    assert h.shape == (100, 5)

    g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1))
    agnn = nn.AGNNConv(1)
    feat = (F.randn((100, 5)), F.randn((200, 5)))
    agnn = agnn.to(ctx)
    h = agnn(g, feat)
    assert h.shape == (200, 5)
564
565
566
567
568
569
570

def test_gated_graph_conv():
    ctx = F.ctx()
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    ggconv = nn.GatedGraphConv(5, 10, 5, 3)
    etypes = th.arange(g.number_of_edges()) % 3
    feat = F.randn((100, 5))
571
572
    ggconv = ggconv.to(ctx)
    etypes = etypes.to(ctx)
573

574
    h = ggconv(g, feat, etypes)
575
576
577
578
579
580
581
582
583
584
    # current we only do shape check
    assert h.shape[-1] == 10

def test_nn_conv():
    ctx = F.ctx()
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    edge_func = th.nn.Linear(4, 5 * 10)
    nnconv = nn.NNConv(5, 10, edge_func, 'mean')
    feat = F.randn((100, 5))
    efeat = F.randn((g.number_of_edges(), 4))
585
    nnconv = nnconv.to(ctx)
586
    h = nnconv(g, feat, efeat)
587
588
589
    # currently we only do shape check
    assert h.shape[-1] == 10

590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
    g = dgl.graph(sp.sparse.random(100, 100, density=0.1))
    edge_func = th.nn.Linear(4, 5 * 10)
    nnconv = nn.NNConv(5, 10, edge_func, 'mean')
    feat = F.randn((100, 5))
    efeat = F.randn((g.number_of_edges(), 4))
    nnconv = nnconv.to(ctx)
    h = nnconv(g, feat, efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

    g = dgl.bipartite(sp.sparse.random(50, 100, density=0.1))
    edge_func = th.nn.Linear(4, 5 * 10)
    nnconv = nn.NNConv((5, 2), 10, edge_func, 'mean')
    feat = F.randn((50, 5))
    feat_dst = F.randn((100, 2))
    efeat = F.randn((g.number_of_edges(), 4))
    nnconv = nnconv.to(ctx)
    h = nnconv(g, (feat, feat_dst), efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

611
612
613
614
615
616
def test_gmm_conv():
    ctx = F.ctx()
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    gmmconv = nn.GMMConv(5, 10, 3, 4, 'mean')
    feat = F.randn((100, 5))
    pseudo = F.randn((g.number_of_edges(), 3))
617
    gmmconv = gmmconv.to(ctx)
618
    h = gmmconv(g, feat, pseudo)
619
620
621
    # currently we only do shape check
    assert h.shape[-1] == 10

622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
    g = dgl.graph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    gmmconv = nn.GMMConv(5, 10, 3, 4, 'mean')
    feat = F.randn((100, 5))
    pseudo = F.randn((g.number_of_edges(), 3))
    gmmconv = gmmconv.to(ctx)
    h = gmmconv(g, feat, pseudo)
    # currently we only do shape check
    assert h.shape[-1] == 10

    g = dgl.bipartite(sp.sparse.random(100, 50, density=0.1), readonly=True)
    gmmconv = nn.GMMConv((5, 2), 10, 3, 4, 'mean')
    feat = F.randn((100, 5))
    feat_dst = F.randn((50, 2))
    pseudo = F.randn((g.number_of_edges(), 3))
    gmmconv = gmmconv.to(ctx)
    h = gmmconv(g, (feat, feat_dst), pseudo)
    # currently we only do shape check
    assert h.shape[-1] == 10

@pytest.mark.parametrize('norm_type', ['both', 'right', 'none'])
@pytest.mark.parametrize('g', [random_graph(100), random_bipartite(100, 200)])
def test_dense_graph_conv(norm_type, g):
644
    ctx = F.ctx()
645
    # TODO(minjie): enable the following option after #1385
646
    adj = g.adjacency_matrix(ctx=ctx).to_dense()
647
648
    conv = nn.GraphConv(5, 2, norm=norm_type, bias=True)
    dense_conv = nn.DenseGraphConv(5, 2, norm=norm_type, bias=True)
649
650
    dense_conv.weight.data = conv.weight.data
    dense_conv.bias.data = conv.bias.data
651
    feat = F.randn((g.number_of_src_nodes(), 5))
652
653
    conv = conv.to(ctx)
    dense_conv = dense_conv.to(ctx)
654
655
    out_conv = conv(g, feat)
    out_dense_conv = dense_conv(adj, feat)
656
657
    assert F.allclose(out_conv, out_dense_conv)

658
659
@pytest.mark.parametrize('g', [random_graph(100), random_bipartite(100, 200)])
def test_dense_sage_conv(g):
660
661
    ctx = F.ctx()
    adj = g.adjacency_matrix(ctx=ctx).to_dense()
662
    sage = nn.SAGEConv(5, 2, 'gcn')
663
664
665
    dense_sage = nn.DenseSAGEConv(5, 2)
    dense_sage.fc.weight.data = sage.fc_neigh.weight.data
    dense_sage.fc.bias.data = sage.fc_neigh.bias.data
666
667
668
669
670
671
672
    if len(g.ntypes) == 2:
        feat = (
            F.randn((g.number_of_src_nodes(), 5)),
            F.randn((g.number_of_dst_nodes(), 5))
        )
    else:
        feat = F.randn((g.number_of_nodes(), 5))
673
674
    sage = sage.to(ctx)
    dense_sage = dense_sage.to(ctx)
675
676
    out_sage = sage(g, feat)
    out_dense_sage = dense_sage(adj, feat)
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
    assert F.allclose(out_sage, out_dense_sage), g

@pytest.mark.parametrize('g', [random_dglgraph(20), random_graph(20), random_bipartite(20, 10)])
def test_edge_conv(g):
    ctx = F.ctx()

    edge_conv = nn.EdgeConv(5, 2).to(ctx)
    print(edge_conv)

    # test #1: basic
    h0 = F.randn((g.number_of_src_nodes(), 5))
    if not g.is_homograph():
        # bipartite
        h1 = edge_conv(g, (h0, h0[:10]))
    else:
        h1 = edge_conv(g, h0)
    assert h1.shape == (g.number_of_dst_nodes(), 2)
694
695
696
697
698
699

def test_dense_cheb_conv():
    for k in range(1, 4):
        ctx = F.ctx()
        g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
        adj = g.adjacency_matrix(ctx=ctx).to_dense()
Axel Nilsson's avatar
Axel Nilsson committed
700
        cheb = nn.ChebConv(5, 2, k, None)
701
        dense_cheb = nn.DenseChebConv(5, 2, k)
Axel Nilsson's avatar
Axel Nilsson committed
702
703
704
705
706
        #for i in range(len(cheb.fc)):
        #    dense_cheb.W.data[i] = cheb.fc[i].weight.data.t()
        dense_cheb.W.data = cheb.linear.weight.data.transpose(-1, -2).view(k, 5, 2)
        if cheb.linear.bias is not None:
            dense_cheb.bias.data = cheb.linear.bias.data
707
        feat = F.randn((100, 5))
708
709
        cheb = cheb.to(ctx)
        dense_cheb = dense_cheb.to(ctx)
710
711
        out_cheb = cheb(g, feat, [2.0])
        out_dense_cheb = dense_cheb(adj, feat, 2.0)
Axel Nilsson's avatar
Axel Nilsson committed
712
        print(k, out_cheb, out_dense_cheb)
713
714
        assert F.allclose(out_cheb, out_dense_cheb)

715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
def test_sequential():
    ctx = F.ctx()
    # Test single graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat, e_feat):
            graph = graph.local_var()
            graph.ndata['h'] = n_feat
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            n_feat += graph.ndata['h']
            graph.apply_edges(fn.u_add_v('h', 'h', 'e'))
            e_feat += graph.edata['e']
            return n_feat, e_feat

    g = dgl.DGLGraph()
    g.add_nodes(3)
    g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    n_feat = F.randn((3, 4))
    e_feat = F.randn((9, 4))
    net = net.to(ctx)
    n_feat, e_feat = net(g, n_feat, e_feat)
    assert n_feat.shape == (3, 4)
    assert e_feat.shape == (9, 4)

    # Test multiple graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat):
            graph = graph.local_var()
            graph.ndata['h'] = n_feat
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            n_feat += graph.ndata['h']
            return n_feat.view(graph.number_of_nodes() // 2, 2, -1).sum(1)

    g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05))
    g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2))
    g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8))
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    net = net.to(ctx)
    n_feat = F.randn((32, 4))
    n_feat = net([g1, g2, g3], n_feat)
    assert n_feat.shape == (4, 4)

763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
def test_atomic_conv():
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    aconv = nn.AtomicConv(interaction_cutoffs=F.tensor([12.0, 12.0]),
                          rbf_kernel_means=F.tensor([0.0, 2.0]),
                          rbf_kernel_scaling=F.tensor([4.0, 4.0]),
                          features_to_use=F.tensor([6.0, 8.0]))

    ctx = F.ctx()
    if F.gpu_ctx():
        aconv = aconv.to(ctx)

    feat = F.randn((100, 1))
    dist = F.randn((g.number_of_edges(), 1))

    h = aconv(g, feat, dist)
    # current we only do shape check
    assert h.shape[-1] == 4

def test_cf_conv():
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    cfconv = nn.CFConv(node_in_feats=2,
                       edge_in_feats=3,
                       hidden_feats=2,
                       out_feats=3)

    ctx = F.ctx()
    if F.gpu_ctx():
        cfconv = cfconv.to(ctx)

    node_feats = F.randn((100, 2))
    edge_feats = F.randn((g.number_of_edges(), 3))
    h = cfconv(g, node_feats, edge_feats)
    # current we only do shape check
    assert h.shape[-1] == 3    

798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
def myagg(alist, dsttype):
    rst = alist[0]
    for i in range(1, len(alist)):
        rst = rst + (i + 1) * alist[i]
    return rst

@pytest.mark.parametrize('agg', ['sum', 'max', 'min', 'mean', 'stack', myagg])
def test_hetero_conv(agg):
    g = dgl.heterograph({
        ('user', 'follows', 'user'): [(0, 1), (0, 2), (2, 1), (1, 3)],
        ('user', 'plays', 'game'): [(0, 0), (0, 2), (0, 3), (1, 0), (2, 2)],
        ('store', 'sells', 'game'): [(0, 0), (0, 3), (1, 1), (1, 2)]})
    conv = nn.HeteroGraphConv({
        'follows': nn.GraphConv(2, 3),
        'plays': nn.GraphConv(2, 4),
        'sells': nn.GraphConv(3, 4)},
        agg)
    if F.gpu_ctx():
        conv = conv.to(F.ctx())
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))
    uf_dst = F.randn((4, 3))
    gf_dst = F.randn((4, 4))

    h = conv(g, {'user': uf})
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 1, 4)

    h = conv(g, {'user': uf, 'store': sf})
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 2, 4)

    h = conv(g, {'store': sf})
    assert set(h.keys()) == {'game'}
    if agg != 'stack':
        assert h['game'].shape == (4, 4)
    else:
        assert h['game'].shape == (4, 1, 4)

    # test with pair input
    conv = nn.HeteroGraphConv({
        'follows': nn.SAGEConv(2, 3, 'mean'),
        'plays': nn.SAGEConv((2, 4), 4, 'mean'),
        'sells': nn.SAGEConv(3, 4, 'mean')},
        agg)
    if F.gpu_ctx():
        conv = conv.to(F.ctx())

    h = conv(g, ({'user': uf}, {'user' : uf, 'game' : gf}))
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 1, 4)

    # pair input requires both src and dst type features to be provided
    h = conv(g, ({'user': uf}, {'game' : gf}))
    assert set(h.keys()) == {'game'}
    if agg != 'stack':
        assert h['game'].shape == (4, 4)
    else:
        assert h['game'].shape == (4, 1, 4)

    # test with mod args
    class MyMod(th.nn.Module):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.carg2 = 0
            self.s1 = s1
            self.s2 = s2
        def forward(self, g, h, arg1=None, *, arg2=None):
            if arg1 is not None:
                self.carg1 += 1
            if arg2 is not None:
                self.carg2 += 1
            return th.zeros((g.number_of_dst_nodes(), self.s2))
    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
    conv = nn.HeteroGraphConv({
        'follows': mod1,
        'plays': mod2,
        'sells': mod3},
        agg)
    if F.gpu_ctx():
        conv = conv.to(F.ctx())
    mod_args = {'follows' : (1,), 'plays' : (1,)}
    mod_kwargs = {'sells' : {'arg2' : 'abc'}}
    h = conv(g, {'user' : uf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs)
    assert mod1.carg1 == 1
    assert mod1.carg2 == 0
    assert mod2.carg1 == 1
    assert mod2.carg2 == 0
    assert mod3.carg1 == 0
    assert mod3.carg2 == 1

908
909
910
if __name__ == '__main__':
    test_graph_conv()
    test_edge_softmax()
911
    test_partial_edge_softmax()
912
913
914
915
    test_set2set()
    test_glob_att_pool()
    test_simple_pool()
    test_set_trans()
Minjie Wang's avatar
Minjie Wang committed
916
    test_rgcn()
917
918
919
920
921
922
923
924
925
926
927
928
929
    test_tagconv()
    test_gat_conv()
    test_sage_conv()
    test_sgc_conv()
    test_appnp_conv()
    test_gin_conv()
    test_agnn_conv()
    test_gated_graph_conv()
    test_nn_conv()
    test_gmm_conv()
    test_dense_graph_conv()
    test_dense_sage_conv()
    test_dense_cheb_conv()
930
    test_sequential()
931
932
    test_atomic_conv()
    test_cf_conv()