"docs/vscode:/vscode.git/clone" did not exist on "40528e9ae7d56740c00d838299198d34111717bb"
test_nn.py 5.43 KB
Newer Older
1
2
3
4
5
import mxnet as mx
import networkx as nx
import numpy as np
import dgl
import dgl.nn.mxnet as nn
6
import backend as F
7
from mxnet import autograd, gluon
8

9
10
def check_close(a, b):
    assert np.allclose(a.asnumpy(), b.asnumpy(), rtol=1e-4, atol=1e-4)
11
12
13
14
15
16
17
18

def _AXWb(A, X, W, b):
    X = mx.nd.dot(X, W.data(X.context))
    Y = mx.nd.dot(A, X.reshape(X.shape[0], -1)).reshape(X.shape)
    return Y + b.data(X.context)

def test_graph_conv():
    g = dgl.DGLGraph(nx.path_graph(3))
19
20
    ctx = F.ctx()
    adj = g.adjacency_matrix(ctx=ctx)
21
22
23
24

    conv = nn.GraphConv(5, 2, norm=False, bias=True)
    conv.initialize(ctx=ctx)
    # test#1: basic
25
    h0 = F.ones((3, 5))
26
    h1 = conv(h0, g)
27
28
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
29
    check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias))
30
    # test#2: more-dim
31
    h0 = F.ones((3, 5, 5))
32
    h1 = conv(h0, g)
33
34
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
35
    check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias))
36
37
38
39
40

    conv = nn.GraphConv(5, 2)
    conv.initialize(ctx=ctx)

    # test#3: basic
41
    h0 = F.ones((3, 5))
42
    h1 = conv(h0, g)
43
44
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
45
    # test#4: basic
46
    h0 = F.ones((3, 5, 5))
47
    h1 = conv(h0, g)
48
49
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
50
51
52
53
54
55

    conv = nn.GraphConv(5, 2)
    conv.initialize(ctx=ctx)

    with autograd.train_mode():
        # test#3: basic
56
        h0 = F.ones((3, 5))
57
        h1 = conv(h0, g)
58
59
        assert len(g.ndata) == 0
        assert len(g.edata) == 0
60
        # test#4: basic
61
        h0 = F.ones((3, 5, 5))
62
        h1 = conv(h0, g)
63
64
        assert len(g.ndata) == 0
        assert len(g.edata) == 0
65

66
    # test not override features
67
    g.ndata["h"] = 2 * F.ones((3, 1))
68
    h1 = conv(h0, g)
69
70
71
    assert len(g.ndata) == 1
    assert len(g.edata) == 0
    assert "h" in g.ndata
72
    check_close(g.ndata['h'], 2 * F.ones((3, 1)))
73
74
75

def test_set2set():
    g = dgl.DGLGraph(nx.path_graph(10))
76
    ctx = F.ctx()
77
78

    s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
79
    s2s.initialize(ctx=ctx)
80
81
82
    print(s2s)

    # test#1: basic
83
    h0 = F.randn((g.number_of_nodes(), 5))
84
85
86
87
88
    h1 = s2s(h0, g)
    assert h1.shape[0] == 10 and h1.ndim == 1

    # test#2: batched graph
    bg = dgl.batch([g, g, g])
89
    h0 = F.randn((bg.number_of_nodes(), 5))
90
91
92
93
94
    h1 = s2s(h0, bg)
    assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.ndim == 2

def test_glob_att_pool():
    g = dgl.DGLGraph(nx.path_graph(10))
95
    ctx = F.ctx()
96
97

    gap = nn.GlobalAttentionPooling(gluon.nn.Dense(1), gluon.nn.Dense(10))
98
    gap.initialize(ctx=ctx)
99
100
    print(gap)
    # test#1: basic
101
    h0 = F.randn((g.number_of_nodes(), 5))
102
103
104
105
106
    h1 = gap(h0, g)
    assert h1.shape[0] == 10 and h1.ndim == 1

    # test#2: batched graph
    bg = dgl.batch([g, g, g, g])
107
    h0 = F.randn((bg.number_of_nodes(), 5))
108
109
110
111
112
113
114
115
116
117
118
119
120
    h1 = gap(h0, bg)
    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2

def test_simple_pool():
    g = dgl.DGLGraph(nx.path_graph(15))

    sum_pool = nn.SumPooling()
    avg_pool = nn.AvgPooling()
    max_pool = nn.MaxPooling()
    sort_pool = nn.SortPooling(10) # k = 10
    print(sum_pool, avg_pool, max_pool, sort_pool)

    # test#1: basic
121
    h0 = F.randn((g.number_of_nodes(), 5))
122
    h1 = sum_pool(h0, g)
123
    check_close(h1, F.sum(h0, 0))
124
    h1 = avg_pool(h0, g)
125
    check_close(h1, F.mean(h0, 0))
126
    h1 = max_pool(h0, g)
127
    check_close(h1, F.max(h0, 0))
128
129
130
131
132
133
    h1 = sort_pool(h0, g)
    assert h1.shape[0] == 10 * 5 and h1.ndim == 1

    # test#2: batched graph
    g_ = dgl.DGLGraph(nx.path_graph(5))
    bg = dgl.batch([g, g_, g, g_, g])
134
    h0 = F.randn((bg.number_of_nodes(), 5))
135
    h1 = sum_pool(h0, bg)
136
137
138
139
140
    truth = mx.nd.stack(F.sum(h0[:15], 0),
                        F.sum(h0[15:20], 0),
                        F.sum(h0[20:35], 0),
                        F.sum(h0[35:40], 0),
                        F.sum(h0[40:55], 0), axis=0)
141
142
143
    check_close(h1, truth)

    h1 = avg_pool(h0, bg)
144
145
146
147
148
    truth = mx.nd.stack(F.mean(h0[:15], 0),
                        F.mean(h0[15:20], 0),
                        F.mean(h0[20:35], 0),
                        F.mean(h0[35:40], 0),
                        F.mean(h0[40:55], 0), axis=0)
149
150
151
    check_close(h1, truth)

    h1 = max_pool(h0, bg)
152
153
154
155
156
    truth = mx.nd.stack(F.max(h0[:15], 0),
                        F.max(h0[15:20], 0),
                        F.max(h0[20:35], 0),
                        F.max(h0[35:40], 0),
                        F.max(h0[40:55], 0), axis=0)
157
158
159
160
161
    check_close(h1, truth)

    h1 = sort_pool(h0, bg)
    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.ndim == 2

162
163
164
165
166
167
168
169
def uniform_attention(g, shape):
    a = mx.nd.ones(shape)
    target_shape = (g.number_of_edges(),) + (1,) * (len(shape) - 1)
    return a / g.in_degrees(g.edges()[1]).reshape(target_shape).astype('float32')

def test_edge_softmax():
    # Basic
    g = dgl.DGLGraph(nx.path_graph(3))
170
    edata = F.ones((g.number_of_edges(), 1))
171
    a = nn.edge_softmax(g, edata)
172
173
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
174
175
176
177
    assert np.allclose(a.asnumpy(), uniform_attention(g, a.shape).asnumpy(),
            1e-4, 1e-4)

    # Test higher dimension case
178
    edata = F.ones((g.number_of_edges(), 3, 1))
179
    a = nn.edge_softmax(g, edata)
180
181
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
182
183
184
    assert np.allclose(a.asnumpy(), uniform_attention(g, a.shape).asnumpy(),
            1e-4, 1e-4)

185
186
if __name__ == '__main__':
    test_graph_conv()
187
    test_edge_softmax()
188
189
190
    test_set2set()
    test_glob_att_pool()
    test_simple_pool()