"tests/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "1926331eaf59dae54aeb97cde19dae16c2fdaa48"
Unverified Commit 65e7805e authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[KVStore] Support delete_data() api on client (#1824)

* delete tensor

* update

* update

* update

* update

* update

* update

* udpate

* update

* update

* update
parent ea420c0a
...@@ -463,6 +463,49 @@ class SendMetaToBackupRequest(rpc.Request): ...@@ -463,6 +463,49 @@ class SendMetaToBackupRequest(rpc.Request):
res = SendMetaToBackupResponse(SEND_META_TO_BACKUP_MSG) res = SendMetaToBackupResponse(SEND_META_TO_BACKUP_MSG)
return res return res
DELETE_DATA = 901240
DELETE_MSG = "Delete_Data"
class DeleteDataResponse(rpc.Response):
"""Send a confirmation signal (just a short string message)
of DeleteDataRequest to client.
"""
def __init__(self, msg):
self.msg = msg
def __getstate__(self):
return self.msg
def __setstate__(self, state):
self.msg = state
class DeleteDataRequest(rpc.Request):
"""Send message to server to delete data tensor
Parameters
----------
name : str
data name
"""
def __init__(self, name):
self.name = name
def __getstate__(self):
return self.name
def __setstate__(self, state):
self.name = state
def process_request(self, server_state):
kv_store = server_state.kv_store
assert self.name in kv_store.data_store, 'data name %s not exists.' % self.name
del kv_store.data_store[self.name]
del kv_store.part_policy[self.name]
del kv_store.push_handlers[self.name]
del kv_store.pull_handlers[self.name]
res = DeleteDataResponse(DELETE_MSG)
return res
############################ KVServer ############################### ############################ KVServer ###############################
def default_push_handler(target, name, id_tensor, data_tensor): def default_push_handler(target, name, id_tensor, data_tensor):
...@@ -558,6 +601,9 @@ class KVServer(object): ...@@ -558,6 +601,9 @@ class KVServer(object):
rpc.register_service(SEND_META_TO_BACKUP, rpc.register_service(SEND_META_TO_BACKUP,
SendMetaToBackupRequest, SendMetaToBackupRequest,
SendMetaToBackupResponse) SendMetaToBackupResponse)
rpc.register_service(DELETE_DATA,
DeleteDataRequest,
DeleteDataResponse)
# Store the tensor data with specified data name # Store the tensor data with specified data name
self._data_store = {} self._data_store = {}
# Store the partition information with specified data name # Store the partition information with specified data name
...@@ -726,6 +772,9 @@ class KVClient(object): ...@@ -726,6 +772,9 @@ class KVClient(object):
rpc.register_service(SEND_META_TO_BACKUP, rpc.register_service(SEND_META_TO_BACKUP,
SendMetaToBackupRequest, SendMetaToBackupRequest,
SendMetaToBackupResponse) SendMetaToBackupResponse)
rpc.register_service(DELETE_DATA,
DeleteDataRequest,
DeleteDataResponse)
# Store the tensor data with specified data name # Store the tensor data with specified data name
self._data_store = {} self._data_store = {}
# Store the partition information with specified data name # Store the partition information with specified data name
...@@ -901,6 +950,39 @@ class KVClient(object): ...@@ -901,6 +950,39 @@ class KVClient(object):
self._push_handlers[name] = default_push_handler self._push_handlers[name] = default_push_handler
self.barrier() self.barrier()
def delete_data(self, name):
"""Send message to kvserver to delete tensor and clear the meta data
Parameters
----------
name : str
data name
"""
assert len(name) > 0, 'name cannot be empty.'
assert name in self._data_name_list, 'data name: %s not exists.' % name
self.barrier()
part_policy = self._part_policy[name]
num_partitions = part_policy.partition_book.num_partitions()
num_clients_per_part = rpc.get_num_client() / num_partitions
if self._client_id % num_clients_per_part == 0:
# send request to every server nodes
request = DeleteDataRequest(name)
for n in range(self._group_count):
server_id = part_policy.part_id * self._group_count + n
rpc.send_request(server_id, request)
for _ in range(self._group_count):
response = rpc.recv_response()
assert response.msg == DELETE_MSG
self.barrier()
self._data_name_list.remove(name)
# TODO(chao) : remove the delete log print
del self._data_store[name]
del self._full_data_shape[name]
del self._part_policy[name]
del self._pull_handlers[name]
del self._push_handlers[name]
self.barrier()
def map_shared_data(self, partition_book): def map_shared_data(self, partition_book):
"""Mapping shared-memory tensor from server to client. """Mapping shared-memory tensor from server to client.
......
...@@ -239,6 +239,12 @@ def start_client(num_clients): ...@@ -239,6 +239,12 @@ def start_client(num_clients):
assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor)) assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor))
res = kvclient.pull(name='data_2', id_tensor=id_tensor) res = kvclient.pull(name='data_2', id_tensor=id_tensor)
assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor)) assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor))
# Test delete data
kvclient.delete_data('data_0')
kvclient.delete_data('data_1')
kvclient.delete_data('data_2')
# Register new push handler # Register new push handler
kvclient.init_data(name='data_3', kvclient.init_data(name='data_3',
shape=F.shape(data_2), shape=F.shape(data_2),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment