test_rpc.py 13.7 KB
Newer Older
1
import multiprocessing as mp
2
import os
3
import socket
4
5
import time
import unittest
6
7

import backend as F
8
9

import dgl
10
import pytest
11
from numpy.testing import assert_array_equal
12
from utils import generate_ip_config, reset_envs
13

14
if os.name != "nt":
15
16
17
    import fcntl
    import struct

18
INTEGER = 2
19
STR = "hello world!"
20
HELLO_SERVICE_ID = 901231
21
TENSOR = F.zeros((1000, 1000), F.int64, F.cpu())
22

23

24
25
26
27
def foo(x, y):
    assert x == 123
    assert y == "abc"

28

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class MyRequest(dgl.distributed.Request):
    def __init__(self):
        self.x = 123
        self.y = "abc"
        self.z = F.randn((3, 4))
        self.foo = foo

    def __getstate__(self):
        return self.x, self.y, self.z, self.foo

    def __setstate__(self, state):
        self.x, self.y, self.z, self.foo = state

    def process_request(self, server_state):
        pass

45

46
47
48
49
50
51
52
53
54
class MyResponse(dgl.distributed.Response):
    def __init__(self):
        self.x = 432

    def __getstate__(self):
        return self.x

    def __setstate__(self, state):
        self.x = state
55
56


57
58
59
def simple_func(tensor):
    return tensor

60

61
62
63
64
65
66
67
68
69
70
71
72
class HelloResponse(dgl.distributed.Response):
    def __init__(self, hello_str, integer, tensor):
        self.hello_str = hello_str
        self.integer = integer
        self.tensor = tensor

    def __getstate__(self):
        return self.hello_str, self.integer, self.tensor

    def __setstate__(self, state):
        self.hello_str, self.integer, self.tensor = state

73

74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class HelloRequest(dgl.distributed.Request):
    def __init__(self, hello_str, integer, tensor, func):
        self.hello_str = hello_str
        self.integer = integer
        self.tensor = tensor
        self.func = func

    def __getstate__(self):
        return self.hello_str, self.integer, self.tensor, self.func

    def __setstate__(self, state):
        self.hello_str, self.integer, self.tensor, self.func = state

    def process_request(self, server_state):
        assert self.hello_str == STR
        assert self.integer == INTEGER
        new_tensor = self.func(self.tensor)
        res = HelloResponse(self.hello_str, self.integer, new_tensor)
        return res

94
95

TIMEOUT_SERVICE_ID = 123456789
96
TIMEOUT_META = "timeout_test"
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124


class TimeoutResponse(dgl.distributed.Response):
    def __init__(self, meta):
        self.meta = meta

    def __getstate__(self):
        return self.meta

    def __setstate__(self, state):
        self.meta = state


class TimeoutRequest(dgl.distributed.Request):
    def __init__(self, meta, timeout, response=True):
        self.meta = meta
        self.timeout = timeout
        self.response = response

    def __getstate__(self):
        return self.meta, self.timeout, self.response

    def __setstate__(self, state):
        self.meta, self.timeout, self.response = state

    def process_request(self, server_state):
        assert self.meta == TIMEOUT_META
        # convert from milliseconds to seconds
125
        time.sleep(self.timeout / 1000)
126
127
128
129
130
        if not self.response:
            return None
        res = TimeoutResponse(self.meta)
        return res

131
132
133
134
135
136
137
138

def start_server(
    num_clients,
    ip_config,
    server_id=0,
    keep_alive=False,
    num_servers=1,
):
139
140
141
    print("Sleep 1 seconds to test client re-connect.")
    time.sleep(1)
    server_state = dgl.distributed.ServerState(
142
143
        None, local_g=None, partition_book=None, keep_alive=keep_alive
    )
144
    dgl.distributed.register_service(
145
146
        HELLO_SERVICE_ID, HelloRequest, HelloResponse
    )
147
    dgl.distributed.register_service(
148
149
        TIMEOUT_SERVICE_ID, TimeoutRequest, TimeoutResponse
    )
150
    print("Start server {}".format(server_id))
151
152
153
154
155
156
157
158
159
    dgl.distributed.start_server(
        server_id=server_id,
        ip_config=ip_config,
        num_servers=num_servers,
        num_clients=num_clients,
        server_state=server_state,
    )


160
def start_client(ip_config, group_id=0, num_servers=1):
161
162
163
    dgl.distributed.register_service(
        HELLO_SERVICE_ID, HelloRequest, HelloResponse
    )
164
    dgl.distributed.connect_to_server(
165
166
167
168
        ip_config=ip_config,
        num_servers=num_servers,
        group_id=group_id,
    )
169
170
171
172
173
174
175
176
177
178
179
180
    req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
    # test send and recv
    dgl.distributed.send_request(0, req)
    res = dgl.distributed.recv_response()
    assert res.hello_str == STR
    assert res.integer == INTEGER
    assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
    # test remote_call
    target_and_requests = []
    for i in range(10):
        target_and_requests.append((0, req))
    res_list = dgl.distributed.remote_call(target_and_requests)
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
    for res in res_list:
        assert res.hello_str == STR
        assert res.integer == INTEGER
        assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
    # test send_request_to_machine
    dgl.distributed.send_request_to_machine(0, req)
    res = dgl.distributed.recv_response()
    assert res.hello_str == STR
    assert res.integer == INTEGER
    assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
    # test remote_call_to_machine
    target_and_requests = []
    for i in range(10):
        target_and_requests.append((0, req))
    res_list = dgl.distributed.remote_call_to_machine(target_and_requests)
196
197
198
199
    for res in res_list:
        assert res.hello_str == STR
        assert res.integer == INTEGER
        assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
200

201

202
def start_client_timeout(ip_config, group_id=0, num_servers=1):
203
    dgl.distributed.register_service(
204
205
        TIMEOUT_SERVICE_ID, TimeoutRequest, TimeoutResponse
    )
206
    dgl.distributed.connect_to_server(
207
208
209
210
        ip_config=ip_config,
        num_servers=num_servers,
        group_id=group_id,
    )
211
212
213
214
    timeout = 1 * 1000  # milliseconds
    req = TimeoutRequest(TIMEOUT_META, timeout)
    # test send and recv
    dgl.distributed.send_request(0, req)
215
    res = dgl.distributed.recv_response(timeout=int(timeout / 2))
216
217
218
219
220
221
222
223
224
225
226
    assert res is None
    res = dgl.distributed.recv_response()
    assert res.meta == TIMEOUT_META
    # test remote_call
    req = TimeoutRequest(TIMEOUT_META, timeout, response=False)
    target_and_requests = []
    for i in range(3):
        target_and_requests.append((0, req))
    expect_except = False
    try:
        res_list = dgl.distributed.remote_call(
227
228
            target_and_requests, timeout=int(timeout / 2)
        )
229
230
231
232
233
234
    except dgl.DGLError:
        expect_except = True
    assert expect_except
    # test send_request_to_machine
    req = TimeoutRequest(TIMEOUT_META, timeout)
    dgl.distributed.send_request_to_machine(0, req)
235
    res = dgl.distributed.recv_response(timeout=int(timeout / 2))
236
237
238
239
240
241
242
243
244
245
246
    assert res is None
    res = dgl.distributed.recv_response()
    assert res.meta == TIMEOUT_META
    # test remote_call_to_machine
    req = TimeoutRequest(TIMEOUT_META, timeout, response=False)
    target_and_requests = []
    for i in range(3):
        target_and_requests.append((0, req))
    expect_except = False
    try:
        res_list = dgl.distributed.remote_call_to_machine(
247
248
            target_and_requests, timeout=int(timeout / 2)
        )
249
250
251
252
    except dgl.DGLError:
        expect_except = True
    assert expect_except

253
254

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
255
def test_rpc_timeout():
256
    reset_envs()
257
    os.environ["DGL_DIST_MODE"] = "distributed"
258
259
    ip_config = "rpc_ip_config.txt"
    generate_ip_config(ip_config, 1, 1)
260
    ctx = mp.get_context("spawn")
261
262
    pserver = ctx.Process(target=start_server, args=(1, ip_config, 0, False, 1))
    pclient = ctx.Process(target=start_client_timeout, args=(ip_config, 0, 1))
263
264
265
266
267
    pserver.start()
    pclient.start()
    pserver.join()
    pclient.join()

268

269
def test_serialize():
270
    reset_envs()
271
272
273
274
275
276
    os.environ["DGL_DIST_MODE"] = "distributed"
    from dgl.distributed.rpc import (
        deserialize_from_payload,
        serialize_to_payload,
    )

277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
    SERVICE_ID = 12345
    dgl.distributed.register_service(SERVICE_ID, MyRequest, MyResponse)
    req = MyRequest()
    data, tensors = serialize_to_payload(req)
    req1 = deserialize_from_payload(MyRequest, data, tensors)
    req1.foo(req1.x, req1.y)
    assert req.x == req1.x
    assert req.y == req1.y
    assert F.array_equal(req.z, req1.z)

    res = MyResponse()
    data, tensors = serialize_to_payload(res)
    res1 = deserialize_from_payload(MyResponse, data, tensors)
    assert res.x == res1.x

292

293
def test_rpc_msg():
294
    reset_envs()
295
296
297
    os.environ["DGL_DIST_MODE"] = "distributed"
    from dgl.distributed.rpc import (
        deserialize_from_payload,
298
        RPCMessage,
299
300
301
        serialize_to_payload,
    )

302
303
304
305
306
307
308
309
310
311
312
313
314
    SERVICE_ID = 32452
    dgl.distributed.register_service(SERVICE_ID, MyRequest, MyResponse)
    req = MyRequest()
    data, tensors = serialize_to_payload(req)
    rpcmsg = RPCMessage(SERVICE_ID, 23, 0, 1, data, tensors)
    assert rpcmsg.service_id == SERVICE_ID
    assert rpcmsg.msg_seq == 23
    assert rpcmsg.client_id == 0
    assert rpcmsg.server_id == 1
    assert len(rpcmsg.data) == len(data)
    assert len(rpcmsg.tensors) == 1
    assert F.array_equal(rpcmsg.tensors[0], req.z)

315
316

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
317
def test_multi_client():
318
    reset_envs()
319
    os.environ["DGL_DIST_MODE"] = "distributed"
320
321
    ip_config = "rpc_ip_config_mul_client.txt"
    generate_ip_config(ip_config, 1, 1)
322
    ctx = mp.get_context("spawn")
323
    num_clients = 20
324
325
    pserver = ctx.Process(
        target=start_server,
326
        args=(num_clients, ip_config, 0, False, 1),
327
    )
328
    pclient_list = []
329
    for i in range(num_clients):
330
        pclient = ctx.Process(target=start_client, args=(ip_config, 0, 1))
331
332
        pclient_list.append(pclient)
    pserver.start()
333
    for i in range(num_clients):
334
        pclient_list[i].start()
335
    for i in range(num_clients):
336
337
338
339
        pclient_list[i].join()
    pserver.join()


340
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
341
def test_multi_thread_rpc():
342
    reset_envs()
343
    os.environ["DGL_DIST_MODE"] = "distributed"
344
    num_servers = 2
345
346
    ip_config = "rpc_ip_config_multithread.txt"
    generate_ip_config(ip_config, num_servers, num_servers)
347
    ctx = mp.get_context("spawn")
348
349
    pserver_list = []
    for i in range(num_servers):
350
        pserver = ctx.Process(
351
            target=start_server, args=(1, ip_config, i, False, 1)
352
        )
353
354
        pserver.start()
        pserver_list.append(pserver)
355

356
357
    def start_client_multithread(ip_config):
        import threading
358
359

        dgl.distributed.connect_to_server(
360
361
            ip_config=ip_config,
            num_servers=1,
362
363
364
365
366
        )
        dgl.distributed.register_service(
            HELLO_SERVICE_ID, HelloRequest, HelloResponse
        )

367
368
369
        req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
        dgl.distributed.send_request(0, req)

370
        def subthread_call(server_id):
371
            req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
372
            dgl.distributed.send_request(server_id, req)
373

374
375
376
        subthread = threading.Thread(target=subthread_call, args=(1,))
        subthread.start()
        subthread.join()
377

378
379
        res0 = dgl.distributed.recv_response()
        res1 = dgl.distributed.recv_response()
380
        # Order is not guaranteed
381
        assert_array_equal(F.asnumpy(res0.tensor), F.asnumpy(TENSOR))
382
        assert_array_equal(F.asnumpy(res1.tensor), F.asnumpy(TENSOR))
383
384
        dgl.distributed.exit_client()

385
    start_client_multithread(ip_config)
386
387
    pserver.join()

388
389
390
391
392
393

@unittest.skipIf(
    True,
    reason="Tests of multiple groups may fail and let's disable them for now.",
)
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
394
395
def test_multi_client_groups():
    reset_envs()
396
    os.environ["DGL_DIST_MODE"] = "distributed"
397
398
399
400
401
402
    ip_config = "rpc_ip_config_mul_client_groups.txt"
    num_machines = 5
    # should test with larger number but due to possible port in-use issue.
    num_servers = 1
    generate_ip_config(ip_config, num_machines, num_servers)
    # presssue test
403
404
    num_clients = 2
    num_groups = 2
405
    ctx = mp.get_context("spawn")
406
    pserver_list = []
407
408
409
410
411
    for i in range(num_servers * num_machines):
        pserver = ctx.Process(
            target=start_server,
            args=(num_clients, ip_config, i, True, num_servers),
        )
412
413
414
415
416
        pserver.start()
        pserver_list.append(pserver)
    pclient_list = []
    for i in range(num_clients):
        for group_id in range(num_groups):
417
418
419
            pclient = ctx.Process(
                target=start_client, args=(ip_config, group_id, num_servers)
            )
420
421
422
423
424
425
426
427
428
429
430
            pclient.start()
            pclient_list.append(pclient)
    for p in pclient_list:
        p.join()
    for p in pserver_list:
        assert p.is_alive()
    # force shutdown server
    dgl.distributed.shutdown_servers(ip_config, num_servers)
    for p in pserver_list:
        p.join()

431
432

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
433
def test_multi_client_connect():
434
    reset_envs()
435
    os.environ["DGL_DIST_MODE"] = "distributed"
436
437
    ip_config = "rpc_ip_config_mul_client.txt"
    generate_ip_config(ip_config, 1, 1)
438
    ctx = mp.get_context("spawn")
439
    num_clients = 1
440
441
    pserver = ctx.Process(
        target=start_server,
442
        args=(num_clients, ip_config, 0, False, 1),
443
    )
444
445

    # small max try times
446
    os.environ["DGL_DIST_MAX_TRY_TIMES"] = "1"
447
448
    expect_except = False
    try:
449
        start_client(ip_config, 0, 1)
450
451
452
453
454
455
    except dgl.distributed.DistConnectError as err:
        print("Expected error: {}".format(err))
        expect_except = True
    assert expect_except

    # large max try times
456
    os.environ["DGL_DIST_MAX_TRY_TIMES"] = "1024"
457
    pclient = ctx.Process(target=start_client, args=(ip_config, 0, 1))
458
459
460
461
462
    pclient.start()
    pserver.start()
    pclient.join()
    pserver.join()
    reset_envs()
463

464
465

if __name__ == "__main__":
466
467
    test_serialize()
    test_rpc_msg()
468
469
    test_multi_client("socket")
    test_multi_client("tesnsorpipe")
470
    test_multi_thread_rpc()
471
    test_multi_client_connect("socket")