test_distributed_sampling.py 42.1 KB
Newer Older
1
import multiprocessing as mp
Jinjing Zhou's avatar
Jinjing Zhou committed
2
import os
3
import random
4
import tempfile
Jinjing Zhou's avatar
Jinjing Zhou committed
5
import time
6
7
import traceback
import unittest
Jinjing Zhou's avatar
Jinjing Zhou committed
8
from pathlib import Path
9
10
11
12

import backend as F
import dgl
import numpy as np
13
import pytest
14
15
16
17
18
19
20
21
22
23
from dgl.data import CitationGraphDataset, WN18Dataset
from dgl.distributed import (
    DistGraph,
    DistGraphServer,
    load_partition,
    load_partition_book,
    partition_graph,
    sample_etype_neighbors,
    sample_neighbors,
)
24
from scipy import sparse as spsp
25
from utils import generate_ip_config, reset_envs
Jinjing Zhou's avatar
Jinjing Zhou committed
26
27


28
29
30
31
32
33
def start_server(
    rank,
    tmpdir,
    disable_shared_mem,
    graph_name,
    graph_format=["csc", "coo"],
34
    use_graphbolt=False,
35
36
37
38
39
40
41
42
43
):
    g = DistGraphServer(
        rank,
        "rpc_ip_config.txt",
        1,
        1,
        tmpdir / (graph_name + ".json"),
        disable_shared_mem=disable_shared_mem,
        graph_format=graph_format,
44
        use_graphbolt=use_graphbolt,
45
    )
Jinjing Zhou's avatar
Jinjing Zhou committed
46
47
48
    g.start()


49
def start_sample_client(rank, tmpdir, disable_shared_mem):
50
51
    gpb = None
    if disable_shared_mem:
52
53
54
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
55
    dgl.distributed.initialize("rpc_ip_config.txt")
56
    dist_graph = DistGraph("test_sampling", gpb=gpb)
57
    try:
58
59
60
        sampled_graph = sample_neighbors(
            dist_graph, [0, 10, 99, 66, 1024, 2008], 3
        )
61
    except Exception as e:
62
        print(traceback.format_exc())
63
        sampled_graph = None
64
    dgl.distributed.exit_client()
Jinjing Zhou's avatar
Jinjing Zhou committed
65
66
    return sampled_graph

67

68
69
70
71
72
73
74
75
76
def start_sample_client_shuffle(
    rank,
    tmpdir,
    disable_shared_mem,
    g,
    num_servers,
    group_id,
    orig_nid,
    orig_eid,
77
    use_graphbolt=False,
78
    return_eids=False,
79
80
):
    os.environ["DGL_GROUP_ID"] = str(group_id)
81
82
    gpb = None
    if disable_shared_mem:
83
84
85
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
86
    dgl.distributed.initialize("rpc_ip_config.txt")
87
88
89
90
91
92
    dist_graph = DistGraph(
        "test_sampling", gpb=gpb, use_graphbolt=use_graphbolt
    )
    sampled_graph = sample_neighbors(
        dist_graph, [0, 10, 99, 66, 1024, 2008], 3, use_graphbolt=use_graphbolt
    )
93
94
95
96

    src, dst = sampled_graph.edges()
    src = orig_nid[src]
    dst = orig_nid[dst]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
97
    assert sampled_graph.num_nodes() == g.num_nodes()
98
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
99
    if use_graphbolt and not return_eids:
100
101
102
103
104
105
106
        assert (
            dgl.EID not in sampled_graph.edata
        ), "EID should not be in sampled graph if use_graphbolt=True."
    else:
        eids = g.edge_ids(src, dst)
        eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
        assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))
107

108

109
def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids, etype=None):
110
111
    gpb = None
    if disable_shared_mem:
112
113
114
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_find_edges.json", rank
        )
115
    dgl.distributed.initialize("rpc_ip_config.txt")
116
    dist_graph = DistGraph("test_find_edges", gpb=gpb)
117
    try:
118
        u, v = dist_graph.find_edges(eids, etype=etype)
119
    except Exception as e:
120
        print(traceback.format_exc())
121
        u, v = None, None
122
123
    dgl.distributed.exit_client()
    return u, v
Jinjing Zhou's avatar
Jinjing Zhou committed
124

125

126
127
128
def start_get_degrees_client(rank, tmpdir, disable_shared_mem, nids=None):
    gpb = None
    if disable_shared_mem:
129
130
131
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_get_degrees.json", rank
        )
132
    dgl.distributed.initialize("rpc_ip_config.txt")
133
134
135
136
137
138
139
    dist_graph = DistGraph("test_get_degrees", gpb=gpb)
    try:
        in_deg = dist_graph.in_degrees(nids)
        all_in_deg = dist_graph.in_degrees()
        out_deg = dist_graph.out_degrees(nids)
        all_out_deg = dist_graph.out_degrees()
    except Exception as e:
140
        print(traceback.format_exc())
141
142
143
144
        in_deg, out_deg, all_in_deg, all_out_deg = None, None, None, None
    dgl.distributed.exit_client()
    return in_deg, out_deg, all_in_deg, all_out_deg

145

146
def check_rpc_sampling(tmpdir, num_server):
147
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
Jinjing Zhou's avatar
Jinjing Zhou committed
148
149
150
151
152
153

    g = CitationGraphDataset("cora")[0]
    print(g.idtype)
    num_parts = num_server
    num_hops = 1

154
155
156
157
158
159
160
161
    partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )
Jinjing Zhou's avatar
Jinjing Zhou committed
162
163

    pserver_list = []
164
    ctx = mp.get_context("spawn")
Jinjing Zhou's avatar
Jinjing Zhou committed
165
    for i in range(num_server):
166
167
168
169
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
Jinjing Zhou's avatar
Jinjing Zhou committed
170
171
172
173
        p.start()
        time.sleep(1)
        pserver_list.append(p)

174
    sampled_graph = start_sample_client(0, tmpdir, num_server > 1)
Jinjing Zhou's avatar
Jinjing Zhou committed
175
176
177
    print("Done sampling")
    for p in pserver_list:
        p.join()
178
        assert p.exitcode == 0
Jinjing Zhou's avatar
Jinjing Zhou committed
179
180

    src, dst = sampled_graph.edges()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
181
    assert sampled_graph.num_nodes() == g.num_nodes()
Jinjing Zhou's avatar
Jinjing Zhou committed
182
183
184
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    assert np.array_equal(
185
186
187
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)
    )

Jinjing Zhou's avatar
Jinjing Zhou committed
188

189
def check_rpc_find_edges_shuffle(tmpdir, num_server):
190
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
191
192
193
194

    g = CitationGraphDataset("cora")[0]
    num_parts = num_server

195
196
197
198
199
200
201
202
203
    orig_nid, orig_eid = partition_graph(
        g,
        "test_find_edges",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
204
205

    pserver_list = []
206
    ctx = mp.get_context("spawn")
207
    for i in range(num_server):
208
209
210
211
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_find_edges", ["csr", "coo"]),
        )
212
213
214
215
        p.start()
        time.sleep(1)
        pserver_list.append(p)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
216
    eids = F.tensor(np.random.randint(g.num_edges(), size=100))
217
    u, v = g.find_edges(orig_eid[eids])
218
    du, dv = start_find_edges_client(0, tmpdir, num_server > 1, eids)
219
220
    du = orig_nid[du]
    dv = orig_nid[dv]
221
222
223
    assert F.array_equal(u, du)
    assert F.array_equal(v, dv)

224

225
def create_random_hetero(dense=False, empty=False):
226
227
228
229
230
231
    num_nodes = (
        {"n1": 210, "n2": 200, "n3": 220}
        if dense
        else {"n1": 1010, "n2": 1000, "n3": 1020}
    )
    etypes = [("n1", "r12", "n2"), ("n1", "r13", "n3"), ("n2", "r23", "n3")]
232
    edges = {}
233
    random.seed(42)
234
235
    for etype in etypes:
        src_ntype, _, dst_ntype = etype
236
237
238
239
240
241
242
        arr = spsp.random(
            num_nodes[src_ntype] - 10 if empty else num_nodes[src_ntype],
            num_nodes[dst_ntype] - 10 if empty else num_nodes[dst_ntype],
            density=0.1 if dense else 0.001,
            format="coo",
            random_state=100,
        )
243
        edges[etype] = (arr.row, arr.col)
244
    g = dgl.heterograph(edges, num_nodes)
245
    g.nodes["n1"].data["feat"] = F.ones(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
246
        (g.num_nodes("n1"), 10), F.float32, F.cpu()
247
    )
248
    return g
249

250

251
def check_rpc_hetero_find_edges_shuffle(tmpdir, num_server):
252
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
253
254
255
256

    g = create_random_hetero()
    num_parts = num_server

257
258
259
260
261
262
263
264
265
    orig_nid, orig_eid = partition_graph(
        g,
        "test_find_edges",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
266
267

    pserver_list = []
268
    ctx = mp.get_context("spawn")
269
    for i in range(num_server):
270
271
272
273
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_find_edges", ["csr", "coo"]),
        )
274
275
276
277
        p.start()
        time.sleep(1)
        pserver_list.append(p)

278
    test_etype = g.to_canonical_etype("r12")
279
    eids = F.tensor(np.random.randint(g.num_edges(test_etype), size=100))
280
281
    expect_except = False
    try:
282
        _, _ = g.find_edges(orig_eid[test_etype][eids], etype=("n1", "r12"))
283
284
285
    except:
        expect_except = True
    assert expect_except
286
287
    u, v = g.find_edges(orig_eid[test_etype][eids], etype="r12")
    u1, v1 = g.find_edges(orig_eid[test_etype][eids], etype=("n1", "r12", "n2"))
288
289
    assert F.array_equal(u, u1)
    assert F.array_equal(v, v1)
290
291
292
293
294
    du, dv = start_find_edges_client(
        0, tmpdir, num_server > 1, eids, etype="r12"
    )
    du = orig_nid["n1"][du]
    dv = orig_nid["n2"][dv]
295
296
297
    assert F.array_equal(u, du)
    assert F.array_equal(v, dv)

298

299
# Wait non shared memory graph store
300
301
302
303
304
305
306
307
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
308
@pytest.mark.parametrize("num_server", [1])
309
def test_rpc_find_edges_shuffle(num_server):
310
    reset_envs()
311
    import tempfile
312
313

    os.environ["DGL_DIST_MODE"] = "distributed"
314
    with tempfile.TemporaryDirectory() as tmpdirname:
315
        check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), num_server)
316
317
        check_rpc_find_edges_shuffle(Path(tmpdirname), num_server)

318

319
def check_rpc_get_degree_shuffle(tmpdir, num_server):
320
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
321
322
323
324

    g = CitationGraphDataset("cora")[0]
    num_parts = num_server

325
326
327
328
329
330
331
332
333
    orig_nid, _ = partition_graph(
        g,
        "test_get_degrees",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
334
335

    pserver_list = []
336
    ctx = mp.get_context("spawn")
337
    for i in range(num_server):
338
339
340
341
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_get_degrees"),
        )
342
343
344
345
        p.start()
        time.sleep(1)
        pserver_list.append(p)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
346
    nids = F.tensor(np.random.randint(g.num_nodes(), size=100))
347
348
349
    in_degs, out_degs, all_in_degs, all_out_degs = start_get_degrees_client(
        0, tmpdir, num_server > 1, nids
    )
350
351
352
353

    print("Done get_degree")
    for p in pserver_list:
        p.join()
354
        assert p.exitcode == 0
355

356
    print("check results")
357
358
359
360
361
    assert F.array_equal(g.in_degrees(orig_nid[nids]), in_degs)
    assert F.array_equal(g.in_degrees(orig_nid), all_in_degs)
    assert F.array_equal(g.out_degrees(orig_nid[nids]), out_degs)
    assert F.array_equal(g.out_degrees(orig_nid), all_out_degs)

362

363
# Wait non shared memory graph store
364
365
366
367
368
369
370
371
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
372
@pytest.mark.parametrize("num_server", [1])
373
def test_rpc_get_degree_shuffle(num_server):
374
    reset_envs()
375
    import tempfile
376
377

    os.environ["DGL_DIST_MODE"] = "distributed"
378
379
380
    with tempfile.TemporaryDirectory() as tmpdirname:
        check_rpc_get_degree_shuffle(Path(tmpdirname), num_server)

381
382
383
384

# @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
# @unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
@unittest.skip("Only support partition with shuffle")
Jinjing Zhou's avatar
Jinjing Zhou committed
385
def test_rpc_sampling():
386
    reset_envs()
Jinjing Zhou's avatar
Jinjing Zhou committed
387
    import tempfile
388
389

    os.environ["DGL_DIST_MODE"] = "distributed"
Jinjing Zhou's avatar
Jinjing Zhou committed
390
    with tempfile.TemporaryDirectory() as tmpdirname:
391
        check_rpc_sampling(Path(tmpdirname), 1)
Jinjing Zhou's avatar
Jinjing Zhou committed
392

393

394
def check_rpc_sampling_shuffle(
395
    tmpdir, num_server, num_groups=1, use_graphbolt=False, return_eids=False
396
):
397
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
398

Jinjing Zhou's avatar
Jinjing Zhou committed
399
400
401
402
    g = CitationGraphDataset("cora")[0]
    num_parts = num_server
    num_hops = 1

403
404
405
406
407
408
409
410
    orig_nids, orig_eids = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
411
        use_graphbolt=use_graphbolt,
412
        store_eids=return_eids,
413
    )
Jinjing Zhou's avatar
Jinjing Zhou committed
414
415

    pserver_list = []
416
    ctx = mp.get_context("spawn")
Jinjing Zhou's avatar
Jinjing Zhou committed
417
    for i in range(num_server):
418
419
420
421
422
423
424
425
        p = ctx.Process(
            target=start_server,
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
426
                use_graphbolt,
427
428
            ),
        )
Jinjing Zhou's avatar
Jinjing Zhou committed
429
430
431
432
        p.start()
        time.sleep(1)
        pserver_list.append(p)

433
434
435
436
    pclient_list = []
    num_clients = 1
    for client_id in range(num_clients):
        for group_id in range(num_groups):
437
438
439
440
441
442
443
444
445
446
447
            p = ctx.Process(
                target=start_sample_client_shuffle,
                args=(
                    client_id,
                    tmpdir,
                    num_server > 1,
                    g,
                    num_server,
                    group_id,
                    orig_nids,
                    orig_eids,
448
                    use_graphbolt,
449
                    return_eids,
450
451
                ),
            )
452
            p.start()
453
            time.sleep(1)  # avoid race condition when instantiating DistGraph
454
455
456
            pclient_list.append(p)
    for p in pclient_list:
        p.join()
457
        assert p.exitcode == 0
Jinjing Zhou's avatar
Jinjing Zhou committed
458
459
    for p in pserver_list:
        p.join()
460
        assert p.exitcode == 0
Jinjing Zhou's avatar
Jinjing Zhou committed
461

462

463
def start_hetero_sample_client(rank, tmpdir, disable_shared_mem, nodes):
464
465
    gpb = None
    if disable_shared_mem:
466
467
468
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
469
    dgl.distributed.initialize("rpc_ip_config.txt")
470
    dist_graph = DistGraph("test_sampling", gpb=gpb)
471
472
473
    assert "feat" in dist_graph.nodes["n1"].data
    assert "feat" not in dist_graph.nodes["n2"].data
    assert "feat" not in dist_graph.nodes["n3"].data
474
475
476
477
478
479
480
    if gpb is None:
        gpb = dist_graph.get_partition_book()
    try:
        sampled_graph = sample_neighbors(dist_graph, nodes, 3)
        block = dgl.to_block(sampled_graph, nodes)
        block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
    except Exception as e:
481
        print(traceback.format_exc())
482
483
484
485
        block = None
    dgl.distributed.exit_client()
    return block, gpb

486
487
488
489
490
491
492
493
494

def start_hetero_etype_sample_client(
    rank,
    tmpdir,
    disable_shared_mem,
    fanout=3,
    nodes={"n3": [0, 10, 99, 66, 124, 208]},
    etype_sorted=False,
):
495
496
    gpb = None
    if disable_shared_mem:
497
498
499
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
500
501
    dgl.distributed.initialize("rpc_ip_config.txt")
    dist_graph = DistGraph("test_sampling", gpb=gpb)
502
503
504
    assert "feat" in dist_graph.nodes["n1"].data
    assert "feat" not in dist_graph.nodes["n2"].data
    assert "feat" not in dist_graph.nodes["n3"].data
505
506
507
508
509
510

    if dist_graph.local_partition is not None:
        # Check whether etypes are sorted in dist_graph
        local_g = dist_graph.local_partition
        local_nids = np.arange(local_g.num_nodes())
        for lnid in local_nids:
511
            leids = local_g.in_edges(lnid, form="eid")
512
513
514
515
            letids = F.asnumpy(local_g.edata[dgl.ETYPE][leids])
            _, idices = np.unique(letids, return_index=True)
            assert np.all(idices[:-1] <= idices[1:])

516
517
518
    if gpb is None:
        gpb = dist_graph.get_partition_book()
    try:
519
        sampled_graph = sample_etype_neighbors(
520
521
            dist_graph, nodes, fanout, etype_sorted=etype_sorted
        )
522
523
524
        block = dgl.to_block(sampled_graph, nodes)
        block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
    except Exception as e:
525
        print(traceback.format_exc())
526
527
528
529
        block = None
    dgl.distributed.exit_client()
    return block, gpb

530

531
def check_rpc_hetero_sampling_shuffle(tmpdir, num_server):
532
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
533
534
535
536
537

    g = create_random_hetero()
    num_parts = num_server
    num_hops = 1

538
539
540
541
542
543
544
545
546
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
    )
547
548

    pserver_list = []
549
    ctx = mp.get_context("spawn")
550
    for i in range(num_server):
551
552
553
554
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
555
556
557
558
        p.start()
        time.sleep(1)
        pserver_list.append(p)

559
560
561
    block, gpb = start_hetero_sample_client(
        0, tmpdir, num_server > 1, nodes={"n3": [0, 10, 99, 66, 124, 208]}
    )
562
563
564
    print("Done sampling")
    for p in pserver_list:
        p.join()
565
        assert p.exitcode == 0
566

567
568
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
569
570
571
572
573
574
575
576
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        shuffled_eid = block.edges[etype].data[dgl.EID]

        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
577
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
578
579

        # Check the node Ids and edge Ids.
580
581
582
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)
583

584

585
586
587
588
589
590
591
592
593
def get_degrees(g, nids, ntype):
    deg = F.zeros((len(nids),), dtype=F.int64)
    for srctype, etype, dsttype in g.canonical_etypes:
        if srctype == ntype:
            deg += g.out_degrees(u=nids, etype=etype)
        elif dsttype == ntype:
            deg += g.in_degrees(v=nids, etype=etype)
    return deg

594

595
def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server):
596
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
597
598
599
600
601

    g = create_random_hetero(empty=True)
    num_parts = num_server
    num_hops = 1

602
603
604
605
606
607
608
609
610
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
    )
611
612

    pserver_list = []
613
    ctx = mp.get_context("spawn")
614
    for i in range(num_server):
615
616
617
618
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
619
620
621
622
        p.start()
        time.sleep(1)
        pserver_list.append(p)

623
    deg = get_degrees(g, orig_nids["n3"], "n3")
624
    empty_nids = F.nonzero_1d(deg == 0)
625
626
627
    block, gpb = start_hetero_sample_client(
        0, tmpdir, num_server > 1, nodes={"n3": empty_nids}
    )
628
629
630
    print("Done sampling")
    for p in pserver_list:
        p.join()
631
        assert p.exitcode == 0
632

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
633
    assert block.num_edges() == 0
634
635
    assert len(block.etypes) == len(g.etypes)

636
637
638
639

def check_rpc_hetero_etype_sampling_shuffle(
    tmpdir, num_server, graph_formats=None
):
640
641
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

642
643
644
645
    g = create_random_hetero(dense=True)
    num_parts = num_server
    num_hops = 1

646
647
648
649
650
651
652
653
654
655
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
        graph_formats=graph_formats,
    )
656
657

    pserver_list = []
658
    ctx = mp.get_context("spawn")
659
    for i in range(num_server):
660
661
662
663
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling", ["csc", "coo"]),
        )
664
665
666
667
        p.start()
        time.sleep(1)
        pserver_list.append(p)

668
    fanout = {etype: 3 for etype in g.canonical_etypes}
669
670
    etype_sorted = False
    if graph_formats is not None:
671
672
673
674
675
676
677
678
679
        etype_sorted = "csc" in graph_formats or "csr" in graph_formats
    block, gpb = start_hetero_etype_sample_client(
        0,
        tmpdir,
        num_server > 1,
        fanout,
        nodes={"n3": [0, 10, 99, 66, 124, 208]},
        etype_sorted=etype_sorted,
    )
680
681
682
    print("Done sampling")
    for p in pserver_list:
        p.join()
683
        assert p.exitcode == 0
684

685
    src, dst = block.edges(etype=("n1", "r13", "n3"))
686
    assert len(src) == 18
687
    src, dst = block.edges(etype=("n2", "r23", "n3"))
688
689
    assert len(src) == 18

690
691
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
692
693
694
695
696
697
698
699
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        shuffled_eid = block.edges[etype].data[dgl.EID]

        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
700
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
701
702
703
704
705
706

        # Check the node Ids and edge Ids.
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)

707

708
def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server):
709
710
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

711
712
713
714
    g = create_random_hetero(dense=True, empty=True)
    num_parts = num_server
    num_hops = 1

715
716
717
718
719
720
721
722
723
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
    )
724
725

    pserver_list = []
726
    ctx = mp.get_context("spawn")
727
    for i in range(num_server):
728
729
730
731
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
732
733
734
735
736
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    fanout = 3
737
    deg = get_degrees(g, orig_nids["n3"], "n3")
738
    empty_nids = F.nonzero_1d(deg == 0)
739
740
741
    block, gpb = start_hetero_etype_sample_client(
        0, tmpdir, num_server > 1, fanout, nodes={"n3": empty_nids}
    )
742
743
744
    print("Done sampling")
    for p in pserver_list:
        p.join()
745
        assert p.exitcode == 0
746

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
747
    assert block.num_edges() == 0
748
749
    assert len(block.etypes) == len(g.etypes)

750
751

def create_random_bipartite():
752
753
754
755
756
757
758
    g = dgl.rand_bipartite("user", "buys", "game", 500, 1000, 1000)
    g.nodes["user"].data["feat"] = F.ones(
        (g.num_nodes("user"), 10), F.float32, F.cpu()
    )
    g.nodes["game"].data["feat"] = F.ones(
        (g.num_nodes("game"), 10), F.float32, F.cpu()
    )
759
760
761
762
763
764
765
    return g


def start_bipartite_sample_client(rank, tmpdir, disable_shared_mem, nodes):
    gpb = None
    if disable_shared_mem:
        _, _, _, gpb, _, _, _ = load_partition(
766
767
            tmpdir / "test_sampling.json", rank
        )
768
769
    dgl.distributed.initialize("rpc_ip_config.txt")
    dist_graph = DistGraph("test_sampling", gpb=gpb)
770
771
    assert "feat" in dist_graph.nodes["user"].data
    assert "feat" in dist_graph.nodes["game"].data
772
773
774
775
776
777
778
779
780
781
    if gpb is None:
        gpb = dist_graph.get_partition_book()
    sampled_graph = sample_neighbors(dist_graph, nodes, 3)
    block = dgl.to_block(sampled_graph, nodes)
    if sampled_graph.num_edges() > 0:
        block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
    dgl.distributed.exit_client()
    return block, gpb


782
783
784
def start_bipartite_etype_sample_client(
    rank, tmpdir, disable_shared_mem, fanout=3, nodes={}
):
785
786
787
    gpb = None
    if disable_shared_mem:
        _, _, _, gpb, _, _, _ = load_partition(
788
789
            tmpdir / "test_sampling.json", rank
        )
790
791
    dgl.distributed.initialize("rpc_ip_config.txt")
    dist_graph = DistGraph("test_sampling", gpb=gpb)
792
793
    assert "feat" in dist_graph.nodes["user"].data
    assert "feat" in dist_graph.nodes["game"].data
794
795
796
797
798
799

    if dist_graph.local_partition is not None:
        # Check whether etypes are sorted in dist_graph
        local_g = dist_graph.local_partition
        local_nids = np.arange(local_g.num_nodes())
        for lnid in local_nids:
800
            leids = local_g.in_edges(lnid, form="eid")
801
802
803
804
805
806
            letids = F.asnumpy(local_g.edata[dgl.ETYPE][leids])
            _, idices = np.unique(letids, return_index=True)
            assert np.all(idices[:-1] <= idices[1:])

    if gpb is None:
        gpb = dist_graph.get_partition_book()
807
    sampled_graph = sample_etype_neighbors(dist_graph, nodes, fanout)
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
    block = dgl.to_block(sampled_graph, nodes)
    if sampled_graph.num_edges() > 0:
        block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
    dgl.distributed.exit_client()
    return block, gpb


def check_rpc_bipartite_sampling_empty(tmpdir, num_server):
    """sample on bipartite via sample_neighbors() which yields empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

823
824
825
826
827
828
829
830
831
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
    )
832
833

    pserver_list = []
834
    ctx = mp.get_context("spawn")
835
    for i in range(num_server):
836
837
838
839
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
840
841
842
843
        p.start()
        time.sleep(1)
        pserver_list.append(p)

844
    deg = get_degrees(g, orig_nids["game"], "game")
845
    empty_nids = F.nonzero_1d(deg == 0)
846
847
848
    block, _ = start_bipartite_sample_client(
        0, tmpdir, num_server > 1, nodes={"game": empty_nids, "user": [1]}
    )
849
850
851
852

    print("Done sampling")
    for p in pserver_list:
        p.join()
853
        assert p.exitcode == 0
854

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
855
    assert block.num_edges() == 0
856
857
858
859
860
861
862
863
864
865
866
    assert len(block.etypes) == len(g.etypes)


def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server):
    """sample on bipartite via sample_neighbors() which yields non-empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

867
868
869
870
871
872
873
874
875
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
    )
876
877

    pserver_list = []
878
    ctx = mp.get_context("spawn")
879
    for i in range(num_server):
880
881
882
883
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
884
885
886
887
        p.start()
        time.sleep(1)
        pserver_list.append(p)

888
    deg = get_degrees(g, orig_nid_map["game"], "game")
889
    nids = F.nonzero_1d(deg > 0)
890
891
892
    block, gpb = start_bipartite_sample_client(
        0, tmpdir, num_server > 1, nodes={"game": nids, "user": [0]}
    )
893
894
895
    print("Done sampling")
    for p in pserver_list:
        p.join()
896
        assert p.exitcode == 0
897

898
899
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
900
901
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
902
903
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
904
905
        shuffled_eid = block.edges[etype].data[dgl.EID]

906
907
        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
908
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923

        # Check the node Ids and edge Ids.
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)


def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server):
    """sample on bipartite via sample_etype_neighbors() which yields empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

924
925
926
927
928
929
930
931
932
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
    )
933
934

    pserver_list = []
935
    ctx = mp.get_context("spawn")
936
    for i in range(num_server):
937
938
939
940
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
941
942
943
944
        p.start()
        time.sleep(1)
        pserver_list.append(p)

945
    deg = get_degrees(g, orig_nids["game"], "game")
946
    empty_nids = F.nonzero_1d(deg == 0)
947
948
949
    block, gpb = start_bipartite_etype_sample_client(
        0, tmpdir, num_server > 1, nodes={"game": empty_nids, "user": [1]}
    )
950
951
952
953

    print("Done sampling")
    for p in pserver_list:
        p.join()
954
        assert p.exitcode == 0
955
956

    assert block is not None
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
957
    assert block.num_edges() == 0
958
959
960
961
962
963
964
965
966
967
968
    assert len(block.etypes) == len(g.etypes)


def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
    """sample on bipartite via sample_etype_neighbors() which yields non-empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

969
970
971
972
973
974
975
976
977
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
    )
978
979

    pserver_list = []
980
    ctx = mp.get_context("spawn")
981
    for i in range(num_server):
982
983
984
985
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
986
987
988
989
990
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    fanout = 3
991
    deg = get_degrees(g, orig_nid_map["game"], "game")
992
    nids = F.nonzero_1d(deg > 0)
993
994
995
    block, gpb = start_bipartite_etype_sample_client(
        0, tmpdir, num_server > 1, fanout, nodes={"game": nids, "user": [0]}
    )
996
997
998
    print("Done sampling")
    for p in pserver_list:
        p.join()
999
        assert p.exitcode == 0
1000

1001
1002
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
1003
1004
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
1005
1006
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
1007
1008
        shuffled_eid = block.edges[etype].data[dgl.EID]

1009
1010
        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
1011
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
1012
1013
1014
1015
1016
1017

        # Check the node Ids and edge Ids.
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)

1018

1019
@pytest.mark.parametrize("num_server", [1])
1020
@pytest.mark.parametrize("use_graphbolt", [False, True])
1021
1022
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_sampling_shuffle(num_server, use_graphbolt, return_eids):
1023
    reset_envs()
1024
    os.environ["DGL_DIST_MODE"] = "distributed"
Jinjing Zhou's avatar
Jinjing Zhou committed
1025
    with tempfile.TemporaryDirectory() as tmpdirname:
1026
        check_rpc_sampling_shuffle(
1027
1028
1029
1030
            Path(tmpdirname),
            num_server,
            use_graphbolt=use_graphbolt,
            return_eids=return_eids,
1031
        )
1032
1033
1034
1035
1036
1037
1038


@pytest.mark.parametrize("num_server", [1])
def test_rpc_hetero_sampling_shuffle(num_server):
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1039
        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), num_server)
1040
1041
1042
1043
1044
1045
1046


@pytest.mark.parametrize("num_server", [1])
def test_rpc_hetero_sampling_empty_shuffle(num_server):
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1047
        check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), num_server)
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057


@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize(
    "graph_formats", [None, ["csc"], ["csr"], ["csc", "coo"]]
)
def test_rpc_hetero_etype_sampling_shuffle(num_server, graph_formats):
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1058
        check_rpc_hetero_etype_sampling_shuffle(
1059
            Path(tmpdirname), num_server, graph_formats=graph_formats
1060
        )
1061
1062
1063
1064
1065
1066
1067


@pytest.mark.parametrize("num_server", [1])
def test_rpc_hetero_etype_sampling_empty_shuffle(num_server):
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1068
1069
1070
        check_rpc_hetero_etype_sampling_empty_shuffle(
            Path(tmpdirname), num_server
        )
1071
1072
1073
1074
1075
1076
1077


@pytest.mark.parametrize("num_server", [1])
def test_rpc_bipartite_sampling_empty_shuffle(num_server):
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1078
        check_rpc_bipartite_sampling_empty(Path(tmpdirname), num_server)
1079
1080
1081
1082
1083
1084
1085


@pytest.mark.parametrize("num_server", [1])
def test_rpc_bipartite_sampling_shuffle(num_server):
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1086
        check_rpc_bipartite_sampling_shuffle(Path(tmpdirname), num_server)
1087
1088
1089
1090
1091
1092
1093


@pytest.mark.parametrize("num_server", [1])
def test_rpc_bipartite_etype_sampling_empty_shuffle(num_server):
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1094
        check_rpc_bipartite_etype_sampling_empty(Path(tmpdirname), num_server)
1095
1096
1097
1098
1099
1100
1101


@pytest.mark.parametrize("num_server", [1])
def test_rpc_bipartite_etype_sampling_shuffle(num_server):
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    with tempfile.TemporaryDirectory() as tmpdirname:
1102
        check_rpc_bipartite_etype_sampling_shuffle(Path(tmpdirname), num_server)
Jinjing Zhou's avatar
Jinjing Zhou committed
1103

1104

1105
def check_standalone_sampling(tmpdir):
1106
    g = CitationGraphDataset("cora")[0]
1107
    prob = np.maximum(np.random.randn(g.num_edges()), 0)
1108
1109
1110
    mask = prob > 0
    g.edata["prob"] = F.tensor(prob)
    g.edata["mask"] = F.tensor(mask)
1111
1112
    num_parts = 1
    num_hops = 1
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
    partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )

    os.environ["DGL_DIST_MODE"] = "standalone"
1123
    dgl.distributed.initialize("rpc_ip_config.txt")
1124
1125
1126
    dist_graph = DistGraph(
        "test_sampling", part_config=tmpdir / "test_sampling.json"
    )
1127
1128
1129
    sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)

    src, dst = sampled_graph.edges()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1130
    assert sampled_graph.num_nodes() == g.num_nodes()
1131
1132
1133
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    assert np.array_equal(
1134
1135
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)
    )
1136
1137

    sampled_graph = sample_neighbors(
1138
1139
        dist_graph, [0, 10, 99, 66, 1024, 2008], 3, prob="mask"
    )
1140
1141
1142
1143
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert mask[eid].all()

    sampled_graph = sample_neighbors(
1144
1145
        dist_graph, [0, 10, 99, 66, 1024, 2008], 3, prob="prob"
    )
1146
1147
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert (prob[eid] > 0).all()
1148
    dgl.distributed.exit_client()
1149

1150

1151
def check_standalone_etype_sampling(tmpdir):
1152
    hg = CitationGraphDataset("cora")[0]
1153
    prob = np.maximum(np.random.randn(hg.num_edges()), 0)
1154
1155
1156
    mask = prob > 0
    hg.edata["prob"] = F.tensor(prob)
    hg.edata["mask"] = F.tensor(mask)
1157
1158
1159
    num_parts = 1
    num_hops = 1

1160
1161
1162
1163
1164
1165
1166
1167
1168
    partition_graph(
        hg,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )
    os.environ["DGL_DIST_MODE"] = "standalone"
1169
    dgl.distributed.initialize("rpc_ip_config.txt")
1170
1171
1172
    dist_graph = DistGraph(
        "test_sampling", part_config=tmpdir / "test_sampling.json"
    )
1173
    sampled_graph = sample_etype_neighbors(dist_graph, [0, 10, 99, 66, 1023], 3)
1174
1175

    src, dst = sampled_graph.edges()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1176
    assert sampled_graph.num_nodes() == hg.num_nodes()
1177
1178
1179
    assert np.all(F.asnumpy(hg.has_edges_between(src, dst)))
    eids = hg.edge_ids(src, dst)
    assert np.array_equal(
1180
1181
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)
    )
1182
1183

    sampled_graph = sample_etype_neighbors(
1184
1185
        dist_graph, [0, 10, 99, 66, 1023], 3, prob="mask"
    )
1186
1187
1188
1189
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert mask[eid].all()

    sampled_graph = sample_etype_neighbors(
1190
1191
        dist_graph, [0, 10, 99, 66, 1023], 3, prob="prob"
    )
1192
1193
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert (prob[eid] > 0).all()
1194
1195
    dgl.distributed.exit_client()

1196

1197
def check_standalone_etype_sampling_heterograph(tmpdir):
1198
    hg = CitationGraphDataset("cora")[0]
1199
1200
1201
    num_parts = 1
    num_hops = 1
    src, dst = hg.edges()
1202
1203
1204
1205
1206
    new_hg = dgl.heterograph(
        {
            ("paper", "cite", "paper"): (src, dst),
            ("paper", "cite-by", "paper"): (dst, src),
        },
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1207
        {"paper": hg.num_nodes()},
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
    )
    partition_graph(
        new_hg,
        "test_hetero_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )
    os.environ["DGL_DIST_MODE"] = "standalone"
1218
    dgl.distributed.initialize("rpc_ip_config.txt")
1219
1220
1221
    dist_graph = DistGraph(
        "test_hetero_sampling", part_config=tmpdir / "test_hetero_sampling.json"
    )
1222
    sampled_graph = sample_etype_neighbors(
1223
1224
1225
        dist_graph, [0, 1, 2, 10, 99, 66, 1023, 1024, 2700, 2701], 1
    )
    src, dst = sampled_graph.edges(etype=("paper", "cite", "paper"))
1226
    assert len(src) == 10
1227
    src, dst = sampled_graph.edges(etype=("paper", "cite-by", "paper"))
1228
    assert len(src) == 10
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1229
    assert sampled_graph.num_nodes() == new_hg.num_nodes()
1230
1231
    dgl.distributed.exit_client()

1232
1233
1234
1235
1236
1237

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
1238
def test_standalone_sampling():
1239
    reset_envs()
1240
    import tempfile
1241
1242

    os.environ["DGL_DIST_MODE"] = "standalone"
1243
    with tempfile.TemporaryDirectory() as tmpdirname:
1244
        check_standalone_sampling(Path(tmpdirname))
1245

1246

1247
1248
def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes):
    gpb = None
1249
    dgl.distributed.initialize("rpc_ip_config.txt")
1250
    if disable_shared_mem:
1251
1252
1253
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_in_subgraph.json", rank
        )
1254
    dist_graph = DistGraph("test_in_subgraph", gpb=gpb)
1255
1256
1257
    try:
        sampled_graph = dgl.distributed.in_subgraph(dist_graph, nodes)
    except Exception as e:
1258
        print(traceback.format_exc())
1259
        sampled_graph = None
1260
    dgl.distributed.exit_client()
1261
1262
1263
    return sampled_graph


1264
def check_rpc_in_subgraph_shuffle(tmpdir, num_server):
1265
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
1266
1267
1268
1269

    g = CitationGraphDataset("cora")[0]
    num_parts = num_server

1270
1271
1272
1273
1274
1275
1276
1277
1278
    orig_nid, orig_eid = partition_graph(
        g,
        "test_in_subgraph",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
1279
1280

    pserver_list = []
1281
    ctx = mp.get_context("spawn")
1282
    for i in range(num_server):
1283
1284
1285
1286
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_in_subgraph"),
        )
1287
1288
1289
1290
1291
1292
1293
1294
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    nodes = [0, 10, 99, 66, 1024, 2008]
    sampled_graph = start_in_subgraph_client(0, tmpdir, num_server > 1, nodes)
    for p in pserver_list:
        p.join()
1295
        assert p.exitcode == 0
1296
1297

    src, dst = sampled_graph.edges()
1298
1299
    src = orig_nid[src]
    dst = orig_nid[dst]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1300
    assert sampled_graph.num_nodes() == g.num_nodes()
1301
1302
1303
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))

    subg1 = dgl.in_subgraph(g, orig_nid[nodes])
1304
1305
1306
1307
    src1, dst1 = subg1.edges()
    assert np.all(np.sort(F.asnumpy(src)) == np.sort(F.asnumpy(src1)))
    assert np.all(np.sort(F.asnumpy(dst)) == np.sort(F.asnumpy(dst1)))
    eids = g.edge_ids(src, dst)
1308
1309
    eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
    assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))
1310

1311
1312
1313
1314
1315
1316

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
1317
def test_rpc_in_subgraph():
1318
    reset_envs()
1319
    import tempfile
1320
1321

    os.environ["DGL_DIST_MODE"] = "distributed"
1322
    with tempfile.TemporaryDirectory() as tmpdirname:
1323
        check_rpc_in_subgraph_shuffle(Path(tmpdirname), 1)
1324

1325
1326
1327
1328
1329
1330
1331
1332
1333

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
1334
def test_standalone_etype_sampling():
1335
    reset_envs()
1336
    import tempfile
1337

1338
    with tempfile.TemporaryDirectory() as tmpdirname:
1339
        os.environ["DGL_DIST_MODE"] = "standalone"
1340
        check_standalone_etype_sampling_heterograph(Path(tmpdirname))
1341
    with tempfile.TemporaryDirectory() as tmpdirname:
1342
        os.environ["DGL_DIST_MODE"] = "standalone"
1343
        check_standalone_etype_sampling(Path(tmpdirname))
1344

1345

Jinjing Zhou's avatar
Jinjing Zhou committed
1346
1347
if __name__ == "__main__":
    import tempfile
1348

Jinjing Zhou's avatar
Jinjing Zhou committed
1349
    with tempfile.TemporaryDirectory() as tmpdirname:
1350
        os.environ["DGL_DIST_MODE"] = "standalone"
1351
        check_standalone_etype_sampling_heterograph(Path(tmpdirname))
1352
1353

    with tempfile.TemporaryDirectory() as tmpdirname:
1354
        os.environ["DGL_DIST_MODE"] = "standalone"
1355
1356
        check_standalone_etype_sampling(Path(tmpdirname))
        check_standalone_sampling(Path(tmpdirname))
1357
        os.environ["DGL_DIST_MODE"] = "distributed"
1358
1359
        check_rpc_sampling(Path(tmpdirname), 2)
        check_rpc_sampling(Path(tmpdirname), 1)
1360
1361
        check_rpc_get_degree_shuffle(Path(tmpdirname), 1)
        check_rpc_get_degree_shuffle(Path(tmpdirname), 2)
1362
1363
        check_rpc_find_edges_shuffle(Path(tmpdirname), 2)
        check_rpc_find_edges_shuffle(Path(tmpdirname), 1)
1364
1365
        check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), 2)
1366
1367
1368
1369
        check_rpc_in_subgraph_shuffle(Path(tmpdirname), 2)
        check_rpc_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), 2)
1370
1371
1372
1373
        check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), 2)
        check_rpc_hetero_etype_sampling_empty_shuffle(Path(tmpdirname), 1)