test_removal.py 5.93 KB
Newer Older
1
2
import backend as F
import numpy as np
nv-dlasalle's avatar
nv-dlasalle committed
3
from test_utils import parametrize_idtype
4

5
6
7
import dgl


nv-dlasalle's avatar
nv-dlasalle committed
8
@parametrize_idtype
9
def test_node_removal(idtype):
10
    g = dgl.DGLGraph()
11
    g = g.astype(idtype).to(F.ctx())
12
13
14
    g.add_nodes(10)
    g.add_edge(0, 0)
    assert g.number_of_nodes() == 10
15
    g.ndata["id"] = F.arange(0, 10)
16
17
18
19

    # remove nodes
    g.remove_nodes(range(4, 7))
    assert g.number_of_nodes() == 7
20
    assert F.array_equal(g.ndata["id"], F.tensor([0, 1, 2, 3, 7, 8, 9]))
21
22
    assert dgl.NID not in g.ndata
    assert dgl.EID not in g.edata
23
24
25
26

    # add nodes
    g.add_nodes(3)
    assert g.number_of_nodes() == 10
27
28
29
    assert F.array_equal(
        g.ndata["id"], F.tensor([0, 1, 2, 3, 7, 8, 9, 0, 0, 0])
    )
30
31

    # remove nodes
32
    g.remove_nodes(range(1, 4), store_ids=True)
33
    assert g.number_of_nodes() == 7
34
    assert F.array_equal(g.ndata["id"], F.tensor([0, 7, 8, 9, 0, 0, 0]))
35
36
    assert dgl.NID in g.ndata
    assert dgl.EID in g.edata
37

38

nv-dlasalle's avatar
nv-dlasalle committed
39
@parametrize_idtype
40
def test_multigraph_node_removal(idtype):
41
    g = dgl.DGLGraph()
42
    g = g.astype(idtype).to(F.ctx())
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    g.add_nodes(5)
    for i in range(5):
        g.add_edge(i, i)
        g.add_edge(i, i)
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 10

    # remove nodes
    g.remove_nodes([2, 3])
    assert g.number_of_nodes() == 3
    assert g.number_of_edges() == 6

    # add nodes
    g.add_nodes(1)
    g.add_edge(1, 1)
    g.add_edge(1, 1)
    assert g.number_of_nodes() == 4
    assert g.number_of_edges() == 8

    # remove nodes
    g.remove_nodes([0])
    assert g.number_of_nodes() == 3
    assert g.number_of_edges() == 6

67

nv-dlasalle's avatar
nv-dlasalle committed
68
@parametrize_idtype
69
def test_multigraph_edge_removal(idtype):
70
    g = dgl.DGLGraph()
71
    g = g.astype(idtype).to(F.ctx())
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    g.add_nodes(5)
    for i in range(5):
        g.add_edge(i, i)
        g.add_edge(i, i)
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 10

    # remove edges
    g.remove_edges([2, 3])
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 8

    # add edges
    g.add_edge(1, 1)
    g.add_edge(1, 1)
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 10

    # remove edges
    g.remove_edges([0, 1])
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 8

95

nv-dlasalle's avatar
nv-dlasalle committed
96
@parametrize_idtype
97
def test_edge_removal(idtype):
98
    g = dgl.DGLGraph()
99
    g = g.astype(idtype).to(F.ctx())
100
101
102
103
    g.add_nodes(5)
    for i in range(5):
        for j in range(5):
            g.add_edge(i, j)
104
    g.edata["id"] = F.arange(0, 25)
105
106
107
108
109

    # remove edges
    g.remove_edges(range(13, 20))
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 18
110
111
112
    assert F.array_equal(
        g.edata["id"], F.tensor(list(range(13)) + list(range(20, 25)))
    )
113
114
    assert dgl.NID not in g.ndata
    assert dgl.EID not in g.edata
115
116
117
118
119

    # add edges
    g.add_edge(3, 3)
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 19
120
121
122
    assert F.array_equal(
        g.edata["id"], F.tensor(list(range(13)) + list(range(20, 25)) + [0])
    )
123
124

    # remove edges
125
    g.remove_edges(range(2, 10), store_ids=True)
126
127
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 11
128
129
130
    assert F.array_equal(
        g.edata["id"], F.tensor([0, 1, 10, 11, 12, 20, 21, 22, 23, 24, 0])
    )
131
    assert dgl.EID in g.edata
132

133

nv-dlasalle's avatar
nv-dlasalle committed
134
@parametrize_idtype
135
def test_node_and_edge_removal(idtype):
136
    g = dgl.DGLGraph()
137
    g = g.astype(idtype).to(F.ctx())
138
139
140
141
    g.add_nodes(10)
    for i in range(10):
        for j in range(10):
            g.add_edge(i, j)
142
    g.edata["id"] = F.arange(0, 100)
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
    assert g.number_of_nodes() == 10
    assert g.number_of_edges() == 100

    # remove nodes
    g.remove_nodes([2, 4])
    assert g.number_of_nodes() == 8
    assert g.number_of_edges() == 64

    # remove edges
    g.remove_edges(range(10, 20))
    assert g.number_of_nodes() == 8
    assert g.number_of_edges() == 54

    # add nodes
    g.add_nodes(2)
    assert g.number_of_nodes() == 10
    assert g.number_of_edges() == 54

    # add edges
    for i in range(8, 10):
        for j in range(8, 10):
            g.add_edge(i, j)
    assert g.number_of_nodes() == 10
    assert g.number_of_edges() == 58

    # remove edges
    g.remove_edges(range(10, 20))
    assert g.number_of_nodes() == 10
    assert g.number_of_edges() == 48

173

nv-dlasalle's avatar
nv-dlasalle committed
174
@parametrize_idtype
175
def test_node_frame(idtype):
176
    g = dgl.DGLGraph()
177
    g = g.astype(idtype).to(F.ctx())
178
179
180
    g.add_nodes(10)
    data = np.random.rand(10, 3)
    new_data = data.take([0, 1, 2, 7, 8, 9], axis=0)
181
    g.ndata["h"] = F.tensor(data)
182
183
184

    # remove nodes
    g.remove_nodes(range(3, 7))
185
186
    assert F.allclose(g.ndata["h"], F.tensor(new_data))

187

nv-dlasalle's avatar
nv-dlasalle committed
188
@parametrize_idtype
189
def test_edge_frame(idtype):
190
    g = dgl.DGLGraph()
191
    g = g.astype(idtype).to(F.ctx())
192
193
194
195
    g.add_nodes(10)
    g.add_edges(list(range(10)), list(range(1, 10)) + [0])
    data = np.random.rand(10, 3)
    new_data = data.take([0, 1, 2, 7, 8, 9], axis=0)
196
    g.edata["h"] = F.tensor(data)
197
198
199

    # remove edges
    g.remove_edges(range(3, 7))
200
201
    assert F.allclose(g.edata["h"], F.tensor(new_data))

202

nv-dlasalle's avatar
nv-dlasalle committed
203
@parametrize_idtype
204
def test_issue1287(idtype):
205
    # reproduce https://github.com/dmlc/dgl/issues/1287.
206
    # setting features after remove nodes
207
    g = dgl.DGLGraph()
208
    g = g.astype(idtype).to(F.ctx())
209
210
211
    g.add_nodes(5)
    g.add_edges([0, 2, 3, 1, 1], [1, 0, 3, 1, 0])
    g.remove_nodes([0, 1])
212
213
    g.ndata["h"] = F.randn((g.number_of_nodes(), 3))
    g.edata["h"] = F.randn((g.number_of_edges(), 2))
214

215
    # remove edges
216
    g = dgl.DGLGraph()
217
    g = g.astype(idtype).to(F.ctx())
218
219
220
    g.add_nodes(5)
    g.add_edges([0, 2, 3, 1, 1], [1, 0, 3, 1, 0])
    g.remove_edges([0, 1])
221
    g = g.to(F.ctx())
222
223
224
    g.ndata["h"] = F.randn((g.number_of_nodes(), 3))
    g.edata["h"] = F.randn((g.number_of_edges(), 2))

225

226
if __name__ == "__main__":
227
228
229
230
231
232
233
    test_node_removal()
    test_edge_removal()
    test_multigraph_node_removal()
    test_multigraph_edge_removal()
    test_node_and_edge_removal()
    test_node_frame()
    test_edge_frame()
234
    test_frame_size()