test_dist_graph_store.py 39.2 KB
Newer Older
1
import os
2
3
4
5
6
7

os.environ["OMP_NUM_THREADS"] = "1"
import math
import multiprocessing as mp
import pickle
import socket
8
9
10
import sys
import time
import unittest
11
12
13
14
from multiprocessing import Condition, Manager, Process, Value

import backend as F
import numpy as np
15
import pytest
16
import torch as th
17
18
19
from numpy.testing import assert_almost_equal, assert_array_equal
from scipy import sparse as spsp
from utils import create_random_graph, generate_ip_config, reset_envs
20

21
22
23
import dgl
from dgl.data.utils import load_graphs, save_graphs
from dgl.distributed import (
24
    DistEmbedding,
25
26
27
28
29
30
31
32
    DistGraph,
    DistGraphServer,
    edge_split,
    load_partition,
    load_partition_book,
    node_split,
    partition_graph,
)
33
from dgl.distributed.optim import SparseAdagrad
34
35
36
from dgl.heterograph_index import create_unitgraph_from_coo

if os.name != "nt":
37
38
39
    import fcntl
    import struct

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

def run_server(
    graph_name,
    server_id,
    server_count,
    num_clients,
    shared_mem,
    keep_alive=False,
):
    g = DistGraphServer(
        server_id,
        "kv_ip_config.txt",
        server_count,
        num_clients,
        "/tmp/dist_graph/{}.json".format(graph_name),
        disable_shared_mem=not shared_mem,
        graph_format=["csc", "coo"],
        keep_alive=keep_alive,
    )
    print("start server", server_id)
60
61
    # verify dtype of underlying graph
    cg = g.client_g
62
    for k, dtype in dgl.distributed.dist_graph.RESERVED_FIELD_DTYPE.items():
63
        if k in cg.ndata:
64
65
66
            assert (
                F.dtype(cg.ndata[k]) == dtype
            ), "Data type of {} in ndata should be {}.".format(k, dtype)
67
        if k in cg.edata:
68
69
70
            assert (
                F.dtype(cg.edata[k]) == dtype
            ), "Data type of {} in edata should be {}.".format(k, dtype)
71
72
    g.start()

73

74
75
76
def emb_init(shape, dtype):
    return F.zeros(shape, dtype, F.cpu())

77

78
def rand_init(shape, dtype):
79
    return F.tensor(np.random.normal(size=shape), F.float32)
80

81

82
83
84
85
86
87
88
def check_dist_graph_empty(g, num_clients, num_nodes, num_edges):
    # Test API
    assert g.number_of_nodes() == num_nodes
    assert g.number_of_edges() == num_edges

    # Test init node data
    new_shape = (g.number_of_nodes(), 2)
89
    g.ndata["test1"] = dgl.distributed.DistTensor(new_shape, F.int32)
90
    nids = F.arange(0, int(g.number_of_nodes() / 2))
91
    feats = g.ndata["test1"][nids]
92
93
94
    assert np.all(F.asnumpy(feats) == 0)

    # create a tensor and destroy a tensor and create it again.
95
96
97
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test3", init_func=rand_init
    )
98
    del test3
99
100
101
    test3 = dgl.distributed.DistTensor(
        (g.number_of_nodes(), 3), F.float32, "test3"
    )
102
103
104
105
    del test3

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
106
107
    g.ndata["test1"][nids] = new_feats
    feats = g.ndata["test1"][nids]
108
109
110
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
111
    assert g.node_attr_schemes()["test1"].dtype == F.int32
112

113
    print("end")
114

115
116
117
118
119

def run_client_empty(
    graph_name, part_id, server_count, num_clients, num_nodes, num_edges
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
120
    dgl.distributed.initialize("kv_ip_config.txt")
121
122
123
    gpb, graph_name, _, _ = load_partition_book(
        "/tmp/dist_graph/{}.json".format(graph_name), part_id, None
    )
124
125
126
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_graph_empty(g, num_clients, num_nodes, num_edges)

127

128
def check_server_client_empty(shared_mem, num_servers, num_clients):
129
    prepare_dist(num_servers)
130
131
132
133
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
134
135
    graph_name = "dist_graph_test_1"
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
136
137
138
139

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
140
    ctx = mp.get_context("spawn")
141
    for serv_id in range(num_servers):
142
143
144
145
        p = ctx.Process(
            target=run_server,
            args=(graph_name, serv_id, num_servers, num_clients, shared_mem),
        )
146
147
148
149
150
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
151
152
153
154
155
156
157
158
159
160
161
162
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client_empty,
            args=(
                graph_name,
                0,
                num_servers,
                num_clients,
                g.number_of_nodes(),
                g.number_of_edges(),
            ),
        )
163
164
165
166
167
168
169
170
171
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()

    for p in serv_ps:
        p.join()

172
    print("clients have terminated")
173

174
175
176
177
178
179
180
181
182
183
184
185

def run_client(
    graph_name,
    part_id,
    server_count,
    num_clients,
    num_nodes,
    num_edges,
    group_id,
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
    os.environ["DGL_GROUP_ID"] = str(group_id)
186
    dgl.distributed.initialize("kv_ip_config.txt")
187
188
189
    gpb, graph_name, _, _ = load_partition_book(
        "/tmp/dist_graph/{}.json".format(graph_name), part_id, None
    )
190
    g = DistGraph(graph_name, gpb=gpb)
191
    check_dist_graph(g, num_clients, num_nodes, num_edges)
192

193
194
195
196
197
198
199
200
201
202
203
204

def run_emb_client(
    graph_name,
    part_id,
    server_count,
    num_clients,
    num_nodes,
    num_edges,
    group_id,
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
    os.environ["DGL_GROUP_ID"] = str(group_id)
205
    dgl.distributed.initialize("kv_ip_config.txt")
206
207
208
    gpb, graph_name, _, _ = load_partition_book(
        "/tmp/dist_graph/{}.json".format(graph_name), part_id, None
    )
209
210
211
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_emb(g, num_clients, num_nodes, num_edges)

212

213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def run_optim_client(
    graph_name,
    part_id,
    server_count,
    rank,
    world_size,
    num_nodes,
    optimizer_states,
    save,
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "12355"
    dgl.distributed.initialize("kv_ip_config.txt")
    th.distributed.init_process_group(
        backend="gloo", rank=rank, world_size=world_size
    )
    gpb, graph_name, _, _ = load_partition_book(
        "/tmp/dist_graph/{}.json".format(graph_name), part_id, None
    )
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_optim_store(rank, num_nodes, optimizer_states, save)


def check_dist_optim_store(rank, num_nodes, optimizer_states, save):
    try:
        total_idx = F.arange(0, num_nodes, F.int64, F.cpu())
        emb = DistEmbedding(num_nodes, 1, name="optim_emb1", init_func=emb_init)
        emb2 = DistEmbedding(
            num_nodes, 1, name="optim_emb2", init_func=emb_init
        )
        if save:
            optimizer = SparseAdagrad([emb, emb2], lr=0.1, eps=1e-08)
            if rank == 0:
                optimizer._state["optim_emb1"][total_idx] = optimizer_states[0]
                optimizer._state["optim_emb2"][total_idx] = optimizer_states[1]
            optimizer.save("/tmp/dist_graph/emb.pt")
        else:
            optimizer = SparseAdagrad([emb, emb2], lr=0.001, eps=2e-08)
            optimizer.load("/tmp/dist_graph/emb.pt")
            if rank == 0:
                assert F.allclose(
                    optimizer._state["optim_emb1"][total_idx],
                    optimizer_states[0],
                    0.0,
                    0.0,
                )
                assert F.allclose(
                    optimizer._state["optim_emb2"][total_idx],
                    optimizer_states[1],
                    0.0,
                    0.0,
                )
                assert 0.1 == optimizer._lr
                assert 1e-08 == optimizer._eps
            th.distributed.barrier()
    except Exception as e:
        print(e)
        sys.exit(-1)


274
275
276
277
def run_client_hierarchy(
    graph_name, part_id, server_count, node_mask, edge_mask, return_dict
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
278
    dgl.distributed.initialize("kv_ip_config.txt")
279
280
281
    gpb, graph_name, _, _ = load_partition_book(
        "/tmp/dist_graph/{}.json".format(graph_name), part_id, None
    )
282
283
284
    g = DistGraph(graph_name, gpb=gpb)
    node_mask = F.tensor(node_mask)
    edge_mask = F.tensor(edge_mask)
285
286
287
288
289
290
291
292
293
294
    nodes = node_split(
        node_mask,
        g.get_partition_book(),
        node_trainer_ids=g.ndata["trainer_id"],
    )
    edges = edge_split(
        edge_mask,
        g.get_partition_book(),
        edge_trainer_ids=g.edata["trainer_id"],
    )
295
296
297
    rank = g.rank()
    return_dict[rank] = (nodes, edges)

298

299
300
301
def check_dist_emb(g, num_clients, num_nodes, num_edges):
    # Test sparse emb
    try:
302
        emb = DistEmbedding(g.number_of_nodes(), 1, "emb1", emb_init)
303
        nids = F.arange(0, int(g.number_of_nodes()))
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
        lr = 0.001
        optimizer = SparseAdagrad([emb], lr=lr)
        with F.record_grad():
            feats = emb(nids)
            assert np.all(F.asnumpy(feats) == np.zeros((len(nids), 1)))
            loss = F.sum(feats + 1, 0)
        loss.backward()
        optimizer.step()
        feats = emb(nids)
        if num_clients == 1:
            assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * -lr)
        rest = np.setdiff1d(np.arange(g.number_of_nodes()), F.asnumpy(nids))
        feats1 = emb(rest)
        assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))

319
320
321
322
        policy = dgl.distributed.PartitionPolicy("node", g.get_partition_book())
        grad_sum = dgl.distributed.DistTensor(
            (g.number_of_nodes(), 1), F.float32, "emb1_sum", policy
        )
323
        if num_clients == 1:
324
325
326
327
            assert np.all(
                F.asnumpy(grad_sum[nids])
                == np.ones((len(nids), 1)) * num_clients
            )
328
329
        assert np.all(F.asnumpy(grad_sum[rest]) == np.zeros((len(rest), 1)))

330
        emb = DistEmbedding(g.number_of_nodes(), 1, "emb2", emb_init)
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
        with F.no_grad():
            feats1 = emb(nids)
        assert np.all(F.asnumpy(feats1) == 0)

        optimizer = SparseAdagrad([emb], lr=lr)
        with F.record_grad():
            feats1 = emb(nids)
            feats2 = emb(nids)
            feats = F.cat([feats1, feats2], 0)
            assert np.all(F.asnumpy(feats) == np.zeros((len(nids) * 2, 1)))
            loss = F.sum(feats + 1, 0)
        loss.backward()
        optimizer.step()
        with F.no_grad():
            feats = emb(nids)
        if num_clients == 1:
347
348
349
            assert_almost_equal(
                F.asnumpy(feats), np.ones((len(nids), 1)) * 1 * -lr
            )
350
351
352
353
354
        rest = np.setdiff1d(np.arange(g.number_of_nodes()), F.asnumpy(nids))
        feats1 = emb(rest)
        assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))
    except NotImplementedError as e:
        pass
355
356
357
    except Exception as e:
        print(e)
        sys.exit(-1)
358

359

360
def check_dist_graph(g, num_clients, num_nodes, num_edges):
361
362
363
364
365
366
    # Test API
    assert g.number_of_nodes() == num_nodes
    assert g.number_of_edges() == num_edges

    # Test reading node data
    nids = F.arange(0, int(g.number_of_nodes() / 2))
367
    feats1 = g.ndata["features"][nids]
368
369
370
371
372
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == nids))

    # Test reading edge data
    eids = F.arange(0, int(g.number_of_edges() / 2))
373
    feats1 = g.edata["features"][eids]
374
375
376
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == eids))

377
378
379
380
381
    # Test edge_subgraph
    sg = g.edge_subgraph(eids)
    assert sg.num_edges() == len(eids)
    assert F.array_equal(sg.edata[dgl.EID], eids)

382
383
    # Test init node data
    new_shape = (g.number_of_nodes(), 2)
384
    test1 = dgl.distributed.DistTensor(new_shape, F.int32)
385
386
    g.ndata["test1"] = test1
    feats = g.ndata["test1"][nids]
387
    assert np.all(F.asnumpy(feats) == 0)
388
    assert test1.count_nonzero() == 0
389

390
    # reference to a one that exists
391
392
393
394
    test2 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test2", init_func=rand_init
    )
    test3 = dgl.distributed.DistTensor(new_shape, F.float32, "test2")
395
396
397
    assert np.all(F.asnumpy(test2[nids]) == F.asnumpy(test3[nids]))

    # create a tensor and destroy a tensor and create it again.
398
399
400
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test3", init_func=rand_init
    )
401
    del test3
402
403
404
    test3 = dgl.distributed.DistTensor(
        (g.number_of_nodes(), 3), F.float32, "test3"
    )
405
406
    del test3

Da Zheng's avatar
Da Zheng committed
407
    # add tests for anonymous distributed tensor.
408
409
410
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
Da Zheng's avatar
Da Zheng committed
411
    data = test3[0:10]
412
413
414
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
Da Zheng's avatar
Da Zheng committed
415
    del test3
416
417
418
    test5 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
Da Zheng's avatar
Da Zheng committed
419
420
    assert np.sum(F.asnumpy(test5[0:10] != data)) > 0

421
    # test a persistent tesnor
422
423
424
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test4", init_func=rand_init, persistent=True
    )
425
426
    del test4
    try:
427
428
429
430
        test4 = dgl.distributed.DistTensor(
            (g.number_of_nodes(), 3), F.float32, "test4"
        )
        raise Exception("")
431
432
    except:
        pass
433
434
435

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
436
437
    g.ndata["test1"][nids] = new_feats
    feats = g.ndata["test1"][nids]
438
439
440
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
441
442
443
444
445
446
    assert len(g.ndata["features"]) == g.number_of_nodes()
    assert g.ndata["features"].shape == (g.number_of_nodes(), 1)
    assert g.ndata["features"].dtype == F.int64
    assert g.node_attr_schemes()["features"].dtype == F.int64
    assert g.node_attr_schemes()["test1"].dtype == F.int32
    assert g.node_attr_schemes()["features"].shape == (1,)
447

448
449
    selected_nodes = np.random.randint(0, 100, size=g.number_of_nodes()) > 30
    # Test node split
450
    nodes = node_split(selected_nodes, g.get_partition_book())
451
452
453
454
455
456
    nodes = F.asnumpy(nodes)
    # We only have one partition, so the local nodes are basically all nodes in the graph.
    local_nids = np.arange(g.number_of_nodes())
    for n in nodes:
        assert n in local_nids

457
458
    print("end")

459

460
461
462
def check_dist_emb_server_client(
    shared_mem, num_servers, num_clients, num_groups=1
):
463
    prepare_dist(num_servers)
464
465
466
467
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
468
469
470
471
472
473
    graph_name = (
        f"check_dist_emb_{shared_mem}_{num_servers}_{num_clients}_{num_groups}"
    )
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
474
475
476
477

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
478
    ctx = mp.get_context("spawn")
479
    keep_alive = num_groups > 1
480
    for serv_id in range(num_servers):
481
482
483
484
485
486
487
488
489
490
491
        p = ctx.Process(
            target=run_server,
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                shared_mem,
                keep_alive,
            ),
        )
492
493
494
495
496
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
497
        for group_id in range(num_groups):
498
499
500
501
502
503
504
505
506
507
508
509
510
            print("start client[{}] for group[{}]".format(cli_id, group_id))
            p = ctx.Process(
                target=run_emb_client,
                args=(
                    graph_name,
                    0,
                    num_servers,
                    num_clients,
                    g.number_of_nodes(),
                    g.number_of_edges(),
                    group_id,
                ),
            )
511
            p.start()
512
            time.sleep(1)  # avoid race condition when instantiating DistGraph
513
            cli_ps.append(p)
514
515
516

    for p in cli_ps:
        p.join()
517
        assert p.exitcode == 0
518

519
520
521
522
523
    if keep_alive:
        for p in serv_ps:
            assert p.is_alive()
        # force shutdown server
        dgl.distributed.shutdown_servers("kv_ip_config.txt", num_servers)
524
525
526
    for p in serv_ps:
        p.join()

527
528
    print("clients have terminated")

529

530
def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
531
    prepare_dist(num_servers)
532
533
534
535
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
536
537
538
539
    graph_name = f"check_server_client_{shared_mem}_{num_servers}_{num_clients}_{num_groups}"
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
540
541
542
543

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
544
    ctx = mp.get_context("spawn")
545
    keep_alive = num_groups > 1
546
    for serv_id in range(num_servers):
547
548
549
550
551
552
553
554
555
556
557
        p = ctx.Process(
            target=run_server,
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                shared_mem,
                keep_alive,
            ),
        )
558
559
560
        serv_ps.append(p)
        p.start()

561
    # launch different client groups simultaneously
562
    cli_ps = []
563
    for cli_id in range(num_clients):
564
        for group_id in range(num_groups):
565
566
567
568
569
570
571
572
573
574
575
576
577
            print("start client[{}] for group[{}]".format(cli_id, group_id))
            p = ctx.Process(
                target=run_client,
                args=(
                    graph_name,
                    0,
                    num_servers,
                    num_clients,
                    g.number_of_nodes(),
                    g.number_of_edges(),
                    group_id,
                ),
            )
578
            p.start()
579
            time.sleep(1)  # avoid race condition when instantiating DistGraph
580
            cli_ps.append(p)
581
582
    for p in cli_ps:
        p.join()
583

584
585
586
587
588
    if keep_alive:
        for p in serv_ps:
            assert p.is_alive()
        # force shutdown server
        dgl.distributed.shutdown_servers("kv_ip_config.txt", num_servers)
589
590
591
    for p in serv_ps:
        p.join()

592
593
    print("clients have terminated")

594

595
def check_server_client_hierarchy(shared_mem, num_servers, num_clients):
596
    prepare_dist(num_servers)
597
598
599
600
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
601
602
603
604
605
606
607
608
609
610
    graph_name = "dist_graph_test_2"
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(
        g,
        graph_name,
        num_parts,
        "/tmp/dist_graph",
        num_trainers_per_machine=num_clients,
    )
611
612
613
614

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
615
    ctx = mp.get_context("spawn")
616
    for serv_id in range(num_servers):
617
618
619
620
        p = ctx.Process(
            target=run_server,
            args=(graph_name, serv_id, num_servers, num_clients, shared_mem),
        )
621
622
623
624
625
626
627
628
        serv_ps.append(p)
        p.start()

    cli_ps = []
    manager = mp.Manager()
    return_dict = manager.dict()
    node_mask = np.zeros((g.number_of_nodes(),), np.int32)
    edge_mask = np.zeros((g.number_of_edges(),), np.int32)
629
630
631
632
633
634
    nodes = np.random.choice(
        g.number_of_nodes(), g.number_of_nodes() // 10, replace=False
    )
    edges = np.random.choice(
        g.number_of_edges(), g.number_of_edges() // 10, replace=False
    )
635
636
637
638
639
    node_mask[nodes] = 1
    edge_mask[edges] = 1
    nodes = np.sort(nodes)
    edges = np.sort(edges)
    for cli_id in range(num_clients):
640
641
642
643
644
645
646
647
648
649
650
651
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client_hierarchy,
            args=(
                graph_name,
                0,
                num_servers,
                node_mask,
                edge_mask,
                return_dict,
            ),
        )
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
    for p in serv_ps:
        p.join()

    nodes1 = []
    edges1 = []
    for n, e in return_dict.values():
        nodes1.append(n)
        edges1.append(e)
    nodes1, _ = F.sort_1d(F.cat(nodes1, 0))
    edges1, _ = F.sort_1d(F.cat(edges1, 0))
    assert np.all(F.asnumpy(nodes1) == nodes)
    assert np.all(F.asnumpy(edges1) == edges)

670
    print("clients have terminated")
671

672

673
674
675
676
def run_client_hetero(
    graph_name, part_id, server_count, num_clients, num_nodes, num_edges
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
677
    dgl.distributed.initialize("kv_ip_config.txt")
678
679
680
    gpb, graph_name, _, _ = load_partition_book(
        "/tmp/dist_graph/{}.json".format(graph_name), part_id, None
    )
681
682
683
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_graph_hetero(g, num_clients, num_nodes, num_edges)

684

685
def create_random_hetero():
686
687
    num_nodes = {"n1": 10000, "n2": 10010, "n3": 10020}
    etypes = [("n1", "r1", "n2"), ("n1", "r2", "n3"), ("n2", "r3", "n3")]
688
689
690
    edges = {}
    for etype in etypes:
        src_ntype, _, dst_ntype = etype
691
692
693
694
695
696
697
        arr = spsp.random(
            num_nodes[src_ntype],
            num_nodes[dst_ntype],
            density=0.001,
            format="coo",
            random_state=100,
        )
698
699
        edges[etype] = (arr.row, arr.col)
    g = dgl.heterograph(edges, num_nodes)
700
701
702
703
704
705
    g.nodes["n1"].data["feat"] = F.unsqueeze(
        F.arange(0, g.number_of_nodes("n1")), 1
    )
    g.edges["r1"].data["feat"] = F.unsqueeze(
        F.arange(0, g.number_of_edges("r1")), 1
    )
706
707
    return g

708

709
710
711
712
713
714
715
716
def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges):
    # Test API
    for ntype in num_nodes:
        assert ntype in g.ntypes
        assert num_nodes[ntype] == g.number_of_nodes(ntype)
    for etype in num_edges:
        assert etype in g.etypes
        assert num_edges[etype] == g.number_of_edges(etype)
717
    etypes = [("n1", "r1", "n2"), ("n1", "r2", "n3"), ("n2", "r3", "n3")]
718
719
720
721
    for i, etype in enumerate(g.canonical_etypes):
        assert etype[0] == etypes[i][0]
        assert etype[1] == etypes[i][1]
        assert etype[2] == etypes[i][2]
722
723
724
725
    assert g.number_of_nodes() == sum([num_nodes[ntype] for ntype in num_nodes])
    assert g.number_of_edges() == sum([num_edges[etype] for etype in num_edges])

    # Test reading node data
726
727
    nids = F.arange(0, int(g.number_of_nodes("n1") / 2))
    feats1 = g.nodes["n1"].data["feat"][nids]
728
729
730
731
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == nids))

    # Test reading edge data
732
733
    eids = F.arange(0, int(g.number_of_edges("r1") / 2))
    feats1 = g.edges["r1"].data["feat"][eids]
734
735
736
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == eids))

737
    # Test edge_subgraph
738
    sg = g.edge_subgraph({"r1": eids})
739
740
    assert sg.num_edges() == len(eids)
    assert F.array_equal(sg.edata[dgl.EID], eids)
741
    sg = g.edge_subgraph({("n1", "r1", "n2"): eids})
742
743
744
    assert sg.num_edges() == len(eids)
    assert F.array_equal(sg.edata[dgl.EID], eids)

745
    # Test init node data
746
747
748
    new_shape = (g.number_of_nodes("n1"), 2)
    g.nodes["n1"].data["test1"] = dgl.distributed.DistTensor(new_shape, F.int32)
    feats = g.nodes["n1"].data["test1"][nids]
749
750
751
    assert np.all(F.asnumpy(feats) == 0)

    # create a tensor and destroy a tensor and create it again.
752
753
754
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test3", init_func=rand_init
    )
755
    del test3
756
757
758
    test3 = dgl.distributed.DistTensor(
        (g.number_of_nodes("n1"), 3), F.float32, "test3"
    )
759
760
761
    del test3

    # add tests for anonymous distributed tensor.
762
763
764
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
765
    data = test3[0:10]
766
767
768
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
769
    del test3
770
771
772
    test5 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
773
774
775
    assert np.sum(F.asnumpy(test5[0:10] != data)) > 0

    # test a persistent tesnor
776
777
778
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test4", init_func=rand_init, persistent=True
    )
779
780
    del test4
    try:
781
782
783
784
        test4 = dgl.distributed.DistTensor(
            (g.number_of_nodes("n1"), 3), F.float32, "test4"
        )
        raise Exception("")
785
786
787
788
789
    except:
        pass

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
790
791
    g.nodes["n1"].data["test1"][nids] = new_feats
    feats = g.nodes["n1"].data["test1"][nids]
792
793
794
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
795
796
797
    assert len(g.nodes["n1"].data["feat"]) == g.number_of_nodes("n1")
    assert g.nodes["n1"].data["feat"].shape == (g.number_of_nodes("n1"), 1)
    assert g.nodes["n1"].data["feat"].dtype == F.int64
798

799
800
801
    selected_nodes = (
        np.random.randint(0, 100, size=g.number_of_nodes("n1")) > 30
    )
802
    # Test node split
803
    nodes = node_split(selected_nodes, g.get_partition_book(), ntype="n1")
804
805
    nodes = F.asnumpy(nodes)
    # We only have one partition, so the local nodes are basically all nodes in the graph.
806
    local_nids = np.arange(g.number_of_nodes("n1"))
807
808
809
    for n in nodes:
        assert n in local_nids

810
811
    print("end")

812
813

def check_server_client_hetero(shared_mem, num_servers, num_clients):
814
    prepare_dist(num_servers)
815
816
817
818
    g = create_random_hetero()

    # Partition the graph
    num_parts = 1
819
820
    graph_name = "dist_graph_test_3"
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
821
822
823
824

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
825
    ctx = mp.get_context("spawn")
826
    for serv_id in range(num_servers):
827
828
829
830
        p = ctx.Process(
            target=run_server,
            args=(graph_name, serv_id, num_servers, num_clients, shared_mem),
        )
831
832
833
834
835
836
837
        serv_ps.append(p)
        p.start()

    cli_ps = []
    num_nodes = {ntype: g.number_of_nodes(ntype) for ntype in g.ntypes}
    num_edges = {etype: g.number_of_edges(etype) for etype in g.etypes}
    for cli_id in range(num_clients):
838
839
840
841
842
843
844
845
846
847
848
849
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client_hetero,
            args=(
                graph_name,
                0,
                num_servers,
                num_clients,
                num_nodes,
                num_edges,
            ),
        )
850
851
852
853
854
855
856
857
858
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()

    for p in serv_ps:
        p.join()

859
860
    print("clients have terminated")

861

862
863
864
865
866
867
868
869
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support some of operations in DistGraph",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
870
def test_server_client():
871
    reset_envs()
872
    os.environ["DGL_DIST_MODE"] = "distributed"
873
    check_server_client_hierarchy(False, 1, 4)
874
    check_server_client_empty(True, 1, 1)
875
876
    check_server_client_hetero(True, 1, 1)
    check_server_client_hetero(False, 1, 1)
877
878
    check_server_client(True, 1, 1)
    check_server_client(False, 1, 1)
879
880
    # [TODO][Rhett] Tests for multiple groups may fail sometimes and
    # root cause is unknown. Let's disable them for now.
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
    # check_server_client(True, 2, 2)
    # check_server_client(True, 1, 1, 2)
    # check_server_client(False, 1, 1, 2)
    # check_server_client(True, 2, 2, 2)


@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support distributed DistEmbedding",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="Mxnet doesn't support distributed DistEmbedding",
)
896
def test_dist_emb_server_client():
897
    reset_envs()
898
    os.environ["DGL_DIST_MODE"] = "distributed"
899
900
    check_dist_emb_server_client(True, 1, 1)
    check_dist_emb_server_client(False, 1, 1)
901
902
    # [TODO][Rhett] Tests for multiple groups may fail sometimes and
    # root cause is unknown. Let's disable them for now.
903
904
905
906
907
908
    # check_dist_emb_server_client(True, 2, 2)
    # check_dist_emb_server_client(True, 1, 1, 2)
    # check_dist_emb_server_client(False, 1, 1, 2)
    # check_dist_emb_server_client(True, 2, 2, 2)


909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support distributed Optimizer",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="Mxnet doesn't support distributed Optimizer",
)
def test_dist_optim_server_client():
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    optimizer_states = []
    num_nodes = 10000
    optimizer_states.append(F.uniform((num_nodes, 1), F.float32, F.cpu(), 0, 1))
    optimizer_states.append(F.uniform((num_nodes, 1), F.float32, F.cpu(), 0, 1))
    check_dist_optim_server_client(num_nodes, 1, 4, optimizer_states, True)
    check_dist_optim_server_client(num_nodes, 1, 8, optimizer_states, False)
    check_dist_optim_server_client(num_nodes, 1, 2, optimizer_states, False)


def check_dist_optim_server_client(
    num_nodes, num_servers, num_clients, optimizer_states, save
):
    graph_name = f"check_dist_optim_{num_servers}_store"
    if save:
        prepare_dist(num_servers)
        g = create_random_graph(num_nodes)

        # Partition the graph
        num_parts = 1
        g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
        g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
        partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
    ctx = mp.get_context("spawn")
    for serv_id in range(num_servers):
        p = ctx.Process(
            target=run_server,
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                True,
                False,
            ),
        )
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
        print("start client[{}] for group[0]".format(cli_id))
        p = ctx.Process(
            target=run_optim_client,
            args=(
                graph_name,
                0,
                num_servers,
                cli_id,
                num_clients,
                num_nodes,
                optimizer_states,
                save,
            ),
        )
        p.start()
        time.sleep(1)  # avoid race condition when instantiating DistGraph
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
        assert p.exitcode == 0

    for p in serv_ps:
        p.join()


990
991
992
993
994
995
996
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support some of operations in DistGraph",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
997
def test_standalone():
998
    reset_envs()
999
    os.environ["DGL_DIST_MODE"] = "standalone"
Da Zheng's avatar
Da Zheng committed
1000

1001
1002
1003
    g = create_random_graph(10000)
    # Partition the graph
    num_parts = 1
1004
1005
1006
1007
    graph_name = "dist_graph_test_3"
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
1008
1009

    dgl.distributed.initialize("kv_ip_config.txt")
1010
1011
1012
    dist_g = DistGraph(
        graph_name, part_config="/tmp/dist_graph/{}.json".format(graph_name)
    )
1013
    check_dist_graph(dist_g, 1, g.number_of_nodes(), g.number_of_edges())
1014
1015
    dgl.distributed.exit_client()  # this is needed since there's two test here in one process

1016

1017
1018
1019
1020
1021
1022
1023
1024
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support distributed DistEmbedding",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="Mxnet doesn't support distributed DistEmbedding",
)
1025
def test_standalone_node_emb():
1026
    reset_envs()
1027
    os.environ["DGL_DIST_MODE"] = "standalone"
1028
1029
1030
1031

    g = create_random_graph(10000)
    # Partition the graph
    num_parts = 1
1032
1033
1034
1035
    graph_name = "dist_graph_test_3"
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
1036
1037

    dgl.distributed.initialize("kv_ip_config.txt")
1038
1039
1040
    dist_g = DistGraph(
        graph_name, part_config="/tmp/dist_graph/{}.json".format(graph_name)
    )
1041
    check_dist_emb(dist_g, 1, g.number_of_nodes(), g.number_of_edges())
1042
1043
    dgl.distributed.exit_client()  # this is needed since there's two test here in one process

1044

1045
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
1046
1047
1048
1049
@pytest.mark.parametrize("hetero", [True, False])
def test_split(hetero):
    if hetero:
        g = create_random_hetero()
1050
1051
        ntype = "n1"
        etype = "r1"
1052
1053
    else:
        g = create_random_graph(10000)
1054
1055
        ntype = "_N"
        etype = "_E"
1056
1057
    num_parts = 4
    num_hops = 2
1058
1059
1060
1061
1062
1063
1064
1065
    partition_graph(
        g,
        "dist_graph_test",
        num_parts,
        "/tmp/dist_graph",
        num_hops=num_hops,
        part_method="metis",
    )
1066

1067
1068
    node_mask = np.random.randint(0, 100, size=g.number_of_nodes(ntype)) > 30
    edge_mask = np.random.randint(0, 100, size=g.number_of_edges(etype)) > 30
1069
1070
    selected_nodes = np.nonzero(node_mask)[0]
    selected_edges = np.nonzero(edge_mask)[0]
Da Zheng's avatar
Da Zheng committed
1071
1072
1073
1074
1075

    # The code now collects the roles of all client processes and use the information
    # to determine how to split the workloads. Here is to simulate the multi-client
    # use case.
    def set_roles(num_clients):
1076
1077
1078
1079
1080
        dgl.distributed.role.CUR_ROLE = "default"
        dgl.distributed.role.GLOBAL_RANK = {i: i for i in range(num_clients)}
        dgl.distributed.role.PER_ROLE_RANK["default"] = {
            i: i for i in range(num_clients)
        }
Da Zheng's avatar
Da Zheng committed
1081

1082
    for i in range(num_parts):
Da Zheng's avatar
Da Zheng committed
1083
        set_roles(num_parts)
1084
1085
1086
1087
        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(
            "/tmp/dist_graph/dist_graph_test.json", i
        )
        local_nids = F.nonzero_1d(part_g.ndata["inner_node"])
1088
        local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
1089
1090
1091
1092
1093
1094
        if hetero:
            ntype_ids, nids = gpb.map_to_per_ntype(local_nids)
            local_nids = F.asnumpy(nids)[F.asnumpy(ntype_ids) == 0]
        else:
            local_nids = F.asnumpy(local_nids)
        nodes1 = np.intersect1d(selected_nodes, local_nids)
1095
1096
1097
        nodes2 = node_split(
            node_mask, gpb, ntype=ntype, rank=i, force_even=False
        )
1098
        assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes2)))
1099
        for n in F.asnumpy(nodes2):
1100
1101
            assert n in local_nids

Da Zheng's avatar
Da Zheng committed
1102
        set_roles(num_parts * 2)
1103
1104
1105
1106
1107
1108
        nodes3 = node_split(
            node_mask, gpb, ntype=ntype, rank=i * 2, force_even=False
        )
        nodes4 = node_split(
            node_mask, gpb, ntype=ntype, rank=i * 2 + 1, force_even=False
        )
1109
1110
1111
        nodes5 = F.cat([nodes3, nodes4], 0)
        assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes5)))

Da Zheng's avatar
Da Zheng committed
1112
        set_roles(num_parts)
1113
        local_eids = F.nonzero_1d(part_g.edata["inner_edge"])
1114
        local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)
1115
1116
1117
1118
1119
1120
        if hetero:
            etype_ids, eids = gpb.map_to_per_etype(local_eids)
            local_eids = F.asnumpy(eids)[F.asnumpy(etype_ids) == 0]
        else:
            local_eids = F.asnumpy(local_eids)
        edges1 = np.intersect1d(selected_edges, local_eids)
1121
1122
1123
        edges2 = edge_split(
            edge_mask, gpb, etype=etype, rank=i, force_even=False
        )
1124
        assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges2)))
1125
        for e in F.asnumpy(edges2):
1126
1127
            assert e in local_eids

Da Zheng's avatar
Da Zheng committed
1128
        set_roles(num_parts * 2)
1129
1130
1131
1132
1133
1134
        edges3 = edge_split(
            edge_mask, gpb, etype=etype, rank=i * 2, force_even=False
        )
        edges4 = edge_split(
            edge_mask, gpb, etype=etype, rank=i * 2 + 1, force_even=False
        )
1135
1136
1137
        edges5 = F.cat([edges3, edges4], 0)
        assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges5)))

1138
1139

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
1140
1141
1142
1143
def test_split_even():
    g = create_random_graph(10000)
    num_parts = 4
    num_hops = 2
1144
1145
1146
1147
1148
1149
1150
1151
    partition_graph(
        g,
        "dist_graph_test",
        num_parts,
        "/tmp/dist_graph",
        num_hops=num_hops,
        part_method="metis",
    )
1152
1153
1154
1155
1156
1157
1158
1159
1160

    node_mask = np.random.randint(0, 100, size=g.number_of_nodes()) > 30
    edge_mask = np.random.randint(0, 100, size=g.number_of_edges()) > 30
    selected_nodes = np.nonzero(node_mask)[0]
    selected_edges = np.nonzero(edge_mask)[0]
    all_nodes1 = []
    all_nodes2 = []
    all_edges1 = []
    all_edges2 = []
Da Zheng's avatar
Da Zheng committed
1161
1162
1163
1164
1165

    # The code now collects the roles of all client processes and use the information
    # to determine how to split the workloads. Here is to simulate the multi-client
    # use case.
    def set_roles(num_clients):
1166
1167
1168
1169
1170
        dgl.distributed.role.CUR_ROLE = "default"
        dgl.distributed.role.GLOBAL_RANK = {i: i for i in range(num_clients)}
        dgl.distributed.role.PER_ROLE_RANK["default"] = {
            i: i for i in range(num_clients)
        }
Da Zheng's avatar
Da Zheng committed
1171

1172
    for i in range(num_parts):
Da Zheng's avatar
Da Zheng committed
1173
        set_roles(num_parts)
1174
1175
1176
1177
        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(
            "/tmp/dist_graph/dist_graph_test.json", i
        )
        local_nids = F.nonzero_1d(part_g.ndata["inner_node"])
1178
        local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
1179
        nodes = node_split(node_mask, gpb, rank=i, force_even=True)
1180
1181
        all_nodes1.append(nodes)
        subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(local_nids))
1182
1183
1184
1185
1186
        print(
            "part {} get {} nodes and {} are in the partition".format(
                i, len(nodes), len(subset)
            )
        )
1187

Da Zheng's avatar
Da Zheng committed
1188
        set_roles(num_parts * 2)
1189
1190
1191
        nodes1 = node_split(node_mask, gpb, rank=i * 2, force_even=True)
        nodes2 = node_split(node_mask, gpb, rank=i * 2 + 1, force_even=True)
        nodes3, _ = F.sort_1d(F.cat([nodes1, nodes2], 0))
1192
1193
        all_nodes2.append(nodes3)
        subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(nodes3))
1194
        print("intersection has", len(subset))
1195

Da Zheng's avatar
Da Zheng committed
1196
        set_roles(num_parts)
1197
        local_eids = F.nonzero_1d(part_g.edata["inner_edge"])
1198
        local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)
1199
        edges = edge_split(edge_mask, gpb, rank=i, force_even=True)
1200
1201
        all_edges1.append(edges)
        subset = np.intersect1d(F.asnumpy(edges), F.asnumpy(local_eids))
1202
1203
1204
1205
1206
        print(
            "part {} get {} edges and {} are in the partition".format(
                i, len(edges), len(subset)
            )
        )
1207

Da Zheng's avatar
Da Zheng committed
1208
        set_roles(num_parts * 2)
1209
1210
1211
        edges1 = edge_split(edge_mask, gpb, rank=i * 2, force_even=True)
        edges2 = edge_split(edge_mask, gpb, rank=i * 2 + 1, force_even=True)
        edges3, _ = F.sort_1d(F.cat([edges1, edges2], 0))
1212
1213
        all_edges2.append(edges3)
        subset = np.intersect1d(F.asnumpy(edges), F.asnumpy(edges3))
1214
        print("intersection has", len(subset))
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
    all_nodes1 = F.cat(all_nodes1, 0)
    all_edges1 = F.cat(all_edges1, 0)
    all_nodes2 = F.cat(all_nodes2, 0)
    all_edges2 = F.cat(all_edges2, 0)
    all_nodes = np.nonzero(node_mask)[0]
    all_edges = np.nonzero(edge_mask)[0]
    assert np.all(all_nodes == F.asnumpy(all_nodes1))
    assert np.all(all_edges == F.asnumpy(all_edges1))
    assert np.all(all_nodes == F.asnumpy(all_nodes2))
    assert np.all(all_edges == F.asnumpy(all_edges2))

1226

1227
1228
def prepare_dist(num_servers=1):
    generate_ip_config("kv_ip_config.txt", 1, num_servers=num_servers)
1229

1230
1231
1232

if __name__ == "__main__":
    os.makedirs("/tmp/dist_graph", exist_ok=True)
1233
    test_dist_emb_server_client()
1234
    test_server_client()
1235
1236
    test_split(True)
    test_split(False)
1237
    test_split_even()
1238
    test_standalone()
1239
    test_standalone_node_emb()