"tests/vscode:/vscode.git/clone" did not exist on "352ca3198cb25e6098f795568547075ff28e3133"
test_dist_graph_store.py 42.7 KB
Newer Older
1
import os
2
3
4
5
6
7

os.environ["OMP_NUM_THREADS"] = "1"
import math
import multiprocessing as mp
import pickle
import socket
8
9
10
import sys
import time
import unittest
11
12
13
from multiprocessing import Condition, Manager, Process, Value

import backend as F
14
15

import dgl
16
import dgl.graphbolt as gb
17
import numpy as np
18
import pytest
19
import torch as th
20
21
from dgl.data.utils import load_graphs, save_graphs
from dgl.distributed import (
22
    dgl_partition_to_graphbolt,
23
    DistEmbedding,
24
25
26
27
28
29
30
31
    DistGraph,
    DistGraphServer,
    edge_split,
    load_partition,
    load_partition_book,
    node_split,
    partition_graph,
)
32
from dgl.distributed.optim import SparseAdagrad
33
from dgl.heterograph_index import create_unitgraph_from_coo
34
35
36
from numpy.testing import assert_almost_equal, assert_array_equal
from scipy import sparse as spsp
from utils import create_random_graph, generate_ip_config, reset_envs
37
38

if os.name != "nt":
39
40
41
    import fcntl
    import struct

42

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def _verify_dist_graph_server_dgl(g):
    # verify dtype of underlying graph
    cg = g.client_g
    for k, dtype in dgl.distributed.dist_graph.RESERVED_FIELD_DTYPE.items():
        if k in cg.ndata:
            assert (
                F.dtype(cg.ndata[k]) == dtype
            ), "Data type of {} in ndata should be {}.".format(k, dtype)
        if k in cg.edata:
            assert (
                F.dtype(cg.edata[k]) == dtype
            ), "Data type of {} in edata should be {}.".format(k, dtype)


def _verify_dist_graph_server_graphbolt(g):
    graph = g.client_g
    assert isinstance(graph, gb.FusedCSCSamplingGraph)
    # [Rui][TODO] verify dtype of underlying graph.


63
64
65
66
67
68
def run_server(
    graph_name,
    server_id,
    server_count,
    num_clients,
    shared_mem,
69
    use_graphbolt=False,
70
71
72
73
74
75
76
77
78
):
    g = DistGraphServer(
        server_id,
        "kv_ip_config.txt",
        server_count,
        num_clients,
        "/tmp/dist_graph/{}.json".format(graph_name),
        disable_shared_mem=not shared_mem,
        graph_format=["csc", "coo"],
79
        use_graphbolt=use_graphbolt,
80
    )
81
82
83
84
85
86
87
    print(f"Starting server[{server_id}] with use_graphbolt={use_graphbolt}")
    _verify = (
        _verify_dist_graph_server_graphbolt
        if use_graphbolt
        else _verify_dist_graph_server_dgl
    )
    _verify(g)
88
89
    g.start()

90

91
92
93
def emb_init(shape, dtype):
    return F.zeros(shape, dtype, F.cpu())

94

95
def rand_init(shape, dtype):
96
    return F.tensor(np.random.normal(size=shape), F.float32)
97

98

99
100
def check_dist_graph_empty(g, num_clients, num_nodes, num_edges):
    # Test API
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
101
102
    assert g.num_nodes() == num_nodes
    assert g.num_edges() == num_edges
103
104

    # Test init node data
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
105
    new_shape = (g.num_nodes(), 2)
106
    g.ndata["test1"] = dgl.distributed.DistTensor(new_shape, F.int32)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
107
    nids = F.arange(0, int(g.num_nodes() / 2))
108
    feats = g.ndata["test1"][nids]
109
110
111
    assert np.all(F.asnumpy(feats) == 0)

    # create a tensor and destroy a tensor and create it again.
112
113
114
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test3", init_func=rand_init
    )
115
    del test3
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
116
    test3 = dgl.distributed.DistTensor((g.num_nodes(), 3), F.float32, "test3")
117
118
119
120
    del test3

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
121
122
    g.ndata["test1"][nids] = new_feats
    feats = g.ndata["test1"][nids]
123
124
125
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
126
    assert g.node_attr_schemes()["test1"].dtype == F.int32
127

128
    print("end")
129

130
131

def run_client_empty(
132
133
134
135
136
137
138
    graph_name,
    part_id,
    server_count,
    num_clients,
    num_nodes,
    num_edges,
    use_graphbolt=False,
139
140
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
141
    dgl.distributed.initialize("kv_ip_config.txt")
142
    gpb, graph_name, _, _ = load_partition_book(
143
        "/tmp/dist_graph/{}.json".format(graph_name), part_id
144
    )
145
    g = DistGraph(graph_name, gpb=gpb, use_graphbolt=use_graphbolt)
146
147
    check_dist_graph_empty(g, num_clients, num_nodes, num_edges)

148

149
150
151
def check_server_client_empty(
    shared_mem, num_servers, num_clients, use_graphbolt=False
):
152
    prepare_dist(num_servers)
153
154
155
156
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
157
158
    graph_name = "dist_graph_test_1"
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
159
160
161
    if use_graphbolt:
        part_config = os.path.join("/tmp/dist_graph", f"{graph_name}.json")
        dgl_partition_to_graphbolt(part_config)
162
163
164
165

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
166
    ctx = mp.get_context("spawn")
167
    for serv_id in range(num_servers):
168
169
        p = ctx.Process(
            target=run_server,
170
171
172
173
174
175
176
177
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                shared_mem,
                use_graphbolt,
            ),
178
        )
179
180
181
182
183
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
184
185
186
187
188
189
190
191
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client_empty,
            args=(
                graph_name,
                0,
                num_servers,
                num_clients,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
192
193
                g.num_nodes(),
                g.num_edges(),
194
                use_graphbolt,
195
196
            ),
        )
197
198
199
200
201
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
202
        assert p.exitcode == 0
203
204
205

    for p in serv_ps:
        p.join()
206
        assert p.exitcode == 0
207

208
    print("clients have terminated")
209

210
211
212
213
214
215
216
217
218

def run_client(
    graph_name,
    part_id,
    server_count,
    num_clients,
    num_nodes,
    num_edges,
    group_id,
219
    use_graphbolt=False,
220
221
222
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
    os.environ["DGL_GROUP_ID"] = str(group_id)
223
    dgl.distributed.initialize("kv_ip_config.txt")
224
    gpb, graph_name, _, _ = load_partition_book(
225
        "/tmp/dist_graph/{}.json".format(graph_name), part_id
226
    )
227
228
229
230
    g = DistGraph(graph_name, gpb=gpb, use_graphbolt=use_graphbolt)
    check_dist_graph(
        g, num_clients, num_nodes, num_edges, use_graphbolt=use_graphbolt
    )
231

232
233
234
235
236
237
238
239
240
241
242
243

def run_emb_client(
    graph_name,
    part_id,
    server_count,
    num_clients,
    num_nodes,
    num_edges,
    group_id,
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
    os.environ["DGL_GROUP_ID"] = str(group_id)
244
    dgl.distributed.initialize("kv_ip_config.txt")
245
    gpb, graph_name, _, _ = load_partition_book(
246
        "/tmp/dist_graph/{}.json".format(graph_name), part_id
247
    )
248
249
250
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_emb(g, num_clients, num_nodes, num_edges)

251

252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
def run_optim_client(
    graph_name,
    part_id,
    server_count,
    rank,
    world_size,
    num_nodes,
    optimizer_states,
    save,
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "12355"
    dgl.distributed.initialize("kv_ip_config.txt")
    th.distributed.init_process_group(
        backend="gloo", rank=rank, world_size=world_size
    )
    gpb, graph_name, _, _ = load_partition_book(
270
        "/tmp/dist_graph/{}.json".format(graph_name), part_id
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
    )
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_optim_store(rank, num_nodes, optimizer_states, save)


def check_dist_optim_store(rank, num_nodes, optimizer_states, save):
    try:
        total_idx = F.arange(0, num_nodes, F.int64, F.cpu())
        emb = DistEmbedding(num_nodes, 1, name="optim_emb1", init_func=emb_init)
        emb2 = DistEmbedding(
            num_nodes, 1, name="optim_emb2", init_func=emb_init
        )
        if save:
            optimizer = SparseAdagrad([emb, emb2], lr=0.1, eps=1e-08)
            if rank == 0:
                optimizer._state["optim_emb1"][total_idx] = optimizer_states[0]
                optimizer._state["optim_emb2"][total_idx] = optimizer_states[1]
            optimizer.save("/tmp/dist_graph/emb.pt")
        else:
            optimizer = SparseAdagrad([emb, emb2], lr=0.001, eps=2e-08)
            optimizer.load("/tmp/dist_graph/emb.pt")
            if rank == 0:
                assert F.allclose(
                    optimizer._state["optim_emb1"][total_idx],
                    optimizer_states[0],
                    0.0,
                    0.0,
                )
                assert F.allclose(
                    optimizer._state["optim_emb2"][total_idx],
                    optimizer_states[1],
                    0.0,
                    0.0,
                )
                assert 0.1 == optimizer._lr
                assert 1e-08 == optimizer._eps
            th.distributed.barrier()
    except Exception as e:
        print(e)
        sys.exit(-1)


313
def run_client_hierarchy(
314
315
316
317
318
319
320
    graph_name,
    part_id,
    server_count,
    node_mask,
    edge_mask,
    return_dict,
    use_graphbolt=False,
321
322
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
323
    dgl.distributed.initialize("kv_ip_config.txt")
324
    gpb, graph_name, _, _ = load_partition_book(
325
        "/tmp/dist_graph/{}.json".format(graph_name), part_id
326
    )
327
    g = DistGraph(graph_name, gpb=gpb, use_graphbolt=use_graphbolt)
328
329
    node_mask = F.tensor(node_mask)
    edge_mask = F.tensor(edge_mask)
330
331
332
333
334
335
336
337
338
339
    nodes = node_split(
        node_mask,
        g.get_partition_book(),
        node_trainer_ids=g.ndata["trainer_id"],
    )
    edges = edge_split(
        edge_mask,
        g.get_partition_book(),
        edge_trainer_ids=g.edata["trainer_id"],
    )
340
341
342
    rank = g.rank()
    return_dict[rank] = (nodes, edges)

343

344
345
346
def check_dist_emb(g, num_clients, num_nodes, num_edges):
    # Test sparse emb
    try:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
347
348
        emb = DistEmbedding(g.num_nodes(), 1, "emb1", emb_init)
        nids = F.arange(0, int(g.num_nodes()))
349
350
351
352
353
354
355
356
357
358
359
        lr = 0.001
        optimizer = SparseAdagrad([emb], lr=lr)
        with F.record_grad():
            feats = emb(nids)
            assert np.all(F.asnumpy(feats) == np.zeros((len(nids), 1)))
            loss = F.sum(feats + 1, 0)
        loss.backward()
        optimizer.step()
        feats = emb(nids)
        if num_clients == 1:
            assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * -lr)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
360
        rest = np.setdiff1d(np.arange(g.num_nodes()), F.asnumpy(nids))
361
362
363
        feats1 = emb(rest)
        assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))

364
365
        policy = dgl.distributed.PartitionPolicy("node", g.get_partition_book())
        grad_sum = dgl.distributed.DistTensor(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
366
            (g.num_nodes(), 1), F.float32, "emb1_sum", policy
367
        )
368
        if num_clients == 1:
369
370
371
372
            assert np.all(
                F.asnumpy(grad_sum[nids])
                == np.ones((len(nids), 1)) * num_clients
            )
373
374
        assert np.all(F.asnumpy(grad_sum[rest]) == np.zeros((len(rest), 1)))

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
375
        emb = DistEmbedding(g.num_nodes(), 1, "emb2", emb_init)
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
        with F.no_grad():
            feats1 = emb(nids)
        assert np.all(F.asnumpy(feats1) == 0)

        optimizer = SparseAdagrad([emb], lr=lr)
        with F.record_grad():
            feats1 = emb(nids)
            feats2 = emb(nids)
            feats = F.cat([feats1, feats2], 0)
            assert np.all(F.asnumpy(feats) == np.zeros((len(nids) * 2, 1)))
            loss = F.sum(feats + 1, 0)
        loss.backward()
        optimizer.step()
        with F.no_grad():
            feats = emb(nids)
        if num_clients == 1:
392
393
394
            assert_almost_equal(
                F.asnumpy(feats), np.ones((len(nids), 1)) * 1 * -lr
            )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
395
        rest = np.setdiff1d(np.arange(g.num_nodes()), F.asnumpy(nids))
396
397
398
399
        feats1 = emb(rest)
        assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))
    except NotImplementedError as e:
        pass
400
401
402
    except Exception as e:
        print(e)
        sys.exit(-1)
403

404

405
def check_dist_graph(g, num_clients, num_nodes, num_edges, use_graphbolt=False):
406
    # Test API
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
407
408
    assert g.num_nodes() == num_nodes
    assert g.num_edges() == num_edges
409
410

    # Test reading node data
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
411
    nids = F.arange(0, int(g.num_nodes() / 2))
412
    feats1 = g.ndata["features"][nids]
413
414
415
416
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == nids))

    # Test reading edge data
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
417
    eids = F.arange(0, int(g.num_edges() / 2))
418
    feats1 = g.edata["features"][eids]
419
420
421
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == eids))

422
    # Test edge_subgraph
423
424
425
426
427
428
429
430
431
    if use_graphbolt:
        with pytest.raises(
            AssertionError, match="find_edges is not supported in GraphBolt."
        ):
            g.edge_subgraph(eids)
    else:
        sg = g.edge_subgraph(eids)
        assert sg.num_edges() == len(eids)
        assert F.array_equal(sg.edata[dgl.EID], eids)
432

433
    # Test init node data
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
434
    new_shape = (g.num_nodes(), 2)
435
    test1 = dgl.distributed.DistTensor(new_shape, F.int32)
436
437
    g.ndata["test1"] = test1
    feats = g.ndata["test1"][nids]
438
    assert np.all(F.asnumpy(feats) == 0)
439
    assert test1.count_nonzero() == 0
440

441
    # reference to a one that exists
442
443
444
445
    test2 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test2", init_func=rand_init
    )
    test3 = dgl.distributed.DistTensor(new_shape, F.float32, "test2")
446
447
448
    assert np.all(F.asnumpy(test2[nids]) == F.asnumpy(test3[nids]))

    # create a tensor and destroy a tensor and create it again.
449
450
451
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test3", init_func=rand_init
    )
452
453
454
    test3_name = test3.kvstore_key
    assert test3_name in g._client.data_name_list()
    assert test3_name in g._client.gdata_name_list()
455
    del test3
456
457
    assert test3_name not in g._client.data_name_list()
    assert test3_name not in g._client.gdata_name_list()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
458
    test3 = dgl.distributed.DistTensor((g.num_nodes(), 3), F.float32, "test3")
459
460
    del test3

Da Zheng's avatar
Da Zheng committed
461
    # add tests for anonymous distributed tensor.
462
463
464
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
Da Zheng's avatar
Da Zheng committed
465
    data = test3[0:10]
466
467
468
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
Da Zheng's avatar
Da Zheng committed
469
    del test3
470
471
472
    test5 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
Da Zheng's avatar
Da Zheng committed
473
474
    assert np.sum(F.asnumpy(test5[0:10] != data)) > 0

475
    # test a persistent tesnor
476
477
478
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test4", init_func=rand_init, persistent=True
    )
479
480
    del test4
    try:
481
        test4 = dgl.distributed.DistTensor(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
482
            (g.num_nodes(), 3), F.float32, "test4"
483
484
        )
        raise Exception("")
485
486
    except:
        pass
487
488
489

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
490
491
    g.ndata["test1"][nids] = new_feats
    feats = g.ndata["test1"][nids]
492
493
494
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
495
496
    assert len(g.ndata["features"]) == g.num_nodes()
    assert g.ndata["features"].shape == (g.num_nodes(), 1)
497
498
499
500
    assert g.ndata["features"].dtype == F.int64
    assert g.node_attr_schemes()["features"].dtype == F.int64
    assert g.node_attr_schemes()["test1"].dtype == F.int32
    assert g.node_attr_schemes()["features"].shape == (1,)
501

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
502
    selected_nodes = np.random.randint(0, 100, size=g.num_nodes()) > 30
503
    # Test node split
504
    nodes = node_split(selected_nodes, g.get_partition_book())
505
506
    nodes = F.asnumpy(nodes)
    # We only have one partition, so the local nodes are basically all nodes in the graph.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
507
    local_nids = np.arange(g.num_nodes())
508
509
510
    for n in nodes:
        assert n in local_nids

511
512
    print("end")

513

514
515
516
def check_dist_emb_server_client(
    shared_mem, num_servers, num_clients, num_groups=1
):
517
    prepare_dist(num_servers)
518
519
520
521
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
522
523
524
    graph_name = (
        f"check_dist_emb_{shared_mem}_{num_servers}_{num_clients}_{num_groups}"
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
525
526
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
527
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
528
529
530
531

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
532
    ctx = mp.get_context("spawn")
533
    for serv_id in range(num_servers):
534
535
536
537
538
539
540
541
542
543
        p = ctx.Process(
            target=run_server,
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                shared_mem,
            ),
        )
544
545
546
547
548
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
549
        for group_id in range(num_groups):
550
551
552
553
554
555
556
557
            print("start client[{}] for group[{}]".format(cli_id, group_id))
            p = ctx.Process(
                target=run_emb_client,
                args=(
                    graph_name,
                    0,
                    num_servers,
                    num_clients,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
558
559
                    g.num_nodes(),
                    g.num_edges(),
560
561
562
                    group_id,
                ),
            )
563
            p.start()
564
            time.sleep(1)  # avoid race condition when instantiating DistGraph
565
            cli_ps.append(p)
566
567
568

    for p in cli_ps:
        p.join()
569
        assert p.exitcode == 0
570
571
572

    for p in serv_ps:
        p.join()
573
        assert p.exitcode == 0
574

575
576
    print("clients have terminated")

577

578
579
580
def check_server_client(
    shared_mem, num_servers, num_clients, num_groups=1, use_graphbolt=False
):
581
    prepare_dist(num_servers)
582
583
584
585
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
586
    graph_name = f"check_server_client_{shared_mem}_{num_servers}_{num_clients}_{num_groups}"
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
587
588
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
589
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
590
591
592
    if use_graphbolt:
        part_config = os.path.join("/tmp/dist_graph", f"{graph_name}.json")
        dgl_partition_to_graphbolt(part_config)
593
594
595
596

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
597
    ctx = mp.get_context("spawn")
598
    for serv_id in range(num_servers):
599
600
601
602
603
604
605
606
        p = ctx.Process(
            target=run_server,
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                shared_mem,
607
                use_graphbolt,
608
609
            ),
        )
610
611
612
        serv_ps.append(p)
        p.start()

613
    # launch different client groups simultaneously
614
    cli_ps = []
615
    for cli_id in range(num_clients):
616
        for group_id in range(num_groups):
617
618
619
620
621
622
623
624
            print("start client[{}] for group[{}]".format(cli_id, group_id))
            p = ctx.Process(
                target=run_client,
                args=(
                    graph_name,
                    0,
                    num_servers,
                    num_clients,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
625
626
                    g.num_nodes(),
                    g.num_edges(),
627
                    group_id,
628
                    use_graphbolt,
629
630
                ),
            )
631
            p.start()
632
            time.sleep(1)  # avoid race condition when instantiating DistGraph
633
            cli_ps.append(p)
634
635
    for p in cli_ps:
        p.join()
636
        assert p.exitcode == 0
637
638
639

    for p in serv_ps:
        p.join()
640
        assert p.exitcode == 0
641

642
643
    print("clients have terminated")

644

645
646
647
648
649
650
def check_server_client_hierarchy(
    shared_mem, num_servers, num_clients, use_graphbolt=False
):
    if num_clients == 1:
        # skip this test if there is only one client.
        return
651
    prepare_dist(num_servers)
652
653
654
655
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
656
    graph_name = "dist_graph_test_2"
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
657
658
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
659
660
661
662
663
664
665
    partition_graph(
        g,
        graph_name,
        num_parts,
        "/tmp/dist_graph",
        num_trainers_per_machine=num_clients,
    )
666
667
668
    if use_graphbolt:
        part_config = os.path.join("/tmp/dist_graph", f"{graph_name}.json")
        dgl_partition_to_graphbolt(part_config)
669
670
671
672

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
673
    ctx = mp.get_context("spawn")
674
    for serv_id in range(num_servers):
675
676
        p = ctx.Process(
            target=run_server,
677
678
679
680
681
682
683
684
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                shared_mem,
                use_graphbolt,
            ),
685
        )
686
687
688
689
690
691
        serv_ps.append(p)
        p.start()

    cli_ps = []
    manager = mp.Manager()
    return_dict = manager.dict()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
692
693
694
695
    node_mask = np.zeros((g.num_nodes(),), np.int32)
    edge_mask = np.zeros((g.num_edges(),), np.int32)
    nodes = np.random.choice(g.num_nodes(), g.num_nodes() // 10, replace=False)
    edges = np.random.choice(g.num_edges(), g.num_edges() // 10, replace=False)
696
697
698
699
700
    node_mask[nodes] = 1
    edge_mask[edges] = 1
    nodes = np.sort(nodes)
    edges = np.sort(edges)
    for cli_id in range(num_clients):
701
702
703
704
705
706
707
708
709
710
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client_hierarchy,
            args=(
                graph_name,
                0,
                num_servers,
                node_mask,
                edge_mask,
                return_dict,
711
                use_graphbolt,
712
713
            ),
        )
714
715
716
717
718
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
719
        assert p.exitcode == 0
720
721
    for p in serv_ps:
        p.join()
722
        assert p.exitcode == 0
723
724
725
726
727
728
729
730
731
732
    nodes1 = []
    edges1 = []
    for n, e in return_dict.values():
        nodes1.append(n)
        edges1.append(e)
    nodes1, _ = F.sort_1d(F.cat(nodes1, 0))
    edges1, _ = F.sort_1d(F.cat(edges1, 0))
    assert np.all(F.asnumpy(nodes1) == nodes)
    assert np.all(F.asnumpy(edges1) == edges)

733
    print("clients have terminated")
734

735

736
def run_client_hetero(
737
738
739
740
741
742
743
    graph_name,
    part_id,
    server_count,
    num_clients,
    num_nodes,
    num_edges,
    use_graphbolt=False,
744
745
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
746
    dgl.distributed.initialize("kv_ip_config.txt")
747
    gpb, graph_name, _, _ = load_partition_book(
748
        "/tmp/dist_graph/{}.json".format(graph_name), part_id
749
    )
750
751
752
753
    g = DistGraph(graph_name, gpb=gpb, use_graphbolt=use_graphbolt)
    check_dist_graph_hetero(
        g, num_clients, num_nodes, num_edges, use_graphbolt=use_graphbolt
    )
754

755

756
def create_random_hetero():
757
758
    num_nodes = {"n1": 10000, "n2": 10010, "n3": 10020}
    etypes = [("n1", "r1", "n2"), ("n1", "r2", "n3"), ("n2", "r3", "n3")]
759
760
761
    edges = {}
    for etype in etypes:
        src_ntype, _, dst_ntype = etype
762
763
764
765
766
767
768
        arr = spsp.random(
            num_nodes[src_ntype],
            num_nodes[dst_ntype],
            density=0.001,
            format="coo",
            random_state=100,
        )
769
770
        edges[etype] = (arr.row, arr.col)
    g = dgl.heterograph(edges, num_nodes)
771
772
773
774
    # assign ndata & edata.
    # data with same name as ntype/etype is assigned on purpose to verify
    # such same names can be correctly handled in DistGraph. See more details
    # in issue #4887 and #4463 on github.
775
776
    ntype = "n1"
    for name in ["feat", ntype]:
777
778
779
        g.nodes[ntype].data[name] = F.unsqueeze(
            F.arange(0, g.num_nodes(ntype)), 1
        )
780
781
    etype = "r1"
    for name in ["feat", etype]:
782
783
784
        g.edges[etype].data[name] = F.unsqueeze(
            F.arange(0, g.num_edges(etype)), 1
        )
785
786
    return g

787

788
789
790
def check_dist_graph_hetero(
    g, num_clients, num_nodes, num_edges, use_graphbolt=False
):
791
792
793
    # Test API
    for ntype in num_nodes:
        assert ntype in g.ntypes
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
794
        assert num_nodes[ntype] == g.num_nodes(ntype)
795
796
    for etype in num_edges:
        assert etype in g.etypes
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
797
        assert num_edges[etype] == g.num_edges(etype)
798
    etypes = [("n1", "r1", "n2"), ("n1", "r2", "n3"), ("n2", "r3", "n3")]
799
800
801
802
    for i, etype in enumerate(g.canonical_etypes):
        assert etype[0] == etypes[i][0]
        assert etype[1] == etypes[i][1]
        assert etype[2] == etypes[i][2]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
803
804
    assert g.num_nodes() == sum([num_nodes[ntype] for ntype in num_nodes])
    assert g.num_edges() == sum([num_edges[etype] for etype in num_edges])
805
806

    # Test reading node data
807
    ntype = "n1"
808
    nids = F.arange(0, g.num_nodes(ntype) // 2)
809
    for name in ["feat", ntype]:
810
811
812
        data = g.nodes[ntype].data[name][nids]
        data = F.squeeze(data, 1)
        assert np.all(F.asnumpy(data == nids))
813
    assert len(g.nodes["n2"].data) == 0
814
815
    expect_except = False
    try:
816
        g.nodes["xxx"].data["x"]
817
818
819
    except dgl.DGLError:
        expect_except = True
    assert expect_except
820
821

    # Test reading edge data
822
    etype = "r1"
823
    eids = F.arange(0, g.num_edges(etype) // 2)
824
    for name in ["feat", etype]:
825
826
827
828
829
830
831
832
833
        # access via etype
        data = g.edges[etype].data[name][eids]
        data = F.squeeze(data, 1)
        assert np.all(F.asnumpy(data == eids))
        # access via canonical etype
        c_etype = g.to_canonical_etype(etype)
        data = g.edges[c_etype].data[name][eids]
        data = F.squeeze(data, 1)
        assert np.all(F.asnumpy(data == eids))
834
    assert len(g.edges["r2"].data) == 0
835
836
    expect_except = False
    try:
837
        g.edges["xxx"].data["x"]
838
839
840
    except dgl.DGLError:
        expect_except = True
    assert expect_except
841

842
    # Test edge_subgraph
843
844
845
846
847
848
849
850
851
852
853
854
    if use_graphbolt:
        with pytest.raises(
            AssertionError, match="find_edges is not supported in GraphBolt."
        ):
            g.edge_subgraph({"r1": eids})
    else:
        sg = g.edge_subgraph({"r1": eids})
        assert sg.num_edges() == len(eids)
        assert F.array_equal(sg.edata[dgl.EID], eids)
        sg = g.edge_subgraph({("n1", "r1", "n2"): eids})
        assert sg.num_edges() == len(eids)
        assert F.array_equal(sg.edata[dgl.EID], eids)
855

856
    # Test init node data
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
857
    new_shape = (g.num_nodes("n1"), 2)
858
859
    g.nodes["n1"].data["test1"] = dgl.distributed.DistTensor(new_shape, F.int32)
    feats = g.nodes["n1"].data["test1"][nids]
860
861
862
    assert np.all(F.asnumpy(feats) == 0)

    # create a tensor and destroy a tensor and create it again.
863
864
865
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test3", init_func=rand_init
    )
866
    del test3
867
    test3 = dgl.distributed.DistTensor(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
868
        (g.num_nodes("n1"), 3), F.float32, "test3"
869
    )
870
871
872
    del test3

    # add tests for anonymous distributed tensor.
873
874
875
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
876
    data = test3[0:10]
877
878
879
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
880
    del test3
881
882
883
    test5 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
884
885
886
    assert np.sum(F.asnumpy(test5[0:10] != data)) > 0

    # test a persistent tesnor
887
888
889
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test4", init_func=rand_init, persistent=True
    )
890
891
    del test4
    try:
892
        test4 = dgl.distributed.DistTensor(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
893
            (g.num_nodes("n1"), 3), F.float32, "test4"
894
895
        )
        raise Exception("")
896
897
898
899
900
    except:
        pass

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
901
902
    g.nodes["n1"].data["test1"][nids] = new_feats
    feats = g.nodes["n1"].data["test1"][nids]
903
904
905
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
906
907
    assert len(g.nodes["n1"].data["feat"]) == g.num_nodes("n1")
    assert g.nodes["n1"].data["feat"].shape == (g.num_nodes("n1"), 1)
908
    assert g.nodes["n1"].data["feat"].dtype == F.int64
909

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
910
    selected_nodes = np.random.randint(0, 100, size=g.num_nodes("n1")) > 30
911
    # Test node split
912
    nodes = node_split(selected_nodes, g.get_partition_book(), ntype="n1")
913
914
    nodes = F.asnumpy(nodes)
    # We only have one partition, so the local nodes are basically all nodes in the graph.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
915
    local_nids = np.arange(g.num_nodes("n1"))
916
917
918
    for n in nodes:
        assert n in local_nids

919
920
    print("end")

921

922
923
924
def check_server_client_hetero(
    shared_mem, num_servers, num_clients, use_graphbolt=False
):
925
    prepare_dist(num_servers)
926
927
928
929
    g = create_random_hetero()

    # Partition the graph
    num_parts = 1
930
931
    graph_name = "dist_graph_test_3"
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
932
933
934
    if use_graphbolt:
        part_config = os.path.join("/tmp/dist_graph", f"{graph_name}.json")
        dgl_partition_to_graphbolt(part_config)
935
936
937
938

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
939
    ctx = mp.get_context("spawn")
940
    for serv_id in range(num_servers):
941
942
        p = ctx.Process(
            target=run_server,
943
944
945
946
947
948
949
950
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                shared_mem,
                use_graphbolt,
            ),
951
        )
952
953
954
955
        serv_ps.append(p)
        p.start()

    cli_ps = []
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
956
957
    num_nodes = {ntype: g.num_nodes(ntype) for ntype in g.ntypes}
    num_edges = {etype: g.num_edges(etype) for etype in g.etypes}
958
    for cli_id in range(num_clients):
959
960
961
962
963
964
965
966
967
968
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client_hetero,
            args=(
                graph_name,
                0,
                num_servers,
                num_clients,
                num_nodes,
                num_edges,
969
                use_graphbolt,
970
971
            ),
        )
972
973
974
975
976
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
977
        assert p.exitcode == 0
978
979
980

    for p in serv_ps:
        p.join()
981
        assert p.exitcode == 0
982

983
984
    print("clients have terminated")

985

986
987
988
989
990
991
992
993
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support some of operations in DistGraph",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
994
995
996
997
998
@pytest.mark.parametrize("shared_mem", [True])
@pytest.mark.parametrize("num_servers", [1])
@pytest.mark.parametrize("num_clients", [1, 4])
@pytest.mark.parametrize("use_graphbolt", [True, False])
def test_server_client(shared_mem, num_servers, num_clients, use_graphbolt):
999
    reset_envs()
1000
    os.environ["DGL_DIST_MODE"] = "distributed"
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
    # [Rui]
    # 1. `disable_shared_mem=False` is not supported yet. Skip it.
    # 2. `num_servers` > 1 does not work on single machine. Skip it.
    for func in [
        check_server_client,
        check_server_client_hetero,
        check_server_client_empty,
        check_server_client_hierarchy,
    ]:
        func(shared_mem, num_servers, num_clients, use_graphbolt=use_graphbolt)
1011
1012


1013
@unittest.skip(reason="Skip due to glitch in CI")
1014
1015
1016
1017
1018
1019
1020
1021
1022
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support distributed DistEmbedding",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="Mxnet doesn't support distributed DistEmbedding",
)
1023
def test_dist_emb_server_client():
1024
    reset_envs()
1025
    os.environ["DGL_DIST_MODE"] = "distributed"
1026
1027
    check_dist_emb_server_client(True, 1, 1)
    check_dist_emb_server_client(False, 1, 1)
1028
1029
    # [TODO][Rhett] Tests for multiple groups may fail sometimes and
    # root cause is unknown. Let's disable them for now.
1030
1031
1032
1033
1034
1035
    # check_dist_emb_server_client(True, 2, 2)
    # check_dist_emb_server_client(True, 1, 1, 2)
    # check_dist_emb_server_client(False, 1, 1, 2)
    # check_dist_emb_server_client(True, 2, 2, 2)


1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support distributed Optimizer",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="Mxnet doesn't support distributed Optimizer",
)
def test_dist_optim_server_client():
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    optimizer_states = []
    num_nodes = 10000
    optimizer_states.append(F.uniform((num_nodes, 1), F.float32, F.cpu(), 0, 1))
    optimizer_states.append(F.uniform((num_nodes, 1), F.float32, F.cpu(), 0, 1))
    check_dist_optim_server_client(num_nodes, 1, 4, optimizer_states, True)
    check_dist_optim_server_client(num_nodes, 1, 8, optimizer_states, False)
    check_dist_optim_server_client(num_nodes, 1, 2, optimizer_states, False)


def check_dist_optim_server_client(
    num_nodes, num_servers, num_clients, optimizer_states, save
):
    graph_name = f"check_dist_optim_{num_servers}_store"
    if save:
        prepare_dist(num_servers)
        g = create_random_graph(num_nodes)

        # Partition the graph
        num_parts = 1
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1066
1067
        g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
        g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
        partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
    ctx = mp.get_context("spawn")
    for serv_id in range(num_servers):
        p = ctx.Process(
            target=run_server,
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                True,
            ),
        )
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
        print("start client[{}] for group[0]".format(cli_id))
        p = ctx.Process(
            target=run_optim_client,
            args=(
                graph_name,
                0,
                num_servers,
                cli_id,
                num_clients,
                num_nodes,
                optimizer_states,
                save,
            ),
        )
        p.start()
        time.sleep(1)  # avoid race condition when instantiating DistGraph
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
        assert p.exitcode == 0

    for p in serv_ps:
        p.join()
1114
        assert p.exitcode == 0
1115
1116


1117
1118
1119
1120
1121
1122
1123
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support some of operations in DistGraph",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
1124
def test_standalone():
1125
    reset_envs()
1126
    os.environ["DGL_DIST_MODE"] = "standalone"
Da Zheng's avatar
Da Zheng committed
1127

1128
1129
1130
    g = create_random_graph(10000)
    # Partition the graph
    num_parts = 1
1131
    graph_name = "dist_graph_test_3"
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1132
1133
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
1134
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
1135
1136

    dgl.distributed.initialize("kv_ip_config.txt")
1137
1138
1139
    dist_g = DistGraph(
        graph_name, part_config="/tmp/dist_graph/{}.json".format(graph_name)
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1140
    check_dist_graph(dist_g, 1, g.num_nodes(), g.num_edges())
1141
1142
    dgl.distributed.exit_client()  # this is needed since there's two test here in one process

1143

1144
@unittest.skip(reason="Skip due to glitch in CI")
1145
1146
1147
1148
1149
1150
1151
1152
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support distributed DistEmbedding",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="Mxnet doesn't support distributed DistEmbedding",
)
1153
def test_standalone_node_emb():
1154
    reset_envs()
1155
    os.environ["DGL_DIST_MODE"] = "standalone"
1156
1157
1158
1159

    g = create_random_graph(10000)
    # Partition the graph
    num_parts = 1
1160
    graph_name = "dist_graph_test_3"
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1161
1162
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
1163
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
1164
1165

    dgl.distributed.initialize("kv_ip_config.txt")
1166
1167
1168
    dist_g = DistGraph(
        graph_name, part_config="/tmp/dist_graph/{}.json".format(graph_name)
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1169
    check_dist_emb(dist_g, 1, g.num_nodes(), g.num_edges())
1170
1171
    dgl.distributed.exit_client()  # this is needed since there's two test here in one process

1172

1173
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
1174
@pytest.mark.parametrize("hetero", [True, False])
1175
1176
@pytest.mark.parametrize("empty_mask", [True, False])
def test_split(hetero, empty_mask):
1177
1178
    if hetero:
        g = create_random_hetero()
1179
1180
        ntype = "n1"
        etype = "r1"
1181
1182
    else:
        g = create_random_graph(10000)
1183
1184
        ntype = "_N"
        etype = "_E"
1185
1186
    num_parts = 4
    num_hops = 2
1187
1188
1189
1190
1191
1192
1193
1194
    partition_graph(
        g,
        "dist_graph_test",
        num_parts,
        "/tmp/dist_graph",
        num_hops=num_hops,
        part_method="metis",
    )
1195

1196
1197
1198
    mask_thd = 100 if empty_mask else 30
    node_mask = np.random.randint(0, 100, size=g.num_nodes(ntype)) > mask_thd
    edge_mask = np.random.randint(0, 100, size=g.num_edges(etype)) > mask_thd
1199
1200
    selected_nodes = np.nonzero(node_mask)[0]
    selected_edges = np.nonzero(edge_mask)[0]
Da Zheng's avatar
Da Zheng committed
1201
1202
1203
1204
1205

    # The code now collects the roles of all client processes and use the information
    # to determine how to split the workloads. Here is to simulate the multi-client
    # use case.
    def set_roles(num_clients):
1206
1207
1208
1209
1210
        dgl.distributed.role.CUR_ROLE = "default"
        dgl.distributed.role.GLOBAL_RANK = {i: i for i in range(num_clients)}
        dgl.distributed.role.PER_ROLE_RANK["default"] = {
            i: i for i in range(num_clients)
        }
Da Zheng's avatar
Da Zheng committed
1211

1212
    for i in range(num_parts):
Da Zheng's avatar
Da Zheng committed
1213
        set_roles(num_parts)
1214
1215
1216
1217
        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(
            "/tmp/dist_graph/dist_graph_test.json", i
        )
        local_nids = F.nonzero_1d(part_g.ndata["inner_node"])
1218
        local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
1219
1220
1221
1222
1223
1224
        if hetero:
            ntype_ids, nids = gpb.map_to_per_ntype(local_nids)
            local_nids = F.asnumpy(nids)[F.asnumpy(ntype_ids) == 0]
        else:
            local_nids = F.asnumpy(local_nids)
        nodes1 = np.intersect1d(selected_nodes, local_nids)
1225
1226
1227
        nodes2 = node_split(
            node_mask, gpb, ntype=ntype, rank=i, force_even=False
        )
1228
        assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes2)))
1229
        for n in F.asnumpy(nodes2):
1230
1231
            assert n in local_nids

Da Zheng's avatar
Da Zheng committed
1232
        set_roles(num_parts * 2)
1233
1234
1235
1236
1237
1238
        nodes3 = node_split(
            node_mask, gpb, ntype=ntype, rank=i * 2, force_even=False
        )
        nodes4 = node_split(
            node_mask, gpb, ntype=ntype, rank=i * 2 + 1, force_even=False
        )
1239
1240
1241
        nodes5 = F.cat([nodes3, nodes4], 0)
        assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes5)))

Da Zheng's avatar
Da Zheng committed
1242
        set_roles(num_parts)
1243
        local_eids = F.nonzero_1d(part_g.edata["inner_edge"])
1244
        local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)
1245
1246
1247
1248
1249
1250
        if hetero:
            etype_ids, eids = gpb.map_to_per_etype(local_eids)
            local_eids = F.asnumpy(eids)[F.asnumpy(etype_ids) == 0]
        else:
            local_eids = F.asnumpy(local_eids)
        edges1 = np.intersect1d(selected_edges, local_eids)
1251
1252
1253
        edges2 = edge_split(
            edge_mask, gpb, etype=etype, rank=i, force_even=False
        )
1254
        assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges2)))
1255
        for e in F.asnumpy(edges2):
1256
1257
            assert e in local_eids

Da Zheng's avatar
Da Zheng committed
1258
        set_roles(num_parts * 2)
1259
1260
1261
1262
1263
1264
        edges3 = edge_split(
            edge_mask, gpb, etype=etype, rank=i * 2, force_even=False
        )
        edges4 = edge_split(
            edge_mask, gpb, etype=etype, rank=i * 2 + 1, force_even=False
        )
1265
1266
1267
        edges5 = F.cat([edges3, edges4], 0)
        assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges5)))

1268
1269

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
1270
1271
@pytest.mark.parametrize("empty_mask", [True, False])
def test_split_even(empty_mask):
1272
1273
1274
    g = create_random_graph(10000)
    num_parts = 4
    num_hops = 2
1275
1276
1277
1278
1279
1280
1281
1282
    partition_graph(
        g,
        "dist_graph_test",
        num_parts,
        "/tmp/dist_graph",
        num_hops=num_hops,
        part_method="metis",
    )
1283

1284
1285
1286
    mask_thd = 100 if empty_mask else 30
    node_mask = np.random.randint(0, 100, size=g.num_nodes()) > mask_thd
    edge_mask = np.random.randint(0, 100, size=g.num_edges()) > mask_thd
1287
1288
1289
1290
    all_nodes1 = []
    all_nodes2 = []
    all_edges1 = []
    all_edges2 = []
Da Zheng's avatar
Da Zheng committed
1291
1292
1293
1294
1295

    # The code now collects the roles of all client processes and use the information
    # to determine how to split the workloads. Here is to simulate the multi-client
    # use case.
    def set_roles(num_clients):
1296
1297
1298
1299
1300
        dgl.distributed.role.CUR_ROLE = "default"
        dgl.distributed.role.GLOBAL_RANK = {i: i for i in range(num_clients)}
        dgl.distributed.role.PER_ROLE_RANK["default"] = {
            i: i for i in range(num_clients)
        }
Da Zheng's avatar
Da Zheng committed
1301

1302
    for i in range(num_parts):
Da Zheng's avatar
Da Zheng committed
1303
        set_roles(num_parts)
1304
1305
1306
1307
        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(
            "/tmp/dist_graph/dist_graph_test.json", i
        )
        local_nids = F.nonzero_1d(part_g.ndata["inner_node"])
1308
        local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
1309
        nodes = node_split(node_mask, gpb, rank=i, force_even=True)
1310
1311
        all_nodes1.append(nodes)
        subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(local_nids))
1312
1313
1314
1315
1316
        print(
            "part {} get {} nodes and {} are in the partition".format(
                i, len(nodes), len(subset)
            )
        )
1317

Da Zheng's avatar
Da Zheng committed
1318
        set_roles(num_parts * 2)
1319
1320
1321
        nodes1 = node_split(node_mask, gpb, rank=i * 2, force_even=True)
        nodes2 = node_split(node_mask, gpb, rank=i * 2 + 1, force_even=True)
        nodes3, _ = F.sort_1d(F.cat([nodes1, nodes2], 0))
1322
1323
        all_nodes2.append(nodes3)
        subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(nodes3))
1324
        print("intersection has", len(subset))
1325

Da Zheng's avatar
Da Zheng committed
1326
        set_roles(num_parts)
1327
        local_eids = F.nonzero_1d(part_g.edata["inner_edge"])
1328
        local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)
1329
        edges = edge_split(edge_mask, gpb, rank=i, force_even=True)
1330
1331
        all_edges1.append(edges)
        subset = np.intersect1d(F.asnumpy(edges), F.asnumpy(local_eids))
1332
1333
1334
1335
1336
        print(
            "part {} get {} edges and {} are in the partition".format(
                i, len(edges), len(subset)
            )
        )
1337

Da Zheng's avatar
Da Zheng committed
1338
        set_roles(num_parts * 2)
1339
1340
1341
        edges1 = edge_split(edge_mask, gpb, rank=i * 2, force_even=True)
        edges2 = edge_split(edge_mask, gpb, rank=i * 2 + 1, force_even=True)
        edges3, _ = F.sort_1d(F.cat([edges1, edges2], 0))
1342
1343
        all_edges2.append(edges3)
        subset = np.intersect1d(F.asnumpy(edges), F.asnumpy(edges3))
1344
        print("intersection has", len(subset))
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
    all_nodes1 = F.cat(all_nodes1, 0)
    all_edges1 = F.cat(all_edges1, 0)
    all_nodes2 = F.cat(all_nodes2, 0)
    all_edges2 = F.cat(all_edges2, 0)
    all_nodes = np.nonzero(node_mask)[0]
    all_edges = np.nonzero(edge_mask)[0]
    assert np.all(all_nodes == F.asnumpy(all_nodes1))
    assert np.all(all_edges == F.asnumpy(all_edges1))
    assert np.all(all_nodes == F.asnumpy(all_nodes2))
    assert np.all(all_edges == F.asnumpy(all_edges2))

1356

1357
1358
def prepare_dist(num_servers=1):
    generate_ip_config("kv_ip_config.txt", 1, num_servers=num_servers)
1359

1360
1361
1362

if __name__ == "__main__":
    os.makedirs("/tmp/dist_graph", exist_ok=True)
1363
    test_dist_emb_server_client()
1364
    test_server_client()
1365
1366
    test_split(True)
    test_split(False)
1367
    test_split_even()
1368
    test_standalone()
1369
    test_standalone_node_emb()