"vscode:/vscode.git/clone" did not exist on "6208d622ca74789f329fb4e9041a600e1f96659b"
test_nn.py 26.8 KB
Newer Older
1
2
3
import mxnet as mx
import networkx as nx
import numpy as np
Minjie Wang's avatar
Minjie Wang committed
4
import scipy as sp
5
import pytest
6
7
import dgl
import dgl.nn.mxnet as nn
8
import dgl.function as fn
9
import backend as F
10
11
from test_utils.graph_cases import get_cases, random_graph, random_bipartite, random_dglgraph
from test_utils import parametrize_dtype
Minjie Wang's avatar
Minjie Wang committed
12
from mxnet import autograd, gluon, nd
13

14
15
def check_close(a, b):
    assert np.allclose(a.asnumpy(), b.asnumpy(), rtol=1e-4, atol=1e-4)
16
17
18
19
20
21

def _AXWb(A, X, W, b):
    X = mx.nd.dot(X, W.data(X.context))
    Y = mx.nd.dot(A, X.reshape(X.shape[0], -1)).reshape(X.shape)
    return Y + b.data(X.context)

22
@parametrize_dtype
23
24
@pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv(idtype, out_dim):
25
    g = dgl.from_networkx(nx.path_graph(3))
26
    g = g.astype(idtype).to(F.ctx())
27
    ctx = F.ctx()
28
    adj = g.adjacency_matrix(transpose=True, ctx=ctx)
29

30
    conv = nn.GraphConv(5, out_dim, norm='none', bias=True)
31
32
    conv.initialize(ctx=ctx)
    # test#1: basic
33
    h0 = F.ones((3, 5))
34
    h1 = conv(g, h0)
35
36
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
37
    check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias))
38
    # test#2: more-dim
39
    h0 = F.ones((3, 5, 5))
40
    h1 = conv(g, h0)
41
42
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
43
    check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias))
44

45
    conv = nn.GraphConv(5, out_dim)
46
47
48
    conv.initialize(ctx=ctx)

    # test#3: basic
49
    h0 = F.ones((3, 5))
50
    h1 = conv(g, h0)
51
52
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
53
    # test#4: basic
54
    h0 = F.ones((3, 5, 5))
55
    h1 = conv(g, h0)
56
57
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
58

59
    conv = nn.GraphConv(5, out_dim)
60
61
62
63
    conv.initialize(ctx=ctx)

    with autograd.train_mode():
        # test#3: basic
64
        h0 = F.ones((3, 5))
65
        h1 = conv(g, h0)
66
67
        assert len(g.ndata) == 0
        assert len(g.edata) == 0
68
        # test#4: basic
69
        h0 = F.ones((3, 5, 5))
70
        h1 = conv(g, h0)
71
72
        assert len(g.ndata) == 0
        assert len(g.edata) == 0
73

74
    # test not override features
75
    g.ndata["h"] = 2 * F.ones((3, 1))
76
    h1 = conv(g, h0)
77
78
79
    assert len(g.ndata) == 1
    assert len(g.edata) == 0
    assert "h" in g.ndata
80
    check_close(g.ndata['h'], 2 * F.ones((3, 1)))
81

82
83
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree', 'dglgraph']))
84
85
86
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [False])
87
88
@pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv2(idtype, g, norm, weight, bias, out_dim):
89
    g = g.astype(idtype).to(F.ctx())
90
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias)
91
    conv.initialize(ctx=F.ctx())
92
    ext_w = F.randn((5, out_dim)).as_in_context(F.ctx())
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
93
94
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
95
96
    h = F.randn((nsrc, 5)).as_in_context(F.ctx())
    if weight:
97
        h_out = conv(g, h)
98
    else:
99
        h_out = conv(g, h, ext_w)
100
    assert h_out.shape == (ndst, out_dim)
101

102
103
104
105
106
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [False])
107
108
@pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv2_bi(idtype, g, norm, weight, bias, out_dim):
109
    g = g.astype(idtype).to(F.ctx())
110
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias)
111
    conv.initialize(ctx=F.ctx())
112
    ext_w = F.randn((5, out_dim)).as_in_context(F.ctx())
113
114
115
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).as_in_context(F.ctx())
116
    h_dst = F.randn((ndst, out_dim)).as_in_context(F.ctx())
117
118
119
120
    if weight:
        h_out = conv(g, (h, h_dst))
    else:
        h_out = conv(g, (h, h_dst), ext_w)
121
    assert h_out.shape == (ndst, out_dim)
122

123
124
125
126
127
128
129
130
131
132
133
134
def _S2AXWb(A, N, X, W, b):
    X1 = X * N
    X1 = mx.nd.dot(A, X1.reshape(X1.shape[0], -1))
    X1 = X1 * N
    X2 = X1 * N
    X2 = mx.nd.dot(A, X2.reshape(X2.shape[0], -1))
    X2 = X2 * N
    X = mx.nd.concat(X, X1, X2, dim=-1)
    Y = mx.nd.dot(X, W)

    return Y + b

135
136
@pytest.mark.parametrize('out_dim', [1, 2])
def test_tagconv(out_dim):
137
    g = dgl.from_networkx(nx.path_graph(3)).to(F.ctx())
138
    ctx = F.ctx()
139
    adj = g.adjacency_matrix(transpose=True, ctx=ctx)
140
141
    norm = mx.nd.power(g.in_degrees().astype('float32'), -0.5)

142
    conv = nn.TAGConv(5, out_dim, bias=True)
143
144
145
146
147
148
149
150
151
152
153
154
155
    conv.initialize(ctx=ctx)
    print(conv)

    # test#1: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    shp = norm.shape + (1,) * (h0.ndim - 1)
    norm = norm.reshape(shp).as_in_context(h0.context)

    assert F.allclose(h1, _S2AXWb(adj, norm, h0, conv.lin.data(ctx), conv.h_bias.data(ctx)))

156
    conv = nn.TAGConv(5, out_dim)
157
158
159
160
161
    conv.initialize(ctx=ctx)

    # test#2: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
162
    assert h1.shape[-1] == out_dim
163

164
@parametrize_dtype
165
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
166
167
168
@pytest.mark.parametrize('out_dim', [1, 20])
@pytest.mark.parametrize('num_heads', [1, 5])
def test_gat_conv(g, idtype, out_dim, num_heads):
169
    g = g.astype(idtype).to(F.ctx())
170
    ctx = F.ctx()
171
    gat = nn.GATConv(10, out_dim, num_heads) # n_heads = 5
172
173
    gat.initialize(ctx=ctx)
    print(gat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
174
    feat = F.randn((g.number_of_src_nodes(), 10))
175
    h = gat(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
176
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
177
    _, a = gat(g, feat, True)
178
    assert a.shape == (g.number_of_edges(), num_heads, 1)
179

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
180
181
182
183
184
    # test residual connection
    gat = nn.GATConv(10, out_dim, num_heads, residual=True)
    gat.initialize(ctx=ctx)
    h = gat(g, feat)

185
@parametrize_dtype
186
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
187
188
189
@pytest.mark.parametrize('out_dim', [1, 2])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_gat_conv_bi(g, idtype, out_dim, num_heads):
190
191
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
192
    gat = nn.GATConv(5, out_dim, num_heads)
193
    gat.initialize(ctx=ctx)
194
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
195
    h = gat(g, feat)
196
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
197
    _, a = gat(g, feat, True)
198
    assert a.shape == (g.number_of_edges(), num_heads, 1)
199

200
201
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
202
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
203
204
@pytest.mark.parametrize('out_dim', [1, 10])
def test_sage_conv(idtype, g, aggre_type, out_dim):
205
    g = g.astype(idtype).to(F.ctx())
206
    ctx = F.ctx()
207
    sage = nn.SAGEConv(5, out_dim, aggre_type)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
208
    feat = F.randn((g.number_of_src_nodes(), 5))
209
210
    sage.initialize(ctx=ctx)
    h = sage(g, feat)
211
    assert h.shape[-1] == out_dim
212

213
214
215
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
216
217
@pytest.mark.parametrize('out_dim', [1, 2])
def test_sage_conv_bi(idtype, g, aggre_type, out_dim):
218
219
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
220
    dst_dim = 5 if aggre_type != 'gcn' else 10
221
    sage = nn.SAGEConv((10, dst_dim), out_dim, aggre_type)
222
    feat = (F.randn((g.number_of_src_nodes(), 10)), F.randn((g.number_of_dst_nodes(), dst_dim)))
223
224
    sage.initialize(ctx=ctx)
    h = sage(g, feat)
225
    assert h.shape[-1] == out_dim
226
    assert h.shape[0] == g.number_of_dst_nodes()
227

228
229
@parametrize_dtype
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
230
231
@pytest.mark.parametrize('out_dim', [1, 2])
def test_sage_conv_bi2(idtype, aggre_type, out_dim):
Mufei Li's avatar
Mufei Li committed
232
    # Test the case for graphs without edges
233
    g = dgl.heterograph({('_U', '_E', '_V'): ([], [])}, {'_U': 5, '_V': 3})
234
235
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
236
    sage = nn.SAGEConv((3, 3), out_dim, 'gcn')
Mufei Li's avatar
Mufei Li committed
237
238
239
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    sage.initialize(ctx=ctx)
    h = sage(g, feat)
240
    assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
241
242
    assert h.shape[0] == 3
    for aggre_type in ['mean', 'pool']:
243
        sage = nn.SAGEConv((3, 1), out_dim, aggre_type)
Mufei Li's avatar
Mufei Li committed
244
245
246
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        sage.initialize(ctx=ctx)
        h = sage(g, feat)
247
        assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
248
249
        assert h.shape[0] == 3

250
def test_gg_conv():
251
    g = dgl.from_networkx(nx.erdos_renyi_graph(20, 0.3)).to(F.ctx())
252
253
254
255
256
257
258
259
260
261
262
263
    ctx = F.ctx()

    gg_conv = nn.GatedGraphConv(10, 20, 3, 4) # n_step = 3, n_etypes = 4
    gg_conv.initialize(ctx=ctx)
    print(gg_conv)

    # test#1: basic
    h0 = F.randn((20, 10))
    etypes = nd.random.randint(0, 4, g.number_of_edges()).as_in_context(ctx)
    h1 = gg_conv(g, h0, etypes)
    assert h1.shape == (20, 20)

264
265
@pytest.mark.parametrize('out_dim', [1, 20])
def test_cheb_conv(out_dim):
266
    g = dgl.from_networkx(nx.erdos_renyi_graph(20, 0.3)).to(F.ctx())
267
268
    ctx = F.ctx()

269
    cheb = nn.ChebConv(10, out_dim, 3) # k = 3
270
271
272
273
274
275
    cheb.initialize(ctx=ctx)
    print(cheb)

    # test#1: basic
    h0 = F.randn((20, 10))
    h1 = cheb(g, h0)
276
    assert h1.shape == (20, out_dim)
277

278
@parametrize_dtype
279
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
280
281
def test_agnn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
282
283
284
285
    ctx = F.ctx()
    agnn_conv = nn.AGNNConv(0.1, True)
    agnn_conv.initialize(ctx=ctx)
    print(agnn_conv)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
286
    feat = F.randn((g.number_of_src_nodes(), 10))
287
    h = agnn_conv(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
288
    assert h.shape == (g.number_of_dst_nodes(), 10)
289

290
@parametrize_dtype
291
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
292
293
294
295
296
297
298
def test_agnn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    agnn_conv = nn.AGNNConv(0.1, True)
    agnn_conv.initialize(ctx=ctx)
    print(agnn_conv)
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
299
    h = agnn_conv(g, feat)
300
    assert h.shape == (g.number_of_dst_nodes(), 5)
301

302
def test_appnp_conv():
303
    g = dgl.from_networkx(nx.erdos_renyi_graph(20, 0.3)).to(F.ctx())
304
305
306
307
308
309
310
311
312
313
314
    ctx = F.ctx()

    appnp_conv = nn.APPNPConv(3, 0.1, 0)
    appnp_conv.initialize(ctx=ctx)
    print(appnp_conv)

    # test#1: basic
    h0 = F.randn((20, 10))
    h1 = appnp_conv(g, h0)
    assert h1.shape == (20, 10)

315
316
@pytest.mark.parametrize('out_dim', [1, 2])
def test_dense_cheb_conv(out_dim):
317
318
    for k in range(1, 4):
        ctx = F.ctx()
319
        g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.3)).to(F.ctx())
320
        adj = g.adjacency_matrix(transpose=True, ctx=ctx).tostype('default')
321
322
        cheb = nn.ChebConv(5, out_dim, k)
        dense_cheb = nn.DenseChebConv(5, out_dim, k)
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
        cheb.initialize(ctx=ctx)
        dense_cheb.initialize(ctx=ctx)

        for i in range(len(cheb.fc)):
            dense_cheb.fc[i].weight.set_data(
                cheb.fc[i].weight.data())
            if cheb.bias is not None:
                dense_cheb.bias.set_data(
                    cheb.bias.data())

        feat = F.randn((100, 5))
        out_cheb = cheb(g, feat, [2.0])
        out_dense_cheb = dense_cheb(adj, feat, 2.0)
        assert F.allclose(out_cheb, out_dense_cheb)

338
@parametrize_dtype
339
@pytest.mark.parametrize('norm_type', ['both', 'right', 'none'])
340
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
341
342
@pytest.mark.parametrize('out_dim', [1, 2])
def test_dense_graph_conv(idtype, g, norm_type, out_dim):
343
    g = g.astype(idtype).to(F.ctx())
344
    ctx = F.ctx()
345
    adj = g.adjacency_matrix(transpose=True, ctx=ctx).tostype('default')
346
347
    conv = nn.GraphConv(5, out_dim, norm=norm_type, bias=True)
    dense_conv = nn.DenseGraphConv(5, out_dim, norm=norm_type, bias=True)
348
349
350
351
352
353
    conv.initialize(ctx=ctx)
    dense_conv.initialize(ctx=ctx)
    dense_conv.weight.set_data(
        conv.weight.data())
    dense_conv.bias.set_data(
        conv.bias.data())
354
    feat = F.randn((g.number_of_src_nodes(), 5))
355
356
357
358
    out_conv = conv(g, feat)
    out_dense_conv = dense_conv(adj, feat)
    assert F.allclose(out_conv, out_dense_conv)

359
360
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite', 'block-bipartite']))
361
362
@pytest.mark.parametrize('out_dim', [1, 2])
def test_dense_sage_conv(idtype, g, out_dim):
363
    g = g.astype(idtype).to(F.ctx())
364
    ctx = F.ctx()
365
    adj = g.adjacency_matrix(transpose=True, ctx=ctx).tostype('default')
366
367
    sage = nn.SAGEConv(5, out_dim, 'gcn')
    dense_sage = nn.DenseSAGEConv(5, out_dim)
368
369
370
371
372
373
    sage.initialize(ctx=ctx)
    dense_sage.initialize(ctx=ctx)
    dense_sage.fc.weight.set_data(
        sage.fc_neigh.weight.data())
    dense_sage.fc.bias.set_data(
        sage.fc_neigh.bias.data())
374
375
376
377
378
379
380
    if len(g.ntypes) == 2:
        feat = (
            F.randn((g.number_of_src_nodes(), 5)),
            F.randn((g.number_of_dst_nodes(), 5))
        )
    else:
        feat = F.randn((g.number_of_nodes(), 5))
381
382
383
384
385

    out_sage = sage(g, feat)
    out_dense_sage = dense_sage(adj, feat)
    assert F.allclose(out_sage, out_dense_sage)

386
@parametrize_dtype
387
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
388
389
@pytest.mark.parametrize('out_dim', [1, 2])
def test_edge_conv(g, idtype, out_dim):
390
    g = g.astype(idtype).to(F.ctx())
391
    ctx = F.ctx()
392
    edge_conv = nn.EdgeConv(5, out_dim)
393
394
    edge_conv.initialize(ctx=ctx)
    print(edge_conv)
395
    # test #1: basic
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
396
    h0 = F.randn((g.number_of_src_nodes(), 5))
397
    h1 = edge_conv(g, h0)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
398
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)
399

400
@parametrize_dtype
401
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
402
403
@pytest.mark.parametrize('out_dim', [1, 2])
def test_edge_conv_bi(g, idtype, out_dim):
404
405
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
406
    edge_conv = nn.EdgeConv(5, out_dim)
407
408
    edge_conv.initialize(ctx=ctx)
    print(edge_conv)
409
    # test #1: basic
410
    h0 = F.randn((g.number_of_src_nodes(), 5))
411
412
    x0 = F.randn((g.number_of_dst_nodes(), 5))
    h1 = edge_conv(g, (h0, x0))
413
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)
414

415
416
417
418
419
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
def test_gin_conv(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
420
421
    ctx = F.ctx()

422
    gin_conv = nn.GINConv(lambda x: x, aggregator_type, 0.1)
423
424
425
426
    gin_conv.initialize(ctx=ctx)
    print(gin_conv)

    # test #1: basic
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
427
    feat = F.randn((g.number_of_src_nodes(), 5))
428
    h = gin_conv(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
429
    assert h.shape == (g.number_of_dst_nodes(), 5)
430
431
432
433
434
435
436
437
438
439
440

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
def test_gin_conv_bi(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()

    gin_conv = nn.GINConv(lambda x: x, aggregator_type, 0.1)
    gin_conv.initialize(ctx=ctx)
    print(gin_conv)
441
442

    # test #2: bipartite
443
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
444
    h = gin_conv(g, feat)
445
    return h.shape == (g.number_of_dst_nodes(), 5)
446

447

448
@parametrize_dtype
449
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
450
451
def test_gmm_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
452
    ctx = F.ctx()
453
454
    gmm_conv = nn.GMMConv(5, 2, 5, 3, 'max')
    gmm_conv.initialize(ctx=ctx)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
455
    h0 = F.randn((g.number_of_src_nodes(), 5))
456
457
    pseudo = F.randn((g.number_of_edges(), 5))
    h1 = gmm_conv(g, h0, pseudo)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
458
    assert h1.shape == (g.number_of_dst_nodes(), 2)
459

460
@parametrize_dtype
461
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
462
463
464
def test_gmm_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
465
466
467
468
469
470
471
472
473
    gmm_conv = nn.GMMConv((5, 4), 2, 5, 3, 'max')
    gmm_conv.initialize(ctx=ctx)
    # test #1: basic
    h0 = F.randn((g.number_of_src_nodes(), 5))
    hd = F.randn((g.number_of_dst_nodes(), 4))
    pseudo = F.randn((g.number_of_edges(), 5))
    h1 = gmm_conv(g, (h0, hd), pseudo)
    assert h1.shape == (g.number_of_dst_nodes(), 2)

474
475
476
477
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
def test_nn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
478
479
480
    ctx = F.ctx()
    nn_conv = nn.NNConv(5, 2, gluon.nn.Embedding(3, 5 * 2), 'max')
    nn_conv.initialize(ctx=ctx)
481
    # test #1: basic
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
482
    h0 = F.randn((g.number_of_src_nodes(), 5))
483
484
    etypes = nd.random.randint(0, 4, g.number_of_edges()).as_in_context(ctx)
    h1 = nn_conv(g, h0, etypes)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
485
    assert h1.shape == (g.number_of_dst_nodes(), 2)
486

487
488
489
490
491
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
def test_nn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
492
493
494
495
496
497
498
499
500
    nn_conv = nn.NNConv((5, 4), 2, gluon.nn.Embedding(3, 5 * 2), 'max')
    nn_conv.initialize(ctx=ctx)
    # test #1: basic
    h0 = F.randn((g.number_of_src_nodes(), 5))
    hd = F.randn((g.number_of_dst_nodes(), 4))
    etypes = nd.random.randint(0, 4, g.number_of_edges()).as_in_context(ctx)
    h1 = nn_conv(g, (h0, hd), etypes)
    assert h1.shape == (g.number_of_dst_nodes(), 2)

501
502
@pytest.mark.parametrize('out_dim', [1, 2])
def test_sg_conv(out_dim):
503
    g = dgl.from_networkx(nx.erdos_renyi_graph(20, 0.3)).to(F.ctx())
504
    g = dgl.add_self_loop(g)
505
506
    ctx = F.ctx()

507
    sgc = nn.SGConv(5, out_dim, 2)
508
509
510
511
512
513
    sgc.initialize(ctx=ctx)
    print(sgc)

    # test #1: basic
    h0 = F.randn((g.number_of_nodes(), 5))
    h1 = sgc(g, h0)
514
    assert h1.shape == (g.number_of_nodes(), out_dim)
515

516
def test_set2set():
517
    g = dgl.from_networkx(nx.path_graph(10)).to(F.ctx())
518
    ctx = F.ctx()
519
520

    s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
521
    s2s.initialize(ctx=ctx)
522
523
524
    print(s2s)

    # test#1: basic
525
    h0 = F.randn((g.number_of_nodes(), 5))
526
    h1 = s2s(g, h0)
527
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2
528
529
530

    # test#2: batched graph
    bg = dgl.batch([g, g, g])
531
    h0 = F.randn((bg.number_of_nodes(), 5))
532
    h1 = s2s(bg, h0)
533
534
535
    assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.ndim == 2

def test_glob_att_pool():
536
    g = dgl.from_networkx(nx.path_graph(10)).to(F.ctx())
537
    ctx = F.ctx()
538
539

    gap = nn.GlobalAttentionPooling(gluon.nn.Dense(1), gluon.nn.Dense(10))
540
    gap.initialize(ctx=ctx)
541
542
    print(gap)
    # test#1: basic
543
    h0 = F.randn((g.number_of_nodes(), 5))
544
    h1 = gap(g, h0)
545
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2
546
547
548

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
549
    h0 = F.randn((bg.number_of_nodes(), 5))
550
    h1 = gap(bg, h0)
551
552
553
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2

def test_simple_pool():
554
    g = dgl.from_networkx(nx.path_graph(15)).to(F.ctx())
555
556
557
558
559
560
561
562

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
    sort_pool = nn.SortPooling(10) # k = 10
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
563
    h0 = F.randn((g.number_of_nodes(), 5))
564
    h1 = sum_pool(g, h0)
565
    check_close(F.squeeze(h1, 0), F.sum(h0, 0))
566
    h1 = avg_pool(g, h0)
567
    check_close(F.squeeze(h1, 0), F.mean(h0, 0))
568
    h1 = max_pool(g, h0)
569
    check_close(F.squeeze(h1, 0), F.max(h0, 0))
570
    h1 = sort_pool(g, h0)
571
    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.ndim == 2
572
573

    # test#2: batched graph
574
    g_ = dgl.from_networkx(nx.path_graph(5)).to(F.ctx())
575
    bg = dgl.batch([g, g_, g, g_, g])
576
    h0 = F.randn((bg.number_of_nodes(), 5))
577
    h1 = sum_pool(bg, h0)
578
579
580
581
582
    truth = mx.nd.stack(F.sum(h0[:15], 0),
                        F.sum(h0[15:20], 0),
                        F.sum(h0[20:35], 0),
                        F.sum(h0[35:40], 0),
                        F.sum(h0[40:55], 0), axis=0)
583
584
    check_close(h1, truth)

585
    h1 = avg_pool(bg, h0)
586
587
588
589
590
    truth = mx.nd.stack(F.mean(h0[:15], 0),
                        F.mean(h0[15:20], 0),
                        F.mean(h0[20:35], 0),
                        F.mean(h0[35:40], 0),
                        F.mean(h0[40:55], 0), axis=0)
591
592
    check_close(h1, truth)

593
    h1 = max_pool(bg, h0)
594
595
596
597
598
    truth = mx.nd.stack(F.max(h0[:15], 0),
                        F.max(h0[15:20], 0),
                        F.max(h0[20:35], 0),
                        F.max(h0[35:40], 0),
                        F.max(h0[40:55], 0), axis=0)
599
600
    check_close(h1, truth)

601
    h1 = sort_pool(bg, h0)
602
603
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.ndim == 2

604
605
@pytest.mark.parametrize('O', [1, 2, 8])
def test_rgcn(O):
Minjie Wang's avatar
Minjie Wang committed
606
607
    ctx = F.ctx()
    etype = []
608
    g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.1)).to(F.ctx())
Minjie Wang's avatar
Minjie Wang committed
609
610
611
612
613
614
615
616
617
618
619
620
621
622
    # 5 etypes
    R = 5
    for i in range(g.number_of_edges()):
        etype.append(i % 5)
    B = 2
    I = 10

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
    rgc_basis.initialize(ctx=ctx)
    h = nd.random.randn(100, I, ctx=ctx)
    r = nd.array(etype, ctx=ctx)
    h_new = rgc_basis(g, h, r)
    assert list(h_new.shape) == [100, O]

623
624
625
626
627
628
629
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
        rgc_bdd.initialize(ctx=ctx)
        h = nd.random.randn(100, I, ctx=ctx)
        r = nd.array(etype, ctx=ctx)
        h_new = rgc_bdd(g, h, r)
        assert list(h_new.shape) == [100, O]
Minjie Wang's avatar
Minjie Wang committed
630
631
632
633
634
635
636
637
638
639
640

    # with norm
    norm = nd.zeros((g.number_of_edges(), 1), ctx=ctx)

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
    rgc_basis.initialize(ctx=ctx)
    h = nd.random.randn(100, I, ctx=ctx)
    r = nd.array(etype, ctx=ctx)
    h_new = rgc_basis(g, h, r, norm)
    assert list(h_new.shape) == [100, O]

641
642
643
644
645
646
647
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
        rgc_bdd.initialize(ctx=ctx)
        h = nd.random.randn(100, I, ctx=ctx)
        r = nd.array(etype, ctx=ctx)
        h_new = rgc_bdd(g, h, r, norm)
        assert list(h_new.shape) == [100, O]
Minjie Wang's avatar
Minjie Wang committed
648
649
650
651
652
653
654
655
656

    # id input
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
    rgc_basis.initialize(ctx=ctx)
    h = nd.random.randint(0, I, (100,), ctx=ctx)
    r = nd.array(etype, ctx=ctx)
    h_new = rgc_basis(g, h, r)
    assert list(h_new.shape) == [100, O]

657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
def test_sequential():
    ctx = F.ctx()
    # test single graph
    class ExampleLayer(gluon.nn.Block):
        def __init__(self, **kwargs):
            super().__init__(**kwargs)

        def forward(self, graph, n_feat, e_feat):
            graph = graph.local_var()
            graph.ndata['h'] = n_feat
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            n_feat += graph.ndata['h']
            graph.apply_edges(fn.u_add_v('h', 'h', 'e'))
            e_feat += graph.edata['e']
            return n_feat, e_feat

673
    g = dgl.graph(([], [])).to(F.ctx())
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
    g.add_nodes(3)
    g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
    net = nn.Sequential()
    net.add(ExampleLayer())
    net.add(ExampleLayer())
    net.add(ExampleLayer())
    net.initialize(ctx=ctx)
    n_feat = F.randn((3, 4))
    e_feat = F.randn((9, 4))
    n_feat, e_feat = net(g, n_feat, e_feat)
    assert n_feat.shape == (3, 4)
    assert e_feat.shape == (9, 4)

    # test multiple graphs
    class ExampleLayer(gluon.nn.Block):
        def __init__(self, **kwargs):
            super().__init__(**kwargs)

        def forward(self, graph, n_feat):
            graph = graph.local_var()
            graph.ndata['h'] = n_feat
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            n_feat += graph.ndata['h']
            return n_feat.reshape(graph.number_of_nodes() // 2, 2, -1).sum(1)

699
700
701
702
    g1 = dgl.from_networkx(nx.erdos_renyi_graph(32, 0.05)).to(F.ctx())
    g2 = dgl.from_networkx(nx.erdos_renyi_graph(16, 0.2)).to(F.ctx())
    g3 = dgl.from_networkx(nx.erdos_renyi_graph(8, 0.8)).to(F.ctx())

703
704
705
706
707
708
709
710
711
    net = nn.Sequential()
    net.add(ExampleLayer())
    net.add(ExampleLayer())
    net.add(ExampleLayer())
    net.initialize(ctx=ctx)
    n_feat = F.randn((32, 4))
    n_feat = net([g1, g2, g3], n_feat)
    assert n_feat.shape == (4, 4)

712
713
714
715
716
717
def myagg(alist, dsttype):
    rst = alist[0]
    for i in range(1, len(alist)):
        rst = rst + (i + 1) * alist[i]
    return rst

718
@parametrize_dtype
719
@pytest.mark.parametrize('agg', ['sum', 'max', 'min', 'mean', 'stack', myagg])
720
def test_hetero_conv(agg, idtype):
721
    g = dgl.heterograph({
722
723
724
        ('user', 'follows', 'user'): ([0, 0, 2, 1], [1, 2, 1, 3]),
        ('user', 'plays', 'game'): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]),
        ('store', 'sells', 'game'): ([0, 0, 1, 1], [0, 3, 1, 2])},
725
        idtype=idtype, device=F.ctx())
726
    conv = nn.HeteroGraphConv({
727
728
729
        'follows': nn.GraphConv(2, 3, allow_zero_in_degree=True),
        'plays': nn.GraphConv(2, 4, allow_zero_in_degree=True),
        'sells': nn.GraphConv(3, 4, allow_zero_in_degree=True)},
730
731
732
733
734
735
736
        agg)
    conv.initialize(ctx=F.ctx())
    print(conv)
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))

737
    h = conv(g, {'user': uf, 'store': sf, 'game': gf})
738
739
740
741
742
743
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
744
        assert h['game'].shape == (4, 2, 4)
745

746
747
    block = dgl.to_block(g.to(F.cpu()), {'user': [0, 1, 2, 3], 'game': [0, 1, 2, 3], 'store': []}).to(F.ctx())
    h = conv(block, ({'user': uf, 'game': gf, 'store': sf}, {'user': uf, 'game': gf, 'store': sf[0:0]}))
748
749
750
751
752
753
754
755
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 2, 4)

756
    h = conv(block, {'user': uf, 'game': gf, 'store': sf})
757
758
759
760
761
762
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
763
        assert h['game'].shape == (4, 2, 4)
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785

    # test with mod args
    class MyMod(mx.gluon.nn.Block):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.s1 = s1
            self.s2 = s2
        def forward(self, g, h, arg1=None):  # mxnet does not support kwargs
            if arg1 is not None:
                self.carg1 += 1
            return F.zeros((g.number_of_dst_nodes(), self.s2))
    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
    conv = nn.HeteroGraphConv({
        'follows': mod1,
        'plays': mod2,
        'sells': mod3},
        agg)
    conv.initialize(ctx=F.ctx())
    mod_args = {'follows' : (1,), 'plays' : (1,)}
786
    h = conv(g, {'user' : uf, 'store' : sf, 'game': gf}, mod_args)
787
788
789
790
    assert mod1.carg1 == 1
    assert mod2.carg1 == 1
    assert mod3.carg1 == 0

791
792
if __name__ == '__main__':
    test_graph_conv()
793
794
795
796
797
798
799
800
801
802
803
804
805
806
    test_gat_conv()
    test_sage_conv()
    test_gg_conv()
    test_cheb_conv()
    test_agnn_conv()
    test_appnp_conv()
    test_dense_cheb_conv()
    test_dense_graph_conv()
    test_dense_sage_conv()
    test_edge_conv()
    test_gin_conv()
    test_gmm_conv()
    test_nn_conv()
    test_sg_conv()
807
808
809
    test_set2set()
    test_glob_att_pool()
    test_simple_pool()
Minjie Wang's avatar
Minjie Wang committed
810
    test_rgcn()
811
    test_sequential()