test_distributed_sampling.py 19.1 KB
Newer Older
Jinjing Zhou's avatar
Jinjing Zhou committed
1
2
3
4
import dgl
import unittest
import os
from dgl.data import CitationGraphDataset
5
from dgl.distributed import sample_neighbors
Jinjing Zhou's avatar
Jinjing Zhou committed
6
7
8
9
10
11
12
13
from dgl.distributed import partition_graph, load_partition, load_partition_book
import sys
import multiprocessing as mp
import numpy as np
import backend as F
import time
from utils import get_local_usable_addr
from pathlib import Path
14
import pytest
15
from scipy import sparse as spsp
Jinjing Zhou's avatar
Jinjing Zhou committed
16
17
18
from dgl.distributed import DistGraphServer, DistGraph


19
def start_server(rank, tmpdir, disable_shared_mem, graph_name, graph_format='csc'):
20
    g = DistGraphServer(rank, "rpc_ip_config.txt", 1, 1,
21
22
                        tmpdir / (graph_name + '.json'), disable_shared_mem=disable_shared_mem,
                        graph_format=graph_format)
Jinjing Zhou's avatar
Jinjing Zhou committed
23
24
25
    g.start()


26
def start_sample_client(rank, tmpdir, disable_shared_mem):
27
28
    gpb = None
    if disable_shared_mem:
29
        _, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_sampling.json', rank)
30
    dgl.distributed.initialize("rpc_ip_config.txt")
31
    dist_graph = DistGraph("test_sampling", gpb=gpb)
32
33
34
35
36
    try:
        sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)
    except Exception as e:
        print(e)
        sampled_graph = None
37
    dgl.distributed.exit_client()
Jinjing Zhou's avatar
Jinjing Zhou committed
38
39
    return sampled_graph

40
41
42
def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids):
    gpb = None
    if disable_shared_mem:
43
        _, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_find_edges.json', rank)
44
    dgl.distributed.initialize("rpc_ip_config.txt")
45
    dist_graph = DistGraph("test_find_edges", gpb=gpb)
46
    try:
47
        u, v = dist_graph.find_edges(eids)
48
49
50
    except Exception as e:
        print(e)
        u, v = None, None
51
52
    dgl.distributed.exit_client()
    return u, v
Jinjing Zhou's avatar
Jinjing Zhou committed
53

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def start_get_degrees_client(rank, tmpdir, disable_shared_mem, nids=None):
    gpb = None
    if disable_shared_mem:
        _, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_get_degrees.json', rank)
    dgl.distributed.initialize("rpc_ip_config.txt", 1)
    dist_graph = DistGraph("test_get_degrees", gpb=gpb)
    try:
        in_deg = dist_graph.in_degrees(nids)
        all_in_deg = dist_graph.in_degrees()
        out_deg = dist_graph.out_degrees(nids)
        all_out_deg = dist_graph.out_degrees()
    except Exception as e:
        print(e)
        in_deg, out_deg, all_in_deg, all_out_deg = None, None, None, None
    dgl.distributed.exit_client()
    return in_deg, out_deg, all_in_deg, all_out_deg

71
def check_rpc_sampling(tmpdir, num_server):
72
    ip_config = open("rpc_ip_config.txt", "w")
Jinjing Zhou's avatar
Jinjing Zhou committed
73
    for _ in range(num_server):
74
        ip_config.write('{}\n'.format(get_local_usable_addr()))
Jinjing Zhou's avatar
Jinjing Zhou committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    ip_config.close()

    g = CitationGraphDataset("cora")[0]
    g.readonly()
    print(g.idtype)
    num_parts = num_server
    num_hops = 1

    partition_graph(g, 'test_sampling', num_parts, tmpdir,
                    num_hops=num_hops, part_method='metis', reshuffle=False)

    pserver_list = []
    ctx = mp.get_context('spawn')
    for i in range(num_server):
89
        p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_sampling'))
Jinjing Zhou's avatar
Jinjing Zhou committed
90
91
92
93
94
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    time.sleep(3)
95
    sampled_graph = start_sample_client(0, tmpdir, num_server > 1)
Jinjing Zhou's avatar
Jinjing Zhou committed
96
97
98
99
100
101
102
103
104
105
106
    print("Done sampling")
    for p in pserver_list:
        p.join()

    src, dst = sampled_graph.edges()
    assert sampled_graph.number_of_nodes() == g.number_of_nodes()
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    assert np.array_equal(
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids))

107
def check_rpc_find_edges_shuffle(tmpdir, num_server):
108
109
    ip_config = open("rpc_ip_config.txt", "w")
    for _ in range(num_server):
110
        ip_config.write('{}\n'.format(get_local_usable_addr()))
111
112
113
114
115
116
117
    ip_config.close()

    g = CitationGraphDataset("cora")[0]
    g.readonly()
    num_parts = num_server

    partition_graph(g, 'test_find_edges', num_parts, tmpdir,
118
                    num_hops=1, part_method='metis', reshuffle=True)
119
120
121
122

    pserver_list = []
    ctx = mp.get_context('spawn')
    for i in range(num_server):
123
124
        p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1,
                                                   'test_find_edges', ['csr', 'coo']))
125
126
127
128
        p.start()
        time.sleep(1)
        pserver_list.append(p)

129
130
    orig_nid = F.zeros((g.number_of_nodes(),), dtype=F.int64, ctx=F.cpu())
    orig_eid = F.zeros((g.number_of_edges(),), dtype=F.int64, ctx=F.cpu())
131
132
133
134
135
    for i in range(num_server):
        part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_find_edges.json', i)
        orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
        orig_eid[part.edata[dgl.EID]] = part.edata['orig_id']

136
137
    time.sleep(3)
    eids = F.tensor(np.random.randint(g.number_of_edges(), size=100))
138
    u, v = g.find_edges(orig_eid[eids])
139
    du, dv = start_find_edges_client(0, tmpdir, num_server > 1, eids)
140
141
    du = orig_nid[du]
    dv = orig_nid[dv]
142
143
144
    assert F.array_equal(u, du)
    assert F.array_equal(v, dv)

145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
# Wait non shared memory graph store
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
@pytest.mark.parametrize("num_server", [1, 2])
def test_rpc_find_edges_shuffle(num_server):
    import tempfile
    os.environ['DGL_DIST_MODE'] = 'distributed'
    with tempfile.TemporaryDirectory() as tmpdirname:
        check_rpc_find_edges_shuffle(Path(tmpdirname), num_server)

def check_rpc_get_degree_shuffle(tmpdir, num_server):
    ip_config = open("rpc_ip_config.txt", "w")
    for _ in range(num_server):
        ip_config.write('{}\n'.format(get_local_usable_addr()))
    ip_config.close()

    g = CitationGraphDataset("cora")[0]
    g.readonly()
    num_parts = num_server

    partition_graph(g, 'test_get_degrees', num_parts, tmpdir,
                    num_hops=1, part_method='metis', reshuffle=True)

    pserver_list = []
    ctx = mp.get_context('spawn')
    for i in range(num_server):
        p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_get_degrees'))
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    orig_nid = F.zeros((g.number_of_nodes(),), dtype=F.int64, ctx=F.cpu())
    for i in range(num_server):
        part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_get_degrees.json', i)
        orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
    time.sleep(3)

    nids = F.tensor(np.random.randint(g.number_of_nodes(), size=100))
    in_degs, out_degs, all_in_degs, all_out_degs = start_get_degrees_client(0, tmpdir, num_server > 1, nids)

    print("Done get_degree")
    for p in pserver_list:
        p.join()

    print('check results')
    assert F.array_equal(g.in_degrees(orig_nid[nids]), in_degs)
    assert F.array_equal(g.in_degrees(orig_nid), all_in_degs)
    assert F.array_equal(g.out_degrees(orig_nid[nids]), out_degs)
    assert F.array_equal(g.out_degrees(orig_nid), all_out_degs)

# Wait non shared memory graph store
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
@pytest.mark.parametrize("num_server", [1, 2])
def test_rpc_get_degree_shuffle(num_server):
    import tempfile
    os.environ['DGL_DIST_MODE'] = 'distributed'
    with tempfile.TemporaryDirectory() as tmpdirname:
        check_rpc_get_degree_shuffle(Path(tmpdirname), num_server)

205
206
207
#@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
#@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
@unittest.skip('Only support partition with shuffle')
Jinjing Zhou's avatar
Jinjing Zhou committed
208
209
def test_rpc_sampling():
    import tempfile
210
    os.environ['DGL_DIST_MODE'] = 'distributed'
Jinjing Zhou's avatar
Jinjing Zhou committed
211
    with tempfile.TemporaryDirectory() as tmpdirname:
212
        check_rpc_sampling(Path(tmpdirname), 2)
Jinjing Zhou's avatar
Jinjing Zhou committed
213

214
def check_rpc_sampling_shuffle(tmpdir, num_server):
215
    ip_config = open("rpc_ip_config.txt", "w")
Jinjing Zhou's avatar
Jinjing Zhou committed
216
    for _ in range(num_server):
217
        ip_config.write('{}\n'.format(get_local_usable_addr()))
Jinjing Zhou's avatar
Jinjing Zhou committed
218
    ip_config.close()
219

Jinjing Zhou's avatar
Jinjing Zhou committed
220
221
222
223
224
225
226
227
228
229
230
    g = CitationGraphDataset("cora")[0]
    g.readonly()
    num_parts = num_server
    num_hops = 1

    partition_graph(g, 'test_sampling', num_parts, tmpdir,
                    num_hops=num_hops, part_method='metis', reshuffle=True)

    pserver_list = []
    ctx = mp.get_context('spawn')
    for i in range(num_server):
231
        p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_sampling'))
Jinjing Zhou's avatar
Jinjing Zhou committed
232
233
234
235
236
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    time.sleep(3)
237
    sampled_graph = start_sample_client(0, tmpdir, num_server > 1)
Jinjing Zhou's avatar
Jinjing Zhou committed
238
239
240
241
    print("Done sampling")
    for p in pserver_list:
        p.join()

242
243
    orig_nid = F.zeros((g.number_of_nodes(),), dtype=F.int64, ctx=F.cpu())
    orig_eid = F.zeros((g.number_of_edges(),), dtype=F.int64, ctx=F.cpu())
Jinjing Zhou's avatar
Jinjing Zhou committed
244
    for i in range(num_server):
245
        part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
Jinjing Zhou's avatar
Jinjing Zhou committed
246
247
248
249
250
251
252
253
254
255
256
257
        orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
        orig_eid[part.edata[dgl.EID]] = part.edata['orig_id']

    src, dst = sampled_graph.edges()
    src = orig_nid[src]
    dst = orig_nid[dst]
    assert sampled_graph.number_of_nodes() == g.number_of_nodes()
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
    assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))

258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
def create_random_hetero():
    num_nodes = {'n1': 1010, 'n2': 1000, 'n3': 1020}
    etypes = [('n1', 'r1', 'n2'),
              ('n1', 'r2', 'n3'),
              ('n2', 'r3', 'n3')]
    edges = {}
    for etype in etypes:
        src_ntype, _, dst_ntype = etype
        arr = spsp.random(num_nodes[src_ntype], num_nodes[dst_ntype], density=0.001, format='coo',
                          random_state=100)
        edges[etype] = (arr.row, arr.col)
    g = dgl.heterograph(edges, num_nodes)
    g.nodes['n1'].data['feat'] = F.ones((g.number_of_nodes('n1'), 10), F.float32, F.cpu())
    return g

def start_hetero_sample_client(rank, tmpdir, disable_shared_mem):
    gpb = None
    if disable_shared_mem:
        _, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_sampling.json', rank)
277
    dgl.distributed.initialize("rpc_ip_config.txt")
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
    dist_graph = DistGraph("test_sampling", gpb=gpb)
    assert 'feat' in dist_graph.nodes['n1'].data
    assert 'feat' not in dist_graph.nodes['n2'].data
    assert 'feat' not in dist_graph.nodes['n3'].data
    if gpb is None:
        gpb = dist_graph.get_partition_book()
    try:
        nodes = {'n3': [0, 10, 99, 66, 124, 208]}
        sampled_graph = sample_neighbors(dist_graph, nodes, 3)
        nodes = gpb.map_to_homo_nid(nodes['n3'], 'n3')
        block = dgl.to_block(sampled_graph, nodes)
        block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
    except Exception as e:
        print(e)
        block = None
    dgl.distributed.exit_client()
    return block, gpb

def check_rpc_hetero_sampling_shuffle(tmpdir, num_server):
    ip_config = open("rpc_ip_config.txt", "w")
    for _ in range(num_server):
        ip_config.write('{}\n'.format(get_local_usable_addr()))
    ip_config.close()

    g = create_random_hetero()
    num_parts = num_server
    num_hops = 1

    partition_graph(g, 'test_sampling', num_parts, tmpdir,
                    num_hops=num_hops, part_method='metis', reshuffle=True)

    pserver_list = []
    ctx = mp.get_context('spawn')
    for i in range(num_server):
        p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_sampling'))
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    time.sleep(3)
    block, gpb = start_hetero_sample_client(0, tmpdir, num_server > 1)
    print("Done sampling")
    for p in pserver_list:
        p.join()

    orig_nid_map = F.zeros((g.number_of_nodes(),), dtype=F.int64)
    orig_eid_map = F.zeros((g.number_of_edges(),), dtype=F.int64)
    for i in range(num_server):
        part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
        F.scatter_row_inplace(orig_nid_map, part.ndata[dgl.NID], part.ndata['orig_id'])
        F.scatter_row_inplace(orig_eid_map, part.edata[dgl.EID], part.edata['orig_id'])

    src, dst = block.edges()
    # These are global Ids after shuffling.
    shuffled_src = F.gather_row(block.srcdata[dgl.NID], src)
    shuffled_dst = F.gather_row(block.dstdata[dgl.NID], dst)
    shuffled_eid = block.edata[dgl.EID]
    # Get node/edge types.
    etype, _ = gpb.map_to_per_etype(shuffled_eid)
    src_type, _ = gpb.map_to_per_ntype(shuffled_src)
    dst_type, _ = gpb.map_to_per_ntype(shuffled_dst)
    etype = F.asnumpy(etype)
    src_type = F.asnumpy(src_type)
    dst_type = F.asnumpy(dst_type)
    # These are global Ids in the original graph.
    orig_src = F.asnumpy(F.gather_row(orig_nid_map, shuffled_src))
    orig_dst = F.asnumpy(F.gather_row(orig_nid_map, shuffled_dst))
    orig_eid = F.asnumpy(F.gather_row(orig_eid_map, shuffled_eid))

    etype_map = {g.get_etype_id(etype):etype for etype in g.etypes}
    etype_to_eptype = {g.get_etype_id(etype):(src_ntype, dst_ntype) for src_ntype, etype, dst_ntype in g.canonical_etypes}
    for e in np.unique(etype):
        src_t = src_type[etype == e]
        dst_t = dst_type[etype == e]
        assert np.all(src_t == src_t[0])
        assert np.all(dst_t == dst_t[0])

        # Check the node Ids and edge Ids.
        orig_src1, orig_dst1 = g.find_edges(orig_eid[etype == e], etype=etype_map[e])
        assert np.all(F.asnumpy(orig_src1) == orig_src[etype == e])
        assert np.all(F.asnumpy(orig_dst1) == orig_dst[etype == e])

        # Check the node types.
        src_ntype, dst_ntype = etype_to_eptype[e]
        assert np.all(src_t == g.get_ntype_id(src_ntype))
        assert np.all(dst_t == g.get_ntype_id(dst_ntype))

Jinjing Zhou's avatar
Jinjing Zhou committed
365
366
367
# Wait non shared memory graph store
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
368
369
@pytest.mark.parametrize("num_server", [1, 2])
def test_rpc_sampling_shuffle(num_server):
Jinjing Zhou's avatar
Jinjing Zhou committed
370
    import tempfile
371
    os.environ['DGL_DIST_MODE'] = 'distributed'
Jinjing Zhou's avatar
Jinjing Zhou committed
372
    with tempfile.TemporaryDirectory() as tmpdirname:
373
        check_rpc_sampling_shuffle(Path(tmpdirname), num_server)
374
        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), num_server)
Jinjing Zhou's avatar
Jinjing Zhou committed
375

376
def check_standalone_sampling(tmpdir, reshuffle):
377
378
379
380
    g = CitationGraphDataset("cora")[0]
    num_parts = 1
    num_hops = 1
    partition_graph(g, 'test_sampling', num_parts, tmpdir,
381
                    num_hops=num_hops, part_method='metis', reshuffle=reshuffle)
382

383
    os.environ['DGL_DIST_MODE'] = 'standalone'
384
    dgl.distributed.initialize("rpc_ip_config.txt")
385
    dist_graph = DistGraph("test_sampling", part_config=tmpdir / 'test_sampling.json')
386
387
388
389
390
391
392
393
    sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)

    src, dst = sampled_graph.edges()
    assert sampled_graph.number_of_nodes() == g.number_of_nodes()
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    assert np.array_equal(
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids))
394
    dgl.distributed.exit_client()
395
396
397
398
399
400
401

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
def test_standalone_sampling():
    import tempfile
    os.environ['DGL_DIST_MODE'] = 'standalone'
    with tempfile.TemporaryDirectory() as tmpdirname:
402
403
        check_standalone_sampling(Path(tmpdirname), False)
        check_standalone_sampling(Path(tmpdirname), True)
404

405
406
def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes):
    gpb = None
407
    dgl.distributed.initialize("rpc_ip_config.txt")
408
    if disable_shared_mem:
409
        _, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_in_subgraph.json', rank)
410
    dist_graph = DistGraph("test_in_subgraph", gpb=gpb)
411
412
413
414
415
    try:
        sampled_graph = dgl.distributed.in_subgraph(dist_graph, nodes)
    except Exception as e:
        print(e)
        sampled_graph = None
416
    dgl.distributed.exit_client()
417
418
419
    return sampled_graph


420
def check_rpc_in_subgraph_shuffle(tmpdir, num_server):
421
422
    ip_config = open("rpc_ip_config.txt", "w")
    for _ in range(num_server):
423
        ip_config.write('{}\n'.format(get_local_usable_addr()))
424
425
426
427
428
429
430
    ip_config.close()

    g = CitationGraphDataset("cora")[0]
    g.readonly()
    num_parts = num_server

    partition_graph(g, 'test_in_subgraph', num_parts, tmpdir,
431
                    num_hops=1, part_method='metis', reshuffle=True)
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446

    pserver_list = []
    ctx = mp.get_context('spawn')
    for i in range(num_server):
        p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_in_subgraph'))
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    nodes = [0, 10, 99, 66, 1024, 2008]
    time.sleep(3)
    sampled_graph = start_in_subgraph_client(0, tmpdir, num_server > 1, nodes)
    for p in pserver_list:
        p.join()

447

448
449
    orig_nid = F.zeros((g.number_of_nodes(),), dtype=F.int64, ctx=F.cpu())
    orig_eid = F.zeros((g.number_of_edges(),), dtype=F.int64, ctx=F.cpu())
450
451
452
453
454
    for i in range(num_server):
        part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_in_subgraph.json', i)
        orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
        orig_eid[part.edata[dgl.EID]] = part.edata['orig_id']

455
    src, dst = sampled_graph.edges()
456
457
    src = orig_nid[src]
    dst = orig_nid[dst]
458
    assert sampled_graph.number_of_nodes() == g.number_of_nodes()
459
460
461
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))

    subg1 = dgl.in_subgraph(g, orig_nid[nodes])
462
463
464
465
    src1, dst1 = subg1.edges()
    assert np.all(np.sort(F.asnumpy(src)) == np.sort(F.asnumpy(src1)))
    assert np.all(np.sort(F.asnumpy(dst)) == np.sort(F.asnumpy(dst1)))
    eids = g.edge_ids(src, dst)
466
467
    eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
    assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))
468
469
470
471
472

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
def test_rpc_in_subgraph():
    import tempfile
473
    os.environ['DGL_DIST_MODE'] = 'distributed'
474
    with tempfile.TemporaryDirectory() as tmpdirname:
475
        check_rpc_in_subgraph_shuffle(Path(tmpdirname), 2)
476

Jinjing Zhou's avatar
Jinjing Zhou committed
477
478
479
if __name__ == "__main__":
    import tempfile
    with tempfile.TemporaryDirectory() as tmpdirname:
480
        os.environ['DGL_DIST_MODE'] = 'standalone'
481
482
        check_standalone_sampling(Path(tmpdirname), True)
        check_standalone_sampling(Path(tmpdirname), False)
483
        os.environ['DGL_DIST_MODE'] = 'distributed'
484
485
        check_rpc_sampling(Path(tmpdirname), 2)
        check_rpc_sampling(Path(tmpdirname), 1)
486
487
        check_rpc_get_degree_shuffle(Path(tmpdirname), 1)
        check_rpc_get_degree_shuffle(Path(tmpdirname), 2)
488
489
490
491
492
493
494
        check_rpc_find_edges_shuffle(Path(tmpdirname), 2)
        check_rpc_find_edges_shuffle(Path(tmpdirname), 1)
        check_rpc_in_subgraph_shuffle(Path(tmpdirname), 2)
        check_rpc_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_sampling_shuffle(Path(tmpdirname), 2)
        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), 2)