Unverified Commit 3d4fe574 authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

feat: Add standalone script for TRTLLM integration into dynamo-run (#1162)

parent 183f2b32
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/// Source code of the TRTLLM sub-process
pub const PY: &str = include_str!("trtllm_inc.py");
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# TODO:
# - Add event and metrics publishers
# - Support default dynamo-run out=trtllm launch
# - Support disaggregated serving
#
# Can be used standalone: `python3 trtllm_inc.py` - lots of optional cmd line params
import argparse
import asyncio
import logging
import sys
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Optional
import uvloop
# Import TRTLLM and related modules
from tensorrt_llm import LLM, LlmArgs, SamplingParams
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from dynamo.llm import KvMetricsPublisher, ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
# Only used if you run it manually from the command line
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
# Qwen/Qwen3-0.6B is not supported by TRTLLM yet.
DEFAULT_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
logging.basicConfig(level=logging.DEBUG)
class Config:
"""Command line parameters or defaults"""
namespace: str
component: str
endpoint: str
model_path: str
model_name: Optional[str]
tensor_parallel_size: int
kv_block_size: int
extra_engine_args: str
class RequestHandler:
"""
Request handler for the generate endpoint
"""
def __init__(self, component, engine, default_sampling_params):
self.engine = engine
self.component = component
self.default_sampling_params = default_sampling_params
self.metrics_publisher = KvMetricsPublisher()
def setup_kv_metrics(self):
# Initially send dummy metrics to kick start,
# TRTLLM will not update stat until forward pass is triggered
self.metrics_publisher.publish(
0, # request_active_slots
1024, # request_total_slots
0, # kv_active_blocks
1024, # kv_total_blocks
0, # num_requests_waiting
0.0, # gpu_cache_usage_perc
0.0, # gpu_prefix_cache_hit_rate
)
task = asyncio.create_task(self.create_metrics_publisher_endpoint())
task.add_done_callback(
lambda _: logging.debug("metrics publisher endpoint created")
)
async def create_metrics_publisher_endpoint(self):
logging.debug("Creating metrics publisher endpoint")
await self.metrics_publisher.create_endpoint(self.component)
async def generate(self, request):
inputs = request["token_ids"]
sampling_params = self.default_sampling_params
for key, value in request["sampling_options"].items():
if not value:
continue
if hasattr(sampling_params, key):
setattr(sampling_params, key, value)
max_tokens = request["stop_conditions"]["max_tokens"]
if max_tokens:
sampling_params.max_tokens = max_tokens
num_output_tokens_so_far = 0
# TODO: Disable streaming for context only requests when adding disagg support
async for res in self.engine.llm.generate_async(
inputs=inputs, sampling_params=sampling_params, streaming=True
):
if res.finished:
yield {"finish_reason": "stop", "token_ids": []}
break
if not res.outputs:
yield {"finish_reason": "error", "token_ids": []}
break
output = res.outputs[0]
next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = output.finish_reason
if output.stop_reason:
out["stop_reason"] = output.stop_reason
yield out
num_output_tokens_so_far = next_total_toks
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
await init(runtime, cmd_line_args())
class AsyncLLMEngine:
def __init__(self, engine_args):
self.engine_args = engine_args
self._llm: Optional[LLM] = None
self._initialized = False
async def initialize(self):
if not self._initialized:
model = self.engine_args.pop("model")
self._llm = LLM(
model=model,
**self.engine_args,
)
self._initialized = True
async def cleanup(self):
if self._initialized:
try:
self._llm.shutdown()
except Exception as e:
logging.error(f"Error during cleanup: {e}")
finally:
self._initialized = False
@property
def llm(self):
if not self._initialized:
raise RuntimeError("Engine not initialized")
return self._llm
@asynccontextmanager
async def get_llm_engine(engine_args: LlmArgs) -> AsyncGenerator[AsyncLLMEngine, None]:
engine = AsyncLLMEngine(engine_args)
try:
await engine.initialize()
yield engine
except Exception as e:
logging.error(f"Error in engine context: {e}")
raise
finally:
await engine.cleanup()
async def init(runtime: DistributedRuntime, config: Config):
"""
Instantiate and serve
"""
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
# Convert model path to Path object if it's a local path, otherwise keep as string
model_path = str(config.model_path)
arg_map = {
"model": model_path,
"tensor_parallel_size": config.tensor_parallel_size,
"skip_tokenizer_init": True,
"disable_log_requests": True,
"enable_prefix_caching": True,
# KV routing relies on logging KV metrics
"disable_log_stats": False,
}
if config.extra_engine_args != "":
arg_map = update_llm_args_with_extra_options(arg_map, config.extra_engine_args)
logging.debug(f"TRTLLM engine args: {arg_map}")
engine_args = arg_map
# Populate default sampling params from the model
tokenizer = tokenizer_factory(arg_map["model"])
default_sampling_params = SamplingParams()
default_sampling_params._setup(tokenizer)
default_sampling_params.stop = None
async with get_llm_engine(engine_args) as engine:
endpoint = component.endpoint(config.endpoint)
await register_llm(
ModelType.Backend, endpoint, config.model_path, config.model_name
)
handler = RequestHandler(component, engine, default_sampling_params)
handler.setup_kv_metrics()
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked
await endpoint.serve_endpoint(handler.generate)
def cmd_line_args():
parser = argparse.ArgumentParser(
description="TensorRT-LLM server integrated with Dynamo LLM."
)
parser.add_argument(
"--endpoint",
type=str,
default=DEFAULT_ENDPOINT,
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT}",
)
parser.add_argument(
"--model-path",
type=str,
default=DEFAULT_MODEL,
help=f"Path to disk model or HuggingFace model identifier to load. Default: {DEFAULT_MODEL}",
)
parser.add_argument(
"--model-name",
type=str,
default="",
help="Name to serve the model under. Defaults to deriving it from model path.",
)
parser.add_argument(
"--tensor-parallel-size", type=int, default=1, help="Number of GPUs to use."
)
parser.add_argument(
"--kv-block-size", type=int, default=32, help="Size of a KV cache block."
)
parser.add_argument(
"--extra-engine-args",
type=str,
default="",
help="Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.",
)
args = parser.parse_args()
config = Config()
config.model_path = args.model_path
if args.model_name:
config.model_name = args.model_name
else:
# This becomes an `Option` on the Rust side
config.model_name = None
endpoint_str = args.endpoint.replace("dyn://", "", 1)
endpoint_parts = endpoint_str.split(".")
if len(endpoint_parts) != 3:
logging.error(
f"Invalid endpoint format: '{args.endpoint}'. Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
)
sys.exit(1)
parsed_namespace, parsed_component_name, parsed_endpoint_name = endpoint_parts
config.namespace = parsed_namespace
config.component = parsed_component_name
config.endpoint = parsed_endpoint_name
config.tensor_parallel_size = args.tensor_parallel_size
config.kv_block_size = args.kv_block_size
config.extra_engine_args = args.extra_engine_args
return config
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment