Unverified Commit 34acb1f4 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

fix(sglang): populate total_kv_blocks via canonical Engine scheduler path (#8439)

parent ffa6e84a
...@@ -133,6 +133,38 @@ async def mm_encode(encoder: Any, mm_items: Any, modality: Any) -> tuple: ...@@ -133,6 +133,38 @@ async def mm_encode(encoder: Any, mm_items: Any, modality: Any) -> tuple:
return result return result
def get_scheduler_info(engine: Any) -> dict:
"""Return the scheduler-info dict for rank-0 of an ``sgl.Engine``.
SGLang exposes per-rank scheduler stats (``max_total_num_tokens``,
``max_req_input_len``, ...) on the ``Engine`` via ``_scheduler_init_result``.
We return the rank-0 dict, or ``{}`` if it is not reachable on this build.
Covers:
- sglang 0.5.10+: ``engine._scheduler_init_result.scheduler_infos[0]``
(canonical; also what ``Engine.get_server_info`` reads internally).
- Older probed attributes (``engine.scheduler_info``,
``engine.tokenizer_manager.scheduler_info``) as a best-effort fallback
for forks/experimental branches that surfaced the dict directly.
"""
result = getattr(engine, "_scheduler_init_result", None)
if result is not None:
infos = getattr(result, "scheduler_infos", None)
if infos:
return infos[0]
direct = getattr(engine, "scheduler_info", None)
if direct:
return direct
tm = getattr(engine, "tokenizer_manager", None)
tm_info = getattr(tm, "scheduler_info", None) if tm is not None else None
if tm_info:
return tm_info
return {}
def enable_disjoint_streaming_output(server_args: Any) -> None: def enable_disjoint_streaming_output(server_args: Any) -> None:
""" """
Enable SGLang's disjoint streaming output across ServerArgs field renames. Enable SGLang's disjoint streaming output across ServerArgs field renames.
...@@ -170,6 +202,7 @@ __all__ = [ ...@@ -170,6 +202,7 @@ __all__ = [
"NetworkAddress", "NetworkAddress",
"enable_disjoint_streaming_output", "enable_disjoint_streaming_output",
"get_local_ip_auto", "get_local_ip_auto",
"get_scheduler_info",
"get_zmq_socket", "get_zmq_socket",
"mm_encode", "mm_encode",
] ]
...@@ -25,6 +25,7 @@ from dynamo.common.backend.engine import ( ...@@ -25,6 +25,7 @@ from dynamo.common.backend.engine import (
from dynamo.common.backend.worker import WorkerConfig from dynamo.common.backend.worker import WorkerConfig
from dynamo.common.utils.input_params import InputParamManager from dynamo.common.utils.input_params import InputParamManager
from dynamo.llm import ModelInput from dynamo.llm import ModelInput
from dynamo.sglang._compat import get_scheduler_info
from dynamo.sglang.args import parse_args from dynamo.sglang.args import parse_args
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -73,7 +74,7 @@ class SglangLLMEngine(LLMEngine): ...@@ -73,7 +74,7 @@ class SglangLLMEngine(LLMEngine):
# Capacity fields -- sourced the same way as register.py in the # Capacity fields -- sourced the same way as register.py in the
# non-unified path so the Rust runtime gets consistent values. # non-unified path so the Rust runtime gets consistent values.
total_kv_blocks = None total_kv_blocks = None
scheduler_info = getattr(self.engine, "scheduler_info", None) or {} scheduler_info = get_scheduler_info(self.engine)
max_total_tokens = scheduler_info.get("max_total_num_tokens") max_total_tokens = scheduler_info.get("max_total_num_tokens")
page_size = self.server_args.page_size page_size = self.server_args.page_size
if max_total_tokens and page_size: if max_total_tokens and page_size:
......
...@@ -11,7 +11,7 @@ from sglang.srt.server_args import ServerArgs ...@@ -11,7 +11,7 @@ from sglang.srt.server_args import ServerArgs
from dynamo._core import Endpoint from dynamo._core import Endpoint
from dynamo.common.utils.output_modalities import get_output_modalities from dynamo.common.utils.output_modalities import get_output_modalities
from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_model from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_model
from dynamo.sglang._compat import NetworkAddress, get_local_ip_auto from dynamo.sglang._compat import NetworkAddress, get_local_ip_auto, get_scheduler_info
from dynamo.sglang.args import DynamoConfig from dynamo.sglang.args import DynamoConfig
...@@ -170,39 +170,38 @@ async def _get_runtime_config( ...@@ -170,39 +170,38 @@ async def _get_runtime_config(
runtime_config.enable_eagle = True runtime_config.enable_eagle = True
try: try:
# Try to check if the engine has a scheduler attribute with the computed values scheduler_info = get_scheduler_info(engine)
if hasattr(engine, "scheduler_info") and engine.scheduler_info is not None: max_total_tokens = scheduler_info.get("max_total_num_tokens")
# Get max_total_num_tokens from scheduler_info
max_total_tokens = engine.scheduler_info.get("max_total_num_tokens") if max_total_tokens:
if max_total_tokens and hasattr(engine.tokenizer_manager, "server_args"): page_size = server_args.page_size
page_size = engine.tokenizer_manager.server_args.page_size if page_size:
if page_size: runtime_config.total_kv_blocks = (
runtime_config.total_kv_blocks = ( max_total_tokens + page_size - 1
max_total_tokens + page_size - 1 ) // page_size
) // page_size logging.info(
logging.info( f"Got total KV blocks from scheduler: {runtime_config.total_kv_blocks} "
f"Got total KV blocks from scheduler: {runtime_config.total_kv_blocks} " f"(max_total_tokens={max_total_tokens}, page_size={page_size})"
f"(max_total_tokens={max_total_tokens}, page_size={page_size})" )
)
# When max_prefill_tokens is not explicitly set by the user, fall back # When max_prefill_tokens is not explicitly set by the user, fall back
# to max_total_num_tokens from the scheduler. This ensures the planner # to max_total_num_tokens from the scheduler. This ensures the planner
# always has a prefill load signal for aggregated scaling decisions. # always has a prefill load signal for aggregated scaling decisions.
if not max_prefill_tokens and max_total_tokens: if not max_prefill_tokens:
runtime_config.max_num_batched_tokens = max_total_tokens runtime_config.max_num_batched_tokens = max_total_tokens
logging.info( logging.info(
f"max_prefill_tokens not set, using max_total_num_tokens " f"max_prefill_tokens not set, using max_total_num_tokens "
f"from scheduler as max_num_batched_tokens: {max_total_tokens}" f"from scheduler as max_num_batched_tokens: {max_total_tokens}"
) )
else:
unpublished = "total_kv_blocks"
if not max_prefill_tokens:
unpublished += " and max_num_batched_tokens"
logging.warning(
f"Could not access scheduler info from SGLang engine. "
f"{unpublished} will not be published; SGLang will use its internal defaults."
)
return runtime_config
# If scheduler approach doesn't work, log and return None to indicate we'll skip runtime config
logging.warning(
"Could not access runtime config from SGLang engine. "
"The engine may compute these values internally after initialization. "
"Proceeding without runtime config - SGLang will use its internal defaults."
)
return runtime_config return runtime_config
except Exception as e: except Exception as e:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment