"examples/simultaneous_translation/vscode:/vscode.git/clone" did not exist on "7df61696f57a11fbefb850c28acde501fd5b753f"
Unverified Commit 4540a466 authored by ybyang's avatar ybyang Committed by GitHub
Browse files

[Feature] Simple Improve Health Check Mechanism for Production-Grade Stability (#8115)


Signed-off-by: default avatarybyang <ybyang7@iflytek.com>
parent abda2542
......@@ -65,6 +65,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import (
MultiprocessingSerializer,
ServerStatus,
assert_pkg_version,
configure_logger,
get_zmq_socket,
......@@ -73,6 +74,7 @@ from sglang.srt.utils import (
launch_dummy_health_check_server,
maybe_set_triton_cache_manager,
prepare_model_and_tokenizer,
report_health,
set_prometheus_multiproc_dir,
set_ulimit,
)
......@@ -661,6 +663,7 @@ def _set_envs_and_config(server_args: ServerArgs):
def sigchld_handler(signum, frame):
pid, exitcode = os.waitpid(0, os.WNOHANG)
if exitcode != 0:
report_health(ServerStatus.Crashed, server_args.host, server_args.port)
logger.warning(
f"Child process unexpectedly failed with {exitcode=}. {pid=}"
)
......@@ -674,6 +677,7 @@ def _set_envs_and_config(server_args: ServerArgs):
logger.error(
"Received sigquit from a child process. It usually means the child failed."
)
report_health(ServerStatus.Crashed, server_args.host, server_args.port)
kill_process_tree(os.getpid())
signal.signal(signal.SIGQUIT, sigquit_handler)
......
......@@ -77,6 +77,7 @@ from sglang.srt.managers.io_struct import (
ParseFunctionCallReq,
ProfileReqInput,
ReleaseMemoryOccupationReqInput,
ReportHealthInput,
ResumeMemoryOccupationReqInput,
SeparateReasoningReqInput,
SetInternalStateReq,
......@@ -93,6 +94,7 @@ from sglang.srt.metrics.func_timer import enable_func_timer
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
ServerStatus,
add_api_key_middleware,
add_prometheus_middleware,
delete_directory,
......@@ -220,8 +222,31 @@ HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
@app.get("/health")
async def health() -> Response:
"""Check the health of the http server."""
return Response(status_code=200)
"""Check the status of the http server."""
code = HTTPStatus.SERVICE_UNAVAILABLE.value
if _global_state.tokenizer_manager.server_status == ServerStatus.Up:
code = HTTPStatus.OK.value
return Response(
status_code=code,
content=json.dumps(
{"status": _global_state.tokenizer_manager.server_status.value}
),
)
@app.post("/health")
async def health_update(obj: ReportHealthInput, request: Request) -> Response:
"""Update the Status of the http server."""
try:
server_status = ServerStatus(obj.status)
_global_state.tokenizer_manager.server_status = server_status
if server_status != ServerStatus.Up:
return Response(
status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, content=obj.msg
)
except Exception as e:
logger.error(e)
return Response(status_code=HTTPStatus.SERVICE_UNAVAILABLE.value)
@app.get("/health_generate")
......@@ -256,7 +281,7 @@ async def health_generate(request: Request) -> Response:
if _global_state.tokenizer_manager.last_receive_tstamp > tic:
task.cancel()
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
_global_state.tokenizer_manager.health_check_failed = False
_global_state.tokenizer_manager.server_status = ServerStatus.Up
return Response(status_code=200)
task.cancel()
......@@ -270,7 +295,7 @@ async def health_generate(request: Request) -> Response:
f"last_heartbeat time: {last_receive_time}"
)
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
_global_state.tokenizer_manager.health_check_failed = True
_global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy
return Response(status_code=503)
......@@ -1022,9 +1047,13 @@ def _execute_server_warmup(
headers=headers,
timeout=600,
)
assert res.status_code == 200, f"{res}"
if res.status_code == 200:
_global_state.tokenizer_manager.server_status = ServerStatus.Up
else:
_global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy
logger.info(f"{res}")
else:
logger.info(f"Start of prefill warmup ...")
logger.info(f"Start of prefill/decode warmup ...")
json_data = {
"sampling_params": {
"temperature": 0.0,
......@@ -1046,15 +1075,25 @@ def _execute_server_warmup(
headers=headers,
timeout=1800, # because of deep gemm precache is very long if not precache.
)
logger.info(
f"End of prefill warmup with status {res.status_code}, resp: {res.json()}"
)
if res.status_code == 200:
logger.info(
f"End of prefill disaggregation mode warmup with status {res.status_code}, resp: {res.json()}"
)
_global_state.tokenizer_manager.server_status = ServerStatus.Up
else:
logger.info(
"Prefill disaggregation mode warm Up Failed, status code: {}".format(
res.status_code
)
)
_global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy
except Exception:
last_traceback = get_exception_traceback()
if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback)
logger.error(f"Initialization failed. warmup error: {last_traceback}")
_global_state.tokenizer_manager.server_status = ServerStatus.Crashed
kill_process_tree(os.getpid())
return False
......
......@@ -1083,3 +1083,9 @@ class LoRAUpdateResult:
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
@dataclass
class ReportHealthInput:
status: str
msg: Optional[str] = ""
......@@ -143,6 +143,7 @@ from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
from sglang.srt.utils import (
DeepEPMode,
DynamicGradMode,
ServerStatus,
broadcast_pyobj,
configure_gc_logger,
configure_logger,
......@@ -154,6 +155,7 @@ from sglang.srt.utils import (
kill_itself_when_parent_died,
point_to_point_pyobj,
pyspy_dump_schedulers,
report_health,
require_mlp_sync,
require_mlp_tp_gather,
set_gpu_proc_affinity,
......@@ -2964,4 +2966,5 @@ def run_scheduler_process(
except Exception:
traceback = get_exception_traceback()
logger.error(f"Scheduler hit an exception: {traceback}")
report_health(ServerStatus.Crashed, server_args.host, ServerArgs.port)
parent_process.send_signal(signal.SIGQUIT)
......@@ -116,6 +116,7 @@ from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
ServerStatus,
dataclass_to_string_truncated,
get_bool_env_var,
get_zmq_socket,
......@@ -173,6 +174,9 @@ class TokenizerManager:
server_args: ServerArgs,
port_args: PortArgs,
):
# Server Status
self.server_status = ServerStatus.Starting
# Parse args
self.server_args = server_args
self.enable_metrics = server_args.enable_metrics
......@@ -251,7 +255,6 @@ class TokenizerManager:
# Store states
self.no_create_loop = False
self.rid_to_state: Dict[str, ReqState] = {}
self.health_check_failed = False
self.gracefully_exit = False
self.last_receive_tstamp = 0
self.dump_requests_folder = "" # By default do not dump
......@@ -1332,7 +1335,7 @@ class TokenizerManager:
while True:
remain_num_req = len(self.rid_to_state)
if self.health_check_failed:
if not self.server_status.is_healthy():
# if health check failed, we should exit immediately
logger.error(
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
......
......@@ -93,6 +93,22 @@ time_infos = {}
HIP_FP8_E4M3_FNUZ_MAX = 224.0
class ServerStatus(Enum):
Up = "Up"
Starting = "Starting"
UnHealthy = "UnHealthy"
Crashed = "Crashed"
def is_healthy(self) -> bool:
return self == ServerStatus.Up
def report_health(status: ServerStatus, host: str, http_port: int, msg: str = ""):
requests.post(
f"http://{host}:{http_port}/health", json={"status": status.value, "msg": msg}
)
# https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
def is_hip() -> bool:
return torch.version.hip is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment