"git@developer.sourcefind.cn:change/sglang.git" did not exist on "10d60cd41bb520d2cbd16f577f6d60f578e3ab4a"
test_nn.py 19.6 KB
Newer Older
1
2
3
from copy import deepcopy

import backend as F
4
5
6
7

import dgl
import dgl.function as fn
import dgl.nn.tensorflow as nn
8
import networkx as nx
9
import numpy as np
10
import pytest
11
12
import scipy as sp
import tensorflow as tf
13
14
15
from tensorflow.keras import layers
from utils import parametrize_idtype
from utils.graph_cases import (
16
17
18
19
20
21
    get_cases,
    random_bipartite,
    random_dglgraph,
    random_graph,
)

22
23
24
25
26
27

def _AXWb(A, X, W, b):
    X = tf.matmul(X, W)
    Y = tf.reshape(tf.matmul(A, tf.reshape(X, (X.shape[0], -1))), X.shape)
    return Y + b

28
29

@pytest.mark.parametrize("out_dim", [1, 2])
30
def test_graph_conv(out_dim):
31
    g = dgl.DGLGraph(nx.path_graph(3)).to(F.ctx())
32
    ctx = F.ctx()
33
    adj = tf.sparse.to_dense(
34
        tf.sparse.reorder(g.adj_external(transpose=True, ctx=ctx))
35
    )
36

37
    conv = nn.GraphConv(5, out_dim, norm="none", bias=True)
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    # conv = conv
    print(conv)
    # test#1: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
    # test#2: more-dim
    h0 = F.ones((3, 5, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))

53
    conv = nn.GraphConv(5, out_dim)
54
55
56
57
58
59
60
61
62
63
64
65
    # conv = conv
    # test#3: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    # test#4: basic
    h0 = F.ones((3, 5, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0

66
    conv = nn.GraphConv(5, out_dim)
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    # conv = conv
    # test#3: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    # test#4: basic
    h0 = F.ones((3, 5, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0

    # test rest_parameters
    # old_weight = deepcopy(conv.weight.data)
    # conv.reset_parameters()
    # new_weight = conv.weight.data
    # assert not F.allclose(old_weight, new_weight)

85

nv-dlasalle's avatar
nv-dlasalle committed
86
@parametrize_idtype
87
88
89
90
91
92
93
94
@pytest.mark.parametrize(
    "g",
    get_cases(["homo", "block-bipartite"], exclude=["zero-degree", "dglgraph"]),
)
@pytest.mark.parametrize("norm", ["none", "both", "right", "left"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("out_dim", [1, 2])
95
def test_graph_conv2(idtype, g, norm, weight, bias, out_dim):
96
    g = g.astype(idtype).to(F.ctx())
97
98
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias)
    ext_w = F.randn((5, out_dim))
99
100
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
101
    h = F.randn((nsrc, 5))
102
    h_dst = F.randn((ndst, out_dim))
103
    if weight:
104
        h_out = conv(g, h)
105
    else:
106
        h_out = conv(g, h, weight=ext_w)
107
    assert h_out.shape == (ndst, out_dim)
108

109

nv-dlasalle's avatar
nv-dlasalle committed
110
@parametrize_idtype
111
112
113
114
115
116
117
@pytest.mark.parametrize(
    "g", get_cases(["bipartite"], exclude=["zero-degree", "dglgraph"])
)
@pytest.mark.parametrize("norm", ["none", "both", "right"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("out_dim", [1, 2])
118
def test_graph_conv2_bi(idtype, g, norm, weight, bias, out_dim):
119
    g = g.astype(idtype).to(F.ctx())
120
121
    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias)
    ext_w = F.randn((5, out_dim))
122
123
124
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5))
125
    h_dst = F.randn((ndst, out_dim))
126
127
128
129
    if weight:
        h_out = conv(g, (h, h_dst))
    else:
        h_out = conv(g, (h, h_dst), weight=ext_w)
130
    assert h_out.shape == (ndst, out_dim)
131

132

133
134
def test_simple_pool():
    ctx = F.ctx()
135
    g = dgl.DGLGraph(nx.path_graph(15)).to(F.ctx())
136
137
138
139

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
140
    sort_pool = nn.SortPooling(10)  # k = 10
141
142
143
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
144
    h0 = F.randn((g.num_nodes(), 5))
145
    h1 = sum_pool(g, h0)
146
    assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
147
    h1 = avg_pool(g, h0)
148
    assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
149
    h1 = max_pool(g, h0)
150
    assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
151
    h1 = sort_pool(g, h0)
152
    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.ndim == 2
153
154

    # test#2: batched graph
155
    g_ = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
156
    bg = dgl.batch([g, g_, g, g_, g])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
157
    h0 = F.randn((bg.num_nodes(), 5))
158
    h1 = sum_pool(bg, h0)
159
160
161
162
163
164
165
166
167
168
    truth = tf.stack(
        [
            F.sum(h0[:15], 0),
            F.sum(h0[15:20], 0),
            F.sum(h0[20:35], 0),
            F.sum(h0[35:40], 0),
            F.sum(h0[40:55], 0),
        ],
        0,
    )
169
170
171
    assert F.allclose(h1, truth)

    h1 = avg_pool(bg, h0)
172
173
174
175
176
177
178
179
180
181
    truth = tf.stack(
        [
            F.mean(h0[:15], 0),
            F.mean(h0[15:20], 0),
            F.mean(h0[20:35], 0),
            F.mean(h0[35:40], 0),
            F.mean(h0[40:55], 0),
        ],
        0,
    )
182
183
184
    assert F.allclose(h1, truth)

    h1 = max_pool(bg, h0)
185
186
187
188
189
190
191
192
193
194
    truth = tf.stack(
        [
            F.max(h0[:15], 0),
            F.max(h0[15:20], 0),
            F.max(h0[20:35], 0),
            F.max(h0[35:40], 0),
            F.max(h0[40:55], 0),
        ],
        0,
    )
195
196
197
198
199
    assert F.allclose(h1, truth)

    h1 = sort_pool(bg, h0)
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.ndim == 2

200

201
def test_glob_att_pool():
202
    g = dgl.DGLGraph(nx.path_graph(10)).to(F.ctx())
203
204
205
206
207

    gap = nn.GlobalAttentionPooling(layers.Dense(1), layers.Dense(10))
    print(gap)

    # test#1: basic
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
208
    h0 = F.randn((g.num_nodes(), 5))
209
    h1 = gap(g, h0)
210
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2
211
212
213

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
214
    h0 = F.randn((bg.num_nodes(), 5))
215
216
217
218
    h1 = gap(bg, h0)
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2


219
@pytest.mark.parametrize("O", [1, 2, 8])
220
def test_rgcn(O):
221
    etype = []
222
223
224
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True).to(
        F.ctx()
    )
225
226
    # 5 etypes
    R = 5
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
227
    for i in range(g.num_edges()):
228
229
230
231
232
        etype.append(i % 5)
    B = 2
    I = 10

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
233
234
235
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
236
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
237
238
239
    h = tf.random.normal((100, I))
    r = tf.constant(etype)
    h_new = rgc_basis(g, h, r)
240
    h_new_low = rgc_basis_low(g, h, r)
241
    assert list(h_new.shape) == [100, O]
242
243
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
244

245
246
247
248
249
250
251
252
253
254
255
256
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
        rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True)
        rgc_bdd_low.weight = rgc_bdd.weight
        rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
        h = tf.random.normal((100, I))
        r = tf.constant(etype)
        h_new = rgc_bdd(g, h, r)
        h_new_low = rgc_bdd_low(g, h, r)
        assert list(h_new.shape) == [100, O]
        assert list(h_new_low.shape) == [100, O]
        assert F.allclose(h_new, h_new_low)
257
258

    # with norm
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
259
    norm = tf.zeros((g.num_edges(), 1))
260
261

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
262
263
264
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
265
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
266
267
268
    h = tf.random.normal((100, I))
    r = tf.constant(etype)
    h_new = rgc_basis(g, h, r, norm)
269
    h_new_low = rgc_basis_low(g, h, r, norm)
270
    assert list(h_new.shape) == [100, O]
271
272
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
273

274
275
276
277
278
279
280
281
282
283
284
285
    if O % B == 0:
        rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
        rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True)
        rgc_bdd_low.weight = rgc_bdd.weight
        rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
        h = tf.random.normal((100, I))
        r = tf.constant(etype)
        h_new = rgc_bdd(g, h, r, norm)
        h_new_low = rgc_bdd_low(g, h, r, norm)
        assert list(h_new.shape) == [100, O]
        assert list(h_new_low.shape) == [100, O]
        assert F.allclose(h_new, h_new_low)
286
287
288

    # id input
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
289
290
291
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
292
    rgc_basis_low.loop_weight = rgc_basis.loop_weight
293
294
    h = tf.constant(np.random.randint(0, I, (100,))) * 1
    r = tf.constant(etype) * 1
295
    h_new = rgc_basis(g, h, r)
296
    h_new_low = rgc_basis_low(g, h, r)
297
    assert list(h_new.shape) == [100, O]
298
299
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
300

301

nv-dlasalle's avatar
nv-dlasalle committed
302
@parametrize_idtype
303
304
305
306
307
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
308
def test_gat_conv(g, idtype, out_dim, num_heads):
309
310
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
311
    gat = nn.GATConv(5, out_dim, num_heads)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
312
    feat = F.randn((g.number_of_src_nodes(), 5))
313
    h = gat(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
314
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
315
    _, a = gat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
316
    assert a.shape == (g.num_edges(), num_heads, 1)
317

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
318
319
320
321
    # test residual connection
    gat = nn.GATConv(5, out_dim, num_heads, residual=True)
    h = gat(g, feat)

322

nv-dlasalle's avatar
nv-dlasalle committed
323
@parametrize_idtype
324
325
326
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
327
def test_gat_conv_bi(g, idtype, out_dim, num_heads):
328
329
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
330
    gat = nn.GATConv(5, out_dim, num_heads)
331
332
333
334
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
    )
335
    h = gat(g, feat)
336
    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
337
    _, a = gat(g, feat, get_attention=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
338
    assert a.shape == (g.num_edges(), num_heads, 1)
339

340

nv-dlasalle's avatar
nv-dlasalle committed
341
@parametrize_idtype
342
343
344
@pytest.mark.parametrize("g", get_cases(["homo", "block-bipartite"]))
@pytest.mark.parametrize("aggre_type", ["mean", "pool", "gcn"])
@pytest.mark.parametrize("out_dim", [1, 10])
345
def test_sage_conv(idtype, g, aggre_type, out_dim):
346
    g = g.astype(idtype).to(F.ctx())
347
    sage = nn.SAGEConv(5, out_dim, aggre_type)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
348
    feat = F.randn((g.number_of_src_nodes(), 5))
349
    h = sage(g, feat)
350
    assert h.shape[-1] == out_dim
351

352

nv-dlasalle's avatar
nv-dlasalle committed
353
@parametrize_idtype
354
355
356
@pytest.mark.parametrize("g", get_cases(["bipartite"]))
@pytest.mark.parametrize("aggre_type", ["mean", "pool", "gcn"])
@pytest.mark.parametrize("out_dim", [1, 2])
357
def test_sage_conv_bi(idtype, g, aggre_type, out_dim):
358
    g = g.astype(idtype).to(F.ctx())
359
    dst_dim = 5 if aggre_type != "gcn" else 10
360
    sage = nn.SAGEConv((10, dst_dim), out_dim, aggre_type)
361
362
363
364
    feat = (
        F.randn((g.number_of_src_nodes(), 10)),
        F.randn((g.number_of_dst_nodes(), dst_dim)),
    )
365
    h = sage(g, feat)
366
    assert h.shape[-1] == out_dim
367
    assert h.shape[0] == g.number_of_dst_nodes()
368

369

nv-dlasalle's avatar
nv-dlasalle committed
370
@parametrize_idtype
371
372
@pytest.mark.parametrize("aggre_type", ["mean", "pool", "gcn"])
@pytest.mark.parametrize("out_dim", [1, 2])
373
def test_sage_conv_bi_empty(idtype, aggre_type, out_dim):
Mufei Li's avatar
Mufei Li committed
374
    # Test the case for graphs without edges
375
376
377
    g = dgl.heterograph({("_U", "_E", "_V"): ([], [])}, {"_U": 5, "_V": 3}).to(
        F.ctx()
    )
378
    g = g.astype(idtype).to(F.ctx())
379
    sage = nn.SAGEConv((3, 3), out_dim, "gcn")
Mufei Li's avatar
Mufei Li committed
380
381
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    h = sage(g, feat)
382
    assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
383
    assert h.shape[0] == 3
384
    for aggre_type in ["mean", "pool", "lstm"]:
385
        sage = nn.SAGEConv((3, 1), out_dim, aggre_type)
Mufei Li's avatar
Mufei Li committed
386
387
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        h = sage(g, feat)
388
        assert h.shape[-1] == out_dim
Mufei Li's avatar
Mufei Li committed
389
390
        assert h.shape[0] == 3

391

nv-dlasalle's avatar
nv-dlasalle committed
392
@parametrize_idtype
393
394
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
395
def test_sgc_conv(g, idtype, out_dim):
396
    ctx = F.ctx()
397
    g = g.astype(idtype).to(ctx)
398
    # not cached
399
    sgc = nn.SGConv(5, out_dim, 3)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
400
    feat = F.randn((g.num_nodes(), 5))
401
402

    h = sgc(g, feat)
403
    assert h.shape[-1] == out_dim
404
405

    # cached
406
    sgc = nn.SGConv(5, out_dim, 3, True)
407
408
409
    h_0 = sgc(g, feat)
    h_1 = sgc(g, feat + 1)
    assert F.allclose(h_0, h_1)
410
    assert h_0.shape[-1] == out_dim
411

412

nv-dlasalle's avatar
nv-dlasalle committed
413
@parametrize_idtype
414
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
415
416
417
def test_appnp_conv(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
418
    appnp = nn.APPNPConv(10, 0.1)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
419
    feat = F.randn((g.num_nodes(), 5))
420
421
422
423

    h = appnp(g, feat)
    assert h.shape[-1] == 5

424

nv-dlasalle's avatar
nv-dlasalle committed
425
@parametrize_idtype
426
427
@pytest.mark.parametrize("g", get_cases(["homo", "block-bipartite"]))
@pytest.mark.parametrize("aggregator_type", ["mean", "max", "sum"])
428
429
430
def test_gin_conv(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
431
    gin = nn.GINConv(tf.keras.layers.Dense(12), aggregator_type)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
432
    feat = F.randn((g.number_of_src_nodes(), 5))
433
    h = gin(g, feat)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
434
    assert h.shape == (g.number_of_dst_nodes(), 12)
435

436

nv-dlasalle's avatar
nv-dlasalle committed
437
@parametrize_idtype
438
439
@pytest.mark.parametrize("g", get_cases(["bipartite"]))
@pytest.mark.parametrize("aggregator_type", ["mean", "max", "sum"])
440
441
def test_gin_conv_bi(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
442
443
444
445
    gin = nn.GINConv(tf.keras.layers.Dense(12), aggregator_type)
    feat = (
        F.randn((g.number_of_src_nodes(), 5)),
        F.randn((g.number_of_dst_nodes(), 5)),
446
447
    )
    h = gin(g, feat)
448
    assert h.shape == (g.number_of_dst_nodes(), 12)
449

450

nv-dlasalle's avatar
nv-dlasalle committed
451
@parametrize_idtype
452
453
454
455
@pytest.mark.parametrize(
    "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 2])
kyawlinoo's avatar
kyawlinoo committed
456
457
458
459
def test_edge_conv(g, idtype, out_dim):
    g = g.astype(idtype).to(F.ctx())
    edge_conv = nn.EdgeConv(out_dim)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
460
    h0 = F.randn((g.number_of_src_nodes(), 5))
kyawlinoo's avatar
kyawlinoo committed
461
    h1 = edge_conv(g, h0)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
462
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)
kyawlinoo's avatar
kyawlinoo committed
463

464

nv-dlasalle's avatar
nv-dlasalle committed
465
@parametrize_idtype
466
467
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
kyawlinoo's avatar
kyawlinoo committed
468
469
470
471
472
473
474
475
476
477
def test_edge_conv_bi(g, idtype, out_dim):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    edge_conv = nn.EdgeConv(out_dim)

    h0 = F.randn((g.number_of_src_nodes(), 5))
    x0 = F.randn((g.number_of_dst_nodes(), 5))
    h1 = edge_conv(g, (h0, x0))
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)

478

479
480
481
482
483
484
def myagg(alist, dsttype):
    rst = alist[0]
    for i in range(1, len(alist)):
        rst = rst + (i + 1) * alist[i]
    return rst

485

nv-dlasalle's avatar
nv-dlasalle committed
486
@parametrize_idtype
487
@pytest.mark.parametrize("agg", ["sum", "max", "min", "mean", "stack", myagg])
488
def test_hetero_conv(agg, idtype):
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
    g = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 0, 2, 1], [1, 2, 1, 3]),
            ("user", "plays", "game"): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]),
            ("store", "sells", "game"): ([0, 0, 1, 1], [0, 3, 1, 2]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
    conv = nn.HeteroGraphConv(
        {
            "follows": nn.GraphConv(2, 3, allow_zero_in_degree=True),
            "plays": nn.GraphConv(2, 4, allow_zero_in_degree=True),
            "sells": nn.GraphConv(3, 4, allow_zero_in_degree=True),
        },
        agg,
    )
506
507
508
509
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))

510
511
512
513
514
    h = conv(g, {"user": uf, "store": sf, "game": gf})
    assert set(h.keys()) == {"user", "game"}
    if agg != "stack":
        assert h["user"].shape == (4, 3)
        assert h["game"].shape == (4, 4)
515
    else:
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
        assert h["user"].shape == (4, 1, 3)
        assert h["game"].shape == (4, 2, 4)

    block = dgl.to_block(
        g.to(F.cpu()), {"user": [0, 1, 2, 3], "game": [0, 1, 2, 3], "store": []}
    ).to(F.ctx())
    h = conv(
        block,
        (
            {"user": uf, "game": gf, "store": sf},
            {"user": uf, "game": gf, "store": sf[0:0]},
        ),
    )
    assert set(h.keys()) == {"user", "game"}
    if agg != "stack":
        assert h["user"].shape == (4, 3)
        assert h["game"].shape == (4, 4)
533
    else:
534
535
536
537
538
539
540
541
        assert h["user"].shape == (4, 1, 3)
        assert h["game"].shape == (4, 2, 4)

    h = conv(block, {"user": uf, "game": gf, "store": sf})
    assert set(h.keys()) == {"user", "game"}
    if agg != "stack":
        assert h["user"].shape == (4, 3)
        assert h["game"].shape == (4, 4)
542
    else:
543
544
        assert h["user"].shape == (4, 1, 3)
        assert h["game"].shape == (4, 2, 4)
545
546
547
548
549
550
551
552
553

    # test with mod args
    class MyMod(tf.keras.layers.Layer):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.carg2 = 0
            self.s1 = s1
            self.s2 = s2
554

555
556
557
558
559
560
        def call(self, g, h, arg1=None, *, arg2=None):
            if arg1 is not None:
                self.carg1 += 1
            if arg2 is not None:
                self.carg2 += 1
            return tf.zeros((g.number_of_dst_nodes(), self.s2))
561

562
563
564
    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
565
566
567
568
569
570
571
572
573
574
575
    conv = nn.HeteroGraphConv(
        {"follows": mod1, "plays": mod2, "sells": mod3}, agg
    )
    mod_args = {"follows": (1,), "plays": (1,)}
    mod_kwargs = {"sells": {"arg2": "abc"}}
    h = conv(
        g,
        {"user": uf, "game": gf, "store": sf},
        mod_args=mod_args,
        mod_kwargs=mod_kwargs,
    )
576
577
578
579
580
581
582
    assert mod1.carg1 == 1
    assert mod1.carg2 == 0
    assert mod2.carg1 == 1
    assert mod2.carg2 == 0
    assert mod3.carg1 == 0
    assert mod3.carg2 == 1

583
    # conv on graph without any edges
584
    for etype in g.etypes:
585
        g = dgl.remove_edges(g, g.edges(form="eid", etype=etype), etype=etype)
586
    assert g.num_edges() == 0
587
588
589
590
591
592
593
594
595
596
597
598
599
600
    h = conv(g, {"user": uf, "game": gf, "store": sf})
    assert set(h.keys()) == {"user", "game"}

    block = dgl.to_block(
        g.to(F.cpu()), {"user": [0, 1, 2, 3], "game": [0, 1, 2, 3], "store": []}
    ).to(F.ctx())
    h = conv(
        block,
        (
            {"user": uf, "game": gf, "store": sf},
            {"user": uf, "game": gf, "store": sf[0:0]},
        ),
    )
    assert set(h.keys()) == {"user", "game"}
601

602

603
@pytest.mark.parametrize("out_dim", [1, 2])
604
def test_dense_cheb_conv(out_dim):
605
606
    for k in range(3, 4):
        ctx = F.ctx()
607
608
609
        g = dgl.DGLGraph(
            sp.sparse.random(100, 100, density=0.1, random_state=42)
        )
610
611
        g = g.to(ctx)

612
        adj = tf.sparse.to_dense(
613
            tf.sparse.reorder(g.adj_external(transpose=True, ctx=ctx))
614
        )
615
616
        cheb = nn.ChebConv(5, out_dim, k, None, bias=True)
        dense_cheb = nn.DenseChebConv(5, out_dim, k, bias=True)
617
618
619
620
621

        # init cheb modules
        feat = F.ones((100, 5))
        out_cheb = cheb(g, feat, [2.0])

622
        dense_cheb.W = tf.reshape(cheb.linear.weights[0], (k, 5, out_dim))
623
624
625
626
627
628
629
630
        if cheb.linear.bias is not None:
            dense_cheb.bias = cheb.linear.bias

        out_dense_cheb = dense_cheb(adj, feat, 2.0)
        print(out_cheb - out_dense_cheb)
        assert F.allclose(out_cheb, out_dense_cheb)


631
if __name__ == "__main__":
632
633
634
635
636
637
638
639
640
641
642
643
    test_graph_conv()
    # test_set2set()
    test_glob_att_pool()
    test_simple_pool()
    # test_set_trans()
    test_rgcn()
    # test_tagconv()
    test_gat_conv()
    test_sage_conv()
    test_sgc_conv()
    test_appnp_conv()
    test_gin_conv()
kyawlinoo's avatar
kyawlinoo committed
644
    test_edge_conv()
645
646
647
648
649
650
    # test_agnn_conv()
    # test_gated_graph_conv()
    # test_nn_conv()
    # test_gmm_conv()
    # test_dense_graph_conv()
    # test_dense_sage_conv()
651
    test_dense_cheb_conv()
652
    # test_sequential()
653
    test_hetero_conv()