test_nn.py 19.7 KB
Newer Older
1
2
3
from copy import deepcopy

import backend as F
4
5
6
7

import dgl
import dgl.function as fn
import dgl.nn.tensorflow as nn
8
import networkx as nx
9
import numpy as np
10
import pytest
11
12
13
14
15
16
17
18
19
20
21
import scipy as sp
import tensorflow as tf
from tensorflow.keras import layers
from test_utils import parametrize_idtype
from test_utils.graph_cases import (
    get_cases,
    random_bipartite,
    random_dglgraph,
    random_graph,
)

22
23
24
25
26
27

def _AXWb(A, X, W, b):
    X = tf.matmul(X, W)
    Y = tf.reshape(tf.matmul(A, tf.reshape(X, (X.shape[0], -1))), X.shape)
    return Y + b

28
29

@pytest.mark.parametrize("out_dim", [1, 2])
30
def test_graph_conv(out_dim):
31
    g = dgl.DGLGraph(nx.path_graph(3)).to(F.ctx())
32
    ctx = F.ctx()
33
34
35
    adj = tf.sparse.to_dense(
        tf.sparse.reorder(g.adjacency_matrix(transpose=True, ctx=ctx))
    )
36

37
    conv = nn.GraphConv(5, out_dim, norm="none", bias=True)
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    # conv = conv
    print(conv)
    # test#1: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
    # test#2: more-dim
    h0 = F.ones((3, 5, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))

53
    conv = nn.GraphConv(5, out_dim)
54
55
56
57
58
59
60
61
62
63
64
65
    # conv = conv
    # test#3: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    # test#4: basic
    h0 = F.ones((3, 5, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0

66
    conv = nn.GraphConv(5, out_dim)
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    # conv = conv
    # test#3: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    # test#4: basic
    h0 = F.ones((3, 5, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0

    # test rest_parameters
    # old_weight = deepcopy(conv.weight.data)
    # conv.reset_parameters()
    # new_weight = conv.weight.data
    # assert not F.allclose(old_weight, new_weight)

85

nv-dlasalle's avatar
nv-dlasalle committed
86
@parametrize_idtype
87
88
89
90
91
92
93
94
@pytest.mark.parametrize(
    "g",
    get_cases(["homo", "block-bipartite"], exclude=["zero-degree", "dglgraph"]),
)
@pytest.mark.parametrize("norm", ["none", "both", "right", "left"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("out_dim", [1, 2])
95
def test_graph_conv2(idtype, g, norm, weight, bias, out_dim):
96
    g = g.astype(idtype).to(F.ctx())
97
98
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias)
    ext_w = F.randn((5, out_dim))
99
100
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
101
    h = F.randn((nsrc, 5))
102
    h_dst = F.randn((ndst, out_dim))
103
    if weight:
104
        h_out = conv(g, h)
105
    else:
106
        h_out = conv(g, h, weight=ext_w)
107
    assert h_out.shape == (ndst, out_dim)
108

109

nv-dlasalle's avatar
nv-dlasalle committed
110
@parametrize_idtype
111
112
113
114
115
116
117
@pytest.mark.parametrize(
    "g", get_cases(["bipartite"], exclude=["zero-degree", "dglgraph"])
)
@pytest.mark.parametrize("norm", ["none", "both", "right"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("out_dim", [1, 2])
118
def test_graph_conv2_bi(idtype, g, norm, weight, bias, out_dim):
119
    g = g.astype(idtype).to(F.ctx())
120
121
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias)
    ext_w = F.randn((5, out_dim))
122
123
124
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5))
125
    h_dst = F.randn((ndst, out_dim))
126
127
128
129
    if weight:
        h_out = conv(g, (h, h_dst))
    else:
        h_out = conv(g, (h, h_dst), weight=ext_w)
130
    assert h_out.shape == (ndst, out_dim)
131

132

133
134
def test_simple_pool():
    ctx = F.ctx()
135
    g = dgl.DGLGraph(nx.path_graph(15)).to(F.ctx())
136
137
138
139

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
140
    sort_pool = nn.SortPooling(10)  # k = 10
141
142
143
144
145
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
    h0 = F.randn((g.number_of_nodes(), 5))
    h1 = sum_pool(g, h0)
146
    assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
147
    h1 = avg_pool(g, h0)
148
    assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
149
    h1 = max_pool(g, h0)
150
    assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
151
    h1 = sort_pool(g, h0)
152
    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.ndim == 2
153
154

    # test#2: batched graph
155
    g_ = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
156
157
158
    bg = dgl.batch([g, g_, g, g_, g])
    h0 = F.randn((bg.number_of_nodes(), 5))
    h1 = sum_pool(bg, h0)
159
160
161
162
163
164
165
166
167
168
    truth = tf.stack(
        [
            F.sum(h0[:15], 0),
            F.sum(h0[15:20], 0),
            F.sum(h0[20:35], 0),
            F.sum(h0[35:40], 0),
            F.sum(h0[40:55], 0),
        ],
        0,
    )
169
170
171
    assert F.allclose(h1, truth)

    h1 = avg_pool(bg, h0)
172
173
174
175
176
177
178
179
180
181
    truth = tf.stack(
        [
            F.mean(h0[:15], 0),
            F.mean(h0[15:20], 0),
            F.mean(h0[20:35], 0),
            F.mean(h0[35:40], 0),
            F.mean(h0[40:55], 0),
        ],
        0,
    )
182
183
184
    assert F.allclose(h1, truth)

    h1 = max_pool(bg, h0)
185
186
187
188
189
190
191
192
193
194
    truth = tf.stack(
        [
            F.max(h0[:15], 0),
            F.max(h0[15:20], 0),
            F.max(h0[20:35], 0),
            F.max(h0[35:40], 0),
            F.max(h0[40:55], 0),
        ],
        0,
    )
195
196
197
198
199
    assert F.allclose(h1, truth)

    h1 = sort_pool(bg, h0)
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.ndim == 2

200

201
def test_glob_att_pool():
202
    g = dgl.DGLGraph(nx.path_graph(10)).to(F.ctx())
203
204
205
206
207
208
209

    gap = nn.GlobalAttentionPooling(layers.Dense(1), layers.Dense(10))
    print(gap)

    # test#1: basic
    h0 = F.randn((g.number_of_nodes(), 5))
    h1 = gap(g, h0)
210
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2
211
212
213
214
215
216
217
218

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
    h0 = F.randn((bg.number_of_nodes(), 5))
    h1 = gap(bg, h0)
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2


219
@pytest.mark.parametrize("O", [1, 2, 8])
220
def test_rgcn(O):
221
    etype = []
222
223
224
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True).to(
        F.ctx()
    )
225
226
227
228
229
230
231
232
    # 5 etypes
    R = 5
    for i in range(g.number_of_edges()):
        etype.append(i % 5)
    B = 2
    I = 10

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
233
234
235
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
236
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
237
238
239
    h = tf.random.normal((100, I))
    r = tf.constant(etype)
    h_new = rgc_basis(g, h, r)
240
    h_new_low = rgc_basis_low(g, h, r)
241
    assert list(h_new.shape) == [100, O]
242
243
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
244

245
246
247
248
249
250
251
252
253
254
255
256
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
        rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True)
        rgc_bdd_low.weight = rgc_bdd.weight
        rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
        h = tf.random.normal((100, I))
        r = tf.constant(etype)
        h_new = rgc_bdd(g, h, r)
        h_new_low = rgc_bdd_low(g, h, r)
        assert list(h_new.shape) == [100, O]
        assert list(h_new_low.shape) == [100, O]
        assert F.allclose(h_new, h_new_low)
257
258
259
260
261

    # with norm
    norm = tf.zeros((g.number_of_edges(), 1))

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
262
263
264
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
265
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
266
267
268
    h = tf.random.normal((100, I))
    r = tf.constant(etype)
    h_new = rgc_basis(g, h, r, norm)
269
    h_new_low = rgc_basis_low(g, h, r, norm)
270
    assert list(h_new.shape) == [100, O]
271
272
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
273

274
275
276
277
278
279
280
281
282
283
284
285
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
        rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True)
        rgc_bdd_low.weight = rgc_bdd.weight
        rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
        h = tf.random.normal((100, I))
        r = tf.constant(etype)
        h_new = rgc_bdd(g, h, r, norm)
        h_new_low = rgc_bdd_low(g, h, r, norm)
        assert list(h_new.shape) == [100, O]
        assert list(h_new_low.shape) == [100, O]
        assert F.allclose(h_new, h_new_low)
286
287
288

    # id input
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
289
290
291
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
292
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
293
294
    h = tf.constant(np.random.randint(0, I, (100,))) * 1
    r = tf.constant(etype) * 1
295
    h_new = rgc_basis(g, h, r)
296
    h_new_low = rgc_basis_low(g, h, r)
297
    assert list(h_new.shape) == [100, O]
298
299
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
300

301

nv-dlasalle's avatar
nv-dlasalle committed
302
@parametrize_idtype
303
304
305
306
307
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
308
def test_gat_conv(g, idtype, out_dim, num_heads):
309
310
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
311
    gat = nn.GATConv(5, out_dim, num_heads)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
312
    feat = F.randn((g.number_of_src_nodes(), 5))
313
    h = gat(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
314
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
315
    _, a = gat(g, feat, get_attention=True)
316
    assert a.shape == (g.number_of_edges(), num_heads, 1)
317

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
318
319
320
321
    # test residual connection
    gat = nn.GATConv(5, out_dim, num_heads, residual=True)
    h = gat(g, feat)

322

nv-dlasalle's avatar
nv-dlasalle committed
323
@parametrize_idtype
324
325
326
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
327
def test_gat_conv_bi(g, idtype, out_dim, num_heads):
328
329
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
330
    gat = nn.GATConv(5, out_dim, num_heads)
331
332
333
334
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
335
    h = gat(g, feat)
336
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
337
    _, a = gat(g, feat, get_attention=True)
338
    assert a.shape == (g.number_of_edges(), num_heads, 1)
339

340

nv-dlasalle's avatar
nv-dlasalle committed
341
@parametrize_idtype
342
343
344
@pytest.mark.parametrize("g", get_cases(["homo", "block-bipartite"]))
@pytest.mark.parametrize("aggre_type", ["mean", "pool", "gcn"])
@pytest.mark.parametrize("out_dim", [1, 10])
345
def test_sage_conv(idtype, g, aggre_type, out_dim):
346
    g = g.astype(idtype).to(F.ctx())
347
    sage = nn.SAGEConv(5, out_dim, aggre_type)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
348
    feat = F.randn((g.number_of_src_nodes(), 5))
349
    h = sage(g, feat)
350
    assert h.shape[-1] == out_dim
351

352

nv-dlasalle's avatar
nv-dlasalle committed
353
@parametrize_idtype
354
355
356
@pytest.mark.parametrize("g", get_cases(["bipartite"]))
@pytest.mark.parametrize("aggre_type", ["mean", "pool", "gcn"])
@pytest.mark.parametrize("out_dim", [1, 2])
357
def test_sage_conv_bi(idtype, g, aggre_type, out_dim):
358
    g = g.astype(idtype).to(F.ctx())
359
    dst_dim = 5 if aggre_type != "gcn" else 10
360
    sage = nn.SAGEConv((10, dst_dim), out_dim, aggre_type)
361
362
363
364
    feat = (
        F.randn((g.number_of_src_nodes(), 10)),
        F.randn((g.number_of_dst_nodes(), dst_dim)),
    )
365
    h = sage(g, feat)
366
    assert h.shape[-1] == out_dim
367
    assert h.shape[0] == g.number_of_dst_nodes()
368

369

nv-dlasalle's avatar
nv-dlasalle committed
370
@parametrize_idtype
371
372
@pytest.mark.parametrize("aggre_type", ["mean", "pool", "gcn"])
@pytest.mark.parametrize("out_dim", [1, 2])
373
def test_sage_conv_bi_empty(idtype, aggre_type, out_dim):
Mufei Li's avatar
Mufei Li committed
374
    # Test the case for graphs without edges
375
376
377
    g = dgl.heterograph({("_U", "_E", "_V"): ([], [])}, {"_U": 5, "_V": 3}).to(
        F.ctx()
    )
378
    g = g.astype(idtype).to(F.ctx())
379
    sage = nn.SAGEConv((3, 3), out_dim, "gcn")
Mufei Li's avatar
Mufei Li committed
380
381
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    h = sage(g, feat)
382
    assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
383
    assert h.shape[0] == 3
384
    for aggre_type in ["mean", "pool", "lstm"]:
385
        sage = nn.SAGEConv((3, 1), out_dim, aggre_type)
Mufei Li's avatar
Mufei Li committed
386
387
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        h = sage(g, feat)
388
        assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
389
390
        assert h.shape[0] == 3

391

nv-dlasalle's avatar
nv-dlasalle committed
392
@parametrize_idtype
393
394
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
395
def test_sgc_conv(g, idtype, out_dim):
396
    ctx = F.ctx()
397
    g = g.astype(idtype).to(ctx)
398
    # not cached
399
    sgc = nn.SGConv(5, out_dim, 3)
400
    feat = F.randn((g.number_of_nodes(), 5))
401
402

    h = sgc(g, feat)
403
    assert h.shape[-1] == out_dim
404
405

    # cached
406
    sgc = nn.SGConv(5, out_dim, 3, True)
407
408
409
    h_0 = sgc(g, feat)
    h_1 = sgc(g, feat + 1)
    assert F.allclose(h_0, h_1)
410
    assert h_0.shape[-1] == out_dim
411

412

nv-dlasalle's avatar
nv-dlasalle committed
413
@parametrize_idtype
414
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
415
416
417
def test_appnp_conv(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
418
    appnp = nn.APPNPConv(10, 0.1)
419
    feat = F.randn((g.number_of_nodes(), 5))
420
421
422
423

    h = appnp(g, feat)
    assert h.shape[-1] == 5

424

nv-dlasalle's avatar
nv-dlasalle committed
425
@parametrize_idtype
426
427
@pytest.mark.parametrize("g", get_cases(["homo", "block-bipartite"]))
@pytest.mark.parametrize("aggregator_type", ["mean", "max", "sum"])
428
429
430
def test_gin_conv(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
431
    gin = nn.GINConv(tf.keras.layers.Dense(12), aggregator_type)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
432
    feat = F.randn((g.number_of_src_nodes(), 5))
433
    h = gin(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
434
    assert h.shape == (g.number_of_dst_nodes(), 12)
435

436

nv-dlasalle's avatar
nv-dlasalle committed
437
@parametrize_idtype
438
439
@pytest.mark.parametrize("g", get_cases(["bipartite"]))
@pytest.mark.parametrize("aggregator_type", ["mean", "max", "sum"])
440
441
def test_gin_conv_bi(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
442
443
444
445
    gin = nn.GINConv(tf.keras.layers.Dense(12), aggregator_type)
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
446
447
    )
    h = gin(g, feat)
448
    assert h.shape == (g.number_of_dst_nodes(), 12)
449

450

nv-dlasalle's avatar
nv-dlasalle committed
451
@parametrize_idtype
452
453
454
455
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 2])
kyawlinoo's avatar
kyawlinoo committed
456
457
458
459
def test_edge_conv(g, idtype, out_dim):
    g = g.astype(idtype).to(F.ctx())
    edge_conv = nn.EdgeConv(out_dim)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
460
    h0 = F.randn((g.number_of_src_nodes(), 5))
kyawlinoo's avatar
kyawlinoo committed
461
    h1 = edge_conv(g, h0)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
462
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)
kyawlinoo's avatar
kyawlinoo committed
463

464

nv-dlasalle's avatar
nv-dlasalle committed
465
@parametrize_idtype
466
467
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
kyawlinoo's avatar
kyawlinoo committed
468
469
470
471
472
473
474
475
476
477
def test_edge_conv_bi(g, idtype, out_dim):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    edge_conv = nn.EdgeConv(out_dim)

    h0 = F.randn((g.number_of_src_nodes(), 5))
    x0 = F.randn((g.number_of_dst_nodes(), 5))
    h1 = edge_conv(g, (h0, x0))
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)

478

479
480
481
482
483
484
def myagg(alist, dsttype):
    rst = alist[0]
    for i in range(1, len(alist)):
        rst = rst + (i + 1) * alist[i]
    return rst

485

nv-dlasalle's avatar
nv-dlasalle committed
486
@parametrize_idtype
487
@pytest.mark.parametrize("agg", ["sum", "max", "min", "mean", "stack", myagg])
488
def test_hetero_conv(agg, idtype):
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
    g = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 0, 2, 1], [1, 2, 1, 3]),
            ("user", "plays", "game"): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]),
            ("store", "sells", "game"): ([0, 0, 1, 1], [0, 3, 1, 2]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
    conv = nn.HeteroGraphConv(
        {
            "follows": nn.GraphConv(2, 3, allow_zero_in_degree=True),
            "plays": nn.GraphConv(2, 4, allow_zero_in_degree=True),
            "sells": nn.GraphConv(3, 4, allow_zero_in_degree=True),
        },
        agg,
    )
506
507
508
509
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))

510
511
512
513
514
    h = conv(g, {"user": uf, "store": sf, "game": gf})
    assert set(h.keys()) == {"user", "game"}
    if agg != "stack":
        assert h["user"].shape == (4, 3)
        assert h["game"].shape == (4, 4)
515
    else:
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
        assert h["user"].shape == (4, 1, 3)
        assert h["game"].shape == (4, 2, 4)

    block = dgl.to_block(
        g.to(F.cpu()), {"user": [0, 1, 2, 3], "game": [0, 1, 2, 3], "store": []}
    ).to(F.ctx())
    h = conv(
        block,
        (
            {"user": uf, "game": gf, "store": sf},
            {"user": uf, "game": gf, "store": sf[0:0]},
        ),
    )
    assert set(h.keys()) == {"user", "game"}
    if agg != "stack":
        assert h["user"].shape == (4, 3)
        assert h["game"].shape == (4, 4)
533
    else:
534
535
536
537
538
539
540
541
        assert h["user"].shape == (4, 1, 3)
        assert h["game"].shape == (4, 2, 4)

    h = conv(block, {"user": uf, "game": gf, "store": sf})
    assert set(h.keys()) == {"user", "game"}
    if agg != "stack":
        assert h["user"].shape == (4, 3)
        assert h["game"].shape == (4, 4)
542
    else:
543
544
        assert h["user"].shape == (4, 1, 3)
        assert h["game"].shape == (4, 2, 4)
545
546
547
548
549
550
551
552
553

    # test with mod args
    class MyMod(tf.keras.layers.Layer):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.carg2 = 0
            self.s1 = s1
            self.s2 = s2
554

555
556
557
558
559
560
        def call(self, g, h, arg1=None, *, arg2=None):
            if arg1 is not None:
                self.carg1 += 1
            if arg2 is not None:
                self.carg2 += 1
            return tf.zeros((g.number_of_dst_nodes(), self.s2))
561

562
563
564
    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
565
566
567
568
569
570
571
572
573
574
575
    conv = nn.HeteroGraphConv(
        {"follows": mod1, "plays": mod2, "sells": mod3}, agg
    )
    mod_args = {"follows": (1,), "plays": (1,)}
    mod_kwargs = {"sells": {"arg2": "abc"}}
    h = conv(
        g,
        {"user": uf, "game": gf, "store": sf},
        mod_args=mod_args,
        mod_kwargs=mod_kwargs,
    )
576
577
578
579
580
581
582
    assert mod1.carg1 == 1
    assert mod1.carg2 == 0
    assert mod2.carg1 == 1
    assert mod2.carg2 == 0
    assert mod3.carg1 == 0
    assert mod3.carg2 == 1

583
    # conv on graph without any edges
584
    for etype in g.etypes:
585
        g = dgl.remove_edges(g, g.edges(form="eid", etype=etype), etype=etype)
586
    assert g.num_edges() == 0
587
588
589
590
591
592
593
594
595
596
597
598
599
600
    h = conv(g, {"user": uf, "game": gf, "store": sf})
    assert set(h.keys()) == {"user", "game"}

    block = dgl.to_block(
        g.to(F.cpu()), {"user": [0, 1, 2, 3], "game": [0, 1, 2, 3], "store": []}
    ).to(F.ctx())
    h = conv(
        block,
        (
            {"user": uf, "game": gf, "store": sf},
            {"user": uf, "game": gf, "store": sf[0:0]},
        ),
    )
    assert set(h.keys()) == {"user", "game"}
601

602

603
@pytest.mark.parametrize("out_dim", [1, 2])
604
def test_dense_cheb_conv(out_dim):
605
606
    for k in range(3, 4):
        ctx = F.ctx()
607
608
609
        g = dgl.DGLGraph(
            sp.sparse.random(100, 100, density=0.1, random_state=42)
        )
610
611
        g = g.to(ctx)

612
613
614
        adj = tf.sparse.to_dense(
            tf.sparse.reorder(g.adjacency_matrix(transpose=True, ctx=ctx))
        )
615
616
        cheb = nn.ChebConv(5, out_dim, k, None, bias=True)
        dense_cheb = nn.DenseChebConv(5, out_dim, k, bias=True)
617
618
619
620
621

        # init cheb modules
        feat = F.ones((100, 5))
        out_cheb = cheb(g, feat, [2.0])

622
        dense_cheb.W = tf.reshape(cheb.linear.weights[0], (k, 5, out_dim))
623
624
625
626
627
628
629
630
        if cheb.linear.bias is not None:
            dense_cheb.bias = cheb.linear.bias

        out_dense_cheb = dense_cheb(adj, feat, 2.0)
        print(out_cheb - out_dense_cheb)
        assert F.allclose(out_cheb, out_dense_cheb)


631
if __name__ == "__main__":
632
633
634
635
636
637
638
639
640
641
642
643
    test_graph_conv()
    # test_set2set()
    test_glob_att_pool()
    test_simple_pool()
    # test_set_trans()
    test_rgcn()
    # test_tagconv()
    test_gat_conv()
    test_sage_conv()
    test_sgc_conv()
    test_appnp_conv()
    test_gin_conv()
kyawlinoo's avatar
kyawlinoo committed
644
    test_edge_conv()
645
646
647
648
649
650
    # test_agnn_conv()
    # test_gated_graph_conv()
    # test_nn_conv()
    # test_gmm_conv()
    # test_dense_graph_conv()
    # test_dense_sage_conv()
651
    test_dense_cheb_conv()
652
    # test_sequential()
653
    test_hetero_conv()