Unverified Commit adc95380 authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

feat: multi-image in request support for sglang backend (#6068)


Signed-off-by: default avatarWang, Yi <yi.a.wang@intel.com>
parent 203249e1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from typing import Any, List, Literal, Optional, Tuple, Union
from pydantic import BaseModel, ConfigDict, Field
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
......@@ -115,18 +115,21 @@ class MultiModalInput(BaseModel):
video_url: Optional[str] = None
class SglangMultimodalRequest(BaseModel):
class MultiModalGroup(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
request: PreprocessedRequest
multimodal_input: Optional[MultiModalInput] = Field(default_factory=MultiModalInput)
image_grid_thw: Optional[List[Any]] = None
class SglangMultimodalRequest(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
request: PreprocessedRequest
multimodal_inputs: List[MultiModalGroup] = Field(default_factory=list)
# Shared embedding transfer metadata for the entire multimodal request.
embeddings_shape: Optional[
Union[Tuple[int, int], Tuple[int, int, int], Tuple[int, int, int, int]]
] = None
serialized_request: Optional[connect.RdmaMetadata] = None
# Processor metadata (e.g. image_grid_thw) carried from encode worker
# to PD/prefill worker for building the format="processor_output" mm_item.
processor_output: Optional[Dict[str, Any]] = None
class DisaggSglangMultimodalRequest(BaseModel):
......
......@@ -115,50 +115,136 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
# The following steps encode the requested image for SGLang:
# 1. Pass the image URL to MMEncoder which loads, preprocesses, and
# runs the vision encoder.
# 2. Add a batch dimension and store metadata on the request.
# 3. Expand the single image placeholder token to match patch count.
# 4. Create a NIXL descriptor and send embeddings to downstream worker.
# 2. Expand each image placeholder token to match patch count.
# 3. Create a single NIXL descriptor for concatenated embeddings.
# 4. Send request + metadata to downstream worker.
# 5. Stream the downstream worker's response back to the caller.
try:
if not request.multimodal_input.image_url:
raise ValueError("image_url is required for the encode worker.")
multimodal_groups = request.multimodal_inputs
if not multimodal_groups:
raise ValueError("multimodal_inputs is required for the encode worker.")
image_urls = []
for idx, mm_group in enumerate(multimodal_groups):
mm_input = mm_group.multimodal_input
if not mm_input or not mm_input.image_url:
raise ValueError(
f"image_url is required for the encode worker (index={idx})."
)
if mm_input.video_url is not None:
raise NotImplementedError(
"video_url encoding is not supported in SGLang encode worker"
)
image_urls.append(mm_input.image_url)
image_grid_dim, mm_embedding = await self.encoder._encode(
[request.multimodal_input.image_url]
image_grid_dim, precomputed_embeddings = await self.encoder._encode(
image_urls
)
image_grid_thw = (
image_grid_thw_list = (
image_grid_dim.tolist()
if isinstance(image_grid_dim, torch.Tensor)
else image_grid_dim
)
# Store the image data info in the request for downstream
request.processor_output = {"image_grid_thw": image_grid_thw}
request.image_grid_thw = image_grid_thw
request.embeddings_shape = tuple(mm_embedding.shape)
if len(image_grid_thw_list) != len(multimodal_groups):
raise ValueError("image_grid_thw size mismatch")
def _build_token_counts(total_tokens: int) -> list[int]:
if total_tokens <= 0:
raise ValueError("Invalid token statistics for embeddings")
# image_grid_thw is [t, h, w]. We derive per-item relative sizes
# from spatial grid (h * w), then infer merge factor
# from the total embedding token count.
grid_sizes = []
for image_grid_thw in image_grid_thw_list:
if not isinstance(image_grid_thw, list) or len(image_grid_thw) != 3:
raise ValueError(
"Cannot split embeddings: invalid image_grid_thw"
)
grid_sizes.append(int(image_grid_thw[1] * image_grid_thw[2]))
total_grid_tokens = sum(grid_sizes)
if total_grid_tokens <= 0:
raise ValueError("Invalid grid statistics for embeddings")
if total_grid_tokens % total_tokens != 0:
raise ValueError(
"Cannot infer merge factor: grid token total is not divisible by embedding token total"
)
merge_factor = total_grid_tokens // total_tokens
token_counts = []
for grid_count in grid_sizes:
if grid_count % merge_factor != 0:
raise ValueError(
"Cannot split embeddings: per-image grid token count not divisible by inferred merge factor"
)
token_counts.append(grid_count // merge_factor)
if sum(token_counts) != total_tokens:
raise ValueError(
"Cannot split embeddings: per-image token counts do not match embedding token total"
)
return token_counts
if isinstance(precomputed_embeddings, torch.Tensor):
if precomputed_embeddings.ndim != 2:
raise ValueError(
"Unsupported embeddings tensor rank from encoder: "
f"{precomputed_embeddings.ndim}. Expected 2D [tokens, hidden]."
)
token_counts = _build_token_counts(precomputed_embeddings.shape[0])
else:
raise ValueError(
"Unsupported embeddings type from encoder: "
f"{type(precomputed_embeddings)}"
)
image_placeholder_count = request.request.token_ids.count(
self.image_token_id
)
if image_placeholder_count < len(multimodal_groups):
raise ValueError(
"Not enough image placeholders in token_ids for provided images"
)
# Keep per-image grid metadata in request groups for worker-side mm_item.
for idx, (mm_group, image_grid_thw) in enumerate(
zip(multimodal_groups, image_grid_thw_list)
):
mm_group.image_grid_thw = image_grid_thw
mm_group.multimodal_input.image_url = None
# Replace the single image token with multiple image tokens based on embedding shape
image_token_id_index = request.request.token_ids.index(self.image_token_id)
# Store shared serialized tensor metadata at request level.
request.embeddings_shape = tuple(precomputed_embeddings.shape)
request.serialized_request = None
num_image_tokens = mm_embedding.shape[0] # Number of image patches
search_start = 0
for num_image_tokens in token_counts:
try:
image_token_id_index = request.request.token_ids.index(
self.image_token_id, search_start
)
except ValueError as e:
raise ValueError(
"Not enough image tokens found for provided images"
) from e
# Replace single image token with multiple image tokens
request.request.token_ids = (
request.request.token_ids[:image_token_id_index]
+ [self.image_token_id] * num_image_tokens
+ request.request.token_ids[
image_token_id_index + 1 :
] # Skip the original token
+ request.request.token_ids[image_token_id_index + 1 :]
)
search_start = image_token_id_index + num_image_tokens
# Create descriptor for the multimodal data
descriptor = connect.Descriptor(mm_embedding)
descriptor = connect.Descriptor(precomputed_embeddings)
with await self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata()
logger.debug(f"Request: {request.model_dump_json()}")
# Get the response generator from downstream worker
......
......@@ -17,6 +17,7 @@ from dynamo.sglang.multimodal_utils import (
process_sglang_stream_response,
)
from dynamo.sglang.protocol import (
MultiModalGroup,
MultiModalInput,
MultiModalRequest,
SglangMultimodalRequest,
......@@ -67,21 +68,37 @@ class MultimodalProcessorHandler(BaseWorkerHandler):
# If the request is not MultiModalRequest, convert it to MultiModalRequest
raw_request = MultiModalRequest.model_validate(raw_request)
multimodal_input = MultiModalInput()
image_urls: list[str] = []
video_url: str | None = None
for message in raw_request.messages:
for item in message.content:
if item.type == "image_url":
multimodal_input.image_url = item.image_url.url
if video_url is not None:
raise ValueError("Cannot provide both image and video URLs")
image_urls.append(item.image_url.url)
elif item.type == "video_url":
if multimodal_input.image_url is not None:
if image_urls:
raise ValueError("Cannot provide both image and video URLs")
multimodal_input.video_url = item.video_url.url
if video_url is not None:
raise ValueError("Multiple video URLs are not supported")
video_url = item.video_url.url
if multimodal_input.image_url is None and multimodal_input.video_url is None:
if not image_urls and video_url is None:
raise ValueError("Either image URL or video URL is required")
async for response in self._generate(raw_request, multimodal_input):
multimodal_groups: list[MultiModalGroup] = []
if image_urls:
multimodal_groups = [
MultiModalGroup(multimodal_input=MultiModalInput(image_url=url))
for url in image_urls
]
elif video_url is not None:
multimodal_groups = [
MultiModalGroup(multimodal_input=MultiModalInput(video_url=video_url))
]
async for response in self._generate(raw_request, multimodal_groups):
logger.debug(
f"Generated response type {type(response)}, content: {response}"
)
......@@ -90,7 +107,7 @@ class MultimodalProcessorHandler(BaseWorkerHandler):
async def _generate(
self,
raw_request: MultiModalRequest,
multimodal_input: MultiModalInput,
multimodal_groups: list[MultiModalGroup],
):
# Generate a unique request ID for tracking
request_id = str(uuid.uuid4().hex)
......@@ -103,7 +120,7 @@ class MultimodalProcessorHandler(BaseWorkerHandler):
worker_request = SglangMultimodalRequest(
request=sglang_request,
multimodal_input=multimodal_input,
multimodal_inputs=multimodal_groups,
)
# Send to encoder worker
......
......@@ -81,16 +81,24 @@ class EmbeddingsProcessor:
self._connector = connect.Connector()
async def process_embeddings(self, request: SglangMultimodalRequest):
"""Process embeddings from serialized request"""
logger.debug(f"Processing embeddings with shape: {request.embeddings_shape}")
# Validate embeddings shape
if request.embeddings_shape is None or len(request.embeddings_shape) < 2:
raise ValueError(f"Invalid embeddings shape: {request.embeddings_shape}")
"""Process one concatenated embedding tensor from serialized request."""
logger.debug("Processing embeddings with shape: " f"{request.embeddings_shape}")
multimodal_groups = request.multimodal_inputs
if not multimodal_groups:
raise ValueError("multimodal_inputs is required")
serialized_request = request.serialized_request
embeddings_shape = request.embeddings_shape
if serialized_request is None:
raise ValueError("serialized_request is required on request")
if embeddings_shape is None:
raise ValueError("embeddings_shape is required on request")
if len(embeddings_shape) < 2:
raise ValueError(f"Invalid embeddings shape: {embeddings_shape}")
embeddings = torch.empty(
request.embeddings_shape,
embeddings_shape,
dtype=MultimodalConfig.EMBEDDINGS_DTYPE,
device=MultimodalConfig.EMBEDDINGS_DEVICE,
)
......@@ -105,17 +113,13 @@ class EmbeddingsProcessor:
)
self._connector = connect.Connector()
read_op = await self._connector.begin_read(
request.serialized_request, descriptor
)
read_op = await self._connector.begin_read(serialized_request, descriptor)
await read_op.wait_for_completion()
return embeddings, descriptor
@staticmethod
def create_multimodal_item(
embeddings: torch.Tensor, request: SglangMultimodalRequest
) -> dict:
def create_multimodal_item(embeddings: torch.Tensor, image_grid_thw) -> dict:
"""Create mm_item dict for SGLang's engine.async_generate(image_data=[...]).
Uses format="processor_output" with precomputed_embeddings so SGLang
......@@ -123,13 +127,7 @@ class EmbeddingsProcessor:
"""
precomputed = embeddings.to(MultimodalConfig.EMBEDDINGS_DTYPE)
# Convert list fields back to tensors (JSON roundtrip loses tensor type)
processor_output = request.processor_output or {}
for key, value in processor_output.items():
if isinstance(value, list):
processor_output[key] = torch.tensor(value)
mm_item = dict(processor_output)
mm_item = {"image_grid_thw": torch.tensor(image_grid_thw)}
mm_item.update(
{
"format": "processor_output",
......@@ -246,6 +244,23 @@ class ErrorResponseBuilder:
return json.dumps(response)
async def _build_mm_items(
request: SglangMultimodalRequest, embeddings_processor: EmbeddingsProcessor
) -> tuple[list[dict], torch.Tensor]:
"""Process embeddings and build a single multimodal item for SGLang."""
embeddings, _ = await embeddings_processor.process_embeddings(request)
image_grid_thw_list = [group.image_grid_thw for group in request.multimodal_inputs]
if any(item is None for item in image_grid_thw_list):
raise ValueError("image_grid_thw is required")
mm_items = [
embeddings_processor.create_multimodal_item(embeddings, image_grid_thw_list)
]
return mm_items, embeddings
class MultimodalWorkerHandler(BaseWorkerHandler):
"""
Multimodal worker handler for LLM inference with multimodal data.
......@@ -355,23 +370,19 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
try:
sampling_params = SglangUtils.build_sampling_params(request)
embeddings, descriptor = await self.embeddings_processor.process_embeddings(
request
)
# Create multimodal item
mm_item = self.embeddings_processor.create_multimodal_item(
embeddings, request
mm_items, combined_embeddings = await _build_mm_items(
request, self.embeddings_processor
)
logger.debug(
f"Generated multimodal item with embeddings shape: {embeddings.shape}"
"Generated combined multimodal item with embeddings shape: "
f"{combined_embeddings.shape}"
)
logger.debug(f"Input token sequence length: {len(input_ids)}")
agg_stream = await self.engine.async_generate(
input_ids=input_ids,
image_data=[mm_item],
image_data=mm_items,
sampling_params=sampling_params,
stream=True,
)
......@@ -385,12 +396,14 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
"Shape mismatch error - this likely indicates a tokenization/embedding alignment issue"
)
logger.error(f"Request token IDs length: {len(input_ids)}")
logger.error(f"Embeddings shape: {request.embeddings_shape}")
logger.error("Embeddings shape: " f"{request.embeddings_shape}")
logger.error(f"Token sequence preview: {input_ids[:20]}...")
error_msg = (
f"Multimodal embedding alignment error: {str(e)}. "
f"This usually happens when the tokenization changes between requests. "
f"Token count: {len(input_ids)}, Embedding shape: {request.embeddings_shape}"
"Token count: "
f"{len(input_ids)}, Embedding shape: "
f"{request.embeddings_shape}"
)
yield ErrorResponseBuilder.build_error_response(RuntimeError(error_msg))
else:
......@@ -515,17 +528,12 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
sampling_params = disagg_request.sampling_params
# Process embeddings from encode worker using our embeddings processor
embeddings, descriptor = await self.embeddings_processor.process_embeddings(
request
)
# Create multimodal item for prefill generation
mm_item = self.embeddings_processor.create_multimodal_item(embeddings, request)
mm_items, _ = await _build_mm_items(request, self.embeddings_processor)
# Start SGLang prefill generation (like regular SGLang)
results = await self.engine.async_generate(
input_ids=input_ids,
image_data=[mm_item],
image_data=mm_items,
sampling_params=sampling_params,
stream=True,
bootstrap_host=self.bootstrap_host,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment