"src/array/kernel_decl.h" did not exist on "90a103e75290571858e9ddc6bd2e29e81851282a"
Unverified Commit e8a56dc1 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[KVStore] make pull/push handler per tensor. (#1646)

* make pull/push handler per tensor.

* update.
parent 41349dce
......@@ -55,12 +55,12 @@ class PullRequest(rpc.Request):
def process_request(self, server_state):
kv_store = server_state.kv_store
if kv_store.part_policy.__contains__(self.name) is False:
if self.name not in kv_store.part_policy:
raise RuntimeError("KVServer cannot find partition policy with name: %s" % self.name)
if kv_store.data_store.__contains__(self.name) is False:
if self.name not in kv_store.data_store:
raise RuntimeError("KVServer Cannot find data tensor with name: %s" % self.name)
local_id = kv_store.part_policy[self.name].to_local(self.id_tensor)
data = kv_store.pull_handler(kv_store.data_store, self.name, local_id)
data = kv_store.pull_handlers[self.name](kv_store.data_store, self.name, local_id)
res = PullResponse(kv_store.server_id, data)
return res
......@@ -93,12 +93,13 @@ class PushRequest(rpc.Request):
def process_request(self, server_state):
kv_store = server_state.kv_store
if kv_store.part_policy.__contains__(self.name) is False:
if self.name not in kv_store.part_policy:
raise RuntimeError("KVServer cannot find partition policy with name: %s" % self.name)
if kv_store.data_store.__contains__(self.name) is False:
if self.name not in kv_store.data_store:
raise RuntimeError("KVServer Cannot find data tensor with name: %s" % self.name)
local_id = kv_store.part_policy[self.name].to_local(self.id_tensor)
kv_store.push_handler(kv_store.data_store, self.name, local_id, self.data_tensor)
kv_store.push_handlers[self.name](kv_store.data_store, self.name,
local_id, self.data_tensor)
INIT_DATA = 901233
INIT_MSG = 'Init'
......@@ -244,18 +245,19 @@ class RegisterPullHandlerRequest(rpc.Request):
pull_func : func
UDF pull handler
"""
def __init__(self, pull_func):
def __init__(self, name, pull_func):
self.name = name
self.pull_func = pull_func
def __getstate__(self):
return self.pull_func
return self.name, self.pull_func
def __setstate__(self, state):
self.pull_func = state
self.name, self.pull_func = state
def process_request(self, server_state):
kv_store = server_state.kv_store
kv_store.pull_handler = self.pull_func
kv_store.pull_handlers[self.name] = self.pull_func
res = RegisterPullHandlerResponse(REGISTER_PULL_MSG)
return res
......@@ -288,18 +290,19 @@ class RegisterPushHandlerRequest(rpc.Request):
push_func : func
UDF push handler
"""
def __init__(self, push_func):
def __init__(self, name, push_func):
self.name = name
self.push_func = push_func
def __getstate__(self):
return self.push_func
return self.name, self.push_func
def __setstate__(self, state):
self.push_func = state
self.name, self.push_func = state
def process_request(self, server_state):
kv_store = server_state.kv_store
kv_store.push_handler = self.push_func
kv_store.push_handlers[self.name] = self.push_func
res = RegisterPushHandlerResponse(REGISTER_PUSH_MSG)
return res
......@@ -569,8 +572,8 @@ class KVServer(object):
self._num_clients = num_clients
self._barrier_count = 0
# push and pull handler
self._push_handler = default_push_handler
self._pull_handler = default_pull_handler
self._push_handlers = {}
self._pull_handlers = {}
@property
def server_id(self):
......@@ -608,24 +611,14 @@ class KVServer(object):
return self._part_id
@property
def push_handler(self):
def push_handlers(self):
"""Get push handler"""
return self._push_handler
return self._push_handlers
@property
def pull_handler(self):
def pull_handlers(self):
"""Get pull handler"""
return self._pull_handler
@pull_handler.setter
def pull_handler(self, pull_handler):
"""Set pull handler"""
self._pull_handler = pull_handler
@push_handler.setter
def push_handler(self, push_handler):
"""Set push handler"""
self._push_handler = push_handler
return self._pull_handlers
def is_backup_server(self):
"""Return True if current server is a backup server.
......@@ -667,6 +660,8 @@ class KVServer(object):
self._data_store[name] = F.zerocopy_from_dlpack(dlpack)
self._data_store[name][:] = data_tensor[:]
self._part_policy[name] = self.find_policy(policy_str)
self._pull_handlers[name] = default_pull_handler
self._push_handlers[name] = default_push_handler
def find_policy(self, policy_str):
"""Find a partition policy from existing policy set
......@@ -748,8 +743,8 @@ class KVClient(object):
self._part_id = self._machine_id
self._main_server_id = self._machine_id * self._group_count
# push and pull handler
self._pull_handler = default_pull_handler
self._push_handler = default_push_handler
self._pull_handlers = {}
self._push_handlers = {}
@property
def client_id(self):
......@@ -775,18 +770,29 @@ class KVClient(object):
response = rpc.recv_response()
assert response.msg == BARRIER_MSG
def register_push_handler(self, func):
"""Register UDF push function on server.
def register_push_handler(self, name, func):
"""Register UDF push function.
This UDF is triggered for every push. The signature of the UDF is
client_0 will send this request to all servers, and the other
clients will just invoke the barrier() api.
```
def push_handler(data_store, name, local_offset, data)
```
`data_store` is a dict that contains all tensors in the kvstore. `name` is the name
of the tensor where new data is pushed to. `local_offset` is the offset where new
data should be written in the tensor in the local partition. `data` is the new data
to be written.
Parameters
----------
func : UDF push function
name : str
The name of the tensor
func : callable
The function to be called.
"""
if self._client_id == 0:
request = RegisterPushHandlerRequest(func)
request = RegisterPushHandlerRequest(name, func)
# send request to all the server nodes
for server_id in range(self._server_count):
rpc.send_request(server_id, request)
......@@ -794,21 +800,31 @@ class KVClient(object):
for _ in range(self._server_count):
response = rpc.recv_response()
assert response.msg == REGISTER_PUSH_MSG
self._push_handler = func
self._push_handlers[name] = func
self.barrier()
def register_pull_handler(self, func):
"""Register UDF pull function on server.
def register_pull_handler(self, name, func):
"""Register UDF pull function.
client_0 will send this request to all servers, and the other
clients will just invoke the barrier() api.
This UDF is triggered for every pull. The signature of the UDF is
```
def pull_handler(data_store, name, local_offset)
```
`data_store` is a dict that contains all tensors in the kvstore. `name` is the name
of the tensor where new data is pushed to. `local_offset` is the offset where new
data should be written in the tensor in the local partition.
Parameters
----------
func : UDF pull function
name : str
The name of the tensor
func : callable
The function to be called.
"""
if self._client_id == 0:
request = RegisterPullHandlerRequest(func)
request = RegisterPullHandlerRequest(name, func)
# send request to all the server nodes
for server_id in range(self._server_count):
rpc.send_request(server_id, request)
......@@ -816,7 +832,7 @@ class KVClient(object):
for _ in range(self._server_namebook):
response = rpc.recv_response()
assert response.msg == REGISTER_PULL_MSG
self._pull_handler = func
self._pull_handlers[name] = func
self.barrier()
def init_data(self, name, shape, dtype, policy_str, partition_book, init_func):
......@@ -887,6 +903,8 @@ class KVClient(object):
self._data_store[name] = F.zerocopy_from_dlpack(dlpack)
self._data_name_list.add(name)
self._full_data_shape[name] = tuple(shape)
self._pull_handlers[name] = default_pull_handler
self._push_handlers[name] = default_push_handler
def map_shared_data(self, partition_book):
"""Mapping shared-memory tensor from server to client.
......@@ -907,6 +925,8 @@ class KVClient(object):
dlpack = shared_data.to_dlpack()
self._data_store[name] = F.zerocopy_from_dlpack(dlpack)
self._part_policy[name] = PartitionPolicy(policy_str, self._part_id, partition_book)
self._pull_handlers[name] = default_pull_handler
self._push_handlers[name] = default_push_handler
# Get full data shape across servers
for name, meta in response.meta.items():
if name not in self._data_name_list:
......@@ -995,7 +1015,7 @@ class KVClient(object):
rpc.send_request_to_machine(machine_idx, request)
start += count[idx]
if local_id is not None: # local push
self._push_handler(self._data_store, name, local_id, local_data)
self._push_handlers[name](self._data_store, name, local_id, local_data)
def pull(self, name, id_tensor):
"""Pull message from KVServer.
......@@ -1043,7 +1063,7 @@ class KVClient(object):
# recv response
response_list = []
if local_id is not None: # local pull
local_data = self._pull_handler(self._data_store, name, local_id)
local_data = self._pull_handlers[name](self._data_store, name, local_id)
server_id = self._main_server_id
local_response = PullResponse(server_id, local_data)
response_list.append(local_response)
......
......@@ -210,7 +210,9 @@ def start_client():
res = kvclient.pull(name='data_2', id_tensor=id_tensor)
assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor))
# Register new push handler
kvclient.register_push_handler(udf_push)
kvclient.register_push_handler('data_0', udf_push)
kvclient.register_push_handler('data_1', udf_push)
kvclient.register_push_handler('data_2', udf_push)
# Test push and pull
kvclient.push(name='data_0',
id_tensor=id_tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment