Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
0d9b6bfd
Unverified
Commit
0d9b6bfd
authored
Jul 03, 2023
by
Rhett Ying
Committed by
GitHub
Jul 03, 2023
Browse files
[TensorpipeDeprecation] remove long live server support from DistDGL (#5931)
parent
4015c5fe
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
9 additions
and
160 deletions
+9
-160
python/dgl/distributed/__init__.py
python/dgl/distributed/__init__.py
+1
-1
python/dgl/distributed/constants.py
python/dgl/distributed/constants.py
+0
-1
python/dgl/distributed/dist_context.py
python/dgl/distributed/dist_context.py
+0
-2
python/dgl/distributed/dist_graph.py
python/dgl/distributed/dist_graph.py
+0
-5
python/dgl/distributed/kvstore.py
python/dgl/distributed/kvstore.py
+0
-3
python/dgl/distributed/rpc.py
python/dgl/distributed/rpc.py
+1
-3
python/dgl/distributed/rpc_client.py
python/dgl/distributed/rpc_client.py
+0
-41
python/dgl/distributed/rpc_server.py
python/dgl/distributed/rpc_server.py
+1
-9
python/dgl/distributed/server_state.py
python/dgl/distributed/server_state.py
+1
-9
tests/distributed/test_dist_graph_store.py
tests/distributed/test_dist_graph_store.py
+0
-17
tests/distributed/test_distributed_sampling.py
tests/distributed/test_distributed_sampling.py
+0
-9
tests/distributed/test_mp_dataloader.py
tests/distributed/test_mp_dataloader.py
+0
-9
tests/distributed/test_rpc.py
tests/distributed/test_rpc.py
+5
-51
No files found.
python/dgl/distributed/__init__.py
View file @
0d9b6bfd
...
@@ -16,6 +16,6 @@ from .partition import (
...
@@ -16,6 +16,6 @@ from .partition import (
partition_graph
,
partition_graph
,
)
)
from
.rpc
import
*
from
.rpc
import
*
from
.rpc_client
import
connect_to_server
,
shutdown_servers
from
.rpc_client
import
connect_to_server
from
.rpc_server
import
start_server
from
.rpc_server
import
start_server
from
.server_state
import
ServerState
from
.server_state
import
ServerState
python/dgl/distributed/constants.py
View file @
0d9b6bfd
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
MAX_QUEUE_SIZE
=
20
*
1024
*
1024
*
1024
MAX_QUEUE_SIZE
=
20
*
1024
*
1024
*
1024
SERVER_EXIT
=
"server_exit"
SERVER_EXIT
=
"server_exit"
SERVER_KEEP_ALIVE
=
"server_keep_alive"
DEFAULT_NTYPE
=
"_N"
DEFAULT_NTYPE
=
"_N"
DEFAULT_ETYPE
=
(
DEFAULT_NTYPE
,
"_E"
,
DEFAULT_NTYPE
)
DEFAULT_ETYPE
=
(
DEFAULT_NTYPE
,
"_E"
,
DEFAULT_NTYPE
)
python/dgl/distributed/dist_context.py
View file @
0d9b6bfd
...
@@ -263,7 +263,6 @@ def initialize(
...
@@ -263,7 +263,6 @@ def initialize(
formats
=
os
.
environ
.
get
(
"DGL_GRAPH_FORMAT"
,
"csc"
).
split
(
","
)
formats
=
os
.
environ
.
get
(
"DGL_GRAPH_FORMAT"
,
"csc"
).
split
(
","
)
formats
=
[
f
.
strip
()
for
f
in
formats
]
formats
=
[
f
.
strip
()
for
f
in
formats
]
rpc
.
reset
()
rpc
.
reset
()
keep_alive
=
bool
(
int
(
os
.
environ
.
get
(
"DGL_KEEP_ALIVE"
,
0
)))
serv
=
DistGraphServer
(
serv
=
DistGraphServer
(
int
(
os
.
environ
.
get
(
"DGL_SERVER_ID"
)),
int
(
os
.
environ
.
get
(
"DGL_SERVER_ID"
)),
os
.
environ
.
get
(
"DGL_IP_CONFIG"
),
os
.
environ
.
get
(
"DGL_IP_CONFIG"
),
...
@@ -271,7 +270,6 @@ def initialize(
...
@@ -271,7 +270,6 @@ def initialize(
int
(
os
.
environ
.
get
(
"DGL_NUM_CLIENT"
)),
int
(
os
.
environ
.
get
(
"DGL_NUM_CLIENT"
)),
os
.
environ
.
get
(
"DGL_CONF_PATH"
),
os
.
environ
.
get
(
"DGL_CONF_PATH"
),
graph_format
=
formats
,
graph_format
=
formats
,
keep_alive
=
keep_alive
,
)
)
serv
.
start
()
serv
.
start
()
sys
.
exit
()
sys
.
exit
()
...
...
python/dgl/distributed/dist_graph.py
View file @
0d9b6bfd
...
@@ -330,8 +330,6 @@ class DistGraphServer(KVServer):
...
@@ -330,8 +330,6 @@ class DistGraphServer(KVServer):
Disable shared memory.
Disable shared memory.
graph_format : str or list of str
graph_format : str or list of str
The graph formats.
The graph formats.
keep_alive : bool
Whether to keep server alive when clients exit
"""
"""
def
__init__
(
def
__init__
(
...
@@ -343,7 +341,6 @@ class DistGraphServer(KVServer):
...
@@ -343,7 +341,6 @@ class DistGraphServer(KVServer):
part_config
,
part_config
,
disable_shared_mem
=
False
,
disable_shared_mem
=
False
,
graph_format
=
(
"csc"
,
"coo"
),
graph_format
=
(
"csc"
,
"coo"
),
keep_alive
=
False
,
):
):
super
(
DistGraphServer
,
self
).
__init__
(
super
(
DistGraphServer
,
self
).
__init__
(
server_id
=
server_id
,
server_id
=
server_id
,
...
@@ -353,7 +350,6 @@ class DistGraphServer(KVServer):
...
@@ -353,7 +350,6 @@ class DistGraphServer(KVServer):
)
)
self
.
ip_config
=
ip_config
self
.
ip_config
=
ip_config
self
.
num_servers
=
num_servers
self
.
num_servers
=
num_servers
self
.
keep_alive
=
keep_alive
# Load graph partition data.
# Load graph partition data.
if
self
.
is_backup_server
():
if
self
.
is_backup_server
():
# The backup server doesn't load the graph partition. It'll initialized afterwards.
# The backup server doesn't load the graph partition. It'll initialized afterwards.
...
@@ -457,7 +453,6 @@ class DistGraphServer(KVServer):
...
@@ -457,7 +453,6 @@ class DistGraphServer(KVServer):
kv_store
=
self
,
kv_store
=
self
,
local_g
=
self
.
client_g
,
local_g
=
self
.
client_g
,
partition_book
=
self
.
gpb
,
partition_book
=
self
.
gpb
,
keep_alive
=
self
.
keep_alive
,
)
)
print
(
print
(
"start graph service on server {} for part {}"
.
format
(
"start graph service on server {} for part {}"
.
format
(
...
...
python/dgl/distributed/kvstore.py
View file @
0d9b6bfd
...
@@ -431,9 +431,6 @@ class GetSharedDataRequest(rpc.Request):
...
@@ -431,9 +431,6 @@ class GetSharedDataRequest(rpc.Request):
meta
=
{}
meta
=
{}
kv_store
=
server_state
.
kv_store
kv_store
=
server_state
.
kv_store
for
name
,
data
in
kv_store
.
data_store
.
items
():
for
name
,
data
in
kv_store
.
data_store
.
items
():
if
server_state
.
keep_alive
:
if
name
not
in
kv_store
.
orig_data
:
continue
meta
[
name
]
=
(
meta
[
name
]
=
(
F
.
shape
(
data
),
F
.
shape
(
data
),
F
.
reverse_data_type_dict
[
F
.
dtype
(
data
)],
F
.
reverse_data_type_dict
[
F
.
dtype
(
data
)],
...
...
python/dgl/distributed/rpc.py
View file @
0d9b6bfd
...
@@ -11,7 +11,7 @@ from .. import backend as F
...
@@ -11,7 +11,7 @@ from .. import backend as F
from
.._ffi.function
import
_init_api
from
.._ffi.function
import
_init_api
from
.._ffi.object
import
ObjectBase
,
register_object
from
.._ffi.object
import
ObjectBase
,
register_object
from
..base
import
DGLError
from
..base
import
DGLError
from
.constants
import
SERVER_EXIT
,
SERVER_KEEP_ALIVE
from
.constants
import
SERVER_EXIT
__all__
=
[
__all__
=
[
"set_rank"
,
"set_rank"
,
...
@@ -1256,8 +1256,6 @@ class ShutDownRequest(Request):
...
@@ -1256,8 +1256,6 @@ class ShutDownRequest(Request):
def
process_request
(
self
,
server_state
):
def
process_request
(
self
,
server_state
):
assert
self
.
client_id
==
0
assert
self
.
client_id
==
0
if
server_state
.
keep_alive
and
not
self
.
force_shutdown_server
:
return
SERVER_KEEP_ALIVE
finalize_server
()
finalize_server
()
return
SERVER_EXIT
return
SERVER_EXIT
...
...
python/dgl/distributed/rpc_client.py
View file @
0d9b6bfd
...
@@ -226,44 +226,3 @@ def connect_to_server(
...
@@ -226,44 +226,3 @@ def connect_to_server(
atexit
.
register
(
exit_client
)
atexit
.
register
(
exit_client
)
set_initialized
(
True
)
set_initialized
(
True
)
def
shutdown_servers
(
ip_config
,
num_servers
):
"""Issue commands to remote servers to shut them down.
This function is required to be called manually only when we
have booted servers which keep alive even clients exit. In
order to shut down server elegantly, we utilize existing
client logic/code to boot a special client which does nothing
but send shut down request to servers. Once such request is
received, servers will exit from endless wait loop, release
occupied resources and end its process. Please call this function
with same arguments used in `dgl.distributed.connect_to_server`.
Parameters
----------
ip_config : str
Path of server IP configuration file.
num_servers : int
server count on each machine.
Raises
------
ConnectionError : If anything wrong with the connection.
"""
rpc
.
register_service
(
rpc
.
SHUT_DOWN_SERVER
,
rpc
.
ShutDownRequest
,
None
)
rpc
.
register_sig_handler
()
server_namebook
=
rpc
.
read_ip_config
(
ip_config
,
num_servers
)
num_servers
=
len
(
server_namebook
)
rpc
.
create_sender
(
MAX_QUEUE_SIZE
)
# Get connected with all server nodes
for
server_id
,
addr
in
server_namebook
.
items
():
server_ip
=
addr
[
1
]
server_port
=
addr
[
2
]
while
not
rpc
.
connect_receiver
(
server_ip
,
server_port
,
server_id
):
time
.
sleep
(
1
)
# send ShutDownRequest to all servers
req
=
rpc
.
ShutDownRequest
(
0
,
True
)
for
server_id
in
range
(
num_servers
):
rpc
.
send_request
(
server_id
,
req
)
rpc
.
finalize_sender
()
python/dgl/distributed/rpc_server.py
View file @
0d9b6bfd
...
@@ -5,7 +5,7 @@ import time
...
@@ -5,7 +5,7 @@ import time
from
..base
import
DGLError
from
..base
import
DGLError
from
.
import
rpc
from
.
import
rpc
from
.constants
import
MAX_QUEUE_SIZE
,
SERVER_EXIT
,
SERVER_KEEP_ALIVE
from
.constants
import
MAX_QUEUE_SIZE
,
SERVER_EXIT
def
start_server
(
def
start_server
(
...
@@ -52,8 +52,6 @@ def start_server(
...
@@ -52,8 +52,6 @@ def start_server(
assert
max_queue_size
>
0
,
(
assert
max_queue_size
>
0
,
(
"queue_size (%d) cannot be a negative number."
%
max_queue_size
"queue_size (%d) cannot be a negative number."
%
max_queue_size
)
)
if
server_state
.
keep_alive
:
assert
False
,
"Long live server is not supported any more."
# Register signal handler.
# Register signal handler.
rpc
.
register_sig_handler
()
rpc
.
register_sig_handler
()
# Register some basic services
# Register some basic services
...
@@ -146,12 +144,6 @@ def start_server(
...
@@ -146,12 +144,6 @@ def start_server(
if
res
==
SERVER_EXIT
:
if
res
==
SERVER_EXIT
:
print
(
"Server is exiting..."
)
print
(
"Server is exiting..."
)
return
return
elif
res
==
SERVER_KEEP_ALIVE
:
print
(
"Server keeps alive while client group~{} is exiting..."
.
format
(
group_id
)
)
else
:
else
:
raise
DGLError
(
"Unexpected response: {}"
.
format
(
res
))
raise
DGLError
(
"Unexpected response: {}"
.
format
(
res
))
else
:
else
:
...
...
python/dgl/distributed/server_state.py
View file @
0d9b6bfd
...
@@ -38,15 +38,12 @@ class ServerState:
...
@@ -38,15 +38,12 @@ class ServerState:
Total number of edges
Total number of edges
partition_book : GraphPartitionBook
partition_book : GraphPartitionBook
Graph Partition book
Graph Partition book
keep_alive : bool
whether to keep alive which supports any number of client groups connect
"""
"""
def
__init__
(
self
,
kv_store
,
local_g
,
partition_book
,
keep_alive
=
False
):
def
__init__
(
self
,
kv_store
,
local_g
,
partition_book
):
self
.
_kv_store
=
kv_store
self
.
_kv_store
=
kv_store
self
.
_graph
=
local_g
self
.
_graph
=
local_g
self
.
partition_book
=
partition_book
self
.
partition_book
=
partition_book
self
.
_keep_alive
=
keep_alive
self
.
_roles
=
{}
self
.
_roles
=
{}
@
property
@
property
...
@@ -72,10 +69,5 @@ class ServerState:
...
@@ -72,10 +69,5 @@ class ServerState:
def
graph
(
self
,
graph
):
def
graph
(
self
,
graph
):
self
.
_graph
=
graph
self
.
_graph
=
graph
@
property
def
keep_alive
(
self
):
"""Flag of whether keep alive"""
return
self
.
_keep_alive
_init_api
(
"dgl.distributed.server_state"
)
_init_api
(
"dgl.distributed.server_state"
)
tests/distributed/test_dist_graph_store.py
View file @
0d9b6bfd
...
@@ -44,7 +44,6 @@ def run_server(
...
@@ -44,7 +44,6 @@ def run_server(
server_count
,
server_count
,
num_clients
,
num_clients
,
shared_mem
,
shared_mem
,
keep_alive
=
False
,
):
):
g
=
DistGraphServer
(
g
=
DistGraphServer
(
server_id
,
server_id
,
...
@@ -54,7 +53,6 @@ def run_server(
...
@@ -54,7 +53,6 @@ def run_server(
"/tmp/dist_graph/{}.json"
.
format
(
graph_name
),
"/tmp/dist_graph/{}.json"
.
format
(
graph_name
),
disable_shared_mem
=
not
shared_mem
,
disable_shared_mem
=
not
shared_mem
,
graph_format
=
[
"csc"
,
"coo"
],
graph_format
=
[
"csc"
,
"coo"
],
keep_alive
=
keep_alive
,
)
)
print
(
"start server"
,
server_id
)
print
(
"start server"
,
server_id
)
# verify dtype of underlying graph
# verify dtype of underlying graph
...
@@ -479,7 +477,6 @@ def check_dist_emb_server_client(
...
@@ -479,7 +477,6 @@ def check_dist_emb_server_client(
# We cannot run multiple servers and clients on the same machine.
# We cannot run multiple servers and clients on the same machine.
serv_ps
=
[]
serv_ps
=
[]
ctx
=
mp
.
get_context
(
"spawn"
)
ctx
=
mp
.
get_context
(
"spawn"
)
keep_alive
=
num_groups
>
1
for
serv_id
in
range
(
num_servers
):
for
serv_id
in
range
(
num_servers
):
p
=
ctx
.
Process
(
p
=
ctx
.
Process
(
target
=
run_server
,
target
=
run_server
,
...
@@ -489,7 +486,6 @@ def check_dist_emb_server_client(
...
@@ -489,7 +486,6 @@ def check_dist_emb_server_client(
num_servers
,
num_servers
,
num_clients
,
num_clients
,
shared_mem
,
shared_mem
,
keep_alive
,
),
),
)
)
serv_ps
.
append
(
p
)
serv_ps
.
append
(
p
)
...
@@ -519,11 +515,6 @@ def check_dist_emb_server_client(
...
@@ -519,11 +515,6 @@ def check_dist_emb_server_client(
p
.
join
()
p
.
join
()
assert
p
.
exitcode
==
0
assert
p
.
exitcode
==
0
if
keep_alive
:
for
p
in
serv_ps
:
assert
p
.
is_alive
()
# force shutdown server
dgl
.
distributed
.
shutdown_servers
(
"kv_ip_config.txt"
,
num_servers
)
for
p
in
serv_ps
:
for
p
in
serv_ps
:
p
.
join
()
p
.
join
()
assert
p
.
exitcode
==
0
assert
p
.
exitcode
==
0
...
@@ -546,7 +537,6 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
...
@@ -546,7 +537,6 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
# We cannot run multiple servers and clients on the same machine.
# We cannot run multiple servers and clients on the same machine.
serv_ps
=
[]
serv_ps
=
[]
ctx
=
mp
.
get_context
(
"spawn"
)
ctx
=
mp
.
get_context
(
"spawn"
)
keep_alive
=
num_groups
>
1
for
serv_id
in
range
(
num_servers
):
for
serv_id
in
range
(
num_servers
):
p
=
ctx
.
Process
(
p
=
ctx
.
Process
(
target
=
run_server
,
target
=
run_server
,
...
@@ -556,7 +546,6 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
...
@@ -556,7 +546,6 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
num_servers
,
num_servers
,
num_clients
,
num_clients
,
shared_mem
,
shared_mem
,
keep_alive
,
),
),
)
)
serv_ps
.
append
(
p
)
serv_ps
.
append
(
p
)
...
@@ -586,11 +575,6 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
...
@@ -586,11 +575,6 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
p
.
join
()
p
.
join
()
assert
p
.
exitcode
==
0
assert
p
.
exitcode
==
0
if
keep_alive
:
for
p
in
serv_ps
:
assert
p
.
is_alive
()
# force shutdown server
dgl
.
distributed
.
shutdown_servers
(
"kv_ip_config.txt"
,
num_servers
)
for
p
in
serv_ps
:
for
p
in
serv_ps
:
p
.
join
()
p
.
join
()
assert
p
.
exitcode
==
0
assert
p
.
exitcode
==
0
...
@@ -988,7 +972,6 @@ def check_dist_optim_server_client(
...
@@ -988,7 +972,6 @@ def check_dist_optim_server_client(
num_servers
,
num_servers
,
num_clients
,
num_clients
,
True
,
True
,
False
,
),
),
)
)
serv_ps
.
append
(
p
)
serv_ps
.
append
(
p
)
...
...
tests/distributed/test_distributed_sampling.py
View file @
0d9b6bfd
...
@@ -31,7 +31,6 @@ def start_server(
...
@@ -31,7 +31,6 @@ def start_server(
disable_shared_mem
,
disable_shared_mem
,
graph_name
,
graph_name
,
graph_format
=
[
"csc"
,
"coo"
],
graph_format
=
[
"csc"
,
"coo"
],
keep_alive
=
False
,
):
):
g
=
DistGraphServer
(
g
=
DistGraphServer
(
rank
,
rank
,
...
@@ -41,7 +40,6 @@ def start_server(
...
@@ -41,7 +40,6 @@ def start_server(
tmpdir
/
(
graph_name
+
".json"
),
tmpdir
/
(
graph_name
+
".json"
),
disable_shared_mem
=
disable_shared_mem
,
disable_shared_mem
=
disable_shared_mem
,
graph_format
=
graph_format
,
graph_format
=
graph_format
,
keep_alive
=
keep_alive
,
)
)
g
.
start
()
g
.
start
()
...
@@ -399,7 +397,6 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
...
@@ -399,7 +397,6 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
pserver_list
=
[]
pserver_list
=
[]
ctx
=
mp
.
get_context
(
"spawn"
)
ctx
=
mp
.
get_context
(
"spawn"
)
keep_alive
=
num_groups
>
1
for
i
in
range
(
num_server
):
for
i
in
range
(
num_server
):
p
=
ctx
.
Process
(
p
=
ctx
.
Process
(
target
=
start_server
,
target
=
start_server
,
...
@@ -409,7 +406,6 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
...
@@ -409,7 +406,6 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
num_server
>
1
,
num_server
>
1
,
"test_sampling"
,
"test_sampling"
,
[
"csc"
,
"coo"
],
[
"csc"
,
"coo"
],
keep_alive
,
),
),
)
)
p
.
start
()
p
.
start
()
...
@@ -439,11 +435,6 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
...
@@ -439,11 +435,6 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
for
p
in
pclient_list
:
for
p
in
pclient_list
:
p
.
join
()
p
.
join
()
assert
p
.
exitcode
==
0
assert
p
.
exitcode
==
0
if
keep_alive
:
for
p
in
pserver_list
:
assert
p
.
is_alive
()
# force shutdown server
dgl
.
distributed
.
shutdown_servers
(
"rpc_ip_config.txt"
,
1
)
for
p
in
pserver_list
:
for
p
in
pserver_list
:
p
.
join
()
p
.
join
()
assert
p
.
exitcode
==
0
assert
p
.
exitcode
==
0
...
...
tests/distributed/test_mp_dataloader.py
View file @
0d9b6bfd
...
@@ -52,7 +52,6 @@ def start_server(
...
@@ -52,7 +52,6 @@ def start_server(
part_config
,
part_config
,
disable_shared_mem
,
disable_shared_mem
,
num_clients
,
num_clients
,
keep_alive
=
False
,
):
):
print
(
"server: #clients="
+
str
(
num_clients
))
print
(
"server: #clients="
+
str
(
num_clients
))
g
=
DistGraphServer
(
g
=
DistGraphServer
(
...
@@ -63,7 +62,6 @@ def start_server(
...
@@ -63,7 +62,6 @@ def start_server(
part_config
,
part_config
,
disable_shared_mem
=
disable_shared_mem
,
disable_shared_mem
=
disable_shared_mem
,
graph_format
=
[
"csc"
,
"coo"
],
graph_format
=
[
"csc"
,
"coo"
],
keep_alive
=
keep_alive
,
)
)
g
.
start
()
g
.
start
()
...
@@ -344,7 +342,6 @@ def test_dist_dataloader(num_server, num_workers, drop_last, num_groups):
...
@@ -344,7 +342,6 @@ def test_dist_dataloader(num_server, num_workers, drop_last, num_groups):
part_config
=
os
.
path
.
join
(
test_dir
,
"test_sampling.json"
)
part_config
=
os
.
path
.
join
(
test_dir
,
"test_sampling.json"
)
pserver_list
=
[]
pserver_list
=
[]
ctx
=
mp
.
get_context
(
"spawn"
)
ctx
=
mp
.
get_context
(
"spawn"
)
keep_alive
=
num_groups
>
1
for
i
in
range
(
num_server
):
for
i
in
range
(
num_server
):
p
=
ctx
.
Process
(
p
=
ctx
.
Process
(
target
=
start_server
,
target
=
start_server
,
...
@@ -354,7 +351,6 @@ def test_dist_dataloader(num_server, num_workers, drop_last, num_groups):
...
@@ -354,7 +351,6 @@ def test_dist_dataloader(num_server, num_workers, drop_last, num_groups):
part_config
,
part_config
,
num_server
>
1
,
num_server
>
1
,
num_workers
+
1
,
num_workers
+
1
,
keep_alive
,
),
),
)
)
p
.
start
()
p
.
start
()
...
@@ -389,11 +385,6 @@ def test_dist_dataloader(num_server, num_workers, drop_last, num_groups):
...
@@ -389,11 +385,6 @@ def test_dist_dataloader(num_server, num_workers, drop_last, num_groups):
for
p
in
ptrainer_list
:
for
p
in
ptrainer_list
:
p
.
join
()
p
.
join
()
assert
p
.
exitcode
==
0
assert
p
.
exitcode
==
0
if
keep_alive
:
for
p
in
pserver_list
:
assert
p
.
is_alive
()
# force shutdown server
dgl
.
distributed
.
shutdown_servers
(
"mp_ip_config.txt"
,
1
)
for
p
in
pserver_list
:
for
p
in
pserver_list
:
p
.
join
()
p
.
join
()
assert
p
.
exitcode
==
0
assert
p
.
exitcode
==
0
...
...
tests/distributed/test_rpc.py
View file @
0d9b6bfd
...
@@ -133,13 +133,12 @@ def start_server(
...
@@ -133,13 +133,12 @@ def start_server(
num_clients
,
num_clients
,
ip_config
,
ip_config
,
server_id
=
0
,
server_id
=
0
,
keep_alive
=
False
,
num_servers
=
1
,
num_servers
=
1
,
):
):
print
(
"Sleep 1 seconds to test client re-connect."
)
print
(
"Sleep 1 seconds to test client re-connect."
)
time
.
sleep
(
1
)
time
.
sleep
(
1
)
server_state
=
dgl
.
distributed
.
ServerState
(
server_state
=
dgl
.
distributed
.
ServerState
(
None
,
local_g
=
None
,
partition_book
=
None
,
keep_alive
=
keep_alive
None
,
local_g
=
None
,
partition_book
=
None
)
)
dgl
.
distributed
.
register_service
(
dgl
.
distributed
.
register_service
(
HELLO_SERVICE_ID
,
HelloRequest
,
HelloResponse
HELLO_SERVICE_ID
,
HelloRequest
,
HelloResponse
...
@@ -258,7 +257,7 @@ def test_rpc_timeout():
...
@@ -258,7 +257,7 @@ def test_rpc_timeout():
ip_config
=
"rpc_ip_config.txt"
ip_config
=
"rpc_ip_config.txt"
generate_ip_config
(
ip_config
,
1
,
1
)
generate_ip_config
(
ip_config
,
1
,
1
)
ctx
=
mp
.
get_context
(
"spawn"
)
ctx
=
mp
.
get_context
(
"spawn"
)
pserver
=
ctx
.
Process
(
target
=
start_server
,
args
=
(
1
,
ip_config
,
0
,
False
,
1
))
pserver
=
ctx
.
Process
(
target
=
start_server
,
args
=
(
1
,
ip_config
,
0
,
1
))
pclient
=
ctx
.
Process
(
target
=
start_client_timeout
,
args
=
(
ip_config
,
0
,
1
))
pclient
=
ctx
.
Process
(
target
=
start_client_timeout
,
args
=
(
ip_config
,
0
,
1
))
pserver
.
start
()
pserver
.
start
()
pclient
.
start
()
pclient
.
start
()
...
@@ -323,7 +322,7 @@ def test_multi_client():
...
@@ -323,7 +322,7 @@ def test_multi_client():
num_clients
=
20
num_clients
=
20
pserver
=
ctx
.
Process
(
pserver
=
ctx
.
Process
(
target
=
start_server
,
target
=
start_server
,
args
=
(
num_clients
,
ip_config
,
0
,
False
,
1
),
args
=
(
num_clients
,
ip_config
,
0
,
1
),
)
)
pclient_list
=
[]
pclient_list
=
[]
for
i
in
range
(
num_clients
):
for
i
in
range
(
num_clients
):
...
@@ -347,9 +346,7 @@ def test_multi_thread_rpc():
...
@@ -347,9 +346,7 @@ def test_multi_thread_rpc():
ctx
=
mp
.
get_context
(
"spawn"
)
ctx
=
mp
.
get_context
(
"spawn"
)
pserver_list
=
[]
pserver_list
=
[]
for
i
in
range
(
num_servers
):
for
i
in
range
(
num_servers
):
pserver
=
ctx
.
Process
(
pserver
=
ctx
.
Process
(
target
=
start_server
,
args
=
(
1
,
ip_config
,
i
,
1
))
target
=
start_server
,
args
=
(
1
,
ip_config
,
i
,
False
,
1
)
)
pserver
.
start
()
pserver
.
start
()
pserver_list
.
append
(
pserver
)
pserver_list
.
append
(
pserver
)
...
@@ -386,49 +383,6 @@ def test_multi_thread_rpc():
...
@@ -386,49 +383,6 @@ def test_multi_thread_rpc():
pserver
.
join
()
pserver
.
join
()
@
unittest
.
skipIf
(
True
,
reason
=
"Tests of multiple groups may fail and let's disable them for now."
,
)
@
unittest
.
skipIf
(
os
.
name
==
"nt"
,
reason
=
"Do not support windows yet"
)
def
test_multi_client_groups
():
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
ip_config
=
"rpc_ip_config_mul_client_groups.txt"
num_machines
=
5
# should test with larger number but due to possible port in-use issue.
num_servers
=
1
generate_ip_config
(
ip_config
,
num_machines
,
num_servers
)
# presssue test
num_clients
=
2
num_groups
=
2
ctx
=
mp
.
get_context
(
"spawn"
)
pserver_list
=
[]
for
i
in
range
(
num_servers
*
num_machines
):
pserver
=
ctx
.
Process
(
target
=
start_server
,
args
=
(
num_clients
,
ip_config
,
i
,
True
,
num_servers
),
)
pserver
.
start
()
pserver_list
.
append
(
pserver
)
pclient_list
=
[]
for
i
in
range
(
num_clients
):
for
group_id
in
range
(
num_groups
):
pclient
=
ctx
.
Process
(
target
=
start_client
,
args
=
(
ip_config
,
group_id
,
num_servers
)
)
pclient
.
start
()
pclient_list
.
append
(
pclient
)
for
p
in
pclient_list
:
p
.
join
()
for
p
in
pserver_list
:
assert
p
.
is_alive
()
# force shutdown server
dgl
.
distributed
.
shutdown_servers
(
ip_config
,
num_servers
)
for
p
in
pserver_list
:
p
.
join
()
@
unittest
.
skipIf
(
os
.
name
==
"nt"
,
reason
=
"Do not support windows yet"
)
@
unittest
.
skipIf
(
os
.
name
==
"nt"
,
reason
=
"Do not support windows yet"
)
def
test_multi_client_connect
():
def
test_multi_client_connect
():
reset_envs
()
reset_envs
()
...
@@ -439,7 +393,7 @@ def test_multi_client_connect():
...
@@ -439,7 +393,7 @@ def test_multi_client_connect():
num_clients
=
1
num_clients
=
1
pserver
=
ctx
.
Process
(
pserver
=
ctx
.
Process
(
target
=
start_server
,
target
=
start_server
,
args
=
(
num_clients
,
ip_config
,
0
,
False
,
1
),
args
=
(
num_clients
,
ip_config
,
0
,
1
),
)
)
# small max try times
# small max try times
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment