Unverified Commit e8a56dc1 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[KVStore] make pull/push handler per tensor. (#1646)

* make pull/push handler per tensor.

* update.
parent 41349dce
...@@ -55,12 +55,12 @@ class PullRequest(rpc.Request): ...@@ -55,12 +55,12 @@ class PullRequest(rpc.Request):
def process_request(self, server_state): def process_request(self, server_state):
kv_store = server_state.kv_store kv_store = server_state.kv_store
if kv_store.part_policy.__contains__(self.name) is False: if self.name not in kv_store.part_policy:
raise RuntimeError("KVServer cannot find partition policy with name: %s" % self.name) raise RuntimeError("KVServer cannot find partition policy with name: %s" % self.name)
if kv_store.data_store.__contains__(self.name) is False: if self.name not in kv_store.data_store:
raise RuntimeError("KVServer Cannot find data tensor with name: %s" % self.name) raise RuntimeError("KVServer Cannot find data tensor with name: %s" % self.name)
local_id = kv_store.part_policy[self.name].to_local(self.id_tensor) local_id = kv_store.part_policy[self.name].to_local(self.id_tensor)
data = kv_store.pull_handler(kv_store.data_store, self.name, local_id) data = kv_store.pull_handlers[self.name](kv_store.data_store, self.name, local_id)
res = PullResponse(kv_store.server_id, data) res = PullResponse(kv_store.server_id, data)
return res return res
...@@ -93,12 +93,13 @@ class PushRequest(rpc.Request): ...@@ -93,12 +93,13 @@ class PushRequest(rpc.Request):
def process_request(self, server_state): def process_request(self, server_state):
kv_store = server_state.kv_store kv_store = server_state.kv_store
if kv_store.part_policy.__contains__(self.name) is False: if self.name not in kv_store.part_policy:
raise RuntimeError("KVServer cannot find partition policy with name: %s" % self.name) raise RuntimeError("KVServer cannot find partition policy with name: %s" % self.name)
if kv_store.data_store.__contains__(self.name) is False: if self.name not in kv_store.data_store:
raise RuntimeError("KVServer Cannot find data tensor with name: %s" % self.name) raise RuntimeError("KVServer Cannot find data tensor with name: %s" % self.name)
local_id = kv_store.part_policy[self.name].to_local(self.id_tensor) local_id = kv_store.part_policy[self.name].to_local(self.id_tensor)
kv_store.push_handler(kv_store.data_store, self.name, local_id, self.data_tensor) kv_store.push_handlers[self.name](kv_store.data_store, self.name,
local_id, self.data_tensor)
INIT_DATA = 901233 INIT_DATA = 901233
INIT_MSG = 'Init' INIT_MSG = 'Init'
...@@ -244,18 +245,19 @@ class RegisterPullHandlerRequest(rpc.Request): ...@@ -244,18 +245,19 @@ class RegisterPullHandlerRequest(rpc.Request):
pull_func : func pull_func : func
UDF pull handler UDF pull handler
""" """
def __init__(self, pull_func): def __init__(self, name, pull_func):
self.name = name
self.pull_func = pull_func self.pull_func = pull_func
def __getstate__(self): def __getstate__(self):
return self.pull_func return self.name, self.pull_func
def __setstate__(self, state): def __setstate__(self, state):
self.pull_func = state self.name, self.pull_func = state
def process_request(self, server_state): def process_request(self, server_state):
kv_store = server_state.kv_store kv_store = server_state.kv_store
kv_store.pull_handler = self.pull_func kv_store.pull_handlers[self.name] = self.pull_func
res = RegisterPullHandlerResponse(REGISTER_PULL_MSG) res = RegisterPullHandlerResponse(REGISTER_PULL_MSG)
return res return res
...@@ -288,18 +290,19 @@ class RegisterPushHandlerRequest(rpc.Request): ...@@ -288,18 +290,19 @@ class RegisterPushHandlerRequest(rpc.Request):
push_func : func push_func : func
UDF push handler UDF push handler
""" """
def __init__(self, push_func): def __init__(self, name, push_func):
self.name = name
self.push_func = push_func self.push_func = push_func
def __getstate__(self): def __getstate__(self):
return self.push_func return self.name, self.push_func
def __setstate__(self, state): def __setstate__(self, state):
self.push_func = state self.name, self.push_func = state
def process_request(self, server_state): def process_request(self, server_state):
kv_store = server_state.kv_store kv_store = server_state.kv_store
kv_store.push_handler = self.push_func kv_store.push_handlers[self.name] = self.push_func
res = RegisterPushHandlerResponse(REGISTER_PUSH_MSG) res = RegisterPushHandlerResponse(REGISTER_PUSH_MSG)
return res return res
...@@ -569,8 +572,8 @@ class KVServer(object): ...@@ -569,8 +572,8 @@ class KVServer(object):
self._num_clients = num_clients self._num_clients = num_clients
self._barrier_count = 0 self._barrier_count = 0
# push and pull handler # push and pull handler
self._push_handler = default_push_handler self._push_handlers = {}
self._pull_handler = default_pull_handler self._pull_handlers = {}
@property @property
def server_id(self): def server_id(self):
...@@ -608,24 +611,14 @@ class KVServer(object): ...@@ -608,24 +611,14 @@ class KVServer(object):
return self._part_id return self._part_id
@property @property
def push_handler(self): def push_handlers(self):
"""Get push handler""" """Get push handler"""
return self._push_handler return self._push_handlers
@property @property
def pull_handler(self): def pull_handlers(self):
"""Get pull handler""" """Get pull handler"""
return self._pull_handler return self._pull_handlers
@pull_handler.setter
def pull_handler(self, pull_handler):
"""Set pull handler"""
self._pull_handler = pull_handler
@push_handler.setter
def push_handler(self, push_handler):
"""Set push handler"""
self._push_handler = push_handler
def is_backup_server(self): def is_backup_server(self):
"""Return True if current server is a backup server. """Return True if current server is a backup server.
...@@ -667,6 +660,8 @@ class KVServer(object): ...@@ -667,6 +660,8 @@ class KVServer(object):
self._data_store[name] = F.zerocopy_from_dlpack(dlpack) self._data_store[name] = F.zerocopy_from_dlpack(dlpack)
self._data_store[name][:] = data_tensor[:] self._data_store[name][:] = data_tensor[:]
self._part_policy[name] = self.find_policy(policy_str) self._part_policy[name] = self.find_policy(policy_str)
self._pull_handlers[name] = default_pull_handler
self._push_handlers[name] = default_push_handler
def find_policy(self, policy_str): def find_policy(self, policy_str):
"""Find a partition policy from existing policy set """Find a partition policy from existing policy set
...@@ -748,8 +743,8 @@ class KVClient(object): ...@@ -748,8 +743,8 @@ class KVClient(object):
self._part_id = self._machine_id self._part_id = self._machine_id
self._main_server_id = self._machine_id * self._group_count self._main_server_id = self._machine_id * self._group_count
# push and pull handler # push and pull handler
self._pull_handler = default_pull_handler self._pull_handlers = {}
self._push_handler = default_push_handler self._push_handlers = {}
@property @property
def client_id(self): def client_id(self):
...@@ -775,18 +770,29 @@ class KVClient(object): ...@@ -775,18 +770,29 @@ class KVClient(object):
response = rpc.recv_response() response = rpc.recv_response()
assert response.msg == BARRIER_MSG assert response.msg == BARRIER_MSG
def register_push_handler(self, func): def register_push_handler(self, name, func):
"""Register UDF push function on server. """Register UDF push function.
This UDF is triggered for every push. The signature of the UDF is
client_0 will send this request to all servers, and the other ```
clients will just invoke the barrier() api. def push_handler(data_store, name, local_offset, data)
```
`data_store` is a dict that contains all tensors in the kvstore. `name` is the name
of the tensor where new data is pushed to. `local_offset` is the offset where new
data should be written in the tensor in the local partition. `data` is the new data
to be written.
Parameters Parameters
---------- ----------
func : UDF push function name : str
The name of the tensor
func : callable
The function to be called.
""" """
if self._client_id == 0: if self._client_id == 0:
request = RegisterPushHandlerRequest(func) request = RegisterPushHandlerRequest(name, func)
# send request to all the server nodes # send request to all the server nodes
for server_id in range(self._server_count): for server_id in range(self._server_count):
rpc.send_request(server_id, request) rpc.send_request(server_id, request)
...@@ -794,21 +800,31 @@ class KVClient(object): ...@@ -794,21 +800,31 @@ class KVClient(object):
for _ in range(self._server_count): for _ in range(self._server_count):
response = rpc.recv_response() response = rpc.recv_response()
assert response.msg == REGISTER_PUSH_MSG assert response.msg == REGISTER_PUSH_MSG
self._push_handler = func self._push_handlers[name] = func
self.barrier() self.barrier()
def register_pull_handler(self, func): def register_pull_handler(self, name, func):
"""Register UDF pull function on server. """Register UDF pull function.
client_0 will send this request to all servers, and the other This UDF is triggered for every pull. The signature of the UDF is
clients will just invoke the barrier() api.
```
def pull_handler(data_store, name, local_offset)
```
`data_store` is a dict that contains all tensors in the kvstore. `name` is the name
of the tensor where new data is pushed to. `local_offset` is the offset where new
data should be written in the tensor in the local partition.
Parameters Parameters
---------- ----------
func : UDF pull function name : str
The name of the tensor
func : callable
The function to be called.
""" """
if self._client_id == 0: if self._client_id == 0:
request = RegisterPullHandlerRequest(func) request = RegisterPullHandlerRequest(name, func)
# send request to all the server nodes # send request to all the server nodes
for server_id in range(self._server_count): for server_id in range(self._server_count):
rpc.send_request(server_id, request) rpc.send_request(server_id, request)
...@@ -816,7 +832,7 @@ class KVClient(object): ...@@ -816,7 +832,7 @@ class KVClient(object):
for _ in range(self._server_namebook): for _ in range(self._server_namebook):
response = rpc.recv_response() response = rpc.recv_response()
assert response.msg == REGISTER_PULL_MSG assert response.msg == REGISTER_PULL_MSG
self._pull_handler = func self._pull_handlers[name] = func
self.barrier() self.barrier()
def init_data(self, name, shape, dtype, policy_str, partition_book, init_func): def init_data(self, name, shape, dtype, policy_str, partition_book, init_func):
...@@ -887,6 +903,8 @@ class KVClient(object): ...@@ -887,6 +903,8 @@ class KVClient(object):
self._data_store[name] = F.zerocopy_from_dlpack(dlpack) self._data_store[name] = F.zerocopy_from_dlpack(dlpack)
self._data_name_list.add(name) self._data_name_list.add(name)
self._full_data_shape[name] = tuple(shape) self._full_data_shape[name] = tuple(shape)
self._pull_handlers[name] = default_pull_handler
self._push_handlers[name] = default_push_handler
def map_shared_data(self, partition_book): def map_shared_data(self, partition_book):
"""Mapping shared-memory tensor from server to client. """Mapping shared-memory tensor from server to client.
...@@ -907,6 +925,8 @@ class KVClient(object): ...@@ -907,6 +925,8 @@ class KVClient(object):
dlpack = shared_data.to_dlpack() dlpack = shared_data.to_dlpack()
self._data_store[name] = F.zerocopy_from_dlpack(dlpack) self._data_store[name] = F.zerocopy_from_dlpack(dlpack)
self._part_policy[name] = PartitionPolicy(policy_str, self._part_id, partition_book) self._part_policy[name] = PartitionPolicy(policy_str, self._part_id, partition_book)
self._pull_handlers[name] = default_pull_handler
self._push_handlers[name] = default_push_handler
# Get full data shape across servers # Get full data shape across servers
for name, meta in response.meta.items(): for name, meta in response.meta.items():
if name not in self._data_name_list: if name not in self._data_name_list:
...@@ -995,7 +1015,7 @@ class KVClient(object): ...@@ -995,7 +1015,7 @@ class KVClient(object):
rpc.send_request_to_machine(machine_idx, request) rpc.send_request_to_machine(machine_idx, request)
start += count[idx] start += count[idx]
if local_id is not None: # local push if local_id is not None: # local push
self._push_handler(self._data_store, name, local_id, local_data) self._push_handlers[name](self._data_store, name, local_id, local_data)
def pull(self, name, id_tensor): def pull(self, name, id_tensor):
"""Pull message from KVServer. """Pull message from KVServer.
...@@ -1043,7 +1063,7 @@ class KVClient(object): ...@@ -1043,7 +1063,7 @@ class KVClient(object):
# recv response # recv response
response_list = [] response_list = []
if local_id is not None: # local pull if local_id is not None: # local pull
local_data = self._pull_handler(self._data_store, name, local_id) local_data = self._pull_handlers[name](self._data_store, name, local_id)
server_id = self._main_server_id server_id = self._main_server_id
local_response = PullResponse(server_id, local_data) local_response = PullResponse(server_id, local_data)
response_list.append(local_response) response_list.append(local_response)
......
...@@ -210,7 +210,9 @@ def start_client(): ...@@ -210,7 +210,9 @@ def start_client():
res = kvclient.pull(name='data_2', id_tensor=id_tensor) res = kvclient.pull(name='data_2', id_tensor=id_tensor)
assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor)) assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor))
# Register new push handler # Register new push handler
kvclient.register_push_handler(udf_push) kvclient.register_push_handler('data_0', udf_push)
kvclient.register_push_handler('data_1', udf_push)
kvclient.register_push_handler('data_2', udf_push)
# Test push and pull # Test push and pull
kvclient.push(name='data_0', kvclient.push(name='data_0',
id_tensor=id_tensor, id_tensor=id_tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment