Unverified Commit 2823c61f authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Fix synchronized APIs in kvstore (#1948)



* handle synchronized API.

* fix.

* fix.

* fix.
Co-authored-by: default avatarChao Ma <mctt90@gmail.com>
parent bab32d5b
...@@ -156,13 +156,21 @@ class InitDataRequest(rpc.Request): ...@@ -156,13 +156,21 @@ class InitDataRequest(rpc.Request):
def process_request(self, server_state): def process_request(self, server_state):
kv_store = server_state.kv_store kv_store = server_state.kv_store
dtype = F.data_type_dict[self.dtype] dtype = F.data_type_dict[self.dtype]
if not kv_store.is_backup_server():
data_tensor = self.init_func(self.shape, dtype) # We should see requests from multiple clients. We need to ignore the duplicated
kv_store.init_data(name=self.name, # reqeusts.
policy_str=self.policy_str, if self.name in kv_store.data_store:
data_tensor=data_tensor) assert tuple(F.shape(kv_store.data_store[self.name])) == tuple(self.shape)
assert F.reverse_data_type_dict[F.dtype(kv_store.data_store[self.name])] == self.dtype
assert kv_store.part_policy[self.name].policy_str == self.policy_str
else: else:
kv_store.init_data(name=self.name, policy_str=self.policy_str) if not kv_store.is_backup_server():
data_tensor = self.init_func(self.shape, dtype)
kv_store.init_data(name=self.name,
policy_str=self.policy_str,
data_tensor=data_tensor)
else:
kv_store.init_data(name=self.name, policy_str=self.policy_str)
res = InitDataResponse(INIT_MSG) res = InitDataResponse(INIT_MSG)
return res return res
...@@ -469,6 +477,12 @@ class SendMetaToBackupRequest(rpc.Request): ...@@ -469,6 +477,12 @@ class SendMetaToBackupRequest(rpc.Request):
kv_store.part_policy[self.name] = kv_store.find_policy(self.policy_str) kv_store.part_policy[self.name] = kv_store.find_policy(self.policy_str)
kv_store.pull_handlers[self.name] = self.pull_handler kv_store.pull_handlers[self.name] = self.pull_handler
kv_store.push_handlers[self.name] = self.push_handler kv_store.push_handlers[self.name] = self.push_handler
else:
assert tuple(F.shape(kv_store.data_store[self.name])) == tuple(self.shape)
assert F.reverse_data_type_dict[F.dtype(kv_store.data_store[self.name])] == self.dtype
assert kv_store.part_policy[self.name].policy_str == self.policy_str
assert kv_store.pull_handlers[self.name] == self.pull_handler
assert kv_store.push_handlers[self.name] == self.push_handler
res = SendMetaToBackupResponse(SEND_META_TO_BACKUP_MSG) res = SendMetaToBackupResponse(SEND_META_TO_BACKUP_MSG)
return res return res
...@@ -507,11 +521,11 @@ class DeleteDataRequest(rpc.Request): ...@@ -507,11 +521,11 @@ class DeleteDataRequest(rpc.Request):
def process_request(self, server_state): def process_request(self, server_state):
kv_store = server_state.kv_store kv_store = server_state.kv_store
assert self.name in kv_store.data_store, 'data name %s not exists.' % self.name if self.name in kv_store.data_store:
del kv_store.data_store[self.name] del kv_store.data_store[self.name]
del kv_store.part_policy[self.name] del kv_store.part_policy[self.name]
del kv_store.push_handlers[self.name] del kv_store.push_handlers[self.name]
del kv_store.pull_handlers[self.name] del kv_store.pull_handlers[self.name]
res = DeleteDataResponse(DELETE_MSG) res = DeleteDataResponse(DELETE_MSG)
return res return res
...@@ -876,7 +890,6 @@ class KVClient(object): ...@@ -876,7 +890,6 @@ class KVClient(object):
self._machine_count = int(self._server_count / self._group_count) self._machine_count = int(self._server_count / self._group_count)
self._client_id = rpc.get_rank() self._client_id = rpc.get_rank()
self._machine_id = rpc.get_machine_id() self._machine_id = rpc.get_machine_id()
self._num_clients = rpc.get_num_client()
self._part_id = self._machine_id self._part_id = self._machine_id
self._main_server_id = self._machine_id * self._group_count self._main_server_id = self._machine_id * self._group_count
# push and pull handler # push and pull handler
...@@ -941,15 +954,14 @@ class KVClient(object): ...@@ -941,15 +954,14 @@ class KVClient(object):
The function to be called. The function to be called.
""" """
self.barrier() self.barrier()
if self._client_id == 0: request = RegisterPushHandlerRequest(name, func)
request = RegisterPushHandlerRequest(name, func) # send request to all the server nodes
# send request to all the server nodes for server_id in range(self._server_count):
for server_id in range(self._server_count): rpc.send_request(server_id, request)
rpc.send_request(server_id, request) # recv response from all the server nodes
# recv response from all the server nodes for _ in range(self._server_count):
for _ in range(self._server_count): response = rpc.recv_response()
response = rpc.recv_response() assert response.msg == REGISTER_PUSH_MSG
assert response.msg == REGISTER_PUSH_MSG
self._push_handlers[name] = func self._push_handlers[name] = func
self.barrier() self.barrier()
...@@ -974,15 +986,14 @@ class KVClient(object): ...@@ -974,15 +986,14 @@ class KVClient(object):
The function to be called. The function to be called.
""" """
self.barrier() self.barrier()
if self._client_id == 0: request = RegisterPullHandlerRequest(name, func)
request = RegisterPullHandlerRequest(name, func) # send request to all the server nodes
# send request to all the server nodes for server_id in range(self._server_count):
for server_id in range(self._server_count): rpc.send_request(server_id, request)
rpc.send_request(server_id, request) # recv response from all the server nodes
# recv response from all the server nodes for _ in range(self._server_count):
for _ in range(self._server_count): response = rpc.recv_response()
response = rpc.recv_response() assert response.msg == REGISTER_PULL_MSG
assert response.msg == REGISTER_PULL_MSG
self._pull_handlers[name] = func self._pull_handlers[name] = func
self.barrier() self.barrier()
...@@ -1008,25 +1019,24 @@ class KVClient(object): ...@@ -1008,25 +1019,24 @@ class KVClient(object):
assert name not in self._data_name_list, 'data name: %s already exists.' % name assert name not in self._data_name_list, 'data name: %s already exists.' % name
self.barrier() self.barrier()
shape = list(shape) shape = list(shape)
# One of the clients in each machine will issue requests to the local server.
assert rpc.get_num_client() % part_policy.partition_book.num_partitions() == 0, \ # Send request to the servers to initialize data.
'#clients ({}) is not divisable by #partitions ({})'.format( # The servers may handle the duplicated initializations.
rpc.get_num_client(), part_policy.partition_book.num_partitions()) part_shape = shape.copy()
num_clients_per_part = rpc.get_num_client() / part_policy.partition_book.num_partitions() part_shape[0] = part_policy.get_data_size()
if self._client_id % num_clients_per_part == 0: request = InitDataRequest(name,
part_shape = shape.copy() tuple(part_shape),
part_shape[0] = part_policy.get_data_size() F.reverse_data_type_dict[dtype],
request = InitDataRequest(name, part_policy.policy_str,
tuple(part_shape), init_func)
F.reverse_data_type_dict[dtype], # The request is sent to the servers in one group, which are on the same machine.
part_policy.policy_str, for n in range(self._group_count):
init_func) server_id = part_policy.part_id * self._group_count + n
for n in range(self._group_count): rpc.send_request(server_id, request)
server_id = part_policy.part_id * self._group_count + n for _ in range(self._group_count):
rpc.send_request(server_id, request) response = rpc.recv_response()
for _ in range(self._group_count): assert response.msg == INIT_MSG
response = rpc.recv_response()
assert response.msg == INIT_MSG
self.barrier() self.barrier()
# Create local shared-data # Create local shared-data
local_shape = shape.copy() local_shape = shape.copy()
...@@ -1048,19 +1058,18 @@ class KVClient(object): ...@@ -1048,19 +1058,18 @@ class KVClient(object):
self._push_handlers[name] = default_push_handler self._push_handlers[name] = default_push_handler
# Now we need to tell the backup server the new tensor. # Now we need to tell the backup server the new tensor.
if self._client_id % num_clients_per_part == 0: request = SendMetaToBackupRequest(name, F.reverse_data_type_dict[dtype],
request = SendMetaToBackupRequest(name, F.reverse_data_type_dict[dtype], part_shape, part_policy.policy_str,
part_shape, part_policy.policy_str, self._pull_handlers[name],
self._pull_handlers[name], self._push_handlers[name])
self._push_handlers[name]) # send request to all the backup server nodes
# send request to all the backup server nodes for i in range(self._group_count-1):
for i in range(self._group_count-1): server_id = self._machine_id * self._group_count + i + 1
server_id = self._machine_id * self._group_count + i + 1 rpc.send_request(server_id, request)
rpc.send_request(server_id, request) # recv response from all the backup server nodes
# recv response from all the backup server nodes for _ in range(self._group_count-1):
for _ in range(self._group_count-1): response = rpc.recv_response()
response = rpc.recv_response() assert response.msg == SEND_META_TO_BACKUP_MSG
assert response.msg == SEND_META_TO_BACKUP_MSG
self.barrier() self.barrier()
def delete_data(self, name): def delete_data(self, name):
...@@ -1075,17 +1084,16 @@ class KVClient(object): ...@@ -1075,17 +1084,16 @@ class KVClient(object):
assert name in self._data_name_list, 'data name: %s not exists.' % name assert name in self._data_name_list, 'data name: %s not exists.' % name
self.barrier() self.barrier()
part_policy = self._part_policy[name] part_policy = self._part_policy[name]
num_partitions = part_policy.partition_book.num_partitions()
num_clients_per_part = rpc.get_num_client() // num_partitions # send request to every server nodes
if self._client_id % num_clients_per_part == 0: request = DeleteDataRequest(name)
# send request to every server nodes for n in range(self._group_count):
request = DeleteDataRequest(name) server_id = part_policy.part_id * self._group_count + n
for n in range(self._group_count): rpc.send_request(server_id, request)
server_id = part_policy.part_id * self._group_count + n for _ in range(self._group_count):
rpc.send_request(server_id, request) response = rpc.recv_response()
for _ in range(self._group_count): assert response.msg == DELETE_MSG
response = rpc.recv_response()
assert response.msg == DELETE_MSG
self.barrier() self.barrier()
self._data_name_list.remove(name) self._data_name_list.remove(name)
# TODO(chao) : remove the delete log print # TODO(chao) : remove the delete log print
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment