test_distributed_sampling.py 26.1 KB
Newer Older
Jinjing Zhou's avatar
Jinjing Zhou committed
1
2
3
4
import dgl
import unittest
import os
from dgl.data import CitationGraphDataset
5
6
from dgl.data import WN18Dataset
from dgl.distributed import sample_neighbors, sample_etype_neighbors
Jinjing Zhou's avatar
Jinjing Zhou committed
7
8
9
10
11
12
13
14
from dgl.distributed import partition_graph, load_partition, load_partition_book
import sys
import multiprocessing as mp
import numpy as np
import backend as F
import time
from utils import get_local_usable_addr
from pathlib import Path
15
import pytest
16
from scipy import sparse as spsp
17
import random
Jinjing Zhou's avatar
Jinjing Zhou committed
18
19
20
from dgl.distributed import DistGraphServer, DistGraph


21
def start_server(rank, tmpdir, disable_shared_mem, graph_name, graph_format=['csc', 'coo']):
22
    g = DistGraphServer(rank, "rpc_ip_config.txt", 1, 1,
23
24
                        tmpdir / (graph_name + '.json'), disable_shared_mem=disable_shared_mem,
                        graph_format=graph_format)
Jinjing Zhou's avatar
Jinjing Zhou committed
25
26
27
    g.start()


28
def start_sample_client(rank, tmpdir, disable_shared_mem):
29
30
    gpb = None
    if disable_shared_mem:
31
        _, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_sampling.json', rank)
32
    dgl.distributed.initialize("rpc_ip_config.txt")
33
    dist_graph = DistGraph("test_sampling", gpb=gpb)
34
35
36
37
38
    try:
        sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)
    except Exception as e:
        print(e)
        sampled_graph = None
39
    dgl.distributed.exit_client()
Jinjing Zhou's avatar
Jinjing Zhou committed
40
41
    return sampled_graph

42
43
44
def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids):
    gpb = None
    if disable_shared_mem:
45
        _, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_find_edges.json', rank)
46
    dgl.distributed.initialize("rpc_ip_config.txt")
47
    dist_graph = DistGraph("test_find_edges", gpb=gpb)
48
    try:
49
        u, v = dist_graph.find_edges(eids)
50
51
52
    except Exception as e:
        print(e)
        u, v = None, None
53
54
    dgl.distributed.exit_client()
    return u, v
Jinjing Zhou's avatar
Jinjing Zhou committed
55

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def start_get_degrees_client(rank, tmpdir, disable_shared_mem, nids=None):
    gpb = None
    if disable_shared_mem:
        _, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_get_degrees.json', rank)
    dgl.distributed.initialize("rpc_ip_config.txt", 1)
    dist_graph = DistGraph("test_get_degrees", gpb=gpb)
    try:
        in_deg = dist_graph.in_degrees(nids)
        all_in_deg = dist_graph.in_degrees()
        out_deg = dist_graph.out_degrees(nids)
        all_out_deg = dist_graph.out_degrees()
    except Exception as e:
        print(e)
        in_deg, out_deg, all_in_deg, all_out_deg = None, None, None, None
    dgl.distributed.exit_client()
    return in_deg, out_deg, all_in_deg, all_out_deg

73
def check_rpc_sampling(tmpdir, num_server):
74
    ip_config = open("rpc_ip_config.txt", "w")
Jinjing Zhou's avatar
Jinjing Zhou committed
75
    for _ in range(num_server):
76
        ip_config.write('{}\n'.format(get_local_usable_addr()))
Jinjing Zhou's avatar
Jinjing Zhou committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    ip_config.close()

    g = CitationGraphDataset("cora")[0]
    g.readonly()
    print(g.idtype)
    num_parts = num_server
    num_hops = 1

    partition_graph(g, 'test_sampling', num_parts, tmpdir,
                    num_hops=num_hops, part_method='metis', reshuffle=False)

    pserver_list = []
    ctx = mp.get_context('spawn')
    for i in range(num_server):
91
        p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_sampling'))
Jinjing Zhou's avatar
Jinjing Zhou committed
92
93
94
95
96
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    time.sleep(3)
97
    sampled_graph = start_sample_client(0, tmpdir, num_server > 1)
Jinjing Zhou's avatar
Jinjing Zhou committed
98
99
100
101
102
103
104
105
106
107
108
    print("Done sampling")
    for p in pserver_list:
        p.join()

    src, dst = sampled_graph.edges()
    assert sampled_graph.number_of_nodes() == g.number_of_nodes()
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    assert np.array_equal(
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids))

109
def check_rpc_find_edges_shuffle(tmpdir, num_server):
110
111
    ip_config = open("rpc_ip_config.txt", "w")
    for _ in range(num_server):
112
        ip_config.write('{}\n'.format(get_local_usable_addr()))
113
114
115
116
117
118
119
    ip_config.close()

    g = CitationGraphDataset("cora")[0]
    g.readonly()
    num_parts = num_server

    partition_graph(g, 'test_find_edges', num_parts, tmpdir,
120
                    num_hops=1, part_method='metis', reshuffle=True)
121
122
123
124

    pserver_list = []
    ctx = mp.get_context('spawn')
    for i in range(num_server):
125
126
        p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1,
                                                   'test_find_edges', ['csr', 'coo']))
127
128
129
130
        p.start()
        time.sleep(1)
        pserver_list.append(p)

131
132
    orig_nid = F.zeros((g.number_of_nodes(),), dtype=F.int64, ctx=F.cpu())
    orig_eid = F.zeros((g.number_of_edges(),), dtype=F.int64, ctx=F.cpu())
133
134
135
136
137
    for i in range(num_server):
        part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_find_edges.json', i)
        orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
        orig_eid[part.edata[dgl.EID]] = part.edata['orig_id']

138
139
    time.sleep(3)
    eids = F.tensor(np.random.randint(g.number_of_edges(), size=100))
140
    u, v = g.find_edges(orig_eid[eids])
141
    du, dv = start_find_edges_client(0, tmpdir, num_server > 1, eids)
142
143
    du = orig_nid[du]
    dv = orig_nid[dv]
144
145
146
    assert F.array_equal(u, du)
    assert F.array_equal(v, dv)

147
148
149
# Wait non shared memory graph store
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
150
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support")
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
@pytest.mark.parametrize("num_server", [1, 2])
def test_rpc_find_edges_shuffle(num_server):
    import tempfile
    os.environ['DGL_DIST_MODE'] = 'distributed'
    with tempfile.TemporaryDirectory() as tmpdirname:
        check_rpc_find_edges_shuffle(Path(tmpdirname), num_server)

def check_rpc_get_degree_shuffle(tmpdir, num_server):
    ip_config = open("rpc_ip_config.txt", "w")
    for _ in range(num_server):
        ip_config.write('{}\n'.format(get_local_usable_addr()))
    ip_config.close()

    g = CitationGraphDataset("cora")[0]
    g.readonly()
    num_parts = num_server

    partition_graph(g, 'test_get_degrees', num_parts, tmpdir,
                    num_hops=1, part_method='metis', reshuffle=True)

    pserver_list = []
    ctx = mp.get_context('spawn')
    for i in range(num_server):
        p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_get_degrees'))
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    orig_nid = F.zeros((g.number_of_nodes(),), dtype=F.int64, ctx=F.cpu())
    for i in range(num_server):
        part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_get_degrees.json', i)
        orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
    time.sleep(3)

    nids = F.tensor(np.random.randint(g.number_of_nodes(), size=100))
    in_degs, out_degs, all_in_degs, all_out_degs = start_get_degrees_client(0, tmpdir, num_server > 1, nids)

    print("Done get_degree")
    for p in pserver_list:
        p.join()

    print('check results')
    assert F.array_equal(g.in_degrees(orig_nid[nids]), in_degs)
    assert F.array_equal(g.in_degrees(orig_nid), all_in_degs)
    assert F.array_equal(g.out_degrees(orig_nid[nids]), out_degs)
    assert F.array_equal(g.out_degrees(orig_nid), all_out_degs)

# Wait non shared memory graph store
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
201
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support")
202
203
204
205
206
207
208
@pytest.mark.parametrize("num_server", [1, 2])
def test_rpc_get_degree_shuffle(num_server):
    import tempfile
    os.environ['DGL_DIST_MODE'] = 'distributed'
    with tempfile.TemporaryDirectory() as tmpdirname:
        check_rpc_get_degree_shuffle(Path(tmpdirname), num_server)

209
210
211
#@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
#@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
@unittest.skip('Only support partition with shuffle')
Jinjing Zhou's avatar
Jinjing Zhou committed
212
213
def test_rpc_sampling():
    import tempfile
214
    os.environ['DGL_DIST_MODE'] = 'distributed'
Jinjing Zhou's avatar
Jinjing Zhou committed
215
    with tempfile.TemporaryDirectory() as tmpdirname:
216
        check_rpc_sampling(Path(tmpdirname), 2)
Jinjing Zhou's avatar
Jinjing Zhou committed
217

218
def check_rpc_sampling_shuffle(tmpdir, num_server):
219
    ip_config = open("rpc_ip_config.txt", "w")
Jinjing Zhou's avatar
Jinjing Zhou committed
220
    for _ in range(num_server):
221
        ip_config.write('{}\n'.format(get_local_usable_addr()))
Jinjing Zhou's avatar
Jinjing Zhou committed
222
    ip_config.close()
223

Jinjing Zhou's avatar
Jinjing Zhou committed
224
225
226
227
228
229
230
231
232
233
234
    g = CitationGraphDataset("cora")[0]
    g.readonly()
    num_parts = num_server
    num_hops = 1

    partition_graph(g, 'test_sampling', num_parts, tmpdir,
                    num_hops=num_hops, part_method='metis', reshuffle=True)

    pserver_list = []
    ctx = mp.get_context('spawn')
    for i in range(num_server):
235
        p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_sampling'))
Jinjing Zhou's avatar
Jinjing Zhou committed
236
237
238
239
240
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    time.sleep(3)
241
    sampled_graph = start_sample_client(0, tmpdir, num_server > 1)
Jinjing Zhou's avatar
Jinjing Zhou committed
242
243
244
245
    print("Done sampling")
    for p in pserver_list:
        p.join()

246
247
    orig_nid = F.zeros((g.number_of_nodes(),), dtype=F.int64, ctx=F.cpu())
    orig_eid = F.zeros((g.number_of_edges(),), dtype=F.int64, ctx=F.cpu())
Jinjing Zhou's avatar
Jinjing Zhou committed
248
    for i in range(num_server):
249
        part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
Jinjing Zhou's avatar
Jinjing Zhou committed
250
251
252
253
254
255
256
257
258
259
260
261
        orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
        orig_eid[part.edata[dgl.EID]] = part.edata['orig_id']

    src, dst = sampled_graph.edges()
    src = orig_nid[src]
    dst = orig_nid[dst]
    assert sampled_graph.number_of_nodes() == g.number_of_nodes()
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
    assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))

262
263
264
def create_random_hetero(dense=False):
    num_nodes = {'n1': 210, 'n2': 200, 'n3': 220} if dense else \
        {'n1': 1010, 'n2': 1000, 'n3': 1020}
265
266
267
268
    etypes = [('n1', 'r1', 'n2'),
              ('n1', 'r2', 'n3'),
              ('n2', 'r3', 'n3')]
    edges = {}
269
    random.seed(42)
270
271
    for etype in etypes:
        src_ntype, _, dst_ntype = etype
272
        arr = spsp.random(num_nodes[src_ntype], num_nodes[dst_ntype], density=0.1 if dense else 0.001, format='coo',
273
274
275
276
277
278
279
280
281
282
                          random_state=100)
        edges[etype] = (arr.row, arr.col)
    g = dgl.heterograph(edges, num_nodes)
    g.nodes['n1'].data['feat'] = F.ones((g.number_of_nodes('n1'), 10), F.float32, F.cpu())
    return g

def start_hetero_sample_client(rank, tmpdir, disable_shared_mem):
    gpb = None
    if disable_shared_mem:
        _, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_sampling.json', rank)
283
    dgl.distributed.initialize("rpc_ip_config.txt")
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
    dist_graph = DistGraph("test_sampling", gpb=gpb)
    assert 'feat' in dist_graph.nodes['n1'].data
    assert 'feat' not in dist_graph.nodes['n2'].data
    assert 'feat' not in dist_graph.nodes['n3'].data
    if gpb is None:
        gpb = dist_graph.get_partition_book()
    try:
        nodes = {'n3': [0, 10, 99, 66, 124, 208]}
        sampled_graph = sample_neighbors(dist_graph, nodes, 3)
        block = dgl.to_block(sampled_graph, nodes)
        block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
    except Exception as e:
        print(e)
        block = None
    dgl.distributed.exit_client()
    return block, gpb

301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
def start_hetero_etype_sample_client(rank, tmpdir, disable_shared_mem, fanout=3):
    gpb = None
    if disable_shared_mem:
        _, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_sampling.json', rank)
    dgl.distributed.initialize("rpc_ip_config.txt")
    dist_graph = DistGraph("test_sampling", gpb=gpb)
    assert 'feat' in dist_graph.nodes['n1'].data
    assert 'feat' not in dist_graph.nodes['n2'].data
    assert 'feat' not in dist_graph.nodes['n3'].data
    if gpb is None:
        gpb = dist_graph.get_partition_book()
    try:
        nodes = {'n3': [0, 10, 99, 66, 124, 208]}
        sampled_graph = sample_etype_neighbors(dist_graph, nodes, dgl.ETYPE, fanout)
        block = dgl.to_block(sampled_graph, nodes)
        block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
    except Exception as e:
        print(e)
        block = None
    dgl.distributed.exit_client()
    return block, gpb

323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
def check_rpc_hetero_sampling_shuffle(tmpdir, num_server):
    ip_config = open("rpc_ip_config.txt", "w")
    for _ in range(num_server):
        ip_config.write('{}\n'.format(get_local_usable_addr()))
    ip_config.close()

    g = create_random_hetero()
    num_parts = num_server
    num_hops = 1

    partition_graph(g, 'test_sampling', num_parts, tmpdir,
                    num_hops=num_hops, part_method='metis', reshuffle=True)

    pserver_list = []
    ctx = mp.get_context('spawn')
    for i in range(num_server):
        p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_sampling'))
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    time.sleep(3)
    block, gpb = start_hetero_sample_client(0, tmpdir, num_server > 1)
    print("Done sampling")
    for p in pserver_list:
        p.join()

350
351
    orig_nid_map = {ntype: F.zeros((g.number_of_nodes(ntype),), dtype=F.int64) for ntype in g.ntypes}
    orig_eid_map = {etype: F.zeros((g.number_of_edges(etype),), dtype=F.int64) for etype in g.etypes}
352
353
    for i in range(num_server):
        part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
        ntype_ids, type_nids = gpb.map_to_per_ntype(part.ndata[dgl.NID])
        for ntype_id, ntype in enumerate(g.ntypes):
            idx = ntype_ids == ntype_id
            F.scatter_row_inplace(orig_nid_map[ntype], F.boolean_mask(type_nids, idx),
                                  F.boolean_mask(part.ndata['orig_id'], idx))
        etype_ids, type_eids = gpb.map_to_per_etype(part.edata[dgl.EID])
        for etype_id, etype in enumerate(g.etypes):
            idx = etype_ids == etype_id
            F.scatter_row_inplace(orig_eid_map[etype], F.boolean_mask(type_eids, idx),
                                  F.boolean_mask(part.edata['orig_id'], idx))

    for src_type, etype, dst_type in block.canonical_etypes:
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        shuffled_eid = block.edges[etype].data[dgl.EID]

        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[etype], shuffled_eid))
375
376

        # Check the node Ids and edge Ids.
377
378
379
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)
380

381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
def check_rpc_hetero_etype_sampling_shuffle(tmpdir, num_server):
    ip_config = open("rpc_ip_config.txt", "w")
    for _ in range(num_server):
        ip_config.write('{}\n'.format(get_local_usable_addr()))
    ip_config.close()
    g = create_random_hetero(dense=True)
    num_parts = num_server
    num_hops = 1

    partition_graph(g, 'test_sampling', num_parts, tmpdir,
                    num_hops=num_hops, part_method='metis', reshuffle=True)

    pserver_list = []
    ctx = mp.get_context('spawn')
    for i in range(num_server):
        p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_sampling'))
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    time.sleep(3)
    fanout = 3
    block, gpb = start_hetero_etype_sample_client(0, tmpdir, num_server > 1, fanout)
    print("Done sampling")
    for p in pserver_list:
        p.join()

    src, dst = block.edges(etype=('n1', 'r2', 'n3'))
    assert len(src) == 18
    src, dst = block.edges(etype=('n2', 'r3', 'n3'))
    assert len(src) == 18

    orig_nid_map = {ntype: F.zeros((g.number_of_nodes(ntype),), dtype=F.int64) for ntype in g.ntypes}
    orig_eid_map = {etype: F.zeros((g.number_of_edges(etype),), dtype=F.int64) for etype in g.etypes}
    for i in range(num_server):
        part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
        ntype_ids, type_nids = gpb.map_to_per_ntype(part.ndata[dgl.NID])
        for ntype_id, ntype in enumerate(g.ntypes):
            idx = ntype_ids == ntype_id
            F.scatter_row_inplace(orig_nid_map[ntype], F.boolean_mask(type_nids, idx),
                                  F.boolean_mask(part.ndata['orig_id'], idx))
        etype_ids, type_eids = gpb.map_to_per_etype(part.edata[dgl.EID])
        for etype_id, etype in enumerate(g.etypes):
            idx = etype_ids == etype_id
            F.scatter_row_inplace(orig_eid_map[etype], F.boolean_mask(type_eids, idx),
                                  F.boolean_mask(part.edata['orig_id'], idx))

    for src_type, etype, dst_type in block.canonical_etypes:
        src, dst = block.edges(etype=etype)
        # These are global Ids after shuffling.
        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
        shuffled_eid = block.edges[etype].data[dgl.EID]

        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[etype], shuffled_eid))

        # Check the node Ids and edge Ids.
        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
        assert np.all(F.asnumpy(orig_src1) == orig_src)
        assert np.all(F.asnumpy(orig_dst1) == orig_dst)

Jinjing Zhou's avatar
Jinjing Zhou committed
444
445
446
# Wait non shared memory graph store
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
447
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support")
448
449
@pytest.mark.parametrize("num_server", [1, 2])
def test_rpc_sampling_shuffle(num_server):
Jinjing Zhou's avatar
Jinjing Zhou committed
450
    import tempfile
451
    os.environ['DGL_DIST_MODE'] = 'distributed'
Jinjing Zhou's avatar
Jinjing Zhou committed
452
    with tempfile.TemporaryDirectory() as tmpdirname:
453
        check_rpc_sampling_shuffle(Path(tmpdirname), num_server)
454
        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), num_server)
455
        check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), num_server)
Jinjing Zhou's avatar
Jinjing Zhou committed
456

457
def check_standalone_sampling(tmpdir, reshuffle):
458
459
460
461
    g = CitationGraphDataset("cora")[0]
    num_parts = 1
    num_hops = 1
    partition_graph(g, 'test_sampling', num_parts, tmpdir,
462
                    num_hops=num_hops, part_method='metis', reshuffle=reshuffle)
463

464
    os.environ['DGL_DIST_MODE'] = 'standalone'
465
    dgl.distributed.initialize("rpc_ip_config.txt")
466
    dist_graph = DistGraph("test_sampling", part_config=tmpdir / 'test_sampling.json')
467
468
469
470
471
472
473
474
    sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)

    src, dst = sampled_graph.edges()
    assert sampled_graph.number_of_nodes() == g.number_of_nodes()
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    assert np.array_equal(
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids))
475
    dgl.distributed.exit_client()
476

477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
def check_standalone_etype_sampling(tmpdir, reshuffle):
    hg = CitationGraphDataset('cora')[0]
    num_parts = 1
    num_hops = 1

    partition_graph(hg, 'test_sampling', num_parts, tmpdir,
                    num_hops=num_hops, part_method='metis', reshuffle=reshuffle)
    os.environ['DGL_DIST_MODE'] = 'standalone'
    dgl.distributed.initialize("rpc_ip_config.txt")
    dist_graph = DistGraph("test_sampling", part_config=tmpdir / 'test_sampling.json')
    sampled_graph = sample_etype_neighbors(dist_graph, [0, 10, 99, 66, 1023], dgl.ETYPE, 3)

    src, dst = sampled_graph.edges()
    assert sampled_graph.number_of_nodes() == hg.number_of_nodes()
    assert np.all(F.asnumpy(hg.has_edges_between(src, dst)))
    eids = hg.edge_ids(src, dst)
    assert np.array_equal(
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids))
    dgl.distributed.exit_client()

def check_standalone_etype_sampling_heterograph(tmpdir, reshuffle):
    hg = CitationGraphDataset('cora')[0]
    num_parts = 1
    num_hops = 1
    src, dst = hg.edges()
    new_hg = dgl.heterograph({('paper', 'cite', 'paper'): (src, dst),
                              ('paper', 'cite-by', 'paper'): (dst, src)},
                              {'paper': hg.number_of_nodes()})
    partition_graph(new_hg, 'test_hetero_sampling', num_parts, tmpdir,
                    num_hops=num_hops, part_method='metis', reshuffle=reshuffle)
    os.environ['DGL_DIST_MODE'] = 'standalone'
    dgl.distributed.initialize("rpc_ip_config.txt")
    dist_graph = DistGraph("test_hetero_sampling", part_config=tmpdir / 'test_hetero_sampling.json')
    sampled_graph = sample_etype_neighbors(dist_graph, [0, 1, 2, 10, 99, 66, 1023, 1024, 2700, 2701], dgl.ETYPE, 1)
    src, dst = sampled_graph.edges(etype=('paper', 'cite', 'paper'))
    assert len(src) == 10
    src, dst = sampled_graph.edges(etype=('paper', 'cite-by', 'paper'))
    assert len(src) == 10
    assert sampled_graph.number_of_nodes() == new_hg.number_of_nodes()
    dgl.distributed.exit_client()

518
519
520
521
522
523
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
def test_standalone_sampling():
    import tempfile
    os.environ['DGL_DIST_MODE'] = 'standalone'
    with tempfile.TemporaryDirectory() as tmpdirname:
524
525
        check_standalone_sampling(Path(tmpdirname), False)
        check_standalone_sampling(Path(tmpdirname), True)
526

527
528
def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes):
    gpb = None
529
    dgl.distributed.initialize("rpc_ip_config.txt")
530
    if disable_shared_mem:
531
        _, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_in_subgraph.json', rank)
532
    dist_graph = DistGraph("test_in_subgraph", gpb=gpb)
533
534
535
536
537
    try:
        sampled_graph = dgl.distributed.in_subgraph(dist_graph, nodes)
    except Exception as e:
        print(e)
        sampled_graph = None
538
    dgl.distributed.exit_client()
539
540
541
    return sampled_graph


542
def check_rpc_in_subgraph_shuffle(tmpdir, num_server):
543
544
    ip_config = open("rpc_ip_config.txt", "w")
    for _ in range(num_server):
545
        ip_config.write('{}\n'.format(get_local_usable_addr()))
546
547
548
549
550
551
552
    ip_config.close()

    g = CitationGraphDataset("cora")[0]
    g.readonly()
    num_parts = num_server

    partition_graph(g, 'test_in_subgraph', num_parts, tmpdir,
553
                    num_hops=1, part_method='metis', reshuffle=True)
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568

    pserver_list = []
    ctx = mp.get_context('spawn')
    for i in range(num_server):
        p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_in_subgraph'))
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    nodes = [0, 10, 99, 66, 1024, 2008]
    time.sleep(3)
    sampled_graph = start_in_subgraph_client(0, tmpdir, num_server > 1, nodes)
    for p in pserver_list:
        p.join()

569

570
571
    orig_nid = F.zeros((g.number_of_nodes(),), dtype=F.int64, ctx=F.cpu())
    orig_eid = F.zeros((g.number_of_edges(),), dtype=F.int64, ctx=F.cpu())
572
573
574
575
576
    for i in range(num_server):
        part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_in_subgraph.json', i)
        orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
        orig_eid[part.edata[dgl.EID]] = part.edata['orig_id']

577
    src, dst = sampled_graph.edges()
578
579
    src = orig_nid[src]
    dst = orig_nid[dst]
580
    assert sampled_graph.number_of_nodes() == g.number_of_nodes()
581
582
583
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))

    subg1 = dgl.in_subgraph(g, orig_nid[nodes])
584
585
586
587
    src1, dst1 = subg1.edges()
    assert np.all(np.sort(F.asnumpy(src)) == np.sort(F.asnumpy(src1)))
    assert np.all(np.sort(F.asnumpy(dst)) == np.sort(F.asnumpy(dst1)))
    eids = g.edge_ids(src, dst)
588
589
    eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
    assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))
590
591
592
593
594

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
def test_rpc_in_subgraph():
    import tempfile
595
    os.environ['DGL_DIST_MODE'] = 'distributed'
596
    with tempfile.TemporaryDirectory() as tmpdirname:
597
        check_rpc_in_subgraph_shuffle(Path(tmpdirname), 2)
598

599
600
601
602
603
604
605
606
607
608
609
610
611
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support")
def test_standalone_etype_sampling():
    import tempfile
    with tempfile.TemporaryDirectory() as tmpdirname:
        os.environ['DGL_DIST_MODE'] = 'standalone'
        check_standalone_etype_sampling_heterograph(Path(tmpdirname), True)
    with tempfile.TemporaryDirectory() as tmpdirname:
        os.environ['DGL_DIST_MODE'] = 'standalone'
        check_standalone_etype_sampling(Path(tmpdirname), True)
        check_standalone_etype_sampling(Path(tmpdirname), False)

Jinjing Zhou's avatar
Jinjing Zhou committed
612
613
614
if __name__ == "__main__":
    import tempfile
    with tempfile.TemporaryDirectory() as tmpdirname:
615
        os.environ['DGL_DIST_MODE'] = 'standalone'
616
617
618
619
620
621
622
623
        check_standalone_etype_sampling_heterograph(Path(tmpdirname), True)

    test_rpc_sampling_shuffle(1)
    test_rpc_sampling_shuffle(2)
    with tempfile.TemporaryDirectory() as tmpdirname:
        os.environ['DGL_DIST_MODE'] = 'standalone'
        check_standalone_etype_sampling(Path(tmpdirname), True)
        check_standalone_etype_sampling(Path(tmpdirname), False)
624
625
        check_standalone_sampling(Path(tmpdirname), True)
        check_standalone_sampling(Path(tmpdirname), False)
626
        os.environ['DGL_DIST_MODE'] = 'distributed'
627
628
        check_rpc_sampling(Path(tmpdirname), 2)
        check_rpc_sampling(Path(tmpdirname), 1)
629
630
        check_rpc_get_degree_shuffle(Path(tmpdirname), 1)
        check_rpc_get_degree_shuffle(Path(tmpdirname), 2)
631
632
633
634
635
636
637
        check_rpc_find_edges_shuffle(Path(tmpdirname), 2)
        check_rpc_find_edges_shuffle(Path(tmpdirname), 1)
        check_rpc_in_subgraph_shuffle(Path(tmpdirname), 2)
        check_rpc_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_sampling_shuffle(Path(tmpdirname), 2)
        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), 2)