"lib/vscode:/vscode.git/clone" did not exist on "7fabe7bfe2a5601858cbe72c94f16f04da2ca0f7"
Unverified Commit f923777e authored by Indrajit Bhosale's avatar Indrajit Bhosale Committed by GitHub
Browse files

fix: replace torch.load with safetensors and enable Rust frontend media...


fix: replace torch.load with safetensors and enable Rust frontend media decoding for TRT-LLM multimodal (#8295)
Signed-off-by: default avatarIndrajit Bhosale <iamindrajitb@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent fd361c82
...@@ -137,6 +137,17 @@ class ImageLoader: ...@@ -137,6 +137,17 @@ class ImageLoader:
finally: finally:
self._inflight.pop(key, None) self._inflight.pop(key, None)
async def _read_and_convert_nixl_image(
self, metadata: Dict[str, Any]
) -> Image.Image:
"""Read decoded image via NIXL and convert numpy array to PIL Image."""
assert self._nixl_connector is not None
arr = await read_decoded_media_via_nixl(self._nixl_connector, metadata)
# TRT-LLM's input processor requires PIL Images (accesses .height/.width
# for token count calculation). fromarray() is near-zero-cost: it wraps
# the existing numpy buffer without copying pixel data.
return Image.fromarray(arr)
@_nvtx.annotate("mm:img:load_image", color="lime") @_nvtx.annotate("mm:img:load_image", color="lime")
async def load_image(self, image_url: str) -> Image.Image: async def load_image(self, image_url: str) -> Image.Image:
parsed_url = urlparse(image_url) parsed_url = urlparse(image_url)
...@@ -222,9 +233,7 @@ class ImageLoader: ...@@ -222,9 +233,7 @@ class ImageLoader:
metadata = item[DECODED_VARIANT_KEY] metadata = item[DECODED_VARIANT_KEY]
if self._nixl_connector is None: if self._nixl_connector is None:
raise RuntimeError("NIXL connector is not initialized") raise RuntimeError("NIXL connector is not initialized")
image_futures.append( image_futures.append(self._read_and_convert_nixl_image(metadata))
read_decoded_media_via_nixl(self._nixl_connector, metadata)
)
else: else:
logger.error( logger.error(
"Received Decoded multimodal data but enable_frontend_decoding=False. " "Received Decoded multimodal data but enable_frontend_decoding=False. "
......
...@@ -215,6 +215,17 @@ class DynamoTrtllmArgGroup(ArgGroup): ...@@ -215,6 +215,17 @@ class DynamoTrtllmArgGroup(ArgGroup):
arg_type=int, arg_type=int,
help="Maximum size of downloadable embedding files/Image URLs.", help="Maximum size of downloadable embedding files/Image URLs.",
) )
add_negatable_bool_argument(
g,
flag_name="--frontend-decoding",
env_var="DYN_TRTLLM_FRONTEND_DECODING",
default=False,
help=(
"Enable frontend decoding of multimodal images. "
"When enabled, images are decoded in the Rust frontend and transferred to the backend via NIXL RDMA. "
"Without this flag, images are decoded in the Python backend (default behavior)."
),
)
# --- Guided Decoding --- # --- Guided Decoding ---
add_argument( add_argument(
...@@ -479,6 +490,7 @@ class DynamoTrtllmConfig(ConfigBase): ...@@ -479,6 +490,7 @@ class DynamoTrtllmConfig(ConfigBase):
encode_endpoint: str encode_endpoint: str
allowed_local_media_path: str allowed_local_media_path: str
max_file_size_mb: int max_file_size_mb: int
frontend_decoding: bool
default_height: int default_height: int
default_width: int default_width: int
......
...@@ -209,8 +209,8 @@ class EncodeHelper: ...@@ -209,8 +209,8 @@ class EncodeHelper:
# Two supported flows: # Two supported flows:
# #
# 1. EMBEDDING-PATH FLOW (Pre-computed embeddings via NIXL) # 1. EMBEDDING-PATH FLOW (Pre-computed embeddings via NIXL)
# - User sends URL ending in .pt/.pth/.bin # - User sends URL ending in .safetensors
# - Encode worker loads tensor, creates NIXL readable op # - Encode worker loads tensor (via safetensors), creates NIXL readable op
# - Prefill worker reads embeddings via RDMA # - Prefill worker reads embeddings via RDMA
# - Use case: Customer has pre-computed embeddings from custom encoder # - Use case: Customer has pre-computed embeddings from custom encoder
# #
...@@ -235,7 +235,7 @@ class EncodeHelper: ...@@ -235,7 +235,7 @@ class EncodeHelper:
for the prefill worker to read via RDMA. for the prefill worker to read via RDMA.
Args: Args:
embedding_paths: List of paths to embedding files (.pt/.pth/.bin) embedding_paths: List of paths to embedding files (.safetensors)
multimodal_processor: Processor to load embeddings multimodal_processor: Processor to load embeddings
connector: NIXL connector for RDMA transfer connector: NIXL connector for RDMA transfer
...@@ -460,5 +460,5 @@ class EncodeHelper: ...@@ -460,5 +460,5 @@ class EncodeHelper:
# No valid multimodal content found # No valid multimodal content found
else: else:
yield { yield {
"error": "No embedding_paths or image_urls found in request, or image_urls without text_prompt or token_ids" "error": "No embedding_paths (.safetensors) or image_urls found in request, or image_urls without text_prompt or token_ids"
} }
...@@ -15,13 +15,14 @@ ...@@ -15,13 +15,14 @@
import logging import logging
import time import time
from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Protocol, Tuple from typing import Any, Dict, List, Optional, Protocol, Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
from urllib.request import urlopen
import httpx
import torch import torch
from safetensors.torch import load as safetensors_load
from safetensors.torch import load_file as safetensors_load_file
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from dynamo.common.multimodal.image_loader import ImageLoader from dynamo.common.multimodal.image_loader import ImageLoader
...@@ -57,6 +58,7 @@ class MultimodalRequestProcessor: ...@@ -57,6 +58,7 @@ class MultimodalRequestProcessor:
max_file_size_mb: int, max_file_size_mb: int,
tokenizer: Optional[TokenizerProtocol] = None, tokenizer: Optional[TokenizerProtocol] = None,
allowed_local_media_path: str = "", allowed_local_media_path: str = "",
enable_frontend_decoding: bool = False,
): ):
self.model_type = model_type self.model_type = model_type
self.model_dir = model_dir self.model_dir = model_dir
...@@ -73,7 +75,9 @@ class MultimodalRequestProcessor: ...@@ -73,7 +75,9 @@ class MultimodalRequestProcessor:
else: else:
self.tokenizer = tokenizer_factory(model_dir) self.tokenizer = tokenizer_factory(model_dir)
self.image_loader = ImageLoader() self.image_loader = ImageLoader(
enable_frontend_decoding=enable_frontend_decoding
)
def is_url(self, path: str) -> bool: def is_url(self, path: str) -> bool:
"""Check if a path is a URL.""" """Check if a path is a URL."""
...@@ -83,45 +87,86 @@ class MultimodalRequestProcessor: ...@@ -83,45 +87,86 @@ class MultimodalRequestProcessor:
return False return False
return bool(parsed.scheme and parsed.netloc) return bool(parsed.scheme and parsed.netloc)
def load_tensor_from_path_or_url(self, path: str) -> torch.Tensor: def _unwrap_safetensors(
"""Load a tensor from either a local file path or a URL.""" self, data: Dict[str, torch.Tensor]
) -> "torch.Tensor | Dict[str, torch.Tensor]":
"""Return a single tensor when the file has one key, else the full dict.
Multi-key files (e.g. Maverick/Scout with mm_embeddings +
image_special_tokens + image_special_token_offsets) need the
full dict so encode_helper can extract auxiliary data.
"""
if len(data) == 1:
return next(iter(data.values()))
return data
def load_tensor_from_path_or_url(
self, path: str
) -> "torch.Tensor | Dict[str, torch.Tensor]":
"""Load tensors from a local .safetensors path or URL.
Returns a single tensor for single-key files (e.g. LLaVA-NeXT),
or a dict of tensors for multi-key files (e.g. Maverick/Scout).
Only .safetensors format is accepted.
"""
parsed = urlparse(path)
lower_path = parsed.path.lower()
if lower_path.endswith((".pt", ".pth", ".bin")):
raise RuntimeError(
"Unsafe tensor format: .pt/.pth/.bin files are not allowed. "
"Use .safetensors format instead."
)
if not lower_path.endswith(".safetensors"):
raise RuntimeError("Only .safetensors embedding files are supported.")
if self.is_url(path): if self.is_url(path):
# Download directly to memory using BytesIO (no filesystem ops) if parsed.scheme not in ("http", "https"):
raise RuntimeError(f"Unsupported URL scheme: {parsed.scheme}")
try: try:
with urlopen(path) as response: with httpx.Client(timeout=300.0) as client:
# Read at most max_size + 1 bytes to detect if file exceeds limit with client.stream("GET", path) as resp:
data = response.read(self.max_file_size_bytes + 1) resp.raise_for_status()
if len(data) > self.max_file_size_bytes: content_length = resp.headers.get("content-length")
raise RuntimeError( if (
f"File size exceeds limit: {len(data) // (1024*1024)}MB > " content_length
f"{self.max_file_size_mb}MB " and int(content_length) > self.max_file_size_bytes
) ):
tensor_stream = BytesIO(data) raise RuntimeError(
tensor = torch.load( f"File size exceeds limit: "
tensor_stream, map_location="cpu", weights_only=True f"{int(content_length) // (1024*1024)}MB > "
) f"{self.max_file_size_mb}MB"
return tensor )
chunks = []
downloaded = 0
for chunk in resp.iter_bytes():
downloaded += len(chunk)
if downloaded > self.max_file_size_bytes:
raise RuntimeError(
f"File size exceeds limit: "
f"{downloaded // (1024*1024)}MB > "
f"{self.max_file_size_mb}MB"
)
chunks.append(chunk)
content = b"".join(chunks)
data = safetensors_load(content)
return self._unwrap_safetensors(data)
except RuntimeError:
raise
except Exception as e: except Exception as e:
# Log actual error for debugging, return generic error to user
logging.error(f"Failed to download or load tensor from URL: {e}") logging.error(f"Failed to download or load tensor from URL: {e}")
raise RuntimeError("Failed to load tensor") raise RuntimeError("Failed to load tensor")
else: else:
# Restrict local file access to configured directory only
try: try:
# Check if local media path is configured
if not self.allowed_local_media_path: if not self.allowed_local_media_path:
logging.warning( logging.warning(
"Local file access attempted but no allowed path configured" "Local file access attempted but no allowed path configured"
) )
raise RuntimeError("Failed to load tensor") raise RuntimeError("Failed to load tensor")
# Strip file:// prefix if present
local_path = path.removeprefix("file://") local_path = path.removeprefix("file://")
resolved_path = Path(local_path).resolve() resolved_path = Path(local_path).resolve()
allowed_path = Path(self.allowed_local_media_path).resolve() allowed_path = Path(self.allowed_local_media_path).resolve()
# Secure path validation: Check if the resolved path is actually within allowed directory
try: try:
resolved_path.relative_to(allowed_path) resolved_path.relative_to(allowed_path)
except ValueError: except ValueError:
...@@ -130,17 +175,19 @@ class MultimodalRequestProcessor: ...@@ -130,17 +175,19 @@ class MultimodalRequestProcessor:
) )
raise RuntimeError("Failed to load tensor") raise RuntimeError("Failed to load tensor")
# Check file size before loading if not resolved_path.exists():
if resolved_path.exists(): raise RuntimeError(f"Embedding file not found: {resolved_path}")
file_size = resolved_path.stat().st_size file_size = resolved_path.stat().st_size
if file_size > self.max_file_size_bytes: if file_size > self.max_file_size_bytes:
raise RuntimeError( raise RuntimeError(
f"File size ({file_size // (1024*1024)}MB) exceeds " f"File size ({file_size // (1024*1024)}MB) exceeds "
f"maximum allowed size ({self.max_file_size_bytes // (1024*1024)}MB)" f"maximum allowed size ({self.max_file_size_bytes // (1024*1024)}MB)"
) )
return torch.load(resolved_path, map_location="cpu", weights_only=True) data = safetensors_load_file(str(resolved_path))
return self._unwrap_safetensors(data)
except RuntimeError:
raise
except Exception as e: except Exception as e:
# Log actual error for debugging, return generic error to user
logging.error(f"Failed to load tensor from local path: {e}") logging.error(f"Failed to load tensor from local path: {e}")
raise RuntimeError("Failed to load tensor") raise RuntimeError("Failed to load tensor")
...@@ -164,7 +211,7 @@ class MultimodalRequestProcessor: ...@@ -164,7 +211,7 @@ class MultimodalRequestProcessor:
if not url: if not url:
continue continue
self.modality = "image" self.modality = "image"
if url.endswith((".pt", ".pth", ".bin")): if url.endswith(".safetensors"):
embedding_paths.append(url) embedding_paths.append(url)
else: else:
image_urls.append(url) image_urls.append(url)
...@@ -247,7 +294,7 @@ class MultimodalRequestProcessor: ...@@ -247,7 +294,7 @@ class MultimodalRequestProcessor:
multi_modal_data = request.get("multi_modal_data") multi_modal_data = request.get("multi_modal_data")
if multi_modal_data and isinstance(multi_modal_data, dict): if multi_modal_data and isinstance(multi_modal_data, dict):
processed_mm_data = {} processed_mm_data = {}
loaded_embeddings = [] loaded_embeddings: list[torch.Tensor] = []
# Process images and embedding paths from image_url field # Process images and embedding paths from image_url field
image_items = multi_modal_data.get("image_url", []) image_items = multi_modal_data.get("image_url", [])
...@@ -274,8 +321,7 @@ class MultimodalRequestProcessor: ...@@ -274,8 +321,7 @@ class MultimodalRequestProcessor:
) )
continue continue
# Check if this is an embedding file based on extension if url.endswith(".safetensors"):
if url.endswith((".pt", ".pth", ".bin")):
embedding_paths.append(url) embedding_paths.append(url)
else: else:
# Keep original item format for load_image_batch # Keep original item format for load_image_batch
...@@ -299,14 +345,25 @@ class MultimodalRequestProcessor: ...@@ -299,14 +345,25 @@ class MultimodalRequestProcessor:
logging.error(f"Failed to load images: {e}") logging.error(f"Failed to load images: {e}")
return None return None
# Load embedding files (.pt, .pth, .bin) for PD flow # Load pre-computed vision encoder embeddings (.safetensors) for PD flow
# These are pre-computed vision encoder outputs
if embedding_paths: if embedding_paths:
try: try:
loaded_embeddings = [ raw_loaded = [
self.load_tensor_from_path_or_url(path) self.load_tensor_from_path_or_url(path)
for path in embedding_paths for path in embedding_paths
] ]
loaded_embeddings = []
for item in raw_loaded:
if isinstance(item, dict):
emb = item.get("mm_embeddings")
if emb is None:
logging.error(
"Dictionary embeddings missing 'mm_embeddings' key"
)
return None
loaded_embeddings.append(emb)
else:
loaded_embeddings.append(item)
if loaded_embeddings: if loaded_embeddings:
logging.info( logging.info(
f"Loaded {len(loaded_embeddings)} embedding file(s) from paths: {embedding_paths}" f"Loaded {len(loaded_embeddings)} embedding file(s) from paths: {embedding_paths}"
......
...@@ -63,6 +63,18 @@ from dynamo.trtllm.request_handlers.handlers import ( ...@@ -63,6 +63,18 @@ from dynamo.trtllm.request_handlers.handlers import (
) )
from dynamo.trtllm.utils.trtllm_utils import deep_update from dynamo.trtllm.utils.trtllm_utils import deep_update
# Optional imports for Rust frontend media decoding support
MediaDecoder: type | None = None
MediaFetcher: type | None = None
try:
from dynamo.llm import MediaDecoder, MediaFetcher
MEDIA_DECODER_AVAILABLE = True
except ImportError:
MediaDecoder = None
MediaFetcher = None
MEDIA_DECODER_AVAILABLE = False
# Default buffer size for kv cache events. # Default buffer size for kv cache events.
DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024 DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024
...@@ -410,6 +422,7 @@ async def init_llm_worker( ...@@ -410,6 +422,7 @@ async def init_llm_worker(
max_file_size_mb=config.max_file_size_mb, max_file_size_mb=config.max_file_size_mb,
tokenizer=tokenizer, tokenizer=tokenizer,
allowed_local_media_path=config.allowed_local_media_path, allowed_local_media_path=config.allowed_local_media_path,
enable_frontend_decoding=config.frontend_decoding,
) )
else: else:
...@@ -586,6 +599,21 @@ async def init_llm_worker( ...@@ -586,6 +599,21 @@ async def init_llm_worker(
disagg_machine_id=int(endpoint.connection_id()) % 1021, disagg_machine_id=int(endpoint.connection_id()) % 1021,
) )
media_decoder = None
media_fetcher = None
if config.frontend_decoding:
if not MEDIA_DECODER_AVAILABLE:
raise RuntimeError(
"--frontend-decoding requires MediaDecoder support. "
"Ensure dynamo.llm module includes MediaDecoder and MediaFetcher."
)
assert MediaDecoder is not None and MediaFetcher is not None
media_decoder = MediaDecoder()
media_decoder.enable_image({"limits": {"max_alloc": 128 * 1024 * 1024}})
media_fetcher = MediaFetcher()
media_fetcher.timeout_ms(30000)
media_fetcher.allow_direct_port(False)
# Register the model with runtime config # Register the model with runtime config
# Encode workers do NOT register - they're internal workers only # Encode workers do NOT register - they're internal workers only
# Prefill and decode workers register - frontend detects their role via ModelType # Prefill and decode workers register - frontend detects their role via ModelType
...@@ -600,6 +628,8 @@ async def init_llm_worker( ...@@ -600,6 +628,8 @@ async def init_llm_worker(
kv_cache_block_size=config.kv_block_size, kv_cache_block_size=config.kv_block_size,
runtime_config=runtime_config, runtime_config=runtime_config,
custom_template_path=config.custom_jinja_template, custom_template_path=config.custom_jinja_template,
media_decoder=media_decoder,
media_fetcher=media_fetcher,
) )
# Get health check payload (checks env var and falls back to TensorRT-LLM default) # Get health check payload (checks env var and falls back to TensorRT-LLM default)
......
...@@ -685,7 +685,7 @@ async def register_vllm_model( ...@@ -685,7 +685,7 @@ async def register_vllm_model(
media_fetcher = MediaFetcher() media_fetcher = MediaFetcher()
media_fetcher.timeout_ms(30000) media_fetcher.timeout_ms(30000)
media_fetcher.allow_direct_port(True) media_fetcher.allow_direct_port(False)
await register_model( await register_model(
model_input, model_input,
......
...@@ -17,7 +17,7 @@ You can provide multimodal inputs in the following ways: ...@@ -17,7 +17,7 @@ You can provide multimodal inputs in the following ways:
| Modality | Input Format | Aggregated | Disaggregated | Notes | | Modality | Input Format | Aggregated | Disaggregated | Notes |
|----------|--------------|------------|---------------|-------| |----------|--------------|------------|---------------|-------|
| **Image** | HTTP/HTTPS URL | Yes | Yes | Full support for all image models | | **Image** | HTTP/HTTPS URL | Yes | Yes | Full support for all image models |
| **Image** | Pre-computed Embeddings (.pt, .pth, .bin) | Yes | Yes | Direct embedding files | | **Image** | Pre-computed Embeddings (.safetensors) | Yes | Yes | Direct embedding files |
| **Video** | HTTP/HTTPS URL | No | No | Not implemented | | **Video** | HTTP/HTTPS URL | No | No | Not implemented |
| **Audio** | HTTP/HTTPS URL | No | No | Not implemented | | **Audio** | HTTP/HTTPS URL | No | No | Not implemented |
...@@ -26,7 +26,7 @@ You can provide multimodal inputs in the following ways: ...@@ -26,7 +26,7 @@ You can provide multimodal inputs in the following ways:
| Format | Example | Description | | Format | Example | Description |
|--------|---------|-------------| |--------|---------|-------------|
| **HTTP/HTTPS** | `http://example.com/image.jpg` | Remote media files | | **HTTP/HTTPS** | `http://example.com/image.jpg` | Remote media files |
| **Pre-computed Embeddings** | `/path/to/embedding.pt` | Local embedding files (.pt, .pth, .bin) | | **Pre-computed Embeddings** | `/path/to/embedding.safetensors` | Local embedding files (.safetensors only) |
## Deployment Patterns ## Deployment Patterns
...@@ -221,40 +221,24 @@ For high-performance multimodal inference, Dynamo supports pre-computed embeddin ...@@ -221,40 +221,24 @@ For high-performance multimodal inference, Dynamo supports pre-computed embeddin
### Supported File Types ### Supported File Types
- `.pt` - PyTorch tensor files - `.safetensors` - Safe tensor files ([safetensors format](https://huggingface.co/docs/safetensors))
- `.pth` - PyTorch checkpoint files
- `.bin` - Binary tensor files
### Embedding File Formats > **Security Note:** `.pt`, `.pth`, and `.bin` files are **rejected** because they use Python pickle deserialization, which can execute arbitrary code. Only `.safetensors` format is accepted.
TRT-LLM supports two formats for embedding files: ### Embedding File Formats
**1. Simple Tensor Format** Embedding files must use the `.safetensors` format. The first tensor key in the file is used as the embedding tensor.
Direct tensor saved as `.pt` file containing only the embedding tensor: **Saving embeddings:**
```python ```python
embedding_tensor = torch.rand(1, 576, 4096) # [batch, seq_len, hidden_dim] from safetensors.torch import save_file
torch.save(embedding_tensor, "embedding.pt") import torch
```
**2. Dictionary Format with Auxiliary Data**
Dictionary containing multiple keys, used by models like Llama-4 that require additional metadata: embedding_tensor = torch.rand(1, 576, 4096) # [batch, seq_len, hidden_dim]
save_file({"embedding": embedding_tensor}, "embedding.safetensors")
```python
embedding_dict = {
"mm_embeddings": torch.rand(1, 576, 4096),
"special_tokens": [128256, 128257],
"image_token_offsets": [[0, 576]],
# ... other model-specific metadata
}
torch.save(embedding_dict, "llama4_embedding.pt")
``` ```
- **Simple tensors**: Loaded directly and passed to `mm_embeddings` parameter
- **Dictionary format**: `mm_embeddings` key extracted as main tensor, other keys preserved as auxiliary data
### How to Launch ### How to Launch
```bash ```bash
...@@ -264,7 +248,7 @@ cd $DYNAMO_HOME/examples/backends/trtllm ...@@ -264,7 +248,7 @@ cd $DYNAMO_HOME/examples/backends/trtllm
./launch/epd_disagg.sh ./launch/epd_disagg.sh
``` ```
> **Note:** This script is designed for 8-node H200 with `Llama-4-Scout-17B-16E-Instruct` model and assumes you have a model-specific embedding file ready. > **Note:** This script is designed for 8-node H200 with `Llama-4-Scout-17B-16E-Instruct` model and assumes you have a model-specific `.safetensors` embedding file ready.
### Configuration ### Configuration
...@@ -289,7 +273,7 @@ curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d ' ...@@ -289,7 +273,7 @@ curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '
"role": "user", "role": "user",
"content": [ "content": [
{"type": "text", "text": "Describe the image"}, {"type": "text", "text": "Describe the image"},
{"type": "image_url", "image_url": {"url": "/path/to/embedding.pt"}} {"type": "image_url", "image_url": {"url": "/path/to/embedding.safetensors"}}
] ]
} }
], ],
...@@ -316,7 +300,7 @@ sequenceDiagram ...@@ -316,7 +300,7 @@ sequenceDiagram
Client->>Frontend: POST /v1/chat/completions Client->>Frontend: POST /v1/chat/completions
Frontend->>PrefillWorker: Route to prefill worker Frontend->>PrefillWorker: Route to prefill worker
PrefillWorker->>EncodeWorker: Send request (embedding paths) PrefillWorker->>EncodeWorker: Send request (embedding .safetensors paths)
EncodeWorker->>NIXL: Create readable operation EncodeWorker->>NIXL: Create readable operation
EncodeWorker->>PrefillWorker: Send metadata + NIXL info EncodeWorker->>PrefillWorker: Send metadata + NIXL info
PrefillWorker->>NIXL: Begin read operation PrefillWorker->>NIXL: Begin read operation
...@@ -401,10 +385,10 @@ await register_model( ...@@ -401,10 +385,10 @@ await register_model(
| Transfer Stage | Message | NIXL Transfer | | Transfer Stage | Message | NIXL Transfer |
|----------------|---------|---------------| |----------------|---------|---------------|
| **Frontend → Prefill** | Request with image URL or embedding path | No | | **Frontend → Prefill** | Request with image URL or .safetensors embedding path | No |
| **Prefill → Encode (Image URL)** | Request with image URL | No | | **Prefill → Encode (Image URL)** | Request with image URL | No |
| **Encode → Prefill (Image URL)** | `ep_disaggregated_params` with `multimodal_embedding_handles`, processed prompt, and token IDs | No | | **Encode → Prefill (Image URL)** | `ep_disaggregated_params` with `multimodal_embedding_handles`, processed prompt, and token IDs | No |
| **Prefill → Encode (Embedding Path)** | Request with embedding file path | No | | **Prefill → Encode (Embedding Path)** | Request with .safetensors embedding file path | No |
| **Encode → Prefill (Embedding Path)** | NIXL readable metadata + shape/dtype + auxiliary data | Yes (Embeddings tensor via RDMA) | | **Encode → Prefill (Embedding Path)** | NIXL readable metadata + shape/dtype + auxiliary data | Yes (Embeddings tensor via RDMA) |
| **Prefill → Decode** | `disaggregated_params` with `_epd_metadata` (prompt, token IDs) | Configurable (KV cache: NIXL default, UCX optional) | | **Prefill → Decode** | `disaggregated_params` with `_epd_metadata` (prompt, token IDs) | Configurable (KV cache: NIXL default, UCX optional) |
......
...@@ -5,18 +5,18 @@ ...@@ -5,18 +5,18 @@
# LLaVA Raw-Embeddings E/PD Test # LLaVA Raw-Embeddings E/PD Test
# #
# Phase 1 — Run HuggingFace vision encoder standalone to produce # Phase 1 — Run HuggingFace vision encoder standalone to produce
# pre-computed embeddings at $EMBEDDINGS_FILE (.pt tensor). # pre-computed embeddings at $EMBEDDINGS_FILE (.safetensors format).
# #
# Phase 2 — Start Encode + Aggregated PD workers for LLaVA, then # Phase 2 — Start Encode + Aggregated PD workers for LLaVA, then
# accept chat/completions requests whose image_url points # accept chat/completions requests whose image_url points
# to the embeddings file (file:///tmp/llava_embeddings.pt). # to the embeddings file (file:///tmp/llava_embeddings.safetensors).
# #
# Known limitation: The default revision of llava-hf/llava-v1.6-mistral-7b-hf # Known limitation: The default revision of llava-hf/llava-v1.6-mistral-7b-hf
# may crash with certain TRT-LLM versions. Set MODEL_REVISION to pin a # may crash with certain TRT-LLM versions. Set MODEL_REVISION to pin a
# safe commit (e.g. 52320fb52229). # safe commit (e.g. 52320fb52229).
set -e set -e
trap 'echo Cleaning up...; rm -f "${EMBEDDINGS_FILE:-/tmp/llava_embeddings.pt}" /tmp/_resolved_model_path.txt; kill 0' EXIT trap 'echo Cleaning up...; rm -f "${EMBEDDINGS_FILE:-/tmp/llava_embeddings.safetensors}" /tmp/_resolved_model_path.txt; kill 0' EXIT
SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
...@@ -37,7 +37,7 @@ export MAX_FILE_SIZE_MB=${MAX_FILE_SIZE_MB:-50} ...@@ -37,7 +37,7 @@ export MAX_FILE_SIZE_MB=${MAX_FILE_SIZE_MB:-50}
export CUSTOM_TEMPLATE=${CUSTOM_TEMPLATE:-"$DYNAMO_HOME/examples/backends/trtllm/templates/llava_multimodal.jinja"} export CUSTOM_TEMPLATE=${CUSTOM_TEMPLATE:-"$DYNAMO_HOME/examples/backends/trtllm/templates/llava_multimodal.jinja"}
# Embeddings configuration # Embeddings configuration
EMBEDDINGS_FILE="${EMBEDDINGS_FILE:-/tmp/llava_embeddings.pt}" EMBEDDINGS_FILE="${EMBEDDINGS_FILE:-/tmp/llava_embeddings.safetensors}"
TEST_IMAGE_URL="${TEST_IMAGE_URL:-https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png}" TEST_IMAGE_URL="${TEST_IMAGE_URL:-https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png}"
# Extra arguments forwarded to the PD worker (e.g. --multimodal-embedding-cache-capacity-gb 10) # Extra arguments forwarded to the PD worker (e.g. --multimodal-embedding-cache-capacity-gb 10)
...@@ -71,13 +71,14 @@ CUDA_VISIBLE_DEVICES=0 python3 - <<'PYEOF' ...@@ -71,13 +71,14 @@ CUDA_VISIBLE_DEVICES=0 python3 - <<'PYEOF'
import torch, io, os, urllib.request import torch, io, os, urllib.request
from PIL import Image from PIL import Image
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from safetensors.torch import save_file as safetensors_save_file
from transformers import LlavaNextForConditionalGeneration, LlavaNextProcessor from transformers import LlavaNextForConditionalGeneration, LlavaNextProcessor
model_id = os.environ["MODEL_PATH"] model_id = os.environ["MODEL_PATH"]
revision = os.environ.get("MODEL_REVISION", "") or None revision = os.environ.get("MODEL_REVISION", "") or None
image_url = os.environ.get("TEST_IMAGE_URL", image_url = os.environ.get("TEST_IMAGE_URL",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png") "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
output = os.environ.get("EMBEDDINGS_FILE", "/tmp/llava_embeddings.pt") output = os.environ.get("EMBEDDINGS_FILE", "/tmp/llava_embeddings.safetensors")
# ── Download / resolve model ── # ── Download / resolve model ──
print(f"Resolving model {model_id} (revision={revision}) …") print(f"Resolving model {model_id} (revision={revision}) …")
...@@ -125,8 +126,8 @@ with torch.no_grad(): ...@@ -125,8 +126,8 @@ with torch.no_grad():
print(f"Embeddings: shape={embeddings.shape}, dtype={embeddings.dtype}") print(f"Embeddings: shape={embeddings.shape}, dtype={embeddings.dtype}")
# ── Save to disk ── # ── Save to disk as safetensors (safe format, no pickle) ──
torch.save(embeddings.cpu(), output) safetensors_save_file({"embedding": embeddings.cpu()}, output)
print(f"Saved embeddings → {output}") print(f"Saved embeddings → {output}")
# ── Write resolved model path so Phase 2 uses the exact same revision ── # ── Write resolved model path so Phase 2 uses the exact same revision ──
......
...@@ -341,12 +341,12 @@ trtllm_configs = { ...@@ -341,12 +341,12 @@ trtllm_configs = {
), ),
# LLaVA raw-embeddings E/PD test # LLaVA raw-embeddings E/PD test
# Validates the raw-embeddings code path where pre-computed vision embeddings # Validates the raw-embeddings code path where pre-computed vision embeddings
# (.pt tensor file) are sent via file:// URL instead of a raw image URL. # (.safetensors file) are sent via file:// URL instead of a raw image URL.
# #
# Flow: # Flow:
# 1. Launch script generates embeddings using standalone HF vision encoder # 1. Launch script generates embeddings using standalone HF vision encoder
# 2. Encode + Aggregated PD workers start for LLaVA # 2. Encode + Aggregated PD workers start for LLaVA
# 3. Test sends chat/completions request with file:///tmp/llava_embeddings.pt # 3. Test sends chat/completions request with file:///tmp/llava_embeddings.safetensors
# #
# Uses gpu_2: encode worker on GPU 0, PD worker on GPU 1. # Uses gpu_2: encode worker on GPU 0, PD worker on GPU 1.
# The 7B LLaVA model requires two GPUs because both encode and PD workers # The 7B LLaVA model requires two GPUs because both encode and PD workers
...@@ -372,7 +372,7 @@ trtllm_configs = { ...@@ -372,7 +372,7 @@ trtllm_configs = {
delayed_start=180, delayed_start=180,
request_payloads=[ request_payloads=[
multimodal_payload_default( multimodal_payload_default(
image_url="file:///tmp/llava_embeddings.pt", image_url="file:///tmp/llava_embeddings.safetensors",
text="Describe what this image shows.", text="Describe what this image shows.",
expected_response=["bench", "person", "image", "picture"], expected_response=["bench", "person", "image", "picture"],
) )
...@@ -440,6 +440,35 @@ trtllm_configs = { ...@@ -440,6 +440,35 @@ trtllm_configs = {
), ),
], ],
), ),
# Aggregated multimodal with --frontend-decoding enabled.
# Verifies image URL inference works when images are decoded by the Rust
# MediaDecoder in the frontend instead of the Python backend.
"aggregated_multimodal_frontend_decoding": TRTLLMConfig(
name="aggregated_multimodal_frontend_decoding",
directory=trtllm_dir,
script_name="agg_multimodal.sh",
marks=[
pytest.mark.gpu_1,
pytest.mark.trtllm,
pytest.mark.multimodal,
pytest.mark.pre_merge,
pytest.mark.timeout(900),
],
model="Qwen/Qwen3-VL-2B-Instruct",
frontend_port=DefaultPort.FRONTEND.value,
timeout=900,
delayed_start=60,
request_payloads=[
multimodal_payload_default(
text="Describe what you see in this image.",
expected_response=["mountain", "rock", "trees", "road"],
)
],
env={
"AGG_ENGINE_ARGS": "/workspace/examples/backends/trtllm/engine_configs/qwen3-vl-2b-instruct/agg.yaml",
"DYN_TRTLLM_FRONTEND_DECODING": "true",
},
),
"completions_only": TRTLLMConfig( "completions_only": TRTLLMConfig(
name="completions_only", name="completions_only",
directory=trtllm_dir, directory=trtllm_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment