test_dist_graph_store.py 42.2 KB
Newer Older
1
import os
2
3
4
5
6
7

os.environ["OMP_NUM_THREADS"] = "1"
import math
import multiprocessing as mp
import pickle
import socket
8
9
10
import sys
import time
import unittest
11
12
13
from multiprocessing import Condition, Manager, Process, Value

import backend as F
14
15

import dgl
16
import dgl.graphbolt as gb
17
import numpy as np
18
import pytest
19
import torch as th
20
21
from dgl.data.utils import load_graphs, save_graphs
from dgl.distributed import (
22
    DistEmbedding,
23
24
25
26
27
28
29
30
    DistGraph,
    DistGraphServer,
    edge_split,
    load_partition,
    load_partition_book,
    node_split,
    partition_graph,
)
31
from dgl.distributed.optim import SparseAdagrad
32
from dgl.heterograph_index import create_unitgraph_from_coo
33
34
35
from numpy.testing import assert_almost_equal, assert_array_equal
from scipy import sparse as spsp
from utils import create_random_graph, generate_ip_config, reset_envs
36
37

if os.name != "nt":
38
39
40
    import fcntl
    import struct

41

42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def _verify_dist_graph_server_dgl(g):
    # verify dtype of underlying graph
    cg = g.client_g
    for k, dtype in dgl.distributed.dist_graph.RESERVED_FIELD_DTYPE.items():
        if k in cg.ndata:
            assert (
                F.dtype(cg.ndata[k]) == dtype
            ), "Data type of {} in ndata should be {}.".format(k, dtype)
        if k in cg.edata:
            assert (
                F.dtype(cg.edata[k]) == dtype
            ), "Data type of {} in edata should be {}.".format(k, dtype)


def _verify_dist_graph_server_graphbolt(g):
    graph = g.client_g
    assert isinstance(graph, gb.FusedCSCSamplingGraph)
    # [Rui][TODO] verify dtype of underlying graph.


62
63
64
65
66
67
def run_server(
    graph_name,
    server_id,
    server_count,
    num_clients,
    shared_mem,
68
    use_graphbolt=False,
69
70
71
72
73
74
75
76
77
):
    g = DistGraphServer(
        server_id,
        "kv_ip_config.txt",
        server_count,
        num_clients,
        "/tmp/dist_graph/{}.json".format(graph_name),
        disable_shared_mem=not shared_mem,
        graph_format=["csc", "coo"],
78
        use_graphbolt=use_graphbolt,
79
    )
80
81
82
83
84
85
86
    print(f"Starting server[{server_id}] with use_graphbolt={use_graphbolt}")
    _verify = (
        _verify_dist_graph_server_graphbolt
        if use_graphbolt
        else _verify_dist_graph_server_dgl
    )
    _verify(g)
87
88
    g.start()

89

90
91
92
def emb_init(shape, dtype):
    return F.zeros(shape, dtype, F.cpu())

93

94
def rand_init(shape, dtype):
95
    return F.tensor(np.random.normal(size=shape), F.float32)
96

97

98
99
def check_dist_graph_empty(g, num_clients, num_nodes, num_edges):
    # Test API
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
100
101
    assert g.num_nodes() == num_nodes
    assert g.num_edges() == num_edges
102
103

    # Test init node data
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
104
    new_shape = (g.num_nodes(), 2)
105
    g.ndata["test1"] = dgl.distributed.DistTensor(new_shape, F.int32)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
106
    nids = F.arange(0, int(g.num_nodes() / 2))
107
    feats = g.ndata["test1"][nids]
108
109
110
    assert np.all(F.asnumpy(feats) == 0)

    # create a tensor and destroy a tensor and create it again.
111
112
113
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test3", init_func=rand_init
    )
114
    del test3
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
115
    test3 = dgl.distributed.DistTensor((g.num_nodes(), 3), F.float32, "test3")
116
117
118
119
    del test3

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
120
121
    g.ndata["test1"][nids] = new_feats
    feats = g.ndata["test1"][nids]
122
123
124
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
125
    assert g.node_attr_schemes()["test1"].dtype == F.int32
126

127
    print("end")
128

129
130

def run_client_empty(
131
132
133
134
135
136
137
    graph_name,
    part_id,
    server_count,
    num_clients,
    num_nodes,
    num_edges,
    use_graphbolt=False,
138
139
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
140
    dgl.distributed.initialize("kv_ip_config.txt")
141
    gpb, graph_name, _, _ = load_partition_book(
142
        "/tmp/dist_graph/{}.json".format(graph_name), part_id
143
    )
144
    g = DistGraph(graph_name, gpb=gpb)
145
146
    check_dist_graph_empty(g, num_clients, num_nodes, num_edges)

147

148
149
150
def check_server_client_empty(
    shared_mem, num_servers, num_clients, use_graphbolt=False
):
151
    prepare_dist(num_servers)
152
153
154
155
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
156
    graph_name = "dist_graph_test_1"
157
158
159
    partition_graph(
        g, graph_name, num_parts, "/tmp/dist_graph", use_graphbolt=use_graphbolt
    )
160
161
162
163

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
164
    ctx = mp.get_context("spawn")
165
    for serv_id in range(num_servers):
166
167
        p = ctx.Process(
            target=run_server,
168
169
170
171
172
173
174
175
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                shared_mem,
                use_graphbolt,
            ),
176
        )
177
178
179
180
181
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
182
183
184
185
186
187
188
189
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client_empty,
            args=(
                graph_name,
                0,
                num_servers,
                num_clients,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
190
191
                g.num_nodes(),
                g.num_edges(),
192
                use_graphbolt,
193
194
            ),
        )
195
196
197
198
199
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
200
        assert p.exitcode == 0
201
202
203

    for p in serv_ps:
        p.join()
204
        assert p.exitcode == 0
205

206
    print("clients have terminated")
207

208
209
210
211
212
213
214
215
216

def run_client(
    graph_name,
    part_id,
    server_count,
    num_clients,
    num_nodes,
    num_edges,
    group_id,
217
    use_graphbolt=False,
218
219
220
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
    os.environ["DGL_GROUP_ID"] = str(group_id)
221
    dgl.distributed.initialize("kv_ip_config.txt")
222
    gpb, graph_name, _, _ = load_partition_book(
223
        "/tmp/dist_graph/{}.json".format(graph_name), part_id
224
    )
225
    g = DistGraph(graph_name, gpb=gpb)
226
227
228
    check_dist_graph(
        g, num_clients, num_nodes, num_edges, use_graphbolt=use_graphbolt
    )
229

230
231
232
233
234
235
236
237
238
239
240
241

def run_emb_client(
    graph_name,
    part_id,
    server_count,
    num_clients,
    num_nodes,
    num_edges,
    group_id,
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
    os.environ["DGL_GROUP_ID"] = str(group_id)
242
    dgl.distributed.initialize("kv_ip_config.txt")
243
    gpb, graph_name, _, _ = load_partition_book(
244
        "/tmp/dist_graph/{}.json".format(graph_name), part_id
245
    )
246
247
248
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_emb(g, num_clients, num_nodes, num_edges)

249

250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
def run_optim_client(
    graph_name,
    part_id,
    server_count,
    rank,
    world_size,
    num_nodes,
    optimizer_states,
    save,
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "12355"
    dgl.distributed.initialize("kv_ip_config.txt")
    th.distributed.init_process_group(
        backend="gloo", rank=rank, world_size=world_size
    )
    gpb, graph_name, _, _ = load_partition_book(
268
        "/tmp/dist_graph/{}.json".format(graph_name), part_id
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
    )
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_optim_store(rank, num_nodes, optimizer_states, save)


def check_dist_optim_store(rank, num_nodes, optimizer_states, save):
    try:
        total_idx = F.arange(0, num_nodes, F.int64, F.cpu())
        emb = DistEmbedding(num_nodes, 1, name="optim_emb1", init_func=emb_init)
        emb2 = DistEmbedding(
            num_nodes, 1, name="optim_emb2", init_func=emb_init
        )
        if save:
            optimizer = SparseAdagrad([emb, emb2], lr=0.1, eps=1e-08)
            if rank == 0:
                optimizer._state["optim_emb1"][total_idx] = optimizer_states[0]
                optimizer._state["optim_emb2"][total_idx] = optimizer_states[1]
            optimizer.save("/tmp/dist_graph/emb.pt")
        else:
            optimizer = SparseAdagrad([emb, emb2], lr=0.001, eps=2e-08)
            optimizer.load("/tmp/dist_graph/emb.pt")
            if rank == 0:
                assert F.allclose(
                    optimizer._state["optim_emb1"][total_idx],
                    optimizer_states[0],
                    0.0,
                    0.0,
                )
                assert F.allclose(
                    optimizer._state["optim_emb2"][total_idx],
                    optimizer_states[1],
                    0.0,
                    0.0,
                )
                assert 0.1 == optimizer._lr
                assert 1e-08 == optimizer._eps
            th.distributed.barrier()
    except Exception as e:
        print(e)
        sys.exit(-1)


311
def run_client_hierarchy(
312
313
314
315
316
317
318
    graph_name,
    part_id,
    server_count,
    node_mask,
    edge_mask,
    return_dict,
    use_graphbolt=False,
319
320
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
321
    dgl.distributed.initialize("kv_ip_config.txt")
322
    gpb, graph_name, _, _ = load_partition_book(
323
        "/tmp/dist_graph/{}.json".format(graph_name), part_id
324
    )
325
    g = DistGraph(graph_name, gpb=gpb)
326
327
    node_mask = F.tensor(node_mask)
    edge_mask = F.tensor(edge_mask)
328
329
330
331
332
333
334
335
336
337
    nodes = node_split(
        node_mask,
        g.get_partition_book(),
        node_trainer_ids=g.ndata["trainer_id"],
    )
    edges = edge_split(
        edge_mask,
        g.get_partition_book(),
        edge_trainer_ids=g.edata["trainer_id"],
    )
338
339
340
    rank = g.rank()
    return_dict[rank] = (nodes, edges)

341

342
343
344
def check_dist_emb(g, num_clients, num_nodes, num_edges):
    # Test sparse emb
    try:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
345
346
        emb = DistEmbedding(g.num_nodes(), 1, "emb1", emb_init)
        nids = F.arange(0, int(g.num_nodes()))
347
348
349
350
351
352
353
354
355
356
357
        lr = 0.001
        optimizer = SparseAdagrad([emb], lr=lr)
        with F.record_grad():
            feats = emb(nids)
            assert np.all(F.asnumpy(feats) == np.zeros((len(nids), 1)))
            loss = F.sum(feats + 1, 0)
        loss.backward()
        optimizer.step()
        feats = emb(nids)
        if num_clients == 1:
            assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * -lr)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
358
        rest = np.setdiff1d(np.arange(g.num_nodes()), F.asnumpy(nids))
359
360
361
        feats1 = emb(rest)
        assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))

362
363
        policy = dgl.distributed.PartitionPolicy("node", g.get_partition_book())
        grad_sum = dgl.distributed.DistTensor(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
364
            (g.num_nodes(), 1), F.float32, "emb1_sum", policy
365
        )
366
        if num_clients == 1:
367
368
369
370
            assert np.all(
                F.asnumpy(grad_sum[nids])
                == np.ones((len(nids), 1)) * num_clients
            )
371
372
        assert np.all(F.asnumpy(grad_sum[rest]) == np.zeros((len(rest), 1)))

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
373
        emb = DistEmbedding(g.num_nodes(), 1, "emb2", emb_init)
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
        with F.no_grad():
            feats1 = emb(nids)
        assert np.all(F.asnumpy(feats1) == 0)

        optimizer = SparseAdagrad([emb], lr=lr)
        with F.record_grad():
            feats1 = emb(nids)
            feats2 = emb(nids)
            feats = F.cat([feats1, feats2], 0)
            assert np.all(F.asnumpy(feats) == np.zeros((len(nids) * 2, 1)))
            loss = F.sum(feats + 1, 0)
        loss.backward()
        optimizer.step()
        with F.no_grad():
            feats = emb(nids)
        if num_clients == 1:
390
391
392
            assert_almost_equal(
                F.asnumpy(feats), np.ones((len(nids), 1)) * 1 * -lr
            )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
393
        rest = np.setdiff1d(np.arange(g.num_nodes()), F.asnumpy(nids))
394
395
396
397
        feats1 = emb(rest)
        assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))
    except NotImplementedError as e:
        pass
398
399
400
    except Exception as e:
        print(e)
        sys.exit(-1)
401

402

403
def check_dist_graph(g, num_clients, num_nodes, num_edges, use_graphbolt=False):
404
    # Test API
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
405
406
    assert g.num_nodes() == num_nodes
    assert g.num_edges() == num_edges
407
408

    # Test reading node data
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
409
    nids = F.arange(0, int(g.num_nodes() / 2))
410
    feats1 = g.ndata["features"][nids]
411
412
413
414
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == nids))

    # Test reading edge data
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
415
    eids = F.arange(0, int(g.num_edges() / 2))
416
    feats1 = g.edata["features"][eids]
417
418
419
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == eids))

420
    # Test edge_subgraph
421
422
423
424
425
426
427
428
429
    if use_graphbolt:
        with pytest.raises(
            AssertionError, match="find_edges is not supported in GraphBolt."
        ):
            g.edge_subgraph(eids)
    else:
        sg = g.edge_subgraph(eids)
        assert sg.num_edges() == len(eids)
        assert F.array_equal(sg.edata[dgl.EID], eids)
430

431
    # Test init node data
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
432
    new_shape = (g.num_nodes(), 2)
433
    test1 = dgl.distributed.DistTensor(new_shape, F.int32)
434
435
    g.ndata["test1"] = test1
    feats = g.ndata["test1"][nids]
436
    assert np.all(F.asnumpy(feats) == 0)
437
    assert test1.count_nonzero() == 0
438

439
    # reference to a one that exists
440
441
442
443
    test2 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test2", init_func=rand_init
    )
    test3 = dgl.distributed.DistTensor(new_shape, F.float32, "test2")
444
445
446
    assert np.all(F.asnumpy(test2[nids]) == F.asnumpy(test3[nids]))

    # create a tensor and destroy a tensor and create it again.
447
448
449
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test3", init_func=rand_init
    )
450
451
452
    test3_name = test3.kvstore_key
    assert test3_name in g._client.data_name_list()
    assert test3_name in g._client.gdata_name_list()
453
    del test3
454
455
    assert test3_name not in g._client.data_name_list()
    assert test3_name not in g._client.gdata_name_list()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
456
    test3 = dgl.distributed.DistTensor((g.num_nodes(), 3), F.float32, "test3")
457
458
    del test3

Da Zheng's avatar
Da Zheng committed
459
    # add tests for anonymous distributed tensor.
460
461
462
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
Da Zheng's avatar
Da Zheng committed
463
    data = test3[0:10]
464
465
466
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
Da Zheng's avatar
Da Zheng committed
467
    del test3
468
469
470
    test5 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
Da Zheng's avatar
Da Zheng committed
471
472
    assert np.sum(F.asnumpy(test5[0:10] != data)) > 0

473
    # test a persistent tesnor
474
475
476
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test4", init_func=rand_init, persistent=True
    )
477
478
    del test4
    try:
479
        test4 = dgl.distributed.DistTensor(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
480
            (g.num_nodes(), 3), F.float32, "test4"
481
482
        )
        raise Exception("")
483
484
    except:
        pass
485
486
487

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
488
489
    g.ndata["test1"][nids] = new_feats
    feats = g.ndata["test1"][nids]
490
491
492
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
493
494
    assert len(g.ndata["features"]) == g.num_nodes()
    assert g.ndata["features"].shape == (g.num_nodes(), 1)
495
496
497
498
    assert g.ndata["features"].dtype == F.int64
    assert g.node_attr_schemes()["features"].dtype == F.int64
    assert g.node_attr_schemes()["test1"].dtype == F.int32
    assert g.node_attr_schemes()["features"].shape == (1,)
499

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
500
    selected_nodes = np.random.randint(0, 100, size=g.num_nodes()) > 30
501
    # Test node split
502
    nodes = node_split(selected_nodes, g.get_partition_book())
503
504
    nodes = F.asnumpy(nodes)
    # We only have one partition, so the local nodes are basically all nodes in the graph.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
505
    local_nids = np.arange(g.num_nodes())
506
507
508
    for n in nodes:
        assert n in local_nids

509
510
    print("end")

511

512
513
514
def check_dist_emb_server_client(
    shared_mem, num_servers, num_clients, num_groups=1
):
515
    prepare_dist(num_servers)
516
517
518
519
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
520
521
522
    graph_name = (
        f"check_dist_emb_{shared_mem}_{num_servers}_{num_clients}_{num_groups}"
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
523
524
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
525
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
526
527
528
529

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
530
    ctx = mp.get_context("spawn")
531
    for serv_id in range(num_servers):
532
533
534
535
536
537
538
539
540
541
        p = ctx.Process(
            target=run_server,
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                shared_mem,
            ),
        )
542
543
544
545
546
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
547
        for group_id in range(num_groups):
548
549
550
551
552
553
554
555
            print("start client[{}] for group[{}]".format(cli_id, group_id))
            p = ctx.Process(
                target=run_emb_client,
                args=(
                    graph_name,
                    0,
                    num_servers,
                    num_clients,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
556
557
                    g.num_nodes(),
                    g.num_edges(),
558
559
560
                    group_id,
                ),
            )
561
            p.start()
562
            time.sleep(1)  # avoid race condition when instantiating DistGraph
563
            cli_ps.append(p)
564
565
566

    for p in cli_ps:
        p.join()
567
        assert p.exitcode == 0
568
569
570

    for p in serv_ps:
        p.join()
571
        assert p.exitcode == 0
572

573
574
    print("clients have terminated")

575

576
577
578
def check_server_client(
    shared_mem, num_servers, num_clients, num_groups=1, use_graphbolt=False
):
579
    prepare_dist(num_servers)
580
581
582
583
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
584
    graph_name = f"check_server_client_{shared_mem}_{num_servers}_{num_clients}_{num_groups}"
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
585
586
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
587
588
589
    partition_graph(
        g, graph_name, num_parts, "/tmp/dist_graph", use_graphbolt=use_graphbolt
    )
590
591
592
593

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
594
    ctx = mp.get_context("spawn")
595
    for serv_id in range(num_servers):
596
597
598
599
600
601
602
603
        p = ctx.Process(
            target=run_server,
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                shared_mem,
604
                use_graphbolt,
605
606
            ),
        )
607
608
609
        serv_ps.append(p)
        p.start()

610
    # launch different client groups simultaneously
611
    cli_ps = []
612
    for cli_id in range(num_clients):
613
        for group_id in range(num_groups):
614
615
616
617
618
619
620
621
            print("start client[{}] for group[{}]".format(cli_id, group_id))
            p = ctx.Process(
                target=run_client,
                args=(
                    graph_name,
                    0,
                    num_servers,
                    num_clients,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
622
623
                    g.num_nodes(),
                    g.num_edges(),
624
                    group_id,
625
                    use_graphbolt,
626
627
                ),
            )
628
            p.start()
629
            time.sleep(1)  # avoid race condition when instantiating DistGraph
630
            cli_ps.append(p)
631
632
    for p in cli_ps:
        p.join()
633
        assert p.exitcode == 0
634
635
636

    for p in serv_ps:
        p.join()
637
        assert p.exitcode == 0
638

639
640
    print("clients have terminated")

641

642
643
644
645
646
647
def check_server_client_hierarchy(
    shared_mem, num_servers, num_clients, use_graphbolt=False
):
    if num_clients == 1:
        # skip this test if there is only one client.
        return
648
    prepare_dist(num_servers)
649
650
651
652
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
653
    graph_name = "dist_graph_test_2"
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
654
655
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
656
657
658
659
660
661
    partition_graph(
        g,
        graph_name,
        num_parts,
        "/tmp/dist_graph",
        num_trainers_per_machine=num_clients,
662
        use_graphbolt=use_graphbolt,
663
    )
664
665
666
667

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
668
    ctx = mp.get_context("spawn")
669
    for serv_id in range(num_servers):
670
671
        p = ctx.Process(
            target=run_server,
672
673
674
675
676
677
678
679
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                shared_mem,
                use_graphbolt,
            ),
680
        )
681
682
683
684
685
686
        serv_ps.append(p)
        p.start()

    cli_ps = []
    manager = mp.Manager()
    return_dict = manager.dict()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
687
688
689
690
    node_mask = np.zeros((g.num_nodes(),), np.int32)
    edge_mask = np.zeros((g.num_edges(),), np.int32)
    nodes = np.random.choice(g.num_nodes(), g.num_nodes() // 10, replace=False)
    edges = np.random.choice(g.num_edges(), g.num_edges() // 10, replace=False)
691
692
693
694
695
    node_mask[nodes] = 1
    edge_mask[edges] = 1
    nodes = np.sort(nodes)
    edges = np.sort(edges)
    for cli_id in range(num_clients):
696
697
698
699
700
701
702
703
704
705
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client_hierarchy,
            args=(
                graph_name,
                0,
                num_servers,
                node_mask,
                edge_mask,
                return_dict,
706
                use_graphbolt,
707
708
            ),
        )
709
710
711
712
713
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
714
        assert p.exitcode == 0
715
716
    for p in serv_ps:
        p.join()
717
        assert p.exitcode == 0
718
719
720
721
722
723
724
725
726
727
    nodes1 = []
    edges1 = []
    for n, e in return_dict.values():
        nodes1.append(n)
        edges1.append(e)
    nodes1, _ = F.sort_1d(F.cat(nodes1, 0))
    edges1, _ = F.sort_1d(F.cat(edges1, 0))
    assert np.all(F.asnumpy(nodes1) == nodes)
    assert np.all(F.asnumpy(edges1) == edges)

728
    print("clients have terminated")
729

730

731
def run_client_hetero(
732
733
734
735
736
737
738
    graph_name,
    part_id,
    server_count,
    num_clients,
    num_nodes,
    num_edges,
    use_graphbolt=False,
739
740
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
741
    dgl.distributed.initialize("kv_ip_config.txt")
742
    gpb, graph_name, _, _ = load_partition_book(
743
        "/tmp/dist_graph/{}.json".format(graph_name), part_id
744
    )
745
    g = DistGraph(graph_name, gpb=gpb)
746
747
748
    check_dist_graph_hetero(
        g, num_clients, num_nodes, num_edges, use_graphbolt=use_graphbolt
    )
749

750

751
def create_random_hetero():
752
753
    num_nodes = {"n1": 10000, "n2": 10010, "n3": 10020}
    etypes = [("n1", "r1", "n2"), ("n1", "r2", "n3"), ("n2", "r3", "n3")]
754
755
756
    edges = {}
    for etype in etypes:
        src_ntype, _, dst_ntype = etype
757
758
759
760
761
762
763
        arr = spsp.random(
            num_nodes[src_ntype],
            num_nodes[dst_ntype],
            density=0.001,
            format="coo",
            random_state=100,
        )
764
765
        edges[etype] = (arr.row, arr.col)
    g = dgl.heterograph(edges, num_nodes)
766
767
768
769
    # assign ndata & edata.
    # data with same name as ntype/etype is assigned on purpose to verify
    # such same names can be correctly handled in DistGraph. See more details
    # in issue #4887 and #4463 on github.
770
771
    ntype = "n1"
    for name in ["feat", ntype]:
772
773
774
        g.nodes[ntype].data[name] = F.unsqueeze(
            F.arange(0, g.num_nodes(ntype)), 1
        )
775
776
    etype = "r1"
    for name in ["feat", etype]:
777
778
779
        g.edges[etype].data[name] = F.unsqueeze(
            F.arange(0, g.num_edges(etype)), 1
        )
780
781
    return g

782

783
784
785
def check_dist_graph_hetero(
    g, num_clients, num_nodes, num_edges, use_graphbolt=False
):
786
787
788
    # Test API
    for ntype in num_nodes:
        assert ntype in g.ntypes
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
789
        assert num_nodes[ntype] == g.num_nodes(ntype)
790
791
    for etype in num_edges:
        assert etype in g.etypes
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
792
        assert num_edges[etype] == g.num_edges(etype)
793
    etypes = [("n1", "r1", "n2"), ("n1", "r2", "n3"), ("n2", "r3", "n3")]
794
795
796
797
    for i, etype in enumerate(g.canonical_etypes):
        assert etype[0] == etypes[i][0]
        assert etype[1] == etypes[i][1]
        assert etype[2] == etypes[i][2]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
798
799
    assert g.num_nodes() == sum([num_nodes[ntype] for ntype in num_nodes])
    assert g.num_edges() == sum([num_edges[etype] for etype in num_edges])
800
801

    # Test reading node data
802
    ntype = "n1"
803
    nids = F.arange(0, g.num_nodes(ntype) // 2)
804
    for name in ["feat", ntype]:
805
806
807
        data = g.nodes[ntype].data[name][nids]
        data = F.squeeze(data, 1)
        assert np.all(F.asnumpy(data == nids))
808
    assert len(g.nodes["n2"].data) == 0
809
810
    expect_except = False
    try:
811
        g.nodes["xxx"].data["x"]
812
813
814
    except dgl.DGLError:
        expect_except = True
    assert expect_except
815
816

    # Test reading edge data
817
    etype = "r1"
818
    eids = F.arange(0, g.num_edges(etype) // 2)
819
    for name in ["feat", etype]:
820
821
822
823
824
825
826
827
828
        # access via etype
        data = g.edges[etype].data[name][eids]
        data = F.squeeze(data, 1)
        assert np.all(F.asnumpy(data == eids))
        # access via canonical etype
        c_etype = g.to_canonical_etype(etype)
        data = g.edges[c_etype].data[name][eids]
        data = F.squeeze(data, 1)
        assert np.all(F.asnumpy(data == eids))
829
    assert len(g.edges["r2"].data) == 0
830
831
    expect_except = False
    try:
832
        g.edges["xxx"].data["x"]
833
834
835
    except dgl.DGLError:
        expect_except = True
    assert expect_except
836

837
    # Test edge_subgraph
838
839
840
841
842
843
844
845
846
847
848
849
    if use_graphbolt:
        with pytest.raises(
            AssertionError, match="find_edges is not supported in GraphBolt."
        ):
            g.edge_subgraph({"r1": eids})
    else:
        sg = g.edge_subgraph({"r1": eids})
        assert sg.num_edges() == len(eids)
        assert F.array_equal(sg.edata[dgl.EID], eids)
        sg = g.edge_subgraph({("n1", "r1", "n2"): eids})
        assert sg.num_edges() == len(eids)
        assert F.array_equal(sg.edata[dgl.EID], eids)
850

851
    # Test init node data
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
852
    new_shape = (g.num_nodes("n1"), 2)
853
854
    g.nodes["n1"].data["test1"] = dgl.distributed.DistTensor(new_shape, F.int32)
    feats = g.nodes["n1"].data["test1"][nids]
855
856
857
    assert np.all(F.asnumpy(feats) == 0)

    # create a tensor and destroy a tensor and create it again.
858
859
860
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test3", init_func=rand_init
    )
861
    del test3
862
    test3 = dgl.distributed.DistTensor(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
863
        (g.num_nodes("n1"), 3), F.float32, "test3"
864
    )
865
866
867
    del test3

    # add tests for anonymous distributed tensor.
868
869
870
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
871
    data = test3[0:10]
872
873
874
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
875
    del test3
876
877
878
    test5 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
879
880
881
    assert np.sum(F.asnumpy(test5[0:10] != data)) > 0

    # test a persistent tesnor
882
883
884
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test4", init_func=rand_init, persistent=True
    )
885
886
    del test4
    try:
887
        test4 = dgl.distributed.DistTensor(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
888
            (g.num_nodes("n1"), 3), F.float32, "test4"
889
890
        )
        raise Exception("")
891
892
893
894
895
    except:
        pass

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
896
897
    g.nodes["n1"].data["test1"][nids] = new_feats
    feats = g.nodes["n1"].data["test1"][nids]
898
899
900
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
901
902
    assert len(g.nodes["n1"].data["feat"]) == g.num_nodes("n1")
    assert g.nodes["n1"].data["feat"].shape == (g.num_nodes("n1"), 1)
903
    assert g.nodes["n1"].data["feat"].dtype == F.int64
904

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
905
    selected_nodes = np.random.randint(0, 100, size=g.num_nodes("n1")) > 30
906
    # Test node split
907
    nodes = node_split(selected_nodes, g.get_partition_book(), ntype="n1")
908
909
    nodes = F.asnumpy(nodes)
    # We only have one partition, so the local nodes are basically all nodes in the graph.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
910
    local_nids = np.arange(g.num_nodes("n1"))
911
912
913
    for n in nodes:
        assert n in local_nids

914
915
    print("end")

916

917
918
919
def check_server_client_hetero(
    shared_mem, num_servers, num_clients, use_graphbolt=False
):
920
    prepare_dist(num_servers)
921
922
923
924
    g = create_random_hetero()

    # Partition the graph
    num_parts = 1
925
    graph_name = "dist_graph_test_3"
926
927
928
    partition_graph(
        g, graph_name, num_parts, "/tmp/dist_graph", use_graphbolt=use_graphbolt
    )
929
930
931
932

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
933
    ctx = mp.get_context("spawn")
934
    for serv_id in range(num_servers):
935
936
        p = ctx.Process(
            target=run_server,
937
938
939
940
941
942
943
944
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                shared_mem,
                use_graphbolt,
            ),
945
        )
946
947
948
949
        serv_ps.append(p)
        p.start()

    cli_ps = []
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
950
951
    num_nodes = {ntype: g.num_nodes(ntype) for ntype in g.ntypes}
    num_edges = {etype: g.num_edges(etype) for etype in g.etypes}
952
    for cli_id in range(num_clients):
953
954
955
956
957
958
959
960
961
962
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client_hetero,
            args=(
                graph_name,
                0,
                num_servers,
                num_clients,
                num_nodes,
                num_edges,
963
                use_graphbolt,
964
965
            ),
        )
966
967
968
969
970
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
971
        assert p.exitcode == 0
972
973
974

    for p in serv_ps:
        p.join()
975
        assert p.exitcode == 0
976

977
978
    print("clients have terminated")

979

980
981
982
983
984
985
986
987
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support some of operations in DistGraph",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
988
989
990
991
992
@pytest.mark.parametrize("shared_mem", [True])
@pytest.mark.parametrize("num_servers", [1])
@pytest.mark.parametrize("num_clients", [1, 4])
@pytest.mark.parametrize("use_graphbolt", [True, False])
def test_server_client(shared_mem, num_servers, num_clients, use_graphbolt):
993
    reset_envs()
994
    os.environ["DGL_DIST_MODE"] = "distributed"
995
996
997
998
999
1000
1001
1002
1003
1004
    # [Rui]
    # 1. `disable_shared_mem=False` is not supported yet. Skip it.
    # 2. `num_servers` > 1 does not work on single machine. Skip it.
    for func in [
        check_server_client,
        check_server_client_hetero,
        check_server_client_empty,
        check_server_client_hierarchy,
    ]:
        func(shared_mem, num_servers, num_clients, use_graphbolt=use_graphbolt)
1005
1006


1007
@unittest.skip(reason="Skip due to glitch in CI")
1008
1009
1010
1011
1012
1013
1014
1015
1016
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support distributed DistEmbedding",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="Mxnet doesn't support distributed DistEmbedding",
)
1017
def test_dist_emb_server_client():
1018
    reset_envs()
1019
    os.environ["DGL_DIST_MODE"] = "distributed"
1020
1021
    check_dist_emb_server_client(True, 1, 1)
    check_dist_emb_server_client(False, 1, 1)
1022
1023
    # [TODO][Rhett] Tests for multiple groups may fail sometimes and
    # root cause is unknown. Let's disable them for now.
1024
1025
1026
1027
1028
1029
    # check_dist_emb_server_client(True, 2, 2)
    # check_dist_emb_server_client(True, 1, 1, 2)
    # check_dist_emb_server_client(False, 1, 1, 2)
    # check_dist_emb_server_client(True, 2, 2, 2)


1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support distributed Optimizer",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="Mxnet doesn't support distributed Optimizer",
)
def test_dist_optim_server_client():
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    optimizer_states = []
    num_nodes = 10000
    optimizer_states.append(F.uniform((num_nodes, 1), F.float32, F.cpu(), 0, 1))
    optimizer_states.append(F.uniform((num_nodes, 1), F.float32, F.cpu(), 0, 1))
    check_dist_optim_server_client(num_nodes, 1, 4, optimizer_states, True)
    check_dist_optim_server_client(num_nodes, 1, 8, optimizer_states, False)
    check_dist_optim_server_client(num_nodes, 1, 2, optimizer_states, False)


def check_dist_optim_server_client(
    num_nodes, num_servers, num_clients, optimizer_states, save
):
    graph_name = f"check_dist_optim_{num_servers}_store"
    if save:
        prepare_dist(num_servers)
        g = create_random_graph(num_nodes)

        # Partition the graph
        num_parts = 1
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1060
1061
        g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
        g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
        partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
    ctx = mp.get_context("spawn")
    for serv_id in range(num_servers):
        p = ctx.Process(
            target=run_server,
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                True,
            ),
        )
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
        print("start client[{}] for group[0]".format(cli_id))
        p = ctx.Process(
            target=run_optim_client,
            args=(
                graph_name,
                0,
                num_servers,
                cli_id,
                num_clients,
                num_nodes,
                optimizer_states,
                save,
            ),
        )
        p.start()
        time.sleep(1)  # avoid race condition when instantiating DistGraph
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
        assert p.exitcode == 0

    for p in serv_ps:
        p.join()
1108
        assert p.exitcode == 0
1109
1110


1111
1112
1113
1114
1115
1116
1117
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support some of operations in DistGraph",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
1118
def test_standalone():
1119
    reset_envs()
1120
    os.environ["DGL_DIST_MODE"] = "standalone"
Da Zheng's avatar
Da Zheng committed
1121

1122
1123
1124
    g = create_random_graph(10000)
    # Partition the graph
    num_parts = 1
1125
    graph_name = "dist_graph_test_3"
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1126
1127
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
1128
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
1129
1130

    dgl.distributed.initialize("kv_ip_config.txt")
1131
1132
1133
    dist_g = DistGraph(
        graph_name, part_config="/tmp/dist_graph/{}.json".format(graph_name)
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1134
    check_dist_graph(dist_g, 1, g.num_nodes(), g.num_edges())
1135
1136
    dgl.distributed.exit_client()  # this is needed since there's two test here in one process

1137

1138
@unittest.skip(reason="Skip due to glitch in CI")
1139
1140
1141
1142
1143
1144
1145
1146
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support distributed DistEmbedding",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="Mxnet doesn't support distributed DistEmbedding",
)
1147
def test_standalone_node_emb():
1148
    reset_envs()
1149
    os.environ["DGL_DIST_MODE"] = "standalone"
1150
1151
1152
1153

    g = create_random_graph(10000)
    # Partition the graph
    num_parts = 1
1154
    graph_name = "dist_graph_test_3"
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1155
1156
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
1157
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
1158
1159

    dgl.distributed.initialize("kv_ip_config.txt")
1160
1161
1162
    dist_g = DistGraph(
        graph_name, part_config="/tmp/dist_graph/{}.json".format(graph_name)
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1163
    check_dist_emb(dist_g, 1, g.num_nodes(), g.num_edges())
1164
1165
    dgl.distributed.exit_client()  # this is needed since there's two test here in one process

1166

1167
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
1168
@pytest.mark.parametrize("hetero", [True, False])
1169
1170
@pytest.mark.parametrize("empty_mask", [True, False])
def test_split(hetero, empty_mask):
1171
1172
    if hetero:
        g = create_random_hetero()
1173
1174
        ntype = "n1"
        etype = "r1"
1175
1176
    else:
        g = create_random_graph(10000)
1177
1178
        ntype = "_N"
        etype = "_E"
1179
1180
    num_parts = 4
    num_hops = 2
1181
1182
1183
1184
1185
1186
1187
1188
    partition_graph(
        g,
        "dist_graph_test",
        num_parts,
        "/tmp/dist_graph",
        num_hops=num_hops,
        part_method="metis",
    )
1189

1190
1191
1192
    mask_thd = 100 if empty_mask else 30
    node_mask = np.random.randint(0, 100, size=g.num_nodes(ntype)) > mask_thd
    edge_mask = np.random.randint(0, 100, size=g.num_edges(etype)) > mask_thd
1193
1194
    selected_nodes = np.nonzero(node_mask)[0]
    selected_edges = np.nonzero(edge_mask)[0]
Da Zheng's avatar
Da Zheng committed
1195
1196
1197
1198
1199

    # The code now collects the roles of all client processes and use the information
    # to determine how to split the workloads. Here is to simulate the multi-client
    # use case.
    def set_roles(num_clients):
1200
1201
1202
1203
1204
        dgl.distributed.role.CUR_ROLE = "default"
        dgl.distributed.role.GLOBAL_RANK = {i: i for i in range(num_clients)}
        dgl.distributed.role.PER_ROLE_RANK["default"] = {
            i: i for i in range(num_clients)
        }
Da Zheng's avatar
Da Zheng committed
1205

1206
    for i in range(num_parts):
Da Zheng's avatar
Da Zheng committed
1207
        set_roles(num_parts)
1208
1209
1210
1211
        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(
            "/tmp/dist_graph/dist_graph_test.json", i
        )
        local_nids = F.nonzero_1d(part_g.ndata["inner_node"])
1212
        local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
1213
1214
1215
1216
1217
1218
        if hetero:
            ntype_ids, nids = gpb.map_to_per_ntype(local_nids)
            local_nids = F.asnumpy(nids)[F.asnumpy(ntype_ids) == 0]
        else:
            local_nids = F.asnumpy(local_nids)
        nodes1 = np.intersect1d(selected_nodes, local_nids)
1219
1220
1221
        nodes2 = node_split(
            node_mask, gpb, ntype=ntype, rank=i, force_even=False
        )
1222
        assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes2)))
1223
        for n in F.asnumpy(nodes2):
1224
1225
            assert n in local_nids

Da Zheng's avatar
Da Zheng committed
1226
        set_roles(num_parts * 2)
1227
1228
1229
1230
1231
1232
        nodes3 = node_split(
            node_mask, gpb, ntype=ntype, rank=i * 2, force_even=False
        )
        nodes4 = node_split(
            node_mask, gpb, ntype=ntype, rank=i * 2 + 1, force_even=False
        )
1233
1234
1235
        nodes5 = F.cat([nodes3, nodes4], 0)
        assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes5)))

Da Zheng's avatar
Da Zheng committed
1236
        set_roles(num_parts)
1237
        local_eids = F.nonzero_1d(part_g.edata["inner_edge"])
1238
        local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)
1239
1240
1241
1242
1243
1244
        if hetero:
            etype_ids, eids = gpb.map_to_per_etype(local_eids)
            local_eids = F.asnumpy(eids)[F.asnumpy(etype_ids) == 0]
        else:
            local_eids = F.asnumpy(local_eids)
        edges1 = np.intersect1d(selected_edges, local_eids)
1245
1246
1247
        edges2 = edge_split(
            edge_mask, gpb, etype=etype, rank=i, force_even=False
        )
1248
        assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges2)))
1249
        for e in F.asnumpy(edges2):
1250
1251
            assert e in local_eids

Da Zheng's avatar
Da Zheng committed
1252
        set_roles(num_parts * 2)
1253
1254
1255
1256
1257
1258
        edges3 = edge_split(
            edge_mask, gpb, etype=etype, rank=i * 2, force_even=False
        )
        edges4 = edge_split(
            edge_mask, gpb, etype=etype, rank=i * 2 + 1, force_even=False
        )
1259
1260
1261
        edges5 = F.cat([edges3, edges4], 0)
        assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges5)))

1262
1263

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
1264
1265
@pytest.mark.parametrize("empty_mask", [True, False])
def test_split_even(empty_mask):
1266
1267
1268
    g = create_random_graph(10000)
    num_parts = 4
    num_hops = 2
1269
1270
1271
1272
1273
1274
1275
1276
    partition_graph(
        g,
        "dist_graph_test",
        num_parts,
        "/tmp/dist_graph",
        num_hops=num_hops,
        part_method="metis",
    )
1277

1278
1279
1280
    mask_thd = 100 if empty_mask else 30
    node_mask = np.random.randint(0, 100, size=g.num_nodes()) > mask_thd
    edge_mask = np.random.randint(0, 100, size=g.num_edges()) > mask_thd
1281
1282
1283
1284
    all_nodes1 = []
    all_nodes2 = []
    all_edges1 = []
    all_edges2 = []
Da Zheng's avatar
Da Zheng committed
1285
1286
1287
1288
1289

    # The code now collects the roles of all client processes and use the information
    # to determine how to split the workloads. Here is to simulate the multi-client
    # use case.
    def set_roles(num_clients):
1290
1291
1292
1293
1294
        dgl.distributed.role.CUR_ROLE = "default"
        dgl.distributed.role.GLOBAL_RANK = {i: i for i in range(num_clients)}
        dgl.distributed.role.PER_ROLE_RANK["default"] = {
            i: i for i in range(num_clients)
        }
Da Zheng's avatar
Da Zheng committed
1295

1296
    for i in range(num_parts):
Da Zheng's avatar
Da Zheng committed
1297
        set_roles(num_parts)
1298
1299
1300
1301
        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(
            "/tmp/dist_graph/dist_graph_test.json", i
        )
        local_nids = F.nonzero_1d(part_g.ndata["inner_node"])
1302
        local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
1303
        nodes = node_split(node_mask, gpb, rank=i, force_even=True)
1304
1305
        all_nodes1.append(nodes)
        subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(local_nids))
1306
1307
1308
1309
1310
        print(
            "part {} get {} nodes and {} are in the partition".format(
                i, len(nodes), len(subset)
            )
        )
1311

Da Zheng's avatar
Da Zheng committed
1312
        set_roles(num_parts * 2)
1313
1314
1315
        nodes1 = node_split(node_mask, gpb, rank=i * 2, force_even=True)
        nodes2 = node_split(node_mask, gpb, rank=i * 2 + 1, force_even=True)
        nodes3, _ = F.sort_1d(F.cat([nodes1, nodes2], 0))
1316
1317
        all_nodes2.append(nodes3)
        subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(nodes3))
1318
        print("intersection has", len(subset))
1319

Da Zheng's avatar
Da Zheng committed
1320
        set_roles(num_parts)
1321
        local_eids = F.nonzero_1d(part_g.edata["inner_edge"])
1322
        local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)
1323
        edges = edge_split(edge_mask, gpb, rank=i, force_even=True)
1324
1325
        all_edges1.append(edges)
        subset = np.intersect1d(F.asnumpy(edges), F.asnumpy(local_eids))
1326
1327
1328
1329
1330
        print(
            "part {} get {} edges and {} are in the partition".format(
                i, len(edges), len(subset)
            )
        )
1331

Da Zheng's avatar
Da Zheng committed
1332
        set_roles(num_parts * 2)
1333
1334
1335
        edges1 = edge_split(edge_mask, gpb, rank=i * 2, force_even=True)
        edges2 = edge_split(edge_mask, gpb, rank=i * 2 + 1, force_even=True)
        edges3, _ = F.sort_1d(F.cat([edges1, edges2], 0))
1336
1337
        all_edges2.append(edges3)
        subset = np.intersect1d(F.asnumpy(edges), F.asnumpy(edges3))
1338
        print("intersection has", len(subset))
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
    all_nodes1 = F.cat(all_nodes1, 0)
    all_edges1 = F.cat(all_edges1, 0)
    all_nodes2 = F.cat(all_nodes2, 0)
    all_edges2 = F.cat(all_edges2, 0)
    all_nodes = np.nonzero(node_mask)[0]
    all_edges = np.nonzero(edge_mask)[0]
    assert np.all(all_nodes == F.asnumpy(all_nodes1))
    assert np.all(all_edges == F.asnumpy(all_edges1))
    assert np.all(all_nodes == F.asnumpy(all_nodes2))
    assert np.all(all_edges == F.asnumpy(all_edges2))

1350

1351
1352
def prepare_dist(num_servers=1):
    generate_ip_config("kv_ip_config.txt", 1, num_servers=num_servers)
1353

1354
1355
1356

if __name__ == "__main__":
    os.makedirs("/tmp/dist_graph", exist_ok=True)
1357
    test_dist_emb_server_client()
1358
    test_server_client()
1359
1360
    test_split(True)
    test_split(False)
1361
    test_split_even()
1362
    test_standalone()
1363
    test_standalone_node_emb()