Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
a0193fd5
Unverified
Commit
a0193fd5
authored
Nov 08, 2019
by
Chao Ma
Committed by
GitHub
Nov 08, 2019
Browse files
Small change for kvstore api (#981)
* Small change for kvstore api * fix ci * fix ci
parent
0b4935d4
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
49 additions
and
48 deletions
+49
-48
examples/mxnet/dis_kvstore/client.py
examples/mxnet/dis_kvstore/client.py
+15
-15
examples/pytorch/dis_kvstore/client.py
examples/pytorch/dis_kvstore/client.py
+15
-15
python/dgl/contrib/dis_kvstore.py
python/dgl/contrib/dis_kvstore.py
+1
-1
tests/compute/test_kvstore.py
tests/compute/test_kvstore.py
+18
-17
No files found.
examples/mxnet/dis_kvstore/client.py
View file @
a0193fd5
...
@@ -37,31 +37,31 @@ def start_client(args):
...
@@ -37,31 +37,31 @@ def start_client(args):
if
client
.
get_id
()
==
0
:
if
client
.
get_id
()
==
0
:
client
.
pull
(
name
=
'embed_0'
,
server_id
=
0
,
id_tensor
=
mx
.
nd
.
array
([
0
,
1
,
2
,
3
,
4
],
dtype
=
'int64'
))
client
.
pull
(
name
=
'embed_0'
,
server_id
=
0
,
id_tensor
=
mx
.
nd
.
array
([
0
,
1
,
2
,
3
,
4
],
dtype
=
'int64'
))
server_id
,
new_tensor
_0
=
client
.
pull_wait
()
msg
_0
=
client
.
pull_wait
()
assert
server_id
==
0
assert
msg_0
.
rank
==
0
client
.
pull
(
name
=
'embed_0'
,
server_id
=
1
,
id_tensor
=
mx
.
nd
.
array
([
0
,
1
,
2
,
3
,
4
,
5
],
dtype
=
'int64'
))
client
.
pull
(
name
=
'embed_0'
,
server_id
=
1
,
id_tensor
=
mx
.
nd
.
array
([
0
,
1
,
2
,
3
,
4
,
5
],
dtype
=
'int64'
))
server_id
,
new_tensor
_1
=
client
.
pull_wait
()
msg
_1
=
client
.
pull_wait
()
assert
server_id
==
1
assert
msg_1
.
rank
==
1
print
(
"embed_0:"
)
print
(
"embed_0:"
)
print
(
mx
.
nd
.
concat
(
new_tensor_0
,
new_tensor_1
,
dim
=
0
))
print
(
mx
.
nd
.
concat
(
msg_0
.
data
,
msg_1
.
data
,
dim
=
0
))
client
.
pull
(
name
=
'embed_1'
,
server_id
=
0
,
id_tensor
=
mx
.
nd
.
array
([
0
,
1
,
2
,
3
,
4
],
dtype
=
'int64'
))
client
.
pull
(
name
=
'embed_1'
,
server_id
=
0
,
id_tensor
=
mx
.
nd
.
array
([
0
,
1
,
2
,
3
,
4
],
dtype
=
'int64'
))
server_id
,
new_tensor
_0
=
client
.
pull_wait
()
msg
_0
=
client
.
pull_wait
()
assert
server_id
==
0
assert
msg_0
.
rank
==
0
client
.
pull
(
name
=
'embed_1'
,
server_id
=
1
,
id_tensor
=
mx
.
nd
.
array
([
0
,
1
,
2
,
3
,
4
,
5
],
dtype
=
'int64'
))
client
.
pull
(
name
=
'embed_1'
,
server_id
=
1
,
id_tensor
=
mx
.
nd
.
array
([
0
,
1
,
2
,
3
,
4
,
5
],
dtype
=
'int64'
))
server_id
,
new_tensor
_1
=
client
.
pull_wait
()
msg
_1
=
client
.
pull_wait
()
assert
server_id
==
1
assert
msg_1
.
rank
==
1
print
(
"embed_1:"
)
print
(
"embed_1:"
)
print
(
mx
.
nd
.
concat
(
new_tensor_0
,
new_tensor_1
,
dim
=
0
))
print
(
mx
.
nd
.
concat
(
msg_0
.
data
,
msg_1
.
data
,
dim
=
0
))
client
.
pull
(
name
=
'server_embed'
,
server_id
=
0
,
id_tensor
=
mx
.
nd
.
array
([
0
,
1
,
2
,
3
,
4
],
dtype
=
'int64'
))
client
.
pull
(
name
=
'server_embed'
,
server_id
=
0
,
id_tensor
=
mx
.
nd
.
array
([
0
,
1
,
2
,
3
,
4
],
dtype
=
'int64'
))
server_id
,
new_tensor
_0
=
client
.
pull_wait
()
msg
_0
=
client
.
pull_wait
()
assert
server_id
==
0
assert
msg_0
.
rank
==
0
client
.
pull
(
name
=
'server_embed'
,
server_id
=
1
,
id_tensor
=
mx
.
nd
.
array
([
0
,
1
,
2
,
3
,
4
],
dtype
=
'int64'
))
client
.
pull
(
name
=
'server_embed'
,
server_id
=
1
,
id_tensor
=
mx
.
nd
.
array
([
0
,
1
,
2
,
3
,
4
],
dtype
=
'int64'
))
server_id
,
new_tensor
_1
=
client
.
pull_wait
()
msg
_1
=
client
.
pull_wait
()
assert
server_id
==
1
assert
msg_1
.
rank
==
1
print
(
"server_embed:"
)
print
(
"server_embed:"
)
print
(
mx
.
nd
.
concat
(
new_tensor_0
,
new_tensor_1
,
dim
=
0
))
print
(
mx
.
nd
.
concat
(
msg_0
.
data
,
msg_1
.
data
,
dim
=
0
))
# Shut-down all the servers
# Shut-down all the servers
if
client
.
get_id
()
==
0
:
if
client
.
get_id
()
==
0
:
...
...
examples/pytorch/dis_kvstore/client.py
View file @
a0193fd5
...
@@ -37,31 +37,31 @@ def start_client(args):
...
@@ -37,31 +37,31 @@ def start_client(args):
if
client
.
get_id
()
==
0
:
if
client
.
get_id
()
==
0
:
client
.
pull
(
name
=
'embed_0'
,
server_id
=
0
,
id_tensor
=
th
.
tensor
([
0
,
1
,
2
,
3
,
4
]))
client
.
pull
(
name
=
'embed_0'
,
server_id
=
0
,
id_tensor
=
th
.
tensor
([
0
,
1
,
2
,
3
,
4
]))
server_id
,
new_tensor
_0
=
client
.
pull_wait
()
msg
_0
=
client
.
pull_wait
()
assert
server_id
==
0
assert
msg_0
.
rank
==
0
client
.
pull
(
name
=
'embed_0'
,
server_id
=
1
,
id_tensor
=
th
.
tensor
([
0
,
1
,
2
,
3
,
4
,
5
]))
client
.
pull
(
name
=
'embed_0'
,
server_id
=
1
,
id_tensor
=
th
.
tensor
([
0
,
1
,
2
,
3
,
4
,
5
]))
server_id
,
new_tensor
_1
=
client
.
pull_wait
()
msg
_1
=
client
.
pull_wait
()
assert
server_id
==
1
assert
msg_1
.
rank
==
1
print
(
"embed_0:"
)
print
(
"embed_0:"
)
print
(
th
.
cat
([
new_tensor_0
,
new_tensor_1
]))
print
(
th
.
cat
([
msg_0
.
data
,
msg_1
.
data
]))
client
.
pull
(
name
=
'embed_1'
,
server_id
=
0
,
id_tensor
=
th
.
tensor
([
0
,
1
,
2
,
3
,
4
]))
client
.
pull
(
name
=
'embed_1'
,
server_id
=
0
,
id_tensor
=
th
.
tensor
([
0
,
1
,
2
,
3
,
4
]))
server_id
,
new_tensor
_0
=
client
.
pull_wait
()
msg
_0
=
client
.
pull_wait
()
assert
server_id
==
0
assert
msg_0
.
rank
==
0
client
.
pull
(
name
=
'embed_1'
,
server_id
=
1
,
id_tensor
=
th
.
tensor
([
0
,
1
,
2
,
3
,
4
,
5
]))
client
.
pull
(
name
=
'embed_1'
,
server_id
=
1
,
id_tensor
=
th
.
tensor
([
0
,
1
,
2
,
3
,
4
,
5
]))
server_id
,
new_tensor
_1
=
client
.
pull_wait
()
msg
_1
=
client
.
pull_wait
()
assert
server_id
==
1
assert
msg_1
.
rank
==
1
print
(
"embed_1:"
)
print
(
"embed_1:"
)
print
(
th
.
cat
([
new_tensor_0
,
new_tensor_1
]))
print
(
th
.
cat
([
msg_0
.
data
,
msg_1
.
data
]))
client
.
pull
(
name
=
'server_embed'
,
server_id
=
0
,
id_tensor
=
th
.
tensor
([
0
,
1
,
2
,
3
,
4
]))
client
.
pull
(
name
=
'server_embed'
,
server_id
=
0
,
id_tensor
=
th
.
tensor
([
0
,
1
,
2
,
3
,
4
]))
server_id
,
new_tensor
_0
=
client
.
pull_wait
()
msg
_0
=
client
.
pull_wait
()
assert
server_id
==
0
assert
msg_0
.
rank
==
0
client
.
pull
(
name
=
'server_embed'
,
server_id
=
1
,
id_tensor
=
th
.
tensor
([
0
,
1
,
2
,
3
,
4
]))
client
.
pull
(
name
=
'server_embed'
,
server_id
=
1
,
id_tensor
=
th
.
tensor
([
0
,
1
,
2
,
3
,
4
]))
server_id
,
new_tensor
_1
=
client
.
pull_wait
()
msg
_1
=
client
.
pull_wait
()
assert
server_id
==
1
assert
msg_1
.
rank
==
1
print
(
"server_embed:"
)
print
(
"server_embed:"
)
print
(
th
.
cat
([
new_tensor_0
,
new_tensor_1
]))
print
(
th
.
cat
([
msg_0
.
data
,
msg_1
.
data
]))
# Shut-down all the servers
# Shut-down all the servers
if
client
.
get_id
()
==
0
:
if
client
.
get_id
()
==
0
:
...
...
python/dgl/contrib/dis_kvstore.py
View file @
a0193fd5
...
@@ -407,7 +407,7 @@ class KVClient(object):
...
@@ -407,7 +407,7 @@ class KVClient(object):
"""
"""
msg
=
_recv_kv_msg
(
self
.
_receiver
)
msg
=
_recv_kv_msg
(
self
.
_receiver
)
assert
msg
.
type
==
KVMsgType
.
PULL_BACK
,
'Recv kv msg error.'
assert
msg
.
type
==
KVMsgType
.
PULL_BACK
,
'Recv kv msg error.'
return
msg
.
rank
,
msg
.
data
return
msg
def
barrier
(
self
):
def
barrier
(
self
):
"""Barrier for all client nodes
"""Barrier for all client nodes
...
...
tests/compute/test_kvstore.py
View file @
a0193fd5
...
@@ -45,8 +45,8 @@ def start_client(server_embed):
...
@@ -45,8 +45,8 @@ def start_client(server_embed):
client
.
barrier
()
client
.
barrier
()
client
.
pull
(
name
=
'embed_0'
,
server_id
=
0
,
id_tensor
=
th
.
tensor
([
0
,
1
,
2
,
3
,
4
]))
client
.
pull
(
name
=
'embed_0'
,
server_id
=
0
,
id_tensor
=
th
.
tensor
([
0
,
1
,
2
,
3
,
4
]))
server_id
,
new_tensor
=
client
.
pull_wait
()
msg
=
client
.
pull_wait
()
assert
server_id
==
0
assert
msg
.
rank
==
0
target_tensor_0
=
th
.
tensor
(
target_tensor_0
=
th
.
tensor
(
[[
0.
,
0.
,
0.
],
[[
0.
,
0.
,
0.
],
...
@@ -55,14 +55,14 @@ def start_client(server_embed):
...
@@ -55,14 +55,14 @@ def start_client(server_embed):
[
0.
,
0.
,
0.
],
[
0.
,
0.
,
0.
],
[
10.
,
10.
,
10.
]])
[
10.
,
10.
,
10.
]])
assert
th
.
equal
(
new_tensor
,
target_tensor_0
)
==
True
assert
th
.
equal
(
msg
.
data
,
target_tensor_0
)
==
True
client
.
pull
(
name
=
'embed_1'
,
server_id
=
0
,
id_tensor
=
th
.
tensor
([
0
,
1
,
2
,
3
,
4
]))
client
.
pull
(
name
=
'embed_1'
,
server_id
=
0
,
id_tensor
=
th
.
tensor
([
0
,
1
,
2
,
3
,
4
]))
server_id
,
new_tensor
=
client
.
pull_wait
()
msg
=
client
.
pull_wait
()
target_tensor_1
=
th
.
tensor
([
0.
,
0.
,
5.
,
0.
,
10.
])
target_tensor_1
=
th
.
tensor
([
0.
,
0.
,
5.
,
0.
,
10.
])
assert
th
.
equal
(
new_tensor
,
target_tensor_1
)
==
True
assert
th
.
equal
(
msg
.
data
,
target_tensor_1
)
==
True
client
.
pull
(
name
=
'embed_0'
,
server_id
=
0
,
id_tensor
=
th
.
tensor
([
0
,
1
,
2
,
3
,
4
]))
client
.
pull
(
name
=
'embed_0'
,
server_id
=
0
,
id_tensor
=
th
.
tensor
([
0
,
1
,
2
,
3
,
4
]))
client
.
pull
(
name
=
'embed_1'
,
server_id
=
0
,
id_tensor
=
th
.
tensor
([
0
,
1
,
2
,
3
,
4
]))
client
.
pull
(
name
=
'embed_1'
,
server_id
=
0
,
id_tensor
=
th
.
tensor
([
0
,
1
,
2
,
3
,
4
]))
...
@@ -70,31 +70,32 @@ def start_client(server_embed):
...
@@ -70,31 +70,32 @@ def start_client(server_embed):
client
.
pull
(
name
=
'embed_1'
,
server_id
=
0
,
id_tensor
=
th
.
tensor
([
0
,
1
,
2
,
3
,
4
]))
client
.
pull
(
name
=
'embed_1'
,
server_id
=
0
,
id_tensor
=
th
.
tensor
([
0
,
1
,
2
,
3
,
4
]))
client
.
pull
(
name
=
'server_embed'
,
server_id
=
0
,
id_tensor
=
th
.
tensor
([
0
,
1
,
2
,
3
,
4
]))
client
.
pull
(
name
=
'server_embed'
,
server_id
=
0
,
id_tensor
=
th
.
tensor
([
0
,
1
,
2
,
3
,
4
]))
_
,
tensor
_0
=
client
.
pull_wait
()
msg
_0
=
client
.
pull_wait
()
_
,
tensor
_1
=
client
.
pull_wait
()
msg
_1
=
client
.
pull_wait
()
_
,
tensor
_2
=
client
.
pull_wait
()
msg
_2
=
client
.
pull_wait
()
_
,
tensor
_3
=
client
.
pull_wait
()
msg
_3
=
client
.
pull_wait
()
_
,
tensor
_4
=
client
.
pull_wait
()
msg
_4
=
client
.
pull_wait
()
target_tensor_2
=
th
.
tensor
([
2.
,
2.
,
7.
,
2.
,
12.
])
target_tensor_2
=
th
.
tensor
([
2.
,
2.
,
7.
,
2.
,
12.
])
assert
th
.
equal
(
tensor_0
,
target_tensor_0
)
==
True
assert
th
.
equal
(
msg_0
.
data
,
target_tensor_0
)
==
True
assert
th
.
equal
(
tensor_1
,
target_tensor_1
)
==
True
assert
th
.
equal
(
msg_1
.
data
,
target_tensor_1
)
==
True
assert
th
.
equal
(
tensor_2
,
target_tensor_0
)
==
True
assert
th
.
equal
(
msg_2
.
data
,
target_tensor_0
)
==
True
assert
th
.
equal
(
tensor_3
,
target_tensor_1
)
==
True
assert
th
.
equal
(
msg_3
.
data
,
target_tensor_1
)
==
True
assert
th
.
equal
(
tensor_4
,
target_tensor_2
)
==
True
assert
th
.
equal
(
msg_4
.
data
,
target_tensor_2
)
==
True
server_embed
+=
target_tensor_2
server_embed
+=
target_tensor_2
client
.
pull
(
name
=
'server_embed'
,
server_id
=
0
,
id_tensor
=
th
.
tensor
([
0
,
1
,
2
,
3
,
4
]))
client
.
pull
(
name
=
'server_embed'
,
server_id
=
0
,
id_tensor
=
th
.
tensor
([
0
,
1
,
2
,
3
,
4
]))
_
,
tensor
_5
=
client
.
pull_wait
()
msg
_5
=
client
.
pull_wait
()
assert
th
.
equal
(
tensor_5
,
target_tensor_2
*
2
)
==
True
assert
th
.
equal
(
msg_5
.
data
,
target_tensor_2
*
2
)
==
True
client
.
shut_down
()
client
.
shut_down
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
server_embed
=
th
.
tensor
([
2.
,
2.
,
2.
,
2.
,
2.
])
server_embed
=
th
.
tensor
([
2.
,
2.
,
2.
,
2.
,
2.
])
# use pytorch shared memory
server_embed
.
share_memory_
()
server_embed
.
share_memory_
()
pid
=
os
.
fork
()
pid
=
os
.
fork
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment