test_distributed_sampling.py 45.4 KB
Newer Older
1
import multiprocessing as mp
Jinjing Zhou's avatar
Jinjing Zhou committed
2
import os
3
import random
4
import tempfile
Jinjing Zhou's avatar
Jinjing Zhou committed
5
import time
6
7
import traceback
import unittest
Jinjing Zhou's avatar
Jinjing Zhou committed
8
from pathlib import Path
9
10
11
12

import backend as F
import dgl
import numpy as np
13
import pytest
14
15
16
17
18
19
20
21
22
23
from dgl.data import CitationGraphDataset, WN18Dataset
from dgl.distributed import (
    DistGraph,
    DistGraphServer,
    load_partition,
    load_partition_book,
    partition_graph,
    sample_etype_neighbors,
    sample_neighbors,
)
24
from scipy import sparse as spsp
25
from utils import generate_ip_config, reset_envs
Jinjing Zhou's avatar
Jinjing Zhou committed
26
27


28
29
30
31
32
33
def start_server(
    rank,
    tmpdir,
    disable_shared_mem,
    graph_name,
    graph_format=["csc", "coo"],
34
    use_graphbolt=False,
35
36
37
38
39
40
41
42
43
):
    g = DistGraphServer(
        rank,
        "rpc_ip_config.txt",
        1,
        1,
        tmpdir / (graph_name + ".json"),
        disable_shared_mem=disable_shared_mem,
        graph_format=graph_format,
44
        use_graphbolt=use_graphbolt,
45
    )
Jinjing Zhou's avatar
Jinjing Zhou committed
46
47
48
    g.start()


49
def start_sample_client(rank, tmpdir, disable_shared_mem):
50
51
    gpb = None
    if disable_shared_mem:
52
53
54
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
55
    dgl.distributed.initialize("rpc_ip_config.txt")
56
    dist_graph = DistGraph("test_sampling", gpb=gpb)
57
    try:
58
59
60
        sampled_graph = sample_neighbors(
            dist_graph, [0, 10, 99, 66, 1024, 2008], 3
        )
61
    except Exception as e:
62
        print(traceback.format_exc())
63
        sampled_graph = None
64
    dgl.distributed.exit_client()
Jinjing Zhou's avatar
Jinjing Zhou committed
65
66
    return sampled_graph

67

68
69
70
71
72
73
74
75
76
def start_sample_client_shuffle(
    rank,
    tmpdir,
    disable_shared_mem,
    g,
    num_servers,
    group_id,
    orig_nid,
    orig_eid,
77
    use_graphbolt=False,
78
    return_eids=False,
79
80
):
    os.environ["DGL_GROUP_ID"] = str(group_id)
81
82
    gpb = None
    if disable_shared_mem:
83
84
85
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
86
    dgl.distributed.initialize("rpc_ip_config.txt")
87
88
89
90
91
92
    dist_graph = DistGraph(
        "test_sampling", gpb=gpb, use_graphbolt=use_graphbolt
    )
    sampled_graph = sample_neighbors(
        dist_graph, [0, 10, 99, 66, 1024, 2008], 3, use_graphbolt=use_graphbolt
    )
93

94
95
96
    assert (
        dgl.ETYPE not in sampled_graph.edata
    ), "Etype should not be in homogeneous sampled graph."
97
98
99
    src, dst = sampled_graph.edges()
    src = orig_nid[src]
    dst = orig_nid[dst]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
100
    assert sampled_graph.num_nodes() == g.num_nodes()
101
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
102
    if use_graphbolt and not return_eids:
103
104
105
106
107
108
109
        assert (
            dgl.EID not in sampled_graph.edata
        ), "EID should not be in sampled graph if use_graphbolt=True."
    else:
        eids = g.edge_ids(src, dst)
        eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
        assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))
110

111

112
def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids, etype=None):
113
114
    gpb = None
    if disable_shared_mem:
115
116
117
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_find_edges.json", rank
        )
118
    dgl.distributed.initialize("rpc_ip_config.txt")
119
    dist_graph = DistGraph("test_find_edges", gpb=gpb)
120
    try:
121
        u, v = dist_graph.find_edges(eids, etype=etype)
122
    except Exception as e:
123
        print(traceback.format_exc())
124
        u, v = None, None
125
126
    dgl.distributed.exit_client()
    return u, v
Jinjing Zhou's avatar
Jinjing Zhou committed
127

128

129
130
131
def start_get_degrees_client(rank, tmpdir, disable_shared_mem, nids=None):
    gpb = None
    if disable_shared_mem:
132
133
134
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_get_degrees.json", rank
        )
135
    dgl.distributed.initialize("rpc_ip_config.txt")
136
137
138
139
140
141
142
    dist_graph = DistGraph("test_get_degrees", gpb=gpb)
    try:
        in_deg = dist_graph.in_degrees(nids)
        all_in_deg = dist_graph.in_degrees()
        out_deg = dist_graph.out_degrees(nids)
        all_out_deg = dist_graph.out_degrees()
    except Exception as e:
143
        print(traceback.format_exc())
144
145
146
147
        in_deg, out_deg, all_in_deg, all_out_deg = None, None, None, None
    dgl.distributed.exit_client()
    return in_deg, out_deg, all_in_deg, all_out_deg

148

149
def check_rpc_sampling(tmpdir, num_server):
150
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
Jinjing Zhou's avatar
Jinjing Zhou committed
151
152
153
154
155
156

    g = CitationGraphDataset("cora")[0]
    print(g.idtype)
    num_parts = num_server
    num_hops = 1

157
158
159
160
161
162
163
164
    partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )
Jinjing Zhou's avatar
Jinjing Zhou committed
165
166

    pserver_list = []
167
    ctx = mp.get_context("spawn")
Jinjing Zhou's avatar
Jinjing Zhou committed
168
    for i in range(num_server):
169
170
171
172
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
Jinjing Zhou's avatar
Jinjing Zhou committed
173
174
175
176
        p.start()
        time.sleep(1)
        pserver_list.append(p)

177
    sampled_graph = start_sample_client(0, tmpdir, num_server > 1)
Jinjing Zhou's avatar
Jinjing Zhou committed
178
179
180
    print("Done sampling")
    for p in pserver_list:
        p.join()
181
        assert p.exitcode == 0
Jinjing Zhou's avatar
Jinjing Zhou committed
182
183

    src, dst = sampled_graph.edges()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
184
    assert sampled_graph.num_nodes() == g.num_nodes()
Jinjing Zhou's avatar
Jinjing Zhou committed
185
186
187
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    assert np.array_equal(
188
189
190
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)
    )

Jinjing Zhou's avatar
Jinjing Zhou committed
191

192
def check_rpc_find_edges_shuffle(tmpdir, num_server):
193
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
194
195
196
197

    g = CitationGraphDataset("cora")[0]
    num_parts = num_server

198
199
200
201
202
203
204
205
206
    orig_nid, orig_eid = partition_graph(
        g,
        "test_find_edges",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
207
208

    pserver_list = []
209
    ctx = mp.get_context("spawn")
210
    for i in range(num_server):
211
212
213
214
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_find_edges", ["csr", "coo"]),
        )
215
216
217
218
        p.start()
        time.sleep(1)
        pserver_list.append(p)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
219
    eids = F.tensor(np.random.randint(g.num_edges(), size=100))
220
    u, v = g.find_edges(orig_eid[eids])
221
    du, dv = start_find_edges_client(0, tmpdir, num_server > 1, eids)
222
223
    du = orig_nid[du]
    dv = orig_nid[dv]
224
225
226
    assert F.array_equal(u, du)
    assert F.array_equal(v, dv)

227

228
def create_random_hetero(dense=False, empty=False):
229
230
231
232
233
234
    num_nodes = (
        {"n1": 210, "n2": 200, "n3": 220}
        if dense
        else {"n1": 1010, "n2": 1000, "n3": 1020}
    )
    etypes = [("n1", "r12", "n2"), ("n1", "r13", "n3"), ("n2", "r23", "n3")]
235
    edges = {}
236
    random.seed(42)
237
238
    for etype in etypes:
        src_ntype, _, dst_ntype = etype
239
240
241
242
243
244
245
        arr = spsp.random(
            num_nodes[src_ntype] - 10 if empty else num_nodes[src_ntype],
            num_nodes[dst_ntype] - 10 if empty else num_nodes[dst_ntype],
            density=0.1 if dense else 0.001,
            format="coo",
            random_state=100,
        )
246
        edges[etype] = (arr.row, arr.col)
247
    g = dgl.heterograph(edges, num_nodes)
248
    g.nodes["n1"].data["feat"] = F.ones(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
249
        (g.num_nodes("n1"), 10), F.float32, F.cpu()
250
    )
251
    return g
252

253

254
def check_rpc_hetero_find_edges_shuffle(tmpdir, num_server):
255
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
256
257
258
259

    g = create_random_hetero()
    num_parts = num_server

260
261
262
263
264
265
266
267
268
    orig_nid, orig_eid = partition_graph(
        g,
        "test_find_edges",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
269
270

    pserver_list = []
271
    ctx = mp.get_context("spawn")
272
    for i in range(num_server):
273
274
275
276
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_find_edges", ["csr", "coo"]),
        )
277
278
279
280
        p.start()
        time.sleep(1)
        pserver_list.append(p)

281
    test_etype = g.to_canonical_etype("r12")
282
    eids = F.tensor(np.random.randint(g.num_edges(test_etype), size=100))
283
284
    expect_except = False
    try:
285
        _, _ = g.find_edges(orig_eid[test_etype][eids], etype=("n1", "r12"))
286
287
288
    except:
        expect_except = True
    assert expect_except
289
290
    u, v = g.find_edges(orig_eid[test_etype][eids], etype="r12")
    u1, v1 = g.find_edges(orig_eid[test_etype][eids], etype=("n1", "r12", "n2"))
291
292
    assert F.array_equal(u, u1)
    assert F.array_equal(v, v1)
293
294
295
296
297
    du, dv = start_find_edges_client(
        0, tmpdir, num_server > 1, eids, etype="r12"
    )
    du = orig_nid["n1"][du]
    dv = orig_nid["n2"][dv]
298
299
300
    assert F.array_equal(u, du)
    assert F.array_equal(v, dv)

301

302
# Wait non shared memory graph store
303
304
305
306
307
308
309
310
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
311
@pytest.mark.parametrize("num_server", [1])
312
def test_rpc_find_edges_shuffle(num_server):
313
    reset_envs()
314
    import tempfile
315
316

    os.environ["DGL_DIST_MODE"] = "distributed"
317
    with tempfile.TemporaryDirectory() as tmpdirname:
318
        check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), num_server)
319
320
        check_rpc_find_edges_shuffle(Path(tmpdirname), num_server)

321

322
def check_rpc_get_degree_shuffle(tmpdir, num_server):
323
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
324
325
326
327

    g = CitationGraphDataset("cora")[0]
    num_parts = num_server

328
329
330
331
332
333
334
335
336
    orig_nid, _ = partition_graph(
        g,
        "test_get_degrees",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
337
338

    pserver_list = []
339
    ctx = mp.get_context("spawn")
340
    for i in range(num_server):
341
342
343
344
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_get_degrees"),
        )
345
346
347
348
        p.start()
        time.sleep(1)
        pserver_list.append(p)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
349
    nids = F.tensor(np.random.randint(g.num_nodes(), size=100))
350
351
352
    in_degs, out_degs, all_in_degs, all_out_degs = start_get_degrees_client(
        0, tmpdir, num_server > 1, nids
    )
353
354
355
356

    print("Done get_degree")
    for p in pserver_list:
        p.join()
357
        assert p.exitcode == 0
358

359
    print("check results")
360
361
362
363
364
    assert F.array_equal(g.in_degrees(orig_nid[nids]), in_degs)
    assert F.array_equal(g.in_degrees(orig_nid), all_in_degs)
    assert F.array_equal(g.out_degrees(orig_nid[nids]), out_degs)
    assert F.array_equal(g.out_degrees(orig_nid), all_out_degs)

365

366
# Wait non shared memory graph store
367
368
369
370
371
372
373
374
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
375
@pytest.mark.parametrize("num_server", [1])
376
def test_rpc_get_degree_shuffle(num_server):
377
    reset_envs()
378
    import tempfile
379
380

    os.environ["DGL_DIST_MODE"] = "distributed"
381
382
383
    with tempfile.TemporaryDirectory() as tmpdirname:
        check_rpc_get_degree_shuffle(Path(tmpdirname), num_server)

384
385
386
387

# @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
# @unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
@unittest.skip("Only support partition with shuffle")
Jinjing Zhou's avatar
Jinjing Zhou committed
388
def test_rpc_sampling():
389
    reset_envs()
Jinjing Zhou's avatar
Jinjing Zhou committed
390
    import tempfile
391
392

    os.environ["DGL_DIST_MODE"] = "distributed"
Jinjing Zhou's avatar
Jinjing Zhou committed
393
    with tempfile.TemporaryDirectory() as tmpdirname:
394
        check_rpc_sampling(Path(tmpdirname), 1)
Jinjing Zhou's avatar
Jinjing Zhou committed
395

396

397
def check_rpc_sampling_shuffle(
398
    tmpdir, num_server, num_groups=1, use_graphbolt=False, return_eids=False
399
):
400
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
401

Jinjing Zhou's avatar
Jinjing Zhou committed
402
403
404
405
    g = CitationGraphDataset("cora")[0]
    num_parts = num_server
    num_hops = 1

406
407
408
409
410
411
412
413
    orig_nids, orig_eids = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
414
        use_graphbolt=use_graphbolt,
415
        store_eids=return_eids,
416
    )
Jinjing Zhou's avatar
Jinjing Zhou committed
417
418

    pserver_list = []
419
    ctx = mp.get_context("spawn")
Jinjing Zhou's avatar
Jinjing Zhou committed
420
    for i in range(num_server):
421
422
423
424
425
426
427
428
        p = ctx.Process(
            target=start_server,
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
429
                use_graphbolt,
430
431
            ),
        )
Jinjing Zhou's avatar
Jinjing Zhou committed
432
433
434
435
        p.start()
        time.sleep(1)
        pserver_list.append(p)

436
437
438
439
    pclient_list = []
    num_clients = 1
    for client_id in range(num_clients):
        for group_id in range(num_groups):
440
441
442
443
444
445
446
447
448
449
450
            p = ctx.Process(
                target=start_sample_client_shuffle,
                args=(
                    client_id,
                    tmpdir,
                    num_server > 1,
                    g,
                    num_server,
                    group_id,
                    orig_nids,
                    orig_eids,
451
                    use_graphbolt,
452
                    return_eids,
453
454
                ),
            )
455
            p.start()
456
            time.sleep(1)  # avoid race condition when instantiating DistGraph
457
458
459
            pclient_list.append(p)
    for p in pclient_list:
        p.join()
460
        assert p.exitcode == 0
Jinjing Zhou's avatar
Jinjing Zhou committed
461
462
    for p in pserver_list:
        p.join()
463
        assert p.exitcode == 0
Jinjing Zhou's avatar
Jinjing Zhou committed
464

465

466
467
468
469
470
471
472
473
def start_hetero_sample_client(
    rank,
    tmpdir,
    disable_shared_mem,
    nodes,
    use_graphbolt=False,
    return_eids=False,
):
474
475
    gpb = None
    if disable_shared_mem:
476
477
478
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
479
    dgl.distributed.initialize("rpc_ip_config.txt")
480
481
482
    dist_graph = DistGraph(
        "test_sampling", gpb=gpb, use_graphbolt=use_graphbolt
    )
483
484
485
    assert "feat" in dist_graph.nodes["n1"].data
    assert "feat" not in dist_graph.nodes["n2"].data
    assert "feat" not in dist_graph.nodes["n3"].data
486
487
488
    if gpb is None:
        gpb = dist_graph.get_partition_book()
    try:
489
490
491
492
493
        # Enable santity check in distributed sampling.
        os.environ["DGL_DIST_DEBUG"] = "1"
        sampled_graph = sample_neighbors(
            dist_graph, nodes, 3, use_graphbolt=use_graphbolt
        )
494
        block = dgl.to_block(sampled_graph, nodes)
495
496
        if not use_graphbolt or return_eids:
            block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
497
    except Exception as e:
498
        print(traceback.format_exc())
499
500
501
502
        block = None
    dgl.distributed.exit_client()
    return block, gpb

503
504
505
506
507
508
509
510
511

def start_hetero_etype_sample_client(
    rank,
    tmpdir,
    disable_shared_mem,
    fanout=3,
    nodes={"n3": [0, 10, 99, 66, 124, 208]},
    etype_sorted=False,
):
512
513
    gpb = None
    if disable_shared_mem:
514
515
516
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
517
518
    dgl.distributed.initialize("rpc_ip_config.txt")
    dist_graph = DistGraph("test_sampling", gpb=gpb)
519
520
521
    assert "feat" in dist_graph.nodes["n1"].data
    assert "feat" not in dist_graph.nodes["n2"].data
    assert "feat" not in dist_graph.nodes["n3"].data
522
523
524
525
526
527

    if dist_graph.local_partition is not None:
        # Check whether etypes are sorted in dist_graph
        local_g = dist_graph.local_partition
        local_nids = np.arange(local_g.num_nodes())
        for lnid in local_nids:
528
            leids = local_g.in_edges(lnid, form="eid")
529
530
531
532
            letids = F.asnumpy(local_g.edata[dgl.ETYPE][leids])
            _, idices = np.unique(letids, return_index=True)
            assert np.all(idices[:-1] <= idices[1:])

533
534
535
    if gpb is None:
        gpb = dist_graph.get_partition_book()
    try:
536
        sampled_graph = sample_etype_neighbors(
537
538
            dist_graph, nodes, fanout, etype_sorted=etype_sorted
        )
539
540
541
        block = dgl.to_block(sampled_graph, nodes)
        block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
    except Exception as e:
542
        print(traceback.format_exc())
543
544
545
546
        block = None
    dgl.distributed.exit_client()
    return block, gpb

547

548
549
550
def check_rpc_hetero_sampling_shuffle(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
551
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
552
553
554
555
556

    g = create_random_hetero()
    num_parts = num_server
    num_hops = 1

557
558
559
560
561
562
563
564
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
565
566
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
567
    )
568
569

    pserver_list = []
570
    ctx = mp.get_context("spawn")
571
    for i in range(num_server):
572
573
        p = ctx.Process(
            target=start_server,
574
575
576
577
578
579
580
581
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
582
        )
583
584
585
586
        p.start()
        time.sleep(1)
        pserver_list.append(p)

587
    block, gpb = start_hetero_sample_client(
588
589
590
591
592
593
        0,
        tmpdir,
        num_server > 1,
        nodes={"n3": [0, 10, 99, 66, 124, 208]},
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
594
    )
595
596
    for p in pserver_list:
        p.join()
597
        assert p.exitcode == 0
598

599
600
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
601
602
603
604
605
606
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
607
608
609
610
611
612
613
614
615

        assert np.all(
            F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))
        )

        if use_graphbolt and not return_eids:
            continue

        shuffled_eid = block.edges[etype].data[dgl.EID]
616
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
617
618

        # Check the node Ids and edge Ids.
619
620
621
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)
622

623

624
625
626
627
628
629
630
631
632
def get_degrees(g, nids, ntype):
    deg = F.zeros((len(nids),), dtype=F.int64)
    for srctype, etype, dsttype in g.canonical_etypes:
        if srctype == ntype:
            deg += g.out_degrees(u=nids, etype=etype)
        elif dsttype == ntype:
            deg += g.in_degrees(v=nids, etype=etype)
    return deg

633

634
635
636
def check_rpc_hetero_sampling_empty_shuffle(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
637
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
638
639
640
641
642

    g = create_random_hetero(empty=True)
    num_parts = num_server
    num_hops = 1

643
644
645
646
647
648
649
650
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
651
652
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
653
    )
654
655

    pserver_list = []
656
    ctx = mp.get_context("spawn")
657
    for i in range(num_server):
658
659
        p = ctx.Process(
            target=start_server,
660
661
662
663
664
665
666
667
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
668
        )
669
670
671
672
        p.start()
        time.sleep(1)
        pserver_list.append(p)

673
    deg = get_degrees(g, orig_nids["n3"], "n3")
674
    empty_nids = F.nonzero_1d(deg == 0)
675
    block, gpb = start_hetero_sample_client(
676
677
678
679
680
681
        0,
        tmpdir,
        num_server > 1,
        nodes={"n3": empty_nids},
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
682
    )
683
684
    for p in pserver_list:
        p.join()
685
        assert p.exitcode == 0
686

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
687
    assert block.num_edges() == 0
688
689
    assert len(block.etypes) == len(g.etypes)

690
691
692
693

def check_rpc_hetero_etype_sampling_shuffle(
    tmpdir, num_server, graph_formats=None
):
694
695
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

696
697
698
699
    g = create_random_hetero(dense=True)
    num_parts = num_server
    num_hops = 1

700
701
702
703
704
705
706
707
708
709
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
        graph_formats=graph_formats,
    )
710
711

    pserver_list = []
712
    ctx = mp.get_context("spawn")
713
    for i in range(num_server):
714
715
716
717
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling", ["csc", "coo"]),
        )
718
719
720
721
        p.start()
        time.sleep(1)
        pserver_list.append(p)

722
    fanout = {etype: 3 for etype in g.canonical_etypes}
723
724
    etype_sorted = False
    if graph_formats is not None:
725
726
727
728
729
730
731
732
733
        etype_sorted = "csc" in graph_formats or "csr" in graph_formats
    block, gpb = start_hetero_etype_sample_client(
        0,
        tmpdir,
        num_server > 1,
        fanout,
        nodes={"n3": [0, 10, 99, 66, 124, 208]},
        etype_sorted=etype_sorted,
    )
734
735
736
    print("Done sampling")
    for p in pserver_list:
        p.join()
737
        assert p.exitcode == 0
738

739
    src, dst = block.edges(etype=("n1", "r13", "n3"))
740
    assert len(src) == 18
741
    src, dst = block.edges(etype=("n2", "r23", "n3"))
742
743
    assert len(src) == 18

744
745
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
746
747
748
749
750
751
752
753
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        shuffled_eid = block.edges[etype].data[dgl.EID]

        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
754
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
755
756
757
758
759
760

        # Check the node Ids and edge Ids.
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)

761

762
def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server):
763
764
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

765
766
767
768
    g = create_random_hetero(dense=True, empty=True)
    num_parts = num_server
    num_hops = 1

769
770
771
772
773
774
775
776
777
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
    )
778
779

    pserver_list = []
780
    ctx = mp.get_context("spawn")
781
    for i in range(num_server):
782
783
784
785
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
786
787
788
789
790
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    fanout = 3
791
    deg = get_degrees(g, orig_nids["n3"], "n3")
792
    empty_nids = F.nonzero_1d(deg == 0)
793
794
795
    block, gpb = start_hetero_etype_sample_client(
        0, tmpdir, num_server > 1, fanout, nodes={"n3": empty_nids}
    )
796
797
798
    print("Done sampling")
    for p in pserver_list:
        p.join()
799
        assert p.exitcode == 0
800

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
801
    assert block.num_edges() == 0
802
803
    assert len(block.etypes) == len(g.etypes)

804
805

def create_random_bipartite():
806
807
808
809
810
811
812
    g = dgl.rand_bipartite("user", "buys", "game", 500, 1000, 1000)
    g.nodes["user"].data["feat"] = F.ones(
        (g.num_nodes("user"), 10), F.float32, F.cpu()
    )
    g.nodes["game"].data["feat"] = F.ones(
        (g.num_nodes("game"), 10), F.float32, F.cpu()
    )
813
814
815
    return g


816
817
818
819
820
821
822
823
def start_bipartite_sample_client(
    rank,
    tmpdir,
    disable_shared_mem,
    nodes,
    use_graphbolt=False,
    return_eids=False,
):
824
825
826
    gpb = None
    if disable_shared_mem:
        _, _, _, gpb, _, _, _ = load_partition(
827
828
            tmpdir / "test_sampling.json", rank
        )
829
    dgl.distributed.initialize("rpc_ip_config.txt")
830
831
832
    dist_graph = DistGraph(
        "test_sampling", gpb=gpb, use_graphbolt=use_graphbolt
    )
833
834
    assert "feat" in dist_graph.nodes["user"].data
    assert "feat" in dist_graph.nodes["game"].data
835
836
    if gpb is None:
        gpb = dist_graph.get_partition_book()
837
838
839
840
841
    # Enable santity check in distributed sampling.
    os.environ["DGL_DIST_DEBUG"] = "1"
    sampled_graph = sample_neighbors(
        dist_graph, nodes, 3, use_graphbolt=use_graphbolt
    )
842
843
    block = dgl.to_block(sampled_graph, nodes)
    if sampled_graph.num_edges() > 0:
844
845
        if not use_graphbolt or return_eids:
            block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
846
847
848
849
    dgl.distributed.exit_client()
    return block, gpb


850
851
852
def start_bipartite_etype_sample_client(
    rank, tmpdir, disable_shared_mem, fanout=3, nodes={}
):
853
854
855
    gpb = None
    if disable_shared_mem:
        _, _, _, gpb, _, _, _ = load_partition(
856
857
            tmpdir / "test_sampling.json", rank
        )
858
859
    dgl.distributed.initialize("rpc_ip_config.txt")
    dist_graph = DistGraph("test_sampling", gpb=gpb)
860
861
    assert "feat" in dist_graph.nodes["user"].data
    assert "feat" in dist_graph.nodes["game"].data
862
863
864
865
866
867

    if dist_graph.local_partition is not None:
        # Check whether etypes are sorted in dist_graph
        local_g = dist_graph.local_partition
        local_nids = np.arange(local_g.num_nodes())
        for lnid in local_nids:
868
            leids = local_g.in_edges(lnid, form="eid")
869
870
871
872
873
874
            letids = F.asnumpy(local_g.edata[dgl.ETYPE][leids])
            _, idices = np.unique(letids, return_index=True)
            assert np.all(idices[:-1] <= idices[1:])

    if gpb is None:
        gpb = dist_graph.get_partition_book()
875
    sampled_graph = sample_etype_neighbors(dist_graph, nodes, fanout)
876
877
878
879
880
881
882
    block = dgl.to_block(sampled_graph, nodes)
    if sampled_graph.num_edges() > 0:
        block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
    dgl.distributed.exit_client()
    return block, gpb


883
884
885
def check_rpc_bipartite_sampling_empty(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
886
887
888
889
890
891
892
    """sample on bipartite via sample_neighbors() which yields empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

893
894
895
896
897
898
899
900
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
901
902
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
903
    )
904
905

    pserver_list = []
906
    ctx = mp.get_context("spawn")
907
    for i in range(num_server):
908
909
        p = ctx.Process(
            target=start_server,
910
911
912
913
914
915
916
917
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
918
        )
919
920
921
922
        p.start()
        time.sleep(1)
        pserver_list.append(p)

923
    deg = get_degrees(g, orig_nids["game"], "game")
924
    empty_nids = F.nonzero_1d(deg == 0)
925
    block, _ = start_bipartite_sample_client(
926
927
928
929
930
931
        0,
        tmpdir,
        num_server > 1,
        nodes={"game": empty_nids, "user": [1]},
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
932
    )
933
934
935
936

    print("Done sampling")
    for p in pserver_list:
        p.join()
937
        assert p.exitcode == 0
938

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
939
    assert block.num_edges() == 0
940
941
942
    assert len(block.etypes) == len(g.etypes)


943
944
945
def check_rpc_bipartite_sampling_shuffle(
    tmpdir, num_server, use_graphbolt=False, return_eids=False
):
946
947
948
949
950
951
952
    """sample on bipartite via sample_neighbors() which yields non-empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

953
954
955
956
957
958
959
960
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
961
962
        use_graphbolt=use_graphbolt,
        store_eids=return_eids,
963
    )
964
965

    pserver_list = []
966
    ctx = mp.get_context("spawn")
967
    for i in range(num_server):
968
969
        p = ctx.Process(
            target=start_server,
970
971
972
973
974
975
976
977
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                use_graphbolt,
            ),
978
        )
979
980
981
982
        p.start()
        time.sleep(1)
        pserver_list.append(p)

983
    deg = get_degrees(g, orig_nid_map["game"], "game")
984
    nids = F.nonzero_1d(deg > 0)
985
    block, gpb = start_bipartite_sample_client(
986
987
988
989
990
991
        0,
        tmpdir,
        num_server > 1,
        nodes={"game": nids, "user": [0]},
        use_graphbolt=use_graphbolt,
        return_eids=return_eids,
992
    )
993
994
995
    print("Done sampling")
    for p in pserver_list:
        p.join()
996
        assert p.exitcode == 0
997

998
999
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
1000
1001
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
1002
1003
1004
1005
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
1006
1007
1008
1009
1010
1011
1012
1013
        assert np.all(
            F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))
        )

        if use_graphbolt and not return_eids:
            continue

        shuffled_eid = block.edges[etype].data[dgl.EID]
1014
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029

        # Check the node Ids and edge Ids.
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)


def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server):
    """sample on bipartite via sample_etype_neighbors() which yields empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

1030
1031
1032
1033
1034
1035
1036
1037
1038
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
    )
1039
1040

    pserver_list = []
1041
    ctx = mp.get_context("spawn")
1042
    for i in range(num_server):
1043
1044
1045
1046
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
1047
1048
1049
1050
        p.start()
        time.sleep(1)
        pserver_list.append(p)

1051
    deg = get_degrees(g, orig_nids["game"], "game")
1052
    empty_nids = F.nonzero_1d(deg == 0)
1053
1054
1055
    block, gpb = start_bipartite_etype_sample_client(
        0, tmpdir, num_server > 1, nodes={"game": empty_nids, "user": [1]}
    )
1056
1057
1058
1059

    print("Done sampling")
    for p in pserver_list:
        p.join()
1060
        assert p.exitcode == 0
1061
1062

    assert block is not None
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1063
    assert block.num_edges() == 0
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
    assert len(block.etypes) == len(g.etypes)


def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
    """sample on bipartite via sample_etype_neighbors() which yields non-empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

1075
1076
1077
1078
1079
1080
1081
1082
1083
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
    )
1084
1085

    pserver_list = []
1086
    ctx = mp.get_context("spawn")
1087
    for i in range(num_server):
1088
1089
1090
1091
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
1092
1093
1094
1095
1096
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    fanout = 3
1097
    deg = get_degrees(g, orig_nid_map["game"], "game")
1098
    nids = F.nonzero_1d(deg > 0)
1099
1100
1101
    block, gpb = start_bipartite_etype_sample_client(
        0, tmpdir, num_server > 1, fanout, nodes={"game": nids, "user": [0]}
    )
1102
1103
1104
    print("Done sampling")
    for p in pserver_list:
        p.join()
1105
        assert p.exitcode == 0
1106

1107
1108
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
1109
1110
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
1111
1112
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
1113
1114
        shuffled_eid = block.edges[etype].data[dgl.EID]

1115
1116
        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
1117
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
1118
1119
1120
1121
1122
1123

        # Check the node Ids and edge Ids.
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)

1124

1125
@pytest.mark.parametrize("num_server", [1])
1126
@pytest.mark.parametrize("use_graphbolt", [False, True])
1127
1128
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_sampling_shuffle(num_server, use_graphbolt, return_eids):
1129
    reset_envs()
1130
    os.environ["DGL_DIST_MODE"] = "distributed"
Jinjing Zhou's avatar
Jinjing Zhou committed
1131
    with tempfile.TemporaryDirectory() as tmpdirname:
1132
        check_rpc_sampling_shuffle(
1133
1134
1135
1136
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
1137
        )
1138
1139
1140


@pytest.mark.parametrize("num_server", [1])
1141
1142
1143
@pytest.mark.parametrize("use_graphbolt,", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_sampling_shuffle(num_server, use_graphbolt, return_eids):
1144
1145
1146
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1147
1148
1149
1150
1151
1152
        check_rpc_hetero_sampling_shuffle(
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
        )
1153
1154
1155


@pytest.mark.parametrize("num_server", [1])
1156
1157
1158
1159
1160
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_sampling_empty_shuffle(
    num_server, use_graphbolt, return_eids
):
1161
1162
1163
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1164
1165
1166
1167
1168
1169
        check_rpc_hetero_sampling_empty_shuffle(
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
        )
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179


@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize(
    "graph_formats", [None, ["csc"], ["csr"], ["csc", "coo"]]
)
def test_rpc_hetero_etype_sampling_shuffle(num_server, graph_formats):
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1180
        check_rpc_hetero_etype_sampling_shuffle(
1181
            Path(tmpdirname), num_server, graph_formats=graph_formats
1182
        )
1183
1184
1185
1186
1187
1188
1189


@pytest.mark.parametrize("num_server", [1])
def test_rpc_hetero_etype_sampling_empty_shuffle(num_server):
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1190
1191
1192
        check_rpc_hetero_etype_sampling_empty_shuffle(
            Path(tmpdirname), num_server
        )
1193
1194
1195


@pytest.mark.parametrize("num_server", [1])
1196
1197
1198
1199
1200
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_bipartite_sampling_empty_shuffle(
    num_server, use_graphbolt, return_eids
):
1201
1202
1203
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1204
1205
1206
        check_rpc_bipartite_sampling_empty(
            Path(tmpdirname), num_server, use_graphbolt, return_eids
        )
1207
1208
1209


@pytest.mark.parametrize("num_server", [1])
1210
1211
1212
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_bipartite_sampling_shuffle(num_server, use_graphbolt, return_eids):
1213
1214
1215
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1216
1217
1218
        check_rpc_bipartite_sampling_shuffle(
            Path(tmpdirname), num_server, use_graphbolt, return_eids
        )
1219
1220
1221
1222
1223
1224
1225


@pytest.mark.parametrize("num_server", [1])
def test_rpc_bipartite_etype_sampling_empty_shuffle(num_server):
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1226
        check_rpc_bipartite_etype_sampling_empty(Path(tmpdirname), num_server)
1227
1228
1229
1230
1231
1232
1233


@pytest.mark.parametrize("num_server", [1])
def test_rpc_bipartite_etype_sampling_shuffle(num_server):
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1234
        check_rpc_bipartite_etype_sampling_shuffle(Path(tmpdirname), num_server)
Jinjing Zhou's avatar
Jinjing Zhou committed
1235

1236

1237
def check_standalone_sampling(tmpdir):
1238
    g = CitationGraphDataset("cora")[0]
1239
    prob = np.maximum(np.random.randn(g.num_edges()), 0)
1240
1241
1242
    mask = prob > 0
    g.edata["prob"] = F.tensor(prob)
    g.edata["mask"] = F.tensor(mask)
1243
1244
    num_parts = 1
    num_hops = 1
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
    partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )

    os.environ["DGL_DIST_MODE"] = "standalone"
1255
    dgl.distributed.initialize("rpc_ip_config.txt")
1256
1257
1258
    dist_graph = DistGraph(
        "test_sampling", part_config=tmpdir / "test_sampling.json"
    )
1259
1260
1261
    sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)

    src, dst = sampled_graph.edges()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1262
    assert sampled_graph.num_nodes() == g.num_nodes()
1263
1264
1265
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    assert np.array_equal(
1266
1267
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)
    )
1268
1269

    sampled_graph = sample_neighbors(
1270
1271
        dist_graph, [0, 10, 99, 66, 1024, 2008], 3, prob="mask"
    )
1272
1273
1274
1275
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert mask[eid].all()

    sampled_graph = sample_neighbors(
1276
1277
        dist_graph, [0, 10, 99, 66, 1024, 2008], 3, prob="prob"
    )
1278
1279
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert (prob[eid] > 0).all()
1280
    dgl.distributed.exit_client()
1281

1282

1283
def check_standalone_etype_sampling(tmpdir):
1284
    hg = CitationGraphDataset("cora")[0]
1285
    prob = np.maximum(np.random.randn(hg.num_edges()), 0)
1286
1287
1288
    mask = prob > 0
    hg.edata["prob"] = F.tensor(prob)
    hg.edata["mask"] = F.tensor(mask)
1289
1290
1291
    num_parts = 1
    num_hops = 1

1292
1293
1294
1295
1296
1297
1298
1299
1300
    partition_graph(
        hg,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )
    os.environ["DGL_DIST_MODE"] = "standalone"
1301
    dgl.distributed.initialize("rpc_ip_config.txt")
1302
1303
1304
    dist_graph = DistGraph(
        "test_sampling", part_config=tmpdir / "test_sampling.json"
    )
1305
    sampled_graph = sample_etype_neighbors(dist_graph, [0, 10, 99, 66, 1023], 3)
1306
1307

    src, dst = sampled_graph.edges()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1308
    assert sampled_graph.num_nodes() == hg.num_nodes()
1309
1310
1311
    assert np.all(F.asnumpy(hg.has_edges_between(src, dst)))
    eids = hg.edge_ids(src, dst)
    assert np.array_equal(
1312
1313
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)
    )
1314
1315

    sampled_graph = sample_etype_neighbors(
1316
1317
        dist_graph, [0, 10, 99, 66, 1023], 3, prob="mask"
    )
1318
1319
1320
1321
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert mask[eid].all()

    sampled_graph = sample_etype_neighbors(
1322
1323
        dist_graph, [0, 10, 99, 66, 1023], 3, prob="prob"
    )
1324
1325
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert (prob[eid] > 0).all()
1326
1327
    dgl.distributed.exit_client()

1328

1329
def check_standalone_etype_sampling_heterograph(tmpdir):
1330
    hg = CitationGraphDataset("cora")[0]
1331
1332
1333
    num_parts = 1
    num_hops = 1
    src, dst = hg.edges()
1334
1335
1336
1337
1338
    new_hg = dgl.heterograph(
        {
            ("paper", "cite", "paper"): (src, dst),
            ("paper", "cite-by", "paper"): (dst, src),
        },
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1339
        {"paper": hg.num_nodes()},
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
    )
    partition_graph(
        new_hg,
        "test_hetero_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )
    os.environ["DGL_DIST_MODE"] = "standalone"
1350
    dgl.distributed.initialize("rpc_ip_config.txt")
1351
1352
1353
    dist_graph = DistGraph(
        "test_hetero_sampling", part_config=tmpdir / "test_hetero_sampling.json"
    )
1354
    sampled_graph = sample_etype_neighbors(
1355
1356
1357
        dist_graph, [0, 1, 2, 10, 99, 66, 1023, 1024, 2700, 2701], 1
    )
    src, dst = sampled_graph.edges(etype=("paper", "cite", "paper"))
1358
    assert len(src) == 10
1359
    src, dst = sampled_graph.edges(etype=("paper", "cite-by", "paper"))
1360
    assert len(src) == 10
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1361
    assert sampled_graph.num_nodes() == new_hg.num_nodes()
1362
1363
    dgl.distributed.exit_client()

1364
1365
1366
1367
1368
1369

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
1370
def test_standalone_sampling():
1371
    reset_envs()
1372
    import tempfile
1373
1374

    os.environ["DGL_DIST_MODE"] = "standalone"
1375
    with tempfile.TemporaryDirectory() as tmpdirname:
1376
        check_standalone_sampling(Path(tmpdirname))
1377

1378

1379
1380
def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes):
    gpb = None
1381
    dgl.distributed.initialize("rpc_ip_config.txt")
1382
    if disable_shared_mem:
1383
1384
1385
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_in_subgraph.json", rank
        )
1386
    dist_graph = DistGraph("test_in_subgraph", gpb=gpb)
1387
1388
1389
    try:
        sampled_graph = dgl.distributed.in_subgraph(dist_graph, nodes)
    except Exception as e:
1390
        print(traceback.format_exc())
1391
        sampled_graph = None
1392
    dgl.distributed.exit_client()
1393
1394
1395
    return sampled_graph


1396
def check_rpc_in_subgraph_shuffle(tmpdir, num_server):
1397
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
1398
1399
1400
1401

    g = CitationGraphDataset("cora")[0]
    num_parts = num_server

1402
1403
1404
1405
1406
1407
1408
1409
1410
    orig_nid, orig_eid = partition_graph(
        g,
        "test_in_subgraph",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
1411
1412

    pserver_list = []
1413
    ctx = mp.get_context("spawn")
1414
    for i in range(num_server):
1415
1416
1417
1418
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_in_subgraph"),
        )
1419
1420
1421
1422
1423
1424
1425
1426
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    nodes = [0, 10, 99, 66, 1024, 2008]
    sampled_graph = start_in_subgraph_client(0, tmpdir, num_server > 1, nodes)
    for p in pserver_list:
        p.join()
1427
        assert p.exitcode == 0
1428
1429

    src, dst = sampled_graph.edges()
1430
1431
    src = orig_nid[src]
    dst = orig_nid[dst]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1432
    assert sampled_graph.num_nodes() == g.num_nodes()
1433
1434
1435
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))

    subg1 = dgl.in_subgraph(g, orig_nid[nodes])
1436
1437
1438
1439
    src1, dst1 = subg1.edges()
    assert np.all(np.sort(F.asnumpy(src)) == np.sort(F.asnumpy(src1)))
    assert np.all(np.sort(F.asnumpy(dst)) == np.sort(F.asnumpy(dst1)))
    eids = g.edge_ids(src, dst)
1440
1441
    eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
    assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))
1442

1443
1444
1445
1446
1447
1448

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
1449
def test_rpc_in_subgraph():
1450
    reset_envs()
1451
    import tempfile
1452
1453

    os.environ["DGL_DIST_MODE"] = "distributed"
1454
    with tempfile.TemporaryDirectory() as tmpdirname:
1455
        check_rpc_in_subgraph_shuffle(Path(tmpdirname), 1)
1456

1457
1458
1459
1460
1461
1462
1463
1464
1465

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
1466
def test_standalone_etype_sampling():
1467
    reset_envs()
1468
    import tempfile
1469

1470
    with tempfile.TemporaryDirectory() as tmpdirname:
1471
        os.environ["DGL_DIST_MODE"] = "standalone"
1472
        check_standalone_etype_sampling_heterograph(Path(tmpdirname))
1473
    with tempfile.TemporaryDirectory() as tmpdirname:
1474
        os.environ["DGL_DIST_MODE"] = "standalone"
1475
        check_standalone_etype_sampling(Path(tmpdirname))
1476

1477

Jinjing Zhou's avatar
Jinjing Zhou committed
1478
1479
if __name__ == "__main__":
    import tempfile
1480

Jinjing Zhou's avatar
Jinjing Zhou committed
1481
    with tempfile.TemporaryDirectory() as tmpdirname:
1482
        os.environ["DGL_DIST_MODE"] = "standalone"
1483
        check_standalone_etype_sampling_heterograph(Path(tmpdirname))
1484
1485

    with tempfile.TemporaryDirectory() as tmpdirname:
1486
        os.environ["DGL_DIST_MODE"] = "standalone"
1487
1488
        check_standalone_etype_sampling(Path(tmpdirname))
        check_standalone_sampling(Path(tmpdirname))
1489
        os.environ["DGL_DIST_MODE"] = "distributed"
1490
1491
        check_rpc_sampling(Path(tmpdirname), 2)
        check_rpc_sampling(Path(tmpdirname), 1)
1492
1493
        check_rpc_get_degree_shuffle(Path(tmpdirname), 1)
        check_rpc_get_degree_shuffle(Path(tmpdirname), 2)
1494
1495
        check_rpc_find_edges_shuffle(Path(tmpdirname), 2)
        check_rpc_find_edges_shuffle(Path(tmpdirname), 1)
1496
1497
        check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), 2)
1498
1499
1500
1501
        check_rpc_in_subgraph_shuffle(Path(tmpdirname), 2)
        check_rpc_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), 2)
1502
1503
1504
1505
        check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), 2)
        check_rpc_hetero_etype_sampling_empty_shuffle(Path(tmpdirname), 1)