Unverified Commit 66963b70 authored by Indrajit Bhosale's avatar Indrajit Bhosale Committed by GitHub
Browse files
parent 6dba119d
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from enum import Enum
class DisaggregationMode(Enum):
AGGREGATED = "prefill_and_decode"
PREFILL = "prefill"
DECODE = "decode"
ENCODE = "encode"
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio
import logging import logging
from typing import Any, Dict, Union from dataclasses import asdict
from typing import Any, Dict, Optional, Union
import torch import torch
from tensorrt_llm.inputs import default_multimodal_input_loader
import dynamo.nixl_connect as nixl_connect import dynamo.nixl_connect as nixl_connect
from dynamo.trtllm.utils.disagg_utils import DisaggregatedParamsCodec
class EncodeHelper: class EncodeHelper:
...@@ -184,68 +188,70 @@ class EncodeHelper: ...@@ -184,68 +188,70 @@ class EncodeHelper:
# Return just the tensor # Return just the tensor
return encodings_tensor return encodings_tensor
# =========================================================================
# ENCODE REQUEST PROCESSING
# =========================================================================
#
# Two supported flows:
#
# 1. EMBEDDING-PATH FLOW (Pre-computed embeddings via NIXL)
# - User sends URL ending in .pt/.pth/.bin
# - Encode worker loads tensor, creates NIXL readable op
# - Prefill worker reads embeddings via RDMA
# - Use case: Customer has pre-computed embeddings from custom encoder
#
# 2. FULL EPD FLOW (Image URLs via MultimodalEncoder)
# - User sends image URL (http/https/base64)
# - Encode worker runs TRT-LLM's MultimodalEncoder.generate()
# - Returns disaggregated_params to prefill worker
# - Use case: Standard VLM inference with TRT-LLM's encoder
#
# =========================================================================
@staticmethod @staticmethod
async def process_embedding_request( async def _process_embedding_path_flow(
request: Dict[str, Any], embedding_paths: list,
multimodal_processor, multimodal_processor,
connector: nixl_connect.Connector, connector: nixl_connect.Connector,
): ):
""" """
Process embedding request by loading embeddings and creating NIXL readable operation. Process pre-computed embeddings via NIXL transfer.
Loads embeddings from a file path/URL and creates a NIXL readable operation
for the prefill worker to read via RDMA.
Args: Args:
request: Request containing messages with embedding paths embedding_paths: List of paths to embedding files (.pt/.pth/.bin)
multimodal_processor: Multimodal processor for loading embeddings multimodal_processor: Processor to load embeddings
connector: NIXL connector for creating readable operations connector: NIXL connector for RDMA transfer
Yields: Yields:
Response dictionary with NIXL metadata and embeddings info, or error response Response with NIXL metadata, shape, dtype, and auxiliary data
""" """
# Load embeddings first to get the actual shape logging.info(f"EncodeHelper: loading embeddings from {embedding_paths[0]}")
# Extract messages from extra_args (set by Rust preprocessor for multimodal) or fall back to direct field
messages = request.get("extra_args", {}).get(
"messages", request.get("messages", [])
)
_, _, embedding_paths = multimodal_processor.extract_prompt_and_media(messages)
if not embedding_paths:
# Placeholder for TRTLLM Encoder to be called
# TRTLLM Encoder will return a memory handler on the encoder GPU with the encodings
logging.warning(
"No embedding paths found, NIXL transfer for image urls not supported by TRTLLM Encoder yet"
)
yield {"error": "No embedding paths found"}
return
# Load the embeddings data
loaded_data = multimodal_processor.load_tensor_from_path_or_url( loaded_data = multimodal_processor.load_tensor_from_path_or_url(
embedding_paths[0] embedding_paths[0]
) )
# Handle both tensor and dictionary formats # Handle both tensor and dictionary formats
if isinstance(loaded_data, dict): if isinstance(loaded_data, dict):
# Dictionary format (e.g., maverick_mm_embed_seashore_v3.pt) # Dictionary format: contains 'mm_embeddings' key plus auxiliary data
encodings = loaded_data.get("mm_embeddings") encodings = loaded_data.get("mm_embeddings")
if encodings is None: if encodings is None:
yield {"error": "Dictionary embeddings missing 'mm_embeddings' key"} yield {"error": "Dictionary embeddings missing 'mm_embeddings' key"}
return return
# Store auxiliary data for later transmission
auxiliary_data = { auxiliary_data = {
k: v for k, v in loaded_data.items() if k != "mm_embeddings" k: v for k, v in loaded_data.items() if k != "mm_embeddings"
} }
else: else:
# Tensor format (e.g., llava_next_mm_embed_seashore.pt) # Tensor format: raw embeddings tensor
encodings = loaded_data encodings = loaded_data
auxiliary_data = {} auxiliary_data = {}
# Create readable operation with main embeddings tensor (works for both formats) # Create NIXL readable operation for prefill worker to read
descriptor = nixl_connect.Descriptor(encodings) descriptor = nixl_connect.Descriptor(encodings)
with await connector.create_readable(descriptor) as readable_op: with await connector.create_readable(descriptor) as readable_op:
# Get the metadata for the readable operation
op_metadata = readable_op.metadata() op_metadata = readable_op.metadata()
# Send back shape info, readable metadata, and serialized auxiliary data
response = { response = {
"nixl_readable_metadata": op_metadata.model_dump(), "nixl_readable_metadata": op_metadata.model_dump(),
"embeddings_shape": list(encodings.shape), "embeddings_shape": list(encodings.shape),
...@@ -254,9 +260,180 @@ class EncodeHelper: ...@@ -254,9 +260,180 @@ class EncodeHelper:
} }
yield response yield response
# Wait for the prefill worker to complete the read operation # Wait for prefill worker to complete the read
logging.debug( logging.debug(
"EncodeHelper waiting for PrefillHandler to read embeddings..." "EncodeHelper waiting for PrefillHandler to read embeddings..."
) )
await readable_op.wait_for_completion() await readable_op.wait_for_completion()
logging.debug("EncodeHelper completed readable operation.") logging.debug("EncodeHelper completed readable operation.")
@staticmethod
async def _process_full_epd_flow(
text_prompt: str,
image_urls: list,
tokenizer,
model_dir: str,
model_type: str,
engine,
):
"""
Process image URLs via TRT-LLM's MultimodalEncoder (full EPD flow).
Runs MultimodalEncoder.generate() to produce disaggregated_params
containing multimodal embedding handles for the prefill worker.
Args:
text_prompt: Text portion of the prompt
image_urls: List of image URLs to process
tokenizer: Tokenizer for encoding the processed prompt
model_dir: Path to model directory (required for AutoProcessor)
model_type: Model type string (required for placeholder retrieval)
engine: TensorRTLLMEngine with MultimodalEncoder
Yields:
Response with ep_disaggregated_params, processed_prompt, and prompt_token_ids
"""
# NOTE: `default_multimodal_input_loader` requires `model_dir` to load the
# HuggingFace AutoProcessor (for chat template application) and as a fallback
# for tokenizer loading. `model_type` is needed to retrieve the correct
# multimodal placeholders and apply model-specific preprocessing.
# Pass tokenizer to reuse the pre-initialized tokenizer instead of
# creating a new one per request
inputs = default_multimodal_input_loader(
tokenizer=tokenizer,
model_dir=model_dir,
model_type=model_type,
modality="image",
prompts=[text_prompt],
media=image_urls[0],
)
# NOTE: MultimodalEncoder.generate() is synchronous. Run it off-thread to avoid
# blocking the encode worker's event loop under concurrency.
encoder_outputs = await asyncio.to_thread(
lambda: list(engine.llm.generate(inputs))
)
if not encoder_outputs:
logging.error("ENCODE WORKER: encoder_outputs is empty")
yield {"ep_disaggregated_params": None}
return
ep_disaggregated_params = encoder_outputs[0].disaggregated_params
if ep_disaggregated_params is None:
logging.error(
"ENCODE WORKER: encoder_outputs[0].disaggregated_params is None"
)
yield {"ep_disaggregated_params": None}
return
if ep_disaggregated_params.multimodal_embedding_handles is None:
logging.warning(
"ENCODE WORKER: ep_disaggregated_params.multimodal_embedding_handles is None"
)
# Prepare for network transfer
encoded_params = DisaggregatedParamsCodec.encode(ep_disaggregated_params)
params_dict = asdict(encoded_params)
# Extract processed prompt (includes <image> tokens) for prefill/decode consistency
processed_prompt = None
prompt_token_ids = None
if isinstance(inputs, list) and len(inputs) > 0:
first_input = inputs[0]
if isinstance(first_input, dict):
processed_prompt = first_input.get("prompt")
else:
processed_prompt = getattr(first_input, "prompt", None)
# Tokenize the processed prompt for prefill worker
if processed_prompt and tokenizer is not None:
# NOTE: processed_prompt already contains template/placeholder tokens
# (e.g. <image>, [INST], etc.). Adding special tokens here can change
# token alignment across EPD stages (prefill/decode), so we explicitly
# avoid adding them.
prompt_token_ids = tokenizer.encode(
processed_prompt, add_special_tokens=False
)
logging.debug(
"ENCODE WORKER: Extracted processed_prompt (len=%s)",
len(processed_prompt) if processed_prompt is not None else None,
)
yield {
"ep_disaggregated_params": params_dict,
"processed_prompt": processed_prompt,
"prompt_token_ids": prompt_token_ids,
}
@staticmethod
async def process_encode_request(
request: Dict[str, Any],
multimodal_processor,
connector: Optional[nixl_connect.Connector],
tokenizer=None,
model_dir=None,
model_type=None,
engine=None,
):
"""
Process an ENCODE-mode request. Dispatches to the appropriate flow.
Args:
request: Request containing OpenAI-format multimodal messages
multimodal_processor: Processor to extract prompt/media and load embeddings
connector: NIXL connector (required only for embedding_paths flow)
tokenizer: Tokenizer for the model
model_dir: Path to model directory
model_type: Model type string
engine: TensorRTLLMEngine instance
Yields:
Response dictionary based on the flow:
- Embedding-path flow: nixl_readable_metadata + shape/dtype + auxiliary_data
- Full EPD flow: ep_disaggregated_params + processed_prompt + prompt_token_ids
"""
if multimodal_processor is None:
yield {"error": "No multimodal_processor configured on encode worker"}
return
# Extract messages and determine which flow to use
messages = request.get("extra_args", {}).get(
"messages", request.get("messages", [])
)
(
text_prompt,
image_urls,
embedding_paths,
) = multimodal_processor.extract_prompt_and_media(messages)
# Flow 1: Embedding-path flow (pre-computed embeddings via NIXL)
if embedding_paths:
if connector is None:
yield {"error": "NIXL connector is required for embedding_paths encode"}
return
async for response in EncodeHelper._process_embedding_path_flow(
embedding_paths, multimodal_processor, connector
):
yield response
# Flow 2: Full EPD flow (image URLs via MultimodalEncoder)
elif image_urls and text_prompt:
if model_dir is None or model_type is None:
yield {
"error": "model_dir and model_type are required for full EPD encode"
}
return
if engine is None:
yield {"error": "No engine configured on encode worker for full EPD"}
return
async for response in EncodeHelper._process_full_epd_flow(
text_prompt, image_urls, tokenizer, model_dir, model_type, engine
):
yield response
# No valid multimodal content found
else:
yield {"error": "No embedding_paths or image_urls found in request"}
...@@ -6,7 +6,10 @@ import logging ...@@ -6,7 +6,10 @@ import logging
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import AsyncGenerator, Optional from typing import AsyncGenerator, Optional
from tensorrt_llm import LLM from tensorrt_llm import LLM, MultimodalEncoder
from tensorrt_llm.llmapi.llm import BaseLLM
from dynamo.trtllm.constants import DisaggregationMode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -19,8 +22,20 @@ class Backend(str, enum.Enum): ...@@ -19,8 +22,20 @@ class Backend(str, enum.Enum):
class TensorRTLLMEngine: class TensorRTLLMEngine:
def __init__(self, engine_args): def __init__(
self,
engine_args,
disaggregation_mode: Optional[DisaggregationMode] = None,
):
self._llm: Optional[LLM] = None self._llm: Optional[LLM] = None
self.disaggregation_mode = (
disaggregation_mode
if disaggregation_mode is not None
else DisaggregationMode.AGGREGATED
)
# NOTE: `engine_args` may be reused by callers (e.g., for logging or other workers).
# Copy it so that our internal `pop()` / pruning doesn't leak side effects.
engine_args = dict(engine_args)
backend = engine_args.pop("backend", Backend.PYTORCH) backend = engine_args.pop("backend", Backend.PYTORCH)
if backend == Backend.PYTORCH: if backend == Backend.PYTORCH:
self._llm_cls = LLM self._llm_cls = LLM
...@@ -38,7 +53,27 @@ class TensorRTLLMEngine: ...@@ -38,7 +53,27 @@ class TensorRTLLMEngine:
async def initialize(self): async def initialize(self):
if not self._llm: if not self._llm:
self._llm = self._llm_cls(**self.engine_args) if self.disaggregation_mode == DisaggregationMode.ENCODE:
# Initialize the multimodal encoder for full EPD
# Prefill/decode workers initialize the standard TRT-LLM `LLM` from `engine_args`
# (model, backend settings, kv cache config, etc.). ENCODE workers instead use
# TRT-LLM's `MultimodalEncoder`, which has a different constructor surface.
# We intentionally pass only the supported parameters to avoid unexpected kwargs.
max_batch_size = self.engine_args.get("max_batch_size", 1)
model = self.engine_args.get("model")
logging.info(
f"Initializing multimodal encoder with max_batch_size: {max_batch_size}"
)
# MultimodalEncoder and LLM both inherit from BaseLLM in TRT-LLM,
# so storing either in self._llm is valid.
self._llm = MultimodalEncoder(
model=model,
max_batch_size=max_batch_size,
)
else:
# Prefill/decode workers: initialize standard TRT-LLM `LLM` with full engine_args
# (model path, backend settings, KV cache config, disaggregation settings, etc.)
self._llm = self._llm_cls(**self.engine_args)
async def cleanup(self): async def cleanup(self):
if self._llm: if self._llm:
...@@ -50,7 +85,7 @@ class TensorRTLLMEngine: ...@@ -50,7 +85,7 @@ class TensorRTLLMEngine:
self._llm = None self._llm = None
@property @property
def llm(self): def llm(self) -> BaseLLM:
if not self._llm: if not self._llm:
raise RuntimeError("Engine not initialized") raise RuntimeError("Engine not initialized")
return self._llm return self._llm
...@@ -91,8 +126,11 @@ class TensorRTLLMEngine: ...@@ -91,8 +126,11 @@ class TensorRTLLMEngine:
@asynccontextmanager @asynccontextmanager
async def get_llm_engine(engine_args) -> AsyncGenerator[TensorRTLLMEngine, None]: async def get_llm_engine(
engine = TensorRTLLMEngine(engine_args) engine_args,
disaggregation_mode: Optional[DisaggregationMode] = None,
) -> AsyncGenerator[TensorRTLLMEngine, None]:
engine = TensorRTLLMEngine(engine_args, disaggregation_mode)
try: try:
await engine.initialize() await engine.initialize()
yield engine yield engine
......
...@@ -352,7 +352,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -352,7 +352,7 @@ async def init(runtime: DistributedRuntime, config: Config):
config.dump_config_to, {"engine_args": engine_args, "dynamo_args": config} config.dump_config_to, {"engine_args": engine_args, "dynamo_args": config}
) )
async with get_llm_engine(engine_args) as engine: async with get_llm_engine(engine_args, config.disaggregation_mode) as engine:
endpoint = component.endpoint(config.endpoint) endpoint = component.endpoint(config.endpoint)
# should ideally call get_engine_runtime_config # should ideally call get_engine_runtime_config
......
...@@ -59,6 +59,8 @@ class MultimodalRequestProcessor: ...@@ -59,6 +59,8 @@ class MultimodalRequestProcessor:
self.allowed_local_media_path = allowed_local_media_path self.allowed_local_media_path = allowed_local_media_path
self.max_file_size_mb = max_file_size_mb self.max_file_size_mb = max_file_size_mb
self.max_file_size_bytes = max_file_size_mb * 1024 * 1024 self.max_file_size_bytes = max_file_size_mb * 1024 * 1024
# Used for streaming delta computation in create_response_chunk()
self.previous_decoded_text = ""
# Initialize tokenizer ONCE at startup to avoid per-request overhead # Initialize tokenizer ONCE at startup to avoid per-request overhead
if tokenizer is not None: if tokenizer is not None:
...@@ -163,27 +165,42 @@ class MultimodalRequestProcessor: ...@@ -163,27 +165,42 @@ class MultimodalRequestProcessor:
return " ".join(text_parts), image_urls, embedding_paths return " ".join(text_parts), image_urls, embedding_paths
async def process_openai_request( async def process_openai_request(
self, request: Dict, embeddings: Any self, request: Dict, embeddings: Any, ep_disaggregated_params: Any
) -> Optional[Any]: ) -> Optional[Any]:
"""Process OpenAI request and return with multimodal data.""" """Process OpenAI request and return with multimodal data."""
# Extract messages - check extra_args first (from Rust preprocessor for multimodal) # Extract messages - check extra_args first (from Rust preprocessor for multimodal)
# Fall back to direct messages field for backward compatibility # Fall back to direct messages field for backward compatibility
self.previous_decoded_text = ""
messages = request.get("extra_args", {}).get( messages = request.get("extra_args", {}).get(
"messages", request.get("messages", []) "messages", request.get("messages", [])
) )
text_prompt, image_urls, embedding_paths = self.extract_prompt_and_media( text_prompt, image_urls, embedding_paths = self.extract_prompt_and_media(
messages messages
) )
if not image_urls and not embedding_paths and not ep_disaggregated_params:
if not image_urls and not embedding_paths:
logging.warning("No multimodal content, returning None") logging.warning("No multimodal content, returning None")
return None return None
processed_prompt_from_encoder = request.get("_epd_processed_prompt")
# Only use EPD flow if we actually have encoder data
# For PD flow (no encoder), fall through to embedding_paths handling
if processed_prompt_from_encoder is not None:
text_prompt = processed_prompt_from_encoder
result = {"prompt": text_prompt}
prompt_token_ids = request.get("_epd_prompt_token_ids")
if prompt_token_ids:
result["prompt_token_ids"] = prompt_token_ids
else:
logging.warning("MM PROCESSOR: No prompt_token_ids from encoder")
return result
loader_kwargs = {} loader_kwargs = {}
if embeddings is not None: if embeddings is not None:
# EPD flow # EPD flow - embeddings received from encode worker via NIXL
loader_kwargs["mm_embeddings"] = [embeddings] loader_kwargs["mm_embeddings"] = [embeddings]
logging.debug(f"Using NIXL embeddings in prefill worker: {embeddings}") logging.info(
f"Using NIXL embeddings: shape={embeddings.shape if hasattr(embeddings, 'shape') else 'N/A'}"
)
elif image_urls: elif image_urls:
# Image-only flow # Image-only flow
loader_kwargs["media"] = [image_urls] loader_kwargs["media"] = [image_urls]
...@@ -192,7 +209,7 @@ class MultimodalRequestProcessor: ...@@ -192,7 +209,7 @@ class MultimodalRequestProcessor:
loader_kwargs["mm_embeddings"] = [ loader_kwargs["mm_embeddings"] = [
self.load_tensor_from_path_or_url(path) for path in embedding_paths self.load_tensor_from_path_or_url(path) for path in embedding_paths
] ]
logging.debug(f"Using embedding paths in prefill worker: {embedding_paths}") logging.info(f"Using embedding paths: {embedding_paths}")
# Process with default_multimodal_input_loader # Process with default_multimodal_input_loader
# Pass self.tokenizer to reuse the pre-initialized tokenizer instead of # Pass self.tokenizer to reuse the pre-initialized tokenizer instead of
...@@ -225,10 +242,20 @@ class MultimodalRequestProcessor: ...@@ -225,10 +242,20 @@ class MultimodalRequestProcessor:
if self.tokenizer is None: if self.tokenizer is None:
raise ValueError("Tokenizer must be provided for creating response chunks.") raise ValueError("Tokenizer must be provided for creating response chunks.")
new_tokens = output.token_ids[num_output_tokens_so_far:] all_tokens = output.token_ids
# Decode the new token IDs into a string. This is the incremental piece current_text = self.tokenizer.decode(
# of text to be sent to the client. all_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True
delta_text = self.tokenizer.decode(new_tokens) )
if num_output_tokens_so_far == 0:
# First chunk: use all decoded text
delta_text = current_text
# Store for next iteration
self.previous_decoded_text = current_text
else:
# Incremental chunk: extract delta using cached previous text
delta_text = current_text[len(self.previous_decoded_text) :]
# Update cache for next iteration
self.previous_decoded_text = current_text
# Assemble the delta payload for the response chunk. # Assemble the delta payload for the response chunk.
delta = {"content": delta_text if delta_text else ""} delta = {"content": delta_text if delta_text else ""}
if num_output_tokens_so_far == 0: if num_output_tokens_so_far == 0:
......
...@@ -19,7 +19,6 @@ import logging ...@@ -19,7 +19,6 @@ import logging
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from enum import Enum
from typing import Any, AsyncGenerator, Optional, Union from typing import Any, AsyncGenerator, Optional, Union
import torch import torch
...@@ -34,6 +33,7 @@ from dynamo.logits_processing.examples import HelloWorldLogitsProcessor ...@@ -34,6 +33,7 @@ from dynamo.logits_processing.examples import HelloWorldLogitsProcessor
from dynamo.nixl_connect import Connector from dynamo.nixl_connect import Connector
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.constants import DisaggregationMode
from dynamo.trtllm.engine import TensorRTLLMEngine from dynamo.trtllm.engine import TensorRTLLMEngine
from dynamo.trtllm.logits_processing.adapter import create_trtllm_adapters from dynamo.trtllm.logits_processing.adapter import create_trtllm_adapters
from dynamo.trtllm.multimodal_processor import MultimodalRequestProcessor from dynamo.trtllm.multimodal_processor import MultimodalRequestProcessor
...@@ -46,13 +46,6 @@ from dynamo.trtllm.utils.disagg_utils import ( ...@@ -46,13 +46,6 @@ from dynamo.trtllm.utils.disagg_utils import (
configure_dynamo_logging() configure_dynamo_logging()
class DisaggregationMode(Enum):
AGGREGATED = "prefill_and_decode"
PREFILL = "prefill"
DECODE = "decode"
ENCODE = "encode"
@dataclass @dataclass
class RequestHandlerConfig: class RequestHandlerConfig:
""" """
...@@ -219,6 +212,296 @@ class HandlerBase: ...@@ -219,6 +212,296 @@ class HandlerBase:
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
def _decode_disaggregated_params_from_prefill(
self, prefill_result: dict
) -> tuple[Any, dict]:
"""
Extract and decode disaggregated params from prefill_result.
Args:
prefill_result: Result from prefill worker containing encoded disaggregated params
Returns:
Tuple of (disaggregated_params, epd_metadata) where:
- disaggregated_params: Decoded LlmDisaggregatedParams object
- epd_metadata: Dictionary containing EPD-specific metadata (_epd_processed_prompt, etc.)
"""
params_dict = prefill_result["disaggregated_params"]
# Remove worker_id if present (added by prefill worker, not needed for decode)
params_dict.pop("worker_id", None)
# Extract EPD metadata that was packed by prefill worker
epd_metadata = {}
if "_epd_metadata" in params_dict:
epd_metadata = params_dict.pop("_epd_metadata")
logging.debug(
f"DECODE: Extracted _epd_metadata with {len(epd_metadata)} fields"
)
# Decode the disaggregated params
disaggregated_params = DisaggregatedParamsCodec.decode(
DisaggregatedParams(**params_dict)
)
# Set to generation_only mode for decode phase
disaggregated_params.request_type = "generation_only"
# In generation-only mode, multimodal embeddings are already processed and in KV cache
# Remove multimodal_embedding_handles to avoid TRT-LLM validation error
# NOTE: `hasattr` is used because multimodal_embedding_handles may not be present
# on DisaggregatedParams in all EPD flows (e.g., text-only requests or certain stages).
if (
hasattr(disaggregated_params, "multimodal_embedding_handles")
and disaggregated_params.multimodal_embedding_handles
):
disaggregated_params.multimodal_embedding_handles = None
logging.debug("DECODE: Set request_type to generation_only")
return disaggregated_params, epd_metadata
def _encode_and_pack_disaggregated_params(
self,
output: GenerationResult,
disaggregated_params: Any,
request: dict,
res: Any,
processed_input: Any = None,
) -> Optional[dict]:
"""
Encode and pack disaggregated params for PREFILL mode response.
Handles:
- Choosing between output and input disaggregated params
- Preserving multimodal_embedding_handles in EPD flow
- Encoding params for transmission
- Packing prefill metadata for DECODE optimization
Args:
output: GenerationResult from the engine
disaggregated_params: Input disaggregated params
request: Original request dict
res: RequestOutput object with prompt and prompt_token_ids attributes
processed_input: The processed input dict from process_openai_request (contains correct prompt)
Returns:
Dictionary with encoded disaggregated params, or None if encoding failed
"""
# In EPD flow, output.disaggregated_params might be None, use the input params
params_to_encode = (
output.disaggregated_params
if output.disaggregated_params is not None
else disaggregated_params
)
# In EPD flow, manually preserve multimodal_embedding_handles from input
# because TRT-LLM engine may not propagate them through prefill
if params_to_encode is not None and disaggregated_params is not None:
input_handles = getattr(
disaggregated_params,
"multimodal_embedding_handles",
None,
)
output_handles = getattr(
params_to_encode, "multimodal_embedding_handles", None
)
if input_handles is not None and output_handles is None:
params_to_encode.multimodal_embedding_handles = input_handles
# Also preserve hashes if they exist
input_hashes = getattr(disaggregated_params, "multimodal_hashes", None)
if input_hashes is not None:
params_to_encode.multimodal_hashes = input_hashes
encoded_params = DisaggregatedParamsCodec.encode(params_to_encode)
if encoded_params is None:
logging.error("PREFILL: encoded_params is None - decode worker will fail!")
return None
logging.debug("PREFILL: Successfully encoded disaggregated params")
params_dict = asdict(encoded_params)
# Pack prefill metadata for DECODE worker optimization
# The frontend only forwards disaggregated_params from prefill response
# Note: max_tokens is already handled by Rust frontend's PrefillRouter
prefill_metadata = {}
# ALWAYS pack prompt info for DECODE to skip re-processing
# Per TRT-LLM team: DECODE never needs to reload images - KV cache has the context
# Use processed_input['prompt'] (from process_openai_request) which is the actual
# multimodal prompt used by TRT-LLM, not res.prompt which might be raw
if (
processed_input
and isinstance(processed_input, dict)
and processed_input.get("prompt")
):
prefill_metadata["_prefill_prompt"] = processed_input["prompt"]
elif res.prompt:
prefill_metadata["_prefill_prompt"] = res.prompt
if res.prompt_token_ids:
prefill_metadata["_prefill_prompt_token_ids"] = list(res.prompt_token_ids)
# EPD-specific: use encoder's prompt if available
if "_epd_processed_prompt" in request and res.prompt:
prefill_metadata["_epd_processed_prompt"] = res.prompt
if "_epd_prompt_token_ids" in request and res.prompt_token_ids:
prefill_metadata["_epd_prompt_token_ids"] = list(res.prompt_token_ids)
# Add metadata to the disaggregated_params dict
if prefill_metadata:
params_dict["_epd_metadata"] = prefill_metadata
return params_dict
def _setup_disaggregated_params_for_mode(
self,
request: dict,
ep_disaggregated_params: Optional[Any],
) -> tuple[Any, Any, dict]:
"""
Setup disaggregated_params based on PREFILL/DECODE mode.
For PREFILL mode:
- Uses ep_disaggregated_params from encode worker if available
- Otherwise creates new LlmDisaggregatedParams with request_type="context_only"
For DECODE mode:
- Decodes disaggregated_params from prefill_result
- Extracts EPD metadata for prompt optimization
Args:
request: Request dictionary (may contain prefill_result)
ep_disaggregated_params: Optional params from encode worker (EPD flow)
Returns:
Tuple of (disaggregated_params, ep_disaggregated_params, epd_metadata)
"""
disaggregated_params = None
epd_metadata = {}
# PREFILL mode: setup context_only params
if self.disaggregation_mode == DisaggregationMode.PREFILL:
if ep_disaggregated_params:
ep_disaggregated_params.request_type = "context_only"
disaggregated_params = ep_disaggregated_params
else:
disaggregated_params = LlmDisaggregatedParams(
request_type="context_only"
)
# DECODE mode: decode params from prefill_result
prefill_result = request.get("prefill_result")
if prefill_result and "disaggregated_params" in prefill_result:
(
disaggregated_params,
epd_metadata,
) = self._decode_disaggregated_params_from_prefill(prefill_result)
# For full EPD flow, make decoded params available to multimodal processor
ep_disaggregated_params = disaggregated_params
return disaggregated_params, ep_disaggregated_params, epd_metadata
async def _prepare_input_for_generation(
self,
request: dict,
embeddings: Optional[Union[torch.Tensor, dict]],
ep_disaggregated_params: Optional[Any],
epd_metadata: dict,
) -> Any:
"""
Prepare input for TRT-LLM generation (handles multimodal/text flows).
Three paths:
1. DECODE with prefill metadata: Use cached prompt, skip image re-processing
2. Multimodal: Process via multimodal_processor
3. Text-only: Use token_ids from request
Args:
request: Request dictionary
embeddings: Optional embeddings tensor/dict from encode worker
ep_disaggregated_params: Optional params from encode worker (EPD flow)
epd_metadata: Metadata from prefill worker (DECODE optimization)
Returns:
Processed input for TRT-LLM (dict with prompt/token_ids, or raw token_ids)
"""
# DECODE mode: Use prefill metadata to skip re-processing multimodal content
# Per TRT-LLM team: DECODE never needs to reload images - KV cache has the context
has_prefill_metadata = epd_metadata and (
epd_metadata.get("_prefill_prompt")
or epd_metadata.get("_epd_processed_prompt")
)
if (
self.disaggregation_mode == DisaggregationMode.DECODE
and has_prefill_metadata
):
# Use prompt/token_ids from PREFILL, skip image re-processing
prefill_prompt = epd_metadata.get("_prefill_prompt") or epd_metadata.get(
"_epd_processed_prompt"
)
prefill_token_ids = epd_metadata.get(
"_prefill_prompt_token_ids"
) or epd_metadata.get("_epd_prompt_token_ids")
# Build input without multimodal data (already in KV cache)
# Use the SAME multimodal key that PREFILL used:
# - EPD/Embeddings flow: PREFILL used multi_modal_embeddings
# - Simple P→D (image URL): PREFILL used multi_modal_data
is_epd_flow = epd_metadata.get("_epd_processed_prompt") is not None
processed_input = {
"prompt": prefill_prompt,
"prompt_token_ids": prefill_token_ids,
}
if is_epd_flow:
processed_input["multi_modal_embeddings"] = None
else:
processed_input["multi_modal_data"] = None
return processed_input
# PREFILL/ENCODE/AGGREGATED: Process multimodal content if available
if self.multimodal_processor:
processed_input = await self.multimodal_processor.process_openai_request(
request, embeddings, ep_disaggregated_params
)
if processed_input:
return processed_input
# Fallback: text-only flow
return request.get("token_ids")
def _normalize_request_format(self, request: dict) -> None:
"""
Convert OpenAI request format to TRT-LLM internal format.
Moves fields from OpenAI locations to where TRT-LLM expects them:
- max_tokens: top-level → stop_conditions.max_tokens
- temperature: top-level → sampling_options.temperature
Note: The Rust frontend's PrefillRouter handles the *value* of max_tokens
(sets to 1 for prefill, restores original for decode). This method only
moves fields to the correct location.
Args:
request: Request dictionary to normalize (modified in place)
"""
# Ensure stop_conditions exists
if "stop_conditions" not in request:
request["stop_conditions"] = {}
if "max_tokens" in request and "max_tokens" not in request["stop_conditions"]:
request["stop_conditions"]["max_tokens"] = request.pop("max_tokens")
# Ensure sampling_options exists
if "sampling_options" not in request:
request["sampling_options"] = {}
if (
"temperature" in request
and "temperature" not in request["sampling_options"]
):
request["sampling_options"]["temperature"] = request.pop("temperature")
async def _initiate_shutdown(self, error: Exception): async def _initiate_shutdown(self, error: Exception):
"""Initiate graceful shutdown after fatal error""" """Initiate graceful shutdown after fatal error"""
logging.warning(f"Initiating graceful shutdown due to: {error}") logging.warning(f"Initiating graceful shutdown due to: {error}")
...@@ -242,6 +525,7 @@ class HandlerBase: ...@@ -242,6 +525,7 @@ class HandlerBase:
request: dict, request: dict,
context: Context, context: Context,
embeddings: Optional[Union[torch.Tensor, dict]] = None, embeddings: Optional[Union[torch.Tensor, dict]] = None,
ep_disaggregated_params: Optional[DisaggregatedParams] = None,
): ):
""" """
Generate responses based on the disaggregation mode in the request. Generate responses based on the disaggregation mode in the request.
...@@ -250,22 +534,24 @@ class HandlerBase: ...@@ -250,22 +534,24 @@ class HandlerBase:
request: The request dictionary containing generation parameters request: The request dictionary containing generation parameters
context: Context object for cancellation handling context: Context object for cancellation handling
embeddings: Optional tensor or dict containing embeddings for multimodal processing embeddings: Optional tensor or dict containing embeddings for multimodal processing
ep_disaggregated_params: Optional DisaggregatedParams from encode worker (full EPD flow)
""" """
logging.debug(f"Request: {request}") logging.debug(f"Request: {request}")
# Default to text-based input. This will be overwritten if multimodal # Normalize OpenAI format to TRT-LLM internal format
# content is found and processed. self._normalize_request_format(request)
processed_input = None
# Check for multimodal request and process it # Setup disaggregated params based on PREFILL/DECODE mode
if self.multimodal_processor: (
processed_input = await self.multimodal_processor.process_openai_request( disaggregated_params,
request, embeddings ep_disaggregated_params,
) epd_metadata,
) = self._setup_disaggregated_params_for_mode(request, ep_disaggregated_params)
else: # Prepare input for generation (handles multimodal/text flows)
# text-only flow processed_input = await self._prepare_input_for_generation(
processed_input = request.get("token_ids") request, embeddings, ep_disaggregated_params, epd_metadata
)
# Check if there is an error in the publisher error queue # Check if there is an error in the publisher error queue
publishers_error = ( publishers_error = (
...@@ -274,31 +560,18 @@ class HandlerBase: ...@@ -274,31 +560,18 @@ class HandlerBase:
if publishers_error: if publishers_error:
raise publishers_error raise publishers_error
# Decode the disaggregated params from the request # For PREFILL mode, set max_tokens=1 (we only need to process context)
disaggregated_params = None
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
request["stop_conditions"]["max_tokens"] = 1 request["stop_conditions"]["max_tokens"] = 1
disaggregated_params = LlmDisaggregatedParams(request_type="context_only") # disaggregated_params is already set above (lines 460-468)
# Don't overwrite it here as it may contain multimodal_embedding_handles from encoder
if "prefill_result" in request:
if self.disaggregation_mode == DisaggregationMode.PREFILL:
raise ValueError("Cannot provide disaggregated_params in prefill mode")
request["prefill_result"].get("disaggregated_params", {}).pop(
"worker_id", None
)
disaggregated_params = DisaggregatedParamsCodec.decode(
DisaggregatedParams(
**request["prefill_result"].get("disaggregated_params")
)
)
disaggregated_params.request_type = "generation_only"
if ( if (
self.disaggregation_mode == DisaggregationMode.DECODE self.disaggregation_mode == DisaggregationMode.DECODE
and disaggregated_params is None and disaggregated_params is None
): ):
logging.error("DECODE: disaggregated_params is None but required!")
logging.error(f"DECODE: Request keys: {list(request.keys())}")
raise ValueError("Disaggregated params are required for decode mode") raise ValueError("Disaggregated params are required for decode mode")
num_output_tokens_so_far = 0 num_output_tokens_so_far = 0
...@@ -416,9 +689,11 @@ class HandlerBase: ...@@ -416,9 +689,11 @@ class HandlerBase:
out["stop_reason"] = output.stop_reason out["stop_reason"] = output.stop_reason
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
# Return the disaggregated params only when operating in prefill mode. # Return the disaggregated params only when operating in prefill mode.
out["disaggregated_params"] = asdict( params_dict = self._encode_and_pack_disaggregated_params(
DisaggregatedParamsCodec.encode(output.disaggregated_params) output, disaggregated_params, request, res, processed_input
) )
if params_dict is not None:
out["disaggregated_params"] = params_dict
if out.get("finish_reason"): if out.get("finish_reason"):
num_input_tokens = len(request.get("token_ids", [])) num_input_tokens = len(request.get("token_ids", []))
......
...@@ -2,6 +2,9 @@ ...@@ -2,6 +2,9 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import logging import logging
from typing import Optional
from tensorrt_llm.llmapi import DisaggregatedParams
from dynamo._core import Context from dynamo._core import Context
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
...@@ -57,25 +60,32 @@ class EncodeHandler(HandlerBase): ...@@ -57,25 +60,32 @@ class EncodeHandler(HandlerBase):
def __init__(self, config: RequestHandlerConfig): def __init__(self, config: RequestHandlerConfig):
super().__init__(config) super().__init__(config)
# Initialize to None by default to avoid AttributeError if multimodal_processor is not set
self.model_dir = None
self.model_type = None
self.tokenizer = None
if self.multimodal_processor:
self.model_dir = self.multimodal_processor.model_dir
self.model_type = self.multimodal_processor.model_type
self.tokenizer = self.multimodal_processor.tokenizer
async def generate(self, request: dict, context: Context): async def generate(self, request: dict, context: Context):
logging.debug(f"New Request ID: {context.id()}") logging.debug(f"New Request ID: {context.id()}")
if self.connector: if self.multimodal_processor is None:
# Use helper method to process embedding request logging.error("encode handler: no multimodal_processor configured")
async for response in EncodeHelper.process_embedding_request( raise RuntimeError("encode handler: no multimodal_processor configured")
request, self.multimodal_processor, self.connector
): async for response in EncodeHelper.process_encode_request(
yield response request,
return self.multimodal_processor,
else: self.connector,
logging.error("encode handler: no Dynamo NIXL connector found") self.tokenizer,
raise RuntimeError("encode handler: no Dynamo NIXL connector found") self.model_dir,
self.model_type,
if not request.get("streaming", False): self.engine,
yield request ):
return yield response
return
yield request
class PrefillHandler(HandlerBase): class PrefillHandler(HandlerBase):
...@@ -86,7 +96,39 @@ class PrefillHandler(HandlerBase): ...@@ -86,7 +96,39 @@ class PrefillHandler(HandlerBase):
def __init__(self, config: RequestHandlerConfig): def __init__(self, config: RequestHandlerConfig):
super().__init__(config) super().__init__(config)
async def remote_encode_full_epd(self, request: dict):
"""
Call encode worker for full EPD flow and unpack the response.
Args:
request: Request dict
Returns:
Encoder's DisaggregatedParams to be used by the prefill worker
"""
encode_response = None
async for res in await self.encode_client.round_robin(request):
encode_response = res.data()
break
if not encode_response:
raise RuntimeError("Did not receive a response from the encode worker.")
ep_disaggregated_params = self._unpack_full_epd_response(
encode_response, request
)
return ep_disaggregated_params
async def remote_encode_with_nixl(self, request: dict): async def remote_encode_with_nixl(self, request: dict):
"""
Call encode worker for NIXL flow to load embeddings and unpack the response.
Args:
request: Request dict
Returns:
Encoder's embeddings tensor to be used by the prefill worker
"""
# Get response with shape info and readable metadata # Get response with shape info and readable metadata
encode_response = None encode_response = None
async for res in await self.encode_client.round_robin(request): async for res in await self.encode_client.round_robin(request):
...@@ -101,6 +143,43 @@ class PrefillHandler(HandlerBase): ...@@ -101,6 +143,43 @@ class PrefillHandler(HandlerBase):
encode_response, self.connector encode_response, self.connector
) )
def _unpack_full_epd_response(
self, encode_response: dict, request: dict
) -> Optional[DisaggregatedParams]:
"""
Unpack encode worker response from full EPD flow.
Extracts DisaggregatedParams and stores EPD metadata in the request
for downstream processing (multimodal_processor, decode worker).
Args:
encode_response: Response dict from encode worker
request: Request dict to store metadata in (modified in-place)
Returns:
DisaggregatedParams if present in response, None otherwise
"""
if "ep_disaggregated_params" not in encode_response:
return None
params_dict = encode_response["ep_disaggregated_params"]
if params_dict is None:
return None
# Reconstruct DisaggregatedParams object from dict
ep_disaggregated_params = DisaggregatedParams(**params_dict)
ep_disaggregated_params.request_type = "context_only"
# Store processed prompt from encoder (includes <image> tokens)
if "processed_prompt" in encode_response:
request["_epd_processed_prompt"] = encode_response["processed_prompt"]
# Store prompt_token_ids from encoder for decode worker
if "prompt_token_ids" in encode_response:
request["_epd_prompt_token_ids"] = encode_response["prompt_token_ids"]
return ep_disaggregated_params
async def generate(self, request: dict, context: Context): async def generate(self, request: dict, context: Context):
""" """
Prefill worker: process prompt and return disaggregated_params. Prefill worker: process prompt and return disaggregated_params.
...@@ -109,25 +188,42 @@ class PrefillHandler(HandlerBase): ...@@ -109,25 +188,42 @@ class PrefillHandler(HandlerBase):
logging.debug(f"Prefill Request ID: {context.id()}") logging.debug(f"Prefill Request ID: {context.id()}")
logging.debug(f"PrefillHandler.generate received request: {request}") logging.debug(f"PrefillHandler.generate received request: {request}")
embeddings_tensor = None embeddings_tensor = None
ep_disaggregated_params = None
if self.multimodal_processor: if self.multimodal_processor:
# Extract messages from extra_args (set by Rust preprocessor) or fall back to direct field # Extract messages from extra_args (set by Rust preprocessor) or fall back to direct field
messages = request.get("extra_args", {}).get( messages = request.get("extra_args", {}).get(
"messages", request.get("messages", []) "messages", request.get("messages", [])
) )
_, _, embedding_paths = self.multimodal_processor.extract_prompt_and_media( (
messages _,
) image_urls,
embedding_paths,
) = self.multimodal_processor.extract_prompt_and_media(messages)
# Handle embedding paths (NIXL transfer of pre-computed embeddings)
if embedding_paths: if embedding_paths:
if self.encode_client and self.connector: if self.encode_client and self.connector:
logging.debug( logging.info(f"PrefillHandler: embedding_paths={embedding_paths}")
"PrefillHandler calling Encode Worker via remote_encode_with_nixl"
)
embeddings_tensor = await self.remote_encode_with_nixl(request) embeddings_tensor = await self.remote_encode_with_nixl(request)
else:
# We can still handle embedding_paths without NIXL:
# `MultimodalRequestProcessor.process_openai_request` will load the embeddings
# locally in the prefill worker as a fallback. The encode-worker+NIXL path is
# useful when you want a dedicated I/O stage and/or explicit RDMA transfer.
logging.info(
"PrefillHandler: no encode_client/connector; falling back to local embedding load"
)
# Handle image URLs (full E-PD flow with MultimodalEncoder)
elif image_urls:
if self.encode_client:
ep_disaggregated_params = await self.remote_encode_full_epd(request)
# Generate prefill response locally and return disaggregated_params # Normal flow: Generate the prefill response locally with embeddings
response_count = 0 response_count = 0
async for res in self.generate_locally(request, context, embeddings_tensor): async for res in self.generate_locally(
request, context, embeddings_tensor, ep_disaggregated_params
):
response_count += 1 response_count += 1
if response_count > 1: if response_count > 1:
raise ValueError("Prefill response should be generated only once.") raise ValueError("Prefill response should be generated only once.")
......
...@@ -89,4 +89,4 @@ async def test_get_llm_engine_forwards_backend(backend): ...@@ -89,4 +89,4 @@ async def test_get_llm_engine_forwards_backend(backend):
async with get_llm_engine(engine_args=engine_args): async with get_llm_engine(engine_args=engine_args):
pass pass
mocked_engine.assert_called_once_with(engine_args) mocked_engine.assert_called_once_with(engine_args, None)
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import base64 import base64
import dataclasses
from tensorrt_llm.llmapi import DisaggregatedParams from tensorrt_llm.llmapi import DisaggregatedParams
...@@ -30,18 +31,10 @@ class DisaggregatedParamsCodec: ...@@ -30,18 +31,10 @@ class DisaggregatedParamsCodec:
if disaggregated_params is None: if disaggregated_params is None:
return None return None
opaque_state = ( opaque_state = disaggregated_params.opaque_state
base64.b64decode(disaggregated_params.opaque_state) if isinstance(opaque_state, str):
if disaggregated_params.opaque_state is not None opaque_state = base64.b64decode(opaque_state)
else None return dataclasses.replace(disaggregated_params, opaque_state=opaque_state)
)
return DisaggregatedParams(
request_type=disaggregated_params.request_type,
first_gen_tokens=disaggregated_params.first_gen_tokens,
ctx_request_id=disaggregated_params.ctx_request_id,
opaque_state=opaque_state,
draft_tokens=disaggregated_params.draft_tokens,
)
@staticmethod @staticmethod
def encode( def encode(
...@@ -50,15 +43,7 @@ class DisaggregatedParamsCodec: ...@@ -50,15 +43,7 @@ class DisaggregatedParamsCodec:
if disaggregated_params is None: if disaggregated_params is None:
return None return None
encoded_opaque_state = ( opaque_state = disaggregated_params.opaque_state
base64.b64encode(disaggregated_params.opaque_state).decode("utf-8") if isinstance(opaque_state, (bytes, bytearray)):
if disaggregated_params.opaque_state is not None opaque_state = base64.b64encode(opaque_state).decode("utf-8")
else None return dataclasses.replace(disaggregated_params, opaque_state=opaque_state)
)
return DisaggregatedParams(
request_type=disaggregated_params.request_type,
first_gen_tokens=disaggregated_params.first_gen_tokens,
ctx_request_id=disaggregated_params.ctx_request_id,
opaque_state=encoded_opaque_state,
draft_tokens=disaggregated_params.draft_tokens,
)
...@@ -47,10 +47,11 @@ TRT-LLM supports aggregated and traditional disaggregated patterns. See [Archite ...@@ -47,10 +47,11 @@ TRT-LLM supports aggregated and traditional disaggregated patterns. See [Archite
| Pattern | Supported | Launch Script | Notes | | Pattern | Supported | Launch Script | Notes |
|---------|-----------|---------------|-------| |---------|-----------|---------------|-------|
| EPD (Simple Aggregated) | ✅ | `agg.sh` | Easiest setup | | Aggregated | ✅ | `agg.sh` | Easiest setup, single worker |
| E/PD (Encode Separate) | ❌ | N/A | Not supported | | EP/D (Traditional Disaggregated) | ✅ | `disagg_multimodal.sh` | Prefill handles encoding, 2 workers |
| E/P/D (Full Disaggregation) | 🚧 WIP | N/A | PR #4668 in progress | | E/P/D (Full - Image URLs) | ✅ | `epd_multimodal_image_and_embeddings.sh` | Standalone encoder with `MultimodalEncoder`, 3 workers |
| EP/D (Traditional Disaggregated) | ✅ | `disagg_multimodal.sh` | Prefill handles encoding | | E/P/D (Full - Pre-computed Embeddings) | ✅ | `epd_multimodal_image_and_embeddings.sh` | Standalone encoder with NIXL transfer, 3 workers |
| E/P/D (Large Models) | ✅ | `epd_disagg.sh` | For Llama-4 Scout/Maverick, multi-node |
### Component Flags ### Component Flags
...@@ -59,7 +60,7 @@ TRT-LLM supports aggregated and traditional disaggregated patterns. See [Archite ...@@ -59,7 +60,7 @@ TRT-LLM supports aggregated and traditional disaggregated patterns. See [Archite
| Worker | `--modality multimodal` | Complete pipeline (aggregated) | | Worker | `--modality multimodal` | Complete pipeline (aggregated) |
| Prefill Worker | `--disaggregation-mode prefill` | Image processing + Prefill (multimodal tokenization happens here) | | Prefill Worker | `--disaggregation-mode prefill` | Image processing + Prefill (multimodal tokenization happens here) |
| Decode Worker | `--disaggregation-mode decode` | Decode only | | Decode Worker | `--disaggregation-mode decode` | Decode only |
| Encode Worker (WIP) | `--disaggregation-mode encode` | Image encoding (E/P/D flow) | | Encode Worker | `--disaggregation-mode encode` | Image encoding (E/P/D flow) |
## Aggregated Serving ## Aggregated Serving
...@@ -143,6 +144,90 @@ curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d ' ...@@ -143,6 +144,90 @@ curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '
For a large model like `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, a multi-node setup is required for disaggregated serving (see [Multi-node Deployment](#multi-node-deployment-slurm) below), while aggregated serving can run on a single node. This is because the model with a disaggregated configuration is too large to fit on a single node's GPUs. For instance, running this model in disaggregated mode requires 2 nodes with 8xH200 GPUs or 4 nodes with 4xGB200 GPUs. For a large model like `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, a multi-node setup is required for disaggregated serving (see [Multi-node Deployment](#multi-node-deployment-slurm) below), while aggregated serving can run on a single node. This is because the model with a disaggregated configuration is too large to fit on a single node's GPUs. For instance, running this model in disaggregated mode requires 2 nodes with 8xH200 GPUs or 4 nodes with 4xGB200 GPUs.
## Full E/P/D Flow (Image URLs)
For high-performance multimodal inference, Dynamo supports a standalone encoder with an **Encode-Prefill-Decode (E/P/D)** flow using TRT-LLM's `MultimodalEncoder`. This separates the vision encoding stage from prefill and decode, enabling better GPU utilization and scalability.
### Supported Input Formats
| Format | Example | Description |
|--------|---------|-------------|
| **HTTP/HTTPS URL** | `https://example.com/image.jpg` | Remote image files |
| **Base64 Data URL** | `data:image/jpeg;base64,...` | Inline base64-encoded images |
### How It Works
In the full E/P/D flow:
1. **Encode Worker**: Runs TRT-LLM's `MultimodalEncoder.generate()` to process image URLs through the vision encoder and projector
2. **Prefill Worker**: Receives `disaggregated_params` containing multimodal embedding handles, processes context and generates KV cache
3. **Decode Worker**: Performs streaming token generation using the KV cache
The encode worker uses TRT-LLM's `MultimodalEncoder` class (which inherits from `BaseLLM`) and only requires the model path and batch size - no KV cache configuration is needed since it only runs the vision encoder + projector.
### How to Launch
```bash
cd $DYNAMO_HOME
# Launch 3-worker E/P/D flow with image URL support
./examples/backends/trtllm/launch/epd_multimodal_image_and_embeddings.sh
```
### Example Request
```bash
curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "llava-v1.6-mistral-7b-hf",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "Describe the image"},
{
"type": "image_url",
"image_url": {
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png"
}
}
]
}
],
"max_tokens": 160
}'
```
### E/P/D Architecture (Image URLs)
```mermaid
sequenceDiagram
participant Client
participant Frontend
participant PrefillWorker as "Prefill Worker"
participant EncodeWorker as "Encode Worker"
participant DecodeWorker as "Decode Worker"
Client->>Frontend: POST /v1/chat/completions (image URL)
Frontend->>PrefillWorker: Route to prefill worker
PrefillWorker->>EncodeWorker: Send request (image URL)
Note over EncodeWorker: MultimodalEncoder.generate()<br/>runs vision encoder + projector
EncodeWorker->>PrefillWorker: Return disaggregated_params<br/>(multimodal_embedding_handles)
Note over PrefillWorker: Process context with embeddings<br/>Generate KV cache
PrefillWorker->>Frontend: Return prefill response
Frontend->>DecodeWorker: Route to decode worker
DecodeWorker->>Frontend: Stream response chunks
Frontend->>Client: Stream response
```
### Key Differences from EP/D (Traditional Disaggregated)
| Aspect | EP/D (Traditional) | E/P/D (Full) |
|--------|-------------------|--------------|
| **Encoding** | Prefill worker handles image encoding | Dedicated encode worker |
| **Prefill Load** | Higher (encoding + prefill) | Lower (prefill only) |
| **Use Case** | Simpler setup | Better scalability for vision-heavy workloads |
| **Launch Script** | `disagg_multimodal.sh` | `epd_multimodal_image_and_embeddings.sh` |
## Pre-computed Embeddings with E/P/D Flow ## Pre-computed Embeddings with E/P/D Flow
For high-performance multimodal inference, Dynamo supports pre-computed embeddings with an **Encode-Prefill-Decode (E/P/D)** flow using **NIXL (RDMA)** for zero-copy tensor transfer. For high-performance multimodal inference, Dynamo supports pre-computed embeddings with an **Encode-Prefill-Decode (E/P/D)** flow using **NIXL (RDMA)** for zero-copy tensor transfer.
...@@ -318,10 +403,11 @@ pkill srun ...@@ -318,10 +403,11 @@ pkill srun
| Use Case | Script | NIXL Used? | Data Transfer | | Use Case | Script | NIXL Used? | Data Transfer |
|----------|--------|------------|---------------| |----------|--------|------------|---------------|
| EPD (Simple Aggregated) | `agg.sh` | No | All in one worker | | Aggregated | `agg.sh` | No | All in one worker |
| EP/D (Traditional Disaggregated) | `disagg_multimodal.sh` | Optional | Prefill → Decode (KV cache via UCX or NIXL) | | EP/D (Traditional Disaggregated) | `disagg_multimodal.sh` | Optional | Prefill → Decode (KV cache via UCX or NIXL) |
| E/P/D (pre-computed embeddings) | `epd_disagg.sh` | Yes | Encoder → Prefill (embeddings via NIXL) | | E/P/D (Image URLs) | `epd_multimodal_image_and_embeddings.sh` | No | Encoder → Prefill (handles via params), Prefill → Decode (KV cache) |
| E/P/D (WIP) | N/A | No | Encoder → Prefill (handles via params), Prefill → Decode (KV cache) | | E/P/D (Pre-computed Embeddings) | `epd_multimodal_image_and_embeddings.sh` | Yes | Encoder → Prefill (embeddings via NIXL RDMA) |
| E/P/D (Large Models) | `epd_disagg.sh` | Yes | Encoder → Prefill (embeddings via NIXL), Prefill → Decode (KV cache) |
> **Note:** NIXL for KV cache transfer is currently beta and only supported on AMD64 (x86_64) architecture. > **Note:** NIXL for KV cache transfer is currently beta and only supported on AMD64 (x86_64) architecture.
...@@ -349,26 +435,29 @@ await register_llm( ...@@ -349,26 +435,29 @@ await register_llm(
| Transfer Stage | Message | NIXL Transfer | | Transfer Stage | Message | NIXL Transfer |
|----------------|---------|---------------| |----------------|---------|---------------|
| **Frontend → Prefill** | Request with image URL or embedding path | No | | **Frontend → Prefill** | Request with image URL or embedding path | No |
| **Encode → Prefill (pre-computed)** | NIXL metadata | Yes (Embeddings tensor) | | **Prefill → Encode (Image URL)** | Request with image URL | No |
| **Encode → Prefill (Image URL) (WIP)** | Disaggregated params with multimodal handles | No | | **Encode → Prefill (Image URL)** | `ep_disaggregated_params` with `multimodal_embedding_handles`, processed prompt, and token IDs | No |
| **Prefill → Decode** | Disaggregated params | Configurable (KV cache: NIXL default, UCX optional) | | **Prefill → Encode (Embedding Path)** | Request with embedding file path | No |
| **Encode → Prefill (Embedding Path)** | NIXL readable metadata + shape/dtype + auxiliary data | Yes (Embeddings tensor via RDMA) |
| **Prefill → Decode** | `disaggregated_params` with `_epd_metadata` (prompt, token IDs) | Configurable (KV cache: NIXL default, UCX optional) |
## Known Limitations ## Known Limitations
- **No Data URL support** - Only HTTP/HTTPS URLs supported; `data:image/...` base64 URLs not supported
- **No video support** - No video encoder implementation - **No video support** - No video encoder implementation
- **No audio support** - No audio encoder implementation - **No audio support** - No audio encoder implementation
- **Multimodal preprocessing/tokenization happens in Python** - Rust may forward token_ids, but multimodal requests are parsed and re-tokenized in the Python worker - **Multimodal preprocessing/tokenization happens in Python** - Rust may forward token_ids, but multimodal requests are parsed and re-tokenized in the Python worker
- **E/P/D mode is WIP** - Full E/P/D with image URLs under development
- **Multi-node H100 limitation** - Loading `meta-llama/Llama-4-Maverick-17B-128E-Instruct` with 8 nodes of H100 with TP=16 is not possible due to head count divisibility (`num_attention_heads: 40` not divisible by `tp_size: 16`) - **Multi-node H100 limitation** - Loading `meta-llama/Llama-4-Maverick-17B-128E-Instruct` with 8 nodes of H100 with TP=16 is not possible due to head count divisibility (`num_attention_heads: 40` not divisible by `tp_size: 16`)
- **llava-v1.6-mistral-7b-hf model crash** - Known issue with TRTLLM backend compatibilty with `TensorRT LLM version: 1.2.0rc6.post1`. To use Llava model download revision `revision='52320fb52229` locally using HF.
- **Embeddings file crash** - Known issue with TRTLLM backend compatibilty with `TensorRT LLM version: 1.2.0rc6.post1`. Embedding file parsing crashes in `attach_multimodal_embeddings(`. To be fixed in next TRTLLM upgrade.
## Supported Models ## Supported Models
Multimodal models listed in [TensorRT-LLM supported models](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/models/supported-models.md) are supported by Dynamo. Multimodal models listed in [TensorRT-LLM supported models](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/models/supported-models.md) are supported by Dynamo.
Common examples: Common examples:
- Llama 4 Vision models (Maverick, Scout) - **Llama 4 Vision models** (Maverick, Scout) - Recommended for large-scale deployments
- Qwen2-VL models - **LLaVA models** (e.g., `llava-hf/llava-v1.6-mistral-7b-hf`) - Default model for E/P/D examples
- **Qwen2-VL models** - Supported in traditional disaggregated mode
- Other vision-language models with TRT-LLM support - Other vision-language models with TRT-LLM support
## Key Files ## Key Files
...@@ -376,8 +465,12 @@ Common examples: ...@@ -376,8 +465,12 @@ Common examples:
| File | Description | | File | Description |
|------|-------------| |------|-------------|
| `components/src/dynamo/trtllm/main.py` | Worker initialization and setup | | `components/src/dynamo/trtllm/main.py` | Worker initialization and setup |
| `components/src/dynamo/trtllm/utils/trtllm_utils.py` | Command-line argument parsing | | `components/src/dynamo/trtllm/engine.py` | TensorRTLLMEngine wrapper (LLM and MultimodalEncoder) |
| `components/src/dynamo/trtllm/constants.py` | DisaggregationMode enum (AGGREGATED, PREFILL, DECODE, ENCODE) |
| `components/src/dynamo/trtllm/encode_helper.py` | Encode worker request processing (embedding-path and full EPD flows) |
| `components/src/dynamo/trtllm/multimodal_processor.py` | Multimodal request processing | | `components/src/dynamo/trtllm/multimodal_processor.py` | Multimodal request processing |
| `components/src/dynamo/trtllm/request_handlers/handlers.py` | Request handler factory | | `components/src/dynamo/trtllm/request_handlers/handlers.py` | Request handlers (EncodeHandler, PrefillHandler, DecodeHandler) |
| `components/src/dynamo/trtllm/request_handlers/handler_base.py` | Base handler and disaggregation modes | | `components/src/dynamo/trtllm/request_handlers/handler_base.py` | Base handler with disaggregated params encoding/decoding |
| `components/src/dynamo/trtllm/utils/disagg_utils.py` | DisaggregatedParamsCodec for network transfer |
| `components/src/dynamo/trtllm/utils/trtllm_utils.py` | Command-line argument parsing |
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
tensor_parallel_size: 1
moe_expert_parallel_size: 1
enable_attention_dp: false
max_num_tokens: 8192
max_batch_size: 16
trust_remote_code: true
backend: pytorch
enable_chunked_prefill: true
disable_overlap_scheduler: false
kv_cache_config:
free_gpu_memory_fraction: 0.30
enable_block_reuse: false
cache_transceiver_config:
backend: DEFAULT
\ No newline at end of file
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
tensor_parallel_size: 1
moe_expert_parallel_size: 1
enable_attention_dp: false
max_num_tokens: 8192
max_batch_size: 16
trust_remote_code: true
backend: pytorch
enable_chunked_prefill: true
# Overlap scheduler not currently supported in prefill only workers.
disable_overlap_scheduler: true
# Note: kv_cache_config is not needed for encode workers since MultimodalEncoder
# only runs vision encoder + projector and doesn't need KV cache memory.
cache_transceiver_config:
backend: DEFAULT
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
tensor_parallel_size: 1
moe_expert_parallel_size: 1
enable_attention_dp: false
max_num_tokens: 8192
max_batch_size: 16
trust_remote_code: true
backend: pytorch
enable_chunked_prefill: true
# Overlap scheduler not currently supported in prefill only workers.
disable_overlap_scheduler: true
kv_cache_config:
free_gpu_memory_fraction: 0.30
enable_block_reuse: false
cache_transceiver_config:
backend: DEFAULT
\ No newline at end of file
...@@ -4,13 +4,14 @@ ...@@ -4,13 +4,14 @@
# Environment variables with defaults # Environment variables with defaults
export DYNAMO_HOME=${DYNAMO_HOME:-"/workspace"} export DYNAMO_HOME=${DYNAMO_HOME:-"/workspace"}
export MODEL_PATH=${MODEL_PATH:-"Qwen/Qwen2-VL-7B-Instruct"} export MODEL_PATH=${MODEL_PATH:-"llava-hf/llava-v1.6-mistral-7b-hf"}
export SERVED_MODEL_NAME=${SERVED_MODEL_NAME:-"Qwen/Qwen2-VL-7B-Instruct"} export SERVED_MODEL_NAME=${SERVED_MODEL_NAME:-"llava-v1.6-mistral-7b-hf"}
export PREFILL_ENGINE_ARGS=${PREFILL_ENGINE_ARGS:-"$DYNAMO_HOME/examples/backends/trtllm/engine_configs/qwen2-vl-7b-instruct/prefill.yaml"} export PREFILL_ENGINE_ARGS=${PREFILL_ENGINE_ARGS:-"$DYNAMO_HOME/examples/backends/trtllm/engine_configs/llava-v1.6-mistral-7b-hf/prefill.yaml"}
export DECODE_ENGINE_ARGS=${DECODE_ENGINE_ARGS:-"$DYNAMO_HOME/examples/backends/trtllm/engine_configs/qwen2-vl-7b-instruct/decode.yaml"} export DECODE_ENGINE_ARGS=${DECODE_ENGINE_ARGS:-"$DYNAMO_HOME/examples/backends/trtllm/engine_configs/llava-v1.6-mistral-7b-hf/decode.yaml"}
export PREFILL_CUDA_VISIBLE_DEVICES=${PREFILL_CUDA_VISIBLE_DEVICES:-"0"} export PREFILL_CUDA_VISIBLE_DEVICES=${PREFILL_CUDA_VISIBLE_DEVICES:-"0"}
export DECODE_CUDA_VISIBLE_DEVICES=${DECODE_CUDA_VISIBLE_DEVICES:-"1"} export DECODE_CUDA_VISIBLE_DEVICES=${DECODE_CUDA_VISIBLE_DEVICES:-"1"}
export MODALITY=${MODALITY:-"multimodal"} export MODALITY=${MODALITY:-"multimodal"}
export CUSTOM_TEMPLATE=${CUSTOM_TEMPLATE:-"$DYNAMO_HOME/examples/backends/trtllm/templates/llava_multimodal.jinja"}
# Setup cleanup trap # Setup cleanup trap
cleanup() { cleanup() {
...@@ -33,6 +34,7 @@ CUDA_VISIBLE_DEVICES=$PREFILL_CUDA_VISIBLE_DEVICES python3 -m dynamo.trtllm \ ...@@ -33,6 +34,7 @@ CUDA_VISIBLE_DEVICES=$PREFILL_CUDA_VISIBLE_DEVICES python3 -m dynamo.trtllm \
--served-model-name "$SERVED_MODEL_NAME" \ --served-model-name "$SERVED_MODEL_NAME" \
--extra-engine-args "$PREFILL_ENGINE_ARGS" \ --extra-engine-args "$PREFILL_ENGINE_ARGS" \
--modality "$MODALITY" \ --modality "$MODALITY" \
--custom-jinja-template "$CUSTOM_TEMPLATE" \
--disaggregation-mode prefill & --disaggregation-mode prefill &
PREFILL_PID=$! PREFILL_PID=$!
...@@ -42,4 +44,5 @@ CUDA_VISIBLE_DEVICES=$DECODE_CUDA_VISIBLE_DEVICES python3 -m dynamo.trtllm \ ...@@ -42,4 +44,5 @@ CUDA_VISIBLE_DEVICES=$DECODE_CUDA_VISIBLE_DEVICES python3 -m dynamo.trtllm \
--served-model-name "$SERVED_MODEL_NAME" \ --served-model-name "$SERVED_MODEL_NAME" \
--extra-engine-args "$DECODE_ENGINE_ARGS" \ --extra-engine-args "$DECODE_ENGINE_ARGS" \
--modality "$MODALITY" \ --modality "$MODALITY" \
--custom-jinja-template "$CUSTOM_TEMPLATE" \
--disaggregation-mode decode --disaggregation-mode decode
...@@ -8,7 +8,6 @@ export MODEL_PATH=${MODEL_PATH:-"meta-llama/Llama-4-Scout-17B-16E-Instruct"} ...@@ -8,7 +8,6 @@ export MODEL_PATH=${MODEL_PATH:-"meta-llama/Llama-4-Scout-17B-16E-Instruct"}
export SERVED_MODEL_NAME=${SERVED_MODEL_NAME:-"meta-llama/Llama-4-Scout-17B-16E-Instruct"} export SERVED_MODEL_NAME=${SERVED_MODEL_NAME:-"meta-llama/Llama-4-Scout-17B-16E-Instruct"}
export PREFILL_ENGINE_ARGS=${PREFILL_ENGINE_ARGS:-"$DYNAMO_HOME/examples/backends/trtllm/engine_configs/llama4/multimodal/llama4-Scout/prefill.yaml"} export PREFILL_ENGINE_ARGS=${PREFILL_ENGINE_ARGS:-"$DYNAMO_HOME/examples/backends/trtllm/engine_configs/llama4/multimodal/llama4-Scout/prefill.yaml"}
export DECODE_ENGINE_ARGS=${DECODE_ENGINE_ARGS:-"$DYNAMO_HOME/examples/backends/trtllm/engine_configs/llama4/multimodal/llama4-Scout/decode.yaml"} export DECODE_ENGINE_ARGS=${DECODE_ENGINE_ARGS:-"$DYNAMO_HOME/examples/backends/trtllm/engine_configs/llama4/multimodal/llama4-Scout/decode.yaml"}
# Placeholder for now, this is NO-OP as encoder just loads embeddings path, done to maintain consistency with other workers adn future api enhancements
export ENCODE_ENGINE_ARGS=${ENCODE_ENGINE_ARGS:-"$DYNAMO_HOME/examples/backends/trtllm/engine_configs/llama4/multimodal/llama4-Scout/encode.yaml"} export ENCODE_ENGINE_ARGS=${ENCODE_ENGINE_ARGS:-"$DYNAMO_HOME/examples/backends/trtllm/engine_configs/llama4/multimodal/llama4-Scout/encode.yaml"}
export PREFILL_CUDA_VISIBLE_DEVICES=${PREFILL_CUDA_VISIBLE_DEVICES:-"0,1,2,3"} export PREFILL_CUDA_VISIBLE_DEVICES=${PREFILL_CUDA_VISIBLE_DEVICES:-"0,1,2,3"}
export DECODE_CUDA_VISIBLE_DEVICES=${DECODE_CUDA_VISIBLE_DEVICES:-"4,5,6,7"} export DECODE_CUDA_VISIBLE_DEVICES=${DECODE_CUDA_VISIBLE_DEVICES:-"4,5,6,7"}
......
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Environment variables with defaults
export DYNAMO_HOME=${DYNAMO_HOME:-"/workspace"}
export MODEL_PATH=${MODEL_PATH:-"llava-hf/llava-v1.6-mistral-7b-hf"}
export SERVED_MODEL_NAME=${SERVED_MODEL_NAME:-"llava-v1.6-mistral-7b-hf"}
export PREFILL_ENGINE_ARGS=${PREFILL_ENGINE_ARGS:-"$DYNAMO_HOME/examples/backends/trtllm/engine_configs/llava-v1.6-mistral-7b-hf/prefill.yaml"}
export DECODE_ENGINE_ARGS=${DECODE_ENGINE_ARGS:-"$DYNAMO_HOME/examples/backends/trtllm/engine_configs/llava-v1.6-mistral-7b-hf/decode.yaml"}
export ENCODE_ENGINE_ARGS=${ENCODE_ENGINE_ARGS:-"$DYNAMO_HOME/examples/backends/trtllm/engine_configs/llava-v1.6-mistral-7b-hf/encode.yaml"}
export PREFILL_CUDA_VISIBLE_DEVICES=${PREFILL_CUDA_VISIBLE_DEVICES:-"0"}
export DECODE_CUDA_VISIBLE_DEVICES=${DECODE_CUDA_VISIBLE_DEVICES:-"1"}
export ENCODE_CUDA_VISIBLE_DEVICES=${ENCODE_CUDA_VISIBLE_DEVICES:-"2"}
export ENCODE_ENDPOINT=${ENCODE_ENDPOINT:-"dyn://dynamo.tensorrt_llm_encode.generate"}
export MODALITY=${MODALITY:-"multimodal"}
export CUSTOM_TEMPLATE=${CUSTOM_TEMPLATE:-"$DYNAMO_HOME/examples/backends/trtllm/templates/llava_multimodal.jinja"}
# Setup cleanup trap
cleanup() {
echo "Cleaning up background processes..."
kill $DYNAMO_PID $PREFILL_PID $DECODE_PID $ENCODE_PID 2>/dev/null || true
wait $DYNAMO_PID $PREFILL_PID $DECODE_PID $ENCODE_PID 2>/dev/null || true
echo "Cleanup complete."
}
trap cleanup EXIT INT TERM
# run frontend
python3 -m dynamo.frontend --http-port 8000 &
DYNAMO_PID=$!
# run encode worker
CUDA_VISIBLE_DEVICES=$ENCODE_CUDA_VISIBLE_DEVICES python3 -m dynamo.trtllm \
--model-path "$MODEL_PATH" \
--served-model-name "$SERVED_MODEL_NAME" \
--extra-engine-args "$ENCODE_ENGINE_ARGS" \
--modality "$MODALITY" \
--disaggregation-mode encode &
ENCODE_PID=$!
# run prefill worker
CUDA_VISIBLE_DEVICES=$PREFILL_CUDA_VISIBLE_DEVICES python3 -m dynamo.trtllm \
--model-path "$MODEL_PATH" \
--served-model-name "$SERVED_MODEL_NAME" \
--extra-engine-args "$PREFILL_ENGINE_ARGS" \
--modality "$MODALITY" \
--disaggregation-mode prefill \
--encode-endpoint "$ENCODE_ENDPOINT" \
--custom-jinja-template "$CUSTOM_TEMPLATE" &
PREFILL_PID=$!
# run decode worker
CUDA_VISIBLE_DEVICES=$DECODE_CUDA_VISIBLE_DEVICES python3 -m dynamo.trtllm \
--model-path "$MODEL_PATH" \
--served-model-name "$SERVED_MODEL_NAME" \
--extra-engine-args "$DECODE_ENGINE_ARGS" \
--modality "$MODALITY" \
--disaggregation-mode decode \
--custom-jinja-template "$CUSTOM_TEMPLATE" &
DECODE_PID=$!
wait $DYNAMO_PID
\ No newline at end of file
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Environment variables with defaults
export DYNAMO_HOME=${DYNAMO_HOME:-"/workspace"}
export MODEL_PATH=${MODEL_PATH:-"llava-hf/llava-v1.6-mistral-7b-hf"}
export SERVED_MODEL_NAME=${SERVED_MODEL_NAME:-"llava-v1.6-mistral-7b-hf"}
export PREFILL_ENGINE_ARGS=${PREFILL_ENGINE_ARGS:-"$DYNAMO_HOME/examples/backends/trtllm/engine_configs/llava-v1.6-mistral-7b-hf/prefill.yaml"}
export DECODE_ENGINE_ARGS=${DECODE_ENGINE_ARGS:-"$DYNAMO_HOME/examples/backends/trtllm/engine_configs/llava-v1.6-mistral-7b-hf/decode.yaml"}
export ENCODE_ENGINE_ARGS=${ENCODE_ENGINE_ARGS:-"$DYNAMO_HOME/examples/backends/trtllm/engine_configs/llava-v1.6-mistral-7b-hf/encode.yaml"}
export PREFILL_CUDA_VISIBLE_DEVICES=${PREFILL_CUDA_VISIBLE_DEVICES:-"0"}
export DECODE_CUDA_VISIBLE_DEVICES=${DECODE_CUDA_VISIBLE_DEVICES:-"1"}
export ENCODE_CUDA_VISIBLE_DEVICES=${ENCODE_CUDA_VISIBLE_DEVICES:-"2"}
export ENCODE_ENDPOINT=${ENCODE_ENDPOINT:-"dyn://dynamo.tensorrt_llm_encode.generate"}
export MODALITY=${MODALITY:-"multimodal"}
export ALLOWED_LOCAL_MEDIA_PATH=${ALLOWED_LOCAL_MEDIA_PATH:-"/tmp"}
export MAX_FILE_SIZE_MB=${MAX_FILE_SIZE_MB:-50}
export CUSTOM_TEMPLATE=${CUSTOM_TEMPLATE:-"$DYNAMO_HOME/examples/backends/trtllm/templates/llava_multimodal.jinja"}
# Setup cleanup trap
cleanup() {
echo "Cleaning up background processes..."
kill $DYNAMO_PID $PREFILL_PID $DECODE_PID $ENCODE_PID 2>/dev/null || true
wait $DYNAMO_PID $PREFILL_PID $DECODE_PID $ENCODE_PID 2>/dev/null || true
echo "Cleanup complete."
}
trap cleanup EXIT INT TERM
# run frontend
python3 -m dynamo.frontend --http-port 8000 &
DYNAMO_PID=$!
# run encode worker
CUDA_VISIBLE_DEVICES=$ENCODE_CUDA_VISIBLE_DEVICES python3 -m dynamo.trtllm \
--model-path "$MODEL_PATH" \
--served-model-name "$SERVED_MODEL_NAME" \
--extra-engine-args "$ENCODE_ENGINE_ARGS" \
--modality "$MODALITY" \
--allowed-local-media-path "$ALLOWED_LOCAL_MEDIA_PATH" \
--max-file-size-mb "$MAX_FILE_SIZE_MB" \
--disaggregation-mode encode &
ENCODE_PID=$!
# run prefill worker
CUDA_VISIBLE_DEVICES=$PREFILL_CUDA_VISIBLE_DEVICES python3 -m dynamo.trtllm \
--model-path "$MODEL_PATH" \
--served-model-name "$SERVED_MODEL_NAME" \
--extra-engine-args "$PREFILL_ENGINE_ARGS" \
--modality "$MODALITY" \
--disaggregation-mode prefill \
--encode-endpoint "$ENCODE_ENDPOINT" \
--custom-jinja-template "$CUSTOM_TEMPLATE" &
PREFILL_PID=$!
# run decode worker
CUDA_VISIBLE_DEVICES=$DECODE_CUDA_VISIBLE_DEVICES python3 -m dynamo.trtllm \
--model-path "$MODEL_PATH" \
--served-model-name "$SERVED_MODEL_NAME" \
--extra-engine-args "$DECODE_ENGINE_ARGS" \
--modality "$MODALITY" \
--allowed-local-media-path "$ALLOWED_LOCAL_MEDIA_PATH" \
--max-file-size-mb "$MAX_FILE_SIZE_MB" \
--disaggregation-mode decode \
--custom-jinja-template "$CUSTOM_TEMPLATE" &
DECODE_PID=$!
wait $DYNAMO_PID
\ No newline at end of file
{% for message in messages %}
{%- if message['role'] == 'system' -%}
<<SYS>>
{{ message['content'] }}
<</SYS>>
{% elif message['role'] == 'user' -%}
[INST] {% if message['content'] is string %}{{ message['content'] }}{% else %}{% for item in message['content'] %}{% if item['type'] == 'image_url' %}<image>
{% elif item['type'] == 'text' %}{{ item['text'] }}{% endif %}{% endfor %}{% endif %} [/INST]
{% elif message['role'] == 'assistant' -%}
{{ message['content'] }}</s>
{% endif %}
{%- endfor -%}
...@@ -199,6 +199,22 @@ trtllm_configs = { ...@@ -199,6 +199,22 @@ trtllm_configs = {
delayed_start=60, delayed_start=60,
request_payloads=[multimodal_payload_default()], request_payloads=[multimodal_payload_default()],
), ),
"epd_multimodal_image_and_embeddings": TRTLLMConfig(
name="epd_multimodal_image_and_embeddings",
directory=trtllm_dir,
script_name="epd_multimodal_image_and_embeddings.sh",
marks=[
pytest.mark.gpu_4,
pytest.mark.trtllm,
pytest.mark.multimodal,
pytest.mark.nightly,
],
model="llava-hf/llava-v1.6-mistral-7b-hf",
frontend_port=DefaultPort.FRONTEND.value,
timeout=1200,
delayed_start=120,
request_payloads=[multimodal_payload_default()],
),
"completions_only": TRTLLMConfig( "completions_only": TRTLLMConfig(
name="completions_only", name="completions_only",
directory=trtllm_dir, directory=trtllm_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment