test_nn.py 27.6 KB
Newer Older
1
import backend as F
2
3
4
import mxnet as mx
import networkx as nx
import numpy as np
5
import pytest
6
7
8
9
10
11
12
13
14
15
import scipy as sp
from mxnet import autograd, gluon, nd
from test_utils import parametrize_idtype
from test_utils.graph_cases import (
    get_cases,
    random_bipartite,
    random_dglgraph,
    random_graph,
)

16
import dgl
17
import dgl.function as fn
18
19
import dgl.nn.mxnet as nn

20

21
22
def check_close(a, b):
    assert np.allclose(a.asnumpy(), b.asnumpy(), rtol=1e-4, atol=1e-4)
23

24

25
26
27
28
29
def _AXWb(A, X, W, b):
    X = mx.nd.dot(X, W.data(X.context))
    Y = mx.nd.dot(A, X.reshape(X.shape[0], -1)).reshape(X.shape)
    return Y + b.data(X.context)

30

nv-dlasalle's avatar
nv-dlasalle committed
31
@parametrize_idtype
32
@pytest.mark.parametrize("out_dim", [1, 2])
33
def test_graph_conv(idtype, out_dim):
34
    g = dgl.from_networkx(nx.path_graph(3))
35
    g = g.astype(idtype).to(F.ctx())
36
    ctx = F.ctx()
37
    adj = g.adjacency_matrix(transpose=True, ctx=ctx)
38

39
    conv = nn.GraphConv(5, out_dim, norm="none", bias=True)
40
41
    conv.initialize(ctx=ctx)
    # test#1: basic
42
    h0 = F.ones((3, 5))
43
    h1 = conv(g, h0)
44
45
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
46
    check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias))
47
    # test#2: more-dim
48
    h0 = F.ones((3, 5, 5))
49
    h1 = conv(g, h0)
50
51
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
52
    check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias))
53

54
    conv = nn.GraphConv(5, out_dim)
55
56
57
    conv.initialize(ctx=ctx)

    # test#3: basic
58
    h0 = F.ones((3, 5))
59
    h1 = conv(g, h0)
60
61
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
62
    # test#4: basic
63
    h0 = F.ones((3, 5, 5))
64
    h1 = conv(g, h0)
65
66
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
67

68
    conv = nn.GraphConv(5, out_dim)
69
70
71
72
    conv.initialize(ctx=ctx)

    with autograd.train_mode():
        # test#3: basic
73
        h0 = F.ones((3, 5))
74
        h1 = conv(g, h0)
75
76
        assert len(g.ndata) == 0
        assert len(g.edata) == 0
77
        # test#4: basic
78
        h0 = F.ones((3, 5, 5))
79
        h1 = conv(g, h0)
80
81
        assert len(g.ndata) == 0
        assert len(g.edata) == 0
82

83
    # test not override features
84
    g.ndata["h"] = 2 * F.ones((3, 1))
85
    h1 = conv(g, h0)
86
87
88
    assert len(g.ndata) == 1
    assert len(g.edata) == 0
    assert "h" in g.ndata
89
90
    check_close(g.ndata["h"], 2 * F.ones((3, 1)))

91

nv-dlasalle's avatar
nv-dlasalle committed
92
@parametrize_idtype
93
94
95
96
97
98
99
100
@pytest.mark.parametrize(
    "g",
    get_cases(["homo", "block-bipartite"], exclude=["zero-degree", "dglgraph"]),
)
@pytest.mark.parametrize("norm", ["none", "both", "right", "left"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [False])
@pytest.mark.parametrize("out_dim", [1, 2])
101
def test_graph_conv2(idtype, g, norm, weight, bias, out_dim):
102
    g = g.astype(idtype).to(F.ctx())
103
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias)
104
    conv.initialize(ctx=F.ctx())
105
    ext_w = F.randn((5, out_dim)).as_in_context(F.ctx())
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
106
107
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
108
109
    h = F.randn((nsrc, 5)).as_in_context(F.ctx())
    if weight:
110
        h_out = conv(g, h)
111
    else:
112
        h_out = conv(g, h, ext_w)
113
    assert h_out.shape == (ndst, out_dim)
114

115

nv-dlasalle's avatar
nv-dlasalle committed
116
@parametrize_idtype
117
118
119
120
121
122
123
@pytest.mark.parametrize(
    "g", get_cases(["bipartite"], exclude=["zero-degree", "dglgraph"])
)
@pytest.mark.parametrize("norm", ["none", "both", "right"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [False])
@pytest.mark.parametrize("out_dim", [1, 2])
124
def test_graph_conv2_bi(idtype, g, norm, weight, bias, out_dim):
125
    g = g.astype(idtype).to(F.ctx())
126
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias)
127
    conv.initialize(ctx=F.ctx())
128
    ext_w = F.randn((5, out_dim)).as_in_context(F.ctx())
129
130
131
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).as_in_context(F.ctx())
132
    h_dst = F.randn((ndst, out_dim)).as_in_context(F.ctx())
133
134
135
136
    if weight:
        h_out = conv(g, (h, h_dst))
    else:
        h_out = conv(g, (h, h_dst), ext_w)
137
    assert h_out.shape == (ndst, out_dim)
138

139

140
141
142
143
144
145
146
147
148
149
150
151
def _S2AXWb(A, N, X, W, b):
    X1 = X * N
    X1 = mx.nd.dot(A, X1.reshape(X1.shape[0], -1))
    X1 = X1 * N
    X2 = X1 * N
    X2 = mx.nd.dot(A, X2.reshape(X2.shape[0], -1))
    X2 = X2 * N
    X = mx.nd.concat(X, X1, X2, dim=-1)
    Y = mx.nd.dot(X, W)

    return Y + b

152
153

@pytest.mark.parametrize("out_dim", [1, 2])
154
def test_tagconv(out_dim):
155
    g = dgl.from_networkx(nx.path_graph(3)).to(F.ctx())
156
    ctx = F.ctx()
157
    adj = g.adjacency_matrix(transpose=True, ctx=ctx)
158
    norm = mx.nd.power(g.in_degrees().astype("float32"), -0.5)
159

160
    conv = nn.TAGConv(5, out_dim, bias=True)
161
162
163
164
165
166
167
168
169
170
171
    conv.initialize(ctx=ctx)
    print(conv)

    # test#1: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    shp = norm.shape + (1,) * (h0.ndim - 1)
    norm = norm.reshape(shp).as_in_context(h0.context)

172
173
174
    assert F.allclose(
        h1, _S2AXWb(adj, norm, h0, conv.lin.data(ctx), conv.h_bias.data(ctx))
    )
175

176
    conv = nn.TAGConv(5, out_dim)
177
178
179
180
181
    conv.initialize(ctx=ctx)

    # test#2: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
182
    assert h1.shape[-1] == out_dim
183

184

nv-dlasalle's avatar
nv-dlasalle committed
185
@parametrize_idtype
186
187
188
189
190
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 20])
@pytest.mark.parametrize("num_heads", [1, 5])
191
def test_gat_conv(g, idtype, out_dim, num_heads):
192
    g = g.astype(idtype).to(F.ctx())
193
    ctx = F.ctx()
194
    gat = nn.GATConv(10, out_dim, num_heads)  # n_heads = 5
195
196
    gat.initialize(ctx=ctx)
    print(gat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
197
    feat = F.randn((g.number_of_src_nodes(), 10))
198
    h = gat(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
199
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
200
    _, a = gat(g, feat, True)
201
    assert a.shape == (g.number_of_edges(), num_heads, 1)
202

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
203
204
205
206
207
    # test residual connection
    gat = nn.GATConv(10, out_dim, num_heads, residual=True)
    gat.initialize(ctx=ctx)
    h = gat(g, feat)

208

nv-dlasalle's avatar
nv-dlasalle committed
209
@parametrize_idtype
210
211
212
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
213
def test_gat_conv_bi(g, idtype, out_dim, num_heads):
214
215
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
216
    gat = nn.GATConv(5, out_dim, num_heads)
217
    gat.initialize(ctx=ctx)
218
219
220
221
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
222
    h = gat(g, feat)
223
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
224
    _, a = gat(g, feat, True)
225
    assert a.shape == (g.number_of_edges(), num_heads, 1)
226

227

nv-dlasalle's avatar
nv-dlasalle committed
228
@parametrize_idtype
229
230
231
@pytest.mark.parametrize("g", get_cases(["homo", "block-bipartite"]))
@pytest.mark.parametrize("aggre_type", ["mean", "pool", "gcn"])
@pytest.mark.parametrize("out_dim", [1, 10])
232
def test_sage_conv(idtype, g, aggre_type, out_dim):
233
    g = g.astype(idtype).to(F.ctx())
234
    ctx = F.ctx()
235
    sage = nn.SAGEConv(5, out_dim, aggre_type)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
236
    feat = F.randn((g.number_of_src_nodes(), 5))
237
238
    sage.initialize(ctx=ctx)
    h = sage(g, feat)
239
    assert h.shape[-1] == out_dim
240

241

nv-dlasalle's avatar
nv-dlasalle committed
242
@parametrize_idtype
243
244
245
@pytest.mark.parametrize("g", get_cases(["bipartite"]))
@pytest.mark.parametrize("aggre_type", ["mean", "pool", "gcn"])
@pytest.mark.parametrize("out_dim", [1, 2])
246
def test_sage_conv_bi(idtype, g, aggre_type, out_dim):
247
248
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
249
    dst_dim = 5 if aggre_type != "gcn" else 10
250
    sage = nn.SAGEConv((10, dst_dim), out_dim, aggre_type)
251
252
253
254
    feat = (
        F.randn((g.number_of_src_nodes(), 10)),
        F.randn((g.number_of_dst_nodes(), dst_dim)),
    )
255
256
    sage.initialize(ctx=ctx)
    h = sage(g, feat)
257
    assert h.shape[-1] == out_dim
258
    assert h.shape[0] == g.number_of_dst_nodes()
259

260

nv-dlasalle's avatar
nv-dlasalle committed
261
@parametrize_idtype
262
263
@pytest.mark.parametrize("aggre_type", ["mean", "pool", "gcn"])
@pytest.mark.parametrize("out_dim", [1, 2])
264
def test_sage_conv_bi2(idtype, aggre_type, out_dim):
Mufei Li's avatar
Mufei Li committed
265
    # Test the case for graphs without edges
266
    g = dgl.heterograph({("_U", "_E", "_V"): ([], [])}, {"_U": 5, "_V": 3})
267
268
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
269
    sage = nn.SAGEConv((3, 3), out_dim, "gcn")
Mufei Li's avatar
Mufei Li committed
270
271
272
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    sage.initialize(ctx=ctx)
    h = sage(g, feat)
273
    assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
274
    assert h.shape[0] == 3
275
    for aggre_type in ["mean", "pool"]:
276
        sage = nn.SAGEConv((3, 1), out_dim, aggre_type)
Mufei Li's avatar
Mufei Li committed
277
278
279
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        sage.initialize(ctx=ctx)
        h = sage(g, feat)
280
        assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
281
282
        assert h.shape[0] == 3

283

284
def test_gg_conv():
285
    g = dgl.from_networkx(nx.erdos_renyi_graph(20, 0.3)).to(F.ctx())
286
287
    ctx = F.ctx()

288
    gg_conv = nn.GatedGraphConv(10, 20, 3, 4)  # n_step = 3, n_etypes = 4
289
290
291
292
293
294
295
296
297
    gg_conv.initialize(ctx=ctx)
    print(gg_conv)

    # test#1: basic
    h0 = F.randn((20, 10))
    etypes = nd.random.randint(0, 4, g.number_of_edges()).as_in_context(ctx)
    h1 = gg_conv(g, h0, etypes)
    assert h1.shape == (20, 20)

298
299

@pytest.mark.parametrize("out_dim", [1, 20])
300
def test_cheb_conv(out_dim):
301
    g = dgl.from_networkx(nx.erdos_renyi_graph(20, 0.3)).to(F.ctx())
302
303
    ctx = F.ctx()

304
    cheb = nn.ChebConv(10, out_dim, 3)  # k = 3
305
306
307
308
309
310
    cheb.initialize(ctx=ctx)
    print(cheb)

    # test#1: basic
    h0 = F.randn((20, 10))
    h1 = cheb(g, h0)
311
    assert h1.shape == (20, out_dim)
312

313

nv-dlasalle's avatar
nv-dlasalle committed
314
@parametrize_idtype
315
316
317
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
318
319
def test_agnn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
320
321
322
323
    ctx = F.ctx()
    agnn_conv = nn.AGNNConv(0.1, True)
    agnn_conv.initialize(ctx=ctx)
    print(agnn_conv)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
324
    feat = F.randn((g.number_of_src_nodes(), 10))
325
    h = agnn_conv(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
326
    assert h.shape == (g.number_of_dst_nodes(), 10)
327

328

nv-dlasalle's avatar
nv-dlasalle committed
329
@parametrize_idtype
330
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
331
332
333
334
335
336
def test_agnn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    agnn_conv = nn.AGNNConv(0.1, True)
    agnn_conv.initialize(ctx=ctx)
    print(agnn_conv)
337
338
339
340
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
341
    h = agnn_conv(g, feat)
342
    assert h.shape == (g.number_of_dst_nodes(), 5)
343

344

345
def test_appnp_conv():
346
    g = dgl.from_networkx(nx.erdos_renyi_graph(20, 0.3)).to(F.ctx())
347
348
349
350
351
352
353
354
355
356
357
    ctx = F.ctx()

    appnp_conv = nn.APPNPConv(3, 0.1, 0)
    appnp_conv.initialize(ctx=ctx)
    print(appnp_conv)

    # test#1: basic
    h0 = F.randn((20, 10))
    h1 = appnp_conv(g, h0)
    assert h1.shape == (20, 10)

358
359

@pytest.mark.parametrize("out_dim", [1, 2])
360
def test_dense_cheb_conv(out_dim):
361
362
    for k in range(1, 4):
        ctx = F.ctx()
363
        g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.3)).to(F.ctx())
364
        adj = g.adjacency_matrix(transpose=True, ctx=ctx).tostype("default")
365
366
        cheb = nn.ChebConv(5, out_dim, k)
        dense_cheb = nn.DenseChebConv(5, out_dim, k)
367
368
369
370
        cheb.initialize(ctx=ctx)
        dense_cheb.initialize(ctx=ctx)

        for i in range(len(cheb.fc)):
371
            dense_cheb.fc[i].weight.set_data(cheb.fc[i].weight.data())
372
            if cheb.bias is not None:
373
                dense_cheb.bias.set_data(cheb.bias.data())
374
375
376
377
378
379

        feat = F.randn((100, 5))
        out_cheb = cheb(g, feat, [2.0])
        out_dense_cheb = dense_cheb(adj, feat, 2.0)
        assert F.allclose(out_cheb, out_dense_cheb)

380

nv-dlasalle's avatar
nv-dlasalle committed
381
@parametrize_idtype
382
383
384
385
386
@pytest.mark.parametrize("norm_type", ["both", "right", "none"])
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 2])
387
def test_dense_graph_conv(idtype, g, norm_type, out_dim):
388
    g = g.astype(idtype).to(F.ctx())
389
    ctx = F.ctx()
390
    adj = g.adjacency_matrix(transpose=True, ctx=ctx).tostype("default")
391
392
    conv = nn.GraphConv(5, out_dim, norm=norm_type, bias=True)
    dense_conv = nn.DenseGraphConv(5, out_dim, norm=norm_type, bias=True)
393
394
    conv.initialize(ctx=ctx)
    dense_conv.initialize(ctx=ctx)
395
396
    dense_conv.weight.set_data(conv.weight.data())
    dense_conv.bias.set_data(conv.bias.data())
397
    feat = F.randn((g.number_of_src_nodes(), 5))
398
399
400
401
    out_conv = conv(g, feat)
    out_dense_conv = dense_conv(adj, feat)
    assert F.allclose(out_conv, out_dense_conv)

402

nv-dlasalle's avatar
nv-dlasalle committed
403
@parametrize_idtype
404
405
406
407
@pytest.mark.parametrize(
    "g", get_cases(["homo", "bipartite", "block-bipartite"])
)
@pytest.mark.parametrize("out_dim", [1, 2])
408
def test_dense_sage_conv(idtype, g, out_dim):
409
    g = g.astype(idtype).to(F.ctx())
410
    ctx = F.ctx()
411
412
    adj = g.adjacency_matrix(transpose=True, ctx=ctx).tostype("default")
    sage = nn.SAGEConv(5, out_dim, "gcn")
413
    dense_sage = nn.DenseSAGEConv(5, out_dim)
414
415
    sage.initialize(ctx=ctx)
    dense_sage.initialize(ctx=ctx)
416
417
    dense_sage.fc.weight.set_data(sage.fc_neigh.weight.data())
    dense_sage.fc.bias.set_data(sage.fc_neigh.bias.data())
418
419
420
    if len(g.ntypes) == 2:
        feat = (
            F.randn((g.number_of_src_nodes(), 5)),
421
            F.randn((g.number_of_dst_nodes(), 5)),
422
423
424
        )
    else:
        feat = F.randn((g.number_of_nodes(), 5))
425
426
427
428
429

    out_sage = sage(g, feat)
    out_dense_sage = dense_sage(adj, feat)
    assert F.allclose(out_sage, out_dense_sage)

430

nv-dlasalle's avatar
nv-dlasalle committed
431
@parametrize_idtype
432
433
434
435
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 2])
436
def test_edge_conv(g, idtype, out_dim):
437
    g = g.astype(idtype).to(F.ctx())
438
    ctx = F.ctx()
439
    edge_conv = nn.EdgeConv(5, out_dim)
440
441
    edge_conv.initialize(ctx=ctx)
    print(edge_conv)
442
    # test #1: basic
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
443
    h0 = F.randn((g.number_of_src_nodes(), 5))
444
    h1 = edge_conv(g, h0)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
445
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)
446

447

nv-dlasalle's avatar
nv-dlasalle committed
448
@parametrize_idtype
449
450
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
451
def test_edge_conv_bi(g, idtype, out_dim):
452
453
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
454
    edge_conv = nn.EdgeConv(5, out_dim)
455
456
    edge_conv.initialize(ctx=ctx)
    print(edge_conv)
457
    # test #1: basic
458
    h0 = F.randn((g.number_of_src_nodes(), 5))
459
460
    x0 = F.randn((g.number_of_dst_nodes(), 5))
    h1 = edge_conv(g, (h0, x0))
461
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)
462

463

nv-dlasalle's avatar
nv-dlasalle committed
464
@parametrize_idtype
465
466
@pytest.mark.parametrize("g", get_cases(["homo", "block-bipartite"]))
@pytest.mark.parametrize("aggregator_type", ["mean", "max", "sum"])
467
468
def test_gin_conv(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
469
470
    ctx = F.ctx()

471
    gin_conv = nn.GINConv(lambda x: x, aggregator_type, 0.1)
472
473
474
475
    gin_conv.initialize(ctx=ctx)
    print(gin_conv)

    # test #1: basic
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
476
    feat = F.randn((g.number_of_src_nodes(), 5))
477
    h = gin_conv(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
478
    assert h.shape == (g.number_of_dst_nodes(), 5)
479

480

nv-dlasalle's avatar
nv-dlasalle committed
481
@parametrize_idtype
482
483
@pytest.mark.parametrize("g", get_cases(["bipartite"]))
@pytest.mark.parametrize("aggregator_type", ["mean", "max", "sum"])
484
485
486
487
488
489
490
def test_gin_conv_bi(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()

    gin_conv = nn.GINConv(lambda x: x, aggregator_type, 0.1)
    gin_conv.initialize(ctx=ctx)
    print(gin_conv)
491
492

    # test #2: bipartite
493
494
495
496
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
497
    h = gin_conv(g, feat)
498
    return h.shape == (g.number_of_dst_nodes(), 5)
499

500

nv-dlasalle's avatar
nv-dlasalle committed
501
@parametrize_idtype
502
503
504
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
505
506
def test_gmm_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
507
    ctx = F.ctx()
508
    gmm_conv = nn.GMMConv(5, 2, 5, 3, "max")
509
    gmm_conv.initialize(ctx=ctx)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
510
    h0 = F.randn((g.number_of_src_nodes(), 5))
511
512
    pseudo = F.randn((g.number_of_edges(), 5))
    h1 = gmm_conv(g, h0, pseudo)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
513
    assert h1.shape == (g.number_of_dst_nodes(), 2)
514

515

nv-dlasalle's avatar
nv-dlasalle committed
516
@parametrize_idtype
517
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
518
519
520
def test_gmm_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
521
    gmm_conv = nn.GMMConv((5, 4), 2, 5, 3, "max")
522
523
524
525
526
527
528
529
    gmm_conv.initialize(ctx=ctx)
    # test #1: basic
    h0 = F.randn((g.number_of_src_nodes(), 5))
    hd = F.randn((g.number_of_dst_nodes(), 4))
    pseudo = F.randn((g.number_of_edges(), 5))
    h1 = gmm_conv(g, (h0, hd), pseudo)
    assert h1.shape == (g.number_of_dst_nodes(), 2)

530

nv-dlasalle's avatar
nv-dlasalle committed
531
@parametrize_idtype
532
@pytest.mark.parametrize("g", get_cases(["homo", "block-bipartite"]))
533
534
def test_nn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
535
    ctx = F.ctx()
536
    nn_conv = nn.NNConv(5, 2, gluon.nn.Embedding(3, 5 * 2), "max")
537
    nn_conv.initialize(ctx=ctx)
538
    # test #1: basic
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
539
    h0 = F.randn((g.number_of_src_nodes(), 5))
540
541
    etypes = nd.random.randint(0, 4, g.number_of_edges()).as_in_context(ctx)
    h1 = nn_conv(g, h0, etypes)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
542
    assert h1.shape == (g.number_of_dst_nodes(), 2)
543

544

nv-dlasalle's avatar
nv-dlasalle committed
545
@parametrize_idtype
546
@pytest.mark.parametrize("g", get_cases(["bipartite"]))
547
548
549
def test_nn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
550
    nn_conv = nn.NNConv((5, 4), 2, gluon.nn.Embedding(3, 5 * 2), "max")
551
552
553
554
555
556
557
558
    nn_conv.initialize(ctx=ctx)
    # test #1: basic
    h0 = F.randn((g.number_of_src_nodes(), 5))
    hd = F.randn((g.number_of_dst_nodes(), 4))
    etypes = nd.random.randint(0, 4, g.number_of_edges()).as_in_context(ctx)
    h1 = nn_conv(g, (h0, hd), etypes)
    assert h1.shape == (g.number_of_dst_nodes(), 2)

559
560

@pytest.mark.parametrize("out_dim", [1, 2])
561
def test_sg_conv(out_dim):
562
    g = dgl.from_networkx(nx.erdos_renyi_graph(20, 0.3)).to(F.ctx())
563
    g = dgl.add_self_loop(g)
564
565
    ctx = F.ctx()

566
    sgc = nn.SGConv(5, out_dim, 2)
567
568
569
570
571
572
    sgc.initialize(ctx=ctx)
    print(sgc)

    # test #1: basic
    h0 = F.randn((g.number_of_nodes(), 5))
    h1 = sgc(g, h0)
573
    assert h1.shape == (g.number_of_nodes(), out_dim)
574

575

576
def test_set2set():
577
    g = dgl.from_networkx(nx.path_graph(10)).to(F.ctx())
578
    ctx = F.ctx()
579

580
    s2s = nn.Set2Set(5, 3, 3)  # hidden size 5, 3 iters, 3 layers
581
    s2s.initialize(ctx=ctx)
582
583
584
    print(s2s)

    # test#1: basic
585
    h0 = F.randn((g.number_of_nodes(), 5))
586
    h1 = s2s(g, h0)
587
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2
588
589
590

    # test#2: batched graph
    bg = dgl.batch([g, g, g])
591
    h0 = F.randn((bg.number_of_nodes(), 5))
592
    h1 = s2s(bg, h0)
593
594
    assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.ndim == 2

595

596
def test_glob_att_pool():
597
    g = dgl.from_networkx(nx.path_graph(10)).to(F.ctx())
598
    ctx = F.ctx()
599
600

    gap = nn.GlobalAttentionPooling(gluon.nn.Dense(1), gluon.nn.Dense(10))
601
    gap.initialize(ctx=ctx)
602
603
    print(gap)
    # test#1: basic
604
    h0 = F.randn((g.number_of_nodes(), 5))
605
    h1 = gap(g, h0)
606
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2
607
608
609

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
610
    h0 = F.randn((bg.number_of_nodes(), 5))
611
    h1 = gap(bg, h0)
612
613
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2

614

615
def test_simple_pool():
616
    g = dgl.from_networkx(nx.path_graph(15)).to(F.ctx())
617
618
619
620

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
621
    sort_pool = nn.SortPooling(10)  # k = 10
622
623
624
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
625
    h0 = F.randn((g.number_of_nodes(), 5))
626
    h1 = sum_pool(g, h0)
627
    check_close(F.squeeze(h1, 0), F.sum(h0, 0))
628
    h1 = avg_pool(g, h0)
629
    check_close(F.squeeze(h1, 0), F.mean(h0, 0))
630
    h1 = max_pool(g, h0)
631
    check_close(F.squeeze(h1, 0), F.max(h0, 0))
632
    h1 = sort_pool(g, h0)
633
    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.ndim == 2
634
635

    # test#2: batched graph
636
    g_ = dgl.from_networkx(nx.path_graph(5)).to(F.ctx())
637
    bg = dgl.batch([g, g_, g, g_, g])
638
    h0 = F.randn((bg.number_of_nodes(), 5))
639
    h1 = sum_pool(bg, h0)
640
641
642
643
644
645
646
647
    truth = mx.nd.stack(
        F.sum(h0[:15], 0),
        F.sum(h0[15:20], 0),
        F.sum(h0[20:35], 0),
        F.sum(h0[35:40], 0),
        F.sum(h0[40:55], 0),
        axis=0,
    )
648
649
    check_close(h1, truth)

650
    h1 = avg_pool(bg, h0)
651
652
653
654
655
656
657
658
    truth = mx.nd.stack(
        F.mean(h0[:15], 0),
        F.mean(h0[15:20], 0),
        F.mean(h0[20:35], 0),
        F.mean(h0[35:40], 0),
        F.mean(h0[40:55], 0),
        axis=0,
    )
659
660
    check_close(h1, truth)

661
    h1 = max_pool(bg, h0)
662
663
664
665
666
667
668
669
    truth = mx.nd.stack(
        F.max(h0[:15], 0),
        F.max(h0[15:20], 0),
        F.max(h0[20:35], 0),
        F.max(h0[35:40], 0),
        F.max(h0[40:55], 0),
        axis=0,
    )
670
671
    check_close(h1, truth)

672
    h1 = sort_pool(bg, h0)
673
674
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.ndim == 2

675
676

@pytest.mark.parametrize("O", [1, 2, 8])
677
def test_rgcn(O):
Minjie Wang's avatar
Minjie Wang committed
678
679
    ctx = F.ctx()
    etype = []
680
    g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.1)).to(F.ctx())
Minjie Wang's avatar
Minjie Wang committed
681
682
683
684
685
686
687
688
689
690
691
692
693
694
    # 5 etypes
    R = 5
    for i in range(g.number_of_edges()):
        etype.append(i % 5)
    B = 2
    I = 10

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
    rgc_basis.initialize(ctx=ctx)
    h = nd.random.randn(100, I, ctx=ctx)
    r = nd.array(etype, ctx=ctx)
    h_new = rgc_basis(g, h, r)
    assert list(h_new.shape) == [100, O]

695
696
697
698
699
700
701
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
        rgc_bdd.initialize(ctx=ctx)
        h = nd.random.randn(100, I, ctx=ctx)
        r = nd.array(etype, ctx=ctx)
        h_new = rgc_bdd(g, h, r)
        assert list(h_new.shape) == [100, O]
Minjie Wang's avatar
Minjie Wang committed
702
703
704
705
706
707
708
709
710
711
712

    # with norm
    norm = nd.zeros((g.number_of_edges(), 1), ctx=ctx)

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
    rgc_basis.initialize(ctx=ctx)
    h = nd.random.randn(100, I, ctx=ctx)
    r = nd.array(etype, ctx=ctx)
    h_new = rgc_basis(g, h, r, norm)
    assert list(h_new.shape) == [100, O]

713
714
715
716
717
718
719
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
        rgc_bdd.initialize(ctx=ctx)
        h = nd.random.randn(100, I, ctx=ctx)
        r = nd.array(etype, ctx=ctx)
        h_new = rgc_bdd(g, h, r, norm)
        assert list(h_new.shape) == [100, O]
Minjie Wang's avatar
Minjie Wang committed
720
721
722
723
724
725
726
727
728

    # id input
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
    rgc_basis.initialize(ctx=ctx)
    h = nd.random.randint(0, I, (100,), ctx=ctx)
    r = nd.array(etype, ctx=ctx)
    h_new = rgc_basis(g, h, r)
    assert list(h_new.shape) == [100, O]

729

730
731
732
733
734
735
736
737
738
def test_sequential():
    ctx = F.ctx()
    # test single graph
    class ExampleLayer(gluon.nn.Block):
        def __init__(self, **kwargs):
            super().__init__(**kwargs)

        def forward(self, graph, n_feat, e_feat):
            graph = graph.local_var()
739
740
741
742
743
            graph.ndata["h"] = n_feat
            graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
            n_feat += graph.ndata["h"]
            graph.apply_edges(fn.u_add_v("h", "h", "e"))
            e_feat += graph.edata["e"]
744
745
            return n_feat, e_feat

746
    g = dgl.graph(([], [])).to(F.ctx())
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
    g.add_nodes(3)
    g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
    net = nn.Sequential()
    net.add(ExampleLayer())
    net.add(ExampleLayer())
    net.add(ExampleLayer())
    net.initialize(ctx=ctx)
    n_feat = F.randn((3, 4))
    e_feat = F.randn((9, 4))
    n_feat, e_feat = net(g, n_feat, e_feat)
    assert n_feat.shape == (3, 4)
    assert e_feat.shape == (9, 4)

    # test multiple graphs
    class ExampleLayer(gluon.nn.Block):
        def __init__(self, **kwargs):
            super().__init__(**kwargs)

        def forward(self, graph, n_feat):
            graph = graph.local_var()
767
768
769
            graph.ndata["h"] = n_feat
            graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
            n_feat += graph.ndata["h"]
770
771
            return n_feat.reshape(graph.number_of_nodes() // 2, 2, -1).sum(1)

772
773
774
775
    g1 = dgl.from_networkx(nx.erdos_renyi_graph(32, 0.05)).to(F.ctx())
    g2 = dgl.from_networkx(nx.erdos_renyi_graph(16, 0.2)).to(F.ctx())
    g3 = dgl.from_networkx(nx.erdos_renyi_graph(8, 0.8)).to(F.ctx())

776
777
778
779
780
781
782
783
784
    net = nn.Sequential()
    net.add(ExampleLayer())
    net.add(ExampleLayer())
    net.add(ExampleLayer())
    net.initialize(ctx=ctx)
    n_feat = F.randn((32, 4))
    n_feat = net([g1, g2, g3], n_feat)
    assert n_feat.shape == (4, 4)

785

786
787
788
789
790
791
def myagg(alist, dsttype):
    rst = alist[0]
    for i in range(1, len(alist)):
        rst = rst + (i + 1) * alist[i]
    return rst

792

nv-dlasalle's avatar
nv-dlasalle committed
793
@parametrize_idtype
794
@pytest.mark.parametrize("agg", ["sum", "max", "min", "mean", "stack", myagg])
795
def test_hetero_conv(agg, idtype):
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
    g = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 0, 2, 1], [1, 2, 1, 3]),
            ("user", "plays", "game"): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]),
            ("store", "sells", "game"): ([0, 0, 1, 1], [0, 3, 1, 2]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
    conv = nn.HeteroGraphConv(
        {
            "follows": nn.GraphConv(2, 3, allow_zero_in_degree=True),
            "plays": nn.GraphConv(2, 4, allow_zero_in_degree=True),
            "sells": nn.GraphConv(3, 4, allow_zero_in_degree=True),
        },
        agg,
    )
813
814
815
816
817
818
    conv.initialize(ctx=F.ctx())
    print(conv)
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))

819
820
821
822
823
    h = conv(g, {"user": uf, "store": sf, "game": gf})
    assert set(h.keys()) == {"user", "game"}
    if agg != "stack":
        assert h["user"].shape == (4, 3)
        assert h["game"].shape == (4, 4)
824
    else:
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
        assert h["user"].shape == (4, 1, 3)
        assert h["game"].shape == (4, 2, 4)

    block = dgl.to_block(
        g.to(F.cpu()), {"user": [0, 1, 2, 3], "game": [0, 1, 2, 3], "store": []}
    ).to(F.ctx())
    h = conv(
        block,
        (
            {"user": uf, "game": gf, "store": sf},
            {"user": uf, "game": gf, "store": sf[0:0]},
        ),
    )
    assert set(h.keys()) == {"user", "game"}
    if agg != "stack":
        assert h["user"].shape == (4, 3)
        assert h["game"].shape == (4, 4)
842
    else:
843
844
845
846
847
848
849
850
        assert h["user"].shape == (4, 1, 3)
        assert h["game"].shape == (4, 2, 4)

    h = conv(block, {"user": uf, "game": gf, "store": sf})
    assert set(h.keys()) == {"user", "game"}
    if agg != "stack":
        assert h["user"].shape == (4, 3)
        assert h["game"].shape == (4, 4)
851
    else:
852
853
        assert h["user"].shape == (4, 1, 3)
        assert h["game"].shape == (4, 2, 4)
854
855
856
857
858
859
860
861

    # test with mod args
    class MyMod(mx.gluon.nn.Block):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.s1 = s1
            self.s2 = s2
862

863
864
865
866
        def forward(self, g, h, arg1=None):  # mxnet does not support kwargs
            if arg1 is not None:
                self.carg1 += 1
            return F.zeros((g.number_of_dst_nodes(), self.s2))
867

868
869
870
    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
871
872
873
    conv = nn.HeteroGraphConv(
        {"follows": mod1, "plays": mod2, "sells": mod3}, agg
    )
874
    conv.initialize(ctx=F.ctx())
875
876
    mod_args = {"follows": (1,), "plays": (1,)}
    h = conv(g, {"user": uf, "store": sf, "game": gf}, mod_args)
877
878
879
880
    assert mod1.carg1 == 1
    assert mod2.carg1 == 1
    assert mod3.carg1 == 0

881
    # conv on graph without any edges
882
    for etype in g.etypes:
883
        g = dgl.remove_edges(g, g.edges(form="eid", etype=etype), etype=etype)
884
    assert g.num_edges() == 0
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
    h = conv(g, {"user": uf, "game": gf, "store": sf})
    assert set(h.keys()) == {"user", "game"}

    block = dgl.to_block(
        g.to(F.cpu()), {"user": [0, 1, 2, 3], "game": [0, 1, 2, 3], "store": []}
    ).to(F.ctx())
    h = conv(
        block,
        (
            {"user": uf, "game": gf, "store": sf},
            {"user": uf, "game": gf, "store": sf[0:0]},
        ),
    )
    assert set(h.keys()) == {"user", "game"}


if __name__ == "__main__":
902
    test_graph_conv()
903
904
905
906
907
908
909
910
911
912
913
914
915
916
    test_gat_conv()
    test_sage_conv()
    test_gg_conv()
    test_cheb_conv()
    test_agnn_conv()
    test_appnp_conv()
    test_dense_cheb_conv()
    test_dense_graph_conv()
    test_dense_sage_conv()
    test_edge_conv()
    test_gin_conv()
    test_gmm_conv()
    test_nn_conv()
    test_sg_conv()
917
918
919
    test_set2set()
    test_glob_att_pool()
    test_simple_pool()
Minjie Wang's avatar
Minjie Wang committed
920
    test_rgcn()
921
    test_sequential()
922
    test_hetero_conv()