Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
e8a56dc1
Unverified
Commit
e8a56dc1
authored
Jun 15, 2020
by
Da Zheng
Committed by
GitHub
Jun 15, 2020
Browse files
[KVStore] make pull/push handler per tensor. (#1646)
* make pull/push handler per tensor. * update.
parent
41349dce
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
71 additions
and
49 deletions
+71
-49
python/dgl/distributed/kvstore.py
python/dgl/distributed/kvstore.py
+68
-48
tests/distributed/test_new_kvstore.py
tests/distributed/test_new_kvstore.py
+3
-1
No files found.
python/dgl/distributed/kvstore.py
View file @
e8a56dc1
...
@@ -55,12 +55,12 @@ class PullRequest(rpc.Request):
...
@@ -55,12 +55,12 @@ class PullRequest(rpc.Request):
def
process_request
(
self
,
server_state
):
def
process_request
(
self
,
server_state
):
kv_store
=
server_state
.
kv_store
kv_store
=
server_state
.
kv_store
if
kv_store
.
part_policy
.
__contains__
(
self
.
name
)
is
False
:
if
self
.
name
not
in
kv_store
.
part_policy
:
raise
RuntimeError
(
"KVServer cannot find partition policy with name: %s"
%
self
.
name
)
raise
RuntimeError
(
"KVServer cannot find partition policy with name: %s"
%
self
.
name
)
if
kv_store
.
data_store
.
__contains__
(
self
.
name
)
is
False
:
if
self
.
name
not
in
kv_store
.
data_store
:
raise
RuntimeError
(
"KVServer Cannot find data tensor with name: %s"
%
self
.
name
)
raise
RuntimeError
(
"KVServer Cannot find data tensor with name: %s"
%
self
.
name
)
local_id
=
kv_store
.
part_policy
[
self
.
name
].
to_local
(
self
.
id_tensor
)
local_id
=
kv_store
.
part_policy
[
self
.
name
].
to_local
(
self
.
id_tensor
)
data
=
kv_store
.
pull_handler
(
kv_store
.
data_store
,
self
.
name
,
local_id
)
data
=
kv_store
.
pull_handler
s
[
self
.
name
]
(
kv_store
.
data_store
,
self
.
name
,
local_id
)
res
=
PullResponse
(
kv_store
.
server_id
,
data
)
res
=
PullResponse
(
kv_store
.
server_id
,
data
)
return
res
return
res
...
@@ -93,12 +93,13 @@ class PushRequest(rpc.Request):
...
@@ -93,12 +93,13 @@ class PushRequest(rpc.Request):
def
process_request
(
self
,
server_state
):
def
process_request
(
self
,
server_state
):
kv_store
=
server_state
.
kv_store
kv_store
=
server_state
.
kv_store
if
kv_store
.
part_policy
.
__contains__
(
self
.
name
)
is
False
:
if
self
.
name
not
in
kv_store
.
part_policy
:
raise
RuntimeError
(
"KVServer cannot find partition policy with name: %s"
%
self
.
name
)
raise
RuntimeError
(
"KVServer cannot find partition policy with name: %s"
%
self
.
name
)
if
kv_store
.
data_store
.
__contains__
(
self
.
name
)
is
False
:
if
self
.
name
not
in
kv_store
.
data_store
:
raise
RuntimeError
(
"KVServer Cannot find data tensor with name: %s"
%
self
.
name
)
raise
RuntimeError
(
"KVServer Cannot find data tensor with name: %s"
%
self
.
name
)
local_id
=
kv_store
.
part_policy
[
self
.
name
].
to_local
(
self
.
id_tensor
)
local_id
=
kv_store
.
part_policy
[
self
.
name
].
to_local
(
self
.
id_tensor
)
kv_store
.
push_handler
(
kv_store
.
data_store
,
self
.
name
,
local_id
,
self
.
data_tensor
)
kv_store
.
push_handlers
[
self
.
name
](
kv_store
.
data_store
,
self
.
name
,
local_id
,
self
.
data_tensor
)
INIT_DATA
=
901233
INIT_DATA
=
901233
INIT_MSG
=
'Init'
INIT_MSG
=
'Init'
...
@@ -244,18 +245,19 @@ class RegisterPullHandlerRequest(rpc.Request):
...
@@ -244,18 +245,19 @@ class RegisterPullHandlerRequest(rpc.Request):
pull_func : func
pull_func : func
UDF pull handler
UDF pull handler
"""
"""
def
__init__
(
self
,
pull_func
):
def
__init__
(
self
,
name
,
pull_func
):
self
.
name
=
name
self
.
pull_func
=
pull_func
self
.
pull_func
=
pull_func
def
__getstate__
(
self
):
def
__getstate__
(
self
):
return
self
.
pull_func
return
self
.
name
,
self
.
pull_func
def
__setstate__
(
self
,
state
):
def
__setstate__
(
self
,
state
):
self
.
pull_func
=
state
self
.
name
,
self
.
pull_func
=
state
def
process_request
(
self
,
server_state
):
def
process_request
(
self
,
server_state
):
kv_store
=
server_state
.
kv_store
kv_store
=
server_state
.
kv_store
kv_store
.
pull_handler
=
self
.
pull_func
kv_store
.
pull_handler
s
[
self
.
name
]
=
self
.
pull_func
res
=
RegisterPullHandlerResponse
(
REGISTER_PULL_MSG
)
res
=
RegisterPullHandlerResponse
(
REGISTER_PULL_MSG
)
return
res
return
res
...
@@ -288,18 +290,19 @@ class RegisterPushHandlerRequest(rpc.Request):
...
@@ -288,18 +290,19 @@ class RegisterPushHandlerRequest(rpc.Request):
push_func : func
push_func : func
UDF push handler
UDF push handler
"""
"""
def
__init__
(
self
,
push_func
):
def
__init__
(
self
,
name
,
push_func
):
self
.
name
=
name
self
.
push_func
=
push_func
self
.
push_func
=
push_func
def
__getstate__
(
self
):
def
__getstate__
(
self
):
return
self
.
push_func
return
self
.
name
,
self
.
push_func
def
__setstate__
(
self
,
state
):
def
__setstate__
(
self
,
state
):
self
.
push_func
=
state
self
.
name
,
self
.
push_func
=
state
def
process_request
(
self
,
server_state
):
def
process_request
(
self
,
server_state
):
kv_store
=
server_state
.
kv_store
kv_store
=
server_state
.
kv_store
kv_store
.
push_handler
=
self
.
push_func
kv_store
.
push_handler
s
[
self
.
name
]
=
self
.
push_func
res
=
RegisterPushHandlerResponse
(
REGISTER_PUSH_MSG
)
res
=
RegisterPushHandlerResponse
(
REGISTER_PUSH_MSG
)
return
res
return
res
...
@@ -569,8 +572,8 @@ class KVServer(object):
...
@@ -569,8 +572,8 @@ class KVServer(object):
self
.
_num_clients
=
num_clients
self
.
_num_clients
=
num_clients
self
.
_barrier_count
=
0
self
.
_barrier_count
=
0
# push and pull handler
# push and pull handler
self
.
_push_handler
=
default_push_handler
self
.
_push_handler
s
=
{}
self
.
_pull_handler
=
default_pull_handler
self
.
_pull_handler
s
=
{}
@
property
@
property
def
server_id
(
self
):
def
server_id
(
self
):
...
@@ -608,24 +611,14 @@ class KVServer(object):
...
@@ -608,24 +611,14 @@ class KVServer(object):
return
self
.
_part_id
return
self
.
_part_id
@
property
@
property
def
push_handler
(
self
):
def
push_handler
s
(
self
):
"""Get push handler"""
"""Get push handler"""
return
self
.
_push_handler
return
self
.
_push_handler
s
@
property
@
property
def
pull_handler
(
self
):
def
pull_handler
s
(
self
):
"""Get pull handler"""
"""Get pull handler"""
return
self
.
_pull_handler
return
self
.
_pull_handlers
@
pull_handler
.
setter
def
pull_handler
(
self
,
pull_handler
):
"""Set pull handler"""
self
.
_pull_handler
=
pull_handler
@
push_handler
.
setter
def
push_handler
(
self
,
push_handler
):
"""Set push handler"""
self
.
_push_handler
=
push_handler
def
is_backup_server
(
self
):
def
is_backup_server
(
self
):
"""Return True if current server is a backup server.
"""Return True if current server is a backup server.
...
@@ -667,6 +660,8 @@ class KVServer(object):
...
@@ -667,6 +660,8 @@ class KVServer(object):
self
.
_data_store
[
name
]
=
F
.
zerocopy_from_dlpack
(
dlpack
)
self
.
_data_store
[
name
]
=
F
.
zerocopy_from_dlpack
(
dlpack
)
self
.
_data_store
[
name
][:]
=
data_tensor
[:]
self
.
_data_store
[
name
][:]
=
data_tensor
[:]
self
.
_part_policy
[
name
]
=
self
.
find_policy
(
policy_str
)
self
.
_part_policy
[
name
]
=
self
.
find_policy
(
policy_str
)
self
.
_pull_handlers
[
name
]
=
default_pull_handler
self
.
_push_handlers
[
name
]
=
default_push_handler
def
find_policy
(
self
,
policy_str
):
def
find_policy
(
self
,
policy_str
):
"""Find a partition policy from existing policy set
"""Find a partition policy from existing policy set
...
@@ -748,8 +743,8 @@ class KVClient(object):
...
@@ -748,8 +743,8 @@ class KVClient(object):
self
.
_part_id
=
self
.
_machine_id
self
.
_part_id
=
self
.
_machine_id
self
.
_main_server_id
=
self
.
_machine_id
*
self
.
_group_count
self
.
_main_server_id
=
self
.
_machine_id
*
self
.
_group_count
# push and pull handler
# push and pull handler
self
.
_pull_handler
=
default_pull_handler
self
.
_pull_handler
s
=
{}
self
.
_push_handler
=
default_push_handler
self
.
_push_handler
s
=
{}
@
property
@
property
def
client_id
(
self
):
def
client_id
(
self
):
...
@@ -775,18 +770,29 @@ class KVClient(object):
...
@@ -775,18 +770,29 @@ class KVClient(object):
response
=
rpc
.
recv_response
()
response
=
rpc
.
recv_response
()
assert
response
.
msg
==
BARRIER_MSG
assert
response
.
msg
==
BARRIER_MSG
def
register_push_handler
(
self
,
func
):
def
register_push_handler
(
self
,
name
,
func
):
"""Register UDF push function on server.
"""Register UDF push function.
This UDF is triggered for every push. The signature of the UDF is
client_0 will send this request to all servers, and the other
```
clients will just invoke the barrier() api.
def push_handler(data_store, name, local_offset, data)
```
`data_store` is a dict that contains all tensors in the kvstore. `name` is the name
of the tensor where new data is pushed to. `local_offset` is the offset where new
data should be written in the tensor in the local partition. `data` is the new data
to be written.
Parameters
Parameters
----------
----------
func : UDF push function
name : str
The name of the tensor
func : callable
The function to be called.
"""
"""
if
self
.
_client_id
==
0
:
if
self
.
_client_id
==
0
:
request
=
RegisterPushHandlerRequest
(
func
)
request
=
RegisterPushHandlerRequest
(
name
,
func
)
# send request to all the server nodes
# send request to all the server nodes
for
server_id
in
range
(
self
.
_server_count
):
for
server_id
in
range
(
self
.
_server_count
):
rpc
.
send_request
(
server_id
,
request
)
rpc
.
send_request
(
server_id
,
request
)
...
@@ -794,21 +800,31 @@ class KVClient(object):
...
@@ -794,21 +800,31 @@ class KVClient(object):
for
_
in
range
(
self
.
_server_count
):
for
_
in
range
(
self
.
_server_count
):
response
=
rpc
.
recv_response
()
response
=
rpc
.
recv_response
()
assert
response
.
msg
==
REGISTER_PUSH_MSG
assert
response
.
msg
==
REGISTER_PUSH_MSG
self
.
_push_handler
=
func
self
.
_push_handler
s
[
name
]
=
func
self
.
barrier
()
self
.
barrier
()
def
register_pull_handler
(
self
,
func
):
def
register_pull_handler
(
self
,
name
,
func
):
"""Register UDF pull function
on server
.
"""Register UDF pull function.
client_0 will send this request to all servers, and the other
This UDF is triggered for every pull. The signature of the UDF is
clients will just invoke the barrier() api.
```
def pull_handler(data_store, name, local_offset)
```
`data_store` is a dict that contains all tensors in the kvstore. `name` is the name
of the tensor where new data is pushed to. `local_offset` is the offset where new
data should be written in the tensor in the local partition.
Parameters
Parameters
----------
----------
func : UDF pull function
name : str
The name of the tensor
func : callable
The function to be called.
"""
"""
if
self
.
_client_id
==
0
:
if
self
.
_client_id
==
0
:
request
=
RegisterPullHandlerRequest
(
func
)
request
=
RegisterPullHandlerRequest
(
name
,
func
)
# send request to all the server nodes
# send request to all the server nodes
for
server_id
in
range
(
self
.
_server_count
):
for
server_id
in
range
(
self
.
_server_count
):
rpc
.
send_request
(
server_id
,
request
)
rpc
.
send_request
(
server_id
,
request
)
...
@@ -816,7 +832,7 @@ class KVClient(object):
...
@@ -816,7 +832,7 @@ class KVClient(object):
for
_
in
range
(
self
.
_server_namebook
):
for
_
in
range
(
self
.
_server_namebook
):
response
=
rpc
.
recv_response
()
response
=
rpc
.
recv_response
()
assert
response
.
msg
==
REGISTER_PULL_MSG
assert
response
.
msg
==
REGISTER_PULL_MSG
self
.
_pull_handler
=
func
self
.
_pull_handler
s
[
name
]
=
func
self
.
barrier
()
self
.
barrier
()
def
init_data
(
self
,
name
,
shape
,
dtype
,
policy_str
,
partition_book
,
init_func
):
def
init_data
(
self
,
name
,
shape
,
dtype
,
policy_str
,
partition_book
,
init_func
):
...
@@ -887,6 +903,8 @@ class KVClient(object):
...
@@ -887,6 +903,8 @@ class KVClient(object):
self
.
_data_store
[
name
]
=
F
.
zerocopy_from_dlpack
(
dlpack
)
self
.
_data_store
[
name
]
=
F
.
zerocopy_from_dlpack
(
dlpack
)
self
.
_data_name_list
.
add
(
name
)
self
.
_data_name_list
.
add
(
name
)
self
.
_full_data_shape
[
name
]
=
tuple
(
shape
)
self
.
_full_data_shape
[
name
]
=
tuple
(
shape
)
self
.
_pull_handlers
[
name
]
=
default_pull_handler
self
.
_push_handlers
[
name
]
=
default_push_handler
def
map_shared_data
(
self
,
partition_book
):
def
map_shared_data
(
self
,
partition_book
):
"""Mapping shared-memory tensor from server to client.
"""Mapping shared-memory tensor from server to client.
...
@@ -907,6 +925,8 @@ class KVClient(object):
...
@@ -907,6 +925,8 @@ class KVClient(object):
dlpack
=
shared_data
.
to_dlpack
()
dlpack
=
shared_data
.
to_dlpack
()
self
.
_data_store
[
name
]
=
F
.
zerocopy_from_dlpack
(
dlpack
)
self
.
_data_store
[
name
]
=
F
.
zerocopy_from_dlpack
(
dlpack
)
self
.
_part_policy
[
name
]
=
PartitionPolicy
(
policy_str
,
self
.
_part_id
,
partition_book
)
self
.
_part_policy
[
name
]
=
PartitionPolicy
(
policy_str
,
self
.
_part_id
,
partition_book
)
self
.
_pull_handlers
[
name
]
=
default_pull_handler
self
.
_push_handlers
[
name
]
=
default_push_handler
# Get full data shape across servers
# Get full data shape across servers
for
name
,
meta
in
response
.
meta
.
items
():
for
name
,
meta
in
response
.
meta
.
items
():
if
name
not
in
self
.
_data_name_list
:
if
name
not
in
self
.
_data_name_list
:
...
@@ -995,7 +1015,7 @@ class KVClient(object):
...
@@ -995,7 +1015,7 @@ class KVClient(object):
rpc
.
send_request_to_machine
(
machine_idx
,
request
)
rpc
.
send_request_to_machine
(
machine_idx
,
request
)
start
+=
count
[
idx
]
start
+=
count
[
idx
]
if
local_id
is
not
None
:
# local push
if
local_id
is
not
None
:
# local push
self
.
_push_handler
(
self
.
_data_store
,
name
,
local_id
,
local_data
)
self
.
_push_handler
s
[
name
]
(
self
.
_data_store
,
name
,
local_id
,
local_data
)
def
pull
(
self
,
name
,
id_tensor
):
def
pull
(
self
,
name
,
id_tensor
):
"""Pull message from KVServer.
"""Pull message from KVServer.
...
@@ -1043,7 +1063,7 @@ class KVClient(object):
...
@@ -1043,7 +1063,7 @@ class KVClient(object):
# recv response
# recv response
response_list
=
[]
response_list
=
[]
if
local_id
is
not
None
:
# local pull
if
local_id
is
not
None
:
# local pull
local_data
=
self
.
_pull_handler
(
self
.
_data_store
,
name
,
local_id
)
local_data
=
self
.
_pull_handler
s
[
name
]
(
self
.
_data_store
,
name
,
local_id
)
server_id
=
self
.
_main_server_id
server_id
=
self
.
_main_server_id
local_response
=
PullResponse
(
server_id
,
local_data
)
local_response
=
PullResponse
(
server_id
,
local_data
)
response_list
.
append
(
local_response
)
response_list
.
append
(
local_response
)
...
...
tests/distributed/test_new_kvstore.py
View file @
e8a56dc1
...
@@ -210,7 +210,9 @@ def start_client():
...
@@ -210,7 +210,9 @@ def start_client():
res
=
kvclient
.
pull
(
name
=
'data_2'
,
id_tensor
=
id_tensor
)
res
=
kvclient
.
pull
(
name
=
'data_2'
,
id_tensor
=
id_tensor
)
assert_array_equal
(
F
.
asnumpy
(
res
),
F
.
asnumpy
(
data_tensor
))
assert_array_equal
(
F
.
asnumpy
(
res
),
F
.
asnumpy
(
data_tensor
))
# Register new push handler
# Register new push handler
kvclient
.
register_push_handler
(
udf_push
)
kvclient
.
register_push_handler
(
'data_0'
,
udf_push
)
kvclient
.
register_push_handler
(
'data_1'
,
udf_push
)
kvclient
.
register_push_handler
(
'data_2'
,
udf_push
)
# Test push and pull
# Test push and pull
kvclient
.
push
(
name
=
'data_0'
,
kvclient
.
push
(
name
=
'data_0'
,
id_tensor
=
id_tensor
,
id_tensor
=
id_tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment