test_dist_graph_store.py 40.8 KB
Newer Older
1
import os
2
3
4
5
6
7

os.environ["OMP_NUM_THREADS"] = "1"
import math
import multiprocessing as mp
import pickle
import socket
8
9
10
import sys
import time
import unittest
11
12
13
14
from multiprocessing import Condition, Manager, Process, Value

import backend as F
import numpy as np
15
import pytest
16
import torch as th
17
18
19
from numpy.testing import assert_almost_equal, assert_array_equal
from scipy import sparse as spsp
from utils import create_random_graph, generate_ip_config, reset_envs
20

21
22
23
import dgl
from dgl.data.utils import load_graphs, save_graphs
from dgl.distributed import (
24
    DistEmbedding,
25
26
27
28
29
30
31
32
    DistGraph,
    DistGraphServer,
    edge_split,
    load_partition,
    load_partition_book,
    node_split,
    partition_graph,
)
33
from dgl.distributed.optim import SparseAdagrad
34
35
36
from dgl.heterograph_index import create_unitgraph_from_coo

if os.name != "nt":
37
38
39
    import fcntl
    import struct

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

def run_server(
    graph_name,
    server_id,
    server_count,
    num_clients,
    shared_mem,
    keep_alive=False,
):
    g = DistGraphServer(
        server_id,
        "kv_ip_config.txt",
        server_count,
        num_clients,
        "/tmp/dist_graph/{}.json".format(graph_name),
        disable_shared_mem=not shared_mem,
        graph_format=["csc", "coo"],
        keep_alive=keep_alive,
    )
    print("start server", server_id)
60
61
    # verify dtype of underlying graph
    cg = g.client_g
62
    for k, dtype in dgl.distributed.dist_graph.RESERVED_FIELD_DTYPE.items():
63
        if k in cg.ndata:
64
65
66
            assert (
                F.dtype(cg.ndata[k]) == dtype
            ), "Data type of {} in ndata should be {}.".format(k, dtype)
67
        if k in cg.edata:
68
69
70
            assert (
                F.dtype(cg.edata[k]) == dtype
            ), "Data type of {} in edata should be {}.".format(k, dtype)
71
72
    g.start()

73

74
75
76
def emb_init(shape, dtype):
    return F.zeros(shape, dtype, F.cpu())

77

78
def rand_init(shape, dtype):
79
    return F.tensor(np.random.normal(size=shape), F.float32)
80

81

82
83
84
85
86
87
88
def check_dist_graph_empty(g, num_clients, num_nodes, num_edges):
    # Test API
    assert g.number_of_nodes() == num_nodes
    assert g.number_of_edges() == num_edges

    # Test init node data
    new_shape = (g.number_of_nodes(), 2)
89
    g.ndata["test1"] = dgl.distributed.DistTensor(new_shape, F.int32)
90
    nids = F.arange(0, int(g.number_of_nodes() / 2))
91
    feats = g.ndata["test1"][nids]
92
93
94
    assert np.all(F.asnumpy(feats) == 0)

    # create a tensor and destroy a tensor and create it again.
95
96
97
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test3", init_func=rand_init
    )
98
    del test3
99
100
101
    test3 = dgl.distributed.DistTensor(
        (g.number_of_nodes(), 3), F.float32, "test3"
    )
102
103
104
105
    del test3

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
106
107
    g.ndata["test1"][nids] = new_feats
    feats = g.ndata["test1"][nids]
108
109
110
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
111
    assert g.node_attr_schemes()["test1"].dtype == F.int32
112

113
    print("end")
114

115
116
117
118
119

def run_client_empty(
    graph_name, part_id, server_count, num_clients, num_nodes, num_edges
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
120
    dgl.distributed.initialize("kv_ip_config.txt")
121
122
123
    gpb, graph_name, _, _ = load_partition_book(
        "/tmp/dist_graph/{}.json".format(graph_name), part_id, None
    )
124
125
126
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_graph_empty(g, num_clients, num_nodes, num_edges)

127

128
def check_server_client_empty(shared_mem, num_servers, num_clients):
129
    prepare_dist(num_servers)
130
131
132
133
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
134
135
    graph_name = "dist_graph_test_1"
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
136
137
138
139

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
140
    ctx = mp.get_context("spawn")
141
    for serv_id in range(num_servers):
142
143
144
145
        p = ctx.Process(
            target=run_server,
            args=(graph_name, serv_id, num_servers, num_clients, shared_mem),
        )
146
147
148
149
150
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
151
152
153
154
155
156
157
158
159
160
161
162
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client_empty,
            args=(
                graph_name,
                0,
                num_servers,
                num_clients,
                g.number_of_nodes(),
                g.number_of_edges(),
            ),
        )
163
164
165
166
167
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
168
        assert p.exitcode == 0
169
170
171

    for p in serv_ps:
        p.join()
172
        assert p.exitcode == 0
173

174
    print("clients have terminated")
175

176
177
178
179
180
181
182
183
184
185
186
187

def run_client(
    graph_name,
    part_id,
    server_count,
    num_clients,
    num_nodes,
    num_edges,
    group_id,
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
    os.environ["DGL_GROUP_ID"] = str(group_id)
188
    dgl.distributed.initialize("kv_ip_config.txt")
189
190
191
    gpb, graph_name, _, _ = load_partition_book(
        "/tmp/dist_graph/{}.json".format(graph_name), part_id, None
    )
192
    g = DistGraph(graph_name, gpb=gpb)
193
    check_dist_graph(g, num_clients, num_nodes, num_edges)
194

195
196
197
198
199
200
201
202
203
204
205
206

def run_emb_client(
    graph_name,
    part_id,
    server_count,
    num_clients,
    num_nodes,
    num_edges,
    group_id,
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
    os.environ["DGL_GROUP_ID"] = str(group_id)
207
    dgl.distributed.initialize("kv_ip_config.txt")
208
209
210
    gpb, graph_name, _, _ = load_partition_book(
        "/tmp/dist_graph/{}.json".format(graph_name), part_id, None
    )
211
212
213
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_emb(g, num_clients, num_nodes, num_edges)

214

215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
def run_optim_client(
    graph_name,
    part_id,
    server_count,
    rank,
    world_size,
    num_nodes,
    optimizer_states,
    save,
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "12355"
    dgl.distributed.initialize("kv_ip_config.txt")
    th.distributed.init_process_group(
        backend="gloo", rank=rank, world_size=world_size
    )
    gpb, graph_name, _, _ = load_partition_book(
        "/tmp/dist_graph/{}.json".format(graph_name), part_id, None
    )
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_optim_store(rank, num_nodes, optimizer_states, save)


def check_dist_optim_store(rank, num_nodes, optimizer_states, save):
    try:
        total_idx = F.arange(0, num_nodes, F.int64, F.cpu())
        emb = DistEmbedding(num_nodes, 1, name="optim_emb1", init_func=emb_init)
        emb2 = DistEmbedding(
            num_nodes, 1, name="optim_emb2", init_func=emb_init
        )
        if save:
            optimizer = SparseAdagrad([emb, emb2], lr=0.1, eps=1e-08)
            if rank == 0:
                optimizer._state["optim_emb1"][total_idx] = optimizer_states[0]
                optimizer._state["optim_emb2"][total_idx] = optimizer_states[1]
            optimizer.save("/tmp/dist_graph/emb.pt")
        else:
            optimizer = SparseAdagrad([emb, emb2], lr=0.001, eps=2e-08)
            optimizer.load("/tmp/dist_graph/emb.pt")
            if rank == 0:
                assert F.allclose(
                    optimizer._state["optim_emb1"][total_idx],
                    optimizer_states[0],
                    0.0,
                    0.0,
                )
                assert F.allclose(
                    optimizer._state["optim_emb2"][total_idx],
                    optimizer_states[1],
                    0.0,
                    0.0,
                )
                assert 0.1 == optimizer._lr
                assert 1e-08 == optimizer._eps
            th.distributed.barrier()
    except Exception as e:
        print(e)
        sys.exit(-1)


276
277
278
279
def run_client_hierarchy(
    graph_name, part_id, server_count, node_mask, edge_mask, return_dict
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
280
    dgl.distributed.initialize("kv_ip_config.txt")
281
282
283
    gpb, graph_name, _, _ = load_partition_book(
        "/tmp/dist_graph/{}.json".format(graph_name), part_id, None
    )
284
285
286
    g = DistGraph(graph_name, gpb=gpb)
    node_mask = F.tensor(node_mask)
    edge_mask = F.tensor(edge_mask)
287
288
289
290
291
292
293
294
295
296
    nodes = node_split(
        node_mask,
        g.get_partition_book(),
        node_trainer_ids=g.ndata["trainer_id"],
    )
    edges = edge_split(
        edge_mask,
        g.get_partition_book(),
        edge_trainer_ids=g.edata["trainer_id"],
    )
297
298
299
    rank = g.rank()
    return_dict[rank] = (nodes, edges)

300

301
302
303
def check_dist_emb(g, num_clients, num_nodes, num_edges):
    # Test sparse emb
    try:
304
        emb = DistEmbedding(g.number_of_nodes(), 1, "emb1", emb_init)
305
        nids = F.arange(0, int(g.number_of_nodes()))
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
        lr = 0.001
        optimizer = SparseAdagrad([emb], lr=lr)
        with F.record_grad():
            feats = emb(nids)
            assert np.all(F.asnumpy(feats) == np.zeros((len(nids), 1)))
            loss = F.sum(feats + 1, 0)
        loss.backward()
        optimizer.step()
        feats = emb(nids)
        if num_clients == 1:
            assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * -lr)
        rest = np.setdiff1d(np.arange(g.number_of_nodes()), F.asnumpy(nids))
        feats1 = emb(rest)
        assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))

321
322
323
324
        policy = dgl.distributed.PartitionPolicy("node", g.get_partition_book())
        grad_sum = dgl.distributed.DistTensor(
            (g.number_of_nodes(), 1), F.float32, "emb1_sum", policy
        )
325
        if num_clients == 1:
326
327
328
329
            assert np.all(
                F.asnumpy(grad_sum[nids])
                == np.ones((len(nids), 1)) * num_clients
            )
330
331
        assert np.all(F.asnumpy(grad_sum[rest]) == np.zeros((len(rest), 1)))

332
        emb = DistEmbedding(g.number_of_nodes(), 1, "emb2", emb_init)
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
        with F.no_grad():
            feats1 = emb(nids)
        assert np.all(F.asnumpy(feats1) == 0)

        optimizer = SparseAdagrad([emb], lr=lr)
        with F.record_grad():
            feats1 = emb(nids)
            feats2 = emb(nids)
            feats = F.cat([feats1, feats2], 0)
            assert np.all(F.asnumpy(feats) == np.zeros((len(nids) * 2, 1)))
            loss = F.sum(feats + 1, 0)
        loss.backward()
        optimizer.step()
        with F.no_grad():
            feats = emb(nids)
        if num_clients == 1:
349
350
351
            assert_almost_equal(
                F.asnumpy(feats), np.ones((len(nids), 1)) * 1 * -lr
            )
352
353
354
355
356
        rest = np.setdiff1d(np.arange(g.number_of_nodes()), F.asnumpy(nids))
        feats1 = emb(rest)
        assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))
    except NotImplementedError as e:
        pass
357
358
359
    except Exception as e:
        print(e)
        sys.exit(-1)
360

361

362
def check_dist_graph(g, num_clients, num_nodes, num_edges):
363
364
365
366
367
368
    # Test API
    assert g.number_of_nodes() == num_nodes
    assert g.number_of_edges() == num_edges

    # Test reading node data
    nids = F.arange(0, int(g.number_of_nodes() / 2))
369
    feats1 = g.ndata["features"][nids]
370
371
372
373
374
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == nids))

    # Test reading edge data
    eids = F.arange(0, int(g.number_of_edges() / 2))
375
    feats1 = g.edata["features"][eids]
376
377
378
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == eids))

379
380
381
382
383
    # Test edge_subgraph
    sg = g.edge_subgraph(eids)
    assert sg.num_edges() == len(eids)
    assert F.array_equal(sg.edata[dgl.EID], eids)

384
385
    # Test init node data
    new_shape = (g.number_of_nodes(), 2)
386
    test1 = dgl.distributed.DistTensor(new_shape, F.int32)
387
388
    g.ndata["test1"] = test1
    feats = g.ndata["test1"][nids]
389
    assert np.all(F.asnumpy(feats) == 0)
390
    assert test1.count_nonzero() == 0
391

392
    # reference to a one that exists
393
394
395
396
    test2 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test2", init_func=rand_init
    )
    test3 = dgl.distributed.DistTensor(new_shape, F.float32, "test2")
397
398
399
    assert np.all(F.asnumpy(test2[nids]) == F.asnumpy(test3[nids]))

    # create a tensor and destroy a tensor and create it again.
400
401
402
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test3", init_func=rand_init
    )
403
404
405
    test3_name = test3.kvstore_key
    assert test3_name in g._client.data_name_list()
    assert test3_name in g._client.gdata_name_list()
406
    del test3
407
408
    assert test3_name not in g._client.data_name_list()
    assert test3_name not in g._client.gdata_name_list()
409
410
411
    test3 = dgl.distributed.DistTensor(
        (g.number_of_nodes(), 3), F.float32, "test3"
    )
412
413
    del test3

Da Zheng's avatar
Da Zheng committed
414
    # add tests for anonymous distributed tensor.
415
416
417
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
Da Zheng's avatar
Da Zheng committed
418
    data = test3[0:10]
419
420
421
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
Da Zheng's avatar
Da Zheng committed
422
    del test3
423
424
425
    test5 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
Da Zheng's avatar
Da Zheng committed
426
427
    assert np.sum(F.asnumpy(test5[0:10] != data)) > 0

428
    # test a persistent tesnor
429
430
431
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test4", init_func=rand_init, persistent=True
    )
432
433
    del test4
    try:
434
435
436
437
        test4 = dgl.distributed.DistTensor(
            (g.number_of_nodes(), 3), F.float32, "test4"
        )
        raise Exception("")
438
439
    except:
        pass
440
441
442

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
443
444
    g.ndata["test1"][nids] = new_feats
    feats = g.ndata["test1"][nids]
445
446
447
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
448
449
450
451
452
453
    assert len(g.ndata["features"]) == g.number_of_nodes()
    assert g.ndata["features"].shape == (g.number_of_nodes(), 1)
    assert g.ndata["features"].dtype == F.int64
    assert g.node_attr_schemes()["features"].dtype == F.int64
    assert g.node_attr_schemes()["test1"].dtype == F.int32
    assert g.node_attr_schemes()["features"].shape == (1,)
454

455
456
    selected_nodes = np.random.randint(0, 100, size=g.number_of_nodes()) > 30
    # Test node split
457
    nodes = node_split(selected_nodes, g.get_partition_book())
458
459
460
461
462
463
    nodes = F.asnumpy(nodes)
    # We only have one partition, so the local nodes are basically all nodes in the graph.
    local_nids = np.arange(g.number_of_nodes())
    for n in nodes:
        assert n in local_nids

464
465
    print("end")

466

467
468
469
def check_dist_emb_server_client(
    shared_mem, num_servers, num_clients, num_groups=1
):
470
    prepare_dist(num_servers)
471
472
473
474
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
475
476
477
478
479
480
    graph_name = (
        f"check_dist_emb_{shared_mem}_{num_servers}_{num_clients}_{num_groups}"
    )
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
481
482
483
484

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
485
    ctx = mp.get_context("spawn")
486
    keep_alive = num_groups > 1
487
    for serv_id in range(num_servers):
488
489
490
491
492
493
494
495
496
497
498
        p = ctx.Process(
            target=run_server,
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                shared_mem,
                keep_alive,
            ),
        )
499
500
501
502
503
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
504
        for group_id in range(num_groups):
505
506
507
508
509
510
511
512
513
514
515
516
517
            print("start client[{}] for group[{}]".format(cli_id, group_id))
            p = ctx.Process(
                target=run_emb_client,
                args=(
                    graph_name,
                    0,
                    num_servers,
                    num_clients,
                    g.number_of_nodes(),
                    g.number_of_edges(),
                    group_id,
                ),
            )
518
            p.start()
519
            time.sleep(1)  # avoid race condition when instantiating DistGraph
520
            cli_ps.append(p)
521
522
523

    for p in cli_ps:
        p.join()
524
        assert p.exitcode == 0
525

526
527
528
529
530
    if keep_alive:
        for p in serv_ps:
            assert p.is_alive()
        # force shutdown server
        dgl.distributed.shutdown_servers("kv_ip_config.txt", num_servers)
531
532
    for p in serv_ps:
        p.join()
533
        assert p.exitcode == 0
534

535
536
    print("clients have terminated")

537

538
def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
539
    prepare_dist(num_servers)
540
541
542
543
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
544
545
546
547
    graph_name = f"check_server_client_{shared_mem}_{num_servers}_{num_clients}_{num_groups}"
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
548
549
550
551

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
552
    ctx = mp.get_context("spawn")
553
    keep_alive = num_groups > 1
554
    for serv_id in range(num_servers):
555
556
557
558
559
560
561
562
563
564
565
        p = ctx.Process(
            target=run_server,
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                shared_mem,
                keep_alive,
            ),
        )
566
567
568
        serv_ps.append(p)
        p.start()

569
    # launch different client groups simultaneously
570
    cli_ps = []
571
    for cli_id in range(num_clients):
572
        for group_id in range(num_groups):
573
574
575
576
577
578
579
580
581
582
583
584
585
            print("start client[{}] for group[{}]".format(cli_id, group_id))
            p = ctx.Process(
                target=run_client,
                args=(
                    graph_name,
                    0,
                    num_servers,
                    num_clients,
                    g.number_of_nodes(),
                    g.number_of_edges(),
                    group_id,
                ),
            )
586
            p.start()
587
            time.sleep(1)  # avoid race condition when instantiating DistGraph
588
            cli_ps.append(p)
589
590
    for p in cli_ps:
        p.join()
591
        assert p.exitcode == 0
592

593
594
595
596
597
    if keep_alive:
        for p in serv_ps:
            assert p.is_alive()
        # force shutdown server
        dgl.distributed.shutdown_servers("kv_ip_config.txt", num_servers)
598
599
    for p in serv_ps:
        p.join()
600
        assert p.exitcode == 0
601

602
603
    print("clients have terminated")

604

605
def check_server_client_hierarchy(shared_mem, num_servers, num_clients):
606
    prepare_dist(num_servers)
607
608
609
610
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
611
612
613
614
615
616
617
618
619
620
    graph_name = "dist_graph_test_2"
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(
        g,
        graph_name,
        num_parts,
        "/tmp/dist_graph",
        num_trainers_per_machine=num_clients,
    )
621
622
623
624

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
625
    ctx = mp.get_context("spawn")
626
    for serv_id in range(num_servers):
627
628
629
630
        p = ctx.Process(
            target=run_server,
            args=(graph_name, serv_id, num_servers, num_clients, shared_mem),
        )
631
632
633
634
635
636
637
638
        serv_ps.append(p)
        p.start()

    cli_ps = []
    manager = mp.Manager()
    return_dict = manager.dict()
    node_mask = np.zeros((g.number_of_nodes(),), np.int32)
    edge_mask = np.zeros((g.number_of_edges(),), np.int32)
639
640
641
642
643
644
    nodes = np.random.choice(
        g.number_of_nodes(), g.number_of_nodes() // 10, replace=False
    )
    edges = np.random.choice(
        g.number_of_edges(), g.number_of_edges() // 10, replace=False
    )
645
646
647
648
649
    node_mask[nodes] = 1
    edge_mask[edges] = 1
    nodes = np.sort(nodes)
    edges = np.sort(edges)
    for cli_id in range(num_clients):
650
651
652
653
654
655
656
657
658
659
660
661
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client_hierarchy,
            args=(
                graph_name,
                0,
                num_servers,
                node_mask,
                edge_mask,
                return_dict,
            ),
        )
662
663
664
665
666
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
667
        assert p.exitcode == 0
668
669
    for p in serv_ps:
        p.join()
670
        assert p.exitcode == 0
671
672
673
674
675
676
677
678
679
680
    nodes1 = []
    edges1 = []
    for n, e in return_dict.values():
        nodes1.append(n)
        edges1.append(e)
    nodes1, _ = F.sort_1d(F.cat(nodes1, 0))
    edges1, _ = F.sort_1d(F.cat(edges1, 0))
    assert np.all(F.asnumpy(nodes1) == nodes)
    assert np.all(F.asnumpy(edges1) == edges)

681
    print("clients have terminated")
682

683

684
685
686
687
def run_client_hetero(
    graph_name, part_id, server_count, num_clients, num_nodes, num_edges
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
688
    dgl.distributed.initialize("kv_ip_config.txt")
689
690
691
    gpb, graph_name, _, _ = load_partition_book(
        "/tmp/dist_graph/{}.json".format(graph_name), part_id, None
    )
692
693
694
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_graph_hetero(g, num_clients, num_nodes, num_edges)

695

696
def create_random_hetero():
697
698
    num_nodes = {"n1": 10000, "n2": 10010, "n3": 10020}
    etypes = [("n1", "r1", "n2"), ("n1", "r2", "n3"), ("n2", "r3", "n3")]
699
700
701
    edges = {}
    for etype in etypes:
        src_ntype, _, dst_ntype = etype
702
703
704
705
706
707
708
        arr = spsp.random(
            num_nodes[src_ntype],
            num_nodes[dst_ntype],
            density=0.001,
            format="coo",
            random_state=100,
        )
709
710
        edges[etype] = (arr.row, arr.col)
    g = dgl.heterograph(edges, num_nodes)
711
712
713
714
715
716
717
718
719
720
721
722
723
724
    # assign ndata & edata.
    # data with same name as ntype/etype is assigned on purpose to verify
    # such same names can be correctly handled in DistGraph. See more details
    # in issue #4887 and #4463 on github.
    ntype = 'n1'
    for name in ['feat', ntype]:
        g.nodes[ntype].data[name] = F.unsqueeze(
            F.arange(0, g.num_nodes(ntype)), 1
        )
    etype = 'r1'
    for name in ['feat', etype]:
        g.edges[etype].data[name] = F.unsqueeze(
            F.arange(0, g.num_edges(etype)), 1
        )
725
726
    return g

727

728
729
730
731
732
733
734
735
def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges):
    # Test API
    for ntype in num_nodes:
        assert ntype in g.ntypes
        assert num_nodes[ntype] == g.number_of_nodes(ntype)
    for etype in num_edges:
        assert etype in g.etypes
        assert num_edges[etype] == g.number_of_edges(etype)
736
    etypes = [("n1", "r1", "n2"), ("n1", "r2", "n3"), ("n2", "r3", "n3")]
737
738
739
740
    for i, etype in enumerate(g.canonical_etypes):
        assert etype[0] == etypes[i][0]
        assert etype[1] == etypes[i][1]
        assert etype[2] == etypes[i][2]
741
742
743
744
    assert g.number_of_nodes() == sum([num_nodes[ntype] for ntype in num_nodes])
    assert g.number_of_edges() == sum([num_edges[etype] for etype in num_edges])

    # Test reading node data
745
746
747
748
749
750
751
752
753
754
755
756
757
    ntype = 'n1'
    nids = F.arange(0, g.num_nodes(ntype) // 2)
    for name in ['feat', ntype]:
        data = g.nodes[ntype].data[name][nids]
        data = F.squeeze(data, 1)
        assert np.all(F.asnumpy(data == nids))
    assert len(g.nodes['n2'].data) == 0
    expect_except = False
    try:
        g.nodes['xxx'].data['x']
    except dgl.DGLError:
        expect_except = True
    assert expect_except
758
759

    # Test reading edge data
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
    etype = 'r1'
    eids = F.arange(0, g.num_edges(etype) // 2)
    for name in ['feat', etype]:
        # access via etype
        data = g.edges[etype].data[name][eids]
        data = F.squeeze(data, 1)
        assert np.all(F.asnumpy(data == eids))
        # access via canonical etype
        c_etype = g.to_canonical_etype(etype)
        data = g.edges[c_etype].data[name][eids]
        data = F.squeeze(data, 1)
        assert np.all(F.asnumpy(data == eids))
    assert len(g.edges['r2'].data) == 0
    expect_except = False
    try:
        g.edges['xxx'].data['x']
    except dgl.DGLError:
        expect_except = True
    assert expect_except
779

780
    # Test edge_subgraph
781
    sg = g.edge_subgraph({"r1": eids})
782
783
    assert sg.num_edges() == len(eids)
    assert F.array_equal(sg.edata[dgl.EID], eids)
784
    sg = g.edge_subgraph({("n1", "r1", "n2"): eids})
785
786
787
    assert sg.num_edges() == len(eids)
    assert F.array_equal(sg.edata[dgl.EID], eids)

788
    # Test init node data
789
790
791
    new_shape = (g.number_of_nodes("n1"), 2)
    g.nodes["n1"].data["test1"] = dgl.distributed.DistTensor(new_shape, F.int32)
    feats = g.nodes["n1"].data["test1"][nids]
792
793
794
    assert np.all(F.asnumpy(feats) == 0)

    # create a tensor and destroy a tensor and create it again.
795
796
797
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test3", init_func=rand_init
    )
798
    del test3
799
800
801
    test3 = dgl.distributed.DistTensor(
        (g.number_of_nodes("n1"), 3), F.float32, "test3"
    )
802
803
804
    del test3

    # add tests for anonymous distributed tensor.
805
806
807
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
808
    data = test3[0:10]
809
810
811
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
812
    del test3
813
814
815
    test5 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
816
817
818
    assert np.sum(F.asnumpy(test5[0:10] != data)) > 0

    # test a persistent tesnor
819
820
821
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test4", init_func=rand_init, persistent=True
    )
822
823
    del test4
    try:
824
825
826
827
        test4 = dgl.distributed.DistTensor(
            (g.number_of_nodes("n1"), 3), F.float32, "test4"
        )
        raise Exception("")
828
829
830
831
832
    except:
        pass

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
833
834
    g.nodes["n1"].data["test1"][nids] = new_feats
    feats = g.nodes["n1"].data["test1"][nids]
835
836
837
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
838
839
840
    assert len(g.nodes["n1"].data["feat"]) == g.number_of_nodes("n1")
    assert g.nodes["n1"].data["feat"].shape == (g.number_of_nodes("n1"), 1)
    assert g.nodes["n1"].data["feat"].dtype == F.int64
841

842
843
844
    selected_nodes = (
        np.random.randint(0, 100, size=g.number_of_nodes("n1")) > 30
    )
845
    # Test node split
846
    nodes = node_split(selected_nodes, g.get_partition_book(), ntype="n1")
847
848
    nodes = F.asnumpy(nodes)
    # We only have one partition, so the local nodes are basically all nodes in the graph.
849
    local_nids = np.arange(g.number_of_nodes("n1"))
850
851
852
    for n in nodes:
        assert n in local_nids

853
854
    print("end")

855
856

def check_server_client_hetero(shared_mem, num_servers, num_clients):
857
    prepare_dist(num_servers)
858
859
860
861
    g = create_random_hetero()

    # Partition the graph
    num_parts = 1
862
863
    graph_name = "dist_graph_test_3"
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
864
865
866
867

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
868
    ctx = mp.get_context("spawn")
869
    for serv_id in range(num_servers):
870
871
872
873
        p = ctx.Process(
            target=run_server,
            args=(graph_name, serv_id, num_servers, num_clients, shared_mem),
        )
874
875
876
877
878
879
880
        serv_ps.append(p)
        p.start()

    cli_ps = []
    num_nodes = {ntype: g.number_of_nodes(ntype) for ntype in g.ntypes}
    num_edges = {etype: g.number_of_edges(etype) for etype in g.etypes}
    for cli_id in range(num_clients):
881
882
883
884
885
886
887
888
889
890
891
892
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client_hetero,
            args=(
                graph_name,
                0,
                num_servers,
                num_clients,
                num_nodes,
                num_edges,
            ),
        )
893
894
895
896
897
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
898
        assert p.exitcode == 0
899
900
901

    for p in serv_ps:
        p.join()
902
        assert p.exitcode == 0
903

904
905
    print("clients have terminated")

906

907
908
909
910
911
912
913
914
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support some of operations in DistGraph",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
915
def test_server_client():
916
    reset_envs()
917
    os.environ["DGL_DIST_MODE"] = "distributed"
918
    check_server_client_hierarchy(False, 1, 4)
919
    check_server_client_empty(True, 1, 1)
920
921
    check_server_client_hetero(True, 1, 1)
    check_server_client_hetero(False, 1, 1)
922
923
    check_server_client(True, 1, 1)
    check_server_client(False, 1, 1)
924
925
    # [TODO][Rhett] Tests for multiple groups may fail sometimes and
    # root cause is unknown. Let's disable them for now.
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
    # check_server_client(True, 2, 2)
    # check_server_client(True, 1, 1, 2)
    # check_server_client(False, 1, 1, 2)
    # check_server_client(True, 2, 2, 2)


@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support distributed DistEmbedding",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="Mxnet doesn't support distributed DistEmbedding",
)
941
def test_dist_emb_server_client():
942
    reset_envs()
943
    os.environ["DGL_DIST_MODE"] = "distributed"
944
945
    check_dist_emb_server_client(True, 1, 1)
    check_dist_emb_server_client(False, 1, 1)
946
947
    # [TODO][Rhett] Tests for multiple groups may fail sometimes and
    # root cause is unknown. Let's disable them for now.
948
949
950
951
952
953
    # check_dist_emb_server_client(True, 2, 2)
    # check_dist_emb_server_client(True, 1, 1, 2)
    # check_dist_emb_server_client(False, 1, 1, 2)
    # check_dist_emb_server_client(True, 2, 2, 2)


954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support distributed Optimizer",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="Mxnet doesn't support distributed Optimizer",
)
def test_dist_optim_server_client():
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    optimizer_states = []
    num_nodes = 10000
    optimizer_states.append(F.uniform((num_nodes, 1), F.float32, F.cpu(), 0, 1))
    optimizer_states.append(F.uniform((num_nodes, 1), F.float32, F.cpu(), 0, 1))
    check_dist_optim_server_client(num_nodes, 1, 4, optimizer_states, True)
    check_dist_optim_server_client(num_nodes, 1, 8, optimizer_states, False)
    check_dist_optim_server_client(num_nodes, 1, 2, optimizer_states, False)


def check_dist_optim_server_client(
    num_nodes, num_servers, num_clients, optimizer_states, save
):
    graph_name = f"check_dist_optim_{num_servers}_store"
    if save:
        prepare_dist(num_servers)
        g = create_random_graph(num_nodes)

        # Partition the graph
        num_parts = 1
        g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
        g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
        partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
    ctx = mp.get_context("spawn")
    for serv_id in range(num_servers):
        p = ctx.Process(
            target=run_server,
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                True,
                False,
            ),
        )
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
        print("start client[{}] for group[0]".format(cli_id))
        p = ctx.Process(
            target=run_optim_client,
            args=(
                graph_name,
                0,
                num_servers,
                cli_id,
                num_clients,
                num_nodes,
                optimizer_states,
                save,
            ),
        )
        p.start()
        time.sleep(1)  # avoid race condition when instantiating DistGraph
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
        assert p.exitcode == 0

    for p in serv_ps:
        p.join()
1033
        assert p.exitcode == 0
1034
1035


1036
1037
1038
1039
1040
1041
1042
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support some of operations in DistGraph",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
1043
def test_standalone():
1044
    reset_envs()
1045
    os.environ["DGL_DIST_MODE"] = "standalone"
Da Zheng's avatar
Da Zheng committed
1046

1047
1048
1049
    g = create_random_graph(10000)
    # Partition the graph
    num_parts = 1
1050
1051
1052
1053
    graph_name = "dist_graph_test_3"
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
1054
1055

    dgl.distributed.initialize("kv_ip_config.txt")
1056
1057
1058
    dist_g = DistGraph(
        graph_name, part_config="/tmp/dist_graph/{}.json".format(graph_name)
    )
1059
    check_dist_graph(dist_g, 1, g.number_of_nodes(), g.number_of_edges())
1060
1061
    dgl.distributed.exit_client()  # this is needed since there's two test here in one process

1062

1063
1064
1065
1066
1067
1068
1069
1070
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support distributed DistEmbedding",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="Mxnet doesn't support distributed DistEmbedding",
)
1071
def test_standalone_node_emb():
1072
    reset_envs()
1073
    os.environ["DGL_DIST_MODE"] = "standalone"
1074
1075
1076
1077

    g = create_random_graph(10000)
    # Partition the graph
    num_parts = 1
1078
1079
1080
1081
    graph_name = "dist_graph_test_3"
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
1082
1083

    dgl.distributed.initialize("kv_ip_config.txt")
1084
1085
1086
    dist_g = DistGraph(
        graph_name, part_config="/tmp/dist_graph/{}.json".format(graph_name)
    )
1087
    check_dist_emb(dist_g, 1, g.number_of_nodes(), g.number_of_edges())
1088
1089
    dgl.distributed.exit_client()  # this is needed since there's two test here in one process

1090

1091
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
1092
1093
1094
1095
@pytest.mark.parametrize("hetero", [True, False])
def test_split(hetero):
    if hetero:
        g = create_random_hetero()
1096
1097
        ntype = "n1"
        etype = "r1"
1098
1099
    else:
        g = create_random_graph(10000)
1100
1101
        ntype = "_N"
        etype = "_E"
1102
1103
    num_parts = 4
    num_hops = 2
1104
1105
1106
1107
1108
1109
1110
1111
    partition_graph(
        g,
        "dist_graph_test",
        num_parts,
        "/tmp/dist_graph",
        num_hops=num_hops,
        part_method="metis",
    )
1112

1113
1114
    node_mask = np.random.randint(0, 100, size=g.number_of_nodes(ntype)) > 30
    edge_mask = np.random.randint(0, 100, size=g.number_of_edges(etype)) > 30
1115
1116
    selected_nodes = np.nonzero(node_mask)[0]
    selected_edges = np.nonzero(edge_mask)[0]
Da Zheng's avatar
Da Zheng committed
1117
1118
1119
1120
1121

    # The code now collects the roles of all client processes and use the information
    # to determine how to split the workloads. Here is to simulate the multi-client
    # use case.
    def set_roles(num_clients):
1122
1123
1124
1125
1126
        dgl.distributed.role.CUR_ROLE = "default"
        dgl.distributed.role.GLOBAL_RANK = {i: i for i in range(num_clients)}
        dgl.distributed.role.PER_ROLE_RANK["default"] = {
            i: i for i in range(num_clients)
        }
Da Zheng's avatar
Da Zheng committed
1127

1128
    for i in range(num_parts):
Da Zheng's avatar
Da Zheng committed
1129
        set_roles(num_parts)
1130
1131
1132
1133
        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(
            "/tmp/dist_graph/dist_graph_test.json", i
        )
        local_nids = F.nonzero_1d(part_g.ndata["inner_node"])
1134
        local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
1135
1136
1137
1138
1139
1140
        if hetero:
            ntype_ids, nids = gpb.map_to_per_ntype(local_nids)
            local_nids = F.asnumpy(nids)[F.asnumpy(ntype_ids) == 0]
        else:
            local_nids = F.asnumpy(local_nids)
        nodes1 = np.intersect1d(selected_nodes, local_nids)
1141
1142
1143
        nodes2 = node_split(
            node_mask, gpb, ntype=ntype, rank=i, force_even=False
        )
1144
        assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes2)))
1145
        for n in F.asnumpy(nodes2):
1146
1147
            assert n in local_nids

Da Zheng's avatar
Da Zheng committed
1148
        set_roles(num_parts * 2)
1149
1150
1151
1152
1153
1154
        nodes3 = node_split(
            node_mask, gpb, ntype=ntype, rank=i * 2, force_even=False
        )
        nodes4 = node_split(
            node_mask, gpb, ntype=ntype, rank=i * 2 + 1, force_even=False
        )
1155
1156
1157
        nodes5 = F.cat([nodes3, nodes4], 0)
        assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes5)))

Da Zheng's avatar
Da Zheng committed
1158
        set_roles(num_parts)
1159
        local_eids = F.nonzero_1d(part_g.edata["inner_edge"])
1160
        local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)
1161
1162
1163
1164
1165
1166
        if hetero:
            etype_ids, eids = gpb.map_to_per_etype(local_eids)
            local_eids = F.asnumpy(eids)[F.asnumpy(etype_ids) == 0]
        else:
            local_eids = F.asnumpy(local_eids)
        edges1 = np.intersect1d(selected_edges, local_eids)
1167
1168
1169
        edges2 = edge_split(
            edge_mask, gpb, etype=etype, rank=i, force_even=False
        )
1170
        assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges2)))
1171
        for e in F.asnumpy(edges2):
1172
1173
            assert e in local_eids

Da Zheng's avatar
Da Zheng committed
1174
        set_roles(num_parts * 2)
1175
1176
1177
1178
1179
1180
        edges3 = edge_split(
            edge_mask, gpb, etype=etype, rank=i * 2, force_even=False
        )
        edges4 = edge_split(
            edge_mask, gpb, etype=etype, rank=i * 2 + 1, force_even=False
        )
1181
1182
1183
        edges5 = F.cat([edges3, edges4], 0)
        assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges5)))

1184
1185

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
1186
1187
1188
1189
def test_split_even():
    g = create_random_graph(10000)
    num_parts = 4
    num_hops = 2
1190
1191
1192
1193
1194
1195
1196
1197
    partition_graph(
        g,
        "dist_graph_test",
        num_parts,
        "/tmp/dist_graph",
        num_hops=num_hops,
        part_method="metis",
    )
1198
1199
1200
1201
1202
1203
1204
1205
1206

    node_mask = np.random.randint(0, 100, size=g.number_of_nodes()) > 30
    edge_mask = np.random.randint(0, 100, size=g.number_of_edges()) > 30
    selected_nodes = np.nonzero(node_mask)[0]
    selected_edges = np.nonzero(edge_mask)[0]
    all_nodes1 = []
    all_nodes2 = []
    all_edges1 = []
    all_edges2 = []
Da Zheng's avatar
Da Zheng committed
1207
1208
1209
1210
1211

    # The code now collects the roles of all client processes and use the information
    # to determine how to split the workloads. Here is to simulate the multi-client
    # use case.
    def set_roles(num_clients):
1212
1213
1214
1215
1216
        dgl.distributed.role.CUR_ROLE = "default"
        dgl.distributed.role.GLOBAL_RANK = {i: i for i in range(num_clients)}
        dgl.distributed.role.PER_ROLE_RANK["default"] = {
            i: i for i in range(num_clients)
        }
Da Zheng's avatar
Da Zheng committed
1217

1218
    for i in range(num_parts):
Da Zheng's avatar
Da Zheng committed
1219
        set_roles(num_parts)
1220
1221
1222
1223
        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(
            "/tmp/dist_graph/dist_graph_test.json", i
        )
        local_nids = F.nonzero_1d(part_g.ndata["inner_node"])
1224
        local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
1225
        nodes = node_split(node_mask, gpb, rank=i, force_even=True)
1226
1227
        all_nodes1.append(nodes)
        subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(local_nids))
1228
1229
1230
1231
1232
        print(
            "part {} get {} nodes and {} are in the partition".format(
                i, len(nodes), len(subset)
            )
        )
1233

Da Zheng's avatar
Da Zheng committed
1234
        set_roles(num_parts * 2)
1235
1236
1237
        nodes1 = node_split(node_mask, gpb, rank=i * 2, force_even=True)
        nodes2 = node_split(node_mask, gpb, rank=i * 2 + 1, force_even=True)
        nodes3, _ = F.sort_1d(F.cat([nodes1, nodes2], 0))
1238
1239
        all_nodes2.append(nodes3)
        subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(nodes3))
1240
        print("intersection has", len(subset))
1241

Da Zheng's avatar
Da Zheng committed
1242
        set_roles(num_parts)
1243
        local_eids = F.nonzero_1d(part_g.edata["inner_edge"])
1244
        local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)
1245
        edges = edge_split(edge_mask, gpb, rank=i, force_even=True)
1246
1247
        all_edges1.append(edges)
        subset = np.intersect1d(F.asnumpy(edges), F.asnumpy(local_eids))
1248
1249
1250
1251
1252
        print(
            "part {} get {} edges and {} are in the partition".format(
                i, len(edges), len(subset)
            )
        )
1253

Da Zheng's avatar
Da Zheng committed
1254
        set_roles(num_parts * 2)
1255
1256
1257
        edges1 = edge_split(edge_mask, gpb, rank=i * 2, force_even=True)
        edges2 = edge_split(edge_mask, gpb, rank=i * 2 + 1, force_even=True)
        edges3, _ = F.sort_1d(F.cat([edges1, edges2], 0))
1258
1259
        all_edges2.append(edges3)
        subset = np.intersect1d(F.asnumpy(edges), F.asnumpy(edges3))
1260
        print("intersection has", len(subset))
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
    all_nodes1 = F.cat(all_nodes1, 0)
    all_edges1 = F.cat(all_edges1, 0)
    all_nodes2 = F.cat(all_nodes2, 0)
    all_edges2 = F.cat(all_edges2, 0)
    all_nodes = np.nonzero(node_mask)[0]
    all_edges = np.nonzero(edge_mask)[0]
    assert np.all(all_nodes == F.asnumpy(all_nodes1))
    assert np.all(all_edges == F.asnumpy(all_edges1))
    assert np.all(all_nodes == F.asnumpy(all_nodes2))
    assert np.all(all_edges == F.asnumpy(all_edges2))

1272

1273
1274
def prepare_dist(num_servers=1):
    generate_ip_config("kv_ip_config.txt", 1, num_servers=num_servers)
1275

1276
1277
1278

if __name__ == "__main__":
    os.makedirs("/tmp/dist_graph", exist_ok=True)
1279
    test_dist_emb_server_client()
1280
    test_server_client()
1281
1282
    test_split(True)
    test_split(False)
1283
    test_split_even()
1284
    test_standalone()
1285
    test_standalone_node_emb()