test_dist_graph_store.py 33.2 KB
Newer Older
1
2
3
4
5
6
import os
os.environ['OMP_NUM_THREADS'] = '1'
import dgl
import sys
import numpy as np
import time
7
import socket
8
9
10
11
from scipy import sparse as spsp
from numpy.testing import assert_array_equal
from multiprocessing import Process, Manager, Condition, Value
import multiprocessing as mp
12
from dgl.heterograph_index import create_unitgraph_from_coo
13
14
from dgl.data.utils import load_graphs, save_graphs
from dgl.distributed import DistGraphServer, DistGraph
15
from dgl.distributed import partition_graph, load_partition, load_partition_book, node_split, edge_split
16
from numpy.testing import assert_almost_equal
17
import backend as F
18
import math
19
20
import unittest
import pickle
21
from utils import reset_envs, generate_ip_config, create_random_graph
22
import pytest
23

24
25
26
27
if os.name != 'nt':
    import fcntl
    import struct

28
def run_server(graph_name, server_id, server_count, num_clients, shared_mem, keep_alive=False):
29
    g = DistGraphServer(server_id, "kv_ip_config.txt", server_count, num_clients,
30
                        '/tmp/dist_graph/{}.json'.format(graph_name),
31
                        disable_shared_mem=not shared_mem,
32
                        graph_format=['csc', 'coo'], keep_alive=keep_alive)
33
    print('start server', server_id)
34
35
36
37
38
39
40
41
42
    # verify dtype of underlying graph
    cg = g.client_g
    for k, dtype in dgl.distributed.dist_graph.FIELD_DICT.items():
        if k in cg.ndata:
            assert F.dtype(
                cg.ndata[k]) == dtype, "Data type of {} in ndata should be {}.".format(k, dtype)
        if k in cg.edata:
            assert F.dtype(
                cg.edata[k]) == dtype, "Data type of {} in edata should be {}.".format(k, dtype)
43
44
    g.start()

45
46
47
def emb_init(shape, dtype):
    return F.zeros(shape, dtype, F.cpu())

48
def rand_init(shape, dtype):
49
    return F.tensor(np.random.normal(size=shape), F.float32)
50

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def check_dist_graph_empty(g, num_clients, num_nodes, num_edges):
    # Test API
    assert g.number_of_nodes() == num_nodes
    assert g.number_of_edges() == num_edges

    # Test init node data
    new_shape = (g.number_of_nodes(), 2)
    g.ndata['test1'] = dgl.distributed.DistTensor(new_shape, F.int32)
    nids = F.arange(0, int(g.number_of_nodes() / 2))
    feats = g.ndata['test1'][nids]
    assert np.all(F.asnumpy(feats) == 0)

    # create a tensor and destroy a tensor and create it again.
    test3 = dgl.distributed.DistTensor(new_shape, F.float32, 'test3', init_func=rand_init)
    del test3
    test3 = dgl.distributed.DistTensor((g.number_of_nodes(), 3), F.float32, 'test3')
    del test3

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
    g.ndata['test1'][nids] = new_feats
    feats = g.ndata['test1'][nids]
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
    assert g.node_attr_schemes()['test1'].dtype == F.int32

    print('end')

def run_client_empty(graph_name, part_id, server_count, num_clients, num_nodes, num_edges):
    os.environ['DGL_NUM_SERVER'] = str(server_count)
    dgl.distributed.initialize("kv_ip_config.txt")
    gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
                                                part_id, None)
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_graph_empty(g, num_clients, num_nodes, num_edges)

def check_server_client_empty(shared_mem, num_servers, num_clients):
89
    prepare_dist(num_servers)
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
    graph_name = 'dist_graph_test_1'
    partition_graph(g, graph_name, num_parts, '/tmp/dist_graph')

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
    ctx = mp.get_context('spawn')
    for serv_id in range(num_servers):
        p = ctx.Process(target=run_server, args=(graph_name, serv_id, num_servers,
                                                 num_clients, shared_mem))
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
        print('start client', cli_id)
        p = ctx.Process(target=run_client_empty, args=(graph_name, 0, num_servers, num_clients,
                                                       g.number_of_nodes(), g.number_of_edges()))
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()

    for p in serv_ps:
        p.join()

    print('clients have terminated')

123
def run_client(graph_name, part_id, server_count, num_clients, num_nodes, num_edges, group_id):
124
    os.environ['DGL_NUM_SERVER'] = str(server_count)
125
    os.environ['DGL_GROUP_ID'] = str(group_id)
126
    dgl.distributed.initialize("kv_ip_config.txt")
127
128
    gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
                                                part_id, None)
129
    g = DistGraph(graph_name, gpb=gpb)
130
    check_dist_graph(g, num_clients, num_nodes, num_edges)
131

132
def run_emb_client(graph_name, part_id, server_count, num_clients, num_nodes, num_edges, group_id):
133
    os.environ['DGL_NUM_SERVER'] = str(server_count)
134
    os.environ['DGL_GROUP_ID'] = str(group_id)
135
136
137
138
139
140
    dgl.distributed.initialize("kv_ip_config.txt")
    gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
                                                part_id, None)
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_emb(g, num_clients, num_nodes, num_edges)

141
142
143
144
145
146
147
148
149
150
151
152
153
def run_client_hierarchy(graph_name, part_id, server_count, node_mask, edge_mask, return_dict):
    os.environ['DGL_NUM_SERVER'] = str(server_count)
    dgl.distributed.initialize("kv_ip_config.txt")
    gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
                                                part_id, None)
    g = DistGraph(graph_name, gpb=gpb)
    node_mask = F.tensor(node_mask)
    edge_mask = F.tensor(edge_mask)
    nodes = node_split(node_mask, g.get_partition_book(), node_trainer_ids=g.ndata['trainer_id'])
    edges = edge_split(edge_mask, g.get_partition_book(), edge_trainer_ids=g.edata['trainer_id'])
    rank = g.rank()
    return_dict[rank] = (nodes, edges)

154
155
def check_dist_emb(g, num_clients, num_nodes, num_edges):
    from dgl.distributed.optim import SparseAdagrad
156
    from dgl.distributed import DistEmbedding
157
158
    # Test sparse emb
    try:
159
        emb = DistEmbedding(g.number_of_nodes(), 1, 'emb1', emb_init)
160
        nids = F.arange(0, int(g.number_of_nodes()))
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
        lr = 0.001
        optimizer = SparseAdagrad([emb], lr=lr)
        with F.record_grad():
            feats = emb(nids)
            assert np.all(F.asnumpy(feats) == np.zeros((len(nids), 1)))
            loss = F.sum(feats + 1, 0)
        loss.backward()
        optimizer.step()
        feats = emb(nids)
        if num_clients == 1:
            assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * -lr)
        rest = np.setdiff1d(np.arange(g.number_of_nodes()), F.asnumpy(nids))
        feats1 = emb(rest)
        assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))

        policy = dgl.distributed.PartitionPolicy('node', g.get_partition_book())
177
        grad_sum = dgl.distributed.DistTensor((g.number_of_nodes(), 1), F.float32,
178
179
180
181
182
                                              'emb1_sum', policy)
        if num_clients == 1:
            assert np.all(F.asnumpy(grad_sum[nids]) == np.ones((len(nids), 1)) * num_clients)
        assert np.all(F.asnumpy(grad_sum[rest]) == np.zeros((len(rest), 1)))

183
        emb = DistEmbedding(g.number_of_nodes(), 1, 'emb2', emb_init)
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
        with F.no_grad():
            feats1 = emb(nids)
        assert np.all(F.asnumpy(feats1) == 0)

        optimizer = SparseAdagrad([emb], lr=lr)
        with F.record_grad():
            feats1 = emb(nids)
            feats2 = emb(nids)
            feats = F.cat([feats1, feats2], 0)
            assert np.all(F.asnumpy(feats) == np.zeros((len(nids) * 2, 1)))
            loss = F.sum(feats + 1, 0)
        loss.backward()
        optimizer.step()
        with F.no_grad():
            feats = emb(nids)
        if num_clients == 1:
200
            assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * 1 * -lr)
201
202
203
204
205
        rest = np.setdiff1d(np.arange(g.number_of_nodes()), F.asnumpy(nids))
        feats1 = emb(rest)
        assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))
    except NotImplementedError as e:
        pass
206
207
208
    except Exception as e:
        print(e)
        sys.exit(-1)
209

210
def check_dist_graph(g, num_clients, num_nodes, num_edges):
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
    # Test API
    assert g.number_of_nodes() == num_nodes
    assert g.number_of_edges() == num_edges

    # Test reading node data
    nids = F.arange(0, int(g.number_of_nodes() / 2))
    feats1 = g.ndata['features'][nids]
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == nids))

    # Test reading edge data
    eids = F.arange(0, int(g.number_of_edges() / 2))
    feats1 = g.edata['features'][eids]
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == eids))

227
228
229
230
231
    # Test edge_subgraph
    sg = g.edge_subgraph(eids)
    assert sg.num_edges() == len(eids)
    assert F.array_equal(sg.edata[dgl.EID], eids)

232
233
    # Test init node data
    new_shape = (g.number_of_nodes(), 2)
234
235
    test1 = dgl.distributed.DistTensor(new_shape, F.int32)
    g.ndata['test1'] = test1
236
237
    feats = g.ndata['test1'][nids]
    assert np.all(F.asnumpy(feats) == 0)
238
    assert test1.count_nonzero() == 0
239

240
    # reference to a one that exists
241
242
    test2 = dgl.distributed.DistTensor(new_shape, F.float32, 'test2', init_func=rand_init)
    test3 = dgl.distributed.DistTensor(new_shape, F.float32, 'test2')
243
244
245
    assert np.all(F.asnumpy(test2[nids]) == F.asnumpy(test3[nids]))

    # create a tensor and destroy a tensor and create it again.
246
    test3 = dgl.distributed.DistTensor(new_shape, F.float32, 'test3', init_func=rand_init)
247
    del test3
248
    test3 = dgl.distributed.DistTensor((g.number_of_nodes(), 3), F.float32, 'test3')
249
250
    del test3

Da Zheng's avatar
Da Zheng committed
251
252
253
254
255
256
257
258
    # add tests for anonymous distributed tensor.
    test3 = dgl.distributed.DistTensor(new_shape, F.float32, init_func=rand_init)
    data = test3[0:10]
    test4 = dgl.distributed.DistTensor(new_shape, F.float32, init_func=rand_init)
    del test3
    test5 = dgl.distributed.DistTensor(new_shape, F.float32, init_func=rand_init)
    assert np.sum(F.asnumpy(test5[0:10] != data)) > 0

259
    # test a persistent tesnor
260
    test4 = dgl.distributed.DistTensor(new_shape, F.float32, 'test4', init_func=rand_init,
261
262
263
                                       persistent=True)
    del test4
    try:
264
        test4 = dgl.distributed.DistTensor((g.number_of_nodes(), 3), F.float32, 'test4')
265
266
267
        raise Exception('')
    except:
        pass
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
    g.ndata['test1'][nids] = new_feats
    feats = g.ndata['test1'][nids]
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
    assert len(g.ndata['features']) == g.number_of_nodes()
    assert g.ndata['features'].shape == (g.number_of_nodes(), 1)
    assert g.ndata['features'].dtype == F.int64
    assert g.node_attr_schemes()['features'].dtype == F.int64
    assert g.node_attr_schemes()['test1'].dtype == F.int32
    assert g.node_attr_schemes()['features'].shape == (1,)

283
284
    selected_nodes = np.random.randint(0, 100, size=g.number_of_nodes()) > 30
    # Test node split
285
    nodes = node_split(selected_nodes, g.get_partition_book())
286
287
288
289
290
291
    nodes = F.asnumpy(nodes)
    # We only have one partition, so the local nodes are basically all nodes in the graph.
    local_nids = np.arange(g.number_of_nodes())
    for n in nodes:
        assert n in local_nids

292
293
    print('end')

294
def check_dist_emb_server_client(shared_mem, num_servers, num_clients, num_groups=1):
295
    prepare_dist(num_servers)
296
297
298
299
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
300
    graph_name = f'check_dist_emb_{shared_mem}_{num_servers}_{num_clients}_{num_groups}'
301
302
303
304
305
306
307
308
    g.ndata['features'] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata['features'] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(g, graph_name, num_parts, '/tmp/dist_graph')

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
    ctx = mp.get_context('spawn')
309
    keep_alive = num_groups > 1
310
311
    for serv_id in range(num_servers):
        p = ctx.Process(target=run_server, args=(graph_name, serv_id, num_servers,
312
                                                 num_clients, shared_mem, keep_alive))
313
314
315
316
317
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
318
319
320
321
322
323
324
        for group_id in range(num_groups):
            print('start client[{}] for group[{}]'.format(cli_id, group_id))
            p = ctx.Process(target=run_emb_client, args=(graph_name, 0, num_servers, num_clients,
                                                        g.number_of_nodes(),
                                                        g.number_of_edges(),
                                                        group_id))
            p.start()
325
            time.sleep(1) # avoid race condition when instantiating DistGraph
326
            cli_ps.append(p)
327
328
329

    for p in cli_ps:
        p.join()
330
        assert p.exitcode == 0
331

332
333
334
335
336
    if keep_alive:
        for p in serv_ps:
            assert p.is_alive()
        # force shutdown server
        dgl.distributed.shutdown_servers("kv_ip_config.txt", num_servers)
337
338
339
340
341
    for p in serv_ps:
        p.join()

    print('clients have terminated')

342
def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
343
    prepare_dist(num_servers)
344
345
346
347
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
348
    graph_name = f'check_server_client_{shared_mem}_{num_servers}_{num_clients}_{num_groups}'
349
350
    g.ndata['features'] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata['features'] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
351
    partition_graph(g, graph_name, num_parts, '/tmp/dist_graph')
352
353
354
355

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
356
    ctx = mp.get_context('spawn')
357
    keep_alive = num_groups > 1
358
    for serv_id in range(num_servers):
359
        p = ctx.Process(target=run_server, args=(graph_name, serv_id, num_servers,
360
                                                 num_clients, shared_mem, keep_alive))
361
362
363
        serv_ps.append(p)
        p.start()

364
    # launch different client groups simultaneously
365
    cli_ps = []
366
    for cli_id in range(num_clients):
367
368
369
370
371
        for group_id in range(num_groups):
            print('start client[{}] for group[{}]'.format(cli_id, group_id))
            p = ctx.Process(target=run_client, args=(graph_name, 0, num_servers, num_clients, g.number_of_nodes(),
                                                    g.number_of_edges(), group_id))
            p.start()
372
            time.sleep(1) # avoid race condition when instantiating DistGraph
373
            cli_ps.append(p)
374
375
    for p in cli_ps:
        p.join()
376

377
378
379
380
381
    if keep_alive:
        for p in serv_ps:
            assert p.is_alive()
        # force shutdown server
        dgl.distributed.shutdown_servers("kv_ip_config.txt", num_servers)
382
383
384
    for p in serv_ps:
        p.join()

385
386
    print('clients have terminated')

387
def check_server_client_hierarchy(shared_mem, num_servers, num_clients):
388
    prepare_dist(num_servers)
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
    g = create_random_graph(10000)

    # Partition the graph
    num_parts = 1
    graph_name = 'dist_graph_test_2'
    g.ndata['features'] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata['features'] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(g, graph_name, num_parts, '/tmp/dist_graph', num_trainers_per_machine=num_clients)

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
    ctx = mp.get_context('spawn')
    for serv_id in range(num_servers):
        p = ctx.Process(target=run_server, args=(graph_name, serv_id, num_servers,
                                                 num_clients, shared_mem))
        serv_ps.append(p)
        p.start()

    cli_ps = []
    manager = mp.Manager()
    return_dict = manager.dict()
    node_mask = np.zeros((g.number_of_nodes(),), np.int32)
    edge_mask = np.zeros((g.number_of_edges(),), np.int32)
    nodes = np.random.choice(g.number_of_nodes(), g.number_of_nodes() // 10, replace=False)
    edges = np.random.choice(g.number_of_edges(), g.number_of_edges() // 10, replace=False)
    node_mask[nodes] = 1
    edge_mask[edges] = 1
    nodes = np.sort(nodes)
    edges = np.sort(edges)
    for cli_id in range(num_clients):
        print('start client', cli_id)
        p = ctx.Process(target=run_client_hierarchy, args=(graph_name, 0, num_servers,
                                                           node_mask, edge_mask, return_dict))
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()
    for p in serv_ps:
        p.join()

    nodes1 = []
    edges1 = []
    for n, e in return_dict.values():
        nodes1.append(n)
        edges1.append(e)
    nodes1, _ = F.sort_1d(F.cat(nodes1, 0))
    edges1, _ = F.sort_1d(F.cat(edges1, 0))
    assert np.all(F.asnumpy(nodes1) == nodes)
    assert np.all(F.asnumpy(edges1) == edges)

    print('clients have terminated')

443
444

def run_client_hetero(graph_name, part_id, server_count, num_clients, num_nodes, num_edges):
445
446
    os.environ['DGL_NUM_SERVER'] = str(server_count)
    dgl.distributed.initialize("kv_ip_config.txt")
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
    gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
                                                part_id, None)
    g = DistGraph(graph_name, gpb=gpb)
    check_dist_graph_hetero(g, num_clients, num_nodes, num_edges)

def create_random_hetero():
    num_nodes = {'n1': 10000, 'n2': 10010, 'n3': 10020}
    etypes = [('n1', 'r1', 'n2'),
              ('n1', 'r2', 'n3'),
              ('n2', 'r3', 'n3')]
    edges = {}
    for etype in etypes:
        src_ntype, _, dst_ntype = etype
        arr = spsp.random(num_nodes[src_ntype], num_nodes[dst_ntype], density=0.001, format='coo',
                          random_state=100)
        edges[etype] = (arr.row, arr.col)
    g = dgl.heterograph(edges, num_nodes)
    g.nodes['n1'].data['feat'] = F.unsqueeze(F.arange(0, g.number_of_nodes('n1')), 1)
    g.edges['r1'].data['feat'] = F.unsqueeze(F.arange(0, g.number_of_edges('r1')), 1)
    return g

def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges):
    # Test API
    for ntype in num_nodes:
        assert ntype in g.ntypes
        assert num_nodes[ntype] == g.number_of_nodes(ntype)
    for etype in num_edges:
        assert etype in g.etypes
        assert num_edges[etype] == g.number_of_edges(etype)
476
477
478
479
480
481
482
    etypes = [('n1', 'r1', 'n2'),
              ('n1', 'r2', 'n3'),
              ('n2', 'r3', 'n3')]
    for i, etype in enumerate(g.canonical_etypes):
        assert etype[0] == etypes[i][0]
        assert etype[1] == etypes[i][1]
        assert etype[2] == etypes[i][2]
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
    assert g.number_of_nodes() == sum([num_nodes[ntype] for ntype in num_nodes])
    assert g.number_of_edges() == sum([num_edges[etype] for etype in num_edges])

    # Test reading node data
    nids = F.arange(0, int(g.number_of_nodes('n1') / 2))
    feats1 = g.nodes['n1'].data['feat'][nids]
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == nids))

    # Test reading edge data
    eids = F.arange(0, int(g.number_of_edges('r1') / 2))
    feats1 = g.edges['r1'].data['feat'][eids]
    feats = F.squeeze(feats1, 1)
    assert np.all(F.asnumpy(feats == eids))

498
499
500
501
502
503
504
505
    # Test edge_subgraph
    sg = g.edge_subgraph({'r1': eids})
    assert sg.num_edges() == len(eids)
    assert F.array_equal(sg.edata[dgl.EID], eids)
    sg = g.edge_subgraph({('n1', 'r1', 'n2'): eids})
    assert sg.num_edges() == len(eids)
    assert F.array_equal(sg.edata[dgl.EID], eids)

506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
    # Test init node data
    new_shape = (g.number_of_nodes('n1'), 2)
    g.nodes['n1'].data['test1'] = dgl.distributed.DistTensor(new_shape, F.int32)
    feats = g.nodes['n1'].data['test1'][nids]
    assert np.all(F.asnumpy(feats) == 0)

    # create a tensor and destroy a tensor and create it again.
    test3 = dgl.distributed.DistTensor(new_shape, F.float32, 'test3', init_func=rand_init)
    del test3
    test3 = dgl.distributed.DistTensor((g.number_of_nodes('n1'), 3), F.float32, 'test3')
    del test3

    # add tests for anonymous distributed tensor.
    test3 = dgl.distributed.DistTensor(new_shape, F.float32, init_func=rand_init)
    data = test3[0:10]
    test4 = dgl.distributed.DistTensor(new_shape, F.float32, init_func=rand_init)
    del test3
    test5 = dgl.distributed.DistTensor(new_shape, F.float32, init_func=rand_init)
    assert np.sum(F.asnumpy(test5[0:10] != data)) > 0

    # test a persistent tesnor
    test4 = dgl.distributed.DistTensor(new_shape, F.float32, 'test4', init_func=rand_init,
                                       persistent=True)
    del test4
    try:
        test4 = dgl.distributed.DistTensor((g.number_of_nodes('n1'), 3), F.float32, 'test4')
        raise Exception('')
    except:
        pass

    # Test write data
    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())
    g.nodes['n1'].data['test1'][nids] = new_feats
    feats = g.nodes['n1'].data['test1'][nids]
    assert np.all(F.asnumpy(feats) == 1)

    # Test metadata operations.
    assert len(g.nodes['n1'].data['feat']) == g.number_of_nodes('n1')
    assert g.nodes['n1'].data['feat'].shape == (g.number_of_nodes('n1'), 1)
    assert g.nodes['n1'].data['feat'].dtype == F.int64

    selected_nodes = np.random.randint(0, 100, size=g.number_of_nodes('n1')) > 30
    # Test node split
    nodes = node_split(selected_nodes, g.get_partition_book(), ntype='n1')
    nodes = F.asnumpy(nodes)
    # We only have one partition, so the local nodes are basically all nodes in the graph.
    local_nids = np.arange(g.number_of_nodes('n1'))
    for n in nodes:
        assert n in local_nids

    print('end')

def check_server_client_hetero(shared_mem, num_servers, num_clients):
559
    prepare_dist(num_servers)
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
    g = create_random_hetero()

    # Partition the graph
    num_parts = 1
    graph_name = 'dist_graph_test_3'
    partition_graph(g, graph_name, num_parts, '/tmp/dist_graph')

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
    ctx = mp.get_context('spawn')
    for serv_id in range(num_servers):
        p = ctx.Process(target=run_server, args=(graph_name, serv_id, num_servers,
                                                 num_clients, shared_mem))
        serv_ps.append(p)
        p.start()

    cli_ps = []
    num_nodes = {ntype: g.number_of_nodes(ntype) for ntype in g.ntypes}
    num_edges = {etype: g.number_of_edges(etype) for etype in g.etypes}
    for cli_id in range(num_clients):
        print('start client', cli_id)
        p = ctx.Process(target=run_client_hetero, args=(graph_name, 0, num_servers, num_clients, num_nodes,
                                                        num_edges))
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()

    for p in serv_ps:
        p.join()

    print('clients have terminated')

595
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
596
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph")
597
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support")
598
def test_server_client():
599
    reset_envs()
600
    os.environ['DGL_DIST_MODE'] = 'distributed'
601
    check_server_client_hierarchy(False, 1, 4)
602
    check_server_client_empty(True, 1, 1)
603
604
    check_server_client_hetero(True, 1, 1)
    check_server_client_hetero(False, 1, 1)
605
606
    check_server_client(True, 1, 1)
    check_server_client(False, 1, 1)
607
608
609
610
611
612
    # [TODO][Rhett] Tests for multiple groups may fail sometimes and
    # root cause is unknown. Let's disable them for now.
    #check_server_client(True, 2, 2)
    #check_server_client(True, 1, 1, 2)
    #check_server_client(False, 1, 1, 2)
    #check_server_client(True, 2, 2, 2)
613

614
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
615
616
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support distributed DistEmbedding")
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Mxnet doesn't support distributed DistEmbedding")
617
def test_dist_emb_server_client():
618
    reset_envs()
619
620
621
    os.environ['DGL_DIST_MODE'] = 'distributed'
    check_dist_emb_server_client(True, 1, 1)
    check_dist_emb_server_client(False, 1, 1)
622
623
624
625
626
627
    # [TODO][Rhett] Tests for multiple groups may fail sometimes and
    # root cause is unknown. Let's disable them for now.
    #check_dist_emb_server_client(True, 2, 2)
    #check_dist_emb_server_client(True, 1, 1, 2)
    #check_dist_emb_server_client(False, 1, 1, 2)
    #check_dist_emb_server_client(True, 2, 2, 2)
628

629
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph")
630
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support")
631
def test_standalone():
632
    reset_envs()
633
    os.environ['DGL_DIST_MODE'] = 'standalone'
Da Zheng's avatar
Da Zheng committed
634

635
636
637
638
639
640
641
    g = create_random_graph(10000)
    # Partition the graph
    num_parts = 1
    graph_name = 'dist_graph_test_3'
    g.ndata['features'] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata['features'] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(g, graph_name, num_parts, '/tmp/dist_graph')
642
643

    dgl.distributed.initialize("kv_ip_config.txt")
644
    dist_g = DistGraph(graph_name, part_config='/tmp/dist_graph/{}.json'.format(graph_name))
645
    check_dist_graph(dist_g, 1, g.number_of_nodes(), g.number_of_edges())
646
    dgl.distributed.exit_client() # this is needed since there's two test here in one process
647

648
649
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support distributed DistEmbedding")
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Mxnet doesn't support distributed DistEmbedding")
650
def test_standalone_node_emb():
651
    reset_envs()
652
653
654
655
656
657
658
659
660
661
662
663
    os.environ['DGL_DIST_MODE'] = 'standalone'

    g = create_random_graph(10000)
    # Partition the graph
    num_parts = 1
    graph_name = 'dist_graph_test_3'
    g.ndata['features'] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
    g.edata['features'] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
    partition_graph(g, graph_name, num_parts, '/tmp/dist_graph')

    dgl.distributed.initialize("kv_ip_config.txt")
    dist_g = DistGraph(graph_name, part_config='/tmp/dist_graph/{}.json'.format(graph_name))
664
    check_dist_emb(dist_g, 1, g.number_of_nodes(), g.number_of_edges())
665
666
    dgl.distributed.exit_client() # this is needed since there's two test here in one process

667
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
668
669
670
671
672
673
674
675
676
677
@pytest.mark.parametrize("hetero", [True, False])
def test_split(hetero):
    if hetero:
        g = create_random_hetero()
        ntype = 'n1'
        etype = 'r1'
    else:
        g = create_random_graph(10000)
        ntype = '_N'
        etype = '_E'
678
679
    num_parts = 4
    num_hops = 2
680
    partition_graph(g, 'dist_graph_test', num_parts, '/tmp/dist_graph', num_hops=num_hops, part_method='metis')
681

682
683
    node_mask = np.random.randint(0, 100, size=g.number_of_nodes(ntype)) > 30
    edge_mask = np.random.randint(0, 100, size=g.number_of_edges(etype)) > 30
684
685
    selected_nodes = np.nonzero(node_mask)[0]
    selected_edges = np.nonzero(edge_mask)[0]
Da Zheng's avatar
Da Zheng committed
686
687
688
689
690
691
692
693
694

    # The code now collects the roles of all client processes and use the information
    # to determine how to split the workloads. Here is to simulate the multi-client
    # use case.
    def set_roles(num_clients):
        dgl.distributed.role.CUR_ROLE = 'default'
        dgl.distributed.role.GLOBAL_RANK = {i:i for i in range(num_clients)}
        dgl.distributed.role.PER_ROLE_RANK['default'] = {i:i for i in range(num_clients)}

695
    for i in range(num_parts):
Da Zheng's avatar
Da Zheng committed
696
        set_roles(num_parts)
697
        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition('/tmp/dist_graph/dist_graph_test.json', i)
Da Zheng's avatar
Da Zheng committed
698
        local_nids = F.nonzero_1d(part_g.ndata['inner_node'])
699
        local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
700
701
702
703
704
705
706
        if hetero:
            ntype_ids, nids = gpb.map_to_per_ntype(local_nids)
            local_nids = F.asnumpy(nids)[F.asnumpy(ntype_ids) == 0]
        else:
            local_nids = F.asnumpy(local_nids)
        nodes1 = np.intersect1d(selected_nodes, local_nids)
        nodes2 = node_split(node_mask, gpb, ntype=ntype, rank=i, force_even=False)
707
        assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes2)))
708
        for n in F.asnumpy(nodes2):
709
710
            assert n in local_nids

Da Zheng's avatar
Da Zheng committed
711
        set_roles(num_parts * 2)
712
713
        nodes3 = node_split(node_mask, gpb, ntype=ntype, rank=i * 2, force_even=False)
        nodes4 = node_split(node_mask, gpb, ntype=ntype, rank=i * 2 + 1, force_even=False)
714
715
716
        nodes5 = F.cat([nodes3, nodes4], 0)
        assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes5)))

Da Zheng's avatar
Da Zheng committed
717
        set_roles(num_parts)
Da Zheng's avatar
Da Zheng committed
718
        local_eids = F.nonzero_1d(part_g.edata['inner_edge'])
719
        local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)
720
721
722
723
724
725
726
        if hetero:
            etype_ids, eids = gpb.map_to_per_etype(local_eids)
            local_eids = F.asnumpy(eids)[F.asnumpy(etype_ids) == 0]
        else:
            local_eids = F.asnumpy(local_eids)
        edges1 = np.intersect1d(selected_edges, local_eids)
        edges2 = edge_split(edge_mask, gpb, etype=etype, rank=i, force_even=False)
727
        assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges2)))
728
        for e in F.asnumpy(edges2):
729
730
            assert e in local_eids

Da Zheng's avatar
Da Zheng committed
731
        set_roles(num_parts * 2)
732
733
        edges3 = edge_split(edge_mask, gpb, etype=etype, rank=i * 2, force_even=False)
        edges4 = edge_split(edge_mask, gpb, etype=etype, rank=i * 2 + 1, force_even=False)
734
735
736
        edges5 = F.cat([edges3, edges4], 0)
        assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges5)))

737
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
738
739
740
741
742
743
744
745
746
747
748
749
750
751
def test_split_even():
    g = create_random_graph(10000)
    num_parts = 4
    num_hops = 2
    partition_graph(g, 'dist_graph_test', num_parts, '/tmp/dist_graph', num_hops=num_hops, part_method='metis')

    node_mask = np.random.randint(0, 100, size=g.number_of_nodes()) > 30
    edge_mask = np.random.randint(0, 100, size=g.number_of_edges()) > 30
    selected_nodes = np.nonzero(node_mask)[0]
    selected_edges = np.nonzero(edge_mask)[0]
    all_nodes1 = []
    all_nodes2 = []
    all_edges1 = []
    all_edges2 = []
Da Zheng's avatar
Da Zheng committed
752
753
754
755
756
757
758
759
760

    # The code now collects the roles of all client processes and use the information
    # to determine how to split the workloads. Here is to simulate the multi-client
    # use case.
    def set_roles(num_clients):
        dgl.distributed.role.CUR_ROLE = 'default'
        dgl.distributed.role.GLOBAL_RANK = {i:i for i in range(num_clients)}
        dgl.distributed.role.PER_ROLE_RANK['default'] = {i:i for i in range(num_clients)}

761
    for i in range(num_parts):
Da Zheng's avatar
Da Zheng committed
762
        set_roles(num_parts)
763
        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition('/tmp/dist_graph/dist_graph_test.json', i)
764
765
        local_nids = F.nonzero_1d(part_g.ndata['inner_node'])
        local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
766
        nodes = node_split(node_mask, gpb, rank=i, force_even=True)
767
768
769
770
        all_nodes1.append(nodes)
        subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(local_nids))
        print('part {} get {} nodes and {} are in the partition'.format(i, len(nodes), len(subset)))

Da Zheng's avatar
Da Zheng committed
771
        set_roles(num_parts * 2)
772
773
774
        nodes1 = node_split(node_mask, gpb, rank=i * 2, force_even=True)
        nodes2 = node_split(node_mask, gpb, rank=i * 2 + 1, force_even=True)
        nodes3, _ = F.sort_1d(F.cat([nodes1, nodes2], 0))
775
776
777
778
        all_nodes2.append(nodes3)
        subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(nodes3))
        print('intersection has', len(subset))

Da Zheng's avatar
Da Zheng committed
779
        set_roles(num_parts)
780
781
        local_eids = F.nonzero_1d(part_g.edata['inner_edge'])
        local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)
782
        edges = edge_split(edge_mask, gpb, rank=i, force_even=True)
783
784
785
786
        all_edges1.append(edges)
        subset = np.intersect1d(F.asnumpy(edges), F.asnumpy(local_eids))
        print('part {} get {} edges and {} are in the partition'.format(i, len(edges), len(subset)))

Da Zheng's avatar
Da Zheng committed
787
        set_roles(num_parts * 2)
788
789
790
        edges1 = edge_split(edge_mask, gpb, rank=i * 2, force_even=True)
        edges2 = edge_split(edge_mask, gpb, rank=i * 2 + 1, force_even=True)
        edges3, _ = F.sort_1d(F.cat([edges1, edges2], 0))
791
792
793
794
795
796
797
798
799
800
801
802
803
804
        all_edges2.append(edges3)
        subset = np.intersect1d(F.asnumpy(edges), F.asnumpy(edges3))
        print('intersection has', len(subset))
    all_nodes1 = F.cat(all_nodes1, 0)
    all_edges1 = F.cat(all_edges1, 0)
    all_nodes2 = F.cat(all_nodes2, 0)
    all_edges2 = F.cat(all_edges2, 0)
    all_nodes = np.nonzero(node_mask)[0]
    all_edges = np.nonzero(edge_mask)[0]
    assert np.all(all_nodes == F.asnumpy(all_nodes1))
    assert np.all(all_edges == F.asnumpy(all_edges1))
    assert np.all(all_nodes == F.asnumpy(all_nodes2))
    assert np.all(all_edges == F.asnumpy(all_edges2))

805
806
def prepare_dist(num_servers=1):
    generate_ip_config("kv_ip_config.txt", 1, num_servers=num_servers)
807

808
if __name__ == '__main__':
Da Zheng's avatar
Da Zheng committed
809
    os.makedirs('/tmp/dist_graph', exist_ok=True)
810
    test_dist_emb_server_client()
811
    test_server_client()
812
813
    test_split(True)
    test_split(False)
814
    test_split_even()
815
    test_standalone()
816
    test_standalone_node_emb()