test_basics_anonymous.py 5.32 KB
Newer Older
1
import torch as th
2
3
from torch.autograd import Variable
import numpy as np
4
5
6
7
8
from dgl.graph import DGLGraph, __REPR__

D = 32
reduce_msg_shapes = set()

9
10
11
12
def check_eq(a, b):
    assert a.shape == b.shape
    assert th.sum(a == b) == int(np.prod(list(a.shape)))

13
14
15
16
17
18
19
20
21
def message_func(hu, e_uv):
    assert len(hu.shape) == 2
    assert hu.shape[1] == D
    return hu

def reduce_func(hv, msgs):
    reduce_msg_shapes.add(tuple(msgs.shape))
    assert len(msgs.shape) == 3
    assert msgs.shape[2] == D
22
    return hv + th.sum(msgs, 1)
23

24
def generate_graph(grad=False):
25
    g = DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
26
    g.add_nodes(10)
27
28
29
30
31
32
    # create a graph where 0 is the source and 9 is the sink
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
    # add a back flow from 9 to 0
    g.add_edge(9, 0)
Minjie Wang's avatar
Minjie Wang committed
33
34
35
36
    ncol = Variable(th.randn(10, D), requires_grad=grad)
    ecol = Variable(th.randn(17, D), requires_grad=grad)
    g.set_n_repr(ncol)
    g.set_e_repr(ecol)
37
38
39
40
41
42
43
44
45
    return g

def test_batch_setter_getter():
    def _pfc(x):
        return list(x.numpy()[:,0])
    g = generate_graph()
    # set all nodes
    g.set_n_repr(th.zeros((10, D)))
    assert _pfc(g.get_n_repr()) == [0.] * 10
Minjie Wang's avatar
Minjie Wang committed
46
47
48
49
    # pop nodes
    assert _pfc(g.pop_n_repr()) == [0.] * 10
    assert len(g.get_n_repr()) == 0
    g.set_n_repr(th.zeros((10, D)))
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    # set partial nodes
    u = th.tensor([1, 3, 5])
    g.set_n_repr(th.ones((3, D)), u)
    assert _pfc(g.get_n_repr()) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
    # get partial nodes
    u = th.tensor([1, 2, 3])
    assert _pfc(g.get_n_repr(u)) == [1., 0., 1.]

    '''
    s, d, eid
    0, 1, 0
    1, 9, 1
    0, 2, 2
    2, 9, 3
    0, 3, 4
    3, 9, 5
    0, 4, 6
    4, 9, 7
    0, 5, 8
    5, 9, 9
    0, 6, 10
    6, 9, 11
    0, 7, 12
    7, 9, 13
    0, 8, 14
    8, 9, 15
    9, 0, 16
    '''
    # set all edges
    g.set_e_repr(th.zeros((17, D)))
    assert _pfc(g.get_e_repr()) == [0.] * 17
Minjie Wang's avatar
Minjie Wang committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    # pop edges
    assert _pfc(g.pop_e_repr()) == [0.] * 17
    assert len(g.get_e_repr()) == 0
    g.set_e_repr(th.zeros((17, D)))
    # set partial edges (many-many)
    u = th.tensor([0, 0, 2, 5, 9])
    v = th.tensor([1, 3, 9, 9, 0])
    g.set_e_repr(th.ones((5, D)), u, v)
    truth = [0.] * 17
    truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
    assert _pfc(g.get_e_repr()) == truth
    # set partial edges (many-one)
    u = th.tensor([3, 4, 6])
    v = th.tensor([9])
    g.set_e_repr(th.ones((3, D)), u, v)
    truth[5] = truth[7] = truth[11] = 1.
    assert _pfc(g.get_e_repr()) == truth
    # set partial edges (one-many)
    u = th.tensor([0])
    v = th.tensor([4, 5, 6])
    g.set_e_repr(th.ones((3, D)), u, v)
    truth[6] = truth[8] = truth[10] = 1.
    assert _pfc(g.get_e_repr()) == truth
    # get partial edges (many-many)
    u = th.tensor([0, 6, 0])
    v = th.tensor([6, 9, 7])
    assert _pfc(g.get_e_repr(u, v)) == [1., 1., 0.]
    # get partial edges (many-one)
    u = th.tensor([5, 6, 7])
    v = th.tensor([9])
    assert _pfc(g.get_e_repr(u, v)) == [1., 1., 0.]
    # get partial edges (one-many)
    u = th.tensor([0])
    v = th.tensor([3, 4, 5])
    assert _pfc(g.get_e_repr(u, v)) == [1., 1., 1.]
116

117
118
119
120
121
122
123
124
125
126
127
128
def test_batch_setter_autograd():
    g = generate_graph(grad=True)
    h1 = g.get_n_repr()
    # partial set
    v = th.tensor([1, 2, 8])
    hh = Variable(th.zeros((len(v), D)), requires_grad=True)
    g.set_n_repr(hh, v)
    h2 = g.get_n_repr()
    h2.backward(th.ones((10, D)) * 2)
    check_eq(h1.grad[:,0], th.tensor([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
    check_eq(hh.grad[:,0], th.tensor([2., 2., 2.]))

129
130
131
132
133
def test_batch_send():
    g = generate_graph()
    def _fmsg(hu, edge):
        assert hu.shape == (5, D)
        return hu
Minjie Wang's avatar
Minjie Wang committed
134
    g.register_message_func(_fmsg)
135
    # many-many send
136
137
    u = th.tensor([0, 0, 0, 0, 0])
    v = th.tensor([1, 2, 3, 4, 5])
138
139
    g.send(u, v)
    # one-many send
140
141
    u = th.tensor([0])
    v = th.tensor([1, 2, 3, 4, 5])
142
143
    g.send(u, v)
    # many-one send
144
145
    u = th.tensor([1, 2, 3, 4, 5])
    v = th.tensor([9])
146
    g.send(u, v)
147
148
149

def test_batch_recv():
    g = generate_graph()
Minjie Wang's avatar
Minjie Wang committed
150
151
    g.register_message_func(message_func)
    g.register_reduce_func(reduce_func)
152
153
154
    u = th.tensor([0, 0, 0, 4, 5, 6])
    v = th.tensor([1, 2, 3, 9, 9, 9])
    reduce_msg_shapes.clear()
155
    g.send(u, v)
156
157
158
159
160
161
    g.recv(th.unique(v))
    assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
    reduce_msg_shapes.clear()

def test_update_routines():
    g = generate_graph()
Minjie Wang's avatar
Minjie Wang committed
162
163
    g.register_message_func(message_func)
    g.register_reduce_func(reduce_func)
164

165
    # send_and_recv
166
167
168
    reduce_msg_shapes.clear()
    u = th.tensor([0, 0, 0, 4, 5, 6])
    v = th.tensor([1, 2, 3, 9, 9, 9])
169
    g.send_and_recv(u, v)
170
171
172
    assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
    reduce_msg_shapes.clear()

173
    # pull
174
175
    v = th.tensor([1, 2, 3, 9])
    reduce_msg_shapes.clear()
176
    g.pull(v)
177
178
179
    assert(reduce_msg_shapes == {(1, 8, D), (3, 1, D)})
    reduce_msg_shapes.clear()

180
    # push
181
182
    v = th.tensor([0, 1, 2, 3])
    reduce_msg_shapes.clear()
183
    g.push(v)
184
185
186
187
188
189
190
191
192
193
    assert(reduce_msg_shapes == {(1, 3, D), (8, 1, D)})
    reduce_msg_shapes.clear()

    # update_all
    reduce_msg_shapes.clear()
    g.update_all()
    assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)})
    reduce_msg_shapes.clear()

if __name__ == '__main__':
194
195
    test_batch_setter_getter()
    test_batch_setter_autograd()
196
197
198
    test_batch_send()
    test_batch_recv()
    test_update_routines()