Unverified Commit 3967ec0f authored by KrishnanPrash's avatar KrishnanPrash Committed by GitHub
Browse files

feat: Use SGLang MMEncoder for multimodal EPD encode worker (#6162)


Signed-off-by: default avatarKrishnan Prashanth <kprashanth@nvidia.com>
parent a3f1e7ec
......@@ -5,14 +5,8 @@ from dynamo.sglang.multimodal_utils.multimodal_chat_processor import (
multimodal_request_to_sglang,
process_sglang_stream_response,
)
from dynamo.sglang.multimodal_utils.multimodal_encode_utils import (
encode_image_embeddings,
)
from dynamo.sglang.multimodal_utils.multimodal_image_loader import ImageLoader
__all__ = [
"multimodal_request_to_sglang",
"process_sglang_stream_response",
"encode_image_embeddings",
"ImageLoader",
]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from pathlib import Path
from typing import Any, Dict, Optional
import torch
logger = logging.getLogger(__name__)
class SupportedModels:
"""Supported multimodal model identifiers"""
QWEN_2_5_VL_7B = "Qwen/Qwen2.5-VL-7B-Instruct"
def normalize_model_name(model_name: str) -> str:
"""
Extract and normalize model name from various formats including HuggingFace cache paths.
Args:
model_name: Model identifier which can be:
- A simple model name: "Qwen/Qwen2.5-VL-7B-Instruct"
- A HuggingFace cache path: "/root/.cache/huggingface/hub/models--Qwen--Qwen2.5-VL-7B-Instruct/..."
- A local path to a model directory
Returns:
Normalized model name in the format "organization/model-name"
Examples:
>>> normalize_model_name("Qwen/Qwen2.5-VL-7B-Instruct")
"Qwen/Qwen2.5-VL-7B-Instruct"
>>> normalize_model_name("/root/.cache/huggingface/hub/models--Qwen--Qwen2.5-VL-7B-Instruct/snapshots/...")
"Qwen/Qwen2.5-VL-7B-Instruct"
"""
# If it's already a simple model name (org/model format), return as-is
if "/" in model_name and not model_name.startswith("/"):
return model_name
# Handle HuggingFace cache paths
if "models--" in model_name:
# Extract from cache path format: models--ORG--MODEL-NAME
# Split on "models--" then on "--" to handle dashes in org/model names
parts_after_models = model_name.split("models--", 1)
if len(parts_after_models) > 1:
# Split the remaining part on "--" and take the last two segments
segments = parts_after_models[1].split("--")
if len(segments) >= 2:
# Take all segments except the last as org (rejoined with dashes)
# and the last segment (before any slash) as model name
org_segments = segments[:-1]
model_segment = segments[-1].split("/")[
0
] # Remove any path after model name
org = "--".join(org_segments) # Rejoin org parts with dashes
model = model_segment
return f"{org}/{model}"
# Handle local directory paths - extract the last directory name
path = Path(model_name)
if path.exists() and path.is_dir():
return path.name
# If no pattern matches, return the original name
return model_name
def is_model_supported(model_name: str, supported_model: str) -> bool:
"""
Check if a model name matches a supported model, handling various naming formats.
Args:
model_name: The model name to check (may be path, cache name, etc.)
supported_model: The supported model identifier
Returns:
True if the model is supported, False otherwise
"""
normalized_name = normalize_model_name(model_name).lower()
normalized_supported = normalize_model_name(supported_model).lower()
# Exact match
if normalized_name == normalized_supported:
return True
# Handle local path case: compare only the model name part (without organization)
# e.g., "qwen2.5-vl-7b-instruct" matches "qwen/qwen2.5-vl-7b-instruct"
if "/" in normalized_supported:
model_part = normalized_supported.split("/")[-1]
if normalized_name == model_part:
return True
return False
def get_qwen_image_features(
vision_encoder: torch.nn.Module, image_embeds: Dict[str, Any]
) -> torch.Tensor:
"""
Extract image features using Qwen-style vision encoder.
Args:
vision_encoder: The vision encoder model
image_embeds: Dictionary containing pixel values and grid information
Returns:
Processed image features tensor
Raises:
ValueError: If grid_thw is not provided for Qwen model
"""
pixel_values = image_embeds["pixel_values"].to(vision_encoder.device)
grid_thw = image_embeds.get("image_grid_thw", None)
if grid_thw is not None:
grid_thw = grid_thw.to(vision_encoder.device)
logger.debug(f"Qwen grid_thw shape: {grid_thw.shape}")
else:
raise ValueError("grid_thw is not provided")
return (
vision_encoder.get_image_features(pixel_values, grid_thw) # type: ignore
if grid_thw is not None
else vision_encoder.get_image_features(pixel_values) # type: ignore
)
def encode_image_embeddings(
model_name: str,
image_embeds: Dict[str, Any],
vision_encoder: torch.nn.Module,
projector: Optional[torch.nn.Module] = None,
) -> torch.Tensor:
"""
Encode image embeddings using the appropriate model-specific encoder.
Args:
model_name: The model identifier
image_embeds: Dictionary containing processed image data
vision_encoder: The vision encoder module
projector: The multimodal projector (required for LLaVA-style models)
Returns:
Encoded embeddings tensor with normalized shape
Raises:
ValueError: If projector is missing for LLaVA models
NotImplementedError: If model is not supported
"""
with torch.no_grad():
# Route through the correct encoder based on model
if is_model_supported(model_name, SupportedModels.QWEN_2_5_VL_7B):
embeddings = get_qwen_image_features(vision_encoder, image_embeds)
else:
# Provide more helpful error message with normalized model name
normalized_name = normalize_model_name(model_name)
raise NotImplementedError(
f"Model not supported: {normalized_name} (original: {model_name})"
)
# Normalize output shape
if isinstance(embeddings, (tuple, list)):
embeddings = embeddings[0]
embeddings = embeddings.unsqueeze(0) if embeddings.ndim == 2 else embeddings
return embeddings
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import base64
import binascii
import logging
from io import BytesIO
from typing import Optional
from urllib.parse import urlparse
import httpx
from PIL import Image
logger = logging.getLogger(__name__)
# Global HTTP client instance
_global_http_client: Optional[httpx.AsyncClient] = None
def get_http_client(timeout: float = 60.0) -> httpx.AsyncClient:
"""
Get or create a shared HTTP client instance.
Args:
timeout: Timeout for HTTP requests
Returns:
Shared HTTP client instance
"""
global _global_http_client
if _global_http_client is None or _global_http_client.is_closed:
_global_http_client = httpx.AsyncClient(
timeout=timeout,
follow_redirects=True,
limits=httpx.Limits(max_keepalive_connections=20, max_connections=100),
)
logger.info(f"Shared HTTP client initialized with timeout={timeout}s")
return _global_http_client
class ImageLoader:
CACHE_SIZE_MAXIMUM = 8
def __init__(
self, cache_size: int = CACHE_SIZE_MAXIMUM, http_timeout: float = 30.0
):
self._http_timeout = http_timeout
self._image_cache: dict[str, Image.Image] = {}
self._cache_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=cache_size)
async def load_image(self, image_url: str) -> Image.Image:
parsed_url = urlparse(image_url)
# For HTTP(S) URLs, check cache first
if parsed_url.scheme in ("http", "https"):
image_url_lower = image_url.lower()
if image_url_lower in self._image_cache:
logger.debug(f"Image found in cache for URL: {image_url}")
return self._image_cache[image_url_lower]
try:
if parsed_url.scheme == "data":
# Parse data URL format: data:[<media type>][;base64],<data>
if not parsed_url.path.startswith("image/"):
raise ValueError("Data URL must be an image type")
# Split the path into media type and data
media_type, data = parsed_url.path.split(",", 1)
if ";base64" not in media_type:
raise ValueError("Data URL must be base64 encoded")
try:
image_bytes = base64.b64decode(data)
image_data = BytesIO(image_bytes)
except binascii.Error as e:
raise ValueError(f"Invalid base64 encoding: {e}")
elif parsed_url.scheme in ("http", "https"):
http_client = get_http_client(self._http_timeout)
response = await http_client.get(image_url)
response.raise_for_status()
if not response.content:
raise ValueError("Empty response content from image URL")
image_data = BytesIO(response.content)
else:
raise ValueError(f"Invalid image source scheme: {parsed_url.scheme}")
# PIL is sync, so offload to a thread to avoid blocking the event loop
# Restrict to supported formats to prevent PSD parsing (GHSA-cfh3-3jmp-rvhc)
image = await asyncio.to_thread(
Image.open, image_data, formats=["JPEG", "PNG", "WEBP"]
)
# Validate image format and convert to RGB
if image.format not in ("JPEG", "PNG", "WEBP"):
raise ValueError(f"Unsupported image format: {image.format}")
image_converted = image.convert("RGB")
# Cache HTTP(S) URLs
if parsed_url.scheme in ("http", "https"):
image_url_lower = image_url.lower()
# Cache the image for future use, and evict the oldest image if the cache is full
if self._cache_queue.full():
oldest_image_url = await self._cache_queue.get()
del self._image_cache[oldest_image_url]
self._image_cache[image_url_lower] = image_converted
await self._cache_queue.put(image_url_lower)
return image_converted
except httpx.HTTPError as e:
logger.error(f"HTTP error loading image: {e}")
raise
except Exception as e:
logger.error(f"Error loading image: {e}")
raise ValueError(f"Failed to load image: {e}")
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Any, List, Literal, Optional, Tuple, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from pydantic import BaseModel, ConfigDict, Field
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
......@@ -122,9 +122,12 @@ class SglangMultimodalRequest(BaseModel):
multimodal_input: Optional[MultiModalInput] = Field(default_factory=MultiModalInput)
image_grid_thw: Optional[List[Any]] = None
embeddings_shape: Optional[
Union[Tuple[int, int, int], Tuple[int, int, int, int]]
Union[Tuple[int, int], Tuple[int, int, int], Tuple[int, int, int, int]]
] = None
serialized_request: Optional[connect.RdmaMetadata] = None
# Processor metadata (e.g. image_grid_thw) carried from encode worker
# to PD/prefill worker for building the format="processor_output" mm_item.
processor_output: Optional[Dict[str, Any]] = None
class DisaggSglangMultimodalRequest(BaseModel):
......
......@@ -6,14 +6,19 @@ import logging
from typing import AsyncIterator, Optional
import torch
# MMEncoder chain imports compiled CUDA ops; may fail in CPU-only environments.
try:
from sglang.srt.disaggregation.encode_server import MMEncoder
except (ImportError, OSError):
MMEncoder = None # type: ignore[assignment]
from sglang.srt.parser.conversation import chat_templates
from transformers import AutoImageProcessor, AutoModel, AutoTokenizer
from transformers import AutoTokenizer
import dynamo.nixl_connect as connect
from dynamo._core import Client, Component, Context
from dynamo.runtime import DistributedRuntime
from dynamo.sglang.args import Config
from dynamo.sglang.multimodal_utils import ImageLoader, encode_image_embeddings
from dynamo.sglang.protocol import SglangMultimodalRequest
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
......@@ -32,8 +37,6 @@ except ImportError as e:
DEVICE = "cpu"
CACHE_SIZE_MAXIMUM = 8
class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
"""
......@@ -53,18 +56,19 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
)
self.pd_worker_client = pd_worker_client
self.model = config.server_args.model_path
self.served_model_name = config.server_args.served_model_name
self.image_loader = ImageLoader(cache_size=CACHE_SIZE_MAXIMUM)
if MMEncoder is None:
raise RuntimeError(
"MMEncoder is not available. "
"Multimodal encode worker requires a CUDA environment."
)
self.image_processor = AutoImageProcessor.from_pretrained(
self.model, trust_remote_code=True
)
self.vision_model = AutoModel.from_pretrained(
self.model,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True,
# torch.distributed requires a dist_init_method even for tp=1;
# port 0 lets the OS assign a free port.
self.encoder = MMEncoder(
server_args=config.server_args,
dist_init_method="tcp://127.0.0.1:0",
rank=0,
)
# Load tokenizer to convert image token string to integer ID
......@@ -112,45 +116,37 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
request = SglangMultimodalRequest.model_validate(request)
# The following steps encode the requested image for SGLang:
# 1. Open the image from the provided URL.
# 2. Process the image using the processor (which handles tokenization).
# 3. Extract input_ids and image data from processed result.
# 4. Run the image through the vision model to get precomputed embeddings.
# 5. Create SGLang-specific multimodal data format.
# 6. Create a descriptor for the embeddings and send to downstream worker.
# 1. Pass the image URL to MMEncoder which loads, preprocesses, and
# runs the vision encoder.
# 2. Add a batch dimension and store metadata on the request.
# 3. Expand the single image placeholder token to match patch count.
# 4. Create a NIXL descriptor and send embeddings to downstream worker.
# 5. Stream the downstream worker's response back to the caller.
try:
if not request.multimodal_input.image_url:
raise ValueError("image_url is required for the encode worker.")
image = await self.image_loader.load_image(
request.multimodal_input.image_url
)
image_embeds = self.image_processor(images=image, return_tensors="pt")
precomputed_embeddings = encode_image_embeddings(
model_name=self.served_model_name,
image_embeds=image_embeds,
vision_encoder=self.vision_model,
projector=None,
image_grid_dim, mm_embedding = await self.encoder._encode(
[request.multimodal_input.image_url]
)
image_grid_thw = (
image_embeds["image_grid_thw"].tolist()
if "image_grid_thw" in image_embeds
else None
image_grid_dim.tolist()
if isinstance(image_grid_dim, torch.Tensor)
else image_grid_dim
)
# Store the image data info in the request for downstream
request.processor_output = {"image_grid_thw": image_grid_thw}
request.image_grid_thw = image_grid_thw
request.embeddings_shape = tuple(precomputed_embeddings.shape)
request.embeddings_shape = tuple(mm_embedding.shape)
# Replace the single image token with multiple image tokens based on embedding shape
image_token_id_index = request.request.token_ids.index(self.image_token_id)
num_image_tokens = precomputed_embeddings.shape[
1
] # Number of image patches
num_image_tokens = mm_embedding.shape[0] # Number of image patches
# Replace single image token with multiple image tokens
request.request.token_ids = (
request.request.token_ids[:image_token_id_index]
......@@ -161,7 +157,7 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
)
# Create descriptor for the multimodal data
descriptor = connect.Descriptor(precomputed_embeddings)
descriptor = connect.Descriptor(mm_embedding)
with await self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata()
......
......@@ -115,28 +115,27 @@ class EmbeddingsProcessor:
def create_multimodal_item(
embeddings: torch.Tensor, request: SglangMultimodalRequest
) -> dict:
"""
Create multimodal item for SGLang generation.
"""Create mm_item dict for SGLang's engine.async_generate(image_data=[...]).
Uses format="precomputed_embedding" since Dynamo's Encoder has already
run the vision encoder. SGLang expects 2D embeddings (num_patches, hidden_dim).
Uses format="processor_output" with precomputed_embeddings so SGLang
bypasses get_image_feature() entirely (model-agnostic path).
"""
precomputed = embeddings.to(MultimodalConfig.EMBEDDINGS_DTYPE)
# SGLang expects 2D tensor for precomputed_embedding format
# Encoder outputs 3D (1, num_patches, hidden_dim) for internal consistency
# Squeeze batch dimension at SGLang boundary
if precomputed.dim() == 3 and precomputed.shape[0] == 1:
precomputed = precomputed.squeeze(0)
grid_thw_tensor = torch.tensor(request.image_grid_thw)
mm_item = {
"format": "precomputed_embedding",
"feature": precomputed,
"image_grid_thw": grid_thw_tensor,
"modality": "IMAGE",
}
# Convert list fields back to tensors (JSON roundtrip loses tensor type)
processor_output = request.processor_output or {}
for key, value in processor_output.items():
if isinstance(value, list):
processor_output[key] = torch.tensor(value)
mm_item = dict(processor_output)
mm_item.update(
{
"format": "processor_output",
"precomputed_embeddings": precomputed,
"modality": "IMAGE",
}
)
return mm_item
......
......@@ -45,7 +45,7 @@ SGLang supports EPD, E/PD, and E/P/D patterns. See [Multimodal Architecture Patt
### SGLang-Specific Characteristics
- **Vision Encoder in Python**: Encode worker loads vision model (AutoModel) and image processor (AutoImageProcessor)
- **Vision Encoder in Python**: Encode worker uses SGLang's MMEncoder for model-agnostic vision encoding
- **Token Expansion**: Single `<|image_pad|>` token replaced with N tokens based on embedding shape
- **NIXL Transfer**: Embeddings transferred from Encoder → PD Worker using NIXL
- **No Rust Processing**: All tokenization and image handling happens in Python
......@@ -338,18 +338,19 @@ await read_op.wait_for_completion()
### Encode Worker Components
The encode worker loads and runs the vision model in Python:
The encode worker uses SGLang's `MMEncoder` for model-agnostic vision encoding. `MMEncoder` handles vision model loading, image preprocessing, and feature extraction internally:
```python
self.image_processor = AutoImageProcessor.from_pretrained(
model_path, trust_remote_code=True
)
self.vision_model = AutoModel.from_pretrained(
model_path,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True
from sglang.srt.disaggregation.encode_server import MMEncoder
self.encoder = MMEncoder(
server_args=config.server_args,
dist_init_method="tcp://127.0.0.1:0",
rank=0,
)
# At request time:
image_grid_dim, mm_embedding = await self.encoder._encode([image_url])
```
### Token Expansion Process
......@@ -390,6 +391,19 @@ Supported templates: `qwen2-vl`, `llama-3`, `vicuna`, etc.
**Key Difference:** SGLang P/D uses bootstrap mechanism, not NIXL for KV cache like vLLM.
## Environment Variables
### `SGLANG_ENCODER_MM_LOAD_WORKERS`
Controls how many threads the encoder uses to fetch and load images concurrently. When a request contains multiple images (URLs, file paths, or base64 data), each image is loaded in a separate thread. Default is 4. Increase if image loading (network fetch or disk I/O) is the bottleneck rather than GPU compute. Has no effect if the vision encoder itself is the bottleneck, since encoding is sequential on GPU after all images are loaded.
```bash
# Example: allow up to 16 concurrent image loads per encoder
export SGLANG_ENCODER_MM_LOAD_WORKERS=16
```
Only applies to the EPD encode worker (which uses [SGLang's MMEncoder](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/disaggregation/encode_server.py) internally).
## Known Limitations
- **No Data URL support** - Only HTTP/HTTPS URLs supported; `data:image/...` base64 URLs not supported
......@@ -404,9 +418,9 @@ Supported templates: `qwen2-vl`, `llama-3`, `vicuna`, etc.
SGLang multimodal **only supports image-based vision-language models**:
- **Qwen2-VL** / **Qwen2.5-VL** (primary support)
- Models with `AutoImageProcessor` and vision tower
- Models compatible with SGLang's image embedding format
- **Qwen2-VL** / **Qwen2.5-VL** - `Qwen/Qwen2.5-VL-7B-Instruct`
- **Qwen3-VL** - `Qwen/Qwen3-VL-30B-A3B-Instruct`
- Models supported by SGLang's MMEncoder
## Key Files
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment