Unverified Commit 8064849b authored by nachiketb-nvidia's avatar nachiketb-nvidia Committed by GitHub
Browse files

feat: add and enable reasoning and tool parser flags for trtllm and sglang (#2713)

parent 50cd81f3
......@@ -10,7 +10,7 @@ import sys
from argparse import Namespace
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict
from typing import Any, Dict, Optional
from sglang.srt.server_args import ServerArgs
......@@ -39,6 +39,10 @@ class DynamoArgs:
endpoint: str
migration_limit: int
# tool and reasoning parser options
tool_call_parser: Optional[str] = None
reasoning_parser: Optional[str] = None
class DisaggregationMode(Enum):
AGGREGATED = "agg"
......@@ -71,6 +75,20 @@ def parse_args(args: list[str]) -> Config:
"--version", action="version", version=f"Dynamo Backend SGLang {__version__}"
)
# To avoid name conflicts with different backends, adoped prefix "dyn-" for dynamo specific args
parser.add_argument(
"--dyn-tool-call-parser",
type=str,
default=None,
help="Tool call parser name for the model. Available options: 'hermes', 'nemotron_deci', 'llama3_json', 'mistral', 'phi4'.",
)
parser.add_argument(
"--dyn-reasoning-parser",
type=str,
default=None,
help="Reasoning parser name for the model. Available options: 'basic', 'deepseek_r1', 'gpt_oss'.",
)
# Dynamo args
for info in DYNAMO_ARGS.values():
parser.add_argument(
......@@ -123,6 +141,8 @@ def parse_args(args: list[str]) -> Config:
component=parsed_component_name,
endpoint=parsed_endpoint_name,
migration_limit=parsed_args.migration_limit,
tool_call_parser=parsed_args.dyn_tool_call_parser,
reasoning_parser=parsed_args.dyn_reasoning_parser,
)
logging.debug(f"Dynamo args: {dynamo_args}")
......
......@@ -97,7 +97,10 @@ async def init(runtime: DistributedRuntime, config: Config):
async def register_model():
"""Register the model and signal readiness"""
registration_success = await register_llm_with_runtime_config(
engine, generate_endpoint, server_args, dynamo_args.migration_limit
engine,
generate_endpoint,
server_args,
dynamo_args,
)
if not registration_success:
......
......@@ -9,20 +9,21 @@ from sglang.srt.server_args import ServerArgs
from dynamo._core import Endpoint
from dynamo.llm import ModelRuntimeConfig, ModelType, register_llm
from dynamo.sglang.args import DynamoArgs
async def register_llm_with_runtime_config(
engine: sgl.Engine,
endpoint: Endpoint,
server_args: ServerArgs,
migration_limit: int,
dynamo_args: DynamoArgs,
) -> bool:
"""Register LLM with runtime config
Returns:
bool: True if registration succeeded, False if it failed
"""
runtime_config = await _get_runtime_config(engine)
runtime_config = await _get_runtime_config(engine, dynamo_args)
try:
await register_llm(
ModelType.Backend,
......@@ -30,7 +31,7 @@ async def register_llm_with_runtime_config(
server_args.model_path,
server_args.served_model_name,
kv_cache_block_size=server_args.page_size,
migration_limit=migration_limit,
migration_limit=dynamo_args.migration_limit,
runtime_config=runtime_config,
)
logging.info("Successfully registered LLM with runtime config")
......@@ -40,13 +41,17 @@ async def register_llm_with_runtime_config(
return False
async def _get_runtime_config(engine: sgl.Engine) -> Optional[ModelRuntimeConfig]:
async def _get_runtime_config(
engine: sgl.Engine, dynamo_args: DynamoArgs
) -> Optional[ModelRuntimeConfig]:
"""Get runtime config from SGLang engine"""
runtime_config = ModelRuntimeConfig()
# set reasoning parser and tool call parser
runtime_config.reasoning_parser = dynamo_args.reasoning_parser
runtime_config.tool_call_parser = dynamo_args.tool_call_parser
try:
# Try to check if the engine has a scheduler attribute with the computed values
if hasattr(engine, "scheduler_info") and engine.scheduler_info is not None:
runtime_config = ModelRuntimeConfig()
# Get max_total_num_tokens from scheduler_info
if "max_total_num_tokens" in engine.scheduler_info:
max_total_tokens = engine.scheduler_info["max_total_num_tokens"]
......@@ -73,8 +78,8 @@ async def _get_runtime_config(engine: sgl.Engine) -> Optional[ModelRuntimeConfig
"The engine may compute these values internally after initialization. "
"Proceeding without runtime config - SGLang will use its internal defaults."
)
return None
return runtime_config
except Exception as e:
logging.warning(f"Failed to get runtime config: {e}. Proceeding without it.")
return None
return runtime_config
......@@ -228,6 +228,17 @@ async def init(runtime: DistributedRuntime, config: Config):
async with get_llm_engine(engine_args) as engine:
endpoint = component.endpoint(config.endpoint)
# should ideally call get_engine_runtime_config
# this is because we don't have a good way to
# get total_kv_blocks from the engine yet without calling get_stats_async
# This causes an issue because get_stats_async doesn't work when no requests are sent to the engine
# So for now, we just set the parsers from the config
# TODO: fix this once we have a better way to get total_kv_blocks
runtime_config = ModelRuntimeConfig()
runtime_config.reasoning_parser = config.reasoning_parser
runtime_config.tool_call_parser = config.tool_call_parser
if is_first_worker(config):
# Register the model with runtime config
await register_llm(
......@@ -237,6 +248,7 @@ async def init(runtime: DistributedRuntime, config: Config):
config.served_model_name,
kv_cache_block_size=config.kv_block_size,
migration_limit=config.migration_limit,
runtime_config=runtime_config,
)
# publisher will be set later if publishing is enabled.
handler_config = RequestHandlerConfig(
......
......@@ -49,6 +49,9 @@ class Config:
self.next_endpoint: str = ""
self.modality: str = "text"
self.reasoning_parser: Optional[str] = None
self.tool_call_parser: Optional[str] = None
def __str__(self) -> str:
return (
f"Config(namespace={self.namespace}, "
......@@ -73,6 +76,8 @@ class Config:
f"disaggregation_strategy={self.disaggregation_strategy}, "
f"next_endpoint={self.next_endpoint}, "
f"modality={self.modality})"
f"reasoning_parser={self.reasoning_parser})"
f"tool_call_parser={self.tool_call_parser})"
)
......@@ -234,6 +239,21 @@ def cmd_line_args():
default="",
help=f"Endpoint(in 'dyn://namespace.component.endpoint' format) to send requests to when running in disaggregation mode. Default: {DEFAULT_NEXT_ENDPOINT} if first worker, empty if next worker",
)
# To avoid name conflicts with different backends, adoped prefix "dyn-" for dynamo specific args
parser.add_argument(
"--dyn-tool-call-parser",
type=str,
default=None,
help="Tool call parser name for the model. Available options: 'hermes', 'nemotron_deci', 'llama3_json', 'mistral', 'phi4'.",
)
parser.add_argument(
"--dyn-reasoning-parser",
type=str,
default=None,
help="Reasoning parser name for the model. Available options: 'basic', 'deepseek_r1', 'gpt_oss'.",
)
args = parser.parse_args()
config = Config()
......@@ -294,4 +314,7 @@ def cmd_line_args():
config.publish_events_and_metrics = args.publish_events_and_metrics
config.modality = args.modality
config.reasoning_parser = args.dyn_reasoning_parser
config.tool_call_parser = args.dyn_tool_call_parser
return config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment