"vscode:/vscode.git/clone" did not exist on "bbb67002c3eda50dd9f83e58d154eb64345209e9"
test_batched_graph.py 4.27 KB
Newer Older
1
2
import networkx as nx
import dgl
Minjie Wang's avatar
Minjie Wang committed
3
import torch as th
4
5
6
7
8
9
10
11
12
13
14
15
import numpy as np

def tree1():
    """Generate a tree
         0
        / \
       1   2
      / \
     3   4
    Edges are from leaves to root.
    """
    g = dgl.DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
16
    g.add_nodes(5)
17
18
19
20
    g.add_edge(3, 1)
    g.add_edge(4, 1)
    g.add_edge(1, 0)
    g.add_edge(2, 0)
Minjie Wang's avatar
Minjie Wang committed
21
22
    g.set_n_repr({'h' : th.Tensor([0, 1, 2, 3, 4])})
    g.set_e_repr({'h' : th.randn(4, 10)})
23
24
25
26
27
28
29
30
31
32
33
34
    return g

def tree2():
    """Generate a tree
         1
        / \
       4   3
      / \
     2   0
    Edges are from leaves to root.
    """
    g = dgl.DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
35
    g.add_nodes(5)
36
37
38
39
    g.add_edge(2, 4)
    g.add_edge(0, 4)
    g.add_edge(4, 1)
    g.add_edge(3, 1)
Minjie Wang's avatar
Minjie Wang committed
40
41
    g.set_n_repr({'h' : th.Tensor([0, 1, 2, 3, 4])})
    g.set_e_repr({'h' : th.randn(4, 10)})
42
43
44
45
46
    return g

def test_batch_unbatch():
    t1 = tree1()
    t2 = tree2()
Minjie Wang's avatar
Minjie Wang committed
47
48
49
50
    n1 = t1.get_n_repr()['h']
    n2 = t2.get_n_repr()['h']
    e1 = t1.get_e_repr()['h']
    e2 = t2.get_e_repr()['h']
51
52

    bg = dgl.batch([t1, t2])
Minjie Wang's avatar
Minjie Wang committed
53
54
55
56
57
58
59
    assert bg.number_of_nodes() == 10
    assert bg.number_of_edges() == 8
    assert bg.batch_size == 2
    assert bg.batch_num_nodes == [5, 5]
    assert bg.batch_num_edges == [4, 4]

    tt1, tt2 = dgl.unbatch(bg)
Minjie Wang's avatar
Minjie Wang committed
60
61
62
63
    assert th.allclose(t1.get_n_repr()['h'], tt1.get_n_repr()['h'])
    assert th.allclose(t1.get_e_repr()['h'], tt1.get_e_repr()['h'])
    assert th.allclose(t2.get_n_repr()['h'], tt2.get_n_repr()['h'])
    assert th.allclose(t2.get_e_repr()['h'], tt2.get_e_repr()['h'])
Minjie Wang's avatar
Minjie Wang committed
64
65
66
67
68
69
70
71
72
73
74
75
76

def test_batch_unbatch1():
    t1 = tree1()
    t2 = tree2()
    b1 = dgl.batch([t1, t2])
    b2 = dgl.batch([t2, b1])
    assert b2.number_of_nodes() == 15
    assert b2.number_of_edges() == 12
    assert b2.batch_size == 3
    assert b2.batch_num_nodes == [5, 5, 5]
    assert b2.batch_num_edges == [4, 4, 4]

    s1, s2, s3 = dgl.unbatch(b2)
Minjie Wang's avatar
Minjie Wang committed
77
78
79
80
81
82
    assert th.allclose(t2.get_n_repr()['h'], s1.get_n_repr()['h'])
    assert th.allclose(t2.get_e_repr()['h'], s1.get_e_repr()['h'])
    assert th.allclose(t1.get_n_repr()['h'], s2.get_n_repr()['h'])
    assert th.allclose(t1.get_e_repr()['h'], s2.get_e_repr()['h'])
    assert th.allclose(t2.get_n_repr()['h'], s3.get_n_repr()['h'])
    assert th.allclose(t2.get_e_repr()['h'], s3.get_e_repr()['h'])
83
84
85
86
87
88

def test_batch_sendrecv():
    t1 = tree1()
    t2 = tree2()

    bg = dgl.batch([t1, t2])
Minjie Wang's avatar
Minjie Wang committed
89
90
    bg.register_message_func(lambda src, edge: {'m' : src['h']})
    bg.register_reduce_func(lambda node, msgs: {'h' : th.sum(msgs['m'], 1)})
Lingfan Yu's avatar
Lingfan Yu committed
91
92
    u = [3, 4, 2 + 5, 0 + 5]
    v = [1, 1, 4 + 5, 4 + 5]
93

94
    bg.send(u, v)
95
96
    bg.recv(v)

Lingfan Yu's avatar
Lingfan Yu committed
97
    t1, t2 = dgl.unbatch(bg)
Minjie Wang's avatar
Minjie Wang committed
98
99
    assert t1.get_n_repr()['h'][1] == 7
    assert t2.get_n_repr()['h'][4] == 2
100
101
102
103
104
105
106


def test_batch_propagate():
    t1 = tree1()
    t2 = tree2()

    bg = dgl.batch([t1, t2])
Minjie Wang's avatar
Minjie Wang committed
107
108
    bg.register_message_func(lambda src, edge: {'m' : src['h']})
    bg.register_reduce_func(lambda node, msgs: {'h' : th.sum(msgs['m'], 1)})
109
110
111
112
113
    # get leaves.

    order = []

    # step 1
Lingfan Yu's avatar
Lingfan Yu committed
114
115
    u = [3, 4, 2 + 5, 0 + 5]
    v = [1, 1, 4 + 5, 4 + 5]
116
117
118
    order.append((u, v))

    # step 2
Lingfan Yu's avatar
Lingfan Yu committed
119
120
    u = [1, 2, 4 + 5, 3 + 5]
    v = [0, 0, 1 + 5, 1 + 5]
121
122
    order.append((u, v))

Lingfan Yu's avatar
Lingfan Yu committed
123
124
    bg.propagate(traverser=order)
    t1, t2 = dgl.unbatch(bg)
125

Minjie Wang's avatar
Minjie Wang committed
126
127
    assert t1.get_n_repr()['h'][0] == 9
    assert t2.get_n_repr()['h'][1] == 5
128

129
130
def test_batched_edge_ordering():
    g1 = dgl.DGLGraph()
Lingfan Yu's avatar
Lingfan Yu committed
131
132
    g1.add_nodes(6)
    g1.add_edges([4, 4, 2, 2, 0], [5, 3, 3, 1, 1])
Minjie Wang's avatar
Minjie Wang committed
133
    e1 = th.randn(5, 10)
Minjie Wang's avatar
Minjie Wang committed
134
    g1.set_e_repr({'h' : e1})
135
    g2 = dgl.DGLGraph()
Lingfan Yu's avatar
Lingfan Yu committed
136
137
    g2.add_nodes(6)
    g2.add_edges([0, 1 ,2 ,5, 4 ,5], [1, 2, 3, 4, 3, 0])
Minjie Wang's avatar
Minjie Wang committed
138
    e2 = th.randn(6, 10)
Minjie Wang's avatar
Minjie Wang committed
139
    g2.set_e_repr({'h' : e2})
140
    g = dgl.batch([g1, g2])
Minjie Wang's avatar
Minjie Wang committed
141
142
    r1 = g.get_e_repr()['h'][g.edge_id(4, 5)]
    r2 = g1.get_e_repr()['h'][g1.edge_id(4, 5)]
Minjie Wang's avatar
Minjie Wang committed
143
    assert th.equal(r1, r2)
144

Lingfan Yu's avatar
Lingfan Yu committed
145
146
147
148
149
150
151
152
153
154
155
156
157
def test_batch_no_edge():
    g1 = dgl.DGLGraph()
    g1.add_nodes(6)
    g1.add_edges([4, 4, 2, 2, 0], [5, 3, 3, 1, 1])
    e1 = th.randn(5, 10)
    g2 = dgl.DGLGraph()
    g2.add_nodes(6)
    g2.add_edges([0, 1, 2, 5, 4, 5], [1 ,2 ,3, 4, 3, 0])
    e2 = th.randn(6, 10)
    g3 = dgl.DGLGraph()
    g3.add_nodes(1)  # no edges
    g = dgl.batch([g1, g3, g2]) # should not throw an error

158
159
if __name__ == '__main__':
    test_batch_unbatch()
Minjie Wang's avatar
Minjie Wang committed
160
    test_batch_unbatch1()
Lingfan Yu's avatar
Lingfan Yu committed
161
162
163
164
    test_batched_edge_ordering()
    test_batch_sendrecv()
    test_batch_propagate()
    test_batch_no_edge()