test_subgraph.py 28.8 KB
Newer Older
1
2
import unittest

3
import backend as F
4
5

import dgl
6
7
8
9
import networkx as nx
import numpy as np
import pytest
import scipy.sparse as ssp
nv-dlasalle's avatar
nv-dlasalle committed
10
from test_utils import parametrize_idtype
Minjie Wang's avatar
Minjie Wang committed
11
12
13

D = 5

14

15
def generate_graph(grad=False, add_data=True):
16
    g = dgl.DGLGraph().to(F.ctx())
17
    g.add_nodes(10)
Minjie Wang's avatar
Minjie Wang committed
18
19
    # create a graph where 0 is the source and 9 is the sink
    for i in range(1, 9):
20
21
        g.add_edges(0, i)
        g.add_edges(i, 9)
Minjie Wang's avatar
Minjie Wang committed
22
    # add a back flow from 9 to 0
23
    g.add_edges(9, 0)
24
25
26
27
28
29
    if add_data:
        ncol = F.randn((10, D))
        ecol = F.randn((17, D))
        if grad:
            ncol = F.attach_grad(ncol)
            ecol = F.attach_grad(ecol)
30
31
        g.ndata["h"] = ncol
        g.edata["l"] = ecol
Minjie Wang's avatar
Minjie Wang committed
32
33
    return g

34

35
def test_edge_subgraph():
36
37
38
    # Test when the graph has no node data and edge data.
    g = generate_graph(add_data=False)
    eid = [0, 2, 3, 6, 7, 9]
39
40

    # relabel=True
41
    sg = g.edge_subgraph(eid)
42
43
44
    assert F.array_equal(
        sg.ndata[dgl.NID], F.tensor([0, 2, 4, 5, 1, 9], g.idtype)
    )
45
    assert F.array_equal(sg.edata[dgl.EID], F.tensor(eid, g.idtype))
46
47
    sg.ndata["h"] = F.arange(0, sg.number_of_nodes())
    sg.edata["h"] = F.arange(0, sg.number_of_edges())
48
49
50
51
52

    # relabel=False
    sg = g.edge_subgraph(eid, relabel_nodes=False)
    assert g.number_of_nodes() == sg.number_of_nodes()
    assert F.array_equal(sg.edata[dgl.EID], F.tensor(eid, g.idtype))
53
54
55
    sg.ndata["h"] = F.arange(0, sg.number_of_nodes())
    sg.edata["h"] = F.arange(0, sg.number_of_edges())

56

57
def test_subgraph():
Minjie Wang's avatar
Minjie Wang committed
58
    g = generate_graph()
59
60
    h = g.ndata["h"]
    l = g.edata["l"]
Minjie Wang's avatar
Minjie Wang committed
61
62
    nid = [0, 2, 3, 6, 7, 9]
    sg = g.subgraph(nid)
63
    eid = {2, 3, 4, 5, 10, 11, 12, 13, 16}
64
65
    assert set(F.asnumpy(sg.edata[dgl.EID])) == eid
    eid = sg.edata[dgl.EID]
66
67
68
    # the subgraph is empty initially except for NID/EID field
    assert len(sg.ndata) == 2
    assert len(sg.edata) == 2
69
    sh = sg.ndata["h"]
VoVAllen's avatar
VoVAllen committed
70
    assert F.allclose(F.gather_row(h, F.tensor(nid)), sh)
71
    """
Minjie Wang's avatar
Minjie Wang committed
72
73
74
    s, d, eid
    0, 1, 0
    1, 9, 1
Minjie Wang's avatar
Minjie Wang committed
75
76
77
78
    0, 2, 2    1
    2, 9, 3    1
    0, 3, 4    1
    3, 9, 5    1
Minjie Wang's avatar
Minjie Wang committed
79
80
81
    0, 4, 6
    4, 9, 7
    0, 5, 8
Minjie Wang's avatar
Minjie Wang committed
82
83
84
85
86
    5, 9, 9       3
    0, 6, 10   1
    6, 9, 11   1  3
    0, 7, 12   1
    7, 9, 13   1  3
Minjie Wang's avatar
Minjie Wang committed
87
    0, 8, 14
Minjie Wang's avatar
Minjie Wang committed
88
89
    8, 9, 15      3
    9, 0, 16   1
90
91
    """
    assert F.allclose(F.gather_row(l, eid), sg.edata["l"])
Minjie Wang's avatar
Minjie Wang committed
92
93
    # update the node/edge features on the subgraph should NOT
    # reflect to the parent graph.
94
95
96
    sg.ndata["h"] = F.zeros((6, D))
    assert F.allclose(h, g.ndata["h"])

Minjie Wang's avatar
Minjie Wang committed
97

98
99
def _test_map_to_subgraph():
    g = dgl.DGLGraph()
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
100
101
102
103
104
    g.add_nodes(10)
    g.add_edges(F.arange(0, 9), F.arange(1, 10))
    h = g.subgraph([0, 1, 2, 5, 8])
    v = h.map_to_subgraph_nid([0, 8, 2])
    assert np.array_equal(F.asnumpy(v), np.array([0, 4, 2]))
105

106

107
108
109
110
111
112
113
114
115
def create_test_heterograph(idtype):
    # test heterograph from the docstring, plus a user -- wishes -- game relation
    # 3 users, 2 games, 2 developers
    # metagraph:
    #    ('user', 'follows', 'user'),
    #    ('user', 'plays', 'game'),
    #    ('user', 'wishes', 'game'),
    #    ('developer', 'develops', 'game')])

116
117
118
119
120
121
122
123
124
125
    g = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 1, 2, 1], [0, 0, 1, 1]),
            ("user", "wishes", "game"): ([0, 2], [1, 0]),
            ("developer", "develops", "game"): ([0, 1], [0, 1]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
126
    for etype in g.etypes:
127
        g.edges[etype].data["weight"] = F.randn((g.num_edges(etype),))
128
129
130
131
    assert g.idtype == idtype
    assert g.device == F.ctx()
    return g

132

133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def create_test_heterograph2(idtype):
    """test heterograph from the docstring, with an empty relation"""
    # 3 users, 2 games, 2 developers
    # metagraph:
    #    ('user', 'follows', 'user'),
    #    ('user', 'plays', 'game'),
    #    ('user', 'wishes', 'game'),
    #    ('developer', 'develops', 'game')

    g = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 1, 2, 1], [0, 0, 1, 1]),
            ("user", "wishes", "game"): ([0, 2], [1, 0]),
            ("developer", "develops", "game"): ([], []),
        },
        idtype=idtype,
        device=F.ctx(),
    )
    for etype in g.etypes:
        g.edges[etype].data["weight"] = F.randn((g.num_edges(etype),))
    assert g.idtype == idtype
    assert g.device == F.ctx()
    return g


159
160
161
162
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="MXNet doesn't support bool tensor",
)
nv-dlasalle's avatar
nv-dlasalle committed
163
@parametrize_idtype
164
165
def test_subgraph_mask(idtype):
    g = create_test_heterograph(idtype)
166
167
    g_graph = g["follows"]
    g_bipartite = g["plays"]
168
169
170

    x = F.randn((3, 5))
    y = F.randn((2, 4))
171
172
    g.nodes["user"].data["h"] = x
    g.edges["follows"].data["h"] = y
173
174
175
176
177
178
179

    def _check_subgraph(g, sg):
        assert sg.idtype == g.idtype
        assert sg.device == g.device
        assert sg.ntypes == g.ntypes
        assert sg.etypes == g.etypes
        assert sg.canonical_etypes == g.canonical_etypes
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
        assert F.array_equal(
            F.tensor(sg.nodes["user"].data[dgl.NID]), F.tensor([1, 2], idtype)
        )
        assert F.array_equal(
            F.tensor(sg.nodes["game"].data[dgl.NID]), F.tensor([0], idtype)
        )
        assert F.array_equal(
            F.tensor(sg.edges["follows"].data[dgl.EID]), F.tensor([1], idtype)
        )
        assert F.array_equal(
            F.tensor(sg.edges["plays"].data[dgl.EID]), F.tensor([1], idtype)
        )
        assert F.array_equal(
            F.tensor(sg.edges["wishes"].data[dgl.EID]), F.tensor([1], idtype)
        )
        assert sg.number_of_nodes("developer") == 0
        assert sg.number_of_edges("develops") == 0
        assert F.array_equal(
            sg.nodes["user"].data["h"], g.nodes["user"].data["h"][1:3]
        )
        assert F.array_equal(
            sg.edges["follows"].data["h"], g.edges["follows"].data["h"][1:2]
        )

    sg1 = g.subgraph(
        {
            "user": F.tensor([False, True, True], dtype=F.bool),
            "game": F.tensor([True, False, False, False], dtype=F.bool),
        }
    )
210
    _check_subgraph(g, sg1)
211
212
213
214
215
216
217
    sg2 = g.edge_subgraph(
        {
            "follows": F.tensor([False, True], dtype=F.bool),
            "plays": F.tensor([False, True, False, False], dtype=F.bool),
            "wishes": F.tensor([False, True], dtype=F.bool),
        }
    )
218
    _check_subgraph(g, sg2)
219

220

nv-dlasalle's avatar
nv-dlasalle committed
221
@parametrize_idtype
222
223
def test_subgraph1(idtype):
    g = create_test_heterograph(idtype)
224
225
    g_graph = g["follows"]
    g_bipartite = g["plays"]
226
227
228

    x = F.randn((3, 5))
    y = F.randn((2, 4))
229
230
    g.nodes["user"].data["h"] = x
    g.edges["follows"].data["h"] = y
231
232
233
234
235
236
237

    def _check_subgraph(g, sg):
        assert sg.idtype == g.idtype
        assert sg.device == g.device
        assert sg.ntypes == g.ntypes
        assert sg.etypes == g.etypes
        assert sg.canonical_etypes == g.canonical_etypes
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
        assert F.array_equal(
            F.tensor(sg.nodes["user"].data[dgl.NID]), F.tensor([1, 2], g.idtype)
        )
        assert F.array_equal(
            F.tensor(sg.nodes["game"].data[dgl.NID]), F.tensor([0], g.idtype)
        )
        assert F.array_equal(
            F.tensor(sg.edges["follows"].data[dgl.EID]), F.tensor([1], g.idtype)
        )
        assert F.array_equal(
            F.tensor(sg.edges["plays"].data[dgl.EID]), F.tensor([1], g.idtype)
        )
        assert F.array_equal(
            F.tensor(sg.edges["wishes"].data[dgl.EID]), F.tensor([1], g.idtype)
        )
        assert sg.number_of_nodes("developer") == 0
        assert sg.number_of_edges("develops") == 0
        assert F.array_equal(
            sg.nodes["user"].data["h"], g.nodes["user"].data["h"][1:3]
        )
        assert F.array_equal(
            sg.edges["follows"].data["h"], g.edges["follows"].data["h"][1:2]
        )

    sg1 = g.subgraph({"user": [1, 2], "game": [0]})
263
    _check_subgraph(g, sg1)
264
    sg2 = g.edge_subgraph({"follows": [1], "plays": [1], "wishes": [1]})
265
    _check_subgraph(g, sg2)
266
267

    # backend tensor input
268
269
270
271
272
273
    sg1 = g.subgraph(
        {
            "user": F.tensor([1, 2], dtype=idtype),
            "game": F.tensor([0], dtype=idtype),
        }
    )
274
    _check_subgraph(g, sg1)
275
276
277
278
279
280
281
    sg2 = g.edge_subgraph(
        {
            "follows": F.tensor([1], dtype=idtype),
            "plays": F.tensor([1], dtype=idtype),
            "wishes": F.tensor([1], dtype=idtype),
        }
    )
282
    _check_subgraph(g, sg2)
283
284

    # numpy input
285
    sg1 = g.subgraph({"user": np.array([1, 2]), "game": np.array([0])})
286
    _check_subgraph(g, sg1)
287
288
289
290
291
292
293
    sg2 = g.edge_subgraph(
        {
            "follows": np.array([1]),
            "plays": np.array([1]),
            "wishes": np.array([1]),
        }
    )
294
    _check_subgraph(g, sg2)
295
296
297
298
299
300
301
302
303

    def _check_subgraph_single_ntype(g, sg, preserve_nodes=False):
        assert sg.idtype == g.idtype
        assert sg.device == g.device
        assert sg.ntypes == g.ntypes
        assert sg.etypes == g.etypes
        assert sg.canonical_etypes == g.canonical_etypes

        if not preserve_nodes:
304
305
306
307
            assert F.array_equal(
                F.tensor(sg.nodes["user"].data[dgl.NID]),
                F.tensor([1, 2], g.idtype),
            )
308
309
310
311
        else:
            for ntype in sg.ntypes:
                assert g.number_of_nodes(ntype) == sg.number_of_nodes(ntype)

312
313
314
        assert F.array_equal(
            F.tensor(sg.edges["follows"].data[dgl.EID]), F.tensor([1], g.idtype)
        )
315
316

        if not preserve_nodes:
317
318
319
320
321
322
            assert F.array_equal(
                sg.nodes["user"].data["h"], g.nodes["user"].data["h"][1:3]
            )
        assert F.array_equal(
            sg.edges["follows"].data["h"], g.edges["follows"].data["h"][1:2]
        )
323
324
325
326
327
328
329

    def _check_subgraph_single_etype(g, sg, preserve_nodes=False):
        assert sg.ntypes == g.ntypes
        assert sg.etypes == g.etypes
        assert sg.canonical_etypes == g.canonical_etypes

        if not preserve_nodes:
330
331
332
333
334
335
336
337
            assert F.array_equal(
                F.tensor(sg.nodes["user"].data[dgl.NID]),
                F.tensor([0, 1], g.idtype),
            )
            assert F.array_equal(
                F.tensor(sg.nodes["game"].data[dgl.NID]),
                F.tensor([0], g.idtype),
            )
338
339
340
341
        else:
            for ntype in sg.ntypes:
                assert g.number_of_nodes(ntype) == sg.number_of_nodes(ntype)

342
343
344
345
        assert F.array_equal(
            F.tensor(sg.edges["plays"].data[dgl.EID]),
            F.tensor([0, 1], g.idtype),
        )
346
347
348

    sg1_graph = g_graph.subgraph([1, 2])
    _check_subgraph_single_ntype(g_graph, sg1_graph)
349
350
351
352
353
354
355
356
    sg1_graph = g_graph.edge_subgraph([1])
    _check_subgraph_single_ntype(g_graph, sg1_graph)
    sg1_graph = g_graph.edge_subgraph([1], relabel_nodes=False)
    _check_subgraph_single_ntype(g_graph, sg1_graph, True)
    sg2_bipartite = g_bipartite.edge_subgraph([0, 1])
    _check_subgraph_single_etype(g_bipartite, sg2_bipartite)
    sg2_bipartite = g_bipartite.edge_subgraph([0, 1], relabel_nodes=False)
    _check_subgraph_single_etype(g_bipartite, sg2_bipartite, True)
357
358
359
360

    def _check_typed_subgraph1(g, sg):
        assert g.idtype == sg.idtype
        assert g.device == sg.device
361
362
        assert set(sg.ntypes) == {"user", "game"}
        assert set(sg.etypes) == {"follows", "plays", "wishes"}
363
364
365
        for ntype in sg.ntypes:
            assert sg.number_of_nodes(ntype) == g.number_of_nodes(ntype)
        for etype in sg.etypes:
366
367
            src_sg, dst_sg = sg.all_edges(etype=etype, order="eid")
            src_g, dst_g = g.all_edges(etype=etype, order="eid")
368
369
            assert F.array_equal(src_sg, src_g)
            assert F.array_equal(dst_sg, dst_g)
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
        assert F.array_equal(
            sg.nodes["user"].data["h"], g.nodes["user"].data["h"]
        )
        assert F.array_equal(
            sg.edges["follows"].data["h"], g.edges["follows"].data["h"]
        )
        g.nodes["user"].data["h"] = F.scatter_row(
            g.nodes["user"].data["h"], F.tensor([2]), F.randn((1, 5))
        )
        g.edges["follows"].data["h"] = F.scatter_row(
            g.edges["follows"].data["h"], F.tensor([1]), F.randn((1, 4))
        )
        assert F.array_equal(
            sg.nodes["user"].data["h"], g.nodes["user"].data["h"]
        )
        assert F.array_equal(
            sg.edges["follows"].data["h"], g.edges["follows"].data["h"]
        )
388
389

    def _check_typed_subgraph2(g, sg):
390
391
        assert set(sg.ntypes) == {"developer", "game"}
        assert set(sg.etypes) == {"develops"}
392
393
394
        for ntype in sg.ntypes:
            assert sg.number_of_nodes(ntype) == g.number_of_nodes(ntype)
        for etype in sg.etypes:
395
396
            src_sg, dst_sg = sg.all_edges(etype=etype, order="eid")
            src_g, dst_g = g.all_edges(etype=etype, order="eid")
397
398
399
            assert F.array_equal(src_sg, src_g)
            assert F.array_equal(dst_sg, dst_g)

400
    sg3 = g.node_type_subgraph(["user", "game"])
401
    _check_typed_subgraph1(g, sg3)
402
    sg4 = g.edge_type_subgraph(["develops"])
403
    _check_typed_subgraph2(g, sg4)
404
    sg5 = g.edge_type_subgraph(["follows", "plays", "wishes"])
405
406
407
    _check_typed_subgraph1(g, sg5)

    # Test for restricted format
408
    for fmt in ["csr", "csc", "coo"]:
409
410
411
412
        g = dgl.graph(([0, 1], [1, 2])).formats(fmt)
        sg = g.subgraph({g.ntypes[0]: [1, 0]})
        nids = F.asnumpy(sg.ndata[dgl.NID])
        assert np.array_equal(nids, np.array([1, 0]))
413
        src, dst = sg.edges(order="eid")
414
415
416
417
        src = F.asnumpy(src)
        dst = F.asnumpy(dst)
        assert np.array_equal(src, np.array([1]))

418

nv-dlasalle's avatar
nv-dlasalle committed
419
@parametrize_idtype
420
def test_in_subgraph(idtype):
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
    hg = dgl.heterograph(
        {
            ("user", "follow", "user"): (
                [1, 2, 3, 0, 2, 3, 0],
                [0, 0, 0, 1, 1, 1, 2],
            ),
            ("user", "play", "game"): ([0, 0, 1, 3], [0, 1, 2, 2]),
            ("game", "liked-by", "user"): (
                [2, 2, 2, 1, 1, 0],
                [0, 1, 2, 0, 3, 0],
            ),
            ("user", "flips", "coin"): ([0, 1, 2, 3], [0, 0, 0, 0]),
        },
        idtype=idtype,
        num_nodes_dict={"user": 5, "game": 10, "coin": 8},
    ).to(F.ctx())
    subg = dgl.in_subgraph(hg, {"user": [0, 1], "game": 0})
438
439
440
    assert subg.idtype == idtype
    assert len(subg.ntypes) == 3
    assert len(subg.etypes) == 4
441
    u, v = subg["follow"].edges()
442
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
443
444
445
446
447
    assert F.array_equal(
        hg["follow"].edge_ids(u, v), subg["follow"].edata[dgl.EID]
    )
    assert edge_set == {(1, 0), (2, 0), (3, 0), (0, 1), (2, 1), (3, 1)}
    u, v = subg["play"].edges()
448
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
449
450
451
    assert F.array_equal(hg["play"].edge_ids(u, v), subg["play"].edata[dgl.EID])
    assert edge_set == {(0, 0)}
    u, v = subg["liked-by"].edges()
452
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
453
454
455
456
457
    assert F.array_equal(
        hg["liked-by"].edge_ids(u, v), subg["liked-by"].edata[dgl.EID]
    )
    assert edge_set == {(2, 0), (2, 1), (1, 0), (0, 0)}
    assert subg["flips"].number_of_edges() == 0
458
459
460
461
    for ntype in subg.ntypes:
        assert dgl.NID not in subg.nodes[ntype].data

    # Test store_ids
462
463
    subg = dgl.in_subgraph(hg, {"user": [0, 1], "game": 0}, store_ids=False)
    for etype in ["follow", "play", "liked-by"]:
464
465
466
467
468
        assert dgl.EID not in subg.edges[etype].data
    for ntype in subg.ntypes:
        assert dgl.NID not in subg.nodes[ntype].data

    # Test relabel nodes
469
    subg = dgl.in_subgraph(hg, {"user": [0, 1], "game": 0}, relabel_nodes=True)
470
471
472
473
    assert subg.idtype == idtype
    assert len(subg.ntypes) == 3
    assert len(subg.etypes) == 4

474
475
476
477
478
479
    u, v = subg["follow"].edges()
    old_u = F.gather_row(subg.nodes["user"].data[dgl.NID], u)
    old_v = F.gather_row(subg.nodes["user"].data[dgl.NID], v)
    assert F.array_equal(
        hg["follow"].edge_ids(old_u, old_v), subg["follow"].edata[dgl.EID]
    )
480
    edge_set = set(zip(list(F.asnumpy(old_u)), list(F.asnumpy(old_v))))
481
482
483
484
485
486
487
488
    assert edge_set == {(1, 0), (2, 0), (3, 0), (0, 1), (2, 1), (3, 1)}

    u, v = subg["play"].edges()
    old_u = F.gather_row(subg.nodes["user"].data[dgl.NID], u)
    old_v = F.gather_row(subg.nodes["game"].data[dgl.NID], v)
    assert F.array_equal(
        hg["play"].edge_ids(old_u, old_v), subg["play"].edata[dgl.EID]
    )
489
    edge_set = set(zip(list(F.asnumpy(old_u)), list(F.asnumpy(old_v))))
490
491
492
493
494
495
496
497
    assert edge_set == {(0, 0)}

    u, v = subg["liked-by"].edges()
    old_u = F.gather_row(subg.nodes["game"].data[dgl.NID], u)
    old_v = F.gather_row(subg.nodes["user"].data[dgl.NID], v)
    assert F.array_equal(
        hg["liked-by"].edge_ids(old_u, old_v), subg["liked-by"].edata[dgl.EID]
    )
498
    edge_set = set(zip(list(F.asnumpy(old_u)), list(F.asnumpy(old_v))))
499
500
501
502
503
504
    assert edge_set == {(2, 0), (2, 1), (1, 0), (0, 0)}

    assert subg.num_nodes("user") == 4
    assert subg.num_nodes("game") == 3
    assert subg.num_nodes("coin") == 0
    assert subg.num_edges("flips") == 0
505

506

nv-dlasalle's avatar
nv-dlasalle committed
507
@parametrize_idtype
508
def test_out_subgraph(idtype):
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
    hg = dgl.heterograph(
        {
            ("user", "follow", "user"): (
                [1, 2, 3, 0, 2, 3, 0],
                [0, 0, 0, 1, 1, 1, 2],
            ),
            ("user", "play", "game"): ([0, 0, 1, 3], [0, 1, 2, 2]),
            ("game", "liked-by", "user"): (
                [2, 2, 2, 1, 1, 0],
                [0, 1, 2, 0, 3, 0],
            ),
            ("user", "flips", "coin"): ([0, 1, 2, 3], [0, 0, 0, 0]),
        },
        idtype=idtype,
    ).to(F.ctx())
    subg = dgl.out_subgraph(hg, {"user": [0, 1], "game": 0})
525
526
527
    assert subg.idtype == idtype
    assert len(subg.ntypes) == 3
    assert len(subg.etypes) == 4
528
    u, v = subg["follow"].edges()
529
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
530
531
532
533
534
    assert edge_set == {(1, 0), (0, 1), (0, 2)}
    assert F.array_equal(
        hg["follow"].edge_ids(u, v), subg["follow"].edata[dgl.EID]
    )
    u, v = subg["play"].edges()
535
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
536
537
538
    assert edge_set == {(0, 0), (0, 1), (1, 2)}
    assert F.array_equal(hg["play"].edge_ids(u, v), subg["play"].edata[dgl.EID])
    u, v = subg["liked-by"].edges()
539
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
540
541
542
543
544
    assert edge_set == {(0, 0)}
    assert F.array_equal(
        hg["liked-by"].edge_ids(u, v), subg["liked-by"].edata[dgl.EID]
    )
    u, v = subg["flips"].edges()
545
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
546
547
548
549
    assert edge_set == {(0, 0), (1, 0)}
    assert F.array_equal(
        hg["flips"].edge_ids(u, v), subg["flips"].edata[dgl.EID]
    )
550
551
552
553
    for ntype in subg.ntypes:
        assert dgl.NID not in subg.nodes[ntype].data

    # Test store_ids
554
    subg = dgl.out_subgraph(hg, {"user": [0, 1], "game": 0}, store_ids=False)
555
556
557
558
559
560
    for etype in subg.canonical_etypes:
        assert dgl.EID not in subg.edges[etype].data
    for ntype in subg.ntypes:
        assert dgl.NID not in subg.nodes[ntype].data

    # Test relabel nodes
561
    subg = dgl.out_subgraph(hg, {"user": [1], "game": 0}, relabel_nodes=True)
562
563
564
565
    assert subg.idtype == idtype
    assert len(subg.ntypes) == 3
    assert len(subg.etypes) == 4

566
567
568
    u, v = subg["follow"].edges()
    old_u = F.gather_row(subg.nodes["user"].data[dgl.NID], u)
    old_v = F.gather_row(subg.nodes["user"].data[dgl.NID], v)
569
570
    edge_set = set(zip(list(F.asnumpy(old_u)), list(F.asnumpy(old_v))))
    assert edge_set == {(1, 0)}
571
572
573
    assert F.array_equal(
        hg["follow"].edge_ids(old_u, old_v), subg["follow"].edata[dgl.EID]
    )
574

575
576
577
    u, v = subg["play"].edges()
    old_u = F.gather_row(subg.nodes["user"].data[dgl.NID], u)
    old_v = F.gather_row(subg.nodes["game"].data[dgl.NID], v)
578
579
    edge_set = set(zip(list(F.asnumpy(old_u)), list(F.asnumpy(old_v))))
    assert edge_set == {(1, 2)}
580
581
582
    assert F.array_equal(
        hg["play"].edge_ids(old_u, old_v), subg["play"].edata[dgl.EID]
    )
583

584
585
586
    u, v = subg["liked-by"].edges()
    old_u = F.gather_row(subg.nodes["game"].data[dgl.NID], u)
    old_v = F.gather_row(subg.nodes["user"].data[dgl.NID], v)
587
    edge_set = set(zip(list(F.asnumpy(old_u)), list(F.asnumpy(old_v))))
588
589
590
591
592
593
594
595
    assert edge_set == {(0, 0)}
    assert F.array_equal(
        hg["liked-by"].edge_ids(old_u, old_v), subg["liked-by"].edata[dgl.EID]
    )

    u, v = subg["flips"].edges()
    old_u = F.gather_row(subg.nodes["user"].data[dgl.NID], u)
    old_v = F.gather_row(subg.nodes["coin"].data[dgl.NID], v)
596
    edge_set = set(zip(list(F.asnumpy(old_u)), list(F.asnumpy(old_v))))
597
598
599
600
601
602
603
604
    assert edge_set == {(1, 0)}
    assert F.array_equal(
        hg["flips"].edge_ids(old_u, old_v), subg["flips"].edata[dgl.EID]
    )
    assert subg.num_nodes("user") == 2
    assert subg.num_nodes("game") == 2
    assert subg.num_nodes("coin") == 1

605
606
607
608

def test_subgraph_message_passing():
    # Unit test for PR #2055
    g = dgl.graph(([0, 1, 2], [2, 3, 4])).to(F.cpu())
609
    g.ndata["x"] = F.copy_to(F.randn((5, 6)), F.cpu())
610
    sg = g.subgraph([1, 2, 3]).to(F.ctx())
611
612
613
614
615
    sg.update_all(
        lambda edges: {"x": edges.src["x"]},
        lambda nodes: {"y": F.sum(nodes.mailbox["x"], 1)},
    )

616

nv-dlasalle's avatar
nv-dlasalle committed
617
@parametrize_idtype
618
def test_khop_in_subgraph(idtype):
619
620
621
622
    g = dgl.graph(
        ([1, 1, 2, 3, 4], [0, 2, 0, 4, 2]), idtype=idtype, device=F.ctx()
    )
    g.edata["w"] = F.tensor([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])
Mufei Li's avatar
Mufei Li committed
623
    sg, inv = dgl.khop_in_subgraph(g, 0, k=2)
624
625
626
    assert sg.idtype == g.idtype
    u, v = sg.edges()
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
627
628
629
630
631
632
633
    assert edge_set == {(1, 0), (1, 2), (2, 0), (3, 2)}
    assert F.array_equal(
        sg.edata[dgl.EID], F.tensor([0, 1, 2, 4], dtype=idtype)
    )
    assert F.array_equal(
        sg.edata["w"], F.tensor([[0, 1], [2, 3], [4, 5], [8, 9]])
    )
Mufei Li's avatar
Mufei Li committed
634
    assert F.array_equal(F.astype(inv, idtype), F.tensor([0], idtype))
635
636

    # Test multiple nodes
Mufei Li's avatar
Mufei Li committed
637
    sg, inv = dgl.khop_in_subgraph(g, [0, 2], k=1)
638
639
    assert sg.num_edges() == 4

Mufei Li's avatar
Mufei Li committed
640
    sg, inv = dgl.khop_in_subgraph(g, F.tensor([0, 2], idtype), k=1)
641
642
643
    assert sg.num_edges() == 4

    # Test isolated node
Mufei Li's avatar
Mufei Li committed
644
    sg, inv = dgl.khop_in_subgraph(g, 1, k=2)
645
646
647
    assert sg.idtype == g.idtype
    assert sg.num_nodes() == 1
    assert sg.num_edges() == 0
Mufei Li's avatar
Mufei Li committed
648
    assert F.array_equal(F.astype(inv, idtype), F.tensor([0], idtype))
649

650
651
652
653
654
655
656
657
658
    g = dgl.heterograph(
        {
            ("user", "plays", "game"): ([0, 1, 1, 2], [0, 0, 2, 1]),
            ("user", "follows", "user"): ([0, 1, 1], [1, 2, 2]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
    sg, inv = dgl.khop_in_subgraph(g, {"game": 0}, k=2)
659
    assert sg.idtype == idtype
660
661
    assert sg.num_nodes("game") == 1
    assert sg.num_nodes("user") == 2
662
663
    assert len(sg.ntypes) == 2
    assert len(sg.etypes) == 2
664
    u, v = sg["follows"].edges()
665
666
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
    assert edge_set == {(0, 1)}
667
    u, v = sg["plays"].edges()
668
669
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
    assert edge_set == {(0, 0), (1, 0)}
670
    assert F.array_equal(F.astype(inv["game"], idtype), F.tensor([0], idtype))
671
672

    # Test isolated node
673
    sg, inv = dgl.khop_in_subgraph(g, {"user": 0}, k=2)
674
    assert sg.idtype == idtype
675
676
677
678
679
    assert sg.num_nodes("game") == 0
    assert sg.num_nodes("user") == 1
    assert sg.num_edges("follows") == 0
    assert sg.num_edges("plays") == 0
    assert F.array_equal(F.astype(inv["user"], idtype), F.tensor([0], idtype))
680
681

    # Test multiple nodes
682
683
684
685
    sg, inv = dgl.khop_in_subgraph(
        g, {"user": F.tensor([0, 1], idtype), "game": 0}, k=1
    )
    u, v = sg["follows"].edges()
686
687
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
    assert edge_set == {(0, 1)}
688
    u, v = sg["plays"].edges()
689
690
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
    assert edge_set == {(0, 0), (1, 0)}
691
692
693
694
695
    assert F.array_equal(
        F.astype(inv["user"], idtype), F.tensor([0, 1], idtype)
    )
    assert F.array_equal(F.astype(inv["game"], idtype), F.tensor([0], idtype))

696

nv-dlasalle's avatar
nv-dlasalle committed
697
@parametrize_idtype
698
def test_khop_out_subgraph(idtype):
699
700
701
702
    g = dgl.graph(
        ([0, 2, 0, 4, 2], [1, 1, 2, 3, 4]), idtype=idtype, device=F.ctx()
    )
    g.edata["w"] = F.tensor([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])
Mufei Li's avatar
Mufei Li committed
703
    sg, inv = dgl.khop_out_subgraph(g, 0, k=2)
704
705
706
    assert sg.idtype == g.idtype
    u, v = sg.edges()
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
707
708
709
710
711
712
713
    assert edge_set == {(0, 1), (2, 1), (0, 2), (2, 3)}
    assert F.array_equal(
        sg.edata[dgl.EID], F.tensor([0, 2, 1, 4], dtype=idtype)
    )
    assert F.array_equal(
        sg.edata["w"], F.tensor([[0, 1], [4, 5], [2, 3], [8, 9]])
    )
Mufei Li's avatar
Mufei Li committed
714
    assert F.array_equal(F.astype(inv, idtype), F.tensor([0], idtype))
715
716

    # Test multiple nodes
Mufei Li's avatar
Mufei Li committed
717
    sg, inv = dgl.khop_out_subgraph(g, [0, 2], k=1)
718
719
    assert sg.num_edges() == 4

Mufei Li's avatar
Mufei Li committed
720
    sg, inv = dgl.khop_out_subgraph(g, F.tensor([0, 2], idtype), k=1)
721
722
723
    assert sg.num_edges() == 4

    # Test isolated node
Mufei Li's avatar
Mufei Li committed
724
    sg, inv = dgl.khop_out_subgraph(g, 1, k=2)
725
726
727
    assert sg.idtype == g.idtype
    assert sg.num_nodes() == 1
    assert sg.num_edges() == 0
Mufei Li's avatar
Mufei Li committed
728
    assert F.array_equal(F.astype(inv, idtype), F.tensor([0], idtype))
729

730
731
732
733
734
735
736
737
738
    g = dgl.heterograph(
        {
            ("user", "plays", "game"): ([0, 1, 1, 2], [0, 0, 2, 1]),
            ("user", "follows", "user"): ([0, 1], [1, 3]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
    sg, inv = dgl.khop_out_subgraph(g, {"user": 0}, k=2)
739
    assert sg.idtype == idtype
740
741
    assert sg.num_nodes("game") == 2
    assert sg.num_nodes("user") == 3
742
743
    assert len(sg.ntypes) == 2
    assert len(sg.etypes) == 2
744
    u, v = sg["follows"].edges()
745
746
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
    assert edge_set == {(0, 1), (1, 2)}
747
    u, v = sg["plays"].edges()
748
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
749
750
    assert edge_set == {(0, 0), (1, 0), (1, 1)}
    assert F.array_equal(F.astype(inv["user"], idtype), F.tensor([0], idtype))
751
752

    # Test isolated node
753
    sg, inv = dgl.khop_out_subgraph(g, {"user": 3}, k=2)
754
    assert sg.idtype == idtype
755
756
757
758
759
    assert sg.num_nodes("game") == 0
    assert sg.num_nodes("user") == 1
    assert sg.num_edges("follows") == 0
    assert sg.num_edges("plays") == 0
    assert F.array_equal(F.astype(inv["user"], idtype), F.tensor([0], idtype))
760
761

    # Test multiple nodes
762
763
764
765
766
    sg, inv = dgl.khop_out_subgraph(
        g, {"user": F.tensor([2], idtype), "game": 0}, k=1
    )
    assert sg.num_edges("follows") == 0
    u, v = sg["plays"].edges()
767
768
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
    assert edge_set == {(0, 1)}
769
770
    assert F.array_equal(F.astype(inv["user"], idtype), F.tensor([0], idtype))
    assert F.array_equal(F.astype(inv["game"], idtype), F.tensor([0], idtype))
771

772
773

@unittest.skipIf(not F.gpu_ctx(), "only necessary with GPU")
774
@pytest.mark.parametrize(
775
776
777
778
    "parent_idx_device",
    [("cpu", F.cpu()), ("cuda", F.cuda()), ("uva", F.cpu()), ("uva", F.cuda())],
)
@pytest.mark.parametrize("child_device", [F.cpu(), F.cuda()])
779
780
def test_subframes(parent_idx_device, child_device):
    parent_device, idx_device = parent_idx_device
781
782
783
    g = dgl.graph(
        (F.tensor([1, 2, 3], dtype=F.int64), F.tensor([2, 3, 4], dtype=F.int64))
    )
784
    print(g.device)
785
786
    g.ndata["x"] = F.randn((5, 4))
    g.edata["a"] = F.randn((3, 6))
787
    idx = F.tensor([1, 2], dtype=F.int64)
788
    if parent_device == "cuda":
789
        g = g.to(F.cuda())
790
791
    elif parent_device == "uva":
        if F.backend_name != "pytorch":
792
            pytest.skip("UVA only supported for PyTorch")
793
794
795
        g = g.to(F.cpu())
        g.create_formats_()
        g.pin_memory_()
796
    elif parent_device == "cpu":
797
798
799
        g = g.to(F.cpu())
    idx = F.copy_to(idx, idx_device)
    sg = g.sample_neighbors(idx, 2).to(child_device)
800
801
    assert sg.device == F.context(sg.ndata["x"])
    assert sg.device == F.context(sg.edata["a"])
802
    assert sg.device == child_device
803
804
805
806
807
808
    if parent_device != "uva":
        sg = g.to(child_device).sample_neighbors(
            F.copy_to(idx, child_device), 2
        )
        assert sg.device == F.context(sg.ndata["x"])
        assert sg.device == F.context(sg.edata["a"])
809
        assert sg.device == child_device
810
    if parent_device == "uva":
811
812
        g.unpin_memory_()

813
814
815
816
817
818
819
820

@unittest.skipIf(
    F._default_context_str != "gpu", reason="UVA only available on GPU"
)
@unittest.skipIf(
    dgl.backend.backend_name != "pytorch",
    reason="UVA only supported for PyTorch",
)
821
@pytest.mark.parametrize("device", [F.cpu(), F.cuda()])
nv-dlasalle's avatar
nv-dlasalle committed
822
@parametrize_idtype
823
def test_uva_subgraph(idtype, device):
824
    g = create_test_heterograph2(idtype)
825
826
827
    g = g.to(F.cpu())
    g.create_formats_()
    g.pin_memory_()
828
829
    indices = {"user": F.copy_to(F.tensor([0], idtype), device)}
    edge_indices = {"follows": F.copy_to(F.tensor([0], idtype), device)}
830
831
832
833
    assert g.subgraph(indices).device == device
    assert g.edge_subgraph(edge_indices).device == device
    assert g.in_subgraph(indices).device == device
    assert g.out_subgraph(indices).device == device
834
835
    assert g.khop_in_subgraph(indices, 1)[0].device == device
    assert g.khop_out_subgraph(indices, 1)[0].device == device
836
837
838
    assert g.sample_neighbors(indices, 1).device == device
    g.unpin_memory_()

839
840

if __name__ == "__main__":
841
    test_edge_subgraph()
842
843
    test_uva_subgraph(F.int64, F.cpu())
    test_uva_subgraph(F.int64, F.cuda())