Unverified Commit 781fa062 authored by Indrajit Bhosale's avatar Indrajit Bhosale Committed by GitHub
Browse files

feat: Add Encode Worker and NIXL support to trtllm multimodal flow (#2452)


Signed-off-by: default avatarIndrajit Bhosale <iamindrajitb@gmail.com>
parent 17861703
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
tensor_parallel_size: 1
moe_expert_parallel_size: 1
enable_attention_dp: false
max_num_tokens: 8192
trust_remote_code: true
backend: pytorch
disable_overlap_scheduler: false
cuda_graph_config:
max_batch_size: 16
kv_cache_config:
free_gpu_memory_fraction: 0.85
cache_transceiver_config:
backend: default
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Environment variables with defaults
export MODEL_PATH=${MODEL_PATH:-"Qwen/Qwen2-VL-7B-Instruct"}
export SERVED_MODEL_NAME=${SERVED_MODEL_NAME:-"Qwen/Qwen2-VL-7B-Instruct"}
export DISAGGREGATION_STRATEGY=${DISAGGREGATION_STRATEGY:-"decode_first"}
export PREFILL_ENGINE_ARGS=${PREFILL_ENGINE_ARGS:-"engine_configs/prefill.yaml"}
export DECODE_ENGINE_ARGS=${DECODE_ENGINE_ARGS:-"engine_configs/decode.yaml"}
export ENCODE_ENGINE_ARGS=${ENCODE_ENGINE_ARGS:-"engine_configs/encode.yaml"}
export PREFILL_CUDA_VISIBLE_DEVICES=${PREFILL_CUDA_VISIBLE_DEVICES:-"0"}
export DECODE_CUDA_VISIBLE_DEVICES=${DECODE_CUDA_VISIBLE_DEVICES:-"1"}
export ENCODE_CUDA_VISIBLE_DEVICES=${ENCODE_CUDA_VISIBLE_DEVICES:-"2"}
export ENCODE_ENDPOINT=${ENCODE_ENDPOINT:-"dyn://dynamo.tensorrt_llm_encode.generate"}
export MODALITY=${MODALITY:-"multimodal"}
export ALLOWED_LOCAL_MEDIA_PATH=${ALLOWED_LOCAL_MEDIA_PATH:-"/tmp"}
export MAX_FILE_SIZE_MB=${MAX_FILE_SIZE_MB:-50}
# Setup cleanup trap
cleanup() {
echo "Cleaning up background processes..."
kill $DYNAMO_PID $PREFILL_PID $DECODE_PID $ENCODE_PID 2>/dev/null || true
wait $DYNAMO_PID $PREFILL_PID $DECODE_PID $ENCODE_PID 2>/dev/null || true
echo "Cleanup complete."
}
trap cleanup EXIT INT TERM
# run clear_namespace
python3 utils/clear_namespace.py --namespace dynamo
# run frontend
python3 -m dynamo.frontend --http-port 8000 &
DYNAMO_PID=$!
# run encode worker
CUDA_VISIBLE_DEVICES=$ENCODE_CUDA_VISIBLE_DEVICES python3 -m dynamo.trtllm \
--model-path "$MODEL_PATH" \
--served-model-name "$SERVED_MODEL_NAME" \
--extra-engine-args "$ENCODE_ENGINE_ARGS" \
--disaggregation-strategy "$DISAGGREGATION_STRATEGY" \
--modality "$MODALITY" \
--allowed-local-media-path "$ALLOWED_LOCAL_MEDIA_PATH" \
--max-file-size-mb "$MAX_FILE_SIZE_MB" \
--disaggregation-mode encode &
ENCODE_PID=$!
# run prefill worker
CUDA_VISIBLE_DEVICES=$PREFILL_CUDA_VISIBLE_DEVICES python3 -m dynamo.trtllm \
--model-path "$MODEL_PATH" \
--served-model-name "$SERVED_MODEL_NAME" \
--extra-engine-args "$PREFILL_ENGINE_ARGS" \
--disaggregation-strategy "$DISAGGREGATION_STRATEGY" \
--modality "$MODALITY" \
--disaggregation-mode prefill \
--encode-endpoint "$ENCODE_ENDPOINT" &
PREFILL_PID=$!
# run decode worker
CUDA_VISIBLE_DEVICES=$DECODE_CUDA_VISIBLE_DEVICES python3 -m dynamo.trtllm \
--model-path "$MODEL_PATH" \
--served-model-name "$SERVED_MODEL_NAME" \
--extra-engine-args "$DECODE_ENGINE_ARGS" \
--disaggregation-strategy "$DISAGGREGATION_STRATEGY" \
--modality "$MODALITY" \
--disaggregation-mode decode &
DECODE_PID=$!
wait $DYNAMO_PID
## Encode-Prefill-Decode (EPD) Flow with NIXL
For high-performance multimodal inference with large embeddings, Dynamo supports a specialized **Encode-Prefill-Decode (EPD)** flow using **NIXL (RDMA)** for zero-copy tensor transfer.
### Enabling the Feature
This is an experimental feature that requires using a specific TensorRT-LLM commit.
To enable it build the dynamo container with the `--tensorrtllm-commit` flag, followed by the commit hash:
```bash
./container/build.sh --framework trtllm --tensorrtllm-commit b4065d8ca64a64eee9fdc64b39cb66d73d4be47c
```
### Key Features
- **High Performance**: Zero-copy RDMA transfer for embeddings
- **Dynamic Shape Allocation**: Automatically handles variable embedding shapes per image
- **Multi-Format Support**: Works with tensor files (`.pt`) and dictionary-based embeddings
- **Hybrid Transfer**: Large tensors via NIXL, small metadata via JSON
### How to use
```bash
cd $DYNAMO_HOME/components/backends/trtllm
# Launch 3-worker EPD flow with NIXL
./launch/epd_disagg.sh
```
### Configuration
The EPD flow uses a dedicated **Encode Worker** that runs separately from the Prefill and Decode workers. The `ENCODE_ENDPOINT` environment variable specifies how the Prefill worker communicates with the Encode worker:
```bash
export ENCODE_ENDPOINT="dyn://dynamo.tensorrt_llm_encode.generate"
```
This endpoint follows Dynamo's standard format: `dyn://namespace.component.endpoint` where the Encode worker registers itself as `dynamo.tensorrt_llm_encode.generate`.
For local embedding file access, use the `--allowed-local-media-path "$ALLOWED_LOCAL_MEDIA_PATH"` parameter to specify the secure directory path where embedding files can be loaded from (default: `/tmp`). This prevents path traversal attacks while allowing flexible file access within the designated directory.
```bash
export ALLOWED_LOCAL_MEDIA_PATH="/tmp"
```
For tensor file size protection, use the `--max-file-size-mb "$MAX_FILE_SIZE_MB"` parameter to limit the maximum size of downloadable embedding files/Image URLs (default: `50MB`). This prevents Denial of Service (DoS) attacks from maliciously large files while accommodating typical embedding file sizes.
```bash
export MAX_FILE_SIZE_MB=50
```
### Architecture Overview
The EPD flow implements a **3-worker architecture** for high-performance multimodal inference:
- **Encode Worker**: Loads and processes multimodal embeddings
- **Prefill Worker**: Handles initial context processing and KV-cache generation
- **Decode Worker**: Performs streaming token generation
### Request Flow Diagrams
#### Prefill-First Disaggregation Strategy
```mermaid
sequenceDiagram
participant Client
participant Gateway
participant PrefillWorker as "Prefill Worker<br/>(AggregatedHandler)"
participant EncodeWorker as "Encode Worker<br/>(EncodeHandler)"
participant DecodeWorker as "Decode Worker<br/>(DecodeHandler)"
participant NIXL as "NIXL<br/>(RDMA Transfer)"
Note over Client,NIXL: Prefill-First Strategy: Context processing first, then streaming generation
Client->>Gateway: POST /v1/chat/completions<br/>(multimodal request)
Gateway->>PrefillWorker: Route request
Note over PrefillWorker: Check for multimodal content
PrefillWorker->>EncodeWorker: Send request<br/>(contains embedding paths)
Note over EncodeWorker: Load embeddings from file/url<br/>
EncodeWorker->>NIXL: Create readable operation<br/>
EncodeWorker->>PrefillWorker: Send metadata + NIXL info<br/>(JSON: shape, dtype, aux_data)
Note over PrefillWorker: Allocate tensor with dynamic shape
PrefillWorker->>NIXL: Begin read operation
NIXL-->>PrefillWorker: Zero-copy transfer complete<br/>
Note over PrefillWorker: Reconstruct embeddings<br/>(mm_embeddings + special_tokens + offsets)
Note over PrefillWorker: Process full context<br/>(text + multimodal embeddings)
Note over PrefillWorker: Generate KV-cache<br/>(max_tokens=1 in prefill mode)
PrefillWorker->>DecodeWorker: Transfer KV-cache + disaggregated_params<br/>(generation_only mode)
Note over DecodeWorker: Continue generation<br/>(streaming tokens)
DecodeWorker->>Gateway: Stream response chunk 1
Gateway->>Client: Response chunk 1
DecodeWorker->>Gateway: Stream response chunk 2
Gateway->>Client: Response chunk 2
DecodeWorker->>Gateway: ... (continue streaming)
Gateway->>Client: ... (continue streaming)
DecodeWorker->>Gateway: Final response + [DONE]
Gateway->>Client: Final response + [DONE]
```
#### Decode-First Disaggregation Strategy
```mermaid
sequenceDiagram
participant Client
participant Gateway
participant DecodeWorker as "Decode Worker<br/>(DecodeHandler)<br/>PRIMARY"
participant PrefillWorker as "Prefill Worker<br/>(PrefillHandler)"
participant EncodeWorker as "Encode Worker<br/>(EncodeHandler)"
participant NIXL as "NIXL<br/>(RDMA Transfer)"
Note over Client,NIXL: Decode-First Strategy: DecodeWorker orchestrates prefill then handles generation
Client->>Gateway: POST /v1/chat/completions<br/>(multimodal request)
Gateway->>DecodeWorker: Route request<br/>(primary worker)
Note over DecodeWorker: Check disaggregation_strategy == DECODE_FIRST
Note over DecodeWorker: Call remote_prefill() to trigger prefill
DecodeWorker->>PrefillWorker: Send request via remote_prefill()
Note over PrefillWorker: Check for multimodal content
PrefillWorker->>EncodeWorker: Send request<br/>(contains embedding paths)
Note over EncodeWorker: Load embeddings from file<br/>
EncodeWorker->>NIXL: Create readable operation<br/>
EncodeWorker->>PrefillWorker: Send metadata + NIXL info<br/>(JSON: shape, dtype, aux_data)
Note over PrefillWorker: Allocate tensor with dynamic shape
PrefillWorker->>NIXL: Begin read operation
NIXL-->>PrefillWorker: Zero-copy transfer complete<br/>
Note over PrefillWorker: Reconstruct embeddings<br/>(mm_embeddings + special_tokens + offsets)
Note over PrefillWorker: Process full context<br/>(text + multimodal embeddings)
Note over PrefillWorker: Generate prefill response<br/>(max_tokens=1 in prefill mode)
PrefillWorker->>DecodeWorker: Return prefill response<br/>(disaggregated_params)
Note over DecodeWorker: Extract disaggregated_params<br/>from prefill_response
Note over DecodeWorker: Update request with params<br/>request["disaggregated_params"] = response_data["disaggregated_params"]
Note over DecodeWorker: Begin local generation<br/>(generate_locally with prefill state)
DecodeWorker->>Gateway: Stream response chunk 1
Gateway->>Client: Response chunk 1
DecodeWorker->>Gateway: Stream response chunk 2
Gateway->>Client: Response chunk 2
DecodeWorker->>Gateway: ... (continue streaming)
Gateway->>Client: ... (continue streaming)
DecodeWorker->>Gateway: Final response + [DONE]
Gateway->>Client: Final response + [DONE]
```
### How the System Works
1. **Request Processing**: Multimodal requests containing embedding file paths OR urls are routed based on disaggregation strategy
2. **Multimodal Loading**: EncodeWorker loads large embedding files and extracts auxiliary metadata
3. **NIXL Transfer**: Main tensors transferred via zero-copy RDMA, small metadata via JSON for efficiency
4. **Dynamic Allocation**: Consumer workers allocate tensors with exact shapes received from EncodeWorker
5. **Reconstruction**: Original embedding format (dictionary or tensor) is reconstructed for model processing
### Example Request
The request format is identical to regular multimodal requests:
```bash
curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "Describe the image"},
{
"type": "image_url",
"image_url": {"url": "/path/to/embeddings.pt"}
}
]
}
],
"max_tokens": 160
}'
```
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Multimodal Support
> [!Important]
> There are some known issues in tensorrt_llm==1.0.0rc6 version for multimodal support
> It is important to rebuild the dynamo container with a specific version of tensorrt_llm
> commit to use multimodal feature.
## Build Container
```bash
./container/build.sh --framework trtllm --tensorrtllm-commit b4065d8ca64a64eee9fdc64b39cb66d73d4be47c
```
## Run Container
```bash
./container/run.sh --framework trtllm -it
```
## Usage Guide
TRTLLM supports multimodal models with dynamo. You can provide multimodal inputs in the following ways:
- By sending image URLs
- By providing paths to pre-computed embedding files
Please note that you should provide **either image URLs or embedding file paths** in a single request.
### Aggregated
Here are quick steps to launch Llama-4 Maverick BF16 in aggregated mode
```bash
cd $DYNAMO_HOME/components/backends/trtllm
export AGG_ENGINE_ARGS=./engine_configs/multinode/agg.yaml
export SERVED_MODEL_NAME="meta-llama/Llama-4-Maverick-17B-128E-Instruct"
export MODEL_PATH="meta-llama/Llama-4-Maverick-17B-128E-Instruct"
./launch/agg.sh
```
### Example Requests
#### With Image URL
Below is an example of an image being sent to `Llama-4-Maverick-17B-128E-Instruct` model
Request :
```bash
curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Describe the image"
},
{
"type": "image_url",
"image_url": {
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png"
}
}
]
}
],
"stream": false,
"max_tokens": 160
}'
```
Response :
```
{"id":"unknown-id","choices":[{"index":0,"message":{"content":"The image depicts a serene landscape featuring a large rock formation, likely El Capitan in Yosemite National Park, California. The scene is characterized by a winding road that curves from the bottom-right corner towards the center-left of the image, with a few rocks and trees lining its edge.\n\n**Key Features:**\n\n* **Rock Formation:** A prominent, tall, and flat-topped rock formation dominates the center of the image.\n* **Road:** A paved road winds its way through the landscape, curving from the bottom-right corner towards the center-left.\n* **Trees and Rocks:** Trees are visible on both sides of the road, with rocks scattered along the left side.\n* **Sky:** The sky above is blue, dotted with white clouds.\n* **Atmosphere:** The overall atmosphere of the","refusal":null,"tool_calls":null,"role":"assistant","function_call":null,"audio":null},"finish_reason":"stop","logprobs":null}],"created":1753322607,"model":"meta-llama/Llama-4-Maverick-17B-128E-Instruct","service_tier":null,"system_fingerprint":null,"object":"chat.completion","usage":null}
```
### Disaggregated
Here are quick steps to launch in disaggregated mode.
The following is an example of launching a model in disaggregated mode. While this example uses `Qwen/Qwen2-VL-7B-Instruct`, you can adapt it for other models by modifying the environment variables for the model path and engine configurations.
```bash
cd $DYNAMO_HOME/components/backends/trtllm
export MODEL_PATH=${MODEL_PATH:-"Qwen/Qwen2-VL-7B-Instruct"}
export SERVED_MODEL_NAME=${SERVED_MODEL_NAME:-"Qwen/Qwen2-VL-7B-Instruct"}
export DISAGGREGATION_STRATEGY=${DISAGGREGATION_STRATEGY:-"decode_first"}
export PREFILL_ENGINE_ARGS=${PREFILL_ENGINE_ARGS:-"engine_configs/multimodal/prefill.yaml"}
export DECODE_ENGINE_ARGS=${DECODE_ENGINE_ARGS:-"engine_configs/multimodal/decode.yaml"}
export MODALITY=${MODALITY:-"multimodal"}
./launch/disagg.sh
```
For a large model like `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, a multi-node setup is required for disaggregated serving, while aggregated serving can run on a single node. This is because the model with a disaggregated configuration is too large to fit on a single node's GPUs. For instance, running this model in disaggregated mode requires a setup of 2 nodes with 8xH200 GPUs or 4 nodes with 4xGB200 GPUs.
In general, disaggregated serving can run on a single node, provided the model fits on the GPU. The multi-node requirement in this example is specific to the size and configuration of the `meta-llama/Llama-4-Maverick-17B-128E-Instruct` model.
To deploy `Llama-4-Maverick-17B-128E-Instruct` in disaggregated mode, you will need to follow the multi-node setup instructions, which can be found [here](./multinode/multinode-multimodal-example.md).
### Using Pre-computed Embeddings (Experimental)
Dynamo with TensorRT-LLM supports providing pre-computed embeddings directly in an inference request. This bypasses the need for the model to process an image and generate embeddings itself, which is useful for performance optimization or when working with custom, pre-generated embeddings.
#### How to Use
Once the container is built, you can send requests with paths to local embedding files.
- **Format:** Provide the embedding as part of the `messages` array, using the `image_url` content type.
- **URL:** The `url` field should contain the absolute or relative path to your embedding file on the local filesystem.
- **File Types:** Supported embedding file extensions are `.pt`, `.pth`, and `.bin`. Dynamo will automatically detect these extensions.
When a request with a supported embedding file is received, Dynamo will load the tensor from the file and pass it directly to the model for inference, skipping the image-to-embedding pipeline.
#### Example Request
Here is an example of how to send a request with a pre-computed embedding file.
```bash
curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Describe the content represented by the embeddings"
},
{
"type": "image_url",
"image_url": {
"url": "/path/to/your/embedding.pt"
}
}
]
}
],
"stream": false,
"max_tokens": 160
}'
```
### Encode-Prefill-Decode (EPD) Flow with NIXL
Dynamo with the TensorRT-LLM backend supports multimodal models in Encode -> Decode -> Prefill fashion, enabling you to process embeddings seperately in a seperate worker. For detailed setup instructions, example requests, and best practices, see the [Multimodal EPD Support Guide](./multimodal_epd.md).
## Supported Multimodal Models
Multimodel models listed [here](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/inputs/utils.py#L221) are supported by dynamo.
\ No newline at end of file
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Any, Dict, Union
import torch
import dynamo.nixl_connect as nixl_connect
class EncodeHelper:
"""Utility class for encoding and serialization operations."""
@staticmethod
def serialize_tensor_dict(tensor_dict: dict) -> dict:
"""Serialize a dictionary of tensors to JSON-serializable format.
Args:
tensor_dict: Dictionary containing tensors and other values
Returns:
Dictionary with tensors converted to JSON-serializable format
Example:
>>> tensor_dict = {"tokens": torch.tensor([1, 2, 3], dtype=torch.int64)}
>>> serialized = EncodeHelper.serialize_tensor_dict(tensor_dict)
>>> # Result: {"tokens": {"data": [1, 2, 3], "shape": [3], "dtype": "torch.int64"}}
"""
serialized = {}
for key, tensor in tensor_dict.items():
if isinstance(tensor, torch.Tensor):
serialized[key] = {
"data": tensor.tolist(),
"shape": list(tensor.shape),
"dtype": str(tensor.dtype),
}
else:
# Non-tensor values pass through unchanged
serialized[key] = tensor
return serialized
@staticmethod
def deserialize_tensor_dict(serialized_dict: dict) -> dict:
"""Deserialize a dictionary back to tensors.
Args:
serialized_dict: Dictionary with serialized tensor data
Returns:
Dictionary with tensors reconstructed from serialized format
Example:
>>> serialized = {"tokens": {"data": [1, 2, 3], "shape": [3], "dtype": "torch.int64"}}
>>> tensors = EncodeHelper.deserialize_tensor_dict(serialized)
>>> # Result: {"tokens": tensor([1, 2, 3], dtype=torch.int64)}
"""
deserialized = {}
for key, value in serialized_dict.items():
if (
isinstance(value, dict)
and "data" in value
and "shape" in value
and "dtype" in value
):
# Reconstruct tensor from serialized format
dtype = EncodeHelper.get_torch_dtype_from_string(value["dtype"])
tensor = torch.tensor(value["data"], dtype=dtype)
deserialized[key] = tensor
else:
# Non-tensor values pass through unchanged
deserialized[key] = value
return deserialized
@staticmethod
def get_torch_dtype_from_string(dtype_str: str) -> torch.dtype:
"""Convert dtype string to torch.dtype object.
Args:
dtype_str: String representation of torch dtype (e.g., "torch.float32")
Returns:
Corresponding torch.dtype object
Example:
>>> dtype = EncodeHelper.get_torch_dtype_from_string("torch.bfloat16")
>>> # Result: torch.bfloat16
"""
dtype_map = {
# Floating point types
"torch.float64": torch.float64,
"torch.float32": torch.float32,
"torch.float16": torch.float16,
"torch.bfloat16": torch.bfloat16,
# FP8 types
"torch.float8_e4m3fn": torch.float8_e4m3fn,
"torch.float8_e4m3fnuz": torch.float8_e4m3fnuz,
"torch.float8_e5m2": torch.float8_e5m2,
"torch.float8_e5m2fnuz": torch.float8_e5m2fnuz,
"torch.float8_e8m0fnu": torch.float8_e8m0fnu,
# Signed integer types
"torch.int64": torch.int64,
"torch.int32": torch.int32,
"torch.int16": torch.int16,
"torch.int8": torch.int8,
# Unsigned integer types
"torch.uint64": torch.uint64,
"torch.uint32": torch.uint32,
"torch.uint16": torch.uint16,
"torch.uint8": torch.uint8,
# Complex types
"torch.complex128": torch.complex128,
"torch.complex64": torch.complex64,
# Quantized types
"torch.qint8": torch.qint8,
"torch.quint8": torch.quint8,
"torch.qint32": torch.qint32,
"torch.quint4x2": torch.quint4x2,
# Boolean type
"torch.bool": torch.bool,
}
return dtype_map.get(dtype_str, torch.float32)
@staticmethod
async def read_embeddings_from_encode_response(
encode_response: Dict[str, Any], connector: nixl_connect.Connector
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Read embeddings from encode worker response using NIXL and reconstruct original format.
Args:
encode_response: Response from encode worker containing metadata and NIXL info
connector: NIXL connector for reading operations
Returns:
Either a single tensor or dictionary containing mm_embeddings and auxiliary data
Raises:
RuntimeError: If there's an error in the encode response or NIXL operations
"""
if nixl_connect is None:
raise RuntimeError("Dynamo NIXL Connect library is not available.")
if "error" in encode_response:
raise RuntimeError(f"EncodeHandler error: {encode_response['error']}")
# Extract dynamic shape, metadata, and auxiliary data
embeddings_shape = encode_response["embeddings_shape"]
embeddings_dtype_str = encode_response["embeddings_dtype"]
auxiliary_data = encode_response.get("auxiliary_data", {})
readable_metadata = nixl_connect.RdmaMetadata.model_validate(
encode_response["nixl_readable_metadata"]
)
# Dynamically allocate tensor with correct shape and dtype
embeddings_dtype = EncodeHelper.get_torch_dtype_from_string(
embeddings_dtype_str
)
encodings_tensor = torch.zeros(*embeddings_shape, dtype=embeddings_dtype)
# Create descriptor for our allocated tensor
descriptor = nixl_connect.Descriptor(encodings_tensor)
# Create read operation to read from EncodeHandler
read_op = await connector.begin_read(readable_metadata, descriptor)
with read_op:
# Wait for the read operation to complete
await read_op.wait_for_completion()
logging.debug(
f"Successfully read embeddings via NIXL: {encodings_tensor.shape}"
)
# Reconstruct original format and return
if auxiliary_data:
# Deserialize auxiliary tensors and reconstruct dictionary format
deserialized_auxiliary = EncodeHelper.deserialize_tensor_dict(
auxiliary_data
)
result = {"mm_embeddings": encodings_tensor}
result.update(deserialized_auxiliary)
return result
else:
# Return just the tensor
return encodings_tensor
@staticmethod
async def process_embedding_request(
request: Dict[str, Any],
multimodal_processor,
connector: nixl_connect.Connector,
):
"""
Process embedding request by loading embeddings and creating NIXL readable operation.
Args:
request: Request containing messages with embedding paths
multimodal_processor: Multimodal processor for loading embeddings
connector: NIXL connector for creating readable operations
Yields:
Response dictionary with NIXL metadata and embeddings info, or error response
"""
# Load embeddings first to get the actual shape
messages = request.get("messages", [])
_, _, embedding_paths = multimodal_processor.extract_prompt_and_media(messages)
if not embedding_paths:
# Placeholder for TRTLLM Encoder to be called
# TRTLLM Encoder will return a memory handler on the encoder GPU with the encodings
logging.warning(
"No embedding paths found, NIXL transfer for image urls not supported by TRTLLM Encoder yet"
)
yield {"error": "No embedding paths found"}
return
# Load the embeddings data
loaded_data = multimodal_processor.load_tensor_from_path_or_url(
embedding_paths[0]
)
# Handle both tensor and dictionary formats
if isinstance(loaded_data, dict):
# Dictionary format (e.g., maverick_mm_embed_seashore_v3.pt)
encodings = loaded_data.get("mm_embeddings")
if encodings is None:
yield {"error": "Dictionary embeddings missing 'mm_embeddings' key"}
return
# Store auxiliary data for later transmission
auxiliary_data = {
k: v for k, v in loaded_data.items() if k != "mm_embeddings"
}
else:
# Tensor format (e.g., llava_next_mm_embed_seashore.pt)
encodings = loaded_data
auxiliary_data = {}
# Create readable operation with main embeddings tensor (works for both formats)
descriptor = nixl_connect.Descriptor(encodings)
with connector.create_readable(descriptor) as readable_op:
# Get the metadata for the readable operation
op_metadata = readable_op.metadata()
# Send back shape info, readable metadata, and serialized auxiliary data
response = {
"nixl_readable_metadata": op_metadata.model_dump(),
"embeddings_shape": list(encodings.shape),
"embeddings_dtype": str(encodings.dtype),
"auxiliary_data": EncodeHelper.serialize_tensor_dict(auxiliary_data),
}
yield response
# Wait for the prefill worker to complete the read operation
logging.debug(
"EncodeHelper waiting for PrefillHandler to read embeddings..."
)
await readable_op.wait_for_completion()
logging.debug("EncodeHelper completed readable operation.")
......@@ -20,6 +20,7 @@ from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from torch.cuda import device_count
from transformers import AutoConfig
import dynamo.nixl_connect as nixl_connect
from dynamo.llm import ModelRuntimeConfig, ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
......@@ -121,6 +122,21 @@ async def init(runtime: DistributedRuntime, config: Config):
.client()
)
encode_client = None
if config.encode_endpoint:
logging.info(
f"Initializing encode worker client for endpoint: {config.encode_endpoint}"
)
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
config.encode_endpoint
)
encode_client = (
await runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
.client()
)
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
......@@ -218,13 +234,20 @@ async def init(runtime: DistributedRuntime, config: Config):
multimodal_processor = MultimodalRequestProcessor(
model_type=model_config.model_type,
model_dir=config.model_path,
max_file_size_mb=config.max_file_size_mb,
tokenizer=tokenizer,
allowed_local_media_path=config.allowed_local_media_path,
)
else:
# We already detokenize inside HandlerBase. No need to also do it in TRTLLM.
default_sampling_params.detokenize = False
connector = None
logging.info("Initializing NIXL Connect.")
connector = nixl_connect.Connector()
await connector.initialize()
async with get_llm_engine(engine_args) as engine:
endpoint = component.endpoint(config.endpoint)
......@@ -248,7 +271,9 @@ async def init(runtime: DistributedRuntime, config: Config):
disaggregation_mode=config.disaggregation_mode,
disaggregation_strategy=config.disaggregation_strategy,
next_client=next_client,
encode_client=encode_client,
multimodal_processor=multimodal_processor,
connector=connector,
)
if next_client:
......
......@@ -13,12 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, List, Optional, Protocol, Tuple
from urllib.parse import urlparse
from urllib.request import urlopen
import torch
from tensorrt_llm.inputs import default_multimodal_input_loader
from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging()
class TokenizerProtocol(Protocol):
"""
......@@ -39,12 +48,80 @@ class MultimodalRequestProcessor:
self,
model_type: str,
model_dir: str,
max_file_size_mb: int,
tokenizer: Optional[TokenizerProtocol] = None,
allowed_local_media_path: str = "",
):
self.model_type = model_type
self.model_dir = model_dir
self.tokenizer = tokenizer
self.modality = ""
self.allowed_local_media_path = allowed_local_media_path
self.max_file_size_mb = max_file_size_mb
self.max_file_size_bytes = max_file_size_mb * 1024 * 1024
def is_url(self, path: str) -> bool:
"""Check if a path is a URL."""
parsed = urlparse(path)
return bool(parsed.scheme and parsed.netloc)
def load_tensor_from_path_or_url(self, path: str) -> torch.Tensor:
"""Load a tensor from either a local file path or a URL."""
if self.is_url(path):
# Download directly to memory using BytesIO (no filesystem ops)
try:
with urlopen(path) as response:
# Read at most max_size + 1 bytes to detect if file exceeds limit
data = response.read(self.max_file_size_bytes + 1)
if len(data) > self.max_file_size_bytes:
raise RuntimeError(
f"File size exceeds limit: {len(data) // (1024*1024)}MB > "
f"{self.max_file_size_mb}MB "
)
tensor_stream = BytesIO(data)
tensor = torch.load(
tensor_stream, map_location="cpu", weights_only=True
)
return tensor
except Exception as e:
# Log actual error for debugging, return generic error to user
logging.error(f"Failed to download or load tensor from URL: {e}")
raise RuntimeError("Failed to load tensor")
else:
# Restrict local file access to configured directory only
try:
# Check if local media path is configured
if not self.allowed_local_media_path:
logging.warning(
"Local file access attempted but no allowed path configured"
)
raise RuntimeError("Failed to load tensor")
resolved_path = Path(path).resolve()
allowed_path = Path(self.allowed_local_media_path).resolve()
# Secure path validation: Check if the resolved path is actually within allowed directory
try:
resolved_path.relative_to(allowed_path)
except ValueError:
logging.warning(
f"Blocked access to file outside {self.allowed_local_media_path}: {path}"
)
raise RuntimeError("Failed to load tensor")
# Check file size before loading
if resolved_path.exists():
file_size = resolved_path.stat().st_size
if file_size > self.max_file_size_bytes:
raise RuntimeError(
f"File size ({file_size // (1024*1024)}MB) exceeds "
f"maximum allowed size ({self.max_file_size_bytes // (1024*1024)}MB)"
)
return torch.load(resolved_path, map_location="cpu", weights_only=True)
except Exception as e:
# Log actual error for debugging, return generic error to user
logging.error(f"Failed to load tensor from local path: {e}")
raise RuntimeError("Failed to load tensor")
def extract_prompt_and_media(
self, messages: List[Dict]
......@@ -70,7 +147,9 @@ class MultimodalRequestProcessor:
return " ".join(text_parts), image_urls, embedding_paths
async def process_openai_request(self, request: Dict) -> Optional[Any]:
async def process_openai_request(
self, request: Dict, embeddings: Any
) -> Optional[Any]:
"""Process OpenAI request and return with multimodal data."""
# Normalize the request to handle OpenAI format
if "stop_conditions" not in request:
......@@ -92,15 +171,23 @@ class MultimodalRequestProcessor:
)
if not image_urls and not embedding_paths:
# No multimodal content, return None
logging.warning("No multimodal content, returning None")
return None
loader_kwargs = {}
if embedding_paths:
mm_embeds = [torch.load(path) for path in embedding_paths]
loader_kwargs["mm_embeddings"] = mm_embeds
if embeddings is not None:
# EPD flow
loader_kwargs["mm_embeddings"] = [embeddings]
logging.debug(f"Using NIXL embeddings in prefill worker: {embeddings}")
elif image_urls:
# Image-only flow
loader_kwargs["media"] = [image_urls]
elif embedding_paths:
# PD flow with no NIXL and no encoder
loader_kwargs["mm_embeddings"] = [
self.load_tensor_from_path_or_url(path) for path in embedding_paths
]
logging.debug(f"Using embedding paths in prefill worker: {embedding_paths}")
# Process with default_multimodal_input_loader
processed_inputs = default_multimodal_input_loader(
......
......@@ -13,14 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
from dataclasses import asdict, dataclass
from enum import Enum
from typing import Optional
from typing import Optional, Union
import torch
from tensorrt_llm import SamplingParams
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from dynamo.nixl_connect import Connector
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.engine import TensorRTLLMEngine
from dynamo.trtllm.multimodal_processor import MultimodalRequestProcessor
......@@ -37,6 +40,7 @@ class DisaggregationMode(Enum):
AGGREGATED = "prefill_and_decode"
PREFILL = "prefill"
DECODE = "decode"
ENCODE = "encode"
class DisaggregationStrategy(Enum):
......@@ -57,9 +61,11 @@ class RequestHandlerConfig:
disaggregation_mode: DisaggregationMode
disaggregation_strategy: DisaggregationStrategy
next_client: object
encode_client: Optional[object] = None
multimodal_processor: Optional[
MultimodalRequestProcessor
] = None # for multimodal support
connector: Optional[Connector] = None
class HandlerBase:
......@@ -75,8 +81,10 @@ class HandlerBase:
self.disaggregation_mode = config.disaggregation_mode
self.disaggregation_strategy = config.disaggregation_strategy
self.next_client = config.next_client
self.encode_client = config.encode_client
self.multimodal_processor = config.multimodal_processor
self.first_generation = True
self.connector = config.connector
def check_error(self, result: dict):
"""
......@@ -89,9 +97,15 @@ class HandlerBase:
result["finish_reason"] == "stop" or result["finish_reason"] == "error"
)
async def generate_locally(self, request: dict):
async def generate_locally(
self, request: dict, embeddings: Optional[Union[torch.Tensor, dict]] = None
):
"""
Generate responses based on the disaggregation mode in the request.
Args:
request: The request dictionary containing generation parameters
embeddings: Optional tensor or dict containing embeddings for multimodal processing
"""
logging.debug(f"Request: {request}")
......@@ -102,7 +116,7 @@ class HandlerBase:
# Check for multimodal request and process it
if self.multimodal_processor:
processed_input = await self.multimodal_processor.process_openai_request(
request
request, embeddings
)
else:
......@@ -139,7 +153,7 @@ class HandlerBase:
num_output_tokens_so_far = 0
sampling_params = self.default_sampling_params
sampling_params = copy.deepcopy(self.default_sampling_params)
for key, value in request["sampling_options"].items():
if not value:
......
......@@ -2,7 +2,10 @@
# SPDX-License-Identifier: Apache-2.0
import copy
import logging
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.encode_helper import EncodeHelper
from dynamo.trtllm.request_handlers.handler_base import (
DisaggregationMode,
DisaggregationStrategy,
......@@ -10,12 +13,15 @@ from dynamo.trtllm.request_handlers.handler_base import (
RequestHandlerConfig,
)
configure_dynamo_logging()
class RequestHandlerFactory:
def __init__(self):
self.handlers = {
"prefill": PrefillHandler,
"decode": DecodeHandler,
"encode": EncodeHandler,
"prefill_and_decode": AggregatedHandler,
}
......@@ -66,6 +72,33 @@ class AggregatedHandler(HandlerBase):
yield res
class EncodeHandler(HandlerBase):
"""
Handler for the encode mode.
"""
def __init__(self, config: RequestHandlerConfig):
super().__init__(config)
async def generate(self, request: dict):
if self.connector:
# Use helper method to process embedding request
async for response in EncodeHelper.process_embedding_request(
request, self.multimodal_processor, self.connector
):
yield response
return
else:
logging.error("encode handler: no Dynamo NIXL connector found")
raise RuntimeError("encode handler: no Dynamo NIXL connector found")
if not request.get("streaming", False):
yield request
return
yield request
class PrefillHandler(HandlerBase):
"""
Handler for the prefill mode.
......@@ -74,16 +107,45 @@ class PrefillHandler(HandlerBase):
def __init__(self, config: RequestHandlerConfig):
super().__init__(config)
async def remote_encode_with_nixl(self, request: dict):
# 2. Get response with shape info and readable metadata
encode_response = None
async for res in await self.encode_client.round_robin(request):
encode_response = res.data()
break
if not encode_response:
raise RuntimeError("Did not receive a response from the encode worker.")
# Use utility function to handle NIXL reading and reconstruction
return await EncodeHelper.read_embeddings_from_encode_response(
encode_response, self.connector
)
async def remote_decode(self, request: dict):
async for res in await self.next_client.round_robin(request):
yield res.data()
async def generate(self, request: dict):
logging.debug(f"PrefillHandler.generate received request: {request}")
embeddings_tensor = None
if self.multimodal_processor:
_, _, embedding_paths = self.multimodal_processor.extract_prompt_and_media(
request.get("messages", [])
)
# This check will be removed once TRTLLM Encoder is integrated.
if embedding_paths:
if self.encode_client and self.connector:
logging.debug(
"PrefillHandler calling Encode Worker via remote_encode_with_nixl"
)
embeddings_tensor = await self.remote_encode_with_nixl(request)
# Generate the prefill response locally
prefill_request = copy.deepcopy(request)
prefill_response = None
response_count = 0
async for res in self.generate_locally(prefill_request):
async for res in self.generate_locally(prefill_request, embeddings_tensor):
prefill_response = res
response_count += 1
if response_count > 1:
......
......@@ -16,6 +16,7 @@ from dynamo.trtllm.request_handlers.handler_base import (
DEFAULT_ENDPOINT = "dyn://dynamo.tensorrt_llm.generate"
DEFAULT_MODEL_PATH = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
DEFAULT_NEXT_ENDPOINT = "dyn://dynamo.tensorrt_llm_next.generate"
DEFAULT_ENCODE_ENDPOINT = "dyn://dynamo.tensorrt_llm_encode.generate"
DEFAULT_DISAGGREGATION_STRATEGY = DisaggregationStrategy.DECODE_FIRST
DEFAULT_DISAGGREGATION_MODE = DisaggregationMode.AGGREGATED
......@@ -47,8 +48,10 @@ class Config:
DEFAULT_DISAGGREGATION_STRATEGY
)
self.next_endpoint: str = ""
self.encode_endpoint: str = ""
self.modality: str = "text"
self.allowed_local_media_path: str = ""
self.max_file_size_mb: int = 50
self.reasoning_parser: Optional[str] = None
self.tool_call_parser: Optional[str] = None
......@@ -75,9 +78,12 @@ class Config:
f"disaggregation_mode={self.disaggregation_mode}, "
f"disaggregation_strategy={self.disaggregation_strategy}, "
f"next_endpoint={self.next_endpoint}, "
f"modality={self.modality})"
f"reasoning_parser={self.reasoning_parser})"
f"tool_call_parser={self.tool_call_parser})"
f"encode_endpoint={self.encode_endpoint}, "
f"modality={self.modality}, "
f"allowed_local_media_path={self.allowed_local_media_path}, "
f"max_file_size_mb={self.max_file_size_mb}, "
f"reasoning_parser={self.reasoning_parser}, "
f"tool_call_parser={self.tool_call_parser}"
)
......@@ -219,6 +225,12 @@ def cmd_line_args():
choices=[mode.value for mode in DisaggregationMode],
help=f"Mode to use for disaggregation. Default: {DEFAULT_DISAGGREGATION_MODE}",
)
parser.add_argument(
"--use-nixl-connect",
type=bool,
default=False,
help="Use NIXL Connect for communication between workers.",
)
parser.add_argument(
"--disaggregation-strategy",
type=str,
......@@ -239,7 +251,24 @@ def cmd_line_args():
default="",
help=f"Endpoint(in 'dyn://namespace.component.endpoint' format) to send requests to when running in disaggregation mode. Default: {DEFAULT_NEXT_ENDPOINT} if first worker, empty if next worker",
)
parser.add_argument(
"--encode-endpoint",
type=str,
default="",
help=f"Endpoint(in 'dyn://namespace.component.endpoint' format) for the encode worker. Default: {DEFAULT_ENCODE_ENDPOINT}",
)
parser.add_argument(
"--allowed-local-media-path",
type=str,
default="",
help="Path to a directory that is allowed to be accessed by the model. Default: empty",
)
parser.add_argument(
"--max-file-size-mb",
type=int,
default=50,
help="Maximum size of downloadable embedding files/Image URLs. Default: 50MB",
)
# To avoid name conflicts with different backends, adoped prefix "dyn-" for dynamo specific args
parser.add_argument(
"--dyn-tool-call-parser",
......@@ -280,6 +309,9 @@ def cmd_line_args():
and config.disaggregation_mode != DisaggregationMode.AGGREGATED
):
args.next_endpoint = DEFAULT_NEXT_ENDPOINT
elif config.disaggregation_mode == DisaggregationMode.ENCODE:
if args.endpoint == "":
args.endpoint = DEFAULT_ENCODE_ENDPOINT
else:
if args.endpoint == "":
args.endpoint = DEFAULT_NEXT_ENDPOINT
......@@ -294,6 +326,9 @@ def cmd_line_args():
config.component = parsed_component_name
config.endpoint = parsed_endpoint_name
config.next_endpoint = args.next_endpoint
config.encode_endpoint = args.encode_endpoint
config.allowed_local_media_path = args.allowed_local_media_path
config.max_file_size_mb = args.max_file_size_mb
config.tensor_parallel_size = args.tensor_parallel_size
if args.pipeline_parallel_size is not None:
......
......@@ -46,9 +46,7 @@ The metadata contains required information (identifiers, keys, etc.) which enabl
```python
@async_on_start
async def async_init(self):
runtime = dynamo_context["runtime"]
self.connector = dynamo.nixl_connect.Connector(runtime=runtime)
self.connector = dynamo.nixl_connect.Connector()
await self.connector.initialize()
```
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment