Unverified Commit a1e697b2 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[router] Improve cleanup logic (#2411)

parent a6ca736c
...@@ -10,12 +10,12 @@ import time ...@@ -10,12 +10,12 @@ import time
from typing import List from typing import List
import requests import requests
from setproctitle import setproctitle
from sglang_router.launch_router import RouterArgs, launch_router from sglang_router.launch_router import RouterArgs, launch_router
from sglang.srt.server import launch_server from sglang.srt.server import launch_server
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import is_port_available from sglang.srt.utils import is_port_available
from sglang.utils import get_exception_traceback
def setup_logger(): def setup_logger():
...@@ -34,10 +34,12 @@ def setup_logger(): ...@@ -34,10 +34,12 @@ def setup_logger():
return logger return logger
logger = setup_logger()
# Create new process group # Create new process group
def run_server(server_args, dp_rank): def run_server(server_args, dp_rank):
os.setpgrp() # Create new process group setproctitle(f"sglang::server")
# Set SGLANG_DP_RANK environment variable # Set SGLANG_DP_RANK environment variable
os.environ["SGLANG_DP_RANK"] = str(dp_rank) os.environ["SGLANG_DP_RANK"] = str(dp_rank)
...@@ -58,36 +60,6 @@ def launch_server_process( ...@@ -58,36 +60,6 @@ def launch_server_process(
return proc return proc
def cleanup_processes(processes: List[mp.Process]):
logger = logging.getLogger("router")
logger.info("Cleaning up processes...")
for proc in processes:
if proc.is_alive():
try:
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
proc.join(timeout=3)
if proc.is_alive():
logger.warning(
f"Process {proc.pid} did not terminate gracefully, force killing..."
)
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
except ProcessLookupError:
pass
def setup_signal_handlers(cleanup_func):
"""Setup handlers for various termination signals."""
def signal_handler(signum, frame):
cleanup_func()
sys.exit(1)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
if hasattr(signal, "SIGQUIT"):
signal.signal(signal.SIGQUIT, signal_handler)
def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool: def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool:
"""Wait for server to be healthy by checking /health endpoint.""" """Wait for server to be healthy by checking /health endpoint."""
start_time = time.time() start_time = time.time()
...@@ -117,8 +89,12 @@ def find_available_ports(base_port: int, count: int) -> List[int]: ...@@ -117,8 +89,12 @@ def find_available_ports(base_port: int, count: int) -> List[int]:
return available_ports return available_ports
def cleanup_processes(processes: List[mp.Process]):
for process in processes:
process.terminate()
def main(): def main():
logger = setup_logger()
# CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes # CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
mp.set_start_method("spawn") mp.set_start_method("spawn")
...@@ -148,52 +124,33 @@ def main(): ...@@ -148,52 +124,33 @@ def main():
# Start server processes # Start server processes
server_processes = [] server_processes = []
try: for i, worker_port in enumerate(worker_ports):
for i, worker_port in enumerate(worker_ports): logger.info(f"Launching DP server process {i} on port {worker_port}")
logger.info(f"Launching DP server process {i} on port {worker_port}") proc = launch_server_process(server_args, worker_port, i)
proc = launch_server_process(server_args, worker_port, i) server_processes.append(proc)
server_processes.append(proc)
signal.signal(signal.SIGINT, lambda sig, frame: cleanup_processes(server_processes))
# Setup cleanup handler signal.signal(
setup_signal_handlers(lambda: cleanup_processes(server_processes)) signal.SIGTERM, lambda sig, frame: cleanup_processes(server_processes)
)
# Wait for all servers to be healthy signal.signal(
all_healthy = True signal.SIGQUIT, lambda sig, frame: cleanup_processes(server_processes)
)
for port in worker_ports:
if not wait_for_server_health(server_args.host, port): for port in worker_ports:
logger.error(f"Server on port {port} failed to become healthy") if not wait_for_server_health(server_args.host, port):
all_healthy = False logger.error(f"Server on port {port} failed to become healthy")
break break
if not all_healthy: logger.info("All servers are healthy. Starting router...")
logger.error("Not all servers are healthy. Shutting down...")
cleanup_processes(server_processes) # Update router args with worker URLs
sys.exit(1) router_args.worker_urls = [
f"http://{server_args.host}:{port}" for port in worker_ports
logger.info("All servers are healthy. Starting router...") ]
# Update router args with worker URLs # Start the router
router_args.worker_urls = [ router = launch_router(router_args)
f"http://{server_args.host}:{port}" for port in worker_ports
]
# Start the router
router = launch_router(router_args)
if router is None:
logger.error("Failed to start router. Shutting down...")
cleanup_processes(server_processes)
sys.exit(1)
except KeyboardInterrupt:
logger.info("Received shutdown signal...")
except Exception as e:
logger.error(f"Error occurred: {e}")
logger.error(get_exception_traceback())
finally:
logger.info("Cleaning up processes...")
cleanup_processes(server_processes)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -6,7 +6,6 @@ from types import SimpleNamespace ...@@ -6,7 +6,6 @@ from types import SimpleNamespace
import requests import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
...@@ -104,23 +103,52 @@ def popen_launch_server( ...@@ -104,23 +103,52 @@ def popen_launch_server(
return process return process
def terminate_and_wait(process, timeout=300):
"""Terminate a process and wait until it is terminated.
Args:
process: subprocess.Popen object
timeout: maximum time to wait in seconds
Raises:
TimeoutError: if process does not terminate within timeout
"""
if process is None:
return
process.terminate()
start_time = time.time()
while process.poll() is None:
print(f"Terminating process {process.pid}")
if time.time() - start_time > timeout:
raise TimeoutError(
f"Process {process.pid} failed to terminate within {timeout}s"
)
time.sleep(1)
print(f"Process {process.pid} is successfully terminated")
class TestLaunchServer(unittest.TestCase): class TestLaunchServer(unittest.TestCase):
@classmethod def setUp(self):
def setUpClass(cls): self.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.model = DEFAULT_MODEL_NAME_FOR_TEST self.base_url = DEFAULT_URL_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST self.process = None
cls.process = None self.other_process = []
cls.other_process = []
def tearDown(self):
@classmethod print("Running tearDown...")
def tearDownClass(cls): if self.process:
kill_process_tree(cls.process.pid) terminate_and_wait(self.process)
for process in cls.other_process: for process in self.other_process:
kill_process_tree(process.pid) terminate_and_wait(process)
print("tearDown done")
def test_mmlu(self):
def test_1_mmlu(self):
print("Running test_1_mmlu...")
# DP size = 2 # DP size = 2
TestLaunchServer.process = popen_launch_router( self.process = popen_launch_router(
self.model, self.model,
self.base_url, self.base_url,
dp_size=2, dp_size=2,
...@@ -144,9 +172,10 @@ class TestLaunchServer(unittest.TestCase): ...@@ -144,9 +172,10 @@ class TestLaunchServer(unittest.TestCase):
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
self.assertGreaterEqual(score, THRESHOLD, msg) self.assertGreaterEqual(score, THRESHOLD, msg)
def test_add_and_remove_worker(self): def test_2_add_and_remove_worker(self):
print("Running test_2_add_and_remove_worker...")
# DP size = 1 # DP size = 1
TestLaunchServer.process = popen_launch_router( self.process = popen_launch_router(
self.model, self.model,
self.base_url, self.base_url,
dp_size=1, dp_size=1,
...@@ -159,7 +188,7 @@ class TestLaunchServer(unittest.TestCase): ...@@ -159,7 +188,7 @@ class TestLaunchServer(unittest.TestCase):
worker_process = popen_launch_server( worker_process = popen_launch_server(
self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
) )
TestLaunchServer.other_process.append(worker_process) self.other_process.append(worker_process)
# 2. use /add_worker api to add it the the router. It will be used by router after it is healthy # 2. use /add_worker api to add it the the router. It will be used by router after it is healthy
with requests.Session() as session: with requests.Session() as session:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment