test_dist_graph_store.py 34.8 KB
Newer Older
1
import os
2
3
4
5
6
7

os.environ["OMP_NUM_THREADS"] = "1"
import math
import multiprocessing as mp
import pickle
import socket
8
9
10
import sys
import time
import unittest
11
12
13
14
from multiprocessing import Condition, Manager, Process, Value

import backend as F
import numpy as np
15
import pytest
16
17
18
from numpy.testing import assert_almost_equal, assert_array_equal
from scipy import sparse as spsp
from utils import create_random_graph, generate_ip_config, reset_envs
19

20
21
22
23
24
25
26
27
28
29
30
31
32
33
import dgl
from dgl.data.utils import load_graphs, save_graphs
from dgl.distributed import (
    DistGraph,
    DistGraphServer,
    edge_split,
    load_partition,
    load_partition_book,
    node_split,
    partition_graph,
)
from dgl.heterograph_index import create_unitgraph_from_coo

if os.name != "nt":
34
35
36
    import fcntl
    import struct

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56

def run_server(
    graph_name,
    server_id,
    server_count,
    num_clients,
    shared_mem,
    keep_alive=False,
):
    g = DistGraphServer(
        server_id,
        "kv_ip_config.txt",
        server_count,
        num_clients,
        "/tmp/dist_graph/{}.json".format(graph_name),
        disable_shared_mem=not shared_mem,
        graph_format=["csc", "coo"],
        keep_alive=keep_alive,
    )
    print("start server", server_id)
57
58
59
60
    # verify dtype of underlying graph
    cg = g.client_g
    for k, dtype in dgl.distributed.dist_graph.FIELD_DICT.items():
        if k in cg.ndata:
61
62
63
            assert (
                F.dtype(cg.ndata[k]) == dtype
            ), "Data type of {} in ndata should be {}.".format(k, dtype)
64
        if k in cg.edata:
65
66
67
            assert (
                F.dtype(cg.edata[k]) == dtype
            ), "Data type of {} in edata should be {}.".format(k, dtype)
68
69
    g.start()

70

71
72
73
def emb_init(shape, dtype):
    return F.zeros(shape, dtype, F.cpu())

74

75
def rand_init(shape, dtype):
76
    return F.tensor(np.random.normal(size=shape), F.float32)
77

78

79
80
81
82
83
84
85
def check_dist_graph_empty(g, num_clients, num_nodes, num_edges):
    # Test API
    assert g.number_of_nodes() == num_nodes
    assert g.number_of_edges() == num_edges

    # Test init node data
    new_shape = (g.number_of_nodes(), 2)
86
    g.ndata["test1"] = dgl.distributed.DistTensor(new_shape, F.int32)
87
    nids = F.arange(0, int(g.number_of_nodes() / 2))
88
    feats = g.ndata["test1"][nids]
89
90
91
    assert np.all(F.asnumpy(feats) == 0)

    # create a tensor and destroy a tensor and create it again.
92
93
94
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test3", init_func=rand_init
    )
95
    del test3
96
97
98
    test3 = dgl.distributed.DistTensor(
        (g.number_of_nodes(), 3), F.float32, "test3"
    )
99
100
101
102
    del test3

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
103
104
    g.ndata["test1"][nids] = new_feats
    feats = g.ndata["test1"][nids]
105
106
107
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
108
    assert g.node_attr_schemes()["test1"].dtype == F.int32
109

110
    print("end")
111

112
113
114
115
116

def run_client_empty(
    graph_name, part_id, server_count, num_clients, num_nodes, num_edges
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
117
    dgl.distributed.initialize("kv_ip_config.txt")
118
119
120
    gpb, graph_name, _, _ = load_partition_book(
        "/tmp/dist_graph/{}.json".format(graph_name), part_id, None
    )
121
122
123
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_graph_empty(g, num_clients, num_nodes, num_edges)

124

125
def check_server_client_empty(shared_mem, num_servers, num_clients):
126
    prepare_dist(num_servers)
127
128
129
130
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
131
132
    graph_name = "dist_graph_test_1"
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
133
134
135
136

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
137
    ctx = mp.get_context("spawn")
138
    for serv_id in range(num_servers):
139
140
141
142
        p = ctx.Process(
            target=run_server,
            args=(graph_name, serv_id, num_servers, num_clients, shared_mem),
        )
143
144
145
146
147
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
148
149
150
151
152
153
154
155
156
157
158
159
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client_empty,
            args=(
                graph_name,
                0,
                num_servers,
                num_clients,
                g.number_of_nodes(),
                g.number_of_edges(),
            ),
        )
160
161
162
163
164
165
166
167
168
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()

    for p in serv_ps:
        p.join()

169
    print("clients have terminated")
170

171
172
173
174
175
176
177
178
179
180
181
182

def run_client(
    graph_name,
    part_id,
    server_count,
    num_clients,
    num_nodes,
    num_edges,
    group_id,
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
    os.environ["DGL_GROUP_ID"] = str(group_id)
183
    dgl.distributed.initialize("kv_ip_config.txt")
184
185
186
    gpb, graph_name, _, _ = load_partition_book(
        "/tmp/dist_graph/{}.json".format(graph_name), part_id, None
    )
187
    g = DistGraph(graph_name, gpb=gpb)
188
    check_dist_graph(g, num_clients, num_nodes, num_edges)
189

190
191
192
193
194
195
196
197
198
199
200
201

def run_emb_client(
    graph_name,
    part_id,
    server_count,
    num_clients,
    num_nodes,
    num_edges,
    group_id,
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
    os.environ["DGL_GROUP_ID"] = str(group_id)
202
    dgl.distributed.initialize("kv_ip_config.txt")
203
204
205
    gpb, graph_name, _, _ = load_partition_book(
        "/tmp/dist_graph/{}.json".format(graph_name), part_id, None
    )
206
207
208
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_emb(g, num_clients, num_nodes, num_edges)

209
210
211
212
213

def run_client_hierarchy(
    graph_name, part_id, server_count, node_mask, edge_mask, return_dict
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
214
    dgl.distributed.initialize("kv_ip_config.txt")
215
216
217
    gpb, graph_name, _, _ = load_partition_book(
        "/tmp/dist_graph/{}.json".format(graph_name), part_id, None
    )
218
219
220
    g = DistGraph(graph_name, gpb=gpb)
    node_mask = F.tensor(node_mask)
    edge_mask = F.tensor(edge_mask)
221
222
223
224
225
226
227
228
229
230
    nodes = node_split(
        node_mask,
        g.get_partition_book(),
        node_trainer_ids=g.ndata["trainer_id"],
    )
    edges = edge_split(
        edge_mask,
        g.get_partition_book(),
        edge_trainer_ids=g.edata["trainer_id"],
    )
231
232
233
    rank = g.rank()
    return_dict[rank] = (nodes, edges)

234

235
def check_dist_emb(g, num_clients, num_nodes, num_edges):
236
    from dgl.distributed import DistEmbedding
237
238
    from dgl.distributed.optim import SparseAdagrad

239
240
    # Test sparse emb
    try:
241
        emb = DistEmbedding(g.number_of_nodes(), 1, "emb1", emb_init)
242
        nids = F.arange(0, int(g.number_of_nodes()))
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
        lr = 0.001
        optimizer = SparseAdagrad([emb], lr=lr)
        with F.record_grad():
            feats = emb(nids)
            assert np.all(F.asnumpy(feats) == np.zeros((len(nids), 1)))
            loss = F.sum(feats + 1, 0)
        loss.backward()
        optimizer.step()
        feats = emb(nids)
        if num_clients == 1:
            assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * -lr)
        rest = np.setdiff1d(np.arange(g.number_of_nodes()), F.asnumpy(nids))
        feats1 = emb(rest)
        assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))

258
259
260
261
        policy = dgl.distributed.PartitionPolicy("node", g.get_partition_book())
        grad_sum = dgl.distributed.DistTensor(
            (g.number_of_nodes(), 1), F.float32, "emb1_sum", policy
        )
262
        if num_clients == 1:
263
264
265
266
            assert np.all(
                F.asnumpy(grad_sum[nids])
                == np.ones((len(nids), 1)) * num_clients
            )
267
268
        assert np.all(F.asnumpy(grad_sum[rest]) == np.zeros((len(rest), 1)))

269
        emb = DistEmbedding(g.number_of_nodes(), 1, "emb2", emb_init)
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
        with F.no_grad():
            feats1 = emb(nids)
        assert np.all(F.asnumpy(feats1) == 0)

        optimizer = SparseAdagrad([emb], lr=lr)
        with F.record_grad():
            feats1 = emb(nids)
            feats2 = emb(nids)
            feats = F.cat([feats1, feats2], 0)
            assert np.all(F.asnumpy(feats) == np.zeros((len(nids) * 2, 1)))
            loss = F.sum(feats + 1, 0)
        loss.backward()
        optimizer.step()
        with F.no_grad():
            feats = emb(nids)
        if num_clients == 1:
286
287
288
            assert_almost_equal(
                F.asnumpy(feats), np.ones((len(nids), 1)) * 1 * -lr
            )
289
290
291
292
293
        rest = np.setdiff1d(np.arange(g.number_of_nodes()), F.asnumpy(nids))
        feats1 = emb(rest)
        assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))
    except NotImplementedError as e:
        pass
294
295
296
    except Exception as e:
        print(e)
        sys.exit(-1)
297

298

299
def check_dist_graph(g, num_clients, num_nodes, num_edges):
300
301
302
303
304
305
    # Test API
    assert g.number_of_nodes() == num_nodes
    assert g.number_of_edges() == num_edges

    # Test reading node data
    nids = F.arange(0, int(g.number_of_nodes() / 2))
306
    feats1 = g.ndata["features"][nids]
307
308
309
310
311
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == nids))

    # Test reading edge data
    eids = F.arange(0, int(g.number_of_edges() / 2))
312
    feats1 = g.edata["features"][eids]
313
314
315
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == eids))

316
317
318
319
320
    # Test edge_subgraph
    sg = g.edge_subgraph(eids)
    assert sg.num_edges() == len(eids)
    assert F.array_equal(sg.edata[dgl.EID], eids)

321
322
    # Test init node data
    new_shape = (g.number_of_nodes(), 2)
323
    test1 = dgl.distributed.DistTensor(new_shape, F.int32)
324
325
    g.ndata["test1"] = test1
    feats = g.ndata["test1"][nids]
326
    assert np.all(F.asnumpy(feats) == 0)
327
    assert test1.count_nonzero() == 0
328

329
    # reference to a one that exists
330
331
332
333
    test2 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test2", init_func=rand_init
    )
    test3 = dgl.distributed.DistTensor(new_shape, F.float32, "test2")
334
335
336
    assert np.all(F.asnumpy(test2[nids]) == F.asnumpy(test3[nids]))

    # create a tensor and destroy a tensor and create it again.
337
338
339
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test3", init_func=rand_init
    )
340
    del test3
341
342
343
    test3 = dgl.distributed.DistTensor(
        (g.number_of_nodes(), 3), F.float32, "test3"
    )
344
345
    del test3

Da Zheng's avatar
Da Zheng committed
346
    # add tests for anonymous distributed tensor.
347
348
349
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
Da Zheng's avatar
Da Zheng committed
350
    data = test3[0:10]
351
352
353
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
Da Zheng's avatar
Da Zheng committed
354
    del test3
355
356
357
    test5 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
Da Zheng's avatar
Da Zheng committed
358
359
    assert np.sum(F.asnumpy(test5[0:10] != data)) > 0

360
    # test a persistent tesnor
361
362
363
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test4", init_func=rand_init, persistent=True
    )
364
365
    del test4
    try:
366
367
368
369
        test4 = dgl.distributed.DistTensor(
            (g.number_of_nodes(), 3), F.float32, "test4"
        )
        raise Exception("")
370
371
    except:
        pass
372
373
374

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
375
376
    g.ndata["test1"][nids] = new_feats
    feats = g.ndata["test1"][nids]
377
378
379
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
380
381
382
383
384
385
    assert len(g.ndata["features"]) == g.number_of_nodes()
    assert g.ndata["features"].shape == (g.number_of_nodes(), 1)
    assert g.ndata["features"].dtype == F.int64
    assert g.node_attr_schemes()["features"].dtype == F.int64
    assert g.node_attr_schemes()["test1"].dtype == F.int32
    assert g.node_attr_schemes()["features"].shape == (1,)
386

387
388
    selected_nodes = np.random.randint(0, 100, size=g.number_of_nodes()) > 30
    # Test node split
389
    nodes = node_split(selected_nodes, g.get_partition_book())
390
391
392
393
394
395
    nodes = F.asnumpy(nodes)
    # We only have one partition, so the local nodes are basically all nodes in the graph.
    local_nids = np.arange(g.number_of_nodes())
    for n in nodes:
        assert n in local_nids

396
397
    print("end")

398

399
400
401
def check_dist_emb_server_client(
    shared_mem, num_servers, num_clients, num_groups=1
):
402
    prepare_dist(num_servers)
403
404
405
406
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
407
408
409
410
411
412
    graph_name = (
        f"check_dist_emb_{shared_mem}_{num_servers}_{num_clients}_{num_groups}"
    )
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
413
414
415
416

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
417
    ctx = mp.get_context("spawn")
418
    keep_alive = num_groups > 1
419
    for serv_id in range(num_servers):
420
421
422
423
424
425
426
427
428
429
430
        p = ctx.Process(
            target=run_server,
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                shared_mem,
                keep_alive,
            ),
        )
431
432
433
434
435
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
436
        for group_id in range(num_groups):
437
438
439
440
441
442
443
444
445
446
447
448
449
            print("start client[{}] for group[{}]".format(cli_id, group_id))
            p = ctx.Process(
                target=run_emb_client,
                args=(
                    graph_name,
                    0,
                    num_servers,
                    num_clients,
                    g.number_of_nodes(),
                    g.number_of_edges(),
                    group_id,
                ),
            )
450
            p.start()
451
            time.sleep(1)  # avoid race condition when instantiating DistGraph
452
            cli_ps.append(p)
453
454
455

    for p in cli_ps:
        p.join()
456
        assert p.exitcode == 0
457

458
459
460
461
462
    if keep_alive:
        for p in serv_ps:
            assert p.is_alive()
        # force shutdown server
        dgl.distributed.shutdown_servers("kv_ip_config.txt", num_servers)
463
464
465
    for p in serv_ps:
        p.join()

466
467
    print("clients have terminated")

468

469
def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
470
    prepare_dist(num_servers)
471
472
473
474
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
475
476
477
478
    graph_name = f"check_server_client_{shared_mem}_{num_servers}_{num_clients}_{num_groups}"
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
479
480
481
482

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
483
    ctx = mp.get_context("spawn")
484
    keep_alive = num_groups > 1
485
    for serv_id in range(num_servers):
486
487
488
489
490
491
492
493
494
495
496
        p = ctx.Process(
            target=run_server,
            args=(
                graph_name,
                serv_id,
                num_servers,
                num_clients,
                shared_mem,
                keep_alive,
            ),
        )
497
498
499
        serv_ps.append(p)
        p.start()

500
    # launch different client groups simultaneously
501
    cli_ps = []
502
    for cli_id in range(num_clients):
503
        for group_id in range(num_groups):
504
505
506
507
508
509
510
511
512
513
514
515
516
            print("start client[{}] for group[{}]".format(cli_id, group_id))
            p = ctx.Process(
                target=run_client,
                args=(
                    graph_name,
                    0,
                    num_servers,
                    num_clients,
                    g.number_of_nodes(),
                    g.number_of_edges(),
                    group_id,
                ),
            )
517
            p.start()
518
            time.sleep(1)  # avoid race condition when instantiating DistGraph
519
            cli_ps.append(p)
520
521
    for p in cli_ps:
        p.join()
522

523
524
525
526
527
    if keep_alive:
        for p in serv_ps:
            assert p.is_alive()
        # force shutdown server
        dgl.distributed.shutdown_servers("kv_ip_config.txt", num_servers)
528
529
530
    for p in serv_ps:
        p.join()

531
532
    print("clients have terminated")

533

534
def check_server_client_hierarchy(shared_mem, num_servers, num_clients):
535
    prepare_dist(num_servers)
536
537
538
539
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
540
541
542
543
544
545
546
547
548
549
    graph_name = "dist_graph_test_2"
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(
        g,
        graph_name,
        num_parts,
        "/tmp/dist_graph",
        num_trainers_per_machine=num_clients,
    )
550
551
552
553

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
554
    ctx = mp.get_context("spawn")
555
    for serv_id in range(num_servers):
556
557
558
559
        p = ctx.Process(
            target=run_server,
            args=(graph_name, serv_id, num_servers, num_clients, shared_mem),
        )
560
561
562
563
564
565
566
567
        serv_ps.append(p)
        p.start()

    cli_ps = []
    manager = mp.Manager()
    return_dict = manager.dict()
    node_mask = np.zeros((g.number_of_nodes(),), np.int32)
    edge_mask = np.zeros((g.number_of_edges(),), np.int32)
568
569
570
571
572
573
    nodes = np.random.choice(
        g.number_of_nodes(), g.number_of_nodes() // 10, replace=False
    )
    edges = np.random.choice(
        g.number_of_edges(), g.number_of_edges() // 10, replace=False
    )
574
575
576
577
578
    node_mask[nodes] = 1
    edge_mask[edges] = 1
    nodes = np.sort(nodes)
    edges = np.sort(edges)
    for cli_id in range(num_clients):
579
580
581
582
583
584
585
586
587
588
589
590
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client_hierarchy,
            args=(
                graph_name,
                0,
                num_servers,
                node_mask,
                edge_mask,
                return_dict,
            ),
        )
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
    for p in serv_ps:
        p.join()

    nodes1 = []
    edges1 = []
    for n, e in return_dict.values():
        nodes1.append(n)
        edges1.append(e)
    nodes1, _ = F.sort_1d(F.cat(nodes1, 0))
    edges1, _ = F.sort_1d(F.cat(edges1, 0))
    assert np.all(F.asnumpy(nodes1) == nodes)
    assert np.all(F.asnumpy(edges1) == edges)

609
    print("clients have terminated")
610

611

612
613
614
615
def run_client_hetero(
    graph_name, part_id, server_count, num_clients, num_nodes, num_edges
):
    os.environ["DGL_NUM_SERVER"] = str(server_count)
616
    dgl.distributed.initialize("kv_ip_config.txt")
617
618
619
    gpb, graph_name, _, _ = load_partition_book(
        "/tmp/dist_graph/{}.json".format(graph_name), part_id, None
    )
620
621
622
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_graph_hetero(g, num_clients, num_nodes, num_edges)

623

624
def create_random_hetero():
625
626
    num_nodes = {"n1": 10000, "n2": 10010, "n3": 10020}
    etypes = [("n1", "r1", "n2"), ("n1", "r2", "n3"), ("n2", "r3", "n3")]
627
628
629
    edges = {}
    for etype in etypes:
        src_ntype, _, dst_ntype = etype
630
631
632
633
634
635
636
        arr = spsp.random(
            num_nodes[src_ntype],
            num_nodes[dst_ntype],
            density=0.001,
            format="coo",
            random_state=100,
        )
637
638
        edges[etype] = (arr.row, arr.col)
    g = dgl.heterograph(edges, num_nodes)
639
640
641
642
643
644
    g.nodes["n1"].data["feat"] = F.unsqueeze(
        F.arange(0, g.number_of_nodes("n1")), 1
    )
    g.edges["r1"].data["feat"] = F.unsqueeze(
        F.arange(0, g.number_of_edges("r1")), 1
    )
645
646
    return g

647

648
649
650
651
652
653
654
655
def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges):
    # Test API
    for ntype in num_nodes:
        assert ntype in g.ntypes
        assert num_nodes[ntype] == g.number_of_nodes(ntype)
    for etype in num_edges:
        assert etype in g.etypes
        assert num_edges[etype] == g.number_of_edges(etype)
656
    etypes = [("n1", "r1", "n2"), ("n1", "r2", "n3"), ("n2", "r3", "n3")]
657
658
659
660
    for i, etype in enumerate(g.canonical_etypes):
        assert etype[0] == etypes[i][0]
        assert etype[1] == etypes[i][1]
        assert etype[2] == etypes[i][2]
661
662
663
664
    assert g.number_of_nodes() == sum([num_nodes[ntype] for ntype in num_nodes])
    assert g.number_of_edges() == sum([num_edges[etype] for etype in num_edges])

    # Test reading node data
665
666
    nids = F.arange(0, int(g.number_of_nodes("n1") / 2))
    feats1 = g.nodes["n1"].data["feat"][nids]
667
668
669
670
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == nids))

    # Test reading edge data
671
672
    eids = F.arange(0, int(g.number_of_edges("r1") / 2))
    feats1 = g.edges["r1"].data["feat"][eids]
673
674
675
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == eids))

676
    # Test edge_subgraph
677
    sg = g.edge_subgraph({"r1": eids})
678
679
    assert sg.num_edges() == len(eids)
    assert F.array_equal(sg.edata[dgl.EID], eids)
680
    sg = g.edge_subgraph({("n1", "r1", "n2"): eids})
681
682
683
    assert sg.num_edges() == len(eids)
    assert F.array_equal(sg.edata[dgl.EID], eids)

684
    # Test init node data
685
686
687
    new_shape = (g.number_of_nodes("n1"), 2)
    g.nodes["n1"].data["test1"] = dgl.distributed.DistTensor(new_shape, F.int32)
    feats = g.nodes["n1"].data["test1"][nids]
688
689
690
    assert np.all(F.asnumpy(feats) == 0)

    # create a tensor and destroy a tensor and create it again.
691
692
693
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test3", init_func=rand_init
    )
694
    del test3
695
696
697
    test3 = dgl.distributed.DistTensor(
        (g.number_of_nodes("n1"), 3), F.float32, "test3"
    )
698
699
700
    del test3

    # add tests for anonymous distributed tensor.
701
702
703
    test3 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
704
    data = test3[0:10]
705
706
707
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
708
    del test3
709
710
711
    test5 = dgl.distributed.DistTensor(
        new_shape, F.float32, init_func=rand_init
    )
712
713
714
    assert np.sum(F.asnumpy(test5[0:10] != data)) > 0

    # test a persistent tesnor
715
716
717
    test4 = dgl.distributed.DistTensor(
        new_shape, F.float32, "test4", init_func=rand_init, persistent=True
    )
718
719
    del test4
    try:
720
721
722
723
        test4 = dgl.distributed.DistTensor(
            (g.number_of_nodes("n1"), 3), F.float32, "test4"
        )
        raise Exception("")
724
725
726
727
728
    except:
        pass

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
729
730
    g.nodes["n1"].data["test1"][nids] = new_feats
    feats = g.nodes["n1"].data["test1"][nids]
731
732
733
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
734
735
736
    assert len(g.nodes["n1"].data["feat"]) == g.number_of_nodes("n1")
    assert g.nodes["n1"].data["feat"].shape == (g.number_of_nodes("n1"), 1)
    assert g.nodes["n1"].data["feat"].dtype == F.int64
737

738
739
740
    selected_nodes = (
        np.random.randint(0, 100, size=g.number_of_nodes("n1")) > 30
    )
741
    # Test node split
742
    nodes = node_split(selected_nodes, g.get_partition_book(), ntype="n1")
743
744
    nodes = F.asnumpy(nodes)
    # We only have one partition, so the local nodes are basically all nodes in the graph.
745
    local_nids = np.arange(g.number_of_nodes("n1"))
746
747
748
    for n in nodes:
        assert n in local_nids

749
750
    print("end")

751
752

def check_server_client_hetero(shared_mem, num_servers, num_clients):
753
    prepare_dist(num_servers)
754
755
756
757
    g = create_random_hetero()

    # Partition the graph
    num_parts = 1
758
759
    graph_name = "dist_graph_test_3"
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
760
761
762
763

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
764
    ctx = mp.get_context("spawn")
765
    for serv_id in range(num_servers):
766
767
768
769
        p = ctx.Process(
            target=run_server,
            args=(graph_name, serv_id, num_servers, num_clients, shared_mem),
        )
770
771
772
773
774
775
776
        serv_ps.append(p)
        p.start()

    cli_ps = []
    num_nodes = {ntype: g.number_of_nodes(ntype) for ntype in g.ntypes}
    num_edges = {etype: g.number_of_edges(etype) for etype in g.etypes}
    for cli_id in range(num_clients):
777
778
779
780
781
782
783
784
785
786
787
788
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client_hetero,
            args=(
                graph_name,
                0,
                num_servers,
                num_clients,
                num_nodes,
                num_edges,
            ),
        )
789
790
791
792
793
794
795
796
797
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()

    for p in serv_ps:
        p.join()

798
799
    print("clients have terminated")

800

801
802
803
804
805
806
807
808
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support some of operations in DistGraph",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
809
def test_server_client():
810
    reset_envs()
811
    os.environ["DGL_DIST_MODE"] = "distributed"
812
    check_server_client_hierarchy(False, 1, 4)
813
    check_server_client_empty(True, 1, 1)
814
815
    check_server_client_hetero(True, 1, 1)
    check_server_client_hetero(False, 1, 1)
816
817
    check_server_client(True, 1, 1)
    check_server_client(False, 1, 1)
818
819
    # [TODO][Rhett] Tests for multiple groups may fail sometimes and
    # root cause is unknown. Let's disable them for now.
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
    # check_server_client(True, 2, 2)
    # check_server_client(True, 1, 1, 2)
    # check_server_client(False, 1, 1, 2)
    # check_server_client(True, 2, 2, 2)


@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support distributed DistEmbedding",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="Mxnet doesn't support distributed DistEmbedding",
)
835
def test_dist_emb_server_client():
836
    reset_envs()
837
    os.environ["DGL_DIST_MODE"] = "distributed"
838
839
    check_dist_emb_server_client(True, 1, 1)
    check_dist_emb_server_client(False, 1, 1)
840
841
    # [TODO][Rhett] Tests for multiple groups may fail sometimes and
    # root cause is unknown. Let's disable them for now.
842
843
844
845
846
847
848
849
850
851
852
853
854
    # check_dist_emb_server_client(True, 2, 2)
    # check_dist_emb_server_client(True, 1, 1, 2)
    # check_dist_emb_server_client(False, 1, 1, 2)
    # check_dist_emb_server_client(True, 2, 2, 2)


@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support some of operations in DistGraph",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
855
def test_standalone():
856
    reset_envs()
857
    os.environ["DGL_DIST_MODE"] = "standalone"
Da Zheng's avatar
Da Zheng committed
858

859
860
861
    g = create_random_graph(10000)
    # Partition the graph
    num_parts = 1
862
863
864
865
    graph_name = "dist_graph_test_3"
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
866
867

    dgl.distributed.initialize("kv_ip_config.txt")
868
869
870
    dist_g = DistGraph(
        graph_name, part_config="/tmp/dist_graph/{}.json".format(graph_name)
    )
871
    check_dist_graph(dist_g, 1, g.number_of_nodes(), g.number_of_edges())
872
873
    dgl.distributed.exit_client()  # this is needed since there's two test here in one process

874

875
876
877
878
879
880
881
882
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support distributed DistEmbedding",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet",
    reason="Mxnet doesn't support distributed DistEmbedding",
)
883
def test_standalone_node_emb():
884
    reset_envs()
885
    os.environ["DGL_DIST_MODE"] = "standalone"
886
887
888
889

    g = create_random_graph(10000)
    # Partition the graph
    num_parts = 1
890
891
892
893
    graph_name = "dist_graph_test_3"
    g.ndata["features"] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata["features"] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
894
895

    dgl.distributed.initialize("kv_ip_config.txt")
896
897
898
    dist_g = DistGraph(
        graph_name, part_config="/tmp/dist_graph/{}.json".format(graph_name)
    )
899
    check_dist_emb(dist_g, 1, g.number_of_nodes(), g.number_of_edges())
900
901
    dgl.distributed.exit_client()  # this is needed since there's two test here in one process

902

903
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
904
905
906
907
@pytest.mark.parametrize("hetero", [True, False])
def test_split(hetero):
    if hetero:
        g = create_random_hetero()
908
909
        ntype = "n1"
        etype = "r1"
910
911
    else:
        g = create_random_graph(10000)
912
913
        ntype = "_N"
        etype = "_E"
914
915
    num_parts = 4
    num_hops = 2
916
917
918
919
920
921
922
923
    partition_graph(
        g,
        "dist_graph_test",
        num_parts,
        "/tmp/dist_graph",
        num_hops=num_hops,
        part_method="metis",
    )
924

925
926
    node_mask = np.random.randint(0, 100, size=g.number_of_nodes(ntype)) > 30
    edge_mask = np.random.randint(0, 100, size=g.number_of_edges(etype)) > 30
927
928
    selected_nodes = np.nonzero(node_mask)[0]
    selected_edges = np.nonzero(edge_mask)[0]
Da Zheng's avatar
Da Zheng committed
929
930
931
932
933

    # The code now collects the roles of all client processes and use the information
    # to determine how to split the workloads. Here is to simulate the multi-client
    # use case.
    def set_roles(num_clients):
934
935
936
937
938
        dgl.distributed.role.CUR_ROLE = "default"
        dgl.distributed.role.GLOBAL_RANK = {i: i for i in range(num_clients)}
        dgl.distributed.role.PER_ROLE_RANK["default"] = {
            i: i for i in range(num_clients)
        }
Da Zheng's avatar
Da Zheng committed
939

940
    for i in range(num_parts):
Da Zheng's avatar
Da Zheng committed
941
        set_roles(num_parts)
942
943
944
945
        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(
            "/tmp/dist_graph/dist_graph_test.json", i
        )
        local_nids = F.nonzero_1d(part_g.ndata["inner_node"])
946
        local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
947
948
949
950
951
952
        if hetero:
            ntype_ids, nids = gpb.map_to_per_ntype(local_nids)
            local_nids = F.asnumpy(nids)[F.asnumpy(ntype_ids) == 0]
        else:
            local_nids = F.asnumpy(local_nids)
        nodes1 = np.intersect1d(selected_nodes, local_nids)
953
954
955
        nodes2 = node_split(
            node_mask, gpb, ntype=ntype, rank=i, force_even=False
        )
956
        assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes2)))
957
        for n in F.asnumpy(nodes2):
958
959
            assert n in local_nids

Da Zheng's avatar
Da Zheng committed
960
        set_roles(num_parts * 2)
961
962
963
964
965
966
        nodes3 = node_split(
            node_mask, gpb, ntype=ntype, rank=i * 2, force_even=False
        )
        nodes4 = node_split(
            node_mask, gpb, ntype=ntype, rank=i * 2 + 1, force_even=False
        )
967
968
969
        nodes5 = F.cat([nodes3, nodes4], 0)
        assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes5)))

Da Zheng's avatar
Da Zheng committed
970
        set_roles(num_parts)
971
        local_eids = F.nonzero_1d(part_g.edata["inner_edge"])
972
        local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)
973
974
975
976
977
978
        if hetero:
            etype_ids, eids = gpb.map_to_per_etype(local_eids)
            local_eids = F.asnumpy(eids)[F.asnumpy(etype_ids) == 0]
        else:
            local_eids = F.asnumpy(local_eids)
        edges1 = np.intersect1d(selected_edges, local_eids)
979
980
981
        edges2 = edge_split(
            edge_mask, gpb, etype=etype, rank=i, force_even=False
        )
982
        assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges2)))
983
        for e in F.asnumpy(edges2):
984
985
            assert e in local_eids

Da Zheng's avatar
Da Zheng committed
986
        set_roles(num_parts * 2)
987
988
989
990
991
992
        edges3 = edge_split(
            edge_mask, gpb, etype=etype, rank=i * 2, force_even=False
        )
        edges4 = edge_split(
            edge_mask, gpb, etype=etype, rank=i * 2 + 1, force_even=False
        )
993
994
995
        edges5 = F.cat([edges3, edges4], 0)
        assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges5)))

996
997

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
998
999
1000
1001
def test_split_even():
    g = create_random_graph(10000)
    num_parts = 4
    num_hops = 2
1002
1003
1004
1005
1006
1007
1008
1009
    partition_graph(
        g,
        "dist_graph_test",
        num_parts,
        "/tmp/dist_graph",
        num_hops=num_hops,
        part_method="metis",
    )
1010
1011
1012
1013
1014
1015
1016
1017
1018

    node_mask = np.random.randint(0, 100, size=g.number_of_nodes()) > 30
    edge_mask = np.random.randint(0, 100, size=g.number_of_edges()) > 30
    selected_nodes = np.nonzero(node_mask)[0]
    selected_edges = np.nonzero(edge_mask)[0]
    all_nodes1 = []
    all_nodes2 = []
    all_edges1 = []
    all_edges2 = []
Da Zheng's avatar
Da Zheng committed
1019
1020
1021
1022
1023

    # The code now collects the roles of all client processes and use the information
    # to determine how to split the workloads. Here is to simulate the multi-client
    # use case.
    def set_roles(num_clients):
1024
1025
1026
1027
1028
        dgl.distributed.role.CUR_ROLE = "default"
        dgl.distributed.role.GLOBAL_RANK = {i: i for i in range(num_clients)}
        dgl.distributed.role.PER_ROLE_RANK["default"] = {
            i: i for i in range(num_clients)
        }
Da Zheng's avatar
Da Zheng committed
1029

1030
    for i in range(num_parts):
Da Zheng's avatar
Da Zheng committed
1031
        set_roles(num_parts)
1032
1033
1034
1035
        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(
            "/tmp/dist_graph/dist_graph_test.json", i
        )
        local_nids = F.nonzero_1d(part_g.ndata["inner_node"])
1036
        local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
1037
        nodes = node_split(node_mask, gpb, rank=i, force_even=True)
1038
1039
        all_nodes1.append(nodes)
        subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(local_nids))
1040
1041
1042
1043
1044
        print(
            "part {} get {} nodes and {} are in the partition".format(
                i, len(nodes), len(subset)
            )
        )
1045

Da Zheng's avatar
Da Zheng committed
1046
        set_roles(num_parts * 2)
1047
1048
1049
        nodes1 = node_split(node_mask, gpb, rank=i * 2, force_even=True)
        nodes2 = node_split(node_mask, gpb, rank=i * 2 + 1, force_even=True)
        nodes3, _ = F.sort_1d(F.cat([nodes1, nodes2], 0))
1050
1051
        all_nodes2.append(nodes3)
        subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(nodes3))
1052
        print("intersection has", len(subset))
1053

Da Zheng's avatar
Da Zheng committed
1054
        set_roles(num_parts)
1055
        local_eids = F.nonzero_1d(part_g.edata["inner_edge"])
1056
        local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)
1057
        edges = edge_split(edge_mask, gpb, rank=i, force_even=True)
1058
1059
        all_edges1.append(edges)
        subset = np.intersect1d(F.asnumpy(edges), F.asnumpy(local_eids))
1060
1061
1062
1063
1064
        print(
            "part {} get {} edges and {} are in the partition".format(
                i, len(edges), len(subset)
            )
        )
1065

Da Zheng's avatar
Da Zheng committed
1066
        set_roles(num_parts * 2)
1067
1068
1069
        edges1 = edge_split(edge_mask, gpb, rank=i * 2, force_even=True)
        edges2 = edge_split(edge_mask, gpb, rank=i * 2 + 1, force_even=True)
        edges3, _ = F.sort_1d(F.cat([edges1, edges2], 0))
1070
1071
        all_edges2.append(edges3)
        subset = np.intersect1d(F.asnumpy(edges), F.asnumpy(edges3))
1072
        print("intersection has", len(subset))
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
    all_nodes1 = F.cat(all_nodes1, 0)
    all_edges1 = F.cat(all_edges1, 0)
    all_nodes2 = F.cat(all_nodes2, 0)
    all_edges2 = F.cat(all_edges2, 0)
    all_nodes = np.nonzero(node_mask)[0]
    all_edges = np.nonzero(edge_mask)[0]
    assert np.all(all_nodes == F.asnumpy(all_nodes1))
    assert np.all(all_edges == F.asnumpy(all_edges1))
    assert np.all(all_nodes == F.asnumpy(all_nodes2))
    assert np.all(all_edges == F.asnumpy(all_edges2))

1084

1085
1086
def prepare_dist(num_servers=1):
    generate_ip_config("kv_ip_config.txt", 1, num_servers=num_servers)
1087

1088
1089
1090

if __name__ == "__main__":
    os.makedirs("/tmp/dist_graph", exist_ok=True)
1091
    test_dist_emb_server_client()
1092
    test_server_client()
1093
1094
    test_split(True)
    test_split(False)
1095
    test_split_even()
1096
    test_standalone()
1097
    test_standalone_node_emb()