test_dist_graph_store.py 30.3 KB
Newer Older
1
2
3
4
5
6
import os
os.environ['OMP_NUM_THREADS'] = '1'
import dgl
import sys
import numpy as np
import time
7
import socket
8
9
10
11
from scipy import sparse as spsp
from numpy.testing import assert_array_equal
from multiprocessing import Process, Manager, Condition, Value
import multiprocessing as mp
12
from dgl.heterograph_index import create_unitgraph_from_coo
13
14
from dgl.data.utils import load_graphs, save_graphs
from dgl.distributed import DistGraphServer, DistGraph
15
from dgl.distributed import partition_graph, load_partition, load_partition_book, node_split, edge_split
16
from numpy.testing import assert_almost_equal
17
import backend as F
18
import math
19
20
21
import unittest
import pickle

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
if os.name != 'nt':
    import fcntl
    import struct

def get_local_usable_addr():
    """Get local usable IP and port

    Returns
    -------
    str
        IP address, e.g., '192.168.8.12:50051'
    """
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    try:
        # doesn't even have to be reachable
        sock.connect(('10.255.255.255', 1))
        ip_addr = sock.getsockname()[0]
    except ValueError:
        ip_addr = '127.0.0.1'
    finally:
        sock.close()
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.bind(("", 0))
    sock.listen(1)
    port = sock.getsockname()[1]
    sock.close()

    return ip_addr + ' ' + str(port)

51
def create_random_graph(n):
52
    arr = (spsp.random(n, n, density=0.001, format='coo', random_state=100) != 0).astype(np.int64)
53
    return dgl.from_scipy(arr)
54

55
def run_server(graph_name, server_id, server_count, num_clients, shared_mem):
56
    g = DistGraphServer(server_id, "kv_ip_config.txt", server_count, num_clients,
57
58
                        '/tmp/dist_graph/{}.json'.format(graph_name),
                        disable_shared_mem=not shared_mem)
59
60
61
    print('start server', server_id)
    g.start()

62
63
64
def emb_init(shape, dtype):
    return F.zeros(shape, dtype, F.cpu())

65
def rand_init(shape, dtype):
66
    return F.tensor(np.random.normal(size=shape), F.float32)
67

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def check_dist_graph_empty(g, num_clients, num_nodes, num_edges):
    # Test API
    assert g.number_of_nodes() == num_nodes
    assert g.number_of_edges() == num_edges

    # Test init node data
    new_shape = (g.number_of_nodes(), 2)
    g.ndata['test1'] = dgl.distributed.DistTensor(new_shape, F.int32)
    nids = F.arange(0, int(g.number_of_nodes() / 2))
    feats = g.ndata['test1'][nids]
    assert np.all(F.asnumpy(feats) == 0)

    # create a tensor and destroy a tensor and create it again.
    test3 = dgl.distributed.DistTensor(new_shape, F.float32, 'test3', init_func=rand_init)
    del test3
    test3 = dgl.distributed.DistTensor((g.number_of_nodes(), 3), F.float32, 'test3')
    del test3

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
    g.ndata['test1'][nids] = new_feats
    feats = g.ndata['test1'][nids]
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
    assert g.node_attr_schemes()['test1'].dtype == F.int32

    print('end')

def run_client_empty(graph_name, part_id, server_count, num_clients, num_nodes, num_edges):
    time.sleep(5)
    os.environ['DGL_NUM_SERVER'] = str(server_count)
    dgl.distributed.initialize("kv_ip_config.txt")
    gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
                                                part_id, None)
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_graph_empty(g, num_clients, num_nodes, num_edges)

def check_server_client_empty(shared_mem, num_servers, num_clients):
    prepare_dist()
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
    graph_name = 'dist_graph_test_1'
    partition_graph(g, graph_name, num_parts, '/tmp/dist_graph')

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
    ctx = mp.get_context('spawn')
    for serv_id in range(num_servers):
        p = ctx.Process(target=run_server, args=(graph_name, serv_id, num_servers,
                                                 num_clients, shared_mem))
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
        print('start client', cli_id)
        p = ctx.Process(target=run_client_empty, args=(graph_name, 0, num_servers, num_clients,
                                                       g.number_of_nodes(), g.number_of_edges()))
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()

    for p in serv_ps:
        p.join()

    print('clients have terminated')

141
def run_client(graph_name, part_id, server_count, num_clients, num_nodes, num_edges):
142
    time.sleep(5)
143
144
    os.environ['DGL_NUM_SERVER'] = str(server_count)
    dgl.distributed.initialize("kv_ip_config.txt")
145
146
    gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
                                                part_id, None)
147
    g = DistGraph(graph_name, gpb=gpb)
148
    check_dist_graph(g, num_clients, num_nodes, num_edges)
149

150
151
152
153
154
155
156
157
158
def run_emb_client(graph_name, part_id, server_count, num_clients, num_nodes, num_edges):
    time.sleep(5)
    os.environ['DGL_NUM_SERVER'] = str(server_count)
    dgl.distributed.initialize("kv_ip_config.txt")
    gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
                                                part_id, None)
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_emb(g, num_clients, num_nodes, num_edges)

159
160
161
162
163
164
165
166
167
168
169
170
171
172
def run_client_hierarchy(graph_name, part_id, server_count, node_mask, edge_mask, return_dict):
    time.sleep(5)
    os.environ['DGL_NUM_SERVER'] = str(server_count)
    dgl.distributed.initialize("kv_ip_config.txt")
    gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
                                                part_id, None)
    g = DistGraph(graph_name, gpb=gpb)
    node_mask = F.tensor(node_mask)
    edge_mask = F.tensor(edge_mask)
    nodes = node_split(node_mask, g.get_partition_book(), node_trainer_ids=g.ndata['trainer_id'])
    edges = edge_split(edge_mask, g.get_partition_book(), edge_trainer_ids=g.edata['trainer_id'])
    rank = g.rank()
    return_dict[rank] = (nodes, edges)

173
174
175
176
177
178
def check_dist_emb(g, num_clients, num_nodes, num_edges):
    from dgl.distributed.optim import SparseAdagrad
    from dgl.distributed.nn import NodeEmbedding
    # Test sparse emb
    try:
        emb = NodeEmbedding(g.number_of_nodes(), 1, 'emb1', emb_init)
179
        nids = F.arange(0, int(g.number_of_nodes()))
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        lr = 0.001
        optimizer = SparseAdagrad([emb], lr=lr)
        with F.record_grad():
            feats = emb(nids)
            assert np.all(F.asnumpy(feats) == np.zeros((len(nids), 1)))
            loss = F.sum(feats + 1, 0)
        loss.backward()
        optimizer.step()
        feats = emb(nids)
        if num_clients == 1:
            assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * -lr)
        rest = np.setdiff1d(np.arange(g.number_of_nodes()), F.asnumpy(nids))
        feats1 = emb(rest)
        assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))

        policy = dgl.distributed.PartitionPolicy('node', g.get_partition_book())
        grad_sum = dgl.distributed.DistTensor((g.number_of_nodes(),), F.float32,
                                              'emb1_sum', policy)
        if num_clients == 1:
            assert np.all(F.asnumpy(grad_sum[nids]) == np.ones((len(nids), 1)) * num_clients)
        assert np.all(F.asnumpy(grad_sum[rest]) == np.zeros((len(rest), 1)))

        emb = NodeEmbedding(g.number_of_nodes(), 1, 'emb2', emb_init)
        with F.no_grad():
            feats1 = emb(nids)
        assert np.all(F.asnumpy(feats1) == 0)

        optimizer = SparseAdagrad([emb], lr=lr)
        with F.record_grad():
            feats1 = emb(nids)
            feats2 = emb(nids)
            feats = F.cat([feats1, feats2], 0)
            assert np.all(F.asnumpy(feats) == np.zeros((len(nids) * 2, 1)))
            loss = F.sum(feats + 1, 0)
        loss.backward()
        optimizer.step()
        with F.no_grad():
            feats = emb(nids)
        if num_clients == 1:
            assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * math.sqrt(2) * -lr)
        rest = np.setdiff1d(np.arange(g.number_of_nodes()), F.asnumpy(nids))
        feats1 = emb(rest)
        assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))
    except NotImplementedError as e:
        pass

226
def check_dist_graph(g, num_clients, num_nodes, num_edges):
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    # Test API
    assert g.number_of_nodes() == num_nodes
    assert g.number_of_edges() == num_edges

    # Test reading node data
    nids = F.arange(0, int(g.number_of_nodes() / 2))
    feats1 = g.ndata['features'][nids]
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == nids))

    # Test reading edge data
    eids = F.arange(0, int(g.number_of_edges() / 2))
    feats1 = g.edata['features'][eids]
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == eids))

    # Test init node data
    new_shape = (g.number_of_nodes(), 2)
245
    g.ndata['test1'] = dgl.distributed.DistTensor(new_shape, F.int32)
246
247
248
    feats = g.ndata['test1'][nids]
    assert np.all(F.asnumpy(feats) == 0)

249
    # reference to a one that exists
250
251
    test2 = dgl.distributed.DistTensor(new_shape, F.float32, 'test2', init_func=rand_init)
    test3 = dgl.distributed.DistTensor(new_shape, F.float32, 'test2')
252
253
254
    assert np.all(F.asnumpy(test2[nids]) == F.asnumpy(test3[nids]))

    # create a tensor and destroy a tensor and create it again.
255
    test3 = dgl.distributed.DistTensor(new_shape, F.float32, 'test3', init_func=rand_init)
256
    del test3
257
    test3 = dgl.distributed.DistTensor((g.number_of_nodes(), 3), F.float32, 'test3')
258
259
    del test3

Da Zheng's avatar
Da Zheng committed
260
261
262
263
264
265
266
267
    # add tests for anonymous distributed tensor.
    test3 = dgl.distributed.DistTensor(new_shape, F.float32, init_func=rand_init)
    data = test3[0:10]
    test4 = dgl.distributed.DistTensor(new_shape, F.float32, init_func=rand_init)
    del test3
    test5 = dgl.distributed.DistTensor(new_shape, F.float32, init_func=rand_init)
    assert np.sum(F.asnumpy(test5[0:10] != data)) > 0

268
    # test a persistent tesnor
269
    test4 = dgl.distributed.DistTensor(new_shape, F.float32, 'test4', init_func=rand_init,
270
271
272
                                       persistent=True)
    del test4
    try:
273
        test4 = dgl.distributed.DistTensor((g.number_of_nodes(), 3), F.float32, 'test4')
274
275
276
        raise Exception('')
    except:
        pass
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
    g.ndata['test1'][nids] = new_feats
    feats = g.ndata['test1'][nids]
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
    assert len(g.ndata['features']) == g.number_of_nodes()
    assert g.ndata['features'].shape == (g.number_of_nodes(), 1)
    assert g.ndata['features'].dtype == F.int64
    assert g.node_attr_schemes()['features'].dtype == F.int64
    assert g.node_attr_schemes()['test1'].dtype == F.int32
    assert g.node_attr_schemes()['features'].shape == (1,)

292
293
    selected_nodes = np.random.randint(0, 100, size=g.number_of_nodes()) > 30
    # Test node split
294
    nodes = node_split(selected_nodes, g.get_partition_book())
295
296
297
298
299
300
    nodes = F.asnumpy(nodes)
    # We only have one partition, so the local nodes are basically all nodes in the graph.
    local_nids = np.arange(g.number_of_nodes())
    for n in nodes:
        assert n in local_nids

301
302
    print('end')

303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
def check_dist_emb_server_client(shared_mem, num_servers, num_clients):
    prepare_dist()
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
    graph_name = 'dist_graph_test_2'
    g.ndata['features'] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata['features'] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(g, graph_name, num_parts, '/tmp/dist_graph')

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
    ctx = mp.get_context('spawn')
    for serv_id in range(num_servers):
        p = ctx.Process(target=run_server, args=(graph_name, serv_id, num_servers,
                                                 num_clients, shared_mem))
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
        print('start client', cli_id)
        p = ctx.Process(target=run_emb_client, args=(graph_name, 0, num_servers, num_clients,
                                                     g.number_of_nodes(),
                                                     g.number_of_edges()))
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()

    for p in serv_ps:
        p.join()

    print('clients have terminated')

341
def check_server_client(shared_mem, num_servers, num_clients):
342
    prepare_dist()
343
344
345
346
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
347
    graph_name = 'dist_graph_test_2'
348
349
    g.ndata['features'] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata['features'] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
350
    partition_graph(g, graph_name, num_parts, '/tmp/dist_graph')
351
352
353
354

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
355
    ctx = mp.get_context('spawn')
356
    for serv_id in range(num_servers):
357
        p = ctx.Process(target=run_server, args=(graph_name, serv_id, num_servers,
358
                                                 num_clients, shared_mem))
359
360
361
362
        serv_ps.append(p)
        p.start()

    cli_ps = []
363
    for cli_id in range(num_clients):
364
        print('start client', cli_id)
365
        p = ctx.Process(target=run_client, args=(graph_name, 0, num_servers, num_clients, g.number_of_nodes(),
366
                                                 g.number_of_edges()))
367
368
369
370
371
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
372
373
374
375

    for p in serv_ps:
        p.join()

376
377
    print('clients have terminated')

378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
def check_server_client_hierarchy(shared_mem, num_servers, num_clients):
    prepare_dist()
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
    graph_name = 'dist_graph_test_2'
    g.ndata['features'] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata['features'] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(g, graph_name, num_parts, '/tmp/dist_graph', num_trainers_per_machine=num_clients)

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
    ctx = mp.get_context('spawn')
    for serv_id in range(num_servers):
        p = ctx.Process(target=run_server, args=(graph_name, serv_id, num_servers,
                                                 num_clients, shared_mem))
        serv_ps.append(p)
        p.start()

    cli_ps = []
    manager = mp.Manager()
    return_dict = manager.dict()
    node_mask = np.zeros((g.number_of_nodes(),), np.int32)
    edge_mask = np.zeros((g.number_of_edges(),), np.int32)
    nodes = np.random.choice(g.number_of_nodes(), g.number_of_nodes() // 10, replace=False)
    edges = np.random.choice(g.number_of_edges(), g.number_of_edges() // 10, replace=False)
    node_mask[nodes] = 1
    edge_mask[edges] = 1
    nodes = np.sort(nodes)
    edges = np.sort(edges)
    for cli_id in range(num_clients):
        print('start client', cli_id)
        p = ctx.Process(target=run_client_hierarchy, args=(graph_name, 0, num_servers,
                                                           node_mask, edge_mask, return_dict))
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
    for p in serv_ps:
        p.join()

    nodes1 = []
    edges1 = []
    for n, e in return_dict.values():
        nodes1.append(n)
        edges1.append(e)
    nodes1, _ = F.sort_1d(F.cat(nodes1, 0))
    edges1, _ = F.sort_1d(F.cat(edges1, 0))
    assert np.all(F.asnumpy(nodes1) == nodes)
    assert np.all(F.asnumpy(edges1) == edges)

    print('clients have terminated')

434
435
436

def run_client_hetero(graph_name, part_id, server_count, num_clients, num_nodes, num_edges):
    time.sleep(5)
437
438
    os.environ['DGL_NUM_SERVER'] = str(server_count)
    dgl.distributed.initialize("kv_ip_config.txt")
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
    gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
                                                part_id, None)
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_graph_hetero(g, num_clients, num_nodes, num_edges)

def create_random_hetero():
    num_nodes = {'n1': 10000, 'n2': 10010, 'n3': 10020}
    etypes = [('n1', 'r1', 'n2'),
              ('n1', 'r2', 'n3'),
              ('n2', 'r3', 'n3')]
    edges = {}
    for etype in etypes:
        src_ntype, _, dst_ntype = etype
        arr = spsp.random(num_nodes[src_ntype], num_nodes[dst_ntype], density=0.001, format='coo',
                          random_state=100)
        edges[etype] = (arr.row, arr.col)
    g = dgl.heterograph(edges, num_nodes)
    g.nodes['n1'].data['feat'] = F.unsqueeze(F.arange(0, g.number_of_nodes('n1')), 1)
    g.edges['r1'].data['feat'] = F.unsqueeze(F.arange(0, g.number_of_edges('r1')), 1)
    return g

def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges):
    # Test API
    for ntype in num_nodes:
        assert ntype in g.ntypes
        assert num_nodes[ntype] == g.number_of_nodes(ntype)
    for etype in num_edges:
        assert etype in g.etypes
        assert num_edges[etype] == g.number_of_edges(etype)
    assert g.number_of_nodes() == sum([num_nodes[ntype] for ntype in num_nodes])
    assert g.number_of_edges() == sum([num_edges[etype] for etype in num_edges])

    # Test reading node data
    nids = F.arange(0, int(g.number_of_nodes('n1') / 2))
    feats1 = g.nodes['n1'].data['feat'][nids]
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == nids))

    # Test reading edge data
    eids = F.arange(0, int(g.number_of_edges('r1') / 2))
    feats1 = g.edges['r1'].data['feat'][eids]
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == eids))

    # Test init node data
    new_shape = (g.number_of_nodes('n1'), 2)
    g.nodes['n1'].data['test1'] = dgl.distributed.DistTensor(new_shape, F.int32)
    feats = g.nodes['n1'].data['test1'][nids]
    assert np.all(F.asnumpy(feats) == 0)

    # create a tensor and destroy a tensor and create it again.
    test3 = dgl.distributed.DistTensor(new_shape, F.float32, 'test3', init_func=rand_init)
    del test3
    test3 = dgl.distributed.DistTensor((g.number_of_nodes('n1'), 3), F.float32, 'test3')
    del test3

    # add tests for anonymous distributed tensor.
    test3 = dgl.distributed.DistTensor(new_shape, F.float32, init_func=rand_init)
    data = test3[0:10]
    test4 = dgl.distributed.DistTensor(new_shape, F.float32, init_func=rand_init)
    del test3
    test5 = dgl.distributed.DistTensor(new_shape, F.float32, init_func=rand_init)
    assert np.sum(F.asnumpy(test5[0:10] != data)) > 0

    # test a persistent tesnor
    test4 = dgl.distributed.DistTensor(new_shape, F.float32, 'test4', init_func=rand_init,
                                       persistent=True)
    del test4
    try:
        test4 = dgl.distributed.DistTensor((g.number_of_nodes('n1'), 3), F.float32, 'test4')
        raise Exception('')
    except:
        pass

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
    g.nodes['n1'].data['test1'][nids] = new_feats
    feats = g.nodes['n1'].data['test1'][nids]
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
    assert len(g.nodes['n1'].data['feat']) == g.number_of_nodes('n1')
    assert g.nodes['n1'].data['feat'].shape == (g.number_of_nodes('n1'), 1)
    assert g.nodes['n1'].data['feat'].dtype == F.int64

    selected_nodes = np.random.randint(0, 100, size=g.number_of_nodes('n1')) > 30
    # Test node split
    nodes = node_split(selected_nodes, g.get_partition_book(), ntype='n1')
    nodes = F.asnumpy(nodes)
    # We only have one partition, so the local nodes are basically all nodes in the graph.
    local_nids = np.arange(g.number_of_nodes('n1'))
    for n in nodes:
        assert n in local_nids

    print('end')

def check_server_client_hetero(shared_mem, num_servers, num_clients):
    prepare_dist()
    g = create_random_hetero()

    # Partition the graph
    num_parts = 1
    graph_name = 'dist_graph_test_3'
    partition_graph(g, graph_name, num_parts, '/tmp/dist_graph')

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
    ctx = mp.get_context('spawn')
    for serv_id in range(num_servers):
        p = ctx.Process(target=run_server, args=(graph_name, serv_id, num_servers,
                                                 num_clients, shared_mem))
        serv_ps.append(p)
        p.start()

    cli_ps = []
    num_nodes = {ntype: g.number_of_nodes(ntype) for ntype in g.ntypes}
    num_edges = {etype: g.number_of_edges(etype) for etype in g.etypes}
    for cli_id in range(num_clients):
        print('start client', cli_id)
        p = ctx.Process(target=run_client_hetero, args=(graph_name, 0, num_servers, num_clients, num_nodes,
                                                        num_edges))
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()

    for p in serv_ps:
        p.join()

    print('clients have terminated')

572
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
573
574
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph")
def test_server_client():
575
    os.environ['DGL_DIST_MODE'] = 'distributed'
576
    check_server_client_hierarchy(False, 1, 4)
577
    check_server_client_empty(True, 1, 1)
578
579
    check_server_client_hetero(True, 1, 1)
    check_server_client_hetero(False, 1, 1)
580
581
582
583
    check_server_client(True, 1, 1)
    check_server_client(False, 1, 1)
    check_server_client(True, 2, 2)
    check_server_client(False, 2, 2)
584

585
586
587
588
589
590
591
592
593
594
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support distributed NodeEmbedding")
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Mxnet doesn't support distributed NodeEmbedding")
def test_dist_emb_server_client():
    os.environ['DGL_DIST_MODE'] = 'distributed'
    check_dist_emb_server_client(True, 1, 1)
    check_dist_emb_server_client(False, 1, 1)
    check_dist_emb_server_client(True, 2, 2)
    check_dist_emb_server_client(False, 2, 2)

595
596
597
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph")
def test_standalone():
    os.environ['DGL_DIST_MODE'] = 'standalone'
Da Zheng's avatar
Da Zheng committed
598

599
600
601
602
603
604
605
    g = create_random_graph(10000)
    # Partition the graph
    num_parts = 1
    graph_name = 'dist_graph_test_3'
    g.ndata['features'] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata['features'] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(g, graph_name, num_parts, '/tmp/dist_graph')
606
607

    dgl.distributed.initialize("kv_ip_config.txt")
608
    dist_g = DistGraph(graph_name, part_config='/tmp/dist_graph/{}.json'.format(graph_name))
609
610
611
612
    try:
        check_dist_graph(dist_g, 1, g.number_of_nodes(), g.number_of_edges())
    except Exception as e:
        print(e)
613
    dgl.distributed.exit_client() # this is needed since there's two test here in one process
614

615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support distributed NodeEmbedding")
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Mxnet doesn't support distributed NodeEmbedding")
def test_standalone_node_emb():
    os.environ['DGL_DIST_MODE'] = 'standalone'

    g = create_random_graph(10000)
    # Partition the graph
    num_parts = 1
    graph_name = 'dist_graph_test_3'
    g.ndata['features'] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata['features'] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(g, graph_name, num_parts, '/tmp/dist_graph')

    dgl.distributed.initialize("kv_ip_config.txt")
    dist_g = DistGraph(graph_name, part_config='/tmp/dist_graph/{}.json'.format(graph_name))
    try:
        check_dist_emb(dist_g, 1, g.number_of_nodes(), g.number_of_edges())
    except Exception as e:
        print(e)
    dgl.distributed.exit_client() # this is needed since there's two test here in one process

636
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
637
def test_split():
638
    #prepare_dist()
639
640
641
    g = create_random_graph(10000)
    num_parts = 4
    num_hops = 2
642
    partition_graph(g, 'dist_graph_test', num_parts, '/tmp/dist_graph', num_hops=num_hops, part_method='metis')
643
644
645
646
647

    node_mask = np.random.randint(0, 100, size=g.number_of_nodes()) > 30
    edge_mask = np.random.randint(0, 100, size=g.number_of_edges()) > 30
    selected_nodes = np.nonzero(node_mask)[0]
    selected_edges = np.nonzero(edge_mask)[0]
Da Zheng's avatar
Da Zheng committed
648
649
650
651
652
653
654
655
656

    # The code now collects the roles of all client processes and use the information
    # to determine how to split the workloads. Here is to simulate the multi-client
    # use case.
    def set_roles(num_clients):
        dgl.distributed.role.CUR_ROLE = 'default'
        dgl.distributed.role.GLOBAL_RANK = {i:i for i in range(num_clients)}
        dgl.distributed.role.PER_ROLE_RANK['default'] = {i:i for i in range(num_clients)}

657
    for i in range(num_parts):
Da Zheng's avatar
Da Zheng committed
658
        set_roles(num_parts)
659
        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition('/tmp/dist_graph/dist_graph_test.json', i)
Da Zheng's avatar
Da Zheng committed
660
        local_nids = F.nonzero_1d(part_g.ndata['inner_node'])
661
662
        local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
        nodes1 = np.intersect1d(selected_nodes, F.asnumpy(local_nids))
663
        nodes2 = node_split(node_mask, gpb, rank=i, force_even=False)
664
665
666
667
668
        assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes2)))
        local_nids = F.asnumpy(local_nids)
        for n in nodes1:
            assert n in local_nids

Da Zheng's avatar
Da Zheng committed
669
        set_roles(num_parts * 2)
670
671
        nodes3 = node_split(node_mask, gpb, rank=i * 2, force_even=False)
        nodes4 = node_split(node_mask, gpb, rank=i * 2 + 1, force_even=False)
672
673
674
        nodes5 = F.cat([nodes3, nodes4], 0)
        assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes5)))

Da Zheng's avatar
Da Zheng committed
675
        set_roles(num_parts)
Da Zheng's avatar
Da Zheng committed
676
        local_eids = F.nonzero_1d(part_g.edata['inner_edge'])
677
678
        local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)
        edges1 = np.intersect1d(selected_edges, F.asnumpy(local_eids))
679
        edges2 = edge_split(edge_mask, gpb, rank=i, force_even=False)
680
681
682
683
684
        assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges2)))
        local_eids = F.asnumpy(local_eids)
        for e in edges1:
            assert e in local_eids

Da Zheng's avatar
Da Zheng committed
685
        set_roles(num_parts * 2)
686
687
        edges3 = edge_split(edge_mask, gpb, rank=i * 2, force_even=False)
        edges4 = edge_split(edge_mask, gpb, rank=i * 2 + 1, force_even=False)
688
689
690
        edges5 = F.cat([edges3, edges4], 0)
        assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges5)))

691
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
692
def test_split_even():
693
    #prepare_dist(1)
694
695
696
697
698
699
700
701
702
703
704
705
706
    g = create_random_graph(10000)
    num_parts = 4
    num_hops = 2
    partition_graph(g, 'dist_graph_test', num_parts, '/tmp/dist_graph', num_hops=num_hops, part_method='metis')

    node_mask = np.random.randint(0, 100, size=g.number_of_nodes()) > 30
    edge_mask = np.random.randint(0, 100, size=g.number_of_edges()) > 30
    selected_nodes = np.nonzero(node_mask)[0]
    selected_edges = np.nonzero(edge_mask)[0]
    all_nodes1 = []
    all_nodes2 = []
    all_edges1 = []
    all_edges2 = []
Da Zheng's avatar
Da Zheng committed
707
708
709
710
711
712
713
714
715

    # The code now collects the roles of all client processes and use the information
    # to determine how to split the workloads. Here is to simulate the multi-client
    # use case.
    def set_roles(num_clients):
        dgl.distributed.role.CUR_ROLE = 'default'
        dgl.distributed.role.GLOBAL_RANK = {i:i for i in range(num_clients)}
        dgl.distributed.role.PER_ROLE_RANK['default'] = {i:i for i in range(num_clients)}

716
    for i in range(num_parts):
Da Zheng's avatar
Da Zheng committed
717
        set_roles(num_parts)
718
        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition('/tmp/dist_graph/dist_graph_test.json', i)
719
720
        local_nids = F.nonzero_1d(part_g.ndata['inner_node'])
        local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
721
        nodes = node_split(node_mask, gpb, rank=i, force_even=True)
722
723
724
725
        all_nodes1.append(nodes)
        subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(local_nids))
        print('part {} get {} nodes and {} are in the partition'.format(i, len(nodes), len(subset)))

Da Zheng's avatar
Da Zheng committed
726
        set_roles(num_parts * 2)
727
728
729
        nodes1 = node_split(node_mask, gpb, rank=i * 2, force_even=True)
        nodes2 = node_split(node_mask, gpb, rank=i * 2 + 1, force_even=True)
        nodes3, _ = F.sort_1d(F.cat([nodes1, nodes2], 0))
730
731
732
733
        all_nodes2.append(nodes3)
        subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(nodes3))
        print('intersection has', len(subset))

Da Zheng's avatar
Da Zheng committed
734
        set_roles(num_parts)
735
736
        local_eids = F.nonzero_1d(part_g.edata['inner_edge'])
        local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)
737
        edges = edge_split(edge_mask, gpb, rank=i, force_even=True)
738
739
740
741
        all_edges1.append(edges)
        subset = np.intersect1d(F.asnumpy(edges), F.asnumpy(local_eids))
        print('part {} get {} edges and {} are in the partition'.format(i, len(edges), len(subset)))

Da Zheng's avatar
Da Zheng committed
742
        set_roles(num_parts * 2)
743
744
745
        edges1 = edge_split(edge_mask, gpb, rank=i * 2, force_even=True)
        edges2 = edge_split(edge_mask, gpb, rank=i * 2 + 1, force_even=True)
        edges3, _ = F.sort_1d(F.cat([edges1, edges2], 0))
746
747
748
749
750
751
752
753
754
755
756
757
758
759
        all_edges2.append(edges3)
        subset = np.intersect1d(F.asnumpy(edges), F.asnumpy(edges3))
        print('intersection has', len(subset))
    all_nodes1 = F.cat(all_nodes1, 0)
    all_edges1 = F.cat(all_edges1, 0)
    all_nodes2 = F.cat(all_nodes2, 0)
    all_edges2 = F.cat(all_edges2, 0)
    all_nodes = np.nonzero(node_mask)[0]
    all_edges = np.nonzero(edge_mask)[0]
    assert np.all(all_nodes == F.asnumpy(all_nodes1))
    assert np.all(all_edges == F.asnumpy(all_edges1))
    assert np.all(all_nodes == F.asnumpy(all_nodes2))
    assert np.all(all_edges == F.asnumpy(all_edges2))

760
def prepare_dist():
761
    ip_config = open("kv_ip_config.txt", "w")
762
    ip_addr = get_local_usable_addr()
763
    ip_config.write('{}\n'.format(ip_addr))
764
765
    ip_config.close()

766
if __name__ == '__main__':
Da Zheng's avatar
Da Zheng committed
767
    os.makedirs('/tmp/dist_graph', exist_ok=True)
768
    test_server_client()
769
770
    test_split()
    test_split_even()
771
    test_standalone()
772
773

    test_standalone_node_emb()