Unverified Commit 4ea9d74a authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Simplify health check (#9034)

parent dd949ace
...@@ -26,7 +26,7 @@ import os ...@@ -26,7 +26,7 @@ import os
import threading import threading
import time import time
from http import HTTPStatus from http import HTTPStatus
from typing import AsyncIterator, Callable, Dict, Optional from typing import Any, AsyncIterator, Callable, Dict, List, Optional
# Fix a bug of Python threading # Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None) setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
...@@ -277,7 +277,7 @@ async def health_generate(request: Request) -> Response: ...@@ -277,7 +277,7 @@ async def health_generate(request: Request) -> Response:
logger.info("Health check request received during shutdown. Returning 503.") logger.info("Health check request received during shutdown. Returning 503.")
return Response(status_code=503) return Response(status_code=503)
if not _global_state.tokenizer_manager.server_status.is_healthy(): if _global_state.tokenizer_manager.server_status == ServerStatus.Starting:
return Response(status_code=503) return Response(status_code=503)
sampling_params = {"max_new_tokens": 1, "temperature": 0.0} sampling_params = {"max_new_tokens": 1, "temperature": 0.0}
...@@ -317,7 +317,7 @@ async def health_generate(request: Request) -> Response: ...@@ -317,7 +317,7 @@ async def health_generate(request: Request) -> Response:
if _global_state.tokenizer_manager.last_receive_tstamp > tic: if _global_state.tokenizer_manager.last_receive_tstamp > tic:
task.cancel() task.cancel()
_global_state.tokenizer_manager.rid_to_state.pop(rid, None) _global_state.tokenizer_manager.rid_to_state.pop(rid, None)
_global_state.tokenizer_manager.health_check_failed = False _global_state.tokenizer_manager.server_status = ServerStatus.Up
return Response(status_code=200) return Response(status_code=200)
task.cancel() task.cancel()
...@@ -331,7 +331,7 @@ async def health_generate(request: Request) -> Response: ...@@ -331,7 +331,7 @@ async def health_generate(request: Request) -> Response:
f"last_heartbeat time: {last_receive_time}" f"last_heartbeat time: {last_receive_time}"
) )
_global_state.tokenizer_manager.rid_to_state.pop(rid, None) _global_state.tokenizer_manager.rid_to_state.pop(rid, None)
_global_state.tokenizer_manager.health_check_failed = True _global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy
return Response(status_code=503) return Response(status_code=503)
......
...@@ -99,25 +99,24 @@ class GenerateReqInput: ...@@ -99,25 +99,24 @@ class GenerateReqInput:
stream: bool = False stream: bool = False
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics) # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
log_metrics: bool = True log_metrics: bool = True
# Whether to return hidden states
return_hidden_states: Union[List[bool], bool] = False
# The modalities of the image data [image, multi-images, video] # The modalities of the image data [image, multi-images, video]
modalities: Optional[List[str]] = None modalities: Optional[List[str]] = None
# Session info for continual prompting
session_params: Optional[Union[List[Dict], Dict]] = None
# The path to the LoRA adaptors # The path to the LoRA adaptors
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
# The uid of LoRA adaptors, should be initialized by tokenizer manager # The uid of LoRA adaptors, should be initialized by tokenizer manager
lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None
# Session info for continual prompting
session_params: Optional[Union[List[Dict], Dict]] = None
# Custom logit processor for advanced sampling control. Must be a serialized instance # Custom logit processor for advanced sampling control. Must be a serialized instance
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
# Use the processor's `to_str()` method to generate the serialized string. # Use the processor's `to_str()` method to generate the serialized string.
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
# Whether to return hidden states
return_hidden_states: Union[List[bool], bool] = False
# For disaggregated inference # For disaggregated inference
bootstrap_host: Optional[Union[List[str], str]] = None bootstrap_host: Optional[Union[List[str], str]] = None
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
......
...@@ -269,10 +269,9 @@ class TokenizerManager: ...@@ -269,10 +269,9 @@ class TokenizerManager:
self.asyncio_tasks = set() self.asyncio_tasks = set()
# Health check # Health check
self.health_check_failed = False self.server_status = ServerStatus.Starting
self.gracefully_exit = False self.gracefully_exit = False
self.last_receive_tstamp = 0 self.last_receive_tstamp = 0
self.server_status = ServerStatus.Starting
# Dumping # Dumping
self.dump_requests_folder = "" # By default do not dump self.dump_requests_folder = "" # By default do not dump
...@@ -291,8 +290,8 @@ class TokenizerManager: ...@@ -291,8 +290,8 @@ class TokenizerManager:
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = ( self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
None None
) )
self._is_updating = False self.is_pause = False
self._is_updating_cond = asyncio.Condition() self.is_pause_cond = asyncio.Condition()
# LoRA # LoRA
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`. # Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
...@@ -476,15 +475,15 @@ class TokenizerManager: ...@@ -476,15 +475,15 @@ class TokenizerManager:
self.auto_create_handle_loop() self.auto_create_handle_loop()
obj.normalize_batch_and_arguments() obj.normalize_batch_and_arguments()
async with self._is_updating_cond:
await self._is_updating_cond.wait_for(lambda: not self._is_updating)
if self.log_requests: if self.log_requests:
max_length, skip_names, _ = self.log_request_metadata max_length, skip_names, _ = self.log_request_metadata
logger.info( logger.info(
f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}" f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
) )
async with self.is_pause_cond:
await self.is_pause_cond.wait_for(lambda: not self.is_pause)
async with self.model_update_lock.reader_lock: async with self.model_update_lock.reader_lock:
if obj.is_single: if obj.is_single:
tokenized_obj = await self._tokenize_one_request(obj) tokenized_obj = await self._tokenize_one_request(obj)
...@@ -982,14 +981,14 @@ class TokenizerManager: ...@@ -982,14 +981,14 @@ class TokenizerManager:
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD) await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
async def pause_generation(self): async def pause_generation(self):
async with self._is_updating_cond: async with self.is_pause_cond:
self._is_updating = True self.is_pause = True
self.abort_request(abort_all=True) self.abort_request(abort_all=True)
async def continue_generation(self): async def continue_generation(self):
async with self._is_updating_cond: async with self.is_pause_cond:
self._is_updating = False self.is_pause = False
self._is_updating_cond.notify_all() self.is_pause_cond.notify_all()
async def update_weights_from_disk( async def update_weights_from_disk(
self, self,
...@@ -1474,7 +1473,7 @@ class TokenizerManager: ...@@ -1474,7 +1473,7 @@ class TokenizerManager:
while True: while True:
remain_num_req = len(self.rid_to_state) remain_num_req = len(self.rid_to_state)
if self.health_check_failed: if self.server_status == ServerStatus.UnHealthy:
# if health check failed, we should exit immediately # if health check failed, we should exit immediately
logger.error( logger.error(
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d", "Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
...@@ -1965,10 +1964,6 @@ class ServerStatus(Enum): ...@@ -1965,10 +1964,6 @@ class ServerStatus(Enum):
Up = "Up" Up = "Up"
Starting = "Starting" Starting = "Starting"
UnHealthy = "UnHealthy" UnHealthy = "UnHealthy"
Crashed = "Crashed"
def is_healthy(self) -> bool:
return self == ServerStatus.Up
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode: def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment