Unverified Commit 6f9baf10 authored by ybyang's avatar ybyang Committed by GitHub
Browse files

[Improvements] Merge health check route (#8444)


Signed-off-by: default avatarybyang <ybyang7@iflytek.com>
Co-authored-by: default avatarLianmin Zheng <lianminzheng@gmail.com>
Co-authored-by: default avatarKan Wu <wukanustc@gmail.com>
parent a31b7a70
...@@ -460,6 +460,7 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -460,6 +460,7 @@ class SchedulerDisaggregationPrefillMixin:
# We need to remove the sync in the following function for overlap schedule. # We need to remove the sync in the following function for overlap schedule.
self.set_next_batch_sampling_info_done(batch) self.set_next_batch_sampling_info_done(batch)
self.maybe_send_health_check_signal()
def process_disagg_prefill_inflight_queue( def process_disagg_prefill_inflight_queue(
self: Scheduler, rids_to_check: Optional[List[str]] = None self: Scheduler, rids_to_check: Optional[List[str]] = None
......
...@@ -45,6 +45,7 @@ from fastapi.responses import ORJSONResponse, Response, StreamingResponse ...@@ -45,6 +45,7 @@ from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from sglang.srt.disaggregation.utils import ( from sglang.srt.disaggregation.utils import (
FAKE_BOOTSTRAP_HOST, FAKE_BOOTSTRAP_HOST,
DisaggregationMode,
register_disaggregation_server, register_disaggregation_server,
) )
from sglang.srt.entrypoints.engine import _launch_subprocesses from sglang.srt.entrypoints.engine import _launch_subprocesses
...@@ -88,7 +89,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -88,7 +89,7 @@ from sglang.srt.managers.io_struct import (
VertexGenerateReqInput, VertexGenerateReqInput,
) )
from sglang.srt.managers.template_manager import TemplateManager from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager
from sglang.srt.metrics.func_timer import enable_func_timer from sglang.srt.metrics.func_timer import enable_func_timer
from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -230,23 +231,28 @@ async def validate_json_request(raw_request: Request): ...@@ -230,23 +231,28 @@ async def validate_json_request(raw_request: Request):
@app.get("/health") @app.get("/health")
async def health() -> Response:
"""Check the health of the http server."""
return Response(status_code=200)
@app.get("/health_generate") @app.get("/health_generate")
async def health_generate(request: Request) -> Response: async def health_generate(request: Request) -> Response:
"""Check the health of the inference server by generating one token.""" """
Check the health of the inference server by sending a special request to generate one token.
If the server is running something, this request will be ignored, so it creates zero overhead.
If the server is not running anything, this request will be run, so we know whether the server is healthy.
"""
if _global_state.tokenizer_manager.gracefully_exit: if _global_state.tokenizer_manager.gracefully_exit:
logger.info("Health check request received during shutdown. Returning 503.") logger.info("Health check request received during shutdown. Returning 503.")
return Response(status_code=503) return Response(status_code=503)
if not _global_state.tokenizer_manager.server_status.is_healthy():
return Response(status_code=503)
sampling_params = {"max_new_tokens": 1, "temperature": 0.0} sampling_params = {"max_new_tokens": 1, "temperature": 0.0}
rid = f"HEALTH_CHECK_{time.time()}" rid = f"HEALTH_CHECK_{time.time()}"
if _global_state.tokenizer_manager.is_image_gen: if _global_state.tokenizer_manager.is_image_gen:
raise NotImplementedError() # Keep this branch for some internal use cases.
raise NotImplementedError("Image generation is not supported yet.")
elif _global_state.tokenizer_manager.is_generation: elif _global_state.tokenizer_manager.is_generation:
gri = GenerateReqInput( gri = GenerateReqInput(
rid=rid, rid=rid,
...@@ -254,6 +260,12 @@ async def health_generate(request: Request) -> Response: ...@@ -254,6 +260,12 @@ async def health_generate(request: Request) -> Response:
sampling_params=sampling_params, sampling_params=sampling_params,
log_metrics=False, log_metrics=False,
) )
if (
_global_state.tokenizer_manager.server_args.disaggregation_mode
!= DisaggregationMode.NULL
):
gri.bootstrap_host = FAKE_BOOTSTRAP_HOST
gri.bootstrap_room = 0
else: else:
gri = EmbeddingReqInput( gri = EmbeddingReqInput(
rid=rid, input_ids=[0], sampling_params=sampling_params, log_metrics=False rid=rid, input_ids=[0], sampling_params=sampling_params, log_metrics=False
...@@ -263,9 +275,6 @@ async def health_generate(request: Request) -> Response: ...@@ -263,9 +275,6 @@ async def health_generate(request: Request) -> Response:
async for _ in _global_state.tokenizer_manager.generate_request(gri, request): async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
break break
# This request is a special request.
# If the server already has something running, this request will be ignored, so it creates zero overhead.
# If the server is not running, this request will be run, so we know whether the server is healthy.
task = asyncio.create_task(gen()) task = asyncio.create_task(gen())
# As long as we receive any response from the detokenizer/scheduler, we consider the server is healthy. # As long as we receive any response from the detokenizer/scheduler, we consider the server is healthy.
...@@ -1032,8 +1041,10 @@ def _execute_server_warmup( ...@@ -1032,8 +1041,10 @@ def _execute_server_warmup(
timeout=600, timeout=600,
) )
assert res.status_code == 200, f"{res}" assert res.status_code == 200, f"{res}"
_global_state.tokenizer_manager.server_status = ServerStatus.Up
else: else:
logger.info(f"Start of prefill warmup ...") logger.info(f"Start of pd disaggregation warmup ...")
json_data = { json_data = {
"sampling_params": { "sampling_params": {
"temperature": 0.0, "temperature": 0.0,
...@@ -1055,9 +1066,18 @@ def _execute_server_warmup( ...@@ -1055,9 +1066,18 @@ def _execute_server_warmup(
headers=headers, headers=headers,
timeout=1800, # because of deep gemm precache is very long if not precache. timeout=1800, # because of deep gemm precache is very long if not precache.
) )
if res.status_code == 200:
logger.info( logger.info(
f"End of prefill warmup with status {res.status_code}, resp: {res.json()}" f"End of prefill disaggregation mode warmup with status {res.status_code}, resp: {res.json()}"
)
_global_state.tokenizer_manager.server_status = ServerStatus.Up
else:
logger.info(
"Prefill disaggregation mode warm Up Failed, status code: {}".format(
res.status_code
)
) )
_global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy
except Exception: except Exception:
last_traceback = get_exception_traceback() last_traceback = get_exception_traceback()
......
...@@ -1781,6 +1781,9 @@ class Scheduler( ...@@ -1781,6 +1781,9 @@ class Scheduler(
elif batch.forward_mode.is_dummy_first(): elif batch.forward_mode.is_dummy_first():
self.set_next_batch_sampling_info_done(batch) self.set_next_batch_sampling_info_done(batch)
self.maybe_send_health_check_signal()
def maybe_send_health_check_signal(self):
if self.return_health_check_ct: if self.return_health_check_ct:
# Return some signal for the health check. # Return some signal for the health check.
# This is used to prevent the health check signal being blocked by long context prefill. # This is used to prevent the health check signal being blocked by long context prefill.
......
...@@ -29,6 +29,7 @@ import uuid ...@@ -29,6 +29,7 @@ import uuid
from collections import deque from collections import deque
from contextlib import nullcontext from contextlib import nullcontext
from datetime import datetime from datetime import datetime
from enum import Enum
from http import HTTPStatus from http import HTTPStatus
from typing import ( from typing import (
Any, Any,
...@@ -115,6 +116,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -115,6 +116,7 @@ from sglang.srt.managers.io_struct import (
) )
from sglang.srt.managers.mm_utils import TensorTransportMode from sglang.srt.managers.mm_utils import TensorTransportMode
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
from sglang.srt.managers.scheduler import is_health_check_generate_req
from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
...@@ -270,6 +272,7 @@ class TokenizerManager: ...@@ -270,6 +272,7 @@ class TokenizerManager:
self.health_check_failed = False self.health_check_failed = False
self.gracefully_exit = False self.gracefully_exit = False
self.last_receive_tstamp = 0 self.last_receive_tstamp = 0
self.server_status = ServerStatus.Starting
# Dumping # Dumping
self.dump_requests_folder = "" # By default do not dump self.dump_requests_folder = "" # By default do not dump
...@@ -1804,6 +1807,8 @@ class TokenizerManager: ...@@ -1804,6 +1807,8 @@ class TokenizerManager:
asyncio.create_task(asyncio.to_thread(background_task)) asyncio.create_task(asyncio.to_thread(background_task))
def _handle_abort_req(self, recv_obj): def _handle_abort_req(self, recv_obj):
if is_health_check_generate_req(recv_obj):
return
state = self.rid_to_state[recv_obj.rid] state = self.rid_to_state[recv_obj.rid]
state.finished = True state.finished = True
if recv_obj.finished_reason: if recv_obj.finished_reason:
...@@ -1938,6 +1943,16 @@ class TokenizerManager: ...@@ -1938,6 +1943,16 @@ class TokenizerManager:
return scores return scores
class ServerStatus(Enum):
Up = "Up"
Starting = "Starting"
UnHealthy = "UnHealthy"
Crashed = "Crashed"
def is_healthy(self) -> bool:
return self == ServerStatus.Up
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode: def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
is_cross_node = server_args.dist_init_addr is_cross_node = server_args.dist_init_addr
......
...@@ -44,7 +44,6 @@ import traceback ...@@ -44,7 +44,6 @@ import traceback
import warnings import warnings
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from enum import Enum
from functools import lru_cache from functools import lru_cache
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
from importlib.util import find_spec from importlib.util import find_spec
...@@ -93,6 +92,7 @@ logger = logging.getLogger(__name__) ...@@ -93,6 +92,7 @@ logger = logging.getLogger(__name__)
show_time_cost = False show_time_cost = False
time_infos = {} time_infos = {}
HIP_FP8_E4M3_FNUZ_MAX = 224.0 HIP_FP8_E4M3_FNUZ_MAX = 224.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment