Unverified Commit adc95380 authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

feat: multi-image in request support for sglang backend (#6068)


Signed-off-by: default avatarWang, Yi <yi.a.wang@intel.com>
parent 203249e1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Literal, Optional, Tuple, Union from typing import Any, List, Literal, Optional, Tuple, Union
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
...@@ -115,18 +115,21 @@ class MultiModalInput(BaseModel): ...@@ -115,18 +115,21 @@ class MultiModalInput(BaseModel):
video_url: Optional[str] = None video_url: Optional[str] = None
class SglangMultimodalRequest(BaseModel): class MultiModalGroup(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
request: PreprocessedRequest
multimodal_input: Optional[MultiModalInput] = Field(default_factory=MultiModalInput) multimodal_input: Optional[MultiModalInput] = Field(default_factory=MultiModalInput)
image_grid_thw: Optional[List[Any]] = None image_grid_thw: Optional[List[Any]] = None
class SglangMultimodalRequest(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
request: PreprocessedRequest
multimodal_inputs: List[MultiModalGroup] = Field(default_factory=list)
# Shared embedding transfer metadata for the entire multimodal request.
embeddings_shape: Optional[ embeddings_shape: Optional[
Union[Tuple[int, int], Tuple[int, int, int], Tuple[int, int, int, int]] Union[Tuple[int, int], Tuple[int, int, int], Tuple[int, int, int, int]]
] = None ] = None
serialized_request: Optional[connect.RdmaMetadata] = None serialized_request: Optional[connect.RdmaMetadata] = None
# Processor metadata (e.g. image_grid_thw) carried from encode worker
# to PD/prefill worker for building the format="processor_output" mm_item.
processor_output: Optional[Dict[str, Any]] = None
class DisaggSglangMultimodalRequest(BaseModel): class DisaggSglangMultimodalRequest(BaseModel):
......
...@@ -115,50 +115,136 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler): ...@@ -115,50 +115,136 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
# The following steps encode the requested image for SGLang: # The following steps encode the requested image for SGLang:
# 1. Pass the image URL to MMEncoder which loads, preprocesses, and # 1. Pass the image URL to MMEncoder which loads, preprocesses, and
# runs the vision encoder. # runs the vision encoder.
# 2. Add a batch dimension and store metadata on the request. # 2. Expand each image placeholder token to match patch count.
# 3. Expand the single image placeholder token to match patch count. # 3. Create a single NIXL descriptor for concatenated embeddings.
# 4. Create a NIXL descriptor and send embeddings to downstream worker. # 4. Send request + metadata to downstream worker.
# 5. Stream the downstream worker's response back to the caller. # 5. Stream the downstream worker's response back to the caller.
try: try:
if not request.multimodal_input.image_url: multimodal_groups = request.multimodal_inputs
raise ValueError("image_url is required for the encode worker.") if not multimodal_groups:
raise ValueError("multimodal_inputs is required for the encode worker.")
image_urls = []
for idx, mm_group in enumerate(multimodal_groups):
mm_input = mm_group.multimodal_input
if not mm_input or not mm_input.image_url:
raise ValueError(
f"image_url is required for the encode worker (index={idx})."
)
if mm_input.video_url is not None:
raise NotImplementedError(
"video_url encoding is not supported in SGLang encode worker"
)
image_urls.append(mm_input.image_url)
image_grid_dim, mm_embedding = await self.encoder._encode( image_grid_dim, precomputed_embeddings = await self.encoder._encode(
[request.multimodal_input.image_url] image_urls
) )
image_grid_thw = ( image_grid_thw_list = (
image_grid_dim.tolist() image_grid_dim.tolist()
if isinstance(image_grid_dim, torch.Tensor) if isinstance(image_grid_dim, torch.Tensor)
else image_grid_dim else image_grid_dim
) )
# Store the image data info in the request for downstream if len(image_grid_thw_list) != len(multimodal_groups):
request.processor_output = {"image_grid_thw": image_grid_thw} raise ValueError("image_grid_thw size mismatch")
request.image_grid_thw = image_grid_thw
request.embeddings_shape = tuple(mm_embedding.shape) def _build_token_counts(total_tokens: int) -> list[int]:
if total_tokens <= 0:
raise ValueError("Invalid token statistics for embeddings")
# image_grid_thw is [t, h, w]. We derive per-item relative sizes
# from spatial grid (h * w), then infer merge factor
# from the total embedding token count.
grid_sizes = []
for image_grid_thw in image_grid_thw_list:
if not isinstance(image_grid_thw, list) or len(image_grid_thw) != 3:
raise ValueError(
"Cannot split embeddings: invalid image_grid_thw"
)
grid_sizes.append(int(image_grid_thw[1] * image_grid_thw[2]))
total_grid_tokens = sum(grid_sizes)
if total_grid_tokens <= 0:
raise ValueError("Invalid grid statistics for embeddings")
if total_grid_tokens % total_tokens != 0:
raise ValueError(
"Cannot infer merge factor: grid token total is not divisible by embedding token total"
)
# Replace the single image token with multiple image tokens based on embedding shape merge_factor = total_grid_tokens // total_tokens
image_token_id_index = request.request.token_ids.index(self.image_token_id) token_counts = []
for grid_count in grid_sizes:
if grid_count % merge_factor != 0:
raise ValueError(
"Cannot split embeddings: per-image grid token count not divisible by inferred merge factor"
)
token_counts.append(grid_count // merge_factor)
if sum(token_counts) != total_tokens:
raise ValueError(
"Cannot split embeddings: per-image token counts do not match embedding token total"
)
num_image_tokens = mm_embedding.shape[0] # Number of image patches return token_counts
# Replace single image token with multiple image tokens if isinstance(precomputed_embeddings, torch.Tensor):
request.request.token_ids = ( if precomputed_embeddings.ndim != 2:
request.request.token_ids[:image_token_id_index] raise ValueError(
+ [self.image_token_id] * num_image_tokens "Unsupported embeddings tensor rank from encoder: "
+ request.request.token_ids[ f"{precomputed_embeddings.ndim}. Expected 2D [tokens, hidden]."
image_token_id_index + 1 : )
] # Skip the original token
token_counts = _build_token_counts(precomputed_embeddings.shape[0])
else:
raise ValueError(
"Unsupported embeddings type from encoder: "
f"{type(precomputed_embeddings)}"
)
image_placeholder_count = request.request.token_ids.count(
self.image_token_id
) )
if image_placeholder_count < len(multimodal_groups):
raise ValueError(
"Not enough image placeholders in token_ids for provided images"
)
# Create descriptor for the multimodal data # Keep per-image grid metadata in request groups for worker-side mm_item.
descriptor = connect.Descriptor(mm_embedding) for idx, (mm_group, image_grid_thw) in enumerate(
zip(multimodal_groups, image_grid_thw_list)
):
mm_group.image_grid_thw = image_grid_thw
mm_group.multimodal_input.image_url = None
# Store shared serialized tensor metadata at request level.
request.embeddings_shape = tuple(precomputed_embeddings.shape)
request.serialized_request = None
search_start = 0
for num_image_tokens in token_counts:
try:
image_token_id_index = request.request.token_ids.index(
self.image_token_id, search_start
)
except ValueError as e:
raise ValueError(
"Not enough image tokens found for provided images"
) from e
request.request.token_ids = (
request.request.token_ids[:image_token_id_index]
+ [self.image_token_id] * num_image_tokens
+ request.request.token_ids[image_token_id_index + 1 :]
)
search_start = image_token_id_index + num_image_tokens
descriptor = connect.Descriptor(precomputed_embeddings)
with await self._connector.create_readable(descriptor) as readable: with await self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata() request.serialized_request = readable.metadata()
logger.debug(f"Request: {request.model_dump_json()}") logger.debug(f"Request: {request.model_dump_json()}")
# Get the response generator from downstream worker # Get the response generator from downstream worker
......
...@@ -17,6 +17,7 @@ from dynamo.sglang.multimodal_utils import ( ...@@ -17,6 +17,7 @@ from dynamo.sglang.multimodal_utils import (
process_sglang_stream_response, process_sglang_stream_response,
) )
from dynamo.sglang.protocol import ( from dynamo.sglang.protocol import (
MultiModalGroup,
MultiModalInput, MultiModalInput,
MultiModalRequest, MultiModalRequest,
SglangMultimodalRequest, SglangMultimodalRequest,
...@@ -67,21 +68,37 @@ class MultimodalProcessorHandler(BaseWorkerHandler): ...@@ -67,21 +68,37 @@ class MultimodalProcessorHandler(BaseWorkerHandler):
# If the request is not MultiModalRequest, convert it to MultiModalRequest # If the request is not MultiModalRequest, convert it to MultiModalRequest
raw_request = MultiModalRequest.model_validate(raw_request) raw_request = MultiModalRequest.model_validate(raw_request)
multimodal_input = MultiModalInput() image_urls: list[str] = []
video_url: str | None = None
for message in raw_request.messages: for message in raw_request.messages:
for item in message.content: for item in message.content:
if item.type == "image_url": if item.type == "image_url":
multimodal_input.image_url = item.image_url.url if video_url is not None:
raise ValueError("Cannot provide both image and video URLs")
image_urls.append(item.image_url.url)
elif item.type == "video_url": elif item.type == "video_url":
if multimodal_input.image_url is not None: if image_urls:
raise ValueError("Cannot provide both image and video URLs") raise ValueError("Cannot provide both image and video URLs")
multimodal_input.video_url = item.video_url.url if video_url is not None:
raise ValueError("Multiple video URLs are not supported")
video_url = item.video_url.url
if multimodal_input.image_url is None and multimodal_input.video_url is None: if not image_urls and video_url is None:
raise ValueError("Either image URL or video URL is required") raise ValueError("Either image URL or video URL is required")
async for response in self._generate(raw_request, multimodal_input): multimodal_groups: list[MultiModalGroup] = []
if image_urls:
multimodal_groups = [
MultiModalGroup(multimodal_input=MultiModalInput(image_url=url))
for url in image_urls
]
elif video_url is not None:
multimodal_groups = [
MultiModalGroup(multimodal_input=MultiModalInput(video_url=video_url))
]
async for response in self._generate(raw_request, multimodal_groups):
logger.debug( logger.debug(
f"Generated response type {type(response)}, content: {response}" f"Generated response type {type(response)}, content: {response}"
) )
...@@ -90,7 +107,7 @@ class MultimodalProcessorHandler(BaseWorkerHandler): ...@@ -90,7 +107,7 @@ class MultimodalProcessorHandler(BaseWorkerHandler):
async def _generate( async def _generate(
self, self,
raw_request: MultiModalRequest, raw_request: MultiModalRequest,
multimodal_input: MultiModalInput, multimodal_groups: list[MultiModalGroup],
): ):
# Generate a unique request ID for tracking # Generate a unique request ID for tracking
request_id = str(uuid.uuid4().hex) request_id = str(uuid.uuid4().hex)
...@@ -103,7 +120,7 @@ class MultimodalProcessorHandler(BaseWorkerHandler): ...@@ -103,7 +120,7 @@ class MultimodalProcessorHandler(BaseWorkerHandler):
worker_request = SglangMultimodalRequest( worker_request = SglangMultimodalRequest(
request=sglang_request, request=sglang_request,
multimodal_input=multimodal_input, multimodal_inputs=multimodal_groups,
) )
# Send to encoder worker # Send to encoder worker
......
...@@ -81,16 +81,24 @@ class EmbeddingsProcessor: ...@@ -81,16 +81,24 @@ class EmbeddingsProcessor:
self._connector = connect.Connector() self._connector = connect.Connector()
async def process_embeddings(self, request: SglangMultimodalRequest): async def process_embeddings(self, request: SglangMultimodalRequest):
"""Process embeddings from serialized request""" """Process one concatenated embedding tensor from serialized request."""
logger.debug("Processing embeddings with shape: " f"{request.embeddings_shape}")
logger.debug(f"Processing embeddings with shape: {request.embeddings_shape}")
multimodal_groups = request.multimodal_inputs
# Validate embeddings shape if not multimodal_groups:
if request.embeddings_shape is None or len(request.embeddings_shape) < 2: raise ValueError("multimodal_inputs is required")
raise ValueError(f"Invalid embeddings shape: {request.embeddings_shape}")
serialized_request = request.serialized_request
embeddings_shape = request.embeddings_shape
if serialized_request is None:
raise ValueError("serialized_request is required on request")
if embeddings_shape is None:
raise ValueError("embeddings_shape is required on request")
if len(embeddings_shape) < 2:
raise ValueError(f"Invalid embeddings shape: {embeddings_shape}")
embeddings = torch.empty( embeddings = torch.empty(
request.embeddings_shape, embeddings_shape,
dtype=MultimodalConfig.EMBEDDINGS_DTYPE, dtype=MultimodalConfig.EMBEDDINGS_DTYPE,
device=MultimodalConfig.EMBEDDINGS_DEVICE, device=MultimodalConfig.EMBEDDINGS_DEVICE,
) )
...@@ -105,17 +113,13 @@ class EmbeddingsProcessor: ...@@ -105,17 +113,13 @@ class EmbeddingsProcessor:
) )
self._connector = connect.Connector() self._connector = connect.Connector()
read_op = await self._connector.begin_read( read_op = await self._connector.begin_read(serialized_request, descriptor)
request.serialized_request, descriptor
)
await read_op.wait_for_completion() await read_op.wait_for_completion()
return embeddings, descriptor return embeddings, descriptor
@staticmethod @staticmethod
def create_multimodal_item( def create_multimodal_item(embeddings: torch.Tensor, image_grid_thw) -> dict:
embeddings: torch.Tensor, request: SglangMultimodalRequest
) -> dict:
"""Create mm_item dict for SGLang's engine.async_generate(image_data=[...]). """Create mm_item dict for SGLang's engine.async_generate(image_data=[...]).
Uses format="processor_output" with precomputed_embeddings so SGLang Uses format="processor_output" with precomputed_embeddings so SGLang
...@@ -123,13 +127,7 @@ class EmbeddingsProcessor: ...@@ -123,13 +127,7 @@ class EmbeddingsProcessor:
""" """
precomputed = embeddings.to(MultimodalConfig.EMBEDDINGS_DTYPE) precomputed = embeddings.to(MultimodalConfig.EMBEDDINGS_DTYPE)
# Convert list fields back to tensors (JSON roundtrip loses tensor type) mm_item = {"image_grid_thw": torch.tensor(image_grid_thw)}
processor_output = request.processor_output or {}
for key, value in processor_output.items():
if isinstance(value, list):
processor_output[key] = torch.tensor(value)
mm_item = dict(processor_output)
mm_item.update( mm_item.update(
{ {
"format": "processor_output", "format": "processor_output",
...@@ -246,6 +244,23 @@ class ErrorResponseBuilder: ...@@ -246,6 +244,23 @@ class ErrorResponseBuilder:
return json.dumps(response) return json.dumps(response)
async def _build_mm_items(
request: SglangMultimodalRequest, embeddings_processor: EmbeddingsProcessor
) -> tuple[list[dict], torch.Tensor]:
"""Process embeddings and build a single multimodal item for SGLang."""
embeddings, _ = await embeddings_processor.process_embeddings(request)
image_grid_thw_list = [group.image_grid_thw for group in request.multimodal_inputs]
if any(item is None for item in image_grid_thw_list):
raise ValueError("image_grid_thw is required")
mm_items = [
embeddings_processor.create_multimodal_item(embeddings, image_grid_thw_list)
]
return mm_items, embeddings
class MultimodalWorkerHandler(BaseWorkerHandler): class MultimodalWorkerHandler(BaseWorkerHandler):
""" """
Multimodal worker handler for LLM inference with multimodal data. Multimodal worker handler for LLM inference with multimodal data.
...@@ -355,23 +370,19 @@ class MultimodalWorkerHandler(BaseWorkerHandler): ...@@ -355,23 +370,19 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
try: try:
sampling_params = SglangUtils.build_sampling_params(request) sampling_params = SglangUtils.build_sampling_params(request)
embeddings, descriptor = await self.embeddings_processor.process_embeddings( mm_items, combined_embeddings = await _build_mm_items(
request request, self.embeddings_processor
)
# Create multimodal item
mm_item = self.embeddings_processor.create_multimodal_item(
embeddings, request
) )
logger.debug( logger.debug(
f"Generated multimodal item with embeddings shape: {embeddings.shape}" "Generated combined multimodal item with embeddings shape: "
f"{combined_embeddings.shape}"
) )
logger.debug(f"Input token sequence length: {len(input_ids)}") logger.debug(f"Input token sequence length: {len(input_ids)}")
agg_stream = await self.engine.async_generate( agg_stream = await self.engine.async_generate(
input_ids=input_ids, input_ids=input_ids,
image_data=[mm_item], image_data=mm_items,
sampling_params=sampling_params, sampling_params=sampling_params,
stream=True, stream=True,
) )
...@@ -385,12 +396,14 @@ class MultimodalWorkerHandler(BaseWorkerHandler): ...@@ -385,12 +396,14 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
"Shape mismatch error - this likely indicates a tokenization/embedding alignment issue" "Shape mismatch error - this likely indicates a tokenization/embedding alignment issue"
) )
logger.error(f"Request token IDs length: {len(input_ids)}") logger.error(f"Request token IDs length: {len(input_ids)}")
logger.error(f"Embeddings shape: {request.embeddings_shape}") logger.error("Embeddings shape: " f"{request.embeddings_shape}")
logger.error(f"Token sequence preview: {input_ids[:20]}...") logger.error(f"Token sequence preview: {input_ids[:20]}...")
error_msg = ( error_msg = (
f"Multimodal embedding alignment error: {str(e)}. " f"Multimodal embedding alignment error: {str(e)}. "
f"This usually happens when the tokenization changes between requests. " f"This usually happens when the tokenization changes between requests. "
f"Token count: {len(input_ids)}, Embedding shape: {request.embeddings_shape}" "Token count: "
f"{len(input_ids)}, Embedding shape: "
f"{request.embeddings_shape}"
) )
yield ErrorResponseBuilder.build_error_response(RuntimeError(error_msg)) yield ErrorResponseBuilder.build_error_response(RuntimeError(error_msg))
else: else:
...@@ -515,17 +528,12 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler): ...@@ -515,17 +528,12 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
sampling_params = disagg_request.sampling_params sampling_params = disagg_request.sampling_params
# Process embeddings from encode worker using our embeddings processor # Process embeddings from encode worker using our embeddings processor
embeddings, descriptor = await self.embeddings_processor.process_embeddings( mm_items, _ = await _build_mm_items(request, self.embeddings_processor)
request
)
# Create multimodal item for prefill generation
mm_item = self.embeddings_processor.create_multimodal_item(embeddings, request)
# Start SGLang prefill generation (like regular SGLang) # Start SGLang prefill generation (like regular SGLang)
results = await self.engine.async_generate( results = await self.engine.async_generate(
input_ids=input_ids, input_ids=input_ids,
image_data=[mm_item], image_data=mm_items,
sampling_params=sampling_params, sampling_params=sampling_params,
stream=True, stream=True,
bootstrap_host=self.bootstrap_host, bootstrap_host=self.bootstrap_host,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment