"docs/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "1a64530d7401be4b03a7f62fb57973e1684d3f9c"
Unverified Commit 0d9b6bfd authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[TensorpipeDeprecation] remove long live server support from DistDGL (#5931)

parent 4015c5fe
...@@ -16,6 +16,6 @@ from .partition import ( ...@@ -16,6 +16,6 @@ from .partition import (
partition_graph, partition_graph,
) )
from .rpc import * from .rpc import *
from .rpc_client import connect_to_server, shutdown_servers from .rpc_client import connect_to_server
from .rpc_server import start_server from .rpc_server import start_server
from .server_state import ServerState from .server_state import ServerState
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
MAX_QUEUE_SIZE = 20 * 1024 * 1024 * 1024 MAX_QUEUE_SIZE = 20 * 1024 * 1024 * 1024
SERVER_EXIT = "server_exit" SERVER_EXIT = "server_exit"
SERVER_KEEP_ALIVE = "server_keep_alive"
DEFAULT_NTYPE = "_N" DEFAULT_NTYPE = "_N"
DEFAULT_ETYPE = (DEFAULT_NTYPE, "_E", DEFAULT_NTYPE) DEFAULT_ETYPE = (DEFAULT_NTYPE, "_E", DEFAULT_NTYPE)
...@@ -263,7 +263,6 @@ def initialize( ...@@ -263,7 +263,6 @@ def initialize(
formats = os.environ.get("DGL_GRAPH_FORMAT", "csc").split(",") formats = os.environ.get("DGL_GRAPH_FORMAT", "csc").split(",")
formats = [f.strip() for f in formats] formats = [f.strip() for f in formats]
rpc.reset() rpc.reset()
keep_alive = bool(int(os.environ.get("DGL_KEEP_ALIVE", 0)))
serv = DistGraphServer( serv = DistGraphServer(
int(os.environ.get("DGL_SERVER_ID")), int(os.environ.get("DGL_SERVER_ID")),
os.environ.get("DGL_IP_CONFIG"), os.environ.get("DGL_IP_CONFIG"),
...@@ -271,7 +270,6 @@ def initialize( ...@@ -271,7 +270,6 @@ def initialize(
int(os.environ.get("DGL_NUM_CLIENT")), int(os.environ.get("DGL_NUM_CLIENT")),
os.environ.get("DGL_CONF_PATH"), os.environ.get("DGL_CONF_PATH"),
graph_format=formats, graph_format=formats,
keep_alive=keep_alive,
) )
serv.start() serv.start()
sys.exit() sys.exit()
......
...@@ -330,8 +330,6 @@ class DistGraphServer(KVServer): ...@@ -330,8 +330,6 @@ class DistGraphServer(KVServer):
Disable shared memory. Disable shared memory.
graph_format : str or list of str graph_format : str or list of str
The graph formats. The graph formats.
keep_alive : bool
Whether to keep server alive when clients exit
""" """
def __init__( def __init__(
...@@ -343,7 +341,6 @@ class DistGraphServer(KVServer): ...@@ -343,7 +341,6 @@ class DistGraphServer(KVServer):
part_config, part_config,
disable_shared_mem=False, disable_shared_mem=False,
graph_format=("csc", "coo"), graph_format=("csc", "coo"),
keep_alive=False,
): ):
super(DistGraphServer, self).__init__( super(DistGraphServer, self).__init__(
server_id=server_id, server_id=server_id,
...@@ -353,7 +350,6 @@ class DistGraphServer(KVServer): ...@@ -353,7 +350,6 @@ class DistGraphServer(KVServer):
) )
self.ip_config = ip_config self.ip_config = ip_config
self.num_servers = num_servers self.num_servers = num_servers
self.keep_alive = keep_alive
# Load graph partition data. # Load graph partition data.
if self.is_backup_server(): if self.is_backup_server():
# The backup server doesn't load the graph partition. It'll initialized afterwards. # The backup server doesn't load the graph partition. It'll initialized afterwards.
...@@ -457,7 +453,6 @@ class DistGraphServer(KVServer): ...@@ -457,7 +453,6 @@ class DistGraphServer(KVServer):
kv_store=self, kv_store=self,
local_g=self.client_g, local_g=self.client_g,
partition_book=self.gpb, partition_book=self.gpb,
keep_alive=self.keep_alive,
) )
print( print(
"start graph service on server {} for part {}".format( "start graph service on server {} for part {}".format(
......
...@@ -431,9 +431,6 @@ class GetSharedDataRequest(rpc.Request): ...@@ -431,9 +431,6 @@ class GetSharedDataRequest(rpc.Request):
meta = {} meta = {}
kv_store = server_state.kv_store kv_store = server_state.kv_store
for name, data in kv_store.data_store.items(): for name, data in kv_store.data_store.items():
if server_state.keep_alive:
if name not in kv_store.orig_data:
continue
meta[name] = ( meta[name] = (
F.shape(data), F.shape(data),
F.reverse_data_type_dict[F.dtype(data)], F.reverse_data_type_dict[F.dtype(data)],
......
...@@ -11,7 +11,7 @@ from .. import backend as F ...@@ -11,7 +11,7 @@ from .. import backend as F
from .._ffi.function import _init_api from .._ffi.function import _init_api
from .._ffi.object import ObjectBase, register_object from .._ffi.object import ObjectBase, register_object
from ..base import DGLError from ..base import DGLError
from .constants import SERVER_EXIT, SERVER_KEEP_ALIVE from .constants import SERVER_EXIT
__all__ = [ __all__ = [
"set_rank", "set_rank",
...@@ -1256,8 +1256,6 @@ class ShutDownRequest(Request): ...@@ -1256,8 +1256,6 @@ class ShutDownRequest(Request):
def process_request(self, server_state): def process_request(self, server_state):
assert self.client_id == 0 assert self.client_id == 0
if server_state.keep_alive and not self.force_shutdown_server:
return SERVER_KEEP_ALIVE
finalize_server() finalize_server()
return SERVER_EXIT return SERVER_EXIT
......
...@@ -226,44 +226,3 @@ def connect_to_server( ...@@ -226,44 +226,3 @@ def connect_to_server(
atexit.register(exit_client) atexit.register(exit_client)
set_initialized(True) set_initialized(True)
def shutdown_servers(ip_config, num_servers):
"""Issue commands to remote servers to shut them down.
This function is required to be called manually only when we
have booted servers which keep alive even clients exit. In
order to shut down server elegantly, we utilize existing
client logic/code to boot a special client which does nothing
but send shut down request to servers. Once such request is
received, servers will exit from endless wait loop, release
occupied resources and end its process. Please call this function
with same arguments used in `dgl.distributed.connect_to_server`.
Parameters
----------
ip_config : str
Path of server IP configuration file.
num_servers : int
server count on each machine.
Raises
------
ConnectionError : If anything wrong with the connection.
"""
rpc.register_service(rpc.SHUT_DOWN_SERVER, rpc.ShutDownRequest, None)
rpc.register_sig_handler()
server_namebook = rpc.read_ip_config(ip_config, num_servers)
num_servers = len(server_namebook)
rpc.create_sender(MAX_QUEUE_SIZE)
# Get connected with all server nodes
for server_id, addr in server_namebook.items():
server_ip = addr[1]
server_port = addr[2]
while not rpc.connect_receiver(server_ip, server_port, server_id):
time.sleep(1)
# send ShutDownRequest to all servers
req = rpc.ShutDownRequest(0, True)
for server_id in range(num_servers):
rpc.send_request(server_id, req)
rpc.finalize_sender()
...@@ -5,7 +5,7 @@ import time ...@@ -5,7 +5,7 @@ import time
from ..base import DGLError from ..base import DGLError
from . import rpc from . import rpc
from .constants import MAX_QUEUE_SIZE, SERVER_EXIT, SERVER_KEEP_ALIVE from .constants import MAX_QUEUE_SIZE, SERVER_EXIT
def start_server( def start_server(
...@@ -52,8 +52,6 @@ def start_server( ...@@ -52,8 +52,6 @@ def start_server(
assert max_queue_size > 0, ( assert max_queue_size > 0, (
"queue_size (%d) cannot be a negative number." % max_queue_size "queue_size (%d) cannot be a negative number." % max_queue_size
) )
if server_state.keep_alive:
assert False, "Long live server is not supported any more."
# Register signal handler. # Register signal handler.
rpc.register_sig_handler() rpc.register_sig_handler()
# Register some basic services # Register some basic services
...@@ -146,12 +144,6 @@ def start_server( ...@@ -146,12 +144,6 @@ def start_server(
if res == SERVER_EXIT: if res == SERVER_EXIT:
print("Server is exiting...") print("Server is exiting...")
return return
elif res == SERVER_KEEP_ALIVE:
print(
"Server keeps alive while client group~{} is exiting...".format(
group_id
)
)
else: else:
raise DGLError("Unexpected response: {}".format(res)) raise DGLError("Unexpected response: {}".format(res))
else: else:
......
...@@ -38,15 +38,12 @@ class ServerState: ...@@ -38,15 +38,12 @@ class ServerState:
Total number of edges Total number of edges
partition_book : GraphPartitionBook partition_book : GraphPartitionBook
Graph Partition book Graph Partition book
keep_alive : bool
whether to keep alive which supports any number of client groups connect
""" """
def __init__(self, kv_store, local_g, partition_book, keep_alive=False): def __init__(self, kv_store, local_g, partition_book):
self._kv_store = kv_store self._kv_store = kv_store
self._graph = local_g self._graph = local_g
self.partition_book = partition_book self.partition_book = partition_book
self._keep_alive = keep_alive
self._roles = {} self._roles = {}
@property @property
...@@ -72,10 +69,5 @@ class ServerState: ...@@ -72,10 +69,5 @@ class ServerState:
def graph(self, graph): def graph(self, graph):
self._graph = graph self._graph = graph
@property
def keep_alive(self):
"""Flag of whether keep alive"""
return self._keep_alive
_init_api("dgl.distributed.server_state") _init_api("dgl.distributed.server_state")
...@@ -44,7 +44,6 @@ def run_server( ...@@ -44,7 +44,6 @@ def run_server(
server_count, server_count,
num_clients, num_clients,
shared_mem, shared_mem,
keep_alive=False,
): ):
g = DistGraphServer( g = DistGraphServer(
server_id, server_id,
...@@ -54,7 +53,6 @@ def run_server( ...@@ -54,7 +53,6 @@ def run_server(
"/tmp/dist_graph/{}.json".format(graph_name), "/tmp/dist_graph/{}.json".format(graph_name),
disable_shared_mem=not shared_mem, disable_shared_mem=not shared_mem,
graph_format=["csc", "coo"], graph_format=["csc", "coo"],
keep_alive=keep_alive,
) )
print("start server", server_id) print("start server", server_id)
# verify dtype of underlying graph # verify dtype of underlying graph
...@@ -479,7 +477,6 @@ def check_dist_emb_server_client( ...@@ -479,7 +477,6 @@ def check_dist_emb_server_client(
# We cannot run multiple servers and clients on the same machine. # We cannot run multiple servers and clients on the same machine.
serv_ps = [] serv_ps = []
ctx = mp.get_context("spawn") ctx = mp.get_context("spawn")
keep_alive = num_groups > 1
for serv_id in range(num_servers): for serv_id in range(num_servers):
p = ctx.Process( p = ctx.Process(
target=run_server, target=run_server,
...@@ -489,7 +486,6 @@ def check_dist_emb_server_client( ...@@ -489,7 +486,6 @@ def check_dist_emb_server_client(
num_servers, num_servers,
num_clients, num_clients,
shared_mem, shared_mem,
keep_alive,
), ),
) )
serv_ps.append(p) serv_ps.append(p)
...@@ -519,11 +515,6 @@ def check_dist_emb_server_client( ...@@ -519,11 +515,6 @@ def check_dist_emb_server_client(
p.join() p.join()
assert p.exitcode == 0 assert p.exitcode == 0
if keep_alive:
for p in serv_ps:
assert p.is_alive()
# force shutdown server
dgl.distributed.shutdown_servers("kv_ip_config.txt", num_servers)
for p in serv_ps: for p in serv_ps:
p.join() p.join()
assert p.exitcode == 0 assert p.exitcode == 0
...@@ -546,7 +537,6 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1): ...@@ -546,7 +537,6 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
# We cannot run multiple servers and clients on the same machine. # We cannot run multiple servers and clients on the same machine.
serv_ps = [] serv_ps = []
ctx = mp.get_context("spawn") ctx = mp.get_context("spawn")
keep_alive = num_groups > 1
for serv_id in range(num_servers): for serv_id in range(num_servers):
p = ctx.Process( p = ctx.Process(
target=run_server, target=run_server,
...@@ -556,7 +546,6 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1): ...@@ -556,7 +546,6 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
num_servers, num_servers,
num_clients, num_clients,
shared_mem, shared_mem,
keep_alive,
), ),
) )
serv_ps.append(p) serv_ps.append(p)
...@@ -586,11 +575,6 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1): ...@@ -586,11 +575,6 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
p.join() p.join()
assert p.exitcode == 0 assert p.exitcode == 0
if keep_alive:
for p in serv_ps:
assert p.is_alive()
# force shutdown server
dgl.distributed.shutdown_servers("kv_ip_config.txt", num_servers)
for p in serv_ps: for p in serv_ps:
p.join() p.join()
assert p.exitcode == 0 assert p.exitcode == 0
...@@ -988,7 +972,6 @@ def check_dist_optim_server_client( ...@@ -988,7 +972,6 @@ def check_dist_optim_server_client(
num_servers, num_servers,
num_clients, num_clients,
True, True,
False,
), ),
) )
serv_ps.append(p) serv_ps.append(p)
......
...@@ -31,7 +31,6 @@ def start_server( ...@@ -31,7 +31,6 @@ def start_server(
disable_shared_mem, disable_shared_mem,
graph_name, graph_name,
graph_format=["csc", "coo"], graph_format=["csc", "coo"],
keep_alive=False,
): ):
g = DistGraphServer( g = DistGraphServer(
rank, rank,
...@@ -41,7 +40,6 @@ def start_server( ...@@ -41,7 +40,6 @@ def start_server(
tmpdir / (graph_name + ".json"), tmpdir / (graph_name + ".json"),
disable_shared_mem=disable_shared_mem, disable_shared_mem=disable_shared_mem,
graph_format=graph_format, graph_format=graph_format,
keep_alive=keep_alive,
) )
g.start() g.start()
...@@ -399,7 +397,6 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1): ...@@ -399,7 +397,6 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
pserver_list = [] pserver_list = []
ctx = mp.get_context("spawn") ctx = mp.get_context("spawn")
keep_alive = num_groups > 1
for i in range(num_server): for i in range(num_server):
p = ctx.Process( p = ctx.Process(
target=start_server, target=start_server,
...@@ -409,7 +406,6 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1): ...@@ -409,7 +406,6 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
num_server > 1, num_server > 1,
"test_sampling", "test_sampling",
["csc", "coo"], ["csc", "coo"],
keep_alive,
), ),
) )
p.start() p.start()
...@@ -439,11 +435,6 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1): ...@@ -439,11 +435,6 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
for p in pclient_list: for p in pclient_list:
p.join() p.join()
assert p.exitcode == 0 assert p.exitcode == 0
if keep_alive:
for p in pserver_list:
assert p.is_alive()
# force shutdown server
dgl.distributed.shutdown_servers("rpc_ip_config.txt", 1)
for p in pserver_list: for p in pserver_list:
p.join() p.join()
assert p.exitcode == 0 assert p.exitcode == 0
......
...@@ -52,7 +52,6 @@ def start_server( ...@@ -52,7 +52,6 @@ def start_server(
part_config, part_config,
disable_shared_mem, disable_shared_mem,
num_clients, num_clients,
keep_alive=False,
): ):
print("server: #clients=" + str(num_clients)) print("server: #clients=" + str(num_clients))
g = DistGraphServer( g = DistGraphServer(
...@@ -63,7 +62,6 @@ def start_server( ...@@ -63,7 +62,6 @@ def start_server(
part_config, part_config,
disable_shared_mem=disable_shared_mem, disable_shared_mem=disable_shared_mem,
graph_format=["csc", "coo"], graph_format=["csc", "coo"],
keep_alive=keep_alive,
) )
g.start() g.start()
...@@ -344,7 +342,6 @@ def test_dist_dataloader(num_server, num_workers, drop_last, num_groups): ...@@ -344,7 +342,6 @@ def test_dist_dataloader(num_server, num_workers, drop_last, num_groups):
part_config = os.path.join(test_dir, "test_sampling.json") part_config = os.path.join(test_dir, "test_sampling.json")
pserver_list = [] pserver_list = []
ctx = mp.get_context("spawn") ctx = mp.get_context("spawn")
keep_alive = num_groups > 1
for i in range(num_server): for i in range(num_server):
p = ctx.Process( p = ctx.Process(
target=start_server, target=start_server,
...@@ -354,7 +351,6 @@ def test_dist_dataloader(num_server, num_workers, drop_last, num_groups): ...@@ -354,7 +351,6 @@ def test_dist_dataloader(num_server, num_workers, drop_last, num_groups):
part_config, part_config,
num_server > 1, num_server > 1,
num_workers + 1, num_workers + 1,
keep_alive,
), ),
) )
p.start() p.start()
...@@ -389,11 +385,6 @@ def test_dist_dataloader(num_server, num_workers, drop_last, num_groups): ...@@ -389,11 +385,6 @@ def test_dist_dataloader(num_server, num_workers, drop_last, num_groups):
for p in ptrainer_list: for p in ptrainer_list:
p.join() p.join()
assert p.exitcode == 0 assert p.exitcode == 0
if keep_alive:
for p in pserver_list:
assert p.is_alive()
# force shutdown server
dgl.distributed.shutdown_servers("mp_ip_config.txt", 1)
for p in pserver_list: for p in pserver_list:
p.join() p.join()
assert p.exitcode == 0 assert p.exitcode == 0
......
...@@ -133,13 +133,12 @@ def start_server( ...@@ -133,13 +133,12 @@ def start_server(
num_clients, num_clients,
ip_config, ip_config,
server_id=0, server_id=0,
keep_alive=False,
num_servers=1, num_servers=1,
): ):
print("Sleep 1 seconds to test client re-connect.") print("Sleep 1 seconds to test client re-connect.")
time.sleep(1) time.sleep(1)
server_state = dgl.distributed.ServerState( server_state = dgl.distributed.ServerState(
None, local_g=None, partition_book=None, keep_alive=keep_alive None, local_g=None, partition_book=None
) )
dgl.distributed.register_service( dgl.distributed.register_service(
HELLO_SERVICE_ID, HelloRequest, HelloResponse HELLO_SERVICE_ID, HelloRequest, HelloResponse
...@@ -258,7 +257,7 @@ def test_rpc_timeout(): ...@@ -258,7 +257,7 @@ def test_rpc_timeout():
ip_config = "rpc_ip_config.txt" ip_config = "rpc_ip_config.txt"
generate_ip_config(ip_config, 1, 1) generate_ip_config(ip_config, 1, 1)
ctx = mp.get_context("spawn") ctx = mp.get_context("spawn")
pserver = ctx.Process(target=start_server, args=(1, ip_config, 0, False, 1)) pserver = ctx.Process(target=start_server, args=(1, ip_config, 0, 1))
pclient = ctx.Process(target=start_client_timeout, args=(ip_config, 0, 1)) pclient = ctx.Process(target=start_client_timeout, args=(ip_config, 0, 1))
pserver.start() pserver.start()
pclient.start() pclient.start()
...@@ -323,7 +322,7 @@ def test_multi_client(): ...@@ -323,7 +322,7 @@ def test_multi_client():
num_clients = 20 num_clients = 20
pserver = ctx.Process( pserver = ctx.Process(
target=start_server, target=start_server,
args=(num_clients, ip_config, 0, False, 1), args=(num_clients, ip_config, 0, 1),
) )
pclient_list = [] pclient_list = []
for i in range(num_clients): for i in range(num_clients):
...@@ -347,9 +346,7 @@ def test_multi_thread_rpc(): ...@@ -347,9 +346,7 @@ def test_multi_thread_rpc():
ctx = mp.get_context("spawn") ctx = mp.get_context("spawn")
pserver_list = [] pserver_list = []
for i in range(num_servers): for i in range(num_servers):
pserver = ctx.Process( pserver = ctx.Process(target=start_server, args=(1, ip_config, i, 1))
target=start_server, args=(1, ip_config, i, False, 1)
)
pserver.start() pserver.start()
pserver_list.append(pserver) pserver_list.append(pserver)
...@@ -386,49 +383,6 @@ def test_multi_thread_rpc(): ...@@ -386,49 +383,6 @@ def test_multi_thread_rpc():
pserver.join() pserver.join()
@unittest.skipIf(
True,
reason="Tests of multiple groups may fail and let's disable them for now.",
)
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
def test_multi_client_groups():
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
ip_config = "rpc_ip_config_mul_client_groups.txt"
num_machines = 5
# should test with larger number but due to possible port in-use issue.
num_servers = 1
generate_ip_config(ip_config, num_machines, num_servers)
# presssue test
num_clients = 2
num_groups = 2
ctx = mp.get_context("spawn")
pserver_list = []
for i in range(num_servers * num_machines):
pserver = ctx.Process(
target=start_server,
args=(num_clients, ip_config, i, True, num_servers),
)
pserver.start()
pserver_list.append(pserver)
pclient_list = []
for i in range(num_clients):
for group_id in range(num_groups):
pclient = ctx.Process(
target=start_client, args=(ip_config, group_id, num_servers)
)
pclient.start()
pclient_list.append(pclient)
for p in pclient_list:
p.join()
for p in pserver_list:
assert p.is_alive()
# force shutdown server
dgl.distributed.shutdown_servers(ip_config, num_servers)
for p in pserver_list:
p.join()
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet") @unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
def test_multi_client_connect(): def test_multi_client_connect():
reset_envs() reset_envs()
...@@ -439,7 +393,7 @@ def test_multi_client_connect(): ...@@ -439,7 +393,7 @@ def test_multi_client_connect():
num_clients = 1 num_clients = 1
pserver = ctx.Process( pserver = ctx.Process(
target=start_server, target=start_server,
args=(num_clients, ip_config, 0, False, 1), args=(num_clients, ip_config, 0, 1),
) )
# small max try times # small max try times
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment