Unverified Commit dd8d5289 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[rpc] Add num_clients and sort client ID by machine-id (#1713)



* add num_clients to kvstore

* update

* update

* update

* update

* update

* update

* update

* fix lint

* update

* update

* update

* add test

* update

* update

* update

* update

* fix test

* update

* update

* update

* update

* update

* update

* update
Co-authored-by: default avatarDa Zheng <zhengda1936@gmail.com>
parent 0015eff7
...@@ -740,6 +740,7 @@ class KVClient(object): ...@@ -740,6 +740,7 @@ class KVClient(object):
self._machine_count = int(self._server_count / self._group_count) self._machine_count = int(self._server_count / self._group_count)
self._client_id = rpc.get_rank() self._client_id = rpc.get_rank()
self._machine_id = rpc.get_machine_id() self._machine_id = rpc.get_machine_id()
self._num_clients = rpc.get_num_client()
self._part_id = self._machine_id self._part_id = self._machine_id
self._main_server_id = self._machine_id * self._group_count self._main_server_id = self._machine_id * self._group_count
# push and pull handler # push and pull handler
......
...@@ -15,7 +15,8 @@ __all__ = ['set_rank', 'get_rank', 'Request', 'Response', 'register_service', \ ...@@ -15,7 +15,8 @@ __all__ = ['set_rank', 'get_rank', 'Request', 'Response', 'register_service', \
'receiver_wait', 'add_receiver_addr', 'sender_connect', 'read_ip_config', \ 'receiver_wait', 'add_receiver_addr', 'sender_connect', 'read_ip_config', \
'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \ 'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \
'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call', \ 'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call', \
'send_request_to_machine', 'remote_call_to_machine', 'fast_pull'] 'send_request_to_machine', 'remote_call_to_machine', 'fast_pull', \
'get_num_client', 'set_num_client']
REQUEST_CLASS_TO_SERVICE_ID = {} REQUEST_CLASS_TO_SERVICE_ID = {}
RESPONSE_CLASS_TO_SERVICE_ID = {} RESPONSE_CLASS_TO_SERVICE_ID = {}
...@@ -223,6 +224,16 @@ def get_num_server(): ...@@ -223,6 +224,16 @@ def get_num_server():
""" """
return _CAPI_DGLRPCGetNumServer() return _CAPI_DGLRPCGetNumServer()
def set_num_client(num_client):
"""Set the total number of client.
"""
_CAPI_DGLRPCSetNumClient(int(num_client))
def get_num_client():
"""Get the total number of client.
"""
return _CAPI_DGLRPCGetNumClient()
def set_num_server_per_machine(num_server): def set_num_server_per_machine(num_server):
"""Set the total number of server per machine """Set the total number of server per machine
""" """
...@@ -1017,4 +1028,44 @@ class ShutDownRequest(Request): ...@@ -1017,4 +1028,44 @@ class ShutDownRequest(Request):
finalize_server() finalize_server()
return 'exit' return 'exit'
GET_NUM_CLIENT = 22453
class GetNumberClientsResponse(Response):
"""This reponse will send total number of clients.
Parameters
----------
num_client : int
total number of clients
"""
def __init__(self, num_client):
self.num_client = num_client
def __getstate__(self):
return self.num_client
def __setstate__(self, state):
self.num_client = state
class GetNumberClientsRequest(Request):
"""Client send this request to get the total number of client.
Parameters
----------
client_id : int
client's ID
"""
def __init__(self, client_id):
self.client_id = client_id
def __getstate__(self):
return self.client_id
def __setstate__(self, state):
self.client_id = state
def process_request(self, server_state):
res = GetNumberClientsResponse(get_num_client())
return res
_init_api("dgl.distributed.rpc") _init_api("dgl.distributed.rpc")
...@@ -111,6 +111,9 @@ def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket ...@@ -111,6 +111,9 @@ def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket
rpc.register_service(rpc.SHUT_DOWN_SERVER, rpc.register_service(rpc.SHUT_DOWN_SERVER,
rpc.ShutDownRequest, rpc.ShutDownRequest,
None) None)
rpc.register_service(rpc.GET_NUM_CLIENT,
rpc.GetNumberClientsRequest,
rpc.GetNumberClientsResponse)
rpc.register_ctrl_c() rpc.register_ctrl_c()
server_namebook = rpc.read_ip_config(ip_config) server_namebook = rpc.read_ip_config(ip_config)
num_servers = len(server_namebook) num_servers = len(server_namebook)
...@@ -150,6 +153,11 @@ def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket ...@@ -150,6 +153,11 @@ def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket
rpc.set_rank(res.client_id) rpc.set_rank(res.client_id)
print("Machine (%d) client (%d) connect to server successfuly!" \ print("Machine (%d) client (%d) connect to server successfuly!" \
% (machine_id, rpc.get_rank())) % (machine_id, rpc.get_rank()))
# get total number of client
get_client_num_req = rpc.GetNumberClientsRequest(rpc.get_rank())
rpc.send_request(0, get_client_num_req)
res = rpc.recv_response()
rpc.set_num_client(res.num_client)
def finalize_client(): def finalize_client():
"""Release resources of this client.""" """Release resources of this client."""
......
...@@ -44,6 +44,9 @@ def start_server(server_id, ip_config, num_clients, server_state, \ ...@@ -44,6 +44,9 @@ def start_server(server_id, ip_config, num_clients, server_state, \
rpc.register_service(rpc.SHUT_DOWN_SERVER, rpc.register_service(rpc.SHUT_DOWN_SERVER,
rpc.ShutDownRequest, rpc.ShutDownRequest,
None) None)
rpc.register_service(rpc.GET_NUM_CLIENT,
rpc.GetNumberClientsRequest,
rpc.GetNumberClientsResponse)
rpc.set_rank(server_id) rpc.set_rank(server_id)
server_namebook = rpc.read_ip_config(ip_config) server_namebook = rpc.read_ip_config(ip_config)
machine_id = server_namebook[server_id][0] machine_id = server_namebook[server_id][0]
...@@ -58,6 +61,7 @@ def start_server(server_id, ip_config, num_clients, server_state, \ ...@@ -58,6 +61,7 @@ def start_server(server_id, ip_config, num_clients, server_state, \
print("Wait connections ...") print("Wait connections ...")
rpc.receiver_wait(ip_addr, port, num_clients) rpc.receiver_wait(ip_addr, port, num_clients)
print("%d clients connected!" % num_clients) print("%d clients connected!" % num_clients)
rpc.set_num_client(num_clients)
# Recv all the client's IP and assign ID to clients # Recv all the client's IP and assign ID to clients
addr_list = [] addr_list = []
client_namebook = {} client_namebook = {}
......
...@@ -169,6 +169,17 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServer") ...@@ -169,6 +169,17 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServer")
*rv = RPCContext::ThreadLocal()->num_servers; *rv = RPCContext::ThreadLocal()->num_servers;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumClient")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const int32_t num_clients = args[0];
*rv = RPCContext::ThreadLocal()->num_clients = num_clients;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumClient")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::ThreadLocal()->num_clients;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServerPerMachine") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServerPerMachine")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
const int32_t num_servers = args[0]; const int32_t num_servers = args[0];
......
...@@ -56,6 +56,11 @@ struct RPCContext { ...@@ -56,6 +56,11 @@ struct RPCContext {
*/ */
int32_t num_servers = 0; int32_t num_servers = 0;
/*!
* \brief Total number of client.
*/
int32_t num_clients = 0;
/*! /*!
* \brief Total number of server per machine. * \brief Total number of server per machine.
*/ */
......
...@@ -130,6 +130,7 @@ def start_client(): ...@@ -130,6 +130,7 @@ def start_client():
dgl.distributed.connect_to_server(ip_config='kv_ip_config.txt') dgl.distributed.connect_to_server(ip_config='kv_ip_config.txt')
# Init kvclient # Init kvclient
kvclient = dgl.distributed.KVClient(ip_config='kv_ip_config.txt') kvclient = dgl.distributed.KVClient(ip_config='kv_ip_config.txt')
assert dgl.distributed.get_num_client() == 1
kvclient.init_data(name='data_1', kvclient.init_data(name='data_1',
shape=F.shape(data_1), shape=F.shape(data_1),
dtype=F.dtype(data_1), dtype=F.dtype(data_1),
......
...@@ -107,17 +107,17 @@ class HelloRequest(dgl.distributed.Request): ...@@ -107,17 +107,17 @@ class HelloRequest(dgl.distributed.Request):
res = HelloResponse(self.hello_str, self.integer, new_tensor) res = HelloResponse(self.hello_str, self.integer, new_tensor)
return res return res
def start_server(): def start_server(num_clients, ip_config):
server_state = dgl.distributed.ServerState(None, local_g=None, partition_book=None) server_state = dgl.distributed.ServerState(None, local_g=None, partition_book=None)
dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse) dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse)
dgl.distributed.start_server(server_id=0, dgl.distributed.start_server(server_id=0,
ip_config='rpc_ip_config.txt', ip_config=ip_config,
num_clients=1, num_clients=num_clients,
server_state=server_state) server_state=server_state)
def start_client(): def start_client(ip_config):
dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse) dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse)
dgl.distributed.connect_to_server(ip_config='rpc_ip_config.txt') dgl.distributed.connect_to_server(ip_config=ip_config)
req = HelloRequest(STR, INTEGER, TENSOR, simple_func) req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
# test send and recv # test send and recv
dgl.distributed.send_request(0, req) dgl.distributed.send_request(0, req)
...@@ -149,9 +149,13 @@ def start_client(): ...@@ -149,9 +149,13 @@ def start_client():
assert res.hello_str == STR assert res.hello_str == STR
assert res.integer == INTEGER assert res.integer == INTEGER
assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR)) assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
# clean up # clean up
dgl.distributed.shutdown_servers() time.sleep(2)
if dgl.distributed.get_rank() == 0:
dgl.distributed.shutdown_servers()
dgl.distributed.finalize_client() dgl.distributed.finalize_client()
print("Get rank: %d" % dgl.distributed.get_rank())
def test_serialize(): def test_serialize():
from dgl.distributed.rpc import serialize_to_payload, deserialize_from_payload from dgl.distributed.rpc import serialize_to_payload, deserialize_from_payload
...@@ -192,15 +196,37 @@ def test_rpc(): ...@@ -192,15 +196,37 @@ def test_rpc():
ip_config.write('%s 1\n' % ip_addr) ip_config.write('%s 1\n' % ip_addr)
ip_config.close() ip_config.close()
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
pserver = ctx.Process(target=start_server) pserver = ctx.Process(target=start_server, args=(1, "rpc_ip_config.txt"))
pclient = ctx.Process(target=start_client) pclient = ctx.Process(target=start_client, args=("rpc_ip_config.txt",))
pserver.start() pserver.start()
time.sleep(1) time.sleep(1)
pclient.start() pclient.start()
pserver.join() pserver.join()
pclient.join() pclient.join()
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_multi_client():
ip_config = open("rpc_ip_config_mul_client.txt", "w")
ip_addr = get_local_usable_addr()
ip_config.write('%s 1\n' % ip_addr)
ip_config.close()
ctx = mp.get_context('spawn')
pserver = ctx.Process(target=start_server, args=(10, "rpc_ip_config_mul_client.txt"))
pclient_list = []
for i in range(10):
pclient = ctx.Process(target=start_client, args=("rpc_ip_config_mul_client.txt",))
pclient_list.append(pclient)
pserver.start()
time.sleep(1)
for i in range(10):
pclient_list[i].start()
for i in range(10):
pclient_list[i].join()
pserver.join()
if __name__ == '__main__': if __name__ == '__main__':
test_serialize() test_serialize()
test_rpc_msg() test_rpc_msg()
test_rpc() test_rpc()
test_multi_client()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment