test_sampling.py 64.7 KB
Newer Older
1
import unittest
2
from collections import defaultdict
3
4
5
6
7

import backend as F

import dgl
import numpy as np
8
import pytest
9

10
11
12
13
14
sample_neighbors_fusing_mode = {
    True: dgl.sampling.sample_neighbors_fused,
    False: dgl.sampling.sample_neighbors,
}

15

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
16
def check_random_walk(g, metapath, traces, ntypes, prob=None, trace_eids=None):
17
18
19
20
    traces = F.asnumpy(traces)
    ntypes = F.asnumpy(ntypes)
    for j in range(traces.shape[1] - 1):
        assert ntypes[j] == g.get_ntype_id(g.to_canonical_etype(metapath[j])[0])
21
22
23
        assert ntypes[j + 1] == g.get_ntype_id(
            g.to_canonical_etype(metapath[j])[2]
        )
24
25
26

    for i in range(traces.shape[0]):
        for j in range(traces.shape[1] - 1):
27
            assert g.has_edges_between(
28
29
                traces[i, j], traces[i, j + 1], etype=metapath[j]
            )
30
            if prob is not None and prob in g.edges[metapath[j]].data:
31
32
33
34
                p = F.asnumpy(g.edges[metapath[j]].data["p"])
                eids = g.edge_ids(
                    traces[i, j], traces[i, j + 1], etype=metapath[j]
                )
35
                assert p[eids] != 0
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
36
37
38
            if trace_eids is not None:
                u, v = g.find_edges(trace_eids[i, j], etype=metapath[j])
                assert (u == traces[i, j]) and (v == traces[i, j + 1])
39

40
41

@pytest.mark.parametrize("use_uva", [True, False])
42
43
44
def test_non_uniform_random_walk(use_uva):
    if use_uva:
        if F.ctx() == F.cpu():
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
            pytest.skip("UVA biased random walk requires a GPU.")
        if dgl.backend.backend_name != "pytorch":
            pytest.skip(
                "UVA biased random walk is only supported with PyTorch."
            )
    g2 = dgl.heterograph(
        {("user", "follow", "user"): ([0, 1, 1, 2, 3], [1, 2, 3, 0, 0])}
    )
    g4 = dgl.heterograph(
        {
            ("user", "follow", "user"): ([0, 1, 1, 2, 3], [1, 2, 3, 0, 0]),
            ("user", "view", "item"): ([0, 0, 1, 2, 3, 3], [0, 1, 1, 2, 2, 1]),
            ("item", "viewed-by", "user"): (
                [0, 1, 1, 2, 2, 1],
                [0, 0, 1, 2, 3, 3],
            ),
        }
    )

    g2.edata["p"] = F.copy_to(
        F.tensor([3, 0, 3, 3, 3], dtype=F.float32), F.cpu()
    )
    g2.edata["p2"] = F.copy_to(
        F.tensor([[3], [0], [3], [3], [3]], dtype=F.float32), F.cpu()
    )
    g4.edges["follow"].data["p"] = F.copy_to(
        F.tensor([3, 0, 3, 3, 3], dtype=F.float32), F.cpu()
    )
    g4.edges["viewed-by"].data["p"] = F.copy_to(
        F.tensor([1, 1, 1, 1, 1, 1], dtype=F.float32), F.cpu()
    )
76

77
78
79
80
    if use_uva:
        for g in (g2, g4):
            g.create_formats_()
            g.pin_memory_()
81
    elif F._default_context_str == "gpu":
82
83
        g2 = g2.to(F.ctx())
        g4 = g4.to(F.ctx())
84

85
    try:
86
        traces, eids, ntypes = dgl.sampling.random_walk(
87
88
89
90
91
92
93
94
95
            g2,
            F.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=g2.idtype),
            length=4,
            prob="p",
            return_eids=True,
        )
        check_random_walk(
            g2, ["follow"] * 4, traces, ntypes, "p", trace_eids=eids
        )
96
97
98

        with pytest.raises(dgl.DGLError):
            traces, ntypes = dgl.sampling.random_walk(
99
100
101
102
103
                g2,
                F.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=g2.idtype),
                length=4,
                prob="p2",
            )
104

105
        metapath = ["follow", "view", "viewed-by"] * 2
106
        traces, eids, ntypes = dgl.sampling.random_walk(
107
108
109
110
111
112
113
            g4,
            F.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=g4.idtype),
            metapath=metapath,
            prob="p",
            return_eids=True,
        )
        check_random_walk(g4, metapath, traces, ntypes, "p", trace_eids=eids)
114
        traces, eids, ntypes = dgl.sampling.random_walk(
115
116
117
118
119
120
121
122
            g4,
            F.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=g4.idtype),
            metapath=metapath,
            prob="p",
            restart_prob=0.0,
            return_eids=True,
        )
        check_random_walk(g4, metapath, traces, ntypes, "p", trace_eids=eids)
123
        traces, eids, ntypes = dgl.sampling.random_walk(
124
125
126
127
128
129
130
131
            g4,
            F.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=g4.idtype),
            metapath=metapath,
            prob="p",
            restart_prob=F.zeros((6,), F.float32, F.ctx()),
            return_eids=True,
        )
        check_random_walk(g4, metapath, traces, ntypes, "p", trace_eids=eids)
132
        traces, eids, ntypes = dgl.sampling.random_walk(
133
134
135
136
137
138
139
140
141
142
            g4,
            F.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=g4.idtype),
            metapath=metapath + ["follow"],
            prob="p",
            restart_prob=F.tensor([0, 0, 0, 0, 0, 0, 1], F.float32),
            return_eids=True,
        )
        check_random_walk(
            g4, metapath, traces[:, :7], ntypes[:7], "p", trace_eids=eids
        )
143
144
145
146
147
        assert (F.asnumpy(traces[:, 7]) == -1).all()
    finally:
        for g in (g2, g4):
            g.unpin_memory_()

148
149

@pytest.mark.parametrize("use_uva", [True, False])
150
def test_uniform_random_walk(use_uva):
151
    if use_uva and F.ctx() == F.cpu():
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
        pytest.skip("UVA random walk requires a GPU.")
    g1 = dgl.heterograph({("user", "follow", "user"): ([0, 1, 2], [1, 2, 0])})
    g2 = dgl.heterograph(
        {("user", "follow", "user"): ([0, 1, 1, 2, 3], [1, 2, 3, 0, 0])}
    )
    g3 = dgl.heterograph(
        {
            ("user", "follow", "user"): ([0, 1, 2], [1, 2, 0]),
            ("user", "view", "item"): ([0, 1, 2], [0, 1, 2]),
            ("item", "viewed-by", "user"): ([0, 1, 2], [0, 1, 2]),
        }
    )
    g4 = dgl.heterograph(
        {
            ("user", "follow", "user"): ([0, 1, 1, 2, 3], [1, 2, 3, 0, 0]),
            ("user", "view", "item"): ([0, 0, 1, 2, 3, 3], [0, 1, 1, 2, 2, 1]),
            ("item", "viewed-by", "user"): (
                [0, 1, 1, 2, 2, 1],
                [0, 0, 1, 2, 3, 3],
            ),
        }
    )
174
175
176
177
178

    if use_uva:
        for g in (g1, g2, g3, g4):
            g.create_formats_()
            g.pin_memory_()
179
    elif F._default_context_str == "gpu":
180
181
182
183
184
185
186
        g1 = g1.to(F.ctx())
        g2 = g2.to(F.ctx())
        g3 = g3.to(F.ctx())
        g4 = g4.to(F.ctx())

    try:
        traces, eids, ntypes = dgl.sampling.random_walk(
187
188
189
190
191
192
193
            g1,
            F.tensor([0, 1, 2, 0, 1, 2], dtype=g1.idtype),
            length=4,
            return_eids=True,
        )
        check_random_walk(g1, ["follow"] * 4, traces, ntypes, trace_eids=eids)
        if F._default_context_str == "cpu":
194
            with pytest.raises(dgl.DGLError):
195
196
197
198
199
200
                dgl.sampling.random_walk(
                    g1,
                    F.tensor([0, 1, 2, 10], dtype=g1.idtype),
                    length=4,
                    return_eids=True,
                )
201
        traces, eids, ntypes = dgl.sampling.random_walk(
202
203
204
205
206
207
208
            g1,
            F.tensor([0, 1, 2, 0, 1, 2], dtype=g1.idtype),
            length=4,
            restart_prob=0.0,
            return_eids=True,
        )
        check_random_walk(g1, ["follow"] * 4, traces, ntypes, trace_eids=eids)
209
        traces, ntypes = dgl.sampling.random_walk(
210
211
212
213
214
215
            g1,
            F.tensor([0, 1, 2, 0, 1, 2], dtype=g1.idtype),
            length=4,
            restart_prob=F.zeros((4,), F.float32),
        )
        check_random_walk(g1, ["follow"] * 4, traces, ntypes)
216
        traces, ntypes = dgl.sampling.random_walk(
217
218
219
220
221
            g1,
            F.tensor([0, 1, 2, 0, 1, 2], dtype=g1.idtype),
            length=5,
            restart_prob=F.tensor([0, 0, 0, 0, 1], dtype=F.float32),
        )
222
        check_random_walk(
223
224
225
226
227
            g1,
            ["follow"] * 4,
            F.slice_axis(traces, 1, 0, 5),
            F.slice_axis(ntypes, 0, 0, 5),
        )
228
229
230
        assert (F.asnumpy(traces)[:, 5] == -1).all()

        traces, eids, ntypes = dgl.sampling.random_walk(
231
232
233
234
235
236
237
238
            g2,
            F.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=g2.idtype),
            length=4,
            return_eids=True,
        )
        check_random_walk(g2, ["follow"] * 4, traces, ntypes, trace_eids=eids)

        metapath = ["follow", "view", "viewed-by"] * 2
239
        traces, eids, ntypes = dgl.sampling.random_walk(
240
241
242
243
244
            g3,
            F.tensor([0, 1, 2, 0, 1, 2], dtype=g3.idtype),
            metapath=metapath,
            return_eids=True,
        )
245
246
        check_random_walk(g3, metapath, traces, ntypes, trace_eids=eids)

247
        metapath = ["follow", "view", "viewed-by"] * 2
248
        traces, eids, ntypes = dgl.sampling.random_walk(
249
250
251
252
253
            g4,
            F.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=g4.idtype),
            metapath=metapath,
            return_eids=True,
        )
254
255
256
        check_random_walk(g4, metapath, traces, ntypes, trace_eids=eids)

        traces, eids, ntypes = dgl.sampling.random_walk(
257
258
259
260
261
            g4,
            F.tensor([0, 1, 2, 0, 1, 2], dtype=g4.idtype),
            metapath=metapath,
            return_eids=True,
        )
262
        check_random_walk(g4, metapath, traces, ntypes, trace_eids=eids)
263
    finally:  # make sure to unpin the graphs even if some test fails
264
265
266
267
        for g in (g1, g2, g3, g4):
            if g.is_pinned():
                g.unpin_memory_()

268
269
270
271

@unittest.skipIf(
    F._default_context_str == "gpu", reason="GPU random walk not implemented"
)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
272
def test_node2vec():
273
274
275
276
277
    g1 = dgl.heterograph({("user", "follow", "user"): ([0, 1, 2], [1, 2, 0])})
    g2 = dgl.heterograph(
        {("user", "follow", "user"): ([0, 1, 1, 2, 3], [1, 2, 3, 0, 0])}
    )
    g2.edata["p"] = F.tensor([3, 0, 3, 3, 3], dtype=F.float32)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
278
279
280

    ntypes = F.zeros((5,), dtype=F.int64)

281
282
283
284
    traces, eids = dgl.sampling.node2vec_random_walk(
        g1, [0, 1, 2, 0, 1, 2], 1, 1, 4, return_eids=True
    )
    check_random_walk(g1, ["follow"] * 4, traces, ntypes, trace_eids=eids)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
285
286

    traces, eids = dgl.sampling.node2vec_random_walk(
287
288
289
290
        g2, [0, 1, 2, 3, 0, 1, 2, 3], 1, 1, 4, prob="p", return_eids=True
    )
    check_random_walk(g2, ["follow"] * 4, traces, ntypes, "p", trace_eids=eids)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
291

292
293
294
@unittest.skipIf(
    F._default_context_str == "gpu", reason="GPU pack traces not implemented"
)
295
def test_pack_traces():
296
297
298
299
300
301
    traces, types = (
        np.array(
            [[0, 1, -1, -1, -1, -1, -1], [0, 1, 1, 3, 0, 0, 0]], dtype="int64"
        ),
        np.array([0, 0, 1, 0, 0, 1, 0], dtype="int64"),
    )
302
303
304
    traces = F.zerocopy_from_numpy(traces)
    types = F.zerocopy_from_numpy(types)
    result = dgl.sampling.pack_traces(traces, types)
305
306
307
308
309
310
    assert F.array_equal(
        result[0], F.tensor([0, 1, 0, 1, 1, 3, 0, 0, 0], dtype=F.int64)
    )
    assert F.array_equal(
        result[1], F.tensor([0, 0, 0, 0, 1, 0, 0, 1, 0], dtype=F.int64)
    )
311
312
313
    assert F.array_equal(result[2], F.tensor([2, 7], dtype=F.int64))
    assert F.array_equal(result[3], F.tensor([0, 2], dtype=F.int64))

314
315

@pytest.mark.parametrize("use_uva", [True, False])
316
def test_pinsage_sampling(use_uva):
317
    if use_uva and F.ctx() == F.cpu():
318
319
        pytest.skip("UVA sampling requires a GPU.")

320
    def _test_sampler(g, sampler, ntype):
321
        seeds = F.copy_to(F.tensor([0, 2], dtype=g.idtype), F.ctx())
322
        neighbor_g = sampler(seeds)
323
        assert neighbor_g.ntypes == [ntype]
324
        u, v = neighbor_g.all_edges(form="uv", order="eid")
325
326
327
328
        uv = list(zip(F.asnumpy(u).tolist(), F.asnumpy(v).tolist()))
        assert (1, 0) in uv or (0, 0) in uv
        assert (2, 2) in uv or (3, 2) in uv

329
330
331
332
333
334
335
336
337
338
339
340
    g = dgl.heterograph(
        {
            ("item", "bought-by", "user"): (
                [0, 0, 1, 1, 2, 2, 3, 3],
                [0, 1, 0, 1, 2, 3, 2, 3],
            ),
            ("user", "bought", "item"): (
                [0, 1, 0, 1, 2, 3, 2, 3],
                [0, 0, 1, 1, 2, 2, 3, 3],
            ),
        }
    )
341
342
343
    if use_uva:
        g.create_formats_()
        g.pin_memory_()
344
    elif F._default_context_str == "gpu":
345
346
        g = g.to(F.ctx())
    try:
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
        sampler = dgl.sampling.PinSAGESampler(g, "item", "user", 4, 0.5, 3, 2)
        _test_sampler(g, sampler, "item")
        sampler = dgl.sampling.RandomWalkNeighborSampler(
            g, 4, 0.5, 3, 2, ["bought-by", "bought"]
        )
        _test_sampler(g, sampler, "item")
        sampler = dgl.sampling.RandomWalkNeighborSampler(
            g,
            4,
            0.5,
            3,
            2,
            [("item", "bought-by", "user"), ("user", "bought", "item")],
        )
        _test_sampler(g, sampler, "item")
362
363
364
365
    finally:
        if g.is_pinned():
            g.unpin_memory_()

366
    g = dgl.graph(([0, 0, 1, 1, 2, 2, 3, 3], [0, 1, 0, 1, 2, 3, 2, 3]))
367
368
369
    if use_uva:
        g.create_formats_()
        g.pin_memory_()
370
    elif F._default_context_str == "gpu":
371
372
373
374
375
376
377
378
        g = g.to(F.ctx())
    try:
        sampler = dgl.sampling.RandomWalkNeighborSampler(g, 4, 0.5, 3, 2)
        _test_sampler(g, sampler, g.ntypes[0])
    finally:
        if g.is_pinned():
            g.unpin_memory_()

379
380
381
382
383
384
385
    g = dgl.heterograph(
        {
            ("A", "AB", "B"): ([0, 2], [1, 3]),
            ("B", "BC", "C"): ([1, 3], [2, 1]),
            ("C", "CA", "A"): ([2, 1], [0, 2]),
        }
    )
386
387
388
    if use_uva:
        g.create_formats_()
        g.pin_memory_()
389
    elif F._default_context_str == "gpu":
390
391
        g = g.to(F.ctx())
    try:
392
393
394
395
        sampler = dgl.sampling.RandomWalkNeighborSampler(
            g, 4, 0.5, 3, 2, ["AB", "BC", "CA"]
        )
        _test_sampler(g, sampler, "A")
396
397
398
    finally:
        if g.is_pinned():
            g.unpin_memory_()
399

400

401
402
403
404
def _gen_neighbor_sampling_test_graph(hypersparse, reverse):
    if hypersparse:
        # should crash if allocated a CSR
        card = 1 << 50
405
        num_nodes_dict = {"user": card, "game": card, "coin": card}
406
407
    else:
        card = None
408
409
        num_nodes_dict = None

410
    if reverse:
411
412
413
414
415
416
417
418
419
        g = dgl.heterograph(
            {
                ("user", "follow", "user"): (
                    [0, 0, 0, 1, 1, 1, 2],
                    [1, 2, 3, 0, 2, 3, 0],
                )
            },
            {"user": card if card is not None else 4},
        )
420
        g = g.to(F.ctx())
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
        g.edata["prob"] = F.tensor(
            [0.5, 0.5, 0.0, 0.5, 0.5, 0.0, 1.0], dtype=F.float32
        )
        g.edata["mask"] = F.tensor([True, True, False, True, True, False, True])
        hg = dgl.heterograph(
            {
                ("user", "follow", "user"): (
                    [0, 0, 0, 1, 1, 1, 2],
                    [1, 2, 3, 0, 2, 3, 0],
                ),
                ("game", "play", "user"): ([0, 1, 2, 2], [0, 0, 1, 3]),
                ("user", "liked-by", "game"): (
                    [0, 1, 2, 0, 3, 0],
                    [2, 2, 2, 1, 1, 0],
                ),
                ("coin", "flips", "user"): ([0, 0, 0, 0], [0, 1, 2, 3]),
            },
            num_nodes_dict,
        )
440
        hg = hg.to(F.ctx())
441
    else:
442
443
444
445
446
447
448
449
450
        g = dgl.heterograph(
            {
                ("user", "follow", "user"): (
                    [1, 2, 3, 0, 2, 3, 0],
                    [0, 0, 0, 1, 1, 1, 2],
                )
            },
            {"user": card if card is not None else 4},
        )
451
        g = g.to(F.ctx())
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
        g.edata["prob"] = F.tensor(
            [0.5, 0.5, 0.0, 0.5, 0.5, 0.0, 1.0], dtype=F.float32
        )
        g.edata["mask"] = F.tensor([True, True, False, True, True, False, True])
        hg = dgl.heterograph(
            {
                ("user", "follow", "user"): (
                    [1, 2, 3, 0, 2, 3, 0],
                    [0, 0, 0, 1, 1, 1, 2],
                ),
                ("user", "play", "game"): ([0, 0, 1, 3], [0, 1, 2, 2]),
                ("game", "liked-by", "user"): (
                    [2, 2, 2, 1, 1, 0],
                    [0, 1, 2, 0, 3, 0],
                ),
                ("user", "flips", "coin"): ([0, 1, 2, 3], [0, 0, 0, 0]),
            },
            num_nodes_dict,
        )
471
        hg = hg.to(F.ctx())
472
473
474
475
476
477
478
479
480
    hg.edges["follow"].data["prob"] = F.tensor(
        [0.5, 0.5, 0.0, 0.5, 0.5, 0.0, 1.0], dtype=F.float32
    )
    hg.edges["follow"].data["mask"] = F.tensor(
        [True, True, False, True, True, False, True]
    )
    hg.edges["play"].data["prob"] = F.tensor(
        [0.8, 0.5, 0.5, 0.5], dtype=F.float32
    )
481
    # Leave out the mask of play and liked-by since all of them are True anyway.
482
483
484
    hg.edges["liked-by"].data["prob"] = F.tensor(
        [0.3, 0.5, 0.2, 0.5, 0.1, 0.1], dtype=F.float32
    )
485
486
487

    return g, hg

488

489
490
491
492
493
494
def _gen_neighbor_topk_test_graph(hypersparse, reverse):
    if hypersparse:
        # should crash if allocated a CSR
        card = 1 << 50
    else:
        card = None
495

496
    if reverse:
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
        g = dgl.heterograph(
            {
                ("user", "follow", "user"): (
                    [0, 0, 0, 1, 1, 1, 2],
                    [1, 2, 3, 0, 2, 3, 0],
                )
            }
        )
        g.edata["weight"] = F.tensor(
            [0.5, 0.3, 0.0, -5.0, 22.0, 0.0, 1.0], dtype=F.float32
        )
        hg = dgl.heterograph(
            {
                ("user", "follow", "user"): (
                    [0, 0, 0, 1, 1, 1, 2],
                    [1, 2, 3, 0, 2, 3, 0],
                ),
                ("game", "play", "user"): ([0, 1, 2, 2], [0, 0, 1, 3]),
                ("user", "liked-by", "game"): (
                    [0, 1, 2, 0, 3, 0],
                    [2, 2, 2, 1, 1, 0],
                ),
                ("coin", "flips", "user"): ([0, 0, 0, 0], [0, 1, 2, 3]),
            }
        )
522
    else:
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
        g = dgl.heterograph(
            {
                ("user", "follow", "user"): (
                    [1, 2, 3, 0, 2, 3, 0],
                    [0, 0, 0, 1, 1, 1, 2],
                )
            }
        )
        g.edata["weight"] = F.tensor(
            [0.5, 0.3, 0.0, -5.0, 22.0, 0.0, 1.0], dtype=F.float32
        )
        hg = dgl.heterograph(
            {
                ("user", "follow", "user"): (
                    [1, 2, 3, 0, 2, 3, 0],
                    [0, 0, 0, 1, 1, 1, 2],
                ),
                ("user", "play", "game"): ([0, 0, 1, 3], [0, 1, 2, 2]),
                ("game", "liked-by", "user"): (
                    [2, 2, 2, 1, 1, 0],
                    [0, 1, 2, 0, 3, 0],
                ),
                ("user", "flips", "coin"): ([0, 1, 2, 3], [0, 0, 0, 0]),
            }
        )
    hg.edges["follow"].data["weight"] = F.tensor(
        [0.5, 0.3, 0.0, -5.0, 22.0, 0.0, 1.0], dtype=F.float32
    )
    hg.edges["play"].data["weight"] = F.tensor(
        [0.8, 0.5, 0.4, 0.5], dtype=F.float32
    )
    hg.edges["liked-by"].data["weight"] = F.tensor(
        [0.3, 0.5, 0.2, 0.5, 0.1, 0.1], dtype=F.float32
    )
    hg.edges["flips"].data["weight"] = F.tensor(
        [10, 2, 13, -1], dtype=F.float32
    )
560
561
    return g, hg

562

563
def _test_sample_neighbors(hypersparse, prob, fused):
564
565
566
    g, hg = _gen_neighbor_sampling_test_graph(hypersparse, False)

    def _test1(p, replace):
567
        subg = sample_neighbors_fusing_mode[fused](
568
569
            g, [0, 1], -1, prob=p, replace=replace
        )
570
571
        if not fused:
            assert subg.num_nodes() == g.num_nodes()
572
        u, v = subg.edges()
573
574
        if fused:
            u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v]
575
        u_ans, v_ans, e_ans = g.in_edges([0, 1], form="all")
576
577
        if p is not None:
            emask = F.gather_row(g.edata[p], e_ans)
578
579
            if p == "prob":
                emask = emask != 0
580
581
            u_ans = F.boolean_mask(u_ans, emask)
            v_ans = F.boolean_mask(v_ans, emask)
582
583
584
585
        uv = set(zip(F.asnumpy(u), F.asnumpy(v)))
        uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans)))
        assert uv == uv_ans

586
        for i in range(10):
587
            subg = sample_neighbors_fusing_mode[fused](
588
589
                g, [0, 1], 2, prob=p, replace=replace
            )
590
591
592
            if not fused:
                assert subg.num_nodes() == g.num_nodes()

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
593
            assert subg.num_edges() == 4
594
            u, v = subg.edges()
595
596
597
            if fused:
                u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v]

598
            assert set(F.asnumpy(F.unique(v))) == {0, 1}
599
600
601
602
            assert F.array_equal(
                F.astype(g.has_edges_between(u, v), F.int64),
                F.ones((4,), dtype=F.int64),
            )
603
604
605
606
607
608
609
610
            assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])
            edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
            if not replace:
                # check no duplication
                assert len(edge_set) == 4
            if p is not None:
                assert not (3, 0) in edge_set
                assert not (3, 1) in edge_set
611
612

    _test1(prob, True)  # w/ replacement, uniform
613
    _test1(prob, False)  # w/o replacement, uniform
614
615

    def _test2(p, replace):  # fanout > #neighbors
616
        subg = sample_neighbors_fusing_mode[fused](
617
618
            g, [0, 2], -1, prob=p, replace=replace
        )
619
620
        if not fused:
            assert subg.num_nodes() == g.num_nodes()
621
        u, v = subg.edges()
622
623
        if fused:
            u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v]
624
        u_ans, v_ans, e_ans = g.in_edges([0, 2], form="all")
625
626
        if p is not None:
            emask = F.gather_row(g.edata[p], e_ans)
627
628
            if p == "prob":
                emask = emask != 0
629
630
            u_ans = F.boolean_mask(u_ans, emask)
            v_ans = F.boolean_mask(v_ans, emask)
631
632
633
634
        uv = set(zip(F.asnumpy(u), F.asnumpy(v)))
        uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans)))
        assert uv == uv_ans

635
        for i in range(10):
636
            subg = sample_neighbors_fusing_mode[fused](
637
638
                g, [0, 2], 2, prob=p, replace=replace
            )
639
640
            if not fused:
                assert subg.num_nodes() == g.num_nodes()
641
            num_edges = 4 if replace else 3
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
642
            assert subg.num_edges() == num_edges
643
            u, v = subg.edges()
644
645
            if fused:
                u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v]
646
            assert set(F.asnumpy(F.unique(v))) == {0, 2}
647
648
649
650
            assert F.array_equal(
                F.astype(g.has_edges_between(u, v), F.int64),
                F.ones((num_edges,), dtype=F.int64),
            )
651
652
653
654
655
656
657
            assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])
            edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
            if not replace:
                # check no duplication
                assert len(edge_set) == num_edges
            if p is not None:
                assert not (3, 0) in edge_set
658
659

    _test2(prob, True)  # w/ replacement, uniform
660
    _test2(prob, False)  # w/o replacement, uniform
661
662

    def _test3(p, replace):
663
        subg = sample_neighbors_fusing_mode[fused](
664
665
            hg, {"user": [0, 1], "game": 0}, -1, prob=p, replace=replace
        )
666
667
668
669
        if not fused:
            assert len(subg.ntypes) == 3
        assert len(subg.srctypes) == 3
        assert len(subg.dsttypes) == 3
670
        assert len(subg.etypes) == 4
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
671
672
673
674
        assert subg["follow"].num_edges() == 6 if p is None else 4
        assert subg["play"].num_edges() == 1
        assert subg["liked-by"].num_edges() == 4
        assert subg["flips"].num_edges() == 0
675

676
        for i in range(10):
677
            subg = sample_neighbors_fusing_mode[fused](
678
679
                hg, {"user": [0, 1], "game": 0}, 2, prob=p, replace=replace
            )
680
681
682
683
            if not fused:
                assert len(subg.ntypes) == 3
            assert len(subg.srctypes) == 3
            assert len(subg.dsttypes) == 3
684
            assert len(subg.etypes) == 4
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
685
686
687
688
            assert subg["follow"].num_edges() == 4
            assert subg["play"].num_edges() == 2 if replace else 1
            assert subg["liked-by"].num_edges() == 4 if replace else 3
            assert subg["flips"].num_edges() == 0
689

690
    _test3(prob, True)  # w/ replacement, uniform
691
    _test3(prob, False)  # w/o replacement, uniform
692
693
694

    # test different fanouts for different relations
    for i in range(10):
695
        subg = sample_neighbors_fusing_mode[fused](
696
            hg,
697
698
699
700
            {"user": [0, 1], "game": 0, "coin": 0},
            {"follow": 1, "play": 2, "liked-by": 0, "flips": -1},
            replace=True,
        )
701
702
703
704
        if not fused:
            assert len(subg.ntypes) == 3
        assert len(subg.srctypes) == 3
        assert len(subg.dsttypes) == 3
705
        assert len(subg.etypes) == 4
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
706
707
708
709
        assert subg["follow"].num_edges() == 2
        assert subg["play"].num_edges() == 2
        assert subg["liked-by"].num_edges() == 0
        assert subg["flips"].num_edges() == 4
710

711

712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
def _test_sample_labors(hypersparse, prob):
    g, hg = _gen_neighbor_sampling_test_graph(hypersparse, False)

    # test with seed nodes [0, 1]
    def _test1(p):
        subg = dgl.sampling.sample_labors(g, [0, 1], -1, prob=p)[0]
        assert subg.num_nodes() == g.num_nodes()
        u, v = subg.edges()
        u_ans, v_ans, e_ans = g.in_edges([0, 1], form="all")
        if p is not None:
            emask = F.gather_row(g.edata[p], e_ans)
            if p == "prob":
                emask = emask != 0
            u_ans = F.boolean_mask(u_ans, emask)
            v_ans = F.boolean_mask(v_ans, emask)
        uv = set(zip(F.asnumpy(u), F.asnumpy(v)))
        uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans)))
        assert uv == uv_ans

        for i in range(10):
            subg = dgl.sampling.sample_labors(g, [0, 1], 2, prob=p)[0]
            assert subg.num_nodes() == g.num_nodes()
            assert subg.num_edges() >= 0
            u, v = subg.edges()
            assert set(F.asnumpy(F.unique(v))).issubset({0, 1})
            assert F.array_equal(
                F.astype(g.has_edges_between(u, v), F.int64),
                F.ones((subg.num_edges(),), dtype=F.int64),
            )
            assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])
            edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
            # check no duplication
            assert len(edge_set) == subg.num_edges()
            if p is not None:
                assert not (3, 0) in edge_set
                assert not (3, 1) in edge_set

    _test1(prob)

    # test with seed nodes [0, 2]
    def _test2(p):
        subg = dgl.sampling.sample_labors(g, [0, 2], -1, prob=p)[0]
        assert subg.num_nodes() == g.num_nodes()
        u, v = subg.edges()
        u_ans, v_ans, e_ans = g.in_edges([0, 2], form="all")
        if p is not None:
            emask = F.gather_row(g.edata[p], e_ans)
            if p == "prob":
                emask = emask != 0
            u_ans = F.boolean_mask(u_ans, emask)
            v_ans = F.boolean_mask(v_ans, emask)
        uv = set(zip(F.asnumpy(u), F.asnumpy(v)))
        uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans)))
        assert uv == uv_ans

        for i in range(10):
            subg = dgl.sampling.sample_labors(g, [0, 2], 2, prob=p)[0]
            assert subg.num_nodes() == g.num_nodes()
            assert subg.num_edges() >= 0
            u, v = subg.edges()
            assert set(F.asnumpy(F.unique(v))).issubset({0, 2})
            assert F.array_equal(
                F.astype(g.has_edges_between(u, v), F.int64),
                F.ones((subg.num_edges(),), dtype=F.int64),
            )
            assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])
            edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
            # check no duplication
            assert len(edge_set) == subg.num_edges()
            if p is not None:
                assert not (3, 0) in edge_set

    _test2(prob)

    # test with heterogenous seed nodes
    def _test3(p):
        subg = dgl.sampling.sample_labors(
            hg, {"user": [0, 1], "game": 0}, -1, prob=p
        )[0]
        assert len(subg.ntypes) == 3
        assert len(subg.etypes) == 4
        assert subg["follow"].num_edges() == 6 if p is None else 4
        assert subg["play"].num_edges() == 1
        assert subg["liked-by"].num_edges() == 4
        assert subg["flips"].num_edges() == 0

        for i in range(10):
            subg = dgl.sampling.sample_labors(
                hg, {"user": [0, 1], "game": 0}, 2, prob=p
            )[0]
            assert len(subg.ntypes) == 3
            assert len(subg.etypes) == 4
            assert subg["follow"].num_edges() >= 0
            assert subg["play"].num_edges() >= 0
            assert subg["liked-by"].num_edges() >= 0
            assert subg["flips"].num_edges() >= 0

    _test3(prob)

    # test different fanouts for different relations
    for i in range(10):
        subg = dgl.sampling.sample_labors(
            hg,
            {"user": [0, 1], "game": 0, "coin": 0},
            {"follow": 1, "play": 2, "liked-by": 0, "flips": g.num_nodes()},
        )[0]
        assert len(subg.ntypes) == 3
        assert len(subg.etypes) == 4
        assert subg["follow"].num_edges() >= 0
        assert subg["play"].num_edges() >= 0
        assert subg["liked-by"].num_edges() == 0
        assert subg["flips"].num_edges() == 4


826
def _test_sample_neighbors_outedge(hypersparse, fused):
827
828
829
    g, hg = _gen_neighbor_sampling_test_graph(hypersparse, True)

    def _test1(p, replace):
830
        subg = sample_neighbors_fusing_mode[fused](
831
832
            g, [0, 1], -1, prob=p, replace=replace, edge_dir="out"
        )
833
834
835
        if not fused:
            assert subg.num_nodes() == g.num_nodes()

836
        u, v = subg.edges()
837
838
        if fused:
            u, v = subg.dstdata[dgl.NID][u], subg.srcdata[dgl.NID][v]
839
        u_ans, v_ans, e_ans = g.out_edges([0, 1], form="all")
840
841
        if p is not None:
            emask = F.gather_row(g.edata[p], e_ans)
842
843
            if p == "prob":
                emask = emask != 0
844
845
            u_ans = F.boolean_mask(u_ans, emask)
            v_ans = F.boolean_mask(v_ans, emask)
846
847
848
849
        uv = set(zip(F.asnumpy(u), F.asnumpy(v)))
        uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans)))
        assert uv == uv_ans

850
        for i in range(10):
851
            subg = sample_neighbors_fusing_mode[fused](
852
853
                g, [0, 1], 2, prob=p, replace=replace, edge_dir="out"
            )
854
855
            if not fused:
                assert subg.num_nodes() == g.num_nodes()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
856
            assert subg.num_edges() == 4
857
            u, v = subg.edges()
858
859
            if fused:
                u, v = subg.dstdata[dgl.NID][u], subg.srcdata[dgl.NID][v]
860
            assert set(F.asnumpy(F.unique(u))) == {0, 1}
861
862
863
864
            assert F.array_equal(
                F.astype(g.has_edges_between(u, v), F.int64),
                F.ones((4,), dtype=F.int64),
            )
865
866
867
868
869
870
871
872
            assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])
            edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
            if not replace:
                # check no duplication
                assert len(edge_set) == 4
            if p is not None:
                assert not (0, 3) in edge_set
                assert not (1, 3) in edge_set
873
874

    _test1(None, True)  # w/ replacement, uniform
875
    _test1(None, False)  # w/o replacement, uniform
876
877
    _test1("prob", True)  # w/ replacement
    _test1("prob", False)  # w/o replacement
878
879

    def _test2(p, replace):  # fanout > #neighbors
880
        subg = sample_neighbors_fusing_mode[fused](
881
882
            g, [0, 2], -1, prob=p, replace=replace, edge_dir="out"
        )
883
884
        if not fused:
            assert subg.num_nodes() == g.num_nodes()
885
        u, v = subg.edges()
886
887
        if fused:
            u, v = subg.dstdata[dgl.NID][u], subg.srcdata[dgl.NID][v]
888
        u_ans, v_ans, e_ans = g.out_edges([0, 2], form="all")
889
890
        if p is not None:
            emask = F.gather_row(g.edata[p], e_ans)
891
892
            if p == "prob":
                emask = emask != 0
893
894
            u_ans = F.boolean_mask(u_ans, emask)
            v_ans = F.boolean_mask(v_ans, emask)
895
896
897
898
        uv = set(zip(F.asnumpy(u), F.asnumpy(v)))
        uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans)))
        assert uv == uv_ans

899
        for i in range(10):
900
            subg = sample_neighbors_fusing_mode[fused](
901
902
                g, [0, 2], 2, prob=p, replace=replace, edge_dir="out"
            )
903
904
            if not fused:
                assert subg.num_nodes() == g.num_nodes()
905
            num_edges = 4 if replace else 3
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
906
            assert subg.num_edges() == num_edges
907
            u, v = subg.edges()
908
909
910
            if fused:
                u, v = subg.dstdata[dgl.NID][u], subg.srcdata[dgl.NID][v]

911
            assert set(F.asnumpy(F.unique(u))) == {0, 2}
912
913
914
915
            assert F.array_equal(
                F.astype(g.has_edges_between(u, v), F.int64),
                F.ones((num_edges,), dtype=F.int64),
            )
916
917
918
919
920
921
922
            assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])
            edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
            if not replace:
                # check no duplication
                assert len(edge_set) == num_edges
            if p is not None:
                assert not (0, 3) in edge_set
923
924

    _test2(None, True)  # w/ replacement, uniform
925
    _test2(None, False)  # w/o replacement, uniform
926
927
    _test2("prob", True)  # w/ replacement
    _test2("prob", False)  # w/o replacement
928
929

    def _test3(p, replace):
930
        subg = sample_neighbors_fusing_mode[fused](
931
932
933
934
935
936
937
            hg,
            {"user": [0, 1], "game": 0},
            -1,
            prob=p,
            replace=replace,
            edge_dir="out",
        )
938
939
940
941
942

        if not fused:
            assert len(subg.ntypes) == 3
        assert len(subg.srctypes) == 3
        assert len(subg.dsttypes) == 3
943
        assert len(subg.etypes) == 4
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
944
945
946
947
        assert subg["follow"].num_edges() == 6 if p is None else 4
        assert subg["play"].num_edges() == 1
        assert subg["liked-by"].num_edges() == 4
        assert subg["flips"].num_edges() == 0
948

949
        for i in range(10):
950
            subg = sample_neighbors_fusing_mode[fused](
951
952
953
954
955
956
957
                hg,
                {"user": [0, 1], "game": 0},
                2,
                prob=p,
                replace=replace,
                edge_dir="out",
            )
958
959
960
961
            if not fused:
                assert len(subg.ntypes) == 3
            assert len(subg.srctypes) == 3
            assert len(subg.dsttypes) == 3
962
            assert len(subg.etypes) == 4
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
963
964
965
966
            assert subg["follow"].num_edges() == 4
            assert subg["play"].num_edges() == 2 if replace else 1
            assert subg["liked-by"].num_edges() == 4 if replace else 3
            assert subg["flips"].num_edges() == 0
967

968
    _test3(None, True)  # w/ replacement, uniform
969
    _test3(None, False)  # w/o replacement, uniform
970
971
972
    _test3("prob", True)  # w/ replacement
    _test3("prob", False)  # w/o replacement

973
974
975
976
977

def _test_sample_neighbors_topk(hypersparse):
    g, hg = _gen_neighbor_topk_test_graph(hypersparse, False)

    def _test1():
978
        subg = dgl.sampling.select_topk(g, -1, "weight", [0, 1])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
979
        assert subg.num_nodes() == g.num_nodes()
980
981
982
983
984
985
        u, v = subg.edges()
        u_ans, v_ans = subg.in_edges([0, 1])
        uv = set(zip(F.asnumpy(u), F.asnumpy(v)))
        uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans)))
        assert uv == uv_ans

986
        subg = dgl.sampling.select_topk(g, 2, "weight", [0, 1])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
987
988
        assert subg.num_nodes() == g.num_nodes()
        assert subg.num_edges() == 4
989
990
991
        u, v = subg.edges()
        edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
        assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])
992
993
        assert edge_set == {(2, 0), (1, 0), (2, 1), (3, 1)}

994
995
996
    _test1()

    def _test2():  # k > #neighbors
997
        subg = dgl.sampling.select_topk(g, -1, "weight", [0, 2])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
998
        assert subg.num_nodes() == g.num_nodes()
999
1000
1001
1002
1003
1004
        u, v = subg.edges()
        u_ans, v_ans = subg.in_edges([0, 2])
        uv = set(zip(F.asnumpy(u), F.asnumpy(v)))
        uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans)))
        assert uv == uv_ans

1005
        subg = dgl.sampling.select_topk(g, 2, "weight", [0, 2])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1006
1007
        assert subg.num_nodes() == g.num_nodes()
        assert subg.num_edges() == 3
1008
1009
1010
        u, v = subg.edges()
        assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])
        edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
1011
1012
        assert edge_set == {(2, 0), (1, 0), (0, 2)}

1013
1014
1015
    _test2()

    def _test3():
1016
1017
1018
        subg = dgl.sampling.select_topk(
            hg, 2, "weight", {"user": [0, 1], "game": 0}
        )
1019
1020
        assert len(subg.ntypes) == 3
        assert len(subg.etypes) == 4
1021
        u, v = subg["follow"].edges()
1022
        edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
1023
1024
1025
1026
1027
        assert F.array_equal(
            hg["follow"].edge_ids(u, v), subg["follow"].edata[dgl.EID]
        )
        assert edge_set == {(2, 0), (1, 0), (2, 1), (3, 1)}
        u, v = subg["play"].edges()
1028
        edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
1029
1030
1031
1032
1033
        assert F.array_equal(
            hg["play"].edge_ids(u, v), subg["play"].edata[dgl.EID]
        )
        assert edge_set == {(0, 0)}
        u, v = subg["liked-by"].edges()
1034
        edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
1035
1036
1037
1038
        assert F.array_equal(
            hg["liked-by"].edge_ids(u, v), subg["liked-by"].edata[dgl.EID]
        )
        assert edge_set == {(2, 0), (2, 1), (1, 0)}
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1039
        assert subg["flips"].num_edges() == 0
1040

1041
1042
1043
    _test3()

    # test different k for different relations
1044
    subg = dgl.sampling.select_topk(
1045
1046
1047
1048
1049
        hg,
        {"follow": 1, "play": 2, "liked-by": 0, "flips": -1},
        "weight",
        {"user": [0, 1], "game": 0, "coin": 0},
    )
1050
1051
    assert len(subg.ntypes) == 3
    assert len(subg.etypes) == 4
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1052
1053
1054
1055
    assert subg["follow"].num_edges() == 2
    assert subg["play"].num_edges() == 1
    assert subg["liked-by"].num_edges() == 0
    assert subg["flips"].num_edges() == 4
1056

1057
1058
1059
1060
1061

def _test_sample_neighbors_topk_outedge(hypersparse):
    g, hg = _gen_neighbor_topk_test_graph(hypersparse, True)

    def _test1():
1062
        subg = dgl.sampling.select_topk(g, -1, "weight", [0, 1], edge_dir="out")
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1063
        assert subg.num_nodes() == g.num_nodes()
1064
1065
1066
1067
1068
1069
        u, v = subg.edges()
        u_ans, v_ans = subg.out_edges([0, 1])
        uv = set(zip(F.asnumpy(u), F.asnumpy(v)))
        uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans)))
        assert uv == uv_ans

1070
        subg = dgl.sampling.select_topk(g, 2, "weight", [0, 1], edge_dir="out")
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1071
1072
        assert subg.num_nodes() == g.num_nodes()
        assert subg.num_edges() == 4
1073
1074
1075
        u, v = subg.edges()
        edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
        assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])
1076
1077
        assert edge_set == {(0, 2), (0, 1), (1, 2), (1, 3)}

1078
1079
1080
    _test1()

    def _test2():  # k > #neighbors
1081
        subg = dgl.sampling.select_topk(g, -1, "weight", [0, 2], edge_dir="out")
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1082
        assert subg.num_nodes() == g.num_nodes()
1083
1084
1085
1086
1087
1088
        u, v = subg.edges()
        u_ans, v_ans = subg.out_edges([0, 2])
        uv = set(zip(F.asnumpy(u), F.asnumpy(v)))
        uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans)))
        assert uv == uv_ans

1089
        subg = dgl.sampling.select_topk(g, 2, "weight", [0, 2], edge_dir="out")
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1090
1091
        assert subg.num_nodes() == g.num_nodes()
        assert subg.num_edges() == 3
1092
1093
1094
        u, v = subg.edges()
        edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
        assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])
1095
1096
        assert edge_set == {(0, 2), (0, 1), (2, 0)}

1097
1098
1099
    _test2()

    def _test3():
1100
1101
1102
        subg = dgl.sampling.select_topk(
            hg, 2, "weight", {"user": [0, 1], "game": 0}, edge_dir="out"
        )
1103
1104
        assert len(subg.ntypes) == 3
        assert len(subg.etypes) == 4
1105
        u, v = subg["follow"].edges()
1106
        edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
1107
1108
1109
1110
1111
        assert F.array_equal(
            hg["follow"].edge_ids(u, v), subg["follow"].edata[dgl.EID]
        )
        assert edge_set == {(0, 2), (0, 1), (1, 2), (1, 3)}
        u, v = subg["play"].edges()
1112
        edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
1113
1114
1115
1116
1117
        assert F.array_equal(
            hg["play"].edge_ids(u, v), subg["play"].edata[dgl.EID]
        )
        assert edge_set == {(0, 0)}
        u, v = subg["liked-by"].edges()
1118
        edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
1119
1120
1121
1122
        assert F.array_equal(
            hg["liked-by"].edge_ids(u, v), subg["liked-by"].edata[dgl.EID]
        )
        assert edge_set == {(0, 2), (1, 2), (0, 1)}
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1123
        assert subg["flips"].num_edges() == 0
1124

1125
1126
    _test3()

1127

1128
def test_sample_neighbors_noprob():
1129
1130
1131
    _test_sample_neighbors(False, None, False)
    if F._default_context_str != "gpu" and F.backend_name == "pytorch":
        _test_sample_neighbors(False, None, True)
1132
1133
    # _test_sample_neighbors(True)

1134

1135
1136
1137
1138
def test_sample_labors_noprob():
    _test_sample_labors(False, None)


1139
def test_sample_neighbors_prob():
1140
1141
1142
    _test_sample_neighbors(False, "prob", False)
    if F._default_context_str != "gpu" and F.backend_name == "pytorch":
        _test_sample_neighbors(False, "prob", True)
1143
1144
    # _test_sample_neighbors(True)

1145

1146
1147
1148
1149
def test_sample_labors_prob():
    _test_sample_labors(False, "prob")


1150
def test_sample_neighbors_outedge():
1151
1152
1153
    _test_sample_neighbors_outedge(False, False)
    if F._default_context_str != "gpu" and F.backend_name == "pytorch":
        _test_sample_neighbors_outedge(False, True)
1154
1155
    # _test_sample_neighbors_outedge(True)

1156

1157
1158
1159
1160
1161
1162
1163
@unittest.skipIf(
    F.backend_name == "mxnet", reason="MXNet has problem converting bool arrays"
)
@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="GPU sample neighbors with mask not implemented",
)
1164
def test_sample_neighbors_mask():
1165
1166
1167
    _test_sample_neighbors(False, "mask", False)
    if F._default_context_str != "gpu" and F.backend_name == "pytorch":
        _test_sample_neighbors(False, "mask", True)
1168

1169
1170
1171
1172
1173

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="GPU sample neighbors not implemented",
)
1174
1175
def test_sample_neighbors_topk():
    _test_sample_neighbors_topk(False)
1176
1177
    # _test_sample_neighbors_topk(True)

1178

1179
1180
1181
1182
@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="GPU sample neighbors not implemented",
)
1183
1184
def test_sample_neighbors_topk_outedge():
    _test_sample_neighbors_topk_outedge(False)
1185
1186
    # _test_sample_neighbors_topk_outedge(True)

1187

1188
1189
1190
1191
1192
1193
@pytest.mark.parametrize("fused", [False, True])
def test_sample_neighbors_with_0deg(fused):
    if fused and (
        F._default_context_str == "gpu" or F.backend_name != "pytorch"
    ):
        pytest.skip("Fused sampling support CPU with backend PyTorch.")
1194
    g = dgl.graph(([], []), num_nodes=5).to(F.ctx())
1195
    sg = sample_neighbors_fusing_mode[fused](
1196
1197
        g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="in", replace=False
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1198
    assert sg.num_edges() == 0
1199
    sg = sample_neighbors_fusing_mode[fused](
1200
1201
        g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="in", replace=True
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1202
    assert sg.num_edges() == 0
1203
    sg = sample_neighbors_fusing_mode[fused](
1204
1205
        g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="out", replace=False
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1206
    assert sg.num_edges() == 0
1207
    sg = sample_neighbors_fusing_mode[fused](
1208
1209
        g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="out", replace=True
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1210
    assert sg.num_edges() == 0
1211

1212

1213
1214
def create_test_graph(num_nodes, num_edges_per_node, bipartite=False):
    src = np.concatenate(
1215
1216
        [np.array([i] * num_edges_per_node) for i in range(num_nodes)]
    )
1217
    dst = np.concatenate(
1218
1219
1220
1221
        [
            np.random.choice(num_nodes, num_edges_per_node, replace=False)
            for i in range(num_nodes)
        ]
1222
1223
    )
    if bipartite:
1224
        g = dgl.heterograph({("u", "e", "v"): (src, dst)})
1225
1226
1227
1228
    else:
        g = dgl.graph((src, dst))
    return g

1229

1230
1231
def create_etype_test_graph(num_nodes, num_edges_per_node, rare_cnt):
    src = np.concatenate(
1232
1233
1234
1235
        [
            np.random.choice(num_nodes, num_edges_per_node, replace=False)
            for i in range(num_nodes)
        ]
1236
1237
    )
    dst = np.concatenate(
1238
1239
        [np.array([i] * num_edges_per_node) for i in range(num_nodes)]
    )
1240
1241

    minor_src = np.concatenate(
1242
1243
1244
1245
        [
            np.random.choice(num_nodes, 2, replace=False)
            for i in range(num_nodes)
        ]
1246
    )
1247
    minor_dst = np.concatenate([np.array([i] * 2) for i in range(num_nodes)])
1248
1249

    most_zero_src = np.concatenate(
1250
1251
1252
1253
        [
            np.random.choice(num_nodes, num_edges_per_node, replace=False)
            for i in range(rare_cnt)
        ]
1254
1255
    )
    most_zero_dst = np.concatenate(
1256
1257
        [np.array([i] * num_edges_per_node) for i in range(rare_cnt)]
    )
1258

1259
1260
1261
1262
1263
1264
1265
1266
1267
    g = dgl.heterograph(
        {
            ("v", "e_major", "u"): (src, dst),
            ("u", "e_major_rev", "v"): (dst, src),
            ("v2", "e_minor", "u"): (minor_src, minor_dst),
            ("v2", "most_zero", "u"): (most_zero_src, most_zero_dst),
            ("u", "e_minor_rev", "v2"): (minor_dst, minor_src),
        }
    )
1268
1269
1270
    for etype in g.etypes:
        prob = np.random.rand(g.num_edges(etype))
        prob[prob > 0.2] = 0
1271
1272
        g.edges[etype].data["p"] = F.zerocopy_from_numpy(prob)
        g.edges[etype].data["mask"] = F.zerocopy_from_numpy(prob != 0)
1273
1274
1275

    return g

1276
1277
1278
1279
1280

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="GPU sample neighbors not implemented",
)
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
def test_sample_neighbors_biased_homogeneous():
    g = create_test_graph(100, 30)

    def check_num(nodes, tag):
        nodes, tag = F.asnumpy(nodes), F.asnumpy(tag)
        cnt = [sum(tag[nodes] == i) for i in range(4)]
        # No tag 0
        assert cnt[0] == 0

        # very rare tag 1
        assert cnt[2] > 2 * cnt[1]
        assert cnt[3] > 2 * cnt[1]

    tag = F.tensor(np.random.choice(4, 100))
    bias = F.tensor([0, 0.1, 10, 10], dtype=F.float32)
    # inedge / without replacement
1297
    g_sorted = dgl.sort_csc_by_tag(g, tag)
1298
    for _ in range(5):
1299
1300
1301
        subg = dgl.sampling.sample_neighbors_biased(
            g_sorted, g.nodes(), 5, bias, replace=False
        )
1302
1303
1304
        check_num(subg.edges()[0], tag)
        u, v = subg.edges()
        edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1305
        assert len(edge_set) == subg.num_edges()
1306
1307
1308

    # inedge / with replacement
    for _ in range(5):
1309
1310
1311
        subg = dgl.sampling.sample_neighbors_biased(
            g_sorted, g.nodes(), 5, bias, replace=True
        )
1312
1313
1314
        check_num(subg.edges()[0], tag)

    # outedge / without replacement
1315
    g_sorted = dgl.sort_csr_by_tag(g, tag)
1316
    for _ in range(5):
1317
1318
1319
        subg = dgl.sampling.sample_neighbors_biased(
            g_sorted, g.nodes(), 5, bias, edge_dir="out", replace=False
        )
1320
1321
1322
        check_num(subg.edges()[1], tag)
        u, v = subg.edges()
        edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1323
        assert len(edge_set) == subg.num_edges()
1324
1325
1326

    # outedge / with replacement
    for _ in range(5):
1327
1328
1329
        subg = dgl.sampling.sample_neighbors_biased(
            g_sorted, g.nodes(), 5, bias, edge_dir="out", replace=True
        )
1330
1331
        check_num(subg.edges()[1], tag)

1332
1333
1334
1335
1336

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="GPU sample neighbors not implemented",
)
1337
1338
def test_sample_neighbors_biased_bipartite():
    g = create_test_graph(100, 30, True)
1339
    num_dst = g.num_dst_nodes()
1340
    bias = F.tensor([0, 0.01, 10, 10], dtype=F.float32)
1341

1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
    def check_num(nodes, tag):
        nodes, tag = F.asnumpy(nodes), F.asnumpy(tag)
        cnt = [sum(tag[nodes] == i) for i in range(4)]
        # No tag 0
        assert cnt[0] == 0

        # very rare tag 1
        assert cnt[2] > 2 * cnt[1]
        assert cnt[3] > 2 * cnt[1]

    # inedge / without replacement
    tag = F.tensor(np.random.choice(4, 100))
1354
    g_sorted = dgl.sort_csc_by_tag(g, tag)
1355
    for _ in range(5):
1356
1357
1358
        subg = dgl.sampling.sample_neighbors_biased(
            g_sorted, g.dstnodes(), 5, bias, replace=False
        )
1359
1360
1361
        check_num(subg.edges()[0], tag)
        u, v = subg.edges()
        edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1362
        assert len(edge_set) == subg.num_edges()
1363
1364
1365

    # inedge / with replacement
    for _ in range(5):
1366
1367
1368
        subg = dgl.sampling.sample_neighbors_biased(
            g_sorted, g.dstnodes(), 5, bias, replace=True
        )
1369
1370
1371
1372
        check_num(subg.edges()[0], tag)

    # outedge / without replacement
    tag = F.tensor(np.random.choice(4, num_dst))
1373
    g_sorted = dgl.sort_csr_by_tag(g, tag)
1374
    for _ in range(5):
1375
1376
1377
        subg = dgl.sampling.sample_neighbors_biased(
            g_sorted, g.srcnodes(), 5, bias, edge_dir="out", replace=False
        )
1378
1379
1380
        check_num(subg.edges()[1], tag)
        u, v = subg.edges()
        edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1381
        assert len(edge_set) == subg.num_edges()
1382
1383
1384

    # outedge / with replacement
    for _ in range(5):
1385
1386
1387
        subg = dgl.sampling.sample_neighbors_biased(
            g_sorted, g.srcnodes(), 5, bias, edge_dir="out", replace=True
        )
1388
1389
        check_num(subg.edges()[1], tag)

1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="GPU sample neighbors not implemented",
)
@unittest.skipIf(
    F.backend_name == "mxnet", reason="MXNet has problem converting bool arrays"
)
@pytest.mark.parametrize("format_", ["coo", "csr", "csc"])
@pytest.mark.parametrize("direction", ["in", "out"])
@pytest.mark.parametrize("replace", [False, True])
1401
def test_sample_neighbors_etype_homogeneous(format_, direction, replace):
1402
1403
1404
    num_nodes = 100
    rare_cnt = 4
    g = create_etype_test_graph(100, 30, rare_cnt)
1405
    h_g = dgl.to_homogeneous(g, edata=["p", "mask"])
1406
1407
    h_g_etype = F.asnumpy(h_g.edata[dgl.ETYPE])
    h_g_offset = np.cumsum(np.insert(np.bincount(h_g_etype), 0, 0)).tolist()
1408
1409
    sg = g.edge_subgraph(g.edata["mask"], relabel_nodes=False)
    h_sg = h_g.edge_subgraph(h_g.edata["mask"], relabel_nodes=False)
1410
1411
1412
    h_sg_etype = F.asnumpy(h_sg.edata[dgl.ETYPE])
    h_sg_offset = np.cumsum(np.insert(np.bincount(h_sg_etype), 0, 0)).tolist()

1413
1414
    seed_ntype = g.get_ntype_id("u")
    seeds = F.nonzero_1d(h_g.ndata[dgl.NTYPE] == seed_ntype)
1415
1416
1417
1418
    fanouts = F.tensor([6, 5, 4, 3, 2], dtype=F.int64)

    def check_num(h_g, all_src, all_dst, subg, replace, fanouts, direction):
        src, dst = subg.edges()
1419
1420
        all_etype_array = F.asnumpy(h_g.edata[dgl.ETYPE])
        num_etypes = all_etype_array.max() + 1
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
        etype_array = F.asnumpy(subg.edata[dgl.ETYPE])
        src = F.asnumpy(src)
        dst = F.asnumpy(dst)
        fanouts = F.asnumpy(fanouts)

        all_src = F.asnumpy(all_src)
        all_dst = F.asnumpy(all_dst)

        src_per_etype = []
        dst_per_etype = []
1431
1432
        all_src_per_etype = []
        all_dst_per_etype = []
1433
1434
1435
        for etype in range(num_etypes):
            src_per_etype.append(src[etype_array == etype])
            dst_per_etype.append(dst[etype_array == etype])
1436
1437
            all_src_per_etype.append(all_src[all_etype_array == etype])
            all_dst_per_etype.append(all_dst[all_etype_array == etype])
1438
1439

        if replace:
1440
            if direction == "in":
1441
                in_degree_per_etype = [np.bincount(d) for d in dst_per_etype]
1442
1443
1444
1445
1446
1447
1448
                for etype in range(len(fanouts)):
                    in_degree = in_degree_per_etype[etype]
                    fanout = fanouts[etype]
                    ans = np.zeros_like(in_degree)
                    if len(in_degree) > 0:
                        ans[all_dst_per_etype[etype]] = fanout
                    assert np.all(in_degree == ans)
1449
            else:
1450
                out_degree_per_etype = [np.bincount(s) for s in src_per_etype]
1451
1452
1453
1454
1455
1456
1457
                for etype in range(len(fanouts)):
                    out_degree = out_degree_per_etype[etype]
                    fanout = fanouts[etype]
                    ans = np.zeros_like(out_degree)
                    if len(out_degree) > 0:
                        ans[all_src_per_etype[etype]] = fanout
                    assert np.all(out_degree == ans)
1458
        else:
1459
            if direction == "in":
1460
1461
1462
1463
1464
1465
1466
1467
                for v in set(dst):
                    u = src[dst == v]
                    et = etype_array[dst == v]
                    all_u = all_src[all_dst == v]
                    all_et = all_etype_array[all_dst == v]
                    for etype in set(et):
                        u_etype = set(u[et == etype])
                        all_u_etype = set(all_u[all_et == etype])
1468
1469
1470
                        assert (len(u_etype) == fanouts[etype]) or (
                            u_etype == all_u_etype
                        )
1471
            else:
1472
1473
1474
1475
1476
1477
1478
1479
                for u in set(src):
                    v = dst[src == u]
                    et = etype_array[src == u]
                    all_v = all_dst[all_src == u]
                    all_et = all_etype_array[all_src == u]
                    for etype in set(et):
                        v_etype = set(v[et == etype])
                        all_v_etype = set(all_v[all_et == etype])
1480
1481
1482
                        assert (len(v_etype) == fanouts[etype]) or (
                            v_etype == all_v_etype
                        )
1483
1484

    all_src, all_dst = h_g.edges()
1485
    all_sub_src, all_sub_dst = h_sg.edges()
1486
    h_g = h_g.formats(format_)
1487
1488
    if (direction, format_) in [("in", "csr"), ("out", "csc")]:
        h_g = h_g.formats(["csc", "csr", "coo"])
1489
1490
    for _ in range(5):
        subg = dgl.sampling.sample_etype_neighbors(
1491
1492
            h_g, seeds, h_g_offset, fanouts, replace=replace, edge_dir=direction
        )
1493
        check_num(h_g, all_src, all_dst, subg, replace, fanouts, direction)
1494

1495
        p = [g.edges[etype].data["p"] for etype in g.etypes]
1496
        subg = dgl.sampling.sample_etype_neighbors(
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
            h_g,
            seeds,
            h_g_offset,
            fanouts,
            replace=replace,
            edge_dir=direction,
            prob=p,
        )
        check_num(
            h_sg, all_sub_src, all_sub_dst, subg, replace, fanouts, direction
        )

        p = [g.edges[etype].data["mask"] for etype in g.etypes]
1510
        subg = dgl.sampling.sample_etype_neighbors(
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
            h_g,
            seeds,
            h_g_offset,
            fanouts,
            replace=replace,
            edge_dir=direction,
            prob=p,
        )
        check_num(
            h_sg, all_sub_src, all_sub_dst, subg, replace, fanouts, direction
        )


@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="GPU sample neighbors not implemented",
)
@unittest.skipIf(
    F.backend_name == "mxnet", reason="MXNet has problem converting bool arrays"
)
@pytest.mark.parametrize("format_", ["csr", "csc"])
@pytest.mark.parametrize("direction", ["in", "out"])
1533
1534
1535
1536
1537
1538
def test_sample_neighbors_etype_sorted_homogeneous(format_, direction):
    rare_cnt = 4
    g = create_etype_test_graph(100, 30, rare_cnt)
    h_g = dgl.to_homogeneous(g)
    seed_ntype = g.get_ntype_id("u")
    seeds = F.nonzero_1d(h_g.ndata[dgl.NTYPE] == seed_ntype)
1539
    fanouts = F.tensor([6, 5, -1, 3, 2], dtype=F.int64)
1540
    h_g = h_g.formats(format_)
1541
1542
    if (direction, format_) in [("in", "csr"), ("out", "csc")]:
        h_g = h_g.formats(["csc", "csr", "coo"])
1543

1544
1545
    if direction == "in":
        h_g = dgl.sort_csc_by_tag(h_g, h_g.edata[dgl.ETYPE], tag_type="edge")
1546
    else:
1547
        h_g = dgl.sort_csr_by_tag(h_g, h_g.edata[dgl.ETYPE], tag_type="edge")
1548
1549
1550
1551
    # shuffle
    h_g_etype = F.asnumpy(h_g.edata[dgl.ETYPE])
    h_g_offset = np.cumsum(np.insert(np.bincount(h_g_etype), 0, 0)).tolist()
    sg = dgl.sampling.sample_etype_neighbors(
1552
1553
        h_g, seeds, h_g_offset, fanouts, edge_dir=direction, etype_sorted=True
    )
1554

1555
1556

@pytest.mark.parametrize("dtype", ["int32", "int64"])
1557
1558
1559
1560
1561
1562
@pytest.mark.parametrize("fused", [False, True])
def test_sample_neighbors_exclude_edges_heteroG(dtype, fused):
    if fused and (
        F._default_context_str == "gpu" or F.backend_name != "pytorch"
    ):
        pytest.skip("Fused sampling support CPU with backend PyTorch.")
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
    d_i_d_u_nodes = F.zerocopy_from_numpy(
        np.unique(np.random.randint(300, size=100, dtype=dtype))
    )
    d_i_d_v_nodes = F.zerocopy_from_numpy(
        np.random.randint(25, size=d_i_d_u_nodes.shape, dtype=dtype)
    )
    d_i_g_u_nodes = F.zerocopy_from_numpy(
        np.unique(np.random.randint(300, size=100, dtype=dtype))
    )
    d_i_g_v_nodes = F.zerocopy_from_numpy(
        np.random.randint(25, size=d_i_g_u_nodes.shape, dtype=dtype)
    )
    d_t_d_u_nodes = F.zerocopy_from_numpy(
        np.unique(np.random.randint(300, size=100, dtype=dtype))
    )
    d_t_d_v_nodes = F.zerocopy_from_numpy(
        np.random.randint(25, size=d_t_d_u_nodes.shape, dtype=dtype)
    )

    g = dgl.heterograph(
        {
            ("drug", "interacts", "drug"): (d_i_d_u_nodes, d_i_d_v_nodes),
            ("drug", "interacts", "gene"): (d_i_g_u_nodes, d_i_g_v_nodes),
            ("drug", "treats", "disease"): (d_t_d_u_nodes, d_t_d_v_nodes),
        }
    ).to(F.ctx())
1589
1590
1591
1592
1593
1594
1595
1596
1597

    (U, V, EID) = (0, 1, 2)

    nd_b_idx = np.random.randint(low=1, high=24, dtype=dtype)
    nd_e_idx = np.random.randint(low=25, high=49, dtype=dtype)
    did_b_idx = np.random.randint(low=1, high=24, dtype=dtype)
    did_e_idx = np.random.randint(low=25, high=49, dtype=dtype)
    sampled_amount = np.random.randint(low=1, high=10, dtype=dtype)

1598
1599
1600
    drug_i_drug_edges = g.all_edges(
        form="all", etype=("drug", "interacts", "drug")
    )
1601
1602
1603
1604
1605
1606
1607
1608
1609
    excluded_d_i_d_edges = drug_i_drug_edges[EID][did_b_idx:did_e_idx]
    sampled_drug_node = drug_i_drug_edges[V][nd_b_idx:nd_e_idx]
    did_excluded_nodes_U = drug_i_drug_edges[U][did_b_idx:did_e_idx]
    did_excluded_nodes_V = drug_i_drug_edges[V][did_b_idx:did_e_idx]

    nd_b_idx = np.random.randint(low=1, high=24, dtype=dtype)
    nd_e_idx = np.random.randint(low=25, high=49, dtype=dtype)
    dig_b_idx = np.random.randint(low=1, high=24, dtype=dtype)
    dig_e_idx = np.random.randint(low=25, high=49, dtype=dtype)
1610
1611
1612
    drug_i_gene_edges = g.all_edges(
        form="all", etype=("drug", "interacts", "gene")
    )
1613
1614
1615
1616
1617
1618
1619
1620
1621
    excluded_d_i_g_edges = drug_i_gene_edges[EID][dig_b_idx:dig_e_idx]
    dig_excluded_nodes_U = drug_i_gene_edges[U][dig_b_idx:dig_e_idx]
    dig_excluded_nodes_V = drug_i_gene_edges[V][dig_b_idx:dig_e_idx]
    sampled_gene_node = drug_i_gene_edges[V][nd_b_idx:nd_e_idx]

    nd_b_idx = np.random.randint(low=1, high=24, dtype=dtype)
    nd_e_idx = np.random.randint(low=25, high=49, dtype=dtype)
    dtd_b_idx = np.random.randint(low=1, high=24, dtype=dtype)
    dtd_e_idx = np.random.randint(low=25, high=49, dtype=dtype)
1622
1623
1624
    drug_t_dis_edges = g.all_edges(
        form="all", etype=("drug", "treats", "disease")
    )
1625
1626
1627
1628
    excluded_d_t_d_edges = drug_t_dis_edges[EID][dtd_b_idx:dtd_e_idx]
    dtd_excluded_nodes_U = drug_t_dis_edges[U][dtd_b_idx:dtd_e_idx]
    dtd_excluded_nodes_V = drug_t_dis_edges[V][dtd_b_idx:dtd_e_idx]
    sampled_disease_node = drug_t_dis_edges[V][nd_b_idx:nd_e_idx]
1629
1630
1631
1632
1633
1634
    excluded_edges = {
        ("drug", "interacts", "drug"): excluded_d_i_d_edges,
        ("drug", "interacts", "gene"): excluded_d_i_g_edges,
        ("drug", "treats", "disease"): excluded_d_t_d_edges,
    }

1635
    sg = sample_neighbors_fusing_mode[fused](
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
        g,
        {
            "drug": sampled_drug_node,
            "gene": sampled_gene_node,
            "disease": sampled_disease_node,
        },
        sampled_amount,
        exclude_edges=excluded_edges,
    )

1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
    if fused:

        def contain_edge(g, sg, etype, u, v):
            # set of subgraph graph edges deduced from original graph
            org_edges = set(
                map(
                    tuple,
                    np.stack(
                        g.find_edges(sg.edges[etype].data[dgl.EID], etype),
                        axis=1,
                    ),
                )
1658
            )
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
            # set of excluded edges
            excluded_edges = set(map(tuple, np.stack((u, v), axis=1)))

            diff_set = org_edges - excluded_edges

            return len(diff_set) != len(org_edges)

        assert not contain_edge(
            g,
            sg,
            ("drug", "interacts", "drug"),
            did_excluded_nodes_U,
            did_excluded_nodes_V,
1672
        )
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
        assert not contain_edge(
            g,
            sg,
            ("drug", "interacts", "gene"),
            dig_excluded_nodes_U,
            dig_excluded_nodes_V,
        )
        assert not contain_edge(
            g,
            sg,
            ("drug", "treats", "disease"),
            dtd_excluded_nodes_U,
            dtd_excluded_nodes_V,
        )
    else:
        assert not np.any(
            F.asnumpy(
                sg.has_edges_between(
                    did_excluded_nodes_U,
                    did_excluded_nodes_V,
                    etype=("drug", "interacts", "drug"),
                )
1695
1696
            )
        )
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
        assert not np.any(
            F.asnumpy(
                sg.has_edges_between(
                    dig_excluded_nodes_U,
                    dig_excluded_nodes_V,
                    etype=("drug", "interacts", "gene"),
                )
            )
        )
        assert not np.any(
            F.asnumpy(
                sg.has_edges_between(
                    dtd_excluded_nodes_U,
                    dtd_excluded_nodes_V,
                    etype=("drug", "treats", "disease"),
                )
1713
1714
1715
1716
1717
            )
        )


@pytest.mark.parametrize("dtype", ["int32", "int64"])
1718
1719
1720
1721
1722
1723
@pytest.mark.parametrize("fused", [False, True])
def test_sample_neighbors_exclude_edges_homoG(dtype, fused):
    if fused and (
        F._default_context_str == "gpu" or F.backend_name != "pytorch"
    ):
        pytest.skip("Fused sampling support CPU with backend PyTorch.")
1724
1725
1726
1727
1728
1729
    u_nodes = F.zerocopy_from_numpy(
        np.unique(np.random.randint(300, size=100, dtype=dtype))
    )
    v_nodes = F.zerocopy_from_numpy(
        np.random.randint(25, size=u_nodes.shape, dtype=dtype)
    )
1730
    g = dgl.graph((u_nodes, v_nodes)).to(F.ctx())
1731
1732
1733

    (U, V, EID) = (0, 1, 2)

1734
1735
1736
1737
1738
    nd_b_idx = np.random.randint(low=1, high=24, dtype=dtype)
    nd_e_idx = np.random.randint(low=25, high=49, dtype=dtype)
    b_idx = np.random.randint(low=1, high=24, dtype=dtype)
    e_idx = np.random.randint(low=25, high=49, dtype=dtype)
    sampled_amount = np.random.randint(low=1, high=10, dtype=dtype)
1739

1740
    g_edges = g.all_edges(form="all")
1741
1742
1743
1744
1745
    excluded_edges = g_edges[EID][b_idx:e_idx]
    sampled_node = g_edges[V][nd_b_idx:nd_e_idx]
    excluded_nodes_U = g_edges[U][b_idx:e_idx]
    excluded_nodes_V = g_edges[V][b_idx:e_idx]

1746
    sg = sample_neighbors_fusing_mode[fused](
1747
1748
        g, sampled_node, sampled_amount, exclude_edges=excluded_edges
    )
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
    if fused:

        def contain_edge(g, sg, u, v):
            # set of subgraph graph edges deduced from original graph
            org_edges = set(
                map(
                    tuple,
                    np.stack(
                        g.find_edges(sg.edges["_E"].data[dgl.EID]), axis=1
                    ),
                )
            )
            # set of excluded edges
            excluded_edges = set(map(tuple, np.stack((u, v), axis=1)))
1763

1764
1765
1766
1767
1768
1769
1770
1771
1772
            diff_set = org_edges - excluded_edges

            return len(diff_set) != len(org_edges)

        assert not contain_edge(g, sg, excluded_nodes_U, excluded_nodes_V)
    else:
        assert not np.any(
            F.asnumpy(sg.has_edges_between(excluded_nodes_U, excluded_nodes_V))
        )
1773
1774


1775
@pytest.mark.parametrize("dtype", ["int32", "int64"])
1776
def test_global_uniform_negative_sampling(dtype):
1777
    g = dgl.graph(([], []), num_nodes=1000).to(F.ctx())
1778
1779
1780
    src, dst = dgl.sampling.global_uniform_negative_sampling(
        g, 2000, False, True
    )
1781
1782
    assert len(src) == 2000
    assert len(dst) == 2000
1783

1784
1785
1786
    g = dgl.graph(
        (np.random.randint(0, 20, (300,)), np.random.randint(0, 20, (300,)))
    ).to(F.ctx())
1787
1788
1789
    src, dst = dgl.sampling.global_uniform_negative_sampling(g, 20, False, True)
    assert not F.asnumpy(g.has_edges_between(src, dst)).any()

1790
1791
1792
    src, dst = dgl.sampling.global_uniform_negative_sampling(
        g, 20, False, False
    )
1793
1794
1795
1796
1797
1798
1799
    assert not F.asnumpy(g.has_edges_between(src, dst)).any()
    src = F.asnumpy(src)
    dst = F.asnumpy(dst)
    s = set(zip(src.tolist(), dst.tolist()))
    assert len(s) == len(src)

    g = dgl.graph(([0], [1])).to(F.ctx())
1800
1801
1802
    src, dst = dgl.sampling.global_uniform_negative_sampling(
        g, 20, True, False, redundancy=10
    )
1803
1804
1805
1806
1807
1808
1809
1810
1811
    src = F.asnumpy(src)
    dst = F.asnumpy(dst)
    # should have either no element or (1, 0)
    assert len(src) < 2
    assert len(dst) < 2
    if len(src) == 1:
        assert src[0] == 1
        assert dst[0] == 0

1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
    g = dgl.heterograph(
        {
            ("A", "AB", "B"): (
                np.random.randint(0, 20, (300,)),
                np.random.randint(0, 40, (300,)),
            ),
            ("B", "BA", "A"): (
                np.random.randint(0, 40, (200,)),
                np.random.randint(0, 20, (200,)),
            ),
        }
    ).to(F.ctx())
    src, dst = dgl.sampling.global_uniform_negative_sampling(
        g, 20, False, etype="AB"
    )
    assert not F.asnumpy(g.has_edges_between(src, dst, etype="AB")).any()
1828

1829

1830
if __name__ == "__main__":
1831
    from itertools import product
1832

1833
    test_sample_neighbors_noprob()
1834
    test_sample_labors_noprob()
1835
    test_sample_neighbors_prob()
1836
    test_sample_labors_prob()
1837
    test_sample_neighbors_mask()
1838
    for args in product(["coo", "csr", "csc"], ["in", "out"], [False, True]):
1839
        test_sample_neighbors_etype_homogeneous(*args)
1840
    for args in product(["csr", "csc"], ["in", "out"]):
1841
        test_sample_neighbors_etype_sorted_homogeneous(*args)
1842
    test_non_uniform_random_walk(False)
1843
    test_uniform_random_walk(False)
1844
    test_pack_traces()
1845
    test_pinsage_sampling(False)
1846
1847
1848
    test_sample_neighbors_outedge()
    test_sample_neighbors_topk()
    test_sample_neighbors_topk_outedge()
1849
    test_sample_neighbors_with_0deg()
1850
1851
    test_sample_neighbors_biased_homogeneous()
    test_sample_neighbors_biased_bipartite()
1852
1853
1854
1855
    test_sample_neighbors_exclude_edges_heteroG("int32")
    test_sample_neighbors_exclude_edges_homoG("int32")
    test_global_uniform_negative_sampling("int32")
    test_global_uniform_negative_sampling("int64")