test_rpc.py 9.58 KB
Newer Older
1
2
import os
import time
3
import socket
4
5
6
7

import dgl
import backend as F
import unittest, pytest
8
import multiprocessing as mp
9
from numpy.testing import assert_array_equal
10
from utils import reset_envs, generate_ip_config
11

12
13
14
15
if os.name != 'nt':
    import fcntl
    import struct

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
INTEGER = 2
STR = 'hello world!'
HELLO_SERVICE_ID = 901231
TENSOR = F.zeros((10, 10), F.int64, F.cpu())

def foo(x, y):
    assert x == 123
    assert y == "abc"

class MyRequest(dgl.distributed.Request):
    def __init__(self):
        self.x = 123
        self.y = "abc"
        self.z = F.randn((3, 4))
        self.foo = foo

    def __getstate__(self):
        return self.x, self.y, self.z, self.foo

    def __setstate__(self, state):
        self.x, self.y, self.z, self.foo = state

    def process_request(self, server_state):
        pass

class MyResponse(dgl.distributed.Response):
    def __init__(self):
        self.x = 432

    def __getstate__(self):
        return self.x

    def __setstate__(self, state):
        self.x = state
 
def simple_func(tensor):
    return tensor

class HelloResponse(dgl.distributed.Response):
    def __init__(self, hello_str, integer, tensor):
        self.hello_str = hello_str
        self.integer = integer
        self.tensor = tensor

    def __getstate__(self):
        return self.hello_str, self.integer, self.tensor

    def __setstate__(self, state):
        self.hello_str, self.integer, self.tensor = state

class HelloRequest(dgl.distributed.Request):
    def __init__(self, hello_str, integer, tensor, func):
        self.hello_str = hello_str
        self.integer = integer
        self.tensor = tensor
        self.func = func

    def __getstate__(self):
        return self.hello_str, self.integer, self.tensor, self.func

    def __setstate__(self, state):
        self.hello_str, self.integer, self.tensor, self.func = state

    def process_request(self, server_state):
        assert self.hello_str == STR
        assert self.integer == INTEGER
        new_tensor = self.func(self.tensor)
        res = HelloResponse(self.hello_str, self.integer, new_tensor)
        return res

86
87
88
89
90
91
92
def start_server(num_clients, ip_config, server_id=0, keep_alive=False, num_servers=1):
    print("Sleep 1 seconds to test client re-connect.")
    time.sleep(1)
    server_state = dgl.distributed.ServerState(
        None, local_g=None, partition_book=None, keep_alive=keep_alive)
    dgl.distributed.register_service(
        HELLO_SERVICE_ID, HelloRequest, HelloResponse)
93
94
    print("Start server {}".format(server_id))
    dgl.distributed.start_server(server_id=server_id, 
95
                                 ip_config=ip_config, 
96
                                 num_servers=num_servers,
97
                                 num_clients=num_clients, 
98
                                 server_state=server_state)
99

100
def start_client(ip_config, group_id=0, num_servers=1):
101
    dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse)
102
103
    dgl.distributed.connect_to_server(
        ip_config=ip_config, num_servers=num_servers, group_id=group_id)
104
105
106
107
108
109
110
111
112
113
114
115
    req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
    # test send and recv
    dgl.distributed.send_request(0, req)
    res = dgl.distributed.recv_response()
    assert res.hello_str == STR
    assert res.integer == INTEGER
    assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
    # test remote_call
    target_and_requests = []
    for i in range(10):
        target_and_requests.append((0, req))
    res_list = dgl.distributed.remote_call(target_and_requests)
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    for res in res_list:
        assert res.hello_str == STR
        assert res.integer == INTEGER
        assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
    # test send_request_to_machine
    dgl.distributed.send_request_to_machine(0, req)
    res = dgl.distributed.recv_response()
    assert res.hello_str == STR
    assert res.integer == INTEGER
    assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
    # test remote_call_to_machine
    target_and_requests = []
    for i in range(10):
        target_and_requests.append((0, req))
    res_list = dgl.distributed.remote_call_to_machine(target_and_requests)
131
132
133
134
    for res in res_list:
        assert res.hello_str == STR
        assert res.integer == INTEGER
        assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
135

136
def test_serialize():
137
    reset_envs()
138
    os.environ['DGL_DIST_MODE'] = 'distributed'
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    from dgl.distributed.rpc import serialize_to_payload, deserialize_from_payload
    SERVICE_ID = 12345
    dgl.distributed.register_service(SERVICE_ID, MyRequest, MyResponse)
    req = MyRequest()
    data, tensors = serialize_to_payload(req)
    req1 = deserialize_from_payload(MyRequest, data, tensors)
    req1.foo(req1.x, req1.y)
    assert req.x == req1.x
    assert req.y == req1.y
    assert F.array_equal(req.z, req1.z)

    res = MyResponse()
    data, tensors = serialize_to_payload(res)
    res1 = deserialize_from_payload(MyResponse, data, tensors)
    assert res.x == res1.x

def test_rpc_msg():
156
    reset_envs()
157
    os.environ['DGL_DIST_MODE'] = 'distributed'
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    from dgl.distributed.rpc import serialize_to_payload, deserialize_from_payload, RPCMessage
    SERVICE_ID = 32452
    dgl.distributed.register_service(SERVICE_ID, MyRequest, MyResponse)
    req = MyRequest()
    data, tensors = serialize_to_payload(req)
    rpcmsg = RPCMessage(SERVICE_ID, 23, 0, 1, data, tensors)
    assert rpcmsg.service_id == SERVICE_ID
    assert rpcmsg.msg_seq == 23
    assert rpcmsg.client_id == 0
    assert rpcmsg.server_id == 1
    assert len(rpcmsg.data) == len(data)
    assert len(rpcmsg.tensors) == 1
    assert F.array_equal(rpcmsg.tensors[0], req.z)

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_rpc():
174
    reset_envs()
175
    os.environ['DGL_DIST_MODE'] = 'distributed'
176
    generate_ip_config("rpc_ip_config.txt", 1, 1)
177
    ctx = mp.get_context('spawn')
178
179
    pserver = ctx.Process(target=start_server, args=(1, "rpc_ip_config.txt"))
    pclient = ctx.Process(target=start_client, args=("rpc_ip_config.txt",))
180
181
182
183
    pserver.start()
    pclient.start()
    pserver.join()
    pclient.join()
184

185
186
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_multi_client():
187
    reset_envs()
188
    os.environ['DGL_DIST_MODE'] = 'distributed'
189
    generate_ip_config("rpc_ip_config_mul_client.txt", 1, 1)
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    ctx = mp.get_context('spawn')
    pserver = ctx.Process(target=start_server, args=(10, "rpc_ip_config_mul_client.txt"))
    pclient_list = []
    for i in range(10):
        pclient = ctx.Process(target=start_client, args=("rpc_ip_config_mul_client.txt",))
        pclient_list.append(pclient)
    pserver.start()
    for i in range(10):
        pclient_list[i].start()
    for i in range(10):
        pclient_list[i].join()
    pserver.join()


204
205
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_multi_thread_rpc():
206
    reset_envs()
207
208
    os.environ['DGL_DIST_MODE'] = 'distributed'
    num_servers = 2
209
    generate_ip_config("rpc_ip_config_multithread.txt", num_servers, num_servers)
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
    ctx = mp.get_context('spawn')
    pserver_list = []
    for i in range(num_servers):
        pserver = ctx.Process(target=start_server, args=(1, "rpc_ip_config_multithread.txt", i))
        pserver.start()
        pserver_list.append(pserver)
    def start_client_multithread(ip_config):
        import threading
        dgl.distributed.connect_to_server(ip_config=ip_config, num_servers=1)
        dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse)
        
        req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
        dgl.distributed.send_request(0, req)

        def subthread_call(server_id):            
225
            req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
226
227
228
229
230
231
232
233
234
            dgl.distributed.send_request(server_id, req)
        
        
        subthread = threading.Thread(target=subthread_call, args=(1,))
        subthread.start()
        subthread.join()
        
        res0 = dgl.distributed.recv_response()
        res1 = dgl.distributed.recv_response()
235
        # Order is not guaranteed
236
        assert_array_equal(F.asnumpy(res0.tensor), F.asnumpy(TENSOR))
237
        assert_array_equal(F.asnumpy(res1.tensor), F.asnumpy(TENSOR))
238
239
240
241
242
243
        dgl.distributed.exit_client()

    start_client_multithread("rpc_ip_config_multithread.txt")
    pserver.join()


244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_multi_client_groups():
    reset_envs()
    os.environ['DGL_DIST_MODE'] = 'distributed'
    ip_config = "rpc_ip_config_mul_client_groups.txt"
    num_machines = 5
    # should test with larger number but due to possible port in-use issue.
    num_servers = 1
    generate_ip_config(ip_config, num_machines, num_servers)
    # presssue test
    num_clients = 15
    num_groups = 15
    ctx = mp.get_context('spawn')
    pserver_list = []
    for i in range(num_servers*num_machines):
        pserver = ctx.Process(target=start_server, args=(num_clients, ip_config, i, True, num_servers))
        pserver.start()
        pserver_list.append(pserver)
    pclient_list = []
    for i in range(num_clients):
        for group_id in range(num_groups):
            pclient = ctx.Process(target=start_client, args=(ip_config, group_id, num_servers))
            pclient.start()
            pclient_list.append(pclient)
    for p in pclient_list:
        p.join()
    for p in pserver_list:
        assert p.is_alive()
    # force shutdown server
    dgl.distributed.shutdown_servers(ip_config, num_servers)
    for p in pserver_list:
        p.join()


278
279
280
281
if __name__ == '__main__':
    test_serialize()
    test_rpc_msg()
    test_rpc()
282
    test_multi_client()
283
    test_multi_thread_rpc()