test_distributed_sampling.py 50.5 KB
Newer Older
1
import multiprocessing as mp
Jinjing Zhou's avatar
Jinjing Zhou committed
2
import os
3
import random
4
import tempfile
Jinjing Zhou's avatar
Jinjing Zhou committed
5
import time
6
7
import traceback
import unittest
Jinjing Zhou's avatar
Jinjing Zhou committed
8
from pathlib import Path
9
10
11
12

import backend as F
import dgl
import numpy as np
13
import pytest
14
import torch
15
16
17
18
19
20
21
22
23
24
from dgl.data import CitationGraphDataset, WN18Dataset
from dgl.distributed import (
    DistGraph,
    DistGraphServer,
    load_partition,
    load_partition_book,
    partition_graph,
    sample_etype_neighbors,
    sample_neighbors,
)
25
from scipy import sparse as spsp
26
from utils import generate_ip_config, reset_envs
Jinjing Zhou's avatar
Jinjing Zhou committed
27
28


29
30
31
32
33
34
def start_server(
    rank,
    tmpdir,
    disable_shared_mem,
    graph_name,
    graph_format=["csc", "coo"],
35
    use_graphbolt=False,
36
37
38
39
40
41
42
43
44
):
    g = DistGraphServer(
        rank,
        "rpc_ip_config.txt",
        1,
        1,
        tmpdir / (graph_name + ".json"),
        disable_shared_mem=disable_shared_mem,
        graph_format=graph_format,
45
        use_graphbolt=use_graphbolt,
46
    )
Jinjing Zhou's avatar
Jinjing Zhou committed
47
48
49
    g.start()


50
def start_sample_client(rank, tmpdir, disable_shared_mem):
51
52
    gpb = None
    if disable_shared_mem:
53
54
55
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
56
    dgl.distributed.initialize("rpc_ip_config.txt")
57
    dist_graph = DistGraph("test_sampling", gpb=gpb)
58
    try:
59
        sampled_graph = sample_neighbors(
60
61
62
            dist_graph,
            torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=dist_graph.idtype),
            3,
63
        )
64
    except Exception as e:
65
        print(traceback.format_exc())
66
        sampled_graph = None
67
    dgl.distributed.exit_client()
Jinjing Zhou's avatar
Jinjing Zhou committed
68
69
    return sampled_graph

70

71
72
73
74
75
76
77
78
79
def start_sample_client_shuffle(
    rank,
    tmpdir,
    disable_shared_mem,
    g,
    num_servers,
    group_id,
    orig_nid,
    orig_eid,
80
    use_graphbolt=False,
81
    return_eids=False,
82
    node_id_dtype=None,
83
84
):
    os.environ["DGL_GROUP_ID"] = str(group_id)
85
86
    gpb = None
    if disable_shared_mem:
87
88
89
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
90
    dgl.distributed.initialize("rpc_ip_config.txt")
91
    dist_graph = DistGraph("test_sampling", gpb=gpb)
92
    sampled_graph = sample_neighbors(
93
        dist_graph,
94
        torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=node_id_dtype),
95
96
        3,
        use_graphbolt=use_graphbolt,
97
    )
98
99
100
101
102
103
104
    assert sampled_graph.idtype == dist_graph.idtype
    if use_graphbolt:
        # dtype conversion is applied for GraphBolt partitions.
        assert sampled_graph.idtype == torch.int32
    else:
        # dtype conversion is not applied for non-GraphBolt partitions.
        assert sampled_graph.idtype == torch.int64
105

106
107
108
    assert (
        dgl.ETYPE not in sampled_graph.edata
    ), "Etype should not be in homogeneous sampled graph."
109
110
111
    src, dst = sampled_graph.edges()
    src = orig_nid[src]
    dst = orig_nid[dst]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
112
    assert sampled_graph.num_nodes() == g.num_nodes()
113
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
114
    if use_graphbolt and not return_eids:
115
116
117
118
119
120
121
        assert (
            dgl.EID not in sampled_graph.edata
        ), "EID should not be in sampled graph if use_graphbolt=True."
    else:
        eids = g.edge_ids(src, dst)
        eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
        assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))
122

123

124
def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids, etype=None):
125
126
    gpb = None
    if disable_shared_mem:
127
128
129
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_find_edges.json", rank
        )
130
    dgl.distributed.initialize("rpc_ip_config.txt")
131
    dist_graph = DistGraph("test_find_edges", gpb=gpb)
132
    try:
133
        u, v = dist_graph.find_edges(eids, etype=etype)
134
    except Exception as e:
135
        print(traceback.format_exc())
136
        u, v = None, None
137
138
    dgl.distributed.exit_client()
    return u, v
Jinjing Zhou's avatar
Jinjing Zhou committed
139

140

141
142
143
def start_get_degrees_client(rank, tmpdir, disable_shared_mem, nids=None):
    gpb = None
    if disable_shared_mem:
144
145
146
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_get_degrees.json", rank
        )
147
    dgl.distributed.initialize("rpc_ip_config.txt")
148
149
150
151
152
153
154
    dist_graph = DistGraph("test_get_degrees", gpb=gpb)
    try:
        in_deg = dist_graph.in_degrees(nids)
        all_in_deg = dist_graph.in_degrees()
        out_deg = dist_graph.out_degrees(nids)
        all_out_deg = dist_graph.out_degrees()
    except Exception as e:
155
        print(traceback.format_exc())
156
157
158
159
        in_deg, out_deg, all_in_deg, all_out_deg = None, None, None, None
    dgl.distributed.exit_client()
    return in_deg, out_deg, all_in_deg, all_out_deg

160

161
def check_rpc_sampling(tmpdir, num_server):
162
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
Jinjing Zhou's avatar
Jinjing Zhou committed
163
164
165
166
167
168

    g = CitationGraphDataset("cora")[0]
    print(g.idtype)
    num_parts = num_server
    num_hops = 1

169
170
171
172
173
174
175
176
    partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )
Jinjing Zhou's avatar
Jinjing Zhou committed
177
178

    pserver_list = []
179
    ctx = mp.get_context("spawn")
Jinjing Zhou's avatar
Jinjing Zhou committed
180
    for i in range(num_server):
181
182
183
184
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
Jinjing Zhou's avatar
Jinjing Zhou committed
185
186
187
188
        p.start()
        time.sleep(1)
        pserver_list.append(p)

189
    sampled_graph = start_sample_client(0, tmpdir, num_server > 1)
Jinjing Zhou's avatar
Jinjing Zhou committed
190
191
192
    print("Done sampling")
    for p in pserver_list:
        p.join()
193
        assert p.exitcode == 0
Jinjing Zhou's avatar
Jinjing Zhou committed
194
195

    src, dst = sampled_graph.edges()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
196
    assert sampled_graph.num_nodes() == g.num_nodes()
Jinjing Zhou's avatar
Jinjing Zhou committed
197
198
199
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    assert np.array_equal(
200
201
202
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)
    )

Jinjing Zhou's avatar
Jinjing Zhou committed
203

204
def check_rpc_find_edges_shuffle(tmpdir, num_server):
205
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
206
207
208
209

    g = CitationGraphDataset("cora")[0]
    num_parts = num_server

210
211
212
213
214
215
216
217
218
    orig_nid, orig_eid = partition_graph(
        g,
        "test_find_edges",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
219
220

    pserver_list = []
221
    ctx = mp.get_context("spawn")
222
    for i in range(num_server):
223
224
225
226
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_find_edges", ["csr", "coo"]),
        )
227
228
229
230
        p.start()
        time.sleep(1)
        pserver_list.append(p)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
231
    eids = F.tensor(np.random.randint(g.num_edges(), size=100))
232
    u, v = g.find_edges(orig_eid[eids])
233
    du, dv = start_find_edges_client(0, tmpdir, num_server > 1, eids)
234
235
    du = orig_nid[du]
    dv = orig_nid[dv]
236
237
238
    assert F.array_equal(u, du)
    assert F.array_equal(v, dv)

239

240
def create_random_hetero(dense=False, empty=False):
241
242
243
244
245
246
    num_nodes = (
        {"n1": 210, "n2": 200, "n3": 220}
        if dense
        else {"n1": 1010, "n2": 1000, "n3": 1020}
    )
    etypes = [("n1", "r12", "n2"), ("n1", "r13", "n3"), ("n2", "r23", "n3")]
247
    edges = {}
248
    random.seed(42)
249
250
    for etype in etypes:
        src_ntype, _, dst_ntype = etype
251
252
253
254
255
256
257
        arr = spsp.random(
            num_nodes[src_ntype] - 10 if empty else num_nodes[src_ntype],
            num_nodes[dst_ntype] - 10 if empty else num_nodes[dst_ntype],
            density=0.1 if dense else 0.001,
            format="coo",
            random_state=100,
        )
258
        edges[etype] = (arr.row, arr.col)
259
    g = dgl.heterograph(edges, num_nodes)
260
    g.nodes["n1"].data["feat"] = F.ones(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
261
        (g.num_nodes("n1"), 10), F.float32, F.cpu()
262
    )
263
    return g
264

265

266
def check_rpc_hetero_find_edges_shuffle(tmpdir, num_server):
267
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
268
269
270
271

    g = create_random_hetero()
    num_parts = num_server

272
273
274
275
276
277
278
279
280
    orig_nid, orig_eid = partition_graph(
        g,
        "test_find_edges",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
281
282

    pserver_list = []
283
    ctx = mp.get_context("spawn")
284
    for i in range(num_server):
285
286
287
288
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_find_edges", ["csr", "coo"]),
        )
289
290
291
292
        p.start()
        time.sleep(1)
        pserver_list.append(p)

293
    test_etype = g.to_canonical_etype("r12")
294
    eids = F.tensor(np.random.randint(g.num_edges(test_etype), size=100))
295
296
    expect_except = False
    try:
297
        _, _ = g.find_edges(orig_eid[test_etype][eids], etype=("n1", "r12"))
298
299
300
    except:
        expect_except = True
    assert expect_except
301
302
    u, v = g.find_edges(orig_eid[test_etype][eids], etype="r12")
    u1, v1 = g.find_edges(orig_eid[test_etype][eids], etype=("n1", "r12", "n2"))
303
304
    assert F.array_equal(u, u1)
    assert F.array_equal(v, v1)
305
306
307
308
309
    du, dv = start_find_edges_client(
        0, tmpdir, num_server > 1, eids, etype="r12"
    )
    du = orig_nid["n1"][du]
    dv = orig_nid["n2"][dv]
310
311
312
    assert F.array_equal(u, du)
    assert F.array_equal(v, dv)

313

314
# Wait non shared memory graph store
315
316
317
318
319
320
321
322
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
323
@pytest.mark.parametrize("num_server", [1])
324
def test_rpc_find_edges_shuffle(num_server):
325
    reset_envs()
326
    import tempfile
327
328

    os.environ["DGL_DIST_MODE"] = "distributed"
329
    with tempfile.TemporaryDirectory() as tmpdirname:
330
        check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), num_server)
331
332
        check_rpc_find_edges_shuffle(Path(tmpdirname), num_server)

333

334
def check_rpc_get_degree_shuffle(tmpdir, num_server):
335
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
336
337
338
339

    g = CitationGraphDataset("cora")[0]
    num_parts = num_server

340
341
342
343
344
345
346
347
348
    orig_nid, _ = partition_graph(
        g,
        "test_get_degrees",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
349
350

    pserver_list = []
351
    ctx = mp.get_context("spawn")
352
    for i in range(num_server):
353
354
355
356
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_get_degrees"),
        )
357
358
359
360
        p.start()
        time.sleep(1)
        pserver_list.append(p)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
361
    nids = F.tensor(np.random.randint(g.num_nodes(), size=100))
362
363
364
    in_degs, out_degs, all_in_degs, all_out_degs = start_get_degrees_client(
        0, tmpdir, num_server > 1, nids
    )
365
366
367
368

    print("Done get_degree")
    for p in pserver_list:
        p.join()
369
        assert p.exitcode == 0
370

371
    print("check results")
372
373
374
375
376
    assert F.array_equal(g.in_degrees(orig_nid[nids]), in_degs)
    assert F.array_equal(g.in_degrees(orig_nid), all_in_degs)
    assert F.array_equal(g.out_degrees(orig_nid[nids]), out_degs)
    assert F.array_equal(g.out_degrees(orig_nid), all_out_degs)

377

378
# Wait non shared memory graph store
379
380
381
382
383
384
385
386
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
387
@pytest.mark.parametrize("num_server", [1])
388
def test_rpc_get_degree_shuffle(num_server):
389
    reset_envs()
390
    import tempfile
391
392

    os.environ["DGL_DIST_MODE"] = "distributed"
393
394
395
    with tempfile.TemporaryDirectory() as tmpdirname:
        check_rpc_get_degree_shuffle(Path(tmpdirname), num_server)

396
397
398
399

# @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
# @unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
@unittest.skip("Only support partition with shuffle")
Jinjing Zhou's avatar
Jinjing Zhou committed
400
def test_rpc_sampling():
401
    reset_envs()
Jinjing Zhou's avatar
Jinjing Zhou committed
402
    import tempfile
403
404

    os.environ["DGL_DIST_MODE"] = "distributed"
Jinjing Zhou's avatar
Jinjing Zhou committed
405
    with tempfile.TemporaryDirectory() as tmpdirname:
406
        check_rpc_sampling(Path(tmpdirname), 1)
Jinjing Zhou's avatar
Jinjing Zhou committed
407

408

409
def check_rpc_sampling_shuffle(
410
411
412
413
414
415
    tmpdir,
    num_server,
    num_groups=1,
    use_graphbolt=False,
    return_eids=False,
    node_id_dtype=None,
416
):
417
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
418

Jinjing Zhou's avatar
Jinjing Zhou committed
419
420
421
422
    g = CitationGraphDataset("cora")[0]
    num_parts = num_server
    num_hops = 1

423
424
425
426
427
428
429
430
    orig_nids, orig_eids = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
431
        use_graphbolt=use_graphbolt,
432
        store_eids=return_eids,
433
    )
Jinjing Zhou's avatar
Jinjing Zhou committed
434
435

    pserver_list = []
436
    ctx = mp.get_context("spawn")
Jinjing Zhou's avatar
Jinjing Zhou committed
437
    for i in range(num_server):
438
439
440
441
442
443
444
445
        p = ctx.Process(
            target=start_server,
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
446
                use_graphbolt,
447
448
            ),
        )
Jinjing Zhou's avatar
Jinjing Zhou committed
449
450
451
452
        p.start()
        time.sleep(1)
        pserver_list.append(p)

453
454
455
456
    pclient_list = []
    num_clients = 1
    for client_id in range(num_clients):
        for group_id in range(num_groups):
457
458
459
460
461
462
463
464
465
466
467
            p = ctx.Process(
                target=start_sample_client_shuffle,
                args=(
                    client_id,
                    tmpdir,
                    num_server > 1,
                    g,
                    num_server,
                    group_id,
                    orig_nids,
                    orig_eids,
468
                    use_graphbolt,
469
                    return_eids,
470
                    node_id_dtype,
471
472
                ),
            )
473
            p.start()
474
            time.sleep(1)  # avoid race condition when instantiating DistGraph
475
476
477
            pclient_list.append(p)
    for p in pclient_list:
        p.join()
478
        assert p.exitcode == 0
Jinjing Zhou's avatar
Jinjing Zhou committed
479
480
    for p in pserver_list:
        p.join()
481
        assert p.exitcode == 0
Jinjing Zhou's avatar
Jinjing Zhou committed
482

483

484
485
486
487
488
489
490
491
def start_hetero_sample_client(
    rank,
    tmpdir,
    disable_shared_mem,
    nodes,
    use_graphbolt=False,
    return_eids=False,
):
492
493
    gpb = None
    if disable_shared_mem:
494
495
496
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
497
    dgl.distributed.initialize("rpc_ip_config.txt")
498
    dist_graph = DistGraph("test_sampling", gpb=gpb)
499
500
501
    assert "feat" in dist_graph.nodes["n1"].data
    assert "feat" not in dist_graph.nodes["n2"].data
    assert "feat" not in dist_graph.nodes["n3"].data
502
503
504
    nodes = {
        k: torch.tensor(v, dtype=dist_graph.idtype) for k, v in nodes.items()
    }
505
506
507
    if gpb is None:
        gpb = dist_graph.get_partition_book()
    try:
508
509
510
511
512
        # Enable santity check in distributed sampling.
        os.environ["DGL_DIST_DEBUG"] = "1"
        sampled_graph = sample_neighbors(
            dist_graph, nodes, 3, use_graphbolt=use_graphbolt
        )
513
        block = dgl.to_block(sampled_graph, nodes)
514
515
        if not use_graphbolt or return_eids:
            block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
516
    except Exception as e:
517
        print(traceback.format_exc())
518
519
520
521
        block = None
    dgl.distributed.exit_client()
    return block, gpb

522
523
524
525
526
527

def start_hetero_etype_sample_client(
    rank,
    tmpdir,
    disable_shared_mem,
    fanout=3,
528
    nodes=None,
529
    etype_sorted=False,
530
531
    use_graphbolt=False,
    return_eids=False,
532
):
533
534
    gpb = None
    if disable_shared_mem:
535
536
537
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
538
    dgl.distributed.initialize("rpc_ip_config.txt")
539
    dist_graph = DistGraph("test_sampling", gpb=gpb)
540
541
542
    assert "feat" in dist_graph.nodes["n1"].data
    assert "feat" not in dist_graph.nodes["n2"].data
    assert "feat" not in dist_graph.nodes["n3"].data
543
544
545
    nodes = {
        k: torch.tensor(v, dtype=dist_graph.idtype) for k, v in nodes.items()
    }
546

547
    if (not use_graphbolt) and dist_graph.local_partition is not None:
548
549
550
551
        # Check whether etypes are sorted in dist_graph
        local_g = dist_graph.local_partition
        local_nids = np.arange(local_g.num_nodes())
        for lnid in local_nids:
552
            leids = local_g.in_edges(lnid, form="eid")
553
554
555
556
            letids = F.asnumpy(local_g.edata[dgl.ETYPE][leids])
            _, idices = np.unique(letids, return_index=True)
            assert np.all(idices[:-1] <= idices[1:])

557
558
559
    if gpb is None:
        gpb = dist_graph.get_partition_book()
    try:
560
561
        # Enable santity check in distributed sampling.
        os.environ["DGL_DIST_DEBUG"] = "1"
562
        sampled_graph = sample_etype_neighbors(
563
564
565
566
567
            dist_graph,
            nodes,
            fanout,
            etype_sorted=etype_sorted,
            use_graphbolt=use_graphbolt,
568
        )
569
        block = dgl.to_block(sampled_graph, nodes)
570
571
572
        if sampled_graph.num_edges() > 0:
            if not use_graphbolt or return_eids:
                block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
573
    except Exception as e:
574
        print(traceback.format_exc())
575
576
577
578
        block = None
    dgl.distributed.exit_client()
    return block, gpb

579

580
581
582
def check_rpc_hetero_sampling_shuffle(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
583
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
584
585
586
587
588

    g = create_random_hetero()
    num_parts = num_server
    num_hops = 1

589
590
591
592
593
594
595
596
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
597
598
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
599
    )
600
601

    pserver_list = []
602
    ctx = mp.get_context("spawn")
603
    for i in range(num_server):
604
605
        p = ctx.Process(
            target=start_server,
606
607
608
609
610
611
612
613
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
614
        )
615
616
617
618
        p.start()
        time.sleep(1)
        pserver_list.append(p)

619
    nodes = {"n3": torch.tensor([0, 10, 99, 66, 124, 208], dtype=g.idtype)}
620
    block, gpb = start_hetero_sample_client(
621
622
623
        0,
        tmpdir,
        num_server > 1,
624
        nodes=nodes,
625
626
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
627
    )
628
629
    for p in pserver_list:
        p.join()
630
        assert p.exitcode == 0
631

632
633
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
634
635
636
637
638
639
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
640
641
642
643
644
645
646
647
648

        assert np.all(
            F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))
        )

        if use_graphbolt and not return_eids:
            continue

        shuffled_eid = block.edges[etype].data[dgl.EID]
649
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
650
651

        # Check the node Ids and edge Ids.
652
653
654
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)
655

656

657
658
659
660
661
662
663
664
665
def get_degrees(g, nids, ntype):
    deg = F.zeros((len(nids),), dtype=F.int64)
    for srctype, etype, dsttype in g.canonical_etypes:
        if srctype == ntype:
            deg += g.out_degrees(u=nids, etype=etype)
        elif dsttype == ntype:
            deg += g.in_degrees(v=nids, etype=etype)
    return deg

666

667
668
669
def check_rpc_hetero_sampling_empty_shuffle(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
670
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
671
672
673
674
675

    g = create_random_hetero(empty=True)
    num_parts = num_server
    num_hops = 1

676
677
678
679
680
681
682
683
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
684
685
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
686
    )
687
688

    pserver_list = []
689
    ctx = mp.get_context("spawn")
690
    for i in range(num_server):
691
692
        p = ctx.Process(
            target=start_server,
693
694
695
696
697
698
699
700
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
701
        )
702
703
704
705
        p.start()
        time.sleep(1)
        pserver_list.append(p)

706
    deg = get_degrees(g, orig_nids["n3"], "n3")
707
    empty_nids = F.nonzero_1d(deg == 0).to(g.idtype)
708
    block, gpb = start_hetero_sample_client(
709
710
711
712
713
714
        0,
        tmpdir,
        num_server > 1,
        nodes={"n3": empty_nids},
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
715
    )
716
717
    for p in pserver_list:
        p.join()
718
        assert p.exitcode == 0
719

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
720
    assert block.num_edges() == 0
721
722
    assert len(block.etypes) == len(g.etypes)

723
724

def check_rpc_hetero_etype_sampling_shuffle(
725
726
727
728
729
    tmpdir,
    num_server,
    graph_formats=None,
    use_graphbolt=False,
    return_eids=False,
730
):
731
732
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

733
734
735
736
    g = create_random_hetero(dense=True)
    num_parts = num_server
    num_hops = 1

737
738
739
740
741
742
743
744
745
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
        graph_formats=graph_formats,
746
747
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
748
    )
749
750

    pserver_list = []
751
    ctx = mp.get_context("spawn")
752
    for i in range(num_server):
753
754
        p = ctx.Process(
            target=start_server,
755
756
757
758
759
760
761
762
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
763
        )
764
765
766
767
        p.start()
        time.sleep(1)
        pserver_list.append(p)

768
    fanout = {etype: 3 for etype in g.canonical_etypes}
769
770
    etype_sorted = False
    if graph_formats is not None:
771
        etype_sorted = "csc" in graph_formats or "csr" in graph_formats
772
    nodes = {"n3": torch.tensor([0, 10, 99, 66, 124, 208], dtype=g.idtype)}
773
774
775
776
777
    block, gpb = start_hetero_etype_sample_client(
        0,
        tmpdir,
        num_server > 1,
        fanout,
778
        nodes=nodes,
779
        etype_sorted=etype_sorted,
780
781
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
782
    )
783
784
785
    print("Done sampling")
    for p in pserver_list:
        p.join()
786
        assert p.exitcode == 0
787

788
    src, dst = block.edges(etype=("n1", "r13", "n3"))
789
    assert len(src) == 18
790
    src, dst = block.edges(etype=("n2", "r23", "n3"))
791
792
    assert len(src) == 18

793
794
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
795
796
797
798
799
800
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
801
802
803
804
805
806
        assert np.all(
            F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))
        )

        if use_graphbolt and not return_eids:
            continue
807
808

        # Check the node Ids and edge Ids.
809
810
        shuffled_eid = block.edges[etype].data[dgl.EID]
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
811
812
813
814
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)

815

816
817
818
def check_rpc_hetero_etype_sampling_empty_shuffle(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
819
820
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

821
822
823
824
    g = create_random_hetero(dense=True, empty=True)
    num_parts = num_server
    num_hops = 1

825
826
827
828
829
830
831
832
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
833
834
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
835
    )
836
837

    pserver_list = []
838
    ctx = mp.get_context("spawn")
839
    for i in range(num_server):
840
841
        p = ctx.Process(
            target=start_server,
842
843
844
845
846
847
848
849
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
850
        )
851
852
853
854
855
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    fanout = 3
856
    deg = get_degrees(g, orig_nids["n3"], "n3")
857
    empty_nids = F.nonzero_1d(deg == 0).to(g.idtype)
858
    block, gpb = start_hetero_etype_sample_client(
859
860
861
862
863
864
865
        0,
        tmpdir,
        num_server > 1,
        fanout,
        nodes={"n3": empty_nids},
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
866
    )
867
868
869
    print("Done sampling")
    for p in pserver_list:
        p.join()
870
        assert p.exitcode == 0
871

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
872
    assert block.num_edges() == 0
873
874
    assert len(block.etypes) == len(g.etypes)

875
876

def create_random_bipartite():
877
878
879
880
881
882
883
    g = dgl.rand_bipartite("user", "buys", "game", 500, 1000, 1000)
    g.nodes["user"].data["feat"] = F.ones(
        (g.num_nodes("user"), 10), F.float32, F.cpu()
    )
    g.nodes["game"].data["feat"] = F.ones(
        (g.num_nodes("game"), 10), F.float32, F.cpu()
    )
884
885
886
    return g


887
888
889
890
891
892
893
894
def start_bipartite_sample_client(
    rank,
    tmpdir,
    disable_shared_mem,
    nodes,
    use_graphbolt=False,
    return_eids=False,
):
895
896
897
    gpb = None
    if disable_shared_mem:
        _, _, _, gpb, _, _, _ = load_partition(
898
899
            tmpdir / "test_sampling.json", rank
        )
900
    dgl.distributed.initialize("rpc_ip_config.txt")
901
    dist_graph = DistGraph("test_sampling", gpb=gpb)
902
903
    assert "feat" in dist_graph.nodes["user"].data
    assert "feat" in dist_graph.nodes["game"].data
904
905
906
    nodes = {
        k: torch.tensor(v, dtype=dist_graph.idtype) for k, v in nodes.items()
    }
907
908
    if gpb is None:
        gpb = dist_graph.get_partition_book()
909
910
911
912
913
    # Enable santity check in distributed sampling.
    os.environ["DGL_DIST_DEBUG"] = "1"
    sampled_graph = sample_neighbors(
        dist_graph, nodes, 3, use_graphbolt=use_graphbolt
    )
914
915
    block = dgl.to_block(sampled_graph, nodes)
    if sampled_graph.num_edges() > 0:
916
917
        if not use_graphbolt or return_eids:
            block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
918
919
920
921
    dgl.distributed.exit_client()
    return block, gpb


922
def start_bipartite_etype_sample_client(
923
924
925
926
927
928
929
    rank,
    tmpdir,
    disable_shared_mem,
    fanout=3,
    nodes={},
    use_graphbolt=False,
    return_eids=False,
930
):
931
932
933
    gpb = None
    if disable_shared_mem:
        _, _, _, gpb, _, _, _ = load_partition(
934
935
            tmpdir / "test_sampling.json", rank
        )
936
    dgl.distributed.initialize("rpc_ip_config.txt")
937
    dist_graph = DistGraph("test_sampling", gpb=gpb)
938
939
    assert "feat" in dist_graph.nodes["user"].data
    assert "feat" in dist_graph.nodes["game"].data
940
941
942
    nodes = {
        k: torch.tensor(v, dtype=dist_graph.idtype) for k, v in nodes.items()
    }
943

944
    if not use_graphbolt and dist_graph.local_partition is not None:
945
946
947
948
        # Check whether etypes are sorted in dist_graph
        local_g = dist_graph.local_partition
        local_nids = np.arange(local_g.num_nodes())
        for lnid in local_nids:
949
            leids = local_g.in_edges(lnid, form="eid")
950
951
952
953
954
955
            letids = F.asnumpy(local_g.edata[dgl.ETYPE][leids])
            _, idices = np.unique(letids, return_index=True)
            assert np.all(idices[:-1] <= idices[1:])

    if gpb is None:
        gpb = dist_graph.get_partition_book()
956
957
958
    sampled_graph = sample_etype_neighbors(
        dist_graph, nodes, fanout, use_graphbolt=use_graphbolt
    )
959
960
    block = dgl.to_block(sampled_graph, nodes)
    if sampled_graph.num_edges() > 0:
961
962
        if not use_graphbolt or return_eids:
            block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
963
964
965
966
    dgl.distributed.exit_client()
    return block, gpb


967
968
969
def check_rpc_bipartite_sampling_empty(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
970
971
972
973
974
975
976
    """sample on bipartite via sample_neighbors() which yields empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

977
978
979
980
981
982
983
984
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
985
986
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
987
    )
988
989

    pserver_list = []
990
    ctx = mp.get_context("spawn")
991
    for i in range(num_server):
992
993
        p = ctx.Process(
            target=start_server,
994
995
996
997
998
999
1000
1001
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
1002
        )
1003
1004
1005
1006
        p.start()
        time.sleep(1)
        pserver_list.append(p)

1007
    deg = get_degrees(g, orig_nids["game"], "game")
1008
    empty_nids = F.nonzero_1d(deg == 0).to(g.idtype)
1009
    nodes = {"game": empty_nids, "user": torch.tensor([1], dtype=g.idtype)}
1010
    block, _ = start_bipartite_sample_client(
1011
1012
1013
        0,
        tmpdir,
        num_server > 1,
1014
        nodes=nodes,
1015
1016
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
1017
    )
1018
1019
1020
1021

    print("Done sampling")
    for p in pserver_list:
        p.join()
1022
        assert p.exitcode == 0
1023

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1024
    assert block.num_edges() == 0
1025
1026
1027
    assert len(block.etypes) == len(g.etypes)


1028
1029
1030
def check_rpc_bipartite_sampling_shuffle(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
1031
1032
1033
1034
1035
1036
1037
    """sample on bipartite via sample_neighbors() which yields non-empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

1038
1039
1040
1041
1042
1043
1044
1045
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
1046
1047
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
1048
    )
1049
1050

    pserver_list = []
1051
    ctx = mp.get_context("spawn")
1052
    for i in range(num_server):
1053
1054
        p = ctx.Process(
            target=start_server,
1055
1056
1057
1058
1059
1060
1061
1062
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
1063
        )
1064
1065
1066
1067
        p.start()
        time.sleep(1)
        pserver_list.append(p)

1068
    deg = get_degrees(g, orig_nid_map["game"], "game")
1069
    nids = F.nonzero_1d(deg > 0)
1070
    nodes = {"game": nids, "user": torch.tensor([0], dtype=g.idtype)}
1071
    block, gpb = start_bipartite_sample_client(
1072
1073
1074
        0,
        tmpdir,
        num_server > 1,
1075
        nodes=nodes,
1076
1077
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
1078
    )
1079
1080
1081
    print("Done sampling")
    for p in pserver_list:
        p.join()
1082
        assert p.exitcode == 0
1083

1084
1085
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
1086
1087
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
1088
1089
1090
1091
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
1092
1093
1094
1095
1096
1097
1098
1099
        assert np.all(
            F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))
        )

        if use_graphbolt and not return_eids:
            continue

        shuffled_eid = block.edges[etype].data[dgl.EID]
1100
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
1101
1102
1103
1104
1105
1106
1107

        # Check the node Ids and edge Ids.
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)


1108
1109
1110
def check_rpc_bipartite_etype_sampling_empty(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
1111
1112
1113
1114
1115
1116
1117
    """sample on bipartite via sample_etype_neighbors() which yields empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

1118
1119
1120
1121
1122
1123
1124
1125
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
1126
1127
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
1128
    )
1129
1130

    pserver_list = []
1131
    ctx = mp.get_context("spawn")
1132
    for i in range(num_server):
1133
1134
        p = ctx.Process(
            target=start_server,
1135
1136
1137
1138
1139
1140
1141
1142
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
1143
        )
1144
1145
1146
1147
        p.start()
        time.sleep(1)
        pserver_list.append(p)

1148
    deg = get_degrees(g, orig_nids["game"], "game")
1149
    empty_nids = F.nonzero_1d(deg == 0).to(g.idtype)
1150
    nodes = {"game": empty_nids, "user": torch.tensor([1], dtype=g.idtype)}
1151
1152
1153
1154
    block, _ = start_bipartite_etype_sample_client(
        0,
        tmpdir,
        num_server > 1,
1155
        nodes=nodes,
1156
1157
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
1158
    )
1159
1160
1161
1162

    print("Done sampling")
    for p in pserver_list:
        p.join()
1163
        assert p.exitcode == 0
1164
1165

    assert block is not None
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1166
    assert block.num_edges() == 0
1167
1168
1169
    assert len(block.etypes) == len(g.etypes)


1170
1171
1172
def check_rpc_bipartite_etype_sampling_shuffle(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
1173
1174
1175
1176
1177
1178
1179
    """sample on bipartite via sample_etype_neighbors() which yields non-empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

1180
1181
1182
1183
1184
1185
1186
1187
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
1188
1189
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
1190
    )
1191
1192

    pserver_list = []
1193
    ctx = mp.get_context("spawn")
1194
    for i in range(num_server):
1195
1196
        p = ctx.Process(
            target=start_server,
1197
1198
1199
1200
1201
1202
1203
1204
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
1205
        )
1206
1207
1208
1209
1210
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    fanout = 3
1211
    deg = get_degrees(g, orig_nid_map["game"], "game")
1212
    nids = F.nonzero_1d(deg > 0)
1213
    nodes = {"game": nids, "user": torch.tensor([0], dtype=g.idtype)}
1214
    block, gpb = start_bipartite_etype_sample_client(
1215
1216
1217
1218
        0,
        tmpdir,
        num_server > 1,
        fanout,
1219
        nodes=nodes,
1220
1221
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
1222
    )
1223
1224
1225
    print("Done sampling")
    for p in pserver_list:
        p.join()
1226
        assert p.exitcode == 0
1227

1228
1229
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
1230
1231
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
1232
1233
1234
1235
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
1236
1237
1238
1239
1240
1241
        assert np.all(
            F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))
        )

        if use_graphbolt and not return_eids:
            continue
1242
1243

        # Check the node Ids and edge Ids.
1244
1245
        shuffled_eid = block.edges[etype].data[dgl.EID]
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
1246
1247
1248
1249
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)

1250

1251
@pytest.mark.parametrize("num_server", [1])
1252
@pytest.mark.parametrize("use_graphbolt", [False, True])
1253
@pytest.mark.parametrize("return_eids", [False, True])
1254
1255
1256
1257
@pytest.mark.parametrize("node_id_dtype", [torch.int32, torch.int64])
def test_rpc_sampling_shuffle(
    num_server, use_graphbolt, return_eids, node_id_dtype
):
1258
    reset_envs()
1259
    os.environ["DGL_DIST_MODE"] = "distributed"
Jinjing Zhou's avatar
Jinjing Zhou committed
1260
    with tempfile.TemporaryDirectory() as tmpdirname:
1261
        check_rpc_sampling_shuffle(
1262
1263
1264
1265
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
1266
            node_id_dtype=node_id_dtype,
1267
        )
1268
1269
1270


@pytest.mark.parametrize("num_server", [1])
1271
1272
1273
@pytest.mark.parametrize("use_graphbolt,", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_sampling_shuffle(num_server, use_graphbolt, return_eids):
1274
1275
1276
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1277
1278
1279
1280
1281
1282
        check_rpc_hetero_sampling_shuffle(
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
        )
1283
1284
1285


@pytest.mark.parametrize("num_server", [1])
1286
1287
1288
1289
1290
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_sampling_empty_shuffle(
    num_server, use_graphbolt, return_eids
):
1291
1292
1293
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1294
1295
1296
1297
1298
1299
        check_rpc_hetero_sampling_empty_shuffle(
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
        )
1300
1301
1302
1303
1304
1305


@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize(
    "graph_formats", [None, ["csc"], ["csr"], ["csc", "coo"]]
)
1306
def test_rpc_hetero_etype_sampling_shuffle_dgl(num_server, graph_formats):
1307
1308
1309
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1310
        check_rpc_hetero_etype_sampling_shuffle(
1311
            Path(tmpdirname), num_server, graph_formats=graph_formats
1312
        )
1313
1314
1315


@pytest.mark.parametrize("num_server", [1])
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_etype_sampling_shuffle_graphbolt(num_server, return_eids):
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
        check_rpc_hetero_etype_sampling_shuffle(
            Path(tmpdirname),
            num_server,
            use_graphbolt=True,
            return_eids=return_eids,
        )


@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_etype_sampling_empty_shuffle(
    num_server, use_graphbolt, return_eids
):
1335
1336
1337
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1338
        check_rpc_hetero_etype_sampling_empty_shuffle(
1339
1340
1341
1342
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
1343
        )
1344
1345
1346


@pytest.mark.parametrize("num_server", [1])
1347
1348
1349
1350
1351
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_bipartite_sampling_empty_shuffle(
    num_server, use_graphbolt, return_eids
):
1352
1353
1354
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1355
1356
1357
        check_rpc_bipartite_sampling_empty(
            Path(tmpdirname), num_server, use_graphbolt, return_eids
        )
1358
1359
1360


@pytest.mark.parametrize("num_server", [1])
1361
1362
1363
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_bipartite_sampling_shuffle(num_server, use_graphbolt, return_eids):
1364
1365
1366
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1367
1368
1369
        check_rpc_bipartite_sampling_shuffle(
            Path(tmpdirname), num_server, use_graphbolt, return_eids
        )
1370
1371
1372


@pytest.mark.parametrize("num_server", [1])
1373
1374
1375
1376
1377
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_bipartite_etype_sampling_empty_shuffle(
    num_server, use_graphbolt, return_eids
):
1378
1379
1380
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1381
1382
1383
1384
1385
1386
        check_rpc_bipartite_etype_sampling_empty(
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
        )
1387
1388
1389


@pytest.mark.parametrize("num_server", [1])
1390
1391
1392
1393
1394
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_bipartite_etype_sampling_shuffle(
    num_server, use_graphbolt, return_eids
):
1395
1396
1397
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1398
1399
1400
1401
1402
1403
        check_rpc_bipartite_etype_sampling_shuffle(
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
        )
Jinjing Zhou's avatar
Jinjing Zhou committed
1404

1405

1406
def check_standalone_sampling(tmpdir):
1407
    g = CitationGraphDataset("cora")[0]
1408
    prob = np.maximum(np.random.randn(g.num_edges()), 0)
1409
1410
1411
    mask = prob > 0
    g.edata["prob"] = F.tensor(prob)
    g.edata["mask"] = F.tensor(mask)
1412
1413
    num_parts = 1
    num_hops = 1
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
    partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )

    os.environ["DGL_DIST_MODE"] = "standalone"
1424
    dgl.distributed.initialize("rpc_ip_config.txt")
1425
1426
1427
    dist_graph = DistGraph(
        "test_sampling", part_config=tmpdir / "test_sampling.json"
    )
1428
1429
1430
1431
1432
    sampled_graph = sample_neighbors(
        dist_graph,
        torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=dist_graph.idtype),
        3,
    )
1433
1434

    src, dst = sampled_graph.edges()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1435
    assert sampled_graph.num_nodes() == g.num_nodes()
1436
1437
1438
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    assert np.array_equal(
1439
1440
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)
    )
1441
1442

    sampled_graph = sample_neighbors(
1443
1444
1445
1446
        dist_graph,
        torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=dist_graph.idtype),
        3,
        prob="mask",
1447
    )
1448
1449
1450
1451
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert mask[eid].all()

    sampled_graph = sample_neighbors(
1452
1453
1454
1455
        dist_graph,
        torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=dist_graph.idtype),
        3,
        prob="prob",
1456
    )
1457
1458
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert (prob[eid] > 0).all()
1459
    dgl.distributed.exit_client()
1460

1461

1462
def check_standalone_etype_sampling(tmpdir):
1463
    hg = CitationGraphDataset("cora")[0]
1464
    prob = np.maximum(np.random.randn(hg.num_edges()), 0)
1465
1466
1467
    mask = prob > 0
    hg.edata["prob"] = F.tensor(prob)
    hg.edata["mask"] = F.tensor(mask)
1468
1469
1470
    num_parts = 1
    num_hops = 1

1471
1472
1473
1474
1475
1476
1477
1478
1479
    partition_graph(
        hg,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )
    os.environ["DGL_DIST_MODE"] = "standalone"
1480
    dgl.distributed.initialize("rpc_ip_config.txt")
1481
1482
1483
    dist_graph = DistGraph(
        "test_sampling", part_config=tmpdir / "test_sampling.json"
    )
1484
1485
1486
1487
1488
    sampled_graph = sample_etype_neighbors(
        dist_graph,
        torch.tensor([0, 10, 99, 66, 1023], dtype=dist_graph.idtype),
        3,
    )
1489
1490

    src, dst = sampled_graph.edges()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1491
    assert sampled_graph.num_nodes() == hg.num_nodes()
1492
1493
1494
    assert np.all(F.asnumpy(hg.has_edges_between(src, dst)))
    eids = hg.edge_ids(src, dst)
    assert np.array_equal(
1495
1496
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)
    )
1497
1498

    sampled_graph = sample_etype_neighbors(
1499
1500
1501
1502
        dist_graph,
        torch.tensor([0, 10, 99, 66, 1023], dtype=dist_graph.idtype),
        3,
        prob="mask",
1503
    )
1504
1505
1506
1507
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert mask[eid].all()

    sampled_graph = sample_etype_neighbors(
1508
1509
1510
1511
        dist_graph,
        torch.tensor([0, 10, 99, 66, 1023], dtype=dist_graph.idtype),
        3,
        prob="prob",
1512
    )
1513
1514
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert (prob[eid] > 0).all()
1515
1516
    dgl.distributed.exit_client()

1517

1518
def check_standalone_etype_sampling_heterograph(tmpdir):
1519
    hg = CitationGraphDataset("cora")[0]
1520
1521
1522
    num_parts = 1
    num_hops = 1
    src, dst = hg.edges()
1523
1524
1525
1526
1527
    new_hg = dgl.heterograph(
        {
            ("paper", "cite", "paper"): (src, dst),
            ("paper", "cite-by", "paper"): (dst, src),
        },
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1528
        {"paper": hg.num_nodes()},
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
    )
    partition_graph(
        new_hg,
        "test_hetero_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )
    os.environ["DGL_DIST_MODE"] = "standalone"
1539
    dgl.distributed.initialize("rpc_ip_config.txt")
1540
1541
1542
    dist_graph = DistGraph(
        "test_hetero_sampling", part_config=tmpdir / "test_hetero_sampling.json"
    )
1543
    sampled_graph = sample_etype_neighbors(
1544
1545
1546
1547
1548
1549
        dist_graph,
        torch.tensor(
            [0, 1, 2, 10, 99, 66, 1023, 1024, 2700, 2701],
            dtype=dist_graph.idtype,
        ),
        1,
1550
1551
    )
    src, dst = sampled_graph.edges(etype=("paper", "cite", "paper"))
1552
    assert len(src) == 10
1553
    src, dst = sampled_graph.edges(etype=("paper", "cite-by", "paper"))
1554
    assert len(src) == 10
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1555
    assert sampled_graph.num_nodes() == new_hg.num_nodes()
1556
1557
    dgl.distributed.exit_client()

1558
1559
1560
1561
1562
1563

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
1564
def test_standalone_sampling():
1565
    reset_envs()
1566
    import tempfile
1567
1568

    os.environ["DGL_DIST_MODE"] = "standalone"
1569
    with tempfile.TemporaryDirectory() as tmpdirname:
1570
        check_standalone_sampling(Path(tmpdirname))
1571

1572

1573
1574
def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes):
    gpb = None
1575
    dgl.distributed.initialize("rpc_ip_config.txt")
1576
    if disable_shared_mem:
1577
1578
1579
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_in_subgraph.json", rank
        )
1580
    dist_graph = DistGraph("test_in_subgraph", gpb=gpb)
1581
1582
1583
    try:
        sampled_graph = dgl.distributed.in_subgraph(dist_graph, nodes)
    except Exception as e:
1584
        print(traceback.format_exc())
1585
        sampled_graph = None
1586
    dgl.distributed.exit_client()
1587
1588
1589
    return sampled_graph


1590
def check_rpc_in_subgraph_shuffle(tmpdir, num_server):
1591
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
1592
1593
1594
1595

    g = CitationGraphDataset("cora")[0]
    num_parts = num_server

1596
1597
1598
1599
1600
1601
1602
1603
1604
    orig_nid, orig_eid = partition_graph(
        g,
        "test_in_subgraph",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
1605
1606

    pserver_list = []
1607
    ctx = mp.get_context("spawn")
1608
    for i in range(num_server):
1609
1610
1611
1612
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_in_subgraph"),
        )
1613
1614
1615
1616
        p.start()
        time.sleep(1)
        pserver_list.append(p)

1617
    nodes = torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=g.idtype)
1618
1619
1620
    sampled_graph = start_in_subgraph_client(0, tmpdir, num_server > 1, nodes)
    for p in pserver_list:
        p.join()
1621
        assert p.exitcode == 0
1622
1623

    src, dst = sampled_graph.edges()
1624
1625
    src = orig_nid[src]
    dst = orig_nid[dst]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1626
    assert sampled_graph.num_nodes() == g.num_nodes()
1627
1628
1629
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))

    subg1 = dgl.in_subgraph(g, orig_nid[nodes])
1630
1631
1632
1633
    src1, dst1 = subg1.edges()
    assert np.all(np.sort(F.asnumpy(src)) == np.sort(F.asnumpy(src1)))
    assert np.all(np.sort(F.asnumpy(dst)) == np.sort(F.asnumpy(dst1)))
    eids = g.edge_ids(src, dst)
1634
1635
    eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
    assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))
1636

1637
1638
1639
1640
1641
1642

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
1643
def test_rpc_in_subgraph():
1644
    reset_envs()
1645
    import tempfile
1646
1647

    os.environ["DGL_DIST_MODE"] = "distributed"
1648
    with tempfile.TemporaryDirectory() as tmpdirname:
1649
        check_rpc_in_subgraph_shuffle(Path(tmpdirname), 1)
1650

1651
1652
1653
1654
1655
1656
1657
1658
1659

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
1660
def test_standalone_etype_sampling():
1661
    reset_envs()
1662
    import tempfile
1663

1664
    with tempfile.TemporaryDirectory() as tmpdirname:
1665
        os.environ["DGL_DIST_MODE"] = "standalone"
1666
        check_standalone_etype_sampling_heterograph(Path(tmpdirname))
1667
    with tempfile.TemporaryDirectory() as tmpdirname:
1668
        os.environ["DGL_DIST_MODE"] = "standalone"
1669
        check_standalone_etype_sampling(Path(tmpdirname))
1670

1671

Jinjing Zhou's avatar
Jinjing Zhou committed
1672
1673
if __name__ == "__main__":
    import tempfile
1674

Jinjing Zhou's avatar
Jinjing Zhou committed
1675
    with tempfile.TemporaryDirectory() as tmpdirname:
1676
        os.environ["DGL_DIST_MODE"] = "standalone"
1677
        check_standalone_etype_sampling_heterograph(Path(tmpdirname))
1678
1679

    with tempfile.TemporaryDirectory() as tmpdirname:
1680
        os.environ["DGL_DIST_MODE"] = "standalone"
1681
1682
        check_standalone_etype_sampling(Path(tmpdirname))
        check_standalone_sampling(Path(tmpdirname))
1683
        os.environ["DGL_DIST_MODE"] = "distributed"
1684
1685
        check_rpc_sampling(Path(tmpdirname), 2)
        check_rpc_sampling(Path(tmpdirname), 1)
1686
1687
        check_rpc_get_degree_shuffle(Path(tmpdirname), 1)
        check_rpc_get_degree_shuffle(Path(tmpdirname), 2)
1688
1689
        check_rpc_find_edges_shuffle(Path(tmpdirname), 2)
        check_rpc_find_edges_shuffle(Path(tmpdirname), 1)
1690
1691
        check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), 2)
1692
1693
1694
1695
        check_rpc_in_subgraph_shuffle(Path(tmpdirname), 2)
        check_rpc_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), 2)
1696
1697
1698
1699
        check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), 2)
        check_rpc_hetero_etype_sampling_empty_shuffle(Path(tmpdirname), 1)