"git@developer.sourcefind.cn:change/sglang.git" did not exist on "5dddb331c48039c2649ccdba3a81a7fbe43a7fe5"
Unverified Commit dd8d5289 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[rpc] Add num_clients and sort client ID by machine-id (#1713)



* add num_clients to kvstore

* update

* update

* update

* update

* update

* update

* update

* fix lint

* update

* update

* update

* add test

* update

* update

* update

* update

* fix test

* update

* update

* update

* update

* update

* update

* update
Co-authored-by: default avatarDa Zheng <zhengda1936@gmail.com>
parent 0015eff7
......@@ -740,6 +740,7 @@ class KVClient(object):
self._machine_count = int(self._server_count / self._group_count)
self._client_id = rpc.get_rank()
self._machine_id = rpc.get_machine_id()
self._num_clients = rpc.get_num_client()
self._part_id = self._machine_id
self._main_server_id = self._machine_id * self._group_count
# push and pull handler
......
......@@ -15,7 +15,8 @@ __all__ = ['set_rank', 'get_rank', 'Request', 'Response', 'register_service', \
'receiver_wait', 'add_receiver_addr', 'sender_connect', 'read_ip_config', \
'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \
'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call', \
'send_request_to_machine', 'remote_call_to_machine', 'fast_pull']
'send_request_to_machine', 'remote_call_to_machine', 'fast_pull', \
'get_num_client', 'set_num_client']
REQUEST_CLASS_TO_SERVICE_ID = {}
RESPONSE_CLASS_TO_SERVICE_ID = {}
......@@ -223,6 +224,16 @@ def get_num_server():
"""
return _CAPI_DGLRPCGetNumServer()
def set_num_client(num_client):
"""Set the total number of client.
"""
_CAPI_DGLRPCSetNumClient(int(num_client))
def get_num_client():
"""Get the total number of client.
"""
return _CAPI_DGLRPCGetNumClient()
def set_num_server_per_machine(num_server):
"""Set the total number of server per machine
"""
......@@ -1017,4 +1028,44 @@ class ShutDownRequest(Request):
finalize_server()
return 'exit'
GET_NUM_CLIENT = 22453
class GetNumberClientsResponse(Response):
"""This reponse will send total number of clients.
Parameters
----------
num_client : int
total number of clients
"""
def __init__(self, num_client):
self.num_client = num_client
def __getstate__(self):
return self.num_client
def __setstate__(self, state):
self.num_client = state
class GetNumberClientsRequest(Request):
"""Client send this request to get the total number of client.
Parameters
----------
client_id : int
client's ID
"""
def __init__(self, client_id):
self.client_id = client_id
def __getstate__(self):
return self.client_id
def __setstate__(self, state):
self.client_id = state
def process_request(self, server_state):
res = GetNumberClientsResponse(get_num_client())
return res
_init_api("dgl.distributed.rpc")
......@@ -111,6 +111,9 @@ def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket
rpc.register_service(rpc.SHUT_DOWN_SERVER,
rpc.ShutDownRequest,
None)
rpc.register_service(rpc.GET_NUM_CLIENT,
rpc.GetNumberClientsRequest,
rpc.GetNumberClientsResponse)
rpc.register_ctrl_c()
server_namebook = rpc.read_ip_config(ip_config)
num_servers = len(server_namebook)
......@@ -150,6 +153,11 @@ def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket
rpc.set_rank(res.client_id)
print("Machine (%d) client (%d) connect to server successfuly!" \
% (machine_id, rpc.get_rank()))
# get total number of client
get_client_num_req = rpc.GetNumberClientsRequest(rpc.get_rank())
rpc.send_request(0, get_client_num_req)
res = rpc.recv_response()
rpc.set_num_client(res.num_client)
def finalize_client():
"""Release resources of this client."""
......
......@@ -44,6 +44,9 @@ def start_server(server_id, ip_config, num_clients, server_state, \
rpc.register_service(rpc.SHUT_DOWN_SERVER,
rpc.ShutDownRequest,
None)
rpc.register_service(rpc.GET_NUM_CLIENT,
rpc.GetNumberClientsRequest,
rpc.GetNumberClientsResponse)
rpc.set_rank(server_id)
server_namebook = rpc.read_ip_config(ip_config)
machine_id = server_namebook[server_id][0]
......@@ -58,6 +61,7 @@ def start_server(server_id, ip_config, num_clients, server_state, \
print("Wait connections ...")
rpc.receiver_wait(ip_addr, port, num_clients)
print("%d clients connected!" % num_clients)
rpc.set_num_client(num_clients)
# Recv all the client's IP and assign ID to clients
addr_list = []
client_namebook = {}
......
......@@ -169,6 +169,17 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServer")
*rv = RPCContext::ThreadLocal()->num_servers;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumClient")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const int32_t num_clients = args[0];
*rv = RPCContext::ThreadLocal()->num_clients = num_clients;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumClient")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::ThreadLocal()->num_clients;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServerPerMachine")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const int32_t num_servers = args[0];
......
......@@ -56,6 +56,11 @@ struct RPCContext {
*/
int32_t num_servers = 0;
/*!
* \brief Total number of client.
*/
int32_t num_clients = 0;
/*!
* \brief Total number of server per machine.
*/
......
......@@ -130,6 +130,7 @@ def start_client():
dgl.distributed.connect_to_server(ip_config='kv_ip_config.txt')
# Init kvclient
kvclient = dgl.distributed.KVClient(ip_config='kv_ip_config.txt')
assert dgl.distributed.get_num_client() == 1
kvclient.init_data(name='data_1',
shape=F.shape(data_1),
dtype=F.dtype(data_1),
......
......@@ -107,17 +107,17 @@ class HelloRequest(dgl.distributed.Request):
res = HelloResponse(self.hello_str, self.integer, new_tensor)
return res
def start_server():
def start_server(num_clients, ip_config):
server_state = dgl.distributed.ServerState(None, local_g=None, partition_book=None)
dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse)
dgl.distributed.start_server(server_id=0,
ip_config='rpc_ip_config.txt',
num_clients=1,
ip_config=ip_config,
num_clients=num_clients,
server_state=server_state)
def start_client():
def start_client(ip_config):
dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse)
dgl.distributed.connect_to_server(ip_config='rpc_ip_config.txt')
dgl.distributed.connect_to_server(ip_config=ip_config)
req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
# test send and recv
dgl.distributed.send_request(0, req)
......@@ -149,9 +149,13 @@ def start_client():
assert res.hello_str == STR
assert res.integer == INTEGER
assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
# clean up
time.sleep(2)
if dgl.distributed.get_rank() == 0:
dgl.distributed.shutdown_servers()
dgl.distributed.finalize_client()
print("Get rank: %d" % dgl.distributed.get_rank())
def test_serialize():
from dgl.distributed.rpc import serialize_to_payload, deserialize_from_payload
......@@ -192,15 +196,37 @@ def test_rpc():
ip_config.write('%s 1\n' % ip_addr)
ip_config.close()
ctx = mp.get_context('spawn')
pserver = ctx.Process(target=start_server)
pclient = ctx.Process(target=start_client)
pserver = ctx.Process(target=start_server, args=(1, "rpc_ip_config.txt"))
pclient = ctx.Process(target=start_client, args=("rpc_ip_config.txt",))
pserver.start()
time.sleep(1)
pclient.start()
pserver.join()
pclient.join()
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_multi_client():
ip_config = open("rpc_ip_config_mul_client.txt", "w")
ip_addr = get_local_usable_addr()
ip_config.write('%s 1\n' % ip_addr)
ip_config.close()
ctx = mp.get_context('spawn')
pserver = ctx.Process(target=start_server, args=(10, "rpc_ip_config_mul_client.txt"))
pclient_list = []
for i in range(10):
pclient = ctx.Process(target=start_client, args=("rpc_ip_config_mul_client.txt",))
pclient_list.append(pclient)
pserver.start()
time.sleep(1)
for i in range(10):
pclient_list[i].start()
for i in range(10):
pclient_list[i].join()
pserver.join()
if __name__ == '__main__':
test_serialize()
test_rpc_msg()
test_rpc()
test_multi_client()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment