test_nn.py 32.3 KB
Newer Older
1
2
3
4
import torch as th
import networkx as nx
import dgl
import dgl.nn.pytorch as nn
5
import dgl.function as fn
6
import backend as F
7
import pytest
8
9
from test_utils.graph_cases import get_cases, random_graph, random_bipartite, random_dglgraph
from test_utils import parametrize_dtype
10
11
from copy import deepcopy

12
13
14
import numpy as np
import scipy as sp

15
16
17
18
19
def _AXWb(A, X, W, b):
    X = th.matmul(X, W)
    Y = th.matmul(A, X.view(X.shape[0], -1)).view_as(X)
    return Y + b

20
21
def test_graph_conv0():
    g = dgl.DGLGraph(nx.path_graph(3)).to(F.ctx())
22
23
    ctx = F.ctx()
    adj = g.adjacency_matrix(ctx=ctx)
24

25
    conv = nn.GraphConv(5, 2, norm='none', bias=True)
26
    conv = conv.to(ctx)
27
28
    print(conv)
    # test#1: basic
29
    h0 = F.ones((3, 5))
30
    h1 = conv(g, h0)
31
32
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
33
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
34
    # test#2: more-dim
35
    h0 = F.ones((3, 5, 5))
36
    h1 = conv(g, h0)
37
38
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
39
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
40
41

    conv = nn.GraphConv(5, 2)
42
    conv = conv.to(ctx)
43
    # test#3: basic
44
    h0 = F.ones((3, 5))
45
    h1 = conv(g, h0)
46
47
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
48
    # test#4: basic
49
    h0 = F.ones((3, 5, 5))
50
    h1 = conv(g, h0)
51
52
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
53
54

    conv = nn.GraphConv(5, 2)
55
    conv = conv.to(ctx)
56
    # test#3: basic
57
    h0 = F.ones((3, 5))
58
    h1 = conv(g, h0)
59
60
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
61
    # test#4: basic
62
    h0 = F.ones((3, 5, 5))
63
    h1 = conv(g, h0)
64
65
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
66
67
68
69
70

    # test rest_parameters
    old_weight = deepcopy(conv.weight.data)
    conv.reset_parameters()
    new_weight = conv.weight.data
71
    assert not F.allclose(old_weight, new_weight)
72

73
74
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite'], exclude=['zero-degree', 'dglgraph']))
75
76
77
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
78
79
80
def test_graph_conv(idtype, g, norm, weight, bias):
    # Test one tensor input
    g = g.astype(idtype).to(F.ctx())
81
82
    conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias).to(F.ctx())
    ext_w = F.randn((5, 2)).to(F.ctx())
83
84
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
85
86
    h = F.randn((nsrc, 5)).to(F.ctx())
    if weight:
87
        h_out = conv(g, h)
88
    else:
89
90
91
        h_out = conv(g, h, weight=ext_w)
    assert h_out.shape == (ndst, 2)

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
def test_graph_conv_bi(idtype, g, norm, weight, bias):
    # Test a pair of tensor inputs
    g = g.astype(idtype).to(F.ctx())
    conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias).to(F.ctx())
    ext_w = F.randn((5, 2)).to(F.ctx())
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5)).to(F.ctx())
    h_dst = F.randn((ndst, 2)).to(F.ctx())
    if weight:
        h_out = conv(g, (h, h_dst))
    else:
        h_out = conv(g, (h, h_dst), weight=ext_w)
    assert h_out.shape == (ndst, 2)
111

112
113
114
115
116
117
118
119
120
121
122
123
def _S2AXWb(A, N, X, W, b):
    X1 = X * N
    X1 = th.matmul(A, X1.view(X1.shape[0], -1))
    X1 = X1 * N
    X2 = X1 * N
    X2 = th.matmul(A, X2.view(X2.shape[0], -1))
    X2 = X2 * N
    X = th.cat([X, X1, X2], dim=-1)
    Y = th.matmul(X, W.rot90())

    return Y + b

124
def test_tagconv():
125
    g = dgl.DGLGraph(nx.path_graph(3))
126
    g = g.to(F.ctx())
127
128
129
130
    ctx = F.ctx()
    adj = g.adjacency_matrix(ctx=ctx)
    norm = th.pow(g.in_degrees().float(), -0.5)

131
    conv = nn.TAGConv(5, 2, bias=True)
132
    conv = conv.to(ctx)
133
134
135
136
    print(conv)

    # test#1: basic
    h0 = F.ones((3, 5))
137
    h1 = conv(g, h0)
138
139
140
141
142
143
144
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    shp = norm.shape + (1,) * (h0.dim() - 1)
    norm = th.reshape(norm, shp).to(ctx)

    assert F.allclose(h1, _S2AXWb(adj, norm, h0, conv.lin.weight, conv.lin.bias))

145
    conv = nn.TAGConv(5, 2)
146
    conv = conv.to(ctx)
147

148
149
    # test#2: basic
    h0 = F.ones((3, 5))
150
    h1 = conv(g, h0)
151
    assert h1.shape[-1] == 2
152

153
    # test reset_parameters
154
155
156
157
158
    old_weight = deepcopy(conv.lin.weight.data)
    conv.reset_parameters()
    new_weight = conv.lin.weight.data
    assert not F.allclose(old_weight, new_weight)

159
def test_set2set():
160
    ctx = F.ctx()
161
    g = dgl.DGLGraph(nx.path_graph(10))
162
    g = g.to(F.ctx())
163
164

    s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
165
    s2s = s2s.to(ctx)
166
167
168
    print(s2s)

    # test#1: basic
169
    h0 = F.randn((g.number_of_nodes(), 5))
170
    h1 = s2s(g, h0)
171
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
172
173

    # test#2: batched graph
174
175
    g1 = dgl.DGLGraph(nx.path_graph(11)).to(F.ctx())
    g2 = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
176
    bg = dgl.batch([g, g1, g2])
177
    h0 = F.randn((bg.number_of_nodes(), 5))
178
    h1 = s2s(bg, h0)
179
180
181
    assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2

def test_glob_att_pool():
182
    ctx = F.ctx()
183
    g = dgl.DGLGraph(nx.path_graph(10))
184
    g = g.to(F.ctx())
185
186

    gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))
187
    gap = gap.to(ctx)
188
189
190
    print(gap)

    # test#1: basic
191
    h0 = F.randn((g.number_of_nodes(), 5))
192
    h1 = gap(g, h0)
193
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2
194
195
196

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
197
    h0 = F.randn((bg.number_of_nodes(), 5))
198
    h1 = gap(bg, h0)
199
200
201
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2

def test_simple_pool():
202
    ctx = F.ctx()
203
    g = dgl.DGLGraph(nx.path_graph(15))
204
    g = g.to(F.ctx())
205
206
207
208
209
210
211
212

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
    sort_pool = nn.SortPooling(10) # k = 10
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
213
    h0 = F.randn((g.number_of_nodes(), 5))
214
215
216
217
    sum_pool = sum_pool.to(ctx)
    avg_pool = avg_pool.to(ctx)
    max_pool = max_pool.to(ctx)
    sort_pool = sort_pool.to(ctx)
218
    h1 = sum_pool(g, h0)
219
    assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
220
    h1 = avg_pool(g, h0)
221
    assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
222
    h1 = max_pool(g, h0)
223
    assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
224
    h1 = sort_pool(g, h0)
225
    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.dim() == 2
226
227

    # test#2: batched graph
228
    g_ = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
229
    bg = dgl.batch([g, g_, g, g_, g])
230
    h0 = F.randn((bg.number_of_nodes(), 5))
231
    h1 = sum_pool(bg, h0)
232
233
234
235
236
237
    truth = th.stack([F.sum(h0[:15], 0),
                      F.sum(h0[15:20], 0),
                      F.sum(h0[20:35], 0),
                      F.sum(h0[35:40], 0),
                      F.sum(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
238

239
    h1 = avg_pool(bg, h0)
240
241
242
243
244
245
    truth = th.stack([F.mean(h0[:15], 0),
                      F.mean(h0[15:20], 0),
                      F.mean(h0[20:35], 0),
                      F.mean(h0[35:40], 0),
                      F.mean(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
246

247
    h1 = max_pool(bg, h0)
248
249
250
251
252
253
    truth = th.stack([F.max(h0[:15], 0),
                      F.max(h0[15:20], 0),
                      F.max(h0[20:35], 0),
                      F.max(h0[35:40], 0),
                      F.max(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
254

255
    h1 = sort_pool(bg, h0)
256
257
258
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2

def test_set_trans():
259
    ctx = F.ctx()
260
261
262
263
264
    g = dgl.DGLGraph(nx.path_graph(15))

    st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'sab')
    st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'isab', 3)
    st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)
265
266
267
    st_enc_0 = st_enc_0.to(ctx)
    st_enc_1 = st_enc_1.to(ctx)
    st_dec = st_dec.to(ctx)
268
269
270
    print(st_enc_0, st_enc_1, st_dec)

    # test#1: basic
271
    h0 = F.randn((g.number_of_nodes(), 50))
272
    h1 = st_enc_0(g, h0)
273
    assert h1.shape == h0.shape
274
    h1 = st_enc_1(g, h0)
275
    assert h1.shape == h0.shape
276
    h2 = st_dec(g, h1)
277
    assert h2.shape[0] == 1 and h2.shape[1] == 200 and h2.dim() == 2
278
279
280
281
282

    # test#2: batched graph
    g1 = dgl.DGLGraph(nx.path_graph(5))
    g2 = dgl.DGLGraph(nx.path_graph(10))
    bg = dgl.batch([g, g1, g2])
283
    h0 = F.randn((bg.number_of_nodes(), 50))
284
    h1 = st_enc_0(bg, h0)
285
    assert h1.shape == h0.shape
286
    h1 = st_enc_1(bg, h0)
287
288
    assert h1.shape == h0.shape

289
    h2 = st_dec(bg, h1)
290
291
    assert h2.shape[0] == 3 and h2.shape[1] == 200 and h2.dim() == 2

292
def uniform_attention(g, shape):
293
    a = F.ones(shape)
294
    target_shape = (g.number_of_edges(),) + (1,) * (len(shape) - 1)
295
    return a / g.in_degrees(g.edges(order='eid')[1]).view(target_shape).float()
296

297
298
@parametrize_dtype
def test_edge_softmax(idtype):
299
    # Basic
300
    g = dgl.graph(nx.path_graph(3))
301
    g = g.astype(idtype).to(F.ctx())
302
    edata = F.ones((g.number_of_edges(), 1))
303
    a = nn.edge_softmax(g, edata)
304
305
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
306
    assert F.allclose(a, uniform_attention(g, a.shape))
307

308
    # Test higher dimension case
309
    edata = F.ones((g.number_of_edges(), 3, 1))
310
    a = nn.edge_softmax(g, edata)
311
312
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
313
    assert F.allclose(a, uniform_attention(g, a.shape))
314

315
    # Test both forward and backward with PyTorch built-in softmax.
316
    g = dgl.rand_graph(30, 900)
317
    g = g.astype(idtype).to(F.ctx())
318

319
    score = F.randn((900, 1))
320
    score.requires_grad_()
321
322
    grad = F.randn((900, 1))
    y = F.softmax(score.view(30, 30), dim=0).view(-1, 1)
323
324
325
326
    y.backward(grad)
    grad_score = score.grad
    score.grad.zero_()
    y_dgl = nn.edge_softmax(g, score)
327
328
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
329
    # check forward
330
    assert F.allclose(y_dgl, y)
331
332
    y_dgl.backward(grad)
    # checkout gradient
333
    assert F.allclose(score.grad, grad_score)
334
335
    print(score.grad[:10], grad_score[:10])
    
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite', 'homo'], exclude=['zero-degree', 'dglgraph']))
def test_edge_softmax2(idtype, g):
    g = g.astype(idtype).to(F.ctx())
    g = g.local_var()
    g.srcdata.clear()
    g.dstdata.clear()
    g.edata.clear()
    a1 = F.randn((g.number_of_edges(), 1)).requires_grad_()
    a2 = a1.clone().detach().requires_grad_()
    g.edata['s'] = a1
    g.group_apply_edges('dst', lambda edges: {'ss':F.softmax(edges.data['s'], 1)})
    g.edata['ss'].sum().backward()
    
    builtin_sm = nn.edge_softmax(g, a2)
    builtin_sm.sum().backward()
    #print(a1.grad - a2.grad)
    assert len(g.srcdata) == 0
    assert len(g.dstdata) == 0
    assert len(g.edata) == 2
    assert F.allclose(a1.grad, a2.grad, rtol=1e-4, atol=1e-4) # Follow tolerance in unittest backend
357
    """
358
    # Test 2
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
    def generate_rand_graph(n, m=None, ctor=dgl.DGLGraph):
        if m is None:
            m = n
        arr = (sp.sparse.random(m, n, density=0.1, format='coo') != 0).astype(np.int64)
        return ctor(arr, readonly=True)

    for g in [generate_rand_graph(50),
              generate_rand_graph(50, ctor=dgl.graph),
              generate_rand_graph(100, 50, ctor=dgl.bipartite)]:
        a1 = F.randn((g.number_of_edges(), 1)).requires_grad_()
        a2 = a1.clone().detach().requires_grad_()
        g.edata['s'] = a1
        g.group_apply_edges('dst', lambda edges: {'ss':F.softmax(edges.data['s'], 1)})
        g.edata['ss'].sum().backward()
        
        builtin_sm = nn.edge_softmax(g, a2)
        builtin_sm.sum().backward()
        print(a1.grad - a2.grad)
        assert len(g.srcdata) == 0
        assert len(g.dstdata) == 0
        assert len(g.edata) == 2
        assert F.allclose(a1.grad, a2.grad, rtol=1e-4, atol=1e-4) # Follow tolerance in unittest backend
381
    """
382

383
384
@parametrize_dtype
def test_partial_edge_softmax(idtype):
385
    g = dgl.rand_graph(30, 900)
386
    g = g.astype(idtype).to(F.ctx())
387
388
389
390
391

    score = F.randn((300, 1))
    score.requires_grad_()
    grad = F.randn((300, 1))
    import numpy as np
392
393
    eids = np.random.choice(900, 300, replace=False)
    eids = F.tensor(eids, dtype=g.idtype)
394
395
396
397
398
399
    # compute partial edge softmax
    y_1 = nn.edge_softmax(g, score, eids)
    y_1.backward(grad)
    grad_1 = score.grad
    score.grad.zero_()
    # compute edge softmax on edge subgraph
400
    subg = g.edge_subgraph(eids, preserve_nodes=True)
401
402
403
404
405
406
407
408
    y_2 = nn.edge_softmax(subg, score)
    y_2.backward(grad)
    grad_2 = score.grad
    score.grad.zero_()

    assert F.allclose(y_1, y_2)
    assert F.allclose(grad_1, grad_2)

Minjie Wang's avatar
Minjie Wang committed
409
410
411
412
def test_rgcn():
    ctx = F.ctx()
    etype = []
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
413
    g = g.to(F.ctx())
Minjie Wang's avatar
Minjie Wang committed
414
415
416
417
418
419
420
421
422
    # 5 etypes
    R = 5
    for i in range(g.number_of_edges()):
        etype.append(i % 5)
    B = 2
    I = 10
    O = 8

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
423
424
425
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
Minjie Wang's avatar
Minjie Wang committed
426
427
428
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r)
429
    h_new_low = rgc_basis_low(g, h, r)
Minjie Wang's avatar
Minjie Wang committed
430
    assert list(h_new.shape) == [100, O]
431
432
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
433
434

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
435
436
    rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
    rgc_bdd_low.weight = rgc_bdd.weight
Minjie Wang's avatar
Minjie Wang committed
437
438
439
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_bdd(g, h, r)
440
    h_new_low = rgc_bdd_low(g, h, r)
Minjie Wang's avatar
Minjie Wang committed
441
    assert list(h_new.shape) == [100, O]
442
443
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
444
445
446
447
448

    # with norm
    norm = th.zeros((g.number_of_edges(), 1)).to(ctx)

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
449
450
451
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
Minjie Wang's avatar
Minjie Wang committed
452
453
454
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r, norm)
455
    h_new_low = rgc_basis_low(g, h, r, norm)
Minjie Wang's avatar
Minjie Wang committed
456
    assert list(h_new.shape) == [100, O]
457
458
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
459
460

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
461
462
    rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
    rgc_bdd_low.weight = rgc_bdd.weight
Minjie Wang's avatar
Minjie Wang committed
463
464
465
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_bdd(g, h, r, norm)
466
    h_new_low = rgc_bdd_low(g, h, r, norm)
Minjie Wang's avatar
Minjie Wang committed
467
    assert list(h_new.shape) == [100, O]
468
469
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
Minjie Wang's avatar
Minjie Wang committed
470
471
472

    # id input
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
473
474
475
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
Minjie Wang's avatar
Minjie Wang committed
476
477
478
    h = th.randint(0, I, (100,)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r)
479
    h_new_low = rgc_basis_low(g, h, r)
Minjie Wang's avatar
Minjie Wang committed
480
    assert list(h_new.shape) == [100, O]
481
482
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
483

484
485
486
487
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
def test_gat_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
488
489
    ctx = F.ctx()
    gat = nn.GATConv(5, 2, 4)
490
    feat = F.randn((g.number_of_nodes(), 5))
491
    gat = gat.to(ctx)
492
    h = gat(g, feat)
493
    assert h.shape == (g.number_of_nodes(), 4, 2)
494

495
496
497
498
499
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
def test_gat_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
500
    gat = nn.GATConv((5, 10), 2, 4)
501
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 10)))
502
503
    gat = gat.to(ctx)
    h = gat(g, feat)
504
    assert h.shape == (g.number_of_dst_nodes(), 4, 2)
505

506
507
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
508
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm'])
509
510
def test_sage_conv(idtype, g, aggre_type):
    g = g.astype(idtype).to(F.ctx())
511
    sage = nn.SAGEConv(5, 10, aggre_type)
512
513
    feat = F.randn((g.number_of_nodes(), 5))
    sage = sage.to(F.ctx())
514
515
516
    h = sage(g, feat)
    assert h.shape[-1] == 10

517
518
519
520
521
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm'])
def test_sage_conv_bi(idtype, g, aggre_type):
    g = g.astype(idtype).to(F.ctx())
522
523
    dst_dim = 5 if aggre_type != 'gcn' else 10
    sage = nn.SAGEConv((10, dst_dim), 2, aggre_type)
524
525
    feat = (F.randn((g.number_of_src_nodes(), 10)), F.randn((g.number_of_dst_nodes(), dst_dim)))
    sage = sage.to(F.ctx())
526
527
    h = sage(g, feat)
    assert h.shape[-1] == 2
528
    assert h.shape[0] == g.number_of_dst_nodes()
529

530
531
532
@parametrize_dtype
def test_sage_conv2(idtype):
    # TODO: add test for blocks
Mufei Li's avatar
Mufei Li committed
533
534
    # Test the case for graphs without edges
    g = dgl.bipartite([], num_nodes=(5, 3))
535
536
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
Mufei Li's avatar
Mufei Li committed
537
538
539
540
541
542
543
544
545
546
547
548
549
550
    sage = nn.SAGEConv((3, 3), 2, 'gcn')
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    sage = sage.to(ctx)
    h = sage(g, feat)
    assert h.shape[-1] == 2
    assert h.shape[0] == 3
    for aggre_type in ['mean', 'pool', 'lstm']:
        sage = nn.SAGEConv((3, 1), 2, aggre_type)
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        sage = sage.to(ctx)
        h = sage(g, feat)
        assert h.shape[-1] == 2
        assert h.shape[0] == 3

551
552
553
def test_sgc_conv():
    ctx = F.ctx()
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
554
    g = g.to(F.ctx())
555
556
557
    # not cached
    sgc = nn.SGConv(5, 10, 3)
    feat = F.randn((100, 5))
558
    sgc = sgc.to(ctx)
559

560
    h = sgc(g, feat)
561
562
563
564
    assert h.shape[-1] == 10

    # cached
    sgc = nn.SGConv(5, 10, 3, True)
565
    sgc = sgc.to(ctx)
566
567
    h_0 = sgc(g, feat)
    h_1 = sgc(g, feat + 1)
568
569
570
571
572
573
    assert F.allclose(h_0, h_1)
    assert h_0.shape[-1] == 10

def test_appnp_conv():
    ctx = F.ctx()
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
574
    g = g.to(F.ctx())
575
576
    appnp = nn.APPNPConv(10, 0.1)
    feat = F.randn((100, 5))
577
    appnp = appnp.to(ctx)
578

579
    h = appnp(g, feat)
580
581
    assert h.shape[-1] == 5

582
583
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
584
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
585
586
def test_gin_conv(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
587
588
589
590
591
    ctx = F.ctx()
    gin = nn.GINConv(
        th.nn.Linear(5, 12),
        aggregator_type
    )
592
    feat = F.randn((g.number_of_nodes(), 5))
593
594
    gin = gin.to(ctx)
    h = gin(g, feat)
595
    assert h.shape == (g.number_of_nodes(), 12)
596

597
598
599
600
601
602
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
def test_gin_conv_bi(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
603
604
605
606
    gin = nn.GINConv(
        th.nn.Linear(5, 12),
        aggregator_type
    )
607
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
608
609
    gin = gin.to(ctx)
    h = gin(g, feat)
610
    assert h.shape == (g.number_of_dst_nodes(), 12)
611

612
613
614
615
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
def test_agnn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
616
617
    ctx = F.ctx()
    agnn = nn.AGNNConv(1)
618
    feat = F.randn((g.number_of_nodes(), 5))
619
    agnn = agnn.to(ctx)
620
    h = agnn(g, feat)
621
    assert h.shape == (g.number_of_nodes(), 5)
622

623
624
625
626
627
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
def test_agnn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
628
    agnn = nn.AGNNConv(1)
629
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
630
631
    agnn = agnn.to(ctx)
    h = agnn(g, feat)
632
    assert h.shape == (g.number_of_dst_nodes(), 5)
633

634
635
636
def test_gated_graph_conv():
    ctx = F.ctx()
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
637
    g = g.to(F.ctx())
638
639
640
    ggconv = nn.GatedGraphConv(5, 10, 5, 3)
    etypes = th.arange(g.number_of_edges()) % 3
    feat = F.randn((100, 5))
641
642
    ggconv = ggconv.to(ctx)
    etypes = etypes.to(ctx)
643

644
    h = ggconv(g, feat, etypes)
645
646
647
    # current we only do shape check
    assert h.shape[-1] == 10

648
649
650
651
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
def test_nn_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
652
653
654
    ctx = F.ctx()
    edge_func = th.nn.Linear(4, 5 * 10)
    nnconv = nn.NNConv(5, 10, edge_func, 'mean')
655
    feat = F.randn((g.number_of_nodes(), 5))
656
657
658
659
660
661
    efeat = F.randn((g.number_of_edges(), 4))
    nnconv = nnconv.to(ctx)
    h = nnconv(g, feat, efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

662
663
664
665
666
667
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
def test_nn_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    #g = dgl.bipartite(sp.sparse.random(50, 100, density=0.1))
668
669
    edge_func = th.nn.Linear(4, 5 * 10)
    nnconv = nn.NNConv((5, 2), 10, edge_func, 'mean')
670
671
    feat = F.randn((g.number_of_src_nodes(), 5))
    feat_dst = F.randn((g.number_of_dst_nodes(), 2))
672
673
674
675
676
677
    efeat = F.randn((g.number_of_edges(), 4))
    nnconv = nnconv.to(ctx)
    h = nnconv(g, (feat, feat_dst), efeat)
    # currently we only do shape check
    assert h.shape[-1] == 10

678
679
680
681
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo']))
def test_gmm_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
682
683
    ctx = F.ctx()
    gmmconv = nn.GMMConv(5, 10, 3, 4, 'mean')
684
    feat = F.randn((g.number_of_nodes(), 5))
685
    pseudo = F.randn((g.number_of_edges(), 3))
686
    gmmconv = gmmconv.to(ctx)
687
    h = gmmconv(g, feat, pseudo)
688
689
690
    # currently we only do shape check
    assert h.shape[-1] == 10

691
692
693
694
695
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite', 'block-bipartite']))
def test_gmm_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
696
    gmmconv = nn.GMMConv((5, 2), 10, 3, 4, 'mean')
697
698
    feat = F.randn((g.number_of_src_nodes(), 5))
    feat_dst = F.randn((g.number_of_dst_nodes(), 2))
699
700
701
702
703
704
    pseudo = F.randn((g.number_of_edges(), 3))
    gmmconv = gmmconv.to(ctx)
    h = gmmconv(g, (feat, feat_dst), pseudo)
    # currently we only do shape check
    assert h.shape[-1] == 10

705
@parametrize_dtype
706
@pytest.mark.parametrize('norm_type', ['both', 'right', 'none'])
707
708
709
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite']))
def test_dense_graph_conv(norm_type, g, idtype):
    g = g.astype(idtype).to(F.ctx())
710
    ctx = F.ctx()
711
    # TODO(minjie): enable the following option after #1385
712
    adj = g.adjacency_matrix(ctx=ctx).to_dense()
713
714
    conv = nn.GraphConv(5, 2, norm=norm_type, bias=True)
    dense_conv = nn.DenseGraphConv(5, 2, norm=norm_type, bias=True)
715
716
    dense_conv.weight.data = conv.weight.data
    dense_conv.bias.data = conv.bias.data
717
    feat = F.randn((g.number_of_src_nodes(), 5))
718
719
    conv = conv.to(ctx)
    dense_conv = dense_conv.to(ctx)
720
721
    out_conv = conv(g, feat)
    out_dense_conv = dense_conv(adj, feat)
722
723
    assert F.allclose(out_conv, out_dense_conv)

724
725
726
727
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite']))
def test_dense_sage_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
728
729
    ctx = F.ctx()
    adj = g.adjacency_matrix(ctx=ctx).to_dense()
730
    sage = nn.SAGEConv(5, 2, 'gcn')
731
732
733
    dense_sage = nn.DenseSAGEConv(5, 2)
    dense_sage.fc.weight.data = sage.fc_neigh.weight.data
    dense_sage.fc.bias.data = sage.fc_neigh.bias.data
734
735
736
737
738
739
740
    if len(g.ntypes) == 2:
        feat = (
            F.randn((g.number_of_src_nodes(), 5)),
            F.randn((g.number_of_dst_nodes(), 5))
        )
    else:
        feat = F.randn((g.number_of_nodes(), 5))
741
742
    sage = sage.to(ctx)
    dense_sage = dense_sage.to(ctx)
743
744
    out_sage = sage(g, feat)
    out_dense_sage = dense_sage(adj, feat)
745
746
    assert F.allclose(out_sage, out_dense_sage), g

747
748
749
750
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
def test_edge_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
751
752
753
    ctx = F.ctx()
    edge_conv = nn.EdgeConv(5, 2).to(ctx)
    print(edge_conv)
754
755
756
    h0 = F.randn((g.number_of_nodes(), 5))
    h1 = edge_conv(g, h0)
    assert h1.shape == (g.number_of_nodes(), 2)
757

758
759
760
761
762
763
764
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
def test_edge_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    edge_conv = nn.EdgeConv(5, 2).to(ctx)
    print(edge_conv)
765
    h0 = F.randn((g.number_of_src_nodes(), 5))
766
767
    x0 = F.randn((g.number_of_dst_nodes(), 5))
    h1 = edge_conv(g, (h0, x0))
768
    assert h1.shape == (g.number_of_dst_nodes(), 2)
769
770
771
772
773

def test_dense_cheb_conv():
    for k in range(1, 4):
        ctx = F.ctx()
        g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
774
        g = g.to(F.ctx())
775
        adj = g.adjacency_matrix(ctx=ctx).to_dense()
Axel Nilsson's avatar
Axel Nilsson committed
776
        cheb = nn.ChebConv(5, 2, k, None)
777
        dense_cheb = nn.DenseChebConv(5, 2, k)
Axel Nilsson's avatar
Axel Nilsson committed
778
779
780
781
782
        #for i in range(len(cheb.fc)):
        #    dense_cheb.W.data[i] = cheb.fc[i].weight.data.t()
        dense_cheb.W.data = cheb.linear.weight.data.transpose(-1, -2).view(k, 5, 2)
        if cheb.linear.bias is not None:
            dense_cheb.bias.data = cheb.linear.bias.data
783
        feat = F.randn((100, 5))
784
785
        cheb = cheb.to(ctx)
        dense_cheb = dense_cheb.to(ctx)
786
787
        out_cheb = cheb(g, feat, [2.0])
        out_dense_cheb = dense_cheb(adj, feat, 2.0)
Axel Nilsson's avatar
Axel Nilsson committed
788
        print(k, out_cheb, out_dense_cheb)
789
790
        assert F.allclose(out_cheb, out_dense_cheb)

791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
def test_sequential():
    ctx = F.ctx()
    # Test single graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat, e_feat):
            graph = graph.local_var()
            graph.ndata['h'] = n_feat
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            n_feat += graph.ndata['h']
            graph.apply_edges(fn.u_add_v('h', 'h', 'e'))
            e_feat += graph.edata['e']
            return n_feat, e_feat

    g = dgl.DGLGraph()
    g.add_nodes(3)
    g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
810
    g = g.to(F.ctx())
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    n_feat = F.randn((3, 4))
    e_feat = F.randn((9, 4))
    net = net.to(ctx)
    n_feat, e_feat = net(g, n_feat, e_feat)
    assert n_feat.shape == (3, 4)
    assert e_feat.shape == (9, 4)

    # Test multiple graph
    class ExampleLayer(th.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, graph, n_feat):
            graph = graph.local_var()
            graph.ndata['h'] = n_feat
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            n_feat += graph.ndata['h']
            return n_feat.view(graph.number_of_nodes() // 2, 2, -1).sum(1)

831
832
833
    g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05)).to(F.ctx())
    g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2)).to(F.ctx())
    g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8)).to(F.ctx())
834
835
836
837
838
839
    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    net = net.to(ctx)
    n_feat = F.randn((32, 4))
    n_feat = net([g1, g2, g3], n_feat)
    assert n_feat.shape == (4, 4)

840
def test_atomic_conv():
841
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True).to(F.ctx())
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
    aconv = nn.AtomicConv(interaction_cutoffs=F.tensor([12.0, 12.0]),
                          rbf_kernel_means=F.tensor([0.0, 2.0]),
                          rbf_kernel_scaling=F.tensor([4.0, 4.0]),
                          features_to_use=F.tensor([6.0, 8.0]))

    ctx = F.ctx()
    if F.gpu_ctx():
        aconv = aconv.to(ctx)

    feat = F.randn((100, 1))
    dist = F.randn((g.number_of_edges(), 1))

    h = aconv(g, feat, dist)
    # current we only do shape check
    assert h.shape[-1] == 4

def test_cf_conv():
859
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True).to(F.ctx())
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
    cfconv = nn.CFConv(node_in_feats=2,
                       edge_in_feats=3,
                       hidden_feats=2,
                       out_feats=3)

    ctx = F.ctx()
    if F.gpu_ctx():
        cfconv = cfconv.to(ctx)

    node_feats = F.randn((100, 2))
    edge_feats = F.randn((g.number_of_edges(), 3))
    h = cfconv(g, node_feats, edge_feats)
    # current we only do shape check
    assert h.shape[-1] == 3    

875
876
877
878
879
880
def myagg(alist, dsttype):
    rst = alist[0]
    for i in range(1, len(alist)):
        rst = rst + (i + 1) * alist[i]
    return rst

881
@parametrize_dtype
882
@pytest.mark.parametrize('agg', ['sum', 'max', 'min', 'mean', 'stack', myagg])
883
def test_hetero_conv(agg, idtype):
884
885
886
    g = dgl.heterograph({
        ('user', 'follows', 'user'): [(0, 1), (0, 2), (2, 1), (1, 3)],
        ('user', 'plays', 'game'): [(0, 0), (0, 2), (0, 3), (1, 0), (2, 2)],
887
888
        ('store', 'sells', 'game'): [(0, 0), (0, 3), (1, 1), (1, 2)]},
        idtype=idtype, device=F.ctx())
889
890
891
892
893
    conv = nn.HeteroGraphConv({
        'follows': nn.GraphConv(2, 3),
        'plays': nn.GraphConv(2, 4),
        'sells': nn.GraphConv(3, 4)},
        agg)
894
    conv = conv.to(F.ctx())
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))
    uf_dst = F.randn((4, 3))
    gf_dst = F.randn((4, 4))

    h = conv(g, {'user': uf})
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 1, 4)

    h = conv(g, {'user': uf, 'store': sf})
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 2, 4)

    h = conv(g, {'store': sf})
    assert set(h.keys()) == {'game'}
    if agg != 'stack':
        assert h['game'].shape == (4, 4)
    else:
        assert h['game'].shape == (4, 1, 4)

    # test with pair input
    conv = nn.HeteroGraphConv({
        'follows': nn.SAGEConv(2, 3, 'mean'),
        'plays': nn.SAGEConv((2, 4), 4, 'mean'),
        'sells': nn.SAGEConv(3, 4, 'mean')},
        agg)
932
    conv = conv.to(F.ctx())
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972

    h = conv(g, ({'user': uf}, {'user' : uf, 'game' : gf}))
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 1, 4)

    # pair input requires both src and dst type features to be provided
    h = conv(g, ({'user': uf}, {'game' : gf}))
    assert set(h.keys()) == {'game'}
    if agg != 'stack':
        assert h['game'].shape == (4, 4)
    else:
        assert h['game'].shape == (4, 1, 4)

    # test with mod args
    class MyMod(th.nn.Module):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.carg2 = 0
            self.s1 = s1
            self.s2 = s2
        def forward(self, g, h, arg1=None, *, arg2=None):
            if arg1 is not None:
                self.carg1 += 1
            if arg2 is not None:
                self.carg2 += 1
            return th.zeros((g.number_of_dst_nodes(), self.s2))
    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
    conv = nn.HeteroGraphConv({
        'follows': mod1,
        'plays': mod2,
        'sells': mod3},
        agg)
973
    conv = conv.to(F.ctx())
974
975
976
977
978
979
980
981
982
983
    mod_args = {'follows' : (1,), 'plays' : (1,)}
    mod_kwargs = {'sells' : {'arg2' : 'abc'}}
    h = conv(g, {'user' : uf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs)
    assert mod1.carg1 == 1
    assert mod1.carg2 == 0
    assert mod2.carg1 == 1
    assert mod2.carg2 == 0
    assert mod3.carg1 == 0
    assert mod3.carg2 == 1

984
985
986
if __name__ == '__main__':
    test_graph_conv()
    test_edge_softmax()
987
    test_partial_edge_softmax()
988
989
990
991
    test_set2set()
    test_glob_att_pool()
    test_simple_pool()
    test_set_trans()
Minjie Wang's avatar
Minjie Wang committed
992
    test_rgcn()
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
    test_tagconv()
    test_gat_conv()
    test_sage_conv()
    test_sgc_conv()
    test_appnp_conv()
    test_gin_conv()
    test_agnn_conv()
    test_gated_graph_conv()
    test_nn_conv()
    test_gmm_conv()
    test_dense_graph_conv()
    test_dense_sage_conv()
    test_dense_cheb_conv()
1006
    test_sequential()
1007
1008
    test_atomic_conv()
    test_cf_conv()