Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
adc95380
Unverified
Commit
adc95380
authored
Mar 05, 2026
by
Wang, Yi
Committed by
GitHub
Mar 04, 2026
Browse files
feat: multi-image in request support for sglang backend (#6068)
Signed-off-by:
Wang, Yi
<
yi.a.wang@intel.com
>
parent
203249e1
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
192 additions
and
78 deletions
+192
-78
components/src/dynamo/sglang/protocol.py
components/src/dynamo/sglang/protocol.py
+9
-6
components/src/dynamo/sglang/request_handlers/multimodal/encode_worker_handler.py
...lang/request_handlers/multimodal/encode_worker_handler.py
+111
-25
components/src/dynamo/sglang/request_handlers/multimodal/processor_handler.py
...o/sglang/request_handlers/multimodal/processor_handler.py
+25
-8
components/src/dynamo/sglang/request_handlers/multimodal/worker_handler.py
...namo/sglang/request_handlers/multimodal/worker_handler.py
+47
-39
No files found.
components/src/dynamo/sglang/protocol.py
View file @
adc95380
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
Dict
,
List
,
Literal
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
List
,
Literal
,
Optional
,
Tuple
,
Union
from
pydantic
import
BaseModel
,
ConfigDict
,
Field
from
pydantic
import
BaseModel
,
ConfigDict
,
Field
from
sglang.srt.entrypoints.openai.protocol
import
ChatCompletionRequest
from
sglang.srt.entrypoints.openai.protocol
import
ChatCompletionRequest
...
@@ -115,18 +115,21 @@ class MultiModalInput(BaseModel):
...
@@ -115,18 +115,21 @@ class MultiModalInput(BaseModel):
video_url
:
Optional
[
str
]
=
None
video_url
:
Optional
[
str
]
=
None
class
Sglang
Multi
m
odal
Request
(
BaseModel
):
class
Multi
M
odal
Group
(
BaseModel
):
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
request
:
PreprocessedRequest
multimodal_input
:
Optional
[
MultiModalInput
]
=
Field
(
default_factory
=
MultiModalInput
)
multimodal_input
:
Optional
[
MultiModalInput
]
=
Field
(
default_factory
=
MultiModalInput
)
image_grid_thw
:
Optional
[
List
[
Any
]]
=
None
image_grid_thw
:
Optional
[
List
[
Any
]]
=
None
class
SglangMultimodalRequest
(
BaseModel
):
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
request
:
PreprocessedRequest
multimodal_inputs
:
List
[
MultiModalGroup
]
=
Field
(
default_factory
=
list
)
# Shared embedding transfer metadata for the entire multimodal request.
embeddings_shape
:
Optional
[
embeddings_shape
:
Optional
[
Union
[
Tuple
[
int
,
int
],
Tuple
[
int
,
int
,
int
],
Tuple
[
int
,
int
,
int
,
int
]]
Union
[
Tuple
[
int
,
int
],
Tuple
[
int
,
int
,
int
],
Tuple
[
int
,
int
,
int
,
int
]]
]
=
None
]
=
None
serialized_request
:
Optional
[
connect
.
RdmaMetadata
]
=
None
serialized_request
:
Optional
[
connect
.
RdmaMetadata
]
=
None
# Processor metadata (e.g. image_grid_thw) carried from encode worker
# to PD/prefill worker for building the format="processor_output" mm_item.
processor_output
:
Optional
[
Dict
[
str
,
Any
]]
=
None
class
DisaggSglangMultimodalRequest
(
BaseModel
):
class
DisaggSglangMultimodalRequest
(
BaseModel
):
...
...
components/src/dynamo/sglang/request_handlers/multimodal/encode_worker_handler.py
View file @
adc95380
...
@@ -115,50 +115,136 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
...
@@ -115,50 +115,136 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
# The following steps encode the requested image for SGLang:
# The following steps encode the requested image for SGLang:
# 1. Pass the image URL to MMEncoder which loads, preprocesses, and
# 1. Pass the image URL to MMEncoder which loads, preprocesses, and
# runs the vision encoder.
# runs the vision encoder.
# 2.
Add a bat
ch
d
im
ension and store metadata on the reques
t.
# 2.
Expand ea
ch im
age placeholder token to match patch coun
t.
# 3.
Expand the single image placeholder token to match patch count
.
# 3.
Create a single NIXL descriptor for concatenated embeddings
.
# 4.
Create a NIXL descriptor and send embeddings
to downstream worker.
# 4.
Send request + metadata
to downstream worker.
# 5. Stream the downstream worker's response back to the caller.
# 5. Stream the downstream worker's response back to the caller.
try
:
try
:
if
not
request
.
multimodal_input
.
image_url
:
multimodal_groups
=
request
.
multimodal_inputs
raise
ValueError
(
"image_url is required for the encode worker."
)
if
not
multimodal_groups
:
raise
ValueError
(
"multimodal_inputs is required for the encode worker."
)
image_urls
=
[]
for
idx
,
mm_group
in
enumerate
(
multimodal_groups
):
mm_input
=
mm_group
.
multimodal_input
if
not
mm_input
or
not
mm_input
.
image_url
:
raise
ValueError
(
f
"image_url is required for the encode worker (index=
{
idx
}
)."
)
if
mm_input
.
video_url
is
not
None
:
raise
NotImplementedError
(
"video_url encoding is not supported in SGLang encode worker"
)
image_urls
.
append
(
mm_input
.
image_url
)
image_grid_dim
,
mm
_embedding
=
await
self
.
encoder
.
_encode
(
image_grid_dim
,
precomputed
_embedding
s
=
await
self
.
encoder
.
_encode
(
[
request
.
multimodal_input
.
image_url
]
image_url
s
)
)
image_grid_thw
=
(
image_grid_thw
_list
=
(
image_grid_dim
.
tolist
()
image_grid_dim
.
tolist
()
if
isinstance
(
image_grid_dim
,
torch
.
Tensor
)
if
isinstance
(
image_grid_dim
,
torch
.
Tensor
)
else
image_grid_dim
else
image_grid_dim
)
)
# Store the image data info in the request for downstream
if
len
(
image_grid_thw_list
)
!=
len
(
multimodal_groups
):
request
.
processor_output
=
{
"image_grid_thw"
:
image_grid_thw
}
raise
ValueError
(
"image_grid_thw size mismatch"
)
request
.
image_grid_thw
=
image_grid_thw
request
.
embeddings_shape
=
tuple
(
mm_embedding
.
shape
)
def
_build_token_counts
(
total_tokens
:
int
)
->
list
[
int
]:
if
total_tokens
<=
0
:
raise
ValueError
(
"Invalid token statistics for embeddings"
)
# image_grid_thw is [t, h, w]. We derive per-item relative sizes
# from spatial grid (h * w), then infer merge factor
# from the total embedding token count.
grid_sizes
=
[]
for
image_grid_thw
in
image_grid_thw_list
:
if
not
isinstance
(
image_grid_thw
,
list
)
or
len
(
image_grid_thw
)
!=
3
:
raise
ValueError
(
"Cannot split embeddings: invalid image_grid_thw"
)
grid_sizes
.
append
(
int
(
image_grid_thw
[
1
]
*
image_grid_thw
[
2
]))
total_grid_tokens
=
sum
(
grid_sizes
)
if
total_grid_tokens
<=
0
:
raise
ValueError
(
"Invalid grid statistics for embeddings"
)
if
total_grid_tokens
%
total_tokens
!=
0
:
raise
ValueError
(
"Cannot infer merge factor: grid token total is not divisible by embedding token total"
)
# Replace the single image token with multiple image tokens based on embedding shape
merge_factor
=
total_grid_tokens
//
total_tokens
image_token_id_index
=
request
.
request
.
token_ids
.
index
(
self
.
image_token_id
)
token_counts
=
[]
for
grid_count
in
grid_sizes
:
if
grid_count
%
merge_factor
!=
0
:
raise
ValueError
(
"Cannot split embeddings: per-image grid token count not divisible by inferred merge factor"
)
token_counts
.
append
(
grid_count
//
merge_factor
)
if
sum
(
token_counts
)
!=
total_tokens
:
raise
ValueError
(
"Cannot split embeddings: per-image token counts do not match embedding token total"
)
num_image_tokens
=
mm_embedding
.
shape
[
0
]
# Number of image patche
s
return
token_count
s
# Replace single image token with multiple image tokens
if
isinstance
(
precomputed_embeddings
,
torch
.
Tensor
):
request
.
request
.
token_ids
=
(
if
precomputed_embeddings
.
ndim
!=
2
:
request
.
request
.
token_ids
[:
image_token_id_index
]
raise
ValueError
(
+
[
self
.
image_token_id
]
*
num_image_tokens
"Unsupported embeddings tensor rank from encoder: "
+
request
.
request
.
token_ids
[
f
"
{
precomputed_embeddings
.
ndim
}
. Expected 2D [tokens, hidden]."
image_token_id_index
+
1
:
)
]
# Skip the original token
token_counts
=
_build_token_counts
(
precomputed_embeddings
.
shape
[
0
])
else
:
raise
ValueError
(
"Unsupported embeddings type from encoder: "
f
"
{
type
(
precomputed_embeddings
)
}
"
)
image_placeholder_count
=
request
.
request
.
token_ids
.
count
(
self
.
image_token_id
)
)
if
image_placeholder_count
<
len
(
multimodal_groups
):
raise
ValueError
(
"Not enough image placeholders in token_ids for provided images"
)
# Create descriptor for the multimodal data
# Keep per-image grid metadata in request groups for worker-side mm_item.
descriptor
=
connect
.
Descriptor
(
mm_embedding
)
for
idx
,
(
mm_group
,
image_grid_thw
)
in
enumerate
(
zip
(
multimodal_groups
,
image_grid_thw_list
)
):
mm_group
.
image_grid_thw
=
image_grid_thw
mm_group
.
multimodal_input
.
image_url
=
None
# Store shared serialized tensor metadata at request level.
request
.
embeddings_shape
=
tuple
(
precomputed_embeddings
.
shape
)
request
.
serialized_request
=
None
search_start
=
0
for
num_image_tokens
in
token_counts
:
try
:
image_token_id_index
=
request
.
request
.
token_ids
.
index
(
self
.
image_token_id
,
search_start
)
except
ValueError
as
e
:
raise
ValueError
(
"Not enough image tokens found for provided images"
)
from
e
request
.
request
.
token_ids
=
(
request
.
request
.
token_ids
[:
image_token_id_index
]
+
[
self
.
image_token_id
]
*
num_image_tokens
+
request
.
request
.
token_ids
[
image_token_id_index
+
1
:]
)
search_start
=
image_token_id_index
+
num_image_tokens
descriptor
=
connect
.
Descriptor
(
precomputed_embeddings
)
with
await
self
.
_connector
.
create_readable
(
descriptor
)
as
readable
:
with
await
self
.
_connector
.
create_readable
(
descriptor
)
as
readable
:
request
.
serialized_request
=
readable
.
metadata
()
request
.
serialized_request
=
readable
.
metadata
()
logger
.
debug
(
f
"Request:
{
request
.
model_dump_json
()
}
"
)
logger
.
debug
(
f
"Request:
{
request
.
model_dump_json
()
}
"
)
# Get the response generator from downstream worker
# Get the response generator from downstream worker
...
...
components/src/dynamo/sglang/request_handlers/multimodal/processor_handler.py
View file @
adc95380
...
@@ -17,6 +17,7 @@ from dynamo.sglang.multimodal_utils import (
...
@@ -17,6 +17,7 @@ from dynamo.sglang.multimodal_utils import (
process_sglang_stream_response
,
process_sglang_stream_response
,
)
)
from
dynamo.sglang.protocol
import
(
from
dynamo.sglang.protocol
import
(
MultiModalGroup
,
MultiModalInput
,
MultiModalInput
,
MultiModalRequest
,
MultiModalRequest
,
SglangMultimodalRequest
,
SglangMultimodalRequest
,
...
@@ -67,21 +68,37 @@ class MultimodalProcessorHandler(BaseWorkerHandler):
...
@@ -67,21 +68,37 @@ class MultimodalProcessorHandler(BaseWorkerHandler):
# If the request is not MultiModalRequest, convert it to MultiModalRequest
# If the request is not MultiModalRequest, convert it to MultiModalRequest
raw_request
=
MultiModalRequest
.
model_validate
(
raw_request
)
raw_request
=
MultiModalRequest
.
model_validate
(
raw_request
)
multimodal_input
=
MultiModalInput
()
image_urls
:
list
[
str
]
=
[]
video_url
:
str
|
None
=
None
for
message
in
raw_request
.
messages
:
for
message
in
raw_request
.
messages
:
for
item
in
message
.
content
:
for
item
in
message
.
content
:
if
item
.
type
==
"image_url"
:
if
item
.
type
==
"image_url"
:
multimodal_input
.
image_url
=
item
.
image_url
.
url
if
video_url
is
not
None
:
raise
ValueError
(
"Cannot provide both image and video URLs"
)
image_urls
.
append
(
item
.
image_url
.
url
)
elif
item
.
type
==
"video_url"
:
elif
item
.
type
==
"video_url"
:
if
multimodal_input
.
image_url
is
not
None
:
if
image_urls
:
raise
ValueError
(
"Cannot provide both image and video URLs"
)
raise
ValueError
(
"Cannot provide both image and video URLs"
)
multimodal_input
.
video_url
=
item
.
video_url
.
url
if
video_url
is
not
None
:
raise
ValueError
(
"Multiple video URLs are not supported"
)
video_url
=
item
.
video_url
.
url
if
multimodal_input
.
image_url
is
None
and
multimodal_input
.
video_url
is
None
:
if
not
image_urls
and
video_url
is
None
:
raise
ValueError
(
"Either image URL or video URL is required"
)
raise
ValueError
(
"Either image URL or video URL is required"
)
async
for
response
in
self
.
_generate
(
raw_request
,
multimodal_input
):
multimodal_groups
:
list
[
MultiModalGroup
]
=
[]
if
image_urls
:
multimodal_groups
=
[
MultiModalGroup
(
multimodal_input
=
MultiModalInput
(
image_url
=
url
))
for
url
in
image_urls
]
elif
video_url
is
not
None
:
multimodal_groups
=
[
MultiModalGroup
(
multimodal_input
=
MultiModalInput
(
video_url
=
video_url
))
]
async
for
response
in
self
.
_generate
(
raw_request
,
multimodal_groups
):
logger
.
debug
(
logger
.
debug
(
f
"Generated response type
{
type
(
response
)
}
, content:
{
response
}
"
f
"Generated response type
{
type
(
response
)
}
, content:
{
response
}
"
)
)
...
@@ -90,7 +107,7 @@ class MultimodalProcessorHandler(BaseWorkerHandler):
...
@@ -90,7 +107,7 @@ class MultimodalProcessorHandler(BaseWorkerHandler):
async
def
_generate
(
async
def
_generate
(
self
,
self
,
raw_request
:
MultiModalRequest
,
raw_request
:
MultiModalRequest
,
multimodal_
input
:
MultiModal
Input
,
multimodal_
groups
:
list
[
MultiModal
Group
]
,
):
):
# Generate a unique request ID for tracking
# Generate a unique request ID for tracking
request_id
=
str
(
uuid
.
uuid4
().
hex
)
request_id
=
str
(
uuid
.
uuid4
().
hex
)
...
@@ -103,7 +120,7 @@ class MultimodalProcessorHandler(BaseWorkerHandler):
...
@@ -103,7 +120,7 @@ class MultimodalProcessorHandler(BaseWorkerHandler):
worker_request
=
SglangMultimodalRequest
(
worker_request
=
SglangMultimodalRequest
(
request
=
sglang_request
,
request
=
sglang_request
,
multimodal_input
=
multimodal_
input
,
multimodal_input
s
=
multimodal_
groups
,
)
)
# Send to encoder worker
# Send to encoder worker
...
...
components/src/dynamo/sglang/request_handlers/multimodal/worker_handler.py
View file @
adc95380
...
@@ -81,16 +81,24 @@ class EmbeddingsProcessor:
...
@@ -81,16 +81,24 @@ class EmbeddingsProcessor:
self
.
_connector
=
connect
.
Connector
()
self
.
_connector
=
connect
.
Connector
()
async
def
process_embeddings
(
self
,
request
:
SglangMultimodalRequest
):
async
def
process_embeddings
(
self
,
request
:
SglangMultimodalRequest
):
"""Process embeddings from serialized request"""
"""Process one concatenated embedding tensor from serialized request."""
logger
.
debug
(
"Processing embeddings with shape: "
f
"
{
request
.
embeddings_shape
}
"
)
logger
.
debug
(
f
"Processing embeddings with shape:
{
request
.
embeddings_shape
}
"
)
multimodal_groups
=
request
.
multimodal_inputs
# Validate embeddings shape
if
not
multimodal_groups
:
if
request
.
embeddings_shape
is
None
or
len
(
request
.
embeddings_shape
)
<
2
:
raise
ValueError
(
"multimodal_inputs is required"
)
raise
ValueError
(
f
"Invalid embeddings shape:
{
request
.
embeddings_shape
}
"
)
serialized_request
=
request
.
serialized_request
embeddings_shape
=
request
.
embeddings_shape
if
serialized_request
is
None
:
raise
ValueError
(
"serialized_request is required on request"
)
if
embeddings_shape
is
None
:
raise
ValueError
(
"embeddings_shape is required on request"
)
if
len
(
embeddings_shape
)
<
2
:
raise
ValueError
(
f
"Invalid embeddings shape:
{
embeddings_shape
}
"
)
embeddings
=
torch
.
empty
(
embeddings
=
torch
.
empty
(
request
.
embeddings_shape
,
embeddings_shape
,
dtype
=
MultimodalConfig
.
EMBEDDINGS_DTYPE
,
dtype
=
MultimodalConfig
.
EMBEDDINGS_DTYPE
,
device
=
MultimodalConfig
.
EMBEDDINGS_DEVICE
,
device
=
MultimodalConfig
.
EMBEDDINGS_DEVICE
,
)
)
...
@@ -105,17 +113,13 @@ class EmbeddingsProcessor:
...
@@ -105,17 +113,13 @@ class EmbeddingsProcessor:
)
)
self
.
_connector
=
connect
.
Connector
()
self
.
_connector
=
connect
.
Connector
()
read_op
=
await
self
.
_connector
.
begin_read
(
read_op
=
await
self
.
_connector
.
begin_read
(
serialized_request
,
descriptor
)
request
.
serialized_request
,
descriptor
)
await
read_op
.
wait_for_completion
()
await
read_op
.
wait_for_completion
()
return
embeddings
,
descriptor
return
embeddings
,
descriptor
@
staticmethod
@
staticmethod
def
create_multimodal_item
(
def
create_multimodal_item
(
embeddings
:
torch
.
Tensor
,
image_grid_thw
)
->
dict
:
embeddings
:
torch
.
Tensor
,
request
:
SglangMultimodalRequest
)
->
dict
:
"""Create mm_item dict for SGLang's engine.async_generate(image_data=[...]).
"""Create mm_item dict for SGLang's engine.async_generate(image_data=[...]).
Uses format="processor_output" with precomputed_embeddings so SGLang
Uses format="processor_output" with precomputed_embeddings so SGLang
...
@@ -123,13 +127,7 @@ class EmbeddingsProcessor:
...
@@ -123,13 +127,7 @@ class EmbeddingsProcessor:
"""
"""
precomputed
=
embeddings
.
to
(
MultimodalConfig
.
EMBEDDINGS_DTYPE
)
precomputed
=
embeddings
.
to
(
MultimodalConfig
.
EMBEDDINGS_DTYPE
)
# Convert list fields back to tensors (JSON roundtrip loses tensor type)
mm_item
=
{
"image_grid_thw"
:
torch
.
tensor
(
image_grid_thw
)}
processor_output
=
request
.
processor_output
or
{}
for
key
,
value
in
processor_output
.
items
():
if
isinstance
(
value
,
list
):
processor_output
[
key
]
=
torch
.
tensor
(
value
)
mm_item
=
dict
(
processor_output
)
mm_item
.
update
(
mm_item
.
update
(
{
{
"format"
:
"processor_output"
,
"format"
:
"processor_output"
,
...
@@ -246,6 +244,23 @@ class ErrorResponseBuilder:
...
@@ -246,6 +244,23 @@ class ErrorResponseBuilder:
return
json
.
dumps
(
response
)
return
json
.
dumps
(
response
)
async
def
_build_mm_items
(
request
:
SglangMultimodalRequest
,
embeddings_processor
:
EmbeddingsProcessor
)
->
tuple
[
list
[
dict
],
torch
.
Tensor
]:
"""Process embeddings and build a single multimodal item for SGLang."""
embeddings
,
_
=
await
embeddings_processor
.
process_embeddings
(
request
)
image_grid_thw_list
=
[
group
.
image_grid_thw
for
group
in
request
.
multimodal_inputs
]
if
any
(
item
is
None
for
item
in
image_grid_thw_list
):
raise
ValueError
(
"image_grid_thw is required"
)
mm_items
=
[
embeddings_processor
.
create_multimodal_item
(
embeddings
,
image_grid_thw_list
)
]
return
mm_items
,
embeddings
class
MultimodalWorkerHandler
(
BaseWorkerHandler
):
class
MultimodalWorkerHandler
(
BaseWorkerHandler
):
"""
"""
Multimodal worker handler for LLM inference with multimodal data.
Multimodal worker handler for LLM inference with multimodal data.
...
@@ -355,23 +370,19 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
...
@@ -355,23 +370,19 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
try
:
try
:
sampling_params
=
SglangUtils
.
build_sampling_params
(
request
)
sampling_params
=
SglangUtils
.
build_sampling_params
(
request
)
embeddings
,
descriptor
=
await
self
.
embeddings_processor
.
process_embeddings
(
mm_items
,
combined_embeddings
=
await
_build_mm_items
(
request
request
,
self
.
embeddings_processor
)
# Create multimodal item
mm_item
=
self
.
embeddings_processor
.
create_multimodal_item
(
embeddings
,
request
)
)
logger
.
debug
(
logger
.
debug
(
f
"Generated multimodal item with embeddings shape:
{
embeddings
.
shape
}
"
"Generated combined multimodal item with embeddings shape: "
f
"
{
combined_embeddings
.
shape
}
"
)
)
logger
.
debug
(
f
"Input token sequence length:
{
len
(
input_ids
)
}
"
)
logger
.
debug
(
f
"Input token sequence length:
{
len
(
input_ids
)
}
"
)
agg_stream
=
await
self
.
engine
.
async_generate
(
agg_stream
=
await
self
.
engine
.
async_generate
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
image_data
=
[
mm_item
]
,
image_data
=
mm_item
s
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
stream
=
True
,
stream
=
True
,
)
)
...
@@ -385,12 +396,14 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
...
@@ -385,12 +396,14 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
"Shape mismatch error - this likely indicates a tokenization/embedding alignment issue"
"Shape mismatch error - this likely indicates a tokenization/embedding alignment issue"
)
)
logger
.
error
(
f
"Request token IDs length:
{
len
(
input_ids
)
}
"
)
logger
.
error
(
f
"Request token IDs length:
{
len
(
input_ids
)
}
"
)
logger
.
error
(
f
"Embeddings shape:
{
request
.
embeddings_shape
}
"
)
logger
.
error
(
"Embeddings shape:
"
f
"
{
request
.
embeddings_shape
}
"
)
logger
.
error
(
f
"Token sequence preview:
{
input_ids
[:
20
]
}
..."
)
logger
.
error
(
f
"Token sequence preview:
{
input_ids
[:
20
]
}
..."
)
error_msg
=
(
error_msg
=
(
f
"Multimodal embedding alignment error:
{
str
(
e
)
}
. "
f
"Multimodal embedding alignment error:
{
str
(
e
)
}
. "
f
"This usually happens when the tokenization changes between requests. "
f
"This usually happens when the tokenization changes between requests. "
f
"Token count:
{
len
(
input_ids
)
}
, Embedding shape:
{
request
.
embeddings_shape
}
"
"Token count: "
f
"
{
len
(
input_ids
)
}
, Embedding shape: "
f
"
{
request
.
embeddings_shape
}
"
)
)
yield
ErrorResponseBuilder
.
build_error_response
(
RuntimeError
(
error_msg
))
yield
ErrorResponseBuilder
.
build_error_response
(
RuntimeError
(
error_msg
))
else
:
else
:
...
@@ -515,17 +528,12 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
...
@@ -515,17 +528,12 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
sampling_params
=
disagg_request
.
sampling_params
sampling_params
=
disagg_request
.
sampling_params
# Process embeddings from encode worker using our embeddings processor
# Process embeddings from encode worker using our embeddings processor
embeddings
,
descriptor
=
await
self
.
embeddings_processor
.
process_embeddings
(
mm_items
,
_
=
await
_build_mm_items
(
request
,
self
.
embeddings_processor
)
request
)
# Create multimodal item for prefill generation
mm_item
=
self
.
embeddings_processor
.
create_multimodal_item
(
embeddings
,
request
)
# Start SGLang prefill generation (like regular SGLang)
# Start SGLang prefill generation (like regular SGLang)
results
=
await
self
.
engine
.
async_generate
(
results
=
await
self
.
engine
.
async_generate
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
image_data
=
[
mm_item
]
,
image_data
=
mm_item
s
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
stream
=
True
,
stream
=
True
,
bootstrap_host
=
self
.
bootstrap_host
,
bootstrap_host
=
self
.
bootstrap_host
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment