".github/vscode:/vscode.git/clone" did not exist on "b2b1b683fe89e47fe80c56b00cbcd0d912973a44"
Unverified Commit 2823c61f authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Fix synchronized APIs in kvstore (#1948)



* handle synchronized API.

* fix.

* fix.

* fix.
Co-authored-by: default avatarChao Ma <mctt90@gmail.com>
parent bab32d5b
......@@ -156,6 +156,14 @@ class InitDataRequest(rpc.Request):
def process_request(self, server_state):
kv_store = server_state.kv_store
dtype = F.data_type_dict[self.dtype]
# We should see requests from multiple clients. We need to ignore the duplicated
# reqeusts.
if self.name in kv_store.data_store:
assert tuple(F.shape(kv_store.data_store[self.name])) == tuple(self.shape)
assert F.reverse_data_type_dict[F.dtype(kv_store.data_store[self.name])] == self.dtype
assert kv_store.part_policy[self.name].policy_str == self.policy_str
else:
if not kv_store.is_backup_server():
data_tensor = self.init_func(self.shape, dtype)
kv_store.init_data(name=self.name,
......@@ -469,6 +477,12 @@ class SendMetaToBackupRequest(rpc.Request):
kv_store.part_policy[self.name] = kv_store.find_policy(self.policy_str)
kv_store.pull_handlers[self.name] = self.pull_handler
kv_store.push_handlers[self.name] = self.push_handler
else:
assert tuple(F.shape(kv_store.data_store[self.name])) == tuple(self.shape)
assert F.reverse_data_type_dict[F.dtype(kv_store.data_store[self.name])] == self.dtype
assert kv_store.part_policy[self.name].policy_str == self.policy_str
assert kv_store.pull_handlers[self.name] == self.pull_handler
assert kv_store.push_handlers[self.name] == self.push_handler
res = SendMetaToBackupResponse(SEND_META_TO_BACKUP_MSG)
return res
......@@ -507,7 +521,7 @@ class DeleteDataRequest(rpc.Request):
def process_request(self, server_state):
kv_store = server_state.kv_store
assert self.name in kv_store.data_store, 'data name %s not exists.' % self.name
if self.name in kv_store.data_store:
del kv_store.data_store[self.name]
del kv_store.part_policy[self.name]
del kv_store.push_handlers[self.name]
......@@ -876,7 +890,6 @@ class KVClient(object):
self._machine_count = int(self._server_count / self._group_count)
self._client_id = rpc.get_rank()
self._machine_id = rpc.get_machine_id()
self._num_clients = rpc.get_num_client()
self._part_id = self._machine_id
self._main_server_id = self._machine_id * self._group_count
# push and pull handler
......@@ -941,7 +954,6 @@ class KVClient(object):
The function to be called.
"""
self.barrier()
if self._client_id == 0:
request = RegisterPushHandlerRequest(name, func)
# send request to all the server nodes
for server_id in range(self._server_count):
......@@ -974,7 +986,6 @@ class KVClient(object):
The function to be called.
"""
self.barrier()
if self._client_id == 0:
request = RegisterPullHandlerRequest(name, func)
# send request to all the server nodes
for server_id in range(self._server_count):
......@@ -1008,12 +1019,9 @@ class KVClient(object):
assert name not in self._data_name_list, 'data name: %s already exists.' % name
self.barrier()
shape = list(shape)
# One of the clients in each machine will issue requests to the local server.
assert rpc.get_num_client() % part_policy.partition_book.num_partitions() == 0, \
'#clients ({}) is not divisable by #partitions ({})'.format(
rpc.get_num_client(), part_policy.partition_book.num_partitions())
num_clients_per_part = rpc.get_num_client() / part_policy.partition_book.num_partitions()
if self._client_id % num_clients_per_part == 0:
# Send request to the servers to initialize data.
# The servers may handle the duplicated initializations.
part_shape = shape.copy()
part_shape[0] = part_policy.get_data_size()
request = InitDataRequest(name,
......@@ -1021,12 +1029,14 @@ class KVClient(object):
F.reverse_data_type_dict[dtype],
part_policy.policy_str,
init_func)
# The request is sent to the servers in one group, which are on the same machine.
for n in range(self._group_count):
server_id = part_policy.part_id * self._group_count + n
rpc.send_request(server_id, request)
for _ in range(self._group_count):
response = rpc.recv_response()
assert response.msg == INIT_MSG
self.barrier()
# Create local shared-data
local_shape = shape.copy()
......@@ -1048,7 +1058,6 @@ class KVClient(object):
self._push_handlers[name] = default_push_handler
# Now we need to tell the backup server the new tensor.
if self._client_id % num_clients_per_part == 0:
request = SendMetaToBackupRequest(name, F.reverse_data_type_dict[dtype],
part_shape, part_policy.policy_str,
self._pull_handlers[name],
......@@ -1075,9 +1084,7 @@ class KVClient(object):
assert name in self._data_name_list, 'data name: %s not exists.' % name
self.barrier()
part_policy = self._part_policy[name]
num_partitions = part_policy.partition_book.num_partitions()
num_clients_per_part = rpc.get_num_client() // num_partitions
if self._client_id % num_clients_per_part == 0:
# send request to every server nodes
request = DeleteDataRequest(name)
for n in range(self._group_count):
......@@ -1086,6 +1093,7 @@ class KVClient(object):
for _ in range(self._group_count):
response = rpc.recv_response()
assert response.msg == DELETE_MSG
self.barrier()
self._data_name_list.remove(name)
# TODO(chao) : remove the delete log print
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment