Unverified Commit 34acb1f4 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

fix(sglang): populate total_kv_blocks via canonical Engine scheduler path (#8439)

parent ffa6e84a
......@@ -133,6 +133,38 @@ async def mm_encode(encoder: Any, mm_items: Any, modality: Any) -> tuple:
return result
def get_scheduler_info(engine: Any) -> dict:
"""Return the scheduler-info dict for rank-0 of an ``sgl.Engine``.
SGLang exposes per-rank scheduler stats (``max_total_num_tokens``,
``max_req_input_len``, ...) on the ``Engine`` via ``_scheduler_init_result``.
We return the rank-0 dict, or ``{}`` if it is not reachable on this build.
Covers:
- sglang 0.5.10+: ``engine._scheduler_init_result.scheduler_infos[0]``
(canonical; also what ``Engine.get_server_info`` reads internally).
- Older probed attributes (``engine.scheduler_info``,
``engine.tokenizer_manager.scheduler_info``) as a best-effort fallback
for forks/experimental branches that surfaced the dict directly.
"""
result = getattr(engine, "_scheduler_init_result", None)
if result is not None:
infos = getattr(result, "scheduler_infos", None)
if infos:
return infos[0]
direct = getattr(engine, "scheduler_info", None)
if direct:
return direct
tm = getattr(engine, "tokenizer_manager", None)
tm_info = getattr(tm, "scheduler_info", None) if tm is not None else None
if tm_info:
return tm_info
return {}
def enable_disjoint_streaming_output(server_args: Any) -> None:
"""
Enable SGLang's disjoint streaming output across ServerArgs field renames.
......@@ -170,6 +202,7 @@ __all__ = [
"NetworkAddress",
"enable_disjoint_streaming_output",
"get_local_ip_auto",
"get_scheduler_info",
"get_zmq_socket",
"mm_encode",
]
......@@ -25,6 +25,7 @@ from dynamo.common.backend.engine import (
from dynamo.common.backend.worker import WorkerConfig
from dynamo.common.utils.input_params import InputParamManager
from dynamo.llm import ModelInput
from dynamo.sglang._compat import get_scheduler_info
from dynamo.sglang.args import parse_args
logger = logging.getLogger(__name__)
......@@ -73,7 +74,7 @@ class SglangLLMEngine(LLMEngine):
# Capacity fields -- sourced the same way as register.py in the
# non-unified path so the Rust runtime gets consistent values.
total_kv_blocks = None
scheduler_info = getattr(self.engine, "scheduler_info", None) or {}
scheduler_info = get_scheduler_info(self.engine)
max_total_tokens = scheduler_info.get("max_total_num_tokens")
page_size = self.server_args.page_size
if max_total_tokens and page_size:
......
......@@ -11,7 +11,7 @@ from sglang.srt.server_args import ServerArgs
from dynamo._core import Endpoint
from dynamo.common.utils.output_modalities import get_output_modalities
from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_model
from dynamo.sglang._compat import NetworkAddress, get_local_ip_auto
from dynamo.sglang._compat import NetworkAddress, get_local_ip_auto, get_scheduler_info
from dynamo.sglang.args import DynamoConfig
......@@ -170,12 +170,11 @@ async def _get_runtime_config(
runtime_config.enable_eagle = True
try:
# Try to check if the engine has a scheduler attribute with the computed values
if hasattr(engine, "scheduler_info") and engine.scheduler_info is not None:
# Get max_total_num_tokens from scheduler_info
max_total_tokens = engine.scheduler_info.get("max_total_num_tokens")
if max_total_tokens and hasattr(engine.tokenizer_manager, "server_args"):
page_size = engine.tokenizer_manager.server_args.page_size
scheduler_info = get_scheduler_info(engine)
max_total_tokens = scheduler_info.get("max_total_num_tokens")
if max_total_tokens:
page_size = server_args.page_size
if page_size:
runtime_config.total_kv_blocks = (
max_total_tokens + page_size - 1
......@@ -188,21 +187,21 @@ async def _get_runtime_config(
# When max_prefill_tokens is not explicitly set by the user, fall back
# to max_total_num_tokens from the scheduler. This ensures the planner
# always has a prefill load signal for aggregated scaling decisions.
if not max_prefill_tokens and max_total_tokens:
if not max_prefill_tokens:
runtime_config.max_num_batched_tokens = max_total_tokens
logging.info(
f"max_prefill_tokens not set, using max_total_num_tokens "
f"from scheduler as max_num_batched_tokens: {max_total_tokens}"
)
return runtime_config
# If scheduler approach doesn't work, log and return None to indicate we'll skip runtime config
else:
unpublished = "total_kv_blocks"
if not max_prefill_tokens:
unpublished += " and max_num_batched_tokens"
logging.warning(
"Could not access runtime config from SGLang engine. "
"The engine may compute these values internally after initialization. "
"Proceeding without runtime config - SGLang will use its internal defaults."
f"Could not access scheduler info from SGLang engine. "
f"{unpublished} will not be published; SGLang will use its internal defaults."
)
return runtime_config
except Exception as e:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment