"vscode:/vscode.git/clone" did not exist on "14f7b545bdc384fd669b55f8a5fd586e1fe25623"
test_transform.py 5.5 KB
Newer Older
1
2
3
4
import networkx as nx
import numpy as np
import dgl
import dgl.function as fn
5
import backend as F
6
7
8
9
10
11
12

D = 5

# line graph related
def test_line_graph():
    N = 5
    G = dgl.DGLGraph(nx.star_graph(N))
13
    G.edata['h'] = F.randn((2 * N, D))
14
15
16
    n_edges = G.number_of_edges()
    L = G.line_graph(shared=True)
    assert L.number_of_nodes() == 2 * N
17
    L.ndata['h'] = F.randn((2 * N, D))
18
19
20
21
22
    # update node features on line graph should reflect to edge features on
    # original graph.
    u = [0, 0, 2, 3]
    v = [1, 2, 0, 0]
    eid = G.edge_ids(u, v)
23
24
    L.nodes[eid].data['h'] = F.zeros((4, D))
    assert F.allclose(G.edges[u, v].data['h'], F.zeros((4, D)))
25
26
27

    # adding a new node feature on line graph should also reflect to a new
    # edge feature on original graph
28
    data = F.randn((n_edges, D))
29
    L.ndata['w'] = data
30
    assert F.allclose(G.edata['w'], data)
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

def test_no_backtracking():
    N = 5
    G = dgl.DGLGraph(nx.star_graph(N))
    L = G.line_graph(backtracking=False)
    assert L.number_of_nodes() == 2 * N
    for i in range(1, N):
        e1 = G.edge_id(0, i)
        e2 = G.edge_id(i, 0)
        assert not L.has_edge_between(e1, e2)
        assert not L.has_edge_between(e2, e1)

# reverse graph related
def test_reverse():
    g = dgl.DGLGraph()
    g.add_nodes(5)
    # The graph need not to be completely connected.
    g.add_edges([0, 1, 2], [1, 2, 1])
49
50
    g.ndata['h'] = F.tensor([[0.], [1.], [2.], [3.], [4.]])
    g.edata['h'] = F.tensor([[5.], [6.], [7.]])
51
52
53
54
55
56
    rg = g.reverse()

    assert g.is_multigraph == rg.is_multigraph

    assert g.number_of_nodes() == rg.number_of_nodes()
    assert g.number_of_edges() == rg.number_of_edges()
57
    assert F.allclose(F.astype(rg.has_edges_between([1, 2, 1], [0, 1, 2]), F.float32), F.ones((3,)))
58
59
60
61
62
63
64
65
    assert g.edge_id(0, 1) == rg.edge_id(1, 0)
    assert g.edge_id(1, 2) == rg.edge_id(2, 1)
    assert g.edge_id(2, 1) == rg.edge_id(1, 2)

def test_reverse_shared_frames():
    g = dgl.DGLGraph()
    g.add_nodes(3)
    g.add_edges([0, 1, 2], [1, 2, 1])
66
67
    g.ndata['h'] = F.tensor([[0.], [1.], [2.]])
    g.edata['h'] = F.tensor([[3.], [4.], [5.]])
68
69

    rg = g.reverse(share_ndata=True, share_edata=True)
70
71
72
    assert F.allclose(g.ndata['h'], rg.ndata['h'])
    assert F.allclose(g.edata['h'], rg.edata['h'])
    assert F.allclose(g.edges[[0, 2], [1, 1]].data['h'],
73
74
75
                      rg.edges[[1, 1], [0, 2]].data['h'])

    rg.ndata['h'] = rg.ndata['h'] + 1
76
    assert F.allclose(rg.ndata['h'], g.ndata['h'])
77
78

    g.edata['h'] = g.edata['h'] - 1
79
    assert F.allclose(rg.edata['h'], g.edata['h'])
80
81
82
83
84

    src_msg = fn.copy_src(src='h', out='m')
    sum_reduce = fn.sum(msg='m', out='h')

    rg.update_all(src_msg, sum_reduce)
85
    assert F.allclose(g.ndata['h'], rg.ndata['h'])
86

87
88
89
90
91
92
93
94
95
96
def test_simple_graph():
    elist = [(0, 1), (0, 2), (1, 2), (0, 1)]
    g = dgl.DGLGraph(elist, readonly=True)
    assert g.is_multigraph
    sg = dgl.to_simple_graph(g)
    assert not sg.is_multigraph
    assert sg.number_of_edges() == 3
    src, dst = sg.edges()
    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
    assert eset == set(elist)
97

98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def test_bidirected_graph():
    def _test(in_readonly, out_readonly):
        elist = [(0, 0), (0, 1), (0, 1), (1, 0), (1, 1), (2, 1), (2, 2), (2, 2)]
        g = dgl.DGLGraph(elist, readonly=in_readonly)
        elist.append((1, 2))
        elist = set(elist)
        big = dgl.to_bidirected(g, out_readonly)
        assert big.number_of_edges() == 10
        src, dst = big.edges()
        eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
        assert eset == set(elist)

    _test(True, True)
    _test(True, False)
    _test(False, True)
    _test(False, False)

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def test_khop_graph():
    N = 20
    feat = F.randn((N, 5))
    g = dgl.DGLGraph(nx.erdos_renyi_graph(N, 0.3))
    for k in range(4):
        g_k = dgl.khop_graph(g, k)
        # use original graph to do message passing for k times.
        g.ndata['h'] = feat
        for _ in range(k):
            g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
        h_0 = g.ndata.pop('h')
        # use k-hop graph to do message passing for one time.
        g_k.ndata['h'] = feat
        g_k.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
        h_1 = g_k.ndata.pop('h')
        assert F.allclose(h_0, h_1, rtol=1e-3, atol=1e-3)

def test_khop_adj():
    N = 20
    feat = F.randn((N, 5))
    g = dgl.DGLGraph(nx.erdos_renyi_graph(N, 0.3))
    for k in range(3):
        adj = F.tensor(dgl.khop_adj(g, k))
        # use original graph to do message passing for k times.
        g.ndata['h'] = feat
        for _ in range(k):
            g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
        h_0 = g.ndata.pop('h')
        # use k-hop adj to do message passing for one time.
        h_1 = F.matmul(adj, feat)
        assert F.allclose(h_0, h_1, rtol=1e-3, atol=1e-3)

def test_laplacian_lambda_max():
    N = 20
    eps = 1e-6
    # test DGLGraph
    g = dgl.DGLGraph(nx.erdos_renyi_graph(N, 0.3))
    l_max = dgl.laplacian_lambda_max(g)
    assert (l_max[0] < 2 + eps)
    # test BatchedDGLGraph
    N_arr = [20, 30, 10, 12]
    bg = dgl.batch([
        dgl.DGLGraph(nx.erdos_renyi_graph(N, 0.3))
        for N in N_arr
    ])
    l_max_arr = dgl.laplacian_lambda_max(bg)
    assert len(l_max_arr) == len(N_arr)
    for l_max in l_max_arr:
        assert l_max < 2 + eps

165
166
167
168
169
if __name__ == '__main__':
    test_line_graph()
    test_no_backtracking()
    test_reverse()
    test_reverse_shared_frames()
170
    test_simple_graph()
171
    test_bidirected_graph()
172
173
174
    test_khop_adj()
    test_khop_graph()
    test_laplacian_lambda_max()