test_nn.py 19.2 KB
Newer Older
1
2
3
import tensorflow as tf
from tensorflow.keras import layers
import networkx as nx
4
import pytest
5
6
7
8
import dgl
import dgl.nn.tensorflow as nn
import dgl.function as fn
import backend as F
9
10
from test_utils.graph_cases import get_cases, random_graph, random_bipartite, random_dglgraph
from test_utils import parametrize_dtype
11
12
13
14
15
16
17
18
19
20
21
from copy import deepcopy

import numpy as np
import scipy as sp

def _AXWb(A, X, W, b):
    X = tf.matmul(X, W)
    Y = tf.reshape(tf.matmul(A, tf.reshape(X, (X.shape[0], -1))), X.shape)
    return Y + b

def test_graph_conv():
22
    g = dgl.DGLGraph(nx.path_graph(3)).to(F.ctx())
23
24
25
    ctx = F.ctx()
    adj = tf.sparse.to_dense(tf.sparse.reorder(g.adjacency_matrix(ctx=ctx)))

26
    conv = nn.GraphConv(5, 2, norm='none', bias=True)
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    # conv = conv
    print(conv)
    # test#1: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
    # test#2: more-dim
    h0 = F.ones((3, 5, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))

    conv = nn.GraphConv(5, 2)
    # conv = conv
    # test#3: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    # test#4: basic
    h0 = F.ones((3, 5, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0

    conv = nn.GraphConv(5, 2)
    # conv = conv
    # test#3: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    # test#4: basic
    h0 = F.ones((3, 5, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0

    # test rest_parameters
    # old_weight = deepcopy(conv.weight.data)
    # conv.reset_parameters()
    # new_weight = conv.weight.data
    # assert not F.allclose(old_weight, new_weight)

74
75
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree', 'dglgraph']))
76
77
78
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
79
80
def test_graph_conv2(idtype, g, norm, weight, bias):
    g = g.astype(idtype).to(F.ctx())
81
82
    conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias)
    ext_w = F.randn((5, 2))
83
84
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
85
    h = F.randn((nsrc, 5))
86
    h_dst = F.randn((ndst, 2))
87
    if weight:
88
        h_out = conv(g, h)
89
    else:
90
91
92
        h_out = conv(g, h, weight=ext_w)
    assert h_out.shape == (ndst, 2)

93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
def test_graph_conv2_bi(idtype, g, norm, weight, bias):
    g = g.astype(idtype).to(F.ctx())
    conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias)
    ext_w = F.randn((5, 2))
    nsrc = g.number_of_src_nodes()
    ndst = g.number_of_dst_nodes()
    h = F.randn((nsrc, 5))
    h_dst = F.randn((ndst, 2))
    if weight:
        h_out = conv(g, (h, h_dst))
    else:
        h_out = conv(g, (h, h_dst), weight=ext_w)
    assert h_out.shape == (ndst, 2)
111
112
113

def test_simple_pool():
    ctx = F.ctx()
114
    g = dgl.DGLGraph(nx.path_graph(15)).to(F.ctx())
115
116
117
118
119
120
121
122
123
124

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
    sort_pool = nn.SortPooling(10) # k = 10
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
    h0 = F.randn((g.number_of_nodes(), 5))
    h1 = sum_pool(g, h0)
125
    assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
126
    h1 = avg_pool(g, h0)
127
    assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
128
    h1 = max_pool(g, h0)
129
    assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
130
    h1 = sort_pool(g, h0)
131
    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.ndim == 2
132
133

    # test#2: batched graph
134
    g_ = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    bg = dgl.batch([g, g_, g, g_, g])
    h0 = F.randn((bg.number_of_nodes(), 5))
    h1 = sum_pool(bg, h0)
    truth = tf.stack([F.sum(h0[:15], 0),
                      F.sum(h0[15:20], 0),
                      F.sum(h0[20:35], 0),
                      F.sum(h0[35:40], 0),
                      F.sum(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)

    h1 = avg_pool(bg, h0)
    truth = tf.stack([F.mean(h0[:15], 0),
                      F.mean(h0[15:20], 0),
                      F.mean(h0[20:35], 0),
                      F.mean(h0[35:40], 0),
                      F.mean(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)

    h1 = max_pool(bg, h0)
    truth = tf.stack([F.max(h0[:15], 0),
                      F.max(h0[15:20], 0),
                      F.max(h0[20:35], 0),
                      F.max(h0[35:40], 0),
                      F.max(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)

    h1 = sort_pool(bg, h0)
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.ndim == 2

def uniform_attention(g, shape):
    a = F.ones(shape)
    target_shape = (g.number_of_edges(),) + (1,) * (len(shape) - 1)
    return a / tf.cast(tf.reshape(g.in_degrees(g.edges()[1]), target_shape), tf.float32)

def test_edge_softmax():
    # Basic
171
    g = dgl.DGLGraph(nx.path_graph(3)).to(F.ctx())
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    edata = F.ones((g.number_of_edges(), 1))
    a = nn.edge_softmax(g, edata)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    assert F.allclose(a, uniform_attention(g, a.shape))

    # Test higher dimension case
    edata = F.ones((g.number_of_edges(), 3, 1))
    a = nn.edge_softmax(g, edata)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    assert F.allclose(a, uniform_attention(g, a.shape))

    # Test both forward and backward with Tensorflow built-in softmax.
186
    g = dgl.DGLGraph().to(F.ctx())
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
    g.add_nodes(30)
    # build a complete graph
    for i in range(30):
        for j in range(30):
            g.add_edge(i, j)
    
    score = F.randn((900, 1))
    with tf.GradientTape() as tape:
        tape.watch(score)
        grad = F.randn((900, 1))
        y = tf.reshape(F.softmax(tf.reshape(score,(30, 30)), dim=0), (-1, 1))
        grads = tape.gradient(y, [score])
        grad_score = grads[0]

    with tf.GradientTape() as tape:
        tape.watch(score)
        y_dgl = nn.edge_softmax(g, score)
        assert len(g.ndata) == 0
        assert len(g.edata) == 0
        # check forward
        assert F.allclose(y_dgl, y)
        grads = tape.gradient(y_dgl, [score])
    # checkout gradient
    assert F.allclose(grads[0], grad_score)
    print(grads[0][:10], grad_score[:10])
    
    # Test 2
    def generate_rand_graph(n):
      arr = (sp.sparse.random(n, n, density=0.1, format='coo') != 0).astype(np.int64)
      return dgl.DGLGraph(arr, readonly=True)
    
218
    g = generate_rand_graph(50).to(F.ctx())
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
    a1 = F.randn((g.number_of_edges(), 1))
    a2 = tf.identity(a1)
    with tf.GradientTape() as tape:
        tape.watch(a1)
        g.edata['s'] = a1
        g.group_apply_edges('dst', lambda edges: {'ss':F.softmax(edges.data['s'], 1)})
        loss = tf.reduce_sum(g.edata['ss'])
        a1_grad = tape.gradient(loss, [a1])[0]
    
    with tf.GradientTape() as tape:
        tape.watch(a2)
        builtin_sm = nn.edge_softmax(g, a2)
        loss = tf.reduce_sum(builtin_sm)
        a2_grad = tape.gradient(loss, [a2])[0]
    print(a1_grad - a2_grad)
    assert len(g.ndata) == 0
    assert len(g.edata) == 2
    assert F.allclose(a1_grad, a2_grad, rtol=1e-4, atol=1e-4) # Follow tolerance in unittest backend

def test_partial_edge_softmax():
239
    g = dgl.DGLGraph().to(F.ctx())
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
    g.add_nodes(30)
    # build a complete graph
    for i in range(30):
        for j in range(30):
            g.add_edge(i, j)

    score = F.randn((300, 1))
    grad = F.randn((300, 1))
    import numpy as np
    eids = np.random.choice(900, 300, replace=False).astype('int64')
    eids = F.zerocopy_from_numpy(eids)
    # compute partial edge softmax
    with tf.GradientTape() as tape:
        tape.watch(score)
        y_1 = nn.edge_softmax(g, score, eids)
        grads = tape.gradient(y_1, [score])
    grad_1 = grads[0]
    # compute edge softmax on edge subgraph
    subg = g.edge_subgraph(eids)
    with tf.GradientTape() as tape:
        tape.watch(score)
        y_2 = nn.edge_softmax(subg, score)
        grads = tape.gradient(y_2, [score])
    grad_2 = grads[0]

    assert F.allclose(y_1, y_2)
    assert F.allclose(grad_1, grad_2)

def test_glob_att_pool():
269
    g = dgl.DGLGraph(nx.path_graph(10)).to(F.ctx())
270
271
272
273
274
275
276

    gap = nn.GlobalAttentionPooling(layers.Dense(1), layers.Dense(10))
    print(gap)

    # test#1: basic
    h0 = F.randn((g.number_of_nodes(), 5))
    h1 = gap(g, h0)
277
    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2
278
279
280
281
282
283
284
285
286
287

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
    h0 = F.randn((bg.number_of_nodes(), 5))
    h1 = gap(bg, h0)
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2


def test_rgcn():
    etype = []
288
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True).to(F.ctx())
289
290
291
292
293
294
295
296
297
    # 5 etypes
    R = 5
    for i in range(g.number_of_edges()):
        etype.append(i % 5)
    B = 2
    I = 10
    O = 8

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
298
299
300
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
301
302
303
    h = tf.random.normal((100, I))
    r = tf.constant(etype)
    h_new = rgc_basis(g, h, r)
304
    h_new_low = rgc_basis_low(g, h, r)
305
    assert list(h_new.shape) == [100, O]
306
307
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
308
309

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
310
311
    rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True)
    rgc_bdd_low.weight = rgc_bdd.weight
312
313
314
    h = tf.random.normal((100, I))
    r = tf.constant(etype)
    h_new = rgc_bdd(g, h, r)
315
    h_new_low = rgc_bdd_low(g, h, r)
316
    assert list(h_new.shape) == [100, O]
317
318
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
319
320
321
322
323

    # with norm
    norm = tf.zeros((g.number_of_edges(), 1))

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
324
325
326
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
327
328
329
    h = tf.random.normal((100, I))
    r = tf.constant(etype)
    h_new = rgc_basis(g, h, r, norm)
330
    h_new_low = rgc_basis_low(g, h, r, norm)
331
    assert list(h_new.shape) == [100, O]
332
333
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
334
335

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
336
337
    rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True)
    rgc_bdd_low.weight = rgc_bdd.weight
338
339
340
    h = tf.random.normal((100, I))
    r = tf.constant(etype)
    h_new = rgc_bdd(g, h, r, norm)
341
    h_new_low = rgc_bdd_low(g, h, r, norm)
342
    assert list(h_new.shape) == [100, O]
343
344
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
345
346
347

    # id input
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
348
349
350
    rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True)
    rgc_basis_low.weight = rgc_basis.weight
    rgc_basis_low.w_comp = rgc_basis.w_comp
351
352
353
    h = tf.constant(np.random.randint(0, I, (100,)))
    r = tf.constant(etype)
    h_new = rgc_basis(g, h, r)
354
    h_new_low = rgc_basis_low(g, h, r)
355
    assert list(h_new.shape) == [100, O]
356
357
    assert list(h_new_low.shape) == [100, O]
    assert F.allclose(h_new, h_new_low)
358

359
360
361
362
363
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
def test_gat_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
364
    gat = nn.GATConv(5, 2, 4)
365
    feat = F.randn((g.number_of_nodes(), 5))
366
    h = gat(g, feat)
367
    assert h.shape == (g.number_of_nodes(), 4, 2)
368

369
370
371
372
373
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
def test_gat_conv_bi(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
374
    gat = nn.GATConv((5, 10), 2, 4)
375
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 10)))
376
    h = gat(g, feat)
377
    assert h.shape == (g.number_of_dst_nodes(), 4, 2)
378

379
380
381
382
383
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
def test_sage_conv(idtype, g, aggre_type):
    g = g.astype(idtype).to(F.ctx())
384
    sage = nn.SAGEConv(5, 10, aggre_type)
385
    feat = F.randn((g.number_of_nodes(), 5))
386
387
388
    h = sage(g, feat)
    assert h.shape[-1] == 10

389
390
391
392
393
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
def test_sage_conv_bi(idtype, g, aggre_type):
    g = g.astype(idtype).to(F.ctx())
394
395
396
    sage = nn.SAGEConv(5, 10, aggre_type)
    dst_dim = 5 if aggre_type != 'gcn' else 10
    sage = nn.SAGEConv((10, dst_dim), 2, aggre_type)
397
    feat = (F.randn((g.number_of_src_nodes(), 10)), F.randn((g.number_of_dst_nodes(), dst_dim)))
398
399
    h = sage(g, feat)
    assert h.shape[-1] == 2
400
    assert h.shape[0] == g.number_of_dst_nodes()
401

402
403
404
@parametrize_dtype
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
def test_sage_conv_bi_empty(idtype, aggre_type):
Mufei Li's avatar
Mufei Li committed
405
    # Test the case for graphs without edges
406
407
    g = dgl.bipartite([], num_nodes=(5, 3)).to(F.ctx())
    g = g.astype(idtype).to(F.ctx())
Mufei Li's avatar
Mufei Li committed
408
409
410
411
412
413
414
415
416
417
418
419
    sage = nn.SAGEConv((3, 3), 2, 'gcn')
    feat = (F.randn((5, 3)), F.randn((3, 3)))
    h = sage(g, feat)
    assert h.shape[-1] == 2
    assert h.shape[0] == 3
    for aggre_type in ['mean', 'pool', 'lstm']:
        sage = nn.SAGEConv((3, 1), 2, aggre_type)
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        h = sage(g, feat)
        assert h.shape[-1] == 2
        assert h.shape[0] == 3

420
421
def test_sgc_conv():
    ctx = F.ctx()
422
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True).to(F.ctx())
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
    # not cached
    sgc = nn.SGConv(5, 10, 3)
    feat = F.randn((100, 5))

    h = sgc(g, feat)
    assert h.shape[-1] == 10

    # cached
    sgc = nn.SGConv(5, 10, 3, True)
    h_0 = sgc(g, feat)
    h_1 = sgc(g, feat + 1)
    assert F.allclose(h_0, h_1)
    assert h_0.shape[-1] == 10

def test_appnp_conv():
438
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True).to(F.ctx())
439
440
441
442
443
444
    appnp = nn.APPNPConv(10, 0.1)
    feat = F.randn((100, 5))

    h = appnp(g, feat)
    assert h.shape[-1] == 5

445
446
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
447
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
448
449
450
def test_gin_conv(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
451
452
453
454
    gin = nn.GINConv(
        tf.keras.layers.Dense(12),
        aggregator_type
    )
455
    feat = F.randn((g.number_of_nodes(), 5))
456
    h = gin(g, feat)
457
    assert h.shape == (g.number_of_nodes(), 12)
458

459
460
461
462
463
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
def test_gin_conv_bi(g, idtype, aggregator_type):
    g = g.astype(idtype).to(F.ctx())
464
465
466
467
    gin = nn.GINConv(
        tf.keras.layers.Dense(12),
        aggregator_type
    )
468
    feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
469
    h = gin(g, feat)
470
    assert h.shape == (g.number_of_dst_nodes(), 12)
471

472
473
474
475
476
477
def myagg(alist, dsttype):
    rst = alist[0]
    for i in range(1, len(alist)):
        rst = rst + (i + 1) * alist[i]
    return rst

478
@parametrize_dtype
479
@pytest.mark.parametrize('agg', ['sum', 'max', 'min', 'mean', 'stack', myagg])
480
def test_hetero_conv(agg, idtype):
481
482
483
    g = dgl.heterograph({
        ('user', 'follows', 'user'): [(0, 1), (0, 2), (2, 1), (1, 3)],
        ('user', 'plays', 'game'): [(0, 0), (0, 2), (0, 3), (1, 0), (2, 2)],
484
485
        ('store', 'sells', 'game'): [(0, 0), (0, 3), (1, 1), (1, 2)]},
        idtype=idtype, device=F.ctx())
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
    conv = nn.HeteroGraphConv({
        'follows': nn.GraphConv(2, 3),
        'plays': nn.GraphConv(2, 4),
        'sells': nn.GraphConv(3, 4)},
        agg)
    uf = F.randn((4, 2))
    gf = F.randn((4, 4))
    sf = F.randn((2, 3))
    uf_dst = F.randn((4, 3))
    gf_dst = F.randn((4, 4))

    h = conv(g, {'user': uf})
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 1, 4)

    h = conv(g, {'user': uf, 'store': sf})
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 2, 4)

    h = conv(g, {'store': sf})
    assert set(h.keys()) == {'game'}
    if agg != 'stack':
        assert h['game'].shape == (4, 4)
    else:
        assert h['game'].shape == (4, 1, 4)

    # test with pair input
    conv = nn.HeteroGraphConv({
        'follows': nn.SAGEConv(2, 3, 'mean'),
        'plays': nn.SAGEConv((2, 4), 4, 'mean'),
        'sells': nn.SAGEConv(3, 4, 'mean')},
        agg)

    h = conv(g, ({'user': uf}, {'user' : uf, 'game' : gf}))
    assert set(h.keys()) == {'user', 'game'}
    if agg != 'stack':
        assert h['user'].shape == (4, 3)
        assert h['game'].shape == (4, 4)
    else:
        assert h['user'].shape == (4, 1, 3)
        assert h['game'].shape == (4, 1, 4)

    # pair input requires both src and dst type features to be provided
    h = conv(g, ({'user': uf}, {'game' : gf}))
    assert set(h.keys()) == {'game'}
    if agg != 'stack':
        assert h['game'].shape == (4, 4)
    else:
        assert h['game'].shape == (4, 1, 4)

    # test with mod args
    class MyMod(tf.keras.layers.Layer):
        def __init__(self, s1, s2):
            super(MyMod, self).__init__()
            self.carg1 = 0
            self.carg2 = 0
            self.s1 = s1
            self.s2 = s2
        def call(self, g, h, arg1=None, *, arg2=None):
            if arg1 is not None:
                self.carg1 += 1
            if arg2 is not None:
                self.carg2 += 1
            return tf.zeros((g.number_of_dst_nodes(), self.s2))
    mod1 = MyMod(2, 3)
    mod2 = MyMod(2, 4)
    mod3 = MyMod(3, 4)
    conv = nn.HeteroGraphConv({
        'follows': mod1,
        'plays': mod2,
        'sells': mod3},
        agg)
    mod_args = {'follows' : (1,), 'plays' : (1,)}
    mod_kwargs = {'sells' : {'arg2' : 'abc'}}
    h = conv(g, {'user' : uf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs)
    assert mod1.carg1 == 1
    assert mod1.carg2 == 0
    assert mod2.carg1 == 1
    assert mod2.carg2 == 0
    assert mod3.carg1 == 0
    assert mod3.carg2 == 1

578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
if __name__ == '__main__':
    test_graph_conv()
    test_edge_softmax()
    test_partial_edge_softmax()
    # test_set2set()
    test_glob_att_pool()
    test_simple_pool()
    # test_set_trans()
    test_rgcn()
    # test_tagconv()
    test_gat_conv()
    test_sage_conv()
    test_sgc_conv()
    test_appnp_conv()
    test_gin_conv()
    # test_agnn_conv()
    # test_gated_graph_conv()
    # test_nn_conv()
    # test_gmm_conv()
    # test_dense_graph_conv()
    # test_dense_sage_conv()
    # test_dense_cheb_conv()
    # test_sequential()