Unverified Commit 2823c61f authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Fix synchronized APIs in kvstore (#1948)



* handle synchronized API.

* fix.

* fix.

* fix.
Co-authored-by: default avatarChao Ma <mctt90@gmail.com>
parent bab32d5b
...@@ -156,6 +156,14 @@ class InitDataRequest(rpc.Request): ...@@ -156,6 +156,14 @@ class InitDataRequest(rpc.Request):
def process_request(self, server_state): def process_request(self, server_state):
kv_store = server_state.kv_store kv_store = server_state.kv_store
dtype = F.data_type_dict[self.dtype] dtype = F.data_type_dict[self.dtype]
# We should see requests from multiple clients. We need to ignore the duplicated
# reqeusts.
if self.name in kv_store.data_store:
assert tuple(F.shape(kv_store.data_store[self.name])) == tuple(self.shape)
assert F.reverse_data_type_dict[F.dtype(kv_store.data_store[self.name])] == self.dtype
assert kv_store.part_policy[self.name].policy_str == self.policy_str
else:
if not kv_store.is_backup_server(): if not kv_store.is_backup_server():
data_tensor = self.init_func(self.shape, dtype) data_tensor = self.init_func(self.shape, dtype)
kv_store.init_data(name=self.name, kv_store.init_data(name=self.name,
...@@ -469,6 +477,12 @@ class SendMetaToBackupRequest(rpc.Request): ...@@ -469,6 +477,12 @@ class SendMetaToBackupRequest(rpc.Request):
kv_store.part_policy[self.name] = kv_store.find_policy(self.policy_str) kv_store.part_policy[self.name] = kv_store.find_policy(self.policy_str)
kv_store.pull_handlers[self.name] = self.pull_handler kv_store.pull_handlers[self.name] = self.pull_handler
kv_store.push_handlers[self.name] = self.push_handler kv_store.push_handlers[self.name] = self.push_handler
else:
assert tuple(F.shape(kv_store.data_store[self.name])) == tuple(self.shape)
assert F.reverse_data_type_dict[F.dtype(kv_store.data_store[self.name])] == self.dtype
assert kv_store.part_policy[self.name].policy_str == self.policy_str
assert kv_store.pull_handlers[self.name] == self.pull_handler
assert kv_store.push_handlers[self.name] == self.push_handler
res = SendMetaToBackupResponse(SEND_META_TO_BACKUP_MSG) res = SendMetaToBackupResponse(SEND_META_TO_BACKUP_MSG)
return res return res
...@@ -507,7 +521,7 @@ class DeleteDataRequest(rpc.Request): ...@@ -507,7 +521,7 @@ class DeleteDataRequest(rpc.Request):
def process_request(self, server_state): def process_request(self, server_state):
kv_store = server_state.kv_store kv_store = server_state.kv_store
assert self.name in kv_store.data_store, 'data name %s not exists.' % self.name if self.name in kv_store.data_store:
del kv_store.data_store[self.name] del kv_store.data_store[self.name]
del kv_store.part_policy[self.name] del kv_store.part_policy[self.name]
del kv_store.push_handlers[self.name] del kv_store.push_handlers[self.name]
...@@ -876,7 +890,6 @@ class KVClient(object): ...@@ -876,7 +890,6 @@ class KVClient(object):
self._machine_count = int(self._server_count / self._group_count) self._machine_count = int(self._server_count / self._group_count)
self._client_id = rpc.get_rank() self._client_id = rpc.get_rank()
self._machine_id = rpc.get_machine_id() self._machine_id = rpc.get_machine_id()
self._num_clients = rpc.get_num_client()
self._part_id = self._machine_id self._part_id = self._machine_id
self._main_server_id = self._machine_id * self._group_count self._main_server_id = self._machine_id * self._group_count
# push and pull handler # push and pull handler
...@@ -941,7 +954,6 @@ class KVClient(object): ...@@ -941,7 +954,6 @@ class KVClient(object):
The function to be called. The function to be called.
""" """
self.barrier() self.barrier()
if self._client_id == 0:
request = RegisterPushHandlerRequest(name, func) request = RegisterPushHandlerRequest(name, func)
# send request to all the server nodes # send request to all the server nodes
for server_id in range(self._server_count): for server_id in range(self._server_count):
...@@ -974,7 +986,6 @@ class KVClient(object): ...@@ -974,7 +986,6 @@ class KVClient(object):
The function to be called. The function to be called.
""" """
self.barrier() self.barrier()
if self._client_id == 0:
request = RegisterPullHandlerRequest(name, func) request = RegisterPullHandlerRequest(name, func)
# send request to all the server nodes # send request to all the server nodes
for server_id in range(self._server_count): for server_id in range(self._server_count):
...@@ -1008,12 +1019,9 @@ class KVClient(object): ...@@ -1008,12 +1019,9 @@ class KVClient(object):
assert name not in self._data_name_list, 'data name: %s already exists.' % name assert name not in self._data_name_list, 'data name: %s already exists.' % name
self.barrier() self.barrier()
shape = list(shape) shape = list(shape)
# One of the clients in each machine will issue requests to the local server.
assert rpc.get_num_client() % part_policy.partition_book.num_partitions() == 0, \ # Send request to the servers to initialize data.
'#clients ({}) is not divisable by #partitions ({})'.format( # The servers may handle the duplicated initializations.
rpc.get_num_client(), part_policy.partition_book.num_partitions())
num_clients_per_part = rpc.get_num_client() / part_policy.partition_book.num_partitions()
if self._client_id % num_clients_per_part == 0:
part_shape = shape.copy() part_shape = shape.copy()
part_shape[0] = part_policy.get_data_size() part_shape[0] = part_policy.get_data_size()
request = InitDataRequest(name, request = InitDataRequest(name,
...@@ -1021,12 +1029,14 @@ class KVClient(object): ...@@ -1021,12 +1029,14 @@ class KVClient(object):
F.reverse_data_type_dict[dtype], F.reverse_data_type_dict[dtype],
part_policy.policy_str, part_policy.policy_str,
init_func) init_func)
# The request is sent to the servers in one group, which are on the same machine.
for n in range(self._group_count): for n in range(self._group_count):
server_id = part_policy.part_id * self._group_count + n server_id = part_policy.part_id * self._group_count + n
rpc.send_request(server_id, request) rpc.send_request(server_id, request)
for _ in range(self._group_count): for _ in range(self._group_count):
response = rpc.recv_response() response = rpc.recv_response()
assert response.msg == INIT_MSG assert response.msg == INIT_MSG
self.barrier() self.barrier()
# Create local shared-data # Create local shared-data
local_shape = shape.copy() local_shape = shape.copy()
...@@ -1048,7 +1058,6 @@ class KVClient(object): ...@@ -1048,7 +1058,6 @@ class KVClient(object):
self._push_handlers[name] = default_push_handler self._push_handlers[name] = default_push_handler
# Now we need to tell the backup server the new tensor. # Now we need to tell the backup server the new tensor.
if self._client_id % num_clients_per_part == 0:
request = SendMetaToBackupRequest(name, F.reverse_data_type_dict[dtype], request = SendMetaToBackupRequest(name, F.reverse_data_type_dict[dtype],
part_shape, part_policy.policy_str, part_shape, part_policy.policy_str,
self._pull_handlers[name], self._pull_handlers[name],
...@@ -1075,9 +1084,7 @@ class KVClient(object): ...@@ -1075,9 +1084,7 @@ class KVClient(object):
assert name in self._data_name_list, 'data name: %s not exists.' % name assert name in self._data_name_list, 'data name: %s not exists.' % name
self.barrier() self.barrier()
part_policy = self._part_policy[name] part_policy = self._part_policy[name]
num_partitions = part_policy.partition_book.num_partitions()
num_clients_per_part = rpc.get_num_client() // num_partitions
if self._client_id % num_clients_per_part == 0:
# send request to every server nodes # send request to every server nodes
request = DeleteDataRequest(name) request = DeleteDataRequest(name)
for n in range(self._group_count): for n in range(self._group_count):
...@@ -1086,6 +1093,7 @@ class KVClient(object): ...@@ -1086,6 +1093,7 @@ class KVClient(object):
for _ in range(self._group_count): for _ in range(self._group_count):
response = rpc.recv_response() response = rpc.recv_response()
assert response.msg == DELETE_MSG assert response.msg == DELETE_MSG
self.barrier() self.barrier()
self._data_name_list.remove(name) self._data_name_list.remove(name)
# TODO(chao) : remove the delete log print # TODO(chao) : remove the delete log print
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment