Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
0d9b6bfd
Unverified
Commit
0d9b6bfd
authored
Jul 03, 2023
by
Rhett Ying
Committed by
GitHub
Jul 03, 2023
Browse files
[TensorpipeDeprecation] remove long live server support from DistDGL (#5931)
parent
4015c5fe
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
9 additions
and
160 deletions
+9
-160
python/dgl/distributed/__init__.py
python/dgl/distributed/__init__.py
+1
-1
python/dgl/distributed/constants.py
python/dgl/distributed/constants.py
+0
-1
python/dgl/distributed/dist_context.py
python/dgl/distributed/dist_context.py
+0
-2
python/dgl/distributed/dist_graph.py
python/dgl/distributed/dist_graph.py
+0
-5
python/dgl/distributed/kvstore.py
python/dgl/distributed/kvstore.py
+0
-3
python/dgl/distributed/rpc.py
python/dgl/distributed/rpc.py
+1
-3
python/dgl/distributed/rpc_client.py
python/dgl/distributed/rpc_client.py
+0
-41
python/dgl/distributed/rpc_server.py
python/dgl/distributed/rpc_server.py
+1
-9
python/dgl/distributed/server_state.py
python/dgl/distributed/server_state.py
+1
-9
tests/distributed/test_dist_graph_store.py
tests/distributed/test_dist_graph_store.py
+0
-17
tests/distributed/test_distributed_sampling.py
tests/distributed/test_distributed_sampling.py
+0
-9
tests/distributed/test_mp_dataloader.py
tests/distributed/test_mp_dataloader.py
+0
-9
tests/distributed/test_rpc.py
tests/distributed/test_rpc.py
+5
-51
No files found.
python/dgl/distributed/__init__.py
View file @
0d9b6bfd
...
...
@@ -16,6 +16,6 @@ from .partition import (
partition_graph
,
)
from
.rpc
import
*
from
.rpc_client
import
connect_to_server
,
shutdown_servers
from
.rpc_client
import
connect_to_server
from
.rpc_server
import
start_server
from
.server_state
import
ServerState
python/dgl/distributed/constants.py
View file @
0d9b6bfd
...
...
@@ -4,7 +4,6 @@
MAX_QUEUE_SIZE
=
20
*
1024
*
1024
*
1024
SERVER_EXIT
=
"server_exit"
SERVER_KEEP_ALIVE
=
"server_keep_alive"
DEFAULT_NTYPE
=
"_N"
DEFAULT_ETYPE
=
(
DEFAULT_NTYPE
,
"_E"
,
DEFAULT_NTYPE
)
python/dgl/distributed/dist_context.py
View file @
0d9b6bfd
...
...
@@ -263,7 +263,6 @@ def initialize(
formats
=
os
.
environ
.
get
(
"DGL_GRAPH_FORMAT"
,
"csc"
).
split
(
","
)
formats
=
[
f
.
strip
()
for
f
in
formats
]
rpc
.
reset
()
keep_alive
=
bool
(
int
(
os
.
environ
.
get
(
"DGL_KEEP_ALIVE"
,
0
)))
serv
=
DistGraphServer
(
int
(
os
.
environ
.
get
(
"DGL_SERVER_ID"
)),
os
.
environ
.
get
(
"DGL_IP_CONFIG"
),
...
...
@@ -271,7 +270,6 @@ def initialize(
int
(
os
.
environ
.
get
(
"DGL_NUM_CLIENT"
)),
os
.
environ
.
get
(
"DGL_CONF_PATH"
),
graph_format
=
formats
,
keep_alive
=
keep_alive
,
)
serv
.
start
()
sys
.
exit
()
...
...
python/dgl/distributed/dist_graph.py
View file @
0d9b6bfd
...
...
@@ -330,8 +330,6 @@ class DistGraphServer(KVServer):
Disable shared memory.
graph_format : str or list of str
The graph formats.
keep_alive : bool
Whether to keep server alive when clients exit
"""
def
__init__
(
...
...
@@ -343,7 +341,6 @@ class DistGraphServer(KVServer):
part_config
,
disable_shared_mem
=
False
,
graph_format
=
(
"csc"
,
"coo"
),
keep_alive
=
False
,
):
super
(
DistGraphServer
,
self
).
__init__
(
server_id
=
server_id
,
...
...
@@ -353,7 +350,6 @@ class DistGraphServer(KVServer):
)
self
.
ip_config
=
ip_config
self
.
num_servers
=
num_servers
self
.
keep_alive
=
keep_alive
# Load graph partition data.
if
self
.
is_backup_server
():
# The backup server doesn't load the graph partition. It'll initialized afterwards.
...
...
@@ -457,7 +453,6 @@ class DistGraphServer(KVServer):
kv_store
=
self
,
local_g
=
self
.
client_g
,
partition_book
=
self
.
gpb
,
keep_alive
=
self
.
keep_alive
,
)
print
(
"start graph service on server {} for part {}"
.
format
(
...
...
python/dgl/distributed/kvstore.py
View file @
0d9b6bfd
...
...
@@ -431,9 +431,6 @@ class GetSharedDataRequest(rpc.Request):
meta
=
{}
kv_store
=
server_state
.
kv_store
for
name
,
data
in
kv_store
.
data_store
.
items
():
if
server_state
.
keep_alive
:
if
name
not
in
kv_store
.
orig_data
:
continue
meta
[
name
]
=
(
F
.
shape
(
data
),
F
.
reverse_data_type_dict
[
F
.
dtype
(
data
)],
...
...
python/dgl/distributed/rpc.py
View file @
0d9b6bfd
...
...
@@ -11,7 +11,7 @@ from .. import backend as F
from
.._ffi.function
import
_init_api
from
.._ffi.object
import
ObjectBase
,
register_object
from
..base
import
DGLError
from
.constants
import
SERVER_EXIT
,
SERVER_KEEP_ALIVE
from
.constants
import
SERVER_EXIT
__all__
=
[
"set_rank"
,
...
...
@@ -1256,8 +1256,6 @@ class ShutDownRequest(Request):
def
process_request
(
self
,
server_state
):
assert
self
.
client_id
==
0
if
server_state
.
keep_alive
and
not
self
.
force_shutdown_server
:
return
SERVER_KEEP_ALIVE
finalize_server
()
return
SERVER_EXIT
...
...
python/dgl/distributed/rpc_client.py
View file @
0d9b6bfd
...
...
@@ -226,44 +226,3 @@ def connect_to_server(
atexit
.
register
(
exit_client
)
set_initialized
(
True
)
def
shutdown_servers
(
ip_config
,
num_servers
):
"""Issue commands to remote servers to shut them down.
This function is required to be called manually only when we
have booted servers which keep alive even clients exit. In
order to shut down server elegantly, we utilize existing
client logic/code to boot a special client which does nothing
but send shut down request to servers. Once such request is
received, servers will exit from endless wait loop, release
occupied resources and end its process. Please call this function
with same arguments used in `dgl.distributed.connect_to_server`.
Parameters
----------
ip_config : str
Path of server IP configuration file.
num_servers : int
server count on each machine.
Raises
------
ConnectionError : If anything wrong with the connection.
"""
rpc
.
register_service
(
rpc
.
SHUT_DOWN_SERVER
,
rpc
.
ShutDownRequest
,
None
)
rpc
.
register_sig_handler
()
server_namebook
=
rpc
.
read_ip_config
(
ip_config
,
num_servers
)
num_servers
=
len
(
server_namebook
)
rpc
.
create_sender
(
MAX_QUEUE_SIZE
)
# Get connected with all server nodes
for
server_id
,
addr
in
server_namebook
.
items
():
server_ip
=
addr
[
1
]
server_port
=
addr
[
2
]
while
not
rpc
.
connect_receiver
(
server_ip
,
server_port
,
server_id
):
time
.
sleep
(
1
)
# send ShutDownRequest to all servers
req
=
rpc
.
ShutDownRequest
(
0
,
True
)
for
server_id
in
range
(
num_servers
):
rpc
.
send_request
(
server_id
,
req
)
rpc
.
finalize_sender
()
python/dgl/distributed/rpc_server.py
View file @
0d9b6bfd
...
...
@@ -5,7 +5,7 @@ import time
from
..base
import
DGLError
from
.
import
rpc
from
.constants
import
MAX_QUEUE_SIZE
,
SERVER_EXIT
,
SERVER_KEEP_ALIVE
from
.constants
import
MAX_QUEUE_SIZE
,
SERVER_EXIT
def
start_server
(
...
...
@@ -52,8 +52,6 @@ def start_server(
assert
max_queue_size
>
0
,
(
"queue_size (%d) cannot be a negative number."
%
max_queue_size
)
if
server_state
.
keep_alive
:
assert
False
,
"Long live server is not supported any more."
# Register signal handler.
rpc
.
register_sig_handler
()
# Register some basic services
...
...
@@ -146,12 +144,6 @@ def start_server(
if
res
==
SERVER_EXIT
:
print
(
"Server is exiting..."
)
return
elif
res
==
SERVER_KEEP_ALIVE
:
print
(
"Server keeps alive while client group~{} is exiting..."
.
format
(
group_id
)
)
else
:
raise
DGLError
(
"Unexpected response: {}"
.
format
(
res
))
else
:
...
...
python/dgl/distributed/server_state.py
View file @
0d9b6bfd
...
...
@@ -38,15 +38,12 @@ class ServerState:
Total number of edges
partition_book : GraphPartitionBook
Graph Partition book
keep_alive : bool
whether to keep alive which supports any number of client groups connect
"""
def
__init__
(
self
,
kv_store
,
local_g
,
partition_book
,
keep_alive
=
False
):
def
__init__
(
self
,
kv_store
,
local_g
,
partition_book
):
self
.
_kv_store
=
kv_store
self
.
_graph
=
local_g
self
.
partition_book
=
partition_book
self
.
_keep_alive
=
keep_alive
self
.
_roles
=
{}
@
property
...
...
@@ -72,10 +69,5 @@ class ServerState:
def
graph
(
self
,
graph
):
self
.
_graph
=
graph
@
property
def
keep_alive
(
self
):
"""Flag of whether keep alive"""
return
self
.
_keep_alive
_init_api
(
"dgl.distributed.server_state"
)
tests/distributed/test_dist_graph_store.py
View file @
0d9b6bfd
...
...
@@ -44,7 +44,6 @@ def run_server(
server_count
,
num_clients
,
shared_mem
,
keep_alive
=
False
,
):
g
=
DistGraphServer
(
server_id
,
...
...
@@ -54,7 +53,6 @@ def run_server(
"/tmp/dist_graph/{}.json"
.
format
(
graph_name
),
disable_shared_mem
=
not
shared_mem
,
graph_format
=
[
"csc"
,
"coo"
],
keep_alive
=
keep_alive
,
)
print
(
"start server"
,
server_id
)
# verify dtype of underlying graph
...
...
@@ -479,7 +477,6 @@ def check_dist_emb_server_client(
# We cannot run multiple servers and clients on the same machine.
serv_ps
=
[]
ctx
=
mp
.
get_context
(
"spawn"
)
keep_alive
=
num_groups
>
1
for
serv_id
in
range
(
num_servers
):
p
=
ctx
.
Process
(
target
=
run_server
,
...
...
@@ -489,7 +486,6 @@ def check_dist_emb_server_client(
num_servers
,
num_clients
,
shared_mem
,
keep_alive
,
),
)
serv_ps
.
append
(
p
)
...
...
@@ -519,11 +515,6 @@ def check_dist_emb_server_client(
p
.
join
()
assert
p
.
exitcode
==
0
if
keep_alive
:
for
p
in
serv_ps
:
assert
p
.
is_alive
()
# force shutdown server
dgl
.
distributed
.
shutdown_servers
(
"kv_ip_config.txt"
,
num_servers
)
for
p
in
serv_ps
:
p
.
join
()
assert
p
.
exitcode
==
0
...
...
@@ -546,7 +537,6 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
# We cannot run multiple servers and clients on the same machine.
serv_ps
=
[]
ctx
=
mp
.
get_context
(
"spawn"
)
keep_alive
=
num_groups
>
1
for
serv_id
in
range
(
num_servers
):
p
=
ctx
.
Process
(
target
=
run_server
,
...
...
@@ -556,7 +546,6 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
num_servers
,
num_clients
,
shared_mem
,
keep_alive
,
),
)
serv_ps
.
append
(
p
)
...
...
@@ -586,11 +575,6 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
p
.
join
()
assert
p
.
exitcode
==
0
if
keep_alive
:
for
p
in
serv_ps
:
assert
p
.
is_alive
()
# force shutdown server
dgl
.
distributed
.
shutdown_servers
(
"kv_ip_config.txt"
,
num_servers
)
for
p
in
serv_ps
:
p
.
join
()
assert
p
.
exitcode
==
0
...
...
@@ -988,7 +972,6 @@ def check_dist_optim_server_client(
num_servers
,
num_clients
,
True
,
False
,
),
)
serv_ps
.
append
(
p
)
...
...
tests/distributed/test_distributed_sampling.py
View file @
0d9b6bfd
...
...
@@ -31,7 +31,6 @@ def start_server(
disable_shared_mem
,
graph_name
,
graph_format
=
[
"csc"
,
"coo"
],
keep_alive
=
False
,
):
g
=
DistGraphServer
(
rank
,
...
...
@@ -41,7 +40,6 @@ def start_server(
tmpdir
/
(
graph_name
+
".json"
),
disable_shared_mem
=
disable_shared_mem
,
graph_format
=
graph_format
,
keep_alive
=
keep_alive
,
)
g
.
start
()
...
...
@@ -399,7 +397,6 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
pserver_list
=
[]
ctx
=
mp
.
get_context
(
"spawn"
)
keep_alive
=
num_groups
>
1
for
i
in
range
(
num_server
):
p
=
ctx
.
Process
(
target
=
start_server
,
...
...
@@ -409,7 +406,6 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
num_server
>
1
,
"test_sampling"
,
[
"csc"
,
"coo"
],
keep_alive
,
),
)
p
.
start
()
...
...
@@ -439,11 +435,6 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
for
p
in
pclient_list
:
p
.
join
()
assert
p
.
exitcode
==
0
if
keep_alive
:
for
p
in
pserver_list
:
assert
p
.
is_alive
()
# force shutdown server
dgl
.
distributed
.
shutdown_servers
(
"rpc_ip_config.txt"
,
1
)
for
p
in
pserver_list
:
p
.
join
()
assert
p
.
exitcode
==
0
...
...
tests/distributed/test_mp_dataloader.py
View file @
0d9b6bfd
...
...
@@ -52,7 +52,6 @@ def start_server(
part_config
,
disable_shared_mem
,
num_clients
,
keep_alive
=
False
,
):
print
(
"server: #clients="
+
str
(
num_clients
))
g
=
DistGraphServer
(
...
...
@@ -63,7 +62,6 @@ def start_server(
part_config
,
disable_shared_mem
=
disable_shared_mem
,
graph_format
=
[
"csc"
,
"coo"
],
keep_alive
=
keep_alive
,
)
g
.
start
()
...
...
@@ -344,7 +342,6 @@ def test_dist_dataloader(num_server, num_workers, drop_last, num_groups):
part_config
=
os
.
path
.
join
(
test_dir
,
"test_sampling.json"
)
pserver_list
=
[]
ctx
=
mp
.
get_context
(
"spawn"
)
keep_alive
=
num_groups
>
1
for
i
in
range
(
num_server
):
p
=
ctx
.
Process
(
target
=
start_server
,
...
...
@@ -354,7 +351,6 @@ def test_dist_dataloader(num_server, num_workers, drop_last, num_groups):
part_config
,
num_server
>
1
,
num_workers
+
1
,
keep_alive
,
),
)
p
.
start
()
...
...
@@ -389,11 +385,6 @@ def test_dist_dataloader(num_server, num_workers, drop_last, num_groups):
for
p
in
ptrainer_list
:
p
.
join
()
assert
p
.
exitcode
==
0
if
keep_alive
:
for
p
in
pserver_list
:
assert
p
.
is_alive
()
# force shutdown server
dgl
.
distributed
.
shutdown_servers
(
"mp_ip_config.txt"
,
1
)
for
p
in
pserver_list
:
p
.
join
()
assert
p
.
exitcode
==
0
...
...
tests/distributed/test_rpc.py
View file @
0d9b6bfd
...
...
@@ -133,13 +133,12 @@ def start_server(
num_clients
,
ip_config
,
server_id
=
0
,
keep_alive
=
False
,
num_servers
=
1
,
):
print
(
"Sleep 1 seconds to test client re-connect."
)
time
.
sleep
(
1
)
server_state
=
dgl
.
distributed
.
ServerState
(
None
,
local_g
=
None
,
partition_book
=
None
,
keep_alive
=
keep_alive
None
,
local_g
=
None
,
partition_book
=
None
)
dgl
.
distributed
.
register_service
(
HELLO_SERVICE_ID
,
HelloRequest
,
HelloResponse
...
...
@@ -258,7 +257,7 @@ def test_rpc_timeout():
ip_config
=
"rpc_ip_config.txt"
generate_ip_config
(
ip_config
,
1
,
1
)
ctx
=
mp
.
get_context
(
"spawn"
)
pserver
=
ctx
.
Process
(
target
=
start_server
,
args
=
(
1
,
ip_config
,
0
,
False
,
1
))
pserver
=
ctx
.
Process
(
target
=
start_server
,
args
=
(
1
,
ip_config
,
0
,
1
))
pclient
=
ctx
.
Process
(
target
=
start_client_timeout
,
args
=
(
ip_config
,
0
,
1
))
pserver
.
start
()
pclient
.
start
()
...
...
@@ -323,7 +322,7 @@ def test_multi_client():
num_clients
=
20
pserver
=
ctx
.
Process
(
target
=
start_server
,
args
=
(
num_clients
,
ip_config
,
0
,
False
,
1
),
args
=
(
num_clients
,
ip_config
,
0
,
1
),
)
pclient_list
=
[]
for
i
in
range
(
num_clients
):
...
...
@@ -347,9 +346,7 @@ def test_multi_thread_rpc():
ctx
=
mp
.
get_context
(
"spawn"
)
pserver_list
=
[]
for
i
in
range
(
num_servers
):
pserver
=
ctx
.
Process
(
target
=
start_server
,
args
=
(
1
,
ip_config
,
i
,
False
,
1
)
)
pserver
=
ctx
.
Process
(
target
=
start_server
,
args
=
(
1
,
ip_config
,
i
,
1
))
pserver
.
start
()
pserver_list
.
append
(
pserver
)
...
...
@@ -386,49 +383,6 @@ def test_multi_thread_rpc():
pserver
.
join
()
@
unittest
.
skipIf
(
True
,
reason
=
"Tests of multiple groups may fail and let's disable them for now."
,
)
@
unittest
.
skipIf
(
os
.
name
==
"nt"
,
reason
=
"Do not support windows yet"
)
def
test_multi_client_groups
():
reset_envs
()
os
.
environ
[
"DGL_DIST_MODE"
]
=
"distributed"
ip_config
=
"rpc_ip_config_mul_client_groups.txt"
num_machines
=
5
# should test with larger number but due to possible port in-use issue.
num_servers
=
1
generate_ip_config
(
ip_config
,
num_machines
,
num_servers
)
# presssue test
num_clients
=
2
num_groups
=
2
ctx
=
mp
.
get_context
(
"spawn"
)
pserver_list
=
[]
for
i
in
range
(
num_servers
*
num_machines
):
pserver
=
ctx
.
Process
(
target
=
start_server
,
args
=
(
num_clients
,
ip_config
,
i
,
True
,
num_servers
),
)
pserver
.
start
()
pserver_list
.
append
(
pserver
)
pclient_list
=
[]
for
i
in
range
(
num_clients
):
for
group_id
in
range
(
num_groups
):
pclient
=
ctx
.
Process
(
target
=
start_client
,
args
=
(
ip_config
,
group_id
,
num_servers
)
)
pclient
.
start
()
pclient_list
.
append
(
pclient
)
for
p
in
pclient_list
:
p
.
join
()
for
p
in
pserver_list
:
assert
p
.
is_alive
()
# force shutdown server
dgl
.
distributed
.
shutdown_servers
(
ip_config
,
num_servers
)
for
p
in
pserver_list
:
p
.
join
()
@
unittest
.
skipIf
(
os
.
name
==
"nt"
,
reason
=
"Do not support windows yet"
)
def
test_multi_client_connect
():
reset_envs
()
...
...
@@ -439,7 +393,7 @@ def test_multi_client_connect():
num_clients
=
1
pserver
=
ctx
.
Process
(
target
=
start_server
,
args
=
(
num_clients
,
ip_config
,
0
,
False
,
1
),
args
=
(
num_clients
,
ip_config
,
0
,
1
),
)
# small max try times
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment