Unverified Commit b520958e authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Replace fake health check with correct ones (#11387)

parent fa7e2c30
...@@ -27,6 +27,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -27,6 +27,7 @@ from sglang.srt.managers.io_struct import (
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
) )
from sglang.srt.managers.scheduler import is_health_check_generate_req
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_zmq_socket, kill_process_tree from sglang.srt.utils import get_zmq_socket, kill_process_tree
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
...@@ -338,12 +339,9 @@ class GrpcRequestManager: ...@@ -338,12 +339,9 @@ class GrpcRequestManager:
break break
except asyncio.TimeoutError: except asyncio.TimeoutError:
# Timeout waiting for response - abort and cleanup # Timeout is for periodic client cancellation check
logger.warning( # Continue waiting for scheduler response
f"Timeout waiting for response for request {request_id}" continue
)
await self.abort_request(request_id)
return
finally: finally:
# Always clean up request state when exiting # Always clean up request state when exiting
...@@ -412,6 +410,10 @@ class GrpcRequestManager: ...@@ -412,6 +410,10 @@ class GrpcRequestManager:
async def abort_request(self, request_id: str) -> bool: async def abort_request(self, request_id: str) -> bool:
"""Abort a running request.""" """Abort a running request."""
# Skip aborting health check requests (they clean themselves up)
if request_id.startswith("HEALTH_CHECK"):
return False
if request_id not in self.rid_to_state: if request_id not in self.rid_to_state:
return False return False
......
...@@ -197,7 +197,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -197,7 +197,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
context: grpc.aio.ServicerContext, context: grpc.aio.ServicerContext,
) -> AsyncIterator[sglang_scheduler_pb2.GenerateResponse]: ) -> AsyncIterator[sglang_scheduler_pb2.GenerateResponse]:
"""Handle generation requests with streaming responses.""" """Handle generation requests with streaming responses."""
logger.debug(f"Receive generation request: {request.request_id}") logger.info(f"Receive generation request: {request.request_id}")
try: try:
# Convert gRPC request to internal format # Convert gRPC request to internal format
...@@ -211,6 +211,13 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -211,6 +211,13 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
) )
async for output in response_generator: async for output in response_generator:
# Check if client cancelled before processing/yielding
if context.cancelled():
logger.info(f"Client cancelled request {request.request_id}")
# Explicitly abort the request to notify scheduler
await self.request_manager.abort_request(request.request_id)
break
# Handle batch responses (for n>1 non-streaming) # Handle batch responses (for n>1 non-streaming)
if isinstance(output, list): if isinstance(output, list):
for batch_output in output: for batch_output in output:
...@@ -268,7 +275,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -268,7 +275,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
_context: grpc.aio.ServicerContext, _context: grpc.aio.ServicerContext,
) -> sglang_scheduler_pb2.EmbedResponse: ) -> sglang_scheduler_pb2.EmbedResponse:
"""Handle embedding requests.""" """Handle embedding requests."""
logger.debug(f"Receive embedding request: {request.request_id}") logger.info(f"Receive embedding request: {request.request_id}")
try: try:
# Convert request # Convert request
...@@ -313,9 +320,86 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -313,9 +320,86 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
request: sglang_scheduler_pb2.HealthCheckRequest, request: sglang_scheduler_pb2.HealthCheckRequest,
context: grpc.aio.ServicerContext, context: grpc.aio.ServicerContext,
) -> sglang_scheduler_pb2.HealthCheckResponse: ) -> sglang_scheduler_pb2.HealthCheckResponse:
"""Health check - always returns healthy after server started.""" """
Check the health of the inference server by sending a special request to generate one token.
Similar to HTTP server's /health endpoint.
"""
logger.info("Receive health check request")
if self.request_manager.gracefully_exit:
logger.info(
"Health check request received during shutdown. Returning unhealthy."
)
return sglang_scheduler_pb2.HealthCheckResponse(
healthy=False, message="Server is shutting down"
)
# Create a special health check request
rid = f"HEALTH_CHECK_{time.time()}"
sampling_params = SGLSamplingParams(max_new_tokens=1, temperature=0.0)
sampling_params.normalize(tokenizer=None)
# Create health check request
is_generation = self.scheduler_info.get("is_generation", True)
if is_generation:
health_req = TokenizedGenerateReqInput(
rid=rid,
input_text="",
input_ids=[0],
sampling_params=sampling_params,
return_logprob=False,
logprob_start_len=-1,
top_logprobs_num=0,
stream=False,
mm_inputs=None,
token_ids_logprob=None,
)
# Set disaggregation params if needed
if self.server_args.disaggregation_mode != DisaggregationMode.NULL:
health_req.bootstrap_host = FAKE_BOOTSTRAP_HOST
health_req.bootstrap_room = 0
else:
health_req = TokenizedEmbeddingReqInput(
rid=rid,
input_text="",
input_ids=[0],
)
# Submit health check request
async def run_health_check():
try:
async for _ in self.request_manager.generate_request(
obj=health_req,
request_id=rid,
):
# Got at least one response, server is healthy
return True
except Exception as e:
logger.warning(f"Health check failed: {e}")
return False
return False
task = asyncio.create_task(run_health_check())
# Wait for response with timeout
tic = time.time()
while time.time() < tic + HEALTH_CHECK_TIMEOUT:
await asyncio.sleep(1)
# Check if we got a response from scheduler
if self.request_manager.last_receive_tstamp > tic:
task.cancel()
# Clean up health check state
self.request_manager._cleanup_request_state(rid)
return sglang_scheduler_pb2.HealthCheckResponse(
healthy=True, message="Health check passed"
)
# Timeout - server not responding
task.cancel()
self.request_manager._cleanup_request_state(rid)
logger.warning(f"Health check timeout after {HEALTH_CHECK_TIMEOUT}s")
return sglang_scheduler_pb2.HealthCheckResponse( return sglang_scheduler_pb2.HealthCheckResponse(
healthy=True, message="Health check passed" healthy=False, message=f"Health check timeout after {HEALTH_CHECK_TIMEOUT}s"
) )
async def Abort( async def Abort(
...@@ -324,7 +408,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -324,7 +408,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
_context: grpc.aio.ServicerContext, _context: grpc.aio.ServicerContext,
) -> sglang_scheduler_pb2.AbortResponse: ) -> sglang_scheduler_pb2.AbortResponse:
"""Abort an ongoing request.""" """Abort an ongoing request."""
logger.debug(f"Receive abort request: {request.request_id}") logger.info(f"Receive abort request: {request.request_id}")
try: try:
success = await self.request_manager.abort_request(request.request_id) success = await self.request_manager.abort_request(request.request_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment