"src/diffusers/schedulers/scheduling_ddim.py" did not exist on "e01bcbb765d4651716eb815304d959d3e4a0b4ab"
test_batched_graph.py 4.19 KB
Newer Older
1
2
import networkx as nx
import dgl
Minjie Wang's avatar
Minjie Wang committed
3
import torch as th
4
5
6
7
8
9
10
11
12
13
14
15
import numpy as np

def tree1():
    """Generate a tree
         0
        / \
       1   2
      / \
     3   4
    Edges are from leaves to root.
    """
    g = dgl.DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
16
    g.add_nodes(5)
17
18
19
20
    g.add_edge(3, 1)
    g.add_edge(4, 1)
    g.add_edge(1, 0)
    g.add_edge(2, 0)
Minjie Wang's avatar
Minjie Wang committed
21
22
    g.set_n_repr(th.Tensor([0, 1, 2, 3, 4]))
    g.set_e_repr(th.randn(4, 10))
23
24
25
26
27
28
29
30
31
32
33
34
    return g

def tree2():
    """Generate a tree
         1
        / \
       4   3
      / \
     2   0
    Edges are from leaves to root.
    """
    g = dgl.DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
35
    g.add_nodes(5)
36
37
38
39
    g.add_edge(2, 4)
    g.add_edge(0, 4)
    g.add_edge(4, 1)
    g.add_edge(3, 1)
Minjie Wang's avatar
Minjie Wang committed
40
41
    g.set_n_repr(th.Tensor([0, 1, 2, 3, 4]))
    g.set_e_repr(th.randn(4, 10))
42
43
44
45
46
    return g

def test_batch_unbatch():
    t1 = tree1()
    t2 = tree2()
47
48
49
50
    n1 = t1.get_n_repr()
    n2 = t2.get_n_repr()
    e1 = t1.get_e_repr()
    e2 = t2.get_e_repr()
51
52

    bg = dgl.batch([t1, t2])
Minjie Wang's avatar
Minjie Wang committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    assert bg.number_of_nodes() == 10
    assert bg.number_of_edges() == 8
    assert bg.batch_size == 2
    assert bg.batch_num_nodes == [5, 5]
    assert bg.batch_num_edges == [4, 4]

    tt1, tt2 = dgl.unbatch(bg)
    assert th.allclose(t1.get_n_repr(), tt1.get_n_repr())
    assert th.allclose(t1.get_e_repr(), tt1.get_e_repr())
    assert th.allclose(t2.get_n_repr(), tt2.get_n_repr())
    assert th.allclose(t2.get_e_repr(), tt2.get_e_repr())

def test_batch_unbatch1():
    t1 = tree1()
    t2 = tree2()
    b1 = dgl.batch([t1, t2])
    b2 = dgl.batch([t2, b1])
    assert b2.number_of_nodes() == 15
    assert b2.number_of_edges() == 12
    assert b2.batch_size == 3
    assert b2.batch_num_nodes == [5, 5, 5]
    assert b2.batch_num_edges == [4, 4, 4]

    s1, s2, s3 = dgl.unbatch(b2)
    assert th.allclose(t2.get_n_repr(), s1.get_n_repr())
    assert th.allclose(t2.get_e_repr(), s1.get_e_repr())
    assert th.allclose(t1.get_n_repr(), s2.get_n_repr())
    assert th.allclose(t1.get_e_repr(), s2.get_e_repr())
    assert th.allclose(t2.get_n_repr(), s3.get_n_repr())
    assert th.allclose(t2.get_e_repr(), s3.get_e_repr())
83
84
85
86
87
88

def test_batch_sendrecv():
    t1 = tree1()
    t2 = tree2()

    bg = dgl.batch([t1, t2])
Minjie Wang's avatar
Minjie Wang committed
89
    bg.register_message_func(lambda src, edge: src)
Minjie Wang's avatar
Minjie Wang committed
90
    bg.register_reduce_func(lambda node, msgs: th.sum(msgs, 1))
Lingfan Yu's avatar
Lingfan Yu committed
91
92
    u = [3, 4, 2 + 5, 0 + 5]
    v = [1, 1, 4 + 5, 4 + 5]
93

94
    bg.send(u, v)
95
96
    bg.recv(v)

Lingfan Yu's avatar
Lingfan Yu committed
97
    t1, t2 = dgl.unbatch(bg)
98
99
100
101
102
103
104
105
106
    assert t1.get_n_repr()[1] == 7
    assert t2.get_n_repr()[4] == 2


def test_batch_propagate():
    t1 = tree1()
    t2 = tree2()

    bg = dgl.batch([t1, t2])
Minjie Wang's avatar
Minjie Wang committed
107
    bg.register_message_func(lambda src, edge: src)
Minjie Wang's avatar
Minjie Wang committed
108
    bg.register_reduce_func(lambda node, msgs: th.sum(msgs, 1))
109
110
111
112
113
    # get leaves.

    order = []

    # step 1
Lingfan Yu's avatar
Lingfan Yu committed
114
115
    u = [3, 4, 2 + 5, 0 + 5]
    v = [1, 1, 4 + 5, 4 + 5]
116
117
118
    order.append((u, v))

    # step 2
Lingfan Yu's avatar
Lingfan Yu committed
119
120
    u = [1, 2, 4 + 5, 3 + 5]
    v = [0, 0, 1 + 5, 1 + 5]
121
122
    order.append((u, v))

Lingfan Yu's avatar
Lingfan Yu committed
123
124
    bg.propagate(traverser=order)
    t1, t2 = dgl.unbatch(bg)
125
126
127
128

    assert t1.get_n_repr()[0] == 9
    assert t2.get_n_repr()[1] == 5

129
130
def test_batched_edge_ordering():
    g1 = dgl.DGLGraph()
Lingfan Yu's avatar
Lingfan Yu committed
131
132
    g1.add_nodes(6)
    g1.add_edges([4, 4, 2, 2, 0], [5, 3, 3, 1, 1])
Minjie Wang's avatar
Minjie Wang committed
133
    e1 = th.randn(5, 10)
134
135
    g1.set_e_repr(e1)
    g2 = dgl.DGLGraph()
Lingfan Yu's avatar
Lingfan Yu committed
136
137
    g2.add_nodes(6)
    g2.add_edges([0, 1 ,2 ,5, 4 ,5], [1, 2, 3, 4, 3, 0])
Minjie Wang's avatar
Minjie Wang committed
138
    e2 = th.randn(6, 10)
139
140
    g2.set_e_repr(e2)
    g = dgl.batch([g1, g2])
Lingfan Yu's avatar
Lingfan Yu committed
141
142
    r1 = g.get_e_repr()[g.edge_id(4, 5)]
    r2 = g1.get_e_repr()[g1.edge_id(4, 5)]
Minjie Wang's avatar
Minjie Wang committed
143
    assert th.equal(r1, r2)
144

Lingfan Yu's avatar
Lingfan Yu committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def test_batch_no_edge():
    # FIXME: current impl cannot handle this case!!!
    #        comment out for now to test CI
    return
    """
    g1 = dgl.DGLGraph()
    g1.add_nodes(6)
    g1.add_edges([4, 4, 2, 2, 0], [5, 3, 3, 1, 1])
    e1 = th.randn(5, 10)
    g1.set_e_repr(e1)
    g2 = dgl.DGLGraph()
    g2.add_nodes(6)
    g2.add_edges([0, 1, 2, 5, 4, 5], [1 ,2 ,3, 4, 3, 0])
    e2 = th.randn(6, 10)
    g2.set_e_repr(e2)
    g3 = dgl.DGLGraph()
    g3.add_nodes(1)  # no edges

    g = dgl.batch([g1, g3, g2]) # should not throw an error
    """

166
167
if __name__ == '__main__':
    test_batch_unbatch()
Minjie Wang's avatar
Minjie Wang committed
168
    test_batch_unbatch1()
Lingfan Yu's avatar
Lingfan Yu committed
169
170
171
172
    test_batched_edge_ordering()
    test_batch_sendrecv()
    test_batch_propagate()
    test_batch_no_edge()