test_nn.py 9.02 KB
Newer Older
1
2
3
4
import torch as th
import networkx as nx
import dgl
import dgl.nn.pytorch as nn
5
import backend as F
6
7
from copy import deepcopy

8
9
10
import numpy as np
import scipy as sp

11
12
13
14
15
16
17
def _AXWb(A, X, W, b):
    X = th.matmul(X, W)
    Y = th.matmul(A, X.view(X.shape[0], -1)).view_as(X)
    return Y + b

def test_graph_conv():
    g = dgl.DGLGraph(nx.path_graph(3))
18
19
    ctx = F.ctx()
    adj = g.adjacency_matrix(ctx=ctx)
20
21

    conv = nn.GraphConv(5, 2, norm=False, bias=True)
22
23
    if F.gpu_ctx():
        conv.cuda()
24
25
    print(conv)
    # test#1: basic
26
    h0 = F.ones((3, 5))
27
    h1 = conv(h0, g)
28
29
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
30
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
31
    # test#2: more-dim
32
    h0 = F.ones((3, 5, 5))
33
    h1 = conv(h0, g)
34
35
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
36
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
37
38

    conv = nn.GraphConv(5, 2)
39
40
    if F.gpu_ctx():
        conv.cuda()
41
    # test#3: basic
42
    h0 = F.ones((3, 5))
43
    h1 = conv(h0, g)
44
45
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
46
    # test#4: basic
47
    h0 = F.ones((3, 5, 5))
48
    h1 = conv(h0, g)
49
50
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
51
52

    conv = nn.GraphConv(5, 2)
53
54
    if F.gpu_ctx():
        conv.cuda()
55
    # test#3: basic
56
    h0 = F.ones((3, 5))
57
    h1 = conv(h0, g)
58
59
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
60
    # test#4: basic
61
    h0 = F.ones((3, 5, 5))
62
    h1 = conv(h0, g)
63
64
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
65
66
67
68
69

    # test rest_parameters
    old_weight = deepcopy(conv.weight.data)
    conv.reset_parameters()
    new_weight = conv.weight.data
70
    assert not F.allclose(old_weight, new_weight)
71

72
73
74
75
def test_set2set():
    g = dgl.DGLGraph(nx.path_graph(10))

    s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
76
77
    if F.gpu_ctx():
        s2s.cuda()
78
79
80
    print(s2s)

    # test#1: basic
81
    h0 = F.randn((g.number_of_nodes(), 5))
82
83
84
85
86
87
88
    h1 = s2s(h0, g)
    assert h1.shape[0] == 10 and h1.dim() == 1

    # test#2: batched graph
    g1 = dgl.DGLGraph(nx.path_graph(11))
    g2 = dgl.DGLGraph(nx.path_graph(5))
    bg = dgl.batch([g, g1, g2])
89
    h0 = F.randn((bg.number_of_nodes(), 5))
90
91
92
93
94
95
96
    h1 = s2s(h0, bg)
    assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2

def test_glob_att_pool():
    g = dgl.DGLGraph(nx.path_graph(10))

    gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))
97
98
    if F.gpu_ctx():
        gap.cuda()
99
100
101
    print(gap)

    # test#1: basic
102
    h0 = F.randn((g.number_of_nodes(), 5))
103
104
105
106
107
    h1 = gap(h0, g)
    assert h1.shape[0] == 10 and h1.dim() == 1

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
108
    h0 = F.randn((bg.number_of_nodes(), 5))
109
110
111
112
113
114
115
116
117
118
119
120
121
    h1 = gap(h0, bg)
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2

def test_simple_pool():
    g = dgl.DGLGraph(nx.path_graph(15))

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
    sort_pool = nn.SortPooling(10) # k = 10
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
122
    h0 = F.randn((g.number_of_nodes(), 5))
123
    h1 = sum_pool(h0, g)
124
    assert F.allclose(h1, F.sum(h0, 0))
125
    h1 = avg_pool(h0, g)
126
    assert F.allclose(h1, F.mean(h0, 0))
127
    h1 = max_pool(h0, g)
128
    assert F.allclose(h1, F.max(h0, 0))
129
130
131
132
133
134
    h1 = sort_pool(h0, g)
    assert h1.shape[0] == 10 * 5 and h1.dim() == 1

    # test#2: batched graph
    g_ = dgl.DGLGraph(nx.path_graph(5))
    bg = dgl.batch([g, g_, g, g_, g])
135
    h0 = F.randn((bg.number_of_nodes(), 5))
136
137

    h1 = sum_pool(h0, bg)
138
139
140
141
142
143
    truth = th.stack([F.sum(h0[:15], 0),
                      F.sum(h0[15:20], 0),
                      F.sum(h0[20:35], 0),
                      F.sum(h0[35:40], 0),
                      F.sum(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
144
145

    h1 = avg_pool(h0, bg)
146
147
148
149
150
151
    truth = th.stack([F.mean(h0[:15], 0),
                      F.mean(h0[15:20], 0),
                      F.mean(h0[20:35], 0),
                      F.mean(h0[35:40], 0),
                      F.mean(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
152
153

    h1 = max_pool(h0, bg)
154
155
156
157
158
159
    truth = th.stack([F.max(h0[:15], 0),
                      F.max(h0[15:20], 0),
                      F.max(h0[20:35], 0),
                      F.max(h0[35:40], 0),
                      F.max(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
160
161
162
163
164
165
166
167
168
169

    h1 = sort_pool(h0, bg)
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2

def test_set_trans():
    g = dgl.DGLGraph(nx.path_graph(15))

    st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'sab')
    st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'isab', 3)
    st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)
170
171
172
173
    if F.gpu_ctx():
        st_enc_0.cuda()
        st_enc_1.cuda()
        st_dec.cuda()
174
175
176
    print(st_enc_0, st_enc_1, st_dec)

    # test#1: basic
177
    h0 = F.randn((g.number_of_nodes(), 50))
178
179
180
181
182
183
184
185
186
187
188
    h1 = st_enc_0(h0, g)
    assert h1.shape == h0.shape
    h1 = st_enc_1(h0, g)
    assert h1.shape == h0.shape
    h2 = st_dec(h1, g)
    assert h2.shape[0] == 200 and h2.dim() == 1

    # test#2: batched graph
    g1 = dgl.DGLGraph(nx.path_graph(5))
    g2 = dgl.DGLGraph(nx.path_graph(10))
    bg = dgl.batch([g, g1, g2])
189
    h0 = F.randn((bg.number_of_nodes(), 50))
190
191
192
193
194
195
196
197
    h1 = st_enc_0(h0, bg)
    assert h1.shape == h0.shape
    h1 = st_enc_1(h0, bg)
    assert h1.shape == h0.shape

    h2 = st_dec(h1, bg)
    assert h2.shape[0] == 3 and h2.shape[1] == 200 and h2.dim() == 2

198
199
200
201
202
203
def uniform_attention(g, shape):
    a = th.ones(shape)
    target_shape = (g.number_of_edges(),) + (1,) * (len(shape) - 1)
    return a / g.in_degrees(g.edges()[1]).view(target_shape).float()

def test_edge_softmax():
204
205
    # Basic
    g = dgl.DGLGraph(nx.path_graph(3))
206
    edata = F.ones((g.number_of_edges(), 1))
207
    a = nn.edge_softmax(g, edata)
208
209
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
210
    assert F.allclose(a, uniform_attention(g, a.shape))
211

212
    # Test higher dimension case
213
    edata = F.ones((g.number_of_edges(), 3, 1))
214
    a = nn.edge_softmax(g, edata)
215
216
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
217
    assert F.allclose(a, uniform_attention(g, a.shape))
218

219
220
221
222
223
224
225
226
    # Test both forward and backward with PyTorch built-in softmax.
    g = dgl.DGLGraph()
    g.add_nodes(30)
    # build a complete graph
    for i in range(30):
        for j in range(30):
            g.add_edge(i, j)

227
    score = F.randn((900, 1))
228
    score.requires_grad_()
229
230
    grad = F.randn((900, 1))
    y = F.softmax(score.view(30, 30), dim=0).view(-1, 1)
231
232
233
234
    y.backward(grad)
    grad_score = score.grad
    score.grad.zero_()
    y_dgl = nn.edge_softmax(g, score)
235
236
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
237
    # check forward
238
    assert F.allclose(y_dgl, y)
239
240
    y_dgl.backward(grad)
    # checkout gradient
241
    assert F.allclose(score.grad, grad_score)
242
243
244
245
246
247
248
249
    print(score.grad[:10], grad_score[:10])
    
    # Test 2
    def generate_rand_graph(n):
      arr = (sp.sparse.random(n, n, density=0.1, format='coo') != 0).astype(np.int64)
      return dgl.DGLGraph(arr, readonly=True)
    
    g = generate_rand_graph(50)
250
    a1 = F.randn((g.number_of_edges(), 1)).requires_grad_()
251
252
    a2 = a1.clone().detach().requires_grad_()
    g.edata['s'] = a1
253
    g.group_apply_edges('dst', lambda edges: {'ss':F.softmax(edges.data['s'], 1)})
254
255
256
257
258
    g.edata['ss'].sum().backward()
    
    builtin_sm = nn.edge_softmax(g, a2)
    builtin_sm.sum().backward()
    print(a1.grad - a2.grad)
259
260
    assert len(g.ndata) == 0
    assert len(g.edata) == 2
261
    assert F.allclose(a1.grad, a2.grad, rtol=1e-4, atol=1e-4) # Follow tolerance in unittest backend
262
    
Minjie Wang's avatar
Minjie Wang committed
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
def test_rgcn():
    ctx = F.ctx()
    etype = []
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    # 5 etypes
    R = 5
    for i in range(g.number_of_edges()):
        etype.append(i % 5)
    B = 2
    I = 10
    O = 8

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r)
    assert list(h_new.shape) == [100, O]

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_bdd(g, h, r)
    assert list(h_new.shape) == [100, O]

    # with norm
    norm = th.zeros((g.number_of_edges(), 1)).to(ctx)

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r, norm)
    assert list(h_new.shape) == [100, O]

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
    h = th.randn((100, I)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_bdd(g, h, r, norm)
    assert list(h_new.shape) == [100, O]

    # id input
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
    h = th.randint(0, I, (100,)).to(ctx)
    r = th.tensor(etype).to(ctx)
    h_new = rgc_basis(g, h, r)
    assert list(h_new.shape) == [100, O]
308
309
310
311

if __name__ == '__main__':
    test_graph_conv()
    test_edge_softmax()
312
313
314
315
    test_set2set()
    test_glob_att_pool()
    test_simple_pool()
    test_set_trans()
Minjie Wang's avatar
Minjie Wang committed
316
    test_rgcn()