test_dist_graph_store.py 39.9 KB
Newer Older
1
import os
2
3
4
5
6
7

os.environ["OMP_NUM_THREADS"] = "1"
import math
import multiprocessing as mp
import pickle
import socket
8
9
10
import sys
import time
import unittest
11
12
13
from multiprocessing import Condition, Manager, Process, Value

import backend as F
14
15

import dgl
16
import numpy as np
17
import pytest
18
import torch as th
19
20
from dgl.data.utils import load_graphs, save_graphs
from dgl.distributed import (
21
    DistEmbedding,
22
23
24
25
26
27
28
29
    DistGraph,
    DistGraphServer,
    edge_split,
    load_partition,
    load_partition_book,
    node_split,
    partition_graph,
)
30
from dgl.distributed.optim import SparseAdagrad
31
from dgl.heterograph_index import create_unitgraph_from_coo
32
33
34
from numpy.testing import assert_almost_equal, assert_array_equal
from scipy import sparse as spsp
from utils import create_random_graph, generate_ip_config, reset_envs
35
36

if os.name != "nt":
37
38
39
    import fcntl
    import struct

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57

def run_server(
    graph_name,
    server_id,
    server_count,
    num_clients,
    shared_mem,
):
    g = DistGraphServer(
        server_id,
        "kv_ip_config.txt",
        server_count,
        num_clients,
        "/tmp/dist_graph/{}.json".format(graph_name),
        disable_shared_mem=not shared_mem,
        graph_format=["csc", "coo"],
    )
    print("start server", server_id)
58
59
    # verify dtype of underlying graph
    cg = g.client_g
60
    for k, dtype in dgl.distributed.dist_graph.RESERVED_FIELD_DTYPE.items():
61
        if k in cg.ndata:
62
63
64
            assert (
                F.dtype(cg.ndata[k]) == dtype
            ), "Data type of {} in ndata should be {}.".format(k, dtype)
65
        if k in cg.edata:
66
67
68
            assert (
                F.dtype(cg.edata[k]) == dtype
            ), "Data type of {} in edata should be {}.".format(k, dtype)
69
70
    g.start()

71

72
73
74
def emb_init(shape, dtype):
    return F.zeros(shape, dtype, F.cpu())

75

76
def rand_init(shape, dtype):
77
    return F.tensor(np.random.normal(size=shape), F.float32)
78

79

80
81
def check_dist_graph_empty(g, num_clients, num_nodes, num_edges):
    # Test API
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
82
83
    assert g.num_nodes() == num_nodes
    assert g.num_edges() == num_edges
84
85

    # Test init node data
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
86
    new_shape = (g.num_nodes(), 2)
87
    g.ndata["test1"] = dgl.distributed.DistTensor(new_shape, F.int32)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
88
    nids = F.arange(0, int(g.num_nodes() / 2))
89
    feats = g.ndata["test1"][nids]
90
91
92
    assert np.all(F.asnumpy(feats) == 0)

    # create a tensor and destroy a tensor and create it again.
93
94
95
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test3", init_func=rand_init
    )
96
    del test3
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
97
    test3 = dgl.distributed.DistTensor((g.num_nodes(), 3), F.float32, "test3")
98
99
100
101
    del test3

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
102
103
    g.ndata["test1"][nids] = new_feats
    feats = g.ndata["test1"][nids]
104
105
106
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
107
    assert g.node_attr_schemes()["test1"].dtype == F.int32
108

109
    print("end")
110

111
112
113
114
115

def run_client_empty(
    graph_name, part_id, server_count, num_clients, num_nodes, num_edges
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
116
    dgl.distributed.initialize("kv_ip_config.txt")
117
    gpb, graph_name, _, _ = load_partition_book(
118
        "/tmp/dist_graph/{}.json".format(graph_name), part_id
119
    )
120
121
122
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_graph_empty(g, num_clients, num_nodes, num_edges)

123

124
def check_server_client_empty(shared_mem, num_servers, num_clients):
125
    prepare_dist(num_servers)
126
127
128
129
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
130
131
    graph_name = "dist_graph_test_1"
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
132
133
134
135

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
136
    ctx = mp.get_context("spawn")
137
    for serv_id in range(num_servers):
138
139
140
141
        p = ctx.Process(
            target=run_server,
            args=(graph_name, serv_id, num_servers, num_clients, shared_mem),
        )
142
143
144
145
146
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
147
148
149
150
151
152
153
154
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client_empty,
            args=(
                graph_name,
                0,
                num_servers,
                num_clients,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
155
156
                g.num_nodes(),
                g.num_edges(),
157
158
            ),
        )
159
160
161
162
163
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
164
        assert p.exitcode == 0
165
166
167

    for p in serv_ps:
        p.join()
168
        assert p.exitcode == 0
169

170
    print("clients have terminated")
171

172
173
174
175
176
177
178
179
180
181
182
183

def run_client(
    graph_name,
    part_id,
    server_count,
    num_clients,
    num_nodes,
    num_edges,
    group_id,
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
    os.environ["DGL_GROUP_ID"] = str(group_id)
184
    dgl.distributed.initialize("kv_ip_config.txt")
185
    gpb, graph_name, _, _ = load_partition_book(
186
        "/tmp/dist_graph/{}.json".format(graph_name), part_id
187
    )
188
    g = DistGraph(graph_name, gpb=gpb)
189
    check_dist_graph(g, num_clients, num_nodes, num_edges)
190

191
192
193
194
195
196
197
198
199
200
201
202

def run_emb_client(
    graph_name,
    part_id,
    server_count,
    num_clients,
    num_nodes,
    num_edges,
    group_id,
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
    os.environ["DGL_GROUP_ID"] = str(group_id)
203
    dgl.distributed.initialize("kv_ip_config.txt")
204
    gpb, graph_name, _, _ = load_partition_book(
205
        "/tmp/dist_graph/{}.json".format(graph_name), part_id
206
    )
207
208
209
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_emb(g, num_clients, num_nodes, num_edges)

210

211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
def run_optim_client(
    graph_name,
    part_id,
    server_count,
    rank,
    world_size,
    num_nodes,
    optimizer_states,
    save,
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "12355"
    dgl.distributed.initialize("kv_ip_config.txt")
    th.distributed.init_process_group(
        backend="gloo", rank=rank, world_size=world_size
    )
    gpb, graph_name, _, _ = load_partition_book(
229
        "/tmp/dist_graph/{}.json".format(graph_name), part_id
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
    )
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_optim_store(rank, num_nodes, optimizer_states, save)


def check_dist_optim_store(rank, num_nodes, optimizer_states, save):
    try:
        total_idx = F.arange(0, num_nodes, F.int64, F.cpu())
        emb = DistEmbedding(num_nodes, 1, name="optim_emb1", init_func=emb_init)
        emb2 = DistEmbedding(
            num_nodes, 1, name="optim_emb2", init_func=emb_init
        )
        if save:
            optimizer = SparseAdagrad([emb, emb2], lr=0.1, eps=1e-08)
            if rank == 0:
                optimizer._state["optim_emb1"][total_idx] = optimizer_states[0]
                optimizer._state["optim_emb2"][total_idx] = optimizer_states[1]
            optimizer.save("/tmp/dist_graph/emb.pt")
        else:
            optimizer = SparseAdagrad([emb, emb2], lr=0.001, eps=2e-08)
            optimizer.load("/tmp/dist_graph/emb.pt")
            if rank == 0:
                assert F.allclose(
                    optimizer._state["optim_emb1"][total_idx],
                    optimizer_states[0],
                    0.0,
                    0.0,
                )
                assert F.allclose(
                    optimizer._state["optim_emb2"][total_idx],
                    optimizer_states[1],
                    0.0,
                    0.0,
                )
                assert 0.1 == optimizer._lr
                assert 1e-08 == optimizer._eps
            th.distributed.barrier()
    except Exception as e:
        print(e)
        sys.exit(-1)


272
273
274
275
def run_client_hierarchy(
    graph_name, part_id, server_count, node_mask, edge_mask, return_dict
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
276
    dgl.distributed.initialize("kv_ip_config.txt")
277
    gpb, graph_name, _, _ = load_partition_book(
278
        "/tmp/dist_graph/{}.json".format(graph_name), part_id
279
    )
280
281
282
    g = DistGraph(graph_name, gpb=gpb)
    node_mask = F.tensor(node_mask)
    edge_mask = F.tensor(edge_mask)
283
284
285
286
287
288
289
290
291
292
    nodes = node_split(
        node_mask,
        g.get_partition_book(),
        node_trainer_ids=g.ndata["trainer_id"],
    )
    edges = edge_split(
        edge_mask,
        g.get_partition_book(),
        edge_trainer_ids=g.edata["trainer_id"],
    )
293
294
295
    rank = g.rank()
    return_dict[rank] = (nodes, edges)

296

297
298
299
def check_dist_emb(g, num_clients, num_nodes, num_edges):
    # Test sparse emb
    try:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
300
301
        emb = DistEmbedding(g.num_nodes(), 1, "emb1", emb_init)
        nids = F.arange(0, int(g.num_nodes()))
302
303
304
305
306
307
308
309
310
311
312
        lr = 0.001
        optimizer = SparseAdagrad([emb], lr=lr)
        with F.record_grad():
            feats = emb(nids)
            assert np.all(F.asnumpy(feats) == np.zeros((len(nids), 1)))
            loss = F.sum(feats + 1, 0)
        loss.backward()
        optimizer.step()
        feats = emb(nids)
        if num_clients == 1:
            assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * -lr)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
313
        rest = np.setdiff1d(np.arange(g.num_nodes()), F.asnumpy(nids))
314
315
316
        feats1 = emb(rest)
        assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))

317
318
        policy = dgl.distributed.PartitionPolicy("node", g.get_partition_book())
        grad_sum = dgl.distributed.DistTensor(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
319
            (g.num_nodes(), 1), F.float32, "emb1_sum", policy
320
        )
321
        if num_clients == 1:
322
323
324
325
            assert np.all(
                F.asnumpy(grad_sum[nids])
                == np.ones((len(nids), 1)) * num_clients
            )
326
327
        assert np.all(F.asnumpy(grad_sum[rest]) == np.zeros((len(rest), 1)))

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
328
        emb = DistEmbedding(g.num_nodes(), 1, "emb2", emb_init)
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
        with F.no_grad():
            feats1 = emb(nids)
        assert np.all(F.asnumpy(feats1) == 0)

        optimizer = SparseAdagrad([emb], lr=lr)
        with F.record_grad():
            feats1 = emb(nids)
            feats2 = emb(nids)
            feats = F.cat([feats1, feats2], 0)
            assert np.all(F.asnumpy(feats) == np.zeros((len(nids) * 2, 1)))
            loss = F.sum(feats + 1, 0)
        loss.backward()
        optimizer.step()
        with F.no_grad():
            feats = emb(nids)
        if num_clients == 1:
345
346
347
            assert_almost_equal(
                F.asnumpy(feats), np.ones((len(nids), 1)) * 1 * -lr
            )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
348
        rest = np.setdiff1d(np.arange(g.num_nodes()), F.asnumpy(nids))
349
350
351
352
        feats1 = emb(rest)
        assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))
    except NotImplementedError as e:
        pass
353
354
355
    except Exception as e:
        print(e)
        sys.exit(-1)
356

357

358
def check_dist_graph(g, num_clients, num_nodes, num_edges):
359
    # Test API
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
360
361
    assert g.num_nodes() == num_nodes
    assert g.num_edges() == num_edges
362
363

    # Test reading node data
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
364
    nids = F.arange(0, int(g.num_nodes() / 2))
365
    feats1 = g.ndata["features"][nids]
366
367
368
369
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == nids))

    # Test reading edge data
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
370
    eids = F.arange(0, int(g.num_edges() / 2))
371
    feats1 = g.edata["features"][eids]
372
373
374
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == eids))

375
376
377
378
379
    # Test edge_subgraph
    sg = g.edge_subgraph(eids)
    assert sg.num_edges() == len(eids)
    assert F.array_equal(sg.edata[dgl.EID], eids)

380
    # Test init node data
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
381
    new_shape = (g.num_nodes(), 2)
382
    test1 = dgl.distributed.DistTensor(new_shape, F.int32)
383
384
    g.ndata["test1"] = test1
    feats = g.ndata["test1"][nids]
385
    assert np.all(F.asnumpy(feats) == 0)
386
    assert test1.count_nonzero() == 0
387

388
    # reference to a one that exists
389
390
391
392
    test2 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test2", init_func=rand_init
    )
    test3 = dgl.distributed.DistTensor(new_shape, F.float32, "test2")
393
394
395
    assert np.all(F.asnumpy(test2[nids]) == F.asnumpy(test3[nids]))

    # create a tensor and destroy a tensor and create it again.
396
397
398
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test3", init_func=rand_init
    )
399
400
401
    test3_name = test3.kvstore_key
    assert test3_name in g._client.data_name_list()
    assert test3_name in g._client.gdata_name_list()
402
    del test3
403
404
    assert test3_name not in g._client.data_name_list()
    assert test3_name not in g._client.gdata_name_list()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
405
    test3 = dgl.distributed.DistTensor((g.num_nodes(), 3), F.float32, "test3")
406
407
    del test3

Da Zheng's avatar
Da Zheng committed
408
    # add tests for anonymous distributed tensor.
409
410
411
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
Da Zheng's avatar
Da Zheng committed
412
    data = test3[0:10]
413
414
415
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
Da Zheng's avatar
Da Zheng committed
416
    del test3
417
418
419
    test5 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
Da Zheng's avatar
Da Zheng committed
420
421
    assert np.sum(F.asnumpy(test5[0:10] != data)) > 0

422
    # test a persistent tesnor
423
424
425
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test4", init_func=rand_init, persistent=True
    )
426
427
    del test4
    try:
428
        test4 = dgl.distributed.DistTensor(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
429
            (g.num_nodes(), 3), F.float32, "test4"
430
431
        )
        raise Exception("")
432
433
    except:
        pass
434
435
436

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
437
438
    g.ndata["test1"][nids] = new_feats
    feats = g.ndata["test1"][nids]
439
440
441
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
442
443
    assert len(g.ndata["features"]) == g.num_nodes()
    assert g.ndata["features"].shape == (g.num_nodes(), 1)
444
445
446
447
    assert g.ndata["features"].dtype == F.int64
    assert g.node_attr_schemes()["features"].dtype == F.int64
    assert g.node_attr_schemes()["test1"].dtype == F.int32
    assert g.node_attr_schemes()["features"].shape == (1,)
448

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
449
    selected_nodes = np.random.randint(0, 100, size=g.num_nodes()) > 30
450
    # Test node split
451
    nodes = node_split(selected_nodes, g.get_partition_book())
452
453
    nodes = F.asnumpy(nodes)
    # We only have one partition, so the local nodes are basically all nodes in the graph.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
454
    local_nids = np.arange(g.num_nodes())
455
456
457
    for n in nodes:
        assert n in local_nids

458
459
    print("end")

460

461
462
463
def check_dist_emb_server_client(
    shared_mem, num_servers, num_clients, num_groups=1
):
464
    prepare_dist(num_servers)
465
466
467
468
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
469
470
471
    graph_name = (
        f"check_dist_emb_{shared_mem}_{num_servers}_{num_clients}_{num_groups}"
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
472
473
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
474
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
475
476
477
478

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
479
    ctx = mp.get_context("spawn")
480
    for serv_id in range(num_servers):
481
482
483
484
485
486
487
488
489
490
        p = ctx.Process(
            target=run_server,
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                shared_mem,
            ),
        )
491
492
493
494
495
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
496
        for group_id in range(num_groups):
497
498
499
500
501
502
503
504
            print("start client[{}] for group[{}]".format(cli_id, group_id))
            p = ctx.Process(
                target=run_emb_client,
                args=(
                    graph_name,
                    0,
                    num_servers,
                    num_clients,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
505
506
                    g.num_nodes(),
                    g.num_edges(),
507
508
509
                    group_id,
                ),
            )
510
            p.start()
511
            time.sleep(1)  # avoid race condition when instantiating DistGraph
512
            cli_ps.append(p)
513
514
515

    for p in cli_ps:
        p.join()
516
        assert p.exitcode == 0
517
518
519

    for p in serv_ps:
        p.join()
520
        assert p.exitcode == 0
521

522
523
    print("clients have terminated")

524

525
def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
526
    prepare_dist(num_servers)
527
528
529
530
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
531
    graph_name = f"check_server_client_{shared_mem}_{num_servers}_{num_clients}_{num_groups}"
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
532
533
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
534
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
535
536
537
538

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
539
    ctx = mp.get_context("spawn")
540
    for serv_id in range(num_servers):
541
542
543
544
545
546
547
548
549
550
        p = ctx.Process(
            target=run_server,
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                shared_mem,
            ),
        )
551
552
553
        serv_ps.append(p)
        p.start()

554
    # launch different client groups simultaneously
555
    cli_ps = []
556
    for cli_id in range(num_clients):
557
        for group_id in range(num_groups):
558
559
560
561
562
563
564
565
            print("start client[{}] for group[{}]".format(cli_id, group_id))
            p = ctx.Process(
                target=run_client,
                args=(
                    graph_name,
                    0,
                    num_servers,
                    num_clients,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
566
567
                    g.num_nodes(),
                    g.num_edges(),
568
569
570
                    group_id,
                ),
            )
571
            p.start()
572
            time.sleep(1)  # avoid race condition when instantiating DistGraph
573
            cli_ps.append(p)
574
575
    for p in cli_ps:
        p.join()
576
        assert p.exitcode == 0
577
578
579

    for p in serv_ps:
        p.join()
580
        assert p.exitcode == 0
581

582
583
    print("clients have terminated")

584

585
def check_server_client_hierarchy(shared_mem, num_servers, num_clients):
586
    prepare_dist(num_servers)
587
588
589
590
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
591
    graph_name = "dist_graph_test_2"
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
592
593
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
594
595
596
597
598
599
600
    partition_graph(
        g,
        graph_name,
        num_parts,
        "/tmp/dist_graph",
        num_trainers_per_machine=num_clients,
    )
601
602
603
604

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
605
    ctx = mp.get_context("spawn")
606
    for serv_id in range(num_servers):
607
608
609
610
        p = ctx.Process(
            target=run_server,
            args=(graph_name, serv_id, num_servers, num_clients, shared_mem),
        )
611
612
613
614
615
616
        serv_ps.append(p)
        p.start()

    cli_ps = []
    manager = mp.Manager()
    return_dict = manager.dict()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
617
618
619
620
    node_mask = np.zeros((g.num_nodes(),), np.int32)
    edge_mask = np.zeros((g.num_edges(),), np.int32)
    nodes = np.random.choice(g.num_nodes(), g.num_nodes() // 10, replace=False)
    edges = np.random.choice(g.num_edges(), g.num_edges() // 10, replace=False)
621
622
623
624
625
    node_mask[nodes] = 1
    edge_mask[edges] = 1
    nodes = np.sort(nodes)
    edges = np.sort(edges)
    for cli_id in range(num_clients):
626
627
628
629
630
631
632
633
634
635
636
637
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client_hierarchy,
            args=(
                graph_name,
                0,
                num_servers,
                node_mask,
                edge_mask,
                return_dict,
            ),
        )
638
639
640
641
642
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
643
        assert p.exitcode == 0
644
645
    for p in serv_ps:
        p.join()
646
        assert p.exitcode == 0
647
648
649
650
651
652
653
654
655
656
    nodes1 = []
    edges1 = []
    for n, e in return_dict.values():
        nodes1.append(n)
        edges1.append(e)
    nodes1, _ = F.sort_1d(F.cat(nodes1, 0))
    edges1, _ = F.sort_1d(F.cat(edges1, 0))
    assert np.all(F.asnumpy(nodes1) == nodes)
    assert np.all(F.asnumpy(edges1) == edges)

657
    print("clients have terminated")
658

659

660
661
662
663
def run_client_hetero(
    graph_name, part_id, server_count, num_clients, num_nodes, num_edges
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
664
    dgl.distributed.initialize("kv_ip_config.txt")
665
    gpb, graph_name, _, _ = load_partition_book(
666
        "/tmp/dist_graph/{}.json".format(graph_name), part_id
667
    )
668
669
670
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_graph_hetero(g, num_clients, num_nodes, num_edges)

671

672
def create_random_hetero():
673
674
    num_nodes = {"n1": 10000, "n2": 10010, "n3": 10020}
    etypes = [("n1", "r1", "n2"), ("n1", "r2", "n3"), ("n2", "r3", "n3")]
675
676
677
    edges = {}
    for etype in etypes:
        src_ntype, _, dst_ntype = etype
678
679
680
681
682
683
684
        arr = spsp.random(
            num_nodes[src_ntype],
            num_nodes[dst_ntype],
            density=0.001,
            format="coo",
            random_state=100,
        )
685
686
        edges[etype] = (arr.row, arr.col)
    g = dgl.heterograph(edges, num_nodes)
687
688
689
690
    # assign ndata & edata.
    # data with same name as ntype/etype is assigned on purpose to verify
    # such same names can be correctly handled in DistGraph. See more details
    # in issue #4887 and #4463 on github.
691
692
    ntype = "n1"
    for name in ["feat", ntype]:
693
694
695
        g.nodes[ntype].data[name] = F.unsqueeze(
            F.arange(0, g.num_nodes(ntype)), 1
        )
696
697
    etype = "r1"
    for name in ["feat", etype]:
698
699
700
        g.edges[etype].data[name] = F.unsqueeze(
            F.arange(0, g.num_edges(etype)), 1
        )
701
702
    return g

703

704
705
706
707
def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges):
    # Test API
    for ntype in num_nodes:
        assert ntype in g.ntypes
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
708
        assert num_nodes[ntype] == g.num_nodes(ntype)
709
710
    for etype in num_edges:
        assert etype in g.etypes
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
711
        assert num_edges[etype] == g.num_edges(etype)
712
    etypes = [("n1", "r1", "n2"), ("n1", "r2", "n3"), ("n2", "r3", "n3")]
713
714
715
716
    for i, etype in enumerate(g.canonical_etypes):
        assert etype[0] == etypes[i][0]
        assert etype[1] == etypes[i][1]
        assert etype[2] == etypes[i][2]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
717
718
    assert g.num_nodes() == sum([num_nodes[ntype] for ntype in num_nodes])
    assert g.num_edges() == sum([num_edges[etype] for etype in num_edges])
719
720

    # Test reading node data
721
    ntype = "n1"
722
    nids = F.arange(0, g.num_nodes(ntype) // 2)
723
    for name in ["feat", ntype]:
724
725
726
        data = g.nodes[ntype].data[name][nids]
        data = F.squeeze(data, 1)
        assert np.all(F.asnumpy(data == nids))
727
    assert len(g.nodes["n2"].data) == 0
728
729
    expect_except = False
    try:
730
        g.nodes["xxx"].data["x"]
731
732
733
    except dgl.DGLError:
        expect_except = True
    assert expect_except
734
735

    # Test reading edge data
736
    etype = "r1"
737
    eids = F.arange(0, g.num_edges(etype) // 2)
738
    for name in ["feat", etype]:
739
740
741
742
743
744
745
746
747
        # access via etype
        data = g.edges[etype].data[name][eids]
        data = F.squeeze(data, 1)
        assert np.all(F.asnumpy(data == eids))
        # access via canonical etype
        c_etype = g.to_canonical_etype(etype)
        data = g.edges[c_etype].data[name][eids]
        data = F.squeeze(data, 1)
        assert np.all(F.asnumpy(data == eids))
748
    assert len(g.edges["r2"].data) == 0
749
750
    expect_except = False
    try:
751
        g.edges["xxx"].data["x"]
752
753
754
    except dgl.DGLError:
        expect_except = True
    assert expect_except
755

756
    # Test edge_subgraph
757
    sg = g.edge_subgraph({"r1": eids})
758
759
    assert sg.num_edges() == len(eids)
    assert F.array_equal(sg.edata[dgl.EID], eids)
760
    sg = g.edge_subgraph({("n1", "r1", "n2"): eids})
761
762
763
    assert sg.num_edges() == len(eids)
    assert F.array_equal(sg.edata[dgl.EID], eids)

764
    # Test init node data
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
765
    new_shape = (g.num_nodes("n1"), 2)
766
767
    g.nodes["n1"].data["test1"] = dgl.distributed.DistTensor(new_shape, F.int32)
    feats = g.nodes["n1"].data["test1"][nids]
768
769
770
    assert np.all(F.asnumpy(feats) == 0)

    # create a tensor and destroy a tensor and create it again.
771
772
773
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test3", init_func=rand_init
    )
774
    del test3
775
    test3 = dgl.distributed.DistTensor(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
776
        (g.num_nodes("n1"), 3), F.float32, "test3"
777
    )
778
779
780
    del test3

    # add tests for anonymous distributed tensor.
781
782
783
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
784
    data = test3[0:10]
785
786
787
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
788
    del test3
789
790
791
    test5 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
792
793
794
    assert np.sum(F.asnumpy(test5[0:10] != data)) > 0

    # test a persistent tesnor
795
796
797
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test4", init_func=rand_init, persistent=True
    )
798
799
    del test4
    try:
800
        test4 = dgl.distributed.DistTensor(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
801
            (g.num_nodes("n1"), 3), F.float32, "test4"
802
803
        )
        raise Exception("")
804
805
806
807
808
    except:
        pass

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
809
810
    g.nodes["n1"].data["test1"][nids] = new_feats
    feats = g.nodes["n1"].data["test1"][nids]
811
812
813
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
814
815
    assert len(g.nodes["n1"].data["feat"]) == g.num_nodes("n1")
    assert g.nodes["n1"].data["feat"].shape == (g.num_nodes("n1"), 1)
816
    assert g.nodes["n1"].data["feat"].dtype == F.int64
817

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
818
    selected_nodes = np.random.randint(0, 100, size=g.num_nodes("n1")) > 30
819
    # Test node split
820
    nodes = node_split(selected_nodes, g.get_partition_book(), ntype="n1")
821
822
    nodes = F.asnumpy(nodes)
    # We only have one partition, so the local nodes are basically all nodes in the graph.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
823
    local_nids = np.arange(g.num_nodes("n1"))
824
825
826
    for n in nodes:
        assert n in local_nids

827
828
    print("end")

829
830

def check_server_client_hetero(shared_mem, num_servers, num_clients):
831
    prepare_dist(num_servers)
832
833
834
835
    g = create_random_hetero()

    # Partition the graph
    num_parts = 1
836
837
    graph_name = "dist_graph_test_3"
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
838
839
840
841

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
842
    ctx = mp.get_context("spawn")
843
    for serv_id in range(num_servers):
844
845
846
847
        p = ctx.Process(
            target=run_server,
            args=(graph_name, serv_id, num_servers, num_clients, shared_mem),
        )
848
849
850
851
        serv_ps.append(p)
        p.start()

    cli_ps = []
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
852
853
    num_nodes = {ntype: g.num_nodes(ntype) for ntype in g.ntypes}
    num_edges = {etype: g.num_edges(etype) for etype in g.etypes}
854
    for cli_id in range(num_clients):
855
856
857
858
859
860
861
862
863
864
865
866
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client_hetero,
            args=(
                graph_name,
                0,
                num_servers,
                num_clients,
                num_nodes,
                num_edges,
            ),
        )
867
868
869
870
871
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
872
        assert p.exitcode == 0
873
874
875

    for p in serv_ps:
        p.join()
876
        assert p.exitcode == 0
877

878
879
    print("clients have terminated")

880

881
882
883
884
885
886
887
888
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support some of operations in DistGraph",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
889
def test_server_client():
890
    reset_envs()
891
    os.environ["DGL_DIST_MODE"] = "distributed"
892
    check_server_client_hierarchy(False, 1, 4)
893
    check_server_client_empty(True, 1, 1)
894
895
    check_server_client_hetero(True, 1, 1)
    check_server_client_hetero(False, 1, 1)
896
897
    check_server_client(True, 1, 1)
    check_server_client(False, 1, 1)
898
899
    # [TODO][Rhett] Tests for multiple groups may fail sometimes and
    # root cause is unknown. Let's disable them for now.
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
    # check_server_client(True, 2, 2)
    # check_server_client(True, 1, 1, 2)
    # check_server_client(False, 1, 1, 2)
    # check_server_client(True, 2, 2, 2)


@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support distributed DistEmbedding",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="Mxnet doesn't support distributed DistEmbedding",
)
915
def test_dist_emb_server_client():
916
    reset_envs()
917
    os.environ["DGL_DIST_MODE"] = "distributed"
918
919
    check_dist_emb_server_client(True, 1, 1)
    check_dist_emb_server_client(False, 1, 1)
920
921
    # [TODO][Rhett] Tests for multiple groups may fail sometimes and
    # root cause is unknown. Let's disable them for now.
922
923
924
925
926
927
    # check_dist_emb_server_client(True, 2, 2)
    # check_dist_emb_server_client(True, 1, 1, 2)
    # check_dist_emb_server_client(False, 1, 1, 2)
    # check_dist_emb_server_client(True, 2, 2, 2)


928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support distributed Optimizer",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="Mxnet doesn't support distributed Optimizer",
)
def test_dist_optim_server_client():
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    optimizer_states = []
    num_nodes = 10000
    optimizer_states.append(F.uniform((num_nodes, 1), F.float32, F.cpu(), 0, 1))
    optimizer_states.append(F.uniform((num_nodes, 1), F.float32, F.cpu(), 0, 1))
    check_dist_optim_server_client(num_nodes, 1, 4, optimizer_states, True)
    check_dist_optim_server_client(num_nodes, 1, 8, optimizer_states, False)
    check_dist_optim_server_client(num_nodes, 1, 2, optimizer_states, False)


def check_dist_optim_server_client(
    num_nodes, num_servers, num_clients, optimizer_states, save
):
    graph_name = f"check_dist_optim_{num_servers}_store"
    if save:
        prepare_dist(num_servers)
        g = create_random_graph(num_nodes)

        # Partition the graph
        num_parts = 1
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
958
959
        g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
        g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
        partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
    ctx = mp.get_context("spawn")
    for serv_id in range(num_servers):
        p = ctx.Process(
            target=run_server,
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                True,
            ),
        )
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
        print("start client[{}] for group[0]".format(cli_id))
        p = ctx.Process(
            target=run_optim_client,
            args=(
                graph_name,
                0,
                num_servers,
                cli_id,
                num_clients,
                num_nodes,
                optimizer_states,
                save,
            ),
        )
        p.start()
        time.sleep(1)  # avoid race condition when instantiating DistGraph
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
        assert p.exitcode == 0

    for p in serv_ps:
        p.join()
1006
        assert p.exitcode == 0
1007
1008


1009
1010
1011
1012
1013
1014
1015
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support some of operations in DistGraph",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
1016
def test_standalone():
1017
    reset_envs()
1018
    os.environ["DGL_DIST_MODE"] = "standalone"
Da Zheng's avatar
Da Zheng committed
1019

1020
1021
1022
    g = create_random_graph(10000)
    # Partition the graph
    num_parts = 1
1023
    graph_name = "dist_graph_test_3"
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1024
1025
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
1026
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
1027
1028

    dgl.distributed.initialize("kv_ip_config.txt")
1029
1030
1031
    dist_g = DistGraph(
        graph_name, part_config="/tmp/dist_graph/{}.json".format(graph_name)
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1032
    check_dist_graph(dist_g, 1, g.num_nodes(), g.num_edges())
1033
1034
    dgl.distributed.exit_client()  # this is needed since there's two test here in one process

1035

1036
1037
1038
1039
1040
1041
1042
1043
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support distributed DistEmbedding",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="Mxnet doesn't support distributed DistEmbedding",
)
1044
def test_standalone_node_emb():
1045
    reset_envs()
1046
    os.environ["DGL_DIST_MODE"] = "standalone"
1047
1048
1049
1050

    g = create_random_graph(10000)
    # Partition the graph
    num_parts = 1
1051
    graph_name = "dist_graph_test_3"
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1052
1053
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
1054
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
1055
1056

    dgl.distributed.initialize("kv_ip_config.txt")
1057
1058
1059
    dist_g = DistGraph(
        graph_name, part_config="/tmp/dist_graph/{}.json".format(graph_name)
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1060
    check_dist_emb(dist_g, 1, g.num_nodes(), g.num_edges())
1061
1062
    dgl.distributed.exit_client()  # this is needed since there's two test here in one process

1063

1064
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
1065
@pytest.mark.parametrize("hetero", [True, False])
1066
1067
@pytest.mark.parametrize("empty_mask", [True, False])
def test_split(hetero, empty_mask):
1068
1069
    if hetero:
        g = create_random_hetero()
1070
1071
        ntype = "n1"
        etype = "r1"
1072
1073
    else:
        g = create_random_graph(10000)
1074
1075
        ntype = "_N"
        etype = "_E"
1076
1077
    num_parts = 4
    num_hops = 2
1078
1079
1080
1081
1082
1083
1084
1085
    partition_graph(
        g,
        "dist_graph_test",
        num_parts,
        "/tmp/dist_graph",
        num_hops=num_hops,
        part_method="metis",
    )
1086

1087
1088
1089
    mask_thd = 100 if empty_mask else 30
    node_mask = np.random.randint(0, 100, size=g.num_nodes(ntype)) > mask_thd
    edge_mask = np.random.randint(0, 100, size=g.num_edges(etype)) > mask_thd
1090
1091
    selected_nodes = np.nonzero(node_mask)[0]
    selected_edges = np.nonzero(edge_mask)[0]
Da Zheng's avatar
Da Zheng committed
1092
1093
1094
1095
1096

    # The code now collects the roles of all client processes and use the information
    # to determine how to split the workloads. Here is to simulate the multi-client
    # use case.
    def set_roles(num_clients):
1097
1098
1099
1100
1101
        dgl.distributed.role.CUR_ROLE = "default"
        dgl.distributed.role.GLOBAL_RANK = {i: i for i in range(num_clients)}
        dgl.distributed.role.PER_ROLE_RANK["default"] = {
            i: i for i in range(num_clients)
        }
Da Zheng's avatar
Da Zheng committed
1102

1103
    for i in range(num_parts):
Da Zheng's avatar
Da Zheng committed
1104
        set_roles(num_parts)
1105
1106
1107
1108
        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(
            "/tmp/dist_graph/dist_graph_test.json", i
        )
        local_nids = F.nonzero_1d(part_g.ndata["inner_node"])
1109
        local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
1110
1111
1112
1113
1114
1115
        if hetero:
            ntype_ids, nids = gpb.map_to_per_ntype(local_nids)
            local_nids = F.asnumpy(nids)[F.asnumpy(ntype_ids) == 0]
        else:
            local_nids = F.asnumpy(local_nids)
        nodes1 = np.intersect1d(selected_nodes, local_nids)
1116
1117
1118
        nodes2 = node_split(
            node_mask, gpb, ntype=ntype, rank=i, force_even=False
        )
1119
        assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes2)))
1120
        for n in F.asnumpy(nodes2):
1121
1122
            assert n in local_nids

Da Zheng's avatar
Da Zheng committed
1123
        set_roles(num_parts * 2)
1124
1125
1126
1127
1128
1129
        nodes3 = node_split(
            node_mask, gpb, ntype=ntype, rank=i * 2, force_even=False
        )
        nodes4 = node_split(
            node_mask, gpb, ntype=ntype, rank=i * 2 + 1, force_even=False
        )
1130
1131
1132
        nodes5 = F.cat([nodes3, nodes4], 0)
        assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes5)))

Da Zheng's avatar
Da Zheng committed
1133
        set_roles(num_parts)
1134
        local_eids = F.nonzero_1d(part_g.edata["inner_edge"])
1135
        local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)
1136
1137
1138
1139
1140
1141
        if hetero:
            etype_ids, eids = gpb.map_to_per_etype(local_eids)
            local_eids = F.asnumpy(eids)[F.asnumpy(etype_ids) == 0]
        else:
            local_eids = F.asnumpy(local_eids)
        edges1 = np.intersect1d(selected_edges, local_eids)
1142
1143
1144
        edges2 = edge_split(
            edge_mask, gpb, etype=etype, rank=i, force_even=False
        )
1145
        assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges2)))
1146
        for e in F.asnumpy(edges2):
1147
1148
            assert e in local_eids

Da Zheng's avatar
Da Zheng committed
1149
        set_roles(num_parts * 2)
1150
1151
1152
1153
1154
1155
        edges3 = edge_split(
            edge_mask, gpb, etype=etype, rank=i * 2, force_even=False
        )
        edges4 = edge_split(
            edge_mask, gpb, etype=etype, rank=i * 2 + 1, force_even=False
        )
1156
1157
1158
        edges5 = F.cat([edges3, edges4], 0)
        assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges5)))

1159
1160

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
1161
1162
@pytest.mark.parametrize("empty_mask", [True, False])
def test_split_even(empty_mask):
1163
1164
1165
    g = create_random_graph(10000)
    num_parts = 4
    num_hops = 2
1166
1167
1168
1169
1170
1171
1172
1173
    partition_graph(
        g,
        "dist_graph_test",
        num_parts,
        "/tmp/dist_graph",
        num_hops=num_hops,
        part_method="metis",
    )
1174

1175
1176
1177
    mask_thd = 100 if empty_mask else 30
    node_mask = np.random.randint(0, 100, size=g.num_nodes()) > mask_thd
    edge_mask = np.random.randint(0, 100, size=g.num_edges()) > mask_thd
1178
1179
1180
1181
    all_nodes1 = []
    all_nodes2 = []
    all_edges1 = []
    all_edges2 = []
Da Zheng's avatar
Da Zheng committed
1182
1183
1184
1185
1186

    # The code now collects the roles of all client processes and use the information
    # to determine how to split the workloads. Here is to simulate the multi-client
    # use case.
    def set_roles(num_clients):
1187
1188
1189
1190
1191
        dgl.distributed.role.CUR_ROLE = "default"
        dgl.distributed.role.GLOBAL_RANK = {i: i for i in range(num_clients)}
        dgl.distributed.role.PER_ROLE_RANK["default"] = {
            i: i for i in range(num_clients)
        }
Da Zheng's avatar
Da Zheng committed
1192

1193
    for i in range(num_parts):
Da Zheng's avatar
Da Zheng committed
1194
        set_roles(num_parts)
1195
1196
1197
1198
        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(
            "/tmp/dist_graph/dist_graph_test.json", i
        )
        local_nids = F.nonzero_1d(part_g.ndata["inner_node"])
1199
        local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
1200
        nodes = node_split(node_mask, gpb, rank=i, force_even=True)
1201
1202
        all_nodes1.append(nodes)
        subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(local_nids))
1203
1204
1205
1206
1207
        print(
            "part {} get {} nodes and {} are in the partition".format(
                i, len(nodes), len(subset)
            )
        )
1208

Da Zheng's avatar
Da Zheng committed
1209
        set_roles(num_parts * 2)
1210
1211
1212
        nodes1 = node_split(node_mask, gpb, rank=i * 2, force_even=True)
        nodes2 = node_split(node_mask, gpb, rank=i * 2 + 1, force_even=True)
        nodes3, _ = F.sort_1d(F.cat([nodes1, nodes2], 0))
1213
1214
        all_nodes2.append(nodes3)
        subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(nodes3))
1215
        print("intersection has", len(subset))
1216

Da Zheng's avatar
Da Zheng committed
1217
        set_roles(num_parts)
1218
        local_eids = F.nonzero_1d(part_g.edata["inner_edge"])
1219
        local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)
1220
        edges = edge_split(edge_mask, gpb, rank=i, force_even=True)
1221
1222
        all_edges1.append(edges)
        subset = np.intersect1d(F.asnumpy(edges), F.asnumpy(local_eids))
1223
1224
1225
1226
1227
        print(
            "part {} get {} edges and {} are in the partition".format(
                i, len(edges), len(subset)
            )
        )
1228

Da Zheng's avatar
Da Zheng committed
1229
        set_roles(num_parts * 2)
1230
1231
1232
        edges1 = edge_split(edge_mask, gpb, rank=i * 2, force_even=True)
        edges2 = edge_split(edge_mask, gpb, rank=i * 2 + 1, force_even=True)
        edges3, _ = F.sort_1d(F.cat([edges1, edges2], 0))
1233
1234
        all_edges2.append(edges3)
        subset = np.intersect1d(F.asnumpy(edges), F.asnumpy(edges3))
1235
        print("intersection has", len(subset))
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
    all_nodes1 = F.cat(all_nodes1, 0)
    all_edges1 = F.cat(all_edges1, 0)
    all_nodes2 = F.cat(all_nodes2, 0)
    all_edges2 = F.cat(all_edges2, 0)
    all_nodes = np.nonzero(node_mask)[0]
    all_edges = np.nonzero(edge_mask)[0]
    assert np.all(all_nodes == F.asnumpy(all_nodes1))
    assert np.all(all_edges == F.asnumpy(all_edges1))
    assert np.all(all_nodes == F.asnumpy(all_nodes2))
    assert np.all(all_edges == F.asnumpy(all_edges2))

1247

1248
1249
def prepare_dist(num_servers=1):
    generate_ip_config("kv_ip_config.txt", 1, num_servers=num_servers)
1250

1251
1252
1253

if __name__ == "__main__":
    os.makedirs("/tmp/dist_graph", exist_ok=True)
1254
    test_dist_emb_server_client()
1255
    test_server_client()
1256
1257
    test_split(True)
    test_split(False)
1258
    test_split_even()
1259
    test_standalone()
1260
    test_standalone_node_emb()