test_distributed_sampling.py 40.2 KB
Newer Older
1
import multiprocessing as mp
Jinjing Zhou's avatar
Jinjing Zhou committed
2
import os
3
import random
Jinjing Zhou's avatar
Jinjing Zhou committed
4
5
import sys
import time
6
7
import traceback
import unittest
Jinjing Zhou's avatar
Jinjing Zhou committed
8
from pathlib import Path
9
10
11
12

import backend as F
import dgl
import numpy as np
13
import pytest
14
15
16
17
18
19
20
21
22
23
from dgl.data import CitationGraphDataset, WN18Dataset
from dgl.distributed import (
    DistGraph,
    DistGraphServer,
    load_partition,
    load_partition_book,
    partition_graph,
    sample_etype_neighbors,
    sample_neighbors,
)
24
from scipy import sparse as spsp
25
from utils import generate_ip_config, reset_envs
Jinjing Zhou's avatar
Jinjing Zhou committed
26
27


28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def start_server(
    rank,
    tmpdir,
    disable_shared_mem,
    graph_name,
    graph_format=["csc", "coo"],
):
    g = DistGraphServer(
        rank,
        "rpc_ip_config.txt",
        1,
        1,
        tmpdir / (graph_name + ".json"),
        disable_shared_mem=disable_shared_mem,
        graph_format=graph_format,
    )
Jinjing Zhou's avatar
Jinjing Zhou committed
44
45
46
    g.start()


47
def start_sample_client(rank, tmpdir, disable_shared_mem):
48
49
    gpb = None
    if disable_shared_mem:
50
51
52
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
53
    dgl.distributed.initialize("rpc_ip_config.txt")
54
    dist_graph = DistGraph("test_sampling", gpb=gpb)
55
    try:
56
57
58
        sampled_graph = sample_neighbors(
            dist_graph, [0, 10, 99, 66, 1024, 2008], 3
        )
59
    except Exception as e:
60
        print(traceback.format_exc())
61
        sampled_graph = None
62
    dgl.distributed.exit_client()
Jinjing Zhou's avatar
Jinjing Zhou committed
63
64
    return sampled_graph

65

66
67
68
69
70
71
72
73
74
75
76
def start_sample_client_shuffle(
    rank,
    tmpdir,
    disable_shared_mem,
    g,
    num_servers,
    group_id,
    orig_nid,
    orig_eid,
):
    os.environ["DGL_GROUP_ID"] = str(group_id)
77
78
    gpb = None
    if disable_shared_mem:
79
80
81
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
82
83
84
85
86
87
88
    dgl.distributed.initialize("rpc_ip_config.txt")
    dist_graph = DistGraph("test_sampling", gpb=gpb)
    sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)

    src, dst = sampled_graph.edges()
    src = orig_nid[src]
    dst = orig_nid[dst]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
89
    assert sampled_graph.num_nodes() == g.num_nodes()
90
91
92
93
94
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
    assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))

95

96
def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids, etype=None):
97
98
    gpb = None
    if disable_shared_mem:
99
100
101
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_find_edges.json", rank
        )
102
    dgl.distributed.initialize("rpc_ip_config.txt")
103
    dist_graph = DistGraph("test_find_edges", gpb=gpb)
104
    try:
105
        u, v = dist_graph.find_edges(eids, etype=etype)
106
    except Exception as e:
107
        print(traceback.format_exc())
108
        u, v = None, None
109
110
    dgl.distributed.exit_client()
    return u, v
Jinjing Zhou's avatar
Jinjing Zhou committed
111

112

113
114
115
def start_get_degrees_client(rank, tmpdir, disable_shared_mem, nids=None):
    gpb = None
    if disable_shared_mem:
116
117
118
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_get_degrees.json", rank
        )
119
    dgl.distributed.initialize("rpc_ip_config.txt")
120
121
122
123
124
125
126
    dist_graph = DistGraph("test_get_degrees", gpb=gpb)
    try:
        in_deg = dist_graph.in_degrees(nids)
        all_in_deg = dist_graph.in_degrees()
        out_deg = dist_graph.out_degrees(nids)
        all_out_deg = dist_graph.out_degrees()
    except Exception as e:
127
        print(traceback.format_exc())
128
129
130
131
        in_deg, out_deg, all_in_deg, all_out_deg = None, None, None, None
    dgl.distributed.exit_client()
    return in_deg, out_deg, all_in_deg, all_out_deg

132

133
def check_rpc_sampling(tmpdir, num_server):
134
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
Jinjing Zhou's avatar
Jinjing Zhou committed
135
136
137
138
139
140

    g = CitationGraphDataset("cora")[0]
    print(g.idtype)
    num_parts = num_server
    num_hops = 1

141
142
143
144
145
146
147
148
    partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )
Jinjing Zhou's avatar
Jinjing Zhou committed
149
150

    pserver_list = []
151
    ctx = mp.get_context("spawn")
Jinjing Zhou's avatar
Jinjing Zhou committed
152
    for i in range(num_server):
153
154
155
156
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
Jinjing Zhou's avatar
Jinjing Zhou committed
157
158
159
160
        p.start()
        time.sleep(1)
        pserver_list.append(p)

161
    sampled_graph = start_sample_client(0, tmpdir, num_server > 1)
Jinjing Zhou's avatar
Jinjing Zhou committed
162
163
164
    print("Done sampling")
    for p in pserver_list:
        p.join()
165
        assert p.exitcode == 0
Jinjing Zhou's avatar
Jinjing Zhou committed
166
167

    src, dst = sampled_graph.edges()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
168
    assert sampled_graph.num_nodes() == g.num_nodes()
Jinjing Zhou's avatar
Jinjing Zhou committed
169
170
171
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    assert np.array_equal(
172
173
174
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)
    )

Jinjing Zhou's avatar
Jinjing Zhou committed
175

176
def check_rpc_find_edges_shuffle(tmpdir, num_server):
177
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
178
179
180
181

    g = CitationGraphDataset("cora")[0]
    num_parts = num_server

182
183
184
185
186
187
188
189
190
    orig_nid, orig_eid = partition_graph(
        g,
        "test_find_edges",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
191
192

    pserver_list = []
193
    ctx = mp.get_context("spawn")
194
    for i in range(num_server):
195
196
197
198
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_find_edges", ["csr", "coo"]),
        )
199
200
201
202
        p.start()
        time.sleep(1)
        pserver_list.append(p)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
203
    eids = F.tensor(np.random.randint(g.num_edges(), size=100))
204
    u, v = g.find_edges(orig_eid[eids])
205
    du, dv = start_find_edges_client(0, tmpdir, num_server > 1, eids)
206
207
    du = orig_nid[du]
    dv = orig_nid[dv]
208
209
210
    assert F.array_equal(u, du)
    assert F.array_equal(v, dv)

211

212
def create_random_hetero(dense=False, empty=False):
213
214
215
216
217
218
    num_nodes = (
        {"n1": 210, "n2": 200, "n3": 220}
        if dense
        else {"n1": 1010, "n2": 1000, "n3": 1020}
    )
    etypes = [("n1", "r12", "n2"), ("n1", "r13", "n3"), ("n2", "r23", "n3")]
219
    edges = {}
220
    random.seed(42)
221
222
    for etype in etypes:
        src_ntype, _, dst_ntype = etype
223
224
225
226
227
228
229
        arr = spsp.random(
            num_nodes[src_ntype] - 10 if empty else num_nodes[src_ntype],
            num_nodes[dst_ntype] - 10 if empty else num_nodes[dst_ntype],
            density=0.1 if dense else 0.001,
            format="coo",
            random_state=100,
        )
230
        edges[etype] = (arr.row, arr.col)
231
    g = dgl.heterograph(edges, num_nodes)
232
    g.nodes["n1"].data["feat"] = F.ones(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
233
        (g.num_nodes("n1"), 10), F.float32, F.cpu()
234
    )
235
    return g
236

237

238
def check_rpc_hetero_find_edges_shuffle(tmpdir, num_server):
239
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
240
241
242
243

    g = create_random_hetero()
    num_parts = num_server

244
245
246
247
248
249
250
251
252
    orig_nid, orig_eid = partition_graph(
        g,
        "test_find_edges",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
253
254

    pserver_list = []
255
    ctx = mp.get_context("spawn")
256
    for i in range(num_server):
257
258
259
260
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_find_edges", ["csr", "coo"]),
        )
261
262
263
264
        p.start()
        time.sleep(1)
        pserver_list.append(p)

265
    test_etype = g.to_canonical_etype("r12")
266
    eids = F.tensor(np.random.randint(g.num_edges(test_etype), size=100))
267
268
    expect_except = False
    try:
269
        _, _ = g.find_edges(orig_eid[test_etype][eids], etype=("n1", "r12"))
270
271
272
    except:
        expect_except = True
    assert expect_except
273
274
    u, v = g.find_edges(orig_eid[test_etype][eids], etype="r12")
    u1, v1 = g.find_edges(orig_eid[test_etype][eids], etype=("n1", "r12", "n2"))
275
276
    assert F.array_equal(u, u1)
    assert F.array_equal(v, v1)
277
278
279
280
281
    du, dv = start_find_edges_client(
        0, tmpdir, num_server > 1, eids, etype="r12"
    )
    du = orig_nid["n1"][du]
    dv = orig_nid["n2"][dv]
282
283
284
    assert F.array_equal(u, du)
    assert F.array_equal(v, dv)

285

286
# Wait non shared memory graph store
287
288
289
290
291
292
293
294
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
295
@pytest.mark.parametrize("num_server", [1])
296
def test_rpc_find_edges_shuffle(num_server):
297
    reset_envs()
298
    import tempfile
299
300

    os.environ["DGL_DIST_MODE"] = "distributed"
301
    with tempfile.TemporaryDirectory() as tmpdirname:
302
        check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), num_server)
303
304
        check_rpc_find_edges_shuffle(Path(tmpdirname), num_server)

305

306
def check_rpc_get_degree_shuffle(tmpdir, num_server):
307
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
308
309
310
311

    g = CitationGraphDataset("cora")[0]
    num_parts = num_server

312
313
314
315
316
317
318
319
320
    orig_nid, _ = partition_graph(
        g,
        "test_get_degrees",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
321
322

    pserver_list = []
323
    ctx = mp.get_context("spawn")
324
    for i in range(num_server):
325
326
327
328
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_get_degrees"),
        )
329
330
331
332
        p.start()
        time.sleep(1)
        pserver_list.append(p)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
333
    nids = F.tensor(np.random.randint(g.num_nodes(), size=100))
334
335
336
    in_degs, out_degs, all_in_degs, all_out_degs = start_get_degrees_client(
        0, tmpdir, num_server > 1, nids
    )
337
338
339
340

    print("Done get_degree")
    for p in pserver_list:
        p.join()
341
        assert p.exitcode == 0
342

343
    print("check results")
344
345
346
347
348
    assert F.array_equal(g.in_degrees(orig_nid[nids]), in_degs)
    assert F.array_equal(g.in_degrees(orig_nid), all_in_degs)
    assert F.array_equal(g.out_degrees(orig_nid[nids]), out_degs)
    assert F.array_equal(g.out_degrees(orig_nid), all_out_degs)

349

350
# Wait non shared memory graph store
351
352
353
354
355
356
357
358
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
359
@pytest.mark.parametrize("num_server", [1])
360
def test_rpc_get_degree_shuffle(num_server):
361
    reset_envs()
362
    import tempfile
363
364

    os.environ["DGL_DIST_MODE"] = "distributed"
365
366
367
    with tempfile.TemporaryDirectory() as tmpdirname:
        check_rpc_get_degree_shuffle(Path(tmpdirname), num_server)

368
369
370
371

# @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
# @unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
@unittest.skip("Only support partition with shuffle")
Jinjing Zhou's avatar
Jinjing Zhou committed
372
def test_rpc_sampling():
373
    reset_envs()
Jinjing Zhou's avatar
Jinjing Zhou committed
374
    import tempfile
375
376

    os.environ["DGL_DIST_MODE"] = "distributed"
Jinjing Zhou's avatar
Jinjing Zhou committed
377
    with tempfile.TemporaryDirectory() as tmpdirname:
378
        check_rpc_sampling(Path(tmpdirname), 1)
Jinjing Zhou's avatar
Jinjing Zhou committed
379

380

381
def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
382
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
383

Jinjing Zhou's avatar
Jinjing Zhou committed
384
385
386
387
    g = CitationGraphDataset("cora")[0]
    num_parts = num_server
    num_hops = 1

388
389
390
391
392
393
394
395
396
    orig_nids, orig_eids = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
    )
Jinjing Zhou's avatar
Jinjing Zhou committed
397
398

    pserver_list = []
399
    ctx = mp.get_context("spawn")
Jinjing Zhou's avatar
Jinjing Zhou committed
400
    for i in range(num_server):
401
402
403
404
405
406
407
408
409
410
        p = ctx.Process(
            target=start_server,
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
            ),
        )
Jinjing Zhou's avatar
Jinjing Zhou committed
411
412
413
414
        p.start()
        time.sleep(1)
        pserver_list.append(p)

415
416
417
418
    pclient_list = []
    num_clients = 1
    for client_id in range(num_clients):
        for group_id in range(num_groups):
419
420
421
422
423
424
425
426
427
428
429
430
431
            p = ctx.Process(
                target=start_sample_client_shuffle,
                args=(
                    client_id,
                    tmpdir,
                    num_server > 1,
                    g,
                    num_server,
                    group_id,
                    orig_nids,
                    orig_eids,
                ),
            )
432
            p.start()
433
            time.sleep(1)  # avoid race condition when instantiating DistGraph
434
435
436
            pclient_list.append(p)
    for p in pclient_list:
        p.join()
437
        assert p.exitcode == 0
Jinjing Zhou's avatar
Jinjing Zhou committed
438
439
    for p in pserver_list:
        p.join()
440
        assert p.exitcode == 0
Jinjing Zhou's avatar
Jinjing Zhou committed
441

442

443
def start_hetero_sample_client(rank, tmpdir, disable_shared_mem, nodes):
444
445
    gpb = None
    if disable_shared_mem:
446
447
448
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
449
    dgl.distributed.initialize("rpc_ip_config.txt")
450
    dist_graph = DistGraph("test_sampling", gpb=gpb)
451
452
453
    assert "feat" in dist_graph.nodes["n1"].data
    assert "feat" not in dist_graph.nodes["n2"].data
    assert "feat" not in dist_graph.nodes["n3"].data
454
455
456
457
458
459
460
    if gpb is None:
        gpb = dist_graph.get_partition_book()
    try:
        sampled_graph = sample_neighbors(dist_graph, nodes, 3)
        block = dgl.to_block(sampled_graph, nodes)
        block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
    except Exception as e:
461
        print(traceback.format_exc())
462
463
464
465
        block = None
    dgl.distributed.exit_client()
    return block, gpb

466
467
468
469
470
471
472
473
474

def start_hetero_etype_sample_client(
    rank,
    tmpdir,
    disable_shared_mem,
    fanout=3,
    nodes={"n3": [0, 10, 99, 66, 124, 208]},
    etype_sorted=False,
):
475
476
    gpb = None
    if disable_shared_mem:
477
478
479
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
480
481
    dgl.distributed.initialize("rpc_ip_config.txt")
    dist_graph = DistGraph("test_sampling", gpb=gpb)
482
483
484
    assert "feat" in dist_graph.nodes["n1"].data
    assert "feat" not in dist_graph.nodes["n2"].data
    assert "feat" not in dist_graph.nodes["n3"].data
485
486
487
488
489
490

    if dist_graph.local_partition is not None:
        # Check whether etypes are sorted in dist_graph
        local_g = dist_graph.local_partition
        local_nids = np.arange(local_g.num_nodes())
        for lnid in local_nids:
491
            leids = local_g.in_edges(lnid, form="eid")
492
493
494
495
            letids = F.asnumpy(local_g.edata[dgl.ETYPE][leids])
            _, idices = np.unique(letids, return_index=True)
            assert np.all(idices[:-1] <= idices[1:])

496
497
498
    if gpb is None:
        gpb = dist_graph.get_partition_book()
    try:
499
        sampled_graph = sample_etype_neighbors(
500
501
            dist_graph, nodes, fanout, etype_sorted=etype_sorted
        )
502
503
504
        block = dgl.to_block(sampled_graph, nodes)
        block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
    except Exception as e:
505
        print(traceback.format_exc())
506
507
508
509
        block = None
    dgl.distributed.exit_client()
    return block, gpb

510

511
def check_rpc_hetero_sampling_shuffle(tmpdir, num_server):
512
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
513
514
515
516
517

    g = create_random_hetero()
    num_parts = num_server
    num_hops = 1

518
519
520
521
522
523
524
525
526
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
    )
527
528

    pserver_list = []
529
    ctx = mp.get_context("spawn")
530
    for i in range(num_server):
531
532
533
534
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
535
536
537
538
        p.start()
        time.sleep(1)
        pserver_list.append(p)

539
540
541
    block, gpb = start_hetero_sample_client(
        0, tmpdir, num_server > 1, nodes={"n3": [0, 10, 99, 66, 124, 208]}
    )
542
543
544
    print("Done sampling")
    for p in pserver_list:
        p.join()
545
        assert p.exitcode == 0
546

547
548
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
549
550
551
552
553
554
555
556
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        shuffled_eid = block.edges[etype].data[dgl.EID]

        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
557
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
558
559

        # Check the node Ids and edge Ids.
560
561
562
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)
563

564

565
566
567
568
569
570
571
572
573
def get_degrees(g, nids, ntype):
    deg = F.zeros((len(nids),), dtype=F.int64)
    for srctype, etype, dsttype in g.canonical_etypes:
        if srctype == ntype:
            deg += g.out_degrees(u=nids, etype=etype)
        elif dsttype == ntype:
            deg += g.in_degrees(v=nids, etype=etype)
    return deg

574

575
def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server):
576
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
577
578
579
580
581

    g = create_random_hetero(empty=True)
    num_parts = num_server
    num_hops = 1

582
583
584
585
586
587
588
589
590
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
    )
591
592

    pserver_list = []
593
    ctx = mp.get_context("spawn")
594
    for i in range(num_server):
595
596
597
598
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
599
600
601
602
        p.start()
        time.sleep(1)
        pserver_list.append(p)

603
    deg = get_degrees(g, orig_nids["n3"], "n3")
604
    empty_nids = F.nonzero_1d(deg == 0)
605
606
607
    block, gpb = start_hetero_sample_client(
        0, tmpdir, num_server > 1, nodes={"n3": empty_nids}
    )
608
609
610
    print("Done sampling")
    for p in pserver_list:
        p.join()
611
        assert p.exitcode == 0
612

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
613
    assert block.num_edges() == 0
614
615
    assert len(block.etypes) == len(g.etypes)

616
617
618
619

def check_rpc_hetero_etype_sampling_shuffle(
    tmpdir, num_server, graph_formats=None
):
620
621
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

622
623
624
625
    g = create_random_hetero(dense=True)
    num_parts = num_server
    num_hops = 1

626
627
628
629
630
631
632
633
634
635
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
        graph_formats=graph_formats,
    )
636
637

    pserver_list = []
638
    ctx = mp.get_context("spawn")
639
    for i in range(num_server):
640
641
642
643
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling", ["csc", "coo"]),
        )
644
645
646
647
        p.start()
        time.sleep(1)
        pserver_list.append(p)

648
    fanout = {etype: 3 for etype in g.canonical_etypes}
649
650
    etype_sorted = False
    if graph_formats is not None:
651
652
653
654
655
656
657
658
659
        etype_sorted = "csc" in graph_formats or "csr" in graph_formats
    block, gpb = start_hetero_etype_sample_client(
        0,
        tmpdir,
        num_server > 1,
        fanout,
        nodes={"n3": [0, 10, 99, 66, 124, 208]},
        etype_sorted=etype_sorted,
    )
660
661
662
    print("Done sampling")
    for p in pserver_list:
        p.join()
663
        assert p.exitcode == 0
664

665
    src, dst = block.edges(etype=("n1", "r13", "n3"))
666
    assert len(src) == 18
667
    src, dst = block.edges(etype=("n2", "r23", "n3"))
668
669
    assert len(src) == 18

670
671
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
672
673
674
675
676
677
678
679
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        shuffled_eid = block.edges[etype].data[dgl.EID]

        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
680
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
681
682
683
684
685
686

        # Check the node Ids and edge Ids.
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)

687

688
def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server):
689
690
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

691
692
693
694
    g = create_random_hetero(dense=True, empty=True)
    num_parts = num_server
    num_hops = 1

695
696
697
698
699
700
701
702
703
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
    )
704
705

    pserver_list = []
706
    ctx = mp.get_context("spawn")
707
    for i in range(num_server):
708
709
710
711
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
712
713
714
715
716
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    fanout = 3
717
    deg = get_degrees(g, orig_nids["n3"], "n3")
718
    empty_nids = F.nonzero_1d(deg == 0)
719
720
721
    block, gpb = start_hetero_etype_sample_client(
        0, tmpdir, num_server > 1, fanout, nodes={"n3": empty_nids}
    )
722
723
724
    print("Done sampling")
    for p in pserver_list:
        p.join()
725
        assert p.exitcode == 0
726

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
727
    assert block.num_edges() == 0
728
729
    assert len(block.etypes) == len(g.etypes)

730
731

def create_random_bipartite():
732
733
734
735
736
737
738
    g = dgl.rand_bipartite("user", "buys", "game", 500, 1000, 1000)
    g.nodes["user"].data["feat"] = F.ones(
        (g.num_nodes("user"), 10), F.float32, F.cpu()
    )
    g.nodes["game"].data["feat"] = F.ones(
        (g.num_nodes("game"), 10), F.float32, F.cpu()
    )
739
740
741
742
743
744
745
    return g


def start_bipartite_sample_client(rank, tmpdir, disable_shared_mem, nodes):
    gpb = None
    if disable_shared_mem:
        _, _, _, gpb, _, _, _ = load_partition(
746
747
            tmpdir / "test_sampling.json", rank
        )
748
749
    dgl.distributed.initialize("rpc_ip_config.txt")
    dist_graph = DistGraph("test_sampling", gpb=gpb)
750
751
    assert "feat" in dist_graph.nodes["user"].data
    assert "feat" in dist_graph.nodes["game"].data
752
753
754
755
756
757
758
759
760
761
    if gpb is None:
        gpb = dist_graph.get_partition_book()
    sampled_graph = sample_neighbors(dist_graph, nodes, 3)
    block = dgl.to_block(sampled_graph, nodes)
    if sampled_graph.num_edges() > 0:
        block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
    dgl.distributed.exit_client()
    return block, gpb


762
763
764
def start_bipartite_etype_sample_client(
    rank, tmpdir, disable_shared_mem, fanout=3, nodes={}
):
765
766
767
    gpb = None
    if disable_shared_mem:
        _, _, _, gpb, _, _, _ = load_partition(
768
769
            tmpdir / "test_sampling.json", rank
        )
770
771
    dgl.distributed.initialize("rpc_ip_config.txt")
    dist_graph = DistGraph("test_sampling", gpb=gpb)
772
773
    assert "feat" in dist_graph.nodes["user"].data
    assert "feat" in dist_graph.nodes["game"].data
774
775
776
777
778
779

    if dist_graph.local_partition is not None:
        # Check whether etypes are sorted in dist_graph
        local_g = dist_graph.local_partition
        local_nids = np.arange(local_g.num_nodes())
        for lnid in local_nids:
780
            leids = local_g.in_edges(lnid, form="eid")
781
782
783
784
785
786
            letids = F.asnumpy(local_g.edata[dgl.ETYPE][leids])
            _, idices = np.unique(letids, return_index=True)
            assert np.all(idices[:-1] <= idices[1:])

    if gpb is None:
        gpb = dist_graph.get_partition_book()
787
    sampled_graph = sample_etype_neighbors(dist_graph, nodes, fanout)
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
    block = dgl.to_block(sampled_graph, nodes)
    if sampled_graph.num_edges() > 0:
        block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
    dgl.distributed.exit_client()
    return block, gpb


def check_rpc_bipartite_sampling_empty(tmpdir, num_server):
    """sample on bipartite via sample_neighbors() which yields empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

803
804
805
806
807
808
809
810
811
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
    )
812
813

    pserver_list = []
814
    ctx = mp.get_context("spawn")
815
    for i in range(num_server):
816
817
818
819
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
820
821
822
823
        p.start()
        time.sleep(1)
        pserver_list.append(p)

824
    deg = get_degrees(g, orig_nids["game"], "game")
825
    empty_nids = F.nonzero_1d(deg == 0)
826
827
828
    block, _ = start_bipartite_sample_client(
        0, tmpdir, num_server > 1, nodes={"game": empty_nids, "user": [1]}
    )
829
830
831
832

    print("Done sampling")
    for p in pserver_list:
        p.join()
833
        assert p.exitcode == 0
834

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
835
    assert block.num_edges() == 0
836
837
838
839
840
841
842
843
844
845
846
    assert len(block.etypes) == len(g.etypes)


def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server):
    """sample on bipartite via sample_neighbors() which yields non-empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

847
848
849
850
851
852
853
854
855
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
    )
856
857

    pserver_list = []
858
    ctx = mp.get_context("spawn")
859
    for i in range(num_server):
860
861
862
863
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
864
865
866
867
        p.start()
        time.sleep(1)
        pserver_list.append(p)

868
    deg = get_degrees(g, orig_nid_map["game"], "game")
869
    nids = F.nonzero_1d(deg > 0)
870
871
872
    block, gpb = start_bipartite_sample_client(
        0, tmpdir, num_server > 1, nodes={"game": nids, "user": [0]}
    )
873
874
875
    print("Done sampling")
    for p in pserver_list:
        p.join()
876
        assert p.exitcode == 0
877

878
879
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
880
881
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
882
883
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
884
885
        shuffled_eid = block.edges[etype].data[dgl.EID]

886
887
        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
888
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903

        # Check the node Ids and edge Ids.
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)


def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server):
    """sample on bipartite via sample_etype_neighbors() which yields empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

904
905
906
907
908
909
910
911
912
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
    )
913
914

    pserver_list = []
915
    ctx = mp.get_context("spawn")
916
    for i in range(num_server):
917
918
919
920
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
921
922
923
924
        p.start()
        time.sleep(1)
        pserver_list.append(p)

925
    deg = get_degrees(g, orig_nids["game"], "game")
926
    empty_nids = F.nonzero_1d(deg == 0)
927
928
929
    block, gpb = start_bipartite_etype_sample_client(
        0, tmpdir, num_server > 1, nodes={"game": empty_nids, "user": [1]}
    )
930
931
932
933

    print("Done sampling")
    for p in pserver_list:
        p.join()
934
        assert p.exitcode == 0
935
936

    assert block is not None
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
937
    assert block.num_edges() == 0
938
939
940
941
942
943
944
945
946
947
948
    assert len(block.etypes) == len(g.etypes)


def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
    """sample on bipartite via sample_etype_neighbors() which yields non-empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

949
950
951
952
953
954
955
956
957
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
    )
958
959

    pserver_list = []
960
    ctx = mp.get_context("spawn")
961
    for i in range(num_server):
962
963
964
965
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
966
967
968
969
970
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    fanout = 3
971
    deg = get_degrees(g, orig_nid_map["game"], "game")
972
    nids = F.nonzero_1d(deg > 0)
973
974
975
    block, gpb = start_bipartite_etype_sample_client(
        0, tmpdir, num_server > 1, fanout, nodes={"game": nids, "user": [0]}
    )
976
977
978
    print("Done sampling")
    for p in pserver_list:
        p.join()
979
        assert p.exitcode == 0
980

981
982
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
983
984
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
985
986
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
987
988
        shuffled_eid = block.edges[etype].data[dgl.EID]

989
990
        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
991
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
992
993
994
995
996
997

        # Check the node Ids and edge Ids.
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)

998

Jinjing Zhou's avatar
Jinjing Zhou committed
999
# Wait non shared memory graph store
1000
1001
1002
1003
1004
1005
1006
1007
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
1008
@pytest.mark.parametrize("num_server", [1])
1009
def test_rpc_sampling_shuffle(num_server):
1010
    reset_envs()
Jinjing Zhou's avatar
Jinjing Zhou committed
1011
    import tempfile
1012
1013

    os.environ["DGL_DIST_MODE"] = "distributed"
Jinjing Zhou's avatar
Jinjing Zhou committed
1014
    with tempfile.TemporaryDirectory() as tmpdirname:
1015
        check_rpc_sampling_shuffle(Path(tmpdirname), num_server)
1016
1017
        # [TODO][Rhett] Tests for multiple groups may fail sometimes and
        # root cause is unknown. Let's disable them for now.
1018
        # check_rpc_sampling_shuffle(Path(tmpdirname), num_server, num_groups=2)
1019
        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), num_server)
1020
        check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), num_server)
1021
        check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), num_server)
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
        check_rpc_hetero_etype_sampling_shuffle(
            Path(tmpdirname), num_server, ["csc"]
        )
        check_rpc_hetero_etype_sampling_shuffle(
            Path(tmpdirname), num_server, ["csr"]
        )
        check_rpc_hetero_etype_sampling_shuffle(
            Path(tmpdirname), num_server, ["csc", "coo"]
        )
        check_rpc_hetero_etype_sampling_empty_shuffle(
            Path(tmpdirname), num_server
        )
1034
1035
1036
1037
        check_rpc_bipartite_sampling_empty(Path(tmpdirname), num_server)
        check_rpc_bipartite_sampling_shuffle(Path(tmpdirname), num_server)
        check_rpc_bipartite_etype_sampling_empty(Path(tmpdirname), num_server)
        check_rpc_bipartite_etype_sampling_shuffle(Path(tmpdirname), num_server)
Jinjing Zhou's avatar
Jinjing Zhou committed
1038

1039

1040
def check_standalone_sampling(tmpdir):
1041
    g = CitationGraphDataset("cora")[0]
1042
    prob = np.maximum(np.random.randn(g.num_edges()), 0)
1043
1044
1045
    mask = prob > 0
    g.edata["prob"] = F.tensor(prob)
    g.edata["mask"] = F.tensor(mask)
1046
1047
    num_parts = 1
    num_hops = 1
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
    partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )

    os.environ["DGL_DIST_MODE"] = "standalone"
1058
    dgl.distributed.initialize("rpc_ip_config.txt")
1059
1060
1061
    dist_graph = DistGraph(
        "test_sampling", part_config=tmpdir / "test_sampling.json"
    )
1062
1063
1064
    sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)

    src, dst = sampled_graph.edges()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1065
    assert sampled_graph.num_nodes() == g.num_nodes()
1066
1067
1068
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    assert np.array_equal(
1069
1070
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)
    )
1071
1072

    sampled_graph = sample_neighbors(
1073
1074
        dist_graph, [0, 10, 99, 66, 1024, 2008], 3, prob="mask"
    )
1075
1076
1077
1078
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert mask[eid].all()

    sampled_graph = sample_neighbors(
1079
1080
        dist_graph, [0, 10, 99, 66, 1024, 2008], 3, prob="prob"
    )
1081
1082
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert (prob[eid] > 0).all()
1083
    dgl.distributed.exit_client()
1084

1085

1086
def check_standalone_etype_sampling(tmpdir):
1087
    hg = CitationGraphDataset("cora")[0]
1088
    prob = np.maximum(np.random.randn(hg.num_edges()), 0)
1089
1090
1091
    mask = prob > 0
    hg.edata["prob"] = F.tensor(prob)
    hg.edata["mask"] = F.tensor(mask)
1092
1093
1094
    num_parts = 1
    num_hops = 1

1095
1096
1097
1098
1099
1100
1101
1102
1103
    partition_graph(
        hg,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )
    os.environ["DGL_DIST_MODE"] = "standalone"
1104
    dgl.distributed.initialize("rpc_ip_config.txt")
1105
1106
1107
    dist_graph = DistGraph(
        "test_sampling", part_config=tmpdir / "test_sampling.json"
    )
1108
    sampled_graph = sample_etype_neighbors(dist_graph, [0, 10, 99, 66, 1023], 3)
1109
1110

    src, dst = sampled_graph.edges()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1111
    assert sampled_graph.num_nodes() == hg.num_nodes()
1112
1113
1114
    assert np.all(F.asnumpy(hg.has_edges_between(src, dst)))
    eids = hg.edge_ids(src, dst)
    assert np.array_equal(
1115
1116
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)
    )
1117
1118

    sampled_graph = sample_etype_neighbors(
1119
1120
        dist_graph, [0, 10, 99, 66, 1023], 3, prob="mask"
    )
1121
1122
1123
1124
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert mask[eid].all()

    sampled_graph = sample_etype_neighbors(
1125
1126
        dist_graph, [0, 10, 99, 66, 1023], 3, prob="prob"
    )
1127
1128
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert (prob[eid] > 0).all()
1129
1130
    dgl.distributed.exit_client()

1131

1132
def check_standalone_etype_sampling_heterograph(tmpdir):
1133
    hg = CitationGraphDataset("cora")[0]
1134
1135
1136
    num_parts = 1
    num_hops = 1
    src, dst = hg.edges()
1137
1138
1139
1140
1141
    new_hg = dgl.heterograph(
        {
            ("paper", "cite", "paper"): (src, dst),
            ("paper", "cite-by", "paper"): (dst, src),
        },
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1142
        {"paper": hg.num_nodes()},
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
    )
    partition_graph(
        new_hg,
        "test_hetero_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )
    os.environ["DGL_DIST_MODE"] = "standalone"
1153
    dgl.distributed.initialize("rpc_ip_config.txt")
1154
1155
1156
    dist_graph = DistGraph(
        "test_hetero_sampling", part_config=tmpdir / "test_hetero_sampling.json"
    )
1157
    sampled_graph = sample_etype_neighbors(
1158
1159
1160
        dist_graph, [0, 1, 2, 10, 99, 66, 1023, 1024, 2700, 2701], 1
    )
    src, dst = sampled_graph.edges(etype=("paper", "cite", "paper"))
1161
    assert len(src) == 10
1162
    src, dst = sampled_graph.edges(etype=("paper", "cite-by", "paper"))
1163
    assert len(src) == 10
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1164
    assert sampled_graph.num_nodes() == new_hg.num_nodes()
1165
1166
    dgl.distributed.exit_client()

1167
1168
1169
1170
1171
1172

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
1173
def test_standalone_sampling():
1174
    reset_envs()
1175
    import tempfile
1176
1177

    os.environ["DGL_DIST_MODE"] = "standalone"
1178
    with tempfile.TemporaryDirectory() as tmpdirname:
1179
        check_standalone_sampling(Path(tmpdirname))
1180

1181

1182
1183
def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes):
    gpb = None
1184
    dgl.distributed.initialize("rpc_ip_config.txt")
1185
    if disable_shared_mem:
1186
1187
1188
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_in_subgraph.json", rank
        )
1189
    dist_graph = DistGraph("test_in_subgraph", gpb=gpb)
1190
1191
1192
    try:
        sampled_graph = dgl.distributed.in_subgraph(dist_graph, nodes)
    except Exception as e:
1193
        print(traceback.format_exc())
1194
        sampled_graph = None
1195
    dgl.distributed.exit_client()
1196
1197
1198
    return sampled_graph


1199
def check_rpc_in_subgraph_shuffle(tmpdir, num_server):
1200
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
1201
1202
1203
1204

    g = CitationGraphDataset("cora")[0]
    num_parts = num_server

1205
1206
1207
1208
1209
1210
1211
1212
1213
    orig_nid, orig_eid = partition_graph(
        g,
        "test_in_subgraph",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
1214
1215

    pserver_list = []
1216
    ctx = mp.get_context("spawn")
1217
    for i in range(num_server):
1218
1219
1220
1221
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_in_subgraph"),
        )
1222
1223
1224
1225
1226
1227
1228
1229
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    nodes = [0, 10, 99, 66, 1024, 2008]
    sampled_graph = start_in_subgraph_client(0, tmpdir, num_server > 1, nodes)
    for p in pserver_list:
        p.join()
1230
        assert p.exitcode == 0
1231
1232

    src, dst = sampled_graph.edges()
1233
1234
    src = orig_nid[src]
    dst = orig_nid[dst]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1235
    assert sampled_graph.num_nodes() == g.num_nodes()
1236
1237
1238
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))

    subg1 = dgl.in_subgraph(g, orig_nid[nodes])
1239
1240
1241
1242
    src1, dst1 = subg1.edges()
    assert np.all(np.sort(F.asnumpy(src)) == np.sort(F.asnumpy(src1)))
    assert np.all(np.sort(F.asnumpy(dst)) == np.sort(F.asnumpy(dst1)))
    eids = g.edge_ids(src, dst)
1243
1244
    eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
    assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))
1245

1246
1247
1248
1249
1250
1251

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
1252
def test_rpc_in_subgraph():
1253
    reset_envs()
1254
    import tempfile
1255
1256

    os.environ["DGL_DIST_MODE"] = "distributed"
1257
    with tempfile.TemporaryDirectory() as tmpdirname:
1258
        check_rpc_in_subgraph_shuffle(Path(tmpdirname), 1)
1259

1260
1261
1262
1263
1264
1265
1266
1267
1268

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
1269
def test_standalone_etype_sampling():
1270
    reset_envs()
1271
    import tempfile
1272

1273
    with tempfile.TemporaryDirectory() as tmpdirname:
1274
        os.environ["DGL_DIST_MODE"] = "standalone"
1275
        check_standalone_etype_sampling_heterograph(Path(tmpdirname))
1276
    with tempfile.TemporaryDirectory() as tmpdirname:
1277
        os.environ["DGL_DIST_MODE"] = "standalone"
1278
        check_standalone_etype_sampling(Path(tmpdirname))
1279

1280

Jinjing Zhou's avatar
Jinjing Zhou committed
1281
1282
if __name__ == "__main__":
    import tempfile
1283

Jinjing Zhou's avatar
Jinjing Zhou committed
1284
    with tempfile.TemporaryDirectory() as tmpdirname:
1285
        os.environ["DGL_DIST_MODE"] = "standalone"
1286
        check_standalone_etype_sampling_heterograph(Path(tmpdirname))
1287
1288

    with tempfile.TemporaryDirectory() as tmpdirname:
1289
        os.environ["DGL_DIST_MODE"] = "standalone"
1290
1291
        check_standalone_etype_sampling(Path(tmpdirname))
        check_standalone_sampling(Path(tmpdirname))
1292
        os.environ["DGL_DIST_MODE"] = "distributed"
1293
1294
        check_rpc_sampling(Path(tmpdirname), 2)
        check_rpc_sampling(Path(tmpdirname), 1)
1295
1296
        check_rpc_get_degree_shuffle(Path(tmpdirname), 1)
        check_rpc_get_degree_shuffle(Path(tmpdirname), 2)
1297
1298
        check_rpc_find_edges_shuffle(Path(tmpdirname), 2)
        check_rpc_find_edges_shuffle(Path(tmpdirname), 1)
1299
1300
        check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), 2)
1301
1302
1303
1304
        check_rpc_in_subgraph_shuffle(Path(tmpdirname), 2)
        check_rpc_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), 2)
1305
1306
1307
1308
        check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), 2)
        check_rpc_hetero_etype_sampling_empty_shuffle(Path(tmpdirname), 1)