test_removal.py 5.88 KB
Newer Older
1
2
3
import backend as F
import numpy as np
import dgl
4
from test_utils import parametrize_dtype
5

6
7
@parametrize_dtype
def test_node_removal(idtype):
8
    g = dgl.DGLGraph()
9
    g = g.astype(idtype).to(F.ctx())
10
11
12
13
14
15
16
17
18
    g.add_nodes(10)
    g.add_edge(0, 0)
    assert g.number_of_nodes() == 10
    g.ndata['id'] = F.arange(0, 10)

    # remove nodes
    g.remove_nodes(range(4, 7))
    assert g.number_of_nodes() == 7
    assert F.array_equal(g.ndata['id'], F.tensor([0, 1, 2, 3, 7, 8, 9]))
19
20
    assert dgl.NID not in g.ndata
    assert dgl.EID not in g.edata
21
22
23
24
25
26
27

    # add nodes
    g.add_nodes(3)
    assert g.number_of_nodes() == 10
    assert F.array_equal(g.ndata['id'], F.tensor([0, 1, 2, 3, 7, 8, 9, 0, 0, 0]))

    # remove nodes
28
    g.remove_nodes(range(1, 4), store_ids=True)
29
30
    assert g.number_of_nodes() == 7
    assert F.array_equal(g.ndata['id'], F.tensor([0, 7, 8, 9, 0, 0, 0]))
31
32
    assert dgl.NID in g.ndata
    assert dgl.EID in g.edata
33

34
35
@parametrize_dtype
def test_multigraph_node_removal(idtype):
36
    g = dgl.DGLGraph()
37
    g = g.astype(idtype).to(F.ctx())
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    g.add_nodes(5)
    for i in range(5):
        g.add_edge(i, i)
        g.add_edge(i, i)
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 10

    # remove nodes
    g.remove_nodes([2, 3])
    assert g.number_of_nodes() == 3
    assert g.number_of_edges() == 6

    # add nodes
    g.add_nodes(1)
    g.add_edge(1, 1)
    g.add_edge(1, 1)
    assert g.number_of_nodes() == 4
    assert g.number_of_edges() == 8

    # remove nodes
    g.remove_nodes([0])
    assert g.number_of_nodes() == 3
    assert g.number_of_edges() == 6

62
63
@parametrize_dtype
def test_multigraph_edge_removal(idtype):
64
    g = dgl.DGLGraph()
65
    g = g.astype(idtype).to(F.ctx())
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    g.add_nodes(5)
    for i in range(5):
        g.add_edge(i, i)
        g.add_edge(i, i)
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 10

    # remove edges
    g.remove_edges([2, 3])
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 8

    # add edges
    g.add_edge(1, 1)
    g.add_edge(1, 1)
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 10

    # remove edges
    g.remove_edges([0, 1])
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 8

89
90
@parametrize_dtype
def test_edge_removal(idtype):
91
    g = dgl.DGLGraph()
92
    g = g.astype(idtype).to(F.ctx())
93
94
95
96
97
98
99
100
101
102
103
    g.add_nodes(5)
    for i in range(5):
        for j in range(5):
            g.add_edge(i, j)
    g.edata['id'] = F.arange(0, 25)

    # remove edges
    g.remove_edges(range(13, 20))
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 18
    assert F.array_equal(g.edata['id'], F.tensor(list(range(13)) + list(range(20, 25))))
104
105
    assert dgl.NID not in g.ndata
    assert dgl.EID not in g.edata
106
107
108
109
110
111
112
113

    # add edges
    g.add_edge(3, 3)
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 19
    assert F.array_equal(g.edata['id'], F.tensor(list(range(13)) + list(range(20, 25)) + [0]))

    # remove edges
114
    g.remove_edges(range(2, 10), store_ids=True)
115
116
117
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 11
    assert F.array_equal(g.edata['id'], F.tensor([0, 1, 10, 11, 12, 20, 21, 22, 23, 24, 0]))
118
119
    assert dgl.NID in g.ndata
    assert dgl.EID in g.edata
120

121
122
@parametrize_dtype
def test_node_and_edge_removal(idtype):
123
    g = dgl.DGLGraph()
124
    g = g.astype(idtype).to(F.ctx())
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    g.add_nodes(10)
    for i in range(10):
        for j in range(10):
            g.add_edge(i, j)
    g.edata['id'] = F.arange(0, 100)
    assert g.number_of_nodes() == 10
    assert g.number_of_edges() == 100

    # remove nodes
    g.remove_nodes([2, 4])
    assert g.number_of_nodes() == 8
    assert g.number_of_edges() == 64

    # remove edges
    g.remove_edges(range(10, 20))
    assert g.number_of_nodes() == 8
    assert g.number_of_edges() == 54

    # add nodes
    g.add_nodes(2)
    assert g.number_of_nodes() == 10
    assert g.number_of_edges() == 54

    # add edges
    for i in range(8, 10):
        for j in range(8, 10):
            g.add_edge(i, j)
    assert g.number_of_nodes() == 10
    assert g.number_of_edges() == 58

    # remove edges
    g.remove_edges(range(10, 20))
    assert g.number_of_nodes() == 10
    assert g.number_of_edges() == 48

160
161
@parametrize_dtype
def test_node_frame(idtype):
162
    g = dgl.DGLGraph()
163
    g = g.astype(idtype).to(F.ctx())
164
165
166
    g.add_nodes(10)
    data = np.random.rand(10, 3)
    new_data = data.take([0, 1, 2, 7, 8, 9], axis=0)
167
    g.ndata['h'] = F.tensor(data)
168
169
170

    # remove nodes
    g.remove_nodes(range(3, 7))
171
    assert F.allclose(g.ndata['h'], F.tensor(new_data))
172

173
174
@parametrize_dtype
def test_edge_frame(idtype):
175
    g = dgl.DGLGraph()
176
    g = g.astype(idtype).to(F.ctx())
177
178
179
180
    g.add_nodes(10)
    g.add_edges(list(range(10)), list(range(1, 10)) + [0])
    data = np.random.rand(10, 3)
    new_data = data.take([0, 1, 2, 7, 8, 9], axis=0)
181
    g.edata['h'] = F.tensor(data)
182
183
184

    # remove edges
    g.remove_edges(range(3, 7))
185
    assert F.allclose(g.edata['h'], F.tensor(new_data))
186

187
188
@parametrize_dtype
def test_issue1287(idtype):
189
    # reproduce https://github.com/dmlc/dgl/issues/1287.
190
    # setting features after remove nodes
191
    g = dgl.DGLGraph()
192
    g = g.astype(idtype).to(F.ctx())
193
194
195
    g.add_nodes(5)
    g.add_edges([0, 2, 3, 1, 1], [1, 0, 3, 1, 0])
    g.remove_nodes([0, 1])
196
197
    g.ndata['h'] = F.randn((g.number_of_nodes(), 3))
    g.edata['h'] = F.randn((g.number_of_edges(), 2))
198
199
200

    # remove edges 
    g = dgl.DGLGraph()
201
    g = g.astype(idtype).to(F.ctx())
202
203
204
    g.add_nodes(5)
    g.add_edges([0, 2, 3, 1, 1], [1, 0, 3, 1, 0])
    g.remove_edges([0, 1])
205
206
207
    g = g.to(F.ctx())
    g.ndata['h'] = F.randn((g.number_of_nodes(), 3))
    g.edata['h'] = F.randn((g.number_of_edges(), 2))
208

209
210
211
212
213
214
215
216
if __name__ == '__main__':
    test_node_removal()
    test_edge_removal()
    test_multigraph_node_removal()
    test_multigraph_edge_removal()
    test_node_and_edge_removal()
    test_node_frame()
    test_edge_frame()
217
    test_frame_size()