"vscode:/vscode.git/clone" did not exist on "bd4df2856ae399dc55e4ded57164eff8ecf6cb65"
test_distributed_sampling.py 40.6 KB
Newer Older
1
import multiprocessing as mp
Jinjing Zhou's avatar
Jinjing Zhou committed
2
import os
3
import random
Jinjing Zhou's avatar
Jinjing Zhou committed
4
5
import sys
import time
6
7
import traceback
import unittest
Jinjing Zhou's avatar
Jinjing Zhou committed
8
from pathlib import Path
9
10
11
12

import backend as F
import dgl
import numpy as np
13
import pytest
14
15
16
17
18
19
20
21
22
23
from dgl.data import CitationGraphDataset, WN18Dataset
from dgl.distributed import (
    DistGraph,
    DistGraphServer,
    load_partition,
    load_partition_book,
    partition_graph,
    sample_etype_neighbors,
    sample_neighbors,
)
24
from scipy import sparse as spsp
25
from utils import generate_ip_config, reset_envs
Jinjing Zhou's avatar
Jinjing Zhou committed
26
27


28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def start_server(
    rank,
    tmpdir,
    disable_shared_mem,
    graph_name,
    graph_format=["csc", "coo"],
    keep_alive=False,
):
    g = DistGraphServer(
        rank,
        "rpc_ip_config.txt",
        1,
        1,
        tmpdir / (graph_name + ".json"),
        disable_shared_mem=disable_shared_mem,
        graph_format=graph_format,
        keep_alive=keep_alive,
    )
Jinjing Zhou's avatar
Jinjing Zhou committed
46
47
48
    g.start()


49
def start_sample_client(rank, tmpdir, disable_shared_mem):
50
51
    gpb = None
    if disable_shared_mem:
52
53
54
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
55
    dgl.distributed.initialize("rpc_ip_config.txt")
56
    dist_graph = DistGraph("test_sampling", gpb=gpb)
57
    try:
58
59
60
        sampled_graph = sample_neighbors(
            dist_graph, [0, 10, 99, 66, 1024, 2008], 3
        )
61
    except Exception as e:
62
        print(traceback.format_exc())
63
        sampled_graph = None
64
    dgl.distributed.exit_client()
Jinjing Zhou's avatar
Jinjing Zhou committed
65
66
    return sampled_graph

67

68
69
70
71
72
73
74
75
76
77
78
def start_sample_client_shuffle(
    rank,
    tmpdir,
    disable_shared_mem,
    g,
    num_servers,
    group_id,
    orig_nid,
    orig_eid,
):
    os.environ["DGL_GROUP_ID"] = str(group_id)
79
80
    gpb = None
    if disable_shared_mem:
81
82
83
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
84
85
86
87
88
89
90
91
92
93
94
95
96
    dgl.distributed.initialize("rpc_ip_config.txt")
    dist_graph = DistGraph("test_sampling", gpb=gpb)
    sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)

    src, dst = sampled_graph.edges()
    src = orig_nid[src]
    dst = orig_nid[dst]
    assert sampled_graph.number_of_nodes() == g.number_of_nodes()
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
    assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))

97

98
def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids, etype=None):
99
100
    gpb = None
    if disable_shared_mem:
101
102
103
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_find_edges.json", rank
        )
104
    dgl.distributed.initialize("rpc_ip_config.txt")
105
    dist_graph = DistGraph("test_find_edges", gpb=gpb)
106
    try:
107
        u, v = dist_graph.find_edges(eids, etype=etype)
108
    except Exception as e:
109
        print(traceback.format_exc())
110
        u, v = None, None
111
112
    dgl.distributed.exit_client()
    return u, v
Jinjing Zhou's avatar
Jinjing Zhou committed
113

114

115
116
117
def start_get_degrees_client(rank, tmpdir, disable_shared_mem, nids=None):
    gpb = None
    if disable_shared_mem:
118
119
120
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_get_degrees.json", rank
        )
121
    dgl.distributed.initialize("rpc_ip_config.txt")
122
123
124
125
126
127
128
    dist_graph = DistGraph("test_get_degrees", gpb=gpb)
    try:
        in_deg = dist_graph.in_degrees(nids)
        all_in_deg = dist_graph.in_degrees()
        out_deg = dist_graph.out_degrees(nids)
        all_out_deg = dist_graph.out_degrees()
    except Exception as e:
129
        print(traceback.format_exc())
130
131
132
133
        in_deg, out_deg, all_in_deg, all_out_deg = None, None, None, None
    dgl.distributed.exit_client()
    return in_deg, out_deg, all_in_deg, all_out_deg

134

135
def check_rpc_sampling(tmpdir, num_server):
136
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
Jinjing Zhou's avatar
Jinjing Zhou committed
137
138
139
140
141
142

    g = CitationGraphDataset("cora")[0]
    print(g.idtype)
    num_parts = num_server
    num_hops = 1

143
144
145
146
147
148
149
150
    partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )
Jinjing Zhou's avatar
Jinjing Zhou committed
151
152

    pserver_list = []
153
    ctx = mp.get_context("spawn")
Jinjing Zhou's avatar
Jinjing Zhou committed
154
    for i in range(num_server):
155
156
157
158
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
Jinjing Zhou's avatar
Jinjing Zhou committed
159
160
161
162
        p.start()
        time.sleep(1)
        pserver_list.append(p)

163
    sampled_graph = start_sample_client(0, tmpdir, num_server > 1)
Jinjing Zhou's avatar
Jinjing Zhou committed
164
165
166
    print("Done sampling")
    for p in pserver_list:
        p.join()
167
        assert p.exitcode == 0
Jinjing Zhou's avatar
Jinjing Zhou committed
168
169
170
171
172
173

    src, dst = sampled_graph.edges()
    assert sampled_graph.number_of_nodes() == g.number_of_nodes()
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    assert np.array_equal(
174
175
176
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)
    )

Jinjing Zhou's avatar
Jinjing Zhou committed
177

178
def check_rpc_find_edges_shuffle(tmpdir, num_server):
179
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
180
181
182
183

    g = CitationGraphDataset("cora")[0]
    num_parts = num_server

184
185
186
187
188
189
190
191
192
    orig_nid, orig_eid = partition_graph(
        g,
        "test_find_edges",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
193
194

    pserver_list = []
195
    ctx = mp.get_context("spawn")
196
    for i in range(num_server):
197
198
199
200
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_find_edges", ["csr", "coo"]),
        )
201
202
203
204
205
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    eids = F.tensor(np.random.randint(g.number_of_edges(), size=100))
206
    u, v = g.find_edges(orig_eid[eids])
207
    du, dv = start_find_edges_client(0, tmpdir, num_server > 1, eids)
208
209
    du = orig_nid[du]
    dv = orig_nid[dv]
210
211
212
    assert F.array_equal(u, du)
    assert F.array_equal(v, dv)

213

214
def create_random_hetero(dense=False, empty=False):
215
216
217
218
219
220
    num_nodes = (
        {"n1": 210, "n2": 200, "n3": 220}
        if dense
        else {"n1": 1010, "n2": 1000, "n3": 1020}
    )
    etypes = [("n1", "r12", "n2"), ("n1", "r13", "n3"), ("n2", "r23", "n3")]
221
    edges = {}
222
    random.seed(42)
223
224
    for etype in etypes:
        src_ntype, _, dst_ntype = etype
225
226
227
228
229
230
231
        arr = spsp.random(
            num_nodes[src_ntype] - 10 if empty else num_nodes[src_ntype],
            num_nodes[dst_ntype] - 10 if empty else num_nodes[dst_ntype],
            density=0.1 if dense else 0.001,
            format="coo",
            random_state=100,
        )
232
        edges[etype] = (arr.row, arr.col)
233
    g = dgl.heterograph(edges, num_nodes)
234
235
236
    g.nodes["n1"].data["feat"] = F.ones(
        (g.number_of_nodes("n1"), 10), F.float32, F.cpu()
    )
237
    return g
238

239

240
def check_rpc_hetero_find_edges_shuffle(tmpdir, num_server):
241
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
242
243
244
245

    g = create_random_hetero()
    num_parts = num_server

246
247
248
249
250
251
252
253
254
    orig_nid, orig_eid = partition_graph(
        g,
        "test_find_edges",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
255
256

    pserver_list = []
257
    ctx = mp.get_context("spawn")
258
    for i in range(num_server):
259
260
261
262
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_find_edges", ["csr", "coo"]),
        )
263
264
265
266
        p.start()
        time.sleep(1)
        pserver_list.append(p)

267
    test_etype = g.to_canonical_etype("r12")
268
    eids = F.tensor(np.random.randint(g.num_edges(test_etype), size=100))
269
270
    expect_except = False
    try:
271
        _, _ = g.find_edges(orig_eid[test_etype][eids], etype=("n1", "r12"))
272
273
274
    except:
        expect_except = True
    assert expect_except
275
276
    u, v = g.find_edges(orig_eid[test_etype][eids], etype="r12")
    u1, v1 = g.find_edges(orig_eid[test_etype][eids], etype=("n1", "r12", "n2"))
277
278
    assert F.array_equal(u, u1)
    assert F.array_equal(v, v1)
279
280
281
282
283
    du, dv = start_find_edges_client(
        0, tmpdir, num_server > 1, eids, etype="r12"
    )
    du = orig_nid["n1"][du]
    dv = orig_nid["n2"][dv]
284
285
286
    assert F.array_equal(u, du)
    assert F.array_equal(v, dv)

287

288
# Wait non shared memory graph store
289
290
291
292
293
294
295
296
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
297
298
@pytest.mark.parametrize("num_server", [1, 2])
def test_rpc_find_edges_shuffle(num_server):
299
    reset_envs()
300
    import tempfile
301
302

    os.environ["DGL_DIST_MODE"] = "distributed"
303
    with tempfile.TemporaryDirectory() as tmpdirname:
304
        check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), num_server)
305
306
        check_rpc_find_edges_shuffle(Path(tmpdirname), num_server)

307

308
def check_rpc_get_degree_shuffle(tmpdir, num_server):
309
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
310
311
312
313

    g = CitationGraphDataset("cora")[0]
    num_parts = num_server

314
315
316
317
318
319
320
321
322
    orig_nid, _ = partition_graph(
        g,
        "test_get_degrees",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
323
324

    pserver_list = []
325
    ctx = mp.get_context("spawn")
326
    for i in range(num_server):
327
328
329
330
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_get_degrees"),
        )
331
332
333
334
335
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    nids = F.tensor(np.random.randint(g.number_of_nodes(), size=100))
336
337
338
    in_degs, out_degs, all_in_degs, all_out_degs = start_get_degrees_client(
        0, tmpdir, num_server > 1, nids
    )
339
340
341
342

    print("Done get_degree")
    for p in pserver_list:
        p.join()
343
        assert p.exitcode == 0
344

345
    print("check results")
346
347
348
349
350
    assert F.array_equal(g.in_degrees(orig_nid[nids]), in_degs)
    assert F.array_equal(g.in_degrees(orig_nid), all_in_degs)
    assert F.array_equal(g.out_degrees(orig_nid[nids]), out_degs)
    assert F.array_equal(g.out_degrees(orig_nid), all_out_degs)

351

352
# Wait non shared memory graph store
353
354
355
356
357
358
359
360
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
361
362
@pytest.mark.parametrize("num_server", [1, 2])
def test_rpc_get_degree_shuffle(num_server):
363
    reset_envs()
364
    import tempfile
365
366

    os.environ["DGL_DIST_MODE"] = "distributed"
367
368
369
    with tempfile.TemporaryDirectory() as tmpdirname:
        check_rpc_get_degree_shuffle(Path(tmpdirname), num_server)

370
371
372
373

# @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
# @unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
@unittest.skip("Only support partition with shuffle")
Jinjing Zhou's avatar
Jinjing Zhou committed
374
def test_rpc_sampling():
375
    reset_envs()
Jinjing Zhou's avatar
Jinjing Zhou committed
376
    import tempfile
377
378

    os.environ["DGL_DIST_MODE"] = "distributed"
Jinjing Zhou's avatar
Jinjing Zhou committed
379
    with tempfile.TemporaryDirectory() as tmpdirname:
380
        check_rpc_sampling(Path(tmpdirname), 2)
Jinjing Zhou's avatar
Jinjing Zhou committed
381

382

383
def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
384
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
385

Jinjing Zhou's avatar
Jinjing Zhou committed
386
387
388
389
    g = CitationGraphDataset("cora")[0]
    num_parts = num_server
    num_hops = 1

390
391
392
393
394
395
396
397
398
    orig_nids, orig_eids = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
    )
Jinjing Zhou's avatar
Jinjing Zhou committed
399
400

    pserver_list = []
401
    ctx = mp.get_context("spawn")
402
    keep_alive = num_groups > 1
Jinjing Zhou's avatar
Jinjing Zhou committed
403
    for i in range(num_server):
404
405
406
407
408
409
410
411
412
413
414
        p = ctx.Process(
            target=start_server,
            args=(
                i,
                tmpdir,
                num_server > 1,
                "test_sampling",
                ["csc", "coo"],
                keep_alive,
            ),
        )
Jinjing Zhou's avatar
Jinjing Zhou committed
415
416
417
418
        p.start()
        time.sleep(1)
        pserver_list.append(p)

419
420
421
422
    pclient_list = []
    num_clients = 1
    for client_id in range(num_clients):
        for group_id in range(num_groups):
423
424
425
426
427
428
429
430
431
432
433
434
435
            p = ctx.Process(
                target=start_sample_client_shuffle,
                args=(
                    client_id,
                    tmpdir,
                    num_server > 1,
                    g,
                    num_server,
                    group_id,
                    orig_nids,
                    orig_eids,
                ),
            )
436
            p.start()
437
            time.sleep(1)  # avoid race condition when instantiating DistGraph
438
439
440
            pclient_list.append(p)
    for p in pclient_list:
        p.join()
441
        assert p.exitcode == 0
442
443
444
445
446
    if keep_alive:
        for p in pserver_list:
            assert p.is_alive()
        # force shutdown server
        dgl.distributed.shutdown_servers("rpc_ip_config.txt", 1)
Jinjing Zhou's avatar
Jinjing Zhou committed
447
448
    for p in pserver_list:
        p.join()
449
        assert p.exitcode == 0
Jinjing Zhou's avatar
Jinjing Zhou committed
450

451

452
def start_hetero_sample_client(rank, tmpdir, disable_shared_mem, nodes):
453
454
    gpb = None
    if disable_shared_mem:
455
456
457
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
458
    dgl.distributed.initialize("rpc_ip_config.txt")
459
    dist_graph = DistGraph("test_sampling", gpb=gpb)
460
461
462
    assert "feat" in dist_graph.nodes["n1"].data
    assert "feat" not in dist_graph.nodes["n2"].data
    assert "feat" not in dist_graph.nodes["n3"].data
463
464
465
466
467
468
469
    if gpb is None:
        gpb = dist_graph.get_partition_book()
    try:
        sampled_graph = sample_neighbors(dist_graph, nodes, 3)
        block = dgl.to_block(sampled_graph, nodes)
        block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
    except Exception as e:
470
        print(traceback.format_exc())
471
472
473
474
        block = None
    dgl.distributed.exit_client()
    return block, gpb

475
476
477
478
479
480
481
482
483

def start_hetero_etype_sample_client(
    rank,
    tmpdir,
    disable_shared_mem,
    fanout=3,
    nodes={"n3": [0, 10, 99, 66, 124, 208]},
    etype_sorted=False,
):
484
485
    gpb = None
    if disable_shared_mem:
486
487
488
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_sampling.json", rank
        )
489
490
    dgl.distributed.initialize("rpc_ip_config.txt")
    dist_graph = DistGraph("test_sampling", gpb=gpb)
491
492
493
    assert "feat" in dist_graph.nodes["n1"].data
    assert "feat" not in dist_graph.nodes["n2"].data
    assert "feat" not in dist_graph.nodes["n3"].data
494
495
496
497
498
499

    if dist_graph.local_partition is not None:
        # Check whether etypes are sorted in dist_graph
        local_g = dist_graph.local_partition
        local_nids = np.arange(local_g.num_nodes())
        for lnid in local_nids:
500
            leids = local_g.in_edges(lnid, form="eid")
501
502
503
504
            letids = F.asnumpy(local_g.edata[dgl.ETYPE][leids])
            _, idices = np.unique(letids, return_index=True)
            assert np.all(idices[:-1] <= idices[1:])

505
506
507
    if gpb is None:
        gpb = dist_graph.get_partition_book()
    try:
508
        sampled_graph = sample_etype_neighbors(
509
510
            dist_graph, nodes, fanout, etype_sorted=etype_sorted
        )
511
512
513
        block = dgl.to_block(sampled_graph, nodes)
        block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
    except Exception as e:
514
        print(traceback.format_exc())
515
516
517
518
        block = None
    dgl.distributed.exit_client()
    return block, gpb

519

520
def check_rpc_hetero_sampling_shuffle(tmpdir, num_server):
521
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
522
523
524
525
526

    g = create_random_hetero()
    num_parts = num_server
    num_hops = 1

527
528
529
530
531
532
533
534
535
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
    )
536
537

    pserver_list = []
538
    ctx = mp.get_context("spawn")
539
    for i in range(num_server):
540
541
542
543
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
544
545
546
547
        p.start()
        time.sleep(1)
        pserver_list.append(p)

548
549
550
    block, gpb = start_hetero_sample_client(
        0, tmpdir, num_server > 1, nodes={"n3": [0, 10, 99, 66, 124, 208]}
    )
551
552
553
    print("Done sampling")
    for p in pserver_list:
        p.join()
554
        assert p.exitcode == 0
555

556
557
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
558
559
560
561
562
563
564
565
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        shuffled_eid = block.edges[etype].data[dgl.EID]

        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
566
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
567
568

        # Check the node Ids and edge Ids.
569
570
571
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)
572

573

574
575
576
577
578
579
580
581
582
def get_degrees(g, nids, ntype):
    deg = F.zeros((len(nids),), dtype=F.int64)
    for srctype, etype, dsttype in g.canonical_etypes:
        if srctype == ntype:
            deg += g.out_degrees(u=nids, etype=etype)
        elif dsttype == ntype:
            deg += g.in_degrees(v=nids, etype=etype)
    return deg

583

584
def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server):
585
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
586
587
588
589
590

    g = create_random_hetero(empty=True)
    num_parts = num_server
    num_hops = 1

591
592
593
594
595
596
597
598
599
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
    )
600
601

    pserver_list = []
602
    ctx = mp.get_context("spawn")
603
    for i in range(num_server):
604
605
606
607
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
608
609
610
611
        p.start()
        time.sleep(1)
        pserver_list.append(p)

612
    deg = get_degrees(g, orig_nids["n3"], "n3")
613
    empty_nids = F.nonzero_1d(deg == 0)
614
615
616
    block, gpb = start_hetero_sample_client(
        0, tmpdir, num_server > 1, nodes={"n3": empty_nids}
    )
617
618
619
    print("Done sampling")
    for p in pserver_list:
        p.join()
620
        assert p.exitcode == 0
621
622
623
624

    assert block.number_of_edges() == 0
    assert len(block.etypes) == len(g.etypes)

625
626
627
628

def check_rpc_hetero_etype_sampling_shuffle(
    tmpdir, num_server, graph_formats=None
):
629
630
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

631
632
633
634
    g = create_random_hetero(dense=True)
    num_parts = num_server
    num_hops = 1

635
636
637
638
639
640
641
642
643
644
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
        graph_formats=graph_formats,
    )
645
646

    pserver_list = []
647
    ctx = mp.get_context("spawn")
648
    for i in range(num_server):
649
650
651
652
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling", ["csc", "coo"]),
        )
653
654
655
656
        p.start()
        time.sleep(1)
        pserver_list.append(p)

657
    fanout = {etype: 3 for etype in g.canonical_etypes}
658
659
    etype_sorted = False
    if graph_formats is not None:
660
661
662
663
664
665
666
667
668
        etype_sorted = "csc" in graph_formats or "csr" in graph_formats
    block, gpb = start_hetero_etype_sample_client(
        0,
        tmpdir,
        num_server > 1,
        fanout,
        nodes={"n3": [0, 10, 99, 66, 124, 208]},
        etype_sorted=etype_sorted,
    )
669
670
671
    print("Done sampling")
    for p in pserver_list:
        p.join()
672
        assert p.exitcode == 0
673

674
    src, dst = block.edges(etype=("n1", "r13", "n3"))
675
    assert len(src) == 18
676
    src, dst = block.edges(etype=("n2", "r23", "n3"))
677
678
    assert len(src) == 18

679
680
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
681
682
683
684
685
686
687
688
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        shuffled_eid = block.edges[etype].data[dgl.EID]

        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
689
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
690
691
692
693
694
695

        # Check the node Ids and edge Ids.
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)

696

697
def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server):
698
699
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

700
701
702
703
    g = create_random_hetero(dense=True, empty=True)
    num_parts = num_server
    num_hops = 1

704
705
706
707
708
709
710
711
712
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
    )
713
714

    pserver_list = []
715
    ctx = mp.get_context("spawn")
716
    for i in range(num_server):
717
718
719
720
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
721
722
723
724
725
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    fanout = 3
726
    deg = get_degrees(g, orig_nids["n3"], "n3")
727
    empty_nids = F.nonzero_1d(deg == 0)
728
729
730
    block, gpb = start_hetero_etype_sample_client(
        0, tmpdir, num_server > 1, fanout, nodes={"n3": empty_nids}
    )
731
732
733
    print("Done sampling")
    for p in pserver_list:
        p.join()
734
        assert p.exitcode == 0
735
736
737
738

    assert block.number_of_edges() == 0
    assert len(block.etypes) == len(g.etypes)

739
740

def create_random_bipartite():
741
742
743
744
745
746
747
    g = dgl.rand_bipartite("user", "buys", "game", 500, 1000, 1000)
    g.nodes["user"].data["feat"] = F.ones(
        (g.num_nodes("user"), 10), F.float32, F.cpu()
    )
    g.nodes["game"].data["feat"] = F.ones(
        (g.num_nodes("game"), 10), F.float32, F.cpu()
    )
748
749
750
751
752
753
754
    return g


def start_bipartite_sample_client(rank, tmpdir, disable_shared_mem, nodes):
    gpb = None
    if disable_shared_mem:
        _, _, _, gpb, _, _, _ = load_partition(
755
756
            tmpdir / "test_sampling.json", rank
        )
757
758
    dgl.distributed.initialize("rpc_ip_config.txt")
    dist_graph = DistGraph("test_sampling", gpb=gpb)
759
760
    assert "feat" in dist_graph.nodes["user"].data
    assert "feat" in dist_graph.nodes["game"].data
761
762
763
764
765
766
767
768
769
770
    if gpb is None:
        gpb = dist_graph.get_partition_book()
    sampled_graph = sample_neighbors(dist_graph, nodes, 3)
    block = dgl.to_block(sampled_graph, nodes)
    if sampled_graph.num_edges() > 0:
        block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
    dgl.distributed.exit_client()
    return block, gpb


771
772
773
def start_bipartite_etype_sample_client(
    rank, tmpdir, disable_shared_mem, fanout=3, nodes={}
):
774
775
776
    gpb = None
    if disable_shared_mem:
        _, _, _, gpb, _, _, _ = load_partition(
777
778
            tmpdir / "test_sampling.json", rank
        )
779
780
    dgl.distributed.initialize("rpc_ip_config.txt")
    dist_graph = DistGraph("test_sampling", gpb=gpb)
781
782
    assert "feat" in dist_graph.nodes["user"].data
    assert "feat" in dist_graph.nodes["game"].data
783
784
785
786
787
788

    if dist_graph.local_partition is not None:
        # Check whether etypes are sorted in dist_graph
        local_g = dist_graph.local_partition
        local_nids = np.arange(local_g.num_nodes())
        for lnid in local_nids:
789
            leids = local_g.in_edges(lnid, form="eid")
790
791
792
793
794
795
            letids = F.asnumpy(local_g.edata[dgl.ETYPE][leids])
            _, idices = np.unique(letids, return_index=True)
            assert np.all(idices[:-1] <= idices[1:])

    if gpb is None:
        gpb = dist_graph.get_partition_book()
796
    sampled_graph = sample_etype_neighbors(dist_graph, nodes, fanout)
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
    block = dgl.to_block(sampled_graph, nodes)
    if sampled_graph.num_edges() > 0:
        block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
    dgl.distributed.exit_client()
    return block, gpb


def check_rpc_bipartite_sampling_empty(tmpdir, num_server):
    """sample on bipartite via sample_neighbors() which yields empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

812
813
814
815
816
817
818
819
820
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
    )
821
822

    pserver_list = []
823
    ctx = mp.get_context("spawn")
824
    for i in range(num_server):
825
826
827
828
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
829
830
831
832
        p.start()
        time.sleep(1)
        pserver_list.append(p)

833
    deg = get_degrees(g, orig_nids["game"], "game")
834
    empty_nids = F.nonzero_1d(deg == 0)
835
836
837
    block, _ = start_bipartite_sample_client(
        0, tmpdir, num_server > 1, nodes={"game": empty_nids, "user": [1]}
    )
838
839
840
841

    print("Done sampling")
    for p in pserver_list:
        p.join()
842
        assert p.exitcode == 0
843
844
845
846
847
848
849
850
851
852
853
854
855

    assert block.number_of_edges() == 0
    assert len(block.etypes) == len(g.etypes)


def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server):
    """sample on bipartite via sample_neighbors() which yields non-empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

856
857
858
859
860
861
862
863
864
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
    )
865
866

    pserver_list = []
867
    ctx = mp.get_context("spawn")
868
    for i in range(num_server):
869
870
871
872
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
873
874
875
876
        p.start()
        time.sleep(1)
        pserver_list.append(p)

877
    deg = get_degrees(g, orig_nid_map["game"], "game")
878
    nids = F.nonzero_1d(deg > 0)
879
880
881
    block, gpb = start_bipartite_sample_client(
        0, tmpdir, num_server > 1, nodes={"game": nids, "user": [0]}
    )
882
883
884
    print("Done sampling")
    for p in pserver_list:
        p.join()
885
        assert p.exitcode == 0
886

887
888
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
889
890
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
891
892
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
893
894
        shuffled_eid = block.edges[etype].data[dgl.EID]

895
896
        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
897
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912

        # Check the node Ids and edge Ids.
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)


def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server):
    """sample on bipartite via sample_etype_neighbors() which yields empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

913
914
915
916
917
918
919
920
921
    orig_nids, _ = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
    )
922
923

    pserver_list = []
924
    ctx = mp.get_context("spawn")
925
    for i in range(num_server):
926
927
928
929
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
930
931
932
933
        p.start()
        time.sleep(1)
        pserver_list.append(p)

934
    deg = get_degrees(g, orig_nids["game"], "game")
935
    empty_nids = F.nonzero_1d(deg == 0)
936
937
938
    block, gpb = start_bipartite_etype_sample_client(
        0, tmpdir, num_server > 1, nodes={"game": empty_nids, "user": [1]}
    )
939
940
941
942

    print("Done sampling")
    for p in pserver_list:
        p.join()
943
        assert p.exitcode == 0
944
945
946
947
948
949
950
951
952
953
954
955
956
957

    assert block is not None
    assert block.number_of_edges() == 0
    assert len(block.etypes) == len(g.etypes)


def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
    """sample on bipartite via sample_etype_neighbors() which yields non-empty sample results"""
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)

    g = create_random_bipartite()
    num_parts = num_server
    num_hops = 1

958
959
960
961
962
963
964
965
966
    orig_nid_map, orig_eid_map = partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
        return_mapping=True,
    )
967
968

    pserver_list = []
969
    ctx = mp.get_context("spawn")
970
    for i in range(num_server):
971
972
973
974
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_sampling"),
        )
975
976
977
978
979
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    fanout = 3
980
    deg = get_degrees(g, orig_nid_map["game"], "game")
981
    nids = F.nonzero_1d(deg > 0)
982
983
984
    block, gpb = start_bipartite_etype_sample_client(
        0, tmpdir, num_server > 1, fanout, nodes={"game": nids, "user": [0]}
    )
985
986
987
    print("Done sampling")
    for p in pserver_list:
        p.join()
988
        assert p.exitcode == 0
989

990
991
    for c_etype in block.canonical_etypes:
        src_type, etype, dst_type = c_etype
992
993
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
994
995
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
996
997
        shuffled_eid = block.edges[etype].data[dgl.EID]

998
999
        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
1000
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
1001
1002
1003
1004
1005
1006

        # Check the node Ids and edge Ids.
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)

1007

Jinjing Zhou's avatar
Jinjing Zhou committed
1008
# Wait non shared memory graph store
1009
1010
1011
1012
1013
1014
1015
1016
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
1017
1018
@pytest.mark.parametrize("num_server", [1, 2])
def test_rpc_sampling_shuffle(num_server):
1019
    reset_envs()
Jinjing Zhou's avatar
Jinjing Zhou committed
1020
    import tempfile
1021
1022

    os.environ["DGL_DIST_MODE"] = "distributed"
Jinjing Zhou's avatar
Jinjing Zhou committed
1023
    with tempfile.TemporaryDirectory() as tmpdirname:
1024
        check_rpc_sampling_shuffle(Path(tmpdirname), num_server)
1025
1026
        # [TODO][Rhett] Tests for multiple groups may fail sometimes and
        # root cause is unknown. Let's disable them for now.
1027
        # check_rpc_sampling_shuffle(Path(tmpdirname), num_server, num_groups=2)
1028
        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), num_server)
1029
        check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), num_server)
1030
        check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), num_server)
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
        check_rpc_hetero_etype_sampling_shuffle(
            Path(tmpdirname), num_server, ["csc"]
        )
        check_rpc_hetero_etype_sampling_shuffle(
            Path(tmpdirname), num_server, ["csr"]
        )
        check_rpc_hetero_etype_sampling_shuffle(
            Path(tmpdirname), num_server, ["csc", "coo"]
        )
        check_rpc_hetero_etype_sampling_empty_shuffle(
            Path(tmpdirname), num_server
        )
1043
1044
1045
1046
        check_rpc_bipartite_sampling_empty(Path(tmpdirname), num_server)
        check_rpc_bipartite_sampling_shuffle(Path(tmpdirname), num_server)
        check_rpc_bipartite_etype_sampling_empty(Path(tmpdirname), num_server)
        check_rpc_bipartite_etype_sampling_shuffle(Path(tmpdirname), num_server)
Jinjing Zhou's avatar
Jinjing Zhou committed
1047

1048

1049
def check_standalone_sampling(tmpdir):
1050
    g = CitationGraphDataset("cora")[0]
1051
    prob = np.maximum(np.random.randn(g.num_edges()), 0)
1052
1053
1054
    mask = prob > 0
    g.edata["prob"] = F.tensor(prob)
    g.edata["mask"] = F.tensor(mask)
1055
1056
    num_parts = 1
    num_hops = 1
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
    partition_graph(
        g,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )

    os.environ["DGL_DIST_MODE"] = "standalone"
1067
    dgl.distributed.initialize("rpc_ip_config.txt")
1068
1069
1070
    dist_graph = DistGraph(
        "test_sampling", part_config=tmpdir / "test_sampling.json"
    )
1071
1072
1073
1074
1075
1076
1077
    sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)

    src, dst = sampled_graph.edges()
    assert sampled_graph.number_of_nodes() == g.number_of_nodes()
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    assert np.array_equal(
1078
1079
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)
    )
1080
1081

    sampled_graph = sample_neighbors(
1082
1083
        dist_graph, [0, 10, 99, 66, 1024, 2008], 3, prob="mask"
    )
1084
1085
1086
1087
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert mask[eid].all()

    sampled_graph = sample_neighbors(
1088
1089
        dist_graph, [0, 10, 99, 66, 1024, 2008], 3, prob="prob"
    )
1090
1091
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert (prob[eid] > 0).all()
1092
    dgl.distributed.exit_client()
1093

1094

1095
def check_standalone_etype_sampling(tmpdir):
1096
    hg = CitationGraphDataset("cora")[0]
1097
    prob = np.maximum(np.random.randn(hg.num_edges()), 0)
1098
1099
1100
    mask = prob > 0
    hg.edata["prob"] = F.tensor(prob)
    hg.edata["mask"] = F.tensor(mask)
1101
1102
1103
    num_parts = 1
    num_hops = 1

1104
1105
1106
1107
1108
1109
1110
1111
1112
    partition_graph(
        hg,
        "test_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )
    os.environ["DGL_DIST_MODE"] = "standalone"
1113
    dgl.distributed.initialize("rpc_ip_config.txt")
1114
1115
1116
    dist_graph = DistGraph(
        "test_sampling", part_config=tmpdir / "test_sampling.json"
    )
1117
    sampled_graph = sample_etype_neighbors(dist_graph, [0, 10, 99, 66, 1023], 3)
1118
1119
1120
1121
1122
1123

    src, dst = sampled_graph.edges()
    assert sampled_graph.number_of_nodes() == hg.number_of_nodes()
    assert np.all(F.asnumpy(hg.has_edges_between(src, dst)))
    eids = hg.edge_ids(src, dst)
    assert np.array_equal(
1124
1125
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)
    )
1126
1127

    sampled_graph = sample_etype_neighbors(
1128
1129
        dist_graph, [0, 10, 99, 66, 1023], 3, prob="mask"
    )
1130
1131
1132
1133
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert mask[eid].all()

    sampled_graph = sample_etype_neighbors(
1134
1135
        dist_graph, [0, 10, 99, 66, 1023], 3, prob="prob"
    )
1136
1137
    eid = F.asnumpy(sampled_graph.edata[dgl.EID])
    assert (prob[eid] > 0).all()
1138
1139
    dgl.distributed.exit_client()

1140

1141
def check_standalone_etype_sampling_heterograph(tmpdir):
1142
    hg = CitationGraphDataset("cora")[0]
1143
1144
1145
    num_parts = 1
    num_hops = 1
    src, dst = hg.edges()
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
    new_hg = dgl.heterograph(
        {
            ("paper", "cite", "paper"): (src, dst),
            ("paper", "cite-by", "paper"): (dst, src),
        },
        {"paper": hg.number_of_nodes()},
    )
    partition_graph(
        new_hg,
        "test_hetero_sampling",
        num_parts,
        tmpdir,
        num_hops=num_hops,
        part_method="metis",
    )
    os.environ["DGL_DIST_MODE"] = "standalone"
1162
    dgl.distributed.initialize("rpc_ip_config.txt")
1163
1164
1165
    dist_graph = DistGraph(
        "test_hetero_sampling", part_config=tmpdir / "test_hetero_sampling.json"
    )
1166
    sampled_graph = sample_etype_neighbors(
1167
1168
1169
        dist_graph, [0, 1, 2, 10, 99, 66, 1023, 1024, 2700, 2701], 1
    )
    src, dst = sampled_graph.edges(etype=("paper", "cite", "paper"))
1170
    assert len(src) == 10
1171
    src, dst = sampled_graph.edges(etype=("paper", "cite-by", "paper"))
1172
1173
1174
1175
    assert len(src) == 10
    assert sampled_graph.number_of_nodes() == new_hg.number_of_nodes()
    dgl.distributed.exit_client()

1176
1177
1178
1179
1180
1181

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
1182
def test_standalone_sampling():
1183
    reset_envs()
1184
    import tempfile
1185
1186

    os.environ["DGL_DIST_MODE"] = "standalone"
1187
    with tempfile.TemporaryDirectory() as tmpdirname:
1188
        check_standalone_sampling(Path(tmpdirname))
1189

1190

1191
1192
def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes):
    gpb = None
1193
    dgl.distributed.initialize("rpc_ip_config.txt")
1194
    if disable_shared_mem:
1195
1196
1197
        _, _, _, gpb, _, _, _ = load_partition(
            tmpdir / "test_in_subgraph.json", rank
        )
1198
    dist_graph = DistGraph("test_in_subgraph", gpb=gpb)
1199
1200
1201
    try:
        sampled_graph = dgl.distributed.in_subgraph(dist_graph, nodes)
    except Exception as e:
1202
        print(traceback.format_exc())
1203
        sampled_graph = None
1204
    dgl.distributed.exit_client()
1205
1206
1207
    return sampled_graph


1208
def check_rpc_in_subgraph_shuffle(tmpdir, num_server):
1209
    generate_ip_config("rpc_ip_config.txt", num_server, num_server)
1210
1211
1212
1213

    g = CitationGraphDataset("cora")[0]
    num_parts = num_server

1214
1215
1216
1217
1218
1219
1220
1221
1222
    orig_nid, orig_eid = partition_graph(
        g,
        "test_in_subgraph",
        num_parts,
        tmpdir,
        num_hops=1,
        part_method="metis",
        return_mapping=True,
    )
1223
1224

    pserver_list = []
1225
    ctx = mp.get_context("spawn")
1226
    for i in range(num_server):
1227
1228
1229
1230
        p = ctx.Process(
            target=start_server,
            args=(i, tmpdir, num_server > 1, "test_in_subgraph"),
        )
1231
1232
1233
1234
1235
1236
1237
1238
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    nodes = [0, 10, 99, 66, 1024, 2008]
    sampled_graph = start_in_subgraph_client(0, tmpdir, num_server > 1, nodes)
    for p in pserver_list:
        p.join()
1239
        assert p.exitcode == 0
1240
1241

    src, dst = sampled_graph.edges()
1242
1243
    src = orig_nid[src]
    dst = orig_nid[dst]
1244
    assert sampled_graph.number_of_nodes() == g.number_of_nodes()
1245
1246
1247
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))

    subg1 = dgl.in_subgraph(g, orig_nid[nodes])
1248
1249
1250
1251
    src1, dst1 = subg1.edges()
    assert np.all(np.sort(F.asnumpy(src)) == np.sort(F.asnumpy(src1)))
    assert np.all(np.sort(F.asnumpy(dst)) == np.sort(F.asnumpy(dst1)))
    eids = g.edge_ids(src, dst)
1252
1253
    eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
    assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))
1254

1255
1256
1257
1258
1259
1260

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
1261
def test_rpc_in_subgraph():
1262
    reset_envs()
1263
    import tempfile
1264
1265

    os.environ["DGL_DIST_MODE"] = "distributed"
1266
    with tempfile.TemporaryDirectory() as tmpdirname:
1267
        check_rpc_in_subgraph_shuffle(Path(tmpdirname), 2)
1268

1269
1270
1271
1272
1273
1274
1275
1276
1277

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="Not support tensorflow for now",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
1278
def test_standalone_etype_sampling():
1279
    reset_envs()
1280
    import tempfile
1281

1282
    with tempfile.TemporaryDirectory() as tmpdirname:
1283
        os.environ["DGL_DIST_MODE"] = "standalone"
1284
        check_standalone_etype_sampling_heterograph(Path(tmpdirname))
1285
    with tempfile.TemporaryDirectory() as tmpdirname:
1286
        os.environ["DGL_DIST_MODE"] = "standalone"
1287
        check_standalone_etype_sampling(Path(tmpdirname))
1288

1289

Jinjing Zhou's avatar
Jinjing Zhou committed
1290
1291
if __name__ == "__main__":
    import tempfile
1292

Jinjing Zhou's avatar
Jinjing Zhou committed
1293
    with tempfile.TemporaryDirectory() as tmpdirname:
1294
        os.environ["DGL_DIST_MODE"] = "standalone"
1295
        check_standalone_etype_sampling_heterograph(Path(tmpdirname))
1296
1297

    with tempfile.TemporaryDirectory() as tmpdirname:
1298
        os.environ["DGL_DIST_MODE"] = "standalone"
1299
1300
        check_standalone_etype_sampling(Path(tmpdirname))
        check_standalone_sampling(Path(tmpdirname))
1301
        os.environ["DGL_DIST_MODE"] = "distributed"
1302
1303
        check_rpc_sampling(Path(tmpdirname), 2)
        check_rpc_sampling(Path(tmpdirname), 1)
1304
1305
        check_rpc_get_degree_shuffle(Path(tmpdirname), 1)
        check_rpc_get_degree_shuffle(Path(tmpdirname), 2)
1306
1307
        check_rpc_find_edges_shuffle(Path(tmpdirname), 2)
        check_rpc_find_edges_shuffle(Path(tmpdirname), 1)
1308
1309
        check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), 2)
1310
1311
1312
1313
        check_rpc_in_subgraph_shuffle(Path(tmpdirname), 2)
        check_rpc_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), 2)
1314
1315
1316
1317
        check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), 2)
        check_rpc_hetero_etype_sampling_empty_shuffle(Path(tmpdirname), 1)