Unverified Commit 8064849b authored by nachiketb-nvidia's avatar nachiketb-nvidia Committed by GitHub
Browse files

feat: add and enable reasoning and tool parser flags for trtllm and sglang (#2713)

parent 50cd81f3
...@@ -10,7 +10,7 @@ import sys ...@@ -10,7 +10,7 @@ import sys
from argparse import Namespace from argparse import Namespace
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Any, Dict from typing import Any, Dict, Optional
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -39,6 +39,10 @@ class DynamoArgs: ...@@ -39,6 +39,10 @@ class DynamoArgs:
endpoint: str endpoint: str
migration_limit: int migration_limit: int
# tool and reasoning parser options
tool_call_parser: Optional[str] = None
reasoning_parser: Optional[str] = None
class DisaggregationMode(Enum): class DisaggregationMode(Enum):
AGGREGATED = "agg" AGGREGATED = "agg"
...@@ -71,6 +75,20 @@ def parse_args(args: list[str]) -> Config: ...@@ -71,6 +75,20 @@ def parse_args(args: list[str]) -> Config:
"--version", action="version", version=f"Dynamo Backend SGLang {__version__}" "--version", action="version", version=f"Dynamo Backend SGLang {__version__}"
) )
# To avoid name conflicts with different backends, adoped prefix "dyn-" for dynamo specific args
parser.add_argument(
"--dyn-tool-call-parser",
type=str,
default=None,
help="Tool call parser name for the model. Available options: 'hermes', 'nemotron_deci', 'llama3_json', 'mistral', 'phi4'.",
)
parser.add_argument(
"--dyn-reasoning-parser",
type=str,
default=None,
help="Reasoning parser name for the model. Available options: 'basic', 'deepseek_r1', 'gpt_oss'.",
)
# Dynamo args # Dynamo args
for info in DYNAMO_ARGS.values(): for info in DYNAMO_ARGS.values():
parser.add_argument( parser.add_argument(
...@@ -123,6 +141,8 @@ def parse_args(args: list[str]) -> Config: ...@@ -123,6 +141,8 @@ def parse_args(args: list[str]) -> Config:
component=parsed_component_name, component=parsed_component_name,
endpoint=parsed_endpoint_name, endpoint=parsed_endpoint_name,
migration_limit=parsed_args.migration_limit, migration_limit=parsed_args.migration_limit,
tool_call_parser=parsed_args.dyn_tool_call_parser,
reasoning_parser=parsed_args.dyn_reasoning_parser,
) )
logging.debug(f"Dynamo args: {dynamo_args}") logging.debug(f"Dynamo args: {dynamo_args}")
......
...@@ -97,7 +97,10 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -97,7 +97,10 @@ async def init(runtime: DistributedRuntime, config: Config):
async def register_model(): async def register_model():
"""Register the model and signal readiness""" """Register the model and signal readiness"""
registration_success = await register_llm_with_runtime_config( registration_success = await register_llm_with_runtime_config(
engine, generate_endpoint, server_args, dynamo_args.migration_limit engine,
generate_endpoint,
server_args,
dynamo_args,
) )
if not registration_success: if not registration_success:
......
...@@ -9,20 +9,21 @@ from sglang.srt.server_args import ServerArgs ...@@ -9,20 +9,21 @@ from sglang.srt.server_args import ServerArgs
from dynamo._core import Endpoint from dynamo._core import Endpoint
from dynamo.llm import ModelRuntimeConfig, ModelType, register_llm from dynamo.llm import ModelRuntimeConfig, ModelType, register_llm
from dynamo.sglang.args import DynamoArgs
async def register_llm_with_runtime_config( async def register_llm_with_runtime_config(
engine: sgl.Engine, engine: sgl.Engine,
endpoint: Endpoint, endpoint: Endpoint,
server_args: ServerArgs, server_args: ServerArgs,
migration_limit: int, dynamo_args: DynamoArgs,
) -> bool: ) -> bool:
"""Register LLM with runtime config """Register LLM with runtime config
Returns: Returns:
bool: True if registration succeeded, False if it failed bool: True if registration succeeded, False if it failed
""" """
runtime_config = await _get_runtime_config(engine) runtime_config = await _get_runtime_config(engine, dynamo_args)
try: try:
await register_llm( await register_llm(
ModelType.Backend, ModelType.Backend,
...@@ -30,7 +31,7 @@ async def register_llm_with_runtime_config( ...@@ -30,7 +31,7 @@ async def register_llm_with_runtime_config(
server_args.model_path, server_args.model_path,
server_args.served_model_name, server_args.served_model_name,
kv_cache_block_size=server_args.page_size, kv_cache_block_size=server_args.page_size,
migration_limit=migration_limit, migration_limit=dynamo_args.migration_limit,
runtime_config=runtime_config, runtime_config=runtime_config,
) )
logging.info("Successfully registered LLM with runtime config") logging.info("Successfully registered LLM with runtime config")
...@@ -40,13 +41,17 @@ async def register_llm_with_runtime_config( ...@@ -40,13 +41,17 @@ async def register_llm_with_runtime_config(
return False return False
async def _get_runtime_config(engine: sgl.Engine) -> Optional[ModelRuntimeConfig]: async def _get_runtime_config(
engine: sgl.Engine, dynamo_args: DynamoArgs
) -> Optional[ModelRuntimeConfig]:
"""Get runtime config from SGLang engine""" """Get runtime config from SGLang engine"""
runtime_config = ModelRuntimeConfig()
# set reasoning parser and tool call parser
runtime_config.reasoning_parser = dynamo_args.reasoning_parser
runtime_config.tool_call_parser = dynamo_args.tool_call_parser
try: try:
# Try to check if the engine has a scheduler attribute with the computed values # Try to check if the engine has a scheduler attribute with the computed values
if hasattr(engine, "scheduler_info") and engine.scheduler_info is not None: if hasattr(engine, "scheduler_info") and engine.scheduler_info is not None:
runtime_config = ModelRuntimeConfig()
# Get max_total_num_tokens from scheduler_info # Get max_total_num_tokens from scheduler_info
if "max_total_num_tokens" in engine.scheduler_info: if "max_total_num_tokens" in engine.scheduler_info:
max_total_tokens = engine.scheduler_info["max_total_num_tokens"] max_total_tokens = engine.scheduler_info["max_total_num_tokens"]
...@@ -73,8 +78,8 @@ async def _get_runtime_config(engine: sgl.Engine) -> Optional[ModelRuntimeConfig ...@@ -73,8 +78,8 @@ async def _get_runtime_config(engine: sgl.Engine) -> Optional[ModelRuntimeConfig
"The engine may compute these values internally after initialization. " "The engine may compute these values internally after initialization. "
"Proceeding without runtime config - SGLang will use its internal defaults." "Proceeding without runtime config - SGLang will use its internal defaults."
) )
return None return runtime_config
except Exception as e: except Exception as e:
logging.warning(f"Failed to get runtime config: {e}. Proceeding without it.") logging.warning(f"Failed to get runtime config: {e}. Proceeding without it.")
return None return runtime_config
...@@ -228,6 +228,17 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -228,6 +228,17 @@ async def init(runtime: DistributedRuntime, config: Config):
async with get_llm_engine(engine_args) as engine: async with get_llm_engine(engine_args) as engine:
endpoint = component.endpoint(config.endpoint) endpoint = component.endpoint(config.endpoint)
# should ideally call get_engine_runtime_config
# this is because we don't have a good way to
# get total_kv_blocks from the engine yet without calling get_stats_async
# This causes an issue because get_stats_async doesn't work when no requests are sent to the engine
# So for now, we just set the parsers from the config
# TODO: fix this once we have a better way to get total_kv_blocks
runtime_config = ModelRuntimeConfig()
runtime_config.reasoning_parser = config.reasoning_parser
runtime_config.tool_call_parser = config.tool_call_parser
if is_first_worker(config): if is_first_worker(config):
# Register the model with runtime config # Register the model with runtime config
await register_llm( await register_llm(
...@@ -237,6 +248,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -237,6 +248,7 @@ async def init(runtime: DistributedRuntime, config: Config):
config.served_model_name, config.served_model_name,
kv_cache_block_size=config.kv_block_size, kv_cache_block_size=config.kv_block_size,
migration_limit=config.migration_limit, migration_limit=config.migration_limit,
runtime_config=runtime_config,
) )
# publisher will be set later if publishing is enabled. # publisher will be set later if publishing is enabled.
handler_config = RequestHandlerConfig( handler_config = RequestHandlerConfig(
......
...@@ -49,6 +49,9 @@ class Config: ...@@ -49,6 +49,9 @@ class Config:
self.next_endpoint: str = "" self.next_endpoint: str = ""
self.modality: str = "text" self.modality: str = "text"
self.reasoning_parser: Optional[str] = None
self.tool_call_parser: Optional[str] = None
def __str__(self) -> str: def __str__(self) -> str:
return ( return (
f"Config(namespace={self.namespace}, " f"Config(namespace={self.namespace}, "
...@@ -73,6 +76,8 @@ class Config: ...@@ -73,6 +76,8 @@ class Config:
f"disaggregation_strategy={self.disaggregation_strategy}, " f"disaggregation_strategy={self.disaggregation_strategy}, "
f"next_endpoint={self.next_endpoint}, " f"next_endpoint={self.next_endpoint}, "
f"modality={self.modality})" f"modality={self.modality})"
f"reasoning_parser={self.reasoning_parser})"
f"tool_call_parser={self.tool_call_parser})"
) )
...@@ -234,6 +239,21 @@ def cmd_line_args(): ...@@ -234,6 +239,21 @@ def cmd_line_args():
default="", default="",
help=f"Endpoint(in 'dyn://namespace.component.endpoint' format) to send requests to when running in disaggregation mode. Default: {DEFAULT_NEXT_ENDPOINT} if first worker, empty if next worker", help=f"Endpoint(in 'dyn://namespace.component.endpoint' format) to send requests to when running in disaggregation mode. Default: {DEFAULT_NEXT_ENDPOINT} if first worker, empty if next worker",
) )
# To avoid name conflicts with different backends, adoped prefix "dyn-" for dynamo specific args
parser.add_argument(
"--dyn-tool-call-parser",
type=str,
default=None,
help="Tool call parser name for the model. Available options: 'hermes', 'nemotron_deci', 'llama3_json', 'mistral', 'phi4'.",
)
parser.add_argument(
"--dyn-reasoning-parser",
type=str,
default=None,
help="Reasoning parser name for the model. Available options: 'basic', 'deepseek_r1', 'gpt_oss'.",
)
args = parser.parse_args() args = parser.parse_args()
config = Config() config = Config()
...@@ -294,4 +314,7 @@ def cmd_line_args(): ...@@ -294,4 +314,7 @@ def cmd_line_args():
config.publish_events_and_metrics = args.publish_events_and_metrics config.publish_events_and_metrics = args.publish_events_and_metrics
config.modality = args.modality config.modality = args.modality
config.reasoning_parser = args.dyn_reasoning_parser
config.tool_call_parser = args.dyn_tool_call_parser
return config return config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment