test_nn.py 7.93 KB
Newer Older
1
2
3
import mxnet as mx
import networkx as nx
import numpy as np
Minjie Wang's avatar
Minjie Wang committed
4
import scipy as sp
5
6
import dgl
import dgl.nn.mxnet as nn
7
import backend as F
Minjie Wang's avatar
Minjie Wang committed
8
from mxnet import autograd, gluon, nd
9

10
11
def check_close(a, b):
    assert np.allclose(a.asnumpy(), b.asnumpy(), rtol=1e-4, atol=1e-4)
12
13
14
15
16
17
18
19

def _AXWb(A, X, W, b):
    X = mx.nd.dot(X, W.data(X.context))
    Y = mx.nd.dot(A, X.reshape(X.shape[0], -1)).reshape(X.shape)
    return Y + b.data(X.context)

def test_graph_conv():
    g = dgl.DGLGraph(nx.path_graph(3))
20
21
    ctx = F.ctx()
    adj = g.adjacency_matrix(ctx=ctx)
22
23
24
25

    conv = nn.GraphConv(5, 2, norm=False, bias=True)
    conv.initialize(ctx=ctx)
    # test#1: basic
26
    h0 = F.ones((3, 5))
27
    h1 = conv(g, h0)
28
29
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
30
    check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias))
31
    # test#2: more-dim
32
    h0 = F.ones((3, 5, 5))
33
    h1 = conv(g, h0)
34
35
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
36
    check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias))
37
38
39
40
41

    conv = nn.GraphConv(5, 2)
    conv.initialize(ctx=ctx)

    # test#3: basic
42
    h0 = F.ones((3, 5))
43
    h1 = conv(g, h0)
44
45
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
46
    # test#4: basic
47
    h0 = F.ones((3, 5, 5))
48
    h1 = conv(g, h0)
49
50
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
51
52
53
54
55
56

    conv = nn.GraphConv(5, 2)
    conv.initialize(ctx=ctx)

    with autograd.train_mode():
        # test#3: basic
57
        h0 = F.ones((3, 5))
58
        h1 = conv(g, h0)
59
60
        assert len(g.ndata) == 0
        assert len(g.edata) == 0
61
        # test#4: basic
62
        h0 = F.ones((3, 5, 5))
63
        h1 = conv(g, h0)
64
65
        assert len(g.ndata) == 0
        assert len(g.edata) == 0
66

67
    # test not override features
68
    g.ndata["h"] = 2 * F.ones((3, 1))
69
    h1 = conv(g, h0)
70
71
72
    assert len(g.ndata) == 1
    assert len(g.edata) == 0
    assert "h" in g.ndata
73
    check_close(g.ndata['h'], 2 * F.ones((3, 1)))
74

75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def _S2AXWb(A, N, X, W, b):
    X1 = X * N
    X1 = mx.nd.dot(A, X1.reshape(X1.shape[0], -1))
    X1 = X1 * N
    X2 = X1 * N
    X2 = mx.nd.dot(A, X2.reshape(X2.shape[0], -1))
    X2 = X2 * N
    X = mx.nd.concat(X, X1, X2, dim=-1)
    Y = mx.nd.dot(X, W)

    return Y + b

def test_tagconv():
    g = dgl.DGLGraph(nx.path_graph(3))
    ctx = F.ctx()
    adj = g.adjacency_matrix(ctx=ctx)
    norm = mx.nd.power(g.in_degrees().astype('float32'), -0.5)

    conv = nn.TAGConv(5, 2, bias=True)
    conv.initialize(ctx=ctx)
    print(conv)

    # test#1: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
    shp = norm.shape + (1,) * (h0.ndim - 1)
    norm = norm.reshape(shp).as_in_context(h0.context)

    assert F.allclose(h1, _S2AXWb(adj, norm, h0, conv.lin.data(ctx), conv.h_bias.data(ctx)))

    conv = nn.TAGConv(5, 2)
    conv.initialize(ctx=ctx)

    # test#2: basic
    h0 = F.ones((3, 5))
    h1 = conv(g, h0)
    assert h1.shape[-1] == 2

115
116
def test_set2set():
    g = dgl.DGLGraph(nx.path_graph(10))
117
    ctx = F.ctx()
118
119

    s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
120
    s2s.initialize(ctx=ctx)
121
122
123
    print(s2s)

    # test#1: basic
124
    h0 = F.randn((g.number_of_nodes(), 5))
125
    h1 = s2s(g, h0)
126
127
128
129
    assert h1.shape[0] == 10 and h1.ndim == 1

    # test#2: batched graph
    bg = dgl.batch([g, g, g])
130
    h0 = F.randn((bg.number_of_nodes(), 5))
131
    h1 = s2s(bg, h0)
132
133
134
135
    assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.ndim == 2

def test_glob_att_pool():
    g = dgl.DGLGraph(nx.path_graph(10))
136
    ctx = F.ctx()
137
138

    gap = nn.GlobalAttentionPooling(gluon.nn.Dense(1), gluon.nn.Dense(10))
139
    gap.initialize(ctx=ctx)
140
141
    print(gap)
    # test#1: basic
142
    h0 = F.randn((g.number_of_nodes(), 5))
143
    h1 = gap(g, h0)
144
145
146
147
    assert h1.shape[0] == 10 and h1.ndim == 1

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
148
    h0 = F.randn((bg.number_of_nodes(), 5))
149
    h1 = gap(bg, h0)
150
151
152
153
154
155
156
157
158
159
160
161
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2

def test_simple_pool():
    g = dgl.DGLGraph(nx.path_graph(15))

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
    sort_pool = nn.SortPooling(10) # k = 10
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
162
    h0 = F.randn((g.number_of_nodes(), 5))
163
    h1 = sum_pool(g, h0)
164
    check_close(h1, F.sum(h0, 0))
165
    h1 = avg_pool(g, h0)
166
    check_close(h1, F.mean(h0, 0))
167
    h1 = max_pool(g, h0)
168
    check_close(h1, F.max(h0, 0))
169
    h1 = sort_pool(g, h0)
170
171
172
173
174
    assert h1.shape[0] == 10 * 5 and h1.ndim == 1

    # test#2: batched graph
    g_ = dgl.DGLGraph(nx.path_graph(5))
    bg = dgl.batch([g, g_, g, g_, g])
175
    h0 = F.randn((bg.number_of_nodes(), 5))
176
    h1 = sum_pool(bg, h0)
177
178
179
180
181
    truth = mx.nd.stack(F.sum(h0[:15], 0),
                        F.sum(h0[15:20], 0),
                        F.sum(h0[20:35], 0),
                        F.sum(h0[35:40], 0),
                        F.sum(h0[40:55], 0), axis=0)
182
183
    check_close(h1, truth)

184
    h1 = avg_pool(bg, h0)
185
186
187
188
189
    truth = mx.nd.stack(F.mean(h0[:15], 0),
                        F.mean(h0[15:20], 0),
                        F.mean(h0[20:35], 0),
                        F.mean(h0[35:40], 0),
                        F.mean(h0[40:55], 0), axis=0)
190
191
    check_close(h1, truth)

192
    h1 = max_pool(bg, h0)
193
194
195
196
197
    truth = mx.nd.stack(F.max(h0[:15], 0),
                        F.max(h0[15:20], 0),
                        F.max(h0[20:35], 0),
                        F.max(h0[35:40], 0),
                        F.max(h0[40:55], 0), axis=0)
198
199
    check_close(h1, truth)

200
    h1 = sort_pool(bg, h0)
201
202
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.ndim == 2

203
204
205
206
207
208
209
210
def uniform_attention(g, shape):
    a = mx.nd.ones(shape)
    target_shape = (g.number_of_edges(),) + (1,) * (len(shape) - 1)
    return a / g.in_degrees(g.edges()[1]).reshape(target_shape).astype('float32')

def test_edge_softmax():
    # Basic
    g = dgl.DGLGraph(nx.path_graph(3))
211
    edata = F.ones((g.number_of_edges(), 1))
212
    a = nn.edge_softmax(g, edata)
213
214
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
215
216
217
218
    assert np.allclose(a.asnumpy(), uniform_attention(g, a.shape).asnumpy(),
            1e-4, 1e-4)

    # Test higher dimension case
219
    edata = F.ones((g.number_of_edges(), 3, 1))
220
    a = nn.edge_softmax(g, edata)
221
222
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
223
224
225
    assert np.allclose(a.asnumpy(), uniform_attention(g, a.shape).asnumpy(),
            1e-4, 1e-4)

Minjie Wang's avatar
Minjie Wang committed
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
def test_rgcn():
    ctx = F.ctx()
    etype = []
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    # 5 etypes
    R = 5
    for i in range(g.number_of_edges()):
        etype.append(i % 5)
    B = 2
    I = 10
    O = 8

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
    rgc_basis.initialize(ctx=ctx)
    h = nd.random.randn(100, I, ctx=ctx)
    r = nd.array(etype, ctx=ctx)
    h_new = rgc_basis(g, h, r)
    assert list(h_new.shape) == [100, O]

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
    rgc_bdd.initialize(ctx=ctx)
    h = nd.random.randn(100, I, ctx=ctx)
    r = nd.array(etype, ctx=ctx)
    h_new = rgc_bdd(g, h, r)
    assert list(h_new.shape) == [100, O]

    # with norm
    norm = nd.zeros((g.number_of_edges(), 1), ctx=ctx)

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
    rgc_basis.initialize(ctx=ctx)
    h = nd.random.randn(100, I, ctx=ctx)
    r = nd.array(etype, ctx=ctx)
    h_new = rgc_basis(g, h, r, norm)
    assert list(h_new.shape) == [100, O]

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
    rgc_bdd.initialize(ctx=ctx)
    h = nd.random.randn(100, I, ctx=ctx)
    r = nd.array(etype, ctx=ctx)
    h_new = rgc_bdd(g, h, r, norm)
    assert list(h_new.shape) == [100, O]

    # id input
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
    rgc_basis.initialize(ctx=ctx)
    h = nd.random.randint(0, I, (100,), ctx=ctx)
    r = nd.array(etype, ctx=ctx)
    h_new = rgc_basis(g, h, r)
    assert list(h_new.shape) == [100, O]

277
278
if __name__ == '__main__':
    test_graph_conv()
279
    test_edge_softmax()
280
281
282
    test_set2set()
    test_glob_att_pool()
    test_simple_pool()
Minjie Wang's avatar
Minjie Wang committed
283
    test_rgcn()