Unverified Commit 2e7a1e6c authored by Thomas Montfort's avatar Thomas Montfort Committed by GitHub
Browse files

fix(trtllm): read runtime config from resolved engine args (#8217)


Signed-off-by: default avatarThomas Montfort <tmontfort@nvidia.com>
Signed-off-by: default avatartmontfort <tmontfort@nvidia.com>
parent 470bb48f
......@@ -170,6 +170,102 @@ def test_deep_update_adds_new_keys():
assert target == {"a": 1, "b": 2, "c": {"nested": 3}}
# ---- Tests for engine_args resolution with extra/override engine args ----
class EngineArgsCaptured(Exception):
"""Raised by mocked get_llm_engine to capture engine_args and stop execution."""
def __init__(self, engine_args):
self.engine_args = engine_args
def _mock_get_llm_engine(engine_args, *args, **kwargs):
"""Mock for get_llm_engine that captures engine_args and short-circuits."""
raise EngineArgsCaptured(engine_args)
@pytest.mark.asyncio
async def test_init_llm_worker_engine_args_without_overrides(monkeypatch):
"""Without overrides, engine_args passed to get_llm_engine use CLI defaults."""
monkeypatch.delenv("DYN_TRTLLM_MAX_NUM_TOKENS", raising=False)
monkeypatch.delenv("DYN_TRTLLM_MAX_BATCH_SIZE", raising=False)
config = parse_args(["--model", "fake-model"])
with (
mock.patch("dynamo.trtllm.workers.llm_worker.tokenizer_factory"),
mock.patch("dynamo.trtllm.workers.llm_worker.nixl_connect.Connector"),
mock.patch("dynamo.trtllm.workers.llm_worker.dump_config"),
mock.patch("dynamo.trtllm.workers.llm_worker.LLMBackendMetrics"),
mock.patch(
"dynamo.trtllm.workers.llm_worker.get_llm_engine",
side_effect=_mock_get_llm_engine,
),
):
with pytest.raises(EngineArgsCaptured) as exc_info:
await init_llm_worker(
runtime=mock.MagicMock(),
config=config,
shutdown_event=asyncio.Event(),
)
engine_args = exc_info.value.engine_args
assert engine_args["max_num_tokens"] == config.max_num_tokens
assert engine_args["max_batch_size"] == config.max_batch_size
@pytest.mark.asyncio
async def test_init_llm_worker_engine_args_with_extra_engine_args(
tmp_path, monkeypatch
):
"""--extra-engine-args YAML overrides are reflected in engine_args passed to get_llm_engine."""
monkeypatch.delenv("DYN_TRTLLM_MAX_NUM_TOKENS", raising=False)
monkeypatch.delenv("DYN_TRTLLM_MAX_BATCH_SIZE", raising=False)
yaml_file = tmp_path / "engine_config.yaml"
yaml_file.write_text("max_num_tokens: 32768\nmax_batch_size: 512\n")
config = parse_args(
[
"--model",
"fake-model",
"--extra-engine-args",
str(yaml_file),
]
)
# CLI config should NOT reflect the YAML values
assert config.max_num_tokens != 32768
assert config.max_batch_size != 512
with (
mock.patch("dynamo.trtllm.workers.llm_worker.tokenizer_factory"),
mock.patch("dynamo.trtllm.workers.llm_worker.nixl_connect.Connector"),
mock.patch("dynamo.trtllm.workers.llm_worker.dump_config"),
mock.patch("dynamo.trtllm.workers.llm_worker.LLMBackendMetrics"),
mock.patch(
"dynamo.trtllm.workers.llm_worker.get_llm_engine",
side_effect=_mock_get_llm_engine,
),
):
with pytest.raises(EngineArgsCaptured) as exc_info:
await init_llm_worker(
runtime=mock.MagicMock(),
config=config,
shutdown_event=asyncio.Event(),
)
engine_args = exc_info.value.engine_args
assert engine_args["max_num_tokens"] == 32768, (
f"Expected max_num_tokens=32768 from YAML override, "
f"got {engine_args['max_num_tokens']}"
)
assert engine_args["max_batch_size"] == 512, (
f"Expected max_batch_size=512 from YAML override, "
f"got {engine_args['max_batch_size']}"
)
class MultimodalProcessorInstantiated(Exception):
"""Custom exception for testing MultimodalRequestProcessor."""
......@@ -180,11 +276,15 @@ async def test_init_llm_worker_creates_multimodal_processor():
assert config.modality == Modality.MULTIMODAL
# Mock everything init_llm_worker touches before MultimodalRequestProcessor.
with mock.patch("dynamo.trtllm.workers.llm_worker.tokenizer_factory"), mock.patch(
"dynamo.trtllm.workers.llm_worker.AutoConfig.from_pretrained",
), mock.patch(
"dynamo.trtllm.workers.llm_worker.MultimodalRequestProcessor",
side_effect=MultimodalProcessorInstantiated,
with (
mock.patch("dynamo.trtllm.workers.llm_worker.tokenizer_factory"),
mock.patch(
"dynamo.trtllm.workers.llm_worker.AutoConfig.from_pretrained",
),
mock.patch(
"dynamo.trtllm.workers.llm_worker.MultimodalRequestProcessor",
side_effect=MultimodalProcessorInstantiated,
),
):
with pytest.raises(MultimodalProcessorInstantiated):
await init_llm_worker(
......
......@@ -53,7 +53,7 @@ from dynamo.llm import (
from dynamo.runtime import DistributedRuntime
from dynamo.trtllm.args import Config
from dynamo.trtllm.constants import DisaggregationMode, Modality
from dynamo.trtllm.engine import Backend, TensorRTLLMEngine, get_llm_engine
from dynamo.trtllm.engine import Backend, get_llm_engine
from dynamo.trtllm.health_check import TrtllmHealthCheckPayload
from dynamo.trtllm.multimodal_processor import MultimodalRequestProcessor
from dynamo.trtllm.publisher import DYNAMO_COMPONENT_REGISTRY, get_publisher
......@@ -67,37 +67,6 @@ from dynamo.trtllm.utils.trtllm_utils import deep_update
DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024
async def get_engine_runtime_config(
engine: TensorRTLLMEngine, config: Config
) -> ModelRuntimeConfig:
"""Retrieve runtime configuration from TensorRT-LLM engine."""
runtime_config = ModelRuntimeConfig()
try:
# Extract total_kv_blocks from engine stats
stats = engine.llm.get_stats_async(timeout=5)
stat = await anext(stats)
runtime_config.total_kv_blocks = stat["kvCacheStats"]["maxNumBlocks"]
logging.info(
f"Set runtime config total_kv_blocks: {runtime_config.total_kv_blocks}"
)
# Extract max number of sequences
runtime_config.max_num_seqs = config.max_batch_size
logging.info(f"Set runtime config max_num_seqs: {runtime_config.max_num_seqs}")
# Get max_num_batched_tokens from config
runtime_config.max_num_batched_tokens = config.max_num_tokens
logging.info(
f"Set runtime config max_num_batched_tokens: {runtime_config.max_num_batched_tokens}"
)
except Exception as e:
logging.error(f"Failed to get runtime config from TensorRT-LLM engine: {e}")
# Keep default/None values if retrieval fails
return runtime_config
def build_kv_connector_config(config: Config):
if config.connector:
if config.connector[0] == "kvbm":
......@@ -499,8 +468,11 @@ async def init_llm_worker(
# - In vLLM: max_num_seqs = maximum concurrent requests (this is an unusual name due to vLLM's historic reasons)
# - In TensorRT-LLM: max_batch_size = maximum concurrent requests (clearer name)
# Both parameters control the same thing: how many requests can be processed simultaneously
runtime_config.max_num_seqs = config.max_batch_size
runtime_config.max_num_batched_tokens = config.max_num_tokens
# Need to get max_num_seqs and max_num_batched_tokens from engine_args
# because they can be overridden by --extra-engine-args or --override-engine-args
runtime_config.max_num_seqs = engine_args["max_batch_size"]
runtime_config.max_num_batched_tokens = engine_args["max_num_tokens"]
runtime_config.reasoning_parser = config.dyn_reasoning_parser
runtime_config.tool_call_parser = config.dyn_tool_call_parser
runtime_config.exclude_tools_when_tool_choice_none = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment