test_rpc.py 12.3 KB
Newer Older
1
import multiprocessing as mp
2
import os
3
import socket
4
5
import time
import unittest
6
7

import backend as F
8
9

import dgl
10
import pytest
11
from numpy.testing import assert_array_equal
12
from utils import generate_ip_config, reset_envs
13

14
if os.name != "nt":
15
16
17
    import fcntl
    import struct

18
INTEGER = 2
19
STR = "hello world!"
20
HELLO_SERVICE_ID = 901231
21
TENSOR = F.zeros((1000, 1000), F.int64, F.cpu())
22

23

24
25
26
27
def foo(x, y):
    assert x == 123
    assert y == "abc"

28

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class MyRequest(dgl.distributed.Request):
    def __init__(self):
        self.x = 123
        self.y = "abc"
        self.z = F.randn((3, 4))
        self.foo = foo

    def __getstate__(self):
        return self.x, self.y, self.z, self.foo

    def __setstate__(self, state):
        self.x, self.y, self.z, self.foo = state

    def process_request(self, server_state):
        pass

45

46
47
48
49
50
51
52
53
54
class MyResponse(dgl.distributed.Response):
    def __init__(self):
        self.x = 432

    def __getstate__(self):
        return self.x

    def __setstate__(self, state):
        self.x = state
55
56


57
58
59
def simple_func(tensor):
    return tensor

60

61
62
63
64
65
66
67
68
69
70
71
72
class HelloResponse(dgl.distributed.Response):
    def __init__(self, hello_str, integer, tensor):
        self.hello_str = hello_str
        self.integer = integer
        self.tensor = tensor

    def __getstate__(self):
        return self.hello_str, self.integer, self.tensor

    def __setstate__(self, state):
        self.hello_str, self.integer, self.tensor = state

73

74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class HelloRequest(dgl.distributed.Request):
    def __init__(self, hello_str, integer, tensor, func):
        self.hello_str = hello_str
        self.integer = integer
        self.tensor = tensor
        self.func = func

    def __getstate__(self):
        return self.hello_str, self.integer, self.tensor, self.func

    def __setstate__(self, state):
        self.hello_str, self.integer, self.tensor, self.func = state

    def process_request(self, server_state):
        assert self.hello_str == STR
        assert self.integer == INTEGER
        new_tensor = self.func(self.tensor)
        res = HelloResponse(self.hello_str, self.integer, new_tensor)
        return res

94
95

TIMEOUT_SERVICE_ID = 123456789
96
TIMEOUT_META = "timeout_test"
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124


class TimeoutResponse(dgl.distributed.Response):
    def __init__(self, meta):
        self.meta = meta

    def __getstate__(self):
        return self.meta

    def __setstate__(self, state):
        self.meta = state


class TimeoutRequest(dgl.distributed.Request):
    def __init__(self, meta, timeout, response=True):
        self.meta = meta
        self.timeout = timeout
        self.response = response

    def __getstate__(self):
        return self.meta, self.timeout, self.response

    def __setstate__(self, state):
        self.meta, self.timeout, self.response = state

    def process_request(self, server_state):
        assert self.meta == TIMEOUT_META
        # convert from milliseconds to seconds
125
        time.sleep(self.timeout / 1000)
126
127
128
129
130
        if not self.response:
            return None
        res = TimeoutResponse(self.meta)
        return res

131
132
133
134
135
136
137

def start_server(
    num_clients,
    ip_config,
    server_id=0,
    num_servers=1,
):
138
139
140
    print("Sleep 1 seconds to test client re-connect.")
    time.sleep(1)
    server_state = dgl.distributed.ServerState(
141
        None, local_g=None, partition_book=None
142
    )
143
    dgl.distributed.register_service(
144
145
        HELLO_SERVICE_ID, HelloRequest, HelloResponse
    )
146
    dgl.distributed.register_service(
147
148
        TIMEOUT_SERVICE_ID, TimeoutRequest, TimeoutResponse
    )
149
    print("Start server {}".format(server_id))
150
151
152
153
154
155
156
157
158
    dgl.distributed.start_server(
        server_id=server_id,
        ip_config=ip_config,
        num_servers=num_servers,
        num_clients=num_clients,
        server_state=server_state,
    )


159
def start_client(ip_config, group_id=0, num_servers=1):
160
161
162
    dgl.distributed.register_service(
        HELLO_SERVICE_ID, HelloRequest, HelloResponse
    )
163
    dgl.distributed.connect_to_server(
164
165
166
167
        ip_config=ip_config,
        num_servers=num_servers,
        group_id=group_id,
    )
168
169
170
171
172
173
174
175
176
177
178
179
    req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
    # test send and recv
    dgl.distributed.send_request(0, req)
    res = dgl.distributed.recv_response()
    assert res.hello_str == STR
    assert res.integer == INTEGER
    assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
    # test remote_call
    target_and_requests = []
    for i in range(10):
        target_and_requests.append((0, req))
    res_list = dgl.distributed.remote_call(target_and_requests)
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
    for res in res_list:
        assert res.hello_str == STR
        assert res.integer == INTEGER
        assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
    # test send_request_to_machine
    dgl.distributed.send_request_to_machine(0, req)
    res = dgl.distributed.recv_response()
    assert res.hello_str == STR
    assert res.integer == INTEGER
    assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
    # test remote_call_to_machine
    target_and_requests = []
    for i in range(10):
        target_and_requests.append((0, req))
    res_list = dgl.distributed.remote_call_to_machine(target_and_requests)
195
196
197
198
    for res in res_list:
        assert res.hello_str == STR
        assert res.integer == INTEGER
        assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
199

200

201
def start_client_timeout(ip_config, group_id=0, num_servers=1):
202
    dgl.distributed.register_service(
203
204
        TIMEOUT_SERVICE_ID, TimeoutRequest, TimeoutResponse
    )
205
    dgl.distributed.connect_to_server(
206
207
208
209
        ip_config=ip_config,
        num_servers=num_servers,
        group_id=group_id,
    )
210
211
212
213
    timeout = 1 * 1000  # milliseconds
    req = TimeoutRequest(TIMEOUT_META, timeout)
    # test send and recv
    dgl.distributed.send_request(0, req)
214
    res = dgl.distributed.recv_response(timeout=int(timeout / 2))
215
216
217
218
219
220
221
222
223
224
225
    assert res is None
    res = dgl.distributed.recv_response()
    assert res.meta == TIMEOUT_META
    # test remote_call
    req = TimeoutRequest(TIMEOUT_META, timeout, response=False)
    target_and_requests = []
    for i in range(3):
        target_and_requests.append((0, req))
    expect_except = False
    try:
        res_list = dgl.distributed.remote_call(
226
227
            target_and_requests, timeout=int(timeout / 2)
        )
228
229
230
231
232
233
    except dgl.DGLError:
        expect_except = True
    assert expect_except
    # test send_request_to_machine
    req = TimeoutRequest(TIMEOUT_META, timeout)
    dgl.distributed.send_request_to_machine(0, req)
234
    res = dgl.distributed.recv_response(timeout=int(timeout / 2))
235
236
237
238
239
240
241
242
243
244
245
    assert res is None
    res = dgl.distributed.recv_response()
    assert res.meta == TIMEOUT_META
    # test remote_call_to_machine
    req = TimeoutRequest(TIMEOUT_META, timeout, response=False)
    target_and_requests = []
    for i in range(3):
        target_and_requests.append((0, req))
    expect_except = False
    try:
        res_list = dgl.distributed.remote_call_to_machine(
246
247
            target_and_requests, timeout=int(timeout / 2)
        )
248
249
250
251
    except dgl.DGLError:
        expect_except = True
    assert expect_except

252
253

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
254
def test_rpc_timeout():
255
    reset_envs()
256
    os.environ["DGL_DIST_MODE"] = "distributed"
257
258
    ip_config = "rpc_ip_config.txt"
    generate_ip_config(ip_config, 1, 1)
259
    ctx = mp.get_context("spawn")
260
    pserver = ctx.Process(target=start_server, args=(1, ip_config, 0, 1))
261
    pclient = ctx.Process(target=start_client_timeout, args=(ip_config, 0, 1))
262
263
264
265
266
    pserver.start()
    pclient.start()
    pserver.join()
    pclient.join()

267

268
def test_serialize():
269
    reset_envs()
270
271
272
273
274
275
    os.environ["DGL_DIST_MODE"] = "distributed"
    from dgl.distributed.rpc import (
        deserialize_from_payload,
        serialize_to_payload,
    )

276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
    SERVICE_ID = 12345
    dgl.distributed.register_service(SERVICE_ID, MyRequest, MyResponse)
    req = MyRequest()
    data, tensors = serialize_to_payload(req)
    req1 = deserialize_from_payload(MyRequest, data, tensors)
    req1.foo(req1.x, req1.y)
    assert req.x == req1.x
    assert req.y == req1.y
    assert F.array_equal(req.z, req1.z)

    res = MyResponse()
    data, tensors = serialize_to_payload(res)
    res1 = deserialize_from_payload(MyResponse, data, tensors)
    assert res.x == res1.x

291

292
def test_rpc_msg():
293
    reset_envs()
294
295
296
    os.environ["DGL_DIST_MODE"] = "distributed"
    from dgl.distributed.rpc import (
        deserialize_from_payload,
297
        RPCMessage,
298
299
300
        serialize_to_payload,
    )

301
302
303
304
305
306
307
308
309
310
311
312
313
    SERVICE_ID = 32452
    dgl.distributed.register_service(SERVICE_ID, MyRequest, MyResponse)
    req = MyRequest()
    data, tensors = serialize_to_payload(req)
    rpcmsg = RPCMessage(SERVICE_ID, 23, 0, 1, data, tensors)
    assert rpcmsg.service_id == SERVICE_ID
    assert rpcmsg.msg_seq == 23
    assert rpcmsg.client_id == 0
    assert rpcmsg.server_id == 1
    assert len(rpcmsg.data) == len(data)
    assert len(rpcmsg.tensors) == 1
    assert F.array_equal(rpcmsg.tensors[0], req.z)

314
315

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
316
def test_multi_client():
317
    reset_envs()
318
    os.environ["DGL_DIST_MODE"] = "distributed"
319
320
    ip_config = "rpc_ip_config_mul_client.txt"
    generate_ip_config(ip_config, 1, 1)
321
    ctx = mp.get_context("spawn")
322
    num_clients = 20
323
324
    pserver = ctx.Process(
        target=start_server,
325
        args=(num_clients, ip_config, 0, 1),
326
    )
327
    pclient_list = []
328
    for i in range(num_clients):
329
        pclient = ctx.Process(target=start_client, args=(ip_config, 0, 1))
330
331
        pclient_list.append(pclient)
    pserver.start()
332
    for i in range(num_clients):
333
        pclient_list[i].start()
334
    for i in range(num_clients):
335
336
337
338
        pclient_list[i].join()
    pserver.join()


339
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
340
def test_multi_thread_rpc():
341
    reset_envs()
342
    os.environ["DGL_DIST_MODE"] = "distributed"
343
    num_servers = 2
344
345
    ip_config = "rpc_ip_config_multithread.txt"
    generate_ip_config(ip_config, num_servers, num_servers)
346
    ctx = mp.get_context("spawn")
347
348
    pserver_list = []
    for i in range(num_servers):
349
        pserver = ctx.Process(target=start_server, args=(1, ip_config, i, 1))
350
351
        pserver.start()
        pserver_list.append(pserver)
352

353
354
    def start_client_multithread(ip_config):
        import threading
355
356

        dgl.distributed.connect_to_server(
357
358
            ip_config=ip_config,
            num_servers=1,
359
360
361
362
363
        )
        dgl.distributed.register_service(
            HELLO_SERVICE_ID, HelloRequest, HelloResponse
        )

364
365
366
        req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
        dgl.distributed.send_request(0, req)

367
        def subthread_call(server_id):
368
            req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
369
            dgl.distributed.send_request(server_id, req)
370

371
372
373
        subthread = threading.Thread(target=subthread_call, args=(1,))
        subthread.start()
        subthread.join()
374

375
376
        res0 = dgl.distributed.recv_response()
        res1 = dgl.distributed.recv_response()
377
        # Order is not guaranteed
378
        assert_array_equal(F.asnumpy(res0.tensor), F.asnumpy(TENSOR))
379
        assert_array_equal(F.asnumpy(res1.tensor), F.asnumpy(TENSOR))
380
381
        dgl.distributed.exit_client()

382
    start_client_multithread(ip_config)
383
384
    pserver.join()

385
386

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
387
def test_multi_client_connect():
388
    reset_envs()
389
    os.environ["DGL_DIST_MODE"] = "distributed"
390
391
    ip_config = "rpc_ip_config_mul_client.txt"
    generate_ip_config(ip_config, 1, 1)
392
    ctx = mp.get_context("spawn")
393
    num_clients = 1
394
395
    pserver = ctx.Process(
        target=start_server,
396
        args=(num_clients, ip_config, 0, 1),
397
    )
398
399

    # small max try times
400
    os.environ["DGL_DIST_MAX_TRY_TIMES"] = "1"
401
402
    expect_except = False
    try:
403
        start_client(ip_config, 0, 1)
404
405
406
407
408
409
    except dgl.distributed.DistConnectError as err:
        print("Expected error: {}".format(err))
        expect_except = True
    assert expect_except

    # large max try times
410
    os.environ["DGL_DIST_MAX_TRY_TIMES"] = "1024"
411
    pclient = ctx.Process(target=start_client, args=(ip_config, 0, 1))
412
413
414
415
416
    pclient.start()
    pserver.start()
    pclient.join()
    pserver.join()
    reset_envs()
417

418
419

if __name__ == "__main__":
420
421
    test_serialize()
    test_rpc_msg()
422
423
    test_multi_client("socket")
    test_multi_client("tesnsorpipe")
424
    test_multi_thread_rpc()
425
    test_multi_client_connect("socket")