test_batched_graph.py 4.04 KB
Newer Older
1
2
import networkx as nx
import dgl
Minjie Wang's avatar
Minjie Wang committed
3
import torch as th
4
import numpy as np
5
import utils as U
6
7
8
9
10
11
12
13
14
15
16

def tree1():
    """Generate a tree
         0
        / \
       1   2
      / \
     3   4
    Edges are from leaves to root.
    """
    g = dgl.DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
17
    g.add_nodes(5)
18
19
20
21
    g.add_edge(3, 1)
    g.add_edge(4, 1)
    g.add_edge(1, 0)
    g.add_edge(2, 0)
22
23
    g.ndata['h'] = th.Tensor([0, 1, 2, 3, 4])
    g.edata['h'] = th.randn(4, 10)
24
25
26
27
28
29
30
31
32
33
34
35
    return g

def tree2():
    """Generate a tree
         1
        / \
       4   3
      / \
     2   0
    Edges are from leaves to root.
    """
    g = dgl.DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
36
    g.add_nodes(5)
37
38
39
40
    g.add_edge(2, 4)
    g.add_edge(0, 4)
    g.add_edge(4, 1)
    g.add_edge(3, 1)
41
42
    g.ndata['h'] = th.Tensor([0, 1, 2, 3, 4])
    g.edata['h'] = th.randn(4, 10)
43
44
45
46
47
    return g

def test_batch_unbatch():
    t1 = tree1()
    t2 = tree2()
48
49
50
51
    n1 = t1.ndata['h']
    n2 = t2.ndata['h']
    e1 = t1.edata['h']
    e2 = t2.edata['h']
52
53

    bg = dgl.batch([t1, t2])
Minjie Wang's avatar
Minjie Wang committed
54
55
56
57
58
59
60
    assert bg.number_of_nodes() == 10
    assert bg.number_of_edges() == 8
    assert bg.batch_size == 2
    assert bg.batch_num_nodes == [5, 5]
    assert bg.batch_num_edges == [4, 4]

    tt1, tt2 = dgl.unbatch(bg)
61
62
63
64
    assert U.allclose(t1.ndata['h'], tt1.ndata['h'])
    assert U.allclose(t1.edata['h'], tt1.edata['h'])
    assert U.allclose(t2.ndata['h'], tt2.ndata['h'])
    assert U.allclose(t2.edata['h'], tt2.edata['h'])
Minjie Wang's avatar
Minjie Wang committed
65
66
67
68
69
70
71
72
73
74
75
76
77

def test_batch_unbatch1():
    t1 = tree1()
    t2 = tree2()
    b1 = dgl.batch([t1, t2])
    b2 = dgl.batch([t2, b1])
    assert b2.number_of_nodes() == 15
    assert b2.number_of_edges() == 12
    assert b2.batch_size == 3
    assert b2.batch_num_nodes == [5, 5, 5]
    assert b2.batch_num_edges == [4, 4, 4]

    s1, s2, s3 = dgl.unbatch(b2)
78
79
80
81
82
83
    assert U.allclose(t2.ndata['h'], s1.ndata['h'])
    assert U.allclose(t2.edata['h'], s1.edata['h'])
    assert U.allclose(t1.ndata['h'], s2.ndata['h'])
    assert U.allclose(t1.edata['h'], s2.edata['h'])
    assert U.allclose(t2.ndata['h'], s3.ndata['h'])
    assert U.allclose(t2.edata['h'], s3.edata['h'])
84
85
86
87
88
89

def test_batch_sendrecv():
    t1 = tree1()
    t2 = tree2()

    bg = dgl.batch([t1, t2])
90
91
    bg.register_message_func(lambda edges: {'m' : edges.src['h']})
    bg.register_reduce_func(lambda nodes: {'h' : th.sum(nodes.mailbox['m'], 1)})
Lingfan Yu's avatar
Lingfan Yu committed
92
93
    u = [3, 4, 2 + 5, 0 + 5]
    v = [1, 1, 4 + 5, 4 + 5]
94

95
    bg.send((u, v))
96
97
    bg.recv(v)

Lingfan Yu's avatar
Lingfan Yu committed
98
    t1, t2 = dgl.unbatch(bg)
99
100
    assert t1.ndata['h'][1] == 7
    assert t2.ndata['h'][4] == 2
101
102
103
104
105
106
107


def test_batch_propagate():
    t1 = tree1()
    t2 = tree2()

    bg = dgl.batch([t1, t2])
108
109
    bg.register_message_func(lambda edges: {'m' : edges.src['h']})
    bg.register_reduce_func(lambda nodes: {'h' : th.sum(nodes.mailbox['m'], 1)})
110
111
112
113
114
    # get leaves.

    order = []

    # step 1
Lingfan Yu's avatar
Lingfan Yu committed
115
116
    u = [3, 4, 2 + 5, 0 + 5]
    v = [1, 1, 4 + 5, 4 + 5]
117
118
119
    order.append((u, v))

    # step 2
Lingfan Yu's avatar
Lingfan Yu committed
120
121
    u = [1, 2, 4 + 5, 3 + 5]
    v = [0, 0, 1 + 5, 1 + 5]
122
123
    order.append((u, v))

GaiYu0's avatar
GaiYu0 committed
124
    bg.prop_edges(order)
Lingfan Yu's avatar
Lingfan Yu committed
125
    t1, t2 = dgl.unbatch(bg)
126

127
128
    assert t1.ndata['h'][0] == 9
    assert t2.ndata['h'][1] == 5
129

130
131
def test_batched_edge_ordering():
    g1 = dgl.DGLGraph()
Lingfan Yu's avatar
Lingfan Yu committed
132
133
    g1.add_nodes(6)
    g1.add_edges([4, 4, 2, 2, 0], [5, 3, 3, 1, 1])
Minjie Wang's avatar
Minjie Wang committed
134
    e1 = th.randn(5, 10)
135
    g1.edata['h'] = e1
136
    g2 = dgl.DGLGraph()
Lingfan Yu's avatar
Lingfan Yu committed
137
138
    g2.add_nodes(6)
    g2.add_edges([0, 1 ,2 ,5, 4 ,5], [1, 2, 3, 4, 3, 0])
Minjie Wang's avatar
Minjie Wang committed
139
    e2 = th.randn(6, 10)
140
    g2.edata['h'] = e2
141
    g = dgl.batch([g1, g2])
142
143
    r1 = g.edata['h'][g.edge_id(4, 5)]
    r2 = g1.edata['h'][g1.edge_id(4, 5)]
Minjie Wang's avatar
Minjie Wang committed
144
    assert th.equal(r1, r2)
145

Lingfan Yu's avatar
Lingfan Yu committed
146
147
148
149
150
151
152
153
154
155
156
157
158
def test_batch_no_edge():
    g1 = dgl.DGLGraph()
    g1.add_nodes(6)
    g1.add_edges([4, 4, 2, 2, 0], [5, 3, 3, 1, 1])
    e1 = th.randn(5, 10)
    g2 = dgl.DGLGraph()
    g2.add_nodes(6)
    g2.add_edges([0, 1, 2, 5, 4, 5], [1 ,2 ,3, 4, 3, 0])
    e2 = th.randn(6, 10)
    g3 = dgl.DGLGraph()
    g3.add_nodes(1)  # no edges
    g = dgl.batch([g1, g3, g2]) # should not throw an error

159
160
if __name__ == '__main__':
    test_batch_unbatch()
Minjie Wang's avatar
Minjie Wang committed
161
    test_batch_unbatch1()
Lingfan Yu's avatar
Lingfan Yu committed
162
163
164
165
    test_batched_edge_ordering()
    test_batch_sendrecv()
    test_batch_propagate()
    test_batch_no_edge()