test_distributed_sampling.py 10 KB
Newer Older
Jinjing Zhou's avatar
Jinjing Zhou committed
1
2
3
4
import dgl
import unittest
import os
from dgl.data import CitationGraphDataset
5
from dgl.distributed import sample_neighbors, find_edges
Jinjing Zhou's avatar
Jinjing Zhou committed
6
7
8
9
10
11
12
13
from dgl.distributed import partition_graph, load_partition, load_partition_book
import sys
import multiprocessing as mp
import numpy as np
import backend as F
import time
from utils import get_local_usable_addr
from pathlib import Path
14
import pytest
Jinjing Zhou's avatar
Jinjing Zhou committed
15
16
17
from dgl.distributed import DistGraphServer, DistGraph


18
def start_server(rank, tmpdir, disable_shared_mem, graph_name):
19
    g = DistGraphServer(rank, "rpc_ip_config.txt", 1, 1,
20
                        tmpdir / (graph_name + '.json'), disable_shared_mem=disable_shared_mem)
Jinjing Zhou's avatar
Jinjing Zhou committed
21
22
23
    g.start()


24
def start_sample_client(rank, tmpdir, disable_shared_mem):
25
26
    gpb = None
    if disable_shared_mem:
27
        _, _, _, gpb, _ = load_partition(tmpdir / 'test_sampling.json', rank)
28
    dgl.distributed.initialize("rpc_ip_config.txt", 1)
29
    dist_graph = DistGraph("test_sampling", gpb=gpb)
30
31
32
33
34
    try:
        sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)
    except Exception as e:
        print(e)
        sampled_graph = None
35
    dgl.distributed.exit_client()
Jinjing Zhou's avatar
Jinjing Zhou committed
36
37
    return sampled_graph

38
39
40
41
def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids):
    gpb = None
    if disable_shared_mem:
        _, _, _, gpb, _ = load_partition(tmpdir / 'test_find_edges.json', rank)
42
    dgl.distributed.initialize("rpc_ip_config.txt", 1)
43
    dist_graph = DistGraph("test_find_edges", gpb=gpb)
44
45
46
47
48
    try:
        u, v = find_edges(dist_graph, eids)
    except Exception as e:
        print(e)
        u, v = None, None
49
50
    dgl.distributed.exit_client()
    return u, v
Jinjing Zhou's avatar
Jinjing Zhou committed
51

52
def check_rpc_sampling(tmpdir, num_server):
53
    ip_config = open("rpc_ip_config.txt", "w")
Jinjing Zhou's avatar
Jinjing Zhou committed
54
    for _ in range(num_server):
55
        ip_config.write('{}\n'.format(get_local_usable_addr()))
Jinjing Zhou's avatar
Jinjing Zhou committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    ip_config.close()

    g = CitationGraphDataset("cora")[0]
    g.readonly()
    print(g.idtype)
    num_parts = num_server
    num_hops = 1

    partition_graph(g, 'test_sampling', num_parts, tmpdir,
                    num_hops=num_hops, part_method='metis', reshuffle=False)

    pserver_list = []
    ctx = mp.get_context('spawn')
    for i in range(num_server):
70
        p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_sampling'))
Jinjing Zhou's avatar
Jinjing Zhou committed
71
72
73
74
75
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    time.sleep(3)
76
    sampled_graph = start_sample_client(0, tmpdir, num_server > 1)
Jinjing Zhou's avatar
Jinjing Zhou committed
77
78
79
80
81
82
83
84
85
86
87
    print("Done sampling")
    for p in pserver_list:
        p.join()

    src, dst = sampled_graph.edges()
    assert sampled_graph.number_of_nodes() == g.number_of_nodes()
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    assert np.array_equal(
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids))

88
89
90
def check_rpc_find_edges(tmpdir, num_server):
    ip_config = open("rpc_ip_config.txt", "w")
    for _ in range(num_server):
91
        ip_config.write('{}\n'.format(get_local_usable_addr()))
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    ip_config.close()

    g = CitationGraphDataset("cora")[0]
    g.readonly()
    num_parts = num_server

    partition_graph(g, 'test_find_edges', num_parts, tmpdir,
                    num_hops=1, part_method='metis', reshuffle=False)

    pserver_list = []
    ctx = mp.get_context('spawn')
    for i in range(num_server):
        p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_find_edges'))
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    time.sleep(3)
    eids = F.tensor(np.random.randint(g.number_of_edges(), size=100))
    u, v = g.find_edges(eids)
    du, dv = start_find_edges_client(0, tmpdir, num_server > 1, eids)
    assert F.array_equal(u, du)
    assert F.array_equal(v, dv)

Jinjing Zhou's avatar
Jinjing Zhou committed
116
117
118
119
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
def test_rpc_sampling():
    import tempfile
120
    os.environ['DGL_DIST_MODE'] = 'distributed'
Jinjing Zhou's avatar
Jinjing Zhou committed
121
    with tempfile.TemporaryDirectory() as tmpdirname:
122
        check_rpc_sampling(Path(tmpdirname), 2)
Jinjing Zhou's avatar
Jinjing Zhou committed
123

124
def check_rpc_sampling_shuffle(tmpdir, num_server):
125
    ip_config = open("rpc_ip_config.txt", "w")
Jinjing Zhou's avatar
Jinjing Zhou committed
126
    for _ in range(num_server):
127
        ip_config.write('{}\n'.format(get_local_usable_addr()))
Jinjing Zhou's avatar
Jinjing Zhou committed
128
    ip_config.close()
129

Jinjing Zhou's avatar
Jinjing Zhou committed
130
131
132
133
134
135
136
137
138
139
140
    g = CitationGraphDataset("cora")[0]
    g.readonly()
    num_parts = num_server
    num_hops = 1

    partition_graph(g, 'test_sampling', num_parts, tmpdir,
                    num_hops=num_hops, part_method='metis', reshuffle=True)

    pserver_list = []
    ctx = mp.get_context('spawn')
    for i in range(num_server):
141
        p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_sampling'))
Jinjing Zhou's avatar
Jinjing Zhou committed
142
143
144
145
146
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    time.sleep(3)
147
    sampled_graph = start_sample_client(0, tmpdir, num_server > 1)
Jinjing Zhou's avatar
Jinjing Zhou committed
148
149
150
151
152
153
154
    print("Done sampling")
    for p in pserver_list:
        p.join()

    orig_nid = F.zeros((g.number_of_nodes(),), dtype=F.int64)
    orig_eid = F.zeros((g.number_of_edges(),), dtype=F.int64)
    for i in range(num_server):
155
        part, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
Jinjing Zhou's avatar
Jinjing Zhou committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
        orig_eid[part.edata[dgl.EID]] = part.edata['orig_id']

    src, dst = sampled_graph.edges()
    src = orig_nid[src]
    dst = orig_nid[dst]
    assert sampled_graph.number_of_nodes() == g.number_of_nodes()
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
    assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))

# Wait non shared memory graph store
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
171
172
@pytest.mark.parametrize("num_server", [1, 2])
def test_rpc_sampling_shuffle(num_server):
Jinjing Zhou's avatar
Jinjing Zhou committed
173
    import tempfile
174
    os.environ['DGL_DIST_MODE'] = 'distributed'
Jinjing Zhou's avatar
Jinjing Zhou committed
175
    with tempfile.TemporaryDirectory() as tmpdirname:
176
        check_rpc_sampling_shuffle(Path(tmpdirname), num_server)
Jinjing Zhou's avatar
Jinjing Zhou committed
177

178
179
180
181
182
183
def check_standalone_sampling(tmpdir):
    g = CitationGraphDataset("cora")[0]
    num_parts = 1
    num_hops = 1
    partition_graph(g, 'test_sampling', num_parts, tmpdir,
                    num_hops=num_hops, part_method='metis', reshuffle=False)
184

185
    os.environ['DGL_DIST_MODE'] = 'standalone'
186
    dgl.distributed.initialize("rpc_ip_config.txt", 1)
187
    dist_graph = DistGraph("test_sampling", part_config=tmpdir / 'test_sampling.json')
188
189
190
191
192
193
194
195
    sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)

    src, dst = sampled_graph.edges()
    assert sampled_graph.number_of_nodes() == g.number_of_nodes()
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    assert np.array_equal(
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids))
196
    dgl.distributed.exit_client()
197
198
199
200
201
202
203
204
205

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
def test_standalone_sampling():
    import tempfile
    os.environ['DGL_DIST_MODE'] = 'standalone'
    with tempfile.TemporaryDirectory() as tmpdirname:
        check_standalone_sampling(Path(tmpdirname))

206
207
def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes):
    gpb = None
208
    dgl.distributed.initialize("rpc_ip_config.txt", 1)
209
    if disable_shared_mem:
210
        _, _, _, gpb, _ = load_partition(tmpdir / 'test_in_subgraph.json', rank)
211
    dist_graph = DistGraph("test_in_subgraph", gpb=gpb)
212
213
214
215
216
    try:
        sampled_graph = dgl.distributed.in_subgraph(dist_graph, nodes)
    except Exception as e:
        print(e)
        sampled_graph = None
217
    dgl.distributed.exit_client()
218
219
220
221
222
223
    return sampled_graph


def check_rpc_in_subgraph(tmpdir, num_server):
    ip_config = open("rpc_ip_config.txt", "w")
    for _ in range(num_server):
224
        ip_config.write('{}\n'.format(get_local_usable_addr()))
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
    ip_config.close()

    g = CitationGraphDataset("cora")[0]
    g.readonly()
    num_parts = num_server

    partition_graph(g, 'test_in_subgraph', num_parts, tmpdir,
                    num_hops=1, part_method='metis', reshuffle=False)

    pserver_list = []
    ctx = mp.get_context('spawn')
    for i in range(num_server):
        p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_in_subgraph'))
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    nodes = [0, 10, 99, 66, 1024, 2008]
    time.sleep(3)
    sampled_graph = start_in_subgraph_client(0, tmpdir, num_server > 1, nodes)
    for p in pserver_list:
        p.join()

    src, dst = sampled_graph.edges()
    assert sampled_graph.number_of_nodes() == g.number_of_nodes()
    subg1 = dgl.in_subgraph(g, nodes)
    src1, dst1 = subg1.edges()
    assert np.all(np.sort(F.asnumpy(src)) == np.sort(F.asnumpy(src1)))
    assert np.all(np.sort(F.asnumpy(dst)) == np.sort(F.asnumpy(dst1)))
    eids = g.edge_ids(src, dst)
    assert np.array_equal(
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids))

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
def test_rpc_in_subgraph():
    import tempfile
262
    os.environ['DGL_DIST_MODE'] = 'distributed'
263
264
265
    with tempfile.TemporaryDirectory() as tmpdirname:
        check_rpc_in_subgraph(Path(tmpdirname), 2)

Jinjing Zhou's avatar
Jinjing Zhou committed
266
267
268
if __name__ == "__main__":
    import tempfile
    with tempfile.TemporaryDirectory() as tmpdirname:
269
270
271
        os.environ['DGL_DIST_MODE'] = 'standalone'
        check_standalone_sampling(Path(tmpdirname))
        os.environ['DGL_DIST_MODE'] = 'distributed'
272
        check_rpc_in_subgraph(Path(tmpdirname), 2)
273
274
275
276
        check_rpc_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_sampling_shuffle(Path(tmpdirname), 2)
        check_rpc_sampling(Path(tmpdirname), 2)
        check_rpc_sampling(Path(tmpdirname), 1)
277
        check_rpc_find_edges(Path(tmpdirname), 2)
278
        check_rpc_find_edges(Path(tmpdirname), 1)