test_nn.py 2.73 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import mxnet as mx
import networkx as nx
import numpy as np
import dgl
import dgl.nn.mxnet as nn
from mxnet import autograd

def check_eq(a, b):
    assert a.shape == b.shape
    assert mx.nd.sum(a == b).asnumpy() == int(np.prod(list(a.shape)))

def _AXWb(A, X, W, b):
    X = mx.nd.dot(X, W.data(X.context))
    Y = mx.nd.dot(A, X.reshape(X.shape[0], -1)).reshape(X.shape)
    return Y + b.data(X.context)

def test_graph_conv():
    g = dgl.DGLGraph(nx.path_graph(3))
    adj = g.adjacency_matrix()
    ctx = mx.cpu(0)

    conv = nn.GraphConv(5, 2, norm=False, bias=True)
    conv.initialize(ctx=ctx)
    # test#1: basic
    h0 = mx.nd.ones((3, 5))
    h1 = conv(h0, g)
27
28
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
29
30
31
32
    check_eq(h1, _AXWb(adj, h0, conv.weight, conv.bias))
    # test#2: more-dim
    h0 = mx.nd.ones((3, 5, 5))
    h1 = conv(h0, g)
33
34
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
35
36
37
38
39
40
41
42
    check_eq(h1, _AXWb(adj, h0, conv.weight, conv.bias))

    conv = nn.GraphConv(5, 2)
    conv.initialize(ctx=ctx)

    # test#3: basic
    h0 = mx.nd.ones((3, 5))
    h1 = conv(h0, g)
43
44
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
45
46
47
    # test#4: basic
    h0 = mx.nd.ones((3, 5, 5))
    h1 = conv(h0, g)
48
49
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
50
51
52
53
54
55
56
57

    conv = nn.GraphConv(5, 2)
    conv.initialize(ctx=ctx)

    with autograd.train_mode():
        # test#3: basic
        h0 = mx.nd.ones((3, 5))
        h1 = conv(h0, g)
58
59
        assert len(g.ndata) == 0
        assert len(g.edata) == 0
60
61
62
        # test#4: basic
        h0 = mx.nd.ones((3, 5, 5))
        h1 = conv(h0, g)
63
64
        assert len(g.ndata) == 0
        assert len(g.edata) == 0
65

66
67
    # test not override features
    g.ndata["h"] = 2 * mx.nd.ones((3, 1))
68
    h1 = conv(h0, g)
69
70
71
72
    assert len(g.ndata) == 1
    assert len(g.edata) == 0
    assert "h" in g.ndata
    check_eq(g.ndata['h'], 2 * mx.nd.ones((3, 1)))
73

74
75
76
77
78
79
80
81
82
83
def uniform_attention(g, shape):
    a = mx.nd.ones(shape)
    target_shape = (g.number_of_edges(),) + (1,) * (len(shape) - 1)
    return a / g.in_degrees(g.edges()[1]).reshape(target_shape).astype('float32')

def test_edge_softmax():
    # Basic
    g = dgl.DGLGraph(nx.path_graph(3))
    edata = mx.nd.ones((g.number_of_edges(), 1))
    a = nn.edge_softmax(g, edata)
84
85
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
86
87
88
89
90
91
    assert np.allclose(a.asnumpy(), uniform_attention(g, a.shape).asnumpy(),
            1e-4, 1e-4)

    # Test higher dimension case
    edata = mx.nd.ones((g.number_of_edges(), 3, 1))
    a = nn.edge_softmax(g, edata)
92
93
    assert len(g.ndata) == 0
    assert len(g.edata) == 0
94
95
96
    assert np.allclose(a.asnumpy(), uniform_attention(g, a.shape).asnumpy(),
            1e-4, 1e-4)

97
98
if __name__ == '__main__':
    test_graph_conv()
99
    test_edge_softmax()