Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
8eab08d0
Unverified
Commit
8eab08d0
authored
Jun 09, 2020
by
Chao Ma
Committed by
GitHub
Jun 09, 2020
Browse files
[KVStore] Remove Freeze flag (#1605)
* remove freeze * update * update * fix lint
parent
cbe4c28f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
48 deletions
+27
-48
python/dgl/distributed/kvstore.py
python/dgl/distributed/kvstore.py
+27
-48
No files found.
python/dgl/distributed/kvstore.py
View file @
8eab08d0
"""Define distributed kvstore"""
"""Define distributed kvstore"""
import
os
import
os
import
time
import
random
import
random
import
numpy
as
np
import
numpy
as
np
...
@@ -356,8 +355,6 @@ class GetSharedDataRequest(rpc.Request):
...
@@ -356,8 +355,6 @@ class GetSharedDataRequest(rpc.Request):
kv_store
.
part_policy
[
name
].
policy_str
)
kv_store
.
part_policy
[
name
].
policy_str
)
if
len
(
meta
)
==
0
:
if
len
(
meta
)
==
0
:
raise
RuntimeError
(
'There is no data on kvserver.'
)
raise
RuntimeError
(
'There is no data on kvserver.'
)
# Freeze data init
kv_store
.
freeze
=
True
res
=
GetSharedDataResponse
(
meta
)
res
=
GetSharedDataResponse
(
meta
)
return
res
return
res
...
@@ -451,6 +448,7 @@ class SendMetaToBackupRequest(rpc.Request):
...
@@ -451,6 +448,7 @@ class SendMetaToBackupRequest(rpc.Request):
def
process_request
(
self
,
server_state
):
def
process_request
(
self
,
server_state
):
kv_store
=
server_state
.
kv_store
kv_store
=
server_state
.
kv_store
assert
kv_store
.
is_backup_server
()
assert
kv_store
.
is_backup_server
()
if
self
.
name
not
in
kv_store
.
data_store
:
shared_data
=
empty_shared_mem
(
self
.
name
+
'-kvdata-'
,
False
,
self
.
shape
,
self
.
dtype
)
shared_data
=
empty_shared_mem
(
self
.
name
+
'-kvdata-'
,
False
,
self
.
shape
,
self
.
dtype
)
dlpack
=
shared_data
.
to_dlpack
()
dlpack
=
shared_data
.
to_dlpack
()
kv_store
.
data_store
[
self
.
name
]
=
F
.
zerocopy_from_dlpack
(
dlpack
)
kv_store
.
data_store
[
self
.
name
]
=
F
.
zerocopy_from_dlpack
(
dlpack
)
...
@@ -570,8 +568,6 @@ class KVServer(object):
...
@@ -570,8 +568,6 @@ class KVServer(object):
# push and pull handler
# push and pull handler
self
.
_push_handler
=
default_push_handler
self
.
_push_handler
=
default_push_handler
self
.
_pull_handler
=
default_pull_handler
self
.
_pull_handler
=
default_pull_handler
# We cannot create new data on kvstore when freeze == True
self
.
_freeze
=
False
@
property
@
property
def
server_id
(
self
):
def
server_id
(
self
):
...
@@ -588,16 +584,6 @@ class KVServer(object):
...
@@ -588,16 +584,6 @@ class KVServer(object):
"""Set barrier count"""
"""Set barrier count"""
self
.
_barrier_count
=
count
self
.
_barrier_count
=
count
@
property
def
freeze
(
self
):
"""Get freeze"""
return
self
.
_freeze
@
freeze
.
setter
def
freeze
(
self
,
freeze
):
"""Set freeze"""
self
.
_freeze
=
freeze
@
property
@
property
def
num_clients
(
self
):
def
num_clients
(
self
):
"""Get number of clients"""
"""Get number of clients"""
...
@@ -669,9 +655,6 @@ class KVServer(object):
...
@@ -669,9 +655,6 @@ class KVServer(object):
read shared-memory when client invoking get_shared_data().
read shared-memory when client invoking get_shared_data().
"""
"""
assert
len
(
name
)
>
0
,
'name cannot be empty.'
assert
len
(
name
)
>
0
,
'name cannot be empty.'
if
self
.
_freeze
:
raise
RuntimeError
(
"KVServer cannot create new data
\
after client invoking get_shared_data() API."
)
if
self
.
_data_store
.
__contains__
(
name
):
if
self
.
_data_store
.
__contains__
(
name
):
raise
RuntimeError
(
"Data %s has already exists!"
%
name
)
raise
RuntimeError
(
"Data %s has already exists!"
%
name
)
if
data_tensor
is
not
None
:
# Create shared-tensor
if
data_tensor
is
not
None
:
# Create shared-tensor
...
@@ -764,9 +747,6 @@ class KVClient(object):
...
@@ -764,9 +747,6 @@ class KVClient(object):
# push and pull handler
# push and pull handler
self
.
_pull_handler
=
default_pull_handler
self
.
_pull_handler
=
default_pull_handler
self
.
_push_handler
=
default_push_handler
self
.
_push_handler
=
default_push_handler
# We cannot create new data on kvstore when freeze == True
self
.
_freeze
=
False
random
.
seed
(
time
.
time
())
@
property
@
property
def
client_id
(
self
):
def
client_id
(
self
):
...
@@ -858,9 +838,7 @@ class KVClient(object):
...
@@ -858,9 +838,7 @@ class KVClient(object):
assert
len
(
name
)
>
0
,
'name cannot be empty.'
assert
len
(
name
)
>
0
,
'name cannot be empty.'
assert
len
(
shape
)
>
0
,
'shape cannot be empty'
assert
len
(
shape
)
>
0
,
'shape cannot be empty'
assert
policy_str
in
(
'edge'
,
'node'
),
'policy_str must be
\'
edge
\'
or
\'
node
\'
.'
assert
policy_str
in
(
'edge'
,
'node'
),
'policy_str must be
\'
edge
\'
or
\'
node
\'
.'
if
self
.
_freeze
:
assert
name
not
in
self
.
_data_name_list
,
'data name: %s already exists.'
%
name
raise
RuntimeError
(
"KVClient cannot create new
\
data after invoking get_shared_data() API."
)
shape
=
list
(
shape
)
shape
=
list
(
shape
)
if
self
.
_client_id
==
0
:
if
self
.
_client_id
==
0
:
for
machine_id
in
range
(
self
.
_machine_count
):
for
machine_id
in
range
(
self
.
_machine_count
):
...
@@ -920,14 +898,15 @@ class KVClient(object):
...
@@ -920,14 +898,15 @@ class KVClient(object):
rpc
.
send_request
(
self
.
_main_server_id
,
request
)
rpc
.
send_request
(
self
.
_main_server_id
,
request
)
response
=
rpc
.
recv_response
()
response
=
rpc
.
recv_response
()
for
name
,
meta
in
response
.
meta
.
items
():
for
name
,
meta
in
response
.
meta
.
items
():
if
name
not
in
self
.
_data_name_list
:
shape
,
dtype
,
policy_str
=
meta
shape
,
dtype
,
policy_str
=
meta
shared_data
=
empty_shared_mem
(
name
+
'-kvdata-'
,
False
,
shape
,
dtype
)
shared_data
=
empty_shared_mem
(
name
+
'-kvdata-'
,
False
,
shape
,
dtype
)
dlpack
=
shared_data
.
to_dlpack
()
dlpack
=
shared_data
.
to_dlpack
()
self
.
_data_store
[
name
]
=
F
.
zerocopy_from_dlpack
(
dlpack
)
self
.
_data_store
[
name
]
=
F
.
zerocopy_from_dlpack
(
dlpack
)
self
.
_part_policy
[
name
]
=
PartitionPolicy
(
policy_str
,
self
.
_part_id
,
partition_book
)
self
.
_part_policy
[
name
]
=
PartitionPolicy
(
policy_str
,
self
.
_part_id
,
partition_book
)
self
.
_data_name_list
.
add
(
name
)
# Get full data shape across servers
# Get full data shape across servers
for
name
,
meta
in
response
.
meta
.
items
():
for
name
,
meta
in
response
.
meta
.
items
():
if
name
not
in
self
.
_data_name_list
:
shape
,
_
,
_
=
meta
shape
,
_
,
_
=
meta
data_shape
=
list
(
shape
)
data_shape
=
list
(
shape
)
data_shape
[
0
]
=
0
data_shape
[
0
]
=
0
...
@@ -953,7 +932,7 @@ class KVClient(object):
...
@@ -953,7 +932,7 @@ class KVClient(object):
for
_
in
range
(
self
.
_group_count
-
1
):
for
_
in
range
(
self
.
_group_count
-
1
):
response
=
rpc
.
recv_response
()
response
=
rpc
.
recv_response
()
assert
response
.
msg
==
SEND_META_TO_BACKUP_MSG
assert
response
.
msg
==
SEND_META_TO_BACKUP_MSG
self
.
_
freeze
=
True
self
.
_
data_name_list
.
add
(
name
)
def
data_name_list
(
self
):
def
data_name_list
(
self
):
"""Get all the data name"""
"""Get all the data name"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment