"vscode:/vscode.git/clone" did not exist on "ebbc42d989bb206d296f074a5dea7aed948d3715"
test_nn.py 25.2 KB
Newer Older
1
2
3
import mxnet as mx
import networkx as nx
import numpy as np
Minjie Wang's avatar
Minjie Wang committed
4
import scipy as sp
5
import pytest
6
7
import dgl
import dgl.nn.mxnet as nn
8
import dgl.function as fn
9
import backend as F
10
11
from test_utils.graph_cases import get_cases, random_graph, random_bipartite, random_dglgraph
from test_utils import parametrize_dtype
Minjie Wang's avatar
Minjie Wang committed
12
from mxnet import autograd, gluon, nd
13

14
15
def check_close(a, b):
    assert np.allclose(a.asnumpy(), b.asnumpy(), rtol=1e-4, atol=1e-4)
16
17
18
19
20
21

def _AXWb(A, X, W, b):
    X = mx.nd.dot(X, W.data(X.context))
    Y = mx.nd.dot(A, X.reshape(X.shape[0], -1)).reshape(X.shape)
    return Y + b.data(X.context)

22
23
@parametrize_dtype
def test_graph_conv(idtype):
24
    g = dgl.from_networkx(nx.path_graph(3))
25
    g = g.astype(idtype).to(F.ctx())
26
    ctx = F.ctx()
27
    adj = g.adjacency_matrix(transpose=False, ctx=ctx)
28

29
    conv = nn.GraphConv(5, 2, norm='none', bias=True)
30
31
    conv.initialize(ctx=ctx)
    # test#1: basic
32
    h0 = F.ones((3, 5))
33
    h1 = conv(g, h0)
34
35
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
36
    check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias))
37
    # test#2: more-dim
38
    h0 = F.ones((3, 5, 5))
39
    h1 = conv(g, h0)
40
41
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
42
    check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias))
43
44
45
46
47

    conv = nn.GraphConv(5, 2)
    conv.initialize(ctx=ctx)

    # test#3: basic
48
    h0 = F.ones((3, 5))
49
    h1 = conv(g, h0)
50
51
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
52
    # test#4: basic
53
    h0 = F.ones((3, 5, 5))
54
    h1 = conv(g, h0)
55
56
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
57
58
59
60
61
62

    conv = nn.GraphConv(5, 2)
    conv.initialize(ctx=ctx)

    with autograd.train_mode():
        # test#3: basic
63
        h0 = F.ones((3, 5))
64
        h1 = conv(g, h0)
65
66
        assert len(g.ndata) == 0
        assert len(g.edata) == 0
67
        # test#4: basic
68
        h0 = F.ones((3, 5, 5))
69
        h1 = conv(g, h0)
70
71
        assert len(g.ndata) == 0
        assert len(g.edata) == 0
72

73
    # test not override features
74
    g.ndata["h"] = 2 * F.ones((3, 1))
75
    h1 = conv(g, h0)
76
77
78
    assert len(g.ndata) == 1
    assert len(g.edata) == 0
    assert "h" in g.ndata
79
    check_close(g.ndata['h'], 2 * F.ones((3, 1)))
80

81
82
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree', 'dglgraph']))
83
84
85
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [False])
86
87
def test_graph_conv2(idtype, g, norm, weight, bias):
    g = g.astype(idtype).to(F.ctx())
88
89
90
    conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias)
    conv.initialize(ctx=F.ctx())
    ext_w = F.randn((5, 2)).as_in_context(F.ctx())
91
    nsrc = ndst = g.number_of_nodes()
92
93
    h = F.randn((nsrc, 5)).as_in_context(F.ctx())
    if weight:
94
        h_out = conv(g, h)
95
    else:
96
97
98
        h_out = conv(g, h, ext_w)
    assert h_out.shape == (ndst, 2)

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [False])
def test_graph_conv2_bi(idtype, g, norm, weight, bias):
    g = g.astype(idtype).to(F.ctx())
    conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias)
    conv.initialize(ctx=F.ctx())
    ext_w = F.randn((5, 2)).as_in_context(F.ctx())
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).as_in_context(F.ctx())
    h_dst = F.randn((ndst, 2)).as_in_context(F.ctx())
    if weight:
        h_out = conv(g, (h, h_dst))
    else:
        h_out = conv(g, (h, h_dst), ext_w)
    assert h_out.shape == (ndst, 2)
118

119
120
121
122
123
124
125
126
127
128
129
130
131
def _S2AXWb(A, N, X, W, b):
    X1 = X * N
    X1 = mx.nd.dot(A, X1.reshape(X1.shape[0], -1))
    X1 = X1 * N
    X2 = X1 * N
    X2 = mx.nd.dot(A, X2.reshape(X2.shape[0], -1))
    X2 = X2 * N
    X = mx.nd.concat(X, X1, X2, dim=-1)
    Y = mx.nd.dot(X, W)

    return Y + b

def test_tagconv():
132
    g = dgl.from_networkx(nx.path_graph(3)).to(F.ctx())
133
    ctx = F.ctx()
134
    adj = g.adjacency_matrix(transpose=False, ctx=ctx)
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    norm = mx.nd.power(g.in_degrees().astype('float32'), -0.5)

    conv = nn.TAGConv(5, 2, bias=True)
    conv.initialize(ctx=ctx)
    print(conv)

    # test#1: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    shp = norm.shape + (1,) * (h0.ndim - 1)
    norm = norm.reshape(shp).as_in_context(h0.context)

    assert F.allclose(h1, _S2AXWb(adj, norm, h0, conv.lin.data(ctx), conv.h_bias.data(ctx)))

    conv = nn.TAGConv(5, 2)
    conv.initialize(ctx=ctx)

    # test#2: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert h1.shape[-1] == 2

159
@parametrize_dtype
160
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
161
162
def test_gat_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
163
164
165
166
    ctx = F.ctx()
    gat = nn.GATConv(10, 20, 5) # n_heads = 5
    gat.initialize(ctx=ctx)
    print(gat)
167
    feat = F.randn((g.number_of_nodes(), 10))
168
    h = gat(g, feat)
169
    assert h.shape == (g.number_of_nodes(), 5, 20)
170
171
    _, a = gat(g, feat, True)
    assert a.shape == (g.number_of_edges(), 5, 1)
172

173
@parametrize_dtype
174
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
175
176
177
def test_gat_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
178
    gat = nn.GATConv(5, 2, 4)
179
    gat.initialize(ctx=ctx)
180
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
181
    h = gat(g, feat)
182
    assert h.shape == (g.number_of_dst_nodes(), 4, 2)
183
184
    _, a = gat(g, feat, True)
    assert a.shape == (g.number_of_edges(), 4, 1)
185

186
187
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
188
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
189
190
def test_sage_conv(idtype, g, aggre_type):
    g = g.astype(idtype).to(F.ctx())
191
192
    ctx = F.ctx()
    sage = nn.SAGEConv(5, 10, aggre_type)
193
    feat = F.randn((g.number_of_nodes(), 5))
194
195
196
197
    sage.initialize(ctx=ctx)
    h = sage(g, feat)
    assert h.shape[-1] == 10

198
199
200
201
202
203
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
def test_sage_conv_bi(idtype, g, aggre_type):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
204
205
    dst_dim = 5 if aggre_type != 'gcn' else 10
    sage = nn.SAGEConv((10, dst_dim), 2, aggre_type)
206
    feat = (F.randn((g.number_of_src_nodes(), 10)), F.randn((g.number_of_dst_nodes(), dst_dim)))
207
208
209
    sage.initialize(ctx=ctx)
    h = sage(g, feat)
    assert h.shape[-1] == 2
210
    assert h.shape[0] == g.number_of_dst_nodes()
211

212
213
214
@parametrize_dtype
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
def test_sage_conv_bi2(idtype, aggre_type):
Mufei Li's avatar
Mufei Li committed
215
    # Test the case for graphs without edges
216
    g = dgl.heterograph({('_U', '_E', '_V'): ([], [])}, {'_U': 5, '_V': 3})
217
218
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
Mufei Li's avatar
Mufei Li committed
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    sage = nn.SAGEConv((3, 3), 2, 'gcn')
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    sage.initialize(ctx=ctx)
    h = sage(g, feat)
    assert h.shape[-1] == 2
    assert h.shape[0] == 3
    for aggre_type in ['mean', 'pool']:
        sage = nn.SAGEConv((3, 1), 2, aggre_type)
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        sage.initialize(ctx=ctx)
        h = sage(g, feat)
        assert h.shape[-1] == 2
        assert h.shape[0] == 3

233
def test_gg_conv():
234
    g = dgl.from_networkx(nx.erdos_renyi_graph(20, 0.3)).to(F.ctx())
235
236
237
238
239
240
241
242
243
244
245
246
247
    ctx = F.ctx()

    gg_conv = nn.GatedGraphConv(10, 20, 3, 4) # n_step = 3, n_etypes = 4
    gg_conv.initialize(ctx=ctx)
    print(gg_conv)

    # test#1: basic
    h0 = F.randn((20, 10))
    etypes = nd.random.randint(0, 4, g.number_of_edges()).as_in_context(ctx)
    h1 = gg_conv(g, h0, etypes)
    assert h1.shape == (20, 20)

def test_cheb_conv():
248
    g = dgl.from_networkx(nx.erdos_renyi_graph(20, 0.3)).to(F.ctx())
249
250
251
252
253
254
255
256
257
258
259
    ctx = F.ctx()

    cheb = nn.ChebConv(10, 20, 3) # k = 3
    cheb.initialize(ctx=ctx)
    print(cheb)

    # test#1: basic
    h0 = F.randn((20, 10))
    h1 = cheb(g, h0)
    assert h1.shape == (20, 20)

260
@parametrize_dtype
261
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
262
263
def test_agnn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
264
265
266
267
    ctx = F.ctx()
    agnn_conv = nn.AGNNConv(0.1, True)
    agnn_conv.initialize(ctx=ctx)
    print(agnn_conv)
268
    feat = F.randn((g.number_of_nodes(), 10))
269
    h = agnn_conv(g, feat)
270
    assert h.shape == (g.number_of_nodes(), 10)
271

272
@parametrize_dtype
273
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
274
275
276
277
278
279
280
def test_agnn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    agnn_conv = nn.AGNNConv(0.1, True)
    agnn_conv.initialize(ctx=ctx)
    print(agnn_conv)
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
281
    h = agnn_conv(g, feat)
282
    assert h.shape == (g.number_of_dst_nodes(), 5)
283

284
def test_appnp_conv():
285
    g = dgl.from_networkx(nx.erdos_renyi_graph(20, 0.3)).to(F.ctx())
286
287
288
289
290
291
292
293
294
295
296
297
298
299
    ctx = F.ctx()

    appnp_conv = nn.APPNPConv(3, 0.1, 0)
    appnp_conv.initialize(ctx=ctx)
    print(appnp_conv)

    # test#1: basic
    h0 = F.randn((20, 10))
    h1 = appnp_conv(g, h0)
    assert h1.shape == (20, 10)

def test_dense_cheb_conv():
    for k in range(1, 4):
        ctx = F.ctx()
300
        g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.3)).to(F.ctx())
301
        adj = g.adjacency_matrix(transpose=False, ctx=ctx).tostype('default')
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
        cheb = nn.ChebConv(5, 2, k)
        dense_cheb = nn.DenseChebConv(5, 2, k)
        cheb.initialize(ctx=ctx)
        dense_cheb.initialize(ctx=ctx)

        for i in range(len(cheb.fc)):
            dense_cheb.fc[i].weight.set_data(
                cheb.fc[i].weight.data())
            if cheb.bias is not None:
                dense_cheb.bias.set_data(
                    cheb.bias.data())

        feat = F.randn((100, 5))
        out_cheb = cheb(g, feat, [2.0])
        out_dense_cheb = dense_cheb(adj, feat, 2.0)
        assert F.allclose(out_cheb, out_dense_cheb)

319
@parametrize_dtype
320
@pytest.mark.parametrize('norm_type', ['both', 'right', 'none'])
321
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
322
323
def test_dense_graph_conv(idtype, g, norm_type):
    g = g.astype(idtype).to(F.ctx())
324
    ctx = F.ctx()
325
    adj = g.adjacency_matrix(transpose=False, ctx=ctx).tostype('default')
326
327
    conv = nn.GraphConv(5, 2, norm=norm_type, bias=True)
    dense_conv = nn.DenseGraphConv(5, 2, norm=norm_type, bias=True)
328
329
330
331
332
333
    conv.initialize(ctx=ctx)
    dense_conv.initialize(ctx=ctx)
    dense_conv.weight.set_data(
        conv.weight.data())
    dense_conv.bias.set_data(
        conv.bias.data())
334
    feat = F.randn((g.number_of_src_nodes(), 5))
335
336
337
338
    out_conv = conv(g, feat)
    out_dense_conv = dense_conv(adj, feat)
    assert F.allclose(out_conv, out_dense_conv)

339
340
341
342
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite', 'block-bipartite']))
def test_dense_sage_conv(idtype, g):
    g = g.astype(idtype).to(F.ctx())
343
    ctx = F.ctx()
344
    adj = g.adjacency_matrix(transpose=False, ctx=ctx).tostype('default')
345
346
347
348
349
350
351
352
    sage = nn.SAGEConv(5, 2, 'gcn')
    dense_sage = nn.DenseSAGEConv(5, 2)
    sage.initialize(ctx=ctx)
    dense_sage.initialize(ctx=ctx)
    dense_sage.fc.weight.set_data(
        sage.fc_neigh.weight.data())
    dense_sage.fc.bias.set_data(
        sage.fc_neigh.bias.data())
353
354
355
356
357
358
359
    if len(g.ntypes) == 2:
        feat = (
            F.randn((g.number_of_src_nodes(), 5)),
            F.randn((g.number_of_dst_nodes(), 5))
        )
    else:
        feat = F.randn((g.number_of_nodes(), 5))
360
361
362
363
364

    out_sage = sage(g, feat)
    out_dense_sage = dense_sage(adj, feat)
    assert F.allclose(out_sage, out_dense_sage)

365
@parametrize_dtype
366
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
367
368
def test_edge_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
369
370
371
372
    ctx = F.ctx()
    edge_conv = nn.EdgeConv(5, 2)
    edge_conv.initialize(ctx=ctx)
    print(edge_conv)
373
374
375
376
    # test #1: basic
    h0 = F.randn((g.number_of_nodes(), 5))
    h1 = edge_conv(g, h0)
    assert h1.shape == (g.number_of_nodes(), 2)
377

378
@parametrize_dtype
379
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
380
381
382
383
384
385
def test_edge_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    edge_conv = nn.EdgeConv(5, 2)
    edge_conv.initialize(ctx=ctx)
    print(edge_conv)
386
    # test #1: basic
387
    h0 = F.randn((g.number_of_src_nodes(), 5))
388
389
    x0 = F.randn((g.number_of_dst_nodes(), 5))
    h1 = edge_conv(g, (h0, x0))
390
    assert h1.shape == (g.number_of_dst_nodes(), 2)
391

392
393
394
395
396
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
def test_gin_conv(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
397
398
    ctx = F.ctx()

399
    gin_conv = nn.GINConv(lambda x: x, aggregator_type, 0.1)
400
401
402
403
    gin_conv.initialize(ctx=ctx)
    print(gin_conv)

    # test #1: basic
404
405
    feat = F.randn((g.number_of_nodes(), 5))
    h = gin_conv(g, feat)
406
407
408
409
410
411
412
413
414
415
416
417
    assert h.shape == (g.number_of_nodes(), 5)

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
def test_gin_conv_bi(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()

    gin_conv = nn.GINConv(lambda x: x, aggregator_type, 0.1)
    gin_conv.initialize(ctx=ctx)
    print(gin_conv)
418
419

    # test #2: bipartite
420
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
421
    h = gin_conv(g, feat)
422
    return h.shape == (g.number_of_dst_nodes(), 5)
423

424

425
@parametrize_dtype
426
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
427
428
def test_gmm_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
429
    ctx = F.ctx()
430
431
    gmm_conv = nn.GMMConv(5, 2, 5, 3, 'max')
    gmm_conv.initialize(ctx=ctx)
432
433
434
435
436
    h0 = F.randn((g.number_of_nodes(), 5))
    pseudo = F.randn((g.number_of_edges(), 5))
    h1 = gmm_conv(g, h0, pseudo)
    assert h1.shape == (g.number_of_nodes(), 2)

437
@parametrize_dtype
438
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
439
440
441
def test_gmm_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
442
443
444
445
446
447
448
449
450
    gmm_conv = nn.GMMConv((5, 4), 2, 5, 3, 'max')
    gmm_conv.initialize(ctx=ctx)
    # test #1: basic
    h0 = F.randn((g.number_of_src_nodes(), 5))
    hd = F.randn((g.number_of_dst_nodes(), 4))
    pseudo = F.randn((g.number_of_edges(), 5))
    h1 = gmm_conv(g, (h0, hd), pseudo)
    assert h1.shape == (g.number_of_dst_nodes(), 2)

451
452
453
454
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
def test_nn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
455
456
457
    ctx = F.ctx()
    nn_conv = nn.NNConv(5, 2, gluon.nn.Embedding(3, 5 * 2), 'max')
    nn_conv.initialize(ctx=ctx)
458
459
460
461
462
    # test #1: basic
    h0 = F.randn((g.number_of_nodes(), 5))
    etypes = nd.random.randint(0, 4, g.number_of_edges()).as_in_context(ctx)
    h1 = nn_conv(g, h0, etypes)
    assert h1.shape == (g.number_of_nodes(), 2)
463

464
465
466
467
468
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
def test_nn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
469
470
471
472
473
474
475
476
477
    nn_conv = nn.NNConv((5, 4), 2, gluon.nn.Embedding(3, 5 * 2), 'max')
    nn_conv.initialize(ctx=ctx)
    # test #1: basic
    h0 = F.randn((g.number_of_src_nodes(), 5))
    hd = F.randn((g.number_of_dst_nodes(), 4))
    etypes = nd.random.randint(0, 4, g.number_of_edges()).as_in_context(ctx)
    h1 = nn_conv(g, (h0, hd), etypes)
    assert h1.shape == (g.number_of_dst_nodes(), 2)

478
def test_sg_conv():
479
    g = dgl.from_networkx(nx.erdos_renyi_graph(20, 0.3)).to(F.ctx())
480
    g = dgl.add_self_loop(g)
481
482
483
484
485
486
487
488
489
490
491
    ctx = F.ctx()

    sgc = nn.SGConv(5, 2, 2)
    sgc.initialize(ctx=ctx)
    print(sgc)

    # test #1: basic
    h0 = F.randn((g.number_of_nodes(), 5))
    h1 = sgc(g, h0)
    assert h1.shape == (g.number_of_nodes(), 2)

492
def test_set2set():
493
    g = dgl.from_networkx(nx.path_graph(10)).to(F.ctx())
494
    ctx = F.ctx()
495
496

    s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
497
    s2s.initialize(ctx=ctx)
498
499
500
    print(s2s)

    # test#1: basic
501
    h0 = F.randn((g.number_of_nodes(), 5))
502
    h1 = s2s(g, h0)
503
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2
504
505
506

    # test#2: batched graph
    bg = dgl.batch([g, g, g])
507
    h0 = F.randn((bg.number_of_nodes(), 5))
508
    h1 = s2s(bg, h0)
509
510
511
    assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.ndim == 2

def test_glob_att_pool():
512
    g = dgl.from_networkx(nx.path_graph(10)).to(F.ctx())
513
    ctx = F.ctx()
514
515

    gap = nn.GlobalAttentionPooling(gluon.nn.Dense(1), gluon.nn.Dense(10))
516
    gap.initialize(ctx=ctx)
517
518
    print(gap)
    # test#1: basic
519
    h0 = F.randn((g.number_of_nodes(), 5))
520
    h1 = gap(g, h0)
521
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2
522
523
524

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
525
    h0 = F.randn((bg.number_of_nodes(), 5))
526
    h1 = gap(bg, h0)
527
528
529
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2

def test_simple_pool():
530
    g = dgl.from_networkx(nx.path_graph(15)).to(F.ctx())
531
532
533
534
535
536
537
538

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
    sort_pool = nn.SortPooling(10) # k = 10
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
539
    h0 = F.randn((g.number_of_nodes(), 5))
540
    h1 = sum_pool(g, h0)
541
    check_close(F.squeeze(h1, 0), F.sum(h0, 0))
542
    h1 = avg_pool(g, h0)
543
    check_close(F.squeeze(h1, 0), F.mean(h0, 0))
544
    h1 = max_pool(g, h0)
545
    check_close(F.squeeze(h1, 0), F.max(h0, 0))
546
    h1 = sort_pool(g, h0)
547
    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.ndim == 2
548
549

    # test#2: batched graph
550
    g_ = dgl.from_networkx(nx.path_graph(5)).to(F.ctx())
551
    bg = dgl.batch([g, g_, g, g_, g])
552
    h0 = F.randn((bg.number_of_nodes(), 5))
553
    h1 = sum_pool(bg, h0)
554
555
556
557
558
    truth = mx.nd.stack(F.sum(h0[:15], 0),
                        F.sum(h0[15:20], 0),
                        F.sum(h0[20:35], 0),
                        F.sum(h0[35:40], 0),
                        F.sum(h0[40:55], 0), axis=0)
559
560
    check_close(h1, truth)

561
    h1 = avg_pool(bg, h0)
562
563
564
565
566
    truth = mx.nd.stack(F.mean(h0[:15], 0),
                        F.mean(h0[15:20], 0),
                        F.mean(h0[20:35], 0),
                        F.mean(h0[35:40], 0),
                        F.mean(h0[40:55], 0), axis=0)
567
568
    check_close(h1, truth)

569
    h1 = max_pool(bg, h0)
570
571
572
573
574
    truth = mx.nd.stack(F.max(h0[:15], 0),
                        F.max(h0[15:20], 0),
                        F.max(h0[20:35], 0),
                        F.max(h0[35:40], 0),
                        F.max(h0[40:55], 0), axis=0)
575
576
    check_close(h1, truth)

577
    h1 = sort_pool(bg, h0)
578
579
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.ndim == 2

Minjie Wang's avatar
Minjie Wang committed
580
581
582
def test_rgcn():
    ctx = F.ctx()
    etype = []
583
    g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.1)).to(F.ctx())
Minjie Wang's avatar
Minjie Wang committed
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
    # 5 etypes
    R = 5
    for i in range(g.number_of_edges()):
        etype.append(i % 5)
    B = 2
    I = 10
    O = 8

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
    rgc_basis.initialize(ctx=ctx)
    h = nd.random.randn(100, I, ctx=ctx)
    r = nd.array(etype, ctx=ctx)
    h_new = rgc_basis(g, h, r)
    assert list(h_new.shape) == [100, O]

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
    rgc_bdd.initialize(ctx=ctx)
    h = nd.random.randn(100, I, ctx=ctx)
    r = nd.array(etype, ctx=ctx)
    h_new = rgc_bdd(g, h, r)
    assert list(h_new.shape) == [100, O]

    # with norm
    norm = nd.zeros((g.number_of_edges(), 1), ctx=ctx)

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
    rgc_basis.initialize(ctx=ctx)
    h = nd.random.randn(100, I, ctx=ctx)
    r = nd.array(etype, ctx=ctx)
    h_new = rgc_basis(g, h, r, norm)
    assert list(h_new.shape) == [100, O]

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
    rgc_bdd.initialize(ctx=ctx)
    h = nd.random.randn(100, I, ctx=ctx)
    r = nd.array(etype, ctx=ctx)
    h_new = rgc_bdd(g, h, r, norm)
    assert list(h_new.shape) == [100, O]

    # id input
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
    rgc_basis.initialize(ctx=ctx)
    h = nd.random.randint(0, I, (100,), ctx=ctx)
    r = nd.array(etype, ctx=ctx)
    h_new = rgc_basis(g, h, r)
    assert list(h_new.shape) == [100, O]

631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
def test_sequential():
    ctx = F.ctx()
    # test single graph
    class ExampleLayer(gluon.nn.Block):
        def __init__(self, **kwargs):
            super().__init__(**kwargs)

        def forward(self, graph, n_feat, e_feat):
            graph = graph.local_var()
            graph.ndata['h'] = n_feat
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            n_feat += graph.ndata['h']
            graph.apply_edges(fn.u_add_v('h', 'h', 'e'))
            e_feat += graph.edata['e']
            return n_feat, e_feat

647
    g = dgl.graph(([], [])).to(F.ctx())
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
    g.add_nodes(3)
    g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
    net = nn.Sequential()
    net.add(ExampleLayer())
    net.add(ExampleLayer())
    net.add(ExampleLayer())
    net.initialize(ctx=ctx)
    n_feat = F.randn((3, 4))
    e_feat = F.randn((9, 4))
    n_feat, e_feat = net(g, n_feat, e_feat)
    assert n_feat.shape == (3, 4)
    assert e_feat.shape == (9, 4)

    # test multiple graphs
    class ExampleLayer(gluon.nn.Block):
        def __init__(self, **kwargs):
            super().__init__(**kwargs)

        def forward(self, graph, n_feat):
            graph = graph.local_var()
            graph.ndata['h'] = n_feat
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            n_feat += graph.ndata['h']
            return n_feat.reshape(graph.number_of_nodes() // 2, 2, -1).sum(1)

673
674
675
676
    g1 = dgl.from_networkx(nx.erdos_renyi_graph(32, 0.05)).to(F.ctx())
    g2 = dgl.from_networkx(nx.erdos_renyi_graph(16, 0.2)).to(F.ctx())
    g3 = dgl.from_networkx(nx.erdos_renyi_graph(8, 0.8)).to(F.ctx())

677
678
679
680
681
682
683
684
685
    net = nn.Sequential()
    net.add(ExampleLayer())
    net.add(ExampleLayer())
    net.add(ExampleLayer())
    net.initialize(ctx=ctx)
    n_feat = F.randn((32, 4))
    n_feat = net([g1, g2, g3], n_feat)
    assert n_feat.shape == (4, 4)

686
687
688
689
690
691
def myagg(alist, dsttype):
    rst = alist[0]
    for i in range(1, len(alist)):
        rst = rst + (i + 1) * alist[i]
    return rst

692
@parametrize_dtype
693
@pytest.mark.parametrize('agg', ['sum', 'max', 'min', 'mean', 'stack', myagg])
694
def test_hetero_conv(agg, idtype):
695
    g = dgl.heterograph({
696
697
698
        ('user', 'follows', 'user'): ([0, 0, 2, 1], [1, 2, 1, 3]),
        ('user', 'plays', 'game'): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]),
        ('store', 'sells', 'game'): ([0, 0, 1, 1], [0, 3, 1, 2])},
699
        idtype=idtype, device=F.ctx())
700
    conv = nn.HeteroGraphConv({
701
702
703
        'follows': nn.GraphConv(2, 3, allow_zero_in_degree=True),
        'plays': nn.GraphConv(2, 4, allow_zero_in_degree=True),
        'sells': nn.GraphConv(3, 4, allow_zero_in_degree=True)},
704
705
706
707
708
709
710
        agg)
    conv.initialize(ctx=F.ctx())
    print(conv)
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))

711
    h = conv(g, {'user': uf, 'store': sf, 'game': gf})
712
713
714
715
716
717
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
718
        assert h['game'].shape == (4, 2, 4)
719

720
721
    block = dgl.to_block(g.to(F.cpu()), {'user': [0, 1, 2, 3], 'game': [0, 1, 2, 3], 'store': []}).to(F.ctx())
    h = conv(block, ({'user': uf, 'game': gf, 'store': sf}, {'user': uf, 'game': gf, 'store': sf[0:0]}))
722
723
724
725
726
727
728
729
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 2, 4)

730
    h = conv(block, {'user': uf, 'game': gf, 'store': sf})
731
732
733
734
735
736
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
737
        assert h['game'].shape == (4, 2, 4)
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759

    # test with mod args
    class MyMod(mx.gluon.nn.Block):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.s1 = s1
            self.s2 = s2
        def forward(self, g, h, arg1=None):  # mxnet does not support kwargs
            if arg1 is not None:
                self.carg1 += 1
            return F.zeros((g.number_of_dst_nodes(), self.s2))
    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
    conv = nn.HeteroGraphConv({
        'follows': mod1,
        'plays': mod2,
        'sells': mod3},
        agg)
    conv.initialize(ctx=F.ctx())
    mod_args = {'follows' : (1,), 'plays' : (1,)}
760
    h = conv(g, {'user' : uf, 'store' : sf, 'game': gf}, mod_args)
761
762
763
764
    assert mod1.carg1 == 1
    assert mod2.carg1 == 1
    assert mod3.carg1 == 0

765
766
if __name__ == '__main__':
    test_graph_conv()
767
768
769
770
771
772
773
774
775
776
777
778
779
780
    test_gat_conv()
    test_sage_conv()
    test_gg_conv()
    test_cheb_conv()
    test_agnn_conv()
    test_appnp_conv()
    test_dense_cheb_conv()
    test_dense_graph_conv()
    test_dense_sage_conv()
    test_edge_conv()
    test_gin_conv()
    test_gmm_conv()
    test_nn_conv()
    test_sg_conv()
781
782
783
    test_set2set()
    test_glob_att_pool()
    test_simple_pool()
Minjie Wang's avatar
Minjie Wang committed
784
    test_rgcn()
785
    test_sequential()