Unverified Commit 4540a466 authored by ybyang's avatar ybyang Committed by GitHub
Browse files

[Feature] Simple Improve Health Check Mechanism for Production-Grade Stability (#8115)


Signed-off-by: default avatarybyang <ybyang7@iflytek.com>
parent abda2542
...@@ -65,6 +65,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs ...@@ -65,6 +65,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import ( from sglang.srt.utils import (
MultiprocessingSerializer, MultiprocessingSerializer,
ServerStatus,
assert_pkg_version, assert_pkg_version,
configure_logger, configure_logger,
get_zmq_socket, get_zmq_socket,
...@@ -73,6 +74,7 @@ from sglang.srt.utils import ( ...@@ -73,6 +74,7 @@ from sglang.srt.utils import (
launch_dummy_health_check_server, launch_dummy_health_check_server,
maybe_set_triton_cache_manager, maybe_set_triton_cache_manager,
prepare_model_and_tokenizer, prepare_model_and_tokenizer,
report_health,
set_prometheus_multiproc_dir, set_prometheus_multiproc_dir,
set_ulimit, set_ulimit,
) )
...@@ -661,6 +663,7 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -661,6 +663,7 @@ def _set_envs_and_config(server_args: ServerArgs):
def sigchld_handler(signum, frame): def sigchld_handler(signum, frame):
pid, exitcode = os.waitpid(0, os.WNOHANG) pid, exitcode = os.waitpid(0, os.WNOHANG)
if exitcode != 0: if exitcode != 0:
report_health(ServerStatus.Crashed, server_args.host, server_args.port)
logger.warning( logger.warning(
f"Child process unexpectedly failed with {exitcode=}. {pid=}" f"Child process unexpectedly failed with {exitcode=}. {pid=}"
) )
...@@ -674,6 +677,7 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -674,6 +677,7 @@ def _set_envs_and_config(server_args: ServerArgs):
logger.error( logger.error(
"Received sigquit from a child process. It usually means the child failed." "Received sigquit from a child process. It usually means the child failed."
) )
report_health(ServerStatus.Crashed, server_args.host, server_args.port)
kill_process_tree(os.getpid()) kill_process_tree(os.getpid())
signal.signal(signal.SIGQUIT, sigquit_handler) signal.signal(signal.SIGQUIT, sigquit_handler)
......
...@@ -77,6 +77,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -77,6 +77,7 @@ from sglang.srt.managers.io_struct import (
ParseFunctionCallReq, ParseFunctionCallReq,
ProfileReqInput, ProfileReqInput,
ReleaseMemoryOccupationReqInput, ReleaseMemoryOccupationReqInput,
ReportHealthInput,
ResumeMemoryOccupationReqInput, ResumeMemoryOccupationReqInput,
SeparateReasoningReqInput, SeparateReasoningReqInput,
SetInternalStateReq, SetInternalStateReq,
...@@ -93,6 +94,7 @@ from sglang.srt.metrics.func_timer import enable_func_timer ...@@ -93,6 +94,7 @@ from sglang.srt.metrics.func_timer import enable_func_timer
from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
ServerStatus,
add_api_key_middleware, add_api_key_middleware,
add_prometheus_middleware, add_prometheus_middleware,
delete_directory, delete_directory,
...@@ -220,8 +222,31 @@ HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20)) ...@@ -220,8 +222,31 @@ HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
@app.get("/health") @app.get("/health")
async def health() -> Response: async def health() -> Response:
"""Check the health of the http server.""" """Check the status of the http server."""
return Response(status_code=200) code = HTTPStatus.SERVICE_UNAVAILABLE.value
if _global_state.tokenizer_manager.server_status == ServerStatus.Up:
code = HTTPStatus.OK.value
return Response(
status_code=code,
content=json.dumps(
{"status": _global_state.tokenizer_manager.server_status.value}
),
)
@app.post("/health")
async def health_update(obj: ReportHealthInput, request: Request) -> Response:
"""Update the Status of the http server."""
try:
server_status = ServerStatus(obj.status)
_global_state.tokenizer_manager.server_status = server_status
if server_status != ServerStatus.Up:
return Response(
status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, content=obj.msg
)
except Exception as e:
logger.error(e)
return Response(status_code=HTTPStatus.SERVICE_UNAVAILABLE.value)
@app.get("/health_generate") @app.get("/health_generate")
...@@ -256,7 +281,7 @@ async def health_generate(request: Request) -> Response: ...@@ -256,7 +281,7 @@ async def health_generate(request: Request) -> Response:
if _global_state.tokenizer_manager.last_receive_tstamp > tic: if _global_state.tokenizer_manager.last_receive_tstamp > tic:
task.cancel() task.cancel()
_global_state.tokenizer_manager.rid_to_state.pop(rid, None) _global_state.tokenizer_manager.rid_to_state.pop(rid, None)
_global_state.tokenizer_manager.health_check_failed = False _global_state.tokenizer_manager.server_status = ServerStatus.Up
return Response(status_code=200) return Response(status_code=200)
task.cancel() task.cancel()
...@@ -270,7 +295,7 @@ async def health_generate(request: Request) -> Response: ...@@ -270,7 +295,7 @@ async def health_generate(request: Request) -> Response:
f"last_heartbeat time: {last_receive_time}" f"last_heartbeat time: {last_receive_time}"
) )
_global_state.tokenizer_manager.rid_to_state.pop(rid, None) _global_state.tokenizer_manager.rid_to_state.pop(rid, None)
_global_state.tokenizer_manager.health_check_failed = True _global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy
return Response(status_code=503) return Response(status_code=503)
...@@ -1022,9 +1047,13 @@ def _execute_server_warmup( ...@@ -1022,9 +1047,13 @@ def _execute_server_warmup(
headers=headers, headers=headers,
timeout=600, timeout=600,
) )
assert res.status_code == 200, f"{res}" if res.status_code == 200:
_global_state.tokenizer_manager.server_status = ServerStatus.Up
else:
_global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy
logger.info(f"{res}")
else: else:
logger.info(f"Start of prefill warmup ...") logger.info(f"Start of prefill/decode warmup ...")
json_data = { json_data = {
"sampling_params": { "sampling_params": {
"temperature": 0.0, "temperature": 0.0,
...@@ -1046,15 +1075,25 @@ def _execute_server_warmup( ...@@ -1046,15 +1075,25 @@ def _execute_server_warmup(
headers=headers, headers=headers,
timeout=1800, # because of deep gemm precache is very long if not precache. timeout=1800, # because of deep gemm precache is very long if not precache.
) )
logger.info( if res.status_code == 200:
f"End of prefill warmup with status {res.status_code}, resp: {res.json()}" logger.info(
) f"End of prefill disaggregation mode warmup with status {res.status_code}, resp: {res.json()}"
)
_global_state.tokenizer_manager.server_status = ServerStatus.Up
else:
logger.info(
"Prefill disaggregation mode warm Up Failed, status code: {}".format(
res.status_code
)
)
_global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy
except Exception: except Exception:
last_traceback = get_exception_traceback() last_traceback = get_exception_traceback()
if pipe_finish_writer is not None: if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback) pipe_finish_writer.send(last_traceback)
logger.error(f"Initialization failed. warmup error: {last_traceback}") logger.error(f"Initialization failed. warmup error: {last_traceback}")
_global_state.tokenizer_manager.server_status = ServerStatus.Crashed
kill_process_tree(os.getpid()) kill_process_tree(os.getpid())
return False return False
......
...@@ -1083,3 +1083,9 @@ class LoRAUpdateResult: ...@@ -1083,3 +1083,9 @@ class LoRAUpdateResult:
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
@dataclass
class ReportHealthInput:
status: str
msg: Optional[str] = ""
...@@ -143,6 +143,7 @@ from sglang.srt.two_batch_overlap import TboDPAttentionPreparer ...@@ -143,6 +143,7 @@ from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
from sglang.srt.utils import ( from sglang.srt.utils import (
DeepEPMode, DeepEPMode,
DynamicGradMode, DynamicGradMode,
ServerStatus,
broadcast_pyobj, broadcast_pyobj,
configure_gc_logger, configure_gc_logger,
configure_logger, configure_logger,
...@@ -154,6 +155,7 @@ from sglang.srt.utils import ( ...@@ -154,6 +155,7 @@ from sglang.srt.utils import (
kill_itself_when_parent_died, kill_itself_when_parent_died,
point_to_point_pyobj, point_to_point_pyobj,
pyspy_dump_schedulers, pyspy_dump_schedulers,
report_health,
require_mlp_sync, require_mlp_sync,
require_mlp_tp_gather, require_mlp_tp_gather,
set_gpu_proc_affinity, set_gpu_proc_affinity,
...@@ -2964,4 +2966,5 @@ def run_scheduler_process( ...@@ -2964,4 +2966,5 @@ def run_scheduler_process(
except Exception: except Exception:
traceback = get_exception_traceback() traceback = get_exception_traceback()
logger.error(f"Scheduler hit an exception: {traceback}") logger.error(f"Scheduler hit an exception: {traceback}")
report_health(ServerStatus.Crashed, server_args.host, ServerArgs.port)
parent_process.send_signal(signal.SIGQUIT) parent_process.send_signal(signal.SIGQUIT)
...@@ -116,6 +116,7 @@ from sglang.srt.metrics.collector import TokenizerMetricsCollector ...@@ -116,6 +116,7 @@ from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
ServerStatus,
dataclass_to_string_truncated, dataclass_to_string_truncated,
get_bool_env_var, get_bool_env_var,
get_zmq_socket, get_zmq_socket,
...@@ -173,6 +174,9 @@ class TokenizerManager: ...@@ -173,6 +174,9 @@ class TokenizerManager:
server_args: ServerArgs, server_args: ServerArgs,
port_args: PortArgs, port_args: PortArgs,
): ):
# Server Status
self.server_status = ServerStatus.Starting
# Parse args # Parse args
self.server_args = server_args self.server_args = server_args
self.enable_metrics = server_args.enable_metrics self.enable_metrics = server_args.enable_metrics
...@@ -251,7 +255,6 @@ class TokenizerManager: ...@@ -251,7 +255,6 @@ class TokenizerManager:
# Store states # Store states
self.no_create_loop = False self.no_create_loop = False
self.rid_to_state: Dict[str, ReqState] = {} self.rid_to_state: Dict[str, ReqState] = {}
self.health_check_failed = False
self.gracefully_exit = False self.gracefully_exit = False
self.last_receive_tstamp = 0 self.last_receive_tstamp = 0
self.dump_requests_folder = "" # By default do not dump self.dump_requests_folder = "" # By default do not dump
...@@ -1332,7 +1335,7 @@ class TokenizerManager: ...@@ -1332,7 +1335,7 @@ class TokenizerManager:
while True: while True:
remain_num_req = len(self.rid_to_state) remain_num_req = len(self.rid_to_state)
if self.health_check_failed: if not self.server_status.is_healthy():
# if health check failed, we should exit immediately # if health check failed, we should exit immediately
logger.error( logger.error(
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d", "Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
......
...@@ -93,6 +93,22 @@ time_infos = {} ...@@ -93,6 +93,22 @@ time_infos = {}
HIP_FP8_E4M3_FNUZ_MAX = 224.0 HIP_FP8_E4M3_FNUZ_MAX = 224.0
class ServerStatus(Enum):
Up = "Up"
Starting = "Starting"
UnHealthy = "UnHealthy"
Crashed = "Crashed"
def is_healthy(self) -> bool:
return self == ServerStatus.Up
def report_health(status: ServerStatus, host: str, http_port: int, msg: str = ""):
requests.post(
f"http://{host}:{http_port}/health", json={"status": status.value, "msg": msg}
)
# https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip # https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
def is_hip() -> bool: def is_hip() -> bool:
return torch.version.hip is not None return torch.version.hip is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment