test_nn.py 6.95 KB
Newer Older
1
2
3
import mxnet as mx
import networkx as nx
import numpy as np
Minjie Wang's avatar
Minjie Wang committed
4
import scipy as sp
5
6
import dgl
import dgl.nn.mxnet as nn
7
import backend as F
Minjie Wang's avatar
Minjie Wang committed
8
from mxnet import autograd, gluon, nd
9

10
11
def check_close(a, b):
    assert np.allclose(a.asnumpy(), b.asnumpy(), rtol=1e-4, atol=1e-4)
12
13
14
15
16
17
18
19

def _AXWb(A, X, W, b):
    X = mx.nd.dot(X, W.data(X.context))
    Y = mx.nd.dot(A, X.reshape(X.shape[0], -1)).reshape(X.shape)
    return Y + b.data(X.context)

def test_graph_conv():
    g = dgl.DGLGraph(nx.path_graph(3))
20
21
    ctx = F.ctx()
    adj = g.adjacency_matrix(ctx=ctx)
22
23
24
25

    conv = nn.GraphConv(5, 2, norm=False, bias=True)
    conv.initialize(ctx=ctx)
    # test#1: basic
26
    h0 = F.ones((3, 5))
27
    h1 = conv(g, h0)
28
29
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
30
    check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias))
31
    # test#2: more-dim
32
    h0 = F.ones((3, 5, 5))
33
    h1 = conv(g, h0)
34
35
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
36
    check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias))
37
38
39
40
41

    conv = nn.GraphConv(5, 2)
    conv.initialize(ctx=ctx)

    # test#3: basic
42
    h0 = F.ones((3, 5))
43
    h1 = conv(g, h0)
44
45
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
46
    # test#4: basic
47
    h0 = F.ones((3, 5, 5))
48
    h1 = conv(g, h0)
49
50
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
51
52
53
54
55
56

    conv = nn.GraphConv(5, 2)
    conv.initialize(ctx=ctx)

    with autograd.train_mode():
        # test#3: basic
57
        h0 = F.ones((3, 5))
58
        h1 = conv(g, h0)
59
60
        assert len(g.ndata) == 0
        assert len(g.edata) == 0
61
        # test#4: basic
62
        h0 = F.ones((3, 5, 5))
63
        h1 = conv(g, h0)
64
65
        assert len(g.ndata) == 0
        assert len(g.edata) == 0
66

67
    # test not override features
68
    g.ndata["h"] = 2 * F.ones((3, 1))
69
    h1 = conv(g, h0)
70
71
72
    assert len(g.ndata) == 1
    assert len(g.edata) == 0
    assert "h" in g.ndata
73
    check_close(g.ndata['h'], 2 * F.ones((3, 1)))
74
75
76

def test_set2set():
    g = dgl.DGLGraph(nx.path_graph(10))
77
    ctx = F.ctx()
78
79

    s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
80
    s2s.initialize(ctx=ctx)
81
82
83
    print(s2s)

    # test#1: basic
84
    h0 = F.randn((g.number_of_nodes(), 5))
85
    h1 = s2s(g, h0)
86
87
88
89
    assert h1.shape[0] == 10 and h1.ndim == 1

    # test#2: batched graph
    bg = dgl.batch([g, g, g])
90
    h0 = F.randn((bg.number_of_nodes(), 5))
91
    h1 = s2s(bg, h0)
92
93
94
95
    assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.ndim == 2

def test_glob_att_pool():
    g = dgl.DGLGraph(nx.path_graph(10))
96
    ctx = F.ctx()
97
98

    gap = nn.GlobalAttentionPooling(gluon.nn.Dense(1), gluon.nn.Dense(10))
99
    gap.initialize(ctx=ctx)
100
101
    print(gap)
    # test#1: basic
102
    h0 = F.randn((g.number_of_nodes(), 5))
103
    h1 = gap(g, h0)
104
105
106
107
    assert h1.shape[0] == 10 and h1.ndim == 1

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
108
    h0 = F.randn((bg.number_of_nodes(), 5))
109
    h1 = gap(bg, h0)
110
111
112
113
114
115
116
117
118
119
120
121
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2

def test_simple_pool():
    g = dgl.DGLGraph(nx.path_graph(15))

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
    sort_pool = nn.SortPooling(10) # k = 10
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
122
    h0 = F.randn((g.number_of_nodes(), 5))
123
    h1 = sum_pool(g, h0)
124
    check_close(h1, F.sum(h0, 0))
125
    h1 = avg_pool(g, h0)
126
    check_close(h1, F.mean(h0, 0))
127
    h1 = max_pool(g, h0)
128
    check_close(h1, F.max(h0, 0))
129
    h1 = sort_pool(g, h0)
130
131
132
133
134
    assert h1.shape[0] == 10 * 5 and h1.ndim == 1

    # test#2: batched graph
    g_ = dgl.DGLGraph(nx.path_graph(5))
    bg = dgl.batch([g, g_, g, g_, g])
135
    h0 = F.randn((bg.number_of_nodes(), 5))
136
    h1 = sum_pool(bg, h0)
137
138
139
140
141
    truth = mx.nd.stack(F.sum(h0[:15], 0),
                        F.sum(h0[15:20], 0),
                        F.sum(h0[20:35], 0),
                        F.sum(h0[35:40], 0),
                        F.sum(h0[40:55], 0), axis=0)
142
143
    check_close(h1, truth)

144
    h1 = avg_pool(bg, h0)
145
146
147
148
149
    truth = mx.nd.stack(F.mean(h0[:15], 0),
                        F.mean(h0[15:20], 0),
                        F.mean(h0[20:35], 0),
                        F.mean(h0[35:40], 0),
                        F.mean(h0[40:55], 0), axis=0)
150
151
    check_close(h1, truth)

152
    h1 = max_pool(bg, h0)
153
154
155
156
157
    truth = mx.nd.stack(F.max(h0[:15], 0),
                        F.max(h0[15:20], 0),
                        F.max(h0[20:35], 0),
                        F.max(h0[35:40], 0),
                        F.max(h0[40:55], 0), axis=0)
158
159
    check_close(h1, truth)

160
    h1 = sort_pool(bg, h0)
161
162
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.ndim == 2

163
164
165
166
167
168
169
170
def uniform_attention(g, shape):
    a = mx.nd.ones(shape)
    target_shape = (g.number_of_edges(),) + (1,) * (len(shape) - 1)
    return a / g.in_degrees(g.edges()[1]).reshape(target_shape).astype('float32')

def test_edge_softmax():
    # Basic
    g = dgl.DGLGraph(nx.path_graph(3))
171
    edata = F.ones((g.number_of_edges(), 1))
172
    a = nn.edge_softmax(g, edata)
173
174
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
175
176
177
178
    assert np.allclose(a.asnumpy(), uniform_attention(g, a.shape).asnumpy(),
            1e-4, 1e-4)

    # Test higher dimension case
179
    edata = F.ones((g.number_of_edges(), 3, 1))
180
    a = nn.edge_softmax(g, edata)
181
182
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
183
184
185
    assert np.allclose(a.asnumpy(), uniform_attention(g, a.shape).asnumpy(),
            1e-4, 1e-4)

Minjie Wang's avatar
Minjie Wang committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def test_rgcn():
    ctx = F.ctx()
    etype = []
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    # 5 etypes
    R = 5
    for i in range(g.number_of_edges()):
        etype.append(i % 5)
    B = 2
    I = 10
    O = 8

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
    rgc_basis.initialize(ctx=ctx)
    h = nd.random.randn(100, I, ctx=ctx)
    r = nd.array(etype, ctx=ctx)
    h_new = rgc_basis(g, h, r)
    assert list(h_new.shape) == [100, O]

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
    rgc_bdd.initialize(ctx=ctx)
    h = nd.random.randn(100, I, ctx=ctx)
    r = nd.array(etype, ctx=ctx)
    h_new = rgc_bdd(g, h, r)
    assert list(h_new.shape) == [100, O]

    # with norm
    norm = nd.zeros((g.number_of_edges(), 1), ctx=ctx)

    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
    rgc_basis.initialize(ctx=ctx)
    h = nd.random.randn(100, I, ctx=ctx)
    r = nd.array(etype, ctx=ctx)
    h_new = rgc_basis(g, h, r, norm)
    assert list(h_new.shape) == [100, O]

    rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
    rgc_bdd.initialize(ctx=ctx)
    h = nd.random.randn(100, I, ctx=ctx)
    r = nd.array(etype, ctx=ctx)
    h_new = rgc_bdd(g, h, r, norm)
    assert list(h_new.shape) == [100, O]

    # id input
    rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
    rgc_basis.initialize(ctx=ctx)
    h = nd.random.randint(0, I, (100,), ctx=ctx)
    r = nd.array(etype, ctx=ctx)
    h_new = rgc_basis(g, h, r)
    assert list(h_new.shape) == [100, O]

237
238
if __name__ == '__main__':
    test_graph_conv()
239
    test_edge_softmax()
240
241
242
    test_set2set()
    test_glob_att_pool()
    test_simple_pool()
Minjie Wang's avatar
Minjie Wang committed
243
    test_rgcn()