Unverified Commit 5b515cf6 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[KVStore] Support group barrier (#1880)



* udpate

* update

* update

* update

* update

* update

* update

* update

* fix lint

* update

* update

* update

* update

* udpate

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update
Co-authored-by: default avatarDa Zheng <zhengda1936@gmail.com>
parent 444becf0
......@@ -193,26 +193,26 @@ class BarrierRequest(rpc.Request):
Parameters
----------
msg : string
string msg
role : string
client role
"""
def __init__(self, msg):
self.msg = msg
def __init__(self, role):
self.role = role
def __getstate__(self):
return self.msg
return self.role
def __setstate__(self, state):
self.msg = state
self.role = state
def process_request(self, server_state):
assert self.msg == BARRIER_MSG
kv_store = server_state.kv_store
kv_store.barrier_count = kv_store.barrier_count + 1
if kv_store.barrier_count == kv_store.num_clients:
kv_store.barrier_count = 0
count = kv_store.barrier_count[self.role]
kv_store.barrier_count[self.role] = count + 1
if kv_store.barrier_count[self.role] == len(kv_store.role[self.role]):
kv_store.barrier_count[self.role] = 0
res_list = []
for target_id in range(kv_store.num_clients):
for target_id in kv_store.role[self.role]:
res_list.append((target_id, BarrierResponse(BARRIER_MSG)))
return res_list
return None
......@@ -506,6 +506,52 @@ class DeleteDataRequest(rpc.Request):
res = DeleteDataResponse(DELETE_MSG)
return res
REGISTER_ROLE = 901241
ROLE_MSG = "Register_Role"
class RegisterRoleResponse(rpc.Response):
"""Send a confirmation signal (just a short string message)
of RegisterRoleRequest to client.
"""
def __init__(self, msg):
self.msg = msg
def __getstate__(self):
return self.msg
def __setstate__(self, state):
self.msg = state
class RegisterRoleRequest(rpc.Request):
"""Send client id and role to server
Parameters
----------
client_id : int
ID of client
role : str
role of client
"""
def __init__(self, client_id, role):
self.client_id = client_id
self.role = role
def __getstate__(self):
return self.client_id, self.role
def __setstate__(self, state):
self.client_id, self.role = state
def process_request(self, server_state):
kv_store = server_state.kv_store
role = kv_store.role
if self.role not in role:
role[self.role] = set()
kv_store.barrier_count[self.role] = 0
role[self.role].add(self.client_id)
res = RegisterRoleResponse(ROLE_MSG)
return res
############################ KVServer ###############################
def default_push_handler(target, name, id_tensor, data_tensor):
......@@ -604,6 +650,9 @@ class KVServer(object):
rpc.register_service(DELETE_DATA,
DeleteDataRequest,
DeleteDataResponse)
rpc.register_service(REGISTER_ROLE,
RegisterRoleRequest,
RegisterRoleResponse)
# Store the tensor data with specified data name
self._data_store = {}
# Store the partition information with specified data name
......@@ -620,10 +669,12 @@ class KVServer(object):
# We assume partition_id is equal to machine_id
self._part_id = self._machine_id
self._num_clients = num_clients
self._barrier_count = 0
self._barrier_count = {}
# push and pull handler
self._push_handlers = {}
self._pull_handlers = {}
# store client role
self._role = {}
@property
def server_id(self):
......@@ -665,6 +716,11 @@ class KVServer(object):
"""Get push handler"""
return self._push_handlers
@property
def role(self):
"""Get client role"""
return self._role
@property
def pull_handlers(self):
"""Get pull handler"""
......@@ -748,8 +804,10 @@ class KVClient(object):
----------
ip_config : str
Path of IP configuration file.
role : str
We can set different role for kvstore.
"""
def __init__(self, ip_config):
def __init__(self, ip_config, role='default'):
assert rpc.get_rank() != -1, 'Please invoke rpc.connect_to_server() \
before creating KVClient.'
assert os.path.exists(ip_config), 'Cannot open file: %s' % ip_config
......@@ -784,6 +842,9 @@ class KVClient(object):
rpc.register_service(DELETE_DATA,
DeleteDataRequest,
DeleteDataResponse)
rpc.register_service(REGISTER_ROLE,
RegisterRoleRequest,
RegisterRoleResponse)
# Store the tensor data with specified data name
self._data_store = {}
# Store the partition information with specified data name
......@@ -805,12 +866,23 @@ class KVClient(object):
# push and pull handler
self._pull_handlers = {}
self._push_handlers = {}
# register role on server-0
self._role = role
request = RegisterRoleRequest(self._client_id, self._role)
rpc.send_request(0, request)
response = rpc.recv_response()
assert response.msg == ROLE_MSG
@property
def client_id(self):
"""Get client ID"""
return self._client_id
@property
def role(self):
"""Get client role"""
return self._role
@property
def machine_id(self):
"""Get machine ID"""
......@@ -821,12 +893,8 @@ class KVClient(object):
This API will be blocked untill all the clients invoke this API.
"""
request = BarrierRequest(BARRIER_MSG)
# send request to all the server nodes
for server_id in range(self._server_count):
rpc.send_request(server_id, request)
# recv response from all the server nodes
for _ in range(self._server_count):
request = BarrierRequest(self._role)
rpc.send_request(0, request)
response = rpc.recv_response()
assert response.msg == BARRIER_MSG
......
......@@ -16,7 +16,7 @@ __all__ = ['set_rank', 'get_rank', 'Request', 'Response', 'register_service', \
'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \
'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call', \
'send_request_to_machine', 'remote_call_to_machine', 'fast_pull', \
'get_num_client', 'set_num_client']
'get_num_client', 'set_num_client', 'client_barrier']
REQUEST_CLASS_TO_SERVICE_ID = {}
RESPONSE_CLASS_TO_SERVICE_ID = {}
......@@ -899,6 +899,13 @@ def recv_rpc_message(timeout=0):
_CAPI_DGLRPCRecvRPCMessage(timeout, msg)
return msg
def client_barrier():
"""Barrier all client processes"""
req = ClientBarrierRequest()
send_request(0, req)
res = recv_response()
assert res.msg == 'barrier'
def finalize_server():
"""Finalize resources of current server
"""
......@@ -1068,4 +1075,50 @@ class GetNumberClientsRequest(Request):
res = GetNumberClientsResponse(get_num_client())
return res
CLIENT_BARRIER = 22454
class ClientBarrierResponse(Response):
"""Send the barrier confirmation to client
Parameters
----------
msg : str
string msg
"""
def __init__(self, msg='barrier'):
self.msg = msg
def __getstate__(self):
return self.msg
def __setstate__(self, state):
self.msg = state
class ClientBarrierRequest(Request):
"""Send the barrier information to server
Parameters
----------
msg : str
string msg
"""
def __init__(self, msg='barrier'):
self.msg = msg
def __getstate__(self):
return self.msg
def __setstate__(self, state):
self.msg = state
def process_request(self, server_state):
_CAPI_DGLRPCSetBarrierCount(_CAPI_DGLRPCGetBarrierCount()+1)
if _CAPI_DGLRPCGetBarrierCount() == get_num_client():
_CAPI_DGLRPCSetBarrierCount(0)
res_list = []
for target_id in range(get_num_client()):
res_list.append((target_id, ClientBarrierResponse()))
return res_list
return None
_init_api("dgl.distributed.rpc")
......@@ -128,6 +128,9 @@ def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket
rpc.register_service(rpc.GET_NUM_CLIENT,
rpc.GetNumberClientsRequest,
rpc.GetNumberClientsResponse)
rpc.register_service(rpc.CLIENT_BARRIER,
rpc.ClientBarrierRequest,
rpc.ClientBarrierResponse)
rpc.register_ctrl_c()
server_namebook = rpc.read_ip_config(ip_config)
num_servers = len(server_namebook)
......@@ -199,6 +202,7 @@ def exit_client():
"""Register exit callback.
"""
# Only client with rank_0 will send shutdown request to servers.
rpc.client_barrier()
shutdown_servers()
finalize_client()
atexit.unregister(exit_client)
......
......@@ -47,6 +47,9 @@ def start_server(server_id, ip_config, num_clients, server_state, \
rpc.register_service(rpc.GET_NUM_CLIENT,
rpc.GetNumberClientsRequest,
rpc.GetNumberClientsResponse)
rpc.register_service(rpc.CLIENT_BARRIER,
rpc.ClientBarrierRequest,
rpc.ClientBarrierResponse)
rpc.set_rank(server_id)
server_namebook = rpc.read_ip_config(ip_config)
machine_id = server_namebook[server_id][0]
......
......@@ -207,6 +207,17 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMsgSeq")
RPCContext::ThreadLocal()->msg_seq = msg_seq;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetBarrierCount")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::ThreadLocal()->barrier_count;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetBarrierCount")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const int32_t count = args[0];
RPCContext::ThreadLocal()->barrier_count = count;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMachineID")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::ThreadLocal()->machine_id;
......
......@@ -61,6 +61,11 @@ struct RPCContext {
*/
int32_t num_clients = 0;
/*!
* \brief Current barrier count
*/
int32_t barrier_count = 0;
/*!
* \brief Total number of server per machine.
*/
......
......@@ -47,13 +47,13 @@ global_eid = F.tensor([0,1,2,3,4,5,6], F.int64)
g = dgl.DGLGraph()
g.add_nodes(6)
g.add_edge(0, 1) # 0
g.add_edge(0, 2) # 1
g.add_edge(0, 3) # 2
g.add_edge(2, 3) # 3
g.add_edge(1, 1) # 4
g.add_edge(0, 4) # 5
g.add_edge(2, 5) # 6
g.add_edges(0, 1) # 0
g.add_edges(0, 2) # 1
g.add_edges(0, 3) # 2
g.add_edges(2, 3) # 3
g.add_edges(1, 1) # 4
g.add_edges(0, 4) # 5
g.add_edges(2, 5) # 6
g.ndata[dgl.NID] = global_nid
g.edata[dgl.EID] = global_eid
......@@ -129,11 +129,29 @@ def start_server(server_id, num_clients):
num_clients=num_clients,
server_state=server_state)
def start_server_mul_role(server_id, num_clients):
# Init kvserver
kvserver = dgl.distributed.KVServer(server_id=server_id,
ip_config='kv_ip_mul_config.txt',
num_clients=num_clients)
kvserver.add_part_policy(node_policy)
if kvserver.is_backup_server():
kvserver.init_data('data_0', 'node')
else:
kvserver.init_data('data_0', 'node', data_0)
# start server
server_state = dgl.distributed.ServerState(kv_store=kvserver, local_g=None, partition_book=None)
dgl.distributed.start_server(server_id=server_id,
ip_config='kv_ip_mul_config.txt',
num_clients=num_clients,
server_state=server_state)
def start_client(num_clients):
# Note: connect to server first !
dgl.distributed.connect_to_server(ip_config='kv_ip_config.txt')
# Init kvclient
kvclient = dgl.distributed.KVClient(ip_config='kv_ip_config.txt')
time.sleep(2)
assert dgl.distributed.get_num_client() == num_clients
kvclient.init_data(name='data_1',
shape=F.shape(data_1),
......@@ -249,6 +267,7 @@ def start_client(num_clients):
kvclient.register_push_handler('data_3', add_push)
kvclient.map_shared_data(partition_book=gpb)
data_tensor = F.tensor([[6.,6.],[6.,6.],[6.,6.]], F.float32)
kvclient.barrier()
time.sleep(kvclient.client_id + 1)
print("add...")
kvclient.push(name='data_3',
......@@ -258,8 +277,20 @@ def start_client(num_clients):
res = kvclient.pull(name='data_3', id_tensor=id_tensor)
data_tensor = data_tensor * num_clients
assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor))
# clean up
def start_client_mul_role(i, num_clients):
# Note: connect to server first !
dgl.distributed.connect_to_server(ip_config='kv_ip_mul_config.txt')
# Init kvclient
if i % 2 == 0:
kvclient = dgl.distributed.KVClient(ip_config='kv_ip_mul_config.txt', role='trainer')
else:
kvclient = dgl.distributed.KVClient(ip_config='kv_ip_mul_config.txt', role='sampler')
time.sleep(2)
if i == 2: # block one trainer
time.sleep(5)
kvclient.barrier()
print("i: %d role: %s" % (i, kvclient.role))
@unittest.skipIf(os.name == 'nt' or os.getenv('DGLBACKEND') == 'tensorflow', reason='Do not support windows and TF yet')
def test_kv_store():
......@@ -285,6 +316,31 @@ def test_kv_store():
for i in range(num_servers):
pserver_list[i].join()
@unittest.skipIf(os.name == 'nt' or os.getenv('DGLBACKEND') == 'tensorflow', reason='Do not support windows and TF yet')
def test_kv_multi_role():
ip_config = open("kv_ip_mul_config.txt", "w")
num_servers = 2
num_clients = 10
ip_addr = get_local_usable_addr()
ip_config.write('{} {}\n'.format(ip_addr, num_servers))
ip_config.close()
ctx = mp.get_context('spawn')
pserver_list = []
pclient_list = []
for i in range(num_servers):
pserver = ctx.Process(target=start_server_mul_role, args=(i, num_clients))
pserver.start()
pserver_list.append(pserver)
for i in range(num_clients):
pclient = ctx.Process(target=start_client_mul_role, args=(i, num_clients))
pclient.start()
pclient_list.append(pclient)
for i in range(num_clients):
pclient_list[i].join()
for i in range(num_servers):
pserver_list[i].join()
if __name__ == '__main__':
test_partition_policy()
test_kv_store()
test_kv_multi_role()
......@@ -152,9 +152,6 @@ def start_client(ip_config):
assert res.integer == INTEGER
assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
# clean up
time.sleep(2)
def test_serialize():
from dgl.distributed.rpc import serialize_to_payload, deserialize_from_payload
SERVICE_ID = 12345
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment