"docs/source/vscode:/vscode.git/clone" did not exist on "3cc32a9775c397f4174ca28852bee2818319c0ce"
test_subgraph.py 28.7 KB
Newer Older
1
2
import unittest

3
import backend as F
4
5

import dgl
6
7
8
9
import networkx as nx
import numpy as np
import pytest
import scipy.sparse as ssp
10
from utils import parametrize_idtype
Minjie Wang's avatar
Minjie Wang committed
11
12
13

D = 5

14

15
def generate_graph(grad=False, add_data=True):
16
    g = dgl.DGLGraph().to(F.ctx())
17
    g.add_nodes(10)
Minjie Wang's avatar
Minjie Wang committed
18
19
    # create a graph where 0 is the source and 9 is the sink
    for i in range(1, 9):
20
21
        g.add_edges(0, i)
        g.add_edges(i, 9)
Minjie Wang's avatar
Minjie Wang committed
22
    # add a back flow from 9 to 0
23
    g.add_edges(9, 0)
24
25
26
27
28
29
    if add_data:
        ncol = F.randn((10, D))
        ecol = F.randn((17, D))
        if grad:
            ncol = F.attach_grad(ncol)
            ecol = F.attach_grad(ecol)
30
31
        g.ndata["h"] = ncol
        g.edata["l"] = ecol
Minjie Wang's avatar
Minjie Wang committed
32
33
    return g

34

35
def test_edge_subgraph():
36
37
38
    # Test when the graph has no node data and edge data.
    g = generate_graph(add_data=False)
    eid = [0, 2, 3, 6, 7, 9]
39
40

    # relabel=True
41
    sg = g.edge_subgraph(eid)
42
43
44
    assert F.array_equal(
        sg.ndata[dgl.NID], F.tensor([0, 2, 4, 5, 1, 9], g.idtype)
    )
45
    assert F.array_equal(sg.edata[dgl.EID], F.tensor(eid, g.idtype))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
46
47
    sg.ndata["h"] = F.arange(0, sg.num_nodes())
    sg.edata["h"] = F.arange(0, sg.num_edges())
48
49
50

    # relabel=False
    sg = g.edge_subgraph(eid, relabel_nodes=False)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
51
    assert g.num_nodes() == sg.num_nodes()
52
    assert F.array_equal(sg.edata[dgl.EID], F.tensor(eid, g.idtype))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
53
54
    sg.ndata["h"] = F.arange(0, sg.num_nodes())
    sg.edata["h"] = F.arange(0, sg.num_edges())
55

56

57
def test_subgraph():
Minjie Wang's avatar
Minjie Wang committed
58
    g = generate_graph()
59
60
    h = g.ndata["h"]
    l = g.edata["l"]
Minjie Wang's avatar
Minjie Wang committed
61
62
    nid = [0, 2, 3, 6, 7, 9]
    sg = g.subgraph(nid)
63
    eid = {2, 3, 4, 5, 10, 11, 12, 13, 16}
64
65
    assert set(F.asnumpy(sg.edata[dgl.EID])) == eid
    eid = sg.edata[dgl.EID]
66
67
68
    # the subgraph is empty initially except for NID/EID field
    assert len(sg.ndata) == 2
    assert len(sg.edata) == 2
69
    sh = sg.ndata["h"]
VoVAllen's avatar
VoVAllen committed
70
    assert F.allclose(F.gather_row(h, F.tensor(nid)), sh)
71
    """
Minjie Wang's avatar
Minjie Wang committed
72
73
74
    s, d, eid
    0, 1, 0
    1, 9, 1
Minjie Wang's avatar
Minjie Wang committed
75
76
77
78
    0, 2, 2    1
    2, 9, 3    1
    0, 3, 4    1
    3, 9, 5    1
Minjie Wang's avatar
Minjie Wang committed
79
80
81
    0, 4, 6
    4, 9, 7
    0, 5, 8
Minjie Wang's avatar
Minjie Wang committed
82
83
84
85
86
    5, 9, 9       3
    0, 6, 10   1
    6, 9, 11   1  3
    0, 7, 12   1
    7, 9, 13   1  3
Minjie Wang's avatar
Minjie Wang committed
87
    0, 8, 14
Minjie Wang's avatar
Minjie Wang committed
88
89
    8, 9, 15      3
    9, 0, 16   1
90
91
    """
    assert F.allclose(F.gather_row(l, eid), sg.edata["l"])
Minjie Wang's avatar
Minjie Wang committed
92
93
    # update the node/edge features on the subgraph should NOT
    # reflect to the parent graph.
94
95
96
    sg.ndata["h"] = F.zeros((6, D))
    assert F.allclose(h, g.ndata["h"])

Minjie Wang's avatar
Minjie Wang committed
97

98
99
def _test_map_to_subgraph():
    g = dgl.DGLGraph()
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
100
101
102
103
104
    g.add_nodes(10)
    g.add_edges(F.arange(0, 9), F.arange(1, 10))
    h = g.subgraph([0, 1, 2, 5, 8])
    v = h.map_to_subgraph_nid([0, 8, 2])
    assert np.array_equal(F.asnumpy(v), np.array([0, 4, 2]))
105

106

107
108
109
110
111
112
113
114
115
def create_test_heterograph(idtype):
    # test heterograph from the docstring, plus a user -- wishes -- game relation
    # 3 users, 2 games, 2 developers
    # metagraph:
    #    ('user', 'follows', 'user'),
    #    ('user', 'plays', 'game'),
    #    ('user', 'wishes', 'game'),
    #    ('developer', 'develops', 'game')])

116
117
118
119
120
121
122
123
124
125
    g = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 1, 2, 1], [0, 0, 1, 1]),
            ("user", "wishes", "game"): ([0, 2], [1, 0]),
            ("developer", "develops", "game"): ([0, 1], [0, 1]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
126
    for etype in g.etypes:
127
        g.edges[etype].data["weight"] = F.randn((g.num_edges(etype),))
128
129
130
131
    assert g.idtype == idtype
    assert g.device == F.ctx()
    return g

132

133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def create_test_heterograph2(idtype):
    """test heterograph from the docstring, with an empty relation"""
    # 3 users, 2 games, 2 developers
    # metagraph:
    #    ('user', 'follows', 'user'),
    #    ('user', 'plays', 'game'),
    #    ('user', 'wishes', 'game'),
    #    ('developer', 'develops', 'game')

    g = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 1, 2, 1], [0, 0, 1, 1]),
            ("user", "wishes", "game"): ([0, 2], [1, 0]),
            ("developer", "develops", "game"): ([], []),
        },
        idtype=idtype,
        device=F.ctx(),
    )
    for etype in g.etypes:
        g.edges[etype].data["weight"] = F.randn((g.num_edges(etype),))
    assert g.idtype == idtype
    assert g.device == F.ctx()
    return g


159
160
161
162
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="MXNet doesn't support bool tensor",
)
nv-dlasalle's avatar
nv-dlasalle committed
163
@parametrize_idtype
164
165
def test_subgraph_mask(idtype):
    g = create_test_heterograph(idtype)
166
167
    g_graph = g["follows"]
    g_bipartite = g["plays"]
168
169
170

    x = F.randn((3, 5))
    y = F.randn((2, 4))
171
172
    g.nodes["user"].data["h"] = x
    g.edges["follows"].data["h"] = y
173
174
175
176
177
178
179

    def _check_subgraph(g, sg):
        assert sg.idtype == g.idtype
        assert sg.device == g.device
        assert sg.ntypes == g.ntypes
        assert sg.etypes == g.etypes
        assert sg.canonical_etypes == g.canonical_etypes
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
        assert F.array_equal(
            F.tensor(sg.nodes["user"].data[dgl.NID]), F.tensor([1, 2], idtype)
        )
        assert F.array_equal(
            F.tensor(sg.nodes["game"].data[dgl.NID]), F.tensor([0], idtype)
        )
        assert F.array_equal(
            F.tensor(sg.edges["follows"].data[dgl.EID]), F.tensor([1], idtype)
        )
        assert F.array_equal(
            F.tensor(sg.edges["plays"].data[dgl.EID]), F.tensor([1], idtype)
        )
        assert F.array_equal(
            F.tensor(sg.edges["wishes"].data[dgl.EID]), F.tensor([1], idtype)
        )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
195
196
        assert sg.num_nodes("developer") == 0
        assert sg.num_edges("develops") == 0
197
198
199
200
201
202
203
204
205
206
207
208
209
        assert F.array_equal(
            sg.nodes["user"].data["h"], g.nodes["user"].data["h"][1:3]
        )
        assert F.array_equal(
            sg.edges["follows"].data["h"], g.edges["follows"].data["h"][1:2]
        )

    sg1 = g.subgraph(
        {
            "user": F.tensor([False, True, True], dtype=F.bool),
            "game": F.tensor([True, False, False, False], dtype=F.bool),
        }
    )
210
    _check_subgraph(g, sg1)
211
212
213
214
215
216
217
    sg2 = g.edge_subgraph(
        {
            "follows": F.tensor([False, True], dtype=F.bool),
            "plays": F.tensor([False, True, False, False], dtype=F.bool),
            "wishes": F.tensor([False, True], dtype=F.bool),
        }
    )
218
    _check_subgraph(g, sg2)
219

220

nv-dlasalle's avatar
nv-dlasalle committed
221
@parametrize_idtype
222
223
def test_subgraph1(idtype):
    g = create_test_heterograph(idtype)
224
225
    g_graph = g["follows"]
    g_bipartite = g["plays"]
226
227
228

    x = F.randn((3, 5))
    y = F.randn((2, 4))
229
230
    g.nodes["user"].data["h"] = x
    g.edges["follows"].data["h"] = y
231
232
233
234
235
236
237

    def _check_subgraph(g, sg):
        assert sg.idtype == g.idtype
        assert sg.device == g.device
        assert sg.ntypes == g.ntypes
        assert sg.etypes == g.etypes
        assert sg.canonical_etypes == g.canonical_etypes
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
        assert F.array_equal(
            F.tensor(sg.nodes["user"].data[dgl.NID]), F.tensor([1, 2], g.idtype)
        )
        assert F.array_equal(
            F.tensor(sg.nodes["game"].data[dgl.NID]), F.tensor([0], g.idtype)
        )
        assert F.array_equal(
            F.tensor(sg.edges["follows"].data[dgl.EID]), F.tensor([1], g.idtype)
        )
        assert F.array_equal(
            F.tensor(sg.edges["plays"].data[dgl.EID]), F.tensor([1], g.idtype)
        )
        assert F.array_equal(
            F.tensor(sg.edges["wishes"].data[dgl.EID]), F.tensor([1], g.idtype)
        )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
253
254
        assert sg.num_nodes("developer") == 0
        assert sg.num_edges("develops") == 0
255
256
257
258
259
260
261
262
        assert F.array_equal(
            sg.nodes["user"].data["h"], g.nodes["user"].data["h"][1:3]
        )
        assert F.array_equal(
            sg.edges["follows"].data["h"], g.edges["follows"].data["h"][1:2]
        )

    sg1 = g.subgraph({"user": [1, 2], "game": [0]})
263
    _check_subgraph(g, sg1)
264
    sg2 = g.edge_subgraph({"follows": [1], "plays": [1], "wishes": [1]})
265
    _check_subgraph(g, sg2)
266
267

    # backend tensor input
268
269
270
271
272
273
    sg1 = g.subgraph(
        {
            "user": F.tensor([1, 2], dtype=idtype),
            "game": F.tensor([0], dtype=idtype),
        }
    )
274
    _check_subgraph(g, sg1)
275
276
277
278
279
280
281
    sg2 = g.edge_subgraph(
        {
            "follows": F.tensor([1], dtype=idtype),
            "plays": F.tensor([1], dtype=idtype),
            "wishes": F.tensor([1], dtype=idtype),
        }
    )
282
    _check_subgraph(g, sg2)
283
284

    # numpy input
285
    sg1 = g.subgraph({"user": np.array([1, 2]), "game": np.array([0])})
286
    _check_subgraph(g, sg1)
287
288
289
290
291
292
293
    sg2 = g.edge_subgraph(
        {
            "follows": np.array([1]),
            "plays": np.array([1]),
            "wishes": np.array([1]),
        }
    )
294
    _check_subgraph(g, sg2)
295
296
297
298
299
300
301
302
303

    def _check_subgraph_single_ntype(g, sg, preserve_nodes=False):
        assert sg.idtype == g.idtype
        assert sg.device == g.device
        assert sg.ntypes == g.ntypes
        assert sg.etypes == g.etypes
        assert sg.canonical_etypes == g.canonical_etypes

        if not preserve_nodes:
304
305
306
307
            assert F.array_equal(
                F.tensor(sg.nodes["user"].data[dgl.NID]),
                F.tensor([1, 2], g.idtype),
            )
308
309
        else:
            for ntype in sg.ntypes:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
310
                assert g.num_nodes(ntype) == sg.num_nodes(ntype)
311

312
313
314
        assert F.array_equal(
            F.tensor(sg.edges["follows"].data[dgl.EID]), F.tensor([1], g.idtype)
        )
315
316

        if not preserve_nodes:
317
318
319
320
321
322
            assert F.array_equal(
                sg.nodes["user"].data["h"], g.nodes["user"].data["h"][1:3]
            )
        assert F.array_equal(
            sg.edges["follows"].data["h"], g.edges["follows"].data["h"][1:2]
        )
323
324
325
326
327
328
329

    def _check_subgraph_single_etype(g, sg, preserve_nodes=False):
        assert sg.ntypes == g.ntypes
        assert sg.etypes == g.etypes
        assert sg.canonical_etypes == g.canonical_etypes

        if not preserve_nodes:
330
331
332
333
334
335
336
337
            assert F.array_equal(
                F.tensor(sg.nodes["user"].data[dgl.NID]),
                F.tensor([0, 1], g.idtype),
            )
            assert F.array_equal(
                F.tensor(sg.nodes["game"].data[dgl.NID]),
                F.tensor([0], g.idtype),
            )
338
339
        else:
            for ntype in sg.ntypes:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
340
                assert g.num_nodes(ntype) == sg.num_nodes(ntype)
341

342
343
344
345
        assert F.array_equal(
            F.tensor(sg.edges["plays"].data[dgl.EID]),
            F.tensor([0, 1], g.idtype),
        )
346
347
348

    sg1_graph = g_graph.subgraph([1, 2])
    _check_subgraph_single_ntype(g_graph, sg1_graph)
349
350
351
352
353
354
355
356
    sg1_graph = g_graph.edge_subgraph([1])
    _check_subgraph_single_ntype(g_graph, sg1_graph)
    sg1_graph = g_graph.edge_subgraph([1], relabel_nodes=False)
    _check_subgraph_single_ntype(g_graph, sg1_graph, True)
    sg2_bipartite = g_bipartite.edge_subgraph([0, 1])
    _check_subgraph_single_etype(g_bipartite, sg2_bipartite)
    sg2_bipartite = g_bipartite.edge_subgraph([0, 1], relabel_nodes=False)
    _check_subgraph_single_etype(g_bipartite, sg2_bipartite, True)
357
358
359
360

    def _check_typed_subgraph1(g, sg):
        assert g.idtype == sg.idtype
        assert g.device == sg.device
361
362
        assert set(sg.ntypes) == {"user", "game"}
        assert set(sg.etypes) == {"follows", "plays", "wishes"}
363
        for ntype in sg.ntypes:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
364
            assert sg.num_nodes(ntype) == g.num_nodes(ntype)
365
        for etype in sg.etypes:
366
367
            src_sg, dst_sg = sg.all_edges(etype=etype, order="eid")
            src_g, dst_g = g.all_edges(etype=etype, order="eid")
368
369
            assert F.array_equal(src_sg, src_g)
            assert F.array_equal(dst_sg, dst_g)
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
        assert F.array_equal(
            sg.nodes["user"].data["h"], g.nodes["user"].data["h"]
        )
        assert F.array_equal(
            sg.edges["follows"].data["h"], g.edges["follows"].data["h"]
        )
        g.nodes["user"].data["h"] = F.scatter_row(
            g.nodes["user"].data["h"], F.tensor([2]), F.randn((1, 5))
        )
        g.edges["follows"].data["h"] = F.scatter_row(
            g.edges["follows"].data["h"], F.tensor([1]), F.randn((1, 4))
        )
        assert F.array_equal(
            sg.nodes["user"].data["h"], g.nodes["user"].data["h"]
        )
        assert F.array_equal(
            sg.edges["follows"].data["h"], g.edges["follows"].data["h"]
        )
388
389

    def _check_typed_subgraph2(g, sg):
390
391
        assert set(sg.ntypes) == {"developer", "game"}
        assert set(sg.etypes) == {"develops"}
392
        for ntype in sg.ntypes:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
393
            assert sg.num_nodes(ntype) == g.num_nodes(ntype)
394
        for etype in sg.etypes:
395
396
            src_sg, dst_sg = sg.all_edges(etype=etype, order="eid")
            src_g, dst_g = g.all_edges(etype=etype, order="eid")
397
398
399
            assert F.array_equal(src_sg, src_g)
            assert F.array_equal(dst_sg, dst_g)

400
    sg3 = g.node_type_subgraph(["user", "game"])
401
    _check_typed_subgraph1(g, sg3)
402
    sg4 = g.edge_type_subgraph(["develops"])
403
    _check_typed_subgraph2(g, sg4)
404
    sg5 = g.edge_type_subgraph(["follows", "plays", "wishes"])
405
406
407
    _check_typed_subgraph1(g, sg5)

    # Test for restricted format
408
    for fmt in ["csr", "csc", "coo"]:
409
410
411
412
        g = dgl.graph(([0, 1], [1, 2])).formats(fmt)
        sg = g.subgraph({g.ntypes[0]: [1, 0]})
        nids = F.asnumpy(sg.ndata[dgl.NID])
        assert np.array_equal(nids, np.array([1, 0]))
413
        src, dst = sg.edges(order="eid")
414
415
416
417
        src = F.asnumpy(src)
        dst = F.asnumpy(dst)
        assert np.array_equal(src, np.array([1]))

418

nv-dlasalle's avatar
nv-dlasalle committed
419
@parametrize_idtype
420
def test_in_subgraph(idtype):
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
    hg = dgl.heterograph(
        {
            ("user", "follow", "user"): (
                [1, 2, 3, 0, 2, 3, 0],
                [0, 0, 0, 1, 1, 1, 2],
            ),
            ("user", "play", "game"): ([0, 0, 1, 3], [0, 1, 2, 2]),
            ("game", "liked-by", "user"): (
                [2, 2, 2, 1, 1, 0],
                [0, 1, 2, 0, 3, 0],
            ),
            ("user", "flips", "coin"): ([0, 1, 2, 3], [0, 0, 0, 0]),
        },
        idtype=idtype,
        num_nodes_dict={"user": 5, "game": 10, "coin": 8},
    ).to(F.ctx())
    subg = dgl.in_subgraph(hg, {"user": [0, 1], "game": 0})
438
439
440
    assert subg.idtype == idtype
    assert len(subg.ntypes) == 3
    assert len(subg.etypes) == 4
441
    u, v = subg["follow"].edges()
442
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
443
444
445
446
447
    assert F.array_equal(
        hg["follow"].edge_ids(u, v), subg["follow"].edata[dgl.EID]
    )
    assert edge_set == {(1, 0), (2, 0), (3, 0), (0, 1), (2, 1), (3, 1)}
    u, v = subg["play"].edges()
448
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
449
450
451
    assert F.array_equal(hg["play"].edge_ids(u, v), subg["play"].edata[dgl.EID])
    assert edge_set == {(0, 0)}
    u, v = subg["liked-by"].edges()
452
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
453
454
455
456
    assert F.array_equal(
        hg["liked-by"].edge_ids(u, v), subg["liked-by"].edata[dgl.EID]
    )
    assert edge_set == {(2, 0), (2, 1), (1, 0), (0, 0)}
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
457
    assert subg["flips"].num_edges() == 0
458
459
460
461
    for ntype in subg.ntypes:
        assert dgl.NID not in subg.nodes[ntype].data

    # Test store_ids
462
463
    subg = dgl.in_subgraph(hg, {"user": [0, 1], "game": 0}, store_ids=False)
    for etype in ["follow", "play", "liked-by"]:
464
465
466
467
468
        assert dgl.EID not in subg.edges[etype].data
    for ntype in subg.ntypes:
        assert dgl.NID not in subg.nodes[ntype].data

    # Test relabel nodes
469
    subg = dgl.in_subgraph(hg, {"user": [0, 1], "game": 0}, relabel_nodes=True)
470
471
472
473
    assert subg.idtype == idtype
    assert len(subg.ntypes) == 3
    assert len(subg.etypes) == 4

474
475
476
477
478
479
    u, v = subg["follow"].edges()
    old_u = F.gather_row(subg.nodes["user"].data[dgl.NID], u)
    old_v = F.gather_row(subg.nodes["user"].data[dgl.NID], v)
    assert F.array_equal(
        hg["follow"].edge_ids(old_u, old_v), subg["follow"].edata[dgl.EID]
    )
480
    edge_set = set(zip(list(F.asnumpy(old_u)), list(F.asnumpy(old_v))))
481
482
483
484
485
486
487
488
    assert edge_set == {(1, 0), (2, 0), (3, 0), (0, 1), (2, 1), (3, 1)}

    u, v = subg["play"].edges()
    old_u = F.gather_row(subg.nodes["user"].data[dgl.NID], u)
    old_v = F.gather_row(subg.nodes["game"].data[dgl.NID], v)
    assert F.array_equal(
        hg["play"].edge_ids(old_u, old_v), subg["play"].edata[dgl.EID]
    )
489
    edge_set = set(zip(list(F.asnumpy(old_u)), list(F.asnumpy(old_v))))
490
491
492
493
494
495
496
497
    assert edge_set == {(0, 0)}

    u, v = subg["liked-by"].edges()
    old_u = F.gather_row(subg.nodes["game"].data[dgl.NID], u)
    old_v = F.gather_row(subg.nodes["user"].data[dgl.NID], v)
    assert F.array_equal(
        hg["liked-by"].edge_ids(old_u, old_v), subg["liked-by"].edata[dgl.EID]
    )
498
    edge_set = set(zip(list(F.asnumpy(old_u)), list(F.asnumpy(old_v))))
499
500
501
502
503
504
    assert edge_set == {(2, 0), (2, 1), (1, 0), (0, 0)}

    assert subg.num_nodes("user") == 4
    assert subg.num_nodes("game") == 3
    assert subg.num_nodes("coin") == 0
    assert subg.num_edges("flips") == 0
505

506

nv-dlasalle's avatar
nv-dlasalle committed
507
@parametrize_idtype
508
def test_out_subgraph(idtype):
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
    hg = dgl.heterograph(
        {
            ("user", "follow", "user"): (
                [1, 2, 3, 0, 2, 3, 0],
                [0, 0, 0, 1, 1, 1, 2],
            ),
            ("user", "play", "game"): ([0, 0, 1, 3], [0, 1, 2, 2]),
            ("game", "liked-by", "user"): (
                [2, 2, 2, 1, 1, 0],
                [0, 1, 2, 0, 3, 0],
            ),
            ("user", "flips", "coin"): ([0, 1, 2, 3], [0, 0, 0, 0]),
        },
        idtype=idtype,
    ).to(F.ctx())
    subg = dgl.out_subgraph(hg, {"user": [0, 1], "game": 0})
525
526
527
    assert subg.idtype == idtype
    assert len(subg.ntypes) == 3
    assert len(subg.etypes) == 4
528
    u, v = subg["follow"].edges()
529
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
530
531
532
533
534
    assert edge_set == {(1, 0), (0, 1), (0, 2)}
    assert F.array_equal(
        hg["follow"].edge_ids(u, v), subg["follow"].edata[dgl.EID]
    )
    u, v = subg["play"].edges()
535
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
536
537
538
    assert edge_set == {(0, 0), (0, 1), (1, 2)}
    assert F.array_equal(hg["play"].edge_ids(u, v), subg["play"].edata[dgl.EID])
    u, v = subg["liked-by"].edges()
539
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
540
541
542
543
544
    assert edge_set == {(0, 0)}
    assert F.array_equal(
        hg["liked-by"].edge_ids(u, v), subg["liked-by"].edata[dgl.EID]
    )
    u, v = subg["flips"].edges()
545
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
546
547
548
549
    assert edge_set == {(0, 0), (1, 0)}
    assert F.array_equal(
        hg["flips"].edge_ids(u, v), subg["flips"].edata[dgl.EID]
    )
550
551
552
553
    for ntype in subg.ntypes:
        assert dgl.NID not in subg.nodes[ntype].data

    # Test store_ids
554
    subg = dgl.out_subgraph(hg, {"user": [0, 1], "game": 0}, store_ids=False)
555
556
557
558
559
560
    for etype in subg.canonical_etypes:
        assert dgl.EID not in subg.edges[etype].data
    for ntype in subg.ntypes:
        assert dgl.NID not in subg.nodes[ntype].data

    # Test relabel nodes
561
    subg = dgl.out_subgraph(hg, {"user": [1], "game": 0}, relabel_nodes=True)
562
563
564
565
    assert subg.idtype == idtype
    assert len(subg.ntypes) == 3
    assert len(subg.etypes) == 4

566
567
568
    u, v = subg["follow"].edges()
    old_u = F.gather_row(subg.nodes["user"].data[dgl.NID], u)
    old_v = F.gather_row(subg.nodes["user"].data[dgl.NID], v)
569
570
    edge_set = set(zip(list(F.asnumpy(old_u)), list(F.asnumpy(old_v))))
    assert edge_set == {(1, 0)}
571
572
573
    assert F.array_equal(
        hg["follow"].edge_ids(old_u, old_v), subg["follow"].edata[dgl.EID]
    )
574

575
576
577
    u, v = subg["play"].edges()
    old_u = F.gather_row(subg.nodes["user"].data[dgl.NID], u)
    old_v = F.gather_row(subg.nodes["game"].data[dgl.NID], v)
578
579
    edge_set = set(zip(list(F.asnumpy(old_u)), list(F.asnumpy(old_v))))
    assert edge_set == {(1, 2)}
580
581
582
    assert F.array_equal(
        hg["play"].edge_ids(old_u, old_v), subg["play"].edata[dgl.EID]
    )
583

584
585
586
    u, v = subg["liked-by"].edges()
    old_u = F.gather_row(subg.nodes["game"].data[dgl.NID], u)
    old_v = F.gather_row(subg.nodes["user"].data[dgl.NID], v)
587
    edge_set = set(zip(list(F.asnumpy(old_u)), list(F.asnumpy(old_v))))
588
589
590
591
592
593
594
595
    assert edge_set == {(0, 0)}
    assert F.array_equal(
        hg["liked-by"].edge_ids(old_u, old_v), subg["liked-by"].edata[dgl.EID]
    )

    u, v = subg["flips"].edges()
    old_u = F.gather_row(subg.nodes["user"].data[dgl.NID], u)
    old_v = F.gather_row(subg.nodes["coin"].data[dgl.NID], v)
596
    edge_set = set(zip(list(F.asnumpy(old_u)), list(F.asnumpy(old_v))))
597
598
599
600
601
602
603
604
    assert edge_set == {(1, 0)}
    assert F.array_equal(
        hg["flips"].edge_ids(old_u, old_v), subg["flips"].edata[dgl.EID]
    )
    assert subg.num_nodes("user") == 2
    assert subg.num_nodes("game") == 2
    assert subg.num_nodes("coin") == 1

605
606
607
608

def test_subgraph_message_passing():
    # Unit test for PR #2055
    g = dgl.graph(([0, 1, 2], [2, 3, 4])).to(F.cpu())
609
    g.ndata["x"] = F.copy_to(F.randn((5, 6)), F.cpu())
610
    sg = g.subgraph([1, 2, 3]).to(F.ctx())
611
612
613
614
615
    sg.update_all(
        lambda edges: {"x": edges.src["x"]},
        lambda nodes: {"y": F.sum(nodes.mailbox["x"], 1)},
    )

616

nv-dlasalle's avatar
nv-dlasalle committed
617
@parametrize_idtype
618
def test_khop_in_subgraph(idtype):
619
620
621
622
    g = dgl.graph(
        ([1, 1, 2, 3, 4], [0, 2, 0, 4, 2]), idtype=idtype, device=F.ctx()
    )
    g.edata["w"] = F.tensor([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])
Mufei Li's avatar
Mufei Li committed
623
    sg, inv = dgl.khop_in_subgraph(g, 0, k=2)
624
625
626
    assert sg.idtype == g.idtype
    u, v = sg.edges()
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
627
628
629
630
631
632
633
    assert edge_set == {(1, 0), (1, 2), (2, 0), (3, 2)}
    assert F.array_equal(
        sg.edata[dgl.EID], F.tensor([0, 1, 2, 4], dtype=idtype)
    )
    assert F.array_equal(
        sg.edata["w"], F.tensor([[0, 1], [2, 3], [4, 5], [8, 9]])
    )
Mufei Li's avatar
Mufei Li committed
634
    assert F.array_equal(F.astype(inv, idtype), F.tensor([0], idtype))
635
636

    # Test multiple nodes
Mufei Li's avatar
Mufei Li committed
637
    sg, inv = dgl.khop_in_subgraph(g, [0, 2], k=1)
638
639
    assert sg.num_edges() == 4

Mufei Li's avatar
Mufei Li committed
640
    sg, inv = dgl.khop_in_subgraph(g, F.tensor([0, 2], idtype), k=1)
641
642
643
    assert sg.num_edges() == 4

    # Test isolated node
Mufei Li's avatar
Mufei Li committed
644
    sg, inv = dgl.khop_in_subgraph(g, 1, k=2)
645
646
647
    assert sg.idtype == g.idtype
    assert sg.num_nodes() == 1
    assert sg.num_edges() == 0
Mufei Li's avatar
Mufei Li committed
648
    assert F.array_equal(F.astype(inv, idtype), F.tensor([0], idtype))
649

650
651
652
653
654
655
656
657
658
    g = dgl.heterograph(
        {
            ("user", "plays", "game"): ([0, 1, 1, 2], [0, 0, 2, 1]),
            ("user", "follows", "user"): ([0, 1, 1], [1, 2, 2]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
    sg, inv = dgl.khop_in_subgraph(g, {"game": 0}, k=2)
659
    assert sg.idtype == idtype
660
661
    assert sg.num_nodes("game") == 1
    assert sg.num_nodes("user") == 2
662
663
    assert len(sg.ntypes) == 2
    assert len(sg.etypes) == 2
664
    u, v = sg["follows"].edges()
665
666
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
    assert edge_set == {(0, 1)}
667
    u, v = sg["plays"].edges()
668
669
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
    assert edge_set == {(0, 0), (1, 0)}
670
    assert F.array_equal(F.astype(inv["game"], idtype), F.tensor([0], idtype))
671
672

    # Test isolated node
673
    sg, inv = dgl.khop_in_subgraph(g, {"user": 0}, k=2)
674
    assert sg.idtype == idtype
675
676
677
678
679
    assert sg.num_nodes("game") == 0
    assert sg.num_nodes("user") == 1
    assert sg.num_edges("follows") == 0
    assert sg.num_edges("plays") == 0
    assert F.array_equal(F.astype(inv["user"], idtype), F.tensor([0], idtype))
680
681

    # Test multiple nodes
682
683
684
685
    sg, inv = dgl.khop_in_subgraph(
        g, {"user": F.tensor([0, 1], idtype), "game": 0}, k=1
    )
    u, v = sg["follows"].edges()
686
687
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
    assert edge_set == {(0, 1)}
688
    u, v = sg["plays"].edges()
689
690
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
    assert edge_set == {(0, 0), (1, 0)}
691
692
693
694
695
    assert F.array_equal(
        F.astype(inv["user"], idtype), F.tensor([0, 1], idtype)
    )
    assert F.array_equal(F.astype(inv["game"], idtype), F.tensor([0], idtype))

696

nv-dlasalle's avatar
nv-dlasalle committed
697
@parametrize_idtype
698
def test_khop_out_subgraph(idtype):
699
700
701
702
    g = dgl.graph(
        ([0, 2, 0, 4, 2], [1, 1, 2, 3, 4]), idtype=idtype, device=F.ctx()
    )
    g.edata["w"] = F.tensor([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])
Mufei Li's avatar
Mufei Li committed
703
    sg, inv = dgl.khop_out_subgraph(g, 0, k=2)
704
705
706
    assert sg.idtype == g.idtype
    u, v = sg.edges()
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
707
708
709
710
711
712
713
    assert edge_set == {(0, 1), (2, 1), (0, 2), (2, 3)}
    assert F.array_equal(
        sg.edata[dgl.EID], F.tensor([0, 2, 1, 4], dtype=idtype)
    )
    assert F.array_equal(
        sg.edata["w"], F.tensor([[0, 1], [4, 5], [2, 3], [8, 9]])
    )
Mufei Li's avatar
Mufei Li committed
714
    assert F.array_equal(F.astype(inv, idtype), F.tensor([0], idtype))
715
716

    # Test multiple nodes
Mufei Li's avatar
Mufei Li committed
717
    sg, inv = dgl.khop_out_subgraph(g, [0, 2], k=1)
718
719
    assert sg.num_edges() == 4

Mufei Li's avatar
Mufei Li committed
720
    sg, inv = dgl.khop_out_subgraph(g, F.tensor([0, 2], idtype), k=1)
721
722
723
    assert sg.num_edges() == 4

    # Test isolated node
Mufei Li's avatar
Mufei Li committed
724
    sg, inv = dgl.khop_out_subgraph(g, 1, k=2)
725
726
727
    assert sg.idtype == g.idtype
    assert sg.num_nodes() == 1
    assert sg.num_edges() == 0
Mufei Li's avatar
Mufei Li committed
728
    assert F.array_equal(F.astype(inv, idtype), F.tensor([0], idtype))
729

730
731
732
733
734
735
736
737
738
    g = dgl.heterograph(
        {
            ("user", "plays", "game"): ([0, 1, 1, 2], [0, 0, 2, 1]),
            ("user", "follows", "user"): ([0, 1], [1, 3]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
    sg, inv = dgl.khop_out_subgraph(g, {"user": 0}, k=2)
739
    assert sg.idtype == idtype
740
741
    assert sg.num_nodes("game") == 2
    assert sg.num_nodes("user") == 3
742
743
    assert len(sg.ntypes) == 2
    assert len(sg.etypes) == 2
744
    u, v = sg["follows"].edges()
745
746
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
    assert edge_set == {(0, 1), (1, 2)}
747
    u, v = sg["plays"].edges()
748
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
749
750
    assert edge_set == {(0, 0), (1, 0), (1, 1)}
    assert F.array_equal(F.astype(inv["user"], idtype), F.tensor([0], idtype))
751
752

    # Test isolated node
753
    sg, inv = dgl.khop_out_subgraph(g, {"user": 3}, k=2)
754
    assert sg.idtype == idtype
755
756
757
758
759
    assert sg.num_nodes("game") == 0
    assert sg.num_nodes("user") == 1
    assert sg.num_edges("follows") == 0
    assert sg.num_edges("plays") == 0
    assert F.array_equal(F.astype(inv["user"], idtype), F.tensor([0], idtype))
760
761

    # Test multiple nodes
762
763
764
765
766
    sg, inv = dgl.khop_out_subgraph(
        g, {"user": F.tensor([2], idtype), "game": 0}, k=1
    )
    assert sg.num_edges("follows") == 0
    u, v = sg["plays"].edges()
767
768
    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
    assert edge_set == {(0, 1)}
769
770
    assert F.array_equal(F.astype(inv["user"], idtype), F.tensor([0], idtype))
    assert F.array_equal(F.astype(inv["game"], idtype), F.tensor([0], idtype))
771

772
773

@unittest.skipIf(not F.gpu_ctx(), "only necessary with GPU")
774
@pytest.mark.parametrize(
775
776
777
778
    "parent_idx_device",
    [("cpu", F.cpu()), ("cuda", F.cuda()), ("uva", F.cpu()), ("uva", F.cuda())],
)
@pytest.mark.parametrize("child_device", [F.cpu(), F.cuda()])
779
780
def test_subframes(parent_idx_device, child_device):
    parent_device, idx_device = parent_idx_device
781
782
783
    g = dgl.graph(
        (F.tensor([1, 2, 3], dtype=F.int64), F.tensor([2, 3, 4], dtype=F.int64))
    )
784
    print(g.device)
785
786
    g.ndata["x"] = F.randn((5, 4))
    g.edata["a"] = F.randn((3, 6))
787
    idx = F.tensor([1, 2], dtype=F.int64)
788
    if parent_device == "cuda":
789
        g = g.to(F.cuda())
790
791
    elif parent_device == "uva":
        if F.backend_name != "pytorch":
792
            pytest.skip("UVA only supported for PyTorch")
793
794
795
        g = g.to(F.cpu())
        g.create_formats_()
        g.pin_memory_()
796
    elif parent_device == "cpu":
797
798
799
        g = g.to(F.cpu())
    idx = F.copy_to(idx, idx_device)
    sg = g.sample_neighbors(idx, 2).to(child_device)
800
801
    assert sg.device == F.context(sg.ndata["x"])
    assert sg.device == F.context(sg.edata["a"])
802
    assert sg.device == child_device
803
804
805
806
807
808
    if parent_device != "uva":
        sg = g.to(child_device).sample_neighbors(
            F.copy_to(idx, child_device), 2
        )
        assert sg.device == F.context(sg.ndata["x"])
        assert sg.device == F.context(sg.edata["a"])
809
        assert sg.device == child_device
810
    if parent_device == "uva":
811
812
        g.unpin_memory_()

813
814
815
816
817
818
819
820

@unittest.skipIf(
    F._default_context_str != "gpu", reason="UVA only available on GPU"
)
@unittest.skipIf(
    dgl.backend.backend_name != "pytorch",
    reason="UVA only supported for PyTorch",
)
821
@pytest.mark.parametrize("device", [F.cpu(), F.cuda()])
nv-dlasalle's avatar
nv-dlasalle committed
822
@parametrize_idtype
823
def test_uva_subgraph(idtype, device):
824
    g = create_test_heterograph2(idtype)
825
826
827
    g = g.to(F.cpu())
    g.create_formats_()
    g.pin_memory_()
828
829
    indices = {"user": F.copy_to(F.tensor([0], idtype), device)}
    edge_indices = {"follows": F.copy_to(F.tensor([0], idtype), device)}
830
831
832
833
    assert g.subgraph(indices).device == device
    assert g.edge_subgraph(edge_indices).device == device
    assert g.in_subgraph(indices).device == device
    assert g.out_subgraph(indices).device == device
834
835
    assert g.khop_in_subgraph(indices, 1)[0].device == device
    assert g.khop_out_subgraph(indices, 1)[0].device == device
836
837
838
    assert g.sample_neighbors(indices, 1).device == device
    g.unpin_memory_()

839
840

if __name__ == "__main__":
841
    test_edge_subgraph()
842
843
    test_uva_subgraph(F.int64, F.cpu())
    test_uva_subgraph(F.int64, F.cuda())