test_dist_graph_store.py 40.3 KB
Newer Older
1
import os
2
3
4
5
6
7

os.environ["OMP_NUM_THREADS"] = "1"
import math
import multiprocessing as mp
import pickle
import socket
8
9
10
import sys
import time
import unittest
11
12
13
from multiprocessing import Condition, Manager, Process, Value

import backend as F
14
15

import dgl
16
import numpy as np
17
import pytest
18
import torch as th
19
20
from dgl.data.utils import load_graphs, save_graphs
from dgl.distributed import (
21
    DistEmbedding,
22
23
24
25
26
27
28
29
    DistGraph,
    DistGraphServer,
    edge_split,
    load_partition,
    load_partition_book,
    node_split,
    partition_graph,
)
30
from dgl.distributed.optim import SparseAdagrad
31
from dgl.heterograph_index import create_unitgraph_from_coo
32
33
34
from numpy.testing import assert_almost_equal, assert_array_equal
from scipy import sparse as spsp
from utils import create_random_graph, generate_ip_config, reset_envs
35
36

if os.name != "nt":
37
38
39
    import fcntl
    import struct

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

def run_server(
    graph_name,
    server_id,
    server_count,
    num_clients,
    shared_mem,
    keep_alive=False,
):
    g = DistGraphServer(
        server_id,
        "kv_ip_config.txt",
        server_count,
        num_clients,
        "/tmp/dist_graph/{}.json".format(graph_name),
        disable_shared_mem=not shared_mem,
        graph_format=["csc", "coo"],
        keep_alive=keep_alive,
    )
    print("start server", server_id)
60
61
    # verify dtype of underlying graph
    cg = g.client_g
62
    for k, dtype in dgl.distributed.dist_graph.RESERVED_FIELD_DTYPE.items():
63
        if k in cg.ndata:
64
65
66
            assert (
                F.dtype(cg.ndata[k]) == dtype
            ), "Data type of {} in ndata should be {}.".format(k, dtype)
67
        if k in cg.edata:
68
69
70
            assert (
                F.dtype(cg.edata[k]) == dtype
            ), "Data type of {} in edata should be {}.".format(k, dtype)
71
72
    g.start()

73

74
75
76
def emb_init(shape, dtype):
    return F.zeros(shape, dtype, F.cpu())

77

78
def rand_init(shape, dtype):
79
    return F.tensor(np.random.normal(size=shape), F.float32)
80

81

82
83
def check_dist_graph_empty(g, num_clients, num_nodes, num_edges):
    # Test API
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
84
85
    assert g.num_nodes() == num_nodes
    assert g.num_edges() == num_edges
86
87

    # Test init node data
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
88
    new_shape = (g.num_nodes(), 2)
89
    g.ndata["test1"] = dgl.distributed.DistTensor(new_shape, F.int32)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
90
    nids = F.arange(0, int(g.num_nodes() / 2))
91
    feats = g.ndata["test1"][nids]
92
93
94
    assert np.all(F.asnumpy(feats) == 0)

    # create a tensor and destroy a tensor and create it again.
95
96
97
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test3", init_func=rand_init
    )
98
    del test3
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
99
    test3 = dgl.distributed.DistTensor((g.num_nodes(), 3), F.float32, "test3")
100
101
102
103
    del test3

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
104
105
    g.ndata["test1"][nids] = new_feats
    feats = g.ndata["test1"][nids]
106
107
108
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
109
    assert g.node_attr_schemes()["test1"].dtype == F.int32
110

111
    print("end")
112

113
114
115
116
117

def run_client_empty(
    graph_name, part_id, server_count, num_clients, num_nodes, num_edges
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
118
    dgl.distributed.initialize("kv_ip_config.txt")
119
    gpb, graph_name, _, _ = load_partition_book(
120
        "/tmp/dist_graph/{}.json".format(graph_name), part_id
121
    )
122
123
124
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_graph_empty(g, num_clients, num_nodes, num_edges)

125

126
def check_server_client_empty(shared_mem, num_servers, num_clients):
127
    prepare_dist(num_servers)
128
129
130
131
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
132
133
    graph_name = "dist_graph_test_1"
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
134
135
136
137

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
138
    ctx = mp.get_context("spawn")
139
    for serv_id in range(num_servers):
140
141
142
143
        p = ctx.Process(
            target=run_server,
            args=(graph_name, serv_id, num_servers, num_clients, shared_mem),
        )
144
145
146
147
148
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
149
150
151
152
153
154
155
156
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client_empty,
            args=(
                graph_name,
                0,
                num_servers,
                num_clients,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
157
158
                g.num_nodes(),
                g.num_edges(),
159
160
            ),
        )
161
162
163
164
165
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
166
        assert p.exitcode == 0
167
168
169

    for p in serv_ps:
        p.join()
170
        assert p.exitcode == 0
171

172
    print("clients have terminated")
173

174
175
176
177
178
179
180
181
182
183
184
185

def run_client(
    graph_name,
    part_id,
    server_count,
    num_clients,
    num_nodes,
    num_edges,
    group_id,
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
    os.environ["DGL_GROUP_ID"] = str(group_id)
186
    dgl.distributed.initialize("kv_ip_config.txt")
187
    gpb, graph_name, _, _ = load_partition_book(
188
        "/tmp/dist_graph/{}.json".format(graph_name), part_id
189
    )
190
    g = DistGraph(graph_name, gpb=gpb)
191
    check_dist_graph(g, num_clients, num_nodes, num_edges)
192

193
194
195
196
197
198
199
200
201
202
203
204

def run_emb_client(
    graph_name,
    part_id,
    server_count,
    num_clients,
    num_nodes,
    num_edges,
    group_id,
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
    os.environ["DGL_GROUP_ID"] = str(group_id)
205
    dgl.distributed.initialize("kv_ip_config.txt")
206
    gpb, graph_name, _, _ = load_partition_book(
207
        "/tmp/dist_graph/{}.json".format(graph_name), part_id
208
    )
209
210
211
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_emb(g, num_clients, num_nodes, num_edges)

212

213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def run_optim_client(
    graph_name,
    part_id,
    server_count,
    rank,
    world_size,
    num_nodes,
    optimizer_states,
    save,
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "12355"
    dgl.distributed.initialize("kv_ip_config.txt")
    th.distributed.init_process_group(
        backend="gloo", rank=rank, world_size=world_size
    )
    gpb, graph_name, _, _ = load_partition_book(
231
        "/tmp/dist_graph/{}.json".format(graph_name), part_id
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
    )
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_optim_store(rank, num_nodes, optimizer_states, save)


def check_dist_optim_store(rank, num_nodes, optimizer_states, save):
    try:
        total_idx = F.arange(0, num_nodes, F.int64, F.cpu())
        emb = DistEmbedding(num_nodes, 1, name="optim_emb1", init_func=emb_init)
        emb2 = DistEmbedding(
            num_nodes, 1, name="optim_emb2", init_func=emb_init
        )
        if save:
            optimizer = SparseAdagrad([emb, emb2], lr=0.1, eps=1e-08)
            if rank == 0:
                optimizer._state["optim_emb1"][total_idx] = optimizer_states[0]
                optimizer._state["optim_emb2"][total_idx] = optimizer_states[1]
            optimizer.save("/tmp/dist_graph/emb.pt")
        else:
            optimizer = SparseAdagrad([emb, emb2], lr=0.001, eps=2e-08)
            optimizer.load("/tmp/dist_graph/emb.pt")
            if rank == 0:
                assert F.allclose(
                    optimizer._state["optim_emb1"][total_idx],
                    optimizer_states[0],
                    0.0,
                    0.0,
                )
                assert F.allclose(
                    optimizer._state["optim_emb2"][total_idx],
                    optimizer_states[1],
                    0.0,
                    0.0,
                )
                assert 0.1 == optimizer._lr
                assert 1e-08 == optimizer._eps
            th.distributed.barrier()
    except Exception as e:
        print(e)
        sys.exit(-1)


274
275
276
277
def run_client_hierarchy(
    graph_name, part_id, server_count, node_mask, edge_mask, return_dict
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
278
    dgl.distributed.initialize("kv_ip_config.txt")
279
    gpb, graph_name, _, _ = load_partition_book(
280
        "/tmp/dist_graph/{}.json".format(graph_name), part_id
281
    )
282
283
284
    g = DistGraph(graph_name, gpb=gpb)
    node_mask = F.tensor(node_mask)
    edge_mask = F.tensor(edge_mask)
285
286
287
288
289
290
291
292
293
294
    nodes = node_split(
        node_mask,
        g.get_partition_book(),
        node_trainer_ids=g.ndata["trainer_id"],
    )
    edges = edge_split(
        edge_mask,
        g.get_partition_book(),
        edge_trainer_ids=g.edata["trainer_id"],
    )
295
296
297
    rank = g.rank()
    return_dict[rank] = (nodes, edges)

298

299
300
301
def check_dist_emb(g, num_clients, num_nodes, num_edges):
    # Test sparse emb
    try:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
302
303
        emb = DistEmbedding(g.num_nodes(), 1, "emb1", emb_init)
        nids = F.arange(0, int(g.num_nodes()))
304
305
306
307
308
309
310
311
312
313
314
        lr = 0.001
        optimizer = SparseAdagrad([emb], lr=lr)
        with F.record_grad():
            feats = emb(nids)
            assert np.all(F.asnumpy(feats) == np.zeros((len(nids), 1)))
            loss = F.sum(feats + 1, 0)
        loss.backward()
        optimizer.step()
        feats = emb(nids)
        if num_clients == 1:
            assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * -lr)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
315
        rest = np.setdiff1d(np.arange(g.num_nodes()), F.asnumpy(nids))
316
317
318
        feats1 = emb(rest)
        assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))

319
320
        policy = dgl.distributed.PartitionPolicy("node", g.get_partition_book())
        grad_sum = dgl.distributed.DistTensor(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
321
            (g.num_nodes(), 1), F.float32, "emb1_sum", policy
322
        )
323
        if num_clients == 1:
324
325
326
327
            assert np.all(
                F.asnumpy(grad_sum[nids])
                == np.ones((len(nids), 1)) * num_clients
            )
328
329
        assert np.all(F.asnumpy(grad_sum[rest]) == np.zeros((len(rest), 1)))

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
330
        emb = DistEmbedding(g.num_nodes(), 1, "emb2", emb_init)
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
        with F.no_grad():
            feats1 = emb(nids)
        assert np.all(F.asnumpy(feats1) == 0)

        optimizer = SparseAdagrad([emb], lr=lr)
        with F.record_grad():
            feats1 = emb(nids)
            feats2 = emb(nids)
            feats = F.cat([feats1, feats2], 0)
            assert np.all(F.asnumpy(feats) == np.zeros((len(nids) * 2, 1)))
            loss = F.sum(feats + 1, 0)
        loss.backward()
        optimizer.step()
        with F.no_grad():
            feats = emb(nids)
        if num_clients == 1:
347
348
349
            assert_almost_equal(
                F.asnumpy(feats), np.ones((len(nids), 1)) * 1 * -lr
            )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
350
        rest = np.setdiff1d(np.arange(g.num_nodes()), F.asnumpy(nids))
351
352
353
354
        feats1 = emb(rest)
        assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))
    except NotImplementedError as e:
        pass
355
356
357
    except Exception as e:
        print(e)
        sys.exit(-1)
358

359

360
def check_dist_graph(g, num_clients, num_nodes, num_edges):
361
    # Test API
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
362
363
    assert g.num_nodes() == num_nodes
    assert g.num_edges() == num_edges
364
365

    # Test reading node data
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
366
    nids = F.arange(0, int(g.num_nodes() / 2))
367
    feats1 = g.ndata["features"][nids]
368
369
370
371
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == nids))

    # Test reading edge data
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
372
    eids = F.arange(0, int(g.num_edges() / 2))
373
    feats1 = g.edata["features"][eids]
374
375
376
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == eids))

377
378
379
380
381
    # Test edge_subgraph
    sg = g.edge_subgraph(eids)
    assert sg.num_edges() == len(eids)
    assert F.array_equal(sg.edata[dgl.EID], eids)

382
    # Test init node data
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
383
    new_shape = (g.num_nodes(), 2)
384
    test1 = dgl.distributed.DistTensor(new_shape, F.int32)
385
386
    g.ndata["test1"] = test1
    feats = g.ndata["test1"][nids]
387
    assert np.all(F.asnumpy(feats) == 0)
388
    assert test1.count_nonzero() == 0
389

390
    # reference to a one that exists
391
392
393
394
    test2 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test2", init_func=rand_init
    )
    test3 = dgl.distributed.DistTensor(new_shape, F.float32, "test2")
395
396
397
    assert np.all(F.asnumpy(test2[nids]) == F.asnumpy(test3[nids]))

    # create a tensor and destroy a tensor and create it again.
398
399
400
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test3", init_func=rand_init
    )
401
402
403
    test3_name = test3.kvstore_key
    assert test3_name in g._client.data_name_list()
    assert test3_name in g._client.gdata_name_list()
404
    del test3
405
406
    assert test3_name not in g._client.data_name_list()
    assert test3_name not in g._client.gdata_name_list()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
407
    test3 = dgl.distributed.DistTensor((g.num_nodes(), 3), F.float32, "test3")
408
409
    del test3

Da Zheng's avatar
Da Zheng committed
410
    # add tests for anonymous distributed tensor.
411
412
413
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
Da Zheng's avatar
Da Zheng committed
414
    data = test3[0:10]
415
416
417
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
Da Zheng's avatar
Da Zheng committed
418
    del test3
419
420
421
    test5 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
Da Zheng's avatar
Da Zheng committed
422
423
    assert np.sum(F.asnumpy(test5[0:10] != data)) > 0

424
    # test a persistent tesnor
425
426
427
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test4", init_func=rand_init, persistent=True
    )
428
429
    del test4
    try:
430
        test4 = dgl.distributed.DistTensor(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
431
            (g.num_nodes(), 3), F.float32, "test4"
432
433
        )
        raise Exception("")
434
435
    except:
        pass
436
437
438

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
439
440
    g.ndata["test1"][nids] = new_feats
    feats = g.ndata["test1"][nids]
441
442
443
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
444
445
    assert len(g.ndata["features"]) == g.num_nodes()
    assert g.ndata["features"].shape == (g.num_nodes(), 1)
446
447
448
449
    assert g.ndata["features"].dtype == F.int64
    assert g.node_attr_schemes()["features"].dtype == F.int64
    assert g.node_attr_schemes()["test1"].dtype == F.int32
    assert g.node_attr_schemes()["features"].shape == (1,)
450

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
451
    selected_nodes = np.random.randint(0, 100, size=g.num_nodes()) > 30
452
    # Test node split
453
    nodes = node_split(selected_nodes, g.get_partition_book())
454
455
    nodes = F.asnumpy(nodes)
    # We only have one partition, so the local nodes are basically all nodes in the graph.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
456
    local_nids = np.arange(g.num_nodes())
457
458
459
    for n in nodes:
        assert n in local_nids

460
461
    print("end")

462

463
464
465
def check_dist_emb_server_client(
    shared_mem, num_servers, num_clients, num_groups=1
):
466
    prepare_dist(num_servers)
467
468
469
470
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
471
472
473
    graph_name = (
        f"check_dist_emb_{shared_mem}_{num_servers}_{num_clients}_{num_groups}"
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
474
475
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
476
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
477
478
479
480

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
481
    ctx = mp.get_context("spawn")
482
    keep_alive = num_groups > 1
483
    for serv_id in range(num_servers):
484
485
486
487
488
489
490
491
492
493
494
        p = ctx.Process(
            target=run_server,
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                shared_mem,
                keep_alive,
            ),
        )
495
496
497
498
499
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
500
        for group_id in range(num_groups):
501
502
503
504
505
506
507
508
            print("start client[{}] for group[{}]".format(cli_id, group_id))
            p = ctx.Process(
                target=run_emb_client,
                args=(
                    graph_name,
                    0,
                    num_servers,
                    num_clients,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
509
510
                    g.num_nodes(),
                    g.num_edges(),
511
512
513
                    group_id,
                ),
            )
514
            p.start()
515
            time.sleep(1)  # avoid race condition when instantiating DistGraph
516
            cli_ps.append(p)
517
518
519

    for p in cli_ps:
        p.join()
520
        assert p.exitcode == 0
521

522
523
524
525
526
    if keep_alive:
        for p in serv_ps:
            assert p.is_alive()
        # force shutdown server
        dgl.distributed.shutdown_servers("kv_ip_config.txt", num_servers)
527
528
    for p in serv_ps:
        p.join()
529
        assert p.exitcode == 0
530

531
532
    print("clients have terminated")

533

534
def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
535
    prepare_dist(num_servers)
536
537
538
539
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
540
    graph_name = f"check_server_client_{shared_mem}_{num_servers}_{num_clients}_{num_groups}"
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
541
542
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
543
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
544
545
546
547

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
548
    ctx = mp.get_context("spawn")
549
    keep_alive = num_groups > 1
550
    for serv_id in range(num_servers):
551
552
553
554
555
556
557
558
559
560
561
        p = ctx.Process(
            target=run_server,
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                shared_mem,
                keep_alive,
            ),
        )
562
563
564
        serv_ps.append(p)
        p.start()

565
    # launch different client groups simultaneously
566
    cli_ps = []
567
    for cli_id in range(num_clients):
568
        for group_id in range(num_groups):
569
570
571
572
573
574
575
576
            print("start client[{}] for group[{}]".format(cli_id, group_id))
            p = ctx.Process(
                target=run_client,
                args=(
                    graph_name,
                    0,
                    num_servers,
                    num_clients,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
577
578
                    g.num_nodes(),
                    g.num_edges(),
579
580
581
                    group_id,
                ),
            )
582
            p.start()
583
            time.sleep(1)  # avoid race condition when instantiating DistGraph
584
            cli_ps.append(p)
585
586
    for p in cli_ps:
        p.join()
587
        assert p.exitcode == 0
588

589
590
591
592
593
    if keep_alive:
        for p in serv_ps:
            assert p.is_alive()
        # force shutdown server
        dgl.distributed.shutdown_servers("kv_ip_config.txt", num_servers)
594
595
    for p in serv_ps:
        p.join()
596
        assert p.exitcode == 0
597

598
599
    print("clients have terminated")

600

601
def check_server_client_hierarchy(shared_mem, num_servers, num_clients):
602
    prepare_dist(num_servers)
603
604
605
606
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
607
    graph_name = "dist_graph_test_2"
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
608
609
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
610
611
612
613
614
615
616
    partition_graph(
        g,
        graph_name,
        num_parts,
        "/tmp/dist_graph",
        num_trainers_per_machine=num_clients,
    )
617
618
619
620

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
621
    ctx = mp.get_context("spawn")
622
    for serv_id in range(num_servers):
623
624
625
626
        p = ctx.Process(
            target=run_server,
            args=(graph_name, serv_id, num_servers, num_clients, shared_mem),
        )
627
628
629
630
631
632
        serv_ps.append(p)
        p.start()

    cli_ps = []
    manager = mp.Manager()
    return_dict = manager.dict()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
633
634
635
636
    node_mask = np.zeros((g.num_nodes(),), np.int32)
    edge_mask = np.zeros((g.num_edges(),), np.int32)
    nodes = np.random.choice(g.num_nodes(), g.num_nodes() // 10, replace=False)
    edges = np.random.choice(g.num_edges(), g.num_edges() // 10, replace=False)
637
638
639
640
641
    node_mask[nodes] = 1
    edge_mask[edges] = 1
    nodes = np.sort(nodes)
    edges = np.sort(edges)
    for cli_id in range(num_clients):
642
643
644
645
646
647
648
649
650
651
652
653
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client_hierarchy,
            args=(
                graph_name,
                0,
                num_servers,
                node_mask,
                edge_mask,
                return_dict,
            ),
        )
654
655
656
657
658
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
659
        assert p.exitcode == 0
660
661
    for p in serv_ps:
        p.join()
662
        assert p.exitcode == 0
663
664
665
666
667
668
669
670
671
672
    nodes1 = []
    edges1 = []
    for n, e in return_dict.values():
        nodes1.append(n)
        edges1.append(e)
    nodes1, _ = F.sort_1d(F.cat(nodes1, 0))
    edges1, _ = F.sort_1d(F.cat(edges1, 0))
    assert np.all(F.asnumpy(nodes1) == nodes)
    assert np.all(F.asnumpy(edges1) == edges)

673
    print("clients have terminated")
674

675

676
677
678
679
def run_client_hetero(
    graph_name, part_id, server_count, num_clients, num_nodes, num_edges
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
680
    dgl.distributed.initialize("kv_ip_config.txt")
681
    gpb, graph_name, _, _ = load_partition_book(
682
        "/tmp/dist_graph/{}.json".format(graph_name), part_id
683
    )
684
685
686
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_graph_hetero(g, num_clients, num_nodes, num_edges)

687

688
def create_random_hetero():
689
690
    num_nodes = {"n1": 10000, "n2": 10010, "n3": 10020}
    etypes = [("n1", "r1", "n2"), ("n1", "r2", "n3"), ("n2", "r3", "n3")]
691
692
693
    edges = {}
    for etype in etypes:
        src_ntype, _, dst_ntype = etype
694
695
696
697
698
699
700
        arr = spsp.random(
            num_nodes[src_ntype],
            num_nodes[dst_ntype],
            density=0.001,
            format="coo",
            random_state=100,
        )
701
702
        edges[etype] = (arr.row, arr.col)
    g = dgl.heterograph(edges, num_nodes)
703
704
705
706
    # assign ndata & edata.
    # data with same name as ntype/etype is assigned on purpose to verify
    # such same names can be correctly handled in DistGraph. See more details
    # in issue #4887 and #4463 on github.
707
708
    ntype = "n1"
    for name in ["feat", ntype]:
709
710
711
        g.nodes[ntype].data[name] = F.unsqueeze(
            F.arange(0, g.num_nodes(ntype)), 1
        )
712
713
    etype = "r1"
    for name in ["feat", etype]:
714
715
716
        g.edges[etype].data[name] = F.unsqueeze(
            F.arange(0, g.num_edges(etype)), 1
        )
717
718
    return g

719

720
721
722
723
def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges):
    # Test API
    for ntype in num_nodes:
        assert ntype in g.ntypes
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
724
        assert num_nodes[ntype] == g.num_nodes(ntype)
725
726
    for etype in num_edges:
        assert etype in g.etypes
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
727
        assert num_edges[etype] == g.num_edges(etype)
728
    etypes = [("n1", "r1", "n2"), ("n1", "r2", "n3"), ("n2", "r3", "n3")]
729
730
731
732
    for i, etype in enumerate(g.canonical_etypes):
        assert etype[0] == etypes[i][0]
        assert etype[1] == etypes[i][1]
        assert etype[2] == etypes[i][2]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
733
734
    assert g.num_nodes() == sum([num_nodes[ntype] for ntype in num_nodes])
    assert g.num_edges() == sum([num_edges[etype] for etype in num_edges])
735
736

    # Test reading node data
737
    ntype = "n1"
738
    nids = F.arange(0, g.num_nodes(ntype) // 2)
739
    for name in ["feat", ntype]:
740
741
742
        data = g.nodes[ntype].data[name][nids]
        data = F.squeeze(data, 1)
        assert np.all(F.asnumpy(data == nids))
743
    assert len(g.nodes["n2"].data) == 0
744
745
    expect_except = False
    try:
746
        g.nodes["xxx"].data["x"]
747
748
749
    except dgl.DGLError:
        expect_except = True
    assert expect_except
750
751

    # Test reading edge data
752
    etype = "r1"
753
    eids = F.arange(0, g.num_edges(etype) // 2)
754
    for name in ["feat", etype]:
755
756
757
758
759
760
761
762
763
        # access via etype
        data = g.edges[etype].data[name][eids]
        data = F.squeeze(data, 1)
        assert np.all(F.asnumpy(data == eids))
        # access via canonical etype
        c_etype = g.to_canonical_etype(etype)
        data = g.edges[c_etype].data[name][eids]
        data = F.squeeze(data, 1)
        assert np.all(F.asnumpy(data == eids))
764
    assert len(g.edges["r2"].data) == 0
765
766
    expect_except = False
    try:
767
        g.edges["xxx"].data["x"]
768
769
770
    except dgl.DGLError:
        expect_except = True
    assert expect_except
771

772
    # Test edge_subgraph
773
    sg = g.edge_subgraph({"r1": eids})
774
775
    assert sg.num_edges() == len(eids)
    assert F.array_equal(sg.edata[dgl.EID], eids)
776
    sg = g.edge_subgraph({("n1", "r1", "n2"): eids})
777
778
779
    assert sg.num_edges() == len(eids)
    assert F.array_equal(sg.edata[dgl.EID], eids)

780
    # Test init node data
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
781
    new_shape = (g.num_nodes("n1"), 2)
782
783
    g.nodes["n1"].data["test1"] = dgl.distributed.DistTensor(new_shape, F.int32)
    feats = g.nodes["n1"].data["test1"][nids]
784
785
786
    assert np.all(F.asnumpy(feats) == 0)

    # create a tensor and destroy a tensor and create it again.
787
788
789
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test3", init_func=rand_init
    )
790
    del test3
791
    test3 = dgl.distributed.DistTensor(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
792
        (g.num_nodes("n1"), 3), F.float32, "test3"
793
    )
794
795
796
    del test3

    # add tests for anonymous distributed tensor.
797
798
799
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
800
    data = test3[0:10]
801
802
803
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
804
    del test3
805
806
807
    test5 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
808
809
810
    assert np.sum(F.asnumpy(test5[0:10] != data)) > 0

    # test a persistent tesnor
811
812
813
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test4", init_func=rand_init, persistent=True
    )
814
815
    del test4
    try:
816
        test4 = dgl.distributed.DistTensor(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
817
            (g.num_nodes("n1"), 3), F.float32, "test4"
818
819
        )
        raise Exception("")
820
821
822
823
824
    except:
        pass

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
825
826
    g.nodes["n1"].data["test1"][nids] = new_feats
    feats = g.nodes["n1"].data["test1"][nids]
827
828
829
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
830
831
    assert len(g.nodes["n1"].data["feat"]) == g.num_nodes("n1")
    assert g.nodes["n1"].data["feat"].shape == (g.num_nodes("n1"), 1)
832
    assert g.nodes["n1"].data["feat"].dtype == F.int64
833

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
834
    selected_nodes = np.random.randint(0, 100, size=g.num_nodes("n1")) > 30
835
    # Test node split
836
    nodes = node_split(selected_nodes, g.get_partition_book(), ntype="n1")
837
838
    nodes = F.asnumpy(nodes)
    # We only have one partition, so the local nodes are basically all nodes in the graph.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
839
    local_nids = np.arange(g.num_nodes("n1"))
840
841
842
    for n in nodes:
        assert n in local_nids

843
844
    print("end")

845
846

def check_server_client_hetero(shared_mem, num_servers, num_clients):
847
    prepare_dist(num_servers)
848
849
850
851
    g = create_random_hetero()

    # Partition the graph
    num_parts = 1
852
853
    graph_name = "dist_graph_test_3"
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
854
855
856
857

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
858
    ctx = mp.get_context("spawn")
859
    for serv_id in range(num_servers):
860
861
862
863
        p = ctx.Process(
            target=run_server,
            args=(graph_name, serv_id, num_servers, num_clients, shared_mem),
        )
864
865
866
867
        serv_ps.append(p)
        p.start()

    cli_ps = []
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
868
869
    num_nodes = {ntype: g.num_nodes(ntype) for ntype in g.ntypes}
    num_edges = {etype: g.num_edges(etype) for etype in g.etypes}
870
    for cli_id in range(num_clients):
871
872
873
874
875
876
877
878
879
880
881
882
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client_hetero,
            args=(
                graph_name,
                0,
                num_servers,
                num_clients,
                num_nodes,
                num_edges,
            ),
        )
883
884
885
886
887
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
888
        assert p.exitcode == 0
889
890
891

    for p in serv_ps:
        p.join()
892
        assert p.exitcode == 0
893

894
895
    print("clients have terminated")

896

897
898
899
900
901
902
903
904
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support some of operations in DistGraph",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
905
def test_server_client():
906
    reset_envs()
907
    os.environ["DGL_DIST_MODE"] = "distributed"
908
    check_server_client_hierarchy(False, 1, 4)
909
    check_server_client_empty(True, 1, 1)
910
911
    check_server_client_hetero(True, 1, 1)
    check_server_client_hetero(False, 1, 1)
912
913
    check_server_client(True, 1, 1)
    check_server_client(False, 1, 1)
914
915
    # [TODO][Rhett] Tests for multiple groups may fail sometimes and
    # root cause is unknown. Let's disable them for now.
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
    # check_server_client(True, 2, 2)
    # check_server_client(True, 1, 1, 2)
    # check_server_client(False, 1, 1, 2)
    # check_server_client(True, 2, 2, 2)


@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support distributed DistEmbedding",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="Mxnet doesn't support distributed DistEmbedding",
)
931
def test_dist_emb_server_client():
932
    reset_envs()
933
    os.environ["DGL_DIST_MODE"] = "distributed"
934
935
    check_dist_emb_server_client(True, 1, 1)
    check_dist_emb_server_client(False, 1, 1)
936
937
    # [TODO][Rhett] Tests for multiple groups may fail sometimes and
    # root cause is unknown. Let's disable them for now.
938
939
940
941
942
943
    # check_dist_emb_server_client(True, 2, 2)
    # check_dist_emb_server_client(True, 1, 1, 2)
    # check_dist_emb_server_client(False, 1, 1, 2)
    # check_dist_emb_server_client(True, 2, 2, 2)


944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support distributed Optimizer",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="Mxnet doesn't support distributed Optimizer",
)
def test_dist_optim_server_client():
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    optimizer_states = []
    num_nodes = 10000
    optimizer_states.append(F.uniform((num_nodes, 1), F.float32, F.cpu(), 0, 1))
    optimizer_states.append(F.uniform((num_nodes, 1), F.float32, F.cpu(), 0, 1))
    check_dist_optim_server_client(num_nodes, 1, 4, optimizer_states, True)
    check_dist_optim_server_client(num_nodes, 1, 8, optimizer_states, False)
    check_dist_optim_server_client(num_nodes, 1, 2, optimizer_states, False)


def check_dist_optim_server_client(
    num_nodes, num_servers, num_clients, optimizer_states, save
):
    graph_name = f"check_dist_optim_{num_servers}_store"
    if save:
        prepare_dist(num_servers)
        g = create_random_graph(num_nodes)

        # Partition the graph
        num_parts = 1
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
974
975
        g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
        g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
        partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
    ctx = mp.get_context("spawn")
    for serv_id in range(num_servers):
        p = ctx.Process(
            target=run_server,
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                True,
                False,
            ),
        )
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
        print("start client[{}] for group[0]".format(cli_id))
        p = ctx.Process(
            target=run_optim_client,
            args=(
                graph_name,
                0,
                num_servers,
                cli_id,
                num_clients,
                num_nodes,
                optimizer_states,
                save,
            ),
        )
        p.start()
        time.sleep(1)  # avoid race condition when instantiating DistGraph
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
        assert p.exitcode == 0

    for p in serv_ps:
        p.join()
1023
        assert p.exitcode == 0
1024
1025


1026
1027
1028
1029
1030
1031
1032
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support some of operations in DistGraph",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
1033
def test_standalone():
1034
    reset_envs()
1035
    os.environ["DGL_DIST_MODE"] = "standalone"
Da Zheng's avatar
Da Zheng committed
1036

1037
1038
1039
    g = create_random_graph(10000)
    # Partition the graph
    num_parts = 1
1040
    graph_name = "dist_graph_test_3"
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1041
1042
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
1043
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
1044
1045

    dgl.distributed.initialize("kv_ip_config.txt")
1046
1047
1048
    dist_g = DistGraph(
        graph_name, part_config="/tmp/dist_graph/{}.json".format(graph_name)
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1049
    check_dist_graph(dist_g, 1, g.num_nodes(), g.num_edges())
1050
1051
    dgl.distributed.exit_client()  # this is needed since there's two test here in one process

1052

1053
1054
1055
1056
1057
1058
1059
1060
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support distributed DistEmbedding",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="Mxnet doesn't support distributed DistEmbedding",
)
1061
def test_standalone_node_emb():
1062
    reset_envs()
1063
    os.environ["DGL_DIST_MODE"] = "standalone"
1064
1065
1066
1067

    g = create_random_graph(10000)
    # Partition the graph
    num_parts = 1
1068
    graph_name = "dist_graph_test_3"
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1069
1070
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
1071
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
1072
1073

    dgl.distributed.initialize("kv_ip_config.txt")
1074
1075
1076
    dist_g = DistGraph(
        graph_name, part_config="/tmp/dist_graph/{}.json".format(graph_name)
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1077
    check_dist_emb(dist_g, 1, g.num_nodes(), g.num_edges())
1078
1079
    dgl.distributed.exit_client()  # this is needed since there's two test here in one process

1080

1081
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
1082
1083
1084
1085
@pytest.mark.parametrize("hetero", [True, False])
def test_split(hetero):
    if hetero:
        g = create_random_hetero()
1086
1087
        ntype = "n1"
        etype = "r1"
1088
1089
    else:
        g = create_random_graph(10000)
1090
1091
        ntype = "_N"
        etype = "_E"
1092
1093
    num_parts = 4
    num_hops = 2
1094
1095
1096
1097
1098
1099
1100
1101
    partition_graph(
        g,
        "dist_graph_test",
        num_parts,
        "/tmp/dist_graph",
        num_hops=num_hops,
        part_method="metis",
    )
1102

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1103
1104
    node_mask = np.random.randint(0, 100, size=g.num_nodes(ntype)) > 30
    edge_mask = np.random.randint(0, 100, size=g.num_edges(etype)) > 30
1105
1106
    selected_nodes = np.nonzero(node_mask)[0]
    selected_edges = np.nonzero(edge_mask)[0]
Da Zheng's avatar
Da Zheng committed
1107
1108
1109
1110
1111

    # The code now collects the roles of all client processes and use the information
    # to determine how to split the workloads. Here is to simulate the multi-client
    # use case.
    def set_roles(num_clients):
1112
1113
1114
1115
1116
        dgl.distributed.role.CUR_ROLE = "default"
        dgl.distributed.role.GLOBAL_RANK = {i: i for i in range(num_clients)}
        dgl.distributed.role.PER_ROLE_RANK["default"] = {
            i: i for i in range(num_clients)
        }
Da Zheng's avatar
Da Zheng committed
1117

1118
    for i in range(num_parts):
Da Zheng's avatar
Da Zheng committed
1119
        set_roles(num_parts)
1120
1121
1122
1123
        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(
            "/tmp/dist_graph/dist_graph_test.json", i
        )
        local_nids = F.nonzero_1d(part_g.ndata["inner_node"])
1124
        local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
1125
1126
1127
1128
1129
1130
        if hetero:
            ntype_ids, nids = gpb.map_to_per_ntype(local_nids)
            local_nids = F.asnumpy(nids)[F.asnumpy(ntype_ids) == 0]
        else:
            local_nids = F.asnumpy(local_nids)
        nodes1 = np.intersect1d(selected_nodes, local_nids)
1131
1132
1133
        nodes2 = node_split(
            node_mask, gpb, ntype=ntype, rank=i, force_even=False
        )
1134
        assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes2)))
1135
        for n in F.asnumpy(nodes2):
1136
1137
            assert n in local_nids

Da Zheng's avatar
Da Zheng committed
1138
        set_roles(num_parts * 2)
1139
1140
1141
1142
1143
1144
        nodes3 = node_split(
            node_mask, gpb, ntype=ntype, rank=i * 2, force_even=False
        )
        nodes4 = node_split(
            node_mask, gpb, ntype=ntype, rank=i * 2 + 1, force_even=False
        )
1145
1146
1147
        nodes5 = F.cat([nodes3, nodes4], 0)
        assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes5)))

Da Zheng's avatar
Da Zheng committed
1148
        set_roles(num_parts)
1149
        local_eids = F.nonzero_1d(part_g.edata["inner_edge"])
1150
        local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)
1151
1152
1153
1154
1155
1156
        if hetero:
            etype_ids, eids = gpb.map_to_per_etype(local_eids)
            local_eids = F.asnumpy(eids)[F.asnumpy(etype_ids) == 0]
        else:
            local_eids = F.asnumpy(local_eids)
        edges1 = np.intersect1d(selected_edges, local_eids)
1157
1158
1159
        edges2 = edge_split(
            edge_mask, gpb, etype=etype, rank=i, force_even=False
        )
1160
        assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges2)))
1161
        for e in F.asnumpy(edges2):
1162
1163
            assert e in local_eids

Da Zheng's avatar
Da Zheng committed
1164
        set_roles(num_parts * 2)
1165
1166
1167
1168
1169
1170
        edges3 = edge_split(
            edge_mask, gpb, etype=etype, rank=i * 2, force_even=False
        )
        edges4 = edge_split(
            edge_mask, gpb, etype=etype, rank=i * 2 + 1, force_even=False
        )
1171
1172
1173
        edges5 = F.cat([edges3, edges4], 0)
        assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges5)))

1174
1175

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
1176
1177
1178
1179
def test_split_even():
    g = create_random_graph(10000)
    num_parts = 4
    num_hops = 2
1180
1181
1182
1183
1184
1185
1186
1187
    partition_graph(
        g,
        "dist_graph_test",
        num_parts,
        "/tmp/dist_graph",
        num_hops=num_hops,
        part_method="metis",
    )
1188

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1189
1190
    node_mask = np.random.randint(0, 100, size=g.num_nodes()) > 30
    edge_mask = np.random.randint(0, 100, size=g.num_edges()) > 30
1191
1192
1193
1194
1195
1196
    selected_nodes = np.nonzero(node_mask)[0]
    selected_edges = np.nonzero(edge_mask)[0]
    all_nodes1 = []
    all_nodes2 = []
    all_edges1 = []
    all_edges2 = []
Da Zheng's avatar
Da Zheng committed
1197
1198
1199
1200
1201

    # The code now collects the roles of all client processes and use the information
    # to determine how to split the workloads. Here is to simulate the multi-client
    # use case.
    def set_roles(num_clients):
1202
1203
1204
1205
1206
        dgl.distributed.role.CUR_ROLE = "default"
        dgl.distributed.role.GLOBAL_RANK = {i: i for i in range(num_clients)}
        dgl.distributed.role.PER_ROLE_RANK["default"] = {
            i: i for i in range(num_clients)
        }
Da Zheng's avatar
Da Zheng committed
1207

1208
    for i in range(num_parts):
Da Zheng's avatar
Da Zheng committed
1209
        set_roles(num_parts)
1210
1211
1212
1213
        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(
            "/tmp/dist_graph/dist_graph_test.json", i
        )
        local_nids = F.nonzero_1d(part_g.ndata["inner_node"])
1214
        local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
1215
        nodes = node_split(node_mask, gpb, rank=i, force_even=True)
1216
1217
        all_nodes1.append(nodes)
        subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(local_nids))
1218
1219
1220
1221
1222
        print(
            "part {} get {} nodes and {} are in the partition".format(
                i, len(nodes), len(subset)
            )
        )
1223

Da Zheng's avatar
Da Zheng committed
1224
        set_roles(num_parts * 2)
1225
1226
1227
        nodes1 = node_split(node_mask, gpb, rank=i * 2, force_even=True)
        nodes2 = node_split(node_mask, gpb, rank=i * 2 + 1, force_even=True)
        nodes3, _ = F.sort_1d(F.cat([nodes1, nodes2], 0))
1228
1229
        all_nodes2.append(nodes3)
        subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(nodes3))
1230
        print("intersection has", len(subset))
1231

Da Zheng's avatar
Da Zheng committed
1232
        set_roles(num_parts)
1233
        local_eids = F.nonzero_1d(part_g.edata["inner_edge"])
1234
        local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)
1235
        edges = edge_split(edge_mask, gpb, rank=i, force_even=True)
1236
1237
        all_edges1.append(edges)
        subset = np.intersect1d(F.asnumpy(edges), F.asnumpy(local_eids))
1238
1239
1240
1241
1242
        print(
            "part {} get {} edges and {} are in the partition".format(
                i, len(edges), len(subset)
            )
        )
1243

Da Zheng's avatar
Da Zheng committed
1244
        set_roles(num_parts * 2)
1245
1246
1247
        edges1 = edge_split(edge_mask, gpb, rank=i * 2, force_even=True)
        edges2 = edge_split(edge_mask, gpb, rank=i * 2 + 1, force_even=True)
        edges3, _ = F.sort_1d(F.cat([edges1, edges2], 0))
1248
1249
        all_edges2.append(edges3)
        subset = np.intersect1d(F.asnumpy(edges), F.asnumpy(edges3))
1250
        print("intersection has", len(subset))
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
    all_nodes1 = F.cat(all_nodes1, 0)
    all_edges1 = F.cat(all_edges1, 0)
    all_nodes2 = F.cat(all_nodes2, 0)
    all_edges2 = F.cat(all_edges2, 0)
    all_nodes = np.nonzero(node_mask)[0]
    all_edges = np.nonzero(edge_mask)[0]
    assert np.all(all_nodes == F.asnumpy(all_nodes1))
    assert np.all(all_edges == F.asnumpy(all_edges1))
    assert np.all(all_nodes == F.asnumpy(all_nodes2))
    assert np.all(all_edges == F.asnumpy(all_edges2))

1262

1263
1264
def prepare_dist(num_servers=1):
    generate_ip_config("kv_ip_config.txt", 1, num_servers=num_servers)
1265

1266
1267
1268

if __name__ == "__main__":
    os.makedirs("/tmp/dist_graph", exist_ok=True)
1269
    test_dist_emb_server_client()
1270
    test_server_client()
1271
1272
    test_split(True)
    test_split(False)
1273
    test_split_even()
1274
    test_standalone()
1275
    test_standalone_node_emb()