test_nn.py 32.3 KB
Newer Older
1
2
3
4
import torch as th
import networkx as nx
import dgl
import dgl.nn.pytorch as nn
5
import dgl.function as fn
6
import backend as F
7
import pytest
8
9
from test_utils.graph_cases import get_cases, random_graph, random_bipartite, random_dglgraph
from test_utils import parametrize_dtype
10
11
from copy import deepcopy

12
13
14
import numpy as np
import scipy as sp

15
16
17
18
19
def _AXWb(A, X, W, b):
    X = th.matmul(X, W)
    Y = th.matmul(A, X.view(X.shape[0], -1)).view_as(X)
    return Y + b

20
21
def test_graph_conv0():
    g = dgl.DGLGraph(nx.path_graph(3)).to(F.ctx())
22
    ctx = F.ctx()
23
    adj = g.adjacency_matrix(transpose=False, ctx=ctx)
24

25
    conv = nn.GraphConv(5, 2, norm='none', bias=True)
26
    conv = conv.to(ctx)
27
28
    print(conv)
    # test#1: basic
29
    h0 = F.ones((3, 5))
30
    h1 = conv(g, h0)
31
32
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
33
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
34
    # test#2: more-dim
35
    h0 = F.ones((3, 5, 5))
36
    h1 = conv(g, h0)
37
38
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
39
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
40
41

    conv = nn.GraphConv(5, 2)
42
    conv = conv.to(ctx)
43
    # test#3: basic
44
    h0 = F.ones((3, 5))
45
    h1 = conv(g, h0)
46
47
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
48
    # test#4: basic
49
    h0 = F.ones((3, 5, 5))
50
    h1 = conv(g, h0)
51
52
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
53
54

    conv = nn.GraphConv(5, 2)
55
    conv = conv.to(ctx)
56
    # test#3: basic
57
    h0 = F.ones((3, 5))
58
    h1 = conv(g, h0)
59
60
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
61
    # test#4: basic
62
    h0 = F.ones((3, 5, 5))
63
    h1 = conv(g, h0)
64
65
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
66
67
68
69
70

    # test rest_parameters
    old_weight = deepcopy(conv.weight.data)
    conv.reset_parameters()
    new_weight = conv.weight.data
71
    assert not F.allclose(old_weight, new_weight)
72

73
74
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite'], exclude=['zero-degree', 'dglgraph']))
75
76
77
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
78
79
80
def test_graph_conv(idtype, g, norm, weight, bias):
    # Test one tensor input
    g = g.astype(idtype).to(F.ctx())
81
82
    conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias).to(F.ctx())
    ext_w = F.randn((5, 2)).to(F.ctx())
83
84
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
85
86
    h = F.randn((nsrc, 5)).to(F.ctx())
    if weight:
87
        h_out = conv(g, h)
88
    else:
89
90
91
        h_out = conv(g, h, weight=ext_w)
    assert h_out.shape == (ndst, 2)

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
def test_graph_conv_bi(idtype, g, norm, weight, bias):
    # Test a pair of tensor inputs
    g = g.astype(idtype).to(F.ctx())
    conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias).to(F.ctx())
    ext_w = F.randn((5, 2)).to(F.ctx())
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
    h_dst = F.randn((ndst, 2)).to(F.ctx())
    if weight:
        h_out = conv(g, (h, h_dst))
    else:
        h_out = conv(g, (h, h_dst), weight=ext_w)
    assert h_out.shape == (ndst, 2)
111

112
113
114
115
116
117
118
119
120
121
122
123
def _S2AXWb(A, N, X, W, b):
    X1 = X * N
    X1 = th.matmul(A, X1.view(X1.shape[0], -1))
    X1 = X1 * N
    X2 = X1 * N
    X2 = th.matmul(A, X2.view(X2.shape[0], -1))
    X2 = X2 * N
    X = th.cat([X, X1, X2], dim=-1)
    Y = th.matmul(X, W.rot90())

    return Y + b

124
def test_tagconv():
125
    g = dgl.DGLGraph(nx.path_graph(3))
126
    g = g.to(F.ctx())
127
    ctx = F.ctx()
128
    adj = g.adjacency_matrix(transpose=False, ctx=ctx)
129
130
    norm = th.pow(g.in_degrees().float(), -0.5)

131
    conv = nn.TAGConv(5, 2, bias=True)
132
    conv = conv.to(ctx)
133
134
135
136
    print(conv)

    # test#1: basic
    h0 = F.ones((3, 5))
137
    h1 = conv(g, h0)
138
139
140
141
142
143
144
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    shp = norm.shape + (1,) * (h0.dim() - 1)
    norm = th.reshape(norm, shp).to(ctx)

    assert F.allclose(h1, _S2AXWb(adj, norm, h0, conv.lin.weight, conv.lin.bias))

145
    conv = nn.TAGConv(5, 2)
146
    conv = conv.to(ctx)
147

148
149
    # test#2: basic
    h0 = F.ones((3, 5))
150
    h1 = conv(g, h0)
151
    assert h1.shape[-1] == 2
152

153
    # test reset_parameters
154
155
156
157
158
    old_weight = deepcopy(conv.lin.weight.data)
    conv.reset_parameters()
    new_weight = conv.lin.weight.data
    assert not F.allclose(old_weight, new_weight)

159
def test_set2set():
160
    ctx = F.ctx()
161
    g = dgl.DGLGraph(nx.path_graph(10))
162
    g = g.to(F.ctx())
163
164

    s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
165
    s2s = s2s.to(ctx)
166
167
168
    print(s2s)

    # test#1: basic
169
    h0 = F.randn((g.number_of_nodes(), 5))
170
    h1 = s2s(g, h0)
171
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
172
173

    # test#2: batched graph
174
175
    g1 = dgl.DGLGraph(nx.path_graph(11)).to(F.ctx())
    g2 = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
176
    bg = dgl.batch([g, g1, g2])
177
    h0 = F.randn((bg.number_of_nodes(), 5))
178
    h1 = s2s(bg, h0)
179
180
181
    assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2

def test_glob_att_pool():
182
    ctx = F.ctx()
183
    g = dgl.DGLGraph(nx.path_graph(10))
184
    g = g.to(F.ctx())
185
186

    gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))
187
    gap = gap.to(ctx)
188
189
190
    print(gap)

    # test#1: basic
191
    h0 = F.randn((g.number_of_nodes(), 5))
192
    h1 = gap(g, h0)
193
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
194
195
196

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
197
    h0 = F.randn((bg.number_of_nodes(), 5))
198
    h1 = gap(bg, h0)
199
200
201
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2

def test_simple_pool():
202
    ctx = F.ctx()
203
    g = dgl.DGLGraph(nx.path_graph(15))
204
    g = g.to(F.ctx())
205
206
207
208
209
210
211
212

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
    sort_pool = nn.SortPooling(10) # k = 10
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
213
    h0 = F.randn((g.number_of_nodes(), 5))
214
215
216
217
    sum_pool = sum_pool.to(ctx)
    avg_pool = avg_pool.to(ctx)
    max_pool = max_pool.to(ctx)
    sort_pool = sort_pool.to(ctx)
218
    h1 = sum_pool(g, h0)
219
    assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
220
    h1 = avg_pool(g, h0)
221
    assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
222
    h1 = max_pool(g, h0)
223
    assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
224
    h1 = sort_pool(g, h0)
225
    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.dim() == 2
226
227

    # test#2: batched graph
228
    g_ = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
229
    bg = dgl.batch([g, g_, g, g_, g])
230
    h0 = F.randn((bg.number_of_nodes(), 5))
231
    h1 = sum_pool(bg, h0)
232
233
234
235
236
237
    truth = th.stack([F.sum(h0[:15], 0),
                      F.sum(h0[15:20], 0),
                      F.sum(h0[20:35], 0),
                      F.sum(h0[35:40], 0),
                      F.sum(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
238

239
    h1 = avg_pool(bg, h0)
240
241
242
243
244
245
    truth = th.stack([F.mean(h0[:15], 0),
                      F.mean(h0[15:20], 0),
                      F.mean(h0[20:35], 0),
                      F.mean(h0[35:40], 0),
                      F.mean(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
246

247
    h1 = max_pool(bg, h0)
248
249
250
251
252
253
    truth = th.stack([F.max(h0[:15], 0),
                      F.max(h0[15:20], 0),
                      F.max(h0[20:35], 0),
                      F.max(h0[35:40], 0),
                      F.max(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
254

255
    h1 = sort_pool(bg, h0)
256
257
258
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2

def test_set_trans():
259
    ctx = F.ctx()
260
261
262
263
264
    g = dgl.DGLGraph(nx.path_graph(15))

    st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'sab')
    st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'isab', 3)
    st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)
265
266
267
    st_enc_0 = st_enc_0.to(ctx)
    st_enc_1 = st_enc_1.to(ctx)
    st_dec = st_dec.to(ctx)
268
269
270
    print(st_enc_0, st_enc_1, st_dec)

    # test#1: basic
271
    h0 = F.randn((g.number_of_nodes(), 50))
272
    h1 = st_enc_0(g, h0)
273
    assert h1.shape == h0.shape
274
    h1 = st_enc_1(g, h0)
275
    assert h1.shape == h0.shape
276
    h2 = st_dec(g, h1)
277
    assert h2.shape[0] == 1 and h2.shape[1] == 200 and h2.dim() == 2
278
279
280
281
282

    # test#2: batched graph
    g1 = dgl.DGLGraph(nx.path_graph(5))
    g2 = dgl.DGLGraph(nx.path_graph(10))
    bg = dgl.batch([g, g1, g2])
283
    h0 = F.randn((bg.number_of_nodes(), 50))
284
    h1 = st_enc_0(bg, h0)
285
    assert h1.shape == h0.shape
286
    h1 = st_enc_1(bg, h0)
287
288
    assert h1.shape == h0.shape

289
    h2 = st_dec(bg, h1)
290
291
    assert h2.shape[0] == 3 and h2.shape[1] == 200 and h2.dim() == 2

Minjie Wang's avatar
Minjie Wang committed
292
293
294
295
def test_rgcn():
    ctx = F.ctx()
    etype = []
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
296
    g = g.to(F.ctx())
Minjie Wang's avatar
Minjie Wang committed
297
298
299
300
301
302
303
304
305
    # 5 etypes
    R = 5
    for i in range(g.number_of_edges()):
        etype.append(i % 5)
    B = 2
    I = 10
    O = 8

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
306
307
308
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
309
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
Minjie Wang's avatar
Minjie Wang committed
310
311
312
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r)
313
    h_new_low = rgc_basis_low(g, h, r)
Minjie Wang's avatar
Minjie Wang committed
314
    assert list(h_new.shape) == [100, O]
315
316
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
317
318

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
319
320
    rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
    rgc_bdd_low.weight = rgc_bdd.weight
321
    rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
Minjie Wang's avatar
Minjie Wang committed
322
323
324
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_bdd(g, h, r)
325
    h_new_low = rgc_bdd_low(g, h, r)
Minjie Wang's avatar
Minjie Wang committed
326
    assert list(h_new.shape) == [100, O]
327
328
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
329
330
331
332
333

    # with norm
    norm = th.zeros((g.number_of_edges(), 1)).to(ctx)

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
334
335
336
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
337
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
Minjie Wang's avatar
Minjie Wang committed
338
339
340
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r, norm)
341
    h_new_low = rgc_basis_low(g, h, r, norm)
Minjie Wang's avatar
Minjie Wang committed
342
    assert list(h_new.shape) == [100, O]
343
344
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
345
346

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
347
348
    rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
    rgc_bdd_low.weight = rgc_bdd.weight
349
    rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
Minjie Wang's avatar
Minjie Wang committed
350
351
352
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_bdd(g, h, r, norm)
353
    h_new_low = rgc_bdd_low(g, h, r, norm)
Minjie Wang's avatar
Minjie Wang committed
354
    assert list(h_new.shape) == [100, O]
355
356
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
357
358
359

    # id input
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
360
361
362
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
363
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
Minjie Wang's avatar
Minjie Wang committed
364
365
366
    h = th.randint(0, I, (100,)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r)
367
    h_new_low = rgc_basis_low(g, h, r)
Minjie Wang's avatar
Minjie Wang committed
368
    assert list(h_new.shape) == [100, O]
369
370
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
371

372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452

def test_rgcn_sorted():
    ctx = F.ctx()
    etype = []
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    g = g.to(F.ctx())
    # 5 etypes
    R = 5
    etype = [200, 200, 200, 200, 200]
    B = 2
    I = 10
    O = 8

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
    h = th.randn((100, I)).to(ctx)
    r = etype
    h_new = rgc_basis(g, h, r)
    h_new_low = rgc_basis_low(g, h, r)
    assert list(h_new.shape) == [100, O]
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
    rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
    rgc_bdd_low.weight = rgc_bdd.weight
    rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
    h = th.randn((100, I)).to(ctx)
    r = etype
    h_new = rgc_bdd(g, h, r)
    h_new_low = rgc_bdd_low(g, h, r)
    assert list(h_new.shape) == [100, O]
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)

    # with norm
    norm = th.zeros((g.number_of_edges(), 1)).to(ctx)

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
    h = th.randn((100, I)).to(ctx)
    r = etype
    h_new = rgc_basis(g, h, r, norm)
    h_new_low = rgc_basis_low(g, h, r, norm)
    assert list(h_new.shape) == [100, O]
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
    rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
    rgc_bdd_low.weight = rgc_bdd.weight
    rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
    h = th.randn((100, I)).to(ctx)
    r = etype
    h_new = rgc_bdd(g, h, r, norm)
    h_new_low = rgc_bdd_low(g, h, r, norm)
    assert list(h_new.shape) == [100, O]
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)

    # id input
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
    h = th.randint(0, I, (100,)).to(ctx)
    r = etype
    h_new = rgc_basis(g, h, r)
    h_new_low = rgc_basis_low(g, h, r)
    assert list(h_new.shape) == [100, O]
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)


453
@parametrize_dtype
454
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
455
456
def test_gat_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
457
458
    ctx = F.ctx()
    gat = nn.GATConv(5, 2, 4)
459
    feat = F.randn((g.number_of_nodes(), 5))
460
    gat = gat.to(ctx)
461
    h = gat(g, feat)
462
    assert h.shape == (g.number_of_nodes(), 4, 2)
463
464
    _, a = gat(g, feat, get_attention=True)
    assert a.shape == (g.number_of_edges(), 4, 1)
465

466
@parametrize_dtype
467
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
468
469
470
def test_gat_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
471
472
    gat = nn.GATConv(5, 2, 4)
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
473
474
    gat = gat.to(ctx)
    h = gat(g, feat)
475
    assert h.shape == (g.number_of_dst_nodes(), 4, 2)
476
477
    _, a = gat(g, feat, get_attention=True)
    assert a.shape == (g.number_of_edges(), 4, 1)
478

479
@parametrize_dtype
480
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
481
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm'])
482
483
def test_sage_conv(idtype, g, aggre_type):
    g = g.astype(idtype).to(F.ctx())
484
    sage = nn.SAGEConv(5, 10, aggre_type)
485
486
    feat = F.randn((g.number_of_nodes(), 5))
    sage = sage.to(F.ctx())
487
488
489
    h = sage(g, feat)
    assert h.shape[-1] == 10

490
@parametrize_dtype
491
@pytest.mark.parametrize('g', get_cases(['bipartite']))
492
493
494
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm'])
def test_sage_conv_bi(idtype, g, aggre_type):
    g = g.astype(idtype).to(F.ctx())
495
496
    dst_dim = 5 if aggre_type != 'gcn' else 10
    sage = nn.SAGEConv((10, dst_dim), 2, aggre_type)
497
498
    feat = (F.randn((g.number_of_src_nodes(), 10)), F.randn((g.number_of_dst_nodes(), dst_dim)))
    sage = sage.to(F.ctx())
499
500
    h = sage(g, feat)
    assert h.shape[-1] == 2
501
    assert h.shape[0] == g.number_of_dst_nodes()
502

503
504
505
@parametrize_dtype
def test_sage_conv2(idtype):
    # TODO: add test for blocks
Mufei Li's avatar
Mufei Li committed
506
    # Test the case for graphs without edges
507
    g = dgl.heterograph({('_U', '_E', '_V'): ([], [])}, {'_U': 5, '_V': 3})
508
509
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
510
    sage = nn.SAGEConv((3, 3), 2, 'gcn')
Mufei Li's avatar
Mufei Li committed
511
512
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    sage = sage.to(ctx)
513
    h = sage(g, (F.copy_to(feat[0], F.ctx()), F.copy_to(feat[1], F.ctx())))
Mufei Li's avatar
Mufei Li committed
514
515
516
    assert h.shape[-1] == 2
    assert h.shape[0] == 3
    for aggre_type in ['mean', 'pool', 'lstm']:
517
        sage = nn.SAGEConv((3, 1), 2, aggre_type)
Mufei Li's avatar
Mufei Li committed
518
519
520
521
522
523
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        sage = sage.to(ctx)
        h = sage(g, feat)
        assert h.shape[-1] == 2
        assert h.shape[0] == 3

524
525
526
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_sgc_conv(g, idtype):
527
    ctx = F.ctx()
528
    g = g.astype(idtype).to(ctx)
529
530
    # not cached
    sgc = nn.SGConv(5, 10, 3)
531
    feat = F.randn((g.number_of_nodes(), 5))
532
    sgc = sgc.to(ctx)
533

534
    h = sgc(g, feat)
535
536
537
538
    assert h.shape[-1] == 10

    # cached
    sgc = nn.SGConv(5, 10, 3, True)
539
    sgc = sgc.to(ctx)
540
541
    h_0 = sgc(g, feat)
    h_1 = sgc(g, feat + 1)
542
543
544
    assert F.allclose(h_0, h_1)
    assert h_0.shape[-1] == 10

545
546
547
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_appnp_conv(g, idtype):
548
    ctx = F.ctx()
549
    g = g.astype(idtype).to(ctx)
550
    appnp = nn.APPNPConv(10, 0.1)
551
    feat = F.randn((g.number_of_nodes(), 5))
552
    appnp = appnp.to(ctx)
553

554
    h = appnp(g, feat)
555
556
    assert h.shape[-1] == 5

557
@parametrize_dtype
558
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
559
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
560
561
def test_gin_conv(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
562
563
564
565
566
    ctx = F.ctx()
    gin = nn.GINConv(
        th.nn.Linear(5, 12),
        aggregator_type
    )
567
    feat = F.randn((g.number_of_nodes(), 5))
568
569
    gin = gin.to(ctx)
    h = gin(g, feat)
570
    assert h.shape == (g.number_of_nodes(), 12)
571

572
@parametrize_dtype
573
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
574
575
576
577
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
def test_gin_conv_bi(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
578
579
580
581
    gin = nn.GINConv(
        th.nn.Linear(5, 12),
        aggregator_type
    )
582
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
583
584
    gin = gin.to(ctx)
    h = gin(g, feat)
585
    assert h.shape == (g.number_of_dst_nodes(), 12)
586

587
@parametrize_dtype
588
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
589
590
def test_agnn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
591
592
    ctx = F.ctx()
    agnn = nn.AGNNConv(1)
593
    feat = F.randn((g.number_of_nodes(), 5))
594
    agnn = agnn.to(ctx)
595
    h = agnn(g, feat)
596
    assert h.shape == (g.number_of_nodes(), 5)
597

598
@parametrize_dtype
599
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
600
601
602
def test_agnn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
603
    agnn = nn.AGNNConv(1)
604
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
605
606
    agnn = agnn.to(ctx)
    h = agnn(g, feat)
607
    assert h.shape == (g.number_of_dst_nodes(), 5)
608

609
610
611
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_gated_graph_conv(g, idtype):
612
    ctx = F.ctx()
613
    g = g.astype(idtype).to(ctx)
614
615
    ggconv = nn.GatedGraphConv(5, 10, 5, 3)
    etypes = th.arange(g.number_of_edges()) % 3
616
    feat = F.randn((g.number_of_nodes(), 5))
617
618
    ggconv = ggconv.to(ctx)
    etypes = etypes.to(ctx)
619

620
    h = ggconv(g, feat, etypes)
621
622
623
    # current we only do shape check
    assert h.shape[-1] == 10

624
@parametrize_dtype
625
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
626
627
def test_nn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
628
629
630
    ctx = F.ctx()
    edge_func = th.nn.Linear(4, 5 * 10)
    nnconv = nn.NNConv(5, 10, edge_func, 'mean')
631
    feat = F.randn((g.number_of_nodes(), 5))
632
633
634
635
636
637
    efeat = F.randn((g.number_of_edges(), 4))
    nnconv = nnconv.to(ctx)
    h = nnconv(g, feat, efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

638
@parametrize_dtype
639
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
640
641
642
def test_nn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
643
644
    edge_func = th.nn.Linear(4, 5 * 10)
    nnconv = nn.NNConv((5, 2), 10, edge_func, 'mean')
645
646
    feat = F.randn((g.number_of_src_nodes(), 5))
    feat_dst = F.randn((g.number_of_dst_nodes(), 2))
647
648
649
650
651
652
    efeat = F.randn((g.number_of_edges(), 4))
    nnconv = nnconv.to(ctx)
    h = nnconv(g, (feat, feat_dst), efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

653
@parametrize_dtype
654
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
655
656
def test_gmm_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
657
658
    ctx = F.ctx()
    gmmconv = nn.GMMConv(5, 10, 3, 4, 'mean')
659
    feat = F.randn((g.number_of_nodes(), 5))
660
    pseudo = F.randn((g.number_of_edges(), 3))
661
    gmmconv = gmmconv.to(ctx)
662
    h = gmmconv(g, feat, pseudo)
663
664
665
    # currently we only do shape check
    assert h.shape[-1] == 10

666
@parametrize_dtype
667
@pytest.mark.parametrize('g', get_cases(['bipartite', 'block-bipartite'], exclude=['zero-degree']))
668
669
670
def test_gmm_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
671
    gmmconv = nn.GMMConv((5, 2), 10, 3, 4, 'mean')
672
673
    feat = F.randn((g.number_of_src_nodes(), 5))
    feat_dst = F.randn((g.number_of_dst_nodes(), 2))
674
675
676
677
678
679
    pseudo = F.randn((g.number_of_edges(), 3))
    gmmconv = gmmconv.to(ctx)
    h = gmmconv(g, (feat, feat_dst), pseudo)
    # currently we only do shape check
    assert h.shape[-1] == 10

680
@parametrize_dtype
681
@pytest.mark.parametrize('norm_type', ['both', 'right', 'none'])
682
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite'], exclude=['zero-degree']))
683
684
def test_dense_graph_conv(norm_type, g, idtype):
    g = g.astype(idtype).to(F.ctx())
685
    ctx = F.ctx()
686
    # TODO(minjie): enable the following option after #1385
687
    adj = g.adjacency_matrix(transpose=False, ctx=ctx).to_dense()
688
689
    conv = nn.GraphConv(5, 2, norm=norm_type, bias=True)
    dense_conv = nn.DenseGraphConv(5, 2, norm=norm_type, bias=True)
690
691
    dense_conv.weight.data = conv.weight.data
    dense_conv.bias.data = conv.bias.data
692
    feat = F.randn((g.number_of_src_nodes(), 5))
693
694
    conv = conv.to(ctx)
    dense_conv = dense_conv.to(ctx)
695
696
    out_conv = conv(g, feat)
    out_dense_conv = dense_conv(adj, feat)
697
698
    assert F.allclose(out_conv, out_dense_conv)

699
@parametrize_dtype
700
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite']))
701
702
def test_dense_sage_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
703
    ctx = F.ctx()
704
    adj = g.adjacency_matrix(transpose=False, ctx=ctx).to_dense()
705
    sage = nn.SAGEConv(5, 2, 'gcn')
706
707
708
    dense_sage = nn.DenseSAGEConv(5, 2)
    dense_sage.fc.weight.data = sage.fc_neigh.weight.data
    dense_sage.fc.bias.data = sage.fc_neigh.bias.data
709
710
711
712
713
714
715
    if len(g.ntypes) == 2:
        feat = (
            F.randn((g.number_of_src_nodes(), 5)),
            F.randn((g.number_of_dst_nodes(), 5))
        )
    else:
        feat = F.randn((g.number_of_nodes(), 5))
716
717
    sage = sage.to(ctx)
    dense_sage = dense_sage.to(ctx)
718
719
    out_sage = sage(g, feat)
    out_dense_sage = dense_sage(adj, feat)
720
721
    assert F.allclose(out_sage, out_dense_sage), g

722
@parametrize_dtype
723
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
724
725
def test_edge_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
726
727
728
    ctx = F.ctx()
    edge_conv = nn.EdgeConv(5, 2).to(ctx)
    print(edge_conv)
729
730
731
    h0 = F.randn((g.number_of_nodes(), 5))
    h1 = edge_conv(g, h0)
    assert h1.shape == (g.number_of_nodes(), 2)
732

733
@parametrize_dtype
734
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
735
736
737
738
739
def test_edge_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    edge_conv = nn.EdgeConv(5, 2).to(ctx)
    print(edge_conv)
740
    h0 = F.randn((g.number_of_src_nodes(), 5))
741
742
    x0 = F.randn((g.number_of_dst_nodes(), 5))
    h1 = edge_conv(g, (h0, x0))
743
    assert h1.shape == (g.number_of_dst_nodes(), 2)
744
745
746
747
748

def test_dense_cheb_conv():
    for k in range(1, 4):
        ctx = F.ctx()
        g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
749
        g = g.to(F.ctx())
750
        adj = g.adjacency_matrix(transpose=False, ctx=ctx).to_dense()
Axel Nilsson's avatar
Axel Nilsson committed
751
        cheb = nn.ChebConv(5, 2, k, None)
752
        dense_cheb = nn.DenseChebConv(5, 2, k)
Axel Nilsson's avatar
Axel Nilsson committed
753
754
755
756
757
        #for i in range(len(cheb.fc)):
        #    dense_cheb.W.data[i] = cheb.fc[i].weight.data.t()
        dense_cheb.W.data = cheb.linear.weight.data.transpose(-1, -2).view(k, 5, 2)
        if cheb.linear.bias is not None:
            dense_cheb.bias.data = cheb.linear.bias.data
758
        feat = F.randn((100, 5))
759
760
        cheb = cheb.to(ctx)
        dense_cheb = dense_cheb.to(ctx)
761
762
        out_cheb = cheb(g, feat, [2.0])
        out_dense_cheb = dense_cheb(adj, feat, 2.0)
Axel Nilsson's avatar
Axel Nilsson committed
763
        print(k, out_cheb, out_dense_cheb)
764
765
        assert F.allclose(out_cheb, out_dense_cheb)

766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
def test_sequential():
    ctx = F.ctx()
    # Test single graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat, e_feat):
            graph = graph.local_var()
            graph.ndata['h'] = n_feat
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            n_feat += graph.ndata['h']
            graph.apply_edges(fn.u_add_v('h', 'h', 'e'))
            e_feat += graph.edata['e']
            return n_feat, e_feat

    g = dgl.DGLGraph()
    g.add_nodes(3)
    g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
785
    g = g.to(F.ctx())
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    n_feat = F.randn((3, 4))
    e_feat = F.randn((9, 4))
    net = net.to(ctx)
    n_feat, e_feat = net(g, n_feat, e_feat)
    assert n_feat.shape == (3, 4)
    assert e_feat.shape == (9, 4)

    # Test multiple graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat):
            graph = graph.local_var()
            graph.ndata['h'] = n_feat
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            n_feat += graph.ndata['h']
            return n_feat.view(graph.number_of_nodes() // 2, 2, -1).sum(1)

806
807
808
    g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05)).to(F.ctx())
    g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2)).to(F.ctx())
    g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8)).to(F.ctx())
809
810
811
812
813
814
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    net = net.to(ctx)
    n_feat = F.randn((32, 4))
    n_feat = net([g1, g2, g3], n_feat)
    assert n_feat.shape == (4, 4)

815
816
817
818
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_atomic_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
819
820
821
822
823
824
825
826
827
    aconv = nn.AtomicConv(interaction_cutoffs=F.tensor([12.0, 12.0]),
                          rbf_kernel_means=F.tensor([0.0, 2.0]),
                          rbf_kernel_scaling=F.tensor([4.0, 4.0]),
                          features_to_use=F.tensor([6.0, 8.0]))

    ctx = F.ctx()
    if F.gpu_ctx():
        aconv = aconv.to(ctx)

828
    feat = F.randn((g.number_of_nodes(), 1))
829
830
831
832
833
834
    dist = F.randn((g.number_of_edges(), 1))

    h = aconv(g, feat, dist)
    # current we only do shape check
    assert h.shape[-1] == 4

835
836
837
838
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_cf_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
839
840
841
842
843
844
845
846
847
    cfconv = nn.CFConv(node_in_feats=2,
                       edge_in_feats=3,
                       hidden_feats=2,
                       out_feats=3)

    ctx = F.ctx()
    if F.gpu_ctx():
        cfconv = cfconv.to(ctx)

848
    node_feats = F.randn((g.number_of_nodes(), 2))
849
850
851
    edge_feats = F.randn((g.number_of_edges(), 3))
    h = cfconv(g, node_feats, edge_feats)
    # current we only do shape check
852
    assert h.shape[-1] == 3
853

854
855
856
857
858
859
def myagg(alist, dsttype):
    rst = alist[0]
    for i in range(1, len(alist)):
        rst = rst + (i + 1) * alist[i]
    return rst

860
@parametrize_dtype
861
@pytest.mark.parametrize('agg', ['sum', 'max', 'min', 'mean', 'stack', myagg])
862
def test_hetero_conv(agg, idtype):
863
    g = dgl.heterograph({
864
865
866
        ('user', 'follows', 'user'): ([0, 0, 2, 1], [1, 2, 1, 3]),
        ('user', 'plays', 'game'): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]),
        ('store', 'sells', 'game'): ([0, 0, 1, 1], [0, 3, 1, 2])},
867
        idtype=idtype, device=F.ctx())
868
    conv = nn.HeteroGraphConv({
869
870
871
        'follows': nn.GraphConv(2, 3, allow_zero_in_degree=True),
        'plays': nn.GraphConv(2, 4, allow_zero_in_degree=True),
        'sells': nn.GraphConv(3, 4, allow_zero_in_degree=True)},
872
        agg)
873
    conv = conv.to(F.ctx())
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))

    h = conv(g, {'user': uf})
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 1, 4)

    h = conv(g, {'user': uf, 'store': sf})
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 2, 4)

    h = conv(g, {'store': sf})
    assert set(h.keys()) == {'game'}
    if agg != 'stack':
        assert h['game'].shape == (4, 4)
    else:
        assert h['game'].shape == (4, 1, 4)

    # test with pair input
    conv = nn.HeteroGraphConv({
905
906
907
        'follows': nn.SAGEConv(2, 3, 'mean'),
        'plays': nn.SAGEConv((2, 4), 4, 'mean'),
        'sells': nn.SAGEConv(3, 4, 'mean')},
908
        agg)
909
    conv = conv.to(F.ctx())
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949

    h = conv(g, ({'user': uf}, {'user' : uf, 'game' : gf}))
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 1, 4)

    # pair input requires both src and dst type features to be provided
    h = conv(g, ({'user': uf}, {'game' : gf}))
    assert set(h.keys()) == {'game'}
    if agg != 'stack':
        assert h['game'].shape == (4, 4)
    else:
        assert h['game'].shape == (4, 1, 4)

    # test with mod args
    class MyMod(th.nn.Module):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.carg2 = 0
            self.s1 = s1
            self.s2 = s2
        def forward(self, g, h, arg1=None, *, arg2=None):
            if arg1 is not None:
                self.carg1 += 1
            if arg2 is not None:
                self.carg2 += 1
            return th.zeros((g.number_of_dst_nodes(), self.s2))
    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
    conv = nn.HeteroGraphConv({
        'follows': mod1,
        'plays': mod2,
        'sells': mod3},
        agg)
950
    conv = conv.to(F.ctx())
951
952
953
954
955
956
957
958
959
960
    mod_args = {'follows' : (1,), 'plays' : (1,)}
    mod_kwargs = {'sells' : {'arg2' : 'abc'}}
    h = conv(g, {'user' : uf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs)
    assert mod1.carg1 == 1
    assert mod1.carg2 == 0
    assert mod2.carg1 == 1
    assert mod2.carg2 == 0
    assert mod3.carg1 == 0
    assert mod3.carg2 == 1

961
962
if __name__ == '__main__':
    test_graph_conv()
963
964
965
966
    test_set2set()
    test_glob_att_pool()
    test_simple_pool()
    test_set_trans()
Minjie Wang's avatar
Minjie Wang committed
967
    test_rgcn()
968
969
970
971
972
973
974
975
976
977
978
979
980
    test_tagconv()
    test_gat_conv()
    test_sage_conv()
    test_sgc_conv()
    test_appnp_conv()
    test_gin_conv()
    test_agnn_conv()
    test_gated_graph_conv()
    test_nn_conv()
    test_gmm_conv()
    test_dense_graph_conv()
    test_dense_sage_conv()
    test_dense_cheb_conv()
981
    test_sequential()
982
    test_atomic_conv()
983
    test_cf_conv()