test_graph_batch.py 4.2 KB
Newer Older
1
2
import networkx as nx
import dgl
Minjie Wang's avatar
Minjie Wang committed
3
import torch as th
4
5
6
7
8
9
10
11
12
13
14
15
import numpy as np

def tree1():
    """Generate a tree
         0
        / \
       1   2
      / \
     3   4
    Edges are from leaves to root.
    """
    g = dgl.DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
16
    g.add_nodes(5)
17
18
19
20
    g.add_edge(3, 1)
    g.add_edge(4, 1)
    g.add_edge(1, 0)
    g.add_edge(2, 0)
Minjie Wang's avatar
Minjie Wang committed
21
22
    g.set_n_repr(th.Tensor([0, 1, 2, 3, 4]))
    g.set_e_repr(th.randn(4, 10))
23
24
25
26
27
28
29
30
31
32
33
34
    return g

def tree2():
    """Generate a tree
         1
        / \
       4   3
      / \
     2   0
    Edges are from leaves to root.
    """
    g = dgl.DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
35
    g.add_nodes(5)
36
37
38
39
    g.add_edge(2, 4)
    g.add_edge(0, 4)
    g.add_edge(4, 1)
    g.add_edge(3, 1)
Minjie Wang's avatar
Minjie Wang committed
40
41
    g.set_n_repr(th.Tensor([0, 1, 2, 3, 4]))
    g.set_e_repr(th.randn(4, 10))
42
43
44
45
46
    return g

def test_batch_unbatch():
    t1 = tree1()
    t2 = tree2()
47
48
49
50
    n1 = t1.get_n_repr()
    n2 = t2.get_n_repr()
    e1 = t1.get_e_repr()
    e2 = t2.get_e_repr()
51
52

    bg = dgl.batch([t1, t2])
Minjie Wang's avatar
Minjie Wang committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    assert bg.number_of_nodes() == 10
    assert bg.number_of_edges() == 8
    assert bg.batch_size == 2
    assert bg.batch_num_nodes == [5, 5]
    assert bg.batch_num_edges == [4, 4]

    tt1, tt2 = dgl.unbatch(bg)
    assert th.allclose(t1.get_n_repr(), tt1.get_n_repr())
    assert th.allclose(t1.get_e_repr(), tt1.get_e_repr())
    assert th.allclose(t2.get_n_repr(), tt2.get_n_repr())
    assert th.allclose(t2.get_e_repr(), tt2.get_e_repr())

def test_batch_unbatch1():
    t1 = tree1()
    t2 = tree2()
    b1 = dgl.batch([t1, t2])
    b2 = dgl.batch([t2, b1])
    assert b2.number_of_nodes() == 15
    assert b2.number_of_edges() == 12
    assert b2.batch_size == 3
    assert b2.batch_num_nodes == [5, 5, 5]
    assert b2.batch_num_edges == [4, 4, 4]

    s1, s2, s3 = dgl.unbatch(b2)
    assert th.allclose(t2.get_n_repr(), s1.get_n_repr())
    assert th.allclose(t2.get_e_repr(), s1.get_e_repr())
    assert th.allclose(t1.get_n_repr(), s2.get_n_repr())
    assert th.allclose(t1.get_e_repr(), s2.get_e_repr())
    assert th.allclose(t2.get_n_repr(), s3.get_n_repr())
    assert th.allclose(t2.get_e_repr(), s3.get_e_repr())
83
84
85
86
87
88

def test_batch_sendrecv():
    t1 = tree1()
    t2 = tree2()

    bg = dgl.batch([t1, t2])
Minjie Wang's avatar
Minjie Wang committed
89
    bg.register_message_func(lambda src, edge: src)
Minjie Wang's avatar
Minjie Wang committed
90
    bg.register_reduce_func(lambda node, msgs: th.sum(msgs, 1))
91
92
93
94
95
96
97
98
    e1 = [(3, 1), (4, 1)]
    e2 = [(2, 4), (0, 4)]

    u1, v1 = bg.query_new_edge(t1, *zip(*e1))
    u2, v2 = bg.query_new_edge(t2, *zip(*e2))
    u = np.concatenate((u1, u2)).tolist()
    v = np.concatenate((v1, v2)).tolist()

99
    bg.send(u, v)
100
101
102
103
104
105
106
107
108
109
110
111
    bg.recv(v)

    dgl.unbatch(bg)
    assert t1.get_n_repr()[1] == 7
    assert t2.get_n_repr()[4] == 2


def test_batch_propagate():
    t1 = tree1()
    t2 = tree2()

    bg = dgl.batch([t1, t2])
Minjie Wang's avatar
Minjie Wang committed
112
    bg.register_message_func(lambda src, edge: src)
Minjie Wang's avatar
Minjie Wang committed
113
    bg.register_reduce_func(lambda node, msgs: th.sum(msgs, 1))
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    # get leaves.

    order = []

    # step 1
    e1 = [(3, 1), (4, 1)]
    e2 = [(2, 4), (0, 4)]
    u1, v1 = bg.query_new_edge(t1, *zip(*e1))
    u2, v2 = bg.query_new_edge(t2, *zip(*e2))
    u = np.concatenate((u1, u2)).tolist()
    v = np.concatenate((v1, v2)).tolist()
    order.append((u, v))

    # step 2
    e1 = [(1, 0), (2, 0)]
    e2 = [(4, 1), (3, 1)]
    u1, v1 = bg.query_new_edge(t1, *zip(*e1))
    u2, v2 = bg.query_new_edge(t2, *zip(*e2))
    u = np.concatenate((u1, u2)).tolist()
    v = np.concatenate((v1, v2)).tolist()
    order.append((u, v))

    bg.propagate(iterator=order)
    dgl.unbatch(bg)

    assert t1.get_n_repr()[0] == 9
    assert t2.get_n_repr()[1] == 5

142
143
144
145
146
def test_batched_edge_ordering():
    g1 = dgl.DGLGraph()
    g1.add_nodes_from([0,1,2, 3, 4, 5])
    g1.add_edges_from([(4, 5), (4, 3), (2, 3), (2, 1), (0, 1)])
    g1.edge_list
Minjie Wang's avatar
Minjie Wang committed
147
    e1 = th.randn(5, 10)
148
149
150
151
    g1.set_e_repr(e1)
    g2 = dgl.DGLGraph()
    g2.add_nodes_from([0, 1, 2, 3, 4, 5])
    g2.add_edges_from([(0, 1), (1, 2), (2, 3), (5, 4), (4, 3), (5, 0)])
Minjie Wang's avatar
Minjie Wang committed
152
    e2 = th.randn(6, 10)
153
154
155
156
    g2.set_e_repr(e2)
    g = dgl.batch([g1, g2])
    r1 = g.get_e_repr()[g.get_edge_id(4, 5)]
    r2 = g1.get_e_repr()[g1.get_edge_id(4, 5)]
Minjie Wang's avatar
Minjie Wang committed
157
    assert th.equal(r1, r2)
158
159
160

if __name__ == '__main__':
    test_batch_unbatch()
Minjie Wang's avatar
Minjie Wang committed
161
162
163
164
    test_batch_unbatch1()
    #test_batched_edge_ordering()
    #test_batch_sendrecv()
    #test_batch_propagate()