test_rpc.py 14.9 KB
Newer Older
1
import multiprocessing as mp
2
import os
3
import socket
4
5
import time
import unittest
6
7

import backend as F
8
9

import dgl
10
import pytest
11
from numpy.testing import assert_array_equal
12
from utils import generate_ip_config, reset_envs
13

14
if os.name != "nt":
15
16
17
    import fcntl
    import struct

18
INTEGER = 2
19
STR = "hello world!"
20
HELLO_SERVICE_ID = 901231
21
TENSOR = F.zeros((1000, 1000), F.int64, F.cpu())
22

23

24
25
26
27
def foo(x, y):
    assert x == 123
    assert y == "abc"

28

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class MyRequest(dgl.distributed.Request):
    def __init__(self):
        self.x = 123
        self.y = "abc"
        self.z = F.randn((3, 4))
        self.foo = foo

    def __getstate__(self):
        return self.x, self.y, self.z, self.foo

    def __setstate__(self, state):
        self.x, self.y, self.z, self.foo = state

    def process_request(self, server_state):
        pass

45

46
47
48
49
50
51
52
53
54
class MyResponse(dgl.distributed.Response):
    def __init__(self):
        self.x = 432

    def __getstate__(self):
        return self.x

    def __setstate__(self, state):
        self.x = state
55
56


57
58
59
def simple_func(tensor):
    return tensor

60

61
62
63
64
65
66
67
68
69
70
71
72
class HelloResponse(dgl.distributed.Response):
    def __init__(self, hello_str, integer, tensor):
        self.hello_str = hello_str
        self.integer = integer
        self.tensor = tensor

    def __getstate__(self):
        return self.hello_str, self.integer, self.tensor

    def __setstate__(self, state):
        self.hello_str, self.integer, self.tensor = state

73

74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class HelloRequest(dgl.distributed.Request):
    def __init__(self, hello_str, integer, tensor, func):
        self.hello_str = hello_str
        self.integer = integer
        self.tensor = tensor
        self.func = func

    def __getstate__(self):
        return self.hello_str, self.integer, self.tensor, self.func

    def __setstate__(self, state):
        self.hello_str, self.integer, self.tensor, self.func = state

    def process_request(self, server_state):
        assert self.hello_str == STR
        assert self.integer == INTEGER
        new_tensor = self.func(self.tensor)
        res = HelloResponse(self.hello_str, self.integer, new_tensor)
        return res

94
95

TIMEOUT_SERVICE_ID = 123456789
96
TIMEOUT_META = "timeout_test"
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124


class TimeoutResponse(dgl.distributed.Response):
    def __init__(self, meta):
        self.meta = meta

    def __getstate__(self):
        return self.meta

    def __setstate__(self, state):
        self.meta = state


class TimeoutRequest(dgl.distributed.Request):
    def __init__(self, meta, timeout, response=True):
        self.meta = meta
        self.timeout = timeout
        self.response = response

    def __getstate__(self):
        return self.meta, self.timeout, self.response

    def __setstate__(self, state):
        self.meta, self.timeout, self.response = state

    def process_request(self, server_state):
        assert self.meta == TIMEOUT_META
        # convert from milliseconds to seconds
125
        time.sleep(self.timeout / 1000)
126
127
128
129
130
        if not self.response:
            return None
        res = TimeoutResponse(self.meta)
        return res

131
132
133
134
135
136
137
138
139

def start_server(
    num_clients,
    ip_config,
    server_id=0,
    keep_alive=False,
    num_servers=1,
    net_type="tensorpipe",
):
140
141
142
    print("Sleep 1 seconds to test client re-connect.")
    time.sleep(1)
    server_state = dgl.distributed.ServerState(
143
144
        None, local_g=None, partition_book=None, keep_alive=keep_alive
    )
145
    dgl.distributed.register_service(
146
147
        HELLO_SERVICE_ID, HelloRequest, HelloResponse
    )
148
    dgl.distributed.register_service(
149
150
        TIMEOUT_SERVICE_ID, TimeoutRequest, TimeoutResponse
    )
151
    print("Start server {}".format(server_id))
152
153
154
155
156
157
158
159
160
161
162
163
164
165
    dgl.distributed.start_server(
        server_id=server_id,
        ip_config=ip_config,
        num_servers=num_servers,
        num_clients=num_clients,
        server_state=server_state,
        net_type=net_type,
    )


def start_client(ip_config, group_id=0, num_servers=1, net_type="tensorpipe"):
    dgl.distributed.register_service(
        HELLO_SERVICE_ID, HelloRequest, HelloResponse
    )
166
    dgl.distributed.connect_to_server(
167
168
169
170
171
        ip_config=ip_config,
        num_servers=num_servers,
        group_id=group_id,
        net_type=net_type,
    )
172
173
174
175
176
177
178
179
180
181
182
183
    req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
    # test send and recv
    dgl.distributed.send_request(0, req)
    res = dgl.distributed.recv_response()
    assert res.hello_str == STR
    assert res.integer == INTEGER
    assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
    # test remote_call
    target_and_requests = []
    for i in range(10):
        target_and_requests.append((0, req))
    res_list = dgl.distributed.remote_call(target_and_requests)
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    for res in res_list:
        assert res.hello_str == STR
        assert res.integer == INTEGER
        assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
    # test send_request_to_machine
    dgl.distributed.send_request_to_machine(0, req)
    res = dgl.distributed.recv_response()
    assert res.hello_str == STR
    assert res.integer == INTEGER
    assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
    # test remote_call_to_machine
    target_and_requests = []
    for i in range(10):
        target_and_requests.append((0, req))
    res_list = dgl.distributed.remote_call_to_machine(target_and_requests)
199
200
201
202
    for res in res_list:
        assert res.hello_str == STR
        assert res.integer == INTEGER
        assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
203

204

205
206
207
def start_client_timeout(
    ip_config, group_id=0, num_servers=1, net_type="tensorpipe"
):
208
    dgl.distributed.register_service(
209
210
        TIMEOUT_SERVICE_ID, TimeoutRequest, TimeoutResponse
    )
211
    dgl.distributed.connect_to_server(
212
213
214
215
216
        ip_config=ip_config,
        num_servers=num_servers,
        group_id=group_id,
        net_type=net_type,
    )
217
218
219
220
    timeout = 1 * 1000  # milliseconds
    req = TimeoutRequest(TIMEOUT_META, timeout)
    # test send and recv
    dgl.distributed.send_request(0, req)
221
    res = dgl.distributed.recv_response(timeout=int(timeout / 2))
222
223
224
225
226
227
228
229
230
231
232
    assert res is None
    res = dgl.distributed.recv_response()
    assert res.meta == TIMEOUT_META
    # test remote_call
    req = TimeoutRequest(TIMEOUT_META, timeout, response=False)
    target_and_requests = []
    for i in range(3):
        target_and_requests.append((0, req))
    expect_except = False
    try:
        res_list = dgl.distributed.remote_call(
233
234
            target_and_requests, timeout=int(timeout / 2)
        )
235
236
237
238
239
240
    except dgl.DGLError:
        expect_except = True
    assert expect_except
    # test send_request_to_machine
    req = TimeoutRequest(TIMEOUT_META, timeout)
    dgl.distributed.send_request_to_machine(0, req)
241
    res = dgl.distributed.recv_response(timeout=int(timeout / 2))
242
243
244
245
246
247
248
249
250
251
252
    assert res is None
    res = dgl.distributed.recv_response()
    assert res.meta == TIMEOUT_META
    # test remote_call_to_machine
    req = TimeoutRequest(TIMEOUT_META, timeout, response=False)
    target_and_requests = []
    for i in range(3):
        target_and_requests.append((0, req))
    expect_except = False
    try:
        res_list = dgl.distributed.remote_call_to_machine(
253
254
            target_and_requests, timeout=int(timeout / 2)
        )
255
256
257
258
    except dgl.DGLError:
        expect_except = True
    assert expect_except

259
260
261

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("net_type", ["socket", "tensorpipe"])
262
263
def test_rpc_timeout(net_type):
    reset_envs()
264
    os.environ["DGL_DIST_MODE"] = "distributed"
265
266
    ip_config = "rpc_ip_config.txt"
    generate_ip_config(ip_config, 1, 1)
267
268
269
270
271
272
273
    ctx = mp.get_context("spawn")
    pserver = ctx.Process(
        target=start_server, args=(1, ip_config, 0, False, 1, net_type)
    )
    pclient = ctx.Process(
        target=start_client_timeout, args=(ip_config, 0, 1, net_type)
    )
274
275
276
277
278
    pserver.start()
    pclient.start()
    pserver.join()
    pclient.join()

279

280
def test_serialize():
281
    reset_envs()
282
283
284
285
286
287
    os.environ["DGL_DIST_MODE"] = "distributed"
    from dgl.distributed.rpc import (
        deserialize_from_payload,
        serialize_to_payload,
    )

288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
    SERVICE_ID = 12345
    dgl.distributed.register_service(SERVICE_ID, MyRequest, MyResponse)
    req = MyRequest()
    data, tensors = serialize_to_payload(req)
    req1 = deserialize_from_payload(MyRequest, data, tensors)
    req1.foo(req1.x, req1.y)
    assert req.x == req1.x
    assert req.y == req1.y
    assert F.array_equal(req.z, req1.z)

    res = MyResponse()
    data, tensors = serialize_to_payload(res)
    res1 = deserialize_from_payload(MyResponse, data, tensors)
    assert res.x == res1.x

303

304
def test_rpc_msg():
305
    reset_envs()
306
307
308
    os.environ["DGL_DIST_MODE"] = "distributed"
    from dgl.distributed.rpc import (
        deserialize_from_payload,
309
        RPCMessage,
310
311
312
        serialize_to_payload,
    )

313
314
315
316
317
318
319
320
321
322
323
324
325
    SERVICE_ID = 32452
    dgl.distributed.register_service(SERVICE_ID, MyRequest, MyResponse)
    req = MyRequest()
    data, tensors = serialize_to_payload(req)
    rpcmsg = RPCMessage(SERVICE_ID, 23, 0, 1, data, tensors)
    assert rpcmsg.service_id == SERVICE_ID
    assert rpcmsg.msg_seq == 23
    assert rpcmsg.client_id == 0
    assert rpcmsg.server_id == 1
    assert len(rpcmsg.data) == len(data)
    assert len(rpcmsg.tensors) == 1
    assert F.array_equal(rpcmsg.tensors[0], req.z)

326
327
328

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("net_type", ["tensorpipe"])
329
def test_rpc(net_type):
330
    reset_envs()
331
    os.environ["DGL_DIST_MODE"] = "distributed"
332
    generate_ip_config("rpc_ip_config.txt", 1, 1)
333
334
335
336
337
338
339
340
    ctx = mp.get_context("spawn")
    pserver = ctx.Process(
        target=start_server,
        args=(1, "rpc_ip_config.txt", 0, False, 1, net_type),
    )
    pclient = ctx.Process(
        target=start_client, args=("rpc_ip_config.txt", 0, 1, net_type)
    )
341
342
343
344
    pserver.start()
    pclient.start()
    pserver.join()
    pclient.join()
345

346
347
348

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("net_type", ["socket", "tensorpipe"])
349
def test_multi_client(net_type):
350
    reset_envs()
351
    os.environ["DGL_DIST_MODE"] = "distributed"
352
353
    ip_config = "rpc_ip_config_mul_client.txt"
    generate_ip_config(ip_config, 1, 1)
354
    ctx = mp.get_context("spawn")
355
    num_clients = 20
356
357
358
359
    pserver = ctx.Process(
        target=start_server,
        args=(num_clients, ip_config, 0, False, 1, net_type),
    )
360
    pclient_list = []
361
    for i in range(num_clients):
362
363
364
        pclient = ctx.Process(
            target=start_client, args=(ip_config, 0, 1, net_type)
        )
365
366
        pclient_list.append(pclient)
    pserver.start()
367
    for i in range(num_clients):
368
        pclient_list[i].start()
369
    for i in range(num_clients):
370
371
372
373
        pclient_list[i].join()
    pserver.join()


374
375
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("net_type", ["socket", "tensorpipe"])
376
def test_multi_thread_rpc(net_type):
377
    reset_envs()
378
    os.environ["DGL_DIST_MODE"] = "distributed"
379
    num_servers = 2
380
381
    ip_config = "rpc_ip_config_multithread.txt"
    generate_ip_config(ip_config, num_servers, num_servers)
382
    ctx = mp.get_context("spawn")
383
384
    pserver_list = []
    for i in range(num_servers):
385
386
387
        pserver = ctx.Process(
            target=start_server, args=(1, ip_config, i, False, 1, net_type)
        )
388
389
        pserver.start()
        pserver_list.append(pserver)
390

391
392
    def start_client_multithread(ip_config):
        import threading
393
394
395
396
397
398
399
400

        dgl.distributed.connect_to_server(
            ip_config=ip_config, num_servers=1, net_type=net_type
        )
        dgl.distributed.register_service(
            HELLO_SERVICE_ID, HelloRequest, HelloResponse
        )

401
402
403
        req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
        dgl.distributed.send_request(0, req)

404
        def subthread_call(server_id):
405
            req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
406
            dgl.distributed.send_request(server_id, req)
407

408
409
410
        subthread = threading.Thread(target=subthread_call, args=(1,))
        subthread.start()
        subthread.join()
411

412
413
        res0 = dgl.distributed.recv_response()
        res1 = dgl.distributed.recv_response()
414
        # Order is not guaranteed
415
        assert_array_equal(F.asnumpy(res0.tensor), F.asnumpy(TENSOR))
416
        assert_array_equal(F.asnumpy(res1.tensor), F.asnumpy(TENSOR))
417
418
        dgl.distributed.exit_client()

419
    start_client_multithread(ip_config)
420
421
    pserver.join()

422
423
424
425
426
427

@unittest.skipIf(
    True,
    reason="Tests of multiple groups may fail and let's disable them for now.",
)
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
428
429
def test_multi_client_groups():
    reset_envs()
430
    os.environ["DGL_DIST_MODE"] = "distributed"
431
432
433
434
435
436
    ip_config = "rpc_ip_config_mul_client_groups.txt"
    num_machines = 5
    # should test with larger number but due to possible port in-use issue.
    num_servers = 1
    generate_ip_config(ip_config, num_machines, num_servers)
    # presssue test
437
438
    num_clients = 2
    num_groups = 2
439
    ctx = mp.get_context("spawn")
440
    pserver_list = []
441
442
443
444
445
    for i in range(num_servers * num_machines):
        pserver = ctx.Process(
            target=start_server,
            args=(num_clients, ip_config, i, True, num_servers),
        )
446
447
448
449
450
        pserver.start()
        pserver_list.append(pserver)
    pclient_list = []
    for i in range(num_clients):
        for group_id in range(num_groups):
451
452
453
            pclient = ctx.Process(
                target=start_client, args=(ip_config, group_id, num_servers)
            )
454
455
456
457
458
459
460
461
462
463
464
            pclient.start()
            pclient_list.append(pclient)
    for p in pclient_list:
        p.join()
    for p in pserver_list:
        assert p.is_alive()
    # force shutdown server
    dgl.distributed.shutdown_servers(ip_config, num_servers)
    for p in pserver_list:
        p.join()

465
466
467

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("net_type", ["socket", "tensorpipe"])
468
469
def test_multi_client_connect(net_type):
    reset_envs()
470
    os.environ["DGL_DIST_MODE"] = "distributed"
471
472
    ip_config = "rpc_ip_config_mul_client.txt"
    generate_ip_config(ip_config, 1, 1)
473
    ctx = mp.get_context("spawn")
474
    num_clients = 1
475
476
477
478
    pserver = ctx.Process(
        target=start_server,
        args=(num_clients, ip_config, 0, False, 1, net_type),
    )
479
480

    # small max try times
481
    os.environ["DGL_DIST_MAX_TRY_TIMES"] = "1"
482
483
484
485
486
487
488
489
490
    expect_except = False
    try:
        start_client(ip_config, 0, 1, net_type)
    except dgl.distributed.DistConnectError as err:
        print("Expected error: {}".format(err))
        expect_except = True
    assert expect_except

    # large max try times
491
    os.environ["DGL_DIST_MAX_TRY_TIMES"] = "1024"
492
493
494
495
496
497
    pclient = ctx.Process(target=start_client, args=(ip_config, 0, 1, net_type))
    pclient.start()
    pserver.start()
    pclient.join()
    pserver.join()
    reset_envs()
498

499
500

if __name__ == "__main__":
501
502
503
    test_serialize()
    test_rpc_msg()
    test_rpc()
504
505
    test_multi_client("socket")
    test_multi_client("tesnsorpipe")
506
    test_multi_thread_rpc()
507
508
    test_multi_client_connect("socket")
    test_multi_client_connect("tensorpipe")