test_distributed_sampling.py 48.7 KB
Newer Older
1
import multiprocessing as mp
Jinjing Zhou's avatar
Jinjing Zhou committed
2
import os
3
import random
4
import tempfile
Jinjing Zhou's avatar
Jinjing Zhou committed
5
import time
6
7
import traceback
import unittest
Jinjing Zhou's avatar
Jinjing Zhou committed
8
from pathlib import Path
9
10
11
12

import backend as F
import dgl
import numpy as np
13
import pytest
14
15
16
17
18
19
20
21
22
23
from dgl.data import CitationGraphDataset, WN18Dataset
from dgl.distributed import (
    DistGraph,
    DistGraphServer,
    load_partition,
    load_partition_book,
    partition_graph,
    sample_etype_neighbors,
    sample_neighbors,
)
24
from scipy import sparse as spsp
25
from utils import generate_ip_config, reset_envs
Jinjing Zhou's avatar
Jinjing Zhou committed
26
27


28
29
30
31
32
33
def start_server(
    rank,
    tmpdir,
    disable_shared_mem,
    graph_name,
    graph_format=["csc", "coo"],
34
    use_graphbolt=False,
35
36
37
38
39
40
41
42
43
):
    g = DistGraphServer(
        rank,
        "rpc_ip_config.txt",
        1,
        1,
        tmpdir / (graph_name + ".json"),
        disable_shared_mem=disable_shared_mem,
        graph_format=graph_format,
44
        use_graphbolt=use_graphbolt,
45
    )
Jinjing Zhou's avatar
Jinjing Zhou committed
46
47
48
    g.start()


49
def start_sample_client(rank, tmpdir, disable_shared_mem):
50
51
    gpb = None
    if disable_shared_mem:
52
53
54
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
55
    dgl.distributed.initialize("rpc_ip_config.txt")
56
    dist_graph = DistGraph("test_sampling", gpb=gpb)
57
    try:
58
59
60
        sampled_graph = sample_neighbors(
            dist_graph, [0, 10, 99, 66, 1024, 2008], 3
        )
61
    except Exception as e:
62
        print(traceback.format_exc())
63
        sampled_graph = None
64
    dgl.distributed.exit_client()
Jinjing Zhou's avatar
Jinjing Zhou committed
65
66
    return sampled_graph

67

68
69
70
71
72
73
74
75
76
def start_sample_client_shuffle(
    rank,
    tmpdir,
    disable_shared_mem,
    g,
    num_servers,
    group_id,
    orig_nid,
    orig_eid,
77
    use_graphbolt=False,
78
    return_eids=False,
79
80
):
    os.environ["DGL_GROUP_ID"] = str(group_id)
81
82
    gpb = None
    if disable_shared_mem:
83
84
85
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
86
    dgl.distributed.initialize("rpc_ip_config.txt")
87
    dist_graph = DistGraph("test_sampling", gpb=gpb)
88
89
90
    sampled_graph = sample_neighbors(
        dist_graph, [0, 10, 99, 66, 1024, 2008], 3, use_graphbolt=use_graphbolt
    )
91

92
93
94
    assert (
        dgl.ETYPE not in sampled_graph.edata
    ), "Etype should not be in homogeneous sampled graph."
95
96
97
    src, dst = sampled_graph.edges()
    src = orig_nid[src]
    dst = orig_nid[dst]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
98
    assert sampled_graph.num_nodes() == g.num_nodes()
99
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
100
    if use_graphbolt and not return_eids:
101
102
103
104
105
106
107
        assert (
            dgl.EID not in sampled_graph.edata
        ), "EID should not be in sampled graph if use_graphbolt=True."
    else:
        eids = g.edge_ids(src, dst)
        eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
        assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))
108

109

110
def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids, etype=None):
111
112
    gpb = None
    if disable_shared_mem:
113
114
115
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_find_edges.json", rank
        )
116
    dgl.distributed.initialize("rpc_ip_config.txt")
117
    dist_graph = DistGraph("test_find_edges", gpb=gpb)
118
    try:
119
        u, v = dist_graph.find_edges(eids, etype=etype)
120
    except Exception as e:
121
        print(traceback.format_exc())
122
        u, v = None, None
123
124
    dgl.distributed.exit_client()
    return u, v
Jinjing Zhou's avatar
Jinjing Zhou committed
125

126

127
128
129
def start_get_degrees_client(rank, tmpdir, disable_shared_mem, nids=None):
    gpb = None
    if disable_shared_mem:
130
131
132
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_get_degrees.json", rank
        )
133
    dgl.distributed.initialize("rpc_ip_config.txt")
134
135
136
137
138
139
140
    dist_graph = DistGraph("test_get_degrees", gpb=gpb)
    try:
        in_deg = dist_graph.in_degrees(nids)
        all_in_deg = dist_graph.in_degrees()
        out_deg = dist_graph.out_degrees(nids)
        all_out_deg = dist_graph.out_degrees()
    except Exception as e:
141
        print(traceback.format_exc())
142
143
144
145
        in_deg, out_deg, all_in_deg, all_out_deg = None, None, None, None
    dgl.distributed.exit_client()
    return in_deg, out_deg, all_in_deg, all_out_deg

146

147
def check_rpc_sampling(tmpdir, num_server):
148
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
Jinjing Zhou's avatar
Jinjing Zhou committed
149
150
151
152
153
154

    g = CitationGraphDataset("cora")[0]
    print(g.idtype)
    num_parts = num_server
    num_hops = 1

155
156
157
158
159
160
161
162
    partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )
Jinjing Zhou's avatar
Jinjing Zhou committed
163
164

    pserver_list = []
165
    ctx = mp.get_context("spawn")
Jinjing Zhou's avatar
Jinjing Zhou committed
166
    for i in range(num_server):
167
168
169
170
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
Jinjing Zhou's avatar
Jinjing Zhou committed
171
172
173
174
        p.start()
        time.sleep(1)
        pserver_list.append(p)

175
    sampled_graph = start_sample_client(0, tmpdir, num_server > 1)
Jinjing Zhou's avatar
Jinjing Zhou committed
176
177
178
    print("Done sampling")
    for p in pserver_list:
        p.join()
179
        assert p.exitcode == 0
Jinjing Zhou's avatar
Jinjing Zhou committed
180
181

    src, dst = sampled_graph.edges()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
182
    assert sampled_graph.num_nodes() == g.num_nodes()
Jinjing Zhou's avatar
Jinjing Zhou committed
183
184
185
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    assert np.array_equal(
186
187
188
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)
    )

Jinjing Zhou's avatar
Jinjing Zhou committed
189

190
def check_rpc_find_edges_shuffle(tmpdir, num_server):
191
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
192
193
194
195

    g = CitationGraphDataset("cora")[0]
    num_parts = num_server

196
197
198
199
200
201
202
203
204
    orig_nid, orig_eid = partition_graph(
        g,
        "test_find_edges",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
205
206

    pserver_list = []
207
    ctx = mp.get_context("spawn")
208
    for i in range(num_server):
209
210
211
212
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_find_edges", ["csr", "coo"]),
        )
213
214
215
216
        p.start()
        time.sleep(1)
        pserver_list.append(p)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
217
    eids = F.tensor(np.random.randint(g.num_edges(), size=100))
218
    u, v = g.find_edges(orig_eid[eids])
219
    du, dv = start_find_edges_client(0, tmpdir, num_server > 1, eids)
220
221
    du = orig_nid[du]
    dv = orig_nid[dv]
222
223
224
    assert F.array_equal(u, du)
    assert F.array_equal(v, dv)

225

226
def create_random_hetero(dense=False, empty=False):
227
228
229
230
231
232
    num_nodes = (
        {"n1": 210, "n2": 200, "n3": 220}
        if dense
        else {"n1": 1010, "n2": 1000, "n3": 1020}
    )
    etypes = [("n1", "r12", "n2"), ("n1", "r13", "n3"), ("n2", "r23", "n3")]
233
    edges = {}
234
    random.seed(42)
235
236
    for etype in etypes:
        src_ntype, _, dst_ntype = etype
237
238
239
240
241
242
243
        arr = spsp.random(
            num_nodes[src_ntype] - 10 if empty else num_nodes[src_ntype],
            num_nodes[dst_ntype] - 10 if empty else num_nodes[dst_ntype],
            density=0.1 if dense else 0.001,
            format="coo",
            random_state=100,
        )
244
        edges[etype] = (arr.row, arr.col)
245
    g = dgl.heterograph(edges, num_nodes)
246
    g.nodes["n1"].data["feat"] = F.ones(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
247
        (g.num_nodes("n1"), 10), F.float32, F.cpu()
248
    )
249
    return g
250

251

252
def check_rpc_hetero_find_edges_shuffle(tmpdir, num_server):
253
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
254
255
256
257

    g = create_random_hetero()
    num_parts = num_server

258
259
260
261
262
263
264
265
266
    orig_nid, orig_eid = partition_graph(
        g,
        "test_find_edges",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
267
268

    pserver_list = []
269
    ctx = mp.get_context("spawn")
270
    for i in range(num_server):
271
272
273
274
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_find_edges", ["csr", "coo"]),
        )
275
276
277
278
        p.start()
        time.sleep(1)
        pserver_list.append(p)

279
    test_etype = g.to_canonical_etype("r12")
280
    eids = F.tensor(np.random.randint(g.num_edges(test_etype), size=100))
281
282
    expect_except = False
    try:
283
        _, _ = g.find_edges(orig_eid[test_etype][eids], etype=("n1", "r12"))
284
285
286
    except:
        expect_except = True
    assert expect_except
287
288
    u, v = g.find_edges(orig_eid[test_etype][eids], etype="r12")
    u1, v1 = g.find_edges(orig_eid[test_etype][eids], etype=("n1", "r12", "n2"))
289
290
    assert F.array_equal(u, u1)
    assert F.array_equal(v, v1)
291
292
293
294
295
    du, dv = start_find_edges_client(
        0, tmpdir, num_server > 1, eids, etype="r12"
    )
    du = orig_nid["n1"][du]
    dv = orig_nid["n2"][dv]
296
297
298
    assert F.array_equal(u, du)
    assert F.array_equal(v, dv)

299

300
# Wait non shared memory graph store
301
302
303
304
305
306
307
308
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
309
@pytest.mark.parametrize("num_server", [1])
310
def test_rpc_find_edges_shuffle(num_server):
311
    reset_envs()
312
    import tempfile
313
314

    os.environ["DGL_DIST_MODE"] = "distributed"
315
    with tempfile.TemporaryDirectory() as tmpdirname:
316
        check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), num_server)
317
318
        check_rpc_find_edges_shuffle(Path(tmpdirname), num_server)

319

320
def check_rpc_get_degree_shuffle(tmpdir, num_server):
321
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
322
323
324
325

    g = CitationGraphDataset("cora")[0]
    num_parts = num_server

326
327
328
329
330
331
332
333
334
    orig_nid, _ = partition_graph(
        g,
        "test_get_degrees",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
335
336

    pserver_list = []
337
    ctx = mp.get_context("spawn")
338
    for i in range(num_server):
339
340
341
342
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_get_degrees"),
        )
343
344
345
346
        p.start()
        time.sleep(1)
        pserver_list.append(p)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
347
    nids = F.tensor(np.random.randint(g.num_nodes(), size=100))
348
349
350
    in_degs, out_degs, all_in_degs, all_out_degs = start_get_degrees_client(
        0, tmpdir, num_server > 1, nids
    )
351
352
353
354

    print("Done get_degree")
    for p in pserver_list:
        p.join()
355
        assert p.exitcode == 0
356

357
    print("check results")
358
359
360
361
362
    assert F.array_equal(g.in_degrees(orig_nid[nids]), in_degs)
    assert F.array_equal(g.in_degrees(orig_nid), all_in_degs)
    assert F.array_equal(g.out_degrees(orig_nid[nids]), out_degs)
    assert F.array_equal(g.out_degrees(orig_nid), all_out_degs)

363

364
# Wait non shared memory graph store
365
366
367
368
369
370
371
372
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
373
@pytest.mark.parametrize("num_server", [1])
374
def test_rpc_get_degree_shuffle(num_server):
375
    reset_envs()
376
    import tempfile
377
378

    os.environ["DGL_DIST_MODE"] = "distributed"
379
380
381
    with tempfile.TemporaryDirectory() as tmpdirname:
        check_rpc_get_degree_shuffle(Path(tmpdirname), num_server)

382
383
384
385

# @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
# @unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
@unittest.skip("Only support partition with shuffle")
Jinjing Zhou's avatar
Jinjing Zhou committed
386
def test_rpc_sampling():
387
    reset_envs()
Jinjing Zhou's avatar
Jinjing Zhou committed
388
    import tempfile
389
390

    os.environ["DGL_DIST_MODE"] = "distributed"
Jinjing Zhou's avatar
Jinjing Zhou committed
391
    with tempfile.TemporaryDirectory() as tmpdirname:
392
        check_rpc_sampling(Path(tmpdirname), 1)
Jinjing Zhou's avatar
Jinjing Zhou committed
393

394

395
def check_rpc_sampling_shuffle(
396
    tmpdir, num_server, num_groups=1, use_graphbolt=False, return_eids=False
397
):
398
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
399

Jinjing Zhou's avatar
Jinjing Zhou committed
400
401
402
403
    g = CitationGraphDataset("cora")[0]
    num_parts = num_server
    num_hops = 1

404
405
406
407
408
409
410
411
    orig_nids, orig_eids = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
412
        use_graphbolt=use_graphbolt,
413
        store_eids=return_eids,
414
    )
Jinjing Zhou's avatar
Jinjing Zhou committed
415
416

    pserver_list = []
417
    ctx = mp.get_context("spawn")
Jinjing Zhou's avatar
Jinjing Zhou committed
418
    for i in range(num_server):
419
420
421
422
423
424
425
426
        p = ctx.Process(
            target=start_server,
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
427
                use_graphbolt,
428
429
            ),
        )
Jinjing Zhou's avatar
Jinjing Zhou committed
430
431
432
433
        p.start()
        time.sleep(1)
        pserver_list.append(p)

434
435
436
437
    pclient_list = []
    num_clients = 1
    for client_id in range(num_clients):
        for group_id in range(num_groups):
438
439
440
441
442
443
444
445
446
447
448
            p = ctx.Process(
                target=start_sample_client_shuffle,
                args=(
                    client_id,
                    tmpdir,
                    num_server > 1,
                    g,
                    num_server,
                    group_id,
                    orig_nids,
                    orig_eids,
449
                    use_graphbolt,
450
                    return_eids,
451
452
                ),
            )
453
            p.start()
454
            time.sleep(1)  # avoid race condition when instantiating DistGraph
455
456
457
            pclient_list.append(p)
    for p in pclient_list:
        p.join()
458
        assert p.exitcode == 0
Jinjing Zhou's avatar
Jinjing Zhou committed
459
460
    for p in pserver_list:
        p.join()
461
        assert p.exitcode == 0
Jinjing Zhou's avatar
Jinjing Zhou committed
462

463

464
465
466
467
468
469
470
471
def start_hetero_sample_client(
    rank,
    tmpdir,
    disable_shared_mem,
    nodes,
    use_graphbolt=False,
    return_eids=False,
):
472
473
    gpb = None
    if disable_shared_mem:
474
475
476
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
477
    dgl.distributed.initialize("rpc_ip_config.txt")
478
    dist_graph = DistGraph("test_sampling", gpb=gpb)
479
480
481
    assert "feat" in dist_graph.nodes["n1"].data
    assert "feat" not in dist_graph.nodes["n2"].data
    assert "feat" not in dist_graph.nodes["n3"].data
482
483
484
    if gpb is None:
        gpb = dist_graph.get_partition_book()
    try:
485
486
487
488
489
        # Enable santity check in distributed sampling.
        os.environ["DGL_DIST_DEBUG"] = "1"
        sampled_graph = sample_neighbors(
            dist_graph, nodes, 3, use_graphbolt=use_graphbolt
        )
490
        block = dgl.to_block(sampled_graph, nodes)
491
492
        if not use_graphbolt or return_eids:
            block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
493
    except Exception as e:
494
        print(traceback.format_exc())
495
496
497
498
        block = None
    dgl.distributed.exit_client()
    return block, gpb

499
500
501
502
503
504
505
506

def start_hetero_etype_sample_client(
    rank,
    tmpdir,
    disable_shared_mem,
    fanout=3,
    nodes={"n3": [0, 10, 99, 66, 124, 208]},
    etype_sorted=False,
507
508
    use_graphbolt=False,
    return_eids=False,
509
):
510
511
    gpb = None
    if disable_shared_mem:
512
513
514
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
515
    dgl.distributed.initialize("rpc_ip_config.txt")
516
    dist_graph = DistGraph("test_sampling", gpb=gpb)
517
518
519
    assert "feat" in dist_graph.nodes["n1"].data
    assert "feat" not in dist_graph.nodes["n2"].data
    assert "feat" not in dist_graph.nodes["n3"].data
520

521
    if (not use_graphbolt) and dist_graph.local_partition is not None:
522
523
524
525
        # Check whether etypes are sorted in dist_graph
        local_g = dist_graph.local_partition
        local_nids = np.arange(local_g.num_nodes())
        for lnid in local_nids:
526
            leids = local_g.in_edges(lnid, form="eid")
527
528
529
530
            letids = F.asnumpy(local_g.edata[dgl.ETYPE][leids])
            _, idices = np.unique(letids, return_index=True)
            assert np.all(idices[:-1] <= idices[1:])

531
532
533
    if gpb is None:
        gpb = dist_graph.get_partition_book()
    try:
534
535
        # Enable santity check in distributed sampling.
        os.environ["DGL_DIST_DEBUG"] = "1"
536
        sampled_graph = sample_etype_neighbors(
537
538
539
540
541
            dist_graph,
            nodes,
            fanout,
            etype_sorted=etype_sorted,
            use_graphbolt=use_graphbolt,
542
        )
543
        block = dgl.to_block(sampled_graph, nodes)
544
545
546
        if sampled_graph.num_edges() > 0:
            if not use_graphbolt or return_eids:
                block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
547
    except Exception as e:
548
        print(traceback.format_exc())
549
550
551
552
        block = None
    dgl.distributed.exit_client()
    return block, gpb

553

554
555
556
def check_rpc_hetero_sampling_shuffle(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
557
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
558
559
560
561
562

    g = create_random_hetero()
    num_parts = num_server
    num_hops = 1

563
564
565
566
567
568
569
570
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
571
572
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
573
    )
574
575

    pserver_list = []
576
    ctx = mp.get_context("spawn")
577
    for i in range(num_server):
578
579
        p = ctx.Process(
            target=start_server,
580
581
582
583
584
585
586
587
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
588
        )
589
590
591
592
        p.start()
        time.sleep(1)
        pserver_list.append(p)

593
    block, gpb = start_hetero_sample_client(
594
595
596
597
598
599
        0,
        tmpdir,
        num_server > 1,
        nodes={"n3": [0, 10, 99, 66, 124, 208]},
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
600
    )
601
602
    for p in pserver_list:
        p.join()
603
        assert p.exitcode == 0
604

605
606
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
607
608
609
610
611
612
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
613
614
615
616
617
618
619
620
621

        assert np.all(
            F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))
        )

        if use_graphbolt and not return_eids:
            continue

        shuffled_eid = block.edges[etype].data[dgl.EID]
622
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
623
624

        # Check the node Ids and edge Ids.
625
626
627
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)
628

629

630
631
632
633
634
635
636
637
638
def get_degrees(g, nids, ntype):
    deg = F.zeros((len(nids),), dtype=F.int64)
    for srctype, etype, dsttype in g.canonical_etypes:
        if srctype == ntype:
            deg += g.out_degrees(u=nids, etype=etype)
        elif dsttype == ntype:
            deg += g.in_degrees(v=nids, etype=etype)
    return deg

639

640
641
642
def check_rpc_hetero_sampling_empty_shuffle(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
643
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
644
645
646
647
648

    g = create_random_hetero(empty=True)
    num_parts = num_server
    num_hops = 1

649
650
651
652
653
654
655
656
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
657
658
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
659
    )
660
661

    pserver_list = []
662
    ctx = mp.get_context("spawn")
663
    for i in range(num_server):
664
665
        p = ctx.Process(
            target=start_server,
666
667
668
669
670
671
672
673
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
674
        )
675
676
677
678
        p.start()
        time.sleep(1)
        pserver_list.append(p)

679
    deg = get_degrees(g, orig_nids["n3"], "n3")
680
    empty_nids = F.nonzero_1d(deg == 0)
681
    block, gpb = start_hetero_sample_client(
682
683
684
685
686
687
        0,
        tmpdir,
        num_server > 1,
        nodes={"n3": empty_nids},
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
688
    )
689
690
    for p in pserver_list:
        p.join()
691
        assert p.exitcode == 0
692

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
693
    assert block.num_edges() == 0
694
695
    assert len(block.etypes) == len(g.etypes)

696
697

def check_rpc_hetero_etype_sampling_shuffle(
698
699
700
701
702
    tmpdir,
    num_server,
    graph_formats=None,
    use_graphbolt=False,
    return_eids=False,
703
):
704
705
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

706
707
708
709
    g = create_random_hetero(dense=True)
    num_parts = num_server
    num_hops = 1

710
711
712
713
714
715
716
717
718
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
        graph_formats=graph_formats,
719
720
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
721
    )
722
723

    pserver_list = []
724
    ctx = mp.get_context("spawn")
725
    for i in range(num_server):
726
727
        p = ctx.Process(
            target=start_server,
728
729
730
731
732
733
734
735
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
736
        )
737
738
739
740
        p.start()
        time.sleep(1)
        pserver_list.append(p)

741
    fanout = {etype: 3 for etype in g.canonical_etypes}
742
743
    etype_sorted = False
    if graph_formats is not None:
744
745
746
747
748
749
750
751
        etype_sorted = "csc" in graph_formats or "csr" in graph_formats
    block, gpb = start_hetero_etype_sample_client(
        0,
        tmpdir,
        num_server > 1,
        fanout,
        nodes={"n3": [0, 10, 99, 66, 124, 208]},
        etype_sorted=etype_sorted,
752
753
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
754
    )
755
756
757
    print("Done sampling")
    for p in pserver_list:
        p.join()
758
        assert p.exitcode == 0
759

760
    src, dst = block.edges(etype=("n1", "r13", "n3"))
761
    assert len(src) == 18
762
    src, dst = block.edges(etype=("n2", "r23", "n3"))
763
764
    assert len(src) == 18

765
766
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
767
768
769
770
771
772
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
773
774
775
776
777
778
        assert np.all(
            F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))
        )

        if use_graphbolt and not return_eids:
            continue
779
780

        # Check the node Ids and edge Ids.
781
782
        shuffled_eid = block.edges[etype].data[dgl.EID]
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
783
784
785
786
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)

787

788
789
790
def check_rpc_hetero_etype_sampling_empty_shuffle(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
791
792
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

793
794
795
796
    g = create_random_hetero(dense=True, empty=True)
    num_parts = num_server
    num_hops = 1

797
798
799
800
801
802
803
804
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
805
806
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
807
    )
808
809

    pserver_list = []
810
    ctx = mp.get_context("spawn")
811
    for i in range(num_server):
812
813
        p = ctx.Process(
            target=start_server,
814
815
816
817
818
819
820
821
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
822
        )
823
824
825
826
827
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    fanout = 3
828
    deg = get_degrees(g, orig_nids["n3"], "n3")
829
    empty_nids = F.nonzero_1d(deg == 0)
830
    block, gpb = start_hetero_etype_sample_client(
831
832
833
834
835
836
837
        0,
        tmpdir,
        num_server > 1,
        fanout,
        nodes={"n3": empty_nids},
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
838
    )
839
840
841
    print("Done sampling")
    for p in pserver_list:
        p.join()
842
        assert p.exitcode == 0
843

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
844
    assert block.num_edges() == 0
845
846
    assert len(block.etypes) == len(g.etypes)

847
848

def create_random_bipartite():
849
850
851
852
853
854
855
    g = dgl.rand_bipartite("user", "buys", "game", 500, 1000, 1000)
    g.nodes["user"].data["feat"] = F.ones(
        (g.num_nodes("user"), 10), F.float32, F.cpu()
    )
    g.nodes["game"].data["feat"] = F.ones(
        (g.num_nodes("game"), 10), F.float32, F.cpu()
    )
856
857
858
    return g


859
860
861
862
863
864
865
866
def start_bipartite_sample_client(
    rank,
    tmpdir,
    disable_shared_mem,
    nodes,
    use_graphbolt=False,
    return_eids=False,
):
867
868
869
    gpb = None
    if disable_shared_mem:
        _, _, _, gpb, _, _, _ = load_partition(
870
871
            tmpdir / "test_sampling.json", rank
        )
872
    dgl.distributed.initialize("rpc_ip_config.txt")
873
    dist_graph = DistGraph("test_sampling", gpb=gpb)
874
875
    assert "feat" in dist_graph.nodes["user"].data
    assert "feat" in dist_graph.nodes["game"].data
876
877
    if gpb is None:
        gpb = dist_graph.get_partition_book()
878
879
880
881
882
    # Enable santity check in distributed sampling.
    os.environ["DGL_DIST_DEBUG"] = "1"
    sampled_graph = sample_neighbors(
        dist_graph, nodes, 3, use_graphbolt=use_graphbolt
    )
883
884
    block = dgl.to_block(sampled_graph, nodes)
    if sampled_graph.num_edges() > 0:
885
886
        if not use_graphbolt or return_eids:
            block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
887
888
889
890
    dgl.distributed.exit_client()
    return block, gpb


891
def start_bipartite_etype_sample_client(
892
893
894
895
896
897
898
    rank,
    tmpdir,
    disable_shared_mem,
    fanout=3,
    nodes={},
    use_graphbolt=False,
    return_eids=False,
899
):
900
901
902
    gpb = None
    if disable_shared_mem:
        _, _, _, gpb, _, _, _ = load_partition(
903
904
            tmpdir / "test_sampling.json", rank
        )
905
    dgl.distributed.initialize("rpc_ip_config.txt")
906
    dist_graph = DistGraph("test_sampling", gpb=gpb)
907
908
    assert "feat" in dist_graph.nodes["user"].data
    assert "feat" in dist_graph.nodes["game"].data
909

910
    if not use_graphbolt and dist_graph.local_partition is not None:
911
912
913
914
        # Check whether etypes are sorted in dist_graph
        local_g = dist_graph.local_partition
        local_nids = np.arange(local_g.num_nodes())
        for lnid in local_nids:
915
            leids = local_g.in_edges(lnid, form="eid")
916
917
918
919
920
921
            letids = F.asnumpy(local_g.edata[dgl.ETYPE][leids])
            _, idices = np.unique(letids, return_index=True)
            assert np.all(idices[:-1] <= idices[1:])

    if gpb is None:
        gpb = dist_graph.get_partition_book()
922
923
924
    sampled_graph = sample_etype_neighbors(
        dist_graph, nodes, fanout, use_graphbolt=use_graphbolt
    )
925
926
    block = dgl.to_block(sampled_graph, nodes)
    if sampled_graph.num_edges() > 0:
927
928
        if not use_graphbolt or return_eids:
            block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
929
930
931
932
    dgl.distributed.exit_client()
    return block, gpb


933
934
935
def check_rpc_bipartite_sampling_empty(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
936
937
938
939
940
941
942
    """sample on bipartite via sample_neighbors() which yields empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

943
944
945
946
947
948
949
950
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
951
952
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
953
    )
954
955

    pserver_list = []
956
    ctx = mp.get_context("spawn")
957
    for i in range(num_server):
958
959
        p = ctx.Process(
            target=start_server,
960
961
962
963
964
965
966
967
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
968
        )
969
970
971
972
        p.start()
        time.sleep(1)
        pserver_list.append(p)

973
    deg = get_degrees(g, orig_nids["game"], "game")
974
    empty_nids = F.nonzero_1d(deg == 0)
975
    block, _ = start_bipartite_sample_client(
976
977
978
979
980
981
        0,
        tmpdir,
        num_server > 1,
        nodes={"game": empty_nids, "user": [1]},
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
982
    )
983
984
985
986

    print("Done sampling")
    for p in pserver_list:
        p.join()
987
        assert p.exitcode == 0
988

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
989
    assert block.num_edges() == 0
990
991
992
    assert len(block.etypes) == len(g.etypes)


993
994
995
def check_rpc_bipartite_sampling_shuffle(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
996
997
998
999
1000
1001
1002
    """sample on bipartite via sample_neighbors() which yields non-empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

1003
1004
1005
1006
1007
1008
1009
1010
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
1011
1012
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
1013
    )
1014
1015

    pserver_list = []
1016
    ctx = mp.get_context("spawn")
1017
    for i in range(num_server):
1018
1019
        p = ctx.Process(
            target=start_server,
1020
1021
1022
1023
1024
1025
1026
1027
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
1028
        )
1029
1030
1031
1032
        p.start()
        time.sleep(1)
        pserver_list.append(p)

1033
    deg = get_degrees(g, orig_nid_map["game"], "game")
1034
    nids = F.nonzero_1d(deg > 0)
1035
    block, gpb = start_bipartite_sample_client(
1036
1037
1038
1039
1040
1041
        0,
        tmpdir,
        num_server > 1,
        nodes={"game": nids, "user": [0]},
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
1042
    )
1043
1044
1045
    print("Done sampling")
    for p in pserver_list:
        p.join()
1046
        assert p.exitcode == 0
1047

1048
1049
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
1050
1051
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
1052
1053
1054
1055
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
1056
1057
1058
1059
1060
1061
1062
1063
        assert np.all(
            F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))
        )

        if use_graphbolt and not return_eids:
            continue

        shuffled_eid = block.edges[etype].data[dgl.EID]
1064
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
1065
1066
1067
1068
1069
1070
1071

        # Check the node Ids and edge Ids.
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)


1072
1073
1074
def check_rpc_bipartite_etype_sampling_empty(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
1075
1076
1077
1078
1079
1080
1081
    """sample on bipartite via sample_etype_neighbors() which yields empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

1082
1083
1084
1085
1086
1087
1088
1089
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
1090
1091
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
1092
    )
1093
1094

    pserver_list = []
1095
    ctx = mp.get_context("spawn")
1096
    for i in range(num_server):
1097
1098
        p = ctx.Process(
            target=start_server,
1099
1100
1101
1102
1103
1104
1105
1106
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
1107
        )
1108
1109
1110
1111
        p.start()
        time.sleep(1)
        pserver_list.append(p)

1112
    deg = get_degrees(g, orig_nids["game"], "game")
1113
    empty_nids = F.nonzero_1d(deg == 0)
1114
1115
1116
1117
1118
1119
1120
    block, _ = start_bipartite_etype_sample_client(
        0,
        tmpdir,
        num_server > 1,
        nodes={"game": empty_nids, "user": [1]},
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
1121
    )
1122
1123
1124
1125

    print("Done sampling")
    for p in pserver_list:
        p.join()
1126
        assert p.exitcode == 0
1127
1128

    assert block is not None
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1129
    assert block.num_edges() == 0
1130
1131
1132
    assert len(block.etypes) == len(g.etypes)


1133
1134
1135
def check_rpc_bipartite_etype_sampling_shuffle(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
1136
1137
1138
1139
1140
1141
1142
    """sample on bipartite via sample_etype_neighbors() which yields non-empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

1143
1144
1145
1146
1147
1148
1149
1150
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
1151
1152
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
1153
    )
1154
1155

    pserver_list = []
1156
    ctx = mp.get_context("spawn")
1157
    for i in range(num_server):
1158
1159
        p = ctx.Process(
            target=start_server,
1160
1161
1162
1163
1164
1165
1166
1167
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
1168
        )
1169
1170
1171
1172
1173
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    fanout = 3
1174
    deg = get_degrees(g, orig_nid_map["game"], "game")
1175
    nids = F.nonzero_1d(deg > 0)
1176
    block, gpb = start_bipartite_etype_sample_client(
1177
1178
1179
1180
1181
1182
1183
        0,
        tmpdir,
        num_server > 1,
        fanout,
        nodes={"game": nids, "user": [0]},
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
1184
    )
1185
1186
1187
    print("Done sampling")
    for p in pserver_list:
        p.join()
1188
        assert p.exitcode == 0
1189

1190
1191
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
1192
1193
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
1194
1195
1196
1197
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
1198
1199
1200
1201
1202
1203
        assert np.all(
            F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))
        )

        if use_graphbolt and not return_eids:
            continue
1204
1205

        # Check the node Ids and edge Ids.
1206
1207
        shuffled_eid = block.edges[etype].data[dgl.EID]
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
1208
1209
1210
1211
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)

1212

1213
@pytest.mark.parametrize("num_server", [1])
1214
@pytest.mark.parametrize("use_graphbolt", [False, True])
1215
1216
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_sampling_shuffle(num_server, use_graphbolt, return_eids):
1217
    reset_envs()
1218
    os.environ["DGL_DIST_MODE"] = "distributed"
Jinjing Zhou's avatar
Jinjing Zhou committed
1219
    with tempfile.TemporaryDirectory() as tmpdirname:
1220
        check_rpc_sampling_shuffle(
1221
1222
1223
1224
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
1225
        )
1226
1227
1228


@pytest.mark.parametrize("num_server", [1])
1229
1230
1231
@pytest.mark.parametrize("use_graphbolt,", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_sampling_shuffle(num_server, use_graphbolt, return_eids):
1232
1233
1234
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1235
1236
1237
1238
1239
1240
        check_rpc_hetero_sampling_shuffle(
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
        )
1241
1242
1243


@pytest.mark.parametrize("num_server", [1])
1244
1245
1246
1247
1248
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_sampling_empty_shuffle(
    num_server, use_graphbolt, return_eids
):
1249
1250
1251
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1252
1253
1254
1255
1256
1257
        check_rpc_hetero_sampling_empty_shuffle(
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
        )
1258
1259
1260
1261
1262
1263


@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize(
    "graph_formats", [None, ["csc"], ["csr"], ["csc", "coo"]]
)
1264
def test_rpc_hetero_etype_sampling_shuffle_dgl(num_server, graph_formats):
1265
1266
1267
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1268
        check_rpc_hetero_etype_sampling_shuffle(
1269
            Path(tmpdirname), num_server, graph_formats=graph_formats
1270
        )
1271
1272
1273


@pytest.mark.parametrize("num_server", [1])
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_etype_sampling_shuffle_graphbolt(num_server, return_eids):
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
        check_rpc_hetero_etype_sampling_shuffle(
            Path(tmpdirname),
            num_server,
            use_graphbolt=True,
            return_eids=return_eids,
        )


@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_etype_sampling_empty_shuffle(
    num_server, use_graphbolt, return_eids
):
1293
1294
1295
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1296
        check_rpc_hetero_etype_sampling_empty_shuffle(
1297
1298
1299
1300
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
1301
        )
1302
1303
1304


@pytest.mark.parametrize("num_server", [1])
1305
1306
1307
1308
1309
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_bipartite_sampling_empty_shuffle(
    num_server, use_graphbolt, return_eids
):
1310
1311
1312
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1313
1314
1315
        check_rpc_bipartite_sampling_empty(
            Path(tmpdirname), num_server, use_graphbolt, return_eids
        )
1316
1317
1318


@pytest.mark.parametrize("num_server", [1])
1319
1320
1321
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_bipartite_sampling_shuffle(num_server, use_graphbolt, return_eids):
1322
1323
1324
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1325
1326
1327
        check_rpc_bipartite_sampling_shuffle(
            Path(tmpdirname), num_server, use_graphbolt, return_eids
        )
1328
1329
1330


@pytest.mark.parametrize("num_server", [1])
1331
1332
1333
1334
1335
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_bipartite_etype_sampling_empty_shuffle(
    num_server, use_graphbolt, return_eids
):
1336
1337
1338
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1339
1340
1341
1342
1343
1344
        check_rpc_bipartite_etype_sampling_empty(
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
        )
1345
1346
1347


@pytest.mark.parametrize("num_server", [1])
1348
1349
1350
1351
1352
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_bipartite_etype_sampling_shuffle(
    num_server, use_graphbolt, return_eids
):
1353
1354
1355
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1356
1357
1358
1359
1360
1361
        check_rpc_bipartite_etype_sampling_shuffle(
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
        )
Jinjing Zhou's avatar
Jinjing Zhou committed
1362

1363

1364
def check_standalone_sampling(tmpdir):
1365
    g = CitationGraphDataset("cora")[0]
1366
    prob = np.maximum(np.random.randn(g.num_edges()), 0)
1367
1368
1369
    mask = prob > 0
    g.edata["prob"] = F.tensor(prob)
    g.edata["mask"] = F.tensor(mask)
1370
1371
    num_parts = 1
    num_hops = 1
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
    partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )

    os.environ["DGL_DIST_MODE"] = "standalone"
1382
    dgl.distributed.initialize("rpc_ip_config.txt")
1383
1384
1385
    dist_graph = DistGraph(
        "test_sampling", part_config=tmpdir / "test_sampling.json"
    )
1386
1387
1388
    sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)

    src, dst = sampled_graph.edges()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1389
    assert sampled_graph.num_nodes() == g.num_nodes()
1390
1391
1392
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    assert np.array_equal(
1393
1394
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)
    )
1395
1396

    sampled_graph = sample_neighbors(
1397
1398
        dist_graph, [0, 10, 99, 66, 1024, 2008], 3, prob="mask"
    )
1399
1400
1401
1402
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert mask[eid].all()

    sampled_graph = sample_neighbors(
1403
1404
        dist_graph, [0, 10, 99, 66, 1024, 2008], 3, prob="prob"
    )
1405
1406
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert (prob[eid] > 0).all()
1407
    dgl.distributed.exit_client()
1408

1409

1410
def check_standalone_etype_sampling(tmpdir):
1411
    hg = CitationGraphDataset("cora")[0]
1412
    prob = np.maximum(np.random.randn(hg.num_edges()), 0)
1413
1414
1415
    mask = prob > 0
    hg.edata["prob"] = F.tensor(prob)
    hg.edata["mask"] = F.tensor(mask)
1416
1417
1418
    num_parts = 1
    num_hops = 1

1419
1420
1421
1422
1423
1424
1425
1426
1427
    partition_graph(
        hg,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )
    os.environ["DGL_DIST_MODE"] = "standalone"
1428
    dgl.distributed.initialize("rpc_ip_config.txt")
1429
1430
1431
    dist_graph = DistGraph(
        "test_sampling", part_config=tmpdir / "test_sampling.json"
    )
1432
    sampled_graph = sample_etype_neighbors(dist_graph, [0, 10, 99, 66, 1023], 3)
1433
1434

    src, dst = sampled_graph.edges()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1435
    assert sampled_graph.num_nodes() == hg.num_nodes()
1436
1437
1438
    assert np.all(F.asnumpy(hg.has_edges_between(src, dst)))
    eids = hg.edge_ids(src, dst)
    assert np.array_equal(
1439
1440
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)
    )
1441
1442

    sampled_graph = sample_etype_neighbors(
1443
1444
        dist_graph, [0, 10, 99, 66, 1023], 3, prob="mask"
    )
1445
1446
1447
1448
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert mask[eid].all()

    sampled_graph = sample_etype_neighbors(
1449
1450
        dist_graph, [0, 10, 99, 66, 1023], 3, prob="prob"
    )
1451
1452
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert (prob[eid] > 0).all()
1453
1454
    dgl.distributed.exit_client()

1455

1456
def check_standalone_etype_sampling_heterograph(tmpdir):
1457
    hg = CitationGraphDataset("cora")[0]
1458
1459
1460
    num_parts = 1
    num_hops = 1
    src, dst = hg.edges()
1461
1462
1463
1464
1465
    new_hg = dgl.heterograph(
        {
            ("paper", "cite", "paper"): (src, dst),
            ("paper", "cite-by", "paper"): (dst, src),
        },
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1466
        {"paper": hg.num_nodes()},
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
    )
    partition_graph(
        new_hg,
        "test_hetero_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )
    os.environ["DGL_DIST_MODE"] = "standalone"
1477
    dgl.distributed.initialize("rpc_ip_config.txt")
1478
1479
1480
    dist_graph = DistGraph(
        "test_hetero_sampling", part_config=tmpdir / "test_hetero_sampling.json"
    )
1481
    sampled_graph = sample_etype_neighbors(
1482
1483
1484
        dist_graph, [0, 1, 2, 10, 99, 66, 1023, 1024, 2700, 2701], 1
    )
    src, dst = sampled_graph.edges(etype=("paper", "cite", "paper"))
1485
    assert len(src) == 10
1486
    src, dst = sampled_graph.edges(etype=("paper", "cite-by", "paper"))
1487
    assert len(src) == 10
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1488
    assert sampled_graph.num_nodes() == new_hg.num_nodes()
1489
1490
    dgl.distributed.exit_client()

1491
1492
1493
1494
1495
1496

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
1497
def test_standalone_sampling():
1498
    reset_envs()
1499
    import tempfile
1500
1501

    os.environ["DGL_DIST_MODE"] = "standalone"
1502
    with tempfile.TemporaryDirectory() as tmpdirname:
1503
        check_standalone_sampling(Path(tmpdirname))
1504

1505

1506
1507
def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes):
    gpb = None
1508
    dgl.distributed.initialize("rpc_ip_config.txt")
1509
    if disable_shared_mem:
1510
1511
1512
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_in_subgraph.json", rank
        )
1513
    dist_graph = DistGraph("test_in_subgraph", gpb=gpb)
1514
1515
1516
    try:
        sampled_graph = dgl.distributed.in_subgraph(dist_graph, nodes)
    except Exception as e:
1517
        print(traceback.format_exc())
1518
        sampled_graph = None
1519
    dgl.distributed.exit_client()
1520
1521
1522
    return sampled_graph


1523
def check_rpc_in_subgraph_shuffle(tmpdir, num_server):
1524
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
1525
1526
1527
1528

    g = CitationGraphDataset("cora")[0]
    num_parts = num_server

1529
1530
1531
1532
1533
1534
1535
1536
1537
    orig_nid, orig_eid = partition_graph(
        g,
        "test_in_subgraph",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
1538
1539

    pserver_list = []
1540
    ctx = mp.get_context("spawn")
1541
    for i in range(num_server):
1542
1543
1544
1545
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_in_subgraph"),
        )
1546
1547
1548
1549
1550
1551
1552
1553
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    nodes = [0, 10, 99, 66, 1024, 2008]
    sampled_graph = start_in_subgraph_client(0, tmpdir, num_server > 1, nodes)
    for p in pserver_list:
        p.join()
1554
        assert p.exitcode == 0
1555
1556

    src, dst = sampled_graph.edges()
1557
1558
    src = orig_nid[src]
    dst = orig_nid[dst]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1559
    assert sampled_graph.num_nodes() == g.num_nodes()
1560
1561
1562
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))

    subg1 = dgl.in_subgraph(g, orig_nid[nodes])
1563
1564
1565
1566
    src1, dst1 = subg1.edges()
    assert np.all(np.sort(F.asnumpy(src)) == np.sort(F.asnumpy(src1)))
    assert np.all(np.sort(F.asnumpy(dst)) == np.sort(F.asnumpy(dst1)))
    eids = g.edge_ids(src, dst)
1567
1568
    eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
    assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))
1569

1570
1571
1572
1573
1574
1575

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
1576
def test_rpc_in_subgraph():
1577
    reset_envs()
1578
    import tempfile
1579
1580

    os.environ["DGL_DIST_MODE"] = "distributed"
1581
    with tempfile.TemporaryDirectory() as tmpdirname:
1582
        check_rpc_in_subgraph_shuffle(Path(tmpdirname), 1)
1583

1584
1585
1586
1587
1588
1589
1590
1591
1592

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
1593
def test_standalone_etype_sampling():
1594
    reset_envs()
1595
    import tempfile
1596

1597
    with tempfile.TemporaryDirectory() as tmpdirname:
1598
        os.environ["DGL_DIST_MODE"] = "standalone"
1599
        check_standalone_etype_sampling_heterograph(Path(tmpdirname))
1600
    with tempfile.TemporaryDirectory() as tmpdirname:
1601
        os.environ["DGL_DIST_MODE"] = "standalone"
1602
        check_standalone_etype_sampling(Path(tmpdirname))
1603

1604

Jinjing Zhou's avatar
Jinjing Zhou committed
1605
1606
if __name__ == "__main__":
    import tempfile
1607

Jinjing Zhou's avatar
Jinjing Zhou committed
1608
    with tempfile.TemporaryDirectory() as tmpdirname:
1609
        os.environ["DGL_DIST_MODE"] = "standalone"
1610
        check_standalone_etype_sampling_heterograph(Path(tmpdirname))
1611
1612

    with tempfile.TemporaryDirectory() as tmpdirname:
1613
        os.environ["DGL_DIST_MODE"] = "standalone"
1614
1615
        check_standalone_etype_sampling(Path(tmpdirname))
        check_standalone_sampling(Path(tmpdirname))
1616
        os.environ["DGL_DIST_MODE"] = "distributed"
1617
1618
        check_rpc_sampling(Path(tmpdirname), 2)
        check_rpc_sampling(Path(tmpdirname), 1)
1619
1620
        check_rpc_get_degree_shuffle(Path(tmpdirname), 1)
        check_rpc_get_degree_shuffle(Path(tmpdirname), 2)
1621
1622
        check_rpc_find_edges_shuffle(Path(tmpdirname), 2)
        check_rpc_find_edges_shuffle(Path(tmpdirname), 1)
1623
1624
        check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), 2)
1625
1626
1627
1628
        check_rpc_in_subgraph_shuffle(Path(tmpdirname), 2)
        check_rpc_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), 2)
1629
1630
1631
1632
        check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), 2)
        check_rpc_hetero_etype_sampling_empty_shuffle(Path(tmpdirname), 1)