test_nn.py 10.2 KB
Newer Older
1
2
3
4
import torch as th
import networkx as nx
import dgl
import dgl.nn.pytorch as nn
5
import backend as F
6
7
from copy import deepcopy

8
9
10
import numpy as np
import scipy as sp

11
12
13
14
15
16
17
def _AXWb(A, X, W, b):
    X = th.matmul(X, W)
    Y = th.matmul(A, X.view(X.shape[0], -1)).view_as(X)
    return Y + b

def test_graph_conv():
    g = dgl.DGLGraph(nx.path_graph(3))
18
19
    ctx = F.ctx()
    adj = g.adjacency_matrix(ctx=ctx)
20
21

    conv = nn.GraphConv(5, 2, norm=False, bias=True)
22
23
    if F.gpu_ctx():
        conv.cuda()
24
25
    print(conv)
    # test#1: basic
26
    h0 = F.ones((3, 5))
27
    h1 = conv(h0, g)
28
29
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
30
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
31
    # test#2: more-dim
32
    h0 = F.ones((3, 5, 5))
33
    h1 = conv(h0, g)
34
35
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
36
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
37
38

    conv = nn.GraphConv(5, 2)
39
40
    if F.gpu_ctx():
        conv.cuda()
41
    # test#3: basic
42
    h0 = F.ones((3, 5))
43
    h1 = conv(h0, g)
44
45
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
46
    # test#4: basic
47
    h0 = F.ones((3, 5, 5))
48
    h1 = conv(h0, g)
49
50
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
51
52

    conv = nn.GraphConv(5, 2)
53
54
    if F.gpu_ctx():
        conv.cuda()
55
    # test#3: basic
56
    h0 = F.ones((3, 5))
57
    h1 = conv(h0, g)
58
59
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
60
    # test#4: basic
61
    h0 = F.ones((3, 5, 5))
62
    h1 = conv(h0, g)
63
64
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
65
66
67
68
69

    # test rest_parameters
    old_weight = deepcopy(conv.weight.data)
    conv.reset_parameters()
    new_weight = conv.weight.data
70
    assert not F.allclose(old_weight, new_weight)
71

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def _S2AXWb(A, N, X, W, b):
    X1 = X * N
    X1 = th.matmul(A, X1.view(X1.shape[0], -1))
    X1 = X1 * N
    X2 = X1 * N
    X2 = th.matmul(A, X2.view(X2.shape[0], -1))
    X2 = X2 * N
    X = th.cat([X, X1, X2], dim=-1)
    Y = th.matmul(X, W.rot90())

    return Y + b

def test_tgconv():
    g = dgl.DGLGraph(nx.path_graph(3))
    ctx = F.ctx()
    adj = g.adjacency_matrix(ctx=ctx)
    norm = th.pow(g.in_degrees().float(), -0.5)

    conv = nn.TGConv(5, 2, bias=True)
    if F.gpu_ctx():
        conv.cuda()
    print(conv)

    # test#1: basic
    h0 = F.ones((3, 5))
    h1 = conv(h0, g)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    shp = norm.shape + (1,) * (h0.dim() - 1)
    norm = th.reshape(norm, shp).to(ctx)

    assert F.allclose(h1, _S2AXWb(adj, norm, h0, conv.lin.weight, conv.lin.bias))

    conv = nn.TGConv(5, 2)
    if F.gpu_ctx():
        conv.cuda()
    # test#2: basic
    h0 = F.ones((3, 5))
    h1 = conv(h0, g)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0

    # test rest_parameters
    old_weight = deepcopy(conv.lin.weight.data)
    conv.reset_parameters()
    new_weight = conv.lin.weight.data
    assert not F.allclose(old_weight, new_weight)

120
121
122
123
def test_set2set():
    g = dgl.DGLGraph(nx.path_graph(10))

    s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
124
125
    if F.gpu_ctx():
        s2s.cuda()
126
127
128
    print(s2s)

    # test#1: basic
129
    h0 = F.randn((g.number_of_nodes(), 5))
130
131
132
133
134
135
136
    h1 = s2s(h0, g)
    assert h1.shape[0] == 10 and h1.dim() == 1

    # test#2: batched graph
    g1 = dgl.DGLGraph(nx.path_graph(11))
    g2 = dgl.DGLGraph(nx.path_graph(5))
    bg = dgl.batch([g, g1, g2])
137
    h0 = F.randn((bg.number_of_nodes(), 5))
138
139
140
141
142
143
144
    h1 = s2s(h0, bg)
    assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2

def test_glob_att_pool():
    g = dgl.DGLGraph(nx.path_graph(10))

    gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))
145
146
    if F.gpu_ctx():
        gap.cuda()
147
148
149
    print(gap)

    # test#1: basic
150
    h0 = F.randn((g.number_of_nodes(), 5))
151
152
153
154
155
    h1 = gap(h0, g)
    assert h1.shape[0] == 10 and h1.dim() == 1

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
156
    h0 = F.randn((bg.number_of_nodes(), 5))
157
158
159
160
161
162
163
164
165
166
167
168
169
    h1 = gap(h0, bg)
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2

def test_simple_pool():
    g = dgl.DGLGraph(nx.path_graph(15))

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
    sort_pool = nn.SortPooling(10) # k = 10
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
170
    h0 = F.randn((g.number_of_nodes(), 5))
171
    h1 = sum_pool(h0, g)
172
    assert F.allclose(h1, F.sum(h0, 0))
173
    h1 = avg_pool(h0, g)
174
    assert F.allclose(h1, F.mean(h0, 0))
175
    h1 = max_pool(h0, g)
176
    assert F.allclose(h1, F.max(h0, 0))
177
178
179
180
181
182
    h1 = sort_pool(h0, g)
    assert h1.shape[0] == 10 * 5 and h1.dim() == 1

    # test#2: batched graph
    g_ = dgl.DGLGraph(nx.path_graph(5))
    bg = dgl.batch([g, g_, g, g_, g])
183
    h0 = F.randn((bg.number_of_nodes(), 5))
184
185

    h1 = sum_pool(h0, bg)
186
187
188
189
190
191
    truth = th.stack([F.sum(h0[:15], 0),
                      F.sum(h0[15:20], 0),
                      F.sum(h0[20:35], 0),
                      F.sum(h0[35:40], 0),
                      F.sum(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
192
193

    h1 = avg_pool(h0, bg)
194
195
196
197
198
199
    truth = th.stack([F.mean(h0[:15], 0),
                      F.mean(h0[15:20], 0),
                      F.mean(h0[20:35], 0),
                      F.mean(h0[35:40], 0),
                      F.mean(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
200
201

    h1 = max_pool(h0, bg)
202
203
204
205
206
207
    truth = th.stack([F.max(h0[:15], 0),
                      F.max(h0[15:20], 0),
                      F.max(h0[20:35], 0),
                      F.max(h0[35:40], 0),
                      F.max(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
208
209
210
211
212
213
214
215
216
217

    h1 = sort_pool(h0, bg)
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2

def test_set_trans():
    g = dgl.DGLGraph(nx.path_graph(15))

    st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'sab')
    st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'isab', 3)
    st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)
218
219
220
221
    if F.gpu_ctx():
        st_enc_0.cuda()
        st_enc_1.cuda()
        st_dec.cuda()
222
223
224
    print(st_enc_0, st_enc_1, st_dec)

    # test#1: basic
225
    h0 = F.randn((g.number_of_nodes(), 50))
226
227
228
229
230
231
232
233
234
235
236
    h1 = st_enc_0(h0, g)
    assert h1.shape == h0.shape
    h1 = st_enc_1(h0, g)
    assert h1.shape == h0.shape
    h2 = st_dec(h1, g)
    assert h2.shape[0] == 200 and h2.dim() == 1

    # test#2: batched graph
    g1 = dgl.DGLGraph(nx.path_graph(5))
    g2 = dgl.DGLGraph(nx.path_graph(10))
    bg = dgl.batch([g, g1, g2])
237
    h0 = F.randn((bg.number_of_nodes(), 50))
238
239
240
241
242
243
244
245
    h1 = st_enc_0(h0, bg)
    assert h1.shape == h0.shape
    h1 = st_enc_1(h0, bg)
    assert h1.shape == h0.shape

    h2 = st_dec(h1, bg)
    assert h2.shape[0] == 3 and h2.shape[1] == 200 and h2.dim() == 2

246
247
248
249
250
251
def uniform_attention(g, shape):
    a = th.ones(shape)
    target_shape = (g.number_of_edges(),) + (1,) * (len(shape) - 1)
    return a / g.in_degrees(g.edges()[1]).view(target_shape).float()

def test_edge_softmax():
252
253
    # Basic
    g = dgl.DGLGraph(nx.path_graph(3))
254
    edata = F.ones((g.number_of_edges(), 1))
255
    a = nn.edge_softmax(g, edata)
256
257
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
258
    assert F.allclose(a, uniform_attention(g, a.shape))
259

260
    # Test higher dimension case
261
    edata = F.ones((g.number_of_edges(), 3, 1))
262
    a = nn.edge_softmax(g, edata)
263
264
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
265
    assert F.allclose(a, uniform_attention(g, a.shape))
266

267
268
269
270
271
272
273
274
    # Test both forward and backward with PyTorch built-in softmax.
    g = dgl.DGLGraph()
    g.add_nodes(30)
    # build a complete graph
    for i in range(30):
        for j in range(30):
            g.add_edge(i, j)

275
    score = F.randn((900, 1))
276
    score.requires_grad_()
277
278
    grad = F.randn((900, 1))
    y = F.softmax(score.view(30, 30), dim=0).view(-1, 1)
279
280
281
282
    y.backward(grad)
    grad_score = score.grad
    score.grad.zero_()
    y_dgl = nn.edge_softmax(g, score)
283
284
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
285
    # check forward
286
    assert F.allclose(y_dgl, y)
287
288
    y_dgl.backward(grad)
    # checkout gradient
289
    assert F.allclose(score.grad, grad_score)
290
291
292
293
294
295
296
297
    print(score.grad[:10], grad_score[:10])
    
    # Test 2
    def generate_rand_graph(n):
      arr = (sp.sparse.random(n, n, density=0.1, format='coo') != 0).astype(np.int64)
      return dgl.DGLGraph(arr, readonly=True)
    
    g = generate_rand_graph(50)
298
    a1 = F.randn((g.number_of_edges(), 1)).requires_grad_()
299
300
    a2 = a1.clone().detach().requires_grad_()
    g.edata['s'] = a1
301
    g.group_apply_edges('dst', lambda edges: {'ss':F.softmax(edges.data['s'], 1)})
302
303
304
305
306
    g.edata['ss'].sum().backward()
    
    builtin_sm = nn.edge_softmax(g, a2)
    builtin_sm.sum().backward()
    print(a1.grad - a2.grad)
307
308
    assert len(g.ndata) == 0
    assert len(g.edata) == 2
309
    assert F.allclose(a1.grad, a2.grad, rtol=1e-4, atol=1e-4) # Follow tolerance in unittest backend
310
    
Minjie Wang's avatar
Minjie Wang committed
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
def test_rgcn():
    ctx = F.ctx()
    etype = []
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    # 5 etypes
    R = 5
    for i in range(g.number_of_edges()):
        etype.append(i % 5)
    B = 2
    I = 10
    O = 8

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r)
    assert list(h_new.shape) == [100, O]

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_bdd(g, h, r)
    assert list(h_new.shape) == [100, O]

    # with norm
    norm = th.zeros((g.number_of_edges(), 1)).to(ctx)

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r, norm)
    assert list(h_new.shape) == [100, O]

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_bdd(g, h, r, norm)
    assert list(h_new.shape) == [100, O]

    # id input
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
    h = th.randint(0, I, (100,)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r)
    assert list(h_new.shape) == [100, O]
356
357
358
359

if __name__ == '__main__':
    test_graph_conv()
    test_edge_softmax()
360
361
362
363
    test_set2set()
    test_glob_att_pool()
    test_simple_pool()
    test_set_trans()
Minjie Wang's avatar
Minjie Wang committed
364
    test_rgcn()