Unverified Commit 119149f2 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

fix(sglang): always populate max_num_batched_tokens in MDC (#8220)

parent 134d484d
...@@ -79,6 +79,12 @@ class SglangLLMEngine(LLMEngine): ...@@ -79,6 +79,12 @@ class SglangLLMEngine(LLMEngine):
if max_total_tokens and page_size: if max_total_tokens and page_size:
total_kv_blocks = (max_total_tokens + page_size - 1) // page_size total_kv_blocks = (max_total_tokens + page_size - 1) // page_size
# Prefer explicit max_prefill_tokens; fall back to max_total_num_tokens
# from the scheduler so the planner always has a prefill load signal.
max_num_batched_tokens = (
getattr(self.server_args, "max_prefill_tokens", None) or max_total_tokens
)
return EngineConfig( return EngineConfig(
model=self.server_args.model_path, model=self.server_args.model_path,
served_model_name=self.server_args.served_model_name, served_model_name=self.server_args.served_model_name,
...@@ -86,9 +92,7 @@ class SglangLLMEngine(LLMEngine): ...@@ -86,9 +92,7 @@ class SglangLLMEngine(LLMEngine):
kv_cache_block_size=page_size, kv_cache_block_size=page_size,
total_kv_blocks=total_kv_blocks, total_kv_blocks=total_kv_blocks,
max_num_seqs=getattr(self.server_args, "max_running_requests", None), max_num_seqs=getattr(self.server_args, "max_running_requests", None),
max_num_batched_tokens=getattr( max_num_batched_tokens=max_num_batched_tokens,
self.server_args, "max_prefill_tokens", None
),
) )
async def generate( async def generate(
......
...@@ -173,25 +173,27 @@ async def _get_runtime_config( ...@@ -173,25 +173,27 @@ async def _get_runtime_config(
# Try to check if the engine has a scheduler attribute with the computed values # Try to check if the engine has a scheduler attribute with the computed values
if hasattr(engine, "scheduler_info") and engine.scheduler_info is not None: if hasattr(engine, "scheduler_info") and engine.scheduler_info is not None:
# Get max_total_num_tokens from scheduler_info # Get max_total_num_tokens from scheduler_info
if "max_total_num_tokens" in engine.scheduler_info: max_total_tokens = engine.scheduler_info.get("max_total_num_tokens")
max_total_tokens = engine.scheduler_info["max_total_num_tokens"] if max_total_tokens and hasattr(engine.tokenizer_manager, "server_args"):
if max_total_tokens and hasattr( page_size = engine.tokenizer_manager.server_args.page_size
engine.tokenizer_manager, "server_args" if page_size:
): runtime_config.total_kv_blocks = (
page_size = engine.tokenizer_manager.server_args.page_size max_total_tokens + page_size - 1
if page_size: ) // page_size
runtime_config.total_kv_blocks = ( logging.info(
max_total_tokens + page_size - 1 f"Got total KV blocks from scheduler: {runtime_config.total_kv_blocks} "
) // page_size f"(max_total_tokens={max_total_tokens}, page_size={page_size})"
logging.info( )
f"Got total KV blocks from scheduler: {runtime_config.total_kv_blocks} "
f"(max_total_tokens={max_total_tokens}, page_size={page_size})" # When max_prefill_tokens is not explicitly set by the user, fall back
) # to max_total_num_tokens from the scheduler. This ensures the planner
# always has a prefill load signal for aggregated scaling decisions.
# Note: max_running_requests and max_prefill_tokens are NOT available in scheduler_info. if not max_prefill_tokens and max_total_tokens:
# SGLang separates configuration (server_args) from runtime stats (scheduler_info). runtime_config.max_num_batched_tokens = max_total_tokens
# In contrast, vLLM exposes both config and runtime values through engine config. logging.info(
# These are config parameters, so they must be retrieved from server_args only. f"max_prefill_tokens not set, using max_total_num_tokens "
f"from scheduler as max_num_batched_tokens: {max_total_tokens}"
)
return runtime_config return runtime_config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment