Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
77822769
Unverified
Commit
77822769
authored
Sep 11, 2019
by
Chao Ma
Committed by
GitHub
Sep 11, 2019
Browse files
[KVStore] Distributed kvstore (#851)
* update * speedup * add some comments
parent
5f2f100b
Changes
20
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1386 additions
and
105 deletions
+1386
-105
examples/mxnet/dis_kvstore/README.md
examples/mxnet/dis_kvstore/README.md
+8
-0
examples/mxnet/dis_kvstore/client.py
examples/mxnet/dis_kvstore/client.py
+62
-0
examples/mxnet/dis_kvstore/config.txt
examples/mxnet/dis_kvstore/config.txt
+6
-0
examples/mxnet/dis_kvstore/run.sh
examples/mxnet/dis_kvstore/run.sh
+6
-0
examples/mxnet/dis_kvstore/server.py
examples/mxnet/dis_kvstore/server.py
+22
-0
examples/pytorch/dis_kvstore/README.md
examples/pytorch/dis_kvstore/README.md
+8
-0
examples/pytorch/dis_kvstore/client.py
examples/pytorch/dis_kvstore/client.py
+62
-0
examples/pytorch/dis_kvstore/config.txt
examples/pytorch/dis_kvstore/config.txt
+6
-0
examples/pytorch/dis_kvstore/run.sh
examples/pytorch/dis_kvstore/run.sh
+6
-0
examples/pytorch/dis_kvstore/server.py
examples/pytorch/dis_kvstore/server.py
+22
-0
python/dgl/backend/backend.py
python/dgl/backend/backend.py
+20
-0
python/dgl/backend/mxnet/tensor.py
python/dgl/backend/mxnet/tensor.py
+3
-0
python/dgl/backend/pytorch/tensor.py
python/dgl/backend/pytorch/tensor.py
+3
-0
python/dgl/contrib/__init__.py
python/dgl/contrib/__init__.py
+2
-0
python/dgl/contrib/dis_kvstore.py
python/dgl/contrib/dis_kvstore.py
+542
-0
python/dgl/network.py
python/dgl/network.py
+139
-0
src/c_api_common.h
src/c_api_common.h
+3
-0
src/graph/network.cc
src/graph/network.cc
+295
-89
src/graph/network.h
src/graph/network.h
+100
-16
tests/compute/test_kvstore.py
tests/compute/test_kvstore.py
+71
-0
No files found.
examples/mxnet/dis_kvstore/README.md
0 → 100644
View file @
77822769
# Usage of DGL distributed KVStore
This is a simple example shows how to use DGL distributed KVStore on MXNet locally.
In this example, we start two servers and four clients, and you can run the example by:
```
./run.sh
```
\ No newline at end of file
examples/mxnet/dis_kvstore/client.py
0 → 100644
View file @
77822769
# This is a simple MXNet client demo shows how to use DGL distributed kvstore.
# In this demo, we initialize two embeddings on server and push/pull data to/from it.
import
dgl
import
mxnet
as
mx
import
time
import
argparse
server_namebook
,
client_namebook
=
dgl
.
contrib
.
ReadNetworkConfigure
(
'config.txt'
)
def
start_client
(
args
):
# Initialize client and connect to server
client
=
dgl
.
contrib
.
KVClient
(
client_id
=
args
.
id
,
server_namebook
=
server_namebook
,
client_addr
=
client_namebook
[
args
.
id
])
client
.
connect
()
# Initialize data on server
client
.
init_data
(
name
=
'embed_0'
,
shape
=
[
10
,
3
],
init_type
=
'zero'
)
client
.
init_data
(
name
=
'embed_1'
,
shape
=
[
11
,
3
],
init_type
=
'uniform'
,
low
=
0.0
,
high
=
0.0
)
tensor_id
=
mx
.
nd
.
array
([
0
,
1
,
2
],
dtype
=
'int64'
)
tensor_data
=
mx
.
nd
.
array
([[
0.
,
0.
,
0.
,
],
[
1.
,
1.
,
1.
],
[
2.
,
2.
,
2.
]])
for
i
in
range
(
5
):
client
.
push
(
'embed_0'
,
tensor_id
,
tensor_data
)
client
.
push
(
'embed_1'
,
tensor_id
,
tensor_data
)
tensor_id
=
mx
.
nd
.
array
([
6
,
7
,
8
],
dtype
=
'int64'
)
for
i
in
range
(
5
):
client
.
push
(
'embed_0'
,
tensor_id
,
tensor_data
)
client
.
push
(
'embed_1'
,
tensor_id
,
tensor_data
)
client
.
barrier
()
if
client
.
get_id
()
==
0
:
tensor_id
=
mx
.
nd
.
array
([
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
],
dtype
=
'int64'
)
new_tensor_0
=
client
.
pull
(
'embed_0'
,
tensor_id
)
tensor_id
=
mx
.
nd
.
array
([
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
dtype
=
'int64'
)
new_tensor_1
=
client
.
pull
(
'embed_1'
,
tensor_id
)
client
.
push_all
(
'embed_0'
,
new_tensor_0
)
client
.
push_all
(
'embed_1'
,
new_tensor_1
)
new_tensor_2
=
client
.
pull_all
(
'embed_0'
)
new_tensor_3
=
client
.
pull_all
(
'embed_1'
)
print
(
"embed_0: "
)
print
(
new_tensor_2
)
print
(
"embed_1: "
)
print
(
new_tensor_3
)
# Shut-down all the servers
if
client
.
get_id
()
==
0
:
client
.
shut_down
()
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'kvstore'
)
parser
.
add_argument
(
"--id"
,
type
=
int
,
default
=
0
,
help
=
"node ID"
)
args
=
parser
.
parse_args
()
time
.
sleep
(
2
)
# wait server start
start_client
(
args
)
examples/mxnet/dis_kvstore/config.txt
0 → 100644
View file @
77822769
server 127.0.0.1:50051 0
server 127.0.0.1:50052 1
client 127.0.0.1:50053 0
client 127.0.0.1:50054 1
client 127.0.0.1:50055 2
client 127.0.0.1:50056 3
\ No newline at end of file
examples/mxnet/dis_kvstore/run.sh
0 → 100755
View file @
77822769
DGLBACKEND
=
mxnet python3 ./server.py
--id
0 &
DGLBACKEND
=
mxnet python3 ./server.py
--id
1 &
DGLBACKEND
=
mxnet python3 ./client.py
--id
0 &
DGLBACKEND
=
mxnet python3 ./client.py
--id
1 &
DGLBACKEND
=
mxnet python3 ./client.py
--id
2 &
DGLBACKEND
=
mxnet python3 ./client.py
--id
3
\ No newline at end of file
examples/mxnet/dis_kvstore/server.py
0 → 100644
View file @
77822769
# This is a simple MXNet server demo shows how to use DGL distributed kvstore.
# In this demo, we initialize two embeddings on server and push/pull data to/from it.
import
dgl
import
torch
import
argparse
server_namebook
,
client_namebook
=
dgl
.
contrib
.
ReadNetworkConfigure
(
'config.txt'
)
def
start_server
(
args
):
server
=
dgl
.
contrib
.
KVServer
(
server_id
=
args
.
id
,
client_namebook
=
client_namebook
,
server_addr
=
server_namebook
[
args
.
id
])
server
.
start
()
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'kvstore'
)
parser
.
add_argument
(
"--id"
,
type
=
int
,
default
=
0
,
help
=
"node ID"
)
args
=
parser
.
parse_args
()
start_server
(
args
)
examples/pytorch/dis_kvstore/README.md
0 → 100644
View file @
77822769
# Usage of DGL distributed KVStore
This is a simple example shows how to use DGL distributed KVStore on Pytorch locally.
In this example, we start two servers and four clients, and you can run the example by:
```
./run.sh
```
examples/pytorch/dis_kvstore/client.py
0 → 100644
View file @
77822769
# This is a simple pytorch client demo shows how to use DGL distributed kvstore.
# In this demo, we initialize two embeddings on server and push/pull data to/from it.
import
dgl
import
torch
import
time
import
argparse
server_namebook
,
client_namebook
=
dgl
.
contrib
.
ReadNetworkConfigure
(
'config.txt'
)
def
start_client
(
args
):
# Initialize client and connect to server
client
=
dgl
.
contrib
.
KVClient
(
client_id
=
args
.
id
,
server_namebook
=
server_namebook
,
client_addr
=
client_namebook
[
args
.
id
])
client
.
connect
()
# Initialize data on server
client
.
init_data
(
name
=
'embed_0'
,
shape
=
[
10
,
3
],
init_type
=
'zero'
)
client
.
init_data
(
name
=
'embed_1'
,
shape
=
[
11
,
3
],
init_type
=
'uniform'
,
low
=
0.0
,
high
=
0.0
)
tensor_id
=
torch
.
tensor
([
0
,
1
,
2
])
tensor_data
=
torch
.
tensor
([[
0.
,
0.
,
0.
,
],
[
1.
,
1.
,
1.
],
[
2.
,
2.
,
2.
]])
for
i
in
range
(
5
):
client
.
push
(
'embed_0'
,
tensor_id
,
tensor_data
)
client
.
push
(
'embed_1'
,
tensor_id
,
tensor_data
)
tensor_id
=
torch
.
tensor
([
6
,
7
,
8
])
for
i
in
range
(
5
):
client
.
push
(
'embed_0'
,
tensor_id
,
tensor_data
)
client
.
push
(
'embed_1'
,
tensor_id
,
tensor_data
)
client
.
barrier
()
if
client
.
get_id
()
==
0
:
tensor_id
=
torch
.
tensor
([
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
])
new_tensor_0
=
client
.
pull
(
'embed_0'
,
tensor_id
)
tensor_id
=
torch
.
tensor
([
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
new_tensor_1
=
client
.
pull
(
'embed_1'
,
tensor_id
)
client
.
push_all
(
'embed_0'
,
new_tensor_0
)
client
.
push_all
(
'embed_1'
,
new_tensor_1
)
new_tensor_2
=
client
.
pull_all
(
'embed_0'
)
new_tensor_3
=
client
.
pull_all
(
'embed_1'
)
print
(
"embed_0:"
)
print
(
new_tensor_2
)
print
(
"embed_1:"
)
print
(
new_tensor_3
)
# Shut-down all the servers
if
client
.
get_id
()
==
0
:
client
.
shut_down
()
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'kvstore'
)
parser
.
add_argument
(
"--id"
,
type
=
int
,
default
=
0
,
help
=
"node ID"
)
args
=
parser
.
parse_args
()
time
.
sleep
(
2
)
# wait server start
start_client
(
args
)
examples/pytorch/dis_kvstore/config.txt
0 → 100644
View file @
77822769
server 127.0.0.1:50051 0
server 127.0.0.1:50052 1
client 127.0.0.1:50053 0
client 127.0.0.1:50054 1
client 127.0.0.1:50055 2
client 127.0.0.1:50056 3
\ No newline at end of file
examples/pytorch/dis_kvstore/run.sh
0 → 100755
View file @
77822769
python3 ./server.py
--id
0 &
python3 ./server.py
--id
1 &
python3 ./client.py
--id
0 &
python3 ./client.py
--id
1 &
python3 ./client.py
--id
2 &
python3 ./client.py
--id
3
\ No newline at end of file
examples/pytorch/dis_kvstore/server.py
0 → 100644
View file @
77822769
# This is a simple pytorch server demo shows how to use DGL distributed kvstore.
# In this demo, we initialize two embeddings on server and push/pull data to/from it.
import
dgl
import
torch
import
argparse
server_namebook
,
client_namebook
=
dgl
.
contrib
.
ReadNetworkConfigure
(
'config.txt'
)
def
start_server
(
args
):
server
=
dgl
.
contrib
.
KVServer
(
server_id
=
args
.
id
,
client_namebook
=
client_namebook
,
server_addr
=
server_namebook
[
args
.
id
])
server
.
start
()
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'kvstore'
)
parser
.
add_argument
(
"--id"
,
type
=
int
,
default
=
0
,
help
=
"node ID"
)
args
=
parser
.
parse_args
()
start_server
(
args
)
python/dgl/backend/backend.py
View file @
77822769
...
@@ -817,6 +817,26 @@ def ones(shape, dtype, ctx):
...
@@ -817,6 +817,26 @@ def ones(shape, dtype, ctx):
"""
"""
pass
pass
def
uniform
(
shape
,
dtype
,
ctx
,
low
,
high
):
"""Crear a tensor with random value in an uniform
distribution between low (inclusive) and high (exclusive).
Parameters
----------
shape : tuple of int
The tensor shape.
dtype : data type
It should be one of the values in the data type dict.
ctx : context
The device of the result tensor.
Returns
-------
Tensor
The random tensor.
"""
pass
def
pad_packed_tensor
(
input
,
lengths
,
value
,
l_min
=
None
):
def
pad_packed_tensor
(
input
,
lengths
,
value
,
l_min
=
None
):
"""Pads a packed batch of variable length tensors with given value.
"""Pads a packed batch of variable length tensors with given value.
...
...
python/dgl/backend/mxnet/tensor.py
View file @
77822769
...
@@ -238,6 +238,9 @@ def zeros_like(input):
...
@@ -238,6 +238,9 @@ def zeros_like(input):
def
ones
(
shape
,
dtype
,
ctx
):
def
ones
(
shape
,
dtype
,
ctx
):
return
nd
.
ones
(
shape
,
dtype
=
dtype
,
ctx
=
ctx
)
return
nd
.
ones
(
shape
,
dtype
=
dtype
,
ctx
=
ctx
)
def
uniform
(
shape
,
dtype
,
ctx
,
low
,
high
):
return
nd
.
random
.
uniform
(
low
,
high
,
ctx
=
ctx
,
dtype
=
dtype
,
shape
=
shape
)
def
pad_packed_tensor
(
input
,
lengths
,
value
,
l_min
=
None
):
def
pad_packed_tensor
(
input
,
lengths
,
value
,
l_min
=
None
):
old_shape
=
input
.
shape
old_shape
=
input
.
shape
if
isinstance
(
lengths
,
nd
.
NDArray
):
if
isinstance
(
lengths
,
nd
.
NDArray
):
...
...
python/dgl/backend/pytorch/tensor.py
View file @
77822769
...
@@ -188,6 +188,9 @@ def zeros_like(input):
...
@@ -188,6 +188,9 @@ def zeros_like(input):
def
ones
(
shape
,
dtype
,
ctx
):
def
ones
(
shape
,
dtype
,
ctx
):
return
th
.
ones
(
shape
,
dtype
=
dtype
,
device
=
ctx
)
return
th
.
ones
(
shape
,
dtype
=
dtype
,
device
=
ctx
)
def
uniform
(
shape
,
dtype
,
ctx
,
low
,
high
):
return
th
.
empty
(
shape
,
dtype
=
dtype
,
device
=
ctx
).
uniform_
(
low
,
high
)
def
pad_packed_tensor
(
input
,
lengths
,
value
,
l_min
=
None
):
def
pad_packed_tensor
(
input
,
lengths
,
value
,
l_min
=
None
):
old_shape
=
input
.
shape
old_shape
=
input
.
shape
if
isinstance
(
lengths
,
th
.
Tensor
):
if
isinstance
(
lengths
,
th
.
Tensor
):
...
...
python/dgl/contrib/__init__.py
View file @
77822769
from
.
import
sampling
from
.
import
sampling
from
.
import
graph_store
from
.
import
graph_store
from
.dis_kvstore
import
KVClient
,
KVServer
from
.dis_kvstore
import
ReadNetworkConfigure
python/dgl/contrib/dis_kvstore.py
0 → 100644
View file @
77822769
# This file contains DGL distributed kvstore APIs.
from
..network
import
_create_sender
,
_create_receiver
from
..network
import
_finalize_sender
,
_finalize_receiver
from
..network
import
_network_wait
,
_add_receiver_addr
from
..network
import
_receiver_wait
,
_sender_connect
from
..network
import
_send_kv_msg
,
_recv_kv_msg
from
..network
import
KVMsgType
,
KVStoreMsg
import
math
import
dgl.backend
as
F
import
numpy
as
np
def
ReadNetworkConfigure
(
filename
):
"""Read networking configuration from file.
Parameters
----------
filename : str
name of target configure file
Returns
-------
dict
server namebook
dict
client namebook
"""
server_namebook
=
{}
client_namebook
=
{}
lines
=
[
line
.
rstrip
(
'
\n
'
)
for
line
in
open
(
filename
)]
for
line
in
lines
:
node_type
,
addr
,
node_id
=
line
.
split
(
' '
)
if
node_type
==
'server'
:
server_namebook
[
int
(
node_id
)]
=
addr
elif
node_type
==
'client'
:
client_namebook
[
int
(
node_id
)]
=
addr
else
:
raise
RuntimeError
(
"Unknown node type: %s"
,
node_type
)
return
server_namebook
,
client_namebook
class
KVServer
(
object
):
"""KVServer is a lightweight key-value store service for DGL distributed training.
In practice, developers can use KVServer to hold large-scale graph features or
graph embeddings across machines in a distributed setting or storing them in one standalone
machine with big memory capability. DGL KVServer uses a very simple range-partition scheme to
partition data into different KVServer nodes. For example, if the total embedding size is 200 and
we have two KVServer nodes, the data (0~99) will be stored in kvserver_0, and the data (100~199) will
be stored in kvserver_1.
For KVServer, user can re-wriite UDF function for _push_handler and _pull_handler.
DO NOT use KVServer in multiple threads!
Parameters
----------
server_id : int
KVServer's ID (start from 0).
client_namebook : dict
IP address namebook of KVClient, where the key is the client's ID
(start from 0) and the value is client's IP address, e.g.,
{ 0:'168.12.23.45:50051',
1:'168.12.23.21:50051',
2:'168.12.46.12:50051' }
server_addr : str
IP address of current KVServer node, e.g., '127.0.0.1:50051'
net_type : str
networking type, e.g., 'socket' (default) or 'mpi'.
"""
def
__init__
(
self
,
server_id
,
client_namebook
,
server_addr
,
net_type
=
'socket'
):
assert
server_id
>=
0
,
'server_id cannot be a negative number.'
assert
len
(
client_namebook
)
>
0
,
'client_namebook cannot be empty.'
assert
len
(
server_addr
.
split
(
':'
))
==
2
,
'Incorrect IP format.'
self
.
_is_init
=
set
()
# Contains tensor name
self
.
_data_store
=
{}
# Key is name string and value is tensor
self
.
_barrier_count
=
0
;
self
.
_server_id
=
server_id
self
.
_client_namebook
=
client_namebook
self
.
_client_count
=
len
(
client_namebook
)
self
.
_addr
=
server_addr
self
.
_sender
=
_create_sender
(
net_type
)
self
.
_receiver
=
_create_receiver
(
net_type
)
def
__del__
(
self
):
"""Finalize KVServer
"""
_finalize_sender
(
self
.
_sender
)
_finalize_receiver
(
self
.
_receiver
)
def
start
(
self
):
"""Start service of KVServer
"""
server_ip
,
server_port
=
self
.
_addr
.
split
(
':'
)
_receiver_wait
(
self
.
_receiver
,
server_ip
,
int
(
server_port
),
self
.
_client_count
)
_network_wait
()
# wait client's start
for
ID
,
addr
in
self
.
_client_namebook
.
items
():
client_ip
,
client_port
=
addr
.
split
(
':'
)
_add_receiver_addr
(
self
.
_sender
,
client_ip
,
int
(
client_port
),
ID
)
_sender_connect
(
self
.
_sender
)
# Service loop
while
True
:
msg
=
_recv_kv_msg
(
self
.
_receiver
)
if
msg
.
type
==
KVMsgType
.
INIT
:
if
(
msg
.
name
in
self
.
_is_init
)
==
False
:
# we hack the msg format here:
# msg.id store the shape of target tensor
# msg.data has two row, and the first row is
# the init_type, [0, 0] means 'zero' and [1,1]
# means 'uniform'. The second row is the min & max threshold.
data_shape
=
F
.
asnumpy
(
msg
.
id
).
tolist
()
row_0
=
(
F
.
asnumpy
(
msg
.
data
).
tolist
())[
0
]
row_1
=
(
F
.
asnumpy
(
msg
.
data
).
tolist
())[
1
]
init_type
=
'zero'
if
row_0
[
0
]
==
0.0
else
'uniform'
self
.
_init_data
(
name
=
msg
.
name
,
shape
=
data_shape
,
init_type
=
init_type
,
low
=
row_1
[
0
],
high
=
row_1
[
1
])
self
.
_is_init
.
add
(
msg
.
name
)
elif
msg
.
type
==
KVMsgType
.
PUSH
:
# convert global ID to local ID
local_id
=
self
.
_remap_id
(
msg
.
name
,
msg
.
id
)
self
.
_push_handler
(
msg
.
name
,
local_id
,
msg
.
data
)
elif
msg
.
type
==
KVMsgType
.
PULL
:
# convert global ID to local ID
local_id
=
self
.
_remap_id
(
msg
.
name
,
msg
.
id
)
res_tensor
=
self
.
_pull_handler
(
msg
.
name
,
local_id
)
back_msg
=
KVStoreMsg
(
type
=
KVMsgType
.
PULL_BACK
,
rank
=
self
.
_server_id
,
name
=
msg
.
name
,
id
=
msg
.
id
,
data
=
res_tensor
)
_send_kv_msg
(
self
.
_sender
,
back_msg
,
msg
.
rank
)
elif
msg
.
type
==
KVMsgType
.
BARRIER
:
self
.
_barrier_count
+=
1
if
self
.
_barrier_count
==
self
.
_client_count
:
back_msg
=
KVStoreMsg
(
type
=
KVMsgType
.
BARRIER
,
rank
=
self
.
_server_id
,
name
=
None
,
id
=
None
,
data
=
None
)
for
i
in
range
(
self
.
_client_count
):
_send_kv_msg
(
self
.
_sender
,
back_msg
,
i
)
self
.
_barrier_count
=
0
elif
msg
.
type
==
KVMsgType
.
FINAL
:
print
(
"Exit KVStore service, server ID: %d"
%
self
.
get_id
())
break
# exit loop
else
:
raise
RuntimeError
(
'Unknown type of kvstore message: %d'
%
msg
.
type
.
value
)
def
get_id
(
self
):
"""Get server id
Return
------
int
KVServer ID
"""
return
self
.
_server_id
def
_init_data
(
self
,
name
,
shape
,
init_type
,
low
,
high
):
"""Initialize kvstore tensor.
Parameters
----------
name : str
data name
shape : list of int
The tensor shape
init_type : str
initialize method, including 'zero' and 'uniform'
low : float
min threshold
high : float
max threshold
"""
if
init_type
==
'uniform'
:
self
.
_data_store
[
name
]
=
F
.
uniform
(
shape
=
shape
,
dtype
=
F
.
float32
,
ctx
=
F
.
cpu
(),
low
=
low
,
high
=
high
)
elif
init_type
==
'zero'
:
self
.
_data_store
[
name
]
=
F
.
zeros
(
shape
=
shape
,
dtype
=
F
.
float32
,
ctx
=
F
.
cpu
())
else
:
raise
RuntimeError
(
'Unknown initial method'
)
def
_push_handler
(
self
,
name
,
ID
,
data
):
"""User-defined handler for PUSH message.
On default, _push_handler perform ADD operation for the tensor.
Parameters
----------
name : str
data name
ID : tensor (mx.ndarray or torch.tensor)
a vector storing the IDs that has been re-mapped to local id.
data : tensor (mx.ndarray or torch.tensor)
a matrix with the same row size of id
"""
for
idx
in
range
(
ID
.
shape
[
0
]):
# For each row
self
.
_data_store
[
name
][
ID
[
idx
]]
+=
data
[
idx
]
def
_pull_handler
(
self
,
name
,
ID
):
"""User-defined handler for PULL operation.
On default, _pull_handler perform gather_row() operation for the tensor.
Parameters
----------
name : str
data name
ID : tensor (mx.ndarray or torch.tensor)
a vector storing the IDs that has been re-mapped to local id.
Return
------
tensor
a matrix with the same row size of ID
"""
new_tensor
=
F
.
gather_row
(
self
.
_data_store
[
name
],
ID
)
return
new_tensor
def
_remap_id
(
self
,
name
,
ID
):
"""Re-mapping global-ID to local-ID.
Parameters
----------
name : str
data name
ID : tensor (mx.ndarray or torch.tensor)
a vector storing the global data ID
Return
------
tensor
re-mapped lcoal ID
"""
row_size
=
self
.
_data_store
[
name
].
shape
[
0
]
return
ID
%
row_size
class
KVClient
(
object
):
"""KVClient is used to push/pull tensors to/from KVServer on DGL trainer.
There are three operations supported by KVClient:
* init_data(name, shape, low, high): initialize tensor on KVServer
* push(name, id, data): push data to KVServer
* pull(name, id): pull data from KVServer
* shut_down(): shut down all KVServer nodes
DO NOT use KVClient in multiple threads!
Parameters
----------
client_id : int
KVClient's ID (start from 0)
server_namebook: dict
IP address namebook of KVServer, where key is the KVServer's ID
(start from 0) and value is the server's IP address, e.g.,
{ 0:'168.12.23.45:50051',
1:'168.12.23.21:50051',
2:'168.12.46.12:50051' }
client_addr : str
IP address of current KVClient, e.g., '168.12.23.22:50051'
net_type : str
networking type, e.g., 'socket' (default) or 'mpi'.
"""
def
__init__
(
self
,
client_id
,
server_namebook
,
client_addr
,
net_type
=
'socket'
):
assert
client_id
>=
0
,
'client_id cannot be a nagative number.'
assert
len
(
server_namebook
)
>
0
,
'server_namebook cannot be empty.'
assert
len
(
client_addr
.
split
(
':'
))
==
2
,
'Incorrect IP format.'
# self._data_size is a key-value store where the key is data name
# and value is the size of tensor. It is used to partition data into
# different KVServer nodes.
self
.
_data_size
=
{}
self
.
_client_id
=
client_id
self
.
_server_namebook
=
server_namebook
self
.
_server_count
=
len
(
server_namebook
)
self
.
_addr
=
client_addr
self
.
_sender
=
_create_sender
(
net_type
)
self
.
_receiver
=
_create_receiver
(
net_type
)
def
__del__
(
self
):
"""Finalize KVClient
"""
_finalize_sender
(
self
.
_sender
)
_finalize_receiver
(
self
.
_receiver
)
def
connect
(
self
):
"""Connect to all KVServer nodes
"""
for
ID
,
addr
in
self
.
_server_namebook
.
items
():
server_ip
,
server_port
=
addr
.
split
(
':'
)
_add_receiver_addr
(
self
.
_sender
,
server_ip
,
int
(
server_port
),
ID
)
_sender_connect
(
self
.
_sender
)
client_ip
,
client_port
=
self
.
_addr
.
split
(
':'
)
_receiver_wait
(
self
.
_receiver
,
client_ip
,
int
(
client_port
),
self
.
_server_count
)
def
init_data
(
self
,
name
,
shape
,
init_type
=
'zero'
,
low
=
0.0
,
high
=
0.0
):
"""Initialize kvstore tensor
Parameters
----------
name : str
data name
shape : list of int
shape of tensor
init_type : str
initialize method, including 'zero' and 'uniform'
low : float
min threshold, if use 'uniform'
high : float
max threshold, if use 'uniform'
"""
self
.
_data_size
[
name
]
=
shape
[
0
]
count
=
math
.
ceil
(
shape
[
0
]
/
self
.
_server_count
)
# We hack the msg format here
init_type
=
0.0
if
init_type
==
'zero'
else
1.0
threshold
=
F
.
tensor
([[
init_type
,
init_type
],
[
low
,
high
]])
# partition shape on server
for
server_id
in
range
(
self
.
_server_count
):
par_shape
=
shape
.
copy
()
if
shape
[
0
]
-
server_id
*
count
>=
count
:
par_shape
[
0
]
=
count
else
:
par_shape
[
0
]
=
shape
[
0
]
-
server_id
*
count
tensor_shape
=
F
.
tensor
(
par_shape
)
msg
=
KVStoreMsg
(
type
=
KVMsgType
.
INIT
,
rank
=
self
.
_client_id
,
name
=
name
,
id
=
tensor_shape
,
data
=
threshold
)
_send_kv_msg
(
self
.
_sender
,
msg
,
server_id
)
def
push
(
self
,
name
,
ID
,
data
):
"""Push sparse message to KVServer
The push() API will partition message into different
KVServer nodes automatically.
Note that we assume the row Ids in ID is in the ascending order.
Parameters
----------
name : str
data name
ID : tensor (mx.ndarray or torch.tensor)
a vector storing the global IDs
data : tensor (mx.ndarray or torch.tensor)
a tensor with the same row size of id
"""
assert
F
.
ndim
(
ID
)
==
1
,
'ID must be a vector.'
assert
F
.
shape
(
ID
)[
0
]
==
F
.
shape
(
data
)[
0
],
'The data must has the same row size with ID.'
group_size
=
[
0
]
*
self
.
_server_count
numpy_id
=
F
.
asnumpy
(
ID
)
count
=
math
.
ceil
(
self
.
_data_size
[
name
]
/
self
.
_server_count
)
server_id
=
numpy_id
/
count
for
id
in
server_id
:
group_size
[
int
(
id
)]
+=
1
min_idx
=
0
max_idx
=
0
for
idx
in
range
(
self
.
_server_count
):
if
group_size
[
idx
]
==
0
:
continue
max_idx
+=
group_size
[
idx
]
range_id
=
ID
[
min_idx
:
max_idx
]
range_data
=
data
[
min_idx
:
max_idx
]
min_idx
=
max_idx
msg
=
KVStoreMsg
(
type
=
KVMsgType
.
PUSH
,
rank
=
self
.
_client_id
,
name
=
name
,
id
=
range_id
,
data
=
range_data
)
_send_kv_msg
(
self
.
_sender
,
msg
,
idx
)
def
push_all
(
self
,
name
,
data
):
"""Push the whole data to KVServer
The push_all() API will partition message into different
KVServer nodes automatically.
Note that we assume the row Ids in ID is in the ascending order.
Parameters
----------
name : str
data name
data : tensor (mx.ndarray or torch.tensor)
data tensor
"""
ID
=
F
.
zerocopy_from_numpy
(
np
.
arange
(
F
.
shape
(
data
)[
0
]))
self
.
push
(
name
,
ID
,
data
)
def
pull
(
self
,
name
,
ID
):
"""Pull sparse message from KVServer
Note that we assume the row Ids in ID is in the ascending order.
Parameters
----------
name : str
data name
ID : tensor (mx.ndarray or torch.tensor)
a vector storing the IDs
Return
------
tensor
a tensor with the same row size of ID
"""
assert
F
.
ndim
(
ID
)
==
1
,
'ID must be a vector.'
group_size
=
[
0
]
*
self
.
_server_count
numpy_id
=
F
.
asnumpy
(
ID
)
count
=
math
.
ceil
(
self
.
_data_size
[
name
]
/
self
.
_server_count
)
server_id
=
numpy_id
/
count
for
id
in
server_id
:
group_size
[
int
(
id
)]
+=
1
min_idx
=
0
max_idx
=
0
server_count
=
0
for
idx
in
range
(
self
.
_server_count
):
if
group_size
[
idx
]
==
0
:
continue
server_count
+=
1
max_idx
+=
group_size
[
idx
]
range_id
=
ID
[
min_idx
:
max_idx
]
min_idx
=
max_idx
msg
=
KVStoreMsg
(
type
=
KVMsgType
.
PULL
,
rank
=
self
.
_client_id
,
name
=
name
,
id
=
range_id
,
data
=
None
)
_send_kv_msg
(
self
.
_sender
,
msg
,
idx
)
# Recv back message
msg_list
=
[]
for
idx
in
range
(
self
.
_server_count
):
if
group_size
[
idx
]
==
0
:
continue
msg
=
_recv_kv_msg
(
self
.
_receiver
)
assert
msg
.
type
==
KVMsgType
.
PULL_BACK
,
'Recv kv msg error.'
msg_list
.
append
(
msg
)
return
self
.
_merge_msg
(
msg_list
)
def
pull_all
(
self
,
name
):
"""Pull the whole data from KVServer
Note that we assume the row Ids in ID is in the ascending order.
Parameters
----------
name : str
data name
Return
------
tensor
target data tensor
"""
ID
=
F
.
zerocopy_from_numpy
(
np
.
arange
(
self
.
_data_size
[
name
]))
return
self
.
pull
(
name
,
ID
)
def
barrier
(
self
):
"""Barrier for all client nodes
This API will be blocked untill all the clients call this API.
"""
msg
=
KVStoreMsg
(
type
=
KVMsgType
.
BARRIER
,
rank
=
self
.
_client_id
,
name
=
None
,
id
=
None
,
data
=
None
)
for
server_id
in
range
(
self
.
_server_count
):
_send_kv_msg
(
self
.
_sender
,
msg
,
server_id
)
for
server_id
in
range
(
self
.
_server_count
):
back_msg
=
_recv_kv_msg
(
self
.
_receiver
)
assert
back_msg
.
type
==
KVMsgType
.
BARRIER
,
'Recv kv msg error.'
def
shut_down
(
self
):
"""Shutdown all KVServer nodes
We usually invoke this API by just one client (e.g., client_0).
"""
for
server_id
in
range
(
self
.
_server_count
):
msg
=
KVStoreMsg
(
type
=
KVMsgType
.
FINAL
,
rank
=
self
.
_client_id
,
name
=
None
,
id
=
None
,
data
=
None
)
_send_kv_msg
(
self
.
_sender
,
msg
,
server_id
)
def
get_id
(
self
):
"""Get client id
Return
------
int
KVClient ID
"""
return
self
.
_client_id
def
_sort_func
(
self
,
msg
):
"""Sort function for KVStoreMsg: sort message by rank
Parameters
----------
msg : KVStoreMsg
KVstore message
"""
return
msg
.
rank
def
_merge_msg
(
self
,
msg_list
):
"""Merge separated message to a big matrix
Parameters
----------
msg_list : list
a list of KVStoreMsg
Return
------
tensor (mx.ndarray or torch.tensor)
a merged data matrix
"""
msg_list
.
sort
(
key
=
self
.
_sort_func
)
return
F
.
cat
([
msg
.
data
for
msg
in
msg_list
],
0
)
\ No newline at end of file
python/dgl/network.py
View file @
77822769
"""DGL Distributed Training Infrastructure."""
"""DGL Distributed Training Infrastructure."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
import
time
import
signal
from
enum
import
Enum
from
collections
import
namedtuple
import
dgl.backend
as
F
from
._ffi.function
import
_init_api
from
._ffi.function
import
_init_api
from
.nodeflow
import
NodeFlow
from
.nodeflow
import
NodeFlow
from
.
import
utils
from
.
import
utils
...
@@ -10,6 +16,21 @@ _init_api("dgl.network")
...
@@ -10,6 +16,21 @@ _init_api("dgl.network")
################################ Common Network Components ##################################
################################ Common Network Components ##################################
_WAIT_TIME_SEC
=
3
# 3 seconds
def
keyboard_interrupt_handler
(
my_signal
):
"""Users can use [Ctl + C] to exit loop service
"""
print
(
"KeyboardInterrupt (ID: {}) has been caught. Cleaning up DGL ..."
.
format
(
my_signal
))
exit
(
0
)
signal
.
signal
(
signal
.
SIGINT
,
keyboard_interrupt_handler
)
def
_network_wait
():
"""Sleep for a few seconds
"""
time
.
sleep
(
_WAIT_TIME_SEC
)
def
_create_sender
(
net_type
):
def
_create_sender
(
net_type
):
"""Create a Sender communicator via C api
"""Create a Sender communicator via C api
...
@@ -153,3 +174,121 @@ def _recv_nodeflow(receiver, graph):
...
@@ -153,3 +174,121 @@ def _recv_nodeflow(receiver, graph):
return
res
return
res
else
:
else
:
return
NodeFlow
(
graph
,
res
)
return
NodeFlow
(
graph
,
res
)
################################ Distributed KVStore Components ################################
class
KVMsgType
(
Enum
):
"""Type of kvstore message
"""
FINAL
=
1
INIT
=
2
PUSH
=
3
PULL
=
4
PULL_BACK
=
5
BARRIER
=
6
KVStoreMsg
=
namedtuple
(
"KVStoreMsg"
,
"type rank name id data"
)
"""Message of DGL kvstore
Data Field
----------
type : KVMsgType
Type of DGL kvstore message
rank : int
sender's ID
name : str
data name
id : tensor (mx.ndarray or torch.tensor)
data vector storing the global IDs
data : tensor (mx.ndarray or torch.tensor)
data matrix with the same row size of id
"""
def
_send_kv_msg
(
sender
,
msg
,
recv_id
):
"""Send kvstore message.
Parameters
----------
sender : ctypes.c_void_p
C sender handle
msg : KVStoreMsg
kvstore message
recv_id : int
receiver's ID
"""
if
msg
.
type
==
KVMsgType
.
PULL
:
tensor_id
=
F
.
zerocopy_to_dgl_ndarray
(
msg
.
id
)
_CAPI_SenderSendKVMsg
(
sender
,
int
(
recv_id
),
msg
.
type
.
value
,
msg
.
rank
,
msg
.
name
,
tensor_id
)
elif
msg
.
type
in
(
KVMsgType
.
FINAL
,
KVMsgType
.
BARRIER
):
_CAPI_SenderSendKVMsg
(
sender
,
int
(
recv_id
),
msg
.
type
.
value
,
msg
.
rank
)
else
:
tensor_id
=
F
.
zerocopy_to_dgl_ndarray
(
msg
.
id
)
data
=
F
.
zerocopy_to_dgl_ndarray
(
msg
.
data
)
_CAPI_SenderSendKVMsg
(
sender
,
int
(
recv_id
),
msg
.
type
.
value
,
msg
.
rank
,
msg
.
name
,
tensor_id
,
data
)
def
_recv_kv_msg
(
receiver
):
"""Receive kvstore message.
Parameters
----------
receiver : ctypes.c_void_p
C Receiver handle
Return
------
KVStoreMsg
kvstore message
"""
msg_ptr
=
CAPI_ReceiverRecvKVMsg
(
receiver
)
msg_type
=
KVMsgType
(
_CAPI_ReceiverGetKVMsgType
(
msg_ptr
))
rank
=
_CAPI_ReceiverGetKVMsgRank
(
msg_ptr
)
if
msg_type
==
KVMsgType
.
PULL
:
name
=
_CAPI_ReceiverGetKVMsgName
(
msg_ptr
)
tensor_id
=
F
.
zerocopy_from_dgl_ndarray
(
_CAPI_ReceiverGetKVMsgID
(
msg_ptr
))
msg
=
KVStoreMsg
(
type
=
msg_type
,
rank
=
rank
,
name
=
name
,
id
=
tensor_id
,
data
=
None
)
return
msg
elif
msg_type
in
(
KVMsgType
.
FINAL
,
KVMsgType
.
BARRIER
):
msg
=
KVStoreMsg
(
type
=
msg_type
,
rank
=
rank
,
name
=
None
,
id
=
None
,
data
=
None
)
return
msg
else
:
name
=
_CAPI_ReceiverGetKVMsgName
(
msg_ptr
)
tensor_id
=
F
.
zerocopy_from_dgl_ndarray
(
_CAPI_ReceiverGetKVMsgID
(
msg_ptr
))
data
=
F
.
zerocopy_from_dgl_ndarray
(
_CAPI_ReceiverGetKVMsgData
(
msg_ptr
))
msg
=
KVStoreMsg
(
type
=
msg_type
,
rank
=
rank
,
name
=
name
,
id
=
tensor_id
,
data
=
data
)
return
msg
raise
RuntimeError
(
'Unknown message type: %d'
%
msg_type
.
value
)
src/c_api_common.h
View file @
77822769
...
@@ -39,6 +39,9 @@ namespace dgl {
...
@@ -39,6 +39,9 @@ namespace dgl {
// Communicator handler type
// Communicator handler type
typedef
void
*
CommunicatorHandle
;
typedef
void
*
CommunicatorHandle
;
// KVstore message handler type
typedef
void
*
KVMsgHandle
;
/*! \brief Enum type for bool value with unknown */
/*! \brief Enum type for bool value with unknown */
enum
BoolFlag
{
enum
BoolFlag
{
kBoolUnknown
=
-
1
,
kBoolUnknown
=
-
1
,
...
...
src/graph/network.cc
View file @
77822769
This diff is collapsed.
Click to expand it.
src/graph/network.h
View file @
77822769
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include <string.h>
#include <string.h>
#include <vector>
#include <vector>
#include <string>
#include "../c_api_common.h"
#include "../c_api_common.h"
#include "./network/msg_queue.h"
#include "./network/msg_queue.h"
...
@@ -24,38 +25,66 @@ namespace network {
...
@@ -24,38 +25,66 @@ namespace network {
// TODO(chao): Make this number configurable
// TODO(chao): Make this number configurable
const
int64_t
kQueueSize
=
200
*
1024
*
1024
;
const
int64_t
kQueueSize
=
200
*
1024
*
1024
;
/*!
* \brief Create NDArray from raw data
*/
NDArray
CreateNDArrayFromRaw
(
std
::
vector
<
int64_t
>
shape
,
DLDataType
dtype
,
DLContext
ctx
,
void
*
raw
);
/*!
/*!
* \brief Message type for DGL distributed training
* \brief Message type for DGL distributed training
*/
*/
enum
MessageType
{
enum
MessageType
{
/*!
/*!
* \brief Message for send/recv NodeFlow
* \brief Message for send/recv NodeFlow
*/
*/
kNodeFlowMsg
=
0
,
kNodeFlowMsg
=
0
,
/*!
/*!
* \brief Message for end-signal
* \brief Message for end-signal
*/
*/
kEndMsg
=
1
kFinalMsg
=
1
,
/*!
* \brief Initialize KVStore
*/
kInitMsg
=
2
,
/*!
* \brief Push msg to KVStore
*/
kPushMsg
=
3
,
/*!
* \brief Pull msg from KVStore
*/
kPullMsg
=
4
,
/*!
* \brief PullBack msg from KVStore
*/
kPullBackMsg
=
5
,
/*!
* \brief Barrier msg for KVStore
*/
kBarrierMsg
=
6
};
};
/*!
/*!
* \brief Meta data for
communicator
message
* \brief Meta data for
NDArray
message
*/
*/
class
Msg
Meta
{
class
Array
Meta
{
public:
public:
/*!
/*!
* \brief
Msg
Meta constructor.
* \brief
Array
Meta constructor.
* \param msg_type type of message
* \param msg_type type of message
*/
*/
explicit
Msg
Meta
(
int
msg_type
)
explicit
Array
Meta
(
int
msg_type
)
:
msg_type_
(
msg_type
),
ndarray_count_
(
0
)
{}
:
msg_type_
(
msg_type
),
ndarray_count_
(
0
)
{}
/*!
/*!
* \brief Construct
Msg
Meta from binary data buffer.
* \brief Construct
Array
Meta from binary data buffer.
* \param buffer data buffer
* \param buffer data buffer
* \param size data size
* \param size data size
*/
*/
Msg
Meta
(
char
*
buffer
,
int64_t
size
)
{
Array
Meta
(
char
*
buffer
,
int64_t
size
)
{
CHECK_NOTNULL
(
buffer
);
CHECK_NOTNULL
(
buffer
);
this
->
Deserialize
(
buffer
,
size
);
this
->
Deserialize
(
buffer
,
size
);
}
}
...
@@ -75,20 +104,20 @@ class MsgMeta {
...
@@ -75,20 +104,20 @@ class MsgMeta {
}
}
/*!
/*!
* \brief Add NDArray meta data to
Msg
Meta
* \brief Add NDArray meta data to
Array
Meta
* \param array DGL NDArray
* \param array DGL NDArray
*/
*/
void
AddArray
(
const
NDArray
&
array
);
void
AddArray
(
const
NDArray
&
array
);
/*!
/*!
* \brief Serialize
Msg
Meta to data buffer
* \brief Serialize
Array
Meta to data buffer
* \param size size of serialized message
* \param size size of serialized message
* \return pointer of data buffer
* \return pointer of data buffer
*/
*/
char
*
Serialize
(
int64_t
*
size
);
char
*
Serialize
(
int64_t
*
size
);
/*!
/*!
* \brief Deserialize
Msg
Meta from data buffer
* \brief Deserialize
Array
Meta from data buffer
* \param buffer data buffer
* \param buffer data buffer
* \param size size of data buffer
* \param size size of data buffer
*/
*/
...
@@ -111,6 +140,61 @@ class MsgMeta {
...
@@ -111,6 +140,61 @@ class MsgMeta {
std
::
vector
<
int64_t
>
data_shape_
;
std
::
vector
<
int64_t
>
data_shape_
;
};
};
/*!
* \brief C structure for holding DGL KVServer message
*/
class
KVStoreMsg
{
public:
/*!
* \brief KVStoreMsg constructor.
*/
KVStoreMsg
()
{}
/*!
* \brief Construct KVStoreMsg from binary data buffer.
* \param buffer data buffer
* \param size data size
*/
KVStoreMsg
(
char
*
buffer
,
int64_t
size
)
{
CHECK_NOTNULL
(
buffer
);
this
->
Deserialize
(
buffer
,
size
);
}
/*!
* \brief Serialize KVStoreMsg to data buffer
* Note that we don't serialize ID and data here.
* \param size size of serialized message
* \return pointer of data buffer
*/
char
*
Serialize
(
int64_t
*
size
);
/*!
* \brief Deserialize KVStoreMsg from data buffer
* \param buffer data buffer
* \param size size of data buffer
*/
void
Deserialize
(
char
*
buffer
,
int64_t
size
);
/*!
* \brief Message type of kvstore
*/
int
msg_type
;
/*!
* \brief Sender's ID
*/
int
rank
;
/*!
* \brief data name
*/
std
::
string
name
;
/*!
* \brief data ID
*/
NDArray
id
;
/*!
* \brief data matrix
*/
NDArray
data
;
};
}
// namespace network
}
// namespace network
}
// namespace dgl
}
// namespace dgl
...
...
tests/compute/test_kvstore.py
0 → 100644
View file @
77822769
import
backend
as
F
import
numpy
as
np
import
scipy
as
sp
import
dgl
import
torch
from
dgl
import
utils
import
os
import
time
client_namebook
=
{
0
:
'127.0.0.1:50061'
}
server_namebook
=
{
0
:
'127.0.0.1:50062'
}
def
start_server
():
server
=
dgl
.
contrib
.
KVServer
(
server_id
=
0
,
client_namebook
=
client_namebook
,
server_addr
=
server_namebook
[
0
])
server
.
start
()
def
start_client
():
client
=
dgl
.
contrib
.
KVClient
(
client_id
=
0
,
server_namebook
=
server_namebook
,
client_addr
=
client_namebook
[
0
])
client
.
connect
()
client
.
init_data
(
name
=
'embed_0'
,
shape
=
[
10
,
3
],
init_type
=
'zero'
)
client
.
init_data
(
name
=
'embed_1'
,
shape
=
[
11
,
3
],
init_type
=
'uniform'
,
low
=
0.0
,
high
=
0.0
)
tensor_id
=
torch
.
tensor
([
0
,
1
,
2
])
tensor_data
=
torch
.
tensor
([[
0.
,
0.
,
0.
,
],
[
1.
,
1.
,
1.
],
[
2.
,
2.
,
2.
]])
# Push
for
i
in
range
(
5
):
client
.
push
(
'embed_0'
,
tensor_id
,
tensor_data
)
client
.
push
(
'embed_1'
,
tensor_id
,
tensor_data
)
tensor_id
=
torch
.
tensor
([
6
,
7
,
8
])
for
i
in
range
(
5
):
client
.
push
(
'embed_0'
,
tensor_id
,
tensor_data
)
client
.
push
(
'embed_1'
,
tensor_id
,
tensor_data
)
# Pull
tensor_id
=
torch
.
tensor
([
0
,
1
,
2
,
6
,
7
,
8
])
new_tensor_0
=
client
.
pull
(
'embed_0'
,
tensor_id
)
new_tensor_1
=
client
.
pull
(
'embed_1'
,
tensor_id
)
target_tensor
=
torch
.
tensor
(
[[
0.
,
0.
,
0.
],
[
5.
,
5.
,
5.
],
[
10.
,
10.
,
10.
],
[
0.
,
0.
,
0.
],
[
5.
,
5.
,
5.
],
[
10.
,
10.
,
10.
]])
assert
torch
.
equal
(
new_tensor_0
,
target_tensor
)
==
True
assert
torch
.
equal
(
new_tensor_1
,
target_tensor
)
==
True
client
.
shut_down
()
if
__name__
==
'__main__'
:
pid
=
os
.
fork
()
if
pid
==
0
:
start_server
()
else
:
time
.
sleep
(
2
)
# wait server start
start_client
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment