test_heterograph-remove.py 5.41 KB
Newer Older
1
2
import backend as F

3
import dgl
4
import numpy as np
5
from utils import parametrize_idtype
6
7


8
9
10
11
12
13
14
def create_graph(idtype, num_node):
    g = dgl.graph([])
    g = g.astype(idtype).to(F.ctx())
    g.add_nodes(num_node)
    return g


nv-dlasalle's avatar
nv-dlasalle committed
15
@parametrize_idtype
16
def test_node_removal(idtype):
17
    g = create_graph(idtype, 10)
18
    g.add_edges(0, 0)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
19
    assert g.num_nodes() == 10
20
    g.ndata["id"] = F.arange(0, 10)
21
22
23

    # remove nodes
    g.remove_nodes(range(4, 7))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
24
    assert g.num_nodes() == 7
25
    assert F.array_equal(g.ndata["id"], F.tensor([0, 1, 2, 3, 7, 8, 9]))
26
27
    assert dgl.NID not in g.ndata
    assert dgl.EID not in g.edata
28
29
30

    # add nodes
    g.add_nodes(3)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
31
    assert g.num_nodes() == 10
32
33
34
    assert F.array_equal(
        g.ndata["id"], F.tensor([0, 1, 2, 3, 7, 8, 9, 0, 0, 0])
    )
35
36

    # remove nodes
37
    g.remove_nodes(range(1, 4), store_ids=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
38
    assert g.num_nodes() == 7
39
    assert F.array_equal(g.ndata["id"], F.tensor([0, 7, 8, 9, 0, 0, 0]))
40
41
    assert dgl.NID in g.ndata
    assert dgl.EID in g.edata
42

43

nv-dlasalle's avatar
nv-dlasalle committed
44
@parametrize_idtype
45
def test_multigraph_node_removal(idtype):
46
    g = create_graph(idtype, 5)
47
    for i in range(5):
48
49
        g.add_edges(i, i)
        g.add_edges(i, i)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
50
51
    assert g.num_nodes() == 5
    assert g.num_edges() == 10
52
53
54

    # remove nodes
    g.remove_nodes([2, 3])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
55
56
    assert g.num_nodes() == 3
    assert g.num_edges() == 6
57
58
59

    # add nodes
    g.add_nodes(1)
60
61
    g.add_edges(1, 1)
    g.add_edges(1, 1)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
62
63
    assert g.num_nodes() == 4
    assert g.num_edges() == 8
64
65
66

    # remove nodes
    g.remove_nodes([0])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
67
68
    assert g.num_nodes() == 3
    assert g.num_edges() == 6
69

70

nv-dlasalle's avatar
nv-dlasalle committed
71
@parametrize_idtype
72
def test_multigraph_edge_removal(idtype):
73
    g = create_graph(idtype, 5)
74
    for i in range(5):
75
76
        g.add_edges(i, i)
        g.add_edges(i, i)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
77
78
    assert g.num_nodes() == 5
    assert g.num_edges() == 10
79
80
81

    # remove edges
    g.remove_edges([2, 3])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
82
83
    assert g.num_nodes() == 5
    assert g.num_edges() == 8
84
85

    # add edges
86
87
    g.add_edges(1, 1)
    g.add_edges(1, 1)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
88
89
    assert g.num_nodes() == 5
    assert g.num_edges() == 10
90
91
92

    # remove edges
    g.remove_edges([0, 1])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
93
94
    assert g.num_nodes() == 5
    assert g.num_edges() == 8
95

96

nv-dlasalle's avatar
nv-dlasalle committed
97
@parametrize_idtype
98
def test_edge_removal(idtype):
99
    g = create_graph(idtype, 5)
100
101
    for i in range(5):
        for j in range(5):
102
            g.add_edges(i, j)
103
    g.edata["id"] = F.arange(0, 25)
104
105
106

    # remove edges
    g.remove_edges(range(13, 20))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
107
108
    assert g.num_nodes() == 5
    assert g.num_edges() == 18
109
110
111
    assert F.array_equal(
        g.edata["id"], F.tensor(list(range(13)) + list(range(20, 25)))
    )
112
113
    assert dgl.NID not in g.ndata
    assert dgl.EID not in g.edata
114
115

    # add edges
116
    g.add_edges(3, 3)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
117
118
    assert g.num_nodes() == 5
    assert g.num_edges() == 19
119
120
121
    assert F.array_equal(
        g.edata["id"], F.tensor(list(range(13)) + list(range(20, 25)) + [0])
    )
122
123

    # remove edges
124
    g.remove_edges(range(2, 10), store_ids=True)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
125
126
    assert g.num_nodes() == 5
    assert g.num_edges() == 11
127
128
129
    assert F.array_equal(
        g.edata["id"], F.tensor([0, 1, 10, 11, 12, 20, 21, 22, 23, 24, 0])
    )
130
    assert dgl.EID in g.edata
131

132

nv-dlasalle's avatar
nv-dlasalle committed
133
@parametrize_idtype
134
def test_node_and_edge_removal(idtype):
135
    g = create_graph(idtype, 10)
136
137
    for i in range(10):
        for j in range(10):
138
            g.add_edges(i, j)
139
    g.edata["id"] = F.arange(0, 100)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
140
141
    assert g.num_nodes() == 10
    assert g.num_edges() == 100
142
143
144

    # remove nodes
    g.remove_nodes([2, 4])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
145
146
    assert g.num_nodes() == 8
    assert g.num_edges() == 64
147
148
149

    # remove edges
    g.remove_edges(range(10, 20))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
150
151
    assert g.num_nodes() == 8
    assert g.num_edges() == 54
152
153
154

    # add nodes
    g.add_nodes(2)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
155
156
    assert g.num_nodes() == 10
    assert g.num_edges() == 54
157
158
159
160

    # add edges
    for i in range(8, 10):
        for j in range(8, 10):
161
            g.add_edges(i, j)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
162
163
    assert g.num_nodes() == 10
    assert g.num_edges() == 58
164
165
166

    # remove edges
    g.remove_edges(range(10, 20))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
167
168
    assert g.num_nodes() == 10
    assert g.num_edges() == 48
169

170

nv-dlasalle's avatar
nv-dlasalle committed
171
@parametrize_idtype
172
def test_node_frame(idtype):
173
    g = create_graph(idtype, 10)
174
175
    data = np.random.rand(10, 3)
    new_data = data.take([0, 1, 2, 7, 8, 9], axis=0)
176
    g.ndata["h"] = F.tensor(data)
177
178
179

    # remove nodes
    g.remove_nodes(range(3, 7))
180
181
    assert F.allclose(g.ndata["h"], F.tensor(new_data))

182

nv-dlasalle's avatar
nv-dlasalle committed
183
@parametrize_idtype
184
def test_edge_frame(idtype):
185
    g = create_graph(idtype, 10)
186
187
188
    g.add_edges(list(range(10)), list(range(1, 10)) + [0])
    data = np.random.rand(10, 3)
    new_data = data.take([0, 1, 2, 7, 8, 9], axis=0)
189
    g.edata["h"] = F.tensor(data)
190
191
192

    # remove edges
    g.remove_edges(range(3, 7))
193
194
    assert F.allclose(g.edata["h"], F.tensor(new_data))

195

nv-dlasalle's avatar
nv-dlasalle committed
196
@parametrize_idtype
197
def test_issue1287(idtype):
198
    # reproduce https://github.com/dmlc/dgl/issues/1287.
199
    # setting features after remove nodes
200
    g = create_graph(idtype, 5)
201
202
    g.add_edges([0, 2, 3, 1, 1], [1, 0, 3, 1, 0])
    g.remove_nodes([0, 1])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
203
204
    g.ndata["h"] = F.randn((g.num_nodes(), 3))
    g.edata["h"] = F.randn((g.num_edges(), 2))
205

206
    # remove edges
207
    g = create_graph(idtype, 5)
208
209
    g.add_edges([0, 2, 3, 1, 1], [1, 0, 3, 1, 0])
    g.remove_edges([0, 1])
210
    g = g.to(F.ctx())
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
211
212
    g.ndata["h"] = F.randn((g.num_nodes(), 3))
    g.edata["h"] = F.randn((g.num_edges(), 2))
213

214

215
if __name__ == "__main__":
216
217
218
219
220
221
222
    test_node_removal()
    test_edge_removal()
    test_multigraph_node_removal()
    test_multigraph_edge_removal()
    test_node_and_edge_removal()
    test_node_frame()
    test_edge_frame()
223
    test_frame_size()