Unverified Commit 03d976c7 authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

refactor: Refactor the TRTLLM example components and improve UI (#1654)


Signed-off-by: default avatarTanmay Verma <tanmayv@nvidia.com>
parent 8a2d6529
......@@ -110,7 +110,7 @@ dynamo serve graphs.agg:Frontend -f ./configs/agg.yaml
#### Aggregated serving with KV Routing
```bash
cd /workspace/examples/tensorrt_llm
dynamo serve graphs.agg_router:Frontend -f ./configs/agg_router.yaml
dynamo serve graphs.agg:Frontend -f ./configs/agg_router.yaml
```
#### Disaggregated serving
......@@ -122,7 +122,7 @@ dynamo serve graphs.disagg:Frontend -f ./configs/disagg.yaml
#### Disaggregated serving with KV Routing
```bash
cd /workspace/examples/tensorrt_llm
dynamo serve graphs.disagg_router:Frontend -f ./configs/disagg_router.yaml
dynamo serve graphs.disagg:Frontend -f ./configs/disagg_router.yaml
```
#### Aggregated serving with Multi-Token Prediction (MTP) and DeepSeek R1
......
......@@ -12,515 +12,302 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import copy
import logging
import os
import signal
import threading
from contextlib import asynccontextmanager
from enum import Enum
from queue import Queue
from dataclasses import dataclass
from typing import Any, Optional
from common.parser import LLMAPIConfig
from common.protocol import DisaggregatedTypeConverter
from common.utils import ManagedThread, ServerType
from tensorrt_llm.executor import CppExecutorError
from tensorrt_llm.llmapi import LLM, SamplingParams
from tensorrt_llm.llmapi.disagg_utils import (
CtxGenServerConfig,
parse_disagg_config_file,
)
from common.protocol import DisaggregatedTypeConverter, TRTLLMWorkerRequest
from tensorrt_llm import SamplingParams
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from tensorrt_llm.serve.openai_protocol import DisaggregatedParams
from tensorrt_llm.serve.openai_protocol import (
DisaggregatedParams as OAIDisaggregatedParams,
)
from dynamo.llm import KvEventPublisher, WorkerMetricsPublisher
from dynamo.sdk import dynamo_context
from dynamo.llm import get_tensorrtllm_engine, get_tensorrtllm_publisher
from dynamo.runtime import DistributedRuntime
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
# Default buffer size for kv cache events.
DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024
class DisaggRequestType(Enum):
CONTEXT_ONLY = "context_only"
GENERATION_ONLY = "generation_only"
def parse_endpoint(endpoint: str) -> tuple[str, str, str]:
endpoint_str = endpoint.replace("dyn://", "", 1)
endpoint_parts = endpoint_str.split(".")
if len(endpoint_parts) != 3:
raise ValueError(
f"Invalid endpoint format: '{endpoint}'. "
"Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
)
def update_args_from_disagg_config(
engine_config: LLMAPIConfig, server_config: CtxGenServerConfig
):
# Update the LLM API config with the disaggregated config
# Allows for different configs for context and generation servers
engine_config.extra_args.update(**server_config.other_args)
engine_config.update_sub_configs(server_config.other_args)
return engine_config
return (endpoint_parts[0], endpoint_parts[1], endpoint_parts[2])
def _to_signed_i64(value: int | None) -> int | None:
"""Convert a Python int to signed 64-bit range by two's complement."""
if value is None:
return None
if value >= 2**63:
return value - 2**64
if value < -(2**63):
return ((value + 2**63) % 2**64) - 2**63
return value
@dataclass
class BaseEngineConfig:
"""Base engine configuration"""
namespace: str
component: str
endpoint: str
model_path: str
served_model_name: Optional[str] = None
kv_block_size: int = 32
extra_engine_args: str = ""
publish_events_and_metrics: bool = False
disaggregation_mode: str = "prefill_and_decode"
remote_prefill_endpoint: Optional[str] = None
lease_id: int = 0
def get_sampling_params(sampling_params_dict, default_sampling_params):
sampling_params = copy.deepcopy(default_sampling_params)
for key, value in sampling_params_dict.items():
if value is None:
continue
if hasattr(sampling_params, key):
setattr(sampling_params, key, value)
return sampling_params
def __str__(self) -> str:
return (
f"Config(namespace={self.namespace}, "
f"component={self.component}, "
f"endpoint={self.endpoint}, "
f"model_path={self.model_path}, "
f"served_model_name={self.served_model_name}, "
f"kv_block_size={self.kv_block_size}, "
f"extra_engine_args={self.extra_engine_args}, "
f"publish_events_and_metrics={self.publish_events_and_metrics}, "
f"disaggregation_mode={self.disaggregation_mode}, "
f"remote_prefill_endpoint={self.remote_prefill_endpoint}, "
f"lease_id={self.lease_id})"
)
class BaseTensorrtLLMEngine:
def __init__(
self,
namespace_str: str = "dynamo",
component_str: str = "tensorrt-llm",
worker_id: Optional[str] = None,
engine_config: LLMAPIConfig = None,
remote_prefill: bool = False,
min_workers: int = 0,
disagg_config_file: Optional[str] = None,
block_size: int = 32,
router: str = "round_robin",
server_type: ServerType = ServerType.GEN,
config: BaseEngineConfig,
):
self._namespace_str = namespace_str
self._component_str = component_str
self._worker_id = worker_id
self._remote_prefill = remote_prefill
self._min_workers = 0
self._kv_block_size = block_size
self._router = router
self._server_type = server_type
self._config = config
self._prefill_client = None
self._error_queue: Queue = Queue()
self._kv_metrics_publisher = None
if self._remote_prefill or self._server_type == ServerType.CTX:
self._min_workers = min_workers
if disagg_config_file is None or not os.path.exists(disagg_config_file):
self._llm_engine = None
self._llm_engine_context = None
self._llm_publisher = None
self._llm_publisher_context = None
self._runtime = None
self._first_generation = True
# Initialize default sampling params
self.default_sampling_params = SamplingParams()
async def initialize(self, runtime: DistributedRuntime):
"""Initialize the engine and prefill client if needed"""
self._runtime = runtime
# Convert model path to Path object if it's a local path, otherwise keep as string
model_path = str(self._config.model_path)
# Initialize the LLM engine
engine_args: dict[str, Any] = {
"model": model_path,
"tensor_parallel_size": 1,
"backend": "pytorch",
"skip_tokenizer_init": True,
}
if self._config.extra_engine_args:
# TODO: Support extra engine args from json file as well.
engine_args = update_llm_args_with_extra_options(
engine_args, self._config.extra_engine_args
)
# Update the model path in the config to the model path used by the engine.
self._config.model_path = str(engine_args["model"])
if not self._config.model_path:
raise ValueError(
"llmapi_disaggregated_config file does not exist or not provided"
"Model specification is required. Present neither in the config nor in the extra engine args."
)
disagg_config = parse_disagg_config_file(disagg_config_file)
server_config: CtxGenServerConfig = None
for config in disagg_config.server_configs:
# Select the first context server config
if config.type == server_type.value:
server_config = config
break
# Populate default sampling params from the model
tokenizer = tokenizer_factory(self._config.model_path)
self.default_sampling_params = SamplingParams()
self.default_sampling_params._setup(tokenizer)
self.default_sampling_params.stop = None
if self._config.publish_events_and_metrics:
# 'event_buffer_max_size' is required to enable TRTLLM to publish kv cache events.
kv_cache_config: dict[str, Any] | Any = None
if "kv_cache_config" not in engine_args:
kv_cache_config = {}
kv_cache_config[
"event_buffer_max_size"
] = DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
else:
kv_cache_config = engine_args["kv_cache_config"]
if (
hasattr(kv_cache_config, "event_buffer_max_size")
and not kv_cache_config.event_buffer_max_size
):
kv_cache_config.event_buffer_max_size = (
DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
)
elif (
isinstance(kv_cache_config, dict)
and "event_buffer_max_size" not in kv_cache_config
):
kv_cache_config[
"event_buffer_max_size"
] = DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
engine_args["kv_cache_config"] = kv_cache_config
if server_config is None:
server_type_str = (
"generation" if server_type == ServerType.GEN else "context"
# Enable iter perf stats by default if we are publishing events and metrics.
if not engine_args.get("enable_iter_perf_stats"):
engine_args["enable_iter_perf_stats"] = True
# Only pytorch backend is supported for now to publish events and metrics.
if engine_args.get("backend") != "pytorch":
logging.error(
"Only pytorch backend is supported for now to publish events and metrics."
)
raise ValueError(
f"No {server_type_str} server config found. Please check the disaggregated config file."
raise RuntimeError(
"Only pytorch backend is supported for now to publish events and metrics. Hence, KV router is not supported."
)
engine_config = update_args_from_disagg_config(engine_config, server_config)
logging.info(f"TRTLLM engine args: {engine_args}")
if router == "kv":
self._publish_stats = True
self._publish_events = True
# Get the engine using the asynccontextmanager
self._llm_engine_context = get_tensorrtllm_engine(engine_args)
if self._llm_engine_context is not None:
self._llm_engine = await self._llm_engine_context.__aenter__()
else:
self._publish_stats = False
self._publish_events = False
if self._publish_stats:
self._kv_metrics_publisher = WorkerMetricsPublisher()
if self._publish_events:
if self._worker_id is None:
raise ValueError("Worker ID is None!")
raise RuntimeError("Failed to create LLM engine context")
runtime = dynamo_context["runtime"]
kv_listener = runtime.namespace(self._namespace_str).component(
self._component_str
)
self._kv_event_publisher = KvEventPublisher(
kv_listener, int(self._worker_id), self._kv_block_size
)
logger.info("KvEventPublisher is initialized")
self._engine_config = engine_config
def _init_engine(self):
logger.info("Initializing engine")
# Run the engine in a separate thread running the AsyncIO event loop.
self._llm_engine: Optional[Any] = None
self._llm_engine_start_cv = threading.Condition()
self._llm_engine_shutdown_event = asyncio.Event()
self._event_thread = threading.Thread(
target=asyncio.run, args=(self._run_llm_engine(),)
if (
self._config.publish_events_and_metrics
and self._config.disaggregation_mode != "prefill"
):
kv_listener = runtime.namespace(self._config.namespace).component(
self._config.component
)
self._llm_publisher_context = get_tensorrtllm_publisher(
kv_listener,
self._llm_engine,
kv_listener,
self._config.lease_id,
self._config.kv_block_size,
)
if self._llm_publisher_context is not None:
self._llm_publisher = await self._llm_publisher_context.__aenter__()
else:
raise RuntimeError("Failed to create LLM publisher context")
# Initialize prefill client if in decode mode
if self._config.disaggregation_mode == "decode":
if self._config.remote_prefill_endpoint is None:
raise ValueError("remote_prefill_endpoint is required for decode mode")
logging.info(
f"Initializing remote prefill client for endpoint: {self._config.remote_prefill_endpoint}"
)
(
parsed_namespace,
parsed_component_name,
parsed_endpoint_name,
) = parse_endpoint(self._config.remote_prefill_endpoint)
if self._runtime is not None:
self._prefill_client = (
await self._runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
.client()
)
else:
raise RuntimeError("Runtime not initialized")
# Populate default sampling params from the model
tokenizer = tokenizer_factory(self._engine_config.model_name)
self._default_sampling_params = SamplingParams()
self._default_sampling_params._setup(tokenizer)
self._default_sampling_params.stop = None
self.publish_kv_cache_events_thread = None
self.publish_stats_thread = None
self._event_thread.start()
with self._llm_engine_start_cv:
while self._llm_engine is None:
self._llm_engine_start_cv.wait()
# The 'threading.Thread()' will not raise the exception here should the engine
# failed to start, so the exception is passed back via the engine variable.
if isinstance(self._llm_engine, Exception):
e = self._llm_engine
logger.error(f"Failed to start engine: {e}")
if self._event_thread is not None:
self._event_thread.join()
self._event_thread = None
raise e
async def cleanup(self):
"""Cleanup resources"""
if self._llm_publisher_context:
try:
if self._publish_stats:
self._init_publish_metrics_thread()
await self._llm_publisher_context.__aexit__(None, None, None)
except Exception as e:
logger.error(f"Failed to initialize publish metrics threads: {e}")
raise e
logging.error(f"Error during publisher cleanup: {e}")
finally:
self._llm_publisher = None
self._llm_publisher_context = None
if self._llm_engine_context:
try:
if self._publish_events:
self._init_publish_kv_cache_events_thread()
await self._llm_engine_context.__aexit__(None, None, None)
except Exception as e:
logger.error(f"Failed to initialize publish events threads: {e}")
raise e
def _init_publish_metrics_thread(self):
# Need to publish stats once so that worker can be selected.
# Publishing some dummy values...
request_active_slots = 0
request_total_slots = 4
kv_active_block = 0
kv_total_blocks = 4
num_requests_waiting = 0
gpu_cache_usage_perc = 0.0
gpu_prefix_cache_hit_rate = 0.0
num_requests_waiting = 0
gpu_cache_usage_perc = 0.0
gpu_prefix_cache_hit_rate = 0.0
if self._kv_metrics_publisher is None:
logger.error("KV metrics publisher not initialized!")
return
self._kv_metrics_publisher.publish(
request_active_slots,
request_total_slots,
kv_active_block,
kv_total_blocks,
num_requests_waiting,
gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate,
)
# Prepare threads for publishing stats but don't start them yet.
# TRTLLM needs to start generating tokens first before stats
# can be retrieved.
self.publish_stats_thread = ManagedThread(
self.publish_stats_task,
error_queue=self._error_queue,
name="publish_stats_thread",
)
def _init_publish_kv_cache_events_thread(self):
if self._kv_event_publisher is None:
logger.error("KV event publisher not initialized!")
return
# A set to store the block hash of partial block (i.e. block containing less than kv_block_size tokens) hashes.
# It is used to prevent sending remove event to kv router since partial blocks are not stored.
self._partial_block_hashes = set()
logging.error(f"Error during engine cleanup: {e}")
finally:
self._llm_engine = None
self._llm_engine_context = None
# Prepare threads for publishing kv cache events but don't start them yet.
# TRTLLM needs to start generating tokens first before kv cache events
# can be retrieved.
self.publish_kv_cache_events_thread = ManagedThread(
self.publish_kv_cache_events_task,
error_queue=self._error_queue,
name="publish_kv_cache_events_thread",
)
self._prefill_client = None
async def publish_stats_task(self):
"""
Publish stats to the metrics publisher.
async def remote_prefill(self, request: TRTLLMWorkerRequest):
"""
if self._llm_engine is None:
logger.error("LLM engine not initialized!")
return
Send a prefill request to the remote prefill worker.
if self._kv_metrics_publisher is None:
logger.error("KV metrics publisher not initialized!")
return False
stats = self._llm_engine.get_stats_async(timeout=5)
async for stat in stats:
request_active_slots = stat["numActiveRequests"]
request_total_slots = stat["maxNumActiveRequests"]
kv_active_block = stat["kvCacheStats"]["usedNumBlocks"]
kv_total_blocks = stat["kvCacheStats"]["maxNumBlocks"]
reused_blocks = stat["kvCacheStats"]["reusedBlocks"]
freeNumBlocks = stat["kvCacheStats"]["freeNumBlocks"]
allocTotalBlocks = stat["kvCacheStats"]["allocTotalBlocks"]
allocNewBlocks = stat["kvCacheStats"]["allocNewBlocks"]
# NOTE: num paused requests is always 0 when using guarantee no evict scheduler (default).
num_requests_waiting = (
stat["numQueuedRequests"]
+ stat["inflightBatchingStats"]["numPausedRequests"]
)
gpu_cache_usage_perc = allocTotalBlocks / kv_total_blocks
gpu_prefix_cache_hit_rate = stat["kvCacheStats"]["cacheHitRate"]
logger.debug(
f"Publishing stats: request_active_slots: {request_active_slots}, request_total_slots: {request_total_slots}, kv_active_block: {kv_active_block}, kv_total_blocks: {kv_total_blocks}, num_requests_waiting: {num_requests_waiting}, reused_blocks: {reused_blocks}, freeNumBlocks: {freeNumBlocks}, allocTotalBlocks: {allocTotalBlocks}, allocNewBlocks: {allocNewBlocks}, gpu_cache_usage_perc: {gpu_cache_usage_perc}, gpu_prefix_cache_hit_rate: {gpu_prefix_cache_hit_rate}"
)
Args:
request: The original request to be sent for prefill
self._kv_metrics_publisher.publish(
request_active_slots,
request_total_slots,
kv_active_block,
kv_total_blocks,
num_requests_waiting,
gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate,
)
Returns:
The response from the remote prefill worker
return True
async def publish_kv_cache_events_task(self):
"""
Publish kv cache events to the events publisher.
Raises:
ValueError: If prefill client is not initialized or multiple responses received
"""
if self._llm_engine is None:
logger.error("LLM engine not initialized!")
return
events = self._llm_engine.get_kv_cache_events_async(timeout=5)
async for event in events:
event_id = event["event_id"]
data = event["data"]
if data["type"] == "stored":
parent_hash = _to_signed_i64(data["parent_hash"])
token_ids = []
num_block_tokens = []
block_hashes = []
for block in data["blocks"]:
token_num_in_block = len(block["tokens"])
block_hash = _to_signed_i64(block["block_hash"])
if token_num_in_block > self._kv_block_size:
logger.error(
f"Block {block_hash} contains {token_num_in_block} tokens, which is greater than kv_block_size {self._kv_block_size}"
)
return
if token_num_in_block < self._kv_block_size:
logger.debug(
f"Early stop when block {block_hash} containing {token_num_in_block} tokens not equal to kv_block_size {self._kv_block_size}"
)
self._partial_block_hashes.add(block_hash)
break
num_block_tokens.append(token_num_in_block)
block_hashes.append(block_hash)
for token in block["tokens"]:
token_ids.append(int(token["token_id"]))
# Note: Currently data does not have lora_id.
# Using 0 as default value. If later data has
# lora_id, we need to verify if this is correct.
lora_id = data.get("lora_id", 0)
logger.debug(
f"publish stored event: event_id: {event_id}, token_ids: {token_ids}, num_block_tokens: {num_block_tokens}, block_hashes: {block_hashes}, lora_id: {lora_id}, parent_hash: {parent_hash}"
)
self._kv_event_publisher.publish_stored(
event_id,
token_ids,
num_block_tokens,
block_hashes,
lora_id,
parent_hash,
)
elif data["type"] == "removed":
block_hashes = []
for block_hash in data["block_hashes"]:
block_hash = _to_signed_i64(block_hash)
if block_hash in self._partial_block_hashes:
logger.debug(
f"Skipping removing block hash {block_hash} since it is a partial block"
)
self._partial_block_hashes.remove(block_hash)
continue
block_hashes.append(block_hash)
logger.debug(
f"publish removed event: event_id: {event_id}, block_hashes: {block_hashes}"
)
self._kv_event_publisher.publish_removed(event_id, block_hashes)
return True
def _start_threads(self):
if (
self.publish_kv_cache_events_thread
and not self.publish_kv_cache_events_thread.is_alive()
):
# [NOTE:] TRTLLM needs the stats to be collected on the same loop as the request handler.
self._stats_loop = asyncio.get_running_loop()
self.publish_kv_cache_events_thread.set_loop(self._stats_loop)
self.publish_kv_cache_events_thread.start()
logger.debug("Started kv cache events thread")
if self.publish_stats_thread and not self.publish_stats_thread.is_alive():
self._stats_loop = asyncio.get_running_loop()
self.publish_stats_thread.set_loop(self._stats_loop)
self.publish_stats_thread.start()
logger.debug("Started stats thread")
async def _run_llm_engine(self):
# Counter to keep track of ongoing request counts.
self._ongoing_request_count = 0
@asynccontextmanager
async def async_llm_wrapper():
# Create LLM in a thread to avoid blocking
loop = asyncio.get_running_loop()
try:
llm = await loop.run_in_executor(
None,
lambda: LLM(
model=self._engine_config.model_name,
**self._engine_config.to_dict(),
),
)
yield llm
finally:
if "llm" in locals():
# Run shutdown in a thread to avoid blocking
await loop.run_in_executor(None, llm.shutdown)
try:
async with async_llm_wrapper() as engine:
# Capture the engine event loop and make it visible to other threads.
self._event_loop = asyncio.get_running_loop()
# Signal the engine is started and make it visible to other threads.
with self._llm_engine_start_cv:
self._llm_engine = engine
self._llm_engine_start_cv.notify_all()
logger.info("Engine loaded and ready to serve...")
# Wait for the engine shutdown signal.
await self._llm_engine_shutdown_event.wait()
# Stop the publishing threads
if self.publish_stats_thread and self.publish_stats_thread.is_alive():
self.publish_stats_thread.stop()
self.publish_stats_thread.join()
if (
self.publish_kv_cache_events_thread
and self.publish_kv_cache_events_thread.is_alive()
):
self.publish_kv_cache_events_thread.stop()
self.publish_kv_cache_events_thread.join()
# Wait for the ongoing requests to complete.
while self._ongoing_request_count > 0:
logger.info(
"Awaiting remaining {} requests".format(
self._ongoing_request_count
)
)
await asyncio.sleep(1)
# Cancel all tasks in the event loop.
for task in asyncio.all_tasks(loop=self._event_loop):
if task is not asyncio.current_task():
task.cancel()
except Exception as e:
# Signal and pass the exception back via the engine variable if the engine
# failed to start. If the engine has started, re-raise the exception.
with self._llm_engine_start_cv:
if self._llm_engine is None:
self._llm_engine = e
self._llm_engine_start_cv.notify_all()
return
raise e
self._llm_engine = None
logger.info("Shutdown complete")
async def _get_remote_prefill_response(self, request):
prefill_request = copy.deepcopy(request)
prefill_request = request.model_copy(deep=True)
# TRTLLM requires max_tokens to be set for prefill requests.
prefill_request.stop_conditions.max_tokens = 1
prefill_request.disaggregated_params = DisaggregatedParams(
request_type=DisaggRequestType.CONTEXT_ONLY.value
prefill_request.disaggregated_params = OAIDisaggregatedParams(
request_type="context_only"
)
if self._prefill_client is None:
raise ValueError("Prefill client not initialized")
try:
# TODO: Use smart KV router to determine which prefill worker to use. This would also require supporting publishing events for prefill workers.
ctx_responses = [
ctx_response
async for ctx_response in await self._prefill_client.round_robin(
remote_prefill_responses = [
remote_prefill_response
async for remote_prefill_response in await self._prefill_client.round_robin(
prefill_request.model_dump_json()
)
]
if len(ctx_responses) > 1:
except Exception as e:
raise ValueError(f"Error in remote prefill: {e}")
if len(remote_prefill_responses) > 1:
raise ValueError(
"Prefill worker returned more than one response. This is currently not supported in remote prefill mode."
)
logger.debug(
f"Received response from prefill worker: {ctx_responses[0].data()}"
)
remote_prefill_response = ctx_responses[0]
if len(remote_prefill_responses) == 0:
raise ValueError("No response received from remote prefill worker")
remote_prefill_response = remote_prefill_responses[0]
return remote_prefill_response
async def generate(self, request):
async def generate(self, request: TRTLLMWorkerRequest):
if self._llm_engine is None:
raise RuntimeError("Engine not initialized")
if not self._error_queue.empty():
raise self._error_queue.get()
self._ongoing_request_count += 1
if self._llm_publisher:
publishers_error = self._llm_publisher.check_error_queue()
if publishers_error:
raise publishers_error
try:
worker_inputs = request.token_ids
inputs = request.token_ids
disaggregated_params = (
DisaggregatedTypeConverter.to_llm_disaggregated_params(
# Decode the disaggregated params from the request
disaggregated_params = DisaggregatedTypeConverter.to_llm_disaggregated_params(
request.disaggregated_params
)
)
num_output_tokens_so_far = 0
if self._remote_prefill and self._server_type == ServerType.GEN:
ctx_response = await self._get_remote_prefill_response(request)
remote_prefill_response = ctx_response.data()
if self._config.disaggregation_mode == "decode":
# Run prefill/context phase remotely if disaggregation mode is decode.
try:
prefill_result = await self.remote_prefill(request)
except Exception as e:
raise ValueError(f"Error in remote prefill: {e}")
remote_prefill_response = prefill_result.data()
if (
remote_prefill_response["finish_reason"] == "stop"
or remote_prefill_response["finish_reason"] == "error"
......@@ -529,10 +316,11 @@ class BaseTensorrtLLMEngine:
return
num_output_tokens_so_far = len(remote_prefill_response["token_ids"])
# Decode the disaggregated params from the remote prefill response
# Decode the disaggregated params from the remote prefill response
disaggregated_params = (
DisaggregatedTypeConverter.to_llm_disaggregated_params(
DisaggregatedParams(
OAIDisaggregatedParams(
**remote_prefill_response["disaggregated_params"]
)
)
......@@ -543,57 +331,55 @@ class BaseTensorrtLLMEngine:
first_token_response.pop("disaggregated_params")
yield first_token_response
disaggregated_params.request_type = (
DisaggRequestType.GENERATION_ONLY.value
)
# Set the disaggregated params to generation_only for the rest of the generation
disaggregated_params.request_type = "generation_only"
logger.debug(
f"Worker inputs: {worker_inputs}, disaggregated params: {disaggregated_params}"
)
sampling_params = self.default_sampling_params
for key, value in request.sampling_options.model_dump().items():
if not value:
continue
if hasattr(sampling_params, key):
setattr(sampling_params, key, value)
sampling_params = get_sampling_params(
request.sampling_options.dict(), self._default_sampling_params
)
max_tokens = request.stop_conditions.max_tokens
if max_tokens:
sampling_params.max_tokens = max_tokens
async for response in self._llm_engine.generate_async(
inputs=worker_inputs,
# TODO: Disable streaming for context only requests when adding disagg support
async for res in self._llm_engine.llm.generate_async(
inputs=inputs,
sampling_params=sampling_params,
disaggregated_params=disaggregated_params,
streaming=self._server_type != ServerType.CTX,
streaming=(self._config.disaggregation_mode != "prefill"),
):
if response.finished and self._server_type != ServerType.CTX:
# TRTLLM engine needs to start generating tokens first before stats
# can be retrieved.
if self._first_generation and self._llm_publisher:
self._llm_publisher.start()
self._first_generation = False
if res.finished and self._config.disaggregation_mode != "prefill":
yield {"finish_reason": "stop", "token_ids": []}
break
if not response.outputs:
if not res.outputs:
yield {"finish_reason": "error", "token_ids": []}
break
output = response.outputs[0]
output = res.outputs[0]
next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = output.finish_reason
if output.stop_reason:
out["stop_reason"] = output.stop_reason
if self._server_type == ServerType.CTX:
if self._config.disaggregation_mode == "prefill":
# Return the disaggregated params only when operating in prefill mode.
out[
"disaggregated_params"
] = DisaggregatedTypeConverter.to_oai_disaggregated_params(
output.disaggregated_params
).dict()
).model_dump()
yield out
num_output_tokens_so_far = next_total_toks
except CppExecutorError:
signal.raise_signal(signal.SIGINT)
except Exception as e:
raise RuntimeError("Failed to generate: " + str(e))
self._start_threads()
self._ongoing_request_count -= 1
......@@ -14,136 +14,28 @@
# limitations under the License.
import argparse
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Tuple
import yaml
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm.llmapi import KvCacheConfig
from tensorrt_llm.llmapi.llm_args import DecodingBaseConfig
@dataclass
class LLMAPIConfig:
def __init__(
self,
model_name: str,
model_path: str | None = None,
pytorch_backend_config: PyTorchConfig | None = None,
kv_cache_config: KvCacheConfig | None = None,
speculative_config: DecodingBaseConfig | None = None,
**kwargs,
):
self.model_name = model_name
self.model_path = model_path
self.pytorch_backend_config = pytorch_backend_config
self.kv_cache_config = kv_cache_config
self.speculative_config = speculative_config
self.extra_args = kwargs
# Hardcoded to skip tokenizer init for now.
# We will handle the tokenization/detokenization
# in the base engine.
if "skip_tokenizer_init" in self.extra_args:
self.extra_args.pop("skip_tokenizer_init")
self.skip_tokenizer_init = True
def to_dict(self) -> Dict[str, Any]:
data = {
"kv_cache_config": self.kv_cache_config,
"speculative_config": self.speculative_config,
"skip_tokenizer_init": self.skip_tokenizer_init,
}
if self.extra_args:
data.update(self.extra_args)
return data
def update_sub_configs(self, other_config: Dict[str, Any]):
# TODO: Consider removing pytorch_backend_config parsing as this section
# was collapsed to top level config fields in recent TRTLLM versions.
if "pytorch_backend_config" in other_config:
self.pytorch_backend_config = PyTorchConfig(
**other_config["pytorch_backend_config"]
)
self.extra_args.pop("pytorch_backend_config", None)
if "kv_cache_config" in other_config:
self.kv_cache_config = KvCacheConfig(**other_config["kv_cache_config"])
self.extra_args.pop("kv_cache_config", None)
if "speculative_config" in other_config:
self.speculative_config = DecodingBaseConfig.from_dict(
other_config["speculative_config"]
)
self.extra_args.pop("speculative_config", None)
def _get_llm_args(engine_config):
# Only do model validation checks and leave other checks to LLMAPI
if "model_name" not in engine_config:
raise ValueError("Model name is required in the TRT-LLM engine config.")
if engine_config.get("model_path", ""):
if os.path.exists(engine_config.get("model_path", "")):
engine_config["model_path"] = Path(engine_config["model_path"])
else:
raise ValueError(f"Model path {engine_config['model_path']} does not exist")
model_name = engine_config["model_name"]
model_path = engine_config.get("model_path", None)
engine_config.pop("model_name")
engine_config.pop("model_path", None)
# Store all other args as kwargs
llm_api_config = LLMAPIConfig(
model_name=model_name,
model_path=model_path,
**engine_config,
)
# Parse supported sub configs and remove from kwargs
llm_api_config.update_sub_configs(engine_config)
return llm_api_config
def _init_engine_args(engine_args_filepath):
"""Initialize engine arguments from config file."""
if not os.path.isfile(engine_args_filepath):
raise ValueError(
"'YAML file containing TRT-LLM engine args must be provided in when launching the worker."
)
try:
with open(engine_args_filepath) as file:
trtllm_engine_config = yaml.safe_load(file)
except yaml.YAMLError as e:
raise RuntimeError(f"Failed to parse engine config: {e}")
return _get_llm_args(trtllm_engine_config)
def parse_tensorrt_llm_args(
config_args,
) -> Tuple[Any, Tuple[Dict[str, Any], Dict[str, Any]]]:
) -> argparse.Namespace:
parser = argparse.ArgumentParser(description="A TensorRT-LLM Worker parser")
parser.add_argument(
"--engine_args", type=str, required=True, help="Path to the engine args file"
"--extra-engine-args",
type=str,
default="",
help="Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.",
)
parser.add_argument(
"--served_model_name",
"--model-path",
type=str,
help="Name of the model to serve",
default=None,
help="Path to disk model or HuggingFace model identifier to load.",
)
parser.add_argument(
"--llmapi-disaggregated-config",
"-c",
"--served_model_name",
type=str,
help="Path to the llmapi disaggregated config file",
default=None,
help="Name to serve the model under.",
)
parser.add_argument(
"--router",
......@@ -152,46 +44,19 @@ def parse_tensorrt_llm_args(
default="random",
help="Router type to use for scheduling requests to workers",
)
parser.add_argument(
"--min-workers",
type=int,
default=1,
help="Minimum number of workers for aggregated (monolith) server",
)
parser.add_argument(
"--min-prefill-workers",
type=int,
default=1,
help="Minimum number of prefill workers for disaggregated server",
)
parser.add_argument(
"--block-size",
"--kv-block-size",
type=int,
default=32,
help="Number of tokens per KV block in TRTLLM worker. Default is 32 for pytorch backend.",
)
parser.add_argument(
"--remote-prefill",
action="store_true",
help="Use remote prefill workers for generation server in Disaggregated mode.",
)
args = parser.parse_args(config_args)
return (args, _init_engine_args(args.engine_args))
def parse_dynamo_run_args() -> Tuple[Any, Tuple[Dict[str, Any], Dict[str, Any]]]:
parser = argparse.ArgumentParser(
description="A TensorRT-LLM Dynamo-run engine parser"
)
parser.add_argument(
"--engine_args", type=str, required=True, help="Path to the engine args file"
)
parser.add_argument(
"--publish-kv-cache-events",
"--enable-disagg",
action="store_true",
help="Publish KV cache events from TensorRT-LLM. Currently, only supported for context worker in Disaggregated mode.",
help="Enable remote prefill for the worker",
)
args, _ = parser.parse_known_args()
return (args, _init_engine_args(args.engine_args))
args = parser.parse_args(config_args)
return args
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import threading
import traceback
import weakref
from enum import Enum
from queue import Queue
from typing import Any, Callable, Coroutine, Optional, TypedDict, Union
logger = logging.getLogger(__name__)
AsyncTask = Union[Callable[..., Coroutine[Any, Any, bool]], weakref.WeakMethod]
class RoutingStrategy(Enum):
ROUND_ROBIN = "round_robin"
RANDOM = "random"
PREFIX = "prefix"
class RequestType(Enum):
CHAT = "chat"
COMPLETION = "completion"
class ServerType(Enum):
# Generation server used for disaggregated and aggregated requests
GEN = "gen"
# Context server used for disaggregated requests
CTX = "ctx"
# Dynamo run server used for Dynamo run requests
DYN_RUN = "dyn_run"
class ConversationMessage(TypedDict):
role: str
content: str
class ManagedThread(threading.Thread):
def __init__(
self,
task: Optional[AsyncTask],
error_queue: Optional[Queue] = None,
name: Optional[str] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
**kwargs,
):
super().__init__(name=name)
self.task = task
self.error_queue = error_queue
self.kwargs = kwargs
self.loop = loop
self.daemon = True
self.stop_event = threading.Event()
def set_loop(self, loop: asyncio.AbstractEventLoop):
self.loop = loop
def run(self):
while not self.stop_event.is_set():
task: Optional[AsyncTask] = self.task
if isinstance(task, weakref.WeakMethod):
task = task()
if task is None:
# Normally, this should not happen.
logger.warning("WeakMethod is expired.")
break
if task is None:
break
try:
if self.loop is None:
logger.error("[ManagedThread] Loop not initialized!")
break
future = asyncio.run_coroutine_threadsafe(
task(**self.kwargs), self.loop
)
_ = future.result()
except Exception as e:
logger.error(
f"Error in thread {self.name}: {e}\n{traceback.format_exc()}"
)
if self.error_queue is not None:
self.error_queue.put(e)
logger.info(f"Thread {self.name} stopped.")
def stop(self):
self.stop_event.set()
......@@ -12,15 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
from common.base_engine import BaseTensorrtLLMEngine
from common.base_engine import BaseEngineConfig, BaseTensorrtLLMEngine
from common.parser import parse_tensorrt_llm_args
from common.protocol import TRTLLMWorkerRequest
from common.utils import ServerType
from dynamo.sdk import async_on_start, dynamo_context, endpoint, service
from dynamo.sdk import async_on_start, dynamo_context, endpoint, on_shutdown, service
from dynamo.sdk.lib.config import ServiceConfig
logger = logging.getLogger(__name__)
......@@ -39,34 +37,37 @@ class TensorRTLLMPrefillWorker(BaseTensorrtLLMEngine):
class_name = self.__class__.__name__
config = ServiceConfig.get_instance()
config_args = config.as_args(class_name, prefix="")
args, engine_config = parse_tensorrt_llm_args(config_args)
worker_id = dynamo_context["endpoints"][0].lease_id()
super().__init__(
namespace_str="dynamo",
component_str=class_name,
worker_id=worker_id,
engine_config=engine_config,
remote_prefill=args.remote_prefill,
min_workers=args.min_workers,
disagg_config_file=args.llmapi_disaggregated_config,
block_size=args.block_size,
router=args.router,
server_type=ServerType.CTX,
args = parse_tensorrt_llm_args(config_args)
lease_id = dynamo_context["endpoints"][0].lease_id()
namespace, _ = TensorRTLLMPrefillWorker.dynamo_address() # type: ignore
engine_config = BaseEngineConfig(
namespace=namespace,
component=class_name,
endpoint="generate",
model_path=args.model_path,
served_model_name=args.served_model_name,
kv_block_size=args.kv_block_size,
extra_engine_args=args.extra_engine_args,
publish_events_and_metrics=False,
disaggregation_mode="prefill",
remote_prefill_endpoint=None,
lease_id=lease_id,
)
super().__init__(config=engine_config)
@async_on_start
async def async_init(self):
self._init_engine()
if self._kv_metrics_publisher is not None:
task = asyncio.create_task(self.create_metrics_publisher_endpoint())
task.add_done_callback(
lambda _: logger.info("metrics publisher endpoint created")
)
runtime = dynamo_context["runtime"]
await self.initialize(runtime)
logger.info("TensorRT-LLM Prefill Worker initialized")
async def create_metrics_publisher_endpoint(self):
component = dynamo_context["component"]
await self.kv_metrics_publisher.create_endpoint(component)
@on_shutdown
async def async_cleanup(self):
logger.info("Cleaning up TensorRT-LLM Prefill Worker")
await self.cleanup()
logger.info("TensorRT-LLM Prefill Worker cleanup completed")
@endpoint()
async def generate(self, request: TRTLLMWorkerRequest):
......
......@@ -12,17 +12,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
from common.base_engine import BaseTensorrtLLMEngine
from common.base_engine import BaseEngineConfig, BaseTensorrtLLMEngine
from common.parser import parse_tensorrt_llm_args
from common.protocol import TRTLLMWorkerRequest
from common.utils import ServerType
from components.prefill_worker import TensorRTLLMPrefillWorker
from dynamo.llm import ModelType, register_llm
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
from dynamo.sdk import (
async_on_start,
depends,
dynamo_context,
endpoint,
on_shutdown,
service,
)
from dynamo.sdk.lib.config import ServiceConfig
logger = logging.getLogger(__name__)
......@@ -43,74 +48,66 @@ class TensorRTLLMWorker(BaseTensorrtLLMEngine):
class_name = self.__class__.__name__
config = ServiceConfig.get_instance()
config_args = config.as_args(class_name, prefix="")
args, engine_config = parse_tensorrt_llm_args(config_args)
self.served_model_name = args.served_model_name
worker_id = dynamo_context["endpoints"][0].lease_id()
args = parse_tensorrt_llm_args(config_args)
lease_id = dynamo_context["endpoints"][0].lease_id()
namespace, _ = TensorRTLLMWorker.dynamo_address() # type: ignore
self._min_prefill_workers = args.min_prefill_workers
super().__init__(
namespace_str=namespace,
component_str=class_name,
worker_id=worker_id,
engine_config=engine_config,
remote_prefill=args.remote_prefill,
min_workers=args.min_workers,
disagg_config_file=args.llmapi_disaggregated_config,
block_size=args.block_size,
router=args.router,
server_type=ServerType.GEN,
endpoint_name = "generate"
publish_events_and_metrics = args.router == "kv"
prefill_class_name = "TensorRTLLMPrefillWorker"
if args.enable_disagg:
disaggregation_mode = "decode"
else:
disaggregation_mode = "prefill_and_decode"
engine_config = BaseEngineConfig(
namespace=namespace,
component=class_name,
endpoint=endpoint_name,
model_path=args.model_path,
served_model_name=args.served_model_name,
kv_block_size=args.kv_block_size,
extra_engine_args=args.extra_engine_args,
publish_events_and_metrics=publish_events_and_metrics,
disaggregation_mode=disaggregation_mode,
remote_prefill_endpoint=f"dyn://{namespace}.{prefill_class_name}.generate",
lease_id=lease_id,
)
super().__init__(config=engine_config)
@async_on_start
async def async_init(self):
self._init_engine()
runtime = dynamo_context["runtime"]
await self.initialize(runtime)
logger.info("Registering LLM for discovery")
comp_ns, comp_name = TensorRTLLMWorker.dynamo_address() # type: ignore
endpoint = runtime.namespace(comp_ns).component(comp_name).endpoint("generate")
endpoint = (
runtime.namespace(self._config.namespace)
.component(self._config.component)
.endpoint(self._config.endpoint)
)
try:
await register_llm(
ModelType.Backend,
endpoint,
self._engine_config.model_name,
self.served_model_name,
kv_cache_block_size=self._kv_block_size,
self._config.model_path,
self._config.served_model_name,
kv_cache_block_size=self._config.kv_block_size,
)
logger.info("Successfully registered LLM for discovery")
except Exception as e:
logger.error(f"Failed to register LLM for discovery: {e}")
raise
if self._remote_prefill:
runtime = dynamo_context["runtime"]
comp_ns, comp_name = TensorRTLLMPrefillWorker.dynamo_address() # type: ignore
self._prefill_client = (
await runtime.namespace(comp_ns)
.component(comp_name)
.endpoint("generate")
.client()
)
while len(self._prefill_client.instance_ids()) < self._min_prefill_workers:
logger.info(
f"Waiting for prefill workers to be ready.\n"
f" Current: {len(self._prefill_client.instance_ids())},"
f" Required: {self._min_prefill_workers}"
)
await asyncio.sleep(30)
if self._kv_metrics_publisher is not None:
task = asyncio.create_task(self.create_metrics_publisher_endpoint())
task.add_done_callback(
lambda _: logger.info("metrics publisher endpoint created")
)
logger.info("TensorRT-LLM Worker initialized")
async def create_metrics_publisher_endpoint(self):
component = dynamo_context["component"]
await self._kv_metrics_publisher.create_endpoint(component)
@on_shutdown
async def async_cleanup(self):
logger.info("Cleaning up TensorRT-LLM Worker")
await self.cleanup()
logger.info("TensorRT-LLM Worker cleanup completed")
@endpoint()
async def generate(self, request: TRTLLMWorkerRequest):
......
......@@ -20,8 +20,13 @@ Frontend:
router: round-robin
TensorRTLLMWorker:
# Path to disk model or HuggingFace model identifier to load
model-path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# Name to serve the model under
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
engine_args: "configs/llm_api_config.yaml"
# Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.
# The fields in `extra-engine-args` holds higher priority than the above TRTLLM engine fields.
extra-engine-args: "configs/engine_configs/agg_config.yaml"
router: round-robin
ServiceArgs:
workers: 1
......
......@@ -20,7 +20,13 @@ Frontend:
router: kv
TensorRTLLMWorker:
engine_args: "configs/llm_api_config_router.yaml"
# Path to disk model or HuggingFace model identifier to load
model-path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# Name to serve the model under
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.
# The fields in `extra-engine-args` holds higher priority than the above TRTLLM engine fields.
extra-engine-args: "configs/engine_configs/agg_config.yaml"
router: kv
ServiceArgs:
workers: 1
......
......@@ -22,7 +22,12 @@ Frontend:
TensorRTLLMWorker:
served_model_name: "nvidia/DeepSeek-R1-FP4"
engine_args: "configs/deepseek_r1/agg_llm_api_config.yaml"
# NOTE: FP4 only supported starting with Blackwell GPUs.
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
model-path: "nvidia/DeepSeek-R1-FP4"
extra-engine-args: "configs/deepseek_r1/engine_configs/agg_config.yaml"
router: round-robin
ServiceArgs:
workers: 1
......
......@@ -22,14 +22,13 @@ Frontend:
TensorRTLLMWorker:
served_model_name: "nvidia/DeepSeek-R1-FP4"
engine_args: "configs/deepseek_r1/agg_llm_api_config.yaml"
llmapi-disaggregated-config: "configs/deepseek_r1/disagg_llm_api_config.yaml"
remote-prefill: true
# NOTE: When testing/benchmarking multiple prefill workers, you can set
# this number to the exact amount of prefill workers if you want Dynamo to
# wait until all the prefill workers are ready before marking the decode
# worker ready.
min-prefill-workers: 1
# NOTE: FP4 only supported starting with Blackwell GPUs.
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
model-path: "nvidia/DeepSeek-R1-FP4"
extra-engine-args: "configs/deepseek_r1/engine_configs/decode_config.yaml"
enable-disagg: true
router: round-robin
ServiceArgs:
workers: 1
......@@ -37,8 +36,12 @@ TensorRTLLMWorker:
gpu: 4
TensorRTLLMPrefillWorker:
engine_args: "configs/deepseek_r1/agg_llm_api_config.yaml"
llmapi-disaggregated-config: "configs/deepseek_r1/disagg_llm_api_config.yaml"
# NOTE: FP4 only supported starting with Blackwell GPUs.
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
model-path: "nvidia/DeepSeek-R1-FP4"
extra-engine-args: "configs/deepseek_r1/engine_configs/prefill_config.yaml"
router: round-robin
ServiceArgs:
workers: 1
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Example Configs for Context & Generation on GB200 nodes
# - Context on 1xGB200 (4xB00)
# - Generation on 1xGB200 (4xB200)
# NOTE: Fields like hostname, ports, urls, num_instances, etc. only used by trtllm-serve, not by dynamo
backend: pytorch
context_servers:
# Context/prefill processes many tokens at once, so for a large ISL, a large
# batch size may not be needed to saturate GPU utilization.
max_batch_size: 1
max_num_tokens: 8192
max_seq_len: 8192
# TP/EP/PP/DP
tensor_parallel_size: 4
moe_expert_parallel_size: 4
pipeline_parallel_size: 1
enable_attention_dp: true
kv_cache_config:
free_gpu_memory_fraction: 0.75
# NOTE: pytorch_backend_config section flattened since: https://github.com/NVIDIA/TensorRT-LLM/pull/4603
# NOTE: This field is called 'enable_overlap_scheduler' in older TRTLLM versions
# Overlap scheduler not currently supported in context-only
disable_overlap_scheduler: true
print_iter_log: true
# NOTE: This dtype must match in both context/generation configs
kv_cache_dtype: fp8
generation_servers:
# Generation/decode processes one token per request at a time, so a larger
# batch size helps to saturate GPU utilization.
max_batch_size: 256
max_num_tokens: 256
# 8448 = 8192 ISL + 256 OSL
max_seq_len: 8448
# TP/EP/PP/DP
tensor_parallel_size: 4
moe_expert_parallel_size: 4
pipeline_parallel_size: 1
enable_attention_dp: false
kv_cache_config:
# With dp attention disabled: high free_gpu_memory_fraction is fine.
free_gpu_memory_fraction: 0.85
# With dp attention enabled: large ISL at high concurrency may need
# free_gpu_memory_fraction low to have enough available memory.
# free_gpu_memory_fraction: 0.30
# NOTE: pytorch_backend_config section flattened since: https://github.com/NVIDIA/TensorRT-LLM/pull/4603
# NOTE: This field is called 'enable_overlap_scheduler' in older TRTLLM versions
disable_overlap_scheduler: false
use_cuda_graph: true
cuda_graph_padding_enabled: true
# NOTE: For larger max batch size, you may want to add larger cuda graph
# batch sizes below to match.
cuda_graph_batch_sizes:
- 1
- 2
- 4
- 8
- 16
- 32
- 64
- 128
- 256
print_iter_log: true
# NOTE: This dtype must match in both context/generation configs
kv_cache_dtype: fp8
......@@ -12,12 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# NOTE: FP4 only supported starting with Blackwell GPUs.
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
model_name: "nvidia/DeepSeek-R1-FP4"
backend: pytorch
# TP/EP/PP/DP
......
......@@ -12,32 +12,44 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
backend: pytorch
# In the case of disaggregated deployment, this config will apply to each server
# and will be overwritten by the disaggregated config file
# TODO: figure out how to generate this from the service config or vice versa
model_name: "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
model_path: null
tensor_parallel_size: 1
moe_expert_parallel_size: 1
# TP/EP/PP/DP
tensor_parallel_size: 4
moe_expert_parallel_size: 4
pipeline_parallel_size: 1
enable_attention_dp: false
max_num_tokens: 8192
max_batch_size: 16
trust_remote_code: true
backend: pytorch
enable_chunked_prefill: true
max_batch_size: 256
max_num_tokens: 256
# 8448 = 8192 ISL + 256 OSL
max_seq_len: 8448
kv_cache_config:
free_gpu_memory_fraction: 0.95
event_buffer_max_size: 1024
enable_block_reuse: true
# With dp attention disabled: high free_gpu_memory_fraction is fine.
free_gpu_memory_fraction: 0.85
# With dp attention enabled: large ISL at high concurrency may need
# free_gpu_memory_fraction low to have enough available memory.
# free_gpu_memory_fraction: 0.30
# NOTE: pytorch_backend_config section flattened since: https://github.com/NVIDIA/TensorRT-LLM/pull/4603
# NOTE: overlap_scheduler enabled by default since this commit and changed
# config field from 'enable_overlap_scheduler' to 'disable_overlap_scheduler':
# https://github.com/NVIDIA/TensorRT-LLM/commit/b4e5df0ee0024eda3eeb83a6ba822245a30ab428
disable_overlap_scheduler: false
use_cuda_graph: true
enable_iter_perf_stats: true
cuda_graph_padding_enabled: true
# NOTE: For larger max batch size, you may want to add larger cuda graph
# batch sizes below to match.
cuda_graph_batch_sizes:
- 1
- 2
- 4
- 8
- 16
- 32
- 64
- 128
- 256
print_iter_log: true
kv_cache_dtype: fp8
......@@ -12,32 +12,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
backend: pytorch
# TP/EP/PP/DP
tensor_parallel_size: 4
moe_expert_parallel_size: 4
pipeline_parallel_size: 1
enable_attention_dp: true
# In the case of disaggregated deployment, this config will apply to each server
# and will be overwritten by the disaggregated config file
# TODO: figure out how to generate this from the service config or vice versa
model_name: "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
model_path: null
tensor_parallel_size: 1
moe_expert_parallel_size: 1
enable_attention_dp: false
max_batch_size: 1
max_num_tokens: 8192
max_batch_size: 16
trust_remote_code: true
backend: pytorch
enable_chunked_prefill: true
max_seq_len: 8192
kv_cache_config:
free_gpu_memory_fraction: 0.95
event_buffer_max_size: 1024
enable_block_reuse: true
# With dp attention disabled: high free_gpu_memory_fraction is fine.
free_gpu_memory_fraction: 0.75
# With dp attention enabled: large ISL at high concurrency may need
# free_gpu_memory_fraction low to have enough available memory.
# free_gpu_memory_fraction: 0.30
# NOTE: pytorch_backend_config section flattened since: https://github.com/NVIDIA/TensorRT-LLM/pull/4603
# NOTE: overlap_scheduler enabled by default since this commit and changed
# config field from 'enable_overlap_scheduler' to 'disable_overlap_scheduler':
# https://github.com/NVIDIA/TensorRT-LLM/commit/b4e5df0ee0024eda3eeb83a6ba822245a30ab428
use_cuda_graph: true
enable_iter_perf_stats: true
disable_overlap_scheduler: true
print_iter_log: true
# NOTE: This dtype must match in both prefill/decode configs
kv_cache_dtype: fp8
......@@ -18,7 +18,6 @@
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
model_name: "nvidia/DeepSeek-R1-FP4"
backend: pytorch
tensor_parallel_size: 4
moe_expert_parallel_size: 4
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# NOTE: FP4 only supported starting with Blackwell GPUs.
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
backend: pytorch
tensor_parallel_size: 4
moe_expert_parallel_size: 4
enable_attention_dp: false
max_batch_size: 256
# Note: When MPT is enabled and `cuda_graph_batch_sizes` is specified, `max_num_tokens` must satisfy the following formula:
# max_num_tokens >= max(cuda_graph_batch_sizes) * (num_nextn_predict_layers + 1)
# This is a known issue in TensorRT-LLM and will be resolved in the next release.
max_num_tokens: 512
# 8704 = 8192 ISL + 512 OSL
max_seq_len: 8704
kv_cache_config:
free_gpu_memory_fraction: 0.85
# Enable the MTP(Multi-Token Prediction) in decode model engine
speculative_config:
decoding_type: MTP
num_nextn_predict_layers: 1
use_cuda_graph: true
cuda_graph_padding_enabled: true
cuda_graph_batch_sizes:
- 1
- 2
- 4
- 8
- 16
- 32
- 64
- 128
- 256
print_iter_log: true
kv_cache_dtype: fp8
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# NOTE: FP4 only supported starting with Blackwell GPUs.
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
backend: pytorch
tensor_parallel_size: 4
moe_expert_parallel_size: 4
enable_attention_dp: true
max_batch_size: 1
max_num_tokens: 8192
max_seq_len: 8192
kv_cache_config:
free_gpu_memory_fraction: 0.75
print_iter_log: true
kv_cache_dtype: fp8
disable_overlap_scheduler: true
# Enable the MTP(Multi-Token Prediction) in the prefill model engine
speculative_config:
decoding_type: MTP
num_nextn_predict_layers: 1
......@@ -21,7 +21,14 @@ Frontend:
TensorRTLLMWorker:
served_model_name: "nvidia/DeepSeek-R1-FP4"
engine_args: "configs/deepseek_r1/mtp/mtp_agg_llm_api_config.yaml"
# NOTE: FP4 only supported starting with Blackwell GPUs.
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
model-path: "nvidia/DeepSeek-R1-FP4"
# Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.
# The fields in `extra-engine-args` holds higher priority than the above TRTLLM engine fields.
extra-engine-args: "configs/deepseek_r1/mtp/engine_configs/agg_config.yaml"
router: round-robin
ServiceArgs:
workers: 1
......
......@@ -21,19 +21,30 @@ Frontend:
TensorRTLLMWorker:
served_model_name: "nvidia/DeepSeek-R1-FP4"
engine_args: "configs/deepseek_r1/agg_llm_api_config.yaml"
llmapi-disaggregated-config: "configs/deepseek_r1/mtp/mtp_disagg_llm_api_config.yaml"
# NOTE: FP4 only supported starting with Blackwell GPUs.
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
model-path: "nvidia/DeepSeek-R1-FP4"
# Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.
# The fields in `extra-engine-args` holds higher priority than the above TRTLLM engine fields.
extra-engine-args: "configs/deepseek_r1/mtp/engine_configs/decode_config.yaml"
router: round-robin
remote-prefill: true
min-prefill-workers: 1
enable-disagg: true
ServiceArgs:
workers: 1
resources:
gpu: 4
TensorRTLLMPrefillWorker:
engine_args: "configs/deepseek_r1/agg_llm_api_config.yaml"
llmapi-disaggregated-config: "configs/deepseek_r1/mtp/mtp_disagg_llm_api_config.yaml"
# NOTE: FP4 only supported starting with Blackwell GPUs.
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
model-path: "nvidia/DeepSeek-R1-FP4"
# Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.
# The fields in `extra-engine-args` holds higher priority than the above TRTLLM engine fields.
extra-engine-args: "configs/deepseek_r1/mtp/engine_configs/prefill_config.yaml"
router: round-robin
ServiceArgs:
workers: 1
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# NOTE: FP4 only supported starting with Blackwell GPUs.
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
backend: pytorch
context_servers:
num_instances: 1
tensor_parallel_size: 4
moe_expert_parallel_size: 4
enable_attention_dp: true
max_batch_size: 1
max_num_tokens: 8192
max_seq_len: 8192
kv_cache_config:
free_gpu_memory_fraction: 0.75
print_iter_log: true
kv_cache_dtype: fp8
disable_overlap_scheduler: true
# Enable the MTP(Multi-Token Prediction) in the prefill model engine
speculative_config:
decoding_type: MTP
num_nextn_predict_layers: 1
generation_servers:
num_instances: 1
tensor_parallel_size: 4
moe_expert_parallel_size: 4
enable_attention_dp: false
max_batch_size: 256
# Note: When MPT is enabled and `cuda_graph_batch_sizes` is specified, `max_num_tokens` must satisfy the following formula:
# max_num_tokens >= max(cuda_graph_batch_sizes) * (num_nextn_predict_layers + 1)
# This is a known issue in TensorRT-LLM and will be resolved in the next release.
max_num_tokens: 512
# 8704 = 8192 ISL + 512 OSL
max_seq_len: 8704
kv_cache_config:
free_gpu_memory_fraction: 0.85
# Enable the MTP(Multi-Token Prediction) in the decode model engine
speculative_config:
decoding_type: MTP
num_nextn_predict_layers: 1
use_cuda_graph: true
cuda_graph_padding_enabled: true
cuda_graph_batch_sizes:
- 1
- 2
- 4
- 8
- 16
- 32
- 64
- 128
- 256
print_iter_log: true
kv_cache_dtype: fp8
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment