Unverified Commit 2e7a1e6c authored by Thomas Montfort's avatar Thomas Montfort Committed by GitHub
Browse files

fix(trtllm): read runtime config from resolved engine args (#8217)


Signed-off-by: default avatarThomas Montfort <tmontfort@nvidia.com>
Signed-off-by: default avatartmontfort <tmontfort@nvidia.com>
parent 470bb48f
...@@ -170,6 +170,102 @@ def test_deep_update_adds_new_keys(): ...@@ -170,6 +170,102 @@ def test_deep_update_adds_new_keys():
assert target == {"a": 1, "b": 2, "c": {"nested": 3}} assert target == {"a": 1, "b": 2, "c": {"nested": 3}}
# ---- Tests for engine_args resolution with extra/override engine args ----
class EngineArgsCaptured(Exception):
"""Raised by mocked get_llm_engine to capture engine_args and stop execution."""
def __init__(self, engine_args):
self.engine_args = engine_args
def _mock_get_llm_engine(engine_args, *args, **kwargs):
"""Mock for get_llm_engine that captures engine_args and short-circuits."""
raise EngineArgsCaptured(engine_args)
@pytest.mark.asyncio
async def test_init_llm_worker_engine_args_without_overrides(monkeypatch):
"""Without overrides, engine_args passed to get_llm_engine use CLI defaults."""
monkeypatch.delenv("DYN_TRTLLM_MAX_NUM_TOKENS", raising=False)
monkeypatch.delenv("DYN_TRTLLM_MAX_BATCH_SIZE", raising=False)
config = parse_args(["--model", "fake-model"])
with (
mock.patch("dynamo.trtllm.workers.llm_worker.tokenizer_factory"),
mock.patch("dynamo.trtllm.workers.llm_worker.nixl_connect.Connector"),
mock.patch("dynamo.trtllm.workers.llm_worker.dump_config"),
mock.patch("dynamo.trtllm.workers.llm_worker.LLMBackendMetrics"),
mock.patch(
"dynamo.trtllm.workers.llm_worker.get_llm_engine",
side_effect=_mock_get_llm_engine,
),
):
with pytest.raises(EngineArgsCaptured) as exc_info:
await init_llm_worker(
runtime=mock.MagicMock(),
config=config,
shutdown_event=asyncio.Event(),
)
engine_args = exc_info.value.engine_args
assert engine_args["max_num_tokens"] == config.max_num_tokens
assert engine_args["max_batch_size"] == config.max_batch_size
@pytest.mark.asyncio
async def test_init_llm_worker_engine_args_with_extra_engine_args(
tmp_path, monkeypatch
):
"""--extra-engine-args YAML overrides are reflected in engine_args passed to get_llm_engine."""
monkeypatch.delenv("DYN_TRTLLM_MAX_NUM_TOKENS", raising=False)
monkeypatch.delenv("DYN_TRTLLM_MAX_BATCH_SIZE", raising=False)
yaml_file = tmp_path / "engine_config.yaml"
yaml_file.write_text("max_num_tokens: 32768\nmax_batch_size: 512\n")
config = parse_args(
[
"--model",
"fake-model",
"--extra-engine-args",
str(yaml_file),
]
)
# CLI config should NOT reflect the YAML values
assert config.max_num_tokens != 32768
assert config.max_batch_size != 512
with (
mock.patch("dynamo.trtllm.workers.llm_worker.tokenizer_factory"),
mock.patch("dynamo.trtllm.workers.llm_worker.nixl_connect.Connector"),
mock.patch("dynamo.trtllm.workers.llm_worker.dump_config"),
mock.patch("dynamo.trtllm.workers.llm_worker.LLMBackendMetrics"),
mock.patch(
"dynamo.trtllm.workers.llm_worker.get_llm_engine",
side_effect=_mock_get_llm_engine,
),
):
with pytest.raises(EngineArgsCaptured) as exc_info:
await init_llm_worker(
runtime=mock.MagicMock(),
config=config,
shutdown_event=asyncio.Event(),
)
engine_args = exc_info.value.engine_args
assert engine_args["max_num_tokens"] == 32768, (
f"Expected max_num_tokens=32768 from YAML override, "
f"got {engine_args['max_num_tokens']}"
)
assert engine_args["max_batch_size"] == 512, (
f"Expected max_batch_size=512 from YAML override, "
f"got {engine_args['max_batch_size']}"
)
class MultimodalProcessorInstantiated(Exception): class MultimodalProcessorInstantiated(Exception):
"""Custom exception for testing MultimodalRequestProcessor.""" """Custom exception for testing MultimodalRequestProcessor."""
...@@ -180,11 +276,15 @@ async def test_init_llm_worker_creates_multimodal_processor(): ...@@ -180,11 +276,15 @@ async def test_init_llm_worker_creates_multimodal_processor():
assert config.modality == Modality.MULTIMODAL assert config.modality == Modality.MULTIMODAL
# Mock everything init_llm_worker touches before MultimodalRequestProcessor. # Mock everything init_llm_worker touches before MultimodalRequestProcessor.
with mock.patch("dynamo.trtllm.workers.llm_worker.tokenizer_factory"), mock.patch( with (
"dynamo.trtllm.workers.llm_worker.AutoConfig.from_pretrained", mock.patch("dynamo.trtllm.workers.llm_worker.tokenizer_factory"),
), mock.patch( mock.patch(
"dynamo.trtllm.workers.llm_worker.MultimodalRequestProcessor", "dynamo.trtllm.workers.llm_worker.AutoConfig.from_pretrained",
side_effect=MultimodalProcessorInstantiated, ),
mock.patch(
"dynamo.trtllm.workers.llm_worker.MultimodalRequestProcessor",
side_effect=MultimodalProcessorInstantiated,
),
): ):
with pytest.raises(MultimodalProcessorInstantiated): with pytest.raises(MultimodalProcessorInstantiated):
await init_llm_worker( await init_llm_worker(
......
...@@ -53,7 +53,7 @@ from dynamo.llm import ( ...@@ -53,7 +53,7 @@ from dynamo.llm import (
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.trtllm.args import Config from dynamo.trtllm.args import Config
from dynamo.trtllm.constants import DisaggregationMode, Modality from dynamo.trtllm.constants import DisaggregationMode, Modality
from dynamo.trtllm.engine import Backend, TensorRTLLMEngine, get_llm_engine from dynamo.trtllm.engine import Backend, get_llm_engine
from dynamo.trtllm.health_check import TrtllmHealthCheckPayload from dynamo.trtllm.health_check import TrtllmHealthCheckPayload
from dynamo.trtllm.multimodal_processor import MultimodalRequestProcessor from dynamo.trtllm.multimodal_processor import MultimodalRequestProcessor
from dynamo.trtllm.publisher import DYNAMO_COMPONENT_REGISTRY, get_publisher from dynamo.trtllm.publisher import DYNAMO_COMPONENT_REGISTRY, get_publisher
...@@ -67,37 +67,6 @@ from dynamo.trtllm.utils.trtllm_utils import deep_update ...@@ -67,37 +67,6 @@ from dynamo.trtllm.utils.trtllm_utils import deep_update
DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024 DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024
async def get_engine_runtime_config(
engine: TensorRTLLMEngine, config: Config
) -> ModelRuntimeConfig:
"""Retrieve runtime configuration from TensorRT-LLM engine."""
runtime_config = ModelRuntimeConfig()
try:
# Extract total_kv_blocks from engine stats
stats = engine.llm.get_stats_async(timeout=5)
stat = await anext(stats)
runtime_config.total_kv_blocks = stat["kvCacheStats"]["maxNumBlocks"]
logging.info(
f"Set runtime config total_kv_blocks: {runtime_config.total_kv_blocks}"
)
# Extract max number of sequences
runtime_config.max_num_seqs = config.max_batch_size
logging.info(f"Set runtime config max_num_seqs: {runtime_config.max_num_seqs}")
# Get max_num_batched_tokens from config
runtime_config.max_num_batched_tokens = config.max_num_tokens
logging.info(
f"Set runtime config max_num_batched_tokens: {runtime_config.max_num_batched_tokens}"
)
except Exception as e:
logging.error(f"Failed to get runtime config from TensorRT-LLM engine: {e}")
# Keep default/None values if retrieval fails
return runtime_config
def build_kv_connector_config(config: Config): def build_kv_connector_config(config: Config):
if config.connector: if config.connector:
if config.connector[0] == "kvbm": if config.connector[0] == "kvbm":
...@@ -499,8 +468,11 @@ async def init_llm_worker( ...@@ -499,8 +468,11 @@ async def init_llm_worker(
# - In vLLM: max_num_seqs = maximum concurrent requests (this is an unusual name due to vLLM's historic reasons) # - In vLLM: max_num_seqs = maximum concurrent requests (this is an unusual name due to vLLM's historic reasons)
# - In TensorRT-LLM: max_batch_size = maximum concurrent requests (clearer name) # - In TensorRT-LLM: max_batch_size = maximum concurrent requests (clearer name)
# Both parameters control the same thing: how many requests can be processed simultaneously # Both parameters control the same thing: how many requests can be processed simultaneously
runtime_config.max_num_seqs = config.max_batch_size
runtime_config.max_num_batched_tokens = config.max_num_tokens # Need to get max_num_seqs and max_num_batched_tokens from engine_args
# because they can be overridden by --extra-engine-args or --override-engine-args
runtime_config.max_num_seqs = engine_args["max_batch_size"]
runtime_config.max_num_batched_tokens = engine_args["max_num_tokens"]
runtime_config.reasoning_parser = config.dyn_reasoning_parser runtime_config.reasoning_parser = config.dyn_reasoning_parser
runtime_config.tool_call_parser = config.dyn_tool_call_parser runtime_config.tool_call_parser = config.dyn_tool_call_parser
runtime_config.exclude_tools_when_tool_choice_none = ( runtime_config.exclude_tools_when_tool_choice_none = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment