test_distributed_sampling.py 19 KB
Newer Older
Jinjing Zhou's avatar
Jinjing Zhou committed
1
2
3
4
import dgl
import unittest
import os
from dgl.data import CitationGraphDataset
5
from dgl.distributed import sample_neighbors
Jinjing Zhou's avatar
Jinjing Zhou committed
6
7
8
9
10
11
12
13
from dgl.distributed import partition_graph, load_partition, load_partition_book
import sys
import multiprocessing as mp
import numpy as np
import backend as F
import time
from utils import get_local_usable_addr
from pathlib import Path
14
import pytest
15
from scipy import sparse as spsp
Jinjing Zhou's avatar
Jinjing Zhou committed
16
17
18
from dgl.distributed import DistGraphServer, DistGraph


19
def start_server(rank, tmpdir, disable_shared_mem, graph_name):
20
    g = DistGraphServer(rank, "rpc_ip_config.txt", 1, 1,
21
                        tmpdir / (graph_name + '.json'), disable_shared_mem=disable_shared_mem)
Jinjing Zhou's avatar
Jinjing Zhou committed
22
23
24
    g.start()


25
def start_sample_client(rank, tmpdir, disable_shared_mem):
26
27
    gpb = None
    if disable_shared_mem:
28
        _, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_sampling.json', rank)
29
    dgl.distributed.initialize("rpc_ip_config.txt")
30
    dist_graph = DistGraph("test_sampling", gpb=gpb)
31
32
33
34
35
    try:
        sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)
    except Exception as e:
        print(e)
        sampled_graph = None
36
    dgl.distributed.exit_client()
Jinjing Zhou's avatar
Jinjing Zhou committed
37
38
    return sampled_graph

39
40
41
def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids):
    gpb = None
    if disable_shared_mem:
42
        _, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_find_edges.json', rank)
43
    dgl.distributed.initialize("rpc_ip_config.txt")
44
    dist_graph = DistGraph("test_find_edges", gpb=gpb)
45
    try:
46
        u, v = dist_graph.find_edges(eids)
47
48
49
    except Exception as e:
        print(e)
        u, v = None, None
50
51
    dgl.distributed.exit_client()
    return u, v
Jinjing Zhou's avatar
Jinjing Zhou committed
52

53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def start_get_degrees_client(rank, tmpdir, disable_shared_mem, nids=None):
    gpb = None
    if disable_shared_mem:
        _, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_get_degrees.json', rank)
    dgl.distributed.initialize("rpc_ip_config.txt", 1)
    dist_graph = DistGraph("test_get_degrees", gpb=gpb)
    try:
        in_deg = dist_graph.in_degrees(nids)
        all_in_deg = dist_graph.in_degrees()
        out_deg = dist_graph.out_degrees(nids)
        all_out_deg = dist_graph.out_degrees()
    except Exception as e:
        print(e)
        in_deg, out_deg, all_in_deg, all_out_deg = None, None, None, None
    dgl.distributed.exit_client()
    return in_deg, out_deg, all_in_deg, all_out_deg

70
def check_rpc_sampling(tmpdir, num_server):
71
    ip_config = open("rpc_ip_config.txt", "w")
Jinjing Zhou's avatar
Jinjing Zhou committed
72
    for _ in range(num_server):
73
        ip_config.write('{}\n'.format(get_local_usable_addr()))
Jinjing Zhou's avatar
Jinjing Zhou committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    ip_config.close()

    g = CitationGraphDataset("cora")[0]
    g.readonly()
    print(g.idtype)
    num_parts = num_server
    num_hops = 1

    partition_graph(g, 'test_sampling', num_parts, tmpdir,
                    num_hops=num_hops, part_method='metis', reshuffle=False)

    pserver_list = []
    ctx = mp.get_context('spawn')
    for i in range(num_server):
88
        p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_sampling'))
Jinjing Zhou's avatar
Jinjing Zhou committed
89
90
91
92
93
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    time.sleep(3)
94
    sampled_graph = start_sample_client(0, tmpdir, num_server > 1)
Jinjing Zhou's avatar
Jinjing Zhou committed
95
96
97
98
99
100
101
102
103
104
105
    print("Done sampling")
    for p in pserver_list:
        p.join()

    src, dst = sampled_graph.edges()
    assert sampled_graph.number_of_nodes() == g.number_of_nodes()
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    assert np.array_equal(
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids))

106
def check_rpc_find_edges_shuffle(tmpdir, num_server):
107
108
    ip_config = open("rpc_ip_config.txt", "w")
    for _ in range(num_server):
109
        ip_config.write('{}\n'.format(get_local_usable_addr()))
110
111
112
113
114
115
116
    ip_config.close()

    g = CitationGraphDataset("cora")[0]
    g.readonly()
    num_parts = num_server

    partition_graph(g, 'test_find_edges', num_parts, tmpdir,
117
                    num_hops=1, part_method='metis', reshuffle=True)
118
119
120
121
122
123
124
125
126

    pserver_list = []
    ctx = mp.get_context('spawn')
    for i in range(num_server):
        p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_find_edges'))
        p.start()
        time.sleep(1)
        pserver_list.append(p)

127
128
    orig_nid = F.zeros((g.number_of_nodes(),), dtype=F.int64, ctx=F.cpu())
    orig_eid = F.zeros((g.number_of_edges(),), dtype=F.int64, ctx=F.cpu())
129
130
131
132
133
    for i in range(num_server):
        part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_find_edges.json', i)
        orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
        orig_eid[part.edata[dgl.EID]] = part.edata['orig_id']

134
135
    time.sleep(3)
    eids = F.tensor(np.random.randint(g.number_of_edges(), size=100))
136
    u, v = g.find_edges(orig_eid[eids])
137
    du, dv = start_find_edges_client(0, tmpdir, num_server > 1, eids)
138
139
    du = orig_nid[du]
    dv = orig_nid[dv]
140
141
142
    assert F.array_equal(u, du)
    assert F.array_equal(v, dv)

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
# Wait non shared memory graph store
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
@pytest.mark.parametrize("num_server", [1, 2])
def test_rpc_find_edges_shuffle(num_server):
    import tempfile
    os.environ['DGL_DIST_MODE'] = 'distributed'
    with tempfile.TemporaryDirectory() as tmpdirname:
        check_rpc_find_edges_shuffle(Path(tmpdirname), num_server)

def check_rpc_get_degree_shuffle(tmpdir, num_server):
    ip_config = open("rpc_ip_config.txt", "w")
    for _ in range(num_server):
        ip_config.write('{}\n'.format(get_local_usable_addr()))
    ip_config.close()

    g = CitationGraphDataset("cora")[0]
    g.readonly()
    num_parts = num_server

    partition_graph(g, 'test_get_degrees', num_parts, tmpdir,
                    num_hops=1, part_method='metis', reshuffle=True)

    pserver_list = []
    ctx = mp.get_context('spawn')
    for i in range(num_server):
        p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_get_degrees'))
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    orig_nid = F.zeros((g.number_of_nodes(),), dtype=F.int64, ctx=F.cpu())
    for i in range(num_server):
        part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_get_degrees.json', i)
        orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
    time.sleep(3)

    nids = F.tensor(np.random.randint(g.number_of_nodes(), size=100))
    in_degs, out_degs, all_in_degs, all_out_degs = start_get_degrees_client(0, tmpdir, num_server > 1, nids)

    print("Done get_degree")
    for p in pserver_list:
        p.join()

    print('check results')
    assert F.array_equal(g.in_degrees(orig_nid[nids]), in_degs)
    assert F.array_equal(g.in_degrees(orig_nid), all_in_degs)
    assert F.array_equal(g.out_degrees(orig_nid[nids]), out_degs)
    assert F.array_equal(g.out_degrees(orig_nid), all_out_degs)

# Wait non shared memory graph store
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
@pytest.mark.parametrize("num_server", [1, 2])
def test_rpc_get_degree_shuffle(num_server):
    import tempfile
    os.environ['DGL_DIST_MODE'] = 'distributed'
    with tempfile.TemporaryDirectory() as tmpdirname:
        check_rpc_get_degree_shuffle(Path(tmpdirname), num_server)

203
204
205
#@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
#@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
@unittest.skip('Only support partition with shuffle')
Jinjing Zhou's avatar
Jinjing Zhou committed
206
207
def test_rpc_sampling():
    import tempfile
208
    os.environ['DGL_DIST_MODE'] = 'distributed'
Jinjing Zhou's avatar
Jinjing Zhou committed
209
    with tempfile.TemporaryDirectory() as tmpdirname:
210
        check_rpc_sampling(Path(tmpdirname), 2)
Jinjing Zhou's avatar
Jinjing Zhou committed
211

212
def check_rpc_sampling_shuffle(tmpdir, num_server):
213
    ip_config = open("rpc_ip_config.txt", "w")
Jinjing Zhou's avatar
Jinjing Zhou committed
214
    for _ in range(num_server):
215
        ip_config.write('{}\n'.format(get_local_usable_addr()))
Jinjing Zhou's avatar
Jinjing Zhou committed
216
    ip_config.close()
217

Jinjing Zhou's avatar
Jinjing Zhou committed
218
219
220
221
222
223
224
225
226
227
228
    g = CitationGraphDataset("cora")[0]
    g.readonly()
    num_parts = num_server
    num_hops = 1

    partition_graph(g, 'test_sampling', num_parts, tmpdir,
                    num_hops=num_hops, part_method='metis', reshuffle=True)

    pserver_list = []
    ctx = mp.get_context('spawn')
    for i in range(num_server):
229
        p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_sampling'))
Jinjing Zhou's avatar
Jinjing Zhou committed
230
231
232
233
234
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    time.sleep(3)
235
    sampled_graph = start_sample_client(0, tmpdir, num_server > 1)
Jinjing Zhou's avatar
Jinjing Zhou committed
236
237
238
239
    print("Done sampling")
    for p in pserver_list:
        p.join()

240
241
    orig_nid = F.zeros((g.number_of_nodes(),), dtype=F.int64, ctx=F.cpu())
    orig_eid = F.zeros((g.number_of_edges(),), dtype=F.int64, ctx=F.cpu())
Jinjing Zhou's avatar
Jinjing Zhou committed
242
    for i in range(num_server):
243
        part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
Jinjing Zhou's avatar
Jinjing Zhou committed
244
245
246
247
248
249
250
251
252
253
254
255
        orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
        orig_eid[part.edata[dgl.EID]] = part.edata['orig_id']

    src, dst = sampled_graph.edges()
    src = orig_nid[src]
    dst = orig_nid[dst]
    assert sampled_graph.number_of_nodes() == g.number_of_nodes()
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
    assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))

256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
def create_random_hetero():
    num_nodes = {'n1': 1010, 'n2': 1000, 'n3': 1020}
    etypes = [('n1', 'r1', 'n2'),
              ('n1', 'r2', 'n3'),
              ('n2', 'r3', 'n3')]
    edges = {}
    for etype in etypes:
        src_ntype, _, dst_ntype = etype
        arr = spsp.random(num_nodes[src_ntype], num_nodes[dst_ntype], density=0.001, format='coo',
                          random_state=100)
        edges[etype] = (arr.row, arr.col)
    g = dgl.heterograph(edges, num_nodes)
    g.nodes['n1'].data['feat'] = F.ones((g.number_of_nodes('n1'), 10), F.float32, F.cpu())
    return g

def start_hetero_sample_client(rank, tmpdir, disable_shared_mem):
    gpb = None
    if disable_shared_mem:
        _, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_sampling.json', rank)
275
    dgl.distributed.initialize("rpc_ip_config.txt")
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
    dist_graph = DistGraph("test_sampling", gpb=gpb)
    assert 'feat' in dist_graph.nodes['n1'].data
    assert 'feat' not in dist_graph.nodes['n2'].data
    assert 'feat' not in dist_graph.nodes['n3'].data
    if gpb is None:
        gpb = dist_graph.get_partition_book()
    try:
        nodes = {'n3': [0, 10, 99, 66, 124, 208]}
        sampled_graph = sample_neighbors(dist_graph, nodes, 3)
        nodes = gpb.map_to_homo_nid(nodes['n3'], 'n3')
        block = dgl.to_block(sampled_graph, nodes)
        block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
    except Exception as e:
        print(e)
        block = None
    dgl.distributed.exit_client()
    return block, gpb

def check_rpc_hetero_sampling_shuffle(tmpdir, num_server):
    ip_config = open("rpc_ip_config.txt", "w")
    for _ in range(num_server):
        ip_config.write('{}\n'.format(get_local_usable_addr()))
    ip_config.close()

    g = create_random_hetero()
    num_parts = num_server
    num_hops = 1

    partition_graph(g, 'test_sampling', num_parts, tmpdir,
                    num_hops=num_hops, part_method='metis', reshuffle=True)

    pserver_list = []
    ctx = mp.get_context('spawn')
    for i in range(num_server):
        p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_sampling'))
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    time.sleep(3)
    block, gpb = start_hetero_sample_client(0, tmpdir, num_server > 1)
    print("Done sampling")
    for p in pserver_list:
        p.join()

    orig_nid_map = F.zeros((g.number_of_nodes(),), dtype=F.int64)
    orig_eid_map = F.zeros((g.number_of_edges(),), dtype=F.int64)
    for i in range(num_server):
        part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
        F.scatter_row_inplace(orig_nid_map, part.ndata[dgl.NID], part.ndata['orig_id'])
        F.scatter_row_inplace(orig_eid_map, part.edata[dgl.EID], part.edata['orig_id'])

    src, dst = block.edges()
    # These are global Ids after shuffling.
    shuffled_src = F.gather_row(block.srcdata[dgl.NID], src)
    shuffled_dst = F.gather_row(block.dstdata[dgl.NID], dst)
    shuffled_eid = block.edata[dgl.EID]
    # Get node/edge types.
    etype, _ = gpb.map_to_per_etype(shuffled_eid)
    src_type, _ = gpb.map_to_per_ntype(shuffled_src)
    dst_type, _ = gpb.map_to_per_ntype(shuffled_dst)
    etype = F.asnumpy(etype)
    src_type = F.asnumpy(src_type)
    dst_type = F.asnumpy(dst_type)
    # These are global Ids in the original graph.
    orig_src = F.asnumpy(F.gather_row(orig_nid_map, shuffled_src))
    orig_dst = F.asnumpy(F.gather_row(orig_nid_map, shuffled_dst))
    orig_eid = F.asnumpy(F.gather_row(orig_eid_map, shuffled_eid))

    etype_map = {g.get_etype_id(etype):etype for etype in g.etypes}
    etype_to_eptype = {g.get_etype_id(etype):(src_ntype, dst_ntype) for src_ntype, etype, dst_ntype in g.canonical_etypes}
    for e in np.unique(etype):
        src_t = src_type[etype == e]
        dst_t = dst_type[etype == e]
        assert np.all(src_t == src_t[0])
        assert np.all(dst_t == dst_t[0])

        # Check the node Ids and edge Ids.
        orig_src1, orig_dst1 = g.find_edges(orig_eid[etype == e], etype=etype_map[e])
        assert np.all(F.asnumpy(orig_src1) == orig_src[etype == e])
        assert np.all(F.asnumpy(orig_dst1) == orig_dst[etype == e])

        # Check the node types.
        src_ntype, dst_ntype = etype_to_eptype[e]
        assert np.all(src_t == g.get_ntype_id(src_ntype))
        assert np.all(dst_t == g.get_ntype_id(dst_ntype))

Jinjing Zhou's avatar
Jinjing Zhou committed
363
364
365
# Wait non shared memory graph store
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
366
367
@pytest.mark.parametrize("num_server", [1, 2])
def test_rpc_sampling_shuffle(num_server):
Jinjing Zhou's avatar
Jinjing Zhou committed
368
    import tempfile
369
    os.environ['DGL_DIST_MODE'] = 'distributed'
Jinjing Zhou's avatar
Jinjing Zhou committed
370
    with tempfile.TemporaryDirectory() as tmpdirname:
371
        check_rpc_sampling_shuffle(Path(tmpdirname), num_server)
372
        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), num_server)
Jinjing Zhou's avatar
Jinjing Zhou committed
373

374
def check_standalone_sampling(tmpdir, reshuffle):
375
376
377
378
    g = CitationGraphDataset("cora")[0]
    num_parts = 1
    num_hops = 1
    partition_graph(g, 'test_sampling', num_parts, tmpdir,
379
                    num_hops=num_hops, part_method='metis', reshuffle=reshuffle)
380

381
    os.environ['DGL_DIST_MODE'] = 'standalone'
382
    dgl.distributed.initialize("rpc_ip_config.txt")
383
    dist_graph = DistGraph("test_sampling", part_config=tmpdir / 'test_sampling.json')
384
385
386
387
388
389
390
391
    sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)

    src, dst = sampled_graph.edges()
    assert sampled_graph.number_of_nodes() == g.number_of_nodes()
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    assert np.array_equal(
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids))
392
    dgl.distributed.exit_client()
393
394
395
396
397
398
399

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
def test_standalone_sampling():
    import tempfile
    os.environ['DGL_DIST_MODE'] = 'standalone'
    with tempfile.TemporaryDirectory() as tmpdirname:
400
401
        check_standalone_sampling(Path(tmpdirname), False)
        check_standalone_sampling(Path(tmpdirname), True)
402

403
404
def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes):
    gpb = None
405
    dgl.distributed.initialize("rpc_ip_config.txt")
406
    if disable_shared_mem:
407
        _, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_in_subgraph.json', rank)
408
    dist_graph = DistGraph("test_in_subgraph", gpb=gpb)
409
410
411
412
413
    try:
        sampled_graph = dgl.distributed.in_subgraph(dist_graph, nodes)
    except Exception as e:
        print(e)
        sampled_graph = None
414
    dgl.distributed.exit_client()
415
416
417
    return sampled_graph


418
def check_rpc_in_subgraph_shuffle(tmpdir, num_server):
419
420
    ip_config = open("rpc_ip_config.txt", "w")
    for _ in range(num_server):
421
        ip_config.write('{}\n'.format(get_local_usable_addr()))
422
423
424
425
426
427
428
    ip_config.close()

    g = CitationGraphDataset("cora")[0]
    g.readonly()
    num_parts = num_server

    partition_graph(g, 'test_in_subgraph', num_parts, tmpdir,
429
                    num_hops=1, part_method='metis', reshuffle=True)
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444

    pserver_list = []
    ctx = mp.get_context('spawn')
    for i in range(num_server):
        p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_in_subgraph'))
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    nodes = [0, 10, 99, 66, 1024, 2008]
    time.sleep(3)
    sampled_graph = start_in_subgraph_client(0, tmpdir, num_server > 1, nodes)
    for p in pserver_list:
        p.join()

445

446
447
    orig_nid = F.zeros((g.number_of_nodes(),), dtype=F.int64, ctx=F.cpu())
    orig_eid = F.zeros((g.number_of_edges(),), dtype=F.int64, ctx=F.cpu())
448
449
450
451
452
    for i in range(num_server):
        part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_in_subgraph.json', i)
        orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
        orig_eid[part.edata[dgl.EID]] = part.edata['orig_id']

453
    src, dst = sampled_graph.edges()
454
455
    src = orig_nid[src]
    dst = orig_nid[dst]
456
    assert sampled_graph.number_of_nodes() == g.number_of_nodes()
457
458
459
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))

    subg1 = dgl.in_subgraph(g, orig_nid[nodes])
460
461
462
463
    src1, dst1 = subg1.edges()
    assert np.all(np.sort(F.asnumpy(src)) == np.sort(F.asnumpy(src1)))
    assert np.all(np.sort(F.asnumpy(dst)) == np.sort(F.asnumpy(dst1)))
    eids = g.edge_ids(src, dst)
464
465
    eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
    assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))
466
467
468
469
470

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
def test_rpc_in_subgraph():
    import tempfile
471
    os.environ['DGL_DIST_MODE'] = 'distributed'
472
    with tempfile.TemporaryDirectory() as tmpdirname:
473
        check_rpc_in_subgraph_shuffle(Path(tmpdirname), 2)
474

Jinjing Zhou's avatar
Jinjing Zhou committed
475
476
477
if __name__ == "__main__":
    import tempfile
    with tempfile.TemporaryDirectory() as tmpdirname:
478
        os.environ['DGL_DIST_MODE'] = 'standalone'
479
480
        check_standalone_sampling(Path(tmpdirname), True)
        check_standalone_sampling(Path(tmpdirname), False)
481
        os.environ['DGL_DIST_MODE'] = 'distributed'
482
483
        check_rpc_sampling(Path(tmpdirname), 2)
        check_rpc_sampling(Path(tmpdirname), 1)
484
485
        check_rpc_get_degree_shuffle(Path(tmpdirname), 1)
        check_rpc_get_degree_shuffle(Path(tmpdirname), 2)
486
487
488
489
490
491
492
        check_rpc_find_edges_shuffle(Path(tmpdirname), 2)
        check_rpc_find_edges_shuffle(Path(tmpdirname), 1)
        check_rpc_in_subgraph_shuffle(Path(tmpdirname), 2)
        check_rpc_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_sampling_shuffle(Path(tmpdirname), 2)
        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), 2)