Unverified Commit 2823c61f authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Fix synchronized APIs in kvstore (#1948)



* handle synchronized API.

* fix.

* fix.

* fix.
Co-authored-by: default avatarChao Ma <mctt90@gmail.com>
parent bab32d5b
......@@ -156,13 +156,21 @@ class InitDataRequest(rpc.Request):
def process_request(self, server_state):
kv_store = server_state.kv_store
dtype = F.data_type_dict[self.dtype]
if not kv_store.is_backup_server():
data_tensor = self.init_func(self.shape, dtype)
kv_store.init_data(name=self.name,
policy_str=self.policy_str,
data_tensor=data_tensor)
# We should see requests from multiple clients. We need to ignore the duplicated
# reqeusts.
if self.name in kv_store.data_store:
assert tuple(F.shape(kv_store.data_store[self.name])) == tuple(self.shape)
assert F.reverse_data_type_dict[F.dtype(kv_store.data_store[self.name])] == self.dtype
assert kv_store.part_policy[self.name].policy_str == self.policy_str
else:
kv_store.init_data(name=self.name, policy_str=self.policy_str)
if not kv_store.is_backup_server():
data_tensor = self.init_func(self.shape, dtype)
kv_store.init_data(name=self.name,
policy_str=self.policy_str,
data_tensor=data_tensor)
else:
kv_store.init_data(name=self.name, policy_str=self.policy_str)
res = InitDataResponse(INIT_MSG)
return res
......@@ -469,6 +477,12 @@ class SendMetaToBackupRequest(rpc.Request):
kv_store.part_policy[self.name] = kv_store.find_policy(self.policy_str)
kv_store.pull_handlers[self.name] = self.pull_handler
kv_store.push_handlers[self.name] = self.push_handler
else:
assert tuple(F.shape(kv_store.data_store[self.name])) == tuple(self.shape)
assert F.reverse_data_type_dict[F.dtype(kv_store.data_store[self.name])] == self.dtype
assert kv_store.part_policy[self.name].policy_str == self.policy_str
assert kv_store.pull_handlers[self.name] == self.pull_handler
assert kv_store.push_handlers[self.name] == self.push_handler
res = SendMetaToBackupResponse(SEND_META_TO_BACKUP_MSG)
return res
......@@ -507,11 +521,11 @@ class DeleteDataRequest(rpc.Request):
def process_request(self, server_state):
kv_store = server_state.kv_store
assert self.name in kv_store.data_store, 'data name %s not exists.' % self.name
del kv_store.data_store[self.name]
del kv_store.part_policy[self.name]
del kv_store.push_handlers[self.name]
del kv_store.pull_handlers[self.name]
if self.name in kv_store.data_store:
del kv_store.data_store[self.name]
del kv_store.part_policy[self.name]
del kv_store.push_handlers[self.name]
del kv_store.pull_handlers[self.name]
res = DeleteDataResponse(DELETE_MSG)
return res
......@@ -876,7 +890,6 @@ class KVClient(object):
self._machine_count = int(self._server_count / self._group_count)
self._client_id = rpc.get_rank()
self._machine_id = rpc.get_machine_id()
self._num_clients = rpc.get_num_client()
self._part_id = self._machine_id
self._main_server_id = self._machine_id * self._group_count
# push and pull handler
......@@ -941,15 +954,14 @@ class KVClient(object):
The function to be called.
"""
self.barrier()
if self._client_id == 0:
request = RegisterPushHandlerRequest(name, func)
# send request to all the server nodes
for server_id in range(self._server_count):
rpc.send_request(server_id, request)
# recv response from all the server nodes
for _ in range(self._server_count):
response = rpc.recv_response()
assert response.msg == REGISTER_PUSH_MSG
request = RegisterPushHandlerRequest(name, func)
# send request to all the server nodes
for server_id in range(self._server_count):
rpc.send_request(server_id, request)
# recv response from all the server nodes
for _ in range(self._server_count):
response = rpc.recv_response()
assert response.msg == REGISTER_PUSH_MSG
self._push_handlers[name] = func
self.barrier()
......@@ -974,15 +986,14 @@ class KVClient(object):
The function to be called.
"""
self.barrier()
if self._client_id == 0:
request = RegisterPullHandlerRequest(name, func)
# send request to all the server nodes
for server_id in range(self._server_count):
rpc.send_request(server_id, request)
# recv response from all the server nodes
for _ in range(self._server_count):
response = rpc.recv_response()
assert response.msg == REGISTER_PULL_MSG
request = RegisterPullHandlerRequest(name, func)
# send request to all the server nodes
for server_id in range(self._server_count):
rpc.send_request(server_id, request)
# recv response from all the server nodes
for _ in range(self._server_count):
response = rpc.recv_response()
assert response.msg == REGISTER_PULL_MSG
self._pull_handlers[name] = func
self.barrier()
......@@ -1008,25 +1019,24 @@ class KVClient(object):
assert name not in self._data_name_list, 'data name: %s already exists.' % name
self.barrier()
shape = list(shape)
# One of the clients in each machine will issue requests to the local server.
assert rpc.get_num_client() % part_policy.partition_book.num_partitions() == 0, \
'#clients ({}) is not divisable by #partitions ({})'.format(
rpc.get_num_client(), part_policy.partition_book.num_partitions())
num_clients_per_part = rpc.get_num_client() / part_policy.partition_book.num_partitions()
if self._client_id % num_clients_per_part == 0:
part_shape = shape.copy()
part_shape[0] = part_policy.get_data_size()
request = InitDataRequest(name,
tuple(part_shape),
F.reverse_data_type_dict[dtype],
part_policy.policy_str,
init_func)
for n in range(self._group_count):
server_id = part_policy.part_id * self._group_count + n
rpc.send_request(server_id, request)
for _ in range(self._group_count):
response = rpc.recv_response()
assert response.msg == INIT_MSG
# Send request to the servers to initialize data.
# The servers may handle the duplicated initializations.
part_shape = shape.copy()
part_shape[0] = part_policy.get_data_size()
request = InitDataRequest(name,
tuple(part_shape),
F.reverse_data_type_dict[dtype],
part_policy.policy_str,
init_func)
# The request is sent to the servers in one group, which are on the same machine.
for n in range(self._group_count):
server_id = part_policy.part_id * self._group_count + n
rpc.send_request(server_id, request)
for _ in range(self._group_count):
response = rpc.recv_response()
assert response.msg == INIT_MSG
self.barrier()
# Create local shared-data
local_shape = shape.copy()
......@@ -1048,19 +1058,18 @@ class KVClient(object):
self._push_handlers[name] = default_push_handler
# Now we need to tell the backup server the new tensor.
if self._client_id % num_clients_per_part == 0:
request = SendMetaToBackupRequest(name, F.reverse_data_type_dict[dtype],
part_shape, part_policy.policy_str,
self._pull_handlers[name],
self._push_handlers[name])
# send request to all the backup server nodes
for i in range(self._group_count-1):
server_id = self._machine_id * self._group_count + i + 1
rpc.send_request(server_id, request)
# recv response from all the backup server nodes
for _ in range(self._group_count-1):
response = rpc.recv_response()
assert response.msg == SEND_META_TO_BACKUP_MSG
request = SendMetaToBackupRequest(name, F.reverse_data_type_dict[dtype],
part_shape, part_policy.policy_str,
self._pull_handlers[name],
self._push_handlers[name])
# send request to all the backup server nodes
for i in range(self._group_count-1):
server_id = self._machine_id * self._group_count + i + 1
rpc.send_request(server_id, request)
# recv response from all the backup server nodes
for _ in range(self._group_count-1):
response = rpc.recv_response()
assert response.msg == SEND_META_TO_BACKUP_MSG
self.barrier()
def delete_data(self, name):
......@@ -1075,17 +1084,16 @@ class KVClient(object):
assert name in self._data_name_list, 'data name: %s not exists.' % name
self.barrier()
part_policy = self._part_policy[name]
num_partitions = part_policy.partition_book.num_partitions()
num_clients_per_part = rpc.get_num_client() // num_partitions
if self._client_id % num_clients_per_part == 0:
# send request to every server nodes
request = DeleteDataRequest(name)
for n in range(self._group_count):
server_id = part_policy.part_id * self._group_count + n
rpc.send_request(server_id, request)
for _ in range(self._group_count):
response = rpc.recv_response()
assert response.msg == DELETE_MSG
# send request to every server nodes
request = DeleteDataRequest(name)
for n in range(self._group_count):
server_id = part_policy.part_id * self._group_count + n
rpc.send_request(server_id, request)
for _ in range(self._group_count):
response = rpc.recv_response()
assert response.msg == DELETE_MSG
self.barrier()
self._data_name_list.remove(name)
# TODO(chao) : remove the delete log print
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment