test_removal.py 5.63 KB
Newer Older
1
2
3
4
5
import os
import backend as F
import networkx as nx
import numpy as np
import dgl
6
from test_utils import parametrize_dtype
7

8
9
@parametrize_dtype
def test_node_removal(idtype):
10
    g = dgl.DGLGraph()
11
    g = g.astype(idtype).to(F.ctx())
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
    g.add_nodes(10)
    g.add_edge(0, 0)
    assert g.number_of_nodes() == 10
    g.ndata['id'] = F.arange(0, 10)

    # remove nodes
    g.remove_nodes(range(4, 7))
    assert g.number_of_nodes() == 7
    assert F.array_equal(g.ndata['id'], F.tensor([0, 1, 2, 3, 7, 8, 9]))

    # add nodes
    g.add_nodes(3)
    assert g.number_of_nodes() == 10
    assert F.array_equal(g.ndata['id'], F.tensor([0, 1, 2, 3, 7, 8, 9, 0, 0, 0]))

    # remove nodes
    g.remove_nodes(range(1, 4))
    assert g.number_of_nodes() == 7
    assert F.array_equal(g.ndata['id'], F.tensor([0, 7, 8, 9, 0, 0, 0]))

32
33
@parametrize_dtype
def test_multigraph_node_removal(idtype):
34
    g = dgl.DGLGraph()
35
    g = g.astype(idtype).to(F.ctx())
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    g.add_nodes(5)
    for i in range(5):
        g.add_edge(i, i)
        g.add_edge(i, i)
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 10

    # remove nodes
    g.remove_nodes([2, 3])
    assert g.number_of_nodes() == 3
    assert g.number_of_edges() == 6

    # add nodes
    g.add_nodes(1)
    g.add_edge(1, 1)
    g.add_edge(1, 1)
    assert g.number_of_nodes() == 4
    assert g.number_of_edges() == 8

    # remove nodes
    g.remove_nodes([0])
    assert g.number_of_nodes() == 3
    assert g.number_of_edges() == 6

60
61
@parametrize_dtype
def test_multigraph_edge_removal(idtype):
62
    g = dgl.DGLGraph()
63
    g = g.astype(idtype).to(F.ctx())
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    g.add_nodes(5)
    for i in range(5):
        g.add_edge(i, i)
        g.add_edge(i, i)
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 10

    # remove edges
    g.remove_edges([2, 3])
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 8

    # add edges
    g.add_edge(1, 1)
    g.add_edge(1, 1)
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 10

    # remove edges
    g.remove_edges([0, 1])
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 8

87
88
@parametrize_dtype
def test_edge_removal(idtype):
89
    g = dgl.DGLGraph()
90
    g = g.astype(idtype).to(F.ctx())
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    g.add_nodes(5)
    for i in range(5):
        for j in range(5):
            g.add_edge(i, j)
    g.edata['id'] = F.arange(0, 25)

    # remove edges
    g.remove_edges(range(13, 20))
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 18
    assert F.array_equal(g.edata['id'], F.tensor(list(range(13)) + list(range(20, 25))))

    # add edges
    g.add_edge(3, 3)
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 19
    assert F.array_equal(g.edata['id'], F.tensor(list(range(13)) + list(range(20, 25)) + [0]))

    # remove edges
    g.remove_edges(range(2, 10))
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 11
    assert F.array_equal(g.edata['id'], F.tensor([0, 1, 10, 11, 12, 20, 21, 22, 23, 24, 0]))

115
116
@parametrize_dtype
def test_node_and_edge_removal(idtype):
117
    g = dgl.DGLGraph()
118
    g = g.astype(idtype).to(F.ctx())
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
    g.add_nodes(10)
    for i in range(10):
        for j in range(10):
            g.add_edge(i, j)
    g.edata['id'] = F.arange(0, 100)
    assert g.number_of_nodes() == 10
    assert g.number_of_edges() == 100

    # remove nodes
    g.remove_nodes([2, 4])
    assert g.number_of_nodes() == 8
    assert g.number_of_edges() == 64

    # remove edges
    g.remove_edges(range(10, 20))
    assert g.number_of_nodes() == 8
    assert g.number_of_edges() == 54

    # add nodes
    g.add_nodes(2)
    assert g.number_of_nodes() == 10
    assert g.number_of_edges() == 54

    # add edges
    for i in range(8, 10):
        for j in range(8, 10):
            g.add_edge(i, j)
    assert g.number_of_nodes() == 10
    assert g.number_of_edges() == 58

    # remove edges
    g.remove_edges(range(10, 20))
    assert g.number_of_nodes() == 10
    assert g.number_of_edges() == 48

154
155
@parametrize_dtype
def test_node_frame(idtype):
156
    g = dgl.DGLGraph()
157
    g = g.astype(idtype).to(F.ctx())
158
159
160
    g.add_nodes(10)
    data = np.random.rand(10, 3)
    new_data = data.take([0, 1, 2, 7, 8, 9], axis=0)
161
    g.ndata['h'] = F.tensor(data)
162
163
164

    # remove nodes
    g.remove_nodes(range(3, 7))
165
    assert F.allclose(g.ndata['h'], F.tensor(new_data))
166

167
168
@parametrize_dtype
def test_edge_frame(idtype):
169
    g = dgl.DGLGraph()
170
    g = g.astype(idtype).to(F.ctx())
171
172
173
174
    g.add_nodes(10)
    g.add_edges(list(range(10)), list(range(1, 10)) + [0])
    data = np.random.rand(10, 3)
    new_data = data.take([0, 1, 2, 7, 8, 9], axis=0)
175
    g.edata['h'] = F.tensor(data)
176
177
178

    # remove edges
    g.remove_edges(range(3, 7))
179
    assert F.allclose(g.edata['h'], F.tensor(new_data))
180

181
182
@parametrize_dtype
def test_issue1287(idtype):
183
    # reproduce https://github.com/dmlc/dgl/issues/1287.
184
    # setting features after remove nodes
185
    g = dgl.DGLGraph()
186
    g = g.astype(idtype).to(F.ctx())
187
188
189
    g.add_nodes(5)
    g.add_edges([0, 2, 3, 1, 1], [1, 0, 3, 1, 0])
    g.remove_nodes([0, 1])
190
191
    g.ndata['h'] = F.randn((g.number_of_nodes(), 3))
    g.edata['h'] = F.randn((g.number_of_edges(), 2))
192
193
194

    # remove edges 
    g = dgl.DGLGraph()
195
    g = g.astype(idtype).to(F.ctx())
196
197
198
    g.add_nodes(5)
    g.add_edges([0, 2, 3, 1, 1], [1, 0, 3, 1, 0])
    g.remove_edges([0, 1])
199
200
201
    g = g.to(F.ctx())
    g.ndata['h'] = F.randn((g.number_of_nodes(), 3))
    g.edata['h'] = F.randn((g.number_of_edges(), 2))
202

203
204
205
206
207
208
209
210
if __name__ == '__main__':
    test_node_removal()
    test_edge_removal()
    test_multigraph_node_removal()
    test_multigraph_edge_removal()
    test_node_and_edge_removal()
    test_node_frame()
    test_edge_frame()
211
    test_frame_size()