test_rpc.py 9.63 KB
Newer Older
1
2
import os
import time
3
import socket
4
5
6
7

import dgl
import backend as F
import unittest, pytest
8
import multiprocessing as mp
9
from numpy.testing import assert_array_equal
10
from utils import reset_envs, generate_ip_config
11

12
13
14
15
if os.name != 'nt':
    import fcntl
    import struct

16
17
18
INTEGER = 2
STR = 'hello world!'
HELLO_SERVICE_ID = 901231
19
TENSOR = F.zeros((1000, 1000), F.int64, F.cpu())
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85

def foo(x, y):
    assert x == 123
    assert y == "abc"

class MyRequest(dgl.distributed.Request):
    def __init__(self):
        self.x = 123
        self.y = "abc"
        self.z = F.randn((3, 4))
        self.foo = foo

    def __getstate__(self):
        return self.x, self.y, self.z, self.foo

    def __setstate__(self, state):
        self.x, self.y, self.z, self.foo = state

    def process_request(self, server_state):
        pass

class MyResponse(dgl.distributed.Response):
    def __init__(self):
        self.x = 432

    def __getstate__(self):
        return self.x

    def __setstate__(self, state):
        self.x = state
 
def simple_func(tensor):
    return tensor

class HelloResponse(dgl.distributed.Response):
    def __init__(self, hello_str, integer, tensor):
        self.hello_str = hello_str
        self.integer = integer
        self.tensor = tensor

    def __getstate__(self):
        return self.hello_str, self.integer, self.tensor

    def __setstate__(self, state):
        self.hello_str, self.integer, self.tensor = state

class HelloRequest(dgl.distributed.Request):
    def __init__(self, hello_str, integer, tensor, func):
        self.hello_str = hello_str
        self.integer = integer
        self.tensor = tensor
        self.func = func

    def __getstate__(self):
        return self.hello_str, self.integer, self.tensor, self.func

    def __setstate__(self, state):
        self.hello_str, self.integer, self.tensor, self.func = state

    def process_request(self, server_state):
        assert self.hello_str == STR
        assert self.integer == INTEGER
        new_tensor = self.func(self.tensor)
        res = HelloResponse(self.hello_str, self.integer, new_tensor)
        return res

86
87
88
89
90
91
92
def start_server(num_clients, ip_config, server_id=0, keep_alive=False, num_servers=1):
    print("Sleep 1 seconds to test client re-connect.")
    time.sleep(1)
    server_state = dgl.distributed.ServerState(
        None, local_g=None, partition_book=None, keep_alive=keep_alive)
    dgl.distributed.register_service(
        HELLO_SERVICE_ID, HelloRequest, HelloResponse)
93
94
    print("Start server {}".format(server_id))
    dgl.distributed.start_server(server_id=server_id, 
95
                                 ip_config=ip_config, 
96
                                 num_servers=num_servers,
97
                                 num_clients=num_clients, 
98
                                 server_state=server_state)
99

100
def start_client(ip_config, group_id=0, num_servers=1):
101
    dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse)
102
103
    dgl.distributed.connect_to_server(
        ip_config=ip_config, num_servers=num_servers, group_id=group_id)
104
105
106
107
108
109
110
111
112
113
114
115
    req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
    # test send and recv
    dgl.distributed.send_request(0, req)
    res = dgl.distributed.recv_response()
    assert res.hello_str == STR
    assert res.integer == INTEGER
    assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
    # test remote_call
    target_and_requests = []
    for i in range(10):
        target_and_requests.append((0, req))
    res_list = dgl.distributed.remote_call(target_and_requests)
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    for res in res_list:
        assert res.hello_str == STR
        assert res.integer == INTEGER
        assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
    # test send_request_to_machine
    dgl.distributed.send_request_to_machine(0, req)
    res = dgl.distributed.recv_response()
    assert res.hello_str == STR
    assert res.integer == INTEGER
    assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
    # test remote_call_to_machine
    target_and_requests = []
    for i in range(10):
        target_and_requests.append((0, req))
    res_list = dgl.distributed.remote_call_to_machine(target_and_requests)
131
132
133
134
    for res in res_list:
        assert res.hello_str == STR
        assert res.integer == INTEGER
        assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
135

136
def test_serialize():
137
    reset_envs()
138
    os.environ['DGL_DIST_MODE'] = 'distributed'
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    from dgl.distributed.rpc import serialize_to_payload, deserialize_from_payload
    SERVICE_ID = 12345
    dgl.distributed.register_service(SERVICE_ID, MyRequest, MyResponse)
    req = MyRequest()
    data, tensors = serialize_to_payload(req)
    req1 = deserialize_from_payload(MyRequest, data, tensors)
    req1.foo(req1.x, req1.y)
    assert req.x == req1.x
    assert req.y == req1.y
    assert F.array_equal(req.z, req1.z)

    res = MyResponse()
    data, tensors = serialize_to_payload(res)
    res1 = deserialize_from_payload(MyResponse, data, tensors)
    assert res.x == res1.x

def test_rpc_msg():
156
    reset_envs()
157
    os.environ['DGL_DIST_MODE'] = 'distributed'
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    from dgl.distributed.rpc import serialize_to_payload, deserialize_from_payload, RPCMessage
    SERVICE_ID = 32452
    dgl.distributed.register_service(SERVICE_ID, MyRequest, MyResponse)
    req = MyRequest()
    data, tensors = serialize_to_payload(req)
    rpcmsg = RPCMessage(SERVICE_ID, 23, 0, 1, data, tensors)
    assert rpcmsg.service_id == SERVICE_ID
    assert rpcmsg.msg_seq == 23
    assert rpcmsg.client_id == 0
    assert rpcmsg.server_id == 1
    assert len(rpcmsg.data) == len(data)
    assert len(rpcmsg.tensors) == 1
    assert F.array_equal(rpcmsg.tensors[0], req.z)

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_rpc():
174
    reset_envs()
175
    os.environ['DGL_DIST_MODE'] = 'distributed'
176
    generate_ip_config("rpc_ip_config.txt", 1, 1)
177
    ctx = mp.get_context('spawn')
178
179
    pserver = ctx.Process(target=start_server, args=(1, "rpc_ip_config.txt"))
    pclient = ctx.Process(target=start_client, args=("rpc_ip_config.txt",))
180
181
182
183
    pserver.start()
    pclient.start()
    pserver.join()
    pclient.join()
184

185
186
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_multi_client():
187
    reset_envs()
188
    os.environ['DGL_DIST_MODE'] = 'distributed'
189
    generate_ip_config("rpc_ip_config_mul_client.txt", 1, 1)
190
    ctx = mp.get_context('spawn')
191
192
    num_clients = 20
    pserver = ctx.Process(target=start_server, args=(num_clients, "rpc_ip_config_mul_client.txt"))
193
    pclient_list = []
194
    for i in range(num_clients):
195
196
197
        pclient = ctx.Process(target=start_client, args=("rpc_ip_config_mul_client.txt",))
        pclient_list.append(pclient)
    pserver.start()
198
    for i in range(num_clients):
199
        pclient_list[i].start()
200
    for i in range(num_clients):
201
202
203
204
        pclient_list[i].join()
    pserver.join()


205
206
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_multi_thread_rpc():
207
    reset_envs()
208
209
    os.environ['DGL_DIST_MODE'] = 'distributed'
    num_servers = 2
210
    generate_ip_config("rpc_ip_config_multithread.txt", num_servers, num_servers)
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    ctx = mp.get_context('spawn')
    pserver_list = []
    for i in range(num_servers):
        pserver = ctx.Process(target=start_server, args=(1, "rpc_ip_config_multithread.txt", i))
        pserver.start()
        pserver_list.append(pserver)
    def start_client_multithread(ip_config):
        import threading
        dgl.distributed.connect_to_server(ip_config=ip_config, num_servers=1)
        dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse)
        
        req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
        dgl.distributed.send_request(0, req)

        def subthread_call(server_id):            
226
            req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
227
228
229
230
231
232
233
234
235
            dgl.distributed.send_request(server_id, req)
        
        
        subthread = threading.Thread(target=subthread_call, args=(1,))
        subthread.start()
        subthread.join()
        
        res0 = dgl.distributed.recv_response()
        res1 = dgl.distributed.recv_response()
236
        # Order is not guaranteed
237
        assert_array_equal(F.asnumpy(res0.tensor), F.asnumpy(TENSOR))
238
        assert_array_equal(F.asnumpy(res1.tensor), F.asnumpy(TENSOR))
239
240
241
242
243
244
        dgl.distributed.exit_client()

    start_client_multithread("rpc_ip_config_multithread.txt")
    pserver.join()


245
246
247
248
249
250
251
252
253
254
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_multi_client_groups():
    reset_envs()
    os.environ['DGL_DIST_MODE'] = 'distributed'
    ip_config = "rpc_ip_config_mul_client_groups.txt"
    num_machines = 5
    # should test with larger number but due to possible port in-use issue.
    num_servers = 1
    generate_ip_config(ip_config, num_machines, num_servers)
    # presssue test
255
256
    num_clients = 2
    num_groups = 2
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
    ctx = mp.get_context('spawn')
    pserver_list = []
    for i in range(num_servers*num_machines):
        pserver = ctx.Process(target=start_server, args=(num_clients, ip_config, i, True, num_servers))
        pserver.start()
        pserver_list.append(pserver)
    pclient_list = []
    for i in range(num_clients):
        for group_id in range(num_groups):
            pclient = ctx.Process(target=start_client, args=(ip_config, group_id, num_servers))
            pclient.start()
            pclient_list.append(pclient)
    for p in pclient_list:
        p.join()
    for p in pserver_list:
        assert p.is_alive()
    # force shutdown server
    dgl.distributed.shutdown_servers(ip_config, num_servers)
    for p in pserver_list:
        p.join()


279
280
281
282
if __name__ == '__main__':
    test_serialize()
    test_rpc_msg()
    test_rpc()
283
    test_multi_client()
284
    test_multi_thread_rpc()