"dgl_sparse/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "d02e560e07754d4d01b25cf99c4d6457e2e9cdd1"
Unverified Commit 5b515cf6 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[KVStore] Support group barrier (#1880)



* udpate

* update

* update

* update

* update

* update

* update

* update

* fix lint

* update

* update

* update

* update

* udpate

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update
Co-authored-by: default avatarDa Zheng <zhengda1936@gmail.com>
parent 444becf0
...@@ -193,26 +193,26 @@ class BarrierRequest(rpc.Request): ...@@ -193,26 +193,26 @@ class BarrierRequest(rpc.Request):
Parameters Parameters
---------- ----------
msg : string role : string
string msg client role
""" """
def __init__(self, msg): def __init__(self, role):
self.msg = msg self.role = role
def __getstate__(self): def __getstate__(self):
return self.msg return self.role
def __setstate__(self, state): def __setstate__(self, state):
self.msg = state self.role = state
def process_request(self, server_state): def process_request(self, server_state):
assert self.msg == BARRIER_MSG
kv_store = server_state.kv_store kv_store = server_state.kv_store
kv_store.barrier_count = kv_store.barrier_count + 1 count = kv_store.barrier_count[self.role]
if kv_store.barrier_count == kv_store.num_clients: kv_store.barrier_count[self.role] = count + 1
kv_store.barrier_count = 0 if kv_store.barrier_count[self.role] == len(kv_store.role[self.role]):
kv_store.barrier_count[self.role] = 0
res_list = [] res_list = []
for target_id in range(kv_store.num_clients): for target_id in kv_store.role[self.role]:
res_list.append((target_id, BarrierResponse(BARRIER_MSG))) res_list.append((target_id, BarrierResponse(BARRIER_MSG)))
return res_list return res_list
return None return None
...@@ -506,6 +506,52 @@ class DeleteDataRequest(rpc.Request): ...@@ -506,6 +506,52 @@ class DeleteDataRequest(rpc.Request):
res = DeleteDataResponse(DELETE_MSG) res = DeleteDataResponse(DELETE_MSG)
return res return res
REGISTER_ROLE = 901241
ROLE_MSG = "Register_Role"
class RegisterRoleResponse(rpc.Response):
"""Send a confirmation signal (just a short string message)
of RegisterRoleRequest to client.
"""
def __init__(self, msg):
self.msg = msg
def __getstate__(self):
return self.msg
def __setstate__(self, state):
self.msg = state
class RegisterRoleRequest(rpc.Request):
"""Send client id and role to server
Parameters
----------
client_id : int
ID of client
role : str
role of client
"""
def __init__(self, client_id, role):
self.client_id = client_id
self.role = role
def __getstate__(self):
return self.client_id, self.role
def __setstate__(self, state):
self.client_id, self.role = state
def process_request(self, server_state):
kv_store = server_state.kv_store
role = kv_store.role
if self.role not in role:
role[self.role] = set()
kv_store.barrier_count[self.role] = 0
role[self.role].add(self.client_id)
res = RegisterRoleResponse(ROLE_MSG)
return res
############################ KVServer ############################### ############################ KVServer ###############################
def default_push_handler(target, name, id_tensor, data_tensor): def default_push_handler(target, name, id_tensor, data_tensor):
...@@ -604,6 +650,9 @@ class KVServer(object): ...@@ -604,6 +650,9 @@ class KVServer(object):
rpc.register_service(DELETE_DATA, rpc.register_service(DELETE_DATA,
DeleteDataRequest, DeleteDataRequest,
DeleteDataResponse) DeleteDataResponse)
rpc.register_service(REGISTER_ROLE,
RegisterRoleRequest,
RegisterRoleResponse)
# Store the tensor data with specified data name # Store the tensor data with specified data name
self._data_store = {} self._data_store = {}
# Store the partition information with specified data name # Store the partition information with specified data name
...@@ -620,10 +669,12 @@ class KVServer(object): ...@@ -620,10 +669,12 @@ class KVServer(object):
# We assume partition_id is equal to machine_id # We assume partition_id is equal to machine_id
self._part_id = self._machine_id self._part_id = self._machine_id
self._num_clients = num_clients self._num_clients = num_clients
self._barrier_count = 0 self._barrier_count = {}
# push and pull handler # push and pull handler
self._push_handlers = {} self._push_handlers = {}
self._pull_handlers = {} self._pull_handlers = {}
# store client role
self._role = {}
@property @property
def server_id(self): def server_id(self):
...@@ -665,6 +716,11 @@ class KVServer(object): ...@@ -665,6 +716,11 @@ class KVServer(object):
"""Get push handler""" """Get push handler"""
return self._push_handlers return self._push_handlers
@property
def role(self):
"""Get client role"""
return self._role
@property @property
def pull_handlers(self): def pull_handlers(self):
"""Get pull handler""" """Get pull handler"""
...@@ -748,8 +804,10 @@ class KVClient(object): ...@@ -748,8 +804,10 @@ class KVClient(object):
---------- ----------
ip_config : str ip_config : str
Path of IP configuration file. Path of IP configuration file.
role : str
We can set different role for kvstore.
""" """
def __init__(self, ip_config): def __init__(self, ip_config, role='default'):
assert rpc.get_rank() != -1, 'Please invoke rpc.connect_to_server() \ assert rpc.get_rank() != -1, 'Please invoke rpc.connect_to_server() \
before creating KVClient.' before creating KVClient.'
assert os.path.exists(ip_config), 'Cannot open file: %s' % ip_config assert os.path.exists(ip_config), 'Cannot open file: %s' % ip_config
...@@ -784,6 +842,9 @@ class KVClient(object): ...@@ -784,6 +842,9 @@ class KVClient(object):
rpc.register_service(DELETE_DATA, rpc.register_service(DELETE_DATA,
DeleteDataRequest, DeleteDataRequest,
DeleteDataResponse) DeleteDataResponse)
rpc.register_service(REGISTER_ROLE,
RegisterRoleRequest,
RegisterRoleResponse)
# Store the tensor data with specified data name # Store the tensor data with specified data name
self._data_store = {} self._data_store = {}
# Store the partition information with specified data name # Store the partition information with specified data name
...@@ -805,12 +866,23 @@ class KVClient(object): ...@@ -805,12 +866,23 @@ class KVClient(object):
# push and pull handler # push and pull handler
self._pull_handlers = {} self._pull_handlers = {}
self._push_handlers = {} self._push_handlers = {}
# register role on server-0
self._role = role
request = RegisterRoleRequest(self._client_id, self._role)
rpc.send_request(0, request)
response = rpc.recv_response()
assert response.msg == ROLE_MSG
@property @property
def client_id(self): def client_id(self):
"""Get client ID""" """Get client ID"""
return self._client_id return self._client_id
@property
def role(self):
"""Get client role"""
return self._role
@property @property
def machine_id(self): def machine_id(self):
"""Get machine ID""" """Get machine ID"""
...@@ -821,14 +893,10 @@ class KVClient(object): ...@@ -821,14 +893,10 @@ class KVClient(object):
This API will be blocked untill all the clients invoke this API. This API will be blocked untill all the clients invoke this API.
""" """
request = BarrierRequest(BARRIER_MSG) request = BarrierRequest(self._role)
# send request to all the server nodes rpc.send_request(0, request)
for server_id in range(self._server_count): response = rpc.recv_response()
rpc.send_request(server_id, request) assert response.msg == BARRIER_MSG
# recv response from all the server nodes
for _ in range(self._server_count):
response = rpc.recv_response()
assert response.msg == BARRIER_MSG
def register_push_handler(self, name, func): def register_push_handler(self, name, func):
"""Register UDF push function. """Register UDF push function.
......
...@@ -16,7 +16,7 @@ __all__ = ['set_rank', 'get_rank', 'Request', 'Response', 'register_service', \ ...@@ -16,7 +16,7 @@ __all__ = ['set_rank', 'get_rank', 'Request', 'Response', 'register_service', \
'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \ 'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \
'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call', \ 'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call', \
'send_request_to_machine', 'remote_call_to_machine', 'fast_pull', \ 'send_request_to_machine', 'remote_call_to_machine', 'fast_pull', \
'get_num_client', 'set_num_client'] 'get_num_client', 'set_num_client', 'client_barrier']
REQUEST_CLASS_TO_SERVICE_ID = {} REQUEST_CLASS_TO_SERVICE_ID = {}
RESPONSE_CLASS_TO_SERVICE_ID = {} RESPONSE_CLASS_TO_SERVICE_ID = {}
...@@ -899,6 +899,13 @@ def recv_rpc_message(timeout=0): ...@@ -899,6 +899,13 @@ def recv_rpc_message(timeout=0):
_CAPI_DGLRPCRecvRPCMessage(timeout, msg) _CAPI_DGLRPCRecvRPCMessage(timeout, msg)
return msg return msg
def client_barrier():
"""Barrier all client processes"""
req = ClientBarrierRequest()
send_request(0, req)
res = recv_response()
assert res.msg == 'barrier'
def finalize_server(): def finalize_server():
"""Finalize resources of current server """Finalize resources of current server
""" """
...@@ -1068,4 +1075,50 @@ class GetNumberClientsRequest(Request): ...@@ -1068,4 +1075,50 @@ class GetNumberClientsRequest(Request):
res = GetNumberClientsResponse(get_num_client()) res = GetNumberClientsResponse(get_num_client())
return res return res
CLIENT_BARRIER = 22454
class ClientBarrierResponse(Response):
"""Send the barrier confirmation to client
Parameters
----------
msg : str
string msg
"""
def __init__(self, msg='barrier'):
self.msg = msg
def __getstate__(self):
return self.msg
def __setstate__(self, state):
self.msg = state
class ClientBarrierRequest(Request):
"""Send the barrier information to server
Parameters
----------
msg : str
string msg
"""
def __init__(self, msg='barrier'):
self.msg = msg
def __getstate__(self):
return self.msg
def __setstate__(self, state):
self.msg = state
def process_request(self, server_state):
_CAPI_DGLRPCSetBarrierCount(_CAPI_DGLRPCGetBarrierCount()+1)
if _CAPI_DGLRPCGetBarrierCount() == get_num_client():
_CAPI_DGLRPCSetBarrierCount(0)
res_list = []
for target_id in range(get_num_client()):
res_list.append((target_id, ClientBarrierResponse()))
return res_list
return None
_init_api("dgl.distributed.rpc") _init_api("dgl.distributed.rpc")
...@@ -128,6 +128,9 @@ def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket ...@@ -128,6 +128,9 @@ def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket
rpc.register_service(rpc.GET_NUM_CLIENT, rpc.register_service(rpc.GET_NUM_CLIENT,
rpc.GetNumberClientsRequest, rpc.GetNumberClientsRequest,
rpc.GetNumberClientsResponse) rpc.GetNumberClientsResponse)
rpc.register_service(rpc.CLIENT_BARRIER,
rpc.ClientBarrierRequest,
rpc.ClientBarrierResponse)
rpc.register_ctrl_c() rpc.register_ctrl_c()
server_namebook = rpc.read_ip_config(ip_config) server_namebook = rpc.read_ip_config(ip_config)
num_servers = len(server_namebook) num_servers = len(server_namebook)
...@@ -199,6 +202,7 @@ def exit_client(): ...@@ -199,6 +202,7 @@ def exit_client():
"""Register exit callback. """Register exit callback.
""" """
# Only client with rank_0 will send shutdown request to servers. # Only client with rank_0 will send shutdown request to servers.
rpc.client_barrier()
shutdown_servers() shutdown_servers()
finalize_client() finalize_client()
atexit.unregister(exit_client) atexit.unregister(exit_client)
......
...@@ -47,6 +47,9 @@ def start_server(server_id, ip_config, num_clients, server_state, \ ...@@ -47,6 +47,9 @@ def start_server(server_id, ip_config, num_clients, server_state, \
rpc.register_service(rpc.GET_NUM_CLIENT, rpc.register_service(rpc.GET_NUM_CLIENT,
rpc.GetNumberClientsRequest, rpc.GetNumberClientsRequest,
rpc.GetNumberClientsResponse) rpc.GetNumberClientsResponse)
rpc.register_service(rpc.CLIENT_BARRIER,
rpc.ClientBarrierRequest,
rpc.ClientBarrierResponse)
rpc.set_rank(server_id) rpc.set_rank(server_id)
server_namebook = rpc.read_ip_config(ip_config) server_namebook = rpc.read_ip_config(ip_config)
machine_id = server_namebook[server_id][0] machine_id = server_namebook[server_id][0]
......
...@@ -207,6 +207,17 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMsgSeq") ...@@ -207,6 +207,17 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMsgSeq")
RPCContext::ThreadLocal()->msg_seq = msg_seq; RPCContext::ThreadLocal()->msg_seq = msg_seq;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetBarrierCount")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::ThreadLocal()->barrier_count;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetBarrierCount")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const int32_t count = args[0];
RPCContext::ThreadLocal()->barrier_count = count;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMachineID") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMachineID")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::ThreadLocal()->machine_id; *rv = RPCContext::ThreadLocal()->machine_id;
......
...@@ -61,6 +61,11 @@ struct RPCContext { ...@@ -61,6 +61,11 @@ struct RPCContext {
*/ */
int32_t num_clients = 0; int32_t num_clients = 0;
/*!
* \brief Current barrier count
*/
int32_t barrier_count = 0;
/*! /*!
* \brief Total number of server per machine. * \brief Total number of server per machine.
*/ */
......
...@@ -47,13 +47,13 @@ global_eid = F.tensor([0,1,2,3,4,5,6], F.int64) ...@@ -47,13 +47,13 @@ global_eid = F.tensor([0,1,2,3,4,5,6], F.int64)
g = dgl.DGLGraph() g = dgl.DGLGraph()
g.add_nodes(6) g.add_nodes(6)
g.add_edge(0, 1) # 0 g.add_edges(0, 1) # 0
g.add_edge(0, 2) # 1 g.add_edges(0, 2) # 1
g.add_edge(0, 3) # 2 g.add_edges(0, 3) # 2
g.add_edge(2, 3) # 3 g.add_edges(2, 3) # 3
g.add_edge(1, 1) # 4 g.add_edges(1, 1) # 4
g.add_edge(0, 4) # 5 g.add_edges(0, 4) # 5
g.add_edge(2, 5) # 6 g.add_edges(2, 5) # 6
g.ndata[dgl.NID] = global_nid g.ndata[dgl.NID] = global_nid
g.edata[dgl.EID] = global_eid g.edata[dgl.EID] = global_eid
...@@ -129,11 +129,29 @@ def start_server(server_id, num_clients): ...@@ -129,11 +129,29 @@ def start_server(server_id, num_clients):
num_clients=num_clients, num_clients=num_clients,
server_state=server_state) server_state=server_state)
def start_server_mul_role(server_id, num_clients):
# Init kvserver
kvserver = dgl.distributed.KVServer(server_id=server_id,
ip_config='kv_ip_mul_config.txt',
num_clients=num_clients)
kvserver.add_part_policy(node_policy)
if kvserver.is_backup_server():
kvserver.init_data('data_0', 'node')
else:
kvserver.init_data('data_0', 'node', data_0)
# start server
server_state = dgl.distributed.ServerState(kv_store=kvserver, local_g=None, partition_book=None)
dgl.distributed.start_server(server_id=server_id,
ip_config='kv_ip_mul_config.txt',
num_clients=num_clients,
server_state=server_state)
def start_client(num_clients): def start_client(num_clients):
# Note: connect to server first ! # Note: connect to server first !
dgl.distributed.connect_to_server(ip_config='kv_ip_config.txt') dgl.distributed.connect_to_server(ip_config='kv_ip_config.txt')
# Init kvclient # Init kvclient
kvclient = dgl.distributed.KVClient(ip_config='kv_ip_config.txt') kvclient = dgl.distributed.KVClient(ip_config='kv_ip_config.txt')
time.sleep(2)
assert dgl.distributed.get_num_client() == num_clients assert dgl.distributed.get_num_client() == num_clients
kvclient.init_data(name='data_1', kvclient.init_data(name='data_1',
shape=F.shape(data_1), shape=F.shape(data_1),
...@@ -249,6 +267,7 @@ def start_client(num_clients): ...@@ -249,6 +267,7 @@ def start_client(num_clients):
kvclient.register_push_handler('data_3', add_push) kvclient.register_push_handler('data_3', add_push)
kvclient.map_shared_data(partition_book=gpb) kvclient.map_shared_data(partition_book=gpb)
data_tensor = F.tensor([[6.,6.],[6.,6.],[6.,6.]], F.float32) data_tensor = F.tensor([[6.,6.],[6.,6.],[6.,6.]], F.float32)
kvclient.barrier()
time.sleep(kvclient.client_id + 1) time.sleep(kvclient.client_id + 1)
print("add...") print("add...")
kvclient.push(name='data_3', kvclient.push(name='data_3',
...@@ -258,8 +277,20 @@ def start_client(num_clients): ...@@ -258,8 +277,20 @@ def start_client(num_clients):
res = kvclient.pull(name='data_3', id_tensor=id_tensor) res = kvclient.pull(name='data_3', id_tensor=id_tensor)
data_tensor = data_tensor * num_clients data_tensor = data_tensor * num_clients
assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor)) assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor))
# clean up
def start_client_mul_role(i, num_clients):
# Note: connect to server first !
dgl.distributed.connect_to_server(ip_config='kv_ip_mul_config.txt')
# Init kvclient
if i % 2 == 0:
kvclient = dgl.distributed.KVClient(ip_config='kv_ip_mul_config.txt', role='trainer')
else:
kvclient = dgl.distributed.KVClient(ip_config='kv_ip_mul_config.txt', role='sampler')
time.sleep(2)
if i == 2: # block one trainer
time.sleep(5)
kvclient.barrier() kvclient.barrier()
print("i: %d role: %s" % (i, kvclient.role))
@unittest.skipIf(os.name == 'nt' or os.getenv('DGLBACKEND') == 'tensorflow', reason='Do not support windows and TF yet') @unittest.skipIf(os.name == 'nt' or os.getenv('DGLBACKEND') == 'tensorflow', reason='Do not support windows and TF yet')
def test_kv_store(): def test_kv_store():
...@@ -285,6 +316,31 @@ def test_kv_store(): ...@@ -285,6 +316,31 @@ def test_kv_store():
for i in range(num_servers): for i in range(num_servers):
pserver_list[i].join() pserver_list[i].join()
@unittest.skipIf(os.name == 'nt' or os.getenv('DGLBACKEND') == 'tensorflow', reason='Do not support windows and TF yet')
def test_kv_multi_role():
ip_config = open("kv_ip_mul_config.txt", "w")
num_servers = 2
num_clients = 10
ip_addr = get_local_usable_addr()
ip_config.write('{} {}\n'.format(ip_addr, num_servers))
ip_config.close()
ctx = mp.get_context('spawn')
pserver_list = []
pclient_list = []
for i in range(num_servers):
pserver = ctx.Process(target=start_server_mul_role, args=(i, num_clients))
pserver.start()
pserver_list.append(pserver)
for i in range(num_clients):
pclient = ctx.Process(target=start_client_mul_role, args=(i, num_clients))
pclient.start()
pclient_list.append(pclient)
for i in range(num_clients):
pclient_list[i].join()
for i in range(num_servers):
pserver_list[i].join()
if __name__ == '__main__': if __name__ == '__main__':
test_partition_policy() test_partition_policy()
test_kv_store() test_kv_store()
test_kv_multi_role()
...@@ -152,9 +152,6 @@ def start_client(ip_config): ...@@ -152,9 +152,6 @@ def start_client(ip_config):
assert res.integer == INTEGER assert res.integer == INTEGER
assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR)) assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
# clean up
time.sleep(2)
def test_serialize(): def test_serialize():
from dgl.distributed.rpc import serialize_to_payload, deserialize_from_payload from dgl.distributed.rpc import serialize_to_payload, deserialize_from_payload
SERVICE_ID = 12345 SERVICE_ID = 12345
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment