test_dist_graph_store.py 40 KB
Newer Older
1
import os
2
3
4
5
6
7

os.environ["OMP_NUM_THREADS"] = "1"
import math
import multiprocessing as mp
import pickle
import socket
8
9
10
import sys
import time
import unittest
11
12
13
from multiprocessing import Condition, Manager, Process, Value

import backend as F
14
15

import dgl
16
import numpy as np
17
import pytest
18
import torch as th
19
20
from dgl.data.utils import load_graphs, save_graphs
from dgl.distributed import (
21
    DistEmbedding,
22
23
24
25
26
27
28
29
    DistGraph,
    DistGraphServer,
    edge_split,
    load_partition,
    load_partition_book,
    node_split,
    partition_graph,
)
30
from dgl.distributed.optim import SparseAdagrad
31
from dgl.heterograph_index import create_unitgraph_from_coo
32
33
34
from numpy.testing import assert_almost_equal, assert_array_equal
from scipy import sparse as spsp
from utils import create_random_graph, generate_ip_config, reset_envs
35
36

if os.name != "nt":
37
38
39
    import fcntl
    import struct

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57

def run_server(
    graph_name,
    server_id,
    server_count,
    num_clients,
    shared_mem,
):
    g = DistGraphServer(
        server_id,
        "kv_ip_config.txt",
        server_count,
        num_clients,
        "/tmp/dist_graph/{}.json".format(graph_name),
        disable_shared_mem=not shared_mem,
        graph_format=["csc", "coo"],
    )
    print("start server", server_id)
58
59
    # verify dtype of underlying graph
    cg = g.client_g
60
    for k, dtype in dgl.distributed.dist_graph.RESERVED_FIELD_DTYPE.items():
61
        if k in cg.ndata:
62
63
64
            assert (
                F.dtype(cg.ndata[k]) == dtype
            ), "Data type of {} in ndata should be {}.".format(k, dtype)
65
        if k in cg.edata:
66
67
68
            assert (
                F.dtype(cg.edata[k]) == dtype
            ), "Data type of {} in edata should be {}.".format(k, dtype)
69
70
    g.start()

71

72
73
74
def emb_init(shape, dtype):
    return F.zeros(shape, dtype, F.cpu())

75

76
def rand_init(shape, dtype):
77
    return F.tensor(np.random.normal(size=shape), F.float32)
78

79

80
81
def check_dist_graph_empty(g, num_clients, num_nodes, num_edges):
    # Test API
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
82
83
    assert g.num_nodes() == num_nodes
    assert g.num_edges() == num_edges
84
85

    # Test init node data
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
86
    new_shape = (g.num_nodes(), 2)
87
    g.ndata["test1"] = dgl.distributed.DistTensor(new_shape, F.int32)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
88
    nids = F.arange(0, int(g.num_nodes() / 2))
89
    feats = g.ndata["test1"][nids]
90
91
92
    assert np.all(F.asnumpy(feats) == 0)

    # create a tensor and destroy a tensor and create it again.
93
94
95
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test3", init_func=rand_init
    )
96
    del test3
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
97
    test3 = dgl.distributed.DistTensor((g.num_nodes(), 3), F.float32, "test3")
98
99
100
101
    del test3

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
102
103
    g.ndata["test1"][nids] = new_feats
    feats = g.ndata["test1"][nids]
104
105
106
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
107
    assert g.node_attr_schemes()["test1"].dtype == F.int32
108

109
    print("end")
110

111
112
113
114
115

def run_client_empty(
    graph_name, part_id, server_count, num_clients, num_nodes, num_edges
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
116
    dgl.distributed.initialize("kv_ip_config.txt")
117
    gpb, graph_name, _, _ = load_partition_book(
118
        "/tmp/dist_graph/{}.json".format(graph_name), part_id
119
    )
120
121
122
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_graph_empty(g, num_clients, num_nodes, num_edges)

123

124
def check_server_client_empty(shared_mem, num_servers, num_clients):
125
    prepare_dist(num_servers)
126
127
128
129
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
130
131
    graph_name = "dist_graph_test_1"
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
132
133
134
135

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
136
    ctx = mp.get_context("spawn")
137
    for serv_id in range(num_servers):
138
139
140
141
        p = ctx.Process(
            target=run_server,
            args=(graph_name, serv_id, num_servers, num_clients, shared_mem),
        )
142
143
144
145
146
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
147
148
149
150
151
152
153
154
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client_empty,
            args=(
                graph_name,
                0,
                num_servers,
                num_clients,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
155
156
                g.num_nodes(),
                g.num_edges(),
157
158
            ),
        )
159
160
161
162
163
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
164
        assert p.exitcode == 0
165
166
167

    for p in serv_ps:
        p.join()
168
        assert p.exitcode == 0
169

170
    print("clients have terminated")
171

172
173
174
175
176
177
178
179
180
181
182
183

def run_client(
    graph_name,
    part_id,
    server_count,
    num_clients,
    num_nodes,
    num_edges,
    group_id,
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
    os.environ["DGL_GROUP_ID"] = str(group_id)
184
    dgl.distributed.initialize("kv_ip_config.txt")
185
    gpb, graph_name, _, _ = load_partition_book(
186
        "/tmp/dist_graph/{}.json".format(graph_name), part_id
187
    )
188
    g = DistGraph(graph_name, gpb=gpb)
189
    check_dist_graph(g, num_clients, num_nodes, num_edges)
190

191
192
193
194
195
196
197
198
199
200
201
202

def run_emb_client(
    graph_name,
    part_id,
    server_count,
    num_clients,
    num_nodes,
    num_edges,
    group_id,
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
    os.environ["DGL_GROUP_ID"] = str(group_id)
203
    dgl.distributed.initialize("kv_ip_config.txt")
204
    gpb, graph_name, _, _ = load_partition_book(
205
        "/tmp/dist_graph/{}.json".format(graph_name), part_id
206
    )
207
208
209
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_emb(g, num_clients, num_nodes, num_edges)

210

211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
def run_optim_client(
    graph_name,
    part_id,
    server_count,
    rank,
    world_size,
    num_nodes,
    optimizer_states,
    save,
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "12355"
    dgl.distributed.initialize("kv_ip_config.txt")
    th.distributed.init_process_group(
        backend="gloo", rank=rank, world_size=world_size
    )
    gpb, graph_name, _, _ = load_partition_book(
229
        "/tmp/dist_graph/{}.json".format(graph_name), part_id
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
    )
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_optim_store(rank, num_nodes, optimizer_states, save)


def check_dist_optim_store(rank, num_nodes, optimizer_states, save):
    try:
        total_idx = F.arange(0, num_nodes, F.int64, F.cpu())
        emb = DistEmbedding(num_nodes, 1, name="optim_emb1", init_func=emb_init)
        emb2 = DistEmbedding(
            num_nodes, 1, name="optim_emb2", init_func=emb_init
        )
        if save:
            optimizer = SparseAdagrad([emb, emb2], lr=0.1, eps=1e-08)
            if rank == 0:
                optimizer._state["optim_emb1"][total_idx] = optimizer_states[0]
                optimizer._state["optim_emb2"][total_idx] = optimizer_states[1]
            optimizer.save("/tmp/dist_graph/emb.pt")
        else:
            optimizer = SparseAdagrad([emb, emb2], lr=0.001, eps=2e-08)
            optimizer.load("/tmp/dist_graph/emb.pt")
            if rank == 0:
                assert F.allclose(
                    optimizer._state["optim_emb1"][total_idx],
                    optimizer_states[0],
                    0.0,
                    0.0,
                )
                assert F.allclose(
                    optimizer._state["optim_emb2"][total_idx],
                    optimizer_states[1],
                    0.0,
                    0.0,
                )
                assert 0.1 == optimizer._lr
                assert 1e-08 == optimizer._eps
            th.distributed.barrier()
    except Exception as e:
        print(e)
        sys.exit(-1)


272
273
274
275
def run_client_hierarchy(
    graph_name, part_id, server_count, node_mask, edge_mask, return_dict
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
276
    dgl.distributed.initialize("kv_ip_config.txt")
277
    gpb, graph_name, _, _ = load_partition_book(
278
        "/tmp/dist_graph/{}.json".format(graph_name), part_id
279
    )
280
281
282
    g = DistGraph(graph_name, gpb=gpb)
    node_mask = F.tensor(node_mask)
    edge_mask = F.tensor(edge_mask)
283
284
285
286
287
288
289
290
291
292
    nodes = node_split(
        node_mask,
        g.get_partition_book(),
        node_trainer_ids=g.ndata["trainer_id"],
    )
    edges = edge_split(
        edge_mask,
        g.get_partition_book(),
        edge_trainer_ids=g.edata["trainer_id"],
    )
293
294
295
    rank = g.rank()
    return_dict[rank] = (nodes, edges)

296

297
298
299
def check_dist_emb(g, num_clients, num_nodes, num_edges):
    # Test sparse emb
    try:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
300
301
        emb = DistEmbedding(g.num_nodes(), 1, "emb1", emb_init)
        nids = F.arange(0, int(g.num_nodes()))
302
303
304
305
306
307
308
309
310
311
312
        lr = 0.001
        optimizer = SparseAdagrad([emb], lr=lr)
        with F.record_grad():
            feats = emb(nids)
            assert np.all(F.asnumpy(feats) == np.zeros((len(nids), 1)))
            loss = F.sum(feats + 1, 0)
        loss.backward()
        optimizer.step()
        feats = emb(nids)
        if num_clients == 1:
            assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * -lr)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
313
        rest = np.setdiff1d(np.arange(g.num_nodes()), F.asnumpy(nids))
314
315
316
        feats1 = emb(rest)
        assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))

317
318
        policy = dgl.distributed.PartitionPolicy("node", g.get_partition_book())
        grad_sum = dgl.distributed.DistTensor(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
319
            (g.num_nodes(), 1), F.float32, "emb1_sum", policy
320
        )
321
        if num_clients == 1:
322
323
324
325
            assert np.all(
                F.asnumpy(grad_sum[nids])
                == np.ones((len(nids), 1)) * num_clients
            )
326
327
        assert np.all(F.asnumpy(grad_sum[rest]) == np.zeros((len(rest), 1)))

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
328
        emb = DistEmbedding(g.num_nodes(), 1, "emb2", emb_init)
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
        with F.no_grad():
            feats1 = emb(nids)
        assert np.all(F.asnumpy(feats1) == 0)

        optimizer = SparseAdagrad([emb], lr=lr)
        with F.record_grad():
            feats1 = emb(nids)
            feats2 = emb(nids)
            feats = F.cat([feats1, feats2], 0)
            assert np.all(F.asnumpy(feats) == np.zeros((len(nids) * 2, 1)))
            loss = F.sum(feats + 1, 0)
        loss.backward()
        optimizer.step()
        with F.no_grad():
            feats = emb(nids)
        if num_clients == 1:
345
346
347
            assert_almost_equal(
                F.asnumpy(feats), np.ones((len(nids), 1)) * 1 * -lr
            )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
348
        rest = np.setdiff1d(np.arange(g.num_nodes()), F.asnumpy(nids))
349
350
351
352
        feats1 = emb(rest)
        assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))
    except NotImplementedError as e:
        pass
353
354
355
    except Exception as e:
        print(e)
        sys.exit(-1)
356

357

358
def check_dist_graph(g, num_clients, num_nodes, num_edges):
359
    # Test API
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
360
361
    assert g.num_nodes() == num_nodes
    assert g.num_edges() == num_edges
362
363

    # Test reading node data
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
364
    nids = F.arange(0, int(g.num_nodes() / 2))
365
    feats1 = g.ndata["features"][nids]
366
367
368
369
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == nids))

    # Test reading edge data
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
370
    eids = F.arange(0, int(g.num_edges() / 2))
371
    feats1 = g.edata["features"][eids]
372
373
374
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == eids))

375
376
377
378
379
    # Test edge_subgraph
    sg = g.edge_subgraph(eids)
    assert sg.num_edges() == len(eids)
    assert F.array_equal(sg.edata[dgl.EID], eids)

380
    # Test init node data
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
381
    new_shape = (g.num_nodes(), 2)
382
    test1 = dgl.distributed.DistTensor(new_shape, F.int32)
383
384
    g.ndata["test1"] = test1
    feats = g.ndata["test1"][nids]
385
    assert np.all(F.asnumpy(feats) == 0)
386
    assert test1.count_nonzero() == 0
387

388
    # reference to a one that exists
389
390
391
392
    test2 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test2", init_func=rand_init
    )
    test3 = dgl.distributed.DistTensor(new_shape, F.float32, "test2")
393
394
395
    assert np.all(F.asnumpy(test2[nids]) == F.asnumpy(test3[nids]))

    # create a tensor and destroy a tensor and create it again.
396
397
398
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test3", init_func=rand_init
    )
399
400
401
    test3_name = test3.kvstore_key
    assert test3_name in g._client.data_name_list()
    assert test3_name in g._client.gdata_name_list()
402
    del test3
403
404
    assert test3_name not in g._client.data_name_list()
    assert test3_name not in g._client.gdata_name_list()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
405
    test3 = dgl.distributed.DistTensor((g.num_nodes(), 3), F.float32, "test3")
406
407
    del test3

Da Zheng's avatar
Da Zheng committed
408
    # add tests for anonymous distributed tensor.
409
410
411
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
Da Zheng's avatar
Da Zheng committed
412
    data = test3[0:10]
413
414
415
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
Da Zheng's avatar
Da Zheng committed
416
    del test3
417
418
419
    test5 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
Da Zheng's avatar
Da Zheng committed
420
421
    assert np.sum(F.asnumpy(test5[0:10] != data)) > 0

422
    # test a persistent tesnor
423
424
425
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test4", init_func=rand_init, persistent=True
    )
426
427
    del test4
    try:
428
        test4 = dgl.distributed.DistTensor(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
429
            (g.num_nodes(), 3), F.float32, "test4"
430
431
        )
        raise Exception("")
432
433
    except:
        pass
434
435
436

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
437
438
    g.ndata["test1"][nids] = new_feats
    feats = g.ndata["test1"][nids]
439
440
441
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
442
443
    assert len(g.ndata["features"]) == g.num_nodes()
    assert g.ndata["features"].shape == (g.num_nodes(), 1)
444
445
446
447
    assert g.ndata["features"].dtype == F.int64
    assert g.node_attr_schemes()["features"].dtype == F.int64
    assert g.node_attr_schemes()["test1"].dtype == F.int32
    assert g.node_attr_schemes()["features"].shape == (1,)
448

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
449
    selected_nodes = np.random.randint(0, 100, size=g.num_nodes()) > 30
450
    # Test node split
451
    nodes = node_split(selected_nodes, g.get_partition_book())
452
453
    nodes = F.asnumpy(nodes)
    # We only have one partition, so the local nodes are basically all nodes in the graph.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
454
    local_nids = np.arange(g.num_nodes())
455
456
457
    for n in nodes:
        assert n in local_nids

458
459
    print("end")

460

461
462
463
def check_dist_emb_server_client(
    shared_mem, num_servers, num_clients, num_groups=1
):
464
    prepare_dist(num_servers)
465
466
467
468
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
469
470
471
    graph_name = (
        f"check_dist_emb_{shared_mem}_{num_servers}_{num_clients}_{num_groups}"
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
472
473
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
474
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
475
476
477
478

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
479
    ctx = mp.get_context("spawn")
480
    for serv_id in range(num_servers):
481
482
483
484
485
486
487
488
489
490
        p = ctx.Process(
            target=run_server,
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                shared_mem,
            ),
        )
491
492
493
494
495
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
496
        for group_id in range(num_groups):
497
498
499
500
501
502
503
504
            print("start client[{}] for group[{}]".format(cli_id, group_id))
            p = ctx.Process(
                target=run_emb_client,
                args=(
                    graph_name,
                    0,
                    num_servers,
                    num_clients,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
505
506
                    g.num_nodes(),
                    g.num_edges(),
507
508
509
                    group_id,
                ),
            )
510
            p.start()
511
            time.sleep(1)  # avoid race condition when instantiating DistGraph
512
            cli_ps.append(p)
513
514
515

    for p in cli_ps:
        p.join()
516
        assert p.exitcode == 0
517
518
519

    for p in serv_ps:
        p.join()
520
        assert p.exitcode == 0
521

522
523
    print("clients have terminated")

524

525
def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
526
    prepare_dist(num_servers)
527
528
529
530
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
531
    graph_name = f"check_server_client_{shared_mem}_{num_servers}_{num_clients}_{num_groups}"
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
532
533
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
534
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
535
536
537
538

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
539
    ctx = mp.get_context("spawn")
540
    for serv_id in range(num_servers):
541
542
543
544
545
546
547
548
549
550
        p = ctx.Process(
            target=run_server,
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                shared_mem,
            ),
        )
551
552
553
        serv_ps.append(p)
        p.start()

554
    # launch different client groups simultaneously
555
    cli_ps = []
556
    for cli_id in range(num_clients):
557
        for group_id in range(num_groups):
558
559
560
561
562
563
564
565
            print("start client[{}] for group[{}]".format(cli_id, group_id))
            p = ctx.Process(
                target=run_client,
                args=(
                    graph_name,
                    0,
                    num_servers,
                    num_clients,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
566
567
                    g.num_nodes(),
                    g.num_edges(),
568
569
570
                    group_id,
                ),
            )
571
            p.start()
572
            time.sleep(1)  # avoid race condition when instantiating DistGraph
573
            cli_ps.append(p)
574
575
    for p in cli_ps:
        p.join()
576
        assert p.exitcode == 0
577
578
579

    for p in serv_ps:
        p.join()
580
        assert p.exitcode == 0
581

582
583
    print("clients have terminated")

584

585
def check_server_client_hierarchy(shared_mem, num_servers, num_clients):
586
    prepare_dist(num_servers)
587
588
589
590
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
591
    graph_name = "dist_graph_test_2"
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
592
593
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
594
595
596
597
598
599
600
    partition_graph(
        g,
        graph_name,
        num_parts,
        "/tmp/dist_graph",
        num_trainers_per_machine=num_clients,
    )
601
602
603
604

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
605
    ctx = mp.get_context("spawn")
606
    for serv_id in range(num_servers):
607
608
609
610
        p = ctx.Process(
            target=run_server,
            args=(graph_name, serv_id, num_servers, num_clients, shared_mem),
        )
611
612
613
614
615
616
        serv_ps.append(p)
        p.start()

    cli_ps = []
    manager = mp.Manager()
    return_dict = manager.dict()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
617
618
619
620
    node_mask = np.zeros((g.num_nodes(),), np.int32)
    edge_mask = np.zeros((g.num_edges(),), np.int32)
    nodes = np.random.choice(g.num_nodes(), g.num_nodes() // 10, replace=False)
    edges = np.random.choice(g.num_edges(), g.num_edges() // 10, replace=False)
621
622
623
624
625
    node_mask[nodes] = 1
    edge_mask[edges] = 1
    nodes = np.sort(nodes)
    edges = np.sort(edges)
    for cli_id in range(num_clients):
626
627
628
629
630
631
632
633
634
635
636
637
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client_hierarchy,
            args=(
                graph_name,
                0,
                num_servers,
                node_mask,
                edge_mask,
                return_dict,
            ),
        )
638
639
640
641
642
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
643
        assert p.exitcode == 0
644
645
    for p in serv_ps:
        p.join()
646
        assert p.exitcode == 0
647
648
649
650
651
652
653
654
655
656
    nodes1 = []
    edges1 = []
    for n, e in return_dict.values():
        nodes1.append(n)
        edges1.append(e)
    nodes1, _ = F.sort_1d(F.cat(nodes1, 0))
    edges1, _ = F.sort_1d(F.cat(edges1, 0))
    assert np.all(F.asnumpy(nodes1) == nodes)
    assert np.all(F.asnumpy(edges1) == edges)

657
    print("clients have terminated")
658

659

660
661
662
663
def run_client_hetero(
    graph_name, part_id, server_count, num_clients, num_nodes, num_edges
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
664
    dgl.distributed.initialize("kv_ip_config.txt")
665
    gpb, graph_name, _, _ = load_partition_book(
666
        "/tmp/dist_graph/{}.json".format(graph_name), part_id
667
    )
668
669
670
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_graph_hetero(g, num_clients, num_nodes, num_edges)

671

672
def create_random_hetero():
673
674
    num_nodes = {"n1": 10000, "n2": 10010, "n3": 10020}
    etypes = [("n1", "r1", "n2"), ("n1", "r2", "n3"), ("n2", "r3", "n3")]
675
676
677
    edges = {}
    for etype in etypes:
        src_ntype, _, dst_ntype = etype
678
679
680
681
682
683
684
        arr = spsp.random(
            num_nodes[src_ntype],
            num_nodes[dst_ntype],
            density=0.001,
            format="coo",
            random_state=100,
        )
685
686
        edges[etype] = (arr.row, arr.col)
    g = dgl.heterograph(edges, num_nodes)
687
688
689
690
    # assign ndata & edata.
    # data with same name as ntype/etype is assigned on purpose to verify
    # such same names can be correctly handled in DistGraph. See more details
    # in issue #4887 and #4463 on github.
691
692
    ntype = "n1"
    for name in ["feat", ntype]:
693
694
695
        g.nodes[ntype].data[name] = F.unsqueeze(
            F.arange(0, g.num_nodes(ntype)), 1
        )
696
697
    etype = "r1"
    for name in ["feat", etype]:
698
699
700
        g.edges[etype].data[name] = F.unsqueeze(
            F.arange(0, g.num_edges(etype)), 1
        )
701
702
    return g

703

704
705
706
707
def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges):
    # Test API
    for ntype in num_nodes:
        assert ntype in g.ntypes
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
708
        assert num_nodes[ntype] == g.num_nodes(ntype)
709
710
    for etype in num_edges:
        assert etype in g.etypes
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
711
        assert num_edges[etype] == g.num_edges(etype)
712
    etypes = [("n1", "r1", "n2"), ("n1", "r2", "n3"), ("n2", "r3", "n3")]
713
714
715
716
    for i, etype in enumerate(g.canonical_etypes):
        assert etype[0] == etypes[i][0]
        assert etype[1] == etypes[i][1]
        assert etype[2] == etypes[i][2]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
717
718
    assert g.num_nodes() == sum([num_nodes[ntype] for ntype in num_nodes])
    assert g.num_edges() == sum([num_edges[etype] for etype in num_edges])
719
720

    # Test reading node data
721
    ntype = "n1"
722
    nids = F.arange(0, g.num_nodes(ntype) // 2)
723
    for name in ["feat", ntype]:
724
725
726
        data = g.nodes[ntype].data[name][nids]
        data = F.squeeze(data, 1)
        assert np.all(F.asnumpy(data == nids))
727
    assert len(g.nodes["n2"].data) == 0
728
729
    expect_except = False
    try:
730
        g.nodes["xxx"].data["x"]
731
732
733
    except dgl.DGLError:
        expect_except = True
    assert expect_except
734
735

    # Test reading edge data
736
    etype = "r1"
737
    eids = F.arange(0, g.num_edges(etype) // 2)
738
    for name in ["feat", etype]:
739
740
741
742
743
744
745
746
747
        # access via etype
        data = g.edges[etype].data[name][eids]
        data = F.squeeze(data, 1)
        assert np.all(F.asnumpy(data == eids))
        # access via canonical etype
        c_etype = g.to_canonical_etype(etype)
        data = g.edges[c_etype].data[name][eids]
        data = F.squeeze(data, 1)
        assert np.all(F.asnumpy(data == eids))
748
    assert len(g.edges["r2"].data) == 0
749
750
    expect_except = False
    try:
751
        g.edges["xxx"].data["x"]
752
753
754
    except dgl.DGLError:
        expect_except = True
    assert expect_except
755

756
    # Test edge_subgraph
757
    sg = g.edge_subgraph({"r1": eids})
758
759
    assert sg.num_edges() == len(eids)
    assert F.array_equal(sg.edata[dgl.EID], eids)
760
    sg = g.edge_subgraph({("n1", "r1", "n2"): eids})
761
762
763
    assert sg.num_edges() == len(eids)
    assert F.array_equal(sg.edata[dgl.EID], eids)

764
    # Test init node data
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
765
    new_shape = (g.num_nodes("n1"), 2)
766
767
    g.nodes["n1"].data["test1"] = dgl.distributed.DistTensor(new_shape, F.int32)
    feats = g.nodes["n1"].data["test1"][nids]
768
769
770
    assert np.all(F.asnumpy(feats) == 0)

    # create a tensor and destroy a tensor and create it again.
771
772
773
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test3", init_func=rand_init
    )
774
    del test3
775
    test3 = dgl.distributed.DistTensor(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
776
        (g.num_nodes("n1"), 3), F.float32, "test3"
777
    )
778
779
780
    del test3

    # add tests for anonymous distributed tensor.
781
782
783
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
784
    data = test3[0:10]
785
786
787
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
788
    del test3
789
790
791
    test5 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
792
793
794
    assert np.sum(F.asnumpy(test5[0:10] != data)) > 0

    # test a persistent tesnor
795
796
797
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test4", init_func=rand_init, persistent=True
    )
798
799
    del test4
    try:
800
        test4 = dgl.distributed.DistTensor(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
801
            (g.num_nodes("n1"), 3), F.float32, "test4"
802
803
        )
        raise Exception("")
804
805
806
807
808
    except:
        pass

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
809
810
    g.nodes["n1"].data["test1"][nids] = new_feats
    feats = g.nodes["n1"].data["test1"][nids]
811
812
813
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
814
815
    assert len(g.nodes["n1"].data["feat"]) == g.num_nodes("n1")
    assert g.nodes["n1"].data["feat"].shape == (g.num_nodes("n1"), 1)
816
    assert g.nodes["n1"].data["feat"].dtype == F.int64
817

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
818
    selected_nodes = np.random.randint(0, 100, size=g.num_nodes("n1")) > 30
819
    # Test node split
820
    nodes = node_split(selected_nodes, g.get_partition_book(), ntype="n1")
821
822
    nodes = F.asnumpy(nodes)
    # We only have one partition, so the local nodes are basically all nodes in the graph.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
823
    local_nids = np.arange(g.num_nodes("n1"))
824
825
826
    for n in nodes:
        assert n in local_nids

827
828
    print("end")

829
830

def check_server_client_hetero(shared_mem, num_servers, num_clients):
831
    prepare_dist(num_servers)
832
833
834
835
    g = create_random_hetero()

    # Partition the graph
    num_parts = 1
836
837
    graph_name = "dist_graph_test_3"
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
838
839
840
841

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
842
    ctx = mp.get_context("spawn")
843
    for serv_id in range(num_servers):
844
845
846
847
        p = ctx.Process(
            target=run_server,
            args=(graph_name, serv_id, num_servers, num_clients, shared_mem),
        )
848
849
850
851
        serv_ps.append(p)
        p.start()

    cli_ps = []
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
852
853
    num_nodes = {ntype: g.num_nodes(ntype) for ntype in g.ntypes}
    num_edges = {etype: g.num_edges(etype) for etype in g.etypes}
854
    for cli_id in range(num_clients):
855
856
857
858
859
860
861
862
863
864
865
866
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client_hetero,
            args=(
                graph_name,
                0,
                num_servers,
                num_clients,
                num_nodes,
                num_edges,
            ),
        )
867
868
869
870
871
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
872
        assert p.exitcode == 0
873
874
875

    for p in serv_ps:
        p.join()
876
        assert p.exitcode == 0
877

878
879
    print("clients have terminated")

880

881
882
883
884
885
886
887
888
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support some of operations in DistGraph",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
889
def test_server_client():
890
    reset_envs()
891
    os.environ["DGL_DIST_MODE"] = "distributed"
892
    check_server_client_hierarchy(False, 1, 4)
893
    check_server_client_empty(True, 1, 1)
894
895
    check_server_client_hetero(True, 1, 1)
    check_server_client_hetero(False, 1, 1)
896
897
    check_server_client(True, 1, 1)
    check_server_client(False, 1, 1)
898
899
    # [TODO][Rhett] Tests for multiple groups may fail sometimes and
    # root cause is unknown. Let's disable them for now.
900
901
902
903
904
905
    # check_server_client(True, 2, 2)
    # check_server_client(True, 1, 1, 2)
    # check_server_client(False, 1, 1, 2)
    # check_server_client(True, 2, 2, 2)


906
@unittest.skip(reason="Skip due to glitch in CI")
907
908
909
910
911
912
913
914
915
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support distributed DistEmbedding",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="Mxnet doesn't support distributed DistEmbedding",
)
916
def test_dist_emb_server_client():
917
    reset_envs()
918
    os.environ["DGL_DIST_MODE"] = "distributed"
919
920
    check_dist_emb_server_client(True, 1, 1)
    check_dist_emb_server_client(False, 1, 1)
921
922
    # [TODO][Rhett] Tests for multiple groups may fail sometimes and
    # root cause is unknown. Let's disable them for now.
923
924
925
926
927
928
    # check_dist_emb_server_client(True, 2, 2)
    # check_dist_emb_server_client(True, 1, 1, 2)
    # check_dist_emb_server_client(False, 1, 1, 2)
    # check_dist_emb_server_client(True, 2, 2, 2)


929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support distributed Optimizer",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="Mxnet doesn't support distributed Optimizer",
)
def test_dist_optim_server_client():
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    optimizer_states = []
    num_nodes = 10000
    optimizer_states.append(F.uniform((num_nodes, 1), F.float32, F.cpu(), 0, 1))
    optimizer_states.append(F.uniform((num_nodes, 1), F.float32, F.cpu(), 0, 1))
    check_dist_optim_server_client(num_nodes, 1, 4, optimizer_states, True)
    check_dist_optim_server_client(num_nodes, 1, 8, optimizer_states, False)
    check_dist_optim_server_client(num_nodes, 1, 2, optimizer_states, False)


def check_dist_optim_server_client(
    num_nodes, num_servers, num_clients, optimizer_states, save
):
    graph_name = f"check_dist_optim_{num_servers}_store"
    if save:
        prepare_dist(num_servers)
        g = create_random_graph(num_nodes)

        # Partition the graph
        num_parts = 1
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
959
960
        g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
        g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
        partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
    ctx = mp.get_context("spawn")
    for serv_id in range(num_servers):
        p = ctx.Process(
            target=run_server,
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                True,
            ),
        )
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
        print("start client[{}] for group[0]".format(cli_id))
        p = ctx.Process(
            target=run_optim_client,
            args=(
                graph_name,
                0,
                num_servers,
                cli_id,
                num_clients,
                num_nodes,
                optimizer_states,
                save,
            ),
        )
        p.start()
        time.sleep(1)  # avoid race condition when instantiating DistGraph
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
        assert p.exitcode == 0

    for p in serv_ps:
        p.join()
1007
        assert p.exitcode == 0
1008
1009


1010
1011
1012
1013
1014
1015
1016
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support some of operations in DistGraph",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
1017
def test_standalone():
1018
    reset_envs()
1019
    os.environ["DGL_DIST_MODE"] = "standalone"
Da Zheng's avatar
Da Zheng committed
1020

1021
1022
1023
    g = create_random_graph(10000)
    # Partition the graph
    num_parts = 1
1024
    graph_name = "dist_graph_test_3"
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1025
1026
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
1027
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
1028
1029

    dgl.distributed.initialize("kv_ip_config.txt")
1030
1031
1032
    dist_g = DistGraph(
        graph_name, part_config="/tmp/dist_graph/{}.json".format(graph_name)
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1033
    check_dist_graph(dist_g, 1, g.num_nodes(), g.num_edges())
1034
1035
    dgl.distributed.exit_client()  # this is needed since there's two test here in one process

1036

1037
@unittest.skip(reason="Skip due to glitch in CI")
1038
1039
1040
1041
1042
1043
1044
1045
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support distributed DistEmbedding",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="Mxnet doesn't support distributed DistEmbedding",
)
1046
def test_standalone_node_emb():
1047
    reset_envs()
1048
    os.environ["DGL_DIST_MODE"] = "standalone"
1049
1050
1051
1052

    g = create_random_graph(10000)
    # Partition the graph
    num_parts = 1
1053
    graph_name = "dist_graph_test_3"
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1054
1055
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
1056
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
1057
1058

    dgl.distributed.initialize("kv_ip_config.txt")
1059
1060
1061
    dist_g = DistGraph(
        graph_name, part_config="/tmp/dist_graph/{}.json".format(graph_name)
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1062
    check_dist_emb(dist_g, 1, g.num_nodes(), g.num_edges())
1063
1064
    dgl.distributed.exit_client()  # this is needed since there's two test here in one process

1065

1066
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
1067
@pytest.mark.parametrize("hetero", [True, False])
1068
1069
@pytest.mark.parametrize("empty_mask", [True, False])
def test_split(hetero, empty_mask):
1070
1071
    if hetero:
        g = create_random_hetero()
1072
1073
        ntype = "n1"
        etype = "r1"
1074
1075
    else:
        g = create_random_graph(10000)
1076
1077
        ntype = "_N"
        etype = "_E"
1078
1079
    num_parts = 4
    num_hops = 2
1080
1081
1082
1083
1084
1085
1086
1087
    partition_graph(
        g,
        "dist_graph_test",
        num_parts,
        "/tmp/dist_graph",
        num_hops=num_hops,
        part_method="metis",
    )
1088

1089
1090
1091
    mask_thd = 100 if empty_mask else 30
    node_mask = np.random.randint(0, 100, size=g.num_nodes(ntype)) > mask_thd
    edge_mask = np.random.randint(0, 100, size=g.num_edges(etype)) > mask_thd
1092
1093
    selected_nodes = np.nonzero(node_mask)[0]
    selected_edges = np.nonzero(edge_mask)[0]
Da Zheng's avatar
Da Zheng committed
1094
1095
1096
1097
1098

    # The code now collects the roles of all client processes and use the information
    # to determine how to split the workloads. Here is to simulate the multi-client
    # use case.
    def set_roles(num_clients):
1099
1100
1101
1102
1103
        dgl.distributed.role.CUR_ROLE = "default"
        dgl.distributed.role.GLOBAL_RANK = {i: i for i in range(num_clients)}
        dgl.distributed.role.PER_ROLE_RANK["default"] = {
            i: i for i in range(num_clients)
        }
Da Zheng's avatar
Da Zheng committed
1104

1105
    for i in range(num_parts):
Da Zheng's avatar
Da Zheng committed
1106
        set_roles(num_parts)
1107
1108
1109
1110
        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(
            "/tmp/dist_graph/dist_graph_test.json", i
        )
        local_nids = F.nonzero_1d(part_g.ndata["inner_node"])
1111
        local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
1112
1113
1114
1115
1116
1117
        if hetero:
            ntype_ids, nids = gpb.map_to_per_ntype(local_nids)
            local_nids = F.asnumpy(nids)[F.asnumpy(ntype_ids) == 0]
        else:
            local_nids = F.asnumpy(local_nids)
        nodes1 = np.intersect1d(selected_nodes, local_nids)
1118
1119
1120
        nodes2 = node_split(
            node_mask, gpb, ntype=ntype, rank=i, force_even=False
        )
1121
        assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes2)))
1122
        for n in F.asnumpy(nodes2):
1123
1124
            assert n in local_nids

Da Zheng's avatar
Da Zheng committed
1125
        set_roles(num_parts * 2)
1126
1127
1128
1129
1130
1131
        nodes3 = node_split(
            node_mask, gpb, ntype=ntype, rank=i * 2, force_even=False
        )
        nodes4 = node_split(
            node_mask, gpb, ntype=ntype, rank=i * 2 + 1, force_even=False
        )
1132
1133
1134
        nodes5 = F.cat([nodes3, nodes4], 0)
        assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes5)))

Da Zheng's avatar
Da Zheng committed
1135
        set_roles(num_parts)
1136
        local_eids = F.nonzero_1d(part_g.edata["inner_edge"])
1137
        local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)
1138
1139
1140
1141
1142
1143
        if hetero:
            etype_ids, eids = gpb.map_to_per_etype(local_eids)
            local_eids = F.asnumpy(eids)[F.asnumpy(etype_ids) == 0]
        else:
            local_eids = F.asnumpy(local_eids)
        edges1 = np.intersect1d(selected_edges, local_eids)
1144
1145
1146
        edges2 = edge_split(
            edge_mask, gpb, etype=etype, rank=i, force_even=False
        )
1147
        assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges2)))
1148
        for e in F.asnumpy(edges2):
1149
1150
            assert e in local_eids

Da Zheng's avatar
Da Zheng committed
1151
        set_roles(num_parts * 2)
1152
1153
1154
1155
1156
1157
        edges3 = edge_split(
            edge_mask, gpb, etype=etype, rank=i * 2, force_even=False
        )
        edges4 = edge_split(
            edge_mask, gpb, etype=etype, rank=i * 2 + 1, force_even=False
        )
1158
1159
1160
        edges5 = F.cat([edges3, edges4], 0)
        assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges5)))

1161
1162

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
1163
1164
@pytest.mark.parametrize("empty_mask", [True, False])
def test_split_even(empty_mask):
1165
1166
1167
    g = create_random_graph(10000)
    num_parts = 4
    num_hops = 2
1168
1169
1170
1171
1172
1173
1174
1175
    partition_graph(
        g,
        "dist_graph_test",
        num_parts,
        "/tmp/dist_graph",
        num_hops=num_hops,
        part_method="metis",
    )
1176

1177
1178
1179
    mask_thd = 100 if empty_mask else 30
    node_mask = np.random.randint(0, 100, size=g.num_nodes()) > mask_thd
    edge_mask = np.random.randint(0, 100, size=g.num_edges()) > mask_thd
1180
1181
1182
1183
    all_nodes1 = []
    all_nodes2 = []
    all_edges1 = []
    all_edges2 = []
Da Zheng's avatar
Da Zheng committed
1184
1185
1186
1187
1188

    # The code now collects the roles of all client processes and use the information
    # to determine how to split the workloads. Here is to simulate the multi-client
    # use case.
    def set_roles(num_clients):
1189
1190
1191
1192
1193
        dgl.distributed.role.CUR_ROLE = "default"
        dgl.distributed.role.GLOBAL_RANK = {i: i for i in range(num_clients)}
        dgl.distributed.role.PER_ROLE_RANK["default"] = {
            i: i for i in range(num_clients)
        }
Da Zheng's avatar
Da Zheng committed
1194

1195
    for i in range(num_parts):
Da Zheng's avatar
Da Zheng committed
1196
        set_roles(num_parts)
1197
1198
1199
1200
        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(
            "/tmp/dist_graph/dist_graph_test.json", i
        )
        local_nids = F.nonzero_1d(part_g.ndata["inner_node"])
1201
        local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
1202
        nodes = node_split(node_mask, gpb, rank=i, force_even=True)
1203
1204
        all_nodes1.append(nodes)
        subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(local_nids))
1205
1206
1207
1208
1209
        print(
            "part {} get {} nodes and {} are in the partition".format(
                i, len(nodes), len(subset)
            )
        )
1210

Da Zheng's avatar
Da Zheng committed
1211
        set_roles(num_parts * 2)
1212
1213
1214
        nodes1 = node_split(node_mask, gpb, rank=i * 2, force_even=True)
        nodes2 = node_split(node_mask, gpb, rank=i * 2 + 1, force_even=True)
        nodes3, _ = F.sort_1d(F.cat([nodes1, nodes2], 0))
1215
1216
        all_nodes2.append(nodes3)
        subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(nodes3))
1217
        print("intersection has", len(subset))
1218

Da Zheng's avatar
Da Zheng committed
1219
        set_roles(num_parts)
1220
        local_eids = F.nonzero_1d(part_g.edata["inner_edge"])
1221
        local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)
1222
        edges = edge_split(edge_mask, gpb, rank=i, force_even=True)
1223
1224
        all_edges1.append(edges)
        subset = np.intersect1d(F.asnumpy(edges), F.asnumpy(local_eids))
1225
1226
1227
1228
1229
        print(
            "part {} get {} edges and {} are in the partition".format(
                i, len(edges), len(subset)
            )
        )
1230

Da Zheng's avatar
Da Zheng committed
1231
        set_roles(num_parts * 2)
1232
1233
1234
        edges1 = edge_split(edge_mask, gpb, rank=i * 2, force_even=True)
        edges2 = edge_split(edge_mask, gpb, rank=i * 2 + 1, force_even=True)
        edges3, _ = F.sort_1d(F.cat([edges1, edges2], 0))
1235
1236
        all_edges2.append(edges3)
        subset = np.intersect1d(F.asnumpy(edges), F.asnumpy(edges3))
1237
        print("intersection has", len(subset))
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
    all_nodes1 = F.cat(all_nodes1, 0)
    all_edges1 = F.cat(all_edges1, 0)
    all_nodes2 = F.cat(all_nodes2, 0)
    all_edges2 = F.cat(all_edges2, 0)
    all_nodes = np.nonzero(node_mask)[0]
    all_edges = np.nonzero(edge_mask)[0]
    assert np.all(all_nodes == F.asnumpy(all_nodes1))
    assert np.all(all_edges == F.asnumpy(all_edges1))
    assert np.all(all_nodes == F.asnumpy(all_nodes2))
    assert np.all(all_edges == F.asnumpy(all_edges2))

1249

1250
1251
def prepare_dist(num_servers=1):
    generate_ip_config("kv_ip_config.txt", 1, num_servers=num_servers)
1252

1253
1254
1255

if __name__ == "__main__":
    os.makedirs("/tmp/dist_graph", exist_ok=True)
1256
    test_dist_emb_server_client()
1257
    test_server_client()
1258
1259
    test_split(True)
    test_split(False)
1260
    test_split_even()
1261
    test_standalone()
1262
    test_standalone_node_emb()