Unverified Commit 3967ec0f authored by KrishnanPrash's avatar KrishnanPrash Committed by GitHub
Browse files

feat: Use SGLang MMEncoder for multimodal EPD encode worker (#6162)


Signed-off-by: default avatarKrishnan Prashanth <kprashanth@nvidia.com>
parent a3f1e7ec
...@@ -5,14 +5,8 @@ from dynamo.sglang.multimodal_utils.multimodal_chat_processor import ( ...@@ -5,14 +5,8 @@ from dynamo.sglang.multimodal_utils.multimodal_chat_processor import (
multimodal_request_to_sglang, multimodal_request_to_sglang,
process_sglang_stream_response, process_sglang_stream_response,
) )
from dynamo.sglang.multimodal_utils.multimodal_encode_utils import (
encode_image_embeddings,
)
from dynamo.sglang.multimodal_utils.multimodal_image_loader import ImageLoader
__all__ = [ __all__ = [
"multimodal_request_to_sglang", "multimodal_request_to_sglang",
"process_sglang_stream_response", "process_sglang_stream_response",
"encode_image_embeddings",
"ImageLoader",
] ]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from pathlib import Path
from typing import Any, Dict, Optional
import torch
logger = logging.getLogger(__name__)
class SupportedModels:
"""Supported multimodal model identifiers"""
QWEN_2_5_VL_7B = "Qwen/Qwen2.5-VL-7B-Instruct"
def normalize_model_name(model_name: str) -> str:
"""
Extract and normalize model name from various formats including HuggingFace cache paths.
Args:
model_name: Model identifier which can be:
- A simple model name: "Qwen/Qwen2.5-VL-7B-Instruct"
- A HuggingFace cache path: "/root/.cache/huggingface/hub/models--Qwen--Qwen2.5-VL-7B-Instruct/..."
- A local path to a model directory
Returns:
Normalized model name in the format "organization/model-name"
Examples:
>>> normalize_model_name("Qwen/Qwen2.5-VL-7B-Instruct")
"Qwen/Qwen2.5-VL-7B-Instruct"
>>> normalize_model_name("/root/.cache/huggingface/hub/models--Qwen--Qwen2.5-VL-7B-Instruct/snapshots/...")
"Qwen/Qwen2.5-VL-7B-Instruct"
"""
# If it's already a simple model name (org/model format), return as-is
if "/" in model_name and not model_name.startswith("/"):
return model_name
# Handle HuggingFace cache paths
if "models--" in model_name:
# Extract from cache path format: models--ORG--MODEL-NAME
# Split on "models--" then on "--" to handle dashes in org/model names
parts_after_models = model_name.split("models--", 1)
if len(parts_after_models) > 1:
# Split the remaining part on "--" and take the last two segments
segments = parts_after_models[1].split("--")
if len(segments) >= 2:
# Take all segments except the last as org (rejoined with dashes)
# and the last segment (before any slash) as model name
org_segments = segments[:-1]
model_segment = segments[-1].split("/")[
0
] # Remove any path after model name
org = "--".join(org_segments) # Rejoin org parts with dashes
model = model_segment
return f"{org}/{model}"
# Handle local directory paths - extract the last directory name
path = Path(model_name)
if path.exists() and path.is_dir():
return path.name
# If no pattern matches, return the original name
return model_name
def is_model_supported(model_name: str, supported_model: str) -> bool:
"""
Check if a model name matches a supported model, handling various naming formats.
Args:
model_name: The model name to check (may be path, cache name, etc.)
supported_model: The supported model identifier
Returns:
True if the model is supported, False otherwise
"""
normalized_name = normalize_model_name(model_name).lower()
normalized_supported = normalize_model_name(supported_model).lower()
# Exact match
if normalized_name == normalized_supported:
return True
# Handle local path case: compare only the model name part (without organization)
# e.g., "qwen2.5-vl-7b-instruct" matches "qwen/qwen2.5-vl-7b-instruct"
if "/" in normalized_supported:
model_part = normalized_supported.split("/")[-1]
if normalized_name == model_part:
return True
return False
def get_qwen_image_features(
vision_encoder: torch.nn.Module, image_embeds: Dict[str, Any]
) -> torch.Tensor:
"""
Extract image features using Qwen-style vision encoder.
Args:
vision_encoder: The vision encoder model
image_embeds: Dictionary containing pixel values and grid information
Returns:
Processed image features tensor
Raises:
ValueError: If grid_thw is not provided for Qwen model
"""
pixel_values = image_embeds["pixel_values"].to(vision_encoder.device)
grid_thw = image_embeds.get("image_grid_thw", None)
if grid_thw is not None:
grid_thw = grid_thw.to(vision_encoder.device)
logger.debug(f"Qwen grid_thw shape: {grid_thw.shape}")
else:
raise ValueError("grid_thw is not provided")
return (
vision_encoder.get_image_features(pixel_values, grid_thw) # type: ignore
if grid_thw is not None
else vision_encoder.get_image_features(pixel_values) # type: ignore
)
def encode_image_embeddings(
model_name: str,
image_embeds: Dict[str, Any],
vision_encoder: torch.nn.Module,
projector: Optional[torch.nn.Module] = None,
) -> torch.Tensor:
"""
Encode image embeddings using the appropriate model-specific encoder.
Args:
model_name: The model identifier
image_embeds: Dictionary containing processed image data
vision_encoder: The vision encoder module
projector: The multimodal projector (required for LLaVA-style models)
Returns:
Encoded embeddings tensor with normalized shape
Raises:
ValueError: If projector is missing for LLaVA models
NotImplementedError: If model is not supported
"""
with torch.no_grad():
# Route through the correct encoder based on model
if is_model_supported(model_name, SupportedModels.QWEN_2_5_VL_7B):
embeddings = get_qwen_image_features(vision_encoder, image_embeds)
else:
# Provide more helpful error message with normalized model name
normalized_name = normalize_model_name(model_name)
raise NotImplementedError(
f"Model not supported: {normalized_name} (original: {model_name})"
)
# Normalize output shape
if isinstance(embeddings, (tuple, list)):
embeddings = embeddings[0]
embeddings = embeddings.unsqueeze(0) if embeddings.ndim == 2 else embeddings
return embeddings
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import base64
import binascii
import logging
from io import BytesIO
from typing import Optional
from urllib.parse import urlparse
import httpx
from PIL import Image
logger = logging.getLogger(__name__)
# Global HTTP client instance
_global_http_client: Optional[httpx.AsyncClient] = None
def get_http_client(timeout: float = 60.0) -> httpx.AsyncClient:
"""
Get or create a shared HTTP client instance.
Args:
timeout: Timeout for HTTP requests
Returns:
Shared HTTP client instance
"""
global _global_http_client
if _global_http_client is None or _global_http_client.is_closed:
_global_http_client = httpx.AsyncClient(
timeout=timeout,
follow_redirects=True,
limits=httpx.Limits(max_keepalive_connections=20, max_connections=100),
)
logger.info(f"Shared HTTP client initialized with timeout={timeout}s")
return _global_http_client
class ImageLoader:
CACHE_SIZE_MAXIMUM = 8
def __init__(
self, cache_size: int = CACHE_SIZE_MAXIMUM, http_timeout: float = 30.0
):
self._http_timeout = http_timeout
self._image_cache: dict[str, Image.Image] = {}
self._cache_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=cache_size)
async def load_image(self, image_url: str) -> Image.Image:
parsed_url = urlparse(image_url)
# For HTTP(S) URLs, check cache first
if parsed_url.scheme in ("http", "https"):
image_url_lower = image_url.lower()
if image_url_lower in self._image_cache:
logger.debug(f"Image found in cache for URL: {image_url}")
return self._image_cache[image_url_lower]
try:
if parsed_url.scheme == "data":
# Parse data URL format: data:[<media type>][;base64],<data>
if not parsed_url.path.startswith("image/"):
raise ValueError("Data URL must be an image type")
# Split the path into media type and data
media_type, data = parsed_url.path.split(",", 1)
if ";base64" not in media_type:
raise ValueError("Data URL must be base64 encoded")
try:
image_bytes = base64.b64decode(data)
image_data = BytesIO(image_bytes)
except binascii.Error as e:
raise ValueError(f"Invalid base64 encoding: {e}")
elif parsed_url.scheme in ("http", "https"):
http_client = get_http_client(self._http_timeout)
response = await http_client.get(image_url)
response.raise_for_status()
if not response.content:
raise ValueError("Empty response content from image URL")
image_data = BytesIO(response.content)
else:
raise ValueError(f"Invalid image source scheme: {parsed_url.scheme}")
# PIL is sync, so offload to a thread to avoid blocking the event loop
# Restrict to supported formats to prevent PSD parsing (GHSA-cfh3-3jmp-rvhc)
image = await asyncio.to_thread(
Image.open, image_data, formats=["JPEG", "PNG", "WEBP"]
)
# Validate image format and convert to RGB
if image.format not in ("JPEG", "PNG", "WEBP"):
raise ValueError(f"Unsupported image format: {image.format}")
image_converted = image.convert("RGB")
# Cache HTTP(S) URLs
if parsed_url.scheme in ("http", "https"):
image_url_lower = image_url.lower()
# Cache the image for future use, and evict the oldest image if the cache is full
if self._cache_queue.full():
oldest_image_url = await self._cache_queue.get()
del self._image_cache[oldest_image_url]
self._image_cache[image_url_lower] = image_converted
await self._cache_queue.put(image_url_lower)
return image_converted
except httpx.HTTPError as e:
logger.error(f"HTTP error loading image: {e}")
raise
except Exception as e:
logger.error(f"Error loading image: {e}")
raise ValueError(f"Failed to load image: {e}")
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, List, Literal, Optional, Tuple, Union from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
...@@ -122,9 +122,12 @@ class SglangMultimodalRequest(BaseModel): ...@@ -122,9 +122,12 @@ class SglangMultimodalRequest(BaseModel):
multimodal_input: Optional[MultiModalInput] = Field(default_factory=MultiModalInput) multimodal_input: Optional[MultiModalInput] = Field(default_factory=MultiModalInput)
image_grid_thw: Optional[List[Any]] = None image_grid_thw: Optional[List[Any]] = None
embeddings_shape: Optional[ embeddings_shape: Optional[
Union[Tuple[int, int, int], Tuple[int, int, int, int]] Union[Tuple[int, int], Tuple[int, int, int], Tuple[int, int, int, int]]
] = None ] = None
serialized_request: Optional[connect.RdmaMetadata] = None serialized_request: Optional[connect.RdmaMetadata] = None
# Processor metadata (e.g. image_grid_thw) carried from encode worker
# to PD/prefill worker for building the format="processor_output" mm_item.
processor_output: Optional[Dict[str, Any]] = None
class DisaggSglangMultimodalRequest(BaseModel): class DisaggSglangMultimodalRequest(BaseModel):
......
...@@ -6,14 +6,19 @@ import logging ...@@ -6,14 +6,19 @@ import logging
from typing import AsyncIterator, Optional from typing import AsyncIterator, Optional
import torch import torch
# MMEncoder chain imports compiled CUDA ops; may fail in CPU-only environments.
try:
from sglang.srt.disaggregation.encode_server import MMEncoder
except (ImportError, OSError):
MMEncoder = None # type: ignore[assignment]
from sglang.srt.parser.conversation import chat_templates from sglang.srt.parser.conversation import chat_templates
from transformers import AutoImageProcessor, AutoModel, AutoTokenizer from transformers import AutoTokenizer
import dynamo.nixl_connect as connect import dynamo.nixl_connect as connect
from dynamo._core import Client, Component, Context from dynamo._core import Client, Component, Context
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
from dynamo.sglang.multimodal_utils import ImageLoader, encode_image_embeddings
from dynamo.sglang.protocol import SglangMultimodalRequest from dynamo.sglang.protocol import SglangMultimodalRequest
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
...@@ -32,8 +37,6 @@ except ImportError as e: ...@@ -32,8 +37,6 @@ except ImportError as e:
DEVICE = "cpu" DEVICE = "cpu"
CACHE_SIZE_MAXIMUM = 8
class MultimodalEncodeWorkerHandler(BaseWorkerHandler): class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
""" """
...@@ -53,18 +56,19 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler): ...@@ -53,18 +56,19 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
) )
self.pd_worker_client = pd_worker_client self.pd_worker_client = pd_worker_client
self.model = config.server_args.model_path self.model = config.server_args.model_path
self.served_model_name = config.server_args.served_model_name
self.image_loader = ImageLoader(cache_size=CACHE_SIZE_MAXIMUM) if MMEncoder is None:
raise RuntimeError(
"MMEncoder is not available. "
"Multimodal encode worker requires a CUDA environment."
)
self.image_processor = AutoImageProcessor.from_pretrained( # torch.distributed requires a dist_init_method even for tp=1;
self.model, trust_remote_code=True # port 0 lets the OS assign a free port.
) self.encoder = MMEncoder(
self.vision_model = AutoModel.from_pretrained( server_args=config.server_args,
self.model, dist_init_method="tcp://127.0.0.1:0",
device_map="auto", rank=0,
torch_dtype=torch.float16,
trust_remote_code=True,
) )
# Load tokenizer to convert image token string to integer ID # Load tokenizer to convert image token string to integer ID
...@@ -112,45 +116,37 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler): ...@@ -112,45 +116,37 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
request = SglangMultimodalRequest.model_validate(request) request = SglangMultimodalRequest.model_validate(request)
# The following steps encode the requested image for SGLang: # The following steps encode the requested image for SGLang:
# 1. Open the image from the provided URL. # 1. Pass the image URL to MMEncoder which loads, preprocesses, and
# 2. Process the image using the processor (which handles tokenization). # runs the vision encoder.
# 3. Extract input_ids and image data from processed result. # 2. Add a batch dimension and store metadata on the request.
# 4. Run the image through the vision model to get precomputed embeddings. # 3. Expand the single image placeholder token to match patch count.
# 5. Create SGLang-specific multimodal data format. # 4. Create a NIXL descriptor and send embeddings to downstream worker.
# 6. Create a descriptor for the embeddings and send to downstream worker. # 5. Stream the downstream worker's response back to the caller.
try: try:
if not request.multimodal_input.image_url: if not request.multimodal_input.image_url:
raise ValueError("image_url is required for the encode worker.") raise ValueError("image_url is required for the encode worker.")
image = await self.image_loader.load_image( image_grid_dim, mm_embedding = await self.encoder._encode(
request.multimodal_input.image_url [request.multimodal_input.image_url]
)
image_embeds = self.image_processor(images=image, return_tensors="pt")
precomputed_embeddings = encode_image_embeddings(
model_name=self.served_model_name,
image_embeds=image_embeds,
vision_encoder=self.vision_model,
projector=None,
) )
image_grid_thw = ( image_grid_thw = (
image_embeds["image_grid_thw"].tolist() image_grid_dim.tolist()
if "image_grid_thw" in image_embeds if isinstance(image_grid_dim, torch.Tensor)
else None else image_grid_dim
) )
# Store the image data info in the request for downstream # Store the image data info in the request for downstream
request.processor_output = {"image_grid_thw": image_grid_thw}
request.image_grid_thw = image_grid_thw request.image_grid_thw = image_grid_thw
request.embeddings_shape = tuple(precomputed_embeddings.shape) request.embeddings_shape = tuple(mm_embedding.shape)
# Replace the single image token with multiple image tokens based on embedding shape # Replace the single image token with multiple image tokens based on embedding shape
image_token_id_index = request.request.token_ids.index(self.image_token_id) image_token_id_index = request.request.token_ids.index(self.image_token_id)
num_image_tokens = precomputed_embeddings.shape[ num_image_tokens = mm_embedding.shape[0] # Number of image patches
1
] # Number of image patches
# Replace single image token with multiple image tokens # Replace single image token with multiple image tokens
request.request.token_ids = ( request.request.token_ids = (
request.request.token_ids[:image_token_id_index] request.request.token_ids[:image_token_id_index]
...@@ -161,7 +157,7 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler): ...@@ -161,7 +157,7 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
) )
# Create descriptor for the multimodal data # Create descriptor for the multimodal data
descriptor = connect.Descriptor(precomputed_embeddings) descriptor = connect.Descriptor(mm_embedding)
with await self._connector.create_readable(descriptor) as readable: with await self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata() request.serialized_request = readable.metadata()
......
...@@ -115,28 +115,27 @@ class EmbeddingsProcessor: ...@@ -115,28 +115,27 @@ class EmbeddingsProcessor:
def create_multimodal_item( def create_multimodal_item(
embeddings: torch.Tensor, request: SglangMultimodalRequest embeddings: torch.Tensor, request: SglangMultimodalRequest
) -> dict: ) -> dict:
""" """Create mm_item dict for SGLang's engine.async_generate(image_data=[...]).
Create multimodal item for SGLang generation.
Uses format="precomputed_embedding" since Dynamo's Encoder has already Uses format="processor_output" with precomputed_embeddings so SGLang
run the vision encoder. SGLang expects 2D embeddings (num_patches, hidden_dim). bypasses get_image_feature() entirely (model-agnostic path).
""" """
precomputed = embeddings.to(MultimodalConfig.EMBEDDINGS_DTYPE) precomputed = embeddings.to(MultimodalConfig.EMBEDDINGS_DTYPE)
# SGLang expects 2D tensor for precomputed_embedding format # Convert list fields back to tensors (JSON roundtrip loses tensor type)
# Encoder outputs 3D (1, num_patches, hidden_dim) for internal consistency processor_output = request.processor_output or {}
# Squeeze batch dimension at SGLang boundary for key, value in processor_output.items():
if precomputed.dim() == 3 and precomputed.shape[0] == 1: if isinstance(value, list):
precomputed = precomputed.squeeze(0) processor_output[key] = torch.tensor(value)
grid_thw_tensor = torch.tensor(request.image_grid_thw) mm_item = dict(processor_output)
mm_item.update(
mm_item = { {
"format": "precomputed_embedding", "format": "processor_output",
"feature": precomputed, "precomputed_embeddings": precomputed,
"image_grid_thw": grid_thw_tensor, "modality": "IMAGE",
"modality": "IMAGE", }
} )
return mm_item return mm_item
......
...@@ -45,7 +45,7 @@ SGLang supports EPD, E/PD, and E/P/D patterns. See [Multimodal Architecture Patt ...@@ -45,7 +45,7 @@ SGLang supports EPD, E/PD, and E/P/D patterns. See [Multimodal Architecture Patt
### SGLang-Specific Characteristics ### SGLang-Specific Characteristics
- **Vision Encoder in Python**: Encode worker loads vision model (AutoModel) and image processor (AutoImageProcessor) - **Vision Encoder in Python**: Encode worker uses SGLang's MMEncoder for model-agnostic vision encoding
- **Token Expansion**: Single `<|image_pad|>` token replaced with N tokens based on embedding shape - **Token Expansion**: Single `<|image_pad|>` token replaced with N tokens based on embedding shape
- **NIXL Transfer**: Embeddings transferred from Encoder → PD Worker using NIXL - **NIXL Transfer**: Embeddings transferred from Encoder → PD Worker using NIXL
- **No Rust Processing**: All tokenization and image handling happens in Python - **No Rust Processing**: All tokenization and image handling happens in Python
...@@ -338,18 +338,19 @@ await read_op.wait_for_completion() ...@@ -338,18 +338,19 @@ await read_op.wait_for_completion()
### Encode Worker Components ### Encode Worker Components
The encode worker loads and runs the vision model in Python: The encode worker uses SGLang's `MMEncoder` for model-agnostic vision encoding. `MMEncoder` handles vision model loading, image preprocessing, and feature extraction internally:
```python ```python
self.image_processor = AutoImageProcessor.from_pretrained( from sglang.srt.disaggregation.encode_server import MMEncoder
model_path, trust_remote_code=True
) self.encoder = MMEncoder(
self.vision_model = AutoModel.from_pretrained( server_args=config.server_args,
model_path, dist_init_method="tcp://127.0.0.1:0",
device_map="auto", rank=0,
torch_dtype=torch.float16,
trust_remote_code=True
) )
# At request time:
image_grid_dim, mm_embedding = await self.encoder._encode([image_url])
``` ```
### Token Expansion Process ### Token Expansion Process
...@@ -390,6 +391,19 @@ Supported templates: `qwen2-vl`, `llama-3`, `vicuna`, etc. ...@@ -390,6 +391,19 @@ Supported templates: `qwen2-vl`, `llama-3`, `vicuna`, etc.
**Key Difference:** SGLang P/D uses bootstrap mechanism, not NIXL for KV cache like vLLM. **Key Difference:** SGLang P/D uses bootstrap mechanism, not NIXL for KV cache like vLLM.
## Environment Variables
### `SGLANG_ENCODER_MM_LOAD_WORKERS`
Controls how many threads the encoder uses to fetch and load images concurrently. When a request contains multiple images (URLs, file paths, or base64 data), each image is loaded in a separate thread. Default is 4. Increase if image loading (network fetch or disk I/O) is the bottleneck rather than GPU compute. Has no effect if the vision encoder itself is the bottleneck, since encoding is sequential on GPU after all images are loaded.
```bash
# Example: allow up to 16 concurrent image loads per encoder
export SGLANG_ENCODER_MM_LOAD_WORKERS=16
```
Only applies to the EPD encode worker (which uses [SGLang's MMEncoder](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/disaggregation/encode_server.py) internally).
## Known Limitations ## Known Limitations
- **No Data URL support** - Only HTTP/HTTPS URLs supported; `data:image/...` base64 URLs not supported - **No Data URL support** - Only HTTP/HTTPS URLs supported; `data:image/...` base64 URLs not supported
...@@ -404,9 +418,9 @@ Supported templates: `qwen2-vl`, `llama-3`, `vicuna`, etc. ...@@ -404,9 +418,9 @@ Supported templates: `qwen2-vl`, `llama-3`, `vicuna`, etc.
SGLang multimodal **only supports image-based vision-language models**: SGLang multimodal **only supports image-based vision-language models**:
- **Qwen2-VL** / **Qwen2.5-VL** (primary support) - **Qwen2-VL** / **Qwen2.5-VL** - `Qwen/Qwen2.5-VL-7B-Instruct`
- Models with `AutoImageProcessor` and vision tower - **Qwen3-VL** - `Qwen/Qwen3-VL-30B-A3B-Instruct`
- Models compatible with SGLang's image embedding format - Models supported by SGLang's MMEncoder
## Key Files ## Key Files
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment