test_dist_graph_store.py 39.5 KB
Newer Older
1
import os
2
3
4
5
6
7

os.environ["OMP_NUM_THREADS"] = "1"
import math
import multiprocessing as mp
import pickle
import socket
8
9
10
import sys
import time
import unittest
11
12
13
14
from multiprocessing import Condition, Manager, Process, Value

import backend as F
import numpy as np
15
import pytest
16
import torch as th
17
18
19
from numpy.testing import assert_almost_equal, assert_array_equal
from scipy import sparse as spsp
from utils import create_random_graph, generate_ip_config, reset_envs
20

21
22
23
import dgl
from dgl.data.utils import load_graphs, save_graphs
from dgl.distributed import (
24
    DistEmbedding,
25
26
27
28
29
30
31
32
    DistGraph,
    DistGraphServer,
    edge_split,
    load_partition,
    load_partition_book,
    node_split,
    partition_graph,
)
33
from dgl.distributed.optim import SparseAdagrad
34
35
36
from dgl.heterograph_index import create_unitgraph_from_coo

if os.name != "nt":
37
38
39
    import fcntl
    import struct

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

def run_server(
    graph_name,
    server_id,
    server_count,
    num_clients,
    shared_mem,
    keep_alive=False,
):
    g = DistGraphServer(
        server_id,
        "kv_ip_config.txt",
        server_count,
        num_clients,
        "/tmp/dist_graph/{}.json".format(graph_name),
        disable_shared_mem=not shared_mem,
        graph_format=["csc", "coo"],
        keep_alive=keep_alive,
    )
    print("start server", server_id)
60
61
    # verify dtype of underlying graph
    cg = g.client_g
62
    for k, dtype in dgl.distributed.dist_graph.RESERVED_FIELD_DTYPE.items():
63
        if k in cg.ndata:
64
65
66
            assert (
                F.dtype(cg.ndata[k]) == dtype
            ), "Data type of {} in ndata should be {}.".format(k, dtype)
67
        if k in cg.edata:
68
69
70
            assert (
                F.dtype(cg.edata[k]) == dtype
            ), "Data type of {} in edata should be {}.".format(k, dtype)
71
72
    g.start()

73

74
75
76
def emb_init(shape, dtype):
    return F.zeros(shape, dtype, F.cpu())

77

78
def rand_init(shape, dtype):
79
    return F.tensor(np.random.normal(size=shape), F.float32)
80

81

82
83
84
85
86
87
88
def check_dist_graph_empty(g, num_clients, num_nodes, num_edges):
    # Test API
    assert g.number_of_nodes() == num_nodes
    assert g.number_of_edges() == num_edges

    # Test init node data
    new_shape = (g.number_of_nodes(), 2)
89
    g.ndata["test1"] = dgl.distributed.DistTensor(new_shape, F.int32)
90
    nids = F.arange(0, int(g.number_of_nodes() / 2))
91
    feats = g.ndata["test1"][nids]
92
93
94
    assert np.all(F.asnumpy(feats) == 0)

    # create a tensor and destroy a tensor and create it again.
95
96
97
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test3", init_func=rand_init
    )
98
    del test3
99
100
101
    test3 = dgl.distributed.DistTensor(
        (g.number_of_nodes(), 3), F.float32, "test3"
    )
102
103
104
105
    del test3

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
106
107
    g.ndata["test1"][nids] = new_feats
    feats = g.ndata["test1"][nids]
108
109
110
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
111
    assert g.node_attr_schemes()["test1"].dtype == F.int32
112

113
    print("end")
114

115
116
117
118
119

def run_client_empty(
    graph_name, part_id, server_count, num_clients, num_nodes, num_edges
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
120
    dgl.distributed.initialize("kv_ip_config.txt")
121
122
123
    gpb, graph_name, _, _ = load_partition_book(
        "/tmp/dist_graph/{}.json".format(graph_name), part_id, None
    )
124
125
126
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_graph_empty(g, num_clients, num_nodes, num_edges)

127

128
def check_server_client_empty(shared_mem, num_servers, num_clients):
129
    prepare_dist(num_servers)
130
131
132
133
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
134
135
    graph_name = "dist_graph_test_1"
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
136
137
138
139

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
140
    ctx = mp.get_context("spawn")
141
    for serv_id in range(num_servers):
142
143
144
145
        p = ctx.Process(
            target=run_server,
            args=(graph_name, serv_id, num_servers, num_clients, shared_mem),
        )
146
147
148
149
150
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
151
152
153
154
155
156
157
158
159
160
161
162
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client_empty,
            args=(
                graph_name,
                0,
                num_servers,
                num_clients,
                g.number_of_nodes(),
                g.number_of_edges(),
            ),
        )
163
164
165
166
167
168
169
170
171
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()

    for p in serv_ps:
        p.join()

172
    print("clients have terminated")
173

174
175
176
177
178
179
180
181
182
183
184
185

def run_client(
    graph_name,
    part_id,
    server_count,
    num_clients,
    num_nodes,
    num_edges,
    group_id,
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
    os.environ["DGL_GROUP_ID"] = str(group_id)
186
    dgl.distributed.initialize("kv_ip_config.txt")
187
188
189
    gpb, graph_name, _, _ = load_partition_book(
        "/tmp/dist_graph/{}.json".format(graph_name), part_id, None
    )
190
    g = DistGraph(graph_name, gpb=gpb)
191
    check_dist_graph(g, num_clients, num_nodes, num_edges)
192

193
194
195
196
197
198
199
200
201
202
203
204

def run_emb_client(
    graph_name,
    part_id,
    server_count,
    num_clients,
    num_nodes,
    num_edges,
    group_id,
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
    os.environ["DGL_GROUP_ID"] = str(group_id)
205
    dgl.distributed.initialize("kv_ip_config.txt")
206
207
208
    gpb, graph_name, _, _ = load_partition_book(
        "/tmp/dist_graph/{}.json".format(graph_name), part_id, None
    )
209
210
211
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_emb(g, num_clients, num_nodes, num_edges)

212

213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def run_optim_client(
    graph_name,
    part_id,
    server_count,
    rank,
    world_size,
    num_nodes,
    optimizer_states,
    save,
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "12355"
    dgl.distributed.initialize("kv_ip_config.txt")
    th.distributed.init_process_group(
        backend="gloo", rank=rank, world_size=world_size
    )
    gpb, graph_name, _, _ = load_partition_book(
        "/tmp/dist_graph/{}.json".format(graph_name), part_id, None
    )
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_optim_store(rank, num_nodes, optimizer_states, save)


def check_dist_optim_store(rank, num_nodes, optimizer_states, save):
    try:
        total_idx = F.arange(0, num_nodes, F.int64, F.cpu())
        emb = DistEmbedding(num_nodes, 1, name="optim_emb1", init_func=emb_init)
        emb2 = DistEmbedding(
            num_nodes, 1, name="optim_emb2", init_func=emb_init
        )
        if save:
            optimizer = SparseAdagrad([emb, emb2], lr=0.1, eps=1e-08)
            if rank == 0:
                optimizer._state["optim_emb1"][total_idx] = optimizer_states[0]
                optimizer._state["optim_emb2"][total_idx] = optimizer_states[1]
            optimizer.save("/tmp/dist_graph/emb.pt")
        else:
            optimizer = SparseAdagrad([emb, emb2], lr=0.001, eps=2e-08)
            optimizer.load("/tmp/dist_graph/emb.pt")
            if rank == 0:
                assert F.allclose(
                    optimizer._state["optim_emb1"][total_idx],
                    optimizer_states[0],
                    0.0,
                    0.0,
                )
                assert F.allclose(
                    optimizer._state["optim_emb2"][total_idx],
                    optimizer_states[1],
                    0.0,
                    0.0,
                )
                assert 0.1 == optimizer._lr
                assert 1e-08 == optimizer._eps
            th.distributed.barrier()
    except Exception as e:
        print(e)
        sys.exit(-1)


274
275
276
277
def run_client_hierarchy(
    graph_name, part_id, server_count, node_mask, edge_mask, return_dict
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
278
    dgl.distributed.initialize("kv_ip_config.txt")
279
280
281
    gpb, graph_name, _, _ = load_partition_book(
        "/tmp/dist_graph/{}.json".format(graph_name), part_id, None
    )
282
283
284
    g = DistGraph(graph_name, gpb=gpb)
    node_mask = F.tensor(node_mask)
    edge_mask = F.tensor(edge_mask)
285
286
287
288
289
290
291
292
293
294
    nodes = node_split(
        node_mask,
        g.get_partition_book(),
        node_trainer_ids=g.ndata["trainer_id"],
    )
    edges = edge_split(
        edge_mask,
        g.get_partition_book(),
        edge_trainer_ids=g.edata["trainer_id"],
    )
295
296
297
    rank = g.rank()
    return_dict[rank] = (nodes, edges)

298

299
300
301
def check_dist_emb(g, num_clients, num_nodes, num_edges):
    # Test sparse emb
    try:
302
        emb = DistEmbedding(g.number_of_nodes(), 1, "emb1", emb_init)
303
        nids = F.arange(0, int(g.number_of_nodes()))
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
        lr = 0.001
        optimizer = SparseAdagrad([emb], lr=lr)
        with F.record_grad():
            feats = emb(nids)
            assert np.all(F.asnumpy(feats) == np.zeros((len(nids), 1)))
            loss = F.sum(feats + 1, 0)
        loss.backward()
        optimizer.step()
        feats = emb(nids)
        if num_clients == 1:
            assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * -lr)
        rest = np.setdiff1d(np.arange(g.number_of_nodes()), F.asnumpy(nids))
        feats1 = emb(rest)
        assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))

319
320
321
322
        policy = dgl.distributed.PartitionPolicy("node", g.get_partition_book())
        grad_sum = dgl.distributed.DistTensor(
            (g.number_of_nodes(), 1), F.float32, "emb1_sum", policy
        )
323
        if num_clients == 1:
324
325
326
327
            assert np.all(
                F.asnumpy(grad_sum[nids])
                == np.ones((len(nids), 1)) * num_clients
            )
328
329
        assert np.all(F.asnumpy(grad_sum[rest]) == np.zeros((len(rest), 1)))

330
        emb = DistEmbedding(g.number_of_nodes(), 1, "emb2", emb_init)
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
        with F.no_grad():
            feats1 = emb(nids)
        assert np.all(F.asnumpy(feats1) == 0)

        optimizer = SparseAdagrad([emb], lr=lr)
        with F.record_grad():
            feats1 = emb(nids)
            feats2 = emb(nids)
            feats = F.cat([feats1, feats2], 0)
            assert np.all(F.asnumpy(feats) == np.zeros((len(nids) * 2, 1)))
            loss = F.sum(feats + 1, 0)
        loss.backward()
        optimizer.step()
        with F.no_grad():
            feats = emb(nids)
        if num_clients == 1:
347
348
349
            assert_almost_equal(
                F.asnumpy(feats), np.ones((len(nids), 1)) * 1 * -lr
            )
350
351
352
353
354
        rest = np.setdiff1d(np.arange(g.number_of_nodes()), F.asnumpy(nids))
        feats1 = emb(rest)
        assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))
    except NotImplementedError as e:
        pass
355
356
357
    except Exception as e:
        print(e)
        sys.exit(-1)
358

359

360
def check_dist_graph(g, num_clients, num_nodes, num_edges):
361
362
363
364
365
366
    # Test API
    assert g.number_of_nodes() == num_nodes
    assert g.number_of_edges() == num_edges

    # Test reading node data
    nids = F.arange(0, int(g.number_of_nodes() / 2))
367
    feats1 = g.ndata["features"][nids]
368
369
370
371
372
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == nids))

    # Test reading edge data
    eids = F.arange(0, int(g.number_of_edges() / 2))
373
    feats1 = g.edata["features"][eids]
374
375
376
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == eids))

377
378
379
380
381
    # Test edge_subgraph
    sg = g.edge_subgraph(eids)
    assert sg.num_edges() == len(eids)
    assert F.array_equal(sg.edata[dgl.EID], eids)

382
383
    # Test init node data
    new_shape = (g.number_of_nodes(), 2)
384
    test1 = dgl.distributed.DistTensor(new_shape, F.int32)
385
386
    g.ndata["test1"] = test1
    feats = g.ndata["test1"][nids]
387
    assert np.all(F.asnumpy(feats) == 0)
388
    assert test1.count_nonzero() == 0
389

390
    # reference to a one that exists
391
392
393
394
    test2 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test2", init_func=rand_init
    )
    test3 = dgl.distributed.DistTensor(new_shape, F.float32, "test2")
395
396
397
    assert np.all(F.asnumpy(test2[nids]) == F.asnumpy(test3[nids]))

    # create a tensor and destroy a tensor and create it again.
398
399
400
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test3", init_func=rand_init
    )
401
    del test3
402
403
404
    test3 = dgl.distributed.DistTensor(
        (g.number_of_nodes(), 3), F.float32, "test3"
    )
405
406
    del test3

Da Zheng's avatar
Da Zheng committed
407
    # add tests for anonymous distributed tensor.
408
409
410
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
Da Zheng's avatar
Da Zheng committed
411
    data = test3[0:10]
412
413
414
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
Da Zheng's avatar
Da Zheng committed
415
    del test3
416
417
418
    test5 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
Da Zheng's avatar
Da Zheng committed
419
420
    assert np.sum(F.asnumpy(test5[0:10] != data)) > 0

421
    # test a persistent tesnor
422
423
424
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test4", init_func=rand_init, persistent=True
    )
425
426
    del test4
    try:
427
428
429
430
        test4 = dgl.distributed.DistTensor(
            (g.number_of_nodes(), 3), F.float32, "test4"
        )
        raise Exception("")
431
432
    except:
        pass
433
434
435

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
436
437
    g.ndata["test1"][nids] = new_feats
    feats = g.ndata["test1"][nids]
438
439
440
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
441
442
443
444
445
446
    assert len(g.ndata["features"]) == g.number_of_nodes()
    assert g.ndata["features"].shape == (g.number_of_nodes(), 1)
    assert g.ndata["features"].dtype == F.int64
    assert g.node_attr_schemes()["features"].dtype == F.int64
    assert g.node_attr_schemes()["test1"].dtype == F.int32
    assert g.node_attr_schemes()["features"].shape == (1,)
447

448
449
    selected_nodes = np.random.randint(0, 100, size=g.number_of_nodes()) > 30
    # Test node split
450
    nodes = node_split(selected_nodes, g.get_partition_book())
451
452
453
454
455
456
    nodes = F.asnumpy(nodes)
    # We only have one partition, so the local nodes are basically all nodes in the graph.
    local_nids = np.arange(g.number_of_nodes())
    for n in nodes:
        assert n in local_nids

457
458
    print("end")

459

460
461
462
def check_dist_emb_server_client(
    shared_mem, num_servers, num_clients, num_groups=1
):
463
    prepare_dist(num_servers)
464
465
466
467
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
468
469
470
471
472
473
    graph_name = (
        f"check_dist_emb_{shared_mem}_{num_servers}_{num_clients}_{num_groups}"
    )
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
474
475
476
477

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
478
    ctx = mp.get_context("spawn")
479
    keep_alive = num_groups > 1
480
    for serv_id in range(num_servers):
481
482
483
484
485
486
487
488
489
490
491
        p = ctx.Process(
            target=run_server,
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                shared_mem,
                keep_alive,
            ),
        )
492
493
494
495
496
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
497
        for group_id in range(num_groups):
498
499
500
501
502
503
504
505
506
507
508
509
510
            print("start client[{}] for group[{}]".format(cli_id, group_id))
            p = ctx.Process(
                target=run_emb_client,
                args=(
                    graph_name,
                    0,
                    num_servers,
                    num_clients,
                    g.number_of_nodes(),
                    g.number_of_edges(),
                    group_id,
                ),
            )
511
            p.start()
512
            time.sleep(1)  # avoid race condition when instantiating DistGraph
513
            cli_ps.append(p)
514
515
516

    for p in cli_ps:
        p.join()
517
        assert p.exitcode == 0
518

519
520
521
522
523
    if keep_alive:
        for p in serv_ps:
            assert p.is_alive()
        # force shutdown server
        dgl.distributed.shutdown_servers("kv_ip_config.txt", num_servers)
524
525
526
    for p in serv_ps:
        p.join()

527
528
    print("clients have terminated")

529

530
def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
531
    prepare_dist(num_servers)
532
533
534
535
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
536
537
538
539
    graph_name = f"check_server_client_{shared_mem}_{num_servers}_{num_clients}_{num_groups}"
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
540
541
542
543

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
544
    ctx = mp.get_context("spawn")
545
    keep_alive = num_groups > 1
546
    for serv_id in range(num_servers):
547
548
549
550
551
552
553
554
555
556
557
        p = ctx.Process(
            target=run_server,
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                shared_mem,
                keep_alive,
            ),
        )
558
559
560
        serv_ps.append(p)
        p.start()

561
    # launch different client groups simultaneously
562
    cli_ps = []
563
    for cli_id in range(num_clients):
564
        for group_id in range(num_groups):
565
566
567
568
569
570
571
572
573
574
575
576
577
            print("start client[{}] for group[{}]".format(cli_id, group_id))
            p = ctx.Process(
                target=run_client,
                args=(
                    graph_name,
                    0,
                    num_servers,
                    num_clients,
                    g.number_of_nodes(),
                    g.number_of_edges(),
                    group_id,
                ),
            )
578
            p.start()
579
            time.sleep(1)  # avoid race condition when instantiating DistGraph
580
            cli_ps.append(p)
581
582
    for p in cli_ps:
        p.join()
583

584
585
586
587
588
    if keep_alive:
        for p in serv_ps:
            assert p.is_alive()
        # force shutdown server
        dgl.distributed.shutdown_servers("kv_ip_config.txt", num_servers)
589
590
591
    for p in serv_ps:
        p.join()

592
593
    print("clients have terminated")

594

595
def check_server_client_hierarchy(shared_mem, num_servers, num_clients):
596
    prepare_dist(num_servers)
597
598
599
600
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
601
602
603
604
605
606
607
608
609
610
    graph_name = "dist_graph_test_2"
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(
        g,
        graph_name,
        num_parts,
        "/tmp/dist_graph",
        num_trainers_per_machine=num_clients,
    )
611
612
613
614

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
615
    ctx = mp.get_context("spawn")
616
    for serv_id in range(num_servers):
617
618
619
620
        p = ctx.Process(
            target=run_server,
            args=(graph_name, serv_id, num_servers, num_clients, shared_mem),
        )
621
622
623
624
625
626
627
628
        serv_ps.append(p)
        p.start()

    cli_ps = []
    manager = mp.Manager()
    return_dict = manager.dict()
    node_mask = np.zeros((g.number_of_nodes(),), np.int32)
    edge_mask = np.zeros((g.number_of_edges(),), np.int32)
629
630
631
632
633
634
    nodes = np.random.choice(
        g.number_of_nodes(), g.number_of_nodes() // 10, replace=False
    )
    edges = np.random.choice(
        g.number_of_edges(), g.number_of_edges() // 10, replace=False
    )
635
636
637
638
639
    node_mask[nodes] = 1
    edge_mask[edges] = 1
    nodes = np.sort(nodes)
    edges = np.sort(edges)
    for cli_id in range(num_clients):
640
641
642
643
644
645
646
647
648
649
650
651
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client_hierarchy,
            args=(
                graph_name,
                0,
                num_servers,
                node_mask,
                edge_mask,
                return_dict,
            ),
        )
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
    for p in serv_ps:
        p.join()

    nodes1 = []
    edges1 = []
    for n, e in return_dict.values():
        nodes1.append(n)
        edges1.append(e)
    nodes1, _ = F.sort_1d(F.cat(nodes1, 0))
    edges1, _ = F.sort_1d(F.cat(edges1, 0))
    assert np.all(F.asnumpy(nodes1) == nodes)
    assert np.all(F.asnumpy(edges1) == edges)

670
    print("clients have terminated")
671

672

673
674
675
676
def run_client_hetero(
    graph_name, part_id, server_count, num_clients, num_nodes, num_edges
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
677
    dgl.distributed.initialize("kv_ip_config.txt")
678
679
680
    gpb, graph_name, _, _ = load_partition_book(
        "/tmp/dist_graph/{}.json".format(graph_name), part_id, None
    )
681
682
683
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_graph_hetero(g, num_clients, num_nodes, num_edges)

684

685
def create_random_hetero():
686
687
    num_nodes = {"n1": 10000, "n2": 10010, "n3": 10020}
    etypes = [("n1", "r1", "n2"), ("n1", "r2", "n3"), ("n2", "r3", "n3")]
688
689
690
    edges = {}
    for etype in etypes:
        src_ntype, _, dst_ntype = etype
691
692
693
694
695
696
697
        arr = spsp.random(
            num_nodes[src_ntype],
            num_nodes[dst_ntype],
            density=0.001,
            format="coo",
            random_state=100,
        )
698
699
        edges[etype] = (arr.row, arr.col)
    g = dgl.heterograph(edges, num_nodes)
700
701
702
703
704
705
    g.nodes["n1"].data["feat"] = F.unsqueeze(
        F.arange(0, g.number_of_nodes("n1")), 1
    )
    g.edges["r1"].data["feat"] = F.unsqueeze(
        F.arange(0, g.number_of_edges("r1")), 1
    )
706
707
    return g

708

709
710
711
712
713
714
715
716
def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges):
    # Test API
    for ntype in num_nodes:
        assert ntype in g.ntypes
        assert num_nodes[ntype] == g.number_of_nodes(ntype)
    for etype in num_edges:
        assert etype in g.etypes
        assert num_edges[etype] == g.number_of_edges(etype)
717
    etypes = [("n1", "r1", "n2"), ("n1", "r2", "n3"), ("n2", "r3", "n3")]
718
719
720
721
    for i, etype in enumerate(g.canonical_etypes):
        assert etype[0] == etypes[i][0]
        assert etype[1] == etypes[i][1]
        assert etype[2] == etypes[i][2]
722
723
724
725
    assert g.number_of_nodes() == sum([num_nodes[ntype] for ntype in num_nodes])
    assert g.number_of_edges() == sum([num_edges[etype] for etype in num_edges])

    # Test reading node data
726
727
    nids = F.arange(0, int(g.number_of_nodes("n1") / 2))
    feats1 = g.nodes["n1"].data["feat"][nids]
728
729
730
731
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == nids))

    # Test reading edge data
732
    eids = F.arange(0, int(g.number_of_edges("r1") / 2))
733
734
735
736
737
738
739
740
    # access via etype
    feats = g.edges["r1"].data["feat"][eids]
    feats = F.squeeze(feats, 1)
    assert np.all(F.asnumpy(feats == eids))
    # access via canonical etype
    c_etype = g.to_canonical_etype("r1")
    feats = g.edges[c_etype].data["feat"][eids]
    feats = F.squeeze(feats, 1)
741
742
    assert np.all(F.asnumpy(feats == eids))

743
    # Test edge_subgraph
744
    sg = g.edge_subgraph({"r1": eids})
745
746
    assert sg.num_edges() == len(eids)
    assert F.array_equal(sg.edata[dgl.EID], eids)
747
    sg = g.edge_subgraph({("n1", "r1", "n2"): eids})
748
749
750
    assert sg.num_edges() == len(eids)
    assert F.array_equal(sg.edata[dgl.EID], eids)

751
    # Test init node data
752
753
754
    new_shape = (g.number_of_nodes("n1"), 2)
    g.nodes["n1"].data["test1"] = dgl.distributed.DistTensor(new_shape, F.int32)
    feats = g.nodes["n1"].data["test1"][nids]
755
756
757
    assert np.all(F.asnumpy(feats) == 0)

    # create a tensor and destroy a tensor and create it again.
758
759
760
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test3", init_func=rand_init
    )
761
    del test3
762
763
764
    test3 = dgl.distributed.DistTensor(
        (g.number_of_nodes("n1"), 3), F.float32, "test3"
    )
765
766
767
    del test3

    # add tests for anonymous distributed tensor.
768
769
770
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
771
    data = test3[0:10]
772
773
774
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
775
    del test3
776
777
778
    test5 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
779
780
781
    assert np.sum(F.asnumpy(test5[0:10] != data)) > 0

    # test a persistent tesnor
782
783
784
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test4", init_func=rand_init, persistent=True
    )
785
786
    del test4
    try:
787
788
789
790
        test4 = dgl.distributed.DistTensor(
            (g.number_of_nodes("n1"), 3), F.float32, "test4"
        )
        raise Exception("")
791
792
793
794
795
    except:
        pass

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
796
797
    g.nodes["n1"].data["test1"][nids] = new_feats
    feats = g.nodes["n1"].data["test1"][nids]
798
799
800
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
801
802
803
    assert len(g.nodes["n1"].data["feat"]) == g.number_of_nodes("n1")
    assert g.nodes["n1"].data["feat"].shape == (g.number_of_nodes("n1"), 1)
    assert g.nodes["n1"].data["feat"].dtype == F.int64
804

805
806
807
    selected_nodes = (
        np.random.randint(0, 100, size=g.number_of_nodes("n1")) > 30
    )
808
    # Test node split
809
    nodes = node_split(selected_nodes, g.get_partition_book(), ntype="n1")
810
811
    nodes = F.asnumpy(nodes)
    # We only have one partition, so the local nodes are basically all nodes in the graph.
812
    local_nids = np.arange(g.number_of_nodes("n1"))
813
814
815
    for n in nodes:
        assert n in local_nids

816
817
    print("end")

818
819

def check_server_client_hetero(shared_mem, num_servers, num_clients):
820
    prepare_dist(num_servers)
821
822
823
824
    g = create_random_hetero()

    # Partition the graph
    num_parts = 1
825
826
    graph_name = "dist_graph_test_3"
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
827
828
829
830

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
831
    ctx = mp.get_context("spawn")
832
    for serv_id in range(num_servers):
833
834
835
836
        p = ctx.Process(
            target=run_server,
            args=(graph_name, serv_id, num_servers, num_clients, shared_mem),
        )
837
838
839
840
841
842
843
        serv_ps.append(p)
        p.start()

    cli_ps = []
    num_nodes = {ntype: g.number_of_nodes(ntype) for ntype in g.ntypes}
    num_edges = {etype: g.number_of_edges(etype) for etype in g.etypes}
    for cli_id in range(num_clients):
844
845
846
847
848
849
850
851
852
853
854
855
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client_hetero,
            args=(
                graph_name,
                0,
                num_servers,
                num_clients,
                num_nodes,
                num_edges,
            ),
        )
856
857
858
859
860
861
862
863
864
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()

    for p in serv_ps:
        p.join()

865
866
    print("clients have terminated")

867

868
869
870
871
872
873
874
875
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support some of operations in DistGraph",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
876
def test_server_client():
877
    reset_envs()
878
    os.environ["DGL_DIST_MODE"] = "distributed"
879
    check_server_client_hierarchy(False, 1, 4)
880
    check_server_client_empty(True, 1, 1)
881
882
    check_server_client_hetero(True, 1, 1)
    check_server_client_hetero(False, 1, 1)
883
884
    check_server_client(True, 1, 1)
    check_server_client(False, 1, 1)
885
886
    # [TODO][Rhett] Tests for multiple groups may fail sometimes and
    # root cause is unknown. Let's disable them for now.
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
    # check_server_client(True, 2, 2)
    # check_server_client(True, 1, 1, 2)
    # check_server_client(False, 1, 1, 2)
    # check_server_client(True, 2, 2, 2)


@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support distributed DistEmbedding",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="Mxnet doesn't support distributed DistEmbedding",
)
902
def test_dist_emb_server_client():
903
    reset_envs()
904
    os.environ["DGL_DIST_MODE"] = "distributed"
905
906
    check_dist_emb_server_client(True, 1, 1)
    check_dist_emb_server_client(False, 1, 1)
907
908
    # [TODO][Rhett] Tests for multiple groups may fail sometimes and
    # root cause is unknown. Let's disable them for now.
909
910
911
912
913
914
    # check_dist_emb_server_client(True, 2, 2)
    # check_dist_emb_server_client(True, 1, 1, 2)
    # check_dist_emb_server_client(False, 1, 1, 2)
    # check_dist_emb_server_client(True, 2, 2, 2)


915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support distributed Optimizer",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="Mxnet doesn't support distributed Optimizer",
)
def test_dist_optim_server_client():
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    optimizer_states = []
    num_nodes = 10000
    optimizer_states.append(F.uniform((num_nodes, 1), F.float32, F.cpu(), 0, 1))
    optimizer_states.append(F.uniform((num_nodes, 1), F.float32, F.cpu(), 0, 1))
    check_dist_optim_server_client(num_nodes, 1, 4, optimizer_states, True)
    check_dist_optim_server_client(num_nodes, 1, 8, optimizer_states, False)
    check_dist_optim_server_client(num_nodes, 1, 2, optimizer_states, False)


def check_dist_optim_server_client(
    num_nodes, num_servers, num_clients, optimizer_states, save
):
    graph_name = f"check_dist_optim_{num_servers}_store"
    if save:
        prepare_dist(num_servers)
        g = create_random_graph(num_nodes)

        # Partition the graph
        num_parts = 1
        g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
        g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
        partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
    ctx = mp.get_context("spawn")
    for serv_id in range(num_servers):
        p = ctx.Process(
            target=run_server,
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                True,
                False,
            ),
        )
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
        print("start client[{}] for group[0]".format(cli_id))
        p = ctx.Process(
            target=run_optim_client,
            args=(
                graph_name,
                0,
                num_servers,
                cli_id,
                num_clients,
                num_nodes,
                optimizer_states,
                save,
            ),
        )
        p.start()
        time.sleep(1)  # avoid race condition when instantiating DistGraph
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
        assert p.exitcode == 0

    for p in serv_ps:
        p.join()


996
997
998
999
1000
1001
1002
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support some of operations in DistGraph",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
1003
def test_standalone():
1004
    reset_envs()
1005
    os.environ["DGL_DIST_MODE"] = "standalone"
Da Zheng's avatar
Da Zheng committed
1006

1007
1008
1009
    g = create_random_graph(10000)
    # Partition the graph
    num_parts = 1
1010
1011
1012
1013
    graph_name = "dist_graph_test_3"
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
1014
1015

    dgl.distributed.initialize("kv_ip_config.txt")
1016
1017
1018
    dist_g = DistGraph(
        graph_name, part_config="/tmp/dist_graph/{}.json".format(graph_name)
    )
1019
    check_dist_graph(dist_g, 1, g.number_of_nodes(), g.number_of_edges())
1020
1021
    dgl.distributed.exit_client()  # this is needed since there's two test here in one process

1022

1023
1024
1025
1026
1027
1028
1029
1030
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support distributed DistEmbedding",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="Mxnet doesn't support distributed DistEmbedding",
)
1031
def test_standalone_node_emb():
1032
    reset_envs()
1033
    os.environ["DGL_DIST_MODE"] = "standalone"
1034
1035
1036
1037

    g = create_random_graph(10000)
    # Partition the graph
    num_parts = 1
1038
1039
1040
1041
    graph_name = "dist_graph_test_3"
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
1042
1043

    dgl.distributed.initialize("kv_ip_config.txt")
1044
1045
1046
    dist_g = DistGraph(
        graph_name, part_config="/tmp/dist_graph/{}.json".format(graph_name)
    )
1047
    check_dist_emb(dist_g, 1, g.number_of_nodes(), g.number_of_edges())
1048
1049
    dgl.distributed.exit_client()  # this is needed since there's two test here in one process

1050

1051
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
1052
1053
1054
1055
@pytest.mark.parametrize("hetero", [True, False])
def test_split(hetero):
    if hetero:
        g = create_random_hetero()
1056
1057
        ntype = "n1"
        etype = "r1"
1058
1059
    else:
        g = create_random_graph(10000)
1060
1061
        ntype = "_N"
        etype = "_E"
1062
1063
    num_parts = 4
    num_hops = 2
1064
1065
1066
1067
1068
1069
1070
1071
    partition_graph(
        g,
        "dist_graph_test",
        num_parts,
        "/tmp/dist_graph",
        num_hops=num_hops,
        part_method="metis",
    )
1072

1073
1074
    node_mask = np.random.randint(0, 100, size=g.number_of_nodes(ntype)) > 30
    edge_mask = np.random.randint(0, 100, size=g.number_of_edges(etype)) > 30
1075
1076
    selected_nodes = np.nonzero(node_mask)[0]
    selected_edges = np.nonzero(edge_mask)[0]
Da Zheng's avatar
Da Zheng committed
1077
1078
1079
1080
1081

    # The code now collects the roles of all client processes and use the information
    # to determine how to split the workloads. Here is to simulate the multi-client
    # use case.
    def set_roles(num_clients):
1082
1083
1084
1085
1086
        dgl.distributed.role.CUR_ROLE = "default"
        dgl.distributed.role.GLOBAL_RANK = {i: i for i in range(num_clients)}
        dgl.distributed.role.PER_ROLE_RANK["default"] = {
            i: i for i in range(num_clients)
        }
Da Zheng's avatar
Da Zheng committed
1087

1088
    for i in range(num_parts):
Da Zheng's avatar
Da Zheng committed
1089
        set_roles(num_parts)
1090
1091
1092
1093
        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(
            "/tmp/dist_graph/dist_graph_test.json", i
        )
        local_nids = F.nonzero_1d(part_g.ndata["inner_node"])
1094
        local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
1095
1096
1097
1098
1099
1100
        if hetero:
            ntype_ids, nids = gpb.map_to_per_ntype(local_nids)
            local_nids = F.asnumpy(nids)[F.asnumpy(ntype_ids) == 0]
        else:
            local_nids = F.asnumpy(local_nids)
        nodes1 = np.intersect1d(selected_nodes, local_nids)
1101
1102
1103
        nodes2 = node_split(
            node_mask, gpb, ntype=ntype, rank=i, force_even=False
        )
1104
        assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes2)))
1105
        for n in F.asnumpy(nodes2):
1106
1107
            assert n in local_nids

Da Zheng's avatar
Da Zheng committed
1108
        set_roles(num_parts * 2)
1109
1110
1111
1112
1113
1114
        nodes3 = node_split(
            node_mask, gpb, ntype=ntype, rank=i * 2, force_even=False
        )
        nodes4 = node_split(
            node_mask, gpb, ntype=ntype, rank=i * 2 + 1, force_even=False
        )
1115
1116
1117
        nodes5 = F.cat([nodes3, nodes4], 0)
        assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes5)))

Da Zheng's avatar
Da Zheng committed
1118
        set_roles(num_parts)
1119
        local_eids = F.nonzero_1d(part_g.edata["inner_edge"])
1120
        local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)
1121
1122
1123
1124
1125
1126
        if hetero:
            etype_ids, eids = gpb.map_to_per_etype(local_eids)
            local_eids = F.asnumpy(eids)[F.asnumpy(etype_ids) == 0]
        else:
            local_eids = F.asnumpy(local_eids)
        edges1 = np.intersect1d(selected_edges, local_eids)
1127
1128
1129
        edges2 = edge_split(
            edge_mask, gpb, etype=etype, rank=i, force_even=False
        )
1130
        assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges2)))
1131
        for e in F.asnumpy(edges2):
1132
1133
            assert e in local_eids

Da Zheng's avatar
Da Zheng committed
1134
        set_roles(num_parts * 2)
1135
1136
1137
1138
1139
1140
        edges3 = edge_split(
            edge_mask, gpb, etype=etype, rank=i * 2, force_even=False
        )
        edges4 = edge_split(
            edge_mask, gpb, etype=etype, rank=i * 2 + 1, force_even=False
        )
1141
1142
1143
        edges5 = F.cat([edges3, edges4], 0)
        assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges5)))

1144
1145

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
1146
1147
1148
1149
def test_split_even():
    g = create_random_graph(10000)
    num_parts = 4
    num_hops = 2
1150
1151
1152
1153
1154
1155
1156
1157
    partition_graph(
        g,
        "dist_graph_test",
        num_parts,
        "/tmp/dist_graph",
        num_hops=num_hops,
        part_method="metis",
    )
1158
1159
1160
1161
1162
1163
1164
1165
1166

    node_mask = np.random.randint(0, 100, size=g.number_of_nodes()) > 30
    edge_mask = np.random.randint(0, 100, size=g.number_of_edges()) > 30
    selected_nodes = np.nonzero(node_mask)[0]
    selected_edges = np.nonzero(edge_mask)[0]
    all_nodes1 = []
    all_nodes2 = []
    all_edges1 = []
    all_edges2 = []
Da Zheng's avatar
Da Zheng committed
1167
1168
1169
1170
1171

    # The code now collects the roles of all client processes and use the information
    # to determine how to split the workloads. Here is to simulate the multi-client
    # use case.
    def set_roles(num_clients):
1172
1173
1174
1175
1176
        dgl.distributed.role.CUR_ROLE = "default"
        dgl.distributed.role.GLOBAL_RANK = {i: i for i in range(num_clients)}
        dgl.distributed.role.PER_ROLE_RANK["default"] = {
            i: i for i in range(num_clients)
        }
Da Zheng's avatar
Da Zheng committed
1177

1178
    for i in range(num_parts):
Da Zheng's avatar
Da Zheng committed
1179
        set_roles(num_parts)
1180
1181
1182
1183
        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(
            "/tmp/dist_graph/dist_graph_test.json", i
        )
        local_nids = F.nonzero_1d(part_g.ndata["inner_node"])
1184
        local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
1185
        nodes = node_split(node_mask, gpb, rank=i, force_even=True)
1186
1187
        all_nodes1.append(nodes)
        subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(local_nids))
1188
1189
1190
1191
1192
        print(
            "part {} get {} nodes and {} are in the partition".format(
                i, len(nodes), len(subset)
            )
        )
1193

Da Zheng's avatar
Da Zheng committed
1194
        set_roles(num_parts * 2)
1195
1196
1197
        nodes1 = node_split(node_mask, gpb, rank=i * 2, force_even=True)
        nodes2 = node_split(node_mask, gpb, rank=i * 2 + 1, force_even=True)
        nodes3, _ = F.sort_1d(F.cat([nodes1, nodes2], 0))
1198
1199
        all_nodes2.append(nodes3)
        subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(nodes3))
1200
        print("intersection has", len(subset))
1201

Da Zheng's avatar
Da Zheng committed
1202
        set_roles(num_parts)
1203
        local_eids = F.nonzero_1d(part_g.edata["inner_edge"])
1204
        local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)
1205
        edges = edge_split(edge_mask, gpb, rank=i, force_even=True)
1206
1207
        all_edges1.append(edges)
        subset = np.intersect1d(F.asnumpy(edges), F.asnumpy(local_eids))
1208
1209
1210
1211
1212
        print(
            "part {} get {} edges and {} are in the partition".format(
                i, len(edges), len(subset)
            )
        )
1213

Da Zheng's avatar
Da Zheng committed
1214
        set_roles(num_parts * 2)
1215
1216
1217
        edges1 = edge_split(edge_mask, gpb, rank=i * 2, force_even=True)
        edges2 = edge_split(edge_mask, gpb, rank=i * 2 + 1, force_even=True)
        edges3, _ = F.sort_1d(F.cat([edges1, edges2], 0))
1218
1219
        all_edges2.append(edges3)
        subset = np.intersect1d(F.asnumpy(edges), F.asnumpy(edges3))
1220
        print("intersection has", len(subset))
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
    all_nodes1 = F.cat(all_nodes1, 0)
    all_edges1 = F.cat(all_edges1, 0)
    all_nodes2 = F.cat(all_nodes2, 0)
    all_edges2 = F.cat(all_edges2, 0)
    all_nodes = np.nonzero(node_mask)[0]
    all_edges = np.nonzero(edge_mask)[0]
    assert np.all(all_nodes == F.asnumpy(all_nodes1))
    assert np.all(all_edges == F.asnumpy(all_edges1))
    assert np.all(all_nodes == F.asnumpy(all_nodes2))
    assert np.all(all_edges == F.asnumpy(all_edges2))

1232

1233
1234
def prepare_dist(num_servers=1):
    generate_ip_config("kv_ip_config.txt", 1, num_servers=num_servers)
1235

1236
1237
1238

if __name__ == "__main__":
    os.makedirs("/tmp/dist_graph", exist_ok=True)
1239
    test_dist_emb_server_client()
1240
    test_server_client()
1241
1242
    test_split(True)
    test_split(False)
1243
    test_split_even()
1244
    test_standalone()
1245
    test_standalone_node_emb()