"tests/tensorflow/test_nn.py" did not exist on "31a7d50964ae1c9d5693661567e7a3e034383bd7"
test_nn.py 17.7 KB
Newer Older
1
2
3
import tensorflow as tf
from tensorflow.keras import layers
import networkx as nx
4
import pytest
5
6
7
8
import dgl
import dgl.nn.tensorflow as nn
import dgl.function as fn
import backend as F
9
10
from test_utils.graph_cases import get_cases, random_graph, random_bipartite, random_dglgraph
from test_utils import parametrize_dtype
11
12
13
14
15
16
17
18
19
20
from copy import deepcopy

import numpy as np
import scipy as sp

def _AXWb(A, X, W, b):
    X = tf.matmul(X, W)
    Y = tf.reshape(tf.matmul(A, tf.reshape(X, (X.shape[0], -1))), X.shape)
    return Y + b

21
22
@pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv(out_dim):
23
    g = dgl.DGLGraph(nx.path_graph(3)).to(F.ctx())
24
    ctx = F.ctx()
25
    adj = tf.sparse.to_dense(tf.sparse.reorder(g.adjacency_matrix(transpose=False, ctx=ctx)))
26

27
    conv = nn.GraphConv(5, out_dim, norm='none', bias=True)
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    # conv = conv
    print(conv)
    # test#1: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
    # test#2: more-dim
    h0 = F.ones((3, 5, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))

43
    conv = nn.GraphConv(5, out_dim)
44
45
46
47
48
49
50
51
52
53
54
55
    # conv = conv
    # test#3: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    # test#4: basic
    h0 = F.ones((3, 5, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0

56
    conv = nn.GraphConv(5, out_dim)
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    # conv = conv
    # test#3: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    # test#4: basic
    h0 = F.ones((3, 5, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0

    # test rest_parameters
    # old_weight = deepcopy(conv.weight.data)
    # conv.reset_parameters()
    # new_weight = conv.weight.data
    # assert not F.allclose(old_weight, new_weight)

75
76
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree', 'dglgraph']))
77
78
79
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
80
81
@pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv2(idtype, g, norm, weight, bias, out_dim):
82
    g = g.astype(idtype).to(F.ctx())
83
84
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias)
    ext_w = F.randn((5, out_dim))
85
86
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
87
    h = F.randn((nsrc, 5))
88
    h_dst = F.randn((ndst, out_dim))
89
    if weight:
90
        h_out = conv(g, h)
91
    else:
92
        h_out = conv(g, h, weight=ext_w)
93
    assert h_out.shape == (ndst, out_dim)
94

95
96
97
98
99
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
100
101
@pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv2_bi(idtype, g, norm, weight, bias, out_dim):
102
    g = g.astype(idtype).to(F.ctx())
103
104
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias)
    ext_w = F.randn((5, out_dim))
105
106
107
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5))
108
    h_dst = F.randn((ndst, out_dim))
109
110
111
112
    if weight:
        h_out = conv(g, (h, h_dst))
    else:
        h_out = conv(g, (h, h_dst), weight=ext_w)
113
    assert h_out.shape == (ndst, out_dim)
114
115
116

def test_simple_pool():
    ctx = F.ctx()
117
    g = dgl.DGLGraph(nx.path_graph(15)).to(F.ctx())
118
119
120
121
122
123
124
125
126
127

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
    sort_pool = nn.SortPooling(10) # k = 10
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
    h0 = F.randn((g.number_of_nodes(), 5))
    h1 = sum_pool(g, h0)
128
    assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
129
    h1 = avg_pool(g, h0)
130
    assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
131
    h1 = max_pool(g, h0)
132
    assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
133
    h1 = sort_pool(g, h0)
134
    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.ndim == 2
135
136

    # test#2: batched graph
137
    g_ = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    bg = dgl.batch([g, g_, g, g_, g])
    h0 = F.randn((bg.number_of_nodes(), 5))
    h1 = sum_pool(bg, h0)
    truth = tf.stack([F.sum(h0[:15], 0),
                      F.sum(h0[15:20], 0),
                      F.sum(h0[20:35], 0),
                      F.sum(h0[35:40], 0),
                      F.sum(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)

    h1 = avg_pool(bg, h0)
    truth = tf.stack([F.mean(h0[:15], 0),
                      F.mean(h0[15:20], 0),
                      F.mean(h0[20:35], 0),
                      F.mean(h0[35:40], 0),
                      F.mean(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)

    h1 = max_pool(bg, h0)
    truth = tf.stack([F.max(h0[:15], 0),
                      F.max(h0[15:20], 0),
                      F.max(h0[20:35], 0),
                      F.max(h0[35:40], 0),
                      F.max(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)

    h1 = sort_pool(bg, h0)
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.ndim == 2

def test_glob_att_pool():
168
    g = dgl.DGLGraph(nx.path_graph(10)).to(F.ctx())
169
170
171
172
173
174
175

    gap = nn.GlobalAttentionPooling(layers.Dense(1), layers.Dense(10))
    print(gap)

    # test#1: basic
    h0 = F.randn((g.number_of_nodes(), 5))
    h1 = gap(g, h0)
176
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2
177
178
179
180
181
182
183
184

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
    h0 = F.randn((bg.number_of_nodes(), 5))
    h1 = gap(bg, h0)
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2


185
186
@pytest.mark.parametrize('O', [1, 2, 8])
def test_rgcn(O):
187
    etype = []
188
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True).to(F.ctx())
189
190
191
192
193
194
195
196
    # 5 etypes
    R = 5
    for i in range(g.number_of_edges()):
        etype.append(i % 5)
    B = 2
    I = 10

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
197
198
199
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
200
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
201
202
203
    h = tf.random.normal((100, I))
    r = tf.constant(etype)
    h_new = rgc_basis(g, h, r)
204
    h_new_low = rgc_basis_low(g, h, r)
205
    assert list(h_new.shape) == [100, O]
206
207
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
208

209
210
211
212
213
214
215
216
217
218
219
220
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
        rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True)
        rgc_bdd_low.weight = rgc_bdd.weight
        rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
        h = tf.random.normal((100, I))
        r = tf.constant(etype)
        h_new = rgc_bdd(g, h, r)
        h_new_low = rgc_bdd_low(g, h, r)
        assert list(h_new.shape) == [100, O]
        assert list(h_new_low.shape) == [100, O]
        assert F.allclose(h_new, h_new_low)
221
222
223
224
225

    # with norm
    norm = tf.zeros((g.number_of_edges(), 1))

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
226
227
228
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
229
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
230
231
232
    h = tf.random.normal((100, I))
    r = tf.constant(etype)
    h_new = rgc_basis(g, h, r, norm)
233
    h_new_low = rgc_basis_low(g, h, r, norm)
234
    assert list(h_new.shape) == [100, O]
235
236
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
237

238
239
240
241
242
243
244
245
246
247
248
249
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
        rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True)
        rgc_bdd_low.weight = rgc_bdd.weight
        rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
        h = tf.random.normal((100, I))
        r = tf.constant(etype)
        h_new = rgc_bdd(g, h, r, norm)
        h_new_low = rgc_bdd_low(g, h, r, norm)
        assert list(h_new.shape) == [100, O]
        assert list(h_new_low.shape) == [100, O]
        assert F.allclose(h_new, h_new_low)
250
251
252

    # id input
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
253
254
255
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
256
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
257
258
    h = tf.constant(np.random.randint(0, I, (100,))) * 1
    r = tf.constant(etype) * 1
259
    h_new = rgc_basis(g, h, r)
260
    h_new_low = rgc_basis_low(g, h, r)
261
    assert list(h_new.shape) == [100, O]
262
263
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
264

265
@parametrize_dtype
266
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
267
268
269
@pytest.mark.parametrize('out_dim', [1, 2])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_gat_conv(g, idtype, out_dim, num_heads):
270
271
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
272
    gat = nn.GATConv(5, out_dim, num_heads)
273
    feat = F.randn((g.number_of_nodes(), 5))
274
    h = gat(g, feat)
275
    assert h.shape == (g.number_of_nodes(), num_heads, out_dim)
276
    _, a = gat(g, feat, get_attention=True)
277
    assert a.shape == (g.number_of_edges(), num_heads, 1)
278

279
@parametrize_dtype
280
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
281
282
283
@pytest.mark.parametrize('out_dim', [1, 2])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_gat_conv_bi(g, idtype, out_dim, num_heads):
284
285
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
286
    gat = nn.GATConv(5, out_dim, num_heads)
287
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
288
    h = gat(g, feat)
289
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
290
    _, a = gat(g, feat, get_attention=True)
291
    assert a.shape == (g.number_of_edges(), num_heads, 1)
292

293
294
295
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
296
297
@pytest.mark.parametrize('out_dim', [1, 10])
def test_sage_conv(idtype, g, aggre_type, out_dim):
298
    g = g.astype(idtype).to(F.ctx())
299
    sage = nn.SAGEConv(5, out_dim, aggre_type)
300
    feat = F.randn((g.number_of_nodes(), 5))
301
    h = sage(g, feat)
302
    assert h.shape[-1] == out_dim
303

304
305
306
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
307
308
@pytest.mark.parametrize('out_dim', [1, 2])
def test_sage_conv_bi(idtype, g, aggre_type, out_dim):
309
    g = g.astype(idtype).to(F.ctx())
310
    dst_dim = 5 if aggre_type != 'gcn' else 10
311
    sage = nn.SAGEConv((10, dst_dim), out_dim, aggre_type)
312
    feat = (F.randn((g.number_of_src_nodes(), 10)), F.randn((g.number_of_dst_nodes(), dst_dim)))
313
    h = sage(g, feat)
314
    assert h.shape[-1] == out_dim
315
    assert h.shape[0] == g.number_of_dst_nodes()
316

317
318
@parametrize_dtype
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
319
320
@pytest.mark.parametrize('out_dim', [1, 2])
def test_sage_conv_bi_empty(idtype, aggre_type, out_dim):
Mufei Li's avatar
Mufei Li committed
321
    # Test the case for graphs without edges
322
    g = dgl.heterograph({('_U', '_E', '_V'): ([], [])}, {'_U': 5, '_V': 3}).to(F.ctx())
323
    g = g.astype(idtype).to(F.ctx())
324
    sage = nn.SAGEConv((3, 3), out_dim, 'gcn')
Mufei Li's avatar
Mufei Li committed
325
326
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    h = sage(g, feat)
327
    assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
328
329
    assert h.shape[0] == 3
    for aggre_type in ['mean', 'pool', 'lstm']:
330
        sage = nn.SAGEConv((3, 1), out_dim, aggre_type)
Mufei Li's avatar
Mufei Li committed
331
332
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        h = sage(g, feat)
333
        assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
334
335
        assert h.shape[0] == 3

336
337
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
338
339
@pytest.mark.parametrize('out_dim', [1, 2])
def test_sgc_conv(g, idtype, out_dim):
340
    ctx = F.ctx()
341
    g = g.astype(idtype).to(ctx)
342
    # not cached
343
    sgc = nn.SGConv(5, out_dim, 3)
344
    feat = F.randn((g.number_of_nodes(), 5))
345
346

    h = sgc(g, feat)
347
    assert h.shape[-1] == out_dim
348
349

    # cached
350
    sgc = nn.SGConv(5, out_dim, 3, True)
351
352
353
    h_0 = sgc(g, feat)
    h_1 = sgc(g, feat + 1)
    assert F.allclose(h_0, h_1)
354
    assert h_0.shape[-1] == out_dim
355

356
357
358
359
360
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_appnp_conv(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
361
    appnp = nn.APPNPConv(10, 0.1)
362
    feat = F.randn((g.number_of_nodes(), 5))
363
364
365
366

    h = appnp(g, feat)
    assert h.shape[-1] == 5

367
368
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
369
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
370
371
372
def test_gin_conv(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
373
374
375
376
    gin = nn.GINConv(
        tf.keras.layers.Dense(12),
        aggregator_type
    )
377
    feat = F.randn((g.number_of_nodes(), 5))
378
    h = gin(g, feat)
379
    assert h.shape == (g.number_of_nodes(), 12)
380

381
382
383
384
385
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
def test_gin_conv_bi(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
386
387
388
389
    gin = nn.GINConv(
        tf.keras.layers.Dense(12),
        aggregator_type
    )
390
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
391
    h = gin(g, feat)
392
    assert h.shape == (g.number_of_dst_nodes(), 12)
393

394
395
396
397
398
399
def myagg(alist, dsttype):
    rst = alist[0]
    for i in range(1, len(alist)):
        rst = rst + (i + 1) * alist[i]
    return rst

400
@parametrize_dtype
401
@pytest.mark.parametrize('agg', ['sum', 'max', 'min', 'mean', 'stack', myagg])
402
def test_hetero_conv(agg, idtype):
403
    g = dgl.heterograph({
404
405
406
        ('user', 'follows', 'user'): ([0, 0, 2, 1], [1, 2, 1, 3]),
        ('user', 'plays', 'game'): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]),
        ('store', 'sells', 'game'): ([0, 0, 1, 1], [0, 3, 1, 2])},
407
        idtype=idtype, device=F.ctx())
408
    conv = nn.HeteroGraphConv({
409
410
411
        'follows': nn.GraphConv(2, 3, allow_zero_in_degree=True),
        'plays': nn.GraphConv(2, 4, allow_zero_in_degree=True),
        'sells': nn.GraphConv(3, 4, allow_zero_in_degree=True)},
412
413
414
415
416
        agg)
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))

417
    h = conv(g, {'user': uf, 'store': sf, 'game': gf})
418
419
420
421
422
423
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
424
        assert h['game'].shape == (4, 2, 4)
425

426
427
    block = dgl.to_block(g.to(F.cpu()), {'user': [0, 1, 2, 3], 'game': [0, 1, 2, 3], 'store': []}).to(F.ctx())
    h = conv(block, ({'user': uf, 'game': gf, 'store': sf}, {'user': uf, 'game': gf, 'store': sf[0:0]}))
428
429
430
431
432
433
434
435
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 2, 4)

436
    h = conv(block, {'user': uf, 'game': gf, 'store': sf})
437
438
439
440
441
442
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
443
        assert h['game'].shape == (4, 2, 4)
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468

    # test with mod args
    class MyMod(tf.keras.layers.Layer):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.carg2 = 0
            self.s1 = s1
            self.s2 = s2
        def call(self, g, h, arg1=None, *, arg2=None):
            if arg1 is not None:
                self.carg1 += 1
            if arg2 is not None:
                self.carg2 += 1
            return tf.zeros((g.number_of_dst_nodes(), self.s2))
    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
    conv = nn.HeteroGraphConv({
        'follows': mod1,
        'plays': mod2,
        'sells': mod3},
        agg)
    mod_args = {'follows' : (1,), 'plays' : (1,)}
    mod_kwargs = {'sells' : {'arg2' : 'abc'}}
469
    h = conv(g, {'user' : uf, 'game': gf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs)
470
471
472
473
474
475
476
    assert mod1.carg1 == 1
    assert mod1.carg2 == 0
    assert mod2.carg1 == 1
    assert mod2.carg2 == 0
    assert mod3.carg1 == 0
    assert mod3.carg2 == 1

477

478
479
@pytest.mark.parametrize('out_dim', [1, 2])
def test_dense_cheb_conv(out_dim):
480
481
482
483
484
    for k in range(3, 4):
        ctx = F.ctx()
        g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1, random_state=42))
        g = g.to(ctx)

485
        adj = tf.sparse.to_dense(tf.sparse.reorder(g.adjacency_matrix(transpose=False, ctx=ctx)))
486
487
        cheb = nn.ChebConv(5, out_dim, k, None, bias=True)
        dense_cheb = nn.DenseChebConv(5, out_dim, k, bias=True)
488
489
490
491
492

        # init cheb modules
        feat = F.ones((100, 5))
        out_cheb = cheb(g, feat, [2.0])

493
        dense_cheb.W = tf.reshape(cheb.linear.weights[0], (k, 5, out_dim))
494
495
496
497
498
499
500
501
        if cheb.linear.bias is not None:
            dense_cheb.bias = cheb.linear.bias

        out_dense_cheb = dense_cheb(adj, feat, 2.0)
        print(out_cheb - out_dense_cheb)
        assert F.allclose(out_cheb, out_dense_cheb)


502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
if __name__ == '__main__':
    test_graph_conv()
    # test_set2set()
    test_glob_att_pool()
    test_simple_pool()
    # test_set_trans()
    test_rgcn()
    # test_tagconv()
    test_gat_conv()
    test_sage_conv()
    test_sgc_conv()
    test_appnp_conv()
    test_gin_conv()
    # test_agnn_conv()
    # test_gated_graph_conv()
    # test_nn_conv()
    # test_gmm_conv()
    # test_dense_graph_conv()
    # test_dense_sage_conv()
521
    test_dense_cheb_conv()
522
    # test_sequential()