test_rpc.py 7.3 KB
Newer Older
1
2
import os
import time
3
import socket
4
5
6
7

import dgl
import backend as F
import unittest, pytest
8
import multiprocessing as mp
9
10
from numpy.testing import assert_array_equal

11
12
13
14
if os.name != 'nt':
    import fcntl
    import struct

15
16
17
18
19
INTEGER = 2
STR = 'hello world!'
HELLO_SERVICE_ID = 901231
TENSOR = F.zeros((10, 10), F.int64, F.cpu())

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def get_local_usable_addr():
    """Get local usable IP and port

    Returns
    -------
    str
        IP address, e.g., '192.168.8.12:50051'
    """
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    try:
        # doesn't even have to be reachable
        sock.connect(('10.255.255.255', 1))
        ip_addr = sock.getsockname()[0]
    except ValueError:
        ip_addr = '127.0.0.1'
    finally:
        sock.close()
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.bind(("", 0))
    sock.listen(1)
    port = sock.getsockname()[1]
    sock.close()

    return ip_addr + ' ' + str(port)

45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
def foo(x, y):
    assert x == 123
    assert y == "abc"

class MyRequest(dgl.distributed.Request):
    def __init__(self):
        self.x = 123
        self.y = "abc"
        self.z = F.randn((3, 4))
        self.foo = foo

    def __getstate__(self):
        return self.x, self.y, self.z, self.foo

    def __setstate__(self, state):
        self.x, self.y, self.z, self.foo = state

    def process_request(self, server_state):
        pass

class MyResponse(dgl.distributed.Response):
    def __init__(self):
        self.x = 432

    def __getstate__(self):
        return self.x

    def __setstate__(self, state):
        self.x = state
 
def simple_func(tensor):
    return tensor

class HelloResponse(dgl.distributed.Response):
    def __init__(self, hello_str, integer, tensor):
        self.hello_str = hello_str
        self.integer = integer
        self.tensor = tensor

    def __getstate__(self):
        return self.hello_str, self.integer, self.tensor

    def __setstate__(self, state):
        self.hello_str, self.integer, self.tensor = state

class HelloRequest(dgl.distributed.Request):
    def __init__(self, hello_str, integer, tensor, func):
        self.hello_str = hello_str
        self.integer = integer
        self.tensor = tensor
        self.func = func

    def __getstate__(self):
        return self.hello_str, self.integer, self.tensor, self.func

    def __setstate__(self, state):
        self.hello_str, self.integer, self.tensor, self.func = state

    def process_request(self, server_state):
        assert self.hello_str == STR
        assert self.integer == INTEGER
        new_tensor = self.func(self.tensor)
        res = HelloResponse(self.hello_str, self.integer, new_tensor)
        return res

110
def start_server(num_clients, ip_config):
111
112
    print("Sleep 5 seconds to test client re-connect.")
    time.sleep(5)
Jinjing Zhou's avatar
Jinjing Zhou committed
113
    server_state = dgl.distributed.ServerState(None, local_g=None, partition_book=None)
114
    dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse)
115
    dgl.distributed.start_server(server_id=0, 
116
                                 ip_config=ip_config, 
117
                                 num_servers=1,
118
                                 num_clients=num_clients, 
119
                                 server_state=server_state)
120

121
def start_client(ip_config):
122
    dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse)
123
    dgl.distributed.connect_to_server(ip_config=ip_config, num_servers=1)
124
125
126
127
128
129
130
131
132
133
134
135
    req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
    # test send and recv
    dgl.distributed.send_request(0, req)
    res = dgl.distributed.recv_response()
    assert res.hello_str == STR
    assert res.integer == INTEGER
    assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
    # test remote_call
    target_and_requests = []
    for i in range(10):
        target_and_requests.append((0, req))
    res_list = dgl.distributed.remote_call(target_and_requests)
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    for res in res_list:
        assert res.hello_str == STR
        assert res.integer == INTEGER
        assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
    # test send_request_to_machine
    dgl.distributed.send_request_to_machine(0, req)
    res = dgl.distributed.recv_response()
    assert res.hello_str == STR
    assert res.integer == INTEGER
    assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
    # test remote_call_to_machine
    target_and_requests = []
    for i in range(10):
        target_and_requests.append((0, req))
    res_list = dgl.distributed.remote_call_to_machine(target_and_requests)
151
152
153
154
    for res in res_list:
        assert res.hello_str == STR
        assert res.integer == INTEGER
        assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
155

156
def test_serialize():
157
    os.environ['DGL_DIST_MODE'] = 'distributed'
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    from dgl.distributed.rpc import serialize_to_payload, deserialize_from_payload
    SERVICE_ID = 12345
    dgl.distributed.register_service(SERVICE_ID, MyRequest, MyResponse)
    req = MyRequest()
    data, tensors = serialize_to_payload(req)
    req1 = deserialize_from_payload(MyRequest, data, tensors)
    req1.foo(req1.x, req1.y)
    assert req.x == req1.x
    assert req.y == req1.y
    assert F.array_equal(req.z, req1.z)

    res = MyResponse()
    data, tensors = serialize_to_payload(res)
    res1 = deserialize_from_payload(MyResponse, data, tensors)
    assert res.x == res1.x

def test_rpc_msg():
175
    os.environ['DGL_DIST_MODE'] = 'distributed'
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    from dgl.distributed.rpc import serialize_to_payload, deserialize_from_payload, RPCMessage
    SERVICE_ID = 32452
    dgl.distributed.register_service(SERVICE_ID, MyRequest, MyResponse)
    req = MyRequest()
    data, tensors = serialize_to_payload(req)
    rpcmsg = RPCMessage(SERVICE_ID, 23, 0, 1, data, tensors)
    assert rpcmsg.service_id == SERVICE_ID
    assert rpcmsg.msg_seq == 23
    assert rpcmsg.client_id == 0
    assert rpcmsg.server_id == 1
    assert len(rpcmsg.data) == len(data)
    assert len(rpcmsg.tensors) == 1
    assert F.array_equal(rpcmsg.tensors[0], req.z)

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_rpc():
192
    os.environ['DGL_DIST_MODE'] = 'distributed'
193
    ip_config = open("rpc_ip_config.txt", "w")
194
    ip_addr = get_local_usable_addr()
195
    ip_config.write('%s\n' % ip_addr)
196
    ip_config.close()
197
    ctx = mp.get_context('spawn')
198
199
    pserver = ctx.Process(target=start_server, args=(1, "rpc_ip_config.txt"))
    pclient = ctx.Process(target=start_client, args=("rpc_ip_config.txt",))
200
201
202
203
204
    pserver.start()
    time.sleep(1)
    pclient.start()
    pserver.join()
    pclient.join()
205

206
207
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_multi_client():
208
    os.environ['DGL_DIST_MODE'] = 'distributed'
209
210
    ip_config = open("rpc_ip_config_mul_client.txt", "w")
    ip_addr = get_local_usable_addr()
211
    ip_config.write('%s\n' % ip_addr)
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
    ip_config.close()
    ctx = mp.get_context('spawn')
    pserver = ctx.Process(target=start_server, args=(10, "rpc_ip_config_mul_client.txt"))
    pclient_list = []
    for i in range(10):
        pclient = ctx.Process(target=start_client, args=("rpc_ip_config_mul_client.txt",))
        pclient_list.append(pclient)
    pserver.start()
    for i in range(10):
        pclient_list[i].start()
    for i in range(10):
        pclient_list[i].join()
    pserver.join()


227
228
229
230
if __name__ == '__main__':
    test_serialize()
    test_rpc_msg()
    test_rpc()
231
    test_multi_client()