test_distributed_sampling.py 51 KB
Newer Older
1
import multiprocessing as mp
Jinjing Zhou's avatar
Jinjing Zhou committed
2
import os
3
import random
4
import tempfile
Jinjing Zhou's avatar
Jinjing Zhou committed
5
import time
6
7
import traceback
import unittest
Jinjing Zhou's avatar
Jinjing Zhou committed
8
from pathlib import Path
9
10
11
12

import backend as F
import dgl
import numpy as np
13
import pytest
14
import torch
15
16
17
18
19
20
21
22
23
24
from dgl.data import CitationGraphDataset, WN18Dataset
from dgl.distributed import (
    DistGraph,
    DistGraphServer,
    load_partition,
    load_partition_book,
    partition_graph,
    sample_etype_neighbors,
    sample_neighbors,
)
25
from scipy import sparse as spsp
26
from utils import generate_ip_config, reset_envs
Jinjing Zhou's avatar
Jinjing Zhou committed
27
28


29
30
31
32
33
34
def start_server(
    rank,
    tmpdir,
    disable_shared_mem,
    graph_name,
    graph_format=["csc", "coo"],
35
    use_graphbolt=False,
36
37
38
39
40
41
42
43
44
):
    g = DistGraphServer(
        rank,
        "rpc_ip_config.txt",
        1,
        1,
        tmpdir / (graph_name + ".json"),
        disable_shared_mem=disable_shared_mem,
        graph_format=graph_format,
45
        use_graphbolt=use_graphbolt,
46
    )
Jinjing Zhou's avatar
Jinjing Zhou committed
47
48
49
    g.start()


50
def start_sample_client(rank, tmpdir, disable_shared_mem):
51
52
    gpb = None
    if disable_shared_mem:
53
54
55
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
56
    dgl.distributed.initialize("rpc_ip_config.txt")
57
    dist_graph = DistGraph("test_sampling", gpb=gpb)
58
    try:
59
        sampled_graph = sample_neighbors(
60
61
62
            dist_graph,
            torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=dist_graph.idtype),
            3,
63
        )
64
    except Exception as e:
65
        print(traceback.format_exc())
66
        sampled_graph = None
67
    dgl.distributed.exit_client()
Jinjing Zhou's avatar
Jinjing Zhou committed
68
69
    return sampled_graph

70

71
72
73
74
75
76
77
78
79
def start_sample_client_shuffle(
    rank,
    tmpdir,
    disable_shared_mem,
    g,
    num_servers,
    group_id,
    orig_nid,
    orig_eid,
80
    use_graphbolt=False,
81
    return_eids=False,
82
    node_id_dtype=None,
83
    replace=False,
84
85
):
    os.environ["DGL_GROUP_ID"] = str(group_id)
86
87
    gpb = None
    if disable_shared_mem:
88
89
90
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
91
    dgl.distributed.initialize("rpc_ip_config.txt")
92
    dist_graph = DistGraph("test_sampling", gpb=gpb)
93
    sampled_graph = sample_neighbors(
94
        dist_graph,
95
        torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=node_id_dtype),
96
        3,
97
        replace=replace,
98
        use_graphbolt=use_graphbolt,
99
    )
100
    assert sampled_graph.idtype == dist_graph.idtype
101
    assert sampled_graph.idtype == torch.int64
102

103
104
105
    assert (
        dgl.ETYPE not in sampled_graph.edata
    ), "Etype should not be in homogeneous sampled graph."
106
    src, dst = sampled_graph.edges()
107
    sampled_in_degrees = sampled_graph.in_degrees(dst)
108
109
    src = orig_nid[src]
    dst = orig_nid[dst]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
110
    assert sampled_graph.num_nodes() == g.num_nodes()
111
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
112
    if use_graphbolt and not return_eids:
113
114
115
116
117
118
119
        assert (
            dgl.EID not in sampled_graph.edata
        ), "EID should not be in sampled graph if use_graphbolt=True."
    else:
        eids = g.edge_ids(src, dst)
        eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
        assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))
120
121
122
123
124
125
126
127
    # Verify replace argument.
    orig_in_degrees = g.in_degrees(dst)
    if replace:
        assert torch.all(
            (sampled_in_degrees == 3) | (sampled_in_degrees == orig_in_degrees)
        )
    else:
        assert torch.all(sampled_in_degrees <= 3)
128

129

130
def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids, etype=None):
131
132
    gpb = None
    if disable_shared_mem:
133
134
135
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_find_edges.json", rank
        )
136
    dgl.distributed.initialize("rpc_ip_config.txt")
137
    dist_graph = DistGraph("test_find_edges", gpb=gpb)
138
    try:
139
        u, v = dist_graph.find_edges(eids, etype=etype)
140
    except Exception as e:
141
        print(traceback.format_exc())
142
        u, v = None, None
143
144
    dgl.distributed.exit_client()
    return u, v
Jinjing Zhou's avatar
Jinjing Zhou committed
145

146

147
148
149
def start_get_degrees_client(rank, tmpdir, disable_shared_mem, nids=None):
    gpb = None
    if disable_shared_mem:
150
151
152
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_get_degrees.json", rank
        )
153
    dgl.distributed.initialize("rpc_ip_config.txt")
154
155
156
157
158
159
160
    dist_graph = DistGraph("test_get_degrees", gpb=gpb)
    try:
        in_deg = dist_graph.in_degrees(nids)
        all_in_deg = dist_graph.in_degrees()
        out_deg = dist_graph.out_degrees(nids)
        all_out_deg = dist_graph.out_degrees()
    except Exception as e:
161
        print(traceback.format_exc())
162
163
164
165
        in_deg, out_deg, all_in_deg, all_out_deg = None, None, None, None
    dgl.distributed.exit_client()
    return in_deg, out_deg, all_in_deg, all_out_deg

166

167
def check_rpc_sampling(tmpdir, num_server):
168
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
Jinjing Zhou's avatar
Jinjing Zhou committed
169
170
171
172
173
174

    g = CitationGraphDataset("cora")[0]
    print(g.idtype)
    num_parts = num_server
    num_hops = 1

175
176
177
178
179
180
181
182
    partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )
Jinjing Zhou's avatar
Jinjing Zhou committed
183
184

    pserver_list = []
185
    ctx = mp.get_context("spawn")
Jinjing Zhou's avatar
Jinjing Zhou committed
186
    for i in range(num_server):
187
188
189
190
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
Jinjing Zhou's avatar
Jinjing Zhou committed
191
192
193
194
        p.start()
        time.sleep(1)
        pserver_list.append(p)

195
    sampled_graph = start_sample_client(0, tmpdir, num_server > 1)
Jinjing Zhou's avatar
Jinjing Zhou committed
196
197
198
    print("Done sampling")
    for p in pserver_list:
        p.join()
199
        assert p.exitcode == 0
Jinjing Zhou's avatar
Jinjing Zhou committed
200
201

    src, dst = sampled_graph.edges()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
202
    assert sampled_graph.num_nodes() == g.num_nodes()
Jinjing Zhou's avatar
Jinjing Zhou committed
203
204
205
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    assert np.array_equal(
206
207
208
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)
    )

Jinjing Zhou's avatar
Jinjing Zhou committed
209

210
def check_rpc_find_edges_shuffle(tmpdir, num_server):
211
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
212
213
214
215

    g = CitationGraphDataset("cora")[0]
    num_parts = num_server

216
217
218
219
220
221
222
223
224
    orig_nid, orig_eid = partition_graph(
        g,
        "test_find_edges",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
225
226

    pserver_list = []
227
    ctx = mp.get_context("spawn")
228
    for i in range(num_server):
229
230
231
232
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_find_edges", ["csr", "coo"]),
        )
233
234
235
236
        p.start()
        time.sleep(1)
        pserver_list.append(p)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
237
    eids = F.tensor(np.random.randint(g.num_edges(), size=100))
238
    u, v = g.find_edges(orig_eid[eids])
239
    du, dv = start_find_edges_client(0, tmpdir, num_server > 1, eids)
240
241
    du = orig_nid[du]
    dv = orig_nid[dv]
242
243
244
    assert F.array_equal(u, du)
    assert F.array_equal(v, dv)

245

246
def create_random_hetero(dense=False, empty=False):
247
248
249
250
251
252
    num_nodes = (
        {"n1": 210, "n2": 200, "n3": 220}
        if dense
        else {"n1": 1010, "n2": 1000, "n3": 1020}
    )
    etypes = [("n1", "r12", "n2"), ("n1", "r13", "n3"), ("n2", "r23", "n3")]
253
    edges = {}
254
    random.seed(42)
255
256
    for etype in etypes:
        src_ntype, _, dst_ntype = etype
257
258
259
260
261
262
263
        arr = spsp.random(
            num_nodes[src_ntype] - 10 if empty else num_nodes[src_ntype],
            num_nodes[dst_ntype] - 10 if empty else num_nodes[dst_ntype],
            density=0.1 if dense else 0.001,
            format="coo",
            random_state=100,
        )
264
        edges[etype] = (arr.row, arr.col)
265
    g = dgl.heterograph(edges, num_nodes)
266
    g.nodes["n1"].data["feat"] = F.ones(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
267
        (g.num_nodes("n1"), 10), F.float32, F.cpu()
268
    )
269
    return g
270

271

272
def check_rpc_hetero_find_edges_shuffle(tmpdir, num_server):
273
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
274
275
276
277

    g = create_random_hetero()
    num_parts = num_server

278
279
280
281
282
283
284
285
286
    orig_nid, orig_eid = partition_graph(
        g,
        "test_find_edges",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
287
288

    pserver_list = []
289
    ctx = mp.get_context("spawn")
290
    for i in range(num_server):
291
292
293
294
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_find_edges", ["csr", "coo"]),
        )
295
296
297
298
        p.start()
        time.sleep(1)
        pserver_list.append(p)

299
    test_etype = g.to_canonical_etype("r12")
300
    eids = F.tensor(np.random.randint(g.num_edges(test_etype), size=100))
301
302
    expect_except = False
    try:
303
        _, _ = g.find_edges(orig_eid[test_etype][eids], etype=("n1", "r12"))
304
305
306
    except:
        expect_except = True
    assert expect_except
307
308
    u, v = g.find_edges(orig_eid[test_etype][eids], etype="r12")
    u1, v1 = g.find_edges(orig_eid[test_etype][eids], etype=("n1", "r12", "n2"))
309
310
    assert F.array_equal(u, u1)
    assert F.array_equal(v, v1)
311
312
313
314
315
    du, dv = start_find_edges_client(
        0, tmpdir, num_server > 1, eids, etype="r12"
    )
    du = orig_nid["n1"][du]
    dv = orig_nid["n2"][dv]
316
317
318
    assert F.array_equal(u, du)
    assert F.array_equal(v, dv)

319

320
# Wait non shared memory graph store
321
322
323
324
325
326
327
328
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
329
@pytest.mark.parametrize("num_server", [1])
330
def test_rpc_find_edges_shuffle(num_server):
331
    reset_envs()
332
    import tempfile
333
334

    os.environ["DGL_DIST_MODE"] = "distributed"
335
    with tempfile.TemporaryDirectory() as tmpdirname:
336
        check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), num_server)
337
338
        check_rpc_find_edges_shuffle(Path(tmpdirname), num_server)

339

340
def check_rpc_get_degree_shuffle(tmpdir, num_server):
341
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
342
343
344
345

    g = CitationGraphDataset("cora")[0]
    num_parts = num_server

346
347
348
349
350
351
352
353
354
    orig_nid, _ = partition_graph(
        g,
        "test_get_degrees",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
355
356

    pserver_list = []
357
    ctx = mp.get_context("spawn")
358
    for i in range(num_server):
359
360
361
362
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_get_degrees"),
        )
363
364
365
366
        p.start()
        time.sleep(1)
        pserver_list.append(p)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
367
    nids = F.tensor(np.random.randint(g.num_nodes(), size=100))
368
369
370
    in_degs, out_degs, all_in_degs, all_out_degs = start_get_degrees_client(
        0, tmpdir, num_server > 1, nids
    )
371
372
373
374

    print("Done get_degree")
    for p in pserver_list:
        p.join()
375
        assert p.exitcode == 0
376

377
    print("check results")
378
379
380
381
382
    assert F.array_equal(g.in_degrees(orig_nid[nids]), in_degs)
    assert F.array_equal(g.in_degrees(orig_nid), all_in_degs)
    assert F.array_equal(g.out_degrees(orig_nid[nids]), out_degs)
    assert F.array_equal(g.out_degrees(orig_nid), all_out_degs)

383

384
# Wait non shared memory graph store
385
386
387
388
389
390
391
392
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
393
@pytest.mark.parametrize("num_server", [1])
394
def test_rpc_get_degree_shuffle(num_server):
395
    reset_envs()
396
    import tempfile
397
398

    os.environ["DGL_DIST_MODE"] = "distributed"
399
400
401
    with tempfile.TemporaryDirectory() as tmpdirname:
        check_rpc_get_degree_shuffle(Path(tmpdirname), num_server)

402
403
404
405

# @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
# @unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
@unittest.skip("Only support partition with shuffle")
Jinjing Zhou's avatar
Jinjing Zhou committed
406
def test_rpc_sampling():
407
    reset_envs()
Jinjing Zhou's avatar
Jinjing Zhou committed
408
    import tempfile
409
410

    os.environ["DGL_DIST_MODE"] = "distributed"
Jinjing Zhou's avatar
Jinjing Zhou committed
411
    with tempfile.TemporaryDirectory() as tmpdirname:
412
        check_rpc_sampling(Path(tmpdirname), 1)
Jinjing Zhou's avatar
Jinjing Zhou committed
413

414

415
def check_rpc_sampling_shuffle(
416
417
418
419
420
421
    tmpdir,
    num_server,
    num_groups=1,
    use_graphbolt=False,
    return_eids=False,
    node_id_dtype=None,
422
    replace=False,
423
):
424
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
425

Jinjing Zhou's avatar
Jinjing Zhou committed
426
427
428
429
    g = CitationGraphDataset("cora")[0]
    num_parts = num_server
    num_hops = 1

430
431
432
433
434
435
436
437
    orig_nids, orig_eids = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
438
        use_graphbolt=use_graphbolt,
439
        store_eids=return_eids,
440
    )
Jinjing Zhou's avatar
Jinjing Zhou committed
441
442

    pserver_list = []
443
    ctx = mp.get_context("spawn")
Jinjing Zhou's avatar
Jinjing Zhou committed
444
    for i in range(num_server):
445
446
447
448
449
450
451
452
        p = ctx.Process(
            target=start_server,
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
453
                use_graphbolt,
454
455
            ),
        )
Jinjing Zhou's avatar
Jinjing Zhou committed
456
457
458
459
        p.start()
        time.sleep(1)
        pserver_list.append(p)

460
461
462
463
    pclient_list = []
    num_clients = 1
    for client_id in range(num_clients):
        for group_id in range(num_groups):
464
465
466
467
468
469
470
471
472
473
474
            p = ctx.Process(
                target=start_sample_client_shuffle,
                args=(
                    client_id,
                    tmpdir,
                    num_server > 1,
                    g,
                    num_server,
                    group_id,
                    orig_nids,
                    orig_eids,
475
                    use_graphbolt,
476
                    return_eids,
477
                    node_id_dtype,
478
                    replace,
479
480
                ),
            )
481
            p.start()
482
            time.sleep(1)  # avoid race condition when instantiating DistGraph
483
484
485
            pclient_list.append(p)
    for p in pclient_list:
        p.join()
486
        assert p.exitcode == 0
Jinjing Zhou's avatar
Jinjing Zhou committed
487
488
    for p in pserver_list:
        p.join()
489
        assert p.exitcode == 0
Jinjing Zhou's avatar
Jinjing Zhou committed
490

491

492
493
494
495
496
497
498
def start_hetero_sample_client(
    rank,
    tmpdir,
    disable_shared_mem,
    nodes,
    use_graphbolt=False,
    return_eids=False,
499
    replace=False,
500
):
501
502
    gpb = None
    if disable_shared_mem:
503
504
505
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
506
    dgl.distributed.initialize("rpc_ip_config.txt")
507
    dist_graph = DistGraph("test_sampling", gpb=gpb)
508
509
510
    assert "feat" in dist_graph.nodes["n1"].data
    assert "feat" not in dist_graph.nodes["n2"].data
    assert "feat" not in dist_graph.nodes["n3"].data
511
512
513
    nodes = {
        k: torch.tensor(v, dtype=dist_graph.idtype) for k, v in nodes.items()
    }
514
515
516
    if gpb is None:
        gpb = dist_graph.get_partition_book()
    try:
517
518
519
        # Enable santity check in distributed sampling.
        os.environ["DGL_DIST_DEBUG"] = "1"
        sampled_graph = sample_neighbors(
520
            dist_graph, nodes, 3, replace=replace, use_graphbolt=use_graphbolt
521
        )
522
        block = dgl.to_block(sampled_graph, nodes)
523
524
        if not use_graphbolt or return_eids:
            block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
525
    except Exception as e:
526
        print(traceback.format_exc())
527
528
529
530
        block = None
    dgl.distributed.exit_client()
    return block, gpb

531
532
533
534
535
536

def start_hetero_etype_sample_client(
    rank,
    tmpdir,
    disable_shared_mem,
    fanout=3,
537
    nodes=None,
538
    etype_sorted=False,
539
540
    use_graphbolt=False,
    return_eids=False,
541
):
542
543
    gpb = None
    if disable_shared_mem:
544
545
546
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
547
    dgl.distributed.initialize("rpc_ip_config.txt")
548
    dist_graph = DistGraph("test_sampling", gpb=gpb)
549
550
551
    assert "feat" in dist_graph.nodes["n1"].data
    assert "feat" not in dist_graph.nodes["n2"].data
    assert "feat" not in dist_graph.nodes["n3"].data
552
553
554
    nodes = {
        k: torch.tensor(v, dtype=dist_graph.idtype) for k, v in nodes.items()
    }
555

556
    if (not use_graphbolt) and dist_graph.local_partition is not None:
557
558
559
560
        # Check whether etypes are sorted in dist_graph
        local_g = dist_graph.local_partition
        local_nids = np.arange(local_g.num_nodes())
        for lnid in local_nids:
561
            leids = local_g.in_edges(lnid, form="eid")
562
563
564
565
            letids = F.asnumpy(local_g.edata[dgl.ETYPE][leids])
            _, idices = np.unique(letids, return_index=True)
            assert np.all(idices[:-1] <= idices[1:])

566
567
568
    if gpb is None:
        gpb = dist_graph.get_partition_book()
    try:
569
570
        # Enable santity check in distributed sampling.
        os.environ["DGL_DIST_DEBUG"] = "1"
571
        sampled_graph = sample_etype_neighbors(
572
573
574
575
576
            dist_graph,
            nodes,
            fanout,
            etype_sorted=etype_sorted,
            use_graphbolt=use_graphbolt,
577
        )
578
        block = dgl.to_block(sampled_graph, nodes)
579
580
581
        if sampled_graph.num_edges() > 0:
            if not use_graphbolt or return_eids:
                block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
582
    except Exception as e:
583
        print(traceback.format_exc())
584
585
586
587
        block = None
    dgl.distributed.exit_client()
    return block, gpb

588

589
def check_rpc_hetero_sampling_shuffle(
590
    tmpdir, num_server, use_graphbolt=False, return_eids=False, replace=False
591
):
592
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
593
594
595
596
597

    g = create_random_hetero()
    num_parts = num_server
    num_hops = 1

598
599
600
601
602
603
604
605
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
606
607
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
608
    )
609
610

    pserver_list = []
611
    ctx = mp.get_context("spawn")
612
    for i in range(num_server):
613
614
        p = ctx.Process(
            target=start_server,
615
616
617
618
619
620
621
622
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
623
        )
624
625
626
627
        p.start()
        time.sleep(1)
        pserver_list.append(p)

628
    nodes = {"n3": torch.tensor([0, 10, 99, 66, 124, 208], dtype=g.idtype)}
629
    block, gpb = start_hetero_sample_client(
630
631
632
        0,
        tmpdir,
        num_server > 1,
633
        nodes=nodes,
634
635
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
636
        replace=replace,
637
    )
638
639
    for p in pserver_list:
        p.join()
640
        assert p.exitcode == 0
641

642
643
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
644
645
646
647
648
649
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
650
651
652
653
654
655
656
657
658

        assert np.all(
            F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))
        )

        if use_graphbolt and not return_eids:
            continue

        shuffled_eid = block.edges[etype].data[dgl.EID]
659
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
660
661

        # Check the node Ids and edge Ids.
662
663
664
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)
665

666

667
668
669
670
671
672
673
674
675
def get_degrees(g, nids, ntype):
    deg = F.zeros((len(nids),), dtype=F.int64)
    for srctype, etype, dsttype in g.canonical_etypes:
        if srctype == ntype:
            deg += g.out_degrees(u=nids, etype=etype)
        elif dsttype == ntype:
            deg += g.in_degrees(v=nids, etype=etype)
    return deg

676

677
678
679
def check_rpc_hetero_sampling_empty_shuffle(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
680
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
681
682
683
684
685

    g = create_random_hetero(empty=True)
    num_parts = num_server
    num_hops = 1

686
687
688
689
690
691
692
693
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
694
695
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
696
    )
697
698

    pserver_list = []
699
    ctx = mp.get_context("spawn")
700
    for i in range(num_server):
701
702
        p = ctx.Process(
            target=start_server,
703
704
705
706
707
708
709
710
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
711
        )
712
713
714
715
        p.start()
        time.sleep(1)
        pserver_list.append(p)

716
    deg = get_degrees(g, orig_nids["n3"], "n3")
717
    empty_nids = F.nonzero_1d(deg == 0).to(g.idtype)
718
    block, gpb = start_hetero_sample_client(
719
720
721
722
723
724
        0,
        tmpdir,
        num_server > 1,
        nodes={"n3": empty_nids},
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
725
    )
726
727
    for p in pserver_list:
        p.join()
728
        assert p.exitcode == 0
729

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
730
    assert block.num_edges() == 0
731
732
    assert len(block.etypes) == len(g.etypes)

733
734

def check_rpc_hetero_etype_sampling_shuffle(
735
736
737
738
739
    tmpdir,
    num_server,
    graph_formats=None,
    use_graphbolt=False,
    return_eids=False,
740
):
741
742
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

743
744
745
746
    g = create_random_hetero(dense=True)
    num_parts = num_server
    num_hops = 1

747
748
749
750
751
752
753
754
755
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
        graph_formats=graph_formats,
756
757
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
758
    )
759
760

    pserver_list = []
761
    ctx = mp.get_context("spawn")
762
    for i in range(num_server):
763
764
        p = ctx.Process(
            target=start_server,
765
766
767
768
769
770
771
772
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
773
        )
774
775
776
777
        p.start()
        time.sleep(1)
        pserver_list.append(p)

778
    fanout = {etype: 3 for etype in g.canonical_etypes}
779
780
    etype_sorted = False
    if graph_formats is not None:
781
        etype_sorted = "csc" in graph_formats or "csr" in graph_formats
782
    nodes = {"n3": torch.tensor([0, 10, 99, 66, 124, 208], dtype=g.idtype)}
783
784
785
786
787
    block, gpb = start_hetero_etype_sample_client(
        0,
        tmpdir,
        num_server > 1,
        fanout,
788
        nodes=nodes,
789
        etype_sorted=etype_sorted,
790
791
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
792
    )
793
794
795
    print("Done sampling")
    for p in pserver_list:
        p.join()
796
        assert p.exitcode == 0
797

798
    src, dst = block.edges(etype=("n1", "r13", "n3"))
799
    assert len(src) == 18
800
    src, dst = block.edges(etype=("n2", "r23", "n3"))
801
802
    assert len(src) == 18

803
804
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
805
806
807
808
809
810
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
811
812
813
814
815
816
        assert np.all(
            F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))
        )

        if use_graphbolt and not return_eids:
            continue
817
818

        # Check the node Ids and edge Ids.
819
820
        shuffled_eid = block.edges[etype].data[dgl.EID]
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
821
822
823
824
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)

825

826
827
828
def check_rpc_hetero_etype_sampling_empty_shuffle(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
829
830
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

831
832
833
834
    g = create_random_hetero(dense=True, empty=True)
    num_parts = num_server
    num_hops = 1

835
836
837
838
839
840
841
842
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
843
844
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
845
    )
846
847

    pserver_list = []
848
    ctx = mp.get_context("spawn")
849
    for i in range(num_server):
850
851
        p = ctx.Process(
            target=start_server,
852
853
854
855
856
857
858
859
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
860
        )
861
862
863
864
865
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    fanout = 3
866
    deg = get_degrees(g, orig_nids["n3"], "n3")
867
    empty_nids = F.nonzero_1d(deg == 0).to(g.idtype)
868
    block, gpb = start_hetero_etype_sample_client(
869
870
871
872
873
874
875
        0,
        tmpdir,
        num_server > 1,
        fanout,
        nodes={"n3": empty_nids},
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
876
    )
877
878
879
    print("Done sampling")
    for p in pserver_list:
        p.join()
880
        assert p.exitcode == 0
881

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
882
    assert block.num_edges() == 0
883
884
    assert len(block.etypes) == len(g.etypes)

885
886

def create_random_bipartite():
887
888
889
890
891
892
893
    g = dgl.rand_bipartite("user", "buys", "game", 500, 1000, 1000)
    g.nodes["user"].data["feat"] = F.ones(
        (g.num_nodes("user"), 10), F.float32, F.cpu()
    )
    g.nodes["game"].data["feat"] = F.ones(
        (g.num_nodes("game"), 10), F.float32, F.cpu()
    )
894
895
896
    return g


897
898
899
900
901
902
903
904
def start_bipartite_sample_client(
    rank,
    tmpdir,
    disable_shared_mem,
    nodes,
    use_graphbolt=False,
    return_eids=False,
):
905
906
907
    gpb = None
    if disable_shared_mem:
        _, _, _, gpb, _, _, _ = load_partition(
908
909
            tmpdir / "test_sampling.json", rank
        )
910
    dgl.distributed.initialize("rpc_ip_config.txt")
911
    dist_graph = DistGraph("test_sampling", gpb=gpb)
912
913
    assert "feat" in dist_graph.nodes["user"].data
    assert "feat" in dist_graph.nodes["game"].data
914
915
916
    nodes = {
        k: torch.tensor(v, dtype=dist_graph.idtype) for k, v in nodes.items()
    }
917
918
    if gpb is None:
        gpb = dist_graph.get_partition_book()
919
920
921
922
923
    # Enable santity check in distributed sampling.
    os.environ["DGL_DIST_DEBUG"] = "1"
    sampled_graph = sample_neighbors(
        dist_graph, nodes, 3, use_graphbolt=use_graphbolt
    )
924
925
    block = dgl.to_block(sampled_graph, nodes)
    if sampled_graph.num_edges() > 0:
926
927
        if not use_graphbolt or return_eids:
            block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
928
929
930
931
    dgl.distributed.exit_client()
    return block, gpb


932
def start_bipartite_etype_sample_client(
933
934
935
936
937
938
939
    rank,
    tmpdir,
    disable_shared_mem,
    fanout=3,
    nodes={},
    use_graphbolt=False,
    return_eids=False,
940
):
941
942
943
    gpb = None
    if disable_shared_mem:
        _, _, _, gpb, _, _, _ = load_partition(
944
945
            tmpdir / "test_sampling.json", rank
        )
946
    dgl.distributed.initialize("rpc_ip_config.txt")
947
    dist_graph = DistGraph("test_sampling", gpb=gpb)
948
949
    assert "feat" in dist_graph.nodes["user"].data
    assert "feat" in dist_graph.nodes["game"].data
950
951
952
    nodes = {
        k: torch.tensor(v, dtype=dist_graph.idtype) for k, v in nodes.items()
    }
953

954
    if not use_graphbolt and dist_graph.local_partition is not None:
955
956
957
958
        # Check whether etypes are sorted in dist_graph
        local_g = dist_graph.local_partition
        local_nids = np.arange(local_g.num_nodes())
        for lnid in local_nids:
959
            leids = local_g.in_edges(lnid, form="eid")
960
961
962
963
964
965
            letids = F.asnumpy(local_g.edata[dgl.ETYPE][leids])
            _, idices = np.unique(letids, return_index=True)
            assert np.all(idices[:-1] <= idices[1:])

    if gpb is None:
        gpb = dist_graph.get_partition_book()
966
967
968
    sampled_graph = sample_etype_neighbors(
        dist_graph, nodes, fanout, use_graphbolt=use_graphbolt
    )
969
970
    block = dgl.to_block(sampled_graph, nodes)
    if sampled_graph.num_edges() > 0:
971
972
        if not use_graphbolt or return_eids:
            block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
973
974
975
976
    dgl.distributed.exit_client()
    return block, gpb


977
978
979
def check_rpc_bipartite_sampling_empty(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
980
981
982
983
984
985
986
    """sample on bipartite via sample_neighbors() which yields empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

987
988
989
990
991
992
993
994
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
995
996
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
997
    )
998
999

    pserver_list = []
1000
    ctx = mp.get_context("spawn")
1001
    for i in range(num_server):
1002
1003
        p = ctx.Process(
            target=start_server,
1004
1005
1006
1007
1008
1009
1010
1011
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
1012
        )
1013
1014
1015
1016
        p.start()
        time.sleep(1)
        pserver_list.append(p)

1017
    deg = get_degrees(g, orig_nids["game"], "game")
1018
    empty_nids = F.nonzero_1d(deg == 0).to(g.idtype)
1019
    nodes = {"game": empty_nids, "user": torch.tensor([1], dtype=g.idtype)}
1020
    block, _ = start_bipartite_sample_client(
1021
1022
1023
        0,
        tmpdir,
        num_server > 1,
1024
        nodes=nodes,
1025
1026
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
1027
    )
1028
1029
1030
1031

    print("Done sampling")
    for p in pserver_list:
        p.join()
1032
        assert p.exitcode == 0
1033

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1034
    assert block.num_edges() == 0
1035
1036
1037
    assert len(block.etypes) == len(g.etypes)


1038
1039
1040
def check_rpc_bipartite_sampling_shuffle(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
1041
1042
1043
1044
1045
1046
1047
    """sample on bipartite via sample_neighbors() which yields non-empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

1048
1049
1050
1051
1052
1053
1054
1055
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
1056
1057
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
1058
    )
1059
1060

    pserver_list = []
1061
    ctx = mp.get_context("spawn")
1062
    for i in range(num_server):
1063
1064
        p = ctx.Process(
            target=start_server,
1065
1066
1067
1068
1069
1070
1071
1072
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
1073
        )
1074
1075
1076
1077
        p.start()
        time.sleep(1)
        pserver_list.append(p)

1078
    deg = get_degrees(g, orig_nid_map["game"], "game")
1079
    nids = F.nonzero_1d(deg > 0)
1080
    nodes = {"game": nids, "user": torch.tensor([0], dtype=g.idtype)}
1081
    block, gpb = start_bipartite_sample_client(
1082
1083
1084
        0,
        tmpdir,
        num_server > 1,
1085
        nodes=nodes,
1086
1087
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
1088
    )
1089
1090
1091
    print("Done sampling")
    for p in pserver_list:
        p.join()
1092
        assert p.exitcode == 0
1093

1094
1095
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
1096
1097
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
1098
1099
1100
1101
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
1102
1103
1104
1105
1106
1107
1108
1109
        assert np.all(
            F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))
        )

        if use_graphbolt and not return_eids:
            continue

        shuffled_eid = block.edges[etype].data[dgl.EID]
1110
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
1111
1112
1113
1114
1115
1116
1117

        # Check the node Ids and edge Ids.
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)


1118
1119
1120
def check_rpc_bipartite_etype_sampling_empty(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
1121
1122
1123
1124
1125
1126
1127
    """sample on bipartite via sample_etype_neighbors() which yields empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

1128
1129
1130
1131
1132
1133
1134
1135
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
1136
1137
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
1138
    )
1139
1140

    pserver_list = []
1141
    ctx = mp.get_context("spawn")
1142
    for i in range(num_server):
1143
1144
        p = ctx.Process(
            target=start_server,
1145
1146
1147
1148
1149
1150
1151
1152
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
1153
        )
1154
1155
1156
1157
        p.start()
        time.sleep(1)
        pserver_list.append(p)

1158
    deg = get_degrees(g, orig_nids["game"], "game")
1159
    empty_nids = F.nonzero_1d(deg == 0).to(g.idtype)
1160
    nodes = {"game": empty_nids, "user": torch.tensor([1], dtype=g.idtype)}
1161
1162
1163
1164
    block, _ = start_bipartite_etype_sample_client(
        0,
        tmpdir,
        num_server > 1,
1165
        nodes=nodes,
1166
1167
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
1168
    )
1169
1170
1171
1172

    print("Done sampling")
    for p in pserver_list:
        p.join()
1173
        assert p.exitcode == 0
1174
1175

    assert block is not None
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1176
    assert block.num_edges() == 0
1177
1178
1179
    assert len(block.etypes) == len(g.etypes)


1180
1181
1182
def check_rpc_bipartite_etype_sampling_shuffle(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
1183
1184
1185
1186
1187
1188
1189
    """sample on bipartite via sample_etype_neighbors() which yields non-empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

1190
1191
1192
1193
1194
1195
1196
1197
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
1198
1199
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
1200
    )
1201
1202

    pserver_list = []
1203
    ctx = mp.get_context("spawn")
1204
    for i in range(num_server):
1205
1206
        p = ctx.Process(
            target=start_server,
1207
1208
1209
1210
1211
1212
1213
1214
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
1215
        )
1216
1217
1218
1219
1220
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    fanout = 3
1221
    deg = get_degrees(g, orig_nid_map["game"], "game")
1222
    nids = F.nonzero_1d(deg > 0)
1223
    nodes = {"game": nids, "user": torch.tensor([0], dtype=g.idtype)}
1224
    block, gpb = start_bipartite_etype_sample_client(
1225
1226
1227
1228
        0,
        tmpdir,
        num_server > 1,
        fanout,
1229
        nodes=nodes,
1230
1231
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
1232
    )
1233
1234
1235
    print("Done sampling")
    for p in pserver_list:
        p.join()
1236
        assert p.exitcode == 0
1237

1238
1239
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
1240
1241
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
1242
1243
1244
1245
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
1246
1247
1248
1249
1250
1251
        assert np.all(
            F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))
        )

        if use_graphbolt and not return_eids:
            continue
1252
1253

        # Check the node Ids and edge Ids.
1254
1255
        shuffled_eid = block.edges[etype].data[dgl.EID]
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
1256
1257
1258
1259
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)

1260

1261
@pytest.mark.parametrize("num_server", [1])
1262
@pytest.mark.parametrize("use_graphbolt", [False, True])
1263
@pytest.mark.parametrize("return_eids", [False, True])
1264
@pytest.mark.parametrize("node_id_dtype", [torch.int64])
1265
@pytest.mark.parametrize("replace", [False, True])
1266
def test_rpc_sampling_shuffle(
1267
    num_server, use_graphbolt, return_eids, node_id_dtype, replace
1268
):
1269
    reset_envs()
1270
    os.environ["DGL_DIST_MODE"] = "distributed"
Jinjing Zhou's avatar
Jinjing Zhou committed
1271
    with tempfile.TemporaryDirectory() as tmpdirname:
1272
        check_rpc_sampling_shuffle(
1273
1274
1275
1276
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
1277
            node_id_dtype=node_id_dtype,
1278
            replace=replace,
1279
        )
1280
1281
1282


@pytest.mark.parametrize("num_server", [1])
1283
1284
@pytest.mark.parametrize("use_graphbolt,", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
1285
1286
1287
1288
@pytest.mark.parametrize("replace", [False, True])
def test_rpc_hetero_sampling_shuffle(
    num_server, use_graphbolt, return_eids, replace
):
1289
1290
1291
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1292
1293
1294
1295
1296
        check_rpc_hetero_sampling_shuffle(
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
1297
            replace=replace,
1298
        )
1299
1300
1301


@pytest.mark.parametrize("num_server", [1])
1302
1303
1304
1305
1306
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_sampling_empty_shuffle(
    num_server, use_graphbolt, return_eids
):
1307
1308
1309
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1310
1311
1312
1313
1314
1315
        check_rpc_hetero_sampling_empty_shuffle(
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
        )
1316
1317
1318
1319
1320
1321


@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize(
    "graph_formats", [None, ["csc"], ["csr"], ["csc", "coo"]]
)
1322
def test_rpc_hetero_etype_sampling_shuffle_dgl(num_server, graph_formats):
1323
1324
1325
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1326
        check_rpc_hetero_etype_sampling_shuffle(
1327
            Path(tmpdirname), num_server, graph_formats=graph_formats
1328
        )
1329
1330
1331


@pytest.mark.parametrize("num_server", [1])
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_etype_sampling_shuffle_graphbolt(num_server, return_eids):
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
        check_rpc_hetero_etype_sampling_shuffle(
            Path(tmpdirname),
            num_server,
            use_graphbolt=True,
            return_eids=return_eids,
        )


@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_etype_sampling_empty_shuffle(
    num_server, use_graphbolt, return_eids
):
1351
1352
1353
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1354
        check_rpc_hetero_etype_sampling_empty_shuffle(
1355
1356
1357
1358
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
1359
        )
1360
1361
1362


@pytest.mark.parametrize("num_server", [1])
1363
1364
1365
1366
1367
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_bipartite_sampling_empty_shuffle(
    num_server, use_graphbolt, return_eids
):
1368
1369
1370
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1371
1372
1373
        check_rpc_bipartite_sampling_empty(
            Path(tmpdirname), num_server, use_graphbolt, return_eids
        )
1374
1375
1376


@pytest.mark.parametrize("num_server", [1])
1377
1378
1379
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_bipartite_sampling_shuffle(num_server, use_graphbolt, return_eids):
1380
1381
1382
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1383
1384
1385
        check_rpc_bipartite_sampling_shuffle(
            Path(tmpdirname), num_server, use_graphbolt, return_eids
        )
1386
1387
1388


@pytest.mark.parametrize("num_server", [1])
1389
1390
1391
1392
1393
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_bipartite_etype_sampling_empty_shuffle(
    num_server, use_graphbolt, return_eids
):
1394
1395
1396
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1397
1398
1399
1400
1401
1402
        check_rpc_bipartite_etype_sampling_empty(
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
        )
1403
1404
1405


@pytest.mark.parametrize("num_server", [1])
1406
1407
1408
1409
1410
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_bipartite_etype_sampling_shuffle(
    num_server, use_graphbolt, return_eids
):
1411
1412
1413
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1414
1415
1416
1417
1418
1419
        check_rpc_bipartite_etype_sampling_shuffle(
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
        )
Jinjing Zhou's avatar
Jinjing Zhou committed
1420

1421

1422
def check_standalone_sampling(tmpdir):
1423
    g = CitationGraphDataset("cora")[0]
1424
    prob = np.maximum(np.random.randn(g.num_edges()), 0)
1425
1426
1427
    mask = prob > 0
    g.edata["prob"] = F.tensor(prob)
    g.edata["mask"] = F.tensor(mask)
1428
1429
    num_parts = 1
    num_hops = 1
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
    partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )

    os.environ["DGL_DIST_MODE"] = "standalone"
1440
    dgl.distributed.initialize("rpc_ip_config.txt")
1441
1442
1443
    dist_graph = DistGraph(
        "test_sampling", part_config=tmpdir / "test_sampling.json"
    )
1444
1445
1446
1447
1448
    sampled_graph = sample_neighbors(
        dist_graph,
        torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=dist_graph.idtype),
        3,
    )
1449
1450

    src, dst = sampled_graph.edges()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1451
    assert sampled_graph.num_nodes() == g.num_nodes()
1452
1453
1454
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    assert np.array_equal(
1455
1456
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)
    )
1457
1458

    sampled_graph = sample_neighbors(
1459
1460
1461
1462
        dist_graph,
        torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=dist_graph.idtype),
        3,
        prob="mask",
1463
    )
1464
1465
1466
1467
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert mask[eid].all()

    sampled_graph = sample_neighbors(
1468
1469
1470
1471
        dist_graph,
        torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=dist_graph.idtype),
        3,
        prob="prob",
1472
    )
1473
1474
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert (prob[eid] > 0).all()
1475
    dgl.distributed.exit_client()
1476

1477

1478
def check_standalone_etype_sampling(tmpdir):
1479
    hg = CitationGraphDataset("cora")[0]
1480
    prob = np.maximum(np.random.randn(hg.num_edges()), 0)
1481
1482
1483
    mask = prob > 0
    hg.edata["prob"] = F.tensor(prob)
    hg.edata["mask"] = F.tensor(mask)
1484
1485
1486
    num_parts = 1
    num_hops = 1

1487
1488
1489
1490
1491
1492
1493
1494
1495
    partition_graph(
        hg,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )
    os.environ["DGL_DIST_MODE"] = "standalone"
1496
    dgl.distributed.initialize("rpc_ip_config.txt")
1497
1498
1499
    dist_graph = DistGraph(
        "test_sampling", part_config=tmpdir / "test_sampling.json"
    )
1500
1501
1502
1503
1504
    sampled_graph = sample_etype_neighbors(
        dist_graph,
        torch.tensor([0, 10, 99, 66, 1023], dtype=dist_graph.idtype),
        3,
    )
1505
1506

    src, dst = sampled_graph.edges()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1507
    assert sampled_graph.num_nodes() == hg.num_nodes()
1508
1509
1510
    assert np.all(F.asnumpy(hg.has_edges_between(src, dst)))
    eids = hg.edge_ids(src, dst)
    assert np.array_equal(
1511
1512
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)
    )
1513
1514

    sampled_graph = sample_etype_neighbors(
1515
1516
1517
1518
        dist_graph,
        torch.tensor([0, 10, 99, 66, 1023], dtype=dist_graph.idtype),
        3,
        prob="mask",
1519
    )
1520
1521
1522
1523
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert mask[eid].all()

    sampled_graph = sample_etype_neighbors(
1524
1525
1526
1527
        dist_graph,
        torch.tensor([0, 10, 99, 66, 1023], dtype=dist_graph.idtype),
        3,
        prob="prob",
1528
    )
1529
1530
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert (prob[eid] > 0).all()
1531
1532
    dgl.distributed.exit_client()

1533

1534
def check_standalone_etype_sampling_heterograph(tmpdir):
1535
    hg = CitationGraphDataset("cora")[0]
1536
1537
1538
    num_parts = 1
    num_hops = 1
    src, dst = hg.edges()
1539
1540
1541
1542
1543
    new_hg = dgl.heterograph(
        {
            ("paper", "cite", "paper"): (src, dst),
            ("paper", "cite-by", "paper"): (dst, src),
        },
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1544
        {"paper": hg.num_nodes()},
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
    )
    partition_graph(
        new_hg,
        "test_hetero_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )
    os.environ["DGL_DIST_MODE"] = "standalone"
1555
    dgl.distributed.initialize("rpc_ip_config.txt")
1556
1557
1558
    dist_graph = DistGraph(
        "test_hetero_sampling", part_config=tmpdir / "test_hetero_sampling.json"
    )
1559
    sampled_graph = sample_etype_neighbors(
1560
1561
1562
1563
1564
1565
        dist_graph,
        torch.tensor(
            [0, 1, 2, 10, 99, 66, 1023, 1024, 2700, 2701],
            dtype=dist_graph.idtype,
        ),
        1,
1566
1567
    )
    src, dst = sampled_graph.edges(etype=("paper", "cite", "paper"))
1568
    assert len(src) == 10
1569
    src, dst = sampled_graph.edges(etype=("paper", "cite-by", "paper"))
1570
    assert len(src) == 10
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1571
    assert sampled_graph.num_nodes() == new_hg.num_nodes()
1572
1573
    dgl.distributed.exit_client()

1574
1575
1576
1577
1578
1579

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
1580
def test_standalone_sampling():
1581
    reset_envs()
1582
    import tempfile
1583
1584

    os.environ["DGL_DIST_MODE"] = "standalone"
1585
    with tempfile.TemporaryDirectory() as tmpdirname:
1586
        check_standalone_sampling(Path(tmpdirname))
1587

1588

1589
1590
def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes):
    gpb = None
1591
    dgl.distributed.initialize("rpc_ip_config.txt")
1592
    if disable_shared_mem:
1593
1594
1595
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_in_subgraph.json", rank
        )
1596
    dist_graph = DistGraph("test_in_subgraph", gpb=gpb)
1597
1598
1599
    try:
        sampled_graph = dgl.distributed.in_subgraph(dist_graph, nodes)
    except Exception as e:
1600
        print(traceback.format_exc())
1601
        sampled_graph = None
1602
    dgl.distributed.exit_client()
1603
1604
1605
    return sampled_graph


1606
def check_rpc_in_subgraph_shuffle(tmpdir, num_server):
1607
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
1608
1609
1610
1611

    g = CitationGraphDataset("cora")[0]
    num_parts = num_server

1612
1613
1614
1615
1616
1617
1618
1619
1620
    orig_nid, orig_eid = partition_graph(
        g,
        "test_in_subgraph",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
1621
1622

    pserver_list = []
1623
    ctx = mp.get_context("spawn")
1624
    for i in range(num_server):
1625
1626
1627
1628
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_in_subgraph"),
        )
1629
1630
1631
1632
        p.start()
        time.sleep(1)
        pserver_list.append(p)

1633
    nodes = torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=g.idtype)
1634
1635
1636
    sampled_graph = start_in_subgraph_client(0, tmpdir, num_server > 1, nodes)
    for p in pserver_list:
        p.join()
1637
        assert p.exitcode == 0
1638
1639

    src, dst = sampled_graph.edges()
1640
1641
    src = orig_nid[src]
    dst = orig_nid[dst]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1642
    assert sampled_graph.num_nodes() == g.num_nodes()
1643
1644
1645
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))

    subg1 = dgl.in_subgraph(g, orig_nid[nodes])
1646
1647
1648
1649
    src1, dst1 = subg1.edges()
    assert np.all(np.sort(F.asnumpy(src)) == np.sort(F.asnumpy(src1)))
    assert np.all(np.sort(F.asnumpy(dst)) == np.sort(F.asnumpy(dst1)))
    eids = g.edge_ids(src, dst)
1650
1651
    eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
    assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))
1652

1653
1654
1655
1656
1657
1658

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
1659
def test_rpc_in_subgraph():
1660
    reset_envs()
1661
    import tempfile
1662
1663

    os.environ["DGL_DIST_MODE"] = "distributed"
1664
    with tempfile.TemporaryDirectory() as tmpdirname:
1665
        check_rpc_in_subgraph_shuffle(Path(tmpdirname), 1)
1666

1667
1668
1669
1670
1671
1672
1673
1674
1675

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
1676
def test_standalone_etype_sampling():
1677
    reset_envs()
1678
    import tempfile
1679

1680
    with tempfile.TemporaryDirectory() as tmpdirname:
1681
        os.environ["DGL_DIST_MODE"] = "standalone"
1682
        check_standalone_etype_sampling_heterograph(Path(tmpdirname))
1683
    with tempfile.TemporaryDirectory() as tmpdirname:
1684
        os.environ["DGL_DIST_MODE"] = "standalone"
1685
        check_standalone_etype_sampling(Path(tmpdirname))
1686

1687

Jinjing Zhou's avatar
Jinjing Zhou committed
1688
1689
if __name__ == "__main__":
    import tempfile
1690

Jinjing Zhou's avatar
Jinjing Zhou committed
1691
    with tempfile.TemporaryDirectory() as tmpdirname:
1692
        os.environ["DGL_DIST_MODE"] = "standalone"
1693
        check_standalone_etype_sampling_heterograph(Path(tmpdirname))
1694
1695

    with tempfile.TemporaryDirectory() as tmpdirname:
1696
        os.environ["DGL_DIST_MODE"] = "standalone"
1697
1698
        check_standalone_etype_sampling(Path(tmpdirname))
        check_standalone_sampling(Path(tmpdirname))
1699
        os.environ["DGL_DIST_MODE"] = "distributed"
1700
1701
        check_rpc_sampling(Path(tmpdirname), 2)
        check_rpc_sampling(Path(tmpdirname), 1)
1702
1703
        check_rpc_get_degree_shuffle(Path(tmpdirname), 1)
        check_rpc_get_degree_shuffle(Path(tmpdirname), 2)
1704
1705
        check_rpc_find_edges_shuffle(Path(tmpdirname), 2)
        check_rpc_find_edges_shuffle(Path(tmpdirname), 1)
1706
1707
        check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), 2)
1708
1709
1710
1711
        check_rpc_in_subgraph_shuffle(Path(tmpdirname), 2)
        check_rpc_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), 2)
1712
1713
1714
1715
        check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), 2)
        check_rpc_hetero_etype_sampling_empty_shuffle(Path(tmpdirname), 1)