test_rpc.py 9.88 KB
Newer Older
1
2
import os
import time
3
import socket
4
5
6
7

import dgl
import backend as F
import unittest, pytest
8
import multiprocessing as mp
9
from numpy.testing import assert_array_equal
10
from utils import reset_envs, generate_ip_config
11

12
13
14
15
if os.name != 'nt':
    import fcntl
    import struct

16
17
18
INTEGER = 2
STR = 'hello world!'
HELLO_SERVICE_ID = 901231
19
TENSOR = F.zeros((1000, 1000), F.int64, F.cpu())
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85

def foo(x, y):
    assert x == 123
    assert y == "abc"

class MyRequest(dgl.distributed.Request):
    def __init__(self):
        self.x = 123
        self.y = "abc"
        self.z = F.randn((3, 4))
        self.foo = foo

    def __getstate__(self):
        return self.x, self.y, self.z, self.foo

    def __setstate__(self, state):
        self.x, self.y, self.z, self.foo = state

    def process_request(self, server_state):
        pass

class MyResponse(dgl.distributed.Response):
    def __init__(self):
        self.x = 432

    def __getstate__(self):
        return self.x

    def __setstate__(self, state):
        self.x = state
 
def simple_func(tensor):
    return tensor

class HelloResponse(dgl.distributed.Response):
    def __init__(self, hello_str, integer, tensor):
        self.hello_str = hello_str
        self.integer = integer
        self.tensor = tensor

    def __getstate__(self):
        return self.hello_str, self.integer, self.tensor

    def __setstate__(self, state):
        self.hello_str, self.integer, self.tensor = state

class HelloRequest(dgl.distributed.Request):
    def __init__(self, hello_str, integer, tensor, func):
        self.hello_str = hello_str
        self.integer = integer
        self.tensor = tensor
        self.func = func

    def __getstate__(self):
        return self.hello_str, self.integer, self.tensor, self.func

    def __setstate__(self, state):
        self.hello_str, self.integer, self.tensor, self.func = state

    def process_request(self, server_state):
        assert self.hello_str == STR
        assert self.integer == INTEGER
        new_tensor = self.func(self.tensor)
        res = HelloResponse(self.hello_str, self.integer, new_tensor)
        return res

86
def start_server(num_clients, ip_config, server_id=0, keep_alive=False, num_servers=1, net_type='tensorpipe'):
87
88
89
90
91
92
    print("Sleep 1 seconds to test client re-connect.")
    time.sleep(1)
    server_state = dgl.distributed.ServerState(
        None, local_g=None, partition_book=None, keep_alive=keep_alive)
    dgl.distributed.register_service(
        HELLO_SERVICE_ID, HelloRequest, HelloResponse)
93
94
    print("Start server {}".format(server_id))
    dgl.distributed.start_server(server_id=server_id, 
95
                                 ip_config=ip_config, 
96
                                 num_servers=num_servers,
97
                                 num_clients=num_clients, 
98
99
                                 server_state=server_state,
                                 net_type=net_type)
100

101
def start_client(ip_config, group_id=0, num_servers=1, net_type='tensorpipe'):
102
    dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse)
103
    dgl.distributed.connect_to_server(
104
        ip_config=ip_config, num_servers=num_servers, group_id=group_id, net_type=net_type)
105
106
107
108
109
110
111
112
113
114
115
116
    req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
    # test send and recv
    dgl.distributed.send_request(0, req)
    res = dgl.distributed.recv_response()
    assert res.hello_str == STR
    assert res.integer == INTEGER
    assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
    # test remote_call
    target_and_requests = []
    for i in range(10):
        target_and_requests.append((0, req))
    res_list = dgl.distributed.remote_call(target_and_requests)
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    for res in res_list:
        assert res.hello_str == STR
        assert res.integer == INTEGER
        assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
    # test send_request_to_machine
    dgl.distributed.send_request_to_machine(0, req)
    res = dgl.distributed.recv_response()
    assert res.hello_str == STR
    assert res.integer == INTEGER
    assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
    # test remote_call_to_machine
    target_and_requests = []
    for i in range(10):
        target_and_requests.append((0, req))
    res_list = dgl.distributed.remote_call_to_machine(target_and_requests)
132
133
134
135
    for res in res_list:
        assert res.hello_str == STR
        assert res.integer == INTEGER
        assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
136

137
def test_serialize():
138
    reset_envs()
139
    os.environ['DGL_DIST_MODE'] = 'distributed'
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    from dgl.distributed.rpc import serialize_to_payload, deserialize_from_payload
    SERVICE_ID = 12345
    dgl.distributed.register_service(SERVICE_ID, MyRequest, MyResponse)
    req = MyRequest()
    data, tensors = serialize_to_payload(req)
    req1 = deserialize_from_payload(MyRequest, data, tensors)
    req1.foo(req1.x, req1.y)
    assert req.x == req1.x
    assert req.y == req1.y
    assert F.array_equal(req.z, req1.z)

    res = MyResponse()
    data, tensors = serialize_to_payload(res)
    res1 = deserialize_from_payload(MyResponse, data, tensors)
    assert res.x == res1.x

def test_rpc_msg():
157
    reset_envs()
158
    os.environ['DGL_DIST_MODE'] = 'distributed'
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    from dgl.distributed.rpc import serialize_to_payload, deserialize_from_payload, RPCMessage
    SERVICE_ID = 32452
    dgl.distributed.register_service(SERVICE_ID, MyRequest, MyResponse)
    req = MyRequest()
    data, tensors = serialize_to_payload(req)
    rpcmsg = RPCMessage(SERVICE_ID, 23, 0, 1, data, tensors)
    assert rpcmsg.service_id == SERVICE_ID
    assert rpcmsg.msg_seq == 23
    assert rpcmsg.client_id == 0
    assert rpcmsg.server_id == 1
    assert len(rpcmsg.data) == len(data)
    assert len(rpcmsg.tensors) == 1
    assert F.array_equal(rpcmsg.tensors[0], req.z)

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_rpc():
175
    reset_envs()
176
    os.environ['DGL_DIST_MODE'] = 'distributed'
177
    generate_ip_config("rpc_ip_config.txt", 1, 1)
178
    ctx = mp.get_context('spawn')
179
180
    pserver = ctx.Process(target=start_server, args=(1, "rpc_ip_config.txt"))
    pclient = ctx.Process(target=start_client, args=("rpc_ip_config.txt",))
181
182
183
184
    pserver.start()
    pclient.start()
    pserver.join()
    pclient.join()
185

186
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
187
188
@pytest.mark.parametrize("net_type", ['socket', 'tensorpipe'])
def test_multi_client(net_type):
189
    reset_envs()
190
    os.environ['DGL_DIST_MODE'] = 'distributed'
191
192
    ip_config = "rpc_ip_config_mul_client.txt"
    generate_ip_config(ip_config, 1, 1)
193
    ctx = mp.get_context('spawn')
194
    num_clients = 20
195
    pserver = ctx.Process(target=start_server, args=(num_clients, ip_config, 0, False, 1, net_type))
196
    pclient_list = []
197
    for i in range(num_clients):
198
        pclient = ctx.Process(target=start_client, args=(ip_config, 0, 1, net_type))
199
200
        pclient_list.append(pclient)
    pserver.start()
201
    for i in range(num_clients):
202
        pclient_list[i].start()
203
    for i in range(num_clients):
204
205
206
207
        pclient_list[i].join()
    pserver.join()


208
209
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_multi_thread_rpc():
210
    reset_envs()
211
212
    os.environ['DGL_DIST_MODE'] = 'distributed'
    num_servers = 2
213
    generate_ip_config("rpc_ip_config_multithread.txt", num_servers, num_servers)
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
    ctx = mp.get_context('spawn')
    pserver_list = []
    for i in range(num_servers):
        pserver = ctx.Process(target=start_server, args=(1, "rpc_ip_config_multithread.txt", i))
        pserver.start()
        pserver_list.append(pserver)
    def start_client_multithread(ip_config):
        import threading
        dgl.distributed.connect_to_server(ip_config=ip_config, num_servers=1)
        dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse)
        
        req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
        dgl.distributed.send_request(0, req)

        def subthread_call(server_id):            
229
            req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
230
231
232
233
234
235
236
237
238
            dgl.distributed.send_request(server_id, req)
        
        
        subthread = threading.Thread(target=subthread_call, args=(1,))
        subthread.start()
        subthread.join()
        
        res0 = dgl.distributed.recv_response()
        res1 = dgl.distributed.recv_response()
239
        # Order is not guaranteed
240
        assert_array_equal(F.asnumpy(res0.tensor), F.asnumpy(TENSOR))
241
        assert_array_equal(F.asnumpy(res1.tensor), F.asnumpy(TENSOR))
242
243
244
245
246
247
        dgl.distributed.exit_client()

    start_client_multithread("rpc_ip_config_multithread.txt")
    pserver.join()


248
249
250
251
252
253
254
255
256
257
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_multi_client_groups():
    reset_envs()
    os.environ['DGL_DIST_MODE'] = 'distributed'
    ip_config = "rpc_ip_config_mul_client_groups.txt"
    num_machines = 5
    # should test with larger number but due to possible port in-use issue.
    num_servers = 1
    generate_ip_config(ip_config, num_machines, num_servers)
    # presssue test
258
259
    num_clients = 2
    num_groups = 2
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    ctx = mp.get_context('spawn')
    pserver_list = []
    for i in range(num_servers*num_machines):
        pserver = ctx.Process(target=start_server, args=(num_clients, ip_config, i, True, num_servers))
        pserver.start()
        pserver_list.append(pserver)
    pclient_list = []
    for i in range(num_clients):
        for group_id in range(num_groups):
            pclient = ctx.Process(target=start_client, args=(ip_config, group_id, num_servers))
            pclient.start()
            pclient_list.append(pclient)
    for p in pclient_list:
        p.join()
    for p in pserver_list:
        assert p.is_alive()
    # force shutdown server
    dgl.distributed.shutdown_servers(ip_config, num_servers)
    for p in pserver_list:
        p.join()


282
283
284
285
if __name__ == '__main__':
    test_serialize()
    test_rpc_msg()
    test_rpc()
286
287
    test_multi_client('socket')
    test_multi_client('tesnsorpipe')
288
    test_multi_thread_rpc()