"docs/vscode:/vscode.git/clone" did not exist on "9c1352eb5736d9e71d37959db44b6a641e898772"
test_nn.py 2.17 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import mxnet as mx
import networkx as nx
import numpy as np
import dgl
import dgl.nn.mxnet as nn
from mxnet import autograd

def check_eq(a, b):
    assert a.shape == b.shape
    assert mx.nd.sum(a == b).asnumpy() == int(np.prod(list(a.shape)))

def _AXWb(A, X, W, b):
    X = mx.nd.dot(X, W.data(X.context))
    Y = mx.nd.dot(A, X.reshape(X.shape[0], -1)).reshape(X.shape)
    return Y + b.data(X.context)

def test_graph_conv():
    g = dgl.DGLGraph(nx.path_graph(3))
    adj = g.adjacency_matrix()
    ctx = mx.cpu(0)

    conv = nn.GraphConv(5, 2, norm=False, bias=True)
    conv.initialize(ctx=ctx)
    # test#1: basic
    h0 = mx.nd.ones((3, 5))
    h1 = conv(h0, g)
    check_eq(h1, _AXWb(adj, h0, conv.weight, conv.bias))
    # test#2: more-dim
    h0 = mx.nd.ones((3, 5, 5))
    h1 = conv(h0, g)
    check_eq(h1, _AXWb(adj, h0, conv.weight, conv.bias))

    conv = nn.GraphConv(5, 2)
    conv.initialize(ctx=ctx)

    # test#3: basic
    h0 = mx.nd.ones((3, 5))
    h1 = conv(h0, g)
    # test#4: basic
    h0 = mx.nd.ones((3, 5, 5))
    h1 = conv(h0, g)

    conv = nn.GraphConv(5, 2)
    conv.initialize(ctx=ctx)

    with autograd.train_mode():
        # test#3: basic
        h0 = mx.nd.ones((3, 5))
        h1 = conv(h0, g)
        # test#4: basic
        h0 = mx.nd.ones((3, 5, 5))
        h1 = conv(h0, g)

    # test repeated features
    g.ndata["_gconv_feat"] = 2 * mx.nd.ones((3, 1))
    h1 = conv(h0, g)
    assert "_gconv_feat" in g.ndata

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def uniform_attention(g, shape):
    a = mx.nd.ones(shape)
    target_shape = (g.number_of_edges(),) + (1,) * (len(shape) - 1)
    return a / g.in_degrees(g.edges()[1]).reshape(target_shape).astype('float32')

def test_edge_softmax():
    # Basic
    g = dgl.DGLGraph(nx.path_graph(3))
    edata = mx.nd.ones((g.number_of_edges(), 1))
    a = nn.edge_softmax(g, edata)
    assert np.allclose(a.asnumpy(), uniform_attention(g, a.shape).asnumpy(),
            1e-4, 1e-4)

    # Test higher dimension case
    edata = mx.nd.ones((g.number_of_edges(), 3, 1))
    a = nn.edge_softmax(g, edata)
    assert np.allclose(a.asnumpy(), uniform_attention(g, a.shape).asnumpy(),
            1e-4, 1e-4)

78
79
if __name__ == '__main__':
    test_graph_conv()
80
    test_edge_softmax()