Unverified Commit 5a290a56 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router][grpc-server] Fix gRPC server shutdown (#11094)

parent 580051c5
......@@ -116,16 +116,16 @@ class GrpcRequestManager:
self.port_args = port_args
# ZMQ Communication Setup (same pattern as TokenizerManager)
context = zmq.asyncio.Context(2)
self.context = zmq.asyncio.Context(2)
# Socket for receiving outputs from scheduler
self.recv_from_scheduler = get_zmq_socket(
context, zmq.PULL, port_args.detokenizer_ipc_name, bind=True
self.context, zmq.PULL, port_args.detokenizer_ipc_name, bind=True
)
# Socket for sending requests to scheduler
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.scheduler_input_ipc_name, bind=True
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name, bind=True
)
# State Management (from TokenizerManager)
......@@ -472,6 +472,15 @@ class GrpcRequestManager:
if self.gracefully_exit:
break
continue
except zmq.error.ZMQError as e:
# Socket closed or other ZMQ error - exit cleanly if shutting down
if self.gracefully_exit:
logger.debug(f"ZMQ recv interrupted during shutdown: {e}")
break
logger.error(
f"ZMQ error in handle loop: {e}\n{get_exception_traceback()}"
)
break
except Exception as e:
logger.error(f"Handle loop error: {e}\n{get_exception_traceback()}")
if self.gracefully_exit:
......@@ -722,8 +731,17 @@ class GrpcRequestManager:
logger.info("Shutting down GrpcRequestManager")
self.gracefully_exit = True
# Cancel all asyncio tasks FIRST - this will interrupt blocked recv() calls
for task in list(self.asyncio_tasks):
if not task.done():
task.cancel()
# Give tasks a moment to process cancellation
if self.asyncio_tasks:
await asyncio.gather(*list(self.asyncio_tasks), return_exceptions=True)
# Cancel all pending requests
for rid, state in self.rid_to_state.items():
for rid, state in list(self.rid_to_state.items()):
if not state.finished:
await state.out_queue.put(
{"error": "Server shutting down", "shutdown": True}
......@@ -731,14 +749,13 @@ class GrpcRequestManager:
state.finished = True
state.event.set()
# Wait for tasks to complete
if self.asyncio_tasks:
await asyncio.gather(*list(self.asyncio_tasks), return_exceptions=True)
# Close ZMQ sockets
self.recv_from_scheduler.close()
self.send_to_scheduler.close()
# Terminate the ZMQ context - this is critical for asyncio loop to exit cleanly
self.context.term()
logger.info("GrpcRequestManager shutdown complete")
def get_server_info(self) -> Dict[str, Any]:
......
......@@ -36,6 +36,20 @@ logger = logging.getLogger(__name__)
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
def _run_scheduler_with_signal_handling(*args, **kwargs):
"""
Wrapper for run_scheduler_process that ignores SIGINT.
The scheduler process should not handle Ctrl+C - it should only terminate
when the parent gRPC server exits (via kill_itself_when_parent_died).
"""
# Ignore SIGINT in this subprocess - let the parent handle it
signal.signal(signal.SIGINT, signal.SIG_IGN)
# Now run the actual scheduler process
run_scheduler_process(*args, **kwargs)
def _launch_scheduler_process_only(
server_args: ServerArgs,
port_args: Optional[PortArgs] = None,
......@@ -88,7 +102,7 @@ def _launch_scheduler_process_only(
)
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
proc = mp.Process(
target=run_scheduler_process,
target=_run_scheduler_with_signal_handling,
args=(
server_args,
port_args,
......@@ -676,19 +690,28 @@ async def serve_grpc(
await stop_event.wait()
finally:
logger.info("Shutting down gRPC server")
# Shutdown request manager first - this closes ZMQ sockets and stops background tasks
await servicer.shutdown()
# Stop the gRPC server
await server.stop(5.0)
# Terminate scheduler processes
# Terminate scheduler processes before exiting to avoid atexit hang
# The scheduler processes have SIGINT ignored, so they won't get KeyboardInterrupt
for i, proc in enumerate(scheduler_procs):
if proc and proc.is_alive():
if proc.is_alive():
logger.info(f"Terminating scheduler process {i}...")
proc.terminate()
proc.join(timeout=5.0)
proc.join(timeout=2.0)
if proc.is_alive():
logger.warning(f"Force killing scheduler process {i}...")
logger.warning(
f"Scheduler process {i} did not terminate, killing..."
)
proc.kill()
proc.join()
proc.join(timeout=1.0)
logger.info("All scheduler processes terminated")
def main():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment