test_nn.py 26.8 KB
Newer Older
1
2
3
import mxnet as mx
import networkx as nx
import numpy as np
Minjie Wang's avatar
Minjie Wang committed
4
import scipy as sp
5
import pytest
6
7
import dgl
import dgl.nn.mxnet as nn
8
import dgl.function as fn
9
import backend as F
10
11
from test_utils.graph_cases import get_cases, random_graph, random_bipartite, random_dglgraph
from test_utils import parametrize_dtype
Minjie Wang's avatar
Minjie Wang committed
12
from mxnet import autograd, gluon, nd
13

14
15
def check_close(a, b):
    assert np.allclose(a.asnumpy(), b.asnumpy(), rtol=1e-4, atol=1e-4)
16
17
18
19
20
21

def _AXWb(A, X, W, b):
    X = mx.nd.dot(X, W.data(X.context))
    Y = mx.nd.dot(A, X.reshape(X.shape[0], -1)).reshape(X.shape)
    return Y + b.data(X.context)

22
23
24
25
@parametrize_dtype
def test_graph_conv(idtype):
    g = dgl.graph(nx.path_graph(3))
    g = g.astype(idtype).to(F.ctx())
26
27
    ctx = F.ctx()
    adj = g.adjacency_matrix(ctx=ctx)
28

29
    conv = nn.GraphConv(5, 2, norm='none', bias=True)
30
31
    conv.initialize(ctx=ctx)
    # test#1: basic
32
    h0 = F.ones((3, 5))
33
    h1 = conv(g, h0)
34
35
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
36
    check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias))
37
    # test#2: more-dim
38
    h0 = F.ones((3, 5, 5))
39
    h1 = conv(g, h0)
40
41
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
42
    check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias))
43
44
45
46
47

    conv = nn.GraphConv(5, 2)
    conv.initialize(ctx=ctx)

    # test#3: basic
48
    h0 = F.ones((3, 5))
49
    h1 = conv(g, h0)
50
51
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
52
    # test#4: basic
53
    h0 = F.ones((3, 5, 5))
54
    h1 = conv(g, h0)
55
56
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
57
58
59
60
61
62

    conv = nn.GraphConv(5, 2)
    conv.initialize(ctx=ctx)

    with autograd.train_mode():
        # test#3: basic
63
        h0 = F.ones((3, 5))
64
        h1 = conv(g, h0)
65
66
        assert len(g.ndata) == 0
        assert len(g.edata) == 0
67
        # test#4: basic
68
        h0 = F.ones((3, 5, 5))
69
        h1 = conv(g, h0)
70
71
        assert len(g.ndata) == 0
        assert len(g.edata) == 0
72

73
    # test not override features
74
    g.ndata["h"] = 2 * F.ones((3, 1))
75
    h1 = conv(g, h0)
76
77
78
    assert len(g.ndata) == 1
    assert len(g.edata) == 0
    assert "h" in g.ndata
79
    check_close(g.ndata['h'], 2 * F.ones((3, 1)))
80

81
82
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree', 'dglgraph']))
83
84
85
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [False])
86
87
def test_graph_conv2(idtype, g, norm, weight, bias):
    g = g.astype(idtype).to(F.ctx())
88
89
90
    conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias)
    conv.initialize(ctx=F.ctx())
    ext_w = F.randn((5, 2)).as_in_context(F.ctx())
91
    nsrc = ndst = g.number_of_nodes()
92
93
    h = F.randn((nsrc, 5)).as_in_context(F.ctx())
    if weight:
94
        h_out = conv(g, h)
95
    else:
96
97
98
        h_out = conv(g, h, ext_w)
    assert h_out.shape == (ndst, 2)

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [False])
def test_graph_conv2_bi(idtype, g, norm, weight, bias):
    g = g.astype(idtype).to(F.ctx())
    conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias)
    conv.initialize(ctx=F.ctx())
    ext_w = F.randn((5, 2)).as_in_context(F.ctx())
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).as_in_context(F.ctx())
    h_dst = F.randn((ndst, 2)).as_in_context(F.ctx())
    if weight:
        h_out = conv(g, (h, h_dst))
    else:
        h_out = conv(g, (h, h_dst), ext_w)
    assert h_out.shape == (ndst, 2)
118

119
120
121
122
123
124
125
126
127
128
129
130
131
def _S2AXWb(A, N, X, W, b):
    X1 = X * N
    X1 = mx.nd.dot(A, X1.reshape(X1.shape[0], -1))
    X1 = X1 * N
    X2 = X1 * N
    X2 = mx.nd.dot(A, X2.reshape(X2.shape[0], -1))
    X2 = X2 * N
    X = mx.nd.concat(X, X1, X2, dim=-1)
    Y = mx.nd.dot(X, W)

    return Y + b

def test_tagconv():
132
    g = dgl.DGLGraph(nx.path_graph(3)).to(F.ctx())
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    ctx = F.ctx()
    adj = g.adjacency_matrix(ctx=ctx)
    norm = mx.nd.power(g.in_degrees().astype('float32'), -0.5)

    conv = nn.TAGConv(5, 2, bias=True)
    conv.initialize(ctx=ctx)
    print(conv)

    # test#1: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    shp = norm.shape + (1,) * (h0.ndim - 1)
    norm = norm.reshape(shp).as_in_context(h0.context)

    assert F.allclose(h1, _S2AXWb(adj, norm, h0, conv.lin.data(ctx), conv.h_bias.data(ctx)))

    conv = nn.TAGConv(5, 2)
    conv.initialize(ctx=ctx)

    # test#2: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert h1.shape[-1] == 2

159
160
161
162
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
def test_gat_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
163
164
165
166
    ctx = F.ctx()
    gat = nn.GATConv(10, 20, 5) # n_heads = 5
    gat.initialize(ctx=ctx)
    print(gat)
167
    feat = F.randn((g.number_of_nodes(), 10))
168
    h = gat(g, feat)
169
    assert h.shape == (g.number_of_nodes(), 5, 20)
170

171
172
173
174
175
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
def test_gat_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
176
    gat = nn.GATConv(5, 2, 4)
177
    gat.initialize(ctx=ctx)
178
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
179
    h = gat(g, feat)
180
    assert h.shape == (g.number_of_dst_nodes(), 4, 2)
181

182
183
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
184
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
185
186
def test_sage_conv(idtype, g, aggre_type):
    g = g.astype(idtype).to(F.ctx())
187
188
    ctx = F.ctx()
    sage = nn.SAGEConv(5, 10, aggre_type)
189
    feat = F.randn((g.number_of_nodes(), 5))
190
191
192
193
    sage.initialize(ctx=ctx)
    h = sage(g, feat)
    assert h.shape[-1] == 10

194
195
196
197
198
199
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
def test_sage_conv_bi(idtype, g, aggre_type):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
200
201
    dst_dim = 5 if aggre_type != 'gcn' else 10
    sage = nn.SAGEConv((10, dst_dim), 2, aggre_type)
202
    feat = (F.randn((g.number_of_src_nodes(), 10)), F.randn((g.number_of_dst_nodes(), dst_dim)))
203
204
205
    sage.initialize(ctx=ctx)
    h = sage(g, feat)
    assert h.shape[-1] == 2
206
    assert h.shape[0] == g.number_of_dst_nodes()
207

208
209
210
@parametrize_dtype
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
def test_sage_conv_bi2(idtype, aggre_type):
Mufei Li's avatar
Mufei Li committed
211
212
    # Test the case for graphs without edges
    g = dgl.bipartite([], num_nodes=(5, 3))
213
214
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
Mufei Li's avatar
Mufei Li committed
215
216
217
218
219
220
221
222
223
224
225
226
227
228
    sage = nn.SAGEConv((3, 3), 2, 'gcn')
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    sage.initialize(ctx=ctx)
    h = sage(g, feat)
    assert h.shape[-1] == 2
    assert h.shape[0] == 3
    for aggre_type in ['mean', 'pool']:
        sage = nn.SAGEConv((3, 1), 2, aggre_type)
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        sage.initialize(ctx=ctx)
        h = sage(g, feat)
        assert h.shape[-1] == 2
        assert h.shape[0] == 3

229
def test_gg_conv():
230
    g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3)).to(F.ctx())
231
232
233
234
235
236
237
238
239
240
241
242
243
    ctx = F.ctx()

    gg_conv = nn.GatedGraphConv(10, 20, 3, 4) # n_step = 3, n_etypes = 4
    gg_conv.initialize(ctx=ctx)
    print(gg_conv)

    # test#1: basic
    h0 = F.randn((20, 10))
    etypes = nd.random.randint(0, 4, g.number_of_edges()).as_in_context(ctx)
    h1 = gg_conv(g, h0, etypes)
    assert h1.shape == (20, 20)

def test_cheb_conv():
244
    g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3)).to(F.ctx())
245
246
247
248
249
250
251
252
253
254
255
    ctx = F.ctx()

    cheb = nn.ChebConv(10, 20, 3) # k = 3
    cheb.initialize(ctx=ctx)
    print(cheb)

    # test#1: basic
    h0 = F.randn((20, 10))
    h1 = cheb(g, h0)
    assert h1.shape == (20, 20)

256
257
258
259
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
def test_agnn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
260
261
262
263
    ctx = F.ctx()
    agnn_conv = nn.AGNNConv(0.1, True)
    agnn_conv.initialize(ctx=ctx)
    print(agnn_conv)
264
    feat = F.randn((g.number_of_nodes(), 10))
265
    h = agnn_conv(g, feat)
266
    assert h.shape == (g.number_of_nodes(), 10)
267

268
269
270
271
272
273
274
275
276
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
def test_agnn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    agnn_conv = nn.AGNNConv(0.1, True)
    agnn_conv.initialize(ctx=ctx)
    print(agnn_conv)
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
277
    h = agnn_conv(g, feat)
278
    assert h.shape == (g.number_of_dst_nodes(), 5)
279

280
def test_appnp_conv():
281
    g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3)).to(F.ctx())
282
283
284
285
286
287
288
289
290
291
292
293
294
295
    ctx = F.ctx()

    appnp_conv = nn.APPNPConv(3, 0.1, 0)
    appnp_conv.initialize(ctx=ctx)
    print(appnp_conv)

    # test#1: basic
    h0 = F.randn((20, 10))
    h1 = appnp_conv(g, h0)
    assert h1.shape == (20, 10)

def test_dense_cheb_conv():
    for k in range(1, 4):
        ctx = F.ctx()
296
        g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.3)).to(F.ctx())
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
        adj = g.adjacency_matrix(ctx=ctx).tostype('default')
        cheb = nn.ChebConv(5, 2, k)
        dense_cheb = nn.DenseChebConv(5, 2, k)
        cheb.initialize(ctx=ctx)
        dense_cheb.initialize(ctx=ctx)

        for i in range(len(cheb.fc)):
            dense_cheb.fc[i].weight.set_data(
                cheb.fc[i].weight.data())
            if cheb.bias is not None:
                dense_cheb.bias.set_data(
                    cheb.bias.data())

        feat = F.randn((100, 5))
        out_cheb = cheb(g, feat, [2.0])
        out_dense_cheb = dense_cheb(adj, feat, 2.0)
        assert F.allclose(out_cheb, out_dense_cheb)

315
@parametrize_dtype
316
@pytest.mark.parametrize('norm_type', ['both', 'right', 'none'])
317
318
319
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
def test_dense_graph_conv(idtype, g, norm_type):
    g = g.astype(idtype).to(F.ctx())
320
321
    ctx = F.ctx()
    adj = g.adjacency_matrix(ctx=ctx).tostype('default')
322
323
    conv = nn.GraphConv(5, 2, norm=norm_type, bias=True)
    dense_conv = nn.DenseGraphConv(5, 2, norm=norm_type, bias=True)
324
325
326
327
328
329
    conv.initialize(ctx=ctx)
    dense_conv.initialize(ctx=ctx)
    dense_conv.weight.set_data(
        conv.weight.data())
    dense_conv.bias.set_data(
        conv.bias.data())
330
    feat = F.randn((g.number_of_src_nodes(), 5))
331
332
333
334
    out_conv = conv(g, feat)
    out_dense_conv = dense_conv(adj, feat)
    assert F.allclose(out_conv, out_dense_conv)

335
336
337
338
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite', 'block-bipartite']))
def test_dense_sage_conv(idtype, g):
    g = g.astype(idtype).to(F.ctx())
339
340
341
342
343
344
345
346
347
348
    ctx = F.ctx()
    adj = g.adjacency_matrix(ctx=ctx).tostype('default')
    sage = nn.SAGEConv(5, 2, 'gcn')
    dense_sage = nn.DenseSAGEConv(5, 2)
    sage.initialize(ctx=ctx)
    dense_sage.initialize(ctx=ctx)
    dense_sage.fc.weight.set_data(
        sage.fc_neigh.weight.data())
    dense_sage.fc.bias.set_data(
        sage.fc_neigh.bias.data())
349
350
351
352
353
354
355
    if len(g.ntypes) == 2:
        feat = (
            F.randn((g.number_of_src_nodes(), 5)),
            F.randn((g.number_of_dst_nodes(), 5))
        )
    else:
        feat = F.randn((g.number_of_nodes(), 5))
356
357
358
359
360

    out_sage = sage(g, feat)
    out_dense_sage = dense_sage(adj, feat)
    assert F.allclose(out_sage, out_dense_sage)

361
362
363
364
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
def test_edge_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
365
366
367
368
    ctx = F.ctx()
    edge_conv = nn.EdgeConv(5, 2)
    edge_conv.initialize(ctx=ctx)
    print(edge_conv)
369
370
371
372
    # test #1: basic
    h0 = F.randn((g.number_of_nodes(), 5))
    h1 = edge_conv(g, h0)
    assert h1.shape == (g.number_of_nodes(), 2)
373

374
375
376
377
378
379
380
381
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
def test_edge_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    edge_conv = nn.EdgeConv(5, 2)
    edge_conv.initialize(ctx=ctx)
    print(edge_conv)
382
    # test #1: basic
383
    h0 = F.randn((g.number_of_src_nodes(), 5))
384
385
    x0 = F.randn((g.number_of_dst_nodes(), 5))
    h1 = edge_conv(g, (h0, x0))
386
    assert h1.shape == (g.number_of_dst_nodes(), 2)
387

388
389
390
391
392
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
def test_gin_conv(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
393
394
    ctx = F.ctx()

395
    gin_conv = nn.GINConv(lambda x: x, aggregator_type, 0.1)
396
397
398
399
    gin_conv.initialize(ctx=ctx)
    print(gin_conv)

    # test #1: basic
400
401
    feat = F.randn((g.number_of_nodes(), 5))
    h = gin_conv(g, feat)
402
403
404
405
406
407
408
409
410
411
412
413
    assert h.shape == (g.number_of_nodes(), 5)

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
def test_gin_conv_bi(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()

    gin_conv = nn.GINConv(lambda x: x, aggregator_type, 0.1)
    gin_conv.initialize(ctx=ctx)
    print(gin_conv)
414
415

    # test #2: bipartite
416
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
417
    h = gin_conv(g, feat)
418
    return h.shape == (g.number_of_dst_nodes(), 5)
419

420

421
422
423
424
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
def test_gmm_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
425
    ctx = F.ctx()
426
427
    gmm_conv = nn.GMMConv(5, 2, 5, 3, 'max')
    gmm_conv.initialize(ctx=ctx)
428
429
430
431
432
    h0 = F.randn((g.number_of_nodes(), 5))
    pseudo = F.randn((g.number_of_edges(), 5))
    h1 = gmm_conv(g, h0, pseudo)
    assert h1.shape == (g.number_of_nodes(), 2)

433
434
435
436
437
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
def test_gmm_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
438
439
440
441
442
443
444
445
446
    gmm_conv = nn.GMMConv((5, 4), 2, 5, 3, 'max')
    gmm_conv.initialize(ctx=ctx)
    # test #1: basic
    h0 = F.randn((g.number_of_src_nodes(), 5))
    hd = F.randn((g.number_of_dst_nodes(), 4))
    pseudo = F.randn((g.number_of_edges(), 5))
    h1 = gmm_conv(g, (h0, hd), pseudo)
    assert h1.shape == (g.number_of_dst_nodes(), 2)

447
448
449
450
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
def test_nn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
451
452
453
    ctx = F.ctx()
    nn_conv = nn.NNConv(5, 2, gluon.nn.Embedding(3, 5 * 2), 'max')
    nn_conv.initialize(ctx=ctx)
454
455
456
457
458
    # test #1: basic
    h0 = F.randn((g.number_of_nodes(), 5))
    etypes = nd.random.randint(0, 4, g.number_of_edges()).as_in_context(ctx)
    h1 = nn_conv(g, h0, etypes)
    assert h1.shape == (g.number_of_nodes(), 2)
459

460
461
462
463
464
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
def test_nn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
465
466
467
468
469
470
471
472
473
    nn_conv = nn.NNConv((5, 4), 2, gluon.nn.Embedding(3, 5 * 2), 'max')
    nn_conv.initialize(ctx=ctx)
    # test #1: basic
    h0 = F.randn((g.number_of_src_nodes(), 5))
    hd = F.randn((g.number_of_dst_nodes(), 4))
    etypes = nd.random.randint(0, 4, g.number_of_edges()).as_in_context(ctx)
    h1 = nn_conv(g, (h0, hd), etypes)
    assert h1.shape == (g.number_of_dst_nodes(), 2)

474
def test_sg_conv():
475
    g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3)).to(F.ctx())
476
477
478
479
480
481
482
483
484
485
486
    ctx = F.ctx()

    sgc = nn.SGConv(5, 2, 2)
    sgc.initialize(ctx=ctx)
    print(sgc)

    # test #1: basic
    h0 = F.randn((g.number_of_nodes(), 5))
    h1 = sgc(g, h0)
    assert h1.shape == (g.number_of_nodes(), 2)

487
def test_set2set():
488
    g = dgl.DGLGraph(nx.path_graph(10)).to(F.ctx())
489
    ctx = F.ctx()
490
491

    s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
492
    s2s.initialize(ctx=ctx)
493
494
495
    print(s2s)

    # test#1: basic
496
    h0 = F.randn((g.number_of_nodes(), 5))
497
    h1 = s2s(g, h0)
498
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2
499
500
501

    # test#2: batched graph
    bg = dgl.batch([g, g, g])
502
    h0 = F.randn((bg.number_of_nodes(), 5))
503
    h1 = s2s(bg, h0)
504
505
506
    assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.ndim == 2

def test_glob_att_pool():
507
    g = dgl.DGLGraph(nx.path_graph(10)).to(F.ctx())
508
    ctx = F.ctx()
509
510

    gap = nn.GlobalAttentionPooling(gluon.nn.Dense(1), gluon.nn.Dense(10))
511
    gap.initialize(ctx=ctx)
512
513
    print(gap)
    # test#1: basic
514
    h0 = F.randn((g.number_of_nodes(), 5))
515
    h1 = gap(g, h0)
516
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2
517
518
519

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
520
    h0 = F.randn((bg.number_of_nodes(), 5))
521
    h1 = gap(bg, h0)
522
523
524
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2

def test_simple_pool():
525
    g = dgl.DGLGraph(nx.path_graph(15)).to(F.ctx())
526
527
528
529
530
531
532
533

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
    sort_pool = nn.SortPooling(10) # k = 10
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
534
    h0 = F.randn((g.number_of_nodes(), 5))
535
    h1 = sum_pool(g, h0)
536
    check_close(F.squeeze(h1, 0), F.sum(h0, 0))
537
    h1 = avg_pool(g, h0)
538
    check_close(F.squeeze(h1, 0), F.mean(h0, 0))
539
    h1 = max_pool(g, h0)
540
    check_close(F.squeeze(h1, 0), F.max(h0, 0))
541
    h1 = sort_pool(g, h0)
542
    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.ndim == 2
543
544

    # test#2: batched graph
545
    g_ = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
546
    bg = dgl.batch([g, g_, g, g_, g])
547
    h0 = F.randn((bg.number_of_nodes(), 5))
548
    h1 = sum_pool(bg, h0)
549
550
551
552
553
    truth = mx.nd.stack(F.sum(h0[:15], 0),
                        F.sum(h0[15:20], 0),
                        F.sum(h0[20:35], 0),
                        F.sum(h0[35:40], 0),
                        F.sum(h0[40:55], 0), axis=0)
554
555
    check_close(h1, truth)

556
    h1 = avg_pool(bg, h0)
557
558
559
560
561
    truth = mx.nd.stack(F.mean(h0[:15], 0),
                        F.mean(h0[15:20], 0),
                        F.mean(h0[20:35], 0),
                        F.mean(h0[35:40], 0),
                        F.mean(h0[40:55], 0), axis=0)
562
563
    check_close(h1, truth)

564
    h1 = max_pool(bg, h0)
565
566
567
568
569
    truth = mx.nd.stack(F.max(h0[:15], 0),
                        F.max(h0[15:20], 0),
                        F.max(h0[20:35], 0),
                        F.max(h0[35:40], 0),
                        F.max(h0[40:55], 0), axis=0)
570
571
    check_close(h1, truth)

572
    h1 = sort_pool(bg, h0)
573
574
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.ndim == 2

575
def uniform_attention(g, shape):
576
    a = mx.nd.ones(shape).as_in_context(g.device)
577
578
579
580
581
    target_shape = (g.number_of_edges(),) + (1,) * (len(shape) - 1)
    return a / g.in_degrees(g.edges()[1]).reshape(target_shape).astype('float32')

def test_edge_softmax():
    # Basic
582
    g = dgl.DGLGraph(nx.path_graph(3)).to(F.ctx())
583
    edata = F.ones((g.number_of_edges(), 1))
584
    a = nn.edge_softmax(g, edata)
585
586
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
587
588
589
590
    assert np.allclose(a.asnumpy(), uniform_attention(g, a.shape).asnumpy(),
            1e-4, 1e-4)

    # Test higher dimension case
591
    edata = F.ones((g.number_of_edges(), 3, 1))
592
    a = nn.edge_softmax(g, edata)
593
594
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
595
596
597
    assert np.allclose(a.asnumpy(), uniform_attention(g, a.shape).asnumpy(),
            1e-4, 1e-4)

598
def test_partial_edge_softmax():
599
    g = dgl.DGLGraph().to(F.ctx())
600
601
602
603
604
605
606
607
608
609
    g.add_nodes(30)
    # build a complete graph
    for i in range(30):
        for j in range(30):
            g.add_edge(i, j)

    score = F.randn((300, 1))
    score.attach_grad()
    grad = F.randn((300, 1))
    import numpy as np
610
    eids = F.tensor(np.random.choice(900, 300, replace=False), g.idtype)
611
612
613
614
615
616
617
    # compute partial edge softmax
    with mx.autograd.record():
        y_1 = nn.edge_softmax(g, score, eids)
        y_1.backward(grad)
        grad_1 = score.grad

    # compute edge softmax on edge subgraph
618
    subg = g.edge_subgraph(eids, preserve_nodes=True)
619
620
621
622
623
624
625
626
    with mx.autograd.record():
        y_2 = nn.edge_softmax(subg, score)
        y_2.backward(grad)
        grad_2 = score.grad

    assert F.allclose(y_1, y_2)
    assert F.allclose(grad_1, grad_2)

Minjie Wang's avatar
Minjie Wang committed
627
628
629
def test_rgcn():
    ctx = F.ctx()
    etype = []
630
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True).to(F.ctx())
Minjie Wang's avatar
Minjie Wang committed
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
    # 5 etypes
    R = 5
    for i in range(g.number_of_edges()):
        etype.append(i % 5)
    B = 2
    I = 10
    O = 8

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
    rgc_basis.initialize(ctx=ctx)
    h = nd.random.randn(100, I, ctx=ctx)
    r = nd.array(etype, ctx=ctx)
    h_new = rgc_basis(g, h, r)
    assert list(h_new.shape) == [100, O]

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
    rgc_bdd.initialize(ctx=ctx)
    h = nd.random.randn(100, I, ctx=ctx)
    r = nd.array(etype, ctx=ctx)
    h_new = rgc_bdd(g, h, r)
    assert list(h_new.shape) == [100, O]

    # with norm
    norm = nd.zeros((g.number_of_edges(), 1), ctx=ctx)

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
    rgc_basis.initialize(ctx=ctx)
    h = nd.random.randn(100, I, ctx=ctx)
    r = nd.array(etype, ctx=ctx)
    h_new = rgc_basis(g, h, r, norm)
    assert list(h_new.shape) == [100, O]

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
    rgc_bdd.initialize(ctx=ctx)
    h = nd.random.randn(100, I, ctx=ctx)
    r = nd.array(etype, ctx=ctx)
    h_new = rgc_bdd(g, h, r, norm)
    assert list(h_new.shape) == [100, O]

    # id input
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
    rgc_basis.initialize(ctx=ctx)
    h = nd.random.randint(0, I, (100,), ctx=ctx)
    r = nd.array(etype, ctx=ctx)
    h_new = rgc_basis(g, h, r)
    assert list(h_new.shape) == [100, O]

678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
def test_sequential():
    ctx = F.ctx()
    # test single graph
    class ExampleLayer(gluon.nn.Block):
        def __init__(self, **kwargs):
            super().__init__(**kwargs)

        def forward(self, graph, n_feat, e_feat):
            graph = graph.local_var()
            graph.ndata['h'] = n_feat
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            n_feat += graph.ndata['h']
            graph.apply_edges(fn.u_add_v('h', 'h', 'e'))
            e_feat += graph.edata['e']
            return n_feat, e_feat

694
    g = dgl.DGLGraph().to(F.ctx())
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
    g.add_nodes(3)
    g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
    net = nn.Sequential()
    net.add(ExampleLayer())
    net.add(ExampleLayer())
    net.add(ExampleLayer())
    net.initialize(ctx=ctx)
    n_feat = F.randn((3, 4))
    e_feat = F.randn((9, 4))
    n_feat, e_feat = net(g, n_feat, e_feat)
    assert n_feat.shape == (3, 4)
    assert e_feat.shape == (9, 4)

    # test multiple graphs
    class ExampleLayer(gluon.nn.Block):
        def __init__(self, **kwargs):
            super().__init__(**kwargs)

        def forward(self, graph, n_feat):
            graph = graph.local_var()
            graph.ndata['h'] = n_feat
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            n_feat += graph.ndata['h']
            return n_feat.reshape(graph.number_of_nodes() // 2, 2, -1).sum(1)

720
721
722
    g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05)).to(F.ctx())
    g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2)).to(F.ctx())
    g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8)).to(F.ctx())
723
724
725
726
727
728
729
730
731
    net = nn.Sequential()
    net.add(ExampleLayer())
    net.add(ExampleLayer())
    net.add(ExampleLayer())
    net.initialize(ctx=ctx)
    n_feat = F.randn((32, 4))
    n_feat = net([g1, g2, g3], n_feat)
    assert n_feat.shape == (4, 4)

732
733
734
735
736
737
def myagg(alist, dsttype):
    rst = alist[0]
    for i in range(1, len(alist)):
        rst = rst + (i + 1) * alist[i]
    return rst

738
@parametrize_dtype
739
@pytest.mark.parametrize('agg', ['sum', 'max', 'min', 'mean', 'stack', myagg])
740
def test_hetero_conv(agg, idtype):
741
742
743
    g = dgl.heterograph({
        ('user', 'follows', 'user'): [(0, 1), (0, 2), (2, 1), (1, 3)],
        ('user', 'plays', 'game'): [(0, 0), (0, 2), (0, 3), (1, 0), (2, 2)],
744
745
        ('store', 'sells', 'game'): [(0, 0), (0, 3), (1, 1), (1, 2)]},
        idtype=idtype, device=F.ctx())
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
    conv = nn.HeteroGraphConv({
        'follows': nn.GraphConv(2, 3),
        'plays': nn.GraphConv(2, 4),
        'sells': nn.GraphConv(3, 4)},
        agg)
    conv.initialize(ctx=F.ctx())
    print(conv)
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))
    uf_dst = F.randn((4, 3))
    gf_dst = F.randn((4, 4))

    h = conv(g, {'user': uf})
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 1, 4)

    h = conv(g, {'user': uf, 'store': sf})
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 2, 4)

    h = conv(g, {'store': sf})
    assert set(h.keys()) == {'game'}
    if agg != 'stack':
        assert h['game'].shape == (4, 4)
    else:
        assert h['game'].shape == (4, 1, 4)

    # test with pair input
    conv = nn.HeteroGraphConv({
        'follows': nn.SAGEConv(2, 3, 'mean'),
        'plays': nn.SAGEConv((2, 4), 4, 'mean'),
        'sells': nn.SAGEConv(3, 4, 'mean')},
        agg)
    conv.initialize(ctx=F.ctx())

    h = conv(g, ({'user': uf}, {'user' : uf, 'game' : gf}))
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 1, 4)

    # pair input requires both src and dst type features to be provided
    h = conv(g, ({'user': uf}, {'game' : gf}))
    assert set(h.keys()) == {'game'}
    if agg != 'stack':
        assert h['game'].shape == (4, 4)
    else:
        assert h['game'].shape == (4, 1, 4)

    # test with mod args
    class MyMod(mx.gluon.nn.Block):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.s1 = s1
            self.s2 = s2
        def forward(self, g, h, arg1=None):  # mxnet does not support kwargs
            if arg1 is not None:
                self.carg1 += 1
            return F.zeros((g.number_of_dst_nodes(), self.s2))
    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
    conv = nn.HeteroGraphConv({
        'follows': mod1,
        'plays': mod2,
        'sells': mod3},
        agg)
    conv.initialize(ctx=F.ctx())
    mod_args = {'follows' : (1,), 'plays' : (1,)}
    h = conv(g, {'user' : uf, 'store' : sf}, mod_args)
    assert mod1.carg1 == 1
    assert mod2.carg1 == 1
    assert mod3.carg1 == 0

835
836
if __name__ == '__main__':
    test_graph_conv()
837
838
839
840
841
842
843
844
845
846
847
848
849
850
    test_gat_conv()
    test_sage_conv()
    test_gg_conv()
    test_cheb_conv()
    test_agnn_conv()
    test_appnp_conv()
    test_dense_cheb_conv()
    test_dense_graph_conv()
    test_dense_sage_conv()
    test_edge_conv()
    test_gin_conv()
    test_gmm_conv()
    test_nn_conv()
    test_sg_conv()
851
    test_edge_softmax()
852
    test_partial_edge_softmax()
853
854
855
    test_set2set()
    test_glob_att_pool()
    test_simple_pool()
Minjie Wang's avatar
Minjie Wang committed
856
    test_rgcn()
857
    test_sequential()