test_distributed_sampling.py 49.6 KB
Newer Older
1
import multiprocessing as mp
Jinjing Zhou's avatar
Jinjing Zhou committed
2
import os
3
import random
4
import tempfile
Jinjing Zhou's avatar
Jinjing Zhou committed
5
import time
6
7
import traceback
import unittest
Jinjing Zhou's avatar
Jinjing Zhou committed
8
from pathlib import Path
9
10
11
12

import backend as F
import dgl
import numpy as np
13
import pytest
14
import torch
15
16
17
18
19
20
21
22
23
24
from dgl.data import CitationGraphDataset, WN18Dataset
from dgl.distributed import (
    DistGraph,
    DistGraphServer,
    load_partition,
    load_partition_book,
    partition_graph,
    sample_etype_neighbors,
    sample_neighbors,
)
25
from scipy import sparse as spsp
26
from utils import generate_ip_config, reset_envs
Jinjing Zhou's avatar
Jinjing Zhou committed
27
28


29
30
31
32
33
34
def start_server(
    rank,
    tmpdir,
    disable_shared_mem,
    graph_name,
    graph_format=["csc", "coo"],
35
    use_graphbolt=False,
36
37
38
39
40
41
42
43
44
):
    g = DistGraphServer(
        rank,
        "rpc_ip_config.txt",
        1,
        1,
        tmpdir / (graph_name + ".json"),
        disable_shared_mem=disable_shared_mem,
        graph_format=graph_format,
45
        use_graphbolt=use_graphbolt,
46
    )
Jinjing Zhou's avatar
Jinjing Zhou committed
47
48
49
    g.start()


50
def start_sample_client(rank, tmpdir, disable_shared_mem):
51
52
    gpb = None
    if disable_shared_mem:
53
54
55
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
56
    dgl.distributed.initialize("rpc_ip_config.txt")
57
    dist_graph = DistGraph("test_sampling", gpb=gpb)
58
    try:
59
        sampled_graph = sample_neighbors(
60
61
62
            dist_graph,
            torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=dist_graph.idtype),
            3,
63
        )
64
    except Exception as e:
65
        print(traceback.format_exc())
66
        sampled_graph = None
67
    dgl.distributed.exit_client()
Jinjing Zhou's avatar
Jinjing Zhou committed
68
69
    return sampled_graph

70

71
72
73
74
75
76
77
78
79
def start_sample_client_shuffle(
    rank,
    tmpdir,
    disable_shared_mem,
    g,
    num_servers,
    group_id,
    orig_nid,
    orig_eid,
80
    use_graphbolt=False,
81
    return_eids=False,
82
83
):
    os.environ["DGL_GROUP_ID"] = str(group_id)
84
85
    gpb = None
    if disable_shared_mem:
86
87
88
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
89
    dgl.distributed.initialize("rpc_ip_config.txt")
90
    dist_graph = DistGraph("test_sampling", gpb=gpb)
91
    sampled_graph = sample_neighbors(
92
93
94
95
        dist_graph,
        torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=dist_graph.idtype),
        3,
        use_graphbolt=use_graphbolt,
96
    )
97

98
99
100
    assert (
        dgl.ETYPE not in sampled_graph.edata
    ), "Etype should not be in homogeneous sampled graph."
101
102
103
    src, dst = sampled_graph.edges()
    src = orig_nid[src]
    dst = orig_nid[dst]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
104
    assert sampled_graph.num_nodes() == g.num_nodes()
105
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
106
    if use_graphbolt and not return_eids:
107
108
109
110
111
112
113
        assert (
            dgl.EID not in sampled_graph.edata
        ), "EID should not be in sampled graph if use_graphbolt=True."
    else:
        eids = g.edge_ids(src, dst)
        eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
        assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))
114

115

116
def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids, etype=None):
117
118
    gpb = None
    if disable_shared_mem:
119
120
121
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_find_edges.json", rank
        )
122
    dgl.distributed.initialize("rpc_ip_config.txt")
123
    dist_graph = DistGraph("test_find_edges", gpb=gpb)
124
    try:
125
        u, v = dist_graph.find_edges(eids, etype=etype)
126
    except Exception as e:
127
        print(traceback.format_exc())
128
        u, v = None, None
129
130
    dgl.distributed.exit_client()
    return u, v
Jinjing Zhou's avatar
Jinjing Zhou committed
131

132

133
134
135
def start_get_degrees_client(rank, tmpdir, disable_shared_mem, nids=None):
    gpb = None
    if disable_shared_mem:
136
137
138
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_get_degrees.json", rank
        )
139
    dgl.distributed.initialize("rpc_ip_config.txt")
140
141
142
143
144
145
146
    dist_graph = DistGraph("test_get_degrees", gpb=gpb)
    try:
        in_deg = dist_graph.in_degrees(nids)
        all_in_deg = dist_graph.in_degrees()
        out_deg = dist_graph.out_degrees(nids)
        all_out_deg = dist_graph.out_degrees()
    except Exception as e:
147
        print(traceback.format_exc())
148
149
150
151
        in_deg, out_deg, all_in_deg, all_out_deg = None, None, None, None
    dgl.distributed.exit_client()
    return in_deg, out_deg, all_in_deg, all_out_deg

152

153
def check_rpc_sampling(tmpdir, num_server):
154
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
Jinjing Zhou's avatar
Jinjing Zhou committed
155
156
157
158
159
160

    g = CitationGraphDataset("cora")[0]
    print(g.idtype)
    num_parts = num_server
    num_hops = 1

161
162
163
164
165
166
167
168
    partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )
Jinjing Zhou's avatar
Jinjing Zhou committed
169
170

    pserver_list = []
171
    ctx = mp.get_context("spawn")
Jinjing Zhou's avatar
Jinjing Zhou committed
172
    for i in range(num_server):
173
174
175
176
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
Jinjing Zhou's avatar
Jinjing Zhou committed
177
178
179
180
        p.start()
        time.sleep(1)
        pserver_list.append(p)

181
    sampled_graph = start_sample_client(0, tmpdir, num_server > 1)
Jinjing Zhou's avatar
Jinjing Zhou committed
182
183
184
    print("Done sampling")
    for p in pserver_list:
        p.join()
185
        assert p.exitcode == 0
Jinjing Zhou's avatar
Jinjing Zhou committed
186
187

    src, dst = sampled_graph.edges()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
188
    assert sampled_graph.num_nodes() == g.num_nodes()
Jinjing Zhou's avatar
Jinjing Zhou committed
189
190
191
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    assert np.array_equal(
192
193
194
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)
    )

Jinjing Zhou's avatar
Jinjing Zhou committed
195

196
def check_rpc_find_edges_shuffle(tmpdir, num_server):
197
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
198
199
200
201

    g = CitationGraphDataset("cora")[0]
    num_parts = num_server

202
203
204
205
206
207
208
209
210
    orig_nid, orig_eid = partition_graph(
        g,
        "test_find_edges",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
211
212

    pserver_list = []
213
    ctx = mp.get_context("spawn")
214
    for i in range(num_server):
215
216
217
218
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_find_edges", ["csr", "coo"]),
        )
219
220
221
222
        p.start()
        time.sleep(1)
        pserver_list.append(p)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
223
    eids = F.tensor(np.random.randint(g.num_edges(), size=100))
224
    u, v = g.find_edges(orig_eid[eids])
225
    du, dv = start_find_edges_client(0, tmpdir, num_server > 1, eids)
226
227
    du = orig_nid[du]
    dv = orig_nid[dv]
228
229
230
    assert F.array_equal(u, du)
    assert F.array_equal(v, dv)

231

232
def create_random_hetero(dense=False, empty=False):
233
234
235
236
237
238
    num_nodes = (
        {"n1": 210, "n2": 200, "n3": 220}
        if dense
        else {"n1": 1010, "n2": 1000, "n3": 1020}
    )
    etypes = [("n1", "r12", "n2"), ("n1", "r13", "n3"), ("n2", "r23", "n3")]
239
    edges = {}
240
    random.seed(42)
241
242
    for etype in etypes:
        src_ntype, _, dst_ntype = etype
243
244
245
246
247
248
249
        arr = spsp.random(
            num_nodes[src_ntype] - 10 if empty else num_nodes[src_ntype],
            num_nodes[dst_ntype] - 10 if empty else num_nodes[dst_ntype],
            density=0.1 if dense else 0.001,
            format="coo",
            random_state=100,
        )
250
        edges[etype] = (arr.row, arr.col)
251
    g = dgl.heterograph(edges, num_nodes)
252
    g.nodes["n1"].data["feat"] = F.ones(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
253
        (g.num_nodes("n1"), 10), F.float32, F.cpu()
254
    )
255
    return g
256

257

258
def check_rpc_hetero_find_edges_shuffle(tmpdir, num_server):
259
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
260
261
262
263

    g = create_random_hetero()
    num_parts = num_server

264
265
266
267
268
269
270
271
272
    orig_nid, orig_eid = partition_graph(
        g,
        "test_find_edges",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
273
274

    pserver_list = []
275
    ctx = mp.get_context("spawn")
276
    for i in range(num_server):
277
278
279
280
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_find_edges", ["csr", "coo"]),
        )
281
282
283
284
        p.start()
        time.sleep(1)
        pserver_list.append(p)

285
    test_etype = g.to_canonical_etype("r12")
286
    eids = F.tensor(np.random.randint(g.num_edges(test_etype), size=100))
287
288
    expect_except = False
    try:
289
        _, _ = g.find_edges(orig_eid[test_etype][eids], etype=("n1", "r12"))
290
291
292
    except:
        expect_except = True
    assert expect_except
293
294
    u, v = g.find_edges(orig_eid[test_etype][eids], etype="r12")
    u1, v1 = g.find_edges(orig_eid[test_etype][eids], etype=("n1", "r12", "n2"))
295
296
    assert F.array_equal(u, u1)
    assert F.array_equal(v, v1)
297
298
299
300
301
    du, dv = start_find_edges_client(
        0, tmpdir, num_server > 1, eids, etype="r12"
    )
    du = orig_nid["n1"][du]
    dv = orig_nid["n2"][dv]
302
303
304
    assert F.array_equal(u, du)
    assert F.array_equal(v, dv)

305

306
# Wait non shared memory graph store
307
308
309
310
311
312
313
314
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
315
@pytest.mark.parametrize("num_server", [1])
316
def test_rpc_find_edges_shuffle(num_server):
317
    reset_envs()
318
    import tempfile
319
320

    os.environ["DGL_DIST_MODE"] = "distributed"
321
    with tempfile.TemporaryDirectory() as tmpdirname:
322
        check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), num_server)
323
324
        check_rpc_find_edges_shuffle(Path(tmpdirname), num_server)

325

326
def check_rpc_get_degree_shuffle(tmpdir, num_server):
327
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
328
329
330
331

    g = CitationGraphDataset("cora")[0]
    num_parts = num_server

332
333
334
335
336
337
338
339
340
    orig_nid, _ = partition_graph(
        g,
        "test_get_degrees",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
341
342

    pserver_list = []
343
    ctx = mp.get_context("spawn")
344
    for i in range(num_server):
345
346
347
348
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_get_degrees"),
        )
349
350
351
352
        p.start()
        time.sleep(1)
        pserver_list.append(p)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
353
    nids = F.tensor(np.random.randint(g.num_nodes(), size=100))
354
355
356
    in_degs, out_degs, all_in_degs, all_out_degs = start_get_degrees_client(
        0, tmpdir, num_server > 1, nids
    )
357
358
359
360

    print("Done get_degree")
    for p in pserver_list:
        p.join()
361
        assert p.exitcode == 0
362

363
    print("check results")
364
365
366
367
368
    assert F.array_equal(g.in_degrees(orig_nid[nids]), in_degs)
    assert F.array_equal(g.in_degrees(orig_nid), all_in_degs)
    assert F.array_equal(g.out_degrees(orig_nid[nids]), out_degs)
    assert F.array_equal(g.out_degrees(orig_nid), all_out_degs)

369

370
# Wait non shared memory graph store
371
372
373
374
375
376
377
378
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
379
@pytest.mark.parametrize("num_server", [1])
380
def test_rpc_get_degree_shuffle(num_server):
381
    reset_envs()
382
    import tempfile
383
384

    os.environ["DGL_DIST_MODE"] = "distributed"
385
386
387
    with tempfile.TemporaryDirectory() as tmpdirname:
        check_rpc_get_degree_shuffle(Path(tmpdirname), num_server)

388
389
390
391

# @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
# @unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
@unittest.skip("Only support partition with shuffle")
Jinjing Zhou's avatar
Jinjing Zhou committed
392
def test_rpc_sampling():
393
    reset_envs()
Jinjing Zhou's avatar
Jinjing Zhou committed
394
    import tempfile
395
396

    os.environ["DGL_DIST_MODE"] = "distributed"
Jinjing Zhou's avatar
Jinjing Zhou committed
397
    with tempfile.TemporaryDirectory() as tmpdirname:
398
        check_rpc_sampling(Path(tmpdirname), 1)
Jinjing Zhou's avatar
Jinjing Zhou committed
399

400

401
def check_rpc_sampling_shuffle(
402
    tmpdir, num_server, num_groups=1, use_graphbolt=False, return_eids=False
403
):
404
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
405

Jinjing Zhou's avatar
Jinjing Zhou committed
406
407
408
409
    g = CitationGraphDataset("cora")[0]
    num_parts = num_server
    num_hops = 1

410
411
412
413
414
415
416
417
    orig_nids, orig_eids = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
418
        use_graphbolt=use_graphbolt,
419
        store_eids=return_eids,
420
    )
Jinjing Zhou's avatar
Jinjing Zhou committed
421
422

    pserver_list = []
423
    ctx = mp.get_context("spawn")
Jinjing Zhou's avatar
Jinjing Zhou committed
424
    for i in range(num_server):
425
426
427
428
429
430
431
432
        p = ctx.Process(
            target=start_server,
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
433
                use_graphbolt,
434
435
            ),
        )
Jinjing Zhou's avatar
Jinjing Zhou committed
436
437
438
439
        p.start()
        time.sleep(1)
        pserver_list.append(p)

440
441
442
443
    pclient_list = []
    num_clients = 1
    for client_id in range(num_clients):
        for group_id in range(num_groups):
444
445
446
447
448
449
450
451
452
453
454
            p = ctx.Process(
                target=start_sample_client_shuffle,
                args=(
                    client_id,
                    tmpdir,
                    num_server > 1,
                    g,
                    num_server,
                    group_id,
                    orig_nids,
                    orig_eids,
455
                    use_graphbolt,
456
                    return_eids,
457
458
                ),
            )
459
            p.start()
460
            time.sleep(1)  # avoid race condition when instantiating DistGraph
461
462
463
            pclient_list.append(p)
    for p in pclient_list:
        p.join()
464
        assert p.exitcode == 0
Jinjing Zhou's avatar
Jinjing Zhou committed
465
466
    for p in pserver_list:
        p.join()
467
        assert p.exitcode == 0
Jinjing Zhou's avatar
Jinjing Zhou committed
468

469

470
471
472
473
474
475
476
477
def start_hetero_sample_client(
    rank,
    tmpdir,
    disable_shared_mem,
    nodes,
    use_graphbolt=False,
    return_eids=False,
):
478
479
    gpb = None
    if disable_shared_mem:
480
481
482
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
483
    dgl.distributed.initialize("rpc_ip_config.txt")
484
    dist_graph = DistGraph("test_sampling", gpb=gpb)
485
486
487
    assert "feat" in dist_graph.nodes["n1"].data
    assert "feat" not in dist_graph.nodes["n2"].data
    assert "feat" not in dist_graph.nodes["n3"].data
488
489
490
    if gpb is None:
        gpb = dist_graph.get_partition_book()
    try:
491
492
493
494
495
        # Enable santity check in distributed sampling.
        os.environ["DGL_DIST_DEBUG"] = "1"
        sampled_graph = sample_neighbors(
            dist_graph, nodes, 3, use_graphbolt=use_graphbolt
        )
496
        block = dgl.to_block(sampled_graph, nodes)
497
498
        if not use_graphbolt or return_eids:
            block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
499
    except Exception as e:
500
        print(traceback.format_exc())
501
502
503
504
        block = None
    dgl.distributed.exit_client()
    return block, gpb

505
506
507
508
509
510

def start_hetero_etype_sample_client(
    rank,
    tmpdir,
    disable_shared_mem,
    fanout=3,
511
    nodes=None,
512
    etype_sorted=False,
513
514
    use_graphbolt=False,
    return_eids=False,
515
):
516
517
    gpb = None
    if disable_shared_mem:
518
519
520
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
521
    dgl.distributed.initialize("rpc_ip_config.txt")
522
    dist_graph = DistGraph("test_sampling", gpb=gpb)
523
524
525
    assert "feat" in dist_graph.nodes["n1"].data
    assert "feat" not in dist_graph.nodes["n2"].data
    assert "feat" not in dist_graph.nodes["n3"].data
526

527
    if (not use_graphbolt) and dist_graph.local_partition is not None:
528
529
530
531
        # Check whether etypes are sorted in dist_graph
        local_g = dist_graph.local_partition
        local_nids = np.arange(local_g.num_nodes())
        for lnid in local_nids:
532
            leids = local_g.in_edges(lnid, form="eid")
533
534
535
536
            letids = F.asnumpy(local_g.edata[dgl.ETYPE][leids])
            _, idices = np.unique(letids, return_index=True)
            assert np.all(idices[:-1] <= idices[1:])

537
538
539
    if gpb is None:
        gpb = dist_graph.get_partition_book()
    try:
540
541
        # Enable santity check in distributed sampling.
        os.environ["DGL_DIST_DEBUG"] = "1"
542
        sampled_graph = sample_etype_neighbors(
543
544
545
546
547
            dist_graph,
            nodes,
            fanout,
            etype_sorted=etype_sorted,
            use_graphbolt=use_graphbolt,
548
        )
549
        block = dgl.to_block(sampled_graph, nodes)
550
551
552
        if sampled_graph.num_edges() > 0:
            if not use_graphbolt or return_eids:
                block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
553
    except Exception as e:
554
        print(traceback.format_exc())
555
556
557
558
        block = None
    dgl.distributed.exit_client()
    return block, gpb

559

560
561
562
def check_rpc_hetero_sampling_shuffle(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
563
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
564
565
566
567
568

    g = create_random_hetero()
    num_parts = num_server
    num_hops = 1

569
570
571
572
573
574
575
576
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
577
578
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
579
    )
580
581

    pserver_list = []
582
    ctx = mp.get_context("spawn")
583
    for i in range(num_server):
584
585
        p = ctx.Process(
            target=start_server,
586
587
588
589
590
591
592
593
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
594
        )
595
596
597
598
        p.start()
        time.sleep(1)
        pserver_list.append(p)

599
    nodes = {"n3": torch.tensor([0, 10, 99, 66, 124, 208], dtype=g.idtype)}
600
    block, gpb = start_hetero_sample_client(
601
602
603
        0,
        tmpdir,
        num_server > 1,
604
        nodes=nodes,
605
606
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
607
    )
608
609
    for p in pserver_list:
        p.join()
610
        assert p.exitcode == 0
611

612
613
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
614
615
616
617
618
619
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
620
621
622
623
624
625
626
627
628

        assert np.all(
            F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))
        )

        if use_graphbolt and not return_eids:
            continue

        shuffled_eid = block.edges[etype].data[dgl.EID]
629
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
630
631

        # Check the node Ids and edge Ids.
632
633
634
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)
635

636

637
638
639
640
641
642
643
644
645
def get_degrees(g, nids, ntype):
    deg = F.zeros((len(nids),), dtype=F.int64)
    for srctype, etype, dsttype in g.canonical_etypes:
        if srctype == ntype:
            deg += g.out_degrees(u=nids, etype=etype)
        elif dsttype == ntype:
            deg += g.in_degrees(v=nids, etype=etype)
    return deg

646

647
648
649
def check_rpc_hetero_sampling_empty_shuffle(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
650
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
651
652
653
654
655

    g = create_random_hetero(empty=True)
    num_parts = num_server
    num_hops = 1

656
657
658
659
660
661
662
663
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
664
665
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
666
    )
667
668

    pserver_list = []
669
    ctx = mp.get_context("spawn")
670
    for i in range(num_server):
671
672
        p = ctx.Process(
            target=start_server,
673
674
675
676
677
678
679
680
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
681
        )
682
683
684
685
        p.start()
        time.sleep(1)
        pserver_list.append(p)

686
    deg = get_degrees(g, orig_nids["n3"], "n3")
687
    empty_nids = F.nonzero_1d(deg == 0)
688
    block, gpb = start_hetero_sample_client(
689
690
691
692
693
694
        0,
        tmpdir,
        num_server > 1,
        nodes={"n3": empty_nids},
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
695
    )
696
697
    for p in pserver_list:
        p.join()
698
        assert p.exitcode == 0
699

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
700
    assert block.num_edges() == 0
701
702
    assert len(block.etypes) == len(g.etypes)

703
704

def check_rpc_hetero_etype_sampling_shuffle(
705
706
707
708
709
    tmpdir,
    num_server,
    graph_formats=None,
    use_graphbolt=False,
    return_eids=False,
710
):
711
712
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

713
714
715
716
    g = create_random_hetero(dense=True)
    num_parts = num_server
    num_hops = 1

717
718
719
720
721
722
723
724
725
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
        graph_formats=graph_formats,
726
727
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
728
    )
729
730

    pserver_list = []
731
    ctx = mp.get_context("spawn")
732
    for i in range(num_server):
733
734
        p = ctx.Process(
            target=start_server,
735
736
737
738
739
740
741
742
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
743
        )
744
745
746
747
        p.start()
        time.sleep(1)
        pserver_list.append(p)

748
    fanout = {etype: 3 for etype in g.canonical_etypes}
749
750
    etype_sorted = False
    if graph_formats is not None:
751
        etype_sorted = "csc" in graph_formats or "csr" in graph_formats
752
    nodes = {"n3": torch.tensor([0, 10, 99, 66, 124, 208], dtype=g.idtype)}
753
754
755
756
757
    block, gpb = start_hetero_etype_sample_client(
        0,
        tmpdir,
        num_server > 1,
        fanout,
758
        nodes=nodes,
759
        etype_sorted=etype_sorted,
760
761
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
762
    )
763
764
765
    print("Done sampling")
    for p in pserver_list:
        p.join()
766
        assert p.exitcode == 0
767

768
    src, dst = block.edges(etype=("n1", "r13", "n3"))
769
    assert len(src) == 18
770
    src, dst = block.edges(etype=("n2", "r23", "n3"))
771
772
    assert len(src) == 18

773
774
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
775
776
777
778
779
780
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
781
782
783
784
785
786
        assert np.all(
            F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))
        )

        if use_graphbolt and not return_eids:
            continue
787
788

        # Check the node Ids and edge Ids.
789
790
        shuffled_eid = block.edges[etype].data[dgl.EID]
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
791
792
793
794
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)

795

796
797
798
def check_rpc_hetero_etype_sampling_empty_shuffle(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
799
800
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

801
802
803
804
    g = create_random_hetero(dense=True, empty=True)
    num_parts = num_server
    num_hops = 1

805
806
807
808
809
810
811
812
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
813
814
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
815
    )
816
817

    pserver_list = []
818
    ctx = mp.get_context("spawn")
819
    for i in range(num_server):
820
821
        p = ctx.Process(
            target=start_server,
822
823
824
825
826
827
828
829
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
830
        )
831
832
833
834
835
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    fanout = 3
836
    deg = get_degrees(g, orig_nids["n3"], "n3")
837
    empty_nids = F.nonzero_1d(deg == 0)
838
    block, gpb = start_hetero_etype_sample_client(
839
840
841
842
843
844
845
        0,
        tmpdir,
        num_server > 1,
        fanout,
        nodes={"n3": empty_nids},
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
846
    )
847
848
849
    print("Done sampling")
    for p in pserver_list:
        p.join()
850
        assert p.exitcode == 0
851

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
852
    assert block.num_edges() == 0
853
854
    assert len(block.etypes) == len(g.etypes)

855
856

def create_random_bipartite():
857
858
859
860
861
862
863
    g = dgl.rand_bipartite("user", "buys", "game", 500, 1000, 1000)
    g.nodes["user"].data["feat"] = F.ones(
        (g.num_nodes("user"), 10), F.float32, F.cpu()
    )
    g.nodes["game"].data["feat"] = F.ones(
        (g.num_nodes("game"), 10), F.float32, F.cpu()
    )
864
865
866
    return g


867
868
869
870
871
872
873
874
def start_bipartite_sample_client(
    rank,
    tmpdir,
    disable_shared_mem,
    nodes,
    use_graphbolt=False,
    return_eids=False,
):
875
876
877
    gpb = None
    if disable_shared_mem:
        _, _, _, gpb, _, _, _ = load_partition(
878
879
            tmpdir / "test_sampling.json", rank
        )
880
    dgl.distributed.initialize("rpc_ip_config.txt")
881
    dist_graph = DistGraph("test_sampling", gpb=gpb)
882
883
    assert "feat" in dist_graph.nodes["user"].data
    assert "feat" in dist_graph.nodes["game"].data
884
885
    if gpb is None:
        gpb = dist_graph.get_partition_book()
886
887
888
889
890
    # Enable santity check in distributed sampling.
    os.environ["DGL_DIST_DEBUG"] = "1"
    sampled_graph = sample_neighbors(
        dist_graph, nodes, 3, use_graphbolt=use_graphbolt
    )
891
892
    block = dgl.to_block(sampled_graph, nodes)
    if sampled_graph.num_edges() > 0:
893
894
        if not use_graphbolt or return_eids:
            block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
895
896
897
898
    dgl.distributed.exit_client()
    return block, gpb


899
def start_bipartite_etype_sample_client(
900
901
902
903
904
905
906
    rank,
    tmpdir,
    disable_shared_mem,
    fanout=3,
    nodes={},
    use_graphbolt=False,
    return_eids=False,
907
):
908
909
910
    gpb = None
    if disable_shared_mem:
        _, _, _, gpb, _, _, _ = load_partition(
911
912
            tmpdir / "test_sampling.json", rank
        )
913
    dgl.distributed.initialize("rpc_ip_config.txt")
914
    dist_graph = DistGraph("test_sampling", gpb=gpb)
915
916
    assert "feat" in dist_graph.nodes["user"].data
    assert "feat" in dist_graph.nodes["game"].data
917

918
    if not use_graphbolt and dist_graph.local_partition is not None:
919
920
921
922
        # Check whether etypes are sorted in dist_graph
        local_g = dist_graph.local_partition
        local_nids = np.arange(local_g.num_nodes())
        for lnid in local_nids:
923
            leids = local_g.in_edges(lnid, form="eid")
924
925
926
927
928
929
            letids = F.asnumpy(local_g.edata[dgl.ETYPE][leids])
            _, idices = np.unique(letids, return_index=True)
            assert np.all(idices[:-1] <= idices[1:])

    if gpb is None:
        gpb = dist_graph.get_partition_book()
930
931
932
    sampled_graph = sample_etype_neighbors(
        dist_graph, nodes, fanout, use_graphbolt=use_graphbolt
    )
933
934
    block = dgl.to_block(sampled_graph, nodes)
    if sampled_graph.num_edges() > 0:
935
936
        if not use_graphbolt or return_eids:
            block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
937
938
939
940
    dgl.distributed.exit_client()
    return block, gpb


941
942
943
def check_rpc_bipartite_sampling_empty(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
944
945
946
947
948
949
950
    """sample on bipartite via sample_neighbors() which yields empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

951
952
953
954
955
956
957
958
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
959
960
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
961
    )
962
963

    pserver_list = []
964
    ctx = mp.get_context("spawn")
965
    for i in range(num_server):
966
967
        p = ctx.Process(
            target=start_server,
968
969
970
971
972
973
974
975
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
976
        )
977
978
979
980
        p.start()
        time.sleep(1)
        pserver_list.append(p)

981
    deg = get_degrees(g, orig_nids["game"], "game")
982
    empty_nids = F.nonzero_1d(deg == 0)
983
    nodes = {"game": empty_nids, "user": torch.tensor([1], dtype=g.idtype)}
984
    block, _ = start_bipartite_sample_client(
985
986
987
        0,
        tmpdir,
        num_server > 1,
988
        nodes=nodes,
989
990
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
991
    )
992
993
994
995

    print("Done sampling")
    for p in pserver_list:
        p.join()
996
        assert p.exitcode == 0
997

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
998
    assert block.num_edges() == 0
999
1000
1001
    assert len(block.etypes) == len(g.etypes)


1002
1003
1004
def check_rpc_bipartite_sampling_shuffle(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
1005
1006
1007
1008
1009
1010
1011
    """sample on bipartite via sample_neighbors() which yields non-empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

1012
1013
1014
1015
1016
1017
1018
1019
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
1020
1021
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
1022
    )
1023
1024

    pserver_list = []
1025
    ctx = mp.get_context("spawn")
1026
    for i in range(num_server):
1027
1028
        p = ctx.Process(
            target=start_server,
1029
1030
1031
1032
1033
1034
1035
1036
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
1037
        )
1038
1039
1040
1041
        p.start()
        time.sleep(1)
        pserver_list.append(p)

1042
    deg = get_degrees(g, orig_nid_map["game"], "game")
1043
    nids = F.nonzero_1d(deg > 0)
1044
    nodes = {"game": nids, "user": torch.tensor([0], dtype=g.idtype)}
1045
    block, gpb = start_bipartite_sample_client(
1046
1047
1048
        0,
        tmpdir,
        num_server > 1,
1049
        nodes=nodes,
1050
1051
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
1052
    )
1053
1054
1055
    print("Done sampling")
    for p in pserver_list:
        p.join()
1056
        assert p.exitcode == 0
1057

1058
1059
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
1060
1061
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
1062
1063
1064
1065
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
1066
1067
1068
1069
1070
1071
1072
1073
        assert np.all(
            F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))
        )

        if use_graphbolt and not return_eids:
            continue

        shuffled_eid = block.edges[etype].data[dgl.EID]
1074
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
1075
1076
1077
1078
1079
1080
1081

        # Check the node Ids and edge Ids.
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)


1082
1083
1084
def check_rpc_bipartite_etype_sampling_empty(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
1085
1086
1087
1088
1089
1090
1091
    """sample on bipartite via sample_etype_neighbors() which yields empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

1092
1093
1094
1095
1096
1097
1098
1099
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
1100
1101
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
1102
    )
1103
1104

    pserver_list = []
1105
    ctx = mp.get_context("spawn")
1106
    for i in range(num_server):
1107
1108
        p = ctx.Process(
            target=start_server,
1109
1110
1111
1112
1113
1114
1115
1116
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
1117
        )
1118
1119
1120
1121
        p.start()
        time.sleep(1)
        pserver_list.append(p)

1122
    deg = get_degrees(g, orig_nids["game"], "game")
1123
    empty_nids = F.nonzero_1d(deg == 0)
1124
    nodes = {"game": empty_nids, "user": torch.tensor([1], dtype=g.idtype)}
1125
1126
1127
1128
    block, _ = start_bipartite_etype_sample_client(
        0,
        tmpdir,
        num_server > 1,
1129
        nodes=nodes,
1130
1131
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
1132
    )
1133
1134
1135
1136

    print("Done sampling")
    for p in pserver_list:
        p.join()
1137
        assert p.exitcode == 0
1138
1139

    assert block is not None
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1140
    assert block.num_edges() == 0
1141
1142
1143
    assert len(block.etypes) == len(g.etypes)


1144
1145
1146
def check_rpc_bipartite_etype_sampling_shuffle(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
1147
1148
1149
1150
1151
1152
1153
    """sample on bipartite via sample_etype_neighbors() which yields non-empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

1154
1155
1156
1157
1158
1159
1160
1161
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
1162
1163
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
1164
    )
1165
1166

    pserver_list = []
1167
    ctx = mp.get_context("spawn")
1168
    for i in range(num_server):
1169
1170
        p = ctx.Process(
            target=start_server,
1171
1172
1173
1174
1175
1176
1177
1178
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
1179
        )
1180
1181
1182
1183
1184
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    fanout = 3
1185
    deg = get_degrees(g, orig_nid_map["game"], "game")
1186
    nids = F.nonzero_1d(deg > 0)
1187
    nodes = {"game": nids, "user": torch.tensor([0], dtype=g.idtype)}
1188
    block, gpb = start_bipartite_etype_sample_client(
1189
1190
1191
1192
        0,
        tmpdir,
        num_server > 1,
        fanout,
1193
        nodes=nodes,
1194
1195
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
1196
    )
1197
1198
1199
    print("Done sampling")
    for p in pserver_list:
        p.join()
1200
        assert p.exitcode == 0
1201

1202
1203
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
1204
1205
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
1206
1207
1208
1209
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
1210
1211
1212
1213
1214
1215
        assert np.all(
            F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))
        )

        if use_graphbolt and not return_eids:
            continue
1216
1217

        # Check the node Ids and edge Ids.
1218
1219
        shuffled_eid = block.edges[etype].data[dgl.EID]
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
1220
1221
1222
1223
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)

1224

1225
@pytest.mark.parametrize("num_server", [1])
1226
@pytest.mark.parametrize("use_graphbolt", [False, True])
1227
1228
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_sampling_shuffle(num_server, use_graphbolt, return_eids):
1229
    reset_envs()
1230
    os.environ["DGL_DIST_MODE"] = "distributed"
Jinjing Zhou's avatar
Jinjing Zhou committed
1231
    with tempfile.TemporaryDirectory() as tmpdirname:
1232
        check_rpc_sampling_shuffle(
1233
1234
1235
1236
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
1237
        )
1238
1239
1240


@pytest.mark.parametrize("num_server", [1])
1241
1242
1243
@pytest.mark.parametrize("use_graphbolt,", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_sampling_shuffle(num_server, use_graphbolt, return_eids):
1244
1245
1246
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1247
1248
1249
1250
1251
1252
        check_rpc_hetero_sampling_shuffle(
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
        )
1253
1254
1255


@pytest.mark.parametrize("num_server", [1])
1256
1257
1258
1259
1260
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_sampling_empty_shuffle(
    num_server, use_graphbolt, return_eids
):
1261
1262
1263
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1264
1265
1266
1267
1268
1269
        check_rpc_hetero_sampling_empty_shuffle(
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
        )
1270
1271
1272
1273
1274
1275


@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize(
    "graph_formats", [None, ["csc"], ["csr"], ["csc", "coo"]]
)
1276
def test_rpc_hetero_etype_sampling_shuffle_dgl(num_server, graph_formats):
1277
1278
1279
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1280
        check_rpc_hetero_etype_sampling_shuffle(
1281
            Path(tmpdirname), num_server, graph_formats=graph_formats
1282
        )
1283
1284
1285


@pytest.mark.parametrize("num_server", [1])
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_etype_sampling_shuffle_graphbolt(num_server, return_eids):
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
        check_rpc_hetero_etype_sampling_shuffle(
            Path(tmpdirname),
            num_server,
            use_graphbolt=True,
            return_eids=return_eids,
        )


@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_etype_sampling_empty_shuffle(
    num_server, use_graphbolt, return_eids
):
1305
1306
1307
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1308
        check_rpc_hetero_etype_sampling_empty_shuffle(
1309
1310
1311
1312
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
1313
        )
1314
1315
1316


@pytest.mark.parametrize("num_server", [1])
1317
1318
1319
1320
1321
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_bipartite_sampling_empty_shuffle(
    num_server, use_graphbolt, return_eids
):
1322
1323
1324
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1325
1326
1327
        check_rpc_bipartite_sampling_empty(
            Path(tmpdirname), num_server, use_graphbolt, return_eids
        )
1328
1329
1330


@pytest.mark.parametrize("num_server", [1])
1331
1332
1333
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_bipartite_sampling_shuffle(num_server, use_graphbolt, return_eids):
1334
1335
1336
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1337
1338
1339
        check_rpc_bipartite_sampling_shuffle(
            Path(tmpdirname), num_server, use_graphbolt, return_eids
        )
1340
1341
1342


@pytest.mark.parametrize("num_server", [1])
1343
1344
1345
1346
1347
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_bipartite_etype_sampling_empty_shuffle(
    num_server, use_graphbolt, return_eids
):
1348
1349
1350
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1351
1352
1353
1354
1355
1356
        check_rpc_bipartite_etype_sampling_empty(
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
        )
1357
1358
1359


@pytest.mark.parametrize("num_server", [1])
1360
1361
1362
1363
1364
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_bipartite_etype_sampling_shuffle(
    num_server, use_graphbolt, return_eids
):
1365
1366
1367
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1368
1369
1370
1371
1372
1373
        check_rpc_bipartite_etype_sampling_shuffle(
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
        )
Jinjing Zhou's avatar
Jinjing Zhou committed
1374

1375

1376
def check_standalone_sampling(tmpdir):
1377
    g = CitationGraphDataset("cora")[0]
1378
    prob = np.maximum(np.random.randn(g.num_edges()), 0)
1379
1380
1381
    mask = prob > 0
    g.edata["prob"] = F.tensor(prob)
    g.edata["mask"] = F.tensor(mask)
1382
1383
    num_parts = 1
    num_hops = 1
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
    partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )

    os.environ["DGL_DIST_MODE"] = "standalone"
1394
    dgl.distributed.initialize("rpc_ip_config.txt")
1395
1396
1397
    dist_graph = DistGraph(
        "test_sampling", part_config=tmpdir / "test_sampling.json"
    )
1398
1399
1400
1401
1402
    sampled_graph = sample_neighbors(
        dist_graph,
        torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=dist_graph.idtype),
        3,
    )
1403
1404

    src, dst = sampled_graph.edges()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1405
    assert sampled_graph.num_nodes() == g.num_nodes()
1406
1407
1408
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    assert np.array_equal(
1409
1410
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)
    )
1411
1412

    sampled_graph = sample_neighbors(
1413
1414
1415
1416
        dist_graph,
        torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=dist_graph.idtype),
        3,
        prob="mask",
1417
    )
1418
1419
1420
1421
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert mask[eid].all()

    sampled_graph = sample_neighbors(
1422
1423
1424
1425
        dist_graph,
        torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=dist_graph.idtype),
        3,
        prob="prob",
1426
    )
1427
1428
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert (prob[eid] > 0).all()
1429
    dgl.distributed.exit_client()
1430

1431

1432
def check_standalone_etype_sampling(tmpdir):
1433
    hg = CitationGraphDataset("cora")[0]
1434
    prob = np.maximum(np.random.randn(hg.num_edges()), 0)
1435
1436
1437
    mask = prob > 0
    hg.edata["prob"] = F.tensor(prob)
    hg.edata["mask"] = F.tensor(mask)
1438
1439
1440
    num_parts = 1
    num_hops = 1

1441
1442
1443
1444
1445
1446
1447
1448
1449
    partition_graph(
        hg,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )
    os.environ["DGL_DIST_MODE"] = "standalone"
1450
    dgl.distributed.initialize("rpc_ip_config.txt")
1451
1452
1453
    dist_graph = DistGraph(
        "test_sampling", part_config=tmpdir / "test_sampling.json"
    )
1454
1455
1456
1457
1458
    sampled_graph = sample_etype_neighbors(
        dist_graph,
        torch.tensor([0, 10, 99, 66, 1023], dtype=dist_graph.idtype),
        3,
    )
1459
1460

    src, dst = sampled_graph.edges()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1461
    assert sampled_graph.num_nodes() == hg.num_nodes()
1462
1463
1464
    assert np.all(F.asnumpy(hg.has_edges_between(src, dst)))
    eids = hg.edge_ids(src, dst)
    assert np.array_equal(
1465
1466
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)
    )
1467
1468

    sampled_graph = sample_etype_neighbors(
1469
1470
1471
1472
        dist_graph,
        torch.tensor([0, 10, 99, 66, 1023], dtype=dist_graph.idtype),
        3,
        prob="mask",
1473
    )
1474
1475
1476
1477
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert mask[eid].all()

    sampled_graph = sample_etype_neighbors(
1478
1479
1480
1481
        dist_graph,
        torch.tensor([0, 10, 99, 66, 1023], dtype=dist_graph.idtype),
        3,
        prob="prob",
1482
    )
1483
1484
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert (prob[eid] > 0).all()
1485
1486
    dgl.distributed.exit_client()

1487

1488
def check_standalone_etype_sampling_heterograph(tmpdir):
1489
    hg = CitationGraphDataset("cora")[0]
1490
1491
1492
    num_parts = 1
    num_hops = 1
    src, dst = hg.edges()
1493
1494
1495
1496
1497
    new_hg = dgl.heterograph(
        {
            ("paper", "cite", "paper"): (src, dst),
            ("paper", "cite-by", "paper"): (dst, src),
        },
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1498
        {"paper": hg.num_nodes()},
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
    )
    partition_graph(
        new_hg,
        "test_hetero_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )
    os.environ["DGL_DIST_MODE"] = "standalone"
1509
    dgl.distributed.initialize("rpc_ip_config.txt")
1510
1511
1512
    dist_graph = DistGraph(
        "test_hetero_sampling", part_config=tmpdir / "test_hetero_sampling.json"
    )
1513
    sampled_graph = sample_etype_neighbors(
1514
1515
1516
1517
1518
1519
        dist_graph,
        torch.tensor(
            [0, 1, 2, 10, 99, 66, 1023, 1024, 2700, 2701],
            dtype=dist_graph.idtype,
        ),
        1,
1520
1521
    )
    src, dst = sampled_graph.edges(etype=("paper", "cite", "paper"))
1522
    assert len(src) == 10
1523
    src, dst = sampled_graph.edges(etype=("paper", "cite-by", "paper"))
1524
    assert len(src) == 10
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1525
    assert sampled_graph.num_nodes() == new_hg.num_nodes()
1526
1527
    dgl.distributed.exit_client()

1528
1529
1530
1531
1532
1533

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
1534
def test_standalone_sampling():
1535
    reset_envs()
1536
    import tempfile
1537
1538

    os.environ["DGL_DIST_MODE"] = "standalone"
1539
    with tempfile.TemporaryDirectory() as tmpdirname:
1540
        check_standalone_sampling(Path(tmpdirname))
1541

1542

1543
1544
def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes):
    gpb = None
1545
    dgl.distributed.initialize("rpc_ip_config.txt")
1546
    if disable_shared_mem:
1547
1548
1549
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_in_subgraph.json", rank
        )
1550
    dist_graph = DistGraph("test_in_subgraph", gpb=gpb)
1551
1552
1553
    try:
        sampled_graph = dgl.distributed.in_subgraph(dist_graph, nodes)
    except Exception as e:
1554
        print(traceback.format_exc())
1555
        sampled_graph = None
1556
    dgl.distributed.exit_client()
1557
1558
1559
    return sampled_graph


1560
def check_rpc_in_subgraph_shuffle(tmpdir, num_server):
1561
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
1562
1563
1564
1565

    g = CitationGraphDataset("cora")[0]
    num_parts = num_server

1566
1567
1568
1569
1570
1571
1572
1573
1574
    orig_nid, orig_eid = partition_graph(
        g,
        "test_in_subgraph",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
1575
1576

    pserver_list = []
1577
    ctx = mp.get_context("spawn")
1578
    for i in range(num_server):
1579
1580
1581
1582
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_in_subgraph"),
        )
1583
1584
1585
1586
        p.start()
        time.sleep(1)
        pserver_list.append(p)

1587
    nodes = torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=g.idtype)
1588
1589
1590
    sampled_graph = start_in_subgraph_client(0, tmpdir, num_server > 1, nodes)
    for p in pserver_list:
        p.join()
1591
        assert p.exitcode == 0
1592
1593

    src, dst = sampled_graph.edges()
1594
1595
    src = orig_nid[src]
    dst = orig_nid[dst]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1596
    assert sampled_graph.num_nodes() == g.num_nodes()
1597
1598
1599
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))

    subg1 = dgl.in_subgraph(g, orig_nid[nodes])
1600
1601
1602
1603
    src1, dst1 = subg1.edges()
    assert np.all(np.sort(F.asnumpy(src)) == np.sort(F.asnumpy(src1)))
    assert np.all(np.sort(F.asnumpy(dst)) == np.sort(F.asnumpy(dst1)))
    eids = g.edge_ids(src, dst)
1604
1605
    eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
    assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))
1606

1607
1608
1609
1610
1611
1612

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
1613
def test_rpc_in_subgraph():
1614
    reset_envs()
1615
    import tempfile
1616
1617

    os.environ["DGL_DIST_MODE"] = "distributed"
1618
    with tempfile.TemporaryDirectory() as tmpdirname:
1619
        check_rpc_in_subgraph_shuffle(Path(tmpdirname), 1)
1620

1621
1622
1623
1624
1625
1626
1627
1628
1629

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
1630
def test_standalone_etype_sampling():
1631
    reset_envs()
1632
    import tempfile
1633

1634
    with tempfile.TemporaryDirectory() as tmpdirname:
1635
        os.environ["DGL_DIST_MODE"] = "standalone"
1636
        check_standalone_etype_sampling_heterograph(Path(tmpdirname))
1637
    with tempfile.TemporaryDirectory() as tmpdirname:
1638
        os.environ["DGL_DIST_MODE"] = "standalone"
1639
        check_standalone_etype_sampling(Path(tmpdirname))
1640

1641

Jinjing Zhou's avatar
Jinjing Zhou committed
1642
1643
if __name__ == "__main__":
    import tempfile
1644

Jinjing Zhou's avatar
Jinjing Zhou committed
1645
    with tempfile.TemporaryDirectory() as tmpdirname:
1646
        os.environ["DGL_DIST_MODE"] = "standalone"
1647
        check_standalone_etype_sampling_heterograph(Path(tmpdirname))
1648
1649

    with tempfile.TemporaryDirectory() as tmpdirname:
1650
        os.environ["DGL_DIST_MODE"] = "standalone"
1651
1652
        check_standalone_etype_sampling(Path(tmpdirname))
        check_standalone_sampling(Path(tmpdirname))
1653
        os.environ["DGL_DIST_MODE"] = "distributed"
1654
1655
        check_rpc_sampling(Path(tmpdirname), 2)
        check_rpc_sampling(Path(tmpdirname), 1)
1656
1657
        check_rpc_get_degree_shuffle(Path(tmpdirname), 1)
        check_rpc_get_degree_shuffle(Path(tmpdirname), 2)
1658
1659
        check_rpc_find_edges_shuffle(Path(tmpdirname), 2)
        check_rpc_find_edges_shuffle(Path(tmpdirname), 1)
1660
1661
        check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), 2)
1662
1663
1664
1665
        check_rpc_in_subgraph_shuffle(Path(tmpdirname), 2)
        check_rpc_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), 2)
1666
1667
1668
1669
        check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), 2)
        check_rpc_hetero_etype_sampling_empty_shuffle(Path(tmpdirname), 1)