Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
e8a56dc1
Unverified
Commit
e8a56dc1
authored
Jun 15, 2020
by
Da Zheng
Committed by
GitHub
Jun 15, 2020
Browse files
[KVStore] make pull/push handler per tensor. (#1646)
* make pull/push handler per tensor. * update.
parent
41349dce
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
71 additions
and
49 deletions
+71
-49
python/dgl/distributed/kvstore.py
python/dgl/distributed/kvstore.py
+68
-48
tests/distributed/test_new_kvstore.py
tests/distributed/test_new_kvstore.py
+3
-1
No files found.
python/dgl/distributed/kvstore.py
View file @
e8a56dc1
...
...
@@ -55,12 +55,12 @@ class PullRequest(rpc.Request):
def
process_request
(
self
,
server_state
):
kv_store
=
server_state
.
kv_store
if
kv_store
.
part_policy
.
__contains__
(
self
.
name
)
is
False
:
if
self
.
name
not
in
kv_store
.
part_policy
:
raise
RuntimeError
(
"KVServer cannot find partition policy with name: %s"
%
self
.
name
)
if
kv_store
.
data_store
.
__contains__
(
self
.
name
)
is
False
:
if
self
.
name
not
in
kv_store
.
data_store
:
raise
RuntimeError
(
"KVServer Cannot find data tensor with name: %s"
%
self
.
name
)
local_id
=
kv_store
.
part_policy
[
self
.
name
].
to_local
(
self
.
id_tensor
)
data
=
kv_store
.
pull_handler
(
kv_store
.
data_store
,
self
.
name
,
local_id
)
data
=
kv_store
.
pull_handler
s
[
self
.
name
]
(
kv_store
.
data_store
,
self
.
name
,
local_id
)
res
=
PullResponse
(
kv_store
.
server_id
,
data
)
return
res
...
...
@@ -93,12 +93,13 @@ class PushRequest(rpc.Request):
def
process_request
(
self
,
server_state
):
kv_store
=
server_state
.
kv_store
if
kv_store
.
part_policy
.
__contains__
(
self
.
name
)
is
False
:
if
self
.
name
not
in
kv_store
.
part_policy
:
raise
RuntimeError
(
"KVServer cannot find partition policy with name: %s"
%
self
.
name
)
if
kv_store
.
data_store
.
__contains__
(
self
.
name
)
is
False
:
if
self
.
name
not
in
kv_store
.
data_store
:
raise
RuntimeError
(
"KVServer Cannot find data tensor with name: %s"
%
self
.
name
)
local_id
=
kv_store
.
part_policy
[
self
.
name
].
to_local
(
self
.
id_tensor
)
kv_store
.
push_handler
(
kv_store
.
data_store
,
self
.
name
,
local_id
,
self
.
data_tensor
)
kv_store
.
push_handlers
[
self
.
name
](
kv_store
.
data_store
,
self
.
name
,
local_id
,
self
.
data_tensor
)
INIT_DATA
=
901233
INIT_MSG
=
'Init'
...
...
@@ -244,18 +245,19 @@ class RegisterPullHandlerRequest(rpc.Request):
pull_func : func
UDF pull handler
"""
def
__init__
(
self
,
pull_func
):
def
__init__
(
self
,
name
,
pull_func
):
self
.
name
=
name
self
.
pull_func
=
pull_func
def
__getstate__
(
self
):
return
self
.
pull_func
return
self
.
name
,
self
.
pull_func
def
__setstate__
(
self
,
state
):
self
.
pull_func
=
state
self
.
name
,
self
.
pull_func
=
state
def
process_request
(
self
,
server_state
):
kv_store
=
server_state
.
kv_store
kv_store
.
pull_handler
=
self
.
pull_func
kv_store
.
pull_handler
s
[
self
.
name
]
=
self
.
pull_func
res
=
RegisterPullHandlerResponse
(
REGISTER_PULL_MSG
)
return
res
...
...
@@ -288,18 +290,19 @@ class RegisterPushHandlerRequest(rpc.Request):
push_func : func
UDF push handler
"""
def
__init__
(
self
,
push_func
):
def
__init__
(
self
,
name
,
push_func
):
self
.
name
=
name
self
.
push_func
=
push_func
def
__getstate__
(
self
):
return
self
.
push_func
return
self
.
name
,
self
.
push_func
def
__setstate__
(
self
,
state
):
self
.
push_func
=
state
self
.
name
,
self
.
push_func
=
state
def
process_request
(
self
,
server_state
):
kv_store
=
server_state
.
kv_store
kv_store
.
push_handler
=
self
.
push_func
kv_store
.
push_handler
s
[
self
.
name
]
=
self
.
push_func
res
=
RegisterPushHandlerResponse
(
REGISTER_PUSH_MSG
)
return
res
...
...
@@ -569,8 +572,8 @@ class KVServer(object):
self
.
_num_clients
=
num_clients
self
.
_barrier_count
=
0
# push and pull handler
self
.
_push_handler
=
default_push_handler
self
.
_pull_handler
=
default_pull_handler
self
.
_push_handler
s
=
{}
self
.
_pull_handler
s
=
{}
@
property
def
server_id
(
self
):
...
...
@@ -608,24 +611,14 @@ class KVServer(object):
return
self
.
_part_id
@
property
def
push_handler
(
self
):
def
push_handler
s
(
self
):
"""Get push handler"""
return
self
.
_push_handler
return
self
.
_push_handler
s
@
property
def
pull_handler
(
self
):
def
pull_handler
s
(
self
):
"""Get pull handler"""
return
self
.
_pull_handler
@
pull_handler
.
setter
def
pull_handler
(
self
,
pull_handler
):
"""Set pull handler"""
self
.
_pull_handler
=
pull_handler
@
push_handler
.
setter
def
push_handler
(
self
,
push_handler
):
"""Set push handler"""
self
.
_push_handler
=
push_handler
return
self
.
_pull_handlers
def
is_backup_server
(
self
):
"""Return True if current server is a backup server.
...
...
@@ -667,6 +660,8 @@ class KVServer(object):
self
.
_data_store
[
name
]
=
F
.
zerocopy_from_dlpack
(
dlpack
)
self
.
_data_store
[
name
][:]
=
data_tensor
[:]
self
.
_part_policy
[
name
]
=
self
.
find_policy
(
policy_str
)
self
.
_pull_handlers
[
name
]
=
default_pull_handler
self
.
_push_handlers
[
name
]
=
default_push_handler
def
find_policy
(
self
,
policy_str
):
"""Find a partition policy from existing policy set
...
...
@@ -748,8 +743,8 @@ class KVClient(object):
self
.
_part_id
=
self
.
_machine_id
self
.
_main_server_id
=
self
.
_machine_id
*
self
.
_group_count
# push and pull handler
self
.
_pull_handler
=
default_pull_handler
self
.
_push_handler
=
default_push_handler
self
.
_pull_handler
s
=
{}
self
.
_push_handler
s
=
{}
@
property
def
client_id
(
self
):
...
...
@@ -775,18 +770,29 @@ class KVClient(object):
response
=
rpc
.
recv_response
()
assert
response
.
msg
==
BARRIER_MSG
def
register_push_handler
(
self
,
func
):
"""Register UDF push function on server.
def
register_push_handler
(
self
,
name
,
func
):
"""Register UDF push function.
This UDF is triggered for every push. The signature of the UDF is
client_0 will send this request to all servers, and the other
clients will just invoke the barrier() api.
```
def push_handler(data_store, name, local_offset, data)
```
`data_store` is a dict that contains all tensors in the kvstore. `name` is the name
of the tensor where new data is pushed to. `local_offset` is the offset where new
data should be written in the tensor in the local partition. `data` is the new data
to be written.
Parameters
----------
func : UDF push function
name : str
The name of the tensor
func : callable
The function to be called.
"""
if
self
.
_client_id
==
0
:
request
=
RegisterPushHandlerRequest
(
func
)
request
=
RegisterPushHandlerRequest
(
name
,
func
)
# send request to all the server nodes
for
server_id
in
range
(
self
.
_server_count
):
rpc
.
send_request
(
server_id
,
request
)
...
...
@@ -794,21 +800,31 @@ class KVClient(object):
for
_
in
range
(
self
.
_server_count
):
response
=
rpc
.
recv_response
()
assert
response
.
msg
==
REGISTER_PUSH_MSG
self
.
_push_handler
=
func
self
.
_push_handler
s
[
name
]
=
func
self
.
barrier
()
def
register_pull_handler
(
self
,
func
):
"""Register UDF pull function
on server
.
def
register_pull_handler
(
self
,
name
,
func
):
"""Register UDF pull function.
client_0 will send this request to all servers, and the other
clients will just invoke the barrier() api.
This UDF is triggered for every pull. The signature of the UDF is
```
def pull_handler(data_store, name, local_offset)
```
`data_store` is a dict that contains all tensors in the kvstore. `name` is the name
of the tensor where new data is pushed to. `local_offset` is the offset where new
data should be written in the tensor in the local partition.
Parameters
----------
func : UDF pull function
name : str
The name of the tensor
func : callable
The function to be called.
"""
if
self
.
_client_id
==
0
:
request
=
RegisterPullHandlerRequest
(
func
)
request
=
RegisterPullHandlerRequest
(
name
,
func
)
# send request to all the server nodes
for
server_id
in
range
(
self
.
_server_count
):
rpc
.
send_request
(
server_id
,
request
)
...
...
@@ -816,7 +832,7 @@ class KVClient(object):
for
_
in
range
(
self
.
_server_namebook
):
response
=
rpc
.
recv_response
()
assert
response
.
msg
==
REGISTER_PULL_MSG
self
.
_pull_handler
=
func
self
.
_pull_handler
s
[
name
]
=
func
self
.
barrier
()
def
init_data
(
self
,
name
,
shape
,
dtype
,
policy_str
,
partition_book
,
init_func
):
...
...
@@ -887,6 +903,8 @@ class KVClient(object):
self
.
_data_store
[
name
]
=
F
.
zerocopy_from_dlpack
(
dlpack
)
self
.
_data_name_list
.
add
(
name
)
self
.
_full_data_shape
[
name
]
=
tuple
(
shape
)
self
.
_pull_handlers
[
name
]
=
default_pull_handler
self
.
_push_handlers
[
name
]
=
default_push_handler
def
map_shared_data
(
self
,
partition_book
):
"""Mapping shared-memory tensor from server to client.
...
...
@@ -907,6 +925,8 @@ class KVClient(object):
dlpack
=
shared_data
.
to_dlpack
()
self
.
_data_store
[
name
]
=
F
.
zerocopy_from_dlpack
(
dlpack
)
self
.
_part_policy
[
name
]
=
PartitionPolicy
(
policy_str
,
self
.
_part_id
,
partition_book
)
self
.
_pull_handlers
[
name
]
=
default_pull_handler
self
.
_push_handlers
[
name
]
=
default_push_handler
# Get full data shape across servers
for
name
,
meta
in
response
.
meta
.
items
():
if
name
not
in
self
.
_data_name_list
:
...
...
@@ -995,7 +1015,7 @@ class KVClient(object):
rpc
.
send_request_to_machine
(
machine_idx
,
request
)
start
+=
count
[
idx
]
if
local_id
is
not
None
:
# local push
self
.
_push_handler
(
self
.
_data_store
,
name
,
local_id
,
local_data
)
self
.
_push_handler
s
[
name
]
(
self
.
_data_store
,
name
,
local_id
,
local_data
)
def
pull
(
self
,
name
,
id_tensor
):
"""Pull message from KVServer.
...
...
@@ -1043,7 +1063,7 @@ class KVClient(object):
# recv response
response_list
=
[]
if
local_id
is
not
None
:
# local pull
local_data
=
self
.
_pull_handler
(
self
.
_data_store
,
name
,
local_id
)
local_data
=
self
.
_pull_handler
s
[
name
]
(
self
.
_data_store
,
name
,
local_id
)
server_id
=
self
.
_main_server_id
local_response
=
PullResponse
(
server_id
,
local_data
)
response_list
.
append
(
local_response
)
...
...
tests/distributed/test_new_kvstore.py
View file @
e8a56dc1
...
...
@@ -210,7 +210,9 @@ def start_client():
res
=
kvclient
.
pull
(
name
=
'data_2'
,
id_tensor
=
id_tensor
)
assert_array_equal
(
F
.
asnumpy
(
res
),
F
.
asnumpy
(
data_tensor
))
# Register new push handler
kvclient
.
register_push_handler
(
udf_push
)
kvclient
.
register_push_handler
(
'data_0'
,
udf_push
)
kvclient
.
register_push_handler
(
'data_1'
,
udf_push
)
kvclient
.
register_push_handler
(
'data_2'
,
udf_push
)
# Test push and pull
kvclient
.
push
(
name
=
'data_0'
,
id_tensor
=
id_tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment