test_dist_graph_store.py 40.5 KB
Newer Older
1
import os
2
3
4
5
6
7

os.environ["OMP_NUM_THREADS"] = "1"
import math
import multiprocessing as mp
import pickle
import socket
8
9
10
import sys
import time
import unittest
11
12
13
14
from multiprocessing import Condition, Manager, Process, Value

import backend as F
import numpy as np
15
import pytest
16
import torch as th
17
18
19
from numpy.testing import assert_almost_equal, assert_array_equal
from scipy import sparse as spsp
from utils import create_random_graph, generate_ip_config, reset_envs
20

21
22
23
import dgl
from dgl.data.utils import load_graphs, save_graphs
from dgl.distributed import (
24
    DistEmbedding,
25
26
27
28
29
30
31
32
    DistGraph,
    DistGraphServer,
    edge_split,
    load_partition,
    load_partition_book,
    node_split,
    partition_graph,
)
33
from dgl.distributed.optim import SparseAdagrad
34
35
36
from dgl.heterograph_index import create_unitgraph_from_coo

if os.name != "nt":
37
38
39
    import fcntl
    import struct

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

def run_server(
    graph_name,
    server_id,
    server_count,
    num_clients,
    shared_mem,
    keep_alive=False,
):
    g = DistGraphServer(
        server_id,
        "kv_ip_config.txt",
        server_count,
        num_clients,
        "/tmp/dist_graph/{}.json".format(graph_name),
        disable_shared_mem=not shared_mem,
        graph_format=["csc", "coo"],
        keep_alive=keep_alive,
    )
    print("start server", server_id)
60
61
    # verify dtype of underlying graph
    cg = g.client_g
62
    for k, dtype in dgl.distributed.dist_graph.RESERVED_FIELD_DTYPE.items():
63
        if k in cg.ndata:
64
65
66
            assert (
                F.dtype(cg.ndata[k]) == dtype
            ), "Data type of {} in ndata should be {}.".format(k, dtype)
67
        if k in cg.edata:
68
69
70
            assert (
                F.dtype(cg.edata[k]) == dtype
            ), "Data type of {} in edata should be {}.".format(k, dtype)
71
72
    g.start()

73

74
75
76
def emb_init(shape, dtype):
    return F.zeros(shape, dtype, F.cpu())

77

78
def rand_init(shape, dtype):
79
    return F.tensor(np.random.normal(size=shape), F.float32)
80

81

82
83
84
85
86
87
88
def check_dist_graph_empty(g, num_clients, num_nodes, num_edges):
    # Test API
    assert g.number_of_nodes() == num_nodes
    assert g.number_of_edges() == num_edges

    # Test init node data
    new_shape = (g.number_of_nodes(), 2)
89
    g.ndata["test1"] = dgl.distributed.DistTensor(new_shape, F.int32)
90
    nids = F.arange(0, int(g.number_of_nodes() / 2))
91
    feats = g.ndata["test1"][nids]
92
93
94
    assert np.all(F.asnumpy(feats) == 0)

    # create a tensor and destroy a tensor and create it again.
95
96
97
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test3", init_func=rand_init
    )
98
    del test3
99
100
101
    test3 = dgl.distributed.DistTensor(
        (g.number_of_nodes(), 3), F.float32, "test3"
    )
102
103
104
105
    del test3

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
106
107
    g.ndata["test1"][nids] = new_feats
    feats = g.ndata["test1"][nids]
108
109
110
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
111
    assert g.node_attr_schemes()["test1"].dtype == F.int32
112

113
    print("end")
114

115
116
117
118
119

def run_client_empty(
    graph_name, part_id, server_count, num_clients, num_nodes, num_edges
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
120
    dgl.distributed.initialize("kv_ip_config.txt")
121
122
123
    gpb, graph_name, _, _ = load_partition_book(
        "/tmp/dist_graph/{}.json".format(graph_name), part_id, None
    )
124
125
126
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_graph_empty(g, num_clients, num_nodes, num_edges)

127

128
def check_server_client_empty(shared_mem, num_servers, num_clients):
129
    prepare_dist(num_servers)
130
131
132
133
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
134
135
    graph_name = "dist_graph_test_1"
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
136
137
138
139

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
140
    ctx = mp.get_context("spawn")
141
    for serv_id in range(num_servers):
142
143
144
145
        p = ctx.Process(
            target=run_server,
            args=(graph_name, serv_id, num_servers, num_clients, shared_mem),
        )
146
147
148
149
150
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
151
152
153
154
155
156
157
158
159
160
161
162
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client_empty,
            args=(
                graph_name,
                0,
                num_servers,
                num_clients,
                g.number_of_nodes(),
                g.number_of_edges(),
            ),
        )
163
164
165
166
167
168
169
170
171
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()

    for p in serv_ps:
        p.join()

172
    print("clients have terminated")
173

174
175
176
177
178
179
180
181
182
183
184
185

def run_client(
    graph_name,
    part_id,
    server_count,
    num_clients,
    num_nodes,
    num_edges,
    group_id,
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
    os.environ["DGL_GROUP_ID"] = str(group_id)
186
    dgl.distributed.initialize("kv_ip_config.txt")
187
188
189
    gpb, graph_name, _, _ = load_partition_book(
        "/tmp/dist_graph/{}.json".format(graph_name), part_id, None
    )
190
    g = DistGraph(graph_name, gpb=gpb)
191
    check_dist_graph(g, num_clients, num_nodes, num_edges)
192

193
194
195
196
197
198
199
200
201
202
203
204

def run_emb_client(
    graph_name,
    part_id,
    server_count,
    num_clients,
    num_nodes,
    num_edges,
    group_id,
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
    os.environ["DGL_GROUP_ID"] = str(group_id)
205
    dgl.distributed.initialize("kv_ip_config.txt")
206
207
208
    gpb, graph_name, _, _ = load_partition_book(
        "/tmp/dist_graph/{}.json".format(graph_name), part_id, None
    )
209
210
211
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_emb(g, num_clients, num_nodes, num_edges)

212

213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def run_optim_client(
    graph_name,
    part_id,
    server_count,
    rank,
    world_size,
    num_nodes,
    optimizer_states,
    save,
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "12355"
    dgl.distributed.initialize("kv_ip_config.txt")
    th.distributed.init_process_group(
        backend="gloo", rank=rank, world_size=world_size
    )
    gpb, graph_name, _, _ = load_partition_book(
        "/tmp/dist_graph/{}.json".format(graph_name), part_id, None
    )
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_optim_store(rank, num_nodes, optimizer_states, save)


def check_dist_optim_store(rank, num_nodes, optimizer_states, save):
    try:
        total_idx = F.arange(0, num_nodes, F.int64, F.cpu())
        emb = DistEmbedding(num_nodes, 1, name="optim_emb1", init_func=emb_init)
        emb2 = DistEmbedding(
            num_nodes, 1, name="optim_emb2", init_func=emb_init
        )
        if save:
            optimizer = SparseAdagrad([emb, emb2], lr=0.1, eps=1e-08)
            if rank == 0:
                optimizer._state["optim_emb1"][total_idx] = optimizer_states[0]
                optimizer._state["optim_emb2"][total_idx] = optimizer_states[1]
            optimizer.save("/tmp/dist_graph/emb.pt")
        else:
            optimizer = SparseAdagrad([emb, emb2], lr=0.001, eps=2e-08)
            optimizer.load("/tmp/dist_graph/emb.pt")
            if rank == 0:
                assert F.allclose(
                    optimizer._state["optim_emb1"][total_idx],
                    optimizer_states[0],
                    0.0,
                    0.0,
                )
                assert F.allclose(
                    optimizer._state["optim_emb2"][total_idx],
                    optimizer_states[1],
                    0.0,
                    0.0,
                )
                assert 0.1 == optimizer._lr
                assert 1e-08 == optimizer._eps
            th.distributed.barrier()
    except Exception as e:
        print(e)
        sys.exit(-1)


274
275
276
277
def run_client_hierarchy(
    graph_name, part_id, server_count, node_mask, edge_mask, return_dict
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
278
    dgl.distributed.initialize("kv_ip_config.txt")
279
280
281
    gpb, graph_name, _, _ = load_partition_book(
        "/tmp/dist_graph/{}.json".format(graph_name), part_id, None
    )
282
283
284
    g = DistGraph(graph_name, gpb=gpb)
    node_mask = F.tensor(node_mask)
    edge_mask = F.tensor(edge_mask)
285
286
287
288
289
290
291
292
293
294
    nodes = node_split(
        node_mask,
        g.get_partition_book(),
        node_trainer_ids=g.ndata["trainer_id"],
    )
    edges = edge_split(
        edge_mask,
        g.get_partition_book(),
        edge_trainer_ids=g.edata["trainer_id"],
    )
295
296
297
    rank = g.rank()
    return_dict[rank] = (nodes, edges)

298

299
300
301
def check_dist_emb(g, num_clients, num_nodes, num_edges):
    # Test sparse emb
    try:
302
        emb = DistEmbedding(g.number_of_nodes(), 1, "emb1", emb_init)
303
        nids = F.arange(0, int(g.number_of_nodes()))
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
        lr = 0.001
        optimizer = SparseAdagrad([emb], lr=lr)
        with F.record_grad():
            feats = emb(nids)
            assert np.all(F.asnumpy(feats) == np.zeros((len(nids), 1)))
            loss = F.sum(feats + 1, 0)
        loss.backward()
        optimizer.step()
        feats = emb(nids)
        if num_clients == 1:
            assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * -lr)
        rest = np.setdiff1d(np.arange(g.number_of_nodes()), F.asnumpy(nids))
        feats1 = emb(rest)
        assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))

319
320
321
322
        policy = dgl.distributed.PartitionPolicy("node", g.get_partition_book())
        grad_sum = dgl.distributed.DistTensor(
            (g.number_of_nodes(), 1), F.float32, "emb1_sum", policy
        )
323
        if num_clients == 1:
324
325
326
327
            assert np.all(
                F.asnumpy(grad_sum[nids])
                == np.ones((len(nids), 1)) * num_clients
            )
328
329
        assert np.all(F.asnumpy(grad_sum[rest]) == np.zeros((len(rest), 1)))

330
        emb = DistEmbedding(g.number_of_nodes(), 1, "emb2", emb_init)
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
        with F.no_grad():
            feats1 = emb(nids)
        assert np.all(F.asnumpy(feats1) == 0)

        optimizer = SparseAdagrad([emb], lr=lr)
        with F.record_grad():
            feats1 = emb(nids)
            feats2 = emb(nids)
            feats = F.cat([feats1, feats2], 0)
            assert np.all(F.asnumpy(feats) == np.zeros((len(nids) * 2, 1)))
            loss = F.sum(feats + 1, 0)
        loss.backward()
        optimizer.step()
        with F.no_grad():
            feats = emb(nids)
        if num_clients == 1:
347
348
349
            assert_almost_equal(
                F.asnumpy(feats), np.ones((len(nids), 1)) * 1 * -lr
            )
350
351
352
353
354
        rest = np.setdiff1d(np.arange(g.number_of_nodes()), F.asnumpy(nids))
        feats1 = emb(rest)
        assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))
    except NotImplementedError as e:
        pass
355
356
357
    except Exception as e:
        print(e)
        sys.exit(-1)
358

359

360
def check_dist_graph(g, num_clients, num_nodes, num_edges):
361
362
363
364
365
366
    # Test API
    assert g.number_of_nodes() == num_nodes
    assert g.number_of_edges() == num_edges

    # Test reading node data
    nids = F.arange(0, int(g.number_of_nodes() / 2))
367
    feats1 = g.ndata["features"][nids]
368
369
370
371
372
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == nids))

    # Test reading edge data
    eids = F.arange(0, int(g.number_of_edges() / 2))
373
    feats1 = g.edata["features"][eids]
374
375
376
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == eids))

377
378
379
380
381
    # Test edge_subgraph
    sg = g.edge_subgraph(eids)
    assert sg.num_edges() == len(eids)
    assert F.array_equal(sg.edata[dgl.EID], eids)

382
383
    # Test init node data
    new_shape = (g.number_of_nodes(), 2)
384
    test1 = dgl.distributed.DistTensor(new_shape, F.int32)
385
386
    g.ndata["test1"] = test1
    feats = g.ndata["test1"][nids]
387
    assert np.all(F.asnumpy(feats) == 0)
388
    assert test1.count_nonzero() == 0
389

390
    # reference to a one that exists
391
392
393
394
    test2 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test2", init_func=rand_init
    )
    test3 = dgl.distributed.DistTensor(new_shape, F.float32, "test2")
395
396
397
    assert np.all(F.asnumpy(test2[nids]) == F.asnumpy(test3[nids]))

    # create a tensor and destroy a tensor and create it again.
398
399
400
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test3", init_func=rand_init
    )
401
402
403
    test3_name = test3.kvstore_key
    assert test3_name in g._client.data_name_list()
    assert test3_name in g._client.gdata_name_list()
404
    del test3
405
406
    assert test3_name not in g._client.data_name_list()
    assert test3_name not in g._client.gdata_name_list()
407
408
409
    test3 = dgl.distributed.DistTensor(
        (g.number_of_nodes(), 3), F.float32, "test3"
    )
410
411
    del test3

Da Zheng's avatar
Da Zheng committed
412
    # add tests for anonymous distributed tensor.
413
414
415
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
Da Zheng's avatar
Da Zheng committed
416
    data = test3[0:10]
417
418
419
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
Da Zheng's avatar
Da Zheng committed
420
    del test3
421
422
423
    test5 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
Da Zheng's avatar
Da Zheng committed
424
425
    assert np.sum(F.asnumpy(test5[0:10] != data)) > 0

426
    # test a persistent tesnor
427
428
429
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test4", init_func=rand_init, persistent=True
    )
430
431
    del test4
    try:
432
433
434
435
        test4 = dgl.distributed.DistTensor(
            (g.number_of_nodes(), 3), F.float32, "test4"
        )
        raise Exception("")
436
437
    except:
        pass
438
439
440

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
441
442
    g.ndata["test1"][nids] = new_feats
    feats = g.ndata["test1"][nids]
443
444
445
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
446
447
448
449
450
451
    assert len(g.ndata["features"]) == g.number_of_nodes()
    assert g.ndata["features"].shape == (g.number_of_nodes(), 1)
    assert g.ndata["features"].dtype == F.int64
    assert g.node_attr_schemes()["features"].dtype == F.int64
    assert g.node_attr_schemes()["test1"].dtype == F.int32
    assert g.node_attr_schemes()["features"].shape == (1,)
452

453
454
    selected_nodes = np.random.randint(0, 100, size=g.number_of_nodes()) > 30
    # Test node split
455
    nodes = node_split(selected_nodes, g.get_partition_book())
456
457
458
459
460
461
    nodes = F.asnumpy(nodes)
    # We only have one partition, so the local nodes are basically all nodes in the graph.
    local_nids = np.arange(g.number_of_nodes())
    for n in nodes:
        assert n in local_nids

462
463
    print("end")

464

465
466
467
def check_dist_emb_server_client(
    shared_mem, num_servers, num_clients, num_groups=1
):
468
    prepare_dist(num_servers)
469
470
471
472
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
473
474
475
476
477
478
    graph_name = (
        f"check_dist_emb_{shared_mem}_{num_servers}_{num_clients}_{num_groups}"
    )
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
479
480
481
482

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
483
    ctx = mp.get_context("spawn")
484
    keep_alive = num_groups > 1
485
    for serv_id in range(num_servers):
486
487
488
489
490
491
492
493
494
495
496
        p = ctx.Process(
            target=run_server,
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                shared_mem,
                keep_alive,
            ),
        )
497
498
499
500
501
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
502
        for group_id in range(num_groups):
503
504
505
506
507
508
509
510
511
512
513
514
515
            print("start client[{}] for group[{}]".format(cli_id, group_id))
            p = ctx.Process(
                target=run_emb_client,
                args=(
                    graph_name,
                    0,
                    num_servers,
                    num_clients,
                    g.number_of_nodes(),
                    g.number_of_edges(),
                    group_id,
                ),
            )
516
            p.start()
517
            time.sleep(1)  # avoid race condition when instantiating DistGraph
518
            cli_ps.append(p)
519
520
521

    for p in cli_ps:
        p.join()
522
        assert p.exitcode == 0
523

524
525
526
527
528
    if keep_alive:
        for p in serv_ps:
            assert p.is_alive()
        # force shutdown server
        dgl.distributed.shutdown_servers("kv_ip_config.txt", num_servers)
529
530
531
    for p in serv_ps:
        p.join()

532
533
    print("clients have terminated")

534

535
def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
536
    prepare_dist(num_servers)
537
538
539
540
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
541
542
543
544
    graph_name = f"check_server_client_{shared_mem}_{num_servers}_{num_clients}_{num_groups}"
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
545
546
547
548

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
549
    ctx = mp.get_context("spawn")
550
    keep_alive = num_groups > 1
551
    for serv_id in range(num_servers):
552
553
554
555
556
557
558
559
560
561
562
        p = ctx.Process(
            target=run_server,
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                shared_mem,
                keep_alive,
            ),
        )
563
564
565
        serv_ps.append(p)
        p.start()

566
    # launch different client groups simultaneously
567
    cli_ps = []
568
    for cli_id in range(num_clients):
569
        for group_id in range(num_groups):
570
571
572
573
574
575
576
577
578
579
580
581
582
            print("start client[{}] for group[{}]".format(cli_id, group_id))
            p = ctx.Process(
                target=run_client,
                args=(
                    graph_name,
                    0,
                    num_servers,
                    num_clients,
                    g.number_of_nodes(),
                    g.number_of_edges(),
                    group_id,
                ),
            )
583
            p.start()
584
            time.sleep(1)  # avoid race condition when instantiating DistGraph
585
            cli_ps.append(p)
586
587
    for p in cli_ps:
        p.join()
588

589
590
591
592
593
    if keep_alive:
        for p in serv_ps:
            assert p.is_alive()
        # force shutdown server
        dgl.distributed.shutdown_servers("kv_ip_config.txt", num_servers)
594
595
596
    for p in serv_ps:
        p.join()

597
598
    print("clients have terminated")

599

600
def check_server_client_hierarchy(shared_mem, num_servers, num_clients):
601
    prepare_dist(num_servers)
602
603
604
605
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
606
607
608
609
610
611
612
613
614
615
    graph_name = "dist_graph_test_2"
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(
        g,
        graph_name,
        num_parts,
        "/tmp/dist_graph",
        num_trainers_per_machine=num_clients,
    )
616
617
618
619

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
620
    ctx = mp.get_context("spawn")
621
    for serv_id in range(num_servers):
622
623
624
625
        p = ctx.Process(
            target=run_server,
            args=(graph_name, serv_id, num_servers, num_clients, shared_mem),
        )
626
627
628
629
630
631
632
633
        serv_ps.append(p)
        p.start()

    cli_ps = []
    manager = mp.Manager()
    return_dict = manager.dict()
    node_mask = np.zeros((g.number_of_nodes(),), np.int32)
    edge_mask = np.zeros((g.number_of_edges(),), np.int32)
634
635
636
637
638
639
    nodes = np.random.choice(
        g.number_of_nodes(), g.number_of_nodes() // 10, replace=False
    )
    edges = np.random.choice(
        g.number_of_edges(), g.number_of_edges() // 10, replace=False
    )
640
641
642
643
644
    node_mask[nodes] = 1
    edge_mask[edges] = 1
    nodes = np.sort(nodes)
    edges = np.sort(edges)
    for cli_id in range(num_clients):
645
646
647
648
649
650
651
652
653
654
655
656
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client_hierarchy,
            args=(
                graph_name,
                0,
                num_servers,
                node_mask,
                edge_mask,
                return_dict,
            ),
        )
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
    for p in serv_ps:
        p.join()

    nodes1 = []
    edges1 = []
    for n, e in return_dict.values():
        nodes1.append(n)
        edges1.append(e)
    nodes1, _ = F.sort_1d(F.cat(nodes1, 0))
    edges1, _ = F.sort_1d(F.cat(edges1, 0))
    assert np.all(F.asnumpy(nodes1) == nodes)
    assert np.all(F.asnumpy(edges1) == edges)

675
    print("clients have terminated")
676

677

678
679
680
681
def run_client_hetero(
    graph_name, part_id, server_count, num_clients, num_nodes, num_edges
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
682
    dgl.distributed.initialize("kv_ip_config.txt")
683
684
685
    gpb, graph_name, _, _ = load_partition_book(
        "/tmp/dist_graph/{}.json".format(graph_name), part_id, None
    )
686
687
688
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_graph_hetero(g, num_clients, num_nodes, num_edges)

689

690
def create_random_hetero():
691
692
    num_nodes = {"n1": 10000, "n2": 10010, "n3": 10020}
    etypes = [("n1", "r1", "n2"), ("n1", "r2", "n3"), ("n2", "r3", "n3")]
693
694
695
    edges = {}
    for etype in etypes:
        src_ntype, _, dst_ntype = etype
696
697
698
699
700
701
702
        arr = spsp.random(
            num_nodes[src_ntype],
            num_nodes[dst_ntype],
            density=0.001,
            format="coo",
            random_state=100,
        )
703
704
        edges[etype] = (arr.row, arr.col)
    g = dgl.heterograph(edges, num_nodes)
705
706
707
708
709
710
711
712
713
714
715
716
717
718
    # assign ndata & edata.
    # data with same name as ntype/etype is assigned on purpose to verify
    # such same names can be correctly handled in DistGraph. See more details
    # in issue #4887 and #4463 on github.
    ntype = 'n1'
    for name in ['feat', ntype]:
        g.nodes[ntype].data[name] = F.unsqueeze(
            F.arange(0, g.num_nodes(ntype)), 1
        )
    etype = 'r1'
    for name in ['feat', etype]:
        g.edges[etype].data[name] = F.unsqueeze(
            F.arange(0, g.num_edges(etype)), 1
        )
719
720
    return g

721

722
723
724
725
726
727
728
729
def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges):
    # Test API
    for ntype in num_nodes:
        assert ntype in g.ntypes
        assert num_nodes[ntype] == g.number_of_nodes(ntype)
    for etype in num_edges:
        assert etype in g.etypes
        assert num_edges[etype] == g.number_of_edges(etype)
730
    etypes = [("n1", "r1", "n2"), ("n1", "r2", "n3"), ("n2", "r3", "n3")]
731
732
733
734
    for i, etype in enumerate(g.canonical_etypes):
        assert etype[0] == etypes[i][0]
        assert etype[1] == etypes[i][1]
        assert etype[2] == etypes[i][2]
735
736
737
738
    assert g.number_of_nodes() == sum([num_nodes[ntype] for ntype in num_nodes])
    assert g.number_of_edges() == sum([num_edges[etype] for etype in num_edges])

    # Test reading node data
739
740
741
742
743
744
745
746
747
748
749
750
751
    ntype = 'n1'
    nids = F.arange(0, g.num_nodes(ntype) // 2)
    for name in ['feat', ntype]:
        data = g.nodes[ntype].data[name][nids]
        data = F.squeeze(data, 1)
        assert np.all(F.asnumpy(data == nids))
    assert len(g.nodes['n2'].data) == 0
    expect_except = False
    try:
        g.nodes['xxx'].data['x']
    except dgl.DGLError:
        expect_except = True
    assert expect_except
752
753

    # Test reading edge data
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
    etype = 'r1'
    eids = F.arange(0, g.num_edges(etype) // 2)
    for name in ['feat', etype]:
        # access via etype
        data = g.edges[etype].data[name][eids]
        data = F.squeeze(data, 1)
        assert np.all(F.asnumpy(data == eids))
        # access via canonical etype
        c_etype = g.to_canonical_etype(etype)
        data = g.edges[c_etype].data[name][eids]
        data = F.squeeze(data, 1)
        assert np.all(F.asnumpy(data == eids))
    assert len(g.edges['r2'].data) == 0
    expect_except = False
    try:
        g.edges['xxx'].data['x']
    except dgl.DGLError:
        expect_except = True
    assert expect_except
773

774
    # Test edge_subgraph
775
    sg = g.edge_subgraph({"r1": eids})
776
777
    assert sg.num_edges() == len(eids)
    assert F.array_equal(sg.edata[dgl.EID], eids)
778
    sg = g.edge_subgraph({("n1", "r1", "n2"): eids})
779
780
781
    assert sg.num_edges() == len(eids)
    assert F.array_equal(sg.edata[dgl.EID], eids)

782
    # Test init node data
783
784
785
    new_shape = (g.number_of_nodes("n1"), 2)
    g.nodes["n1"].data["test1"] = dgl.distributed.DistTensor(new_shape, F.int32)
    feats = g.nodes["n1"].data["test1"][nids]
786
787
788
    assert np.all(F.asnumpy(feats) == 0)

    # create a tensor and destroy a tensor and create it again.
789
790
791
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test3", init_func=rand_init
    )
792
    del test3
793
794
795
    test3 = dgl.distributed.DistTensor(
        (g.number_of_nodes("n1"), 3), F.float32, "test3"
    )
796
797
798
    del test3

    # add tests for anonymous distributed tensor.
799
800
801
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
802
    data = test3[0:10]
803
804
805
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
806
    del test3
807
808
809
    test5 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
810
811
812
    assert np.sum(F.asnumpy(test5[0:10] != data)) > 0

    # test a persistent tesnor
813
814
815
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test4", init_func=rand_init, persistent=True
    )
816
817
    del test4
    try:
818
819
820
821
        test4 = dgl.distributed.DistTensor(
            (g.number_of_nodes("n1"), 3), F.float32, "test4"
        )
        raise Exception("")
822
823
824
825
826
    except:
        pass

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
827
828
    g.nodes["n1"].data["test1"][nids] = new_feats
    feats = g.nodes["n1"].data["test1"][nids]
829
830
831
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
832
833
834
    assert len(g.nodes["n1"].data["feat"]) == g.number_of_nodes("n1")
    assert g.nodes["n1"].data["feat"].shape == (g.number_of_nodes("n1"), 1)
    assert g.nodes["n1"].data["feat"].dtype == F.int64
835

836
837
838
    selected_nodes = (
        np.random.randint(0, 100, size=g.number_of_nodes("n1")) > 30
    )
839
    # Test node split
840
    nodes = node_split(selected_nodes, g.get_partition_book(), ntype="n1")
841
842
    nodes = F.asnumpy(nodes)
    # We only have one partition, so the local nodes are basically all nodes in the graph.
843
    local_nids = np.arange(g.number_of_nodes("n1"))
844
845
846
    for n in nodes:
        assert n in local_nids

847
848
    print("end")

849
850

def check_server_client_hetero(shared_mem, num_servers, num_clients):
851
    prepare_dist(num_servers)
852
853
854
855
    g = create_random_hetero()

    # Partition the graph
    num_parts = 1
856
857
    graph_name = "dist_graph_test_3"
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
858
859
860
861

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
862
    ctx = mp.get_context("spawn")
863
    for serv_id in range(num_servers):
864
865
866
867
        p = ctx.Process(
            target=run_server,
            args=(graph_name, serv_id, num_servers, num_clients, shared_mem),
        )
868
869
870
871
872
873
874
        serv_ps.append(p)
        p.start()

    cli_ps = []
    num_nodes = {ntype: g.number_of_nodes(ntype) for ntype in g.ntypes}
    num_edges = {etype: g.number_of_edges(etype) for etype in g.etypes}
    for cli_id in range(num_clients):
875
876
877
878
879
880
881
882
883
884
885
886
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client_hetero,
            args=(
                graph_name,
                0,
                num_servers,
                num_clients,
                num_nodes,
                num_edges,
            ),
        )
887
888
889
890
891
892
893
894
895
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()

    for p in serv_ps:
        p.join()

896
897
    print("clients have terminated")

898

899
900
901
902
903
904
905
906
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support some of operations in DistGraph",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
907
def test_server_client():
908
    reset_envs()
909
    os.environ["DGL_DIST_MODE"] = "distributed"
910
    check_server_client_hierarchy(False, 1, 4)
911
    check_server_client_empty(True, 1, 1)
912
913
    check_server_client_hetero(True, 1, 1)
    check_server_client_hetero(False, 1, 1)
914
915
    check_server_client(True, 1, 1)
    check_server_client(False, 1, 1)
916
917
    # [TODO][Rhett] Tests for multiple groups may fail sometimes and
    # root cause is unknown. Let's disable them for now.
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
    # check_server_client(True, 2, 2)
    # check_server_client(True, 1, 1, 2)
    # check_server_client(False, 1, 1, 2)
    # check_server_client(True, 2, 2, 2)


@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support distributed DistEmbedding",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="Mxnet doesn't support distributed DistEmbedding",
)
933
def test_dist_emb_server_client():
934
    reset_envs()
935
    os.environ["DGL_DIST_MODE"] = "distributed"
936
937
    check_dist_emb_server_client(True, 1, 1)
    check_dist_emb_server_client(False, 1, 1)
938
939
    # [TODO][Rhett] Tests for multiple groups may fail sometimes and
    # root cause is unknown. Let's disable them for now.
940
941
942
943
944
945
    # check_dist_emb_server_client(True, 2, 2)
    # check_dist_emb_server_client(True, 1, 1, 2)
    # check_dist_emb_server_client(False, 1, 1, 2)
    # check_dist_emb_server_client(True, 2, 2, 2)


946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support distributed Optimizer",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="Mxnet doesn't support distributed Optimizer",
)
def test_dist_optim_server_client():
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    optimizer_states = []
    num_nodes = 10000
    optimizer_states.append(F.uniform((num_nodes, 1), F.float32, F.cpu(), 0, 1))
    optimizer_states.append(F.uniform((num_nodes, 1), F.float32, F.cpu(), 0, 1))
    check_dist_optim_server_client(num_nodes, 1, 4, optimizer_states, True)
    check_dist_optim_server_client(num_nodes, 1, 8, optimizer_states, False)
    check_dist_optim_server_client(num_nodes, 1, 2, optimizer_states, False)


def check_dist_optim_server_client(
    num_nodes, num_servers, num_clients, optimizer_states, save
):
    graph_name = f"check_dist_optim_{num_servers}_store"
    if save:
        prepare_dist(num_servers)
        g = create_random_graph(num_nodes)

        # Partition the graph
        num_parts = 1
        g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
        g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
        partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
    ctx = mp.get_context("spawn")
    for serv_id in range(num_servers):
        p = ctx.Process(
            target=run_server,
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                True,
                False,
            ),
        )
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
        print("start client[{}] for group[0]".format(cli_id))
        p = ctx.Process(
            target=run_optim_client,
            args=(
                graph_name,
                0,
                num_servers,
                cli_id,
                num_clients,
                num_nodes,
                optimizer_states,
                save,
            ),
        )
        p.start()
        time.sleep(1)  # avoid race condition when instantiating DistGraph
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
        assert p.exitcode == 0

    for p in serv_ps:
        p.join()


1027
1028
1029
1030
1031
1032
1033
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support some of operations in DistGraph",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
1034
def test_standalone():
1035
    reset_envs()
1036
    os.environ["DGL_DIST_MODE"] = "standalone"
Da Zheng's avatar
Da Zheng committed
1037

1038
1039
1040
    g = create_random_graph(10000)
    # Partition the graph
    num_parts = 1
1041
1042
1043
1044
    graph_name = "dist_graph_test_3"
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
1045
1046

    dgl.distributed.initialize("kv_ip_config.txt")
1047
1048
1049
    dist_g = DistGraph(
        graph_name, part_config="/tmp/dist_graph/{}.json".format(graph_name)
    )
1050
    check_dist_graph(dist_g, 1, g.number_of_nodes(), g.number_of_edges())
1051
1052
    dgl.distributed.exit_client()  # this is needed since there's two test here in one process

1053

1054
1055
1056
1057
1058
1059
1060
1061
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support distributed DistEmbedding",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="Mxnet doesn't support distributed DistEmbedding",
)
1062
def test_standalone_node_emb():
1063
    reset_envs()
1064
    os.environ["DGL_DIST_MODE"] = "standalone"
1065
1066
1067
1068

    g = create_random_graph(10000)
    # Partition the graph
    num_parts = 1
1069
1070
1071
1072
    graph_name = "dist_graph_test_3"
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
1073
1074

    dgl.distributed.initialize("kv_ip_config.txt")
1075
1076
1077
    dist_g = DistGraph(
        graph_name, part_config="/tmp/dist_graph/{}.json".format(graph_name)
    )
1078
    check_dist_emb(dist_g, 1, g.number_of_nodes(), g.number_of_edges())
1079
1080
    dgl.distributed.exit_client()  # this is needed since there's two test here in one process

1081

1082
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
1083
1084
1085
1086
@pytest.mark.parametrize("hetero", [True, False])
def test_split(hetero):
    if hetero:
        g = create_random_hetero()
1087
1088
        ntype = "n1"
        etype = "r1"
1089
1090
    else:
        g = create_random_graph(10000)
1091
1092
        ntype = "_N"
        etype = "_E"
1093
1094
    num_parts = 4
    num_hops = 2
1095
1096
1097
1098
1099
1100
1101
1102
    partition_graph(
        g,
        "dist_graph_test",
        num_parts,
        "/tmp/dist_graph",
        num_hops=num_hops,
        part_method="metis",
    )
1103

1104
1105
    node_mask = np.random.randint(0, 100, size=g.number_of_nodes(ntype)) > 30
    edge_mask = np.random.randint(0, 100, size=g.number_of_edges(etype)) > 30
1106
1107
    selected_nodes = np.nonzero(node_mask)[0]
    selected_edges = np.nonzero(edge_mask)[0]
Da Zheng's avatar
Da Zheng committed
1108
1109
1110
1111
1112

    # The code now collects the roles of all client processes and use the information
    # to determine how to split the workloads. Here is to simulate the multi-client
    # use case.
    def set_roles(num_clients):
1113
1114
1115
1116
1117
        dgl.distributed.role.CUR_ROLE = "default"
        dgl.distributed.role.GLOBAL_RANK = {i: i for i in range(num_clients)}
        dgl.distributed.role.PER_ROLE_RANK["default"] = {
            i: i for i in range(num_clients)
        }
Da Zheng's avatar
Da Zheng committed
1118

1119
    for i in range(num_parts):
Da Zheng's avatar
Da Zheng committed
1120
        set_roles(num_parts)
1121
1122
1123
1124
        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(
            "/tmp/dist_graph/dist_graph_test.json", i
        )
        local_nids = F.nonzero_1d(part_g.ndata["inner_node"])
1125
        local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
1126
1127
1128
1129
1130
1131
        if hetero:
            ntype_ids, nids = gpb.map_to_per_ntype(local_nids)
            local_nids = F.asnumpy(nids)[F.asnumpy(ntype_ids) == 0]
        else:
            local_nids = F.asnumpy(local_nids)
        nodes1 = np.intersect1d(selected_nodes, local_nids)
1132
1133
1134
        nodes2 = node_split(
            node_mask, gpb, ntype=ntype, rank=i, force_even=False
        )
1135
        assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes2)))
1136
        for n in F.asnumpy(nodes2):
1137
1138
            assert n in local_nids

Da Zheng's avatar
Da Zheng committed
1139
        set_roles(num_parts * 2)
1140
1141
1142
1143
1144
1145
        nodes3 = node_split(
            node_mask, gpb, ntype=ntype, rank=i * 2, force_even=False
        )
        nodes4 = node_split(
            node_mask, gpb, ntype=ntype, rank=i * 2 + 1, force_even=False
        )
1146
1147
1148
        nodes5 = F.cat([nodes3, nodes4], 0)
        assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes5)))

Da Zheng's avatar
Da Zheng committed
1149
        set_roles(num_parts)
1150
        local_eids = F.nonzero_1d(part_g.edata["inner_edge"])
1151
        local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)
1152
1153
1154
1155
1156
1157
        if hetero:
            etype_ids, eids = gpb.map_to_per_etype(local_eids)
            local_eids = F.asnumpy(eids)[F.asnumpy(etype_ids) == 0]
        else:
            local_eids = F.asnumpy(local_eids)
        edges1 = np.intersect1d(selected_edges, local_eids)
1158
1159
1160
        edges2 = edge_split(
            edge_mask, gpb, etype=etype, rank=i, force_even=False
        )
1161
        assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges2)))
1162
        for e in F.asnumpy(edges2):
1163
1164
            assert e in local_eids

Da Zheng's avatar
Da Zheng committed
1165
        set_roles(num_parts * 2)
1166
1167
1168
1169
1170
1171
        edges3 = edge_split(
            edge_mask, gpb, etype=etype, rank=i * 2, force_even=False
        )
        edges4 = edge_split(
            edge_mask, gpb, etype=etype, rank=i * 2 + 1, force_even=False
        )
1172
1173
1174
        edges5 = F.cat([edges3, edges4], 0)
        assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges5)))

1175
1176

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
1177
1178
1179
1180
def test_split_even():
    g = create_random_graph(10000)
    num_parts = 4
    num_hops = 2
1181
1182
1183
1184
1185
1186
1187
1188
    partition_graph(
        g,
        "dist_graph_test",
        num_parts,
        "/tmp/dist_graph",
        num_hops=num_hops,
        part_method="metis",
    )
1189
1190
1191
1192
1193
1194
1195
1196
1197

    node_mask = np.random.randint(0, 100, size=g.number_of_nodes()) > 30
    edge_mask = np.random.randint(0, 100, size=g.number_of_edges()) > 30
    selected_nodes = np.nonzero(node_mask)[0]
    selected_edges = np.nonzero(edge_mask)[0]
    all_nodes1 = []
    all_nodes2 = []
    all_edges1 = []
    all_edges2 = []
Da Zheng's avatar
Da Zheng committed
1198
1199
1200
1201
1202

    # The code now collects the roles of all client processes and use the information
    # to determine how to split the workloads. Here is to simulate the multi-client
    # use case.
    def set_roles(num_clients):
1203
1204
1205
1206
1207
        dgl.distributed.role.CUR_ROLE = "default"
        dgl.distributed.role.GLOBAL_RANK = {i: i for i in range(num_clients)}
        dgl.distributed.role.PER_ROLE_RANK["default"] = {
            i: i for i in range(num_clients)
        }
Da Zheng's avatar
Da Zheng committed
1208

1209
    for i in range(num_parts):
Da Zheng's avatar
Da Zheng committed
1210
        set_roles(num_parts)
1211
1212
1213
1214
        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(
            "/tmp/dist_graph/dist_graph_test.json", i
        )
        local_nids = F.nonzero_1d(part_g.ndata["inner_node"])
1215
        local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
1216
        nodes = node_split(node_mask, gpb, rank=i, force_even=True)
1217
1218
        all_nodes1.append(nodes)
        subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(local_nids))
1219
1220
1221
1222
1223
        print(
            "part {} get {} nodes and {} are in the partition".format(
                i, len(nodes), len(subset)
            )
        )
1224

Da Zheng's avatar
Da Zheng committed
1225
        set_roles(num_parts * 2)
1226
1227
1228
        nodes1 = node_split(node_mask, gpb, rank=i * 2, force_even=True)
        nodes2 = node_split(node_mask, gpb, rank=i * 2 + 1, force_even=True)
        nodes3, _ = F.sort_1d(F.cat([nodes1, nodes2], 0))
1229
1230
        all_nodes2.append(nodes3)
        subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(nodes3))
1231
        print("intersection has", len(subset))
1232

Da Zheng's avatar
Da Zheng committed
1233
        set_roles(num_parts)
1234
        local_eids = F.nonzero_1d(part_g.edata["inner_edge"])
1235
        local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)
1236
        edges = edge_split(edge_mask, gpb, rank=i, force_even=True)
1237
1238
        all_edges1.append(edges)
        subset = np.intersect1d(F.asnumpy(edges), F.asnumpy(local_eids))
1239
1240
1241
1242
1243
        print(
            "part {} get {} edges and {} are in the partition".format(
                i, len(edges), len(subset)
            )
        )
1244

Da Zheng's avatar
Da Zheng committed
1245
        set_roles(num_parts * 2)
1246
1247
1248
        edges1 = edge_split(edge_mask, gpb, rank=i * 2, force_even=True)
        edges2 = edge_split(edge_mask, gpb, rank=i * 2 + 1, force_even=True)
        edges3, _ = F.sort_1d(F.cat([edges1, edges2], 0))
1249
1250
        all_edges2.append(edges3)
        subset = np.intersect1d(F.asnumpy(edges), F.asnumpy(edges3))
1251
        print("intersection has", len(subset))
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
    all_nodes1 = F.cat(all_nodes1, 0)
    all_edges1 = F.cat(all_edges1, 0)
    all_nodes2 = F.cat(all_nodes2, 0)
    all_edges2 = F.cat(all_edges2, 0)
    all_nodes = np.nonzero(node_mask)[0]
    all_edges = np.nonzero(edge_mask)[0]
    assert np.all(all_nodes == F.asnumpy(all_nodes1))
    assert np.all(all_edges == F.asnumpy(all_edges1))
    assert np.all(all_nodes == F.asnumpy(all_nodes2))
    assert np.all(all_edges == F.asnumpy(all_edges2))

1263

1264
1265
def prepare_dist(num_servers=1):
    generate_ip_config("kv_ip_config.txt", 1, num_servers=num_servers)
1266

1267
1268
1269

if __name__ == "__main__":
    os.makedirs("/tmp/dist_graph", exist_ok=True)
1270
    test_dist_emb_server_client()
1271
    test_server_client()
1272
1273
    test_split(True)
    test_split(False)
1274
    test_split_even()
1275
    test_standalone()
1276
    test_standalone_node_emb()