Unverified Commit 03d976c7 authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

refactor: Refactor the TRTLLM example components and improve UI (#1654)


Signed-off-by: default avatarTanmay Verma <tanmayv@nvidia.com>
parent 8a2d6529
...@@ -110,7 +110,7 @@ dynamo serve graphs.agg:Frontend -f ./configs/agg.yaml ...@@ -110,7 +110,7 @@ dynamo serve graphs.agg:Frontend -f ./configs/agg.yaml
#### Aggregated serving with KV Routing #### Aggregated serving with KV Routing
```bash ```bash
cd /workspace/examples/tensorrt_llm cd /workspace/examples/tensorrt_llm
dynamo serve graphs.agg_router:Frontend -f ./configs/agg_router.yaml dynamo serve graphs.agg:Frontend -f ./configs/agg_router.yaml
``` ```
#### Disaggregated serving #### Disaggregated serving
...@@ -122,7 +122,7 @@ dynamo serve graphs.disagg:Frontend -f ./configs/disagg.yaml ...@@ -122,7 +122,7 @@ dynamo serve graphs.disagg:Frontend -f ./configs/disagg.yaml
#### Disaggregated serving with KV Routing #### Disaggregated serving with KV Routing
```bash ```bash
cd /workspace/examples/tensorrt_llm cd /workspace/examples/tensorrt_llm
dynamo serve graphs.disagg_router:Frontend -f ./configs/disagg_router.yaml dynamo serve graphs.disagg:Frontend -f ./configs/disagg_router.yaml
``` ```
#### Aggregated serving with Multi-Token Prediction (MTP) and DeepSeek R1 #### Aggregated serving with Multi-Token Prediction (MTP) and DeepSeek R1
......
...@@ -12,588 +12,374 @@ ...@@ -12,588 +12,374 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import copy
import logging import logging
import os from dataclasses import dataclass
import signal
import threading
from contextlib import asynccontextmanager
from enum import Enum
from queue import Queue
from typing import Any, Optional from typing import Any, Optional
from common.parser import LLMAPIConfig from common.protocol import DisaggregatedTypeConverter, TRTLLMWorkerRequest
from common.protocol import DisaggregatedTypeConverter from tensorrt_llm import SamplingParams
from common.utils import ManagedThread, ServerType from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options
from tensorrt_llm.executor import CppExecutorError
from tensorrt_llm.llmapi import LLM, SamplingParams
from tensorrt_llm.llmapi.disagg_utils import (
CtxGenServerConfig,
parse_disagg_config_file,
)
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from tensorrt_llm.serve.openai_protocol import DisaggregatedParams from tensorrt_llm.serve.openai_protocol import (
DisaggregatedParams as OAIDisaggregatedParams,
)
from dynamo.llm import KvEventPublisher, WorkerMetricsPublisher from dynamo.llm import get_tensorrtllm_engine, get_tensorrtllm_publisher
from dynamo.sdk import dynamo_context from dynamo.runtime import DistributedRuntime
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
# Default buffer size for kv cache events.
class DisaggRequestType(Enum): DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024
CONTEXT_ONLY = "context_only"
GENERATION_ONLY = "generation_only"
def update_args_from_disagg_config(
engine_config: LLMAPIConfig, server_config: CtxGenServerConfig
):
# Update the LLM API config with the disaggregated config
# Allows for different configs for context and generation servers
engine_config.extra_args.update(**server_config.other_args)
engine_config.update_sub_configs(server_config.other_args)
return engine_config
def _to_signed_i64(value: int | None) -> int | None: def parse_endpoint(endpoint: str) -> tuple[str, str, str]:
"""Convert a Python int to signed 64-bit range by two's complement.""" endpoint_str = endpoint.replace("dyn://", "", 1)
if value is None: endpoint_parts = endpoint_str.split(".")
return None if len(endpoint_parts) != 3:
raise ValueError(
if value >= 2**63: f"Invalid endpoint format: '{endpoint}'. "
return value - 2**64 "Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
if value < -(2**63): )
return ((value + 2**63) % 2**64) - 2**63
return value
def get_sampling_params(sampling_params_dict, default_sampling_params): return (endpoint_parts[0], endpoint_parts[1], endpoint_parts[2])
sampling_params = copy.deepcopy(default_sampling_params)
for key, value in sampling_params_dict.items():
if value is None: @dataclass
continue class BaseEngineConfig:
if hasattr(sampling_params, key): """Base engine configuration"""
setattr(sampling_params, key, value)
return sampling_params namespace: str
component: str
endpoint: str
model_path: str
served_model_name: Optional[str] = None
kv_block_size: int = 32
extra_engine_args: str = ""
publish_events_and_metrics: bool = False
disaggregation_mode: str = "prefill_and_decode"
remote_prefill_endpoint: Optional[str] = None
lease_id: int = 0
def __str__(self) -> str:
return (
f"Config(namespace={self.namespace}, "
f"component={self.component}, "
f"endpoint={self.endpoint}, "
f"model_path={self.model_path}, "
f"served_model_name={self.served_model_name}, "
f"kv_block_size={self.kv_block_size}, "
f"extra_engine_args={self.extra_engine_args}, "
f"publish_events_and_metrics={self.publish_events_and_metrics}, "
f"disaggregation_mode={self.disaggregation_mode}, "
f"remote_prefill_endpoint={self.remote_prefill_endpoint}, "
f"lease_id={self.lease_id})"
)
class BaseTensorrtLLMEngine: class BaseTensorrtLLMEngine:
def __init__( def __init__(
self, self,
namespace_str: str = "dynamo", config: BaseEngineConfig,
component_str: str = "tensorrt-llm",
worker_id: Optional[str] = None,
engine_config: LLMAPIConfig = None,
remote_prefill: bool = False,
min_workers: int = 0,
disagg_config_file: Optional[str] = None,
block_size: int = 32,
router: str = "round_robin",
server_type: ServerType = ServerType.GEN,
): ):
self._namespace_str = namespace_str self._config = config
self._component_str = component_str
self._worker_id = worker_id
self._remote_prefill = remote_prefill
self._min_workers = 0
self._kv_block_size = block_size
self._router = router
self._server_type = server_type
self._prefill_client = None self._prefill_client = None
self._error_queue: Queue = Queue() self._llm_engine = None
self._kv_metrics_publisher = None self._llm_engine_context = None
self._llm_publisher = None
if self._remote_prefill or self._server_type == ServerType.CTX: self._llm_publisher_context = None
self._min_workers = min_workers self._runtime = None
if disagg_config_file is None or not os.path.exists(disagg_config_file): self._first_generation = True
raise ValueError( # Initialize default sampling params
"llmapi_disaggregated_config file does not exist or not provided" self.default_sampling_params = SamplingParams()
)
disagg_config = parse_disagg_config_file(disagg_config_file) async def initialize(self, runtime: DistributedRuntime):
server_config: CtxGenServerConfig = None """Initialize the engine and prefill client if needed"""
self._runtime = runtime
for config in disagg_config.server_configs:
# Select the first context server config # Convert model path to Path object if it's a local path, otherwise keep as string
if config.type == server_type.value: model_path = str(self._config.model_path)
server_config = config
break # Initialize the LLM engine
engine_args: dict[str, Any] = {
if server_config is None: "model": model_path,
server_type_str = ( "tensor_parallel_size": 1,
"generation" if server_type == ServerType.GEN else "context" "backend": "pytorch",
) "skip_tokenizer_init": True,
raise ValueError( }
f"No {server_type_str} server config found. Please check the disaggregated config file."
) if self._config.extra_engine_args:
# TODO: Support extra engine args from json file as well.
engine_config = update_args_from_disagg_config(engine_config, server_config) engine_args = update_llm_args_with_extra_options(
engine_args, self._config.extra_engine_args
if router == "kv":
self._publish_stats = True
self._publish_events = True
else:
self._publish_stats = False
self._publish_events = False
if self._publish_stats:
self._kv_metrics_publisher = WorkerMetricsPublisher()
if self._publish_events:
if self._worker_id is None:
raise ValueError("Worker ID is None!")
runtime = dynamo_context["runtime"]
kv_listener = runtime.namespace(self._namespace_str).component(
self._component_str
) )
self._kv_event_publisher = KvEventPublisher( # Update the model path in the config to the model path used by the engine.
kv_listener, int(self._worker_id), self._kv_block_size self._config.model_path = str(engine_args["model"])
if not self._config.model_path:
raise ValueError(
"Model specification is required. Present neither in the config nor in the extra engine args."
) )
logger.info("KvEventPublisher is initialized")
self._engine_config = engine_config
def _init_engine(self):
logger.info("Initializing engine")
# Run the engine in a separate thread running the AsyncIO event loop.
self._llm_engine: Optional[Any] = None
self._llm_engine_start_cv = threading.Condition()
self._llm_engine_shutdown_event = asyncio.Event()
self._event_thread = threading.Thread(
target=asyncio.run, args=(self._run_llm_engine(),)
)
# Populate default sampling params from the model # Populate default sampling params from the model
tokenizer = tokenizer_factory(self._engine_config.model_name) tokenizer = tokenizer_factory(self._config.model_path)
self._default_sampling_params = SamplingParams() self.default_sampling_params = SamplingParams()
self._default_sampling_params._setup(tokenizer) self.default_sampling_params._setup(tokenizer)
self._default_sampling_params.stop = None self.default_sampling_params.stop = None
self.publish_kv_cache_events_thread = None if self._config.publish_events_and_metrics:
self.publish_stats_thread = None # 'event_buffer_max_size' is required to enable TRTLLM to publish kv cache events.
kv_cache_config: dict[str, Any] | Any = None
self._event_thread.start() if "kv_cache_config" not in engine_args:
with self._llm_engine_start_cv: kv_cache_config = {}
while self._llm_engine is None: kv_cache_config[
self._llm_engine_start_cv.wait() "event_buffer_max_size"
] = DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
# The 'threading.Thread()' will not raise the exception here should the engine else:
# failed to start, so the exception is passed back via the engine variable. kv_cache_config = engine_args["kv_cache_config"]
if isinstance(self._llm_engine, Exception): if (
e = self._llm_engine hasattr(kv_cache_config, "event_buffer_max_size")
logger.error(f"Failed to start engine: {e}") and not kv_cache_config.event_buffer_max_size
if self._event_thread is not None: ):
self._event_thread.join() kv_cache_config.event_buffer_max_size = (
self._event_thread = None DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
raise e )
elif (
try: isinstance(kv_cache_config, dict)
if self._publish_stats: and "event_buffer_max_size" not in kv_cache_config
self._init_publish_metrics_thread() ):
except Exception as e: kv_cache_config[
logger.error(f"Failed to initialize publish metrics threads: {e}") "event_buffer_max_size"
raise e ] = DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
engine_args["kv_cache_config"] = kv_cache_config
try:
if self._publish_events: # Enable iter perf stats by default if we are publishing events and metrics.
self._init_publish_kv_cache_events_thread() if not engine_args.get("enable_iter_perf_stats"):
except Exception as e: engine_args["enable_iter_perf_stats"] = True
logger.error(f"Failed to initialize publish events threads: {e}")
raise e # Only pytorch backend is supported for now to publish events and metrics.
if engine_args.get("backend") != "pytorch":
def _init_publish_metrics_thread(self): logging.error(
# Need to publish stats once so that worker can be selected. "Only pytorch backend is supported for now to publish events and metrics."
# Publishing some dummy values... )
request_active_slots = 0 raise RuntimeError(
request_total_slots = 4 "Only pytorch backend is supported for now to publish events and metrics. Hence, KV router is not supported."
kv_active_block = 0 )
kv_total_blocks = 4
num_requests_waiting = 0
gpu_cache_usage_perc = 0.0
gpu_prefix_cache_hit_rate = 0.0
num_requests_waiting = 0
gpu_cache_usage_perc = 0.0
gpu_prefix_cache_hit_rate = 0.0
if self._kv_metrics_publisher is None:
logger.error("KV metrics publisher not initialized!")
return
self._kv_metrics_publisher.publish(
request_active_slots,
request_total_slots,
kv_active_block,
kv_total_blocks,
num_requests_waiting,
gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate,
)
# Prepare threads for publishing stats but don't start them yet. logging.info(f"TRTLLM engine args: {engine_args}")
# TRTLLM needs to start generating tokens first before stats
# can be retrieved.
self.publish_stats_thread = ManagedThread(
self.publish_stats_task,
error_queue=self._error_queue,
name="publish_stats_thread",
)
def _init_publish_kv_cache_events_thread(self): # Get the engine using the asynccontextmanager
if self._kv_event_publisher is None: self._llm_engine_context = get_tensorrtllm_engine(engine_args)
logger.error("KV event publisher not initialized!") if self._llm_engine_context is not None:
return self._llm_engine = await self._llm_engine_context.__aenter__()
else:
# A set to store the block hash of partial block (i.e. block containing less than kv_block_size tokens) hashes. raise RuntimeError("Failed to create LLM engine context")
# It is used to prevent sending remove event to kv router since partial blocks are not stored.
self._partial_block_hashes = set()
# Prepare threads for publishing kv cache events but don't start them yet.
# TRTLLM needs to start generating tokens first before kv cache events
# can be retrieved.
self.publish_kv_cache_events_thread = ManagedThread(
self.publish_kv_cache_events_task,
error_queue=self._error_queue,
name="publish_kv_cache_events_thread",
)
async def publish_stats_task(self): if (
""" self._config.publish_events_and_metrics
Publish stats to the metrics publisher. and self._config.disaggregation_mode != "prefill"
""" ):
if self._llm_engine is None: kv_listener = runtime.namespace(self._config.namespace).component(
logger.error("LLM engine not initialized!") self._config.component
return
if self._kv_metrics_publisher is None:
logger.error("KV metrics publisher not initialized!")
return False
stats = self._llm_engine.get_stats_async(timeout=5)
async for stat in stats:
request_active_slots = stat["numActiveRequests"]
request_total_slots = stat["maxNumActiveRequests"]
kv_active_block = stat["kvCacheStats"]["usedNumBlocks"]
kv_total_blocks = stat["kvCacheStats"]["maxNumBlocks"]
reused_blocks = stat["kvCacheStats"]["reusedBlocks"]
freeNumBlocks = stat["kvCacheStats"]["freeNumBlocks"]
allocTotalBlocks = stat["kvCacheStats"]["allocTotalBlocks"]
allocNewBlocks = stat["kvCacheStats"]["allocNewBlocks"]
# NOTE: num paused requests is always 0 when using guarantee no evict scheduler (default).
num_requests_waiting = (
stat["numQueuedRequests"]
+ stat["inflightBatchingStats"]["numPausedRequests"]
) )
gpu_cache_usage_perc = allocTotalBlocks / kv_total_blocks self._llm_publisher_context = get_tensorrtllm_publisher(
gpu_prefix_cache_hit_rate = stat["kvCacheStats"]["cacheHitRate"] kv_listener,
self._llm_engine,
logger.debug( kv_listener,
f"Publishing stats: request_active_slots: {request_active_slots}, request_total_slots: {request_total_slots}, kv_active_block: {kv_active_block}, kv_total_blocks: {kv_total_blocks}, num_requests_waiting: {num_requests_waiting}, reused_blocks: {reused_blocks}, freeNumBlocks: {freeNumBlocks}, allocTotalBlocks: {allocTotalBlocks}, allocNewBlocks: {allocNewBlocks}, gpu_cache_usage_perc: {gpu_cache_usage_perc}, gpu_prefix_cache_hit_rate: {gpu_prefix_cache_hit_rate}" self._config.lease_id,
self._config.kv_block_size,
) )
if self._llm_publisher_context is not None:
self._kv_metrics_publisher.publish( self._llm_publisher = await self._llm_publisher_context.__aenter__()
request_active_slots, else:
request_total_slots, raise RuntimeError("Failed to create LLM publisher context")
kv_active_block,
kv_total_blocks, # Initialize prefill client if in decode mode
num_requests_waiting, if self._config.disaggregation_mode == "decode":
gpu_cache_usage_perc, if self._config.remote_prefill_endpoint is None:
gpu_prefix_cache_hit_rate, raise ValueError("remote_prefill_endpoint is required for decode mode")
logging.info(
f"Initializing remote prefill client for endpoint: {self._config.remote_prefill_endpoint}"
) )
(
return True parsed_namespace,
parsed_component_name,
async def publish_kv_cache_events_task(self): parsed_endpoint_name,
""" ) = parse_endpoint(self._config.remote_prefill_endpoint)
Publish kv cache events to the events publisher. if self._runtime is not None:
""" self._prefill_client = (
if self._llm_engine is None: await self._runtime.namespace(parsed_namespace)
logger.error("LLM engine not initialized!") .component(parsed_component_name)
return .endpoint(parsed_endpoint_name)
.client()
events = self._llm_engine.get_kv_cache_events_async(timeout=5)
async for event in events:
event_id = event["event_id"]
data = event["data"]
if data["type"] == "stored":
parent_hash = _to_signed_i64(data["parent_hash"])
token_ids = []
num_block_tokens = []
block_hashes = []
for block in data["blocks"]:
token_num_in_block = len(block["tokens"])
block_hash = _to_signed_i64(block["block_hash"])
if token_num_in_block > self._kv_block_size:
logger.error(
f"Block {block_hash} contains {token_num_in_block} tokens, which is greater than kv_block_size {self._kv_block_size}"
)
return
if token_num_in_block < self._kv_block_size:
logger.debug(
f"Early stop when block {block_hash} containing {token_num_in_block} tokens not equal to kv_block_size {self._kv_block_size}"
)
self._partial_block_hashes.add(block_hash)
break
num_block_tokens.append(token_num_in_block)
block_hashes.append(block_hash)
for token in block["tokens"]:
token_ids.append(int(token["token_id"]))
# Note: Currently data does not have lora_id.
# Using 0 as default value. If later data has
# lora_id, we need to verify if this is correct.
lora_id = data.get("lora_id", 0)
logger.debug(
f"publish stored event: event_id: {event_id}, token_ids: {token_ids}, num_block_tokens: {num_block_tokens}, block_hashes: {block_hashes}, lora_id: {lora_id}, parent_hash: {parent_hash}"
)
self._kv_event_publisher.publish_stored(
event_id,
token_ids,
num_block_tokens,
block_hashes,
lora_id,
parent_hash,
) )
elif data["type"] == "removed": else:
block_hashes = [] raise RuntimeError("Runtime not initialized")
for block_hash in data["block_hashes"]:
block_hash = _to_signed_i64(block_hash)
if block_hash in self._partial_block_hashes:
logger.debug(
f"Skipping removing block hash {block_hash} since it is a partial block"
)
self._partial_block_hashes.remove(block_hash)
continue
block_hashes.append(block_hash)
logger.debug(
f"publish removed event: event_id: {event_id}, block_hashes: {block_hashes}"
)
self._kv_event_publisher.publish_removed(event_id, block_hashes)
return True
def _start_threads(self): async def cleanup(self):
if ( """Cleanup resources"""
self.publish_kv_cache_events_thread if self._llm_publisher_context:
and not self.publish_kv_cache_events_thread.is_alive()
):
# [NOTE:] TRTLLM needs the stats to be collected on the same loop as the request handler.
self._stats_loop = asyncio.get_running_loop()
self.publish_kv_cache_events_thread.set_loop(self._stats_loop)
self.publish_kv_cache_events_thread.start()
logger.debug("Started kv cache events thread")
if self.publish_stats_thread and not self.publish_stats_thread.is_alive():
self._stats_loop = asyncio.get_running_loop()
self.publish_stats_thread.set_loop(self._stats_loop)
self.publish_stats_thread.start()
logger.debug("Started stats thread")
async def _run_llm_engine(self):
# Counter to keep track of ongoing request counts.
self._ongoing_request_count = 0
@asynccontextmanager
async def async_llm_wrapper():
# Create LLM in a thread to avoid blocking
loop = asyncio.get_running_loop()
try: try:
llm = await loop.run_in_executor( await self._llm_publisher_context.__aexit__(None, None, None)
None, except Exception as e:
lambda: LLM( logging.error(f"Error during publisher cleanup: {e}")
model=self._engine_config.model_name,
**self._engine_config.to_dict(),
),
)
yield llm
finally: finally:
if "llm" in locals(): self._llm_publisher = None
# Run shutdown in a thread to avoid blocking self._llm_publisher_context = None
await loop.run_in_executor(None, llm.shutdown)
try:
async with async_llm_wrapper() as engine:
# Capture the engine event loop and make it visible to other threads.
self._event_loop = asyncio.get_running_loop()
# Signal the engine is started and make it visible to other threads. if self._llm_engine_context:
with self._llm_engine_start_cv: try:
self._llm_engine = engine await self._llm_engine_context.__aexit__(None, None, None)
self._llm_engine_start_cv.notify_all() except Exception as e:
logging.error(f"Error during engine cleanup: {e}")
logger.info("Engine loaded and ready to serve...") finally:
self._llm_engine = None
# Wait for the engine shutdown signal. self._llm_engine_context = None
await self._llm_engine_shutdown_event.wait()
# Stop the publishing threads self._prefill_client = None
if self.publish_stats_thread and self.publish_stats_thread.is_alive():
self.publish_stats_thread.stop()
self.publish_stats_thread.join()
if (
self.publish_kv_cache_events_thread
and self.publish_kv_cache_events_thread.is_alive()
):
self.publish_kv_cache_events_thread.stop()
self.publish_kv_cache_events_thread.join()
# Wait for the ongoing requests to complete.
while self._ongoing_request_count > 0:
logger.info(
"Awaiting remaining {} requests".format(
self._ongoing_request_count
)
)
await asyncio.sleep(1)
# Cancel all tasks in the event loop. async def remote_prefill(self, request: TRTLLMWorkerRequest):
for task in asyncio.all_tasks(loop=self._event_loop): """
if task is not asyncio.current_task(): Send a prefill request to the remote prefill worker.
task.cancel()
except Exception as e: Args:
# Signal and pass the exception back via the engine variable if the engine request: The original request to be sent for prefill
# failed to start. If the engine has started, re-raise the exception.
with self._llm_engine_start_cv:
if self._llm_engine is None:
self._llm_engine = e
self._llm_engine_start_cv.notify_all()
return
raise e
self._llm_engine = None Returns:
logger.info("Shutdown complete") The response from the remote prefill worker
async def _get_remote_prefill_response(self, request): Raises:
prefill_request = copy.deepcopy(request) ValueError: If prefill client is not initialized or multiple responses received
"""
prefill_request = request.model_copy(deep=True)
# TRTLLM requires max_tokens to be set for prefill requests. # TRTLLM requires max_tokens to be set for prefill requests.
prefill_request.stop_conditions.max_tokens = 1 prefill_request.stop_conditions.max_tokens = 1
prefill_request.disaggregated_params = DisaggregatedParams( prefill_request.disaggregated_params = OAIDisaggregatedParams(
request_type=DisaggRequestType.CONTEXT_ONLY.value request_type="context_only"
) )
if self._prefill_client is None: if self._prefill_client is None:
raise ValueError("Prefill client not initialized") raise ValueError("Prefill client not initialized")
try:
# TODO: Use smart KV router to determine which prefill worker to use. This would also require supporting publishing events for prefill workers.
remote_prefill_responses = [
remote_prefill_response
async for remote_prefill_response in await self._prefill_client.round_robin(
prefill_request.model_dump_json()
)
]
except Exception as e:
raise ValueError(f"Error in remote prefill: {e}")
# TODO: Use smart KV router to determine which prefill worker to use. This would also require supporting publishing events for prefill workers. if len(remote_prefill_responses) > 1:
ctx_responses = [
ctx_response
async for ctx_response in await self._prefill_client.round_robin(
prefill_request.model_dump_json()
)
]
if len(ctx_responses) > 1:
raise ValueError( raise ValueError(
"Prefill worker returned more than one response. This is currently not supported in remote prefill mode." "Prefill worker returned more than one response. This is currently not supported in remote prefill mode."
) )
logger.debug(
f"Received response from prefill worker: {ctx_responses[0].data()}" if len(remote_prefill_responses) == 0:
) raise ValueError("No response received from remote prefill worker")
remote_prefill_response = ctx_responses[0]
remote_prefill_response = remote_prefill_responses[0]
return remote_prefill_response return remote_prefill_response
async def generate(self, request): async def generate(self, request: TRTLLMWorkerRequest):
if self._llm_engine is None: if self._llm_engine is None:
raise RuntimeError("Engine not initialized") raise RuntimeError("Engine not initialized")
if not self._error_queue.empty(): if self._llm_publisher:
raise self._error_queue.get() publishers_error = self._llm_publisher.check_error_queue()
if publishers_error:
raise publishers_error
self._ongoing_request_count += 1 inputs = request.token_ids
try: # Decode the disaggregated params from the request
worker_inputs = request.token_ids disaggregated_params = DisaggregatedTypeConverter.to_llm_disaggregated_params(
request.disaggregated_params
)
num_output_tokens_so_far = 0
if self._config.disaggregation_mode == "decode":
# Run prefill/context phase remotely if disaggregation mode is decode.
try:
prefill_result = await self.remote_prefill(request)
except Exception as e:
raise ValueError(f"Error in remote prefill: {e}")
remote_prefill_response = prefill_result.data()
if (
remote_prefill_response["finish_reason"] == "stop"
or remote_prefill_response["finish_reason"] == "error"
):
yield remote_prefill_response
return
num_output_tokens_so_far = len(remote_prefill_response["token_ids"])
# Decode the disaggregated params from the remote prefill response
# Decode the disaggregated params from the remote prefill response
disaggregated_params = ( disaggregated_params = (
DisaggregatedTypeConverter.to_llm_disaggregated_params( DisaggregatedTypeConverter.to_llm_disaggregated_params(
request.disaggregated_params OAIDisaggregatedParams(
) **remote_prefill_response["disaggregated_params"]
)
num_output_tokens_so_far = 0
if self._remote_prefill and self._server_type == ServerType.GEN:
ctx_response = await self._get_remote_prefill_response(request)
remote_prefill_response = ctx_response.data()
if (
remote_prefill_response["finish_reason"] == "stop"
or remote_prefill_response["finish_reason"] == "error"
):
yield remote_prefill_response
return
num_output_tokens_so_far = len(remote_prefill_response["token_ids"])
# Decode the disaggregated params from the remote prefill response
disaggregated_params = (
DisaggregatedTypeConverter.to_llm_disaggregated_params(
DisaggregatedParams(
**remote_prefill_response["disaggregated_params"]
)
) )
) )
# Send the first token response to the client
first_token_response = remote_prefill_response
first_token_response.pop("disaggregated_params")
yield first_token_response
disaggregated_params.request_type = (
DisaggRequestType.GENERATION_ONLY.value
)
logger.debug(
f"Worker inputs: {worker_inputs}, disaggregated params: {disaggregated_params}"
)
sampling_params = get_sampling_params(
request.sampling_options.dict(), self._default_sampling_params
) )
max_tokens = request.stop_conditions.max_tokens
if max_tokens:
sampling_params.max_tokens = max_tokens
async for response in self._llm_engine.generate_async(
inputs=worker_inputs,
sampling_params=sampling_params,
disaggregated_params=disaggregated_params,
streaming=self._server_type != ServerType.CTX,
):
if response.finished and self._server_type != ServerType.CTX:
yield {"finish_reason": "stop", "token_ids": []}
break
if not response.outputs:
yield {"finish_reason": "error", "token_ids": []}
break
output = response.outputs[0]
next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = output.finish_reason
if output.stop_reason:
out["stop_reason"] = output.stop_reason
if self._server_type == ServerType.CTX:
# Return the disaggregated params only when operating in prefill mode.
out[
"disaggregated_params"
] = DisaggregatedTypeConverter.to_oai_disaggregated_params(
output.disaggregated_params
).dict()
yield out
num_output_tokens_so_far = next_total_toks
except CppExecutorError:
signal.raise_signal(signal.SIGINT)
except Exception as e:
raise RuntimeError("Failed to generate: " + str(e))
self._start_threads() # Send the first token response to the client
self._ongoing_request_count -= 1 first_token_response = remote_prefill_response
first_token_response.pop("disaggregated_params")
yield first_token_response
# Set the disaggregated params to generation_only for the rest of the generation
disaggregated_params.request_type = "generation_only"
sampling_params = self.default_sampling_params
for key, value in request.sampling_options.model_dump().items():
if not value:
continue
if hasattr(sampling_params, key):
setattr(sampling_params, key, value)
max_tokens = request.stop_conditions.max_tokens
if max_tokens:
sampling_params.max_tokens = max_tokens
# TODO: Disable streaming for context only requests when adding disagg support
async for res in self._llm_engine.llm.generate_async(
inputs=inputs,
sampling_params=sampling_params,
disaggregated_params=disaggregated_params,
streaming=(self._config.disaggregation_mode != "prefill"),
):
# TRTLLM engine needs to start generating tokens first before stats
# can be retrieved.
if self._first_generation and self._llm_publisher:
self._llm_publisher.start()
self._first_generation = False
if res.finished and self._config.disaggregation_mode != "prefill":
yield {"finish_reason": "stop", "token_ids": []}
break
if not res.outputs:
yield {"finish_reason": "error", "token_ids": []}
break
output = res.outputs[0]
next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = output.finish_reason
if output.stop_reason:
out["stop_reason"] = output.stop_reason
if self._config.disaggregation_mode == "prefill":
# Return the disaggregated params only when operating in prefill mode.
out[
"disaggregated_params"
] = DisaggregatedTypeConverter.to_oai_disaggregated_params(
output.disaggregated_params
).model_dump()
yield out
num_output_tokens_so_far = next_total_toks
...@@ -14,136 +14,28 @@ ...@@ -14,136 +14,28 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Tuple
import yaml
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm.llmapi import KvCacheConfig
from tensorrt_llm.llmapi.llm_args import DecodingBaseConfig
@dataclass
class LLMAPIConfig:
def __init__(
self,
model_name: str,
model_path: str | None = None,
pytorch_backend_config: PyTorchConfig | None = None,
kv_cache_config: KvCacheConfig | None = None,
speculative_config: DecodingBaseConfig | None = None,
**kwargs,
):
self.model_name = model_name
self.model_path = model_path
self.pytorch_backend_config = pytorch_backend_config
self.kv_cache_config = kv_cache_config
self.speculative_config = speculative_config
self.extra_args = kwargs
# Hardcoded to skip tokenizer init for now.
# We will handle the tokenization/detokenization
# in the base engine.
if "skip_tokenizer_init" in self.extra_args:
self.extra_args.pop("skip_tokenizer_init")
self.skip_tokenizer_init = True
def to_dict(self) -> Dict[str, Any]:
data = {
"kv_cache_config": self.kv_cache_config,
"speculative_config": self.speculative_config,
"skip_tokenizer_init": self.skip_tokenizer_init,
}
if self.extra_args:
data.update(self.extra_args)
return data
def update_sub_configs(self, other_config: Dict[str, Any]):
# TODO: Consider removing pytorch_backend_config parsing as this section
# was collapsed to top level config fields in recent TRTLLM versions.
if "pytorch_backend_config" in other_config:
self.pytorch_backend_config = PyTorchConfig(
**other_config["pytorch_backend_config"]
)
self.extra_args.pop("pytorch_backend_config", None)
if "kv_cache_config" in other_config:
self.kv_cache_config = KvCacheConfig(**other_config["kv_cache_config"])
self.extra_args.pop("kv_cache_config", None)
if "speculative_config" in other_config:
self.speculative_config = DecodingBaseConfig.from_dict(
other_config["speculative_config"]
)
self.extra_args.pop("speculative_config", None)
def _get_llm_args(engine_config):
# Only do model validation checks and leave other checks to LLMAPI
if "model_name" not in engine_config:
raise ValueError("Model name is required in the TRT-LLM engine config.")
if engine_config.get("model_path", ""):
if os.path.exists(engine_config.get("model_path", "")):
engine_config["model_path"] = Path(engine_config["model_path"])
else:
raise ValueError(f"Model path {engine_config['model_path']} does not exist")
model_name = engine_config["model_name"]
model_path = engine_config.get("model_path", None)
engine_config.pop("model_name")
engine_config.pop("model_path", None)
# Store all other args as kwargs
llm_api_config = LLMAPIConfig(
model_name=model_name,
model_path=model_path,
**engine_config,
)
# Parse supported sub configs and remove from kwargs
llm_api_config.update_sub_configs(engine_config)
return llm_api_config
def _init_engine_args(engine_args_filepath):
"""Initialize engine arguments from config file."""
if not os.path.isfile(engine_args_filepath):
raise ValueError(
"'YAML file containing TRT-LLM engine args must be provided in when launching the worker."
)
try:
with open(engine_args_filepath) as file:
trtllm_engine_config = yaml.safe_load(file)
except yaml.YAMLError as e:
raise RuntimeError(f"Failed to parse engine config: {e}")
return _get_llm_args(trtllm_engine_config)
def parse_tensorrt_llm_args( def parse_tensorrt_llm_args(
config_args, config_args,
) -> Tuple[Any, Tuple[Dict[str, Any], Dict[str, Any]]]: ) -> argparse.Namespace:
parser = argparse.ArgumentParser(description="A TensorRT-LLM Worker parser") parser = argparse.ArgumentParser(description="A TensorRT-LLM Worker parser")
parser.add_argument( parser.add_argument(
"--engine_args", type=str, required=True, help="Path to the engine args file" "--extra-engine-args",
type=str,
default="",
help="Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.",
) )
parser.add_argument( parser.add_argument(
"--served_model_name", "--model-path",
type=str, type=str,
help="Name of the model to serve",
default=None, default=None,
help="Path to disk model or HuggingFace model identifier to load.",
) )
parser.add_argument( parser.add_argument(
"--llmapi-disaggregated-config", "--served_model_name",
"-c",
type=str, type=str,
help="Path to the llmapi disaggregated config file", help="Name to serve the model under.",
default=None,
) )
parser.add_argument( parser.add_argument(
"--router", "--router",
...@@ -152,46 +44,19 @@ def parse_tensorrt_llm_args( ...@@ -152,46 +44,19 @@ def parse_tensorrt_llm_args(
default="random", default="random",
help="Router type to use for scheduling requests to workers", help="Router type to use for scheduling requests to workers",
) )
parser.add_argument( parser.add_argument(
"--min-workers", "--kv-block-size",
type=int,
default=1,
help="Minimum number of workers for aggregated (monolith) server",
)
parser.add_argument(
"--min-prefill-workers",
type=int,
default=1,
help="Minimum number of prefill workers for disaggregated server",
)
parser.add_argument(
"--block-size",
type=int, type=int,
default=32, default=32,
help="Number of tokens per KV block in TRTLLM worker. Default is 32 for pytorch backend.", help="Number of tokens per KV block in TRTLLM worker. Default is 32 for pytorch backend.",
) )
parser.add_argument(
"--remote-prefill",
action="store_true",
help="Use remote prefill workers for generation server in Disaggregated mode.",
)
args = parser.parse_args(config_args)
return (args, _init_engine_args(args.engine_args))
def parse_dynamo_run_args() -> Tuple[Any, Tuple[Dict[str, Any], Dict[str, Any]]]:
parser = argparse.ArgumentParser(
description="A TensorRT-LLM Dynamo-run engine parser"
)
parser.add_argument( parser.add_argument(
"--engine_args", type=str, required=True, help="Path to the engine args file" "--enable-disagg",
)
parser.add_argument(
"--publish-kv-cache-events",
action="store_true", action="store_true",
help="Publish KV cache events from TensorRT-LLM. Currently, only supported for context worker in Disaggregated mode.", help="Enable remote prefill for the worker",
) )
args, _ = parser.parse_known_args() args = parser.parse_args(config_args)
return (args, _init_engine_args(args.engine_args)) return args
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import threading
import traceback
import weakref
from enum import Enum
from queue import Queue
from typing import Any, Callable, Coroutine, Optional, TypedDict, Union
logger = logging.getLogger(__name__)
AsyncTask = Union[Callable[..., Coroutine[Any, Any, bool]], weakref.WeakMethod]
class RoutingStrategy(Enum):
ROUND_ROBIN = "round_robin"
RANDOM = "random"
PREFIX = "prefix"
class RequestType(Enum):
CHAT = "chat"
COMPLETION = "completion"
class ServerType(Enum):
# Generation server used for disaggregated and aggregated requests
GEN = "gen"
# Context server used for disaggregated requests
CTX = "ctx"
# Dynamo run server used for Dynamo run requests
DYN_RUN = "dyn_run"
class ConversationMessage(TypedDict):
role: str
content: str
class ManagedThread(threading.Thread):
def __init__(
self,
task: Optional[AsyncTask],
error_queue: Optional[Queue] = None,
name: Optional[str] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
**kwargs,
):
super().__init__(name=name)
self.task = task
self.error_queue = error_queue
self.kwargs = kwargs
self.loop = loop
self.daemon = True
self.stop_event = threading.Event()
def set_loop(self, loop: asyncio.AbstractEventLoop):
self.loop = loop
def run(self):
while not self.stop_event.is_set():
task: Optional[AsyncTask] = self.task
if isinstance(task, weakref.WeakMethod):
task = task()
if task is None:
# Normally, this should not happen.
logger.warning("WeakMethod is expired.")
break
if task is None:
break
try:
if self.loop is None:
logger.error("[ManagedThread] Loop not initialized!")
break
future = asyncio.run_coroutine_threadsafe(
task(**self.kwargs), self.loop
)
_ = future.result()
except Exception as e:
logger.error(
f"Error in thread {self.name}: {e}\n{traceback.format_exc()}"
)
if self.error_queue is not None:
self.error_queue.put(e)
logger.info(f"Thread {self.name} stopped.")
def stop(self):
self.stop_event.set()
...@@ -12,15 +12,13 @@ ...@@ -12,15 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import logging import logging
from common.base_engine import BaseTensorrtLLMEngine from common.base_engine import BaseEngineConfig, BaseTensorrtLLMEngine
from common.parser import parse_tensorrt_llm_args from common.parser import parse_tensorrt_llm_args
from common.protocol import TRTLLMWorkerRequest from common.protocol import TRTLLMWorkerRequest
from common.utils import ServerType
from dynamo.sdk import async_on_start, dynamo_context, endpoint, service from dynamo.sdk import async_on_start, dynamo_context, endpoint, on_shutdown, service
from dynamo.sdk.lib.config import ServiceConfig from dynamo.sdk.lib.config import ServiceConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -39,34 +37,37 @@ class TensorRTLLMPrefillWorker(BaseTensorrtLLMEngine): ...@@ -39,34 +37,37 @@ class TensorRTLLMPrefillWorker(BaseTensorrtLLMEngine):
class_name = self.__class__.__name__ class_name = self.__class__.__name__
config = ServiceConfig.get_instance() config = ServiceConfig.get_instance()
config_args = config.as_args(class_name, prefix="") config_args = config.as_args(class_name, prefix="")
args, engine_config = parse_tensorrt_llm_args(config_args) args = parse_tensorrt_llm_args(config_args)
worker_id = dynamo_context["endpoints"][0].lease_id() lease_id = dynamo_context["endpoints"][0].lease_id()
super().__init__( namespace, _ = TensorRTLLMPrefillWorker.dynamo_address() # type: ignore
namespace_str="dynamo",
component_str=class_name, engine_config = BaseEngineConfig(
worker_id=worker_id, namespace=namespace,
engine_config=engine_config, component=class_name,
remote_prefill=args.remote_prefill, endpoint="generate",
min_workers=args.min_workers, model_path=args.model_path,
disagg_config_file=args.llmapi_disaggregated_config, served_model_name=args.served_model_name,
block_size=args.block_size, kv_block_size=args.kv_block_size,
router=args.router, extra_engine_args=args.extra_engine_args,
server_type=ServerType.CTX, publish_events_and_metrics=False,
disaggregation_mode="prefill",
remote_prefill_endpoint=None,
lease_id=lease_id,
) )
super().__init__(config=engine_config)
@async_on_start @async_on_start
async def async_init(self): async def async_init(self):
self._init_engine() runtime = dynamo_context["runtime"]
if self._kv_metrics_publisher is not None: await self.initialize(runtime)
task = asyncio.create_task(self.create_metrics_publisher_endpoint())
task.add_done_callback(
lambda _: logger.info("metrics publisher endpoint created")
)
logger.info("TensorRT-LLM Prefill Worker initialized") logger.info("TensorRT-LLM Prefill Worker initialized")
async def create_metrics_publisher_endpoint(self): @on_shutdown
component = dynamo_context["component"] async def async_cleanup(self):
await self.kv_metrics_publisher.create_endpoint(component) logger.info("Cleaning up TensorRT-LLM Prefill Worker")
await self.cleanup()
logger.info("TensorRT-LLM Prefill Worker cleanup completed")
@endpoint() @endpoint()
async def generate(self, request: TRTLLMWorkerRequest): async def generate(self, request: TRTLLMWorkerRequest):
......
...@@ -12,17 +12,22 @@ ...@@ -12,17 +12,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import logging import logging
from common.base_engine import BaseTensorrtLLMEngine from common.base_engine import BaseEngineConfig, BaseTensorrtLLMEngine
from common.parser import parse_tensorrt_llm_args from common.parser import parse_tensorrt_llm_args
from common.protocol import TRTLLMWorkerRequest from common.protocol import TRTLLMWorkerRequest
from common.utils import ServerType
from components.prefill_worker import TensorRTLLMPrefillWorker from components.prefill_worker import TensorRTLLMPrefillWorker
from dynamo.llm import ModelType, register_llm from dynamo.llm import ModelType, register_llm
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service from dynamo.sdk import (
async_on_start,
depends,
dynamo_context,
endpoint,
on_shutdown,
service,
)
from dynamo.sdk.lib.config import ServiceConfig from dynamo.sdk.lib.config import ServiceConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -43,74 +48,66 @@ class TensorRTLLMWorker(BaseTensorrtLLMEngine): ...@@ -43,74 +48,66 @@ class TensorRTLLMWorker(BaseTensorrtLLMEngine):
class_name = self.__class__.__name__ class_name = self.__class__.__name__
config = ServiceConfig.get_instance() config = ServiceConfig.get_instance()
config_args = config.as_args(class_name, prefix="") config_args = config.as_args(class_name, prefix="")
args, engine_config = parse_tensorrt_llm_args(config_args) args = parse_tensorrt_llm_args(config_args)
self.served_model_name = args.served_model_name lease_id = dynamo_context["endpoints"][0].lease_id()
worker_id = dynamo_context["endpoints"][0].lease_id()
namespace, _ = TensorRTLLMWorker.dynamo_address() # type: ignore namespace, _ = TensorRTLLMWorker.dynamo_address() # type: ignore
self._min_prefill_workers = args.min_prefill_workers endpoint_name = "generate"
super().__init__( publish_events_and_metrics = args.router == "kv"
namespace_str=namespace, prefill_class_name = "TensorRTLLMPrefillWorker"
component_str=class_name,
worker_id=worker_id, if args.enable_disagg:
engine_config=engine_config, disaggregation_mode = "decode"
remote_prefill=args.remote_prefill, else:
min_workers=args.min_workers, disaggregation_mode = "prefill_and_decode"
disagg_config_file=args.llmapi_disaggregated_config,
block_size=args.block_size, engine_config = BaseEngineConfig(
router=args.router, namespace=namespace,
server_type=ServerType.GEN, component=class_name,
endpoint=endpoint_name,
model_path=args.model_path,
served_model_name=args.served_model_name,
kv_block_size=args.kv_block_size,
extra_engine_args=args.extra_engine_args,
publish_events_and_metrics=publish_events_and_metrics,
disaggregation_mode=disaggregation_mode,
remote_prefill_endpoint=f"dyn://{namespace}.{prefill_class_name}.generate",
lease_id=lease_id,
) )
super().__init__(config=engine_config)
@async_on_start @async_on_start
async def async_init(self): async def async_init(self):
self._init_engine()
runtime = dynamo_context["runtime"] runtime = dynamo_context["runtime"]
await self.initialize(runtime)
logger.info("Registering LLM for discovery") logger.info("Registering LLM for discovery")
comp_ns, comp_name = TensorRTLLMWorker.dynamo_address() # type: ignore endpoint = (
endpoint = runtime.namespace(comp_ns).component(comp_name).endpoint("generate") runtime.namespace(self._config.namespace)
.component(self._config.component)
.endpoint(self._config.endpoint)
)
try: try:
await register_llm( await register_llm(
ModelType.Backend, ModelType.Backend,
endpoint, endpoint,
self._engine_config.model_name, self._config.model_path,
self.served_model_name, self._config.served_model_name,
kv_cache_block_size=self._kv_block_size, kv_cache_block_size=self._config.kv_block_size,
) )
logger.info("Successfully registered LLM for discovery") logger.info("Successfully registered LLM for discovery")
except Exception as e: except Exception as e:
logger.error(f"Failed to register LLM for discovery: {e}") logger.error(f"Failed to register LLM for discovery: {e}")
raise raise
if self._remote_prefill:
runtime = dynamo_context["runtime"]
comp_ns, comp_name = TensorRTLLMPrefillWorker.dynamo_address() # type: ignore
self._prefill_client = (
await runtime.namespace(comp_ns)
.component(comp_name)
.endpoint("generate")
.client()
)
while len(self._prefill_client.instance_ids()) < self._min_prefill_workers:
logger.info(
f"Waiting for prefill workers to be ready.\n"
f" Current: {len(self._prefill_client.instance_ids())},"
f" Required: {self._min_prefill_workers}"
)
await asyncio.sleep(30)
if self._kv_metrics_publisher is not None:
task = asyncio.create_task(self.create_metrics_publisher_endpoint())
task.add_done_callback(
lambda _: logger.info("metrics publisher endpoint created")
)
logger.info("TensorRT-LLM Worker initialized") logger.info("TensorRT-LLM Worker initialized")
async def create_metrics_publisher_endpoint(self): @on_shutdown
component = dynamo_context["component"] async def async_cleanup(self):
await self._kv_metrics_publisher.create_endpoint(component) logger.info("Cleaning up TensorRT-LLM Worker")
await self.cleanup()
logger.info("TensorRT-LLM Worker cleanup completed")
@endpoint() @endpoint()
async def generate(self, request: TRTLLMWorkerRequest): async def generate(self, request: TRTLLMWorkerRequest):
......
...@@ -20,8 +20,13 @@ Frontend: ...@@ -20,8 +20,13 @@ Frontend:
router: round-robin router: round-robin
TensorRTLLMWorker: TensorRTLLMWorker:
# Path to disk model or HuggingFace model identifier to load
model-path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# Name to serve the model under
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
engine_args: "configs/llm_api_config.yaml" # Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.
# The fields in `extra-engine-args` holds higher priority than the above TRTLLM engine fields.
extra-engine-args: "configs/engine_configs/agg_config.yaml"
router: round-robin router: round-robin
ServiceArgs: ServiceArgs:
workers: 1 workers: 1
......
...@@ -20,9 +20,15 @@ Frontend: ...@@ -20,9 +20,15 @@ Frontend:
router: kv router: kv
TensorRTLLMWorker: TensorRTLLMWorker:
engine_args: "configs/llm_api_config_router.yaml" # Path to disk model or HuggingFace model identifier to load
model-path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# Name to serve the model under
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.
# The fields in `extra-engine-args` holds higher priority than the above TRTLLM engine fields.
extra-engine-args: "configs/engine_configs/agg_config.yaml"
router: kv router: kv
ServiceArgs: ServiceArgs:
workers: 1 workers: 1
resources: resources:
gpu: 1 gpu: 1
\ No newline at end of file
...@@ -22,7 +22,12 @@ Frontend: ...@@ -22,7 +22,12 @@ Frontend:
TensorRTLLMWorker: TensorRTLLMWorker:
served_model_name: "nvidia/DeepSeek-R1-FP4" served_model_name: "nvidia/DeepSeek-R1-FP4"
engine_args: "configs/deepseek_r1/agg_llm_api_config.yaml" # NOTE: FP4 only supported starting with Blackwell GPUs.
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
model-path: "nvidia/DeepSeek-R1-FP4"
extra-engine-args: "configs/deepseek_r1/engine_configs/agg_config.yaml"
router: round-robin router: round-robin
ServiceArgs: ServiceArgs:
workers: 1 workers: 1
......
...@@ -22,14 +22,13 @@ Frontend: ...@@ -22,14 +22,13 @@ Frontend:
TensorRTLLMWorker: TensorRTLLMWorker:
served_model_name: "nvidia/DeepSeek-R1-FP4" served_model_name: "nvidia/DeepSeek-R1-FP4"
engine_args: "configs/deepseek_r1/agg_llm_api_config.yaml" # NOTE: FP4 only supported starting with Blackwell GPUs.
llmapi-disaggregated-config: "configs/deepseek_r1/disagg_llm_api_config.yaml" # https://huggingface.co/nvidia/DeepSeek-R1-FP4
remote-prefill: true # You can also specify the full path to locally downloaded weights
# NOTE: When testing/benchmarking multiple prefill workers, you can set # instead of a HuggingFace ID here.
# this number to the exact amount of prefill workers if you want Dynamo to model-path: "nvidia/DeepSeek-R1-FP4"
# wait until all the prefill workers are ready before marking the decode extra-engine-args: "configs/deepseek_r1/engine_configs/decode_config.yaml"
# worker ready. enable-disagg: true
min-prefill-workers: 1
router: round-robin router: round-robin
ServiceArgs: ServiceArgs:
workers: 1 workers: 1
...@@ -37,8 +36,12 @@ TensorRTLLMWorker: ...@@ -37,8 +36,12 @@ TensorRTLLMWorker:
gpu: 4 gpu: 4
TensorRTLLMPrefillWorker: TensorRTLLMPrefillWorker:
engine_args: "configs/deepseek_r1/agg_llm_api_config.yaml" # NOTE: FP4 only supported starting with Blackwell GPUs.
llmapi-disaggregated-config: "configs/deepseek_r1/disagg_llm_api_config.yaml" # https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
model-path: "nvidia/DeepSeek-R1-FP4"
extra-engine-args: "configs/deepseek_r1/engine_configs/prefill_config.yaml"
router: round-robin router: round-robin
ServiceArgs: ServiceArgs:
workers: 1 workers: 1
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Example Configs for Context & Generation on GB200 nodes
# - Context on 1xGB200 (4xB00)
# - Generation on 1xGB200 (4xB200)
# NOTE: Fields like hostname, ports, urls, num_instances, etc. only used by trtllm-serve, not by dynamo
backend: pytorch
context_servers:
# Context/prefill processes many tokens at once, so for a large ISL, a large
# batch size may not be needed to saturate GPU utilization.
max_batch_size: 1
max_num_tokens: 8192
max_seq_len: 8192
# TP/EP/PP/DP
tensor_parallel_size: 4
moe_expert_parallel_size: 4
pipeline_parallel_size: 1
enable_attention_dp: true
kv_cache_config:
free_gpu_memory_fraction: 0.75
# NOTE: pytorch_backend_config section flattened since: https://github.com/NVIDIA/TensorRT-LLM/pull/4603
# NOTE: This field is called 'enable_overlap_scheduler' in older TRTLLM versions
# Overlap scheduler not currently supported in context-only
disable_overlap_scheduler: true
print_iter_log: true
# NOTE: This dtype must match in both context/generation configs
kv_cache_dtype: fp8
generation_servers:
# Generation/decode processes one token per request at a time, so a larger
# batch size helps to saturate GPU utilization.
max_batch_size: 256
max_num_tokens: 256
# 8448 = 8192 ISL + 256 OSL
max_seq_len: 8448
# TP/EP/PP/DP
tensor_parallel_size: 4
moe_expert_parallel_size: 4
pipeline_parallel_size: 1
enable_attention_dp: false
kv_cache_config:
# With dp attention disabled: high free_gpu_memory_fraction is fine.
free_gpu_memory_fraction: 0.85
# With dp attention enabled: large ISL at high concurrency may need
# free_gpu_memory_fraction low to have enough available memory.
# free_gpu_memory_fraction: 0.30
# NOTE: pytorch_backend_config section flattened since: https://github.com/NVIDIA/TensorRT-LLM/pull/4603
# NOTE: This field is called 'enable_overlap_scheduler' in older TRTLLM versions
disable_overlap_scheduler: false
use_cuda_graph: true
cuda_graph_padding_enabled: true
# NOTE: For larger max batch size, you may want to add larger cuda graph
# batch sizes below to match.
cuda_graph_batch_sizes:
- 1
- 2
- 4
- 8
- 16
- 32
- 64
- 128
- 256
print_iter_log: true
# NOTE: This dtype must match in both context/generation configs
kv_cache_dtype: fp8
...@@ -12,12 +12,6 @@ ...@@ -12,12 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# NOTE: FP4 only supported starting with Blackwell GPUs.
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
model_name: "nvidia/DeepSeek-R1-FP4"
backend: pytorch backend: pytorch
# TP/EP/PP/DP # TP/EP/PP/DP
......
...@@ -12,32 +12,44 @@ ...@@ -12,32 +12,44 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
backend: pytorch
# TP/EP/PP/DP
# In the case of disaggregated deployment, this config will apply to each server tensor_parallel_size: 4
# and will be overwritten by the disaggregated config file moe_expert_parallel_size: 4
pipeline_parallel_size: 1
# TODO: figure out how to generate this from the service config or vice versa
model_name: "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
model_path: null
tensor_parallel_size: 1
moe_expert_parallel_size: 1
enable_attention_dp: false enable_attention_dp: false
max_num_tokens: 8192
max_batch_size: 16 max_batch_size: 256
trust_remote_code: true max_num_tokens: 256
backend: pytorch # 8448 = 8192 ISL + 256 OSL
enable_chunked_prefill: true max_seq_len: 8448
kv_cache_config: kv_cache_config:
free_gpu_memory_fraction: 0.95 # With dp attention disabled: high free_gpu_memory_fraction is fine.
event_buffer_max_size: 1024 free_gpu_memory_fraction: 0.85
enable_block_reuse: true # With dp attention enabled: large ISL at high concurrency may need
# free_gpu_memory_fraction low to have enough available memory.
# free_gpu_memory_fraction: 0.30
# NOTE: pytorch_backend_config section flattened since: https://github.com/NVIDIA/TensorRT-LLM/pull/4603 # NOTE: pytorch_backend_config section flattened since: https://github.com/NVIDIA/TensorRT-LLM/pull/4603
# NOTE: overlap_scheduler enabled by default since this commit and changed # NOTE: overlap_scheduler enabled by default since this commit and changed
# config field from 'enable_overlap_scheduler' to 'disable_overlap_scheduler': # config field from 'enable_overlap_scheduler' to 'disable_overlap_scheduler':
# https://github.com/NVIDIA/TensorRT-LLM/commit/b4e5df0ee0024eda3eeb83a6ba822245a30ab428 # https://github.com/NVIDIA/TensorRT-LLM/commit/b4e5df0ee0024eda3eeb83a6ba822245a30ab428
disable_overlap_scheduler: false
use_cuda_graph: true use_cuda_graph: true
enable_iter_perf_stats: true cuda_graph_padding_enabled: true
# NOTE: For larger max batch size, you may want to add larger cuda graph
# batch sizes below to match.
cuda_graph_batch_sizes:
- 1
- 2
- 4
- 8
- 16
- 32
- 64
- 128
- 256
print_iter_log: true
kv_cache_dtype: fp8
...@@ -12,32 +12,30 @@ ...@@ -12,32 +12,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
backend: pytorch
# TP/EP/PP/DP
tensor_parallel_size: 4
moe_expert_parallel_size: 4
pipeline_parallel_size: 1
enable_attention_dp: true
# In the case of disaggregated deployment, this config will apply to each server max_batch_size: 1
# and will be overwritten by the disaggregated config file
# TODO: figure out how to generate this from the service config or vice versa
model_name: "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
model_path: null
tensor_parallel_size: 1
moe_expert_parallel_size: 1
enable_attention_dp: false
max_num_tokens: 8192 max_num_tokens: 8192
max_batch_size: 16 max_seq_len: 8192
trust_remote_code: true
backend: pytorch
enable_chunked_prefill: true
kv_cache_config: kv_cache_config:
free_gpu_memory_fraction: 0.95 # With dp attention disabled: high free_gpu_memory_fraction is fine.
event_buffer_max_size: 1024 free_gpu_memory_fraction: 0.75
enable_block_reuse: true # With dp attention enabled: large ISL at high concurrency may need
# free_gpu_memory_fraction low to have enough available memory.
# free_gpu_memory_fraction: 0.30
# NOTE: pytorch_backend_config section flattened since: https://github.com/NVIDIA/TensorRT-LLM/pull/4603 # NOTE: pytorch_backend_config section flattened since: https://github.com/NVIDIA/TensorRT-LLM/pull/4603
# NOTE: overlap_scheduler enabled by default since this commit and changed # NOTE: overlap_scheduler enabled by default since this commit and changed
# config field from 'enable_overlap_scheduler' to 'disable_overlap_scheduler': # config field from 'enable_overlap_scheduler' to 'disable_overlap_scheduler':
# https://github.com/NVIDIA/TensorRT-LLM/commit/b4e5df0ee0024eda3eeb83a6ba822245a30ab428 # https://github.com/NVIDIA/TensorRT-LLM/commit/b4e5df0ee0024eda3eeb83a6ba822245a30ab428
use_cuda_graph: true disable_overlap_scheduler: true
enable_iter_perf_stats: true print_iter_log: true
# NOTE: This dtype must match in both prefill/decode configs
kv_cache_dtype: fp8
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
# You can also specify the full path to locally downloaded weights # You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here. # instead of a HuggingFace ID here.
model_name: "nvidia/DeepSeek-R1-FP4"
backend: pytorch backend: pytorch
tensor_parallel_size: 4 tensor_parallel_size: 4
moe_expert_parallel_size: 4 moe_expert_parallel_size: 4
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# NOTE: FP4 only supported starting with Blackwell GPUs.
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
backend: pytorch
tensor_parallel_size: 4
moe_expert_parallel_size: 4
enable_attention_dp: false
max_batch_size: 256
# Note: When MPT is enabled and `cuda_graph_batch_sizes` is specified, `max_num_tokens` must satisfy the following formula:
# max_num_tokens >= max(cuda_graph_batch_sizes) * (num_nextn_predict_layers + 1)
# This is a known issue in TensorRT-LLM and will be resolved in the next release.
max_num_tokens: 512
# 8704 = 8192 ISL + 512 OSL
max_seq_len: 8704
kv_cache_config:
free_gpu_memory_fraction: 0.85
# Enable the MTP(Multi-Token Prediction) in decode model engine
speculative_config:
decoding_type: MTP
num_nextn_predict_layers: 1
use_cuda_graph: true
cuda_graph_padding_enabled: true
cuda_graph_batch_sizes:
- 1
- 2
- 4
- 8
- 16
- 32
- 64
- 128
- 256
print_iter_log: true
kv_cache_dtype: fp8
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# NOTE: FP4 only supported starting with Blackwell GPUs.
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
backend: pytorch
tensor_parallel_size: 4
moe_expert_parallel_size: 4
enable_attention_dp: true
max_batch_size: 1
max_num_tokens: 8192
max_seq_len: 8192
kv_cache_config:
free_gpu_memory_fraction: 0.75
print_iter_log: true
kv_cache_dtype: fp8
disable_overlap_scheduler: true
# Enable the MTP(Multi-Token Prediction) in the prefill model engine
speculative_config:
decoding_type: MTP
num_nextn_predict_layers: 1
...@@ -21,7 +21,14 @@ Frontend: ...@@ -21,7 +21,14 @@ Frontend:
TensorRTLLMWorker: TensorRTLLMWorker:
served_model_name: "nvidia/DeepSeek-R1-FP4" served_model_name: "nvidia/DeepSeek-R1-FP4"
engine_args: "configs/deepseek_r1/mtp/mtp_agg_llm_api_config.yaml" # NOTE: FP4 only supported starting with Blackwell GPUs.
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
model-path: "nvidia/DeepSeek-R1-FP4"
# Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.
# The fields in `extra-engine-args` holds higher priority than the above TRTLLM engine fields.
extra-engine-args: "configs/deepseek_r1/mtp/engine_configs/agg_config.yaml"
router: round-robin router: round-robin
ServiceArgs: ServiceArgs:
workers: 1 workers: 1
......
...@@ -21,19 +21,30 @@ Frontend: ...@@ -21,19 +21,30 @@ Frontend:
TensorRTLLMWorker: TensorRTLLMWorker:
served_model_name: "nvidia/DeepSeek-R1-FP4" served_model_name: "nvidia/DeepSeek-R1-FP4"
engine_args: "configs/deepseek_r1/agg_llm_api_config.yaml" # NOTE: FP4 only supported starting with Blackwell GPUs.
llmapi-disaggregated-config: "configs/deepseek_r1/mtp/mtp_disagg_llm_api_config.yaml" # https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
model-path: "nvidia/DeepSeek-R1-FP4"
# Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.
# The fields in `extra-engine-args` holds higher priority than the above TRTLLM engine fields.
extra-engine-args: "configs/deepseek_r1/mtp/engine_configs/decode_config.yaml"
router: round-robin router: round-robin
remote-prefill: true enable-disagg: true
min-prefill-workers: 1
ServiceArgs: ServiceArgs:
workers: 1 workers: 1
resources: resources:
gpu: 4 gpu: 4
TensorRTLLMPrefillWorker: TensorRTLLMPrefillWorker:
engine_args: "configs/deepseek_r1/agg_llm_api_config.yaml" # NOTE: FP4 only supported starting with Blackwell GPUs.
llmapi-disaggregated-config: "configs/deepseek_r1/mtp/mtp_disagg_llm_api_config.yaml" # https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
model-path: "nvidia/DeepSeek-R1-FP4"
# Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.
# The fields in `extra-engine-args` holds higher priority than the above TRTLLM engine fields.
extra-engine-args: "configs/deepseek_r1/mtp/engine_configs/prefill_config.yaml"
router: round-robin router: round-robin
ServiceArgs: ServiceArgs:
workers: 1 workers: 1
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# NOTE: FP4 only supported starting with Blackwell GPUs.
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
backend: pytorch
context_servers:
num_instances: 1
tensor_parallel_size: 4
moe_expert_parallel_size: 4
enable_attention_dp: true
max_batch_size: 1
max_num_tokens: 8192
max_seq_len: 8192
kv_cache_config:
free_gpu_memory_fraction: 0.75
print_iter_log: true
kv_cache_dtype: fp8
disable_overlap_scheduler: true
# Enable the MTP(Multi-Token Prediction) in the prefill model engine
speculative_config:
decoding_type: MTP
num_nextn_predict_layers: 1
generation_servers:
num_instances: 1
tensor_parallel_size: 4
moe_expert_parallel_size: 4
enable_attention_dp: false
max_batch_size: 256
# Note: When MPT is enabled and `cuda_graph_batch_sizes` is specified, `max_num_tokens` must satisfy the following formula:
# max_num_tokens >= max(cuda_graph_batch_sizes) * (num_nextn_predict_layers + 1)
# This is a known issue in TensorRT-LLM and will be resolved in the next release.
max_num_tokens: 512
# 8704 = 8192 ISL + 512 OSL
max_seq_len: 8704
kv_cache_config:
free_gpu_memory_fraction: 0.85
# Enable the MTP(Multi-Token Prediction) in the decode model engine
speculative_config:
decoding_type: MTP
num_nextn_predict_layers: 1
use_cuda_graph: true
cuda_graph_padding_enabled: true
cuda_graph_batch_sizes:
- 1
- 2
- 4
- 8
- 16
- 32
- 64
- 128
- 256
print_iter_log: true
kv_cache_dtype: fp8
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment