test_nn.py 19.3 KB
Newer Older
1
2
3
import tensorflow as tf
from tensorflow.keras import layers
import networkx as nx
4
import pytest
5
6
7
8
import dgl
import dgl.nn.tensorflow as nn
import dgl.function as fn
import backend as F
9
from test_utils.graph_cases import get_cases, random_graph, random_bipartite, random_dglgraph
nv-dlasalle's avatar
nv-dlasalle committed
10
from test_utils import parametrize_idtype
11
12
13
14
15
16
17
18
19
20
from copy import deepcopy

import numpy as np
import scipy as sp

def _AXWb(A, X, W, b):
    X = tf.matmul(X, W)
    Y = tf.reshape(tf.matmul(A, tf.reshape(X, (X.shape[0], -1))), X.shape)
    return Y + b

21
22
@pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv(out_dim):
23
    g = dgl.DGLGraph(nx.path_graph(3)).to(F.ctx())
24
    ctx = F.ctx()
25
    adj = tf.sparse.to_dense(tf.sparse.reorder(g.adjacency_matrix(transpose=True, ctx=ctx)))
26

27
    conv = nn.GraphConv(5, out_dim, norm='none', bias=True)
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    # conv = conv
    print(conv)
    # test#1: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
    # test#2: more-dim
    h0 = F.ones((3, 5, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))

43
    conv = nn.GraphConv(5, out_dim)
44
45
46
47
48
49
50
51
52
53
54
55
    # conv = conv
    # test#3: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    # test#4: basic
    h0 = F.ones((3, 5, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0

56
    conv = nn.GraphConv(5, out_dim)
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    # conv = conv
    # test#3: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    # test#4: basic
    h0 = F.ones((3, 5, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0

    # test rest_parameters
    # old_weight = deepcopy(conv.weight.data)
    # conv.reset_parameters()
    # new_weight = conv.weight.data
    # assert not F.allclose(old_weight, new_weight)

nv-dlasalle's avatar
nv-dlasalle committed
75
@parametrize_idtype
76
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree', 'dglgraph']))
77
@pytest.mark.parametrize('norm', ['none', 'both', 'right', 'left'])
78
79
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
80
81
@pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv2(idtype, g, norm, weight, bias, out_dim):
82
    g = g.astype(idtype).to(F.ctx())
83
84
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias)
    ext_w = F.randn((5, out_dim))
85
86
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
87
    h = F.randn((nsrc, 5))
88
    h_dst = F.randn((ndst, out_dim))
89
    if weight:
90
        h_out = conv(g, h)
91
    else:
92
        h_out = conv(g, h, weight=ext_w)
93
    assert h_out.shape == (ndst, out_dim)
94

nv-dlasalle's avatar
nv-dlasalle committed
95
@parametrize_idtype
96
97
98
99
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
100
101
@pytest.mark.parametrize('out_dim', [1, 2])
def test_graph_conv2_bi(idtype, g, norm, weight, bias, out_dim):
102
    g = g.astype(idtype).to(F.ctx())
103
104
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias)
    ext_w = F.randn((5, out_dim))
105
106
107
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5))
108
    h_dst = F.randn((ndst, out_dim))
109
110
111
112
    if weight:
        h_out = conv(g, (h, h_dst))
    else:
        h_out = conv(g, (h, h_dst), weight=ext_w)
113
    assert h_out.shape == (ndst, out_dim)
114
115
116

def test_simple_pool():
    ctx = F.ctx()
117
    g = dgl.DGLGraph(nx.path_graph(15)).to(F.ctx())
118
119
120
121
122
123
124
125
126
127

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
    sort_pool = nn.SortPooling(10) # k = 10
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
    h0 = F.randn((g.number_of_nodes(), 5))
    h1 = sum_pool(g, h0)
128
    assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
129
    h1 = avg_pool(g, h0)
130
    assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
131
    h1 = max_pool(g, h0)
132
    assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
133
    h1 = sort_pool(g, h0)
134
    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.ndim == 2
135
136

    # test#2: batched graph
137
    g_ = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    bg = dgl.batch([g, g_, g, g_, g])
    h0 = F.randn((bg.number_of_nodes(), 5))
    h1 = sum_pool(bg, h0)
    truth = tf.stack([F.sum(h0[:15], 0),
                      F.sum(h0[15:20], 0),
                      F.sum(h0[20:35], 0),
                      F.sum(h0[35:40], 0),
                      F.sum(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)

    h1 = avg_pool(bg, h0)
    truth = tf.stack([F.mean(h0[:15], 0),
                      F.mean(h0[15:20], 0),
                      F.mean(h0[20:35], 0),
                      F.mean(h0[35:40], 0),
                      F.mean(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)

    h1 = max_pool(bg, h0)
    truth = tf.stack([F.max(h0[:15], 0),
                      F.max(h0[15:20], 0),
                      F.max(h0[20:35], 0),
                      F.max(h0[35:40], 0),
                      F.max(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)

    h1 = sort_pool(bg, h0)
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.ndim == 2

def test_glob_att_pool():
168
    g = dgl.DGLGraph(nx.path_graph(10)).to(F.ctx())
169
170
171
172
173
174
175

    gap = nn.GlobalAttentionPooling(layers.Dense(1), layers.Dense(10))
    print(gap)

    # test#1: basic
    h0 = F.randn((g.number_of_nodes(), 5))
    h1 = gap(g, h0)
176
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2
177
178
179
180
181
182
183
184

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
    h0 = F.randn((bg.number_of_nodes(), 5))
    h1 = gap(bg, h0)
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2


185
186
@pytest.mark.parametrize('O', [1, 2, 8])
def test_rgcn(O):
187
    etype = []
188
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True).to(F.ctx())
189
190
191
192
193
194
195
196
    # 5 etypes
    R = 5
    for i in range(g.number_of_edges()):
        etype.append(i % 5)
    B = 2
    I = 10

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
197
198
199
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
200
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
201
202
203
    h = tf.random.normal((100, I))
    r = tf.constant(etype)
    h_new = rgc_basis(g, h, r)
204
    h_new_low = rgc_basis_low(g, h, r)
205
    assert list(h_new.shape) == [100, O]
206
207
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
208

209
210
211
212
213
214
215
216
217
218
219
220
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
        rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True)
        rgc_bdd_low.weight = rgc_bdd.weight
        rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
        h = tf.random.normal((100, I))
        r = tf.constant(etype)
        h_new = rgc_bdd(g, h, r)
        h_new_low = rgc_bdd_low(g, h, r)
        assert list(h_new.shape) == [100, O]
        assert list(h_new_low.shape) == [100, O]
        assert F.allclose(h_new, h_new_low)
221
222
223
224
225

    # with norm
    norm = tf.zeros((g.number_of_edges(), 1))

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
226
227
228
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
229
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
230
231
232
    h = tf.random.normal((100, I))
    r = tf.constant(etype)
    h_new = rgc_basis(g, h, r, norm)
233
    h_new_low = rgc_basis_low(g, h, r, norm)
234
    assert list(h_new.shape) == [100, O]
235
236
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
237

238
239
240
241
242
243
244
245
246
247
248
249
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
        rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True)
        rgc_bdd_low.weight = rgc_bdd.weight
        rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
        h = tf.random.normal((100, I))
        r = tf.constant(etype)
        h_new = rgc_bdd(g, h, r, norm)
        h_new_low = rgc_bdd_low(g, h, r, norm)
        assert list(h_new.shape) == [100, O]
        assert list(h_new_low.shape) == [100, O]
        assert F.allclose(h_new, h_new_low)
250
251
252

    # id input
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
253
254
255
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
256
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
257
258
    h = tf.constant(np.random.randint(0, I, (100,))) * 1
    r = tf.constant(etype) * 1
259
    h_new = rgc_basis(g, h, r)
260
    h_new_low = rgc_basis_low(g, h, r)
261
    assert list(h_new.shape) == [100, O]
262
263
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
264

nv-dlasalle's avatar
nv-dlasalle committed
265
@parametrize_idtype
266
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
267
268
269
@pytest.mark.parametrize('out_dim', [1, 2])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_gat_conv(g, idtype, out_dim, num_heads):
270
271
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
272
    gat = nn.GATConv(5, out_dim, num_heads)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
273
    feat = F.randn((g.number_of_src_nodes(), 5))
274
    h = gat(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
275
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
276
    _, a = gat(g, feat, get_attention=True)
277
    assert a.shape == (g.number_of_edges(), num_heads, 1)
278

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
279
280
281
282
    # test residual connection
    gat = nn.GATConv(5, out_dim, num_heads, residual=True)
    h = gat(g, feat)

nv-dlasalle's avatar
nv-dlasalle committed
283
@parametrize_idtype
284
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
285
286
287
@pytest.mark.parametrize('out_dim', [1, 2])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_gat_conv_bi(g, idtype, out_dim, num_heads):
288
289
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
290
    gat = nn.GATConv(5, out_dim, num_heads)
291
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
292
    h = gat(g, feat)
293
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
294
    _, a = gat(g, feat, get_attention=True)
295
    assert a.shape == (g.number_of_edges(), num_heads, 1)
296

nv-dlasalle's avatar
nv-dlasalle committed
297
@parametrize_idtype
298
299
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
300
301
@pytest.mark.parametrize('out_dim', [1, 10])
def test_sage_conv(idtype, g, aggre_type, out_dim):
302
    g = g.astype(idtype).to(F.ctx())
303
    sage = nn.SAGEConv(5, out_dim, aggre_type)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
304
    feat = F.randn((g.number_of_src_nodes(), 5))
305
    h = sage(g, feat)
306
    assert h.shape[-1] == out_dim
307

nv-dlasalle's avatar
nv-dlasalle committed
308
@parametrize_idtype
309
310
@pytest.mark.parametrize('g', get_cases(['bipartite']))
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
311
312
@pytest.mark.parametrize('out_dim', [1, 2])
def test_sage_conv_bi(idtype, g, aggre_type, out_dim):
313
    g = g.astype(idtype).to(F.ctx())
314
    dst_dim = 5 if aggre_type != 'gcn' else 10
315
    sage = nn.SAGEConv((10, dst_dim), out_dim, aggre_type)
316
    feat = (F.randn((g.number_of_src_nodes(), 10)), F.randn((g.number_of_dst_nodes(), dst_dim)))
317
    h = sage(g, feat)
318
    assert h.shape[-1] == out_dim
319
    assert h.shape[0] == g.number_of_dst_nodes()
320

nv-dlasalle's avatar
nv-dlasalle committed
321
@parametrize_idtype
322
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
323
324
@pytest.mark.parametrize('out_dim', [1, 2])
def test_sage_conv_bi_empty(idtype, aggre_type, out_dim):
Mufei Li's avatar
Mufei Li committed
325
    # Test the case for graphs without edges
326
    g = dgl.heterograph({('_U', '_E', '_V'): ([], [])}, {'_U': 5, '_V': 3}).to(F.ctx())
327
    g = g.astype(idtype).to(F.ctx())
328
    sage = nn.SAGEConv((3, 3), out_dim, 'gcn')
Mufei Li's avatar
Mufei Li committed
329
330
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    h = sage(g, feat)
331
    assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
332
333
    assert h.shape[0] == 3
    for aggre_type in ['mean', 'pool', 'lstm']:
334
        sage = nn.SAGEConv((3, 1), out_dim, aggre_type)
Mufei Li's avatar
Mufei Li committed
335
336
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        h = sage(g, feat)
337
        assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
338
339
        assert h.shape[0] == 3

nv-dlasalle's avatar
nv-dlasalle committed
340
@parametrize_idtype
341
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
342
343
@pytest.mark.parametrize('out_dim', [1, 2])
def test_sgc_conv(g, idtype, out_dim):
344
    ctx = F.ctx()
345
    g = g.astype(idtype).to(ctx)
346
    # not cached
347
    sgc = nn.SGConv(5, out_dim, 3)
348
    feat = F.randn((g.number_of_nodes(), 5))
349
350

    h = sgc(g, feat)
351
    assert h.shape[-1] == out_dim
352
353

    # cached
354
    sgc = nn.SGConv(5, out_dim, 3, True)
355
356
357
    h_0 = sgc(g, feat)
    h_1 = sgc(g, feat + 1)
    assert F.allclose(h_0, h_1)
358
    assert h_0.shape[-1] == out_dim
359

nv-dlasalle's avatar
nv-dlasalle committed
360
@parametrize_idtype
361
362
363
364
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_appnp_conv(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
365
    appnp = nn.APPNPConv(10, 0.1)
366
    feat = F.randn((g.number_of_nodes(), 5))
367
368
369
370

    h = appnp(g, feat)
    assert h.shape[-1] == 5

nv-dlasalle's avatar
nv-dlasalle committed
371
@parametrize_idtype
372
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
373
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
374
375
376
def test_gin_conv(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
377
378
379
380
    gin = nn.GINConv(
        tf.keras.layers.Dense(12),
        aggregator_type
    )
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
381
    feat = F.randn((g.number_of_src_nodes(), 5))
382
    h = gin(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
383
    assert h.shape == (g.number_of_dst_nodes(), 12)
384

nv-dlasalle's avatar
nv-dlasalle committed
385
@parametrize_idtype
386
387
388
389
@pytest.mark.parametrize('g', get_cases(['bipartite']))
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
def test_gin_conv_bi(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
390
391
392
393
    gin = nn.GINConv(
        tf.keras.layers.Dense(12),
        aggregator_type
    )
394
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
395
    h = gin(g, feat)
396
    assert h.shape == (g.number_of_dst_nodes(), 12)
397

nv-dlasalle's avatar
nv-dlasalle committed
398
@parametrize_idtype
kyawlinoo's avatar
kyawlinoo committed
399
400
401
402
403
404
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_dim', [1, 2])
def test_edge_conv(g, idtype, out_dim):
    g = g.astype(idtype).to(F.ctx())
    edge_conv = nn.EdgeConv(out_dim)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
405
    h0 = F.randn((g.number_of_src_nodes(), 5))
kyawlinoo's avatar
kyawlinoo committed
406
    h1 = edge_conv(g, h0)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
407
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)
kyawlinoo's avatar
kyawlinoo committed
408

nv-dlasalle's avatar
nv-dlasalle committed
409
@parametrize_idtype
kyawlinoo's avatar
kyawlinoo committed
410
411
412
413
414
415
416
417
418
419
420
421
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_dim', [1, 2])
def test_edge_conv_bi(g, idtype, out_dim):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    edge_conv = nn.EdgeConv(out_dim)

    h0 = F.randn((g.number_of_src_nodes(), 5))
    x0 = F.randn((g.number_of_dst_nodes(), 5))
    h1 = edge_conv(g, (h0, x0))
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)

422
423
424
425
426
427
def myagg(alist, dsttype):
    rst = alist[0]
    for i in range(1, len(alist)):
        rst = rst + (i + 1) * alist[i]
    return rst

nv-dlasalle's avatar
nv-dlasalle committed
428
@parametrize_idtype
429
@pytest.mark.parametrize('agg', ['sum', 'max', 'min', 'mean', 'stack', myagg])
430
def test_hetero_conv(agg, idtype):
431
    g = dgl.heterograph({
432
433
434
        ('user', 'follows', 'user'): ([0, 0, 2, 1], [1, 2, 1, 3]),
        ('user', 'plays', 'game'): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]),
        ('store', 'sells', 'game'): ([0, 0, 1, 1], [0, 3, 1, 2])},
435
        idtype=idtype, device=F.ctx())
436
    conv = nn.HeteroGraphConv({
437
438
439
        'follows': nn.GraphConv(2, 3, allow_zero_in_degree=True),
        'plays': nn.GraphConv(2, 4, allow_zero_in_degree=True),
        'sells': nn.GraphConv(3, 4, allow_zero_in_degree=True)},
440
441
442
443
444
        agg)
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))

445
    h = conv(g, {'user': uf, 'store': sf, 'game': gf})
446
447
448
449
450
451
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
452
        assert h['game'].shape == (4, 2, 4)
453

454
455
    block = dgl.to_block(g.to(F.cpu()), {'user': [0, 1, 2, 3], 'game': [0, 1, 2, 3], 'store': []}).to(F.ctx())
    h = conv(block, ({'user': uf, 'game': gf, 'store': sf}, {'user': uf, 'game': gf, 'store': sf[0:0]}))
456
457
458
459
460
461
462
463
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 2, 4)

464
    h = conv(block, {'user': uf, 'game': gf, 'store': sf})
465
466
467
468
469
470
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
471
        assert h['game'].shape == (4, 2, 4)
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496

    # test with mod args
    class MyMod(tf.keras.layers.Layer):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.carg2 = 0
            self.s1 = s1
            self.s2 = s2
        def call(self, g, h, arg1=None, *, arg2=None):
            if arg1 is not None:
                self.carg1 += 1
            if arg2 is not None:
                self.carg2 += 1
            return tf.zeros((g.number_of_dst_nodes(), self.s2))
    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
    conv = nn.HeteroGraphConv({
        'follows': mod1,
        'plays': mod2,
        'sells': mod3},
        agg)
    mod_args = {'follows' : (1,), 'plays' : (1,)}
    mod_kwargs = {'sells' : {'arg2' : 'abc'}}
497
    h = conv(g, {'user' : uf, 'game': gf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs)
498
499
500
501
502
503
504
    assert mod1.carg1 == 1
    assert mod1.carg2 == 0
    assert mod2.carg1 == 1
    assert mod2.carg2 == 0
    assert mod3.carg1 == 0
    assert mod3.carg2 == 1

505
506
507
508
509
510
511
512
513
514
515
516
517
    #conv on graph without any edges
    for etype in g.etypes:
        g = dgl.remove_edges(g, g.edges(form='eid', etype=etype), etype=etype)
    assert g.num_edges() == 0
    h = conv(g, {'user': uf, 'game': gf, 'store': sf})
    assert set(h.keys()) == {'user', 'game'}

    block = dgl.to_block(g.to(F.cpu()), {'user': [0, 1, 2, 3], 'game': [
                         0, 1, 2, 3], 'store': []}).to(F.ctx())
    h = conv(block, ({'user': uf, 'game': gf, 'store': sf},
             {'user': uf, 'game': gf, 'store': sf[0:0]}))
    assert set(h.keys()) == {'user', 'game'}

518

519
520
@pytest.mark.parametrize('out_dim', [1, 2])
def test_dense_cheb_conv(out_dim):
521
522
523
524
525
    for k in range(3, 4):
        ctx = F.ctx()
        g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1, random_state=42))
        g = g.to(ctx)

526
        adj = tf.sparse.to_dense(tf.sparse.reorder(g.adjacency_matrix(transpose=True, ctx=ctx)))
527
528
        cheb = nn.ChebConv(5, out_dim, k, None, bias=True)
        dense_cheb = nn.DenseChebConv(5, out_dim, k, bias=True)
529
530
531
532
533

        # init cheb modules
        feat = F.ones((100, 5))
        out_cheb = cheb(g, feat, [2.0])

534
        dense_cheb.W = tf.reshape(cheb.linear.weights[0], (k, 5, out_dim))
535
536
537
538
539
540
541
542
        if cheb.linear.bias is not None:
            dense_cheb.bias = cheb.linear.bias

        out_dense_cheb = dense_cheb(adj, feat, 2.0)
        print(out_cheb - out_dense_cheb)
        assert F.allclose(out_cheb, out_dense_cheb)


543
544
545
546
547
548
549
550
551
552
553
554
555
if __name__ == '__main__':
    test_graph_conv()
    # test_set2set()
    test_glob_att_pool()
    test_simple_pool()
    # test_set_trans()
    test_rgcn()
    # test_tagconv()
    test_gat_conv()
    test_sage_conv()
    test_sgc_conv()
    test_appnp_conv()
    test_gin_conv()
kyawlinoo's avatar
kyawlinoo committed
556
    test_edge_conv()
557
558
559
560
561
562
    # test_agnn_conv()
    # test_gated_graph_conv()
    # test_nn_conv()
    # test_gmm_conv()
    # test_dense_graph_conv()
    # test_dense_sage_conv()
563
    test_dense_cheb_conv()
564
    # test_sequential()
565
    test_hetero_conv()