test_heterograph-remove.py 5.94 KB
Newer Older
1
2
import backend as F

3
import dgl
4
5
import numpy as np
from test_utils import parametrize_idtype
6
7


nv-dlasalle's avatar
nv-dlasalle committed
8
@parametrize_idtype
9
def test_node_removal(idtype):
10
    g = dgl.DGLGraph()
11
    g = g.astype(idtype).to(F.ctx())
12
    g.add_nodes(10)
13
    g.add_edges(0, 0)
14
    assert g.number_of_nodes() == 10
15
    g.ndata["id"] = F.arange(0, 10)
16
17
18
19

    # remove nodes
    g.remove_nodes(range(4, 7))
    assert g.number_of_nodes() == 7
20
    assert F.array_equal(g.ndata["id"], F.tensor([0, 1, 2, 3, 7, 8, 9]))
21
22
    assert dgl.NID not in g.ndata
    assert dgl.EID not in g.edata
23
24
25
26

    # add nodes
    g.add_nodes(3)
    assert g.number_of_nodes() == 10
27
28
29
    assert F.array_equal(
        g.ndata["id"], F.tensor([0, 1, 2, 3, 7, 8, 9, 0, 0, 0])
    )
30
31

    # remove nodes
32
    g.remove_nodes(range(1, 4), store_ids=True)
33
    assert g.number_of_nodes() == 7
34
    assert F.array_equal(g.ndata["id"], F.tensor([0, 7, 8, 9, 0, 0, 0]))
35
36
    assert dgl.NID in g.ndata
    assert dgl.EID in g.edata
37

38

nv-dlasalle's avatar
nv-dlasalle committed
39
@parametrize_idtype
40
def test_multigraph_node_removal(idtype):
41
    g = dgl.DGLGraph()
42
    g = g.astype(idtype).to(F.ctx())
43
44
    g.add_nodes(5)
    for i in range(5):
45
46
        g.add_edges(i, i)
        g.add_edges(i, i)
47
48
49
50
51
52
53
54
55
56
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 10

    # remove nodes
    g.remove_nodes([2, 3])
    assert g.number_of_nodes() == 3
    assert g.number_of_edges() == 6

    # add nodes
    g.add_nodes(1)
57
58
    g.add_edges(1, 1)
    g.add_edges(1, 1)
59
60
61
62
63
64
65
66
    assert g.number_of_nodes() == 4
    assert g.number_of_edges() == 8

    # remove nodes
    g.remove_nodes([0])
    assert g.number_of_nodes() == 3
    assert g.number_of_edges() == 6

67

nv-dlasalle's avatar
nv-dlasalle committed
68
@parametrize_idtype
69
def test_multigraph_edge_removal(idtype):
70
    g = dgl.DGLGraph()
71
    g = g.astype(idtype).to(F.ctx())
72
73
    g.add_nodes(5)
    for i in range(5):
74
75
        g.add_edges(i, i)
        g.add_edges(i, i)
76
77
78
79
80
81
82
83
84
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 10

    # remove edges
    g.remove_edges([2, 3])
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 8

    # add edges
85
86
    g.add_edges(1, 1)
    g.add_edges(1, 1)
87
88
89
90
91
92
93
94
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 10

    # remove edges
    g.remove_edges([0, 1])
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 8

95

nv-dlasalle's avatar
nv-dlasalle committed
96
@parametrize_idtype
97
def test_edge_removal(idtype):
98
    g = dgl.DGLGraph()
99
    g = g.astype(idtype).to(F.ctx())
100
101
102
    g.add_nodes(5)
    for i in range(5):
        for j in range(5):
103
            g.add_edges(i, j)
104
    g.edata["id"] = F.arange(0, 25)
105
106
107
108
109

    # remove edges
    g.remove_edges(range(13, 20))
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 18
110
111
112
    assert F.array_equal(
        g.edata["id"], F.tensor(list(range(13)) + list(range(20, 25)))
    )
113
114
    assert dgl.NID not in g.ndata
    assert dgl.EID not in g.edata
115
116

    # add edges
117
    g.add_edges(3, 3)
118
119
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 19
120
121
122
    assert F.array_equal(
        g.edata["id"], F.tensor(list(range(13)) + list(range(20, 25)) + [0])
    )
123
124

    # remove edges
125
    g.remove_edges(range(2, 10), store_ids=True)
126
127
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 11
128
129
130
    assert F.array_equal(
        g.edata["id"], F.tensor([0, 1, 10, 11, 12, 20, 21, 22, 23, 24, 0])
    )
131
    assert dgl.EID in g.edata
132

133

nv-dlasalle's avatar
nv-dlasalle committed
134
@parametrize_idtype
135
def test_node_and_edge_removal(idtype):
136
    g = dgl.DGLGraph()
137
    g = g.astype(idtype).to(F.ctx())
138
139
140
    g.add_nodes(10)
    for i in range(10):
        for j in range(10):
141
            g.add_edges(i, j)
142
    g.edata["id"] = F.arange(0, 100)
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    assert g.number_of_nodes() == 10
    assert g.number_of_edges() == 100

    # remove nodes
    g.remove_nodes([2, 4])
    assert g.number_of_nodes() == 8
    assert g.number_of_edges() == 64

    # remove edges
    g.remove_edges(range(10, 20))
    assert g.number_of_nodes() == 8
    assert g.number_of_edges() == 54

    # add nodes
    g.add_nodes(2)
    assert g.number_of_nodes() == 10
    assert g.number_of_edges() == 54

    # add edges
    for i in range(8, 10):
        for j in range(8, 10):
164
            g.add_edges(i, j)
165
166
167
168
169
170
171
172
    assert g.number_of_nodes() == 10
    assert g.number_of_edges() == 58

    # remove edges
    g.remove_edges(range(10, 20))
    assert g.number_of_nodes() == 10
    assert g.number_of_edges() == 48

173

nv-dlasalle's avatar
nv-dlasalle committed
174
@parametrize_idtype
175
def test_node_frame(idtype):
176
    g = dgl.DGLGraph()
177
    g = g.astype(idtype).to(F.ctx())
178
179
180
    g.add_nodes(10)
    data = np.random.rand(10, 3)
    new_data = data.take([0, 1, 2, 7, 8, 9], axis=0)
181
    g.ndata["h"] = F.tensor(data)
182
183
184

    # remove nodes
    g.remove_nodes(range(3, 7))
185
186
    assert F.allclose(g.ndata["h"], F.tensor(new_data))

187

nv-dlasalle's avatar
nv-dlasalle committed
188
@parametrize_idtype
189
def test_edge_frame(idtype):
190
    g = dgl.DGLGraph()
191
    g = g.astype(idtype).to(F.ctx())
192
193
194
195
    g.add_nodes(10)
    g.add_edges(list(range(10)), list(range(1, 10)) + [0])
    data = np.random.rand(10, 3)
    new_data = data.take([0, 1, 2, 7, 8, 9], axis=0)
196
    g.edata["h"] = F.tensor(data)
197
198
199

    # remove edges
    g.remove_edges(range(3, 7))
200
201
    assert F.allclose(g.edata["h"], F.tensor(new_data))

202

nv-dlasalle's avatar
nv-dlasalle committed
203
@parametrize_idtype
204
def test_issue1287(idtype):
205
    # reproduce https://github.com/dmlc/dgl/issues/1287.
206
    # setting features after remove nodes
207
    g = dgl.DGLGraph()
208
    g = g.astype(idtype).to(F.ctx())
209
210
211
    g.add_nodes(5)
    g.add_edges([0, 2, 3, 1, 1], [1, 0, 3, 1, 0])
    g.remove_nodes([0, 1])
212
213
    g.ndata["h"] = F.randn((g.number_of_nodes(), 3))
    g.edata["h"] = F.randn((g.number_of_edges(), 2))
214

215
    # remove edges
216
    g = dgl.DGLGraph()
217
    g = g.astype(idtype).to(F.ctx())
218
219
220
    g.add_nodes(5)
    g.add_edges([0, 2, 3, 1, 1], [1, 0, 3, 1, 0])
    g.remove_edges([0, 1])
221
    g = g.to(F.ctx())
222
223
224
    g.ndata["h"] = F.randn((g.number_of_nodes(), 3))
    g.edata["h"] = F.randn((g.number_of_edges(), 2))

225

226
if __name__ == "__main__":
227
228
229
230
231
232
233
    test_node_removal()
    test_edge_removal()
    test_multigraph_node_removal()
    test_multigraph_edge_removal()
    test_node_and_edge_removal()
    test_node_frame()
    test_edge_frame()
234
    test_frame_size()