test_nn.py 18.6 KB
Newer Older
1
2
3
import tensorflow as tf
from tensorflow.keras import layers
import networkx as nx
4
import pytest
5
6
7
8
import dgl
import dgl.nn.tensorflow as nn
import dgl.function as fn
import backend as F
9
10
from test_utils.graph_cases import get_cases, random_graph, random_bipartite, random_dglgraph
from test_utils import parametrize_dtype
11
12
13
14
15
16
17
18
19
20
from copy import deepcopy

import numpy as np
import scipy as sp

def _AXWb(A, X, W, b):
    X = tf.matmul(X, W)
    Y = tf.reshape(tf.matmul(A, tf.reshape(X, (X.shape[0], -1))), X.shape)
    return Y + b

21
22
@pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv(out_dim):
23
    g = dgl.DGLGraph(nx.path_graph(3)).to(F.ctx())
24
    ctx = F.ctx()
25
    adj = tf.sparse.to_dense(tf.sparse.reorder(g.adjacency_matrix(transpose=False, ctx=ctx)))
26

27
    conv = nn.GraphConv(5, out_dim, norm='none', bias=True)
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    # conv = conv
    print(conv)
    # test#1: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
    # test#2: more-dim
    h0 = F.ones((3, 5, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))

43
    conv = nn.GraphConv(5, out_dim)
44
45
46
47
48
49
50
51
52
53
54
55
    # conv = conv
    # test#3: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    # test#4: basic
    h0 = F.ones((3, 5, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0

56
    conv = nn.GraphConv(5, out_dim)
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    # conv = conv
    # test#3: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    # test#4: basic
    h0 = F.ones((3, 5, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0

    # test rest_parameters
    # old_weight = deepcopy(conv.weight.data)
    # conv.reset_parameters()
    # new_weight = conv.weight.data
    # assert not F.allclose(old_weight, new_weight)

75
76
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree', 'dglgraph']))
77
78
79
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
80
81
@pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv2(idtype, g, norm, weight, bias, out_dim):
82
    g = g.astype(idtype).to(F.ctx())
83
84
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias)
    ext_w = F.randn((5, out_dim))
85
86
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
87
    h = F.randn((nsrc, 5))
88
    h_dst = F.randn((ndst, out_dim))
89
    if weight:
90
        h_out = conv(g, h)
91
    else:
92
        h_out = conv(g, h, weight=ext_w)
93
    assert h_out.shape == (ndst, out_dim)
94

95
96
97
98
99
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
100
101
@pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv2_bi(idtype, g, norm, weight, bias, out_dim):
102
    g = g.astype(idtype).to(F.ctx())
103
104
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias)
    ext_w = F.randn((5, out_dim))
105
106
107
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5))
108
    h_dst = F.randn((ndst, out_dim))
109
110
111
112
    if weight:
        h_out = conv(g, (h, h_dst))
    else:
        h_out = conv(g, (h, h_dst), weight=ext_w)
113
    assert h_out.shape == (ndst, out_dim)
114
115
116

def test_simple_pool():
    ctx = F.ctx()
117
    g = dgl.DGLGraph(nx.path_graph(15)).to(F.ctx())
118
119
120
121
122
123
124
125
126
127

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
    sort_pool = nn.SortPooling(10) # k = 10
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
    h0 = F.randn((g.number_of_nodes(), 5))
    h1 = sum_pool(g, h0)
128
    assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
129
    h1 = avg_pool(g, h0)
130
    assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
131
    h1 = max_pool(g, h0)
132
    assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
133
    h1 = sort_pool(g, h0)
134
    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.ndim == 2
135
136

    # test#2: batched graph
137
    g_ = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    bg = dgl.batch([g, g_, g, g_, g])
    h0 = F.randn((bg.number_of_nodes(), 5))
    h1 = sum_pool(bg, h0)
    truth = tf.stack([F.sum(h0[:15], 0),
                      F.sum(h0[15:20], 0),
                      F.sum(h0[20:35], 0),
                      F.sum(h0[35:40], 0),
                      F.sum(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)

    h1 = avg_pool(bg, h0)
    truth = tf.stack([F.mean(h0[:15], 0),
                      F.mean(h0[15:20], 0),
                      F.mean(h0[20:35], 0),
                      F.mean(h0[35:40], 0),
                      F.mean(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)

    h1 = max_pool(bg, h0)
    truth = tf.stack([F.max(h0[:15], 0),
                      F.max(h0[15:20], 0),
                      F.max(h0[20:35], 0),
                      F.max(h0[35:40], 0),
                      F.max(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)

    h1 = sort_pool(bg, h0)
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.ndim == 2

def test_glob_att_pool():
168
    g = dgl.DGLGraph(nx.path_graph(10)).to(F.ctx())
169
170
171
172
173
174
175

    gap = nn.GlobalAttentionPooling(layers.Dense(1), layers.Dense(10))
    print(gap)

    # test#1: basic
    h0 = F.randn((g.number_of_nodes(), 5))
    h1 = gap(g, h0)
176
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2
177
178
179
180
181
182
183
184

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
    h0 = F.randn((bg.number_of_nodes(), 5))
    h1 = gap(bg, h0)
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2


185
186
@pytest.mark.parametrize('O', [1, 2, 8])
def test_rgcn(O):
187
    etype = []
188
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True).to(F.ctx())
189
190
191
192
193
194
195
196
    # 5 etypes
    R = 5
    for i in range(g.number_of_edges()):
        etype.append(i % 5)
    B = 2
    I = 10

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
197
198
199
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
200
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
201
202
203
    h = tf.random.normal((100, I))
    r = tf.constant(etype)
    h_new = rgc_basis(g, h, r)
204
    h_new_low = rgc_basis_low(g, h, r)
205
    assert list(h_new.shape) == [100, O]
206
207
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
208

209
210
211
212
213
214
215
216
217
218
219
220
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
        rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True)
        rgc_bdd_low.weight = rgc_bdd.weight
        rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
        h = tf.random.normal((100, I))
        r = tf.constant(etype)
        h_new = rgc_bdd(g, h, r)
        h_new_low = rgc_bdd_low(g, h, r)
        assert list(h_new.shape) == [100, O]
        assert list(h_new_low.shape) == [100, O]
        assert F.allclose(h_new, h_new_low)
221
222
223
224
225

    # with norm
    norm = tf.zeros((g.number_of_edges(), 1))

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
226
227
228
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
229
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
230
231
232
    h = tf.random.normal((100, I))
    r = tf.constant(etype)
    h_new = rgc_basis(g, h, r, norm)
233
    h_new_low = rgc_basis_low(g, h, r, norm)
234
    assert list(h_new.shape) == [100, O]
235
236
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
237

238
239
240
241
242
243
244
245
246
247
248
249
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
        rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True)
        rgc_bdd_low.weight = rgc_bdd.weight
        rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
        h = tf.random.normal((100, I))
        r = tf.constant(etype)
        h_new = rgc_bdd(g, h, r, norm)
        h_new_low = rgc_bdd_low(g, h, r, norm)
        assert list(h_new.shape) == [100, O]
        assert list(h_new_low.shape) == [100, O]
        assert F.allclose(h_new, h_new_low)
250
251
252

    # id input
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
253
254
255
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
256
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
257
258
    h = tf.constant(np.random.randint(0, I, (100,))) * 1
    r = tf.constant(etype) * 1
259
    h_new = rgc_basis(g, h, r)
260
    h_new_low = rgc_basis_low(g, h, r)
261
    assert list(h_new.shape) == [100, O]
262
263
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
264

265
@parametrize_dtype
266
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
267
268
269
@pytest.mark.parametrize('out_dim', [1, 2])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_gat_conv(g, idtype, out_dim, num_heads):
270
271
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
272
    gat = nn.GATConv(5, out_dim, num_heads)
273
    feat = F.randn((g.number_of_nodes(), 5))
274
    h = gat(g, feat)
275
    assert h.shape == (g.number_of_nodes(), num_heads, out_dim)
276
    _, a = gat(g, feat, get_attention=True)
277
    assert a.shape == (g.number_of_edges(), num_heads, 1)
278

279
@parametrize_dtype
280
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
281
282
283
@pytest.mark.parametrize('out_dim', [1, 2])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_gat_conv_bi(g, idtype, out_dim, num_heads):
284
285
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
286
    gat = nn.GATConv(5, out_dim, num_heads)
287
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
288
    h = gat(g, feat)
289
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
290
    _, a = gat(g, feat, get_attention=True)
291
    assert a.shape == (g.number_of_edges(), num_heads, 1)
292

293
294
295
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
296
297
@pytest.mark.parametrize('out_dim', [1, 10])
def test_sage_conv(idtype, g, aggre_type, out_dim):
298
    g = g.astype(idtype).to(F.ctx())
299
    sage = nn.SAGEConv(5, out_dim, aggre_type)
300
    feat = F.randn((g.number_of_nodes(), 5))
301
    h = sage(g, feat)
302
    assert h.shape[-1] == out_dim
303

304
305
306
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
307
308
@pytest.mark.parametrize('out_dim', [1, 2])
def test_sage_conv_bi(idtype, g, aggre_type, out_dim):
309
    g = g.astype(idtype).to(F.ctx())
310
    dst_dim = 5 if aggre_type != 'gcn' else 10
311
    sage = nn.SAGEConv((10, dst_dim), out_dim, aggre_type)
312
    feat = (F.randn((g.number_of_src_nodes(), 10)), F.randn((g.number_of_dst_nodes(), dst_dim)))
313
    h = sage(g, feat)
314
    assert h.shape[-1] == out_dim
315
    assert h.shape[0] == g.number_of_dst_nodes()
316

317
318
@parametrize_dtype
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
319
320
@pytest.mark.parametrize('out_dim', [1, 2])
def test_sage_conv_bi_empty(idtype, aggre_type, out_dim):
Mufei Li's avatar
Mufei Li committed
321
    # Test the case for graphs without edges
322
    g = dgl.heterograph({('_U', '_E', '_V'): ([], [])}, {'_U': 5, '_V': 3}).to(F.ctx())
323
    g = g.astype(idtype).to(F.ctx())
324
    sage = nn.SAGEConv((3, 3), out_dim, 'gcn')
Mufei Li's avatar
Mufei Li committed
325
326
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    h = sage(g, feat)
327
    assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
328
329
    assert h.shape[0] == 3
    for aggre_type in ['mean', 'pool', 'lstm']:
330
        sage = nn.SAGEConv((3, 1), out_dim, aggre_type)
Mufei Li's avatar
Mufei Li committed
331
332
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        h = sage(g, feat)
333
        assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
334
335
        assert h.shape[0] == 3

336
337
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
338
339
@pytest.mark.parametrize('out_dim', [1, 2])
def test_sgc_conv(g, idtype, out_dim):
340
    ctx = F.ctx()
341
    g = g.astype(idtype).to(ctx)
342
    # not cached
343
    sgc = nn.SGConv(5, out_dim, 3)
344
    feat = F.randn((g.number_of_nodes(), 5))
345
346

    h = sgc(g, feat)
347
    assert h.shape[-1] == out_dim
348
349

    # cached
350
    sgc = nn.SGConv(5, out_dim, 3, True)
351
352
353
    h_0 = sgc(g, feat)
    h_1 = sgc(g, feat + 1)
    assert F.allclose(h_0, h_1)
354
    assert h_0.shape[-1] == out_dim
355

356
357
358
359
360
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_appnp_conv(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
361
    appnp = nn.APPNPConv(10, 0.1)
362
    feat = F.randn((g.number_of_nodes(), 5))
363
364
365
366

    h = appnp(g, feat)
    assert h.shape[-1] == 5

367
368
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
369
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
370
371
372
def test_gin_conv(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
373
374
375
376
    gin = nn.GINConv(
        tf.keras.layers.Dense(12),
        aggregator_type
    )
377
    feat = F.randn((g.number_of_nodes(), 5))
378
    h = gin(g, feat)
379
    assert h.shape == (g.number_of_nodes(), 12)
380

381
382
383
384
385
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
def test_gin_conv_bi(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
386
387
388
389
    gin = nn.GINConv(
        tf.keras.layers.Dense(12),
        aggregator_type
    )
390
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
391
    h = gin(g, feat)
392
    assert h.shape == (g.number_of_dst_nodes(), 12)
393

kyawlinoo's avatar
kyawlinoo committed
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_dim', [1, 2])
def test_edge_conv(g, idtype, out_dim):
    g = g.astype(idtype).to(F.ctx())
    edge_conv = nn.EdgeConv(out_dim)

    h0 = F.randn((g.number_of_nodes(), 5))
    h1 = edge_conv(g, h0)
    assert h1.shape == (g.number_of_nodes(), out_dim)

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_dim', [1, 2])
def test_edge_conv_bi(g, idtype, out_dim):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    edge_conv = nn.EdgeConv(out_dim)

    h0 = F.randn((g.number_of_src_nodes(), 5))
    x0 = F.randn((g.number_of_dst_nodes(), 5))
    h1 = edge_conv(g, (h0, x0))
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)

418
419
420
421
422
423
def myagg(alist, dsttype):
    rst = alist[0]
    for i in range(1, len(alist)):
        rst = rst + (i + 1) * alist[i]
    return rst

424
@parametrize_dtype
425
@pytest.mark.parametrize('agg', ['sum', 'max', 'min', 'mean', 'stack', myagg])
426
def test_hetero_conv(agg, idtype):
427
    g = dgl.heterograph({
428
429
430
        ('user', 'follows', 'user'): ([0, 0, 2, 1], [1, 2, 1, 3]),
        ('user', 'plays', 'game'): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]),
        ('store', 'sells', 'game'): ([0, 0, 1, 1], [0, 3, 1, 2])},
431
        idtype=idtype, device=F.ctx())
432
    conv = nn.HeteroGraphConv({
433
434
435
        'follows': nn.GraphConv(2, 3, allow_zero_in_degree=True),
        'plays': nn.GraphConv(2, 4, allow_zero_in_degree=True),
        'sells': nn.GraphConv(3, 4, allow_zero_in_degree=True)},
436
437
438
439
440
        agg)
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))

441
    h = conv(g, {'user': uf, 'store': sf, 'game': gf})
442
443
444
445
446
447
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
448
        assert h['game'].shape == (4, 2, 4)
449

450
451
    block = dgl.to_block(g.to(F.cpu()), {'user': [0, 1, 2, 3], 'game': [0, 1, 2, 3], 'store': []}).to(F.ctx())
    h = conv(block, ({'user': uf, 'game': gf, 'store': sf}, {'user': uf, 'game': gf, 'store': sf[0:0]}))
452
453
454
455
456
457
458
459
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 2, 4)

460
    h = conv(block, {'user': uf, 'game': gf, 'store': sf})
461
462
463
464
465
466
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
467
        assert h['game'].shape == (4, 2, 4)
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492

    # test with mod args
    class MyMod(tf.keras.layers.Layer):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.carg2 = 0
            self.s1 = s1
            self.s2 = s2
        def call(self, g, h, arg1=None, *, arg2=None):
            if arg1 is not None:
                self.carg1 += 1
            if arg2 is not None:
                self.carg2 += 1
            return tf.zeros((g.number_of_dst_nodes(), self.s2))
    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
    conv = nn.HeteroGraphConv({
        'follows': mod1,
        'plays': mod2,
        'sells': mod3},
        agg)
    mod_args = {'follows' : (1,), 'plays' : (1,)}
    mod_kwargs = {'sells' : {'arg2' : 'abc'}}
493
    h = conv(g, {'user' : uf, 'game': gf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs)
494
495
496
497
498
499
500
    assert mod1.carg1 == 1
    assert mod1.carg2 == 0
    assert mod2.carg1 == 1
    assert mod2.carg2 == 0
    assert mod3.carg1 == 0
    assert mod3.carg2 == 1

501

502
503
@pytest.mark.parametrize('out_dim', [1, 2])
def test_dense_cheb_conv(out_dim):
504
505
506
507
508
    for k in range(3, 4):
        ctx = F.ctx()
        g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1, random_state=42))
        g = g.to(ctx)

509
        adj = tf.sparse.to_dense(tf.sparse.reorder(g.adjacency_matrix(transpose=False, ctx=ctx)))
510
511
        cheb = nn.ChebConv(5, out_dim, k, None, bias=True)
        dense_cheb = nn.DenseChebConv(5, out_dim, k, bias=True)
512
513
514
515
516

        # init cheb modules
        feat = F.ones((100, 5))
        out_cheb = cheb(g, feat, [2.0])

517
        dense_cheb.W = tf.reshape(cheb.linear.weights[0], (k, 5, out_dim))
518
519
520
521
522
523
524
525
        if cheb.linear.bias is not None:
            dense_cheb.bias = cheb.linear.bias

        out_dense_cheb = dense_cheb(adj, feat, 2.0)
        print(out_cheb - out_dense_cheb)
        assert F.allclose(out_cheb, out_dense_cheb)


526
527
528
529
530
531
532
533
534
535
536
537
538
if __name__ == '__main__':
    test_graph_conv()
    # test_set2set()
    test_glob_att_pool()
    test_simple_pool()
    # test_set_trans()
    test_rgcn()
    # test_tagconv()
    test_gat_conv()
    test_sage_conv()
    test_sgc_conv()
    test_appnp_conv()
    test_gin_conv()
kyawlinoo's avatar
kyawlinoo committed
539
    test_edge_conv()
540
541
542
543
544
545
    # test_agnn_conv()
    # test_gated_graph_conv()
    # test_nn_conv()
    # test_gmm_conv()
    # test_dense_graph_conv()
    # test_dense_sage_conv()
546
    test_dense_cheb_conv()
547
    # test_sequential()