test_nn.py 7.68 KB
Newer Older
1
2
3
4
import torch as th
import networkx as nx
import dgl
import dgl.nn.pytorch as nn
5
import backend as F
6
7
from copy import deepcopy

8
9
10
import numpy as np
import scipy as sp

11
12
13
14
15
16
17
def _AXWb(A, X, W, b):
    X = th.matmul(X, W)
    Y = th.matmul(A, X.view(X.shape[0], -1)).view_as(X)
    return Y + b

def test_graph_conv():
    g = dgl.DGLGraph(nx.path_graph(3))
18
19
    ctx = F.ctx()
    adj = g.adjacency_matrix(ctx=ctx)
20
21

    conv = nn.GraphConv(5, 2, norm=False, bias=True)
22
23
    if F.gpu_ctx():
        conv.cuda()
24
25
    print(conv)
    # test#1: basic
26
    h0 = F.ones((3, 5))
27
    h1 = conv(h0, g)
28
29
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
30
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
31
    # test#2: more-dim
32
    h0 = F.ones((3, 5, 5))
33
    h1 = conv(h0, g)
34
35
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
36
    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
37
38

    conv = nn.GraphConv(5, 2)
39
40
    if F.gpu_ctx():
        conv.cuda()
41
    # test#3: basic
42
    h0 = F.ones((3, 5))
43
    h1 = conv(h0, g)
44
45
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
46
    # test#4: basic
47
    h0 = F.ones((3, 5, 5))
48
    h1 = conv(h0, g)
49
50
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
51
52

    conv = nn.GraphConv(5, 2)
53
54
    if F.gpu_ctx():
        conv.cuda()
55
    # test#3: basic
56
    h0 = F.ones((3, 5))
57
    h1 = conv(h0, g)
58
59
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
60
    # test#4: basic
61
    h0 = F.ones((3, 5, 5))
62
    h1 = conv(h0, g)
63
64
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
65
66
67
68
69

    # test rest_parameters
    old_weight = deepcopy(conv.weight.data)
    conv.reset_parameters()
    new_weight = conv.weight.data
70
    assert not F.allclose(old_weight, new_weight)
71

72
73
74
75
def test_set2set():
    g = dgl.DGLGraph(nx.path_graph(10))

    s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
76
77
    if F.gpu_ctx():
        s2s.cuda()
78
79
80
    print(s2s)

    # test#1: basic
81
    h0 = F.randn((g.number_of_nodes(), 5))
82
83
84
85
86
87
88
    h1 = s2s(h0, g)
    assert h1.shape[0] == 10 and h1.dim() == 1

    # test#2: batched graph
    g1 = dgl.DGLGraph(nx.path_graph(11))
    g2 = dgl.DGLGraph(nx.path_graph(5))
    bg = dgl.batch([g, g1, g2])
89
    h0 = F.randn((bg.number_of_nodes(), 5))
90
91
92
93
94
95
96
    h1 = s2s(h0, bg)
    assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2

def test_glob_att_pool():
    g = dgl.DGLGraph(nx.path_graph(10))

    gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))
97
98
    if F.gpu_ctx():
        gap.cuda()
99
100
101
    print(gap)

    # test#1: basic
102
    h0 = F.randn((g.number_of_nodes(), 5))
103
104
105
106
107
    h1 = gap(h0, g)
    assert h1.shape[0] == 10 and h1.dim() == 1

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
108
    h0 = F.randn((bg.number_of_nodes(), 5))
109
110
111
112
113
114
115
116
117
118
119
120
121
    h1 = gap(h0, bg)
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2

def test_simple_pool():
    g = dgl.DGLGraph(nx.path_graph(15))

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
    sort_pool = nn.SortPooling(10) # k = 10
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
122
    h0 = F.randn((g.number_of_nodes(), 5))
123
    h1 = sum_pool(h0, g)
124
    assert F.allclose(h1, F.sum(h0, 0))
125
    h1 = avg_pool(h0, g)
126
    assert F.allclose(h1, F.mean(h0, 0))
127
    h1 = max_pool(h0, g)
128
    assert F.allclose(h1, F.max(h0, 0))
129
130
131
132
133
134
    h1 = sort_pool(h0, g)
    assert h1.shape[0] == 10 * 5 and h1.dim() == 1

    # test#2: batched graph
    g_ = dgl.DGLGraph(nx.path_graph(5))
    bg = dgl.batch([g, g_, g, g_, g])
135
    h0 = F.randn((bg.number_of_nodes(), 5))
136
137

    h1 = sum_pool(h0, bg)
138
139
140
141
142
143
    truth = th.stack([F.sum(h0[:15], 0),
                      F.sum(h0[15:20], 0),
                      F.sum(h0[20:35], 0),
                      F.sum(h0[35:40], 0),
                      F.sum(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
144
145

    h1 = avg_pool(h0, bg)
146
147
148
149
150
151
    truth = th.stack([F.mean(h0[:15], 0),
                      F.mean(h0[15:20], 0),
                      F.mean(h0[20:35], 0),
                      F.mean(h0[35:40], 0),
                      F.mean(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
152
153

    h1 = max_pool(h0, bg)
154
155
156
157
158
159
    truth = th.stack([F.max(h0[:15], 0),
                      F.max(h0[15:20], 0),
                      F.max(h0[20:35], 0),
                      F.max(h0[35:40], 0),
                      F.max(h0[40:55], 0)], 0)
    assert F.allclose(h1, truth)
160
161
162
163
164
165
166
167
168
169

    h1 = sort_pool(h0, bg)
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2

def test_set_trans():
    g = dgl.DGLGraph(nx.path_graph(15))

    st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'sab')
    st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'isab', 3)
    st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)
170
171
172
173
    if F.gpu_ctx():
        st_enc_0.cuda()
        st_enc_1.cuda()
        st_dec.cuda()
174
175
176
    print(st_enc_0, st_enc_1, st_dec)

    # test#1: basic
177
    h0 = F.randn((g.number_of_nodes(), 50))
178
179
180
181
182
183
184
185
186
187
188
    h1 = st_enc_0(h0, g)
    assert h1.shape == h0.shape
    h1 = st_enc_1(h0, g)
    assert h1.shape == h0.shape
    h2 = st_dec(h1, g)
    assert h2.shape[0] == 200 and h2.dim() == 1

    # test#2: batched graph
    g1 = dgl.DGLGraph(nx.path_graph(5))
    g2 = dgl.DGLGraph(nx.path_graph(10))
    bg = dgl.batch([g, g1, g2])
189
    h0 = F.randn((bg.number_of_nodes(), 50))
190
191
192
193
194
195
196
197
    h1 = st_enc_0(h0, bg)
    assert h1.shape == h0.shape
    h1 = st_enc_1(h0, bg)
    assert h1.shape == h0.shape

    h2 = st_dec(h1, bg)
    assert h2.shape[0] == 3 and h2.shape[1] == 200 and h2.dim() == 2

198
199
200
201
202
203
def uniform_attention(g, shape):
    a = th.ones(shape)
    target_shape = (g.number_of_edges(),) + (1,) * (len(shape) - 1)
    return a / g.in_degrees(g.edges()[1]).view(target_shape).float()

def test_edge_softmax():
204
205
    # Basic
    g = dgl.DGLGraph(nx.path_graph(3))
206
    edata = F.ones((g.number_of_edges(), 1))
207
    a = nn.edge_softmax(g, edata)
208
209
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
210
    assert F.allclose(a, uniform_attention(g, a.shape))
211

212
    # Test higher dimension case
213
    edata = F.ones((g.number_of_edges(), 3, 1))
214
    a = nn.edge_softmax(g, edata)
215
216
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
217
    assert F.allclose(a, uniform_attention(g, a.shape))
218

219
220
221
222
223
224
225
226
    # Test both forward and backward with PyTorch built-in softmax.
    g = dgl.DGLGraph()
    g.add_nodes(30)
    # build a complete graph
    for i in range(30):
        for j in range(30):
            g.add_edge(i, j)

227
    score = F.randn((900, 1))
228
    score.requires_grad_()
229
230
    grad = F.randn((900, 1))
    y = F.softmax(score.view(30, 30), dim=0).view(-1, 1)
231
232
233
234
    y.backward(grad)
    grad_score = score.grad
    score.grad.zero_()
    y_dgl = nn.edge_softmax(g, score)
235
236
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
237
    # check forward
238
    assert F.allclose(y_dgl, y)
239
240
    y_dgl.backward(grad)
    # checkout gradient
241
    assert F.allclose(score.grad, grad_score)
242
243
244
245
246
247
248
249
    print(score.grad[:10], grad_score[:10])
    
    # Test 2
    def generate_rand_graph(n):
      arr = (sp.sparse.random(n, n, density=0.1, format='coo') != 0).astype(np.int64)
      return dgl.DGLGraph(arr, readonly=True)
    
    g = generate_rand_graph(50)
250
    a1 = F.randn((g.number_of_edges(), 1)).requires_grad_()
251
252
    a2 = a1.clone().detach().requires_grad_()
    g.edata['s'] = a1
253
    g.group_apply_edges('dst', lambda edges: {'ss':F.softmax(edges.data['s'], 1)})
254
255
256
257
258
    g.edata['ss'].sum().backward()
    
    builtin_sm = nn.edge_softmax(g, a2)
    builtin_sm.sum().backward()
    print(a1.grad - a2.grad)
259
260
    assert len(g.ndata) == 0
    assert len(g.edata) == 2
261
    assert F.allclose(a1.grad, a2.grad, rtol=1e-4, atol=1e-4) # Follow tolerance in unittest backend
262
    
263
264
265
266

if __name__ == '__main__':
    test_graph_conv()
    test_edge_softmax()
267
268
269
270
    test_set2set()
    test_glob_att_pool()
    test_simple_pool()
    test_set_trans()