Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
adc95380
Unverified
Commit
adc95380
authored
Mar 05, 2026
by
Wang, Yi
Committed by
GitHub
Mar 04, 2026
Browse files
feat: multi-image in request support for sglang backend (#6068)
Signed-off-by:
Wang, Yi
<
yi.a.wang@intel.com
>
parent
203249e1
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
192 additions
and
78 deletions
+192
-78
components/src/dynamo/sglang/protocol.py
components/src/dynamo/sglang/protocol.py
+9
-6
components/src/dynamo/sglang/request_handlers/multimodal/encode_worker_handler.py
...lang/request_handlers/multimodal/encode_worker_handler.py
+111
-25
components/src/dynamo/sglang/request_handlers/multimodal/processor_handler.py
...o/sglang/request_handlers/multimodal/processor_handler.py
+25
-8
components/src/dynamo/sglang/request_handlers/multimodal/worker_handler.py
...namo/sglang/request_handlers/multimodal/worker_handler.py
+47
-39
No files found.
components/src/dynamo/sglang/protocol.py
View file @
adc95380
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
Dict
,
List
,
Literal
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
List
,
Literal
,
Optional
,
Tuple
,
Union
from
pydantic
import
BaseModel
,
ConfigDict
,
Field
from
sglang.srt.entrypoints.openai.protocol
import
ChatCompletionRequest
...
...
@@ -115,18 +115,21 @@ class MultiModalInput(BaseModel):
video_url
:
Optional
[
str
]
=
None
class
Sglang
Multi
m
odal
Request
(
BaseModel
):
class
Multi
M
odal
Group
(
BaseModel
):
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
request
:
PreprocessedRequest
multimodal_input
:
Optional
[
MultiModalInput
]
=
Field
(
default_factory
=
MultiModalInput
)
image_grid_thw
:
Optional
[
List
[
Any
]]
=
None
class
SglangMultimodalRequest
(
BaseModel
):
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
request
:
PreprocessedRequest
multimodal_inputs
:
List
[
MultiModalGroup
]
=
Field
(
default_factory
=
list
)
# Shared embedding transfer metadata for the entire multimodal request.
embeddings_shape
:
Optional
[
Union
[
Tuple
[
int
,
int
],
Tuple
[
int
,
int
,
int
],
Tuple
[
int
,
int
,
int
,
int
]]
]
=
None
serialized_request
:
Optional
[
connect
.
RdmaMetadata
]
=
None
# Processor metadata (e.g. image_grid_thw) carried from encode worker
# to PD/prefill worker for building the format="processor_output" mm_item.
processor_output
:
Optional
[
Dict
[
str
,
Any
]]
=
None
class
DisaggSglangMultimodalRequest
(
BaseModel
):
...
...
components/src/dynamo/sglang/request_handlers/multimodal/encode_worker_handler.py
View file @
adc95380
...
...
@@ -115,50 +115,136 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
# The following steps encode the requested image for SGLang:
# 1. Pass the image URL to MMEncoder which loads, preprocesses, and
# runs the vision encoder.
# 2.
Add a bat
ch
d
im
ension and store metadata on the reques
t.
# 3.
Expand the single image placeholder token to match patch count
.
# 4.
Create a NIXL descriptor and send embeddings
to downstream worker.
# 2.
Expand ea
ch im
age placeholder token to match patch coun
t.
# 3.
Create a single NIXL descriptor for concatenated embeddings
.
# 4.
Send request + metadata
to downstream worker.
# 5. Stream the downstream worker's response back to the caller.
try
:
if
not
request
.
multimodal_input
.
image_url
:
raise
ValueError
(
"image_url is required for the encode worker."
)
multimodal_groups
=
request
.
multimodal_inputs
if
not
multimodal_groups
:
raise
ValueError
(
"multimodal_inputs is required for the encode worker."
)
image_urls
=
[]
for
idx
,
mm_group
in
enumerate
(
multimodal_groups
):
mm_input
=
mm_group
.
multimodal_input
if
not
mm_input
or
not
mm_input
.
image_url
:
raise
ValueError
(
f
"image_url is required for the encode worker (index=
{
idx
}
)."
)
if
mm_input
.
video_url
is
not
None
:
raise
NotImplementedError
(
"video_url encoding is not supported in SGLang encode worker"
)
image_urls
.
append
(
mm_input
.
image_url
)
image_grid_dim
,
mm
_embedding
=
await
self
.
encoder
.
_encode
(
[
request
.
multimodal_input
.
image_url
]
image_grid_dim
,
precomputed
_embedding
s
=
await
self
.
encoder
.
_encode
(
image_url
s
)
image_grid_thw
=
(
image_grid_thw
_list
=
(
image_grid_dim
.
tolist
()
if
isinstance
(
image_grid_dim
,
torch
.
Tensor
)
else
image_grid_dim
)
# Store the image data info in the request for downstream
request
.
processor_output
=
{
"image_grid_thw"
:
image_grid_thw
}
request
.
image_grid_thw
=
image_grid_thw
request
.
embeddings_shape
=
tuple
(
mm_embedding
.
shape
)
if
len
(
image_grid_thw_list
)
!=
len
(
multimodal_groups
):
raise
ValueError
(
"image_grid_thw size mismatch"
)
def
_build_token_counts
(
total_tokens
:
int
)
->
list
[
int
]:
if
total_tokens
<=
0
:
raise
ValueError
(
"Invalid token statistics for embeddings"
)
# image_grid_thw is [t, h, w]. We derive per-item relative sizes
# from spatial grid (h * w), then infer merge factor
# from the total embedding token count.
grid_sizes
=
[]
for
image_grid_thw
in
image_grid_thw_list
:
if
not
isinstance
(
image_grid_thw
,
list
)
or
len
(
image_grid_thw
)
!=
3
:
raise
ValueError
(
"Cannot split embeddings: invalid image_grid_thw"
)
grid_sizes
.
append
(
int
(
image_grid_thw
[
1
]
*
image_grid_thw
[
2
]))
total_grid_tokens
=
sum
(
grid_sizes
)
if
total_grid_tokens
<=
0
:
raise
ValueError
(
"Invalid grid statistics for embeddings"
)
if
total_grid_tokens
%
total_tokens
!=
0
:
raise
ValueError
(
"Cannot infer merge factor: grid token total is not divisible by embedding token total"
)
merge_factor
=
total_grid_tokens
//
total_tokens
token_counts
=
[]
for
grid_count
in
grid_sizes
:
if
grid_count
%
merge_factor
!=
0
:
raise
ValueError
(
"Cannot split embeddings: per-image grid token count not divisible by inferred merge factor"
)
token_counts
.
append
(
grid_count
//
merge_factor
)
if
sum
(
token_counts
)
!=
total_tokens
:
raise
ValueError
(
"Cannot split embeddings: per-image token counts do not match embedding token total"
)
return
token_counts
if
isinstance
(
precomputed_embeddings
,
torch
.
Tensor
):
if
precomputed_embeddings
.
ndim
!=
2
:
raise
ValueError
(
"Unsupported embeddings tensor rank from encoder: "
f
"
{
precomputed_embeddings
.
ndim
}
. Expected 2D [tokens, hidden]."
)
token_counts
=
_build_token_counts
(
precomputed_embeddings
.
shape
[
0
])
else
:
raise
ValueError
(
"Unsupported embeddings type from encoder: "
f
"
{
type
(
precomputed_embeddings
)
}
"
)
image_placeholder_count
=
request
.
request
.
token_ids
.
count
(
self
.
image_token_id
)
if
image_placeholder_count
<
len
(
multimodal_groups
):
raise
ValueError
(
"Not enough image placeholders in token_ids for provided images"
)
# Keep per-image grid metadata in request groups for worker-side mm_item.
for
idx
,
(
mm_group
,
image_grid_thw
)
in
enumerate
(
zip
(
multimodal_groups
,
image_grid_thw_list
)
):
mm_group
.
image_grid_thw
=
image_grid_thw
mm_group
.
multimodal_input
.
image_url
=
None
# Replace the single image token with multiple image tokens based on embedding shape
image_token_id_index
=
request
.
request
.
token_ids
.
index
(
self
.
image_token_id
)
# Store shared serialized tensor metadata at request level.
request
.
embeddings_shape
=
tuple
(
precomputed_embeddings
.
shape
)
request
.
serialized_request
=
None
num_image_tokens
=
mm_embedding
.
shape
[
0
]
# Number of image patches
search_start
=
0
for
num_image_tokens
in
token_counts
:
try
:
image_token_id_index
=
request
.
request
.
token_ids
.
index
(
self
.
image_token_id
,
search_start
)
except
ValueError
as
e
:
raise
ValueError
(
"Not enough image tokens found for provided images"
)
from
e
# Replace single image token with multiple image tokens
request
.
request
.
token_ids
=
(
request
.
request
.
token_ids
[:
image_token_id_index
]
+
[
self
.
image_token_id
]
*
num_image_tokens
+
request
.
request
.
token_ids
[
image_token_id_index
+
1
:
]
# Skip the original token
+
request
.
request
.
token_ids
[
image_token_id_index
+
1
:]
)
search_start
=
image_token_id_index
+
num_image_tokens
# Create descriptor for the multimodal data
descriptor
=
connect
.
Descriptor
(
mm_embedding
)
descriptor
=
connect
.
Descriptor
(
precomputed_embeddings
)
with
await
self
.
_connector
.
create_readable
(
descriptor
)
as
readable
:
request
.
serialized_request
=
readable
.
metadata
()
logger
.
debug
(
f
"Request:
{
request
.
model_dump_json
()
}
"
)
# Get the response generator from downstream worker
...
...
components/src/dynamo/sglang/request_handlers/multimodal/processor_handler.py
View file @
adc95380
...
...
@@ -17,6 +17,7 @@ from dynamo.sglang.multimodal_utils import (
process_sglang_stream_response
,
)
from
dynamo.sglang.protocol
import
(
MultiModalGroup
,
MultiModalInput
,
MultiModalRequest
,
SglangMultimodalRequest
,
...
...
@@ -67,21 +68,37 @@ class MultimodalProcessorHandler(BaseWorkerHandler):
# If the request is not MultiModalRequest, convert it to MultiModalRequest
raw_request
=
MultiModalRequest
.
model_validate
(
raw_request
)
multimodal_input
=
MultiModalInput
()
image_urls
:
list
[
str
]
=
[]
video_url
:
str
|
None
=
None
for
message
in
raw_request
.
messages
:
for
item
in
message
.
content
:
if
item
.
type
==
"image_url"
:
multimodal_input
.
image_url
=
item
.
image_url
.
url
if
video_url
is
not
None
:
raise
ValueError
(
"Cannot provide both image and video URLs"
)
image_urls
.
append
(
item
.
image_url
.
url
)
elif
item
.
type
==
"video_url"
:
if
multimodal_input
.
image_url
is
not
None
:
if
image_urls
:
raise
ValueError
(
"Cannot provide both image and video URLs"
)
multimodal_input
.
video_url
=
item
.
video_url
.
url
if
video_url
is
not
None
:
raise
ValueError
(
"Multiple video URLs are not supported"
)
video_url
=
item
.
video_url
.
url
if
multimodal_input
.
image_url
is
None
and
multimodal_input
.
video_url
is
None
:
if
not
image_urls
and
video_url
is
None
:
raise
ValueError
(
"Either image URL or video URL is required"
)
async
for
response
in
self
.
_generate
(
raw_request
,
multimodal_input
):
multimodal_groups
:
list
[
MultiModalGroup
]
=
[]
if
image_urls
:
multimodal_groups
=
[
MultiModalGroup
(
multimodal_input
=
MultiModalInput
(
image_url
=
url
))
for
url
in
image_urls
]
elif
video_url
is
not
None
:
multimodal_groups
=
[
MultiModalGroup
(
multimodal_input
=
MultiModalInput
(
video_url
=
video_url
))
]
async
for
response
in
self
.
_generate
(
raw_request
,
multimodal_groups
):
logger
.
debug
(
f
"Generated response type
{
type
(
response
)
}
, content:
{
response
}
"
)
...
...
@@ -90,7 +107,7 @@ class MultimodalProcessorHandler(BaseWorkerHandler):
async
def
_generate
(
self
,
raw_request
:
MultiModalRequest
,
multimodal_
input
:
MultiModal
Input
,
multimodal_
groups
:
list
[
MultiModal
Group
]
,
):
# Generate a unique request ID for tracking
request_id
=
str
(
uuid
.
uuid4
().
hex
)
...
...
@@ -103,7 +120,7 @@ class MultimodalProcessorHandler(BaseWorkerHandler):
worker_request
=
SglangMultimodalRequest
(
request
=
sglang_request
,
multimodal_input
=
multimodal_
input
,
multimodal_input
s
=
multimodal_
groups
,
)
# Send to encoder worker
...
...
components/src/dynamo/sglang/request_handlers/multimodal/worker_handler.py
View file @
adc95380
...
...
@@ -81,16 +81,24 @@ class EmbeddingsProcessor:
self
.
_connector
=
connect
.
Connector
()
async
def
process_embeddings
(
self
,
request
:
SglangMultimodalRequest
):
"""Process embeddings from serialized request"""
logger
.
debug
(
f
"Processing embeddings with shape:
{
request
.
embeddings_shape
}
"
)
# Validate embeddings shape
if
request
.
embeddings_shape
is
None
or
len
(
request
.
embeddings_shape
)
<
2
:
raise
ValueError
(
f
"Invalid embeddings shape:
{
request
.
embeddings_shape
}
"
)
"""Process one concatenated embedding tensor from serialized request."""
logger
.
debug
(
"Processing embeddings with shape: "
f
"
{
request
.
embeddings_shape
}
"
)
multimodal_groups
=
request
.
multimodal_inputs
if
not
multimodal_groups
:
raise
ValueError
(
"multimodal_inputs is required"
)
serialized_request
=
request
.
serialized_request
embeddings_shape
=
request
.
embeddings_shape
if
serialized_request
is
None
:
raise
ValueError
(
"serialized_request is required on request"
)
if
embeddings_shape
is
None
:
raise
ValueError
(
"embeddings_shape is required on request"
)
if
len
(
embeddings_shape
)
<
2
:
raise
ValueError
(
f
"Invalid embeddings shape:
{
embeddings_shape
}
"
)
embeddings
=
torch
.
empty
(
request
.
embeddings_shape
,
embeddings_shape
,
dtype
=
MultimodalConfig
.
EMBEDDINGS_DTYPE
,
device
=
MultimodalConfig
.
EMBEDDINGS_DEVICE
,
)
...
...
@@ -105,17 +113,13 @@ class EmbeddingsProcessor:
)
self
.
_connector
=
connect
.
Connector
()
read_op
=
await
self
.
_connector
.
begin_read
(
request
.
serialized_request
,
descriptor
)
read_op
=
await
self
.
_connector
.
begin_read
(
serialized_request
,
descriptor
)
await
read_op
.
wait_for_completion
()
return
embeddings
,
descriptor
@
staticmethod
def
create_multimodal_item
(
embeddings
:
torch
.
Tensor
,
request
:
SglangMultimodalRequest
)
->
dict
:
def
create_multimodal_item
(
embeddings
:
torch
.
Tensor
,
image_grid_thw
)
->
dict
:
"""Create mm_item dict for SGLang's engine.async_generate(image_data=[...]).
Uses format="processor_output" with precomputed_embeddings so SGLang
...
...
@@ -123,13 +127,7 @@ class EmbeddingsProcessor:
"""
precomputed
=
embeddings
.
to
(
MultimodalConfig
.
EMBEDDINGS_DTYPE
)
# Convert list fields back to tensors (JSON roundtrip loses tensor type)
processor_output
=
request
.
processor_output
or
{}
for
key
,
value
in
processor_output
.
items
():
if
isinstance
(
value
,
list
):
processor_output
[
key
]
=
torch
.
tensor
(
value
)
mm_item
=
dict
(
processor_output
)
mm_item
=
{
"image_grid_thw"
:
torch
.
tensor
(
image_grid_thw
)}
mm_item
.
update
(
{
"format"
:
"processor_output"
,
...
...
@@ -246,6 +244,23 @@ class ErrorResponseBuilder:
return
json
.
dumps
(
response
)
async
def
_build_mm_items
(
request
:
SglangMultimodalRequest
,
embeddings_processor
:
EmbeddingsProcessor
)
->
tuple
[
list
[
dict
],
torch
.
Tensor
]:
"""Process embeddings and build a single multimodal item for SGLang."""
embeddings
,
_
=
await
embeddings_processor
.
process_embeddings
(
request
)
image_grid_thw_list
=
[
group
.
image_grid_thw
for
group
in
request
.
multimodal_inputs
]
if
any
(
item
is
None
for
item
in
image_grid_thw_list
):
raise
ValueError
(
"image_grid_thw is required"
)
mm_items
=
[
embeddings_processor
.
create_multimodal_item
(
embeddings
,
image_grid_thw_list
)
]
return
mm_items
,
embeddings
class
MultimodalWorkerHandler
(
BaseWorkerHandler
):
"""
Multimodal worker handler for LLM inference with multimodal data.
...
...
@@ -355,23 +370,19 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
try
:
sampling_params
=
SglangUtils
.
build_sampling_params
(
request
)
embeddings
,
descriptor
=
await
self
.
embeddings_processor
.
process_embeddings
(
request
)
# Create multimodal item
mm_item
=
self
.
embeddings_processor
.
create_multimodal_item
(
embeddings
,
request
mm_items
,
combined_embeddings
=
await
_build_mm_items
(
request
,
self
.
embeddings_processor
)
logger
.
debug
(
f
"Generated multimodal item with embeddings shape:
{
embeddings
.
shape
}
"
"Generated combined multimodal item with embeddings shape: "
f
"
{
combined_embeddings
.
shape
}
"
)
logger
.
debug
(
f
"Input token sequence length:
{
len
(
input_ids
)
}
"
)
agg_stream
=
await
self
.
engine
.
async_generate
(
input_ids
=
input_ids
,
image_data
=
[
mm_item
]
,
image_data
=
mm_item
s
,
sampling_params
=
sampling_params
,
stream
=
True
,
)
...
...
@@ -385,12 +396,14 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
"Shape mismatch error - this likely indicates a tokenization/embedding alignment issue"
)
logger
.
error
(
f
"Request token IDs length:
{
len
(
input_ids
)
}
"
)
logger
.
error
(
f
"Embeddings shape:
{
request
.
embeddings_shape
}
"
)
logger
.
error
(
"Embeddings shape:
"
f
"
{
request
.
embeddings_shape
}
"
)
logger
.
error
(
f
"Token sequence preview:
{
input_ids
[:
20
]
}
..."
)
error_msg
=
(
f
"Multimodal embedding alignment error:
{
str
(
e
)
}
. "
f
"This usually happens when the tokenization changes between requests. "
f
"Token count:
{
len
(
input_ids
)
}
, Embedding shape:
{
request
.
embeddings_shape
}
"
"Token count: "
f
"
{
len
(
input_ids
)
}
, Embedding shape: "
f
"
{
request
.
embeddings_shape
}
"
)
yield
ErrorResponseBuilder
.
build_error_response
(
RuntimeError
(
error_msg
))
else
:
...
...
@@ -515,17 +528,12 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
sampling_params
=
disagg_request
.
sampling_params
# Process embeddings from encode worker using our embeddings processor
embeddings
,
descriptor
=
await
self
.
embeddings_processor
.
process_embeddings
(
request
)
# Create multimodal item for prefill generation
mm_item
=
self
.
embeddings_processor
.
create_multimodal_item
(
embeddings
,
request
)
mm_items
,
_
=
await
_build_mm_items
(
request
,
self
.
embeddings_processor
)
# Start SGLang prefill generation (like regular SGLang)
results
=
await
self
.
engine
.
async_generate
(
input_ids
=
input_ids
,
image_data
=
[
mm_item
]
,
image_data
=
mm_item
s
,
sampling_params
=
sampling_params
,
stream
=
True
,
bootstrap_host
=
self
.
bootstrap_host
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment