test_nn.py 18.7 KB
Newer Older
1
2
3
import tensorflow as tf
from tensorflow.keras import layers
import networkx as nx
4
import pytest
5
6
7
8
import dgl
import dgl.nn.tensorflow as nn
import dgl.function as fn
import backend as F
9
10
from test_utils.graph_cases import get_cases, random_graph, random_bipartite, random_dglgraph
from test_utils import parametrize_dtype
11
12
13
14
15
16
17
18
19
20
from copy import deepcopy

import numpy as np
import scipy as sp

def _AXWb(A, X, W, b):
    X = tf.matmul(X, W)
    Y = tf.reshape(tf.matmul(A, tf.reshape(X, (X.shape[0], -1))), X.shape)
    return Y + b

21
22
@pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv(out_dim):
23
    g = dgl.DGLGraph(nx.path_graph(3)).to(F.ctx())
24
    ctx = F.ctx()
25
    adj = tf.sparse.to_dense(tf.sparse.reorder(g.adjacency_matrix(transpose=True, ctx=ctx)))
26

27
    conv = nn.GraphConv(5, out_dim, norm='none', bias=True)
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    # conv = conv
    print(conv)
    # test#1: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
    # test#2: more-dim
    h0 = F.ones((3, 5, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))

43
    conv = nn.GraphConv(5, out_dim)
44
45
46
47
48
49
50
51
52
53
54
55
    # conv = conv
    # test#3: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    # test#4: basic
    h0 = F.ones((3, 5, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0

56
    conv = nn.GraphConv(5, out_dim)
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    # conv = conv
    # test#3: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    # test#4: basic
    h0 = F.ones((3, 5, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0

    # test rest_parameters
    # old_weight = deepcopy(conv.weight.data)
    # conv.reset_parameters()
    # new_weight = conv.weight.data
    # assert not F.allclose(old_weight, new_weight)

75
76
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree', 'dglgraph']))
77
@pytest.mark.parametrize('norm', ['none', 'both', 'right', 'left'])
78
79
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
80
81
@pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv2(idtype, g, norm, weight, bias, out_dim):
82
    g = g.astype(idtype).to(F.ctx())
83
84
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias)
    ext_w = F.randn((5, out_dim))
85
86
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
87
    h = F.randn((nsrc, 5))
88
    h_dst = F.randn((ndst, out_dim))
89
    if weight:
90
        h_out = conv(g, h)
91
    else:
92
        h_out = conv(g, h, weight=ext_w)
93
    assert h_out.shape == (ndst, out_dim)
94

95
96
97
98
99
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
100
101
@pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv2_bi(idtype, g, norm, weight, bias, out_dim):
102
    g = g.astype(idtype).to(F.ctx())
103
104
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias)
    ext_w = F.randn((5, out_dim))
105
106
107
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5))
108
    h_dst = F.randn((ndst, out_dim))
109
110
111
112
    if weight:
        h_out = conv(g, (h, h_dst))
    else:
        h_out = conv(g, (h, h_dst), weight=ext_w)
113
    assert h_out.shape == (ndst, out_dim)
114
115
116

def test_simple_pool():
    ctx = F.ctx()
117
    g = dgl.DGLGraph(nx.path_graph(15)).to(F.ctx())
118
119
120
121
122
123
124
125
126
127

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
    sort_pool = nn.SortPooling(10) # k = 10
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
    h0 = F.randn((g.number_of_nodes(), 5))
    h1 = sum_pool(g, h0)
128
    assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
129
    h1 = avg_pool(g, h0)
130
    assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
131
    h1 = max_pool(g, h0)
132
    assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
133
    h1 = sort_pool(g, h0)
134
    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.ndim == 2
135
136

    # test#2: batched graph
137
    g_ = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    bg = dgl.batch([g, g_, g, g_, g])
    h0 = F.randn((bg.number_of_nodes(), 5))
    h1 = sum_pool(bg, h0)
    truth = tf.stack([F.sum(h0[:15], 0),
                      F.sum(h0[15:20], 0),
                      F.sum(h0[20:35], 0),
                      F.sum(h0[35:40], 0),
                      F.sum(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)

    h1 = avg_pool(bg, h0)
    truth = tf.stack([F.mean(h0[:15], 0),
                      F.mean(h0[15:20], 0),
                      F.mean(h0[20:35], 0),
                      F.mean(h0[35:40], 0),
                      F.mean(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)

    h1 = max_pool(bg, h0)
    truth = tf.stack([F.max(h0[:15], 0),
                      F.max(h0[15:20], 0),
                      F.max(h0[20:35], 0),
                      F.max(h0[35:40], 0),
                      F.max(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)

    h1 = sort_pool(bg, h0)
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.ndim == 2

def test_glob_att_pool():
168
    g = dgl.DGLGraph(nx.path_graph(10)).to(F.ctx())
169
170
171
172
173
174
175

    gap = nn.GlobalAttentionPooling(layers.Dense(1), layers.Dense(10))
    print(gap)

    # test#1: basic
    h0 = F.randn((g.number_of_nodes(), 5))
    h1 = gap(g, h0)
176
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2
177
178
179
180
181
182
183
184

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
    h0 = F.randn((bg.number_of_nodes(), 5))
    h1 = gap(bg, h0)
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2


185
186
@pytest.mark.parametrize('O', [1, 2, 8])
def test_rgcn(O):
187
    etype = []
188
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True).to(F.ctx())
189
190
191
192
193
194
195
196
    # 5 etypes
    R = 5
    for i in range(g.number_of_edges()):
        etype.append(i % 5)
    B = 2
    I = 10

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
197
198
199
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
200
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
201
202
203
    h = tf.random.normal((100, I))
    r = tf.constant(etype)
    h_new = rgc_basis(g, h, r)
204
    h_new_low = rgc_basis_low(g, h, r)
205
    assert list(h_new.shape) == [100, O]
206
207
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
208

209
210
211
212
213
214
215
216
217
218
219
220
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
        rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True)
        rgc_bdd_low.weight = rgc_bdd.weight
        rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
        h = tf.random.normal((100, I))
        r = tf.constant(etype)
        h_new = rgc_bdd(g, h, r)
        h_new_low = rgc_bdd_low(g, h, r)
        assert list(h_new.shape) == [100, O]
        assert list(h_new_low.shape) == [100, O]
        assert F.allclose(h_new, h_new_low)
221
222
223
224
225

    # with norm
    norm = tf.zeros((g.number_of_edges(), 1))

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
226
227
228
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
229
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
230
231
232
    h = tf.random.normal((100, I))
    r = tf.constant(etype)
    h_new = rgc_basis(g, h, r, norm)
233
    h_new_low = rgc_basis_low(g, h, r, norm)
234
    assert list(h_new.shape) == [100, O]
235
236
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
237

238
239
240
241
242
243
244
245
246
247
248
249
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
        rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True)
        rgc_bdd_low.weight = rgc_bdd.weight
        rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
        h = tf.random.normal((100, I))
        r = tf.constant(etype)
        h_new = rgc_bdd(g, h, r, norm)
        h_new_low = rgc_bdd_low(g, h, r, norm)
        assert list(h_new.shape) == [100, O]
        assert list(h_new_low.shape) == [100, O]
        assert F.allclose(h_new, h_new_low)
250
251
252

    # id input
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
253
254
255
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
256
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
257
258
    h = tf.constant(np.random.randint(0, I, (100,))) * 1
    r = tf.constant(etype) * 1
259
    h_new = rgc_basis(g, h, r)
260
    h_new_low = rgc_basis_low(g, h, r)
261
    assert list(h_new.shape) == [100, O]
262
263
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
264

265
@parametrize_dtype
266
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
267
268
269
@pytest.mark.parametrize('out_dim', [1, 2])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_gat_conv(g, idtype, out_dim, num_heads):
270
271
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
272
    gat = nn.GATConv(5, out_dim, num_heads)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
273
    feat = F.randn((g.number_of_src_nodes(), 5))
274
    h = gat(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
275
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
276
    _, a = gat(g, feat, get_attention=True)
277
    assert a.shape == (g.number_of_edges(), num_heads, 1)
278

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
279
280
281
282
    # test residual connection
    gat = nn.GATConv(5, out_dim, num_heads, residual=True)
    h = gat(g, feat)

283
@parametrize_dtype
284
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
285
286
287
@pytest.mark.parametrize('out_dim', [1, 2])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_gat_conv_bi(g, idtype, out_dim, num_heads):
288
289
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
290
    gat = nn.GATConv(5, out_dim, num_heads)
291
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
292
    h = gat(g, feat)
293
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
294
    _, a = gat(g, feat, get_attention=True)
295
    assert a.shape == (g.number_of_edges(), num_heads, 1)
296

297
298
299
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
300
301
@pytest.mark.parametrize('out_dim', [1, 10])
def test_sage_conv(idtype, g, aggre_type, out_dim):
302
    g = g.astype(idtype).to(F.ctx())
303
    sage = nn.SAGEConv(5, out_dim, aggre_type)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
304
    feat = F.randn((g.number_of_src_nodes(), 5))
305
    h = sage(g, feat)
306
    assert h.shape[-1] == out_dim
307

308
309
310
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
311
312
@pytest.mark.parametrize('out_dim', [1, 2])
def test_sage_conv_bi(idtype, g, aggre_type, out_dim):
313
    g = g.astype(idtype).to(F.ctx())
314
    dst_dim = 5 if aggre_type != 'gcn' else 10
315
    sage = nn.SAGEConv((10, dst_dim), out_dim, aggre_type)
316
    feat = (F.randn((g.number_of_src_nodes(), 10)), F.randn((g.number_of_dst_nodes(), dst_dim)))
317
    h = sage(g, feat)
318
    assert h.shape[-1] == out_dim
319
    assert h.shape[0] == g.number_of_dst_nodes()
320

321
322
@parametrize_dtype
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
323
324
@pytest.mark.parametrize('out_dim', [1, 2])
def test_sage_conv_bi_empty(idtype, aggre_type, out_dim):
Mufei Li's avatar
Mufei Li committed
325
    # Test the case for graphs without edges
326
    g = dgl.heterograph({('_U', '_E', '_V'): ([], [])}, {'_U': 5, '_V': 3}).to(F.ctx())
327
    g = g.astype(idtype).to(F.ctx())
328
    sage = nn.SAGEConv((3, 3), out_dim, 'gcn')
Mufei Li's avatar
Mufei Li committed
329
330
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    h = sage(g, feat)
331
    assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
332
333
    assert h.shape[0] == 3
    for aggre_type in ['mean', 'pool', 'lstm']:
334
        sage = nn.SAGEConv((3, 1), out_dim, aggre_type)
Mufei Li's avatar
Mufei Li committed
335
336
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        h = sage(g, feat)
337
        assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
338
339
        assert h.shape[0] == 3

340
341
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
342
343
@pytest.mark.parametrize('out_dim', [1, 2])
def test_sgc_conv(g, idtype, out_dim):
344
    ctx = F.ctx()
345
    g = g.astype(idtype).to(ctx)
346
    # not cached
347
    sgc = nn.SGConv(5, out_dim, 3)
348
    feat = F.randn((g.number_of_nodes(), 5))
349
350

    h = sgc(g, feat)
351
    assert h.shape[-1] == out_dim
352
353

    # cached
354
    sgc = nn.SGConv(5, out_dim, 3, True)
355
356
357
    h_0 = sgc(g, feat)
    h_1 = sgc(g, feat + 1)
    assert F.allclose(h_0, h_1)
358
    assert h_0.shape[-1] == out_dim
359

360
361
362
363
364
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_appnp_conv(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
365
    appnp = nn.APPNPConv(10, 0.1)
366
    feat = F.randn((g.number_of_nodes(), 5))
367
368
369
370

    h = appnp(g, feat)
    assert h.shape[-1] == 5

371
372
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
373
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
374
375
376
def test_gin_conv(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
377
378
379
380
    gin = nn.GINConv(
        tf.keras.layers.Dense(12),
        aggregator_type
    )
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
381
    feat = F.randn((g.number_of_src_nodes(), 5))
382
    h = gin(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
383
    assert h.shape == (g.number_of_dst_nodes(), 12)
384

385
386
387
388
389
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
def test_gin_conv_bi(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
390
391
392
393
    gin = nn.GINConv(
        tf.keras.layers.Dense(12),
        aggregator_type
    )
394
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
395
    h = gin(g, feat)
396
    assert h.shape == (g.number_of_dst_nodes(), 12)
397

kyawlinoo's avatar
kyawlinoo committed
398
399
400
401
402
403
404
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_dim', [1, 2])
def test_edge_conv(g, idtype, out_dim):
    g = g.astype(idtype).to(F.ctx())
    edge_conv = nn.EdgeConv(out_dim)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
405
    h0 = F.randn((g.number_of_src_nodes(), 5))
kyawlinoo's avatar
kyawlinoo committed
406
    h1 = edge_conv(g, h0)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
407
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)
kyawlinoo's avatar
kyawlinoo committed
408
409
410
411
412
413
414
415
416
417
418
419
420
421

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_dim', [1, 2])
def test_edge_conv_bi(g, idtype, out_dim):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    edge_conv = nn.EdgeConv(out_dim)

    h0 = F.randn((g.number_of_src_nodes(), 5))
    x0 = F.randn((g.number_of_dst_nodes(), 5))
    h1 = edge_conv(g, (h0, x0))
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)

422
423
424
425
426
427
def myagg(alist, dsttype):
    rst = alist[0]
    for i in range(1, len(alist)):
        rst = rst + (i + 1) * alist[i]
    return rst

428
@parametrize_dtype
429
@pytest.mark.parametrize('agg', ['sum', 'max', 'min', 'mean', 'stack', myagg])
430
def test_hetero_conv(agg, idtype):
431
    g = dgl.heterograph({
432
433
434
        ('user', 'follows', 'user'): ([0, 0, 2, 1], [1, 2, 1, 3]),
        ('user', 'plays', 'game'): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]),
        ('store', 'sells', 'game'): ([0, 0, 1, 1], [0, 3, 1, 2])},
435
        idtype=idtype, device=F.ctx())
436
    conv = nn.HeteroGraphConv({
437
438
439
        'follows': nn.GraphConv(2, 3, allow_zero_in_degree=True),
        'plays': nn.GraphConv(2, 4, allow_zero_in_degree=True),
        'sells': nn.GraphConv(3, 4, allow_zero_in_degree=True)},
440
441
442
443
444
        agg)
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))

445
    h = conv(g, {'user': uf, 'store': sf, 'game': gf})
446
447
448
449
450
451
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
452
        assert h['game'].shape == (4, 2, 4)
453

454
455
    block = dgl.to_block(g.to(F.cpu()), {'user': [0, 1, 2, 3], 'game': [0, 1, 2, 3], 'store': []}).to(F.ctx())
    h = conv(block, ({'user': uf, 'game': gf, 'store': sf}, {'user': uf, 'game': gf, 'store': sf[0:0]}))
456
457
458
459
460
461
462
463
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 2, 4)

464
    h = conv(block, {'user': uf, 'game': gf, 'store': sf})
465
466
467
468
469
470
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
471
        assert h['game'].shape == (4, 2, 4)
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496

    # test with mod args
    class MyMod(tf.keras.layers.Layer):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.carg2 = 0
            self.s1 = s1
            self.s2 = s2
        def call(self, g, h, arg1=None, *, arg2=None):
            if arg1 is not None:
                self.carg1 += 1
            if arg2 is not None:
                self.carg2 += 1
            return tf.zeros((g.number_of_dst_nodes(), self.s2))
    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
    conv = nn.HeteroGraphConv({
        'follows': mod1,
        'plays': mod2,
        'sells': mod3},
        agg)
    mod_args = {'follows' : (1,), 'plays' : (1,)}
    mod_kwargs = {'sells' : {'arg2' : 'abc'}}
497
    h = conv(g, {'user' : uf, 'game': gf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs)
498
499
500
501
502
503
504
    assert mod1.carg1 == 1
    assert mod1.carg2 == 0
    assert mod2.carg1 == 1
    assert mod2.carg2 == 0
    assert mod3.carg1 == 0
    assert mod3.carg2 == 1

505

506
507
@pytest.mark.parametrize('out_dim', [1, 2])
def test_dense_cheb_conv(out_dim):
508
509
510
511
512
    for k in range(3, 4):
        ctx = F.ctx()
        g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1, random_state=42))
        g = g.to(ctx)

513
        adj = tf.sparse.to_dense(tf.sparse.reorder(g.adjacency_matrix(transpose=True, ctx=ctx)))
514
515
        cheb = nn.ChebConv(5, out_dim, k, None, bias=True)
        dense_cheb = nn.DenseChebConv(5, out_dim, k, bias=True)
516
517
518
519
520

        # init cheb modules
        feat = F.ones((100, 5))
        out_cheb = cheb(g, feat, [2.0])

521
        dense_cheb.W = tf.reshape(cheb.linear.weights[0], (k, 5, out_dim))
522
523
524
525
526
527
528
529
        if cheb.linear.bias is not None:
            dense_cheb.bias = cheb.linear.bias

        out_dense_cheb = dense_cheb(adj, feat, 2.0)
        print(out_cheb - out_dense_cheb)
        assert F.allclose(out_cheb, out_dense_cheb)


530
531
532
533
534
535
536
537
538
539
540
541
542
if __name__ == '__main__':
    test_graph_conv()
    # test_set2set()
    test_glob_att_pool()
    test_simple_pool()
    # test_set_trans()
    test_rgcn()
    # test_tagconv()
    test_gat_conv()
    test_sage_conv()
    test_sgc_conv()
    test_appnp_conv()
    test_gin_conv()
kyawlinoo's avatar
kyawlinoo committed
543
    test_edge_conv()
544
545
546
547
548
549
    # test_agnn_conv()
    # test_gated_graph_conv()
    # test_nn_conv()
    # test_gmm_conv()
    # test_dense_graph_conv()
    # test_dense_sage_conv()
550
    test_dense_cheb_conv()
551
    # test_sequential()