Unverified Commit 4ea9d74a authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Simplify health check (#9034)

parent dd949ace
......@@ -26,7 +26,7 @@ import os
import threading
import time
from http import HTTPStatus
from typing import AsyncIterator, Callable, Dict, Optional
from typing import Any, AsyncIterator, Callable, Dict, List, Optional
# Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
......@@ -277,7 +277,7 @@ async def health_generate(request: Request) -> Response:
logger.info("Health check request received during shutdown. Returning 503.")
return Response(status_code=503)
if not _global_state.tokenizer_manager.server_status.is_healthy():
if _global_state.tokenizer_manager.server_status == ServerStatus.Starting:
return Response(status_code=503)
sampling_params = {"max_new_tokens": 1, "temperature": 0.0}
......@@ -317,7 +317,7 @@ async def health_generate(request: Request) -> Response:
if _global_state.tokenizer_manager.last_receive_tstamp > tic:
task.cancel()
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
_global_state.tokenizer_manager.health_check_failed = False
_global_state.tokenizer_manager.server_status = ServerStatus.Up
return Response(status_code=200)
task.cancel()
......@@ -331,7 +331,7 @@ async def health_generate(request: Request) -> Response:
f"last_heartbeat time: {last_receive_time}"
)
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
_global_state.tokenizer_manager.health_check_failed = True
_global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy
return Response(status_code=503)
......
......@@ -99,25 +99,24 @@ class GenerateReqInput:
stream: bool = False
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
log_metrics: bool = True
# Whether to return hidden states
return_hidden_states: Union[List[bool], bool] = False
# The modalities of the image data [image, multi-images, video]
modalities: Optional[List[str]] = None
# Session info for continual prompting
session_params: Optional[Union[List[Dict], Dict]] = None
# The path to the LoRA adaptors
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
# The uid of LoRA adaptors, should be initialized by tokenizer manager
lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None
# Session info for continual prompting
session_params: Optional[Union[List[Dict], Dict]] = None
# Custom logit processor for advanced sampling control. Must be a serialized instance
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
# Use the processor's `to_str()` method to generate the serialized string.
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
# Whether to return hidden states
return_hidden_states: Union[List[bool], bool] = False
# For disaggregated inference
bootstrap_host: Optional[Union[List[str], str]] = None
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
......
......@@ -269,10 +269,9 @@ class TokenizerManager:
self.asyncio_tasks = set()
# Health check
self.health_check_failed = False
self.server_status = ServerStatus.Starting
self.gracefully_exit = False
self.last_receive_tstamp = 0
self.server_status = ServerStatus.Starting
# Dumping
self.dump_requests_folder = "" # By default do not dump
......@@ -291,8 +290,8 @@ class TokenizerManager:
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
None
)
self._is_updating = False
self._is_updating_cond = asyncio.Condition()
self.is_pause = False
self.is_pause_cond = asyncio.Condition()
# LoRA
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
......@@ -476,15 +475,15 @@ class TokenizerManager:
self.auto_create_handle_loop()
obj.normalize_batch_and_arguments()
async with self._is_updating_cond:
await self._is_updating_cond.wait_for(lambda: not self._is_updating)
if self.log_requests:
max_length, skip_names, _ = self.log_request_metadata
logger.info(
f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
)
async with self.is_pause_cond:
await self.is_pause_cond.wait_for(lambda: not self.is_pause)
async with self.model_update_lock.reader_lock:
if obj.is_single:
tokenized_obj = await self._tokenize_one_request(obj)
......@@ -982,14 +981,14 @@ class TokenizerManager:
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
async def pause_generation(self):
async with self._is_updating_cond:
self._is_updating = True
async with self.is_pause_cond:
self.is_pause = True
self.abort_request(abort_all=True)
async def continue_generation(self):
async with self._is_updating_cond:
self._is_updating = False
self._is_updating_cond.notify_all()
async with self.is_pause_cond:
self.is_pause = False
self.is_pause_cond.notify_all()
async def update_weights_from_disk(
self,
......@@ -1474,7 +1473,7 @@ class TokenizerManager:
while True:
remain_num_req = len(self.rid_to_state)
if self.health_check_failed:
if self.server_status == ServerStatus.UnHealthy:
# if health check failed, we should exit immediately
logger.error(
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
......@@ -1965,10 +1964,6 @@ class ServerStatus(Enum):
Up = "Up"
Starting = "Starting"
UnHealthy = "UnHealthy"
Crashed = "Crashed"
def is_healthy(self) -> bool:
return self == ServerStatus.Up
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment