Unverified Commit 0d9b6bfd authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[TensorpipeDeprecation] remove long live server support from DistDGL (#5931)

parent 4015c5fe
......@@ -16,6 +16,6 @@ from .partition import (
partition_graph,
)
from .rpc import *
from .rpc_client import connect_to_server, shutdown_servers
from .rpc_client import connect_to_server
from .rpc_server import start_server
from .server_state import ServerState
......@@ -4,7 +4,6 @@
MAX_QUEUE_SIZE = 20 * 1024 * 1024 * 1024
SERVER_EXIT = "server_exit"
SERVER_KEEP_ALIVE = "server_keep_alive"
DEFAULT_NTYPE = "_N"
DEFAULT_ETYPE = (DEFAULT_NTYPE, "_E", DEFAULT_NTYPE)
......@@ -263,7 +263,6 @@ def initialize(
formats = os.environ.get("DGL_GRAPH_FORMAT", "csc").split(",")
formats = [f.strip() for f in formats]
rpc.reset()
keep_alive = bool(int(os.environ.get("DGL_KEEP_ALIVE", 0)))
serv = DistGraphServer(
int(os.environ.get("DGL_SERVER_ID")),
os.environ.get("DGL_IP_CONFIG"),
......@@ -271,7 +270,6 @@ def initialize(
int(os.environ.get("DGL_NUM_CLIENT")),
os.environ.get("DGL_CONF_PATH"),
graph_format=formats,
keep_alive=keep_alive,
)
serv.start()
sys.exit()
......
......@@ -330,8 +330,6 @@ class DistGraphServer(KVServer):
Disable shared memory.
graph_format : str or list of str
The graph formats.
keep_alive : bool
Whether to keep server alive when clients exit
"""
def __init__(
......@@ -343,7 +341,6 @@ class DistGraphServer(KVServer):
part_config,
disable_shared_mem=False,
graph_format=("csc", "coo"),
keep_alive=False,
):
super(DistGraphServer, self).__init__(
server_id=server_id,
......@@ -353,7 +350,6 @@ class DistGraphServer(KVServer):
)
self.ip_config = ip_config
self.num_servers = num_servers
self.keep_alive = keep_alive
# Load graph partition data.
if self.is_backup_server():
# The backup server doesn't load the graph partition. It'll initialized afterwards.
......@@ -457,7 +453,6 @@ class DistGraphServer(KVServer):
kv_store=self,
local_g=self.client_g,
partition_book=self.gpb,
keep_alive=self.keep_alive,
)
print(
"start graph service on server {} for part {}".format(
......
......@@ -431,9 +431,6 @@ class GetSharedDataRequest(rpc.Request):
meta = {}
kv_store = server_state.kv_store
for name, data in kv_store.data_store.items():
if server_state.keep_alive:
if name not in kv_store.orig_data:
continue
meta[name] = (
F.shape(data),
F.reverse_data_type_dict[F.dtype(data)],
......
......@@ -11,7 +11,7 @@ from .. import backend as F
from .._ffi.function import _init_api
from .._ffi.object import ObjectBase, register_object
from ..base import DGLError
from .constants import SERVER_EXIT, SERVER_KEEP_ALIVE
from .constants import SERVER_EXIT
__all__ = [
"set_rank",
......@@ -1256,8 +1256,6 @@ class ShutDownRequest(Request):
def process_request(self, server_state):
assert self.client_id == 0
if server_state.keep_alive and not self.force_shutdown_server:
return SERVER_KEEP_ALIVE
finalize_server()
return SERVER_EXIT
......
......@@ -226,44 +226,3 @@ def connect_to_server(
atexit.register(exit_client)
set_initialized(True)
def shutdown_servers(ip_config, num_servers):
"""Issue commands to remote servers to shut them down.
This function is required to be called manually only when we
have booted servers which keep alive even clients exit. In
order to shut down server elegantly, we utilize existing
client logic/code to boot a special client which does nothing
but send shut down request to servers. Once such request is
received, servers will exit from endless wait loop, release
occupied resources and end its process. Please call this function
with same arguments used in `dgl.distributed.connect_to_server`.
Parameters
----------
ip_config : str
Path of server IP configuration file.
num_servers : int
server count on each machine.
Raises
------
ConnectionError : If anything wrong with the connection.
"""
rpc.register_service(rpc.SHUT_DOWN_SERVER, rpc.ShutDownRequest, None)
rpc.register_sig_handler()
server_namebook = rpc.read_ip_config(ip_config, num_servers)
num_servers = len(server_namebook)
rpc.create_sender(MAX_QUEUE_SIZE)
# Get connected with all server nodes
for server_id, addr in server_namebook.items():
server_ip = addr[1]
server_port = addr[2]
while not rpc.connect_receiver(server_ip, server_port, server_id):
time.sleep(1)
# send ShutDownRequest to all servers
req = rpc.ShutDownRequest(0, True)
for server_id in range(num_servers):
rpc.send_request(server_id, req)
rpc.finalize_sender()
......@@ -5,7 +5,7 @@ import time
from ..base import DGLError
from . import rpc
from .constants import MAX_QUEUE_SIZE, SERVER_EXIT, SERVER_KEEP_ALIVE
from .constants import MAX_QUEUE_SIZE, SERVER_EXIT
def start_server(
......@@ -52,8 +52,6 @@ def start_server(
assert max_queue_size > 0, (
"queue_size (%d) cannot be a negative number." % max_queue_size
)
if server_state.keep_alive:
assert False, "Long live server is not supported any more."
# Register signal handler.
rpc.register_sig_handler()
# Register some basic services
......@@ -146,12 +144,6 @@ def start_server(
if res == SERVER_EXIT:
print("Server is exiting...")
return
elif res == SERVER_KEEP_ALIVE:
print(
"Server keeps alive while client group~{} is exiting...".format(
group_id
)
)
else:
raise DGLError("Unexpected response: {}".format(res))
else:
......
......@@ -38,15 +38,12 @@ class ServerState:
Total number of edges
partition_book : GraphPartitionBook
Graph Partition book
keep_alive : bool
whether to keep alive which supports any number of client groups connect
"""
def __init__(self, kv_store, local_g, partition_book, keep_alive=False):
def __init__(self, kv_store, local_g, partition_book):
self._kv_store = kv_store
self._graph = local_g
self.partition_book = partition_book
self._keep_alive = keep_alive
self._roles = {}
@property
......@@ -72,10 +69,5 @@ class ServerState:
def graph(self, graph):
self._graph = graph
@property
def keep_alive(self):
"""Flag of whether keep alive"""
return self._keep_alive
_init_api("dgl.distributed.server_state")
......@@ -44,7 +44,6 @@ def run_server(
server_count,
num_clients,
shared_mem,
keep_alive=False,
):
g = DistGraphServer(
server_id,
......@@ -54,7 +53,6 @@ def run_server(
"/tmp/dist_graph/{}.json".format(graph_name),
disable_shared_mem=not shared_mem,
graph_format=["csc", "coo"],
keep_alive=keep_alive,
)
print("start server", server_id)
# verify dtype of underlying graph
......@@ -479,7 +477,6 @@ def check_dist_emb_server_client(
# We cannot run multiple servers and clients on the same machine.
serv_ps = []
ctx = mp.get_context("spawn")
keep_alive = num_groups > 1
for serv_id in range(num_servers):
p = ctx.Process(
target=run_server,
......@@ -489,7 +486,6 @@ def check_dist_emb_server_client(
num_servers,
num_clients,
shared_mem,
keep_alive,
),
)
serv_ps.append(p)
......@@ -519,11 +515,6 @@ def check_dist_emb_server_client(
p.join()
assert p.exitcode == 0
if keep_alive:
for p in serv_ps:
assert p.is_alive()
# force shutdown server
dgl.distributed.shutdown_servers("kv_ip_config.txt", num_servers)
for p in serv_ps:
p.join()
assert p.exitcode == 0
......@@ -546,7 +537,6 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
# We cannot run multiple servers and clients on the same machine.
serv_ps = []
ctx = mp.get_context("spawn")
keep_alive = num_groups > 1
for serv_id in range(num_servers):
p = ctx.Process(
target=run_server,
......@@ -556,7 +546,6 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
num_servers,
num_clients,
shared_mem,
keep_alive,
),
)
serv_ps.append(p)
......@@ -586,11 +575,6 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
p.join()
assert p.exitcode == 0
if keep_alive:
for p in serv_ps:
assert p.is_alive()
# force shutdown server
dgl.distributed.shutdown_servers("kv_ip_config.txt", num_servers)
for p in serv_ps:
p.join()
assert p.exitcode == 0
......@@ -988,7 +972,6 @@ def check_dist_optim_server_client(
num_servers,
num_clients,
True,
False,
),
)
serv_ps.append(p)
......
......@@ -31,7 +31,6 @@ def start_server(
disable_shared_mem,
graph_name,
graph_format=["csc", "coo"],
keep_alive=False,
):
g = DistGraphServer(
rank,
......@@ -41,7 +40,6 @@ def start_server(
tmpdir / (graph_name + ".json"),
disable_shared_mem=disable_shared_mem,
graph_format=graph_format,
keep_alive=keep_alive,
)
g.start()
......@@ -399,7 +397,6 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
pserver_list = []
ctx = mp.get_context("spawn")
keep_alive = num_groups > 1
for i in range(num_server):
p = ctx.Process(
target=start_server,
......@@ -409,7 +406,6 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
num_server > 1,
"test_sampling",
["csc", "coo"],
keep_alive,
),
)
p.start()
......@@ -439,11 +435,6 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
for p in pclient_list:
p.join()
assert p.exitcode == 0
if keep_alive:
for p in pserver_list:
assert p.is_alive()
# force shutdown server
dgl.distributed.shutdown_servers("rpc_ip_config.txt", 1)
for p in pserver_list:
p.join()
assert p.exitcode == 0
......
......@@ -52,7 +52,6 @@ def start_server(
part_config,
disable_shared_mem,
num_clients,
keep_alive=False,
):
print("server: #clients=" + str(num_clients))
g = DistGraphServer(
......@@ -63,7 +62,6 @@ def start_server(
part_config,
disable_shared_mem=disable_shared_mem,
graph_format=["csc", "coo"],
keep_alive=keep_alive,
)
g.start()
......@@ -344,7 +342,6 @@ def test_dist_dataloader(num_server, num_workers, drop_last, num_groups):
part_config = os.path.join(test_dir, "test_sampling.json")
pserver_list = []
ctx = mp.get_context("spawn")
keep_alive = num_groups > 1
for i in range(num_server):
p = ctx.Process(
target=start_server,
......@@ -354,7 +351,6 @@ def test_dist_dataloader(num_server, num_workers, drop_last, num_groups):
part_config,
num_server > 1,
num_workers + 1,
keep_alive,
),
)
p.start()
......@@ -389,11 +385,6 @@ def test_dist_dataloader(num_server, num_workers, drop_last, num_groups):
for p in ptrainer_list:
p.join()
assert p.exitcode == 0
if keep_alive:
for p in pserver_list:
assert p.is_alive()
# force shutdown server
dgl.distributed.shutdown_servers("mp_ip_config.txt", 1)
for p in pserver_list:
p.join()
assert p.exitcode == 0
......
......@@ -133,13 +133,12 @@ def start_server(
num_clients,
ip_config,
server_id=0,
keep_alive=False,
num_servers=1,
):
print("Sleep 1 seconds to test client re-connect.")
time.sleep(1)
server_state = dgl.distributed.ServerState(
None, local_g=None, partition_book=None, keep_alive=keep_alive
None, local_g=None, partition_book=None
)
dgl.distributed.register_service(
HELLO_SERVICE_ID, HelloRequest, HelloResponse
......@@ -258,7 +257,7 @@ def test_rpc_timeout():
ip_config = "rpc_ip_config.txt"
generate_ip_config(ip_config, 1, 1)
ctx = mp.get_context("spawn")
pserver = ctx.Process(target=start_server, args=(1, ip_config, 0, False, 1))
pserver = ctx.Process(target=start_server, args=(1, ip_config, 0, 1))
pclient = ctx.Process(target=start_client_timeout, args=(ip_config, 0, 1))
pserver.start()
pclient.start()
......@@ -323,7 +322,7 @@ def test_multi_client():
num_clients = 20
pserver = ctx.Process(
target=start_server,
args=(num_clients, ip_config, 0, False, 1),
args=(num_clients, ip_config, 0, 1),
)
pclient_list = []
for i in range(num_clients):
......@@ -347,9 +346,7 @@ def test_multi_thread_rpc():
ctx = mp.get_context("spawn")
pserver_list = []
for i in range(num_servers):
pserver = ctx.Process(
target=start_server, args=(1, ip_config, i, False, 1)
)
pserver = ctx.Process(target=start_server, args=(1, ip_config, i, 1))
pserver.start()
pserver_list.append(pserver)
......@@ -386,49 +383,6 @@ def test_multi_thread_rpc():
pserver.join()
@unittest.skipIf(
True,
reason="Tests of multiple groups may fail and let's disable them for now.",
)
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
def test_multi_client_groups():
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
ip_config = "rpc_ip_config_mul_client_groups.txt"
num_machines = 5
# should test with larger number but due to possible port in-use issue.
num_servers = 1
generate_ip_config(ip_config, num_machines, num_servers)
# presssue test
num_clients = 2
num_groups = 2
ctx = mp.get_context("spawn")
pserver_list = []
for i in range(num_servers * num_machines):
pserver = ctx.Process(
target=start_server,
args=(num_clients, ip_config, i, True, num_servers),
)
pserver.start()
pserver_list.append(pserver)
pclient_list = []
for i in range(num_clients):
for group_id in range(num_groups):
pclient = ctx.Process(
target=start_client, args=(ip_config, group_id, num_servers)
)
pclient.start()
pclient_list.append(pclient)
for p in pclient_list:
p.join()
for p in pserver_list:
assert p.is_alive()
# force shutdown server
dgl.distributed.shutdown_servers(ip_config, num_servers)
for p in pserver_list:
p.join()
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
def test_multi_client_connect():
reset_envs()
......@@ -439,7 +393,7 @@ def test_multi_client_connect():
num_clients = 1
pserver = ctx.Process(
target=start_server,
args=(num_clients, ip_config, 0, False, 1),
args=(num_clients, ip_config, 0, 1),
)
# small max try times
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment