test_distributed_sampling.py 50.3 KB
Newer Older
1
import multiprocessing as mp
Jinjing Zhou's avatar
Jinjing Zhou committed
2
import os
3
import random
4
import tempfile
Jinjing Zhou's avatar
Jinjing Zhou committed
5
import time
6
7
import traceback
import unittest
Jinjing Zhou's avatar
Jinjing Zhou committed
8
from pathlib import Path
9
10
11
12

import backend as F
import dgl
import numpy as np
13
import pytest
14
import torch
15
16
17
18
19
20
21
22
23
24
from dgl.data import CitationGraphDataset, WN18Dataset
from dgl.distributed import (
    DistGraph,
    DistGraphServer,
    load_partition,
    load_partition_book,
    partition_graph,
    sample_etype_neighbors,
    sample_neighbors,
)
25
from scipy import sparse as spsp
26
from utils import generate_ip_config, reset_envs
Jinjing Zhou's avatar
Jinjing Zhou committed
27
28


29
30
31
32
33
34
def start_server(
    rank,
    tmpdir,
    disable_shared_mem,
    graph_name,
    graph_format=["csc", "coo"],
35
    use_graphbolt=False,
36
37
38
39
40
41
42
43
44
):
    g = DistGraphServer(
        rank,
        "rpc_ip_config.txt",
        1,
        1,
        tmpdir / (graph_name + ".json"),
        disable_shared_mem=disable_shared_mem,
        graph_format=graph_format,
45
        use_graphbolt=use_graphbolt,
46
    )
Jinjing Zhou's avatar
Jinjing Zhou committed
47
48
49
    g.start()


50
def start_sample_client(rank, tmpdir, disable_shared_mem):
51
52
    gpb = None
    if disable_shared_mem:
53
54
55
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
56
    dgl.distributed.initialize("rpc_ip_config.txt")
57
    dist_graph = DistGraph("test_sampling", gpb=gpb)
58
    try:
59
        sampled_graph = sample_neighbors(
60
61
62
            dist_graph,
            torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=dist_graph.idtype),
            3,
63
        )
64
    except Exception as e:
65
        print(traceback.format_exc())
66
        sampled_graph = None
67
    dgl.distributed.exit_client()
Jinjing Zhou's avatar
Jinjing Zhou committed
68
69
    return sampled_graph

70

71
72
73
74
75
76
77
78
79
def start_sample_client_shuffle(
    rank,
    tmpdir,
    disable_shared_mem,
    g,
    num_servers,
    group_id,
    orig_nid,
    orig_eid,
80
    use_graphbolt=False,
81
    return_eids=False,
82
    node_id_dtype=None,
83
84
):
    os.environ["DGL_GROUP_ID"] = str(group_id)
85
86
    gpb = None
    if disable_shared_mem:
87
88
89
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
90
    dgl.distributed.initialize("rpc_ip_config.txt")
91
    dist_graph = DistGraph("test_sampling", gpb=gpb)
92
    sampled_graph = sample_neighbors(
93
        dist_graph,
94
        torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=node_id_dtype),
95
96
        3,
        use_graphbolt=use_graphbolt,
97
    )
98
    assert sampled_graph.idtype == dist_graph.idtype
99
    assert sampled_graph.idtype == torch.int64
100

101
102
103
    assert (
        dgl.ETYPE not in sampled_graph.edata
    ), "Etype should not be in homogeneous sampled graph."
104
105
106
    src, dst = sampled_graph.edges()
    src = orig_nid[src]
    dst = orig_nid[dst]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
107
    assert sampled_graph.num_nodes() == g.num_nodes()
108
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
109
    if use_graphbolt and not return_eids:
110
111
112
113
114
115
116
        assert (
            dgl.EID not in sampled_graph.edata
        ), "EID should not be in sampled graph if use_graphbolt=True."
    else:
        eids = g.edge_ids(src, dst)
        eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
        assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))
117

118

119
def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids, etype=None):
120
121
    gpb = None
    if disable_shared_mem:
122
123
124
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_find_edges.json", rank
        )
125
    dgl.distributed.initialize("rpc_ip_config.txt")
126
    dist_graph = DistGraph("test_find_edges", gpb=gpb)
127
    try:
128
        u, v = dist_graph.find_edges(eids, etype=etype)
129
    except Exception as e:
130
        print(traceback.format_exc())
131
        u, v = None, None
132
133
    dgl.distributed.exit_client()
    return u, v
Jinjing Zhou's avatar
Jinjing Zhou committed
134

135

136
137
138
def start_get_degrees_client(rank, tmpdir, disable_shared_mem, nids=None):
    gpb = None
    if disable_shared_mem:
139
140
141
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_get_degrees.json", rank
        )
142
    dgl.distributed.initialize("rpc_ip_config.txt")
143
144
145
146
147
148
149
    dist_graph = DistGraph("test_get_degrees", gpb=gpb)
    try:
        in_deg = dist_graph.in_degrees(nids)
        all_in_deg = dist_graph.in_degrees()
        out_deg = dist_graph.out_degrees(nids)
        all_out_deg = dist_graph.out_degrees()
    except Exception as e:
150
        print(traceback.format_exc())
151
152
153
154
        in_deg, out_deg, all_in_deg, all_out_deg = None, None, None, None
    dgl.distributed.exit_client()
    return in_deg, out_deg, all_in_deg, all_out_deg

155

156
def check_rpc_sampling(tmpdir, num_server):
157
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
Jinjing Zhou's avatar
Jinjing Zhou committed
158
159
160
161
162
163

    g = CitationGraphDataset("cora")[0]
    print(g.idtype)
    num_parts = num_server
    num_hops = 1

164
165
166
167
168
169
170
171
    partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )
Jinjing Zhou's avatar
Jinjing Zhou committed
172
173

    pserver_list = []
174
    ctx = mp.get_context("spawn")
Jinjing Zhou's avatar
Jinjing Zhou committed
175
    for i in range(num_server):
176
177
178
179
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
Jinjing Zhou's avatar
Jinjing Zhou committed
180
181
182
183
        p.start()
        time.sleep(1)
        pserver_list.append(p)

184
    sampled_graph = start_sample_client(0, tmpdir, num_server > 1)
Jinjing Zhou's avatar
Jinjing Zhou committed
185
186
187
    print("Done sampling")
    for p in pserver_list:
        p.join()
188
        assert p.exitcode == 0
Jinjing Zhou's avatar
Jinjing Zhou committed
189
190

    src, dst = sampled_graph.edges()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
191
    assert sampled_graph.num_nodes() == g.num_nodes()
Jinjing Zhou's avatar
Jinjing Zhou committed
192
193
194
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    assert np.array_equal(
195
196
197
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)
    )

Jinjing Zhou's avatar
Jinjing Zhou committed
198

199
def check_rpc_find_edges_shuffle(tmpdir, num_server):
200
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
201
202
203
204

    g = CitationGraphDataset("cora")[0]
    num_parts = num_server

205
206
207
208
209
210
211
212
213
    orig_nid, orig_eid = partition_graph(
        g,
        "test_find_edges",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
214
215

    pserver_list = []
216
    ctx = mp.get_context("spawn")
217
    for i in range(num_server):
218
219
220
221
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_find_edges", ["csr", "coo"]),
        )
222
223
224
225
        p.start()
        time.sleep(1)
        pserver_list.append(p)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
226
    eids = F.tensor(np.random.randint(g.num_edges(), size=100))
227
    u, v = g.find_edges(orig_eid[eids])
228
    du, dv = start_find_edges_client(0, tmpdir, num_server > 1, eids)
229
230
    du = orig_nid[du]
    dv = orig_nid[dv]
231
232
233
    assert F.array_equal(u, du)
    assert F.array_equal(v, dv)

234

235
def create_random_hetero(dense=False, empty=False):
236
237
238
239
240
241
    num_nodes = (
        {"n1": 210, "n2": 200, "n3": 220}
        if dense
        else {"n1": 1010, "n2": 1000, "n3": 1020}
    )
    etypes = [("n1", "r12", "n2"), ("n1", "r13", "n3"), ("n2", "r23", "n3")]
242
    edges = {}
243
    random.seed(42)
244
245
    for etype in etypes:
        src_ntype, _, dst_ntype = etype
246
247
248
249
250
251
252
        arr = spsp.random(
            num_nodes[src_ntype] - 10 if empty else num_nodes[src_ntype],
            num_nodes[dst_ntype] - 10 if empty else num_nodes[dst_ntype],
            density=0.1 if dense else 0.001,
            format="coo",
            random_state=100,
        )
253
        edges[etype] = (arr.row, arr.col)
254
    g = dgl.heterograph(edges, num_nodes)
255
    g.nodes["n1"].data["feat"] = F.ones(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
256
        (g.num_nodes("n1"), 10), F.float32, F.cpu()
257
    )
258
    return g
259

260

261
def check_rpc_hetero_find_edges_shuffle(tmpdir, num_server):
262
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
263
264
265
266

    g = create_random_hetero()
    num_parts = num_server

267
268
269
270
271
272
273
274
275
    orig_nid, orig_eid = partition_graph(
        g,
        "test_find_edges",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
276
277

    pserver_list = []
278
    ctx = mp.get_context("spawn")
279
    for i in range(num_server):
280
281
282
283
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_find_edges", ["csr", "coo"]),
        )
284
285
286
287
        p.start()
        time.sleep(1)
        pserver_list.append(p)

288
    test_etype = g.to_canonical_etype("r12")
289
    eids = F.tensor(np.random.randint(g.num_edges(test_etype), size=100))
290
291
    expect_except = False
    try:
292
        _, _ = g.find_edges(orig_eid[test_etype][eids], etype=("n1", "r12"))
293
294
295
    except:
        expect_except = True
    assert expect_except
296
297
    u, v = g.find_edges(orig_eid[test_etype][eids], etype="r12")
    u1, v1 = g.find_edges(orig_eid[test_etype][eids], etype=("n1", "r12", "n2"))
298
299
    assert F.array_equal(u, u1)
    assert F.array_equal(v, v1)
300
301
302
303
304
    du, dv = start_find_edges_client(
        0, tmpdir, num_server > 1, eids, etype="r12"
    )
    du = orig_nid["n1"][du]
    dv = orig_nid["n2"][dv]
305
306
307
    assert F.array_equal(u, du)
    assert F.array_equal(v, dv)

308

309
# Wait non shared memory graph store
310
311
312
313
314
315
316
317
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
318
@pytest.mark.parametrize("num_server", [1])
319
def test_rpc_find_edges_shuffle(num_server):
320
    reset_envs()
321
    import tempfile
322
323

    os.environ["DGL_DIST_MODE"] = "distributed"
324
    with tempfile.TemporaryDirectory() as tmpdirname:
325
        check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), num_server)
326
327
        check_rpc_find_edges_shuffle(Path(tmpdirname), num_server)

328

329
def check_rpc_get_degree_shuffle(tmpdir, num_server):
330
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
331
332
333
334

    g = CitationGraphDataset("cora")[0]
    num_parts = num_server

335
336
337
338
339
340
341
342
343
    orig_nid, _ = partition_graph(
        g,
        "test_get_degrees",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
344
345

    pserver_list = []
346
    ctx = mp.get_context("spawn")
347
    for i in range(num_server):
348
349
350
351
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_get_degrees"),
        )
352
353
354
355
        p.start()
        time.sleep(1)
        pserver_list.append(p)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
356
    nids = F.tensor(np.random.randint(g.num_nodes(), size=100))
357
358
359
    in_degs, out_degs, all_in_degs, all_out_degs = start_get_degrees_client(
        0, tmpdir, num_server > 1, nids
    )
360
361
362
363

    print("Done get_degree")
    for p in pserver_list:
        p.join()
364
        assert p.exitcode == 0
365

366
    print("check results")
367
368
369
370
371
    assert F.array_equal(g.in_degrees(orig_nid[nids]), in_degs)
    assert F.array_equal(g.in_degrees(orig_nid), all_in_degs)
    assert F.array_equal(g.out_degrees(orig_nid[nids]), out_degs)
    assert F.array_equal(g.out_degrees(orig_nid), all_out_degs)

372

373
# Wait non shared memory graph store
374
375
376
377
378
379
380
381
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
382
@pytest.mark.parametrize("num_server", [1])
383
def test_rpc_get_degree_shuffle(num_server):
384
    reset_envs()
385
    import tempfile
386
387

    os.environ["DGL_DIST_MODE"] = "distributed"
388
389
390
    with tempfile.TemporaryDirectory() as tmpdirname:
        check_rpc_get_degree_shuffle(Path(tmpdirname), num_server)

391
392
393
394

# @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
# @unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
@unittest.skip("Only support partition with shuffle")
Jinjing Zhou's avatar
Jinjing Zhou committed
395
def test_rpc_sampling():
396
    reset_envs()
Jinjing Zhou's avatar
Jinjing Zhou committed
397
    import tempfile
398
399

    os.environ["DGL_DIST_MODE"] = "distributed"
Jinjing Zhou's avatar
Jinjing Zhou committed
400
    with tempfile.TemporaryDirectory() as tmpdirname:
401
        check_rpc_sampling(Path(tmpdirname), 1)
Jinjing Zhou's avatar
Jinjing Zhou committed
402

403

404
def check_rpc_sampling_shuffle(
405
406
407
408
409
410
    tmpdir,
    num_server,
    num_groups=1,
    use_graphbolt=False,
    return_eids=False,
    node_id_dtype=None,
411
):
412
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
413

Jinjing Zhou's avatar
Jinjing Zhou committed
414
415
416
417
    g = CitationGraphDataset("cora")[0]
    num_parts = num_server
    num_hops = 1

418
419
420
421
422
423
424
425
    orig_nids, orig_eids = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
426
        use_graphbolt=use_graphbolt,
427
        store_eids=return_eids,
428
    )
Jinjing Zhou's avatar
Jinjing Zhou committed
429
430

    pserver_list = []
431
    ctx = mp.get_context("spawn")
Jinjing Zhou's avatar
Jinjing Zhou committed
432
    for i in range(num_server):
433
434
435
436
437
438
439
440
        p = ctx.Process(
            target=start_server,
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
441
                use_graphbolt,
442
443
            ),
        )
Jinjing Zhou's avatar
Jinjing Zhou committed
444
445
446
447
        p.start()
        time.sleep(1)
        pserver_list.append(p)

448
449
450
451
    pclient_list = []
    num_clients = 1
    for client_id in range(num_clients):
        for group_id in range(num_groups):
452
453
454
455
456
457
458
459
460
461
462
            p = ctx.Process(
                target=start_sample_client_shuffle,
                args=(
                    client_id,
                    tmpdir,
                    num_server > 1,
                    g,
                    num_server,
                    group_id,
                    orig_nids,
                    orig_eids,
463
                    use_graphbolt,
464
                    return_eids,
465
                    node_id_dtype,
466
467
                ),
            )
468
            p.start()
469
            time.sleep(1)  # avoid race condition when instantiating DistGraph
470
471
472
            pclient_list.append(p)
    for p in pclient_list:
        p.join()
473
        assert p.exitcode == 0
Jinjing Zhou's avatar
Jinjing Zhou committed
474
475
    for p in pserver_list:
        p.join()
476
        assert p.exitcode == 0
Jinjing Zhou's avatar
Jinjing Zhou committed
477

478

479
480
481
482
483
484
485
486
def start_hetero_sample_client(
    rank,
    tmpdir,
    disable_shared_mem,
    nodes,
    use_graphbolt=False,
    return_eids=False,
):
487
488
    gpb = None
    if disable_shared_mem:
489
490
491
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
492
    dgl.distributed.initialize("rpc_ip_config.txt")
493
    dist_graph = DistGraph("test_sampling", gpb=gpb)
494
495
496
    assert "feat" in dist_graph.nodes["n1"].data
    assert "feat" not in dist_graph.nodes["n2"].data
    assert "feat" not in dist_graph.nodes["n3"].data
497
498
499
    nodes = {
        k: torch.tensor(v, dtype=dist_graph.idtype) for k, v in nodes.items()
    }
500
501
502
    if gpb is None:
        gpb = dist_graph.get_partition_book()
    try:
503
504
505
506
507
        # Enable santity check in distributed sampling.
        os.environ["DGL_DIST_DEBUG"] = "1"
        sampled_graph = sample_neighbors(
            dist_graph, nodes, 3, use_graphbolt=use_graphbolt
        )
508
        block = dgl.to_block(sampled_graph, nodes)
509
510
        if not use_graphbolt or return_eids:
            block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
511
    except Exception as e:
512
        print(traceback.format_exc())
513
514
515
516
        block = None
    dgl.distributed.exit_client()
    return block, gpb

517
518
519
520
521
522

def start_hetero_etype_sample_client(
    rank,
    tmpdir,
    disable_shared_mem,
    fanout=3,
523
    nodes=None,
524
    etype_sorted=False,
525
526
    use_graphbolt=False,
    return_eids=False,
527
):
528
529
    gpb = None
    if disable_shared_mem:
530
531
532
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
533
    dgl.distributed.initialize("rpc_ip_config.txt")
534
    dist_graph = DistGraph("test_sampling", gpb=gpb)
535
536
537
    assert "feat" in dist_graph.nodes["n1"].data
    assert "feat" not in dist_graph.nodes["n2"].data
    assert "feat" not in dist_graph.nodes["n3"].data
538
539
540
    nodes = {
        k: torch.tensor(v, dtype=dist_graph.idtype) for k, v in nodes.items()
    }
541

542
    if (not use_graphbolt) and dist_graph.local_partition is not None:
543
544
545
546
        # Check whether etypes are sorted in dist_graph
        local_g = dist_graph.local_partition
        local_nids = np.arange(local_g.num_nodes())
        for lnid in local_nids:
547
            leids = local_g.in_edges(lnid, form="eid")
548
549
550
551
            letids = F.asnumpy(local_g.edata[dgl.ETYPE][leids])
            _, idices = np.unique(letids, return_index=True)
            assert np.all(idices[:-1] <= idices[1:])

552
553
554
    if gpb is None:
        gpb = dist_graph.get_partition_book()
    try:
555
556
        # Enable santity check in distributed sampling.
        os.environ["DGL_DIST_DEBUG"] = "1"
557
        sampled_graph = sample_etype_neighbors(
558
559
560
561
562
            dist_graph,
            nodes,
            fanout,
            etype_sorted=etype_sorted,
            use_graphbolt=use_graphbolt,
563
        )
564
        block = dgl.to_block(sampled_graph, nodes)
565
566
567
        if sampled_graph.num_edges() > 0:
            if not use_graphbolt or return_eids:
                block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
568
    except Exception as e:
569
        print(traceback.format_exc())
570
571
572
573
        block = None
    dgl.distributed.exit_client()
    return block, gpb

574

575
576
577
def check_rpc_hetero_sampling_shuffle(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
578
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
579
580
581
582
583

    g = create_random_hetero()
    num_parts = num_server
    num_hops = 1

584
585
586
587
588
589
590
591
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
592
593
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
594
    )
595
596

    pserver_list = []
597
    ctx = mp.get_context("spawn")
598
    for i in range(num_server):
599
600
        p = ctx.Process(
            target=start_server,
601
602
603
604
605
606
607
608
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
609
        )
610
611
612
613
        p.start()
        time.sleep(1)
        pserver_list.append(p)

614
    nodes = {"n3": torch.tensor([0, 10, 99, 66, 124, 208], dtype=g.idtype)}
615
    block, gpb = start_hetero_sample_client(
616
617
618
        0,
        tmpdir,
        num_server > 1,
619
        nodes=nodes,
620
621
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
622
    )
623
624
    for p in pserver_list:
        p.join()
625
        assert p.exitcode == 0
626

627
628
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
629
630
631
632
633
634
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
635
636
637
638
639
640
641
642
643

        assert np.all(
            F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))
        )

        if use_graphbolt and not return_eids:
            continue

        shuffled_eid = block.edges[etype].data[dgl.EID]
644
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
645
646

        # Check the node Ids and edge Ids.
647
648
649
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)
650

651

652
653
654
655
656
657
658
659
660
def get_degrees(g, nids, ntype):
    deg = F.zeros((len(nids),), dtype=F.int64)
    for srctype, etype, dsttype in g.canonical_etypes:
        if srctype == ntype:
            deg += g.out_degrees(u=nids, etype=etype)
        elif dsttype == ntype:
            deg += g.in_degrees(v=nids, etype=etype)
    return deg

661

662
663
664
def check_rpc_hetero_sampling_empty_shuffle(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
665
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
666
667
668
669
670

    g = create_random_hetero(empty=True)
    num_parts = num_server
    num_hops = 1

671
672
673
674
675
676
677
678
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
679
680
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
681
    )
682
683

    pserver_list = []
684
    ctx = mp.get_context("spawn")
685
    for i in range(num_server):
686
687
        p = ctx.Process(
            target=start_server,
688
689
690
691
692
693
694
695
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
696
        )
697
698
699
700
        p.start()
        time.sleep(1)
        pserver_list.append(p)

701
    deg = get_degrees(g, orig_nids["n3"], "n3")
702
    empty_nids = F.nonzero_1d(deg == 0).to(g.idtype)
703
    block, gpb = start_hetero_sample_client(
704
705
706
707
708
709
        0,
        tmpdir,
        num_server > 1,
        nodes={"n3": empty_nids},
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
710
    )
711
712
    for p in pserver_list:
        p.join()
713
        assert p.exitcode == 0
714

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
715
    assert block.num_edges() == 0
716
717
    assert len(block.etypes) == len(g.etypes)

718
719

def check_rpc_hetero_etype_sampling_shuffle(
720
721
722
723
724
    tmpdir,
    num_server,
    graph_formats=None,
    use_graphbolt=False,
    return_eids=False,
725
):
726
727
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

728
729
730
731
    g = create_random_hetero(dense=True)
    num_parts = num_server
    num_hops = 1

732
733
734
735
736
737
738
739
740
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
        graph_formats=graph_formats,
741
742
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
743
    )
744
745

    pserver_list = []
746
    ctx = mp.get_context("spawn")
747
    for i in range(num_server):
748
749
        p = ctx.Process(
            target=start_server,
750
751
752
753
754
755
756
757
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
758
        )
759
760
761
762
        p.start()
        time.sleep(1)
        pserver_list.append(p)

763
    fanout = {etype: 3 for etype in g.canonical_etypes}
764
765
    etype_sorted = False
    if graph_formats is not None:
766
        etype_sorted = "csc" in graph_formats or "csr" in graph_formats
767
    nodes = {"n3": torch.tensor([0, 10, 99, 66, 124, 208], dtype=g.idtype)}
768
769
770
771
772
    block, gpb = start_hetero_etype_sample_client(
        0,
        tmpdir,
        num_server > 1,
        fanout,
773
        nodes=nodes,
774
        etype_sorted=etype_sorted,
775
776
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
777
    )
778
779
780
    print("Done sampling")
    for p in pserver_list:
        p.join()
781
        assert p.exitcode == 0
782

783
    src, dst = block.edges(etype=("n1", "r13", "n3"))
784
    assert len(src) == 18
785
    src, dst = block.edges(etype=("n2", "r23", "n3"))
786
787
    assert len(src) == 18

788
789
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
790
791
792
793
794
795
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
796
797
798
799
800
801
        assert np.all(
            F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))
        )

        if use_graphbolt and not return_eids:
            continue
802
803

        # Check the node Ids and edge Ids.
804
805
        shuffled_eid = block.edges[etype].data[dgl.EID]
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
806
807
808
809
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)

810

811
812
813
def check_rpc_hetero_etype_sampling_empty_shuffle(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
814
815
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

816
817
818
819
    g = create_random_hetero(dense=True, empty=True)
    num_parts = num_server
    num_hops = 1

820
821
822
823
824
825
826
827
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
828
829
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
830
    )
831
832

    pserver_list = []
833
    ctx = mp.get_context("spawn")
834
    for i in range(num_server):
835
836
        p = ctx.Process(
            target=start_server,
837
838
839
840
841
842
843
844
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
845
        )
846
847
848
849
850
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    fanout = 3
851
    deg = get_degrees(g, orig_nids["n3"], "n3")
852
    empty_nids = F.nonzero_1d(deg == 0).to(g.idtype)
853
    block, gpb = start_hetero_etype_sample_client(
854
855
856
857
858
859
860
        0,
        tmpdir,
        num_server > 1,
        fanout,
        nodes={"n3": empty_nids},
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
861
    )
862
863
864
    print("Done sampling")
    for p in pserver_list:
        p.join()
865
        assert p.exitcode == 0
866

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
867
    assert block.num_edges() == 0
868
869
    assert len(block.etypes) == len(g.etypes)

870
871

def create_random_bipartite():
872
873
874
875
876
877
878
    g = dgl.rand_bipartite("user", "buys", "game", 500, 1000, 1000)
    g.nodes["user"].data["feat"] = F.ones(
        (g.num_nodes("user"), 10), F.float32, F.cpu()
    )
    g.nodes["game"].data["feat"] = F.ones(
        (g.num_nodes("game"), 10), F.float32, F.cpu()
    )
879
880
881
    return g


882
883
884
885
886
887
888
889
def start_bipartite_sample_client(
    rank,
    tmpdir,
    disable_shared_mem,
    nodes,
    use_graphbolt=False,
    return_eids=False,
):
890
891
892
    gpb = None
    if disable_shared_mem:
        _, _, _, gpb, _, _, _ = load_partition(
893
894
            tmpdir / "test_sampling.json", rank
        )
895
    dgl.distributed.initialize("rpc_ip_config.txt")
896
    dist_graph = DistGraph("test_sampling", gpb=gpb)
897
898
    assert "feat" in dist_graph.nodes["user"].data
    assert "feat" in dist_graph.nodes["game"].data
899
900
901
    nodes = {
        k: torch.tensor(v, dtype=dist_graph.idtype) for k, v in nodes.items()
    }
902
903
    if gpb is None:
        gpb = dist_graph.get_partition_book()
904
905
906
907
908
    # Enable santity check in distributed sampling.
    os.environ["DGL_DIST_DEBUG"] = "1"
    sampled_graph = sample_neighbors(
        dist_graph, nodes, 3, use_graphbolt=use_graphbolt
    )
909
910
    block = dgl.to_block(sampled_graph, nodes)
    if sampled_graph.num_edges() > 0:
911
912
        if not use_graphbolt or return_eids:
            block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
913
914
915
916
    dgl.distributed.exit_client()
    return block, gpb


917
def start_bipartite_etype_sample_client(
918
919
920
921
922
923
924
    rank,
    tmpdir,
    disable_shared_mem,
    fanout=3,
    nodes={},
    use_graphbolt=False,
    return_eids=False,
925
):
926
927
928
    gpb = None
    if disable_shared_mem:
        _, _, _, gpb, _, _, _ = load_partition(
929
930
            tmpdir / "test_sampling.json", rank
        )
931
    dgl.distributed.initialize("rpc_ip_config.txt")
932
    dist_graph = DistGraph("test_sampling", gpb=gpb)
933
934
    assert "feat" in dist_graph.nodes["user"].data
    assert "feat" in dist_graph.nodes["game"].data
935
936
937
    nodes = {
        k: torch.tensor(v, dtype=dist_graph.idtype) for k, v in nodes.items()
    }
938

939
    if not use_graphbolt and dist_graph.local_partition is not None:
940
941
942
943
        # Check whether etypes are sorted in dist_graph
        local_g = dist_graph.local_partition
        local_nids = np.arange(local_g.num_nodes())
        for lnid in local_nids:
944
            leids = local_g.in_edges(lnid, form="eid")
945
946
947
948
949
950
            letids = F.asnumpy(local_g.edata[dgl.ETYPE][leids])
            _, idices = np.unique(letids, return_index=True)
            assert np.all(idices[:-1] <= idices[1:])

    if gpb is None:
        gpb = dist_graph.get_partition_book()
951
952
953
    sampled_graph = sample_etype_neighbors(
        dist_graph, nodes, fanout, use_graphbolt=use_graphbolt
    )
954
955
    block = dgl.to_block(sampled_graph, nodes)
    if sampled_graph.num_edges() > 0:
956
957
        if not use_graphbolt or return_eids:
            block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
958
959
960
961
    dgl.distributed.exit_client()
    return block, gpb


962
963
964
def check_rpc_bipartite_sampling_empty(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
965
966
967
968
969
970
971
    """sample on bipartite via sample_neighbors() which yields empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

972
973
974
975
976
977
978
979
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
980
981
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
982
    )
983
984

    pserver_list = []
985
    ctx = mp.get_context("spawn")
986
    for i in range(num_server):
987
988
        p = ctx.Process(
            target=start_server,
989
990
991
992
993
994
995
996
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
997
        )
998
999
1000
1001
        p.start()
        time.sleep(1)
        pserver_list.append(p)

1002
    deg = get_degrees(g, orig_nids["game"], "game")
1003
    empty_nids = F.nonzero_1d(deg == 0).to(g.idtype)
1004
    nodes = {"game": empty_nids, "user": torch.tensor([1], dtype=g.idtype)}
1005
    block, _ = start_bipartite_sample_client(
1006
1007
1008
        0,
        tmpdir,
        num_server > 1,
1009
        nodes=nodes,
1010
1011
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
1012
    )
1013
1014
1015
1016

    print("Done sampling")
    for p in pserver_list:
        p.join()
1017
        assert p.exitcode == 0
1018

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1019
    assert block.num_edges() == 0
1020
1021
1022
    assert len(block.etypes) == len(g.etypes)


1023
1024
1025
def check_rpc_bipartite_sampling_shuffle(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
1026
1027
1028
1029
1030
1031
1032
    """sample on bipartite via sample_neighbors() which yields non-empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

1033
1034
1035
1036
1037
1038
1039
1040
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
1041
1042
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
1043
    )
1044
1045

    pserver_list = []
1046
    ctx = mp.get_context("spawn")
1047
    for i in range(num_server):
1048
1049
        p = ctx.Process(
            target=start_server,
1050
1051
1052
1053
1054
1055
1056
1057
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
1058
        )
1059
1060
1061
1062
        p.start()
        time.sleep(1)
        pserver_list.append(p)

1063
    deg = get_degrees(g, orig_nid_map["game"], "game")
1064
    nids = F.nonzero_1d(deg > 0)
1065
    nodes = {"game": nids, "user": torch.tensor([0], dtype=g.idtype)}
1066
    block, gpb = start_bipartite_sample_client(
1067
1068
1069
        0,
        tmpdir,
        num_server > 1,
1070
        nodes=nodes,
1071
1072
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
1073
    )
1074
1075
1076
    print("Done sampling")
    for p in pserver_list:
        p.join()
1077
        assert p.exitcode == 0
1078

1079
1080
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
1081
1082
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
1083
1084
1085
1086
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
1087
1088
1089
1090
1091
1092
1093
1094
        assert np.all(
            F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))
        )

        if use_graphbolt and not return_eids:
            continue

        shuffled_eid = block.edges[etype].data[dgl.EID]
1095
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
1096
1097
1098
1099
1100
1101
1102

        # Check the node Ids and edge Ids.
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)


1103
1104
1105
def check_rpc_bipartite_etype_sampling_empty(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
1106
1107
1108
1109
1110
1111
1112
    """sample on bipartite via sample_etype_neighbors() which yields empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

1113
1114
1115
1116
1117
1118
1119
1120
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
1121
1122
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
1123
    )
1124
1125

    pserver_list = []
1126
    ctx = mp.get_context("spawn")
1127
    for i in range(num_server):
1128
1129
        p = ctx.Process(
            target=start_server,
1130
1131
1132
1133
1134
1135
1136
1137
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
1138
        )
1139
1140
1141
1142
        p.start()
        time.sleep(1)
        pserver_list.append(p)

1143
    deg = get_degrees(g, orig_nids["game"], "game")
1144
    empty_nids = F.nonzero_1d(deg == 0).to(g.idtype)
1145
    nodes = {"game": empty_nids, "user": torch.tensor([1], dtype=g.idtype)}
1146
1147
1148
1149
    block, _ = start_bipartite_etype_sample_client(
        0,
        tmpdir,
        num_server > 1,
1150
        nodes=nodes,
1151
1152
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
1153
    )
1154
1155
1156
1157

    print("Done sampling")
    for p in pserver_list:
        p.join()
1158
        assert p.exitcode == 0
1159
1160

    assert block is not None
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1161
    assert block.num_edges() == 0
1162
1163
1164
    assert len(block.etypes) == len(g.etypes)


1165
1166
1167
def check_rpc_bipartite_etype_sampling_shuffle(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
1168
1169
1170
1171
1172
1173
1174
    """sample on bipartite via sample_etype_neighbors() which yields non-empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

1175
1176
1177
1178
1179
1180
1181
1182
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
1183
1184
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
1185
    )
1186
1187

    pserver_list = []
1188
    ctx = mp.get_context("spawn")
1189
    for i in range(num_server):
1190
1191
        p = ctx.Process(
            target=start_server,
1192
1193
1194
1195
1196
1197
1198
1199
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
1200
        )
1201
1202
1203
1204
1205
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    fanout = 3
1206
    deg = get_degrees(g, orig_nid_map["game"], "game")
1207
    nids = F.nonzero_1d(deg > 0)
1208
    nodes = {"game": nids, "user": torch.tensor([0], dtype=g.idtype)}
1209
    block, gpb = start_bipartite_etype_sample_client(
1210
1211
1212
1213
        0,
        tmpdir,
        num_server > 1,
        fanout,
1214
        nodes=nodes,
1215
1216
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
1217
    )
1218
1219
1220
    print("Done sampling")
    for p in pserver_list:
        p.join()
1221
        assert p.exitcode == 0
1222

1223
1224
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
1225
1226
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
1227
1228
1229
1230
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
1231
1232
1233
1234
1235
1236
        assert np.all(
            F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))
        )

        if use_graphbolt and not return_eids:
            continue
1237
1238

        # Check the node Ids and edge Ids.
1239
1240
        shuffled_eid = block.edges[etype].data[dgl.EID]
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
1241
1242
1243
1244
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)

1245

1246
@pytest.mark.parametrize("num_server", [1])
1247
@pytest.mark.parametrize("use_graphbolt", [False, True])
1248
@pytest.mark.parametrize("return_eids", [False, True])
1249
@pytest.mark.parametrize("node_id_dtype", [torch.int64])
1250
1251
1252
def test_rpc_sampling_shuffle(
    num_server, use_graphbolt, return_eids, node_id_dtype
):
1253
    reset_envs()
1254
    os.environ["DGL_DIST_MODE"] = "distributed"
Jinjing Zhou's avatar
Jinjing Zhou committed
1255
    with tempfile.TemporaryDirectory() as tmpdirname:
1256
        check_rpc_sampling_shuffle(
1257
1258
1259
1260
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
1261
            node_id_dtype=node_id_dtype,
1262
        )
1263
1264
1265


@pytest.mark.parametrize("num_server", [1])
1266
1267
1268
@pytest.mark.parametrize("use_graphbolt,", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_sampling_shuffle(num_server, use_graphbolt, return_eids):
1269
1270
1271
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1272
1273
1274
1275
1276
1277
        check_rpc_hetero_sampling_shuffle(
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
        )
1278
1279
1280


@pytest.mark.parametrize("num_server", [1])
1281
1282
1283
1284
1285
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_sampling_empty_shuffle(
    num_server, use_graphbolt, return_eids
):
1286
1287
1288
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1289
1290
1291
1292
1293
1294
        check_rpc_hetero_sampling_empty_shuffle(
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
        )
1295
1296
1297
1298
1299
1300


@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize(
    "graph_formats", [None, ["csc"], ["csr"], ["csc", "coo"]]
)
1301
def test_rpc_hetero_etype_sampling_shuffle_dgl(num_server, graph_formats):
1302
1303
1304
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1305
        check_rpc_hetero_etype_sampling_shuffle(
1306
            Path(tmpdirname), num_server, graph_formats=graph_formats
1307
        )
1308
1309
1310


@pytest.mark.parametrize("num_server", [1])
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_etype_sampling_shuffle_graphbolt(num_server, return_eids):
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
        check_rpc_hetero_etype_sampling_shuffle(
            Path(tmpdirname),
            num_server,
            use_graphbolt=True,
            return_eids=return_eids,
        )


@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_etype_sampling_empty_shuffle(
    num_server, use_graphbolt, return_eids
):
1330
1331
1332
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1333
        check_rpc_hetero_etype_sampling_empty_shuffle(
1334
1335
1336
1337
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
1338
        )
1339
1340
1341


@pytest.mark.parametrize("num_server", [1])
1342
1343
1344
1345
1346
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_bipartite_sampling_empty_shuffle(
    num_server, use_graphbolt, return_eids
):
1347
1348
1349
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1350
1351
1352
        check_rpc_bipartite_sampling_empty(
            Path(tmpdirname), num_server, use_graphbolt, return_eids
        )
1353
1354
1355


@pytest.mark.parametrize("num_server", [1])
1356
1357
1358
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_bipartite_sampling_shuffle(num_server, use_graphbolt, return_eids):
1359
1360
1361
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1362
1363
1364
        check_rpc_bipartite_sampling_shuffle(
            Path(tmpdirname), num_server, use_graphbolt, return_eids
        )
1365
1366
1367


@pytest.mark.parametrize("num_server", [1])
1368
1369
1370
1371
1372
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_bipartite_etype_sampling_empty_shuffle(
    num_server, use_graphbolt, return_eids
):
1373
1374
1375
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1376
1377
1378
1379
1380
1381
        check_rpc_bipartite_etype_sampling_empty(
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
        )
1382
1383
1384


@pytest.mark.parametrize("num_server", [1])
1385
1386
1387
1388
1389
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_bipartite_etype_sampling_shuffle(
    num_server, use_graphbolt, return_eids
):
1390
1391
1392
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1393
1394
1395
1396
1397
1398
        check_rpc_bipartite_etype_sampling_shuffle(
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
        )
Jinjing Zhou's avatar
Jinjing Zhou committed
1399

1400

1401
def check_standalone_sampling(tmpdir):
1402
    g = CitationGraphDataset("cora")[0]
1403
    prob = np.maximum(np.random.randn(g.num_edges()), 0)
1404
1405
1406
    mask = prob > 0
    g.edata["prob"] = F.tensor(prob)
    g.edata["mask"] = F.tensor(mask)
1407
1408
    num_parts = 1
    num_hops = 1
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
    partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )

    os.environ["DGL_DIST_MODE"] = "standalone"
1419
    dgl.distributed.initialize("rpc_ip_config.txt")
1420
1421
1422
    dist_graph = DistGraph(
        "test_sampling", part_config=tmpdir / "test_sampling.json"
    )
1423
1424
1425
1426
1427
    sampled_graph = sample_neighbors(
        dist_graph,
        torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=dist_graph.idtype),
        3,
    )
1428
1429

    src, dst = sampled_graph.edges()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1430
    assert sampled_graph.num_nodes() == g.num_nodes()
1431
1432
1433
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    assert np.array_equal(
1434
1435
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)
    )
1436
1437

    sampled_graph = sample_neighbors(
1438
1439
1440
1441
        dist_graph,
        torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=dist_graph.idtype),
        3,
        prob="mask",
1442
    )
1443
1444
1445
1446
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert mask[eid].all()

    sampled_graph = sample_neighbors(
1447
1448
1449
1450
        dist_graph,
        torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=dist_graph.idtype),
        3,
        prob="prob",
1451
    )
1452
1453
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert (prob[eid] > 0).all()
1454
    dgl.distributed.exit_client()
1455

1456

1457
def check_standalone_etype_sampling(tmpdir):
1458
    hg = CitationGraphDataset("cora")[0]
1459
    prob = np.maximum(np.random.randn(hg.num_edges()), 0)
1460
1461
1462
    mask = prob > 0
    hg.edata["prob"] = F.tensor(prob)
    hg.edata["mask"] = F.tensor(mask)
1463
1464
1465
    num_parts = 1
    num_hops = 1

1466
1467
1468
1469
1470
1471
1472
1473
1474
    partition_graph(
        hg,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )
    os.environ["DGL_DIST_MODE"] = "standalone"
1475
    dgl.distributed.initialize("rpc_ip_config.txt")
1476
1477
1478
    dist_graph = DistGraph(
        "test_sampling", part_config=tmpdir / "test_sampling.json"
    )
1479
1480
1481
1482
1483
    sampled_graph = sample_etype_neighbors(
        dist_graph,
        torch.tensor([0, 10, 99, 66, 1023], dtype=dist_graph.idtype),
        3,
    )
1484
1485

    src, dst = sampled_graph.edges()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1486
    assert sampled_graph.num_nodes() == hg.num_nodes()
1487
1488
1489
    assert np.all(F.asnumpy(hg.has_edges_between(src, dst)))
    eids = hg.edge_ids(src, dst)
    assert np.array_equal(
1490
1491
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)
    )
1492
1493

    sampled_graph = sample_etype_neighbors(
1494
1495
1496
1497
        dist_graph,
        torch.tensor([0, 10, 99, 66, 1023], dtype=dist_graph.idtype),
        3,
        prob="mask",
1498
    )
1499
1500
1501
1502
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert mask[eid].all()

    sampled_graph = sample_etype_neighbors(
1503
1504
1505
1506
        dist_graph,
        torch.tensor([0, 10, 99, 66, 1023], dtype=dist_graph.idtype),
        3,
        prob="prob",
1507
    )
1508
1509
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert (prob[eid] > 0).all()
1510
1511
    dgl.distributed.exit_client()

1512

1513
def check_standalone_etype_sampling_heterograph(tmpdir):
1514
    hg = CitationGraphDataset("cora")[0]
1515
1516
1517
    num_parts = 1
    num_hops = 1
    src, dst = hg.edges()
1518
1519
1520
1521
1522
    new_hg = dgl.heterograph(
        {
            ("paper", "cite", "paper"): (src, dst),
            ("paper", "cite-by", "paper"): (dst, src),
        },
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1523
        {"paper": hg.num_nodes()},
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
    )
    partition_graph(
        new_hg,
        "test_hetero_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )
    os.environ["DGL_DIST_MODE"] = "standalone"
1534
    dgl.distributed.initialize("rpc_ip_config.txt")
1535
1536
1537
    dist_graph = DistGraph(
        "test_hetero_sampling", part_config=tmpdir / "test_hetero_sampling.json"
    )
1538
    sampled_graph = sample_etype_neighbors(
1539
1540
1541
1542
1543
1544
        dist_graph,
        torch.tensor(
            [0, 1, 2, 10, 99, 66, 1023, 1024, 2700, 2701],
            dtype=dist_graph.idtype,
        ),
        1,
1545
1546
    )
    src, dst = sampled_graph.edges(etype=("paper", "cite", "paper"))
1547
    assert len(src) == 10
1548
    src, dst = sampled_graph.edges(etype=("paper", "cite-by", "paper"))
1549
    assert len(src) == 10
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1550
    assert sampled_graph.num_nodes() == new_hg.num_nodes()
1551
1552
    dgl.distributed.exit_client()

1553
1554
1555
1556
1557
1558

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
1559
def test_standalone_sampling():
1560
    reset_envs()
1561
    import tempfile
1562
1563

    os.environ["DGL_DIST_MODE"] = "standalone"
1564
    with tempfile.TemporaryDirectory() as tmpdirname:
1565
        check_standalone_sampling(Path(tmpdirname))
1566

1567

1568
1569
def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes):
    gpb = None
1570
    dgl.distributed.initialize("rpc_ip_config.txt")
1571
    if disable_shared_mem:
1572
1573
1574
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_in_subgraph.json", rank
        )
1575
    dist_graph = DistGraph("test_in_subgraph", gpb=gpb)
1576
1577
1578
    try:
        sampled_graph = dgl.distributed.in_subgraph(dist_graph, nodes)
    except Exception as e:
1579
        print(traceback.format_exc())
1580
        sampled_graph = None
1581
    dgl.distributed.exit_client()
1582
1583
1584
    return sampled_graph


1585
def check_rpc_in_subgraph_shuffle(tmpdir, num_server):
1586
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
1587
1588
1589
1590

    g = CitationGraphDataset("cora")[0]
    num_parts = num_server

1591
1592
1593
1594
1595
1596
1597
1598
1599
    orig_nid, orig_eid = partition_graph(
        g,
        "test_in_subgraph",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
1600
1601

    pserver_list = []
1602
    ctx = mp.get_context("spawn")
1603
    for i in range(num_server):
1604
1605
1606
1607
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_in_subgraph"),
        )
1608
1609
1610
1611
        p.start()
        time.sleep(1)
        pserver_list.append(p)

1612
    nodes = torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=g.idtype)
1613
1614
1615
    sampled_graph = start_in_subgraph_client(0, tmpdir, num_server > 1, nodes)
    for p in pserver_list:
        p.join()
1616
        assert p.exitcode == 0
1617
1618

    src, dst = sampled_graph.edges()
1619
1620
    src = orig_nid[src]
    dst = orig_nid[dst]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1621
    assert sampled_graph.num_nodes() == g.num_nodes()
1622
1623
1624
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))

    subg1 = dgl.in_subgraph(g, orig_nid[nodes])
1625
1626
1627
1628
    src1, dst1 = subg1.edges()
    assert np.all(np.sort(F.asnumpy(src)) == np.sort(F.asnumpy(src1)))
    assert np.all(np.sort(F.asnumpy(dst)) == np.sort(F.asnumpy(dst1)))
    eids = g.edge_ids(src, dst)
1629
1630
    eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
    assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))
1631

1632
1633
1634
1635
1636
1637

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
1638
def test_rpc_in_subgraph():
1639
    reset_envs()
1640
    import tempfile
1641
1642

    os.environ["DGL_DIST_MODE"] = "distributed"
1643
    with tempfile.TemporaryDirectory() as tmpdirname:
1644
        check_rpc_in_subgraph_shuffle(Path(tmpdirname), 1)
1645

1646
1647
1648
1649
1650
1651
1652
1653
1654

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
1655
def test_standalone_etype_sampling():
1656
    reset_envs()
1657
    import tempfile
1658

1659
    with tempfile.TemporaryDirectory() as tmpdirname:
1660
        os.environ["DGL_DIST_MODE"] = "standalone"
1661
        check_standalone_etype_sampling_heterograph(Path(tmpdirname))
1662
    with tempfile.TemporaryDirectory() as tmpdirname:
1663
        os.environ["DGL_DIST_MODE"] = "standalone"
1664
        check_standalone_etype_sampling(Path(tmpdirname))
1665

1666

Jinjing Zhou's avatar
Jinjing Zhou committed
1667
1668
if __name__ == "__main__":
    import tempfile
1669

Jinjing Zhou's avatar
Jinjing Zhou committed
1670
    with tempfile.TemporaryDirectory() as tmpdirname:
1671
        os.environ["DGL_DIST_MODE"] = "standalone"
1672
        check_standalone_etype_sampling_heterograph(Path(tmpdirname))
1673
1674

    with tempfile.TemporaryDirectory() as tmpdirname:
1675
        os.environ["DGL_DIST_MODE"] = "standalone"
1676
1677
        check_standalone_etype_sampling(Path(tmpdirname))
        check_standalone_sampling(Path(tmpdirname))
1678
        os.environ["DGL_DIST_MODE"] = "distributed"
1679
1680
        check_rpc_sampling(Path(tmpdirname), 2)
        check_rpc_sampling(Path(tmpdirname), 1)
1681
1682
        check_rpc_get_degree_shuffle(Path(tmpdirname), 1)
        check_rpc_get_degree_shuffle(Path(tmpdirname), 2)
1683
1684
        check_rpc_find_edges_shuffle(Path(tmpdirname), 2)
        check_rpc_find_edges_shuffle(Path(tmpdirname), 1)
1685
1686
        check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), 2)
1687
1688
1689
1690
        check_rpc_in_subgraph_shuffle(Path(tmpdirname), 2)
        check_rpc_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), 2)
1691
1692
1693
1694
        check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), 2)
        check_rpc_hetero_etype_sampling_empty_shuffle(Path(tmpdirname), 1)