Unverified Commit 6f9baf10 authored by ybyang's avatar ybyang Committed by GitHub
Browse files

[Improvements] Merge health check route (#8444)


Signed-off-by: default avatarybyang <ybyang7@iflytek.com>
Co-authored-by: default avatarLianmin Zheng <lianminzheng@gmail.com>
Co-authored-by: default avatarKan Wu <wukanustc@gmail.com>
parent a31b7a70
......@@ -460,6 +460,7 @@ class SchedulerDisaggregationPrefillMixin:
# We need to remove the sync in the following function for overlap schedule.
self.set_next_batch_sampling_info_done(batch)
self.maybe_send_health_check_signal()
def process_disagg_prefill_inflight_queue(
self: Scheduler, rids_to_check: Optional[List[str]] = None
......
......@@ -45,6 +45,7 @@ from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from sglang.srt.disaggregation.utils import (
FAKE_BOOTSTRAP_HOST,
DisaggregationMode,
register_disaggregation_server,
)
from sglang.srt.entrypoints.engine import _launch_subprocesses
......@@ -88,7 +89,7 @@ from sglang.srt.managers.io_struct import (
VertexGenerateReqInput,
)
from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager
from sglang.srt.metrics.func_timer import enable_func_timer
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.server_args import ServerArgs
......@@ -230,23 +231,28 @@ async def validate_json_request(raw_request: Request):
@app.get("/health")
async def health() -> Response:
"""Check the health of the http server."""
return Response(status_code=200)
@app.get("/health_generate")
async def health_generate(request: Request) -> Response:
"""Check the health of the inference server by generating one token."""
"""
Check the health of the inference server by sending a special request to generate one token.
If the server is running something, this request will be ignored, so it creates zero overhead.
If the server is not running anything, this request will be run, so we know whether the server is healthy.
"""
if _global_state.tokenizer_manager.gracefully_exit:
logger.info("Health check request received during shutdown. Returning 503.")
return Response(status_code=503)
if not _global_state.tokenizer_manager.server_status.is_healthy():
return Response(status_code=503)
sampling_params = {"max_new_tokens": 1, "temperature": 0.0}
rid = f"HEALTH_CHECK_{time.time()}"
if _global_state.tokenizer_manager.is_image_gen:
raise NotImplementedError()
# Keep this branch for some internal use cases.
raise NotImplementedError("Image generation is not supported yet.")
elif _global_state.tokenizer_manager.is_generation:
gri = GenerateReqInput(
rid=rid,
......@@ -254,6 +260,12 @@ async def health_generate(request: Request) -> Response:
sampling_params=sampling_params,
log_metrics=False,
)
if (
_global_state.tokenizer_manager.server_args.disaggregation_mode
!= DisaggregationMode.NULL
):
gri.bootstrap_host = FAKE_BOOTSTRAP_HOST
gri.bootstrap_room = 0
else:
gri = EmbeddingReqInput(
rid=rid, input_ids=[0], sampling_params=sampling_params, log_metrics=False
......@@ -263,9 +275,6 @@ async def health_generate(request: Request) -> Response:
async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
break
# This request is a special request.
# If the server already has something running, this request will be ignored, so it creates zero overhead.
# If the server is not running, this request will be run, so we know whether the server is healthy.
task = asyncio.create_task(gen())
# As long as we receive any response from the detokenizer/scheduler, we consider the server is healthy.
......@@ -1032,8 +1041,10 @@ def _execute_server_warmup(
timeout=600,
)
assert res.status_code == 200, f"{res}"
_global_state.tokenizer_manager.server_status = ServerStatus.Up
else:
logger.info(f"Start of prefill warmup ...")
logger.info(f"Start of pd disaggregation warmup ...")
json_data = {
"sampling_params": {
"temperature": 0.0,
......@@ -1055,9 +1066,18 @@ def _execute_server_warmup(
headers=headers,
timeout=1800, # because of deep gemm precache is very long if not precache.
)
logger.info(
f"End of prefill warmup with status {res.status_code}, resp: {res.json()}"
)
if res.status_code == 200:
logger.info(
f"End of prefill disaggregation mode warmup with status {res.status_code}, resp: {res.json()}"
)
_global_state.tokenizer_manager.server_status = ServerStatus.Up
else:
logger.info(
"Prefill disaggregation mode warm Up Failed, status code: {}".format(
res.status_code
)
)
_global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy
except Exception:
last_traceback = get_exception_traceback()
......
......@@ -1781,6 +1781,9 @@ class Scheduler(
elif batch.forward_mode.is_dummy_first():
self.set_next_batch_sampling_info_done(batch)
self.maybe_send_health_check_signal()
def maybe_send_health_check_signal(self):
if self.return_health_check_ct:
# Return some signal for the health check.
# This is used to prevent the health check signal being blocked by long context prefill.
......
......@@ -29,6 +29,7 @@ import uuid
from collections import deque
from contextlib import nullcontext
from datetime import datetime
from enum import Enum
from http import HTTPStatus
from typing import (
Any,
......@@ -115,6 +116,7 @@ from sglang.srt.managers.io_struct import (
)
from sglang.srt.managers.mm_utils import TensorTransportMode
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
from sglang.srt.managers.scheduler import is_health_check_generate_req
from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams
......@@ -270,6 +272,7 @@ class TokenizerManager:
self.health_check_failed = False
self.gracefully_exit = False
self.last_receive_tstamp = 0
self.server_status = ServerStatus.Starting
# Dumping
self.dump_requests_folder = "" # By default do not dump
......@@ -1804,6 +1807,8 @@ class TokenizerManager:
asyncio.create_task(asyncio.to_thread(background_task))
def _handle_abort_req(self, recv_obj):
if is_health_check_generate_req(recv_obj):
return
state = self.rid_to_state[recv_obj.rid]
state.finished = True
if recv_obj.finished_reason:
......@@ -1938,6 +1943,16 @@ class TokenizerManager:
return scores
class ServerStatus(Enum):
Up = "Up"
Starting = "Starting"
UnHealthy = "UnHealthy"
Crashed = "Crashed"
def is_healthy(self) -> bool:
return self == ServerStatus.Up
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
is_cross_node = server_args.dist_init_addr
......
......@@ -44,7 +44,6 @@ import traceback
import warnings
from collections import OrderedDict, defaultdict
from contextlib import contextmanager
from enum import Enum
from functools import lru_cache
from importlib.metadata import PackageNotFoundError, version
from importlib.util import find_spec
......@@ -93,6 +92,7 @@ logger = logging.getLogger(__name__)
show_time_cost = False
time_infos = {}
HIP_FP8_E4M3_FNUZ_MAX = 224.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment