Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
cf9815ba
Unverified
Commit
cf9815ba
authored
Jun 04, 2025
by
Xinyuan Tong
Committed by
GitHub
Jun 04, 2025
Browse files
[Refactor] Multimodal data processing for VLM (#6659)
Signed-off-by:
Xinyuan Tong
<
justinning0323@outlook.com
>
parent
bd75690f
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
249 additions
and
168 deletions
+249
-168
docs/backend/vlm_query.ipynb
docs/backend/vlm_query.ipynb
+1
-1
python/sglang/srt/managers/multimodal_processors/base_processor.py
...lang/srt/managers/multimodal_processors/base_processor.py
+183
-25
python/sglang/srt/managers/multimodal_processors/gemma3.py
python/sglang/srt/managers/multimodal_processors/gemma3.py
+4
-31
python/sglang/srt/managers/multimodal_processors/kimi_vl.py
python/sglang/srt/managers/multimodal_processors/kimi_vl.py
+4
-42
python/sglang/srt/managers/multimodal_processors/qwen_vl.py
python/sglang/srt/managers/multimodal_processors/qwen_vl.py
+21
-48
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+4
-1
python/sglang/srt/models/gemma3_mm.py
python/sglang/srt/models/gemma3_mm.py
+20
-8
python/sglang/srt/models/kimi_vl.py
python/sglang/srt/models/kimi_vl.py
+4
-4
python/sglang/srt/models/qwen2_5_vl.py
python/sglang/srt/models/qwen2_5_vl.py
+3
-3
python/sglang/srt/models/qwen2_vl.py
python/sglang/srt/models/qwen2_vl.py
+3
-3
test/srt/test_vlm_input_format.py
test/srt/test_vlm_input_format.py
+2
-2
No files found.
docs/backend/vlm_query.ipynb
View file @
cf9815ba
...
@@ -132,7 +132,7 @@
...
@@ -132,7 +132,7 @@
"\n",
"\n",
"mm_item = dict(\n",
"mm_item = dict(\n",
" modality=\"IMAGE\",\n",
" modality=\"IMAGE\",\n",
" image_grid_thw
s
=processed_prompt[\"image_grid_thw\"],\n",
" image_grid_thw=processed_prompt[\"image_grid_thw\"],\n",
" precomputed_features=precomputed_features,\n",
" precomputed_features=precomputed_features,\n",
")\n",
")\n",
"out = llm.generate(input_ids=input_ids, image_data=[mm_item])\n",
"out = llm.generate(input_ids=input_ids, image_data=[mm_item])\n",
...
...
python/sglang/srt/managers/multimodal_processors/base_processor.py
View file @
cf9815ba
...
@@ -5,7 +5,8 @@ import multiprocessing as mp
...
@@ -5,7 +5,8 @@ import multiprocessing as mp
import
os
import
os
import
re
import
re
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Union
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -16,16 +17,24 @@ from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
...
@@ -16,16 +17,24 @@ from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from
sglang.srt.utils
import
encode_video
,
load_audio
,
load_image
from
sglang.srt.utils
import
encode_video
,
load_audio
,
load_image
class
MultimodalInputFormat
(
Enum
):
"""Enum for different multimodal input formats."""
RAW_IMAGES
=
"raw_images"
PRECOMPUTED_FEATURES
=
"precomputed_features"
PIXEL_VALUES
=
"pixel_values"
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
BaseMultiModalProcessorOutput
:
class
BaseMultiModalProcessorOutput
:
# input_text, with each frame of video/image represented with a image_token
# input_text, with each frame of video/image represented with a image_token
input_text
:
str
input_text
:
str
# frames loaded from image and video, in given order
# frames loaded from image and video, in given order
images
:
Optional
[
list
[
Union
[
Image
.
Image
,
MultimodalDataItem
]]]
=
None
images
:
Optional
[
list
[
Union
[
Image
.
Image
,
dict
]]]
=
None
# audios
# audios
audios
:
Optional
[
list
[
Union
[
np
.
ndarray
,
MultimodalDataItem
]]]
=
None
audios
:
Optional
[
list
[
Union
[
np
.
ndarray
,
dict
]]]
=
None
def
normalize
(
self
):
def
normalize
(
self
):
for
field_name
in
[
"images"
,
"audios"
]:
for
field_name
in
[
"images"
,
"audios"
]:
...
@@ -170,8 +179,6 @@ class BaseMultimodalProcessor(ABC):
...
@@ -170,8 +179,6 @@ class BaseMultimodalProcessor(ABC):
):
):
"""Static method that can be pickled for multiprocessing"""
"""Static method that can be pickled for multiprocessing"""
if
isinstance
(
data
,
dict
):
if
isinstance
(
data
,
dict
):
return
MultimodalDataItem
.
from_dict
(
data
)
if
isinstance
(
data
,
MultimodalDataItem
):
return
data
return
data
try
:
try
:
if
is_audio
:
if
is_audio
:
...
@@ -370,29 +377,180 @@ class BaseMultimodalProcessor(ABC):
...
@@ -370,29 +377,180 @@ class BaseMultimodalProcessor(ABC):
return
list
(
zip
(
indices_start
.
tolist
(),
indices_end
.
tolist
()))
return
list
(
zip
(
indices_start
.
tolist
(),
indices_end
.
tolist
()))
def
mm_inputs_are_preprocessed
(
self
,
mm_inputs
:
Optional
[
list
]):
"""Returns true if all images are preprocessed, false if all are not, and error otherwise."""
if
not
mm_inputs
:
return
True
ret
=
any
(
isinstance
(
mm_input
,
MultimodalDataItem
)
for
mm_input
in
mm_inputs
)
if
ret
and
not
all
(
isinstance
(
mm_input
,
MultimodalDataItem
)
for
mm_input
in
mm_inputs
):
raise
ValueError
(
"Unsupported: mixture of multimodal inputs where some but not all are preprocessed."
)
return
ret
@
staticmethod
@
staticmethod
def
_extract_processor_features
(
def
_extract_processor_features
(
items
:
List
[
Any
],
attr_name
:
str
items
:
List
[
dict
],
attr_name
:
str
)
->
Optional
[
torch
.
Tensor
]:
)
->
Optional
[
torch
.
Tensor
]:
"""
"""
Helper function to concat extracted attributes from processor output.
Helper function to concat extracted attributes from processor output.
"""
"""
values
=
[
values
=
[
value
for
item
in
items
if
(
value
:
=
item
.
get
(
attr_name
))
is
not
None
]
getattr
(
item
,
attr_name
)
return
torch
.
cat
(
values
)
if
values
else
None
for
item
in
items
if
getattr
(
item
,
attr_name
)
is
not
None
# When we assume that all the items have the same attributes
]
def
_extract_processor_features_from_all_attributes
(
return
torch
.
concat
(
values
)
if
values
else
None
self
,
items
:
List
[
dict
]
)
->
dict
:
values
=
{}
# Verify all items have the same keys
first_keys
=
set
(
items
[
0
].
keys
())
for
item
in
items
[
1
:]:
if
set
(
item
.
keys
())
!=
first_keys
:
raise
ValueError
(
f
"All items must have the same attributes. "
f
"First item has
{
first_keys
}
, but found
{
set
(
item
.
keys
())
}
"
)
# Process each attribute
for
k
,
v
in
items
[
0
].
items
():
if
isinstance
(
v
,
list
):
values
[
k
]
=
self
.
_extract_processor_features
(
items
,
k
)
else
:
# Verify all items have the same value for non-list attributes
for
item
in
items
[
1
:]:
if
item
[
k
]
!=
v
:
raise
ValueError
(
f
"All items must have the same value for attribute
{
k
}
. "
f
"First item has
{
v
}
, but found
{
item
[
k
]
}
"
)
values
[
k
]
=
v
return
values
def
process_and_combine_mm_data
(
self
,
base_output
:
BaseMultiModalProcessorOutput
)
->
Tuple
[
Optional
[
MultimodalDataItem
],
torch
.
Tensor
]:
"""
Process multimodal data and return the combined multimodal item and input_ids.
Handles all three input formats at the same abstraction level.
Returns:
Tuple of (combined_mm_item, input_ids)
"""
def
tokenize_text
(
input_text
:
str
)
->
torch
.
Tensor
:
"""Tokenize input text."""
return
self
.
_processor
.
tokenizer
(
input_text
,
return_tensors
=
"pt"
,
add_special_tokens
=
True
,
).
input_ids
.
flatten
()
def
categorize_mm_inputs
(
mm_inputs
:
List
)
->
MultimodalInputFormat
:
"""Categorize multimodal inputs and validate consistency."""
try
:
has_image
=
False
has_pixel_values
=
False
has_precomputed_features
=
False
for
mm_input
in
mm_inputs
:
if
isinstance
(
mm_input
,
Image
.
Image
):
has_image
=
True
elif
isinstance
(
mm_input
,
dict
):
if
mm_input
.
get
(
"precomputed_features"
,
None
)
is
not
None
:
has_precomputed_features
=
True
elif
mm_input
.
get
(
"pixel_values"
,
None
)
is
not
None
:
has_pixel_values
=
True
else
:
raise
ValueError
(
f
"Invalid multimodal input:
{
mm_input
}
, expected dict with pixel_values or precomputed_features"
)
else
:
raise
ValueError
(
f
"Invalid multimodal input:
{
mm_input
}
, expected Image.Image or dict"
)
# Validate format consistency
format_count
=
sum
(
[
has_image
,
has_pixel_values
,
has_precomputed_features
]
)
if
format_count
>
1
:
raise
ValueError
(
"Unsupported: mixture of multimodal input formats. "
f
"Found formats: image=
{
has_image
}
, pixel_values=
{
has_pixel_values
}
, "
f
"precomputed_features=
{
has_precomputed_features
}
"
)
if
has_image
:
return
MultimodalInputFormat
.
RAW_IMAGES
elif
has_precomputed_features
:
return
MultimodalInputFormat
.
PRECOMPUTED_FEATURES
elif
has_pixel_values
:
return
MultimodalInputFormat
.
PIXEL_VALUES
else
:
raise
ValueError
(
"No valid multimodal input format found"
)
except
Exception
as
e
:
raise
ValueError
(
f
"Failed to categorize inputs:
{
e
}
"
)
def
process_raw_images
(
base_output
:
BaseMultiModalProcessorOutput
,
)
->
Tuple
[
MultimodalDataItem
,
torch
.
Tensor
]:
"""Process raw Image.Image objects using transformers processor."""
ret
=
self
.
process_mm_data
(
input_text
=
base_output
.
input_text
,
images
=
base_output
.
images
,
)
combined_mm_item
=
MultimodalDataItem
(
modality
=
Modality
.
IMAGE
)
# Copy all fields from processor output except input_ids
for
key
,
value
in
ret
.
items
():
if
key
!=
"input_ids"
and
hasattr
(
combined_mm_item
,
key
):
setattr
(
combined_mm_item
,
key
,
value
)
input_ids
=
ret
[
"input_ids"
].
flatten
()
return
combined_mm_item
,
input_ids
def
process_precomputed_features
(
base_output
:
BaseMultiModalProcessorOutput
,
)
->
Tuple
[
MultimodalDataItem
,
torch
.
Tensor
]:
"""Process inputs with precomputed features."""
combined_mm_item
=
MultimodalDataItem
(
modality
=
Modality
.
IMAGE
)
combined_mm_item
.
precomputed_features
=
self
.
_extract_processor_features
(
base_output
.
images
,
"precomputed_features"
)
input_ids
=
tokenize_text
(
base_output
.
input_text
)
return
combined_mm_item
,
input_ids
def
process_pixel_values
(
base_output
:
BaseMultiModalProcessorOutput
,
)
->
Tuple
[
MultimodalDataItem
,
torch
.
Tensor
]:
"""Process inputs with pixel values."""
values
=
self
.
_extract_processor_features_from_all_attributes
(
base_output
.
images
)
combined_mm_item
=
MultimodalDataItem
.
from_dict
(
values
)
input_ids
=
tokenize_text
(
base_output
.
input_text
)
return
combined_mm_item
,
input_ids
def
finalize_mm_item
(
combined_mm_item
:
MultimodalDataItem
,
input_ids
:
torch
.
Tensor
)
->
MultimodalDataItem
:
"""Apply common post-processing to the multimodal item."""
combined_mm_item
.
image_offsets
=
self
.
get_mm_items_offset
(
input_ids
=
input_ids
,
mm_token_id
=
self
.
IM_TOKEN_ID
,
)
return
combined_mm_item
# Main logic
mm_inputs
=
base_output
.
images
if
not
mm_inputs
:
# Return text-only case
input_ids
=
tokenize_text
(
base_output
.
input_text
)
return
None
,
input_ids
# Categorize input formats
input_format
=
categorize_mm_inputs
(
mm_inputs
)
# Process based on format
if
input_format
==
MultimodalInputFormat
.
RAW_IMAGES
:
combined_mm_item
,
input_ids
=
process_raw_images
(
base_output
)
elif
input_format
==
MultimodalInputFormat
.
PRECOMPUTED_FEATURES
:
combined_mm_item
,
input_ids
=
process_precomputed_features
(
base_output
)
elif
input_format
==
MultimodalInputFormat
.
PIXEL_VALUES
:
combined_mm_item
,
input_ids
=
process_pixel_values
(
base_output
)
else
:
raise
ValueError
(
f
"Unknown input format:
{
input_format
}
"
)
# Finalize with common processing
combined_mm_item
=
finalize_mm_item
(
combined_mm_item
,
input_ids
)
return
combined_mm_item
,
input_ids
python/sglang/srt/managers/multimodal_processors/gemma3.py
View file @
cf9815ba
...
@@ -27,6 +27,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
...
@@ -27,6 +27,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
)
)
self
.
IM_START_TOKEN_ID
=
hf_config
.
boi_token_index
self
.
IM_START_TOKEN_ID
=
hf_config
.
boi_token_index
self
.
IM_END_TOKEN_ID
=
hf_config
.
eoi_token_index
self
.
IM_END_TOKEN_ID
=
hf_config
.
eoi_token_index
self
.
IM_TOKEN_ID
=
hf_config
.
image_token_index
async
def
process_mm_data_async
(
async
def
process_mm_data_async
(
self
,
self
,
...
@@ -42,49 +43,21 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
...
@@ -42,49 +43,21 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
if
isinstance
(
image_data
,
str
):
if
isinstance
(
image_data
,
str
):
image_data
=
[
image_data
]
image_data
=
[
image_data
]
image_token
=
self
.
IMAGE_TOKEN
image_token_regex
=
self
.
IMAGE_TOKEN_REGEX
base_output
=
self
.
load_mm_data
(
base_output
=
self
.
load_mm_data
(
prompt
=
input_text
,
prompt
=
input_text
,
image_data
=
image_data
,
image_data
=
image_data
,
multimodal_tokens
=
MultimodalSpecialTokens
(
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
image_token
,
image_token_regex
=
image_token_regex
image_token
=
self
.
IMAGE_TOKEN
,
image_token_regex
=
self
.
IMAGE_TOKEN_REGEX
),
),
max_req_input_len
=
max_req_input_len
,
max_req_input_len
=
max_req_input_len
,
discard_alpha_channel
=
True
,
discard_alpha_channel
=
True
,
)
)
images_are_preprocessed
=
self
.
mm_inputs_are_preprocessed
(
base_output
.
images
)
combined_mm_item
,
input_ids
=
self
.
process_and_combine_mm_data
(
base_output
)
ret
=
self
.
process_mm_data
(
input_text
=
base_output
.
input_text
,
images
=
None
if
images_are_preprocessed
else
base_output
.
images
,
)
items
=
[]
input_ids
=
ret
[
"input_ids"
].
flatten
()
image_offsets
=
self
.
get_mm_items_offset
(
input_ids
=
input_ids
,
mm_token_id
=
self
.
hf_config
.
image_token_index
,
)
for
i
,
image
in
enumerate
(
base_output
.
images
):
if
images_are_preprocessed
:
pixel_values
=
image
.
pixel_values
precomputed_features
=
image
.
precomputed_features
else
:
pixel_values
=
ret
[
"pixel_values"
][
i
]
precomputed_features
=
None
item
=
MultimodalDataItem
(
pixel_values
=
pixel_values
,
precomputed_features
=
precomputed_features
,
modality
=
Modality
.
IMAGE
,
image_offsets
=
image_offsets
[
i
],
)
items
+=
[
item
]
return
{
return
{
"mm_items"
:
items
,
"input_ids"
:
input_ids
.
tolist
(),
"input_ids"
:
input_ids
.
tolist
(),
"mm_items"
:
[
combined_mm_item
]
if
combined_mm_item
is
not
None
else
[],
"im_start_id"
:
self
.
IM_START_TOKEN_ID
,
"im_start_id"
:
self
.
IM_START_TOKEN_ID
,
"im_end_id"
:
self
.
IM_END_TOKEN_ID
,
"im_end_id"
:
self
.
IM_END_TOKEN_ID
,
}
}
python/sglang/srt/managers/multimodal_processors/kimi_vl.py
View file @
cf9815ba
...
@@ -21,7 +21,7 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
...
@@ -21,7 +21,7 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
self
.
IMAGE_TOKEN
=
"<|media_pad|>"
self
.
IMAGE_TOKEN
=
"<|media_pad|>"
self
.
IMAGE_TOKEN_REGEX
=
re
.
compile
(
r
"(?:<\|media_pad\|>)+"
)
self
.
IMAGE_TOKEN_REGEX
=
re
.
compile
(
r
"(?:<\|media_pad\|>)+"
)
self
.
im_token_id
=
_processor
.
tokenizer
.
convert_tokens_to_ids
(
self
.
IMAGE_TOKEN
)
self
.
IM_TOKEN_ID
=
_processor
.
tokenizer
.
convert_tokens_to_ids
(
self
.
IMAGE_TOKEN
)
async
def
process_mm_data_async
(
async
def
process_mm_data_async
(
self
,
self
,
...
@@ -46,48 +46,10 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
...
@@ -46,48 +46,10 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
max_req_input_len
=
max_req_input_len
,
max_req_input_len
=
max_req_input_len
,
)
)
images_are_preprocessed
=
self
.
mm_inputs_are_preprocessed
(
base_output
.
images
)
combined_mm_item
,
input_ids
=
self
.
process_and_combine_mm_data
(
base_output
)
if
not
images_are_preprocessed
:
ret
=
self
.
process_mm_data
(
input_text
=
base_output
.
input_text
,
images
=
base_output
.
images
,
)
input_ids
=
ret
[
"input_ids"
].
flatten
()
image_grid_thws
=
ret
[
"image_grid_hws"
]
pixel_values
=
ret
[
"pixel_values"
]
precomputed_features
=
None
else
:
input_ids
=
self
.
_processor
.
tokenizer
(
base_output
.
input_text
,
return_tensors
=
"pt"
,
add_special_tokens
=
True
,
).
input_ids
.
flatten
()
image_grid_thws
=
self
.
_extract_processor_features
(
base_output
.
images
,
"image_grid_thws"
)
precomputed_features
=
self
.
_extract_processor_features
(
base_output
.
images
,
"precomputed_features"
)
pixel_values
=
self
.
_extract_processor_features
(
base_output
.
images
,
"pixel_values"
)
image_offsets
=
self
.
get_mm_items_offset
(
input_ids
=
input_ids
,
mm_token_id
=
self
.
im_token_id
,
)
return
{
return
{
"input_ids"
:
input_ids
.
tolist
(),
"input_ids"
:
input_ids
.
tolist
(),
"mm_items"
:
[
"mm_items"
:
[
combined_mm_item
]
if
combined_mm_item
is
not
None
else
[],
MultimodalDataItem
(
"im_token_id"
:
self
.
IM_TOKEN_ID
,
pixel_values
=
pixel_values
,
image_grid_thws
=
image_grid_thws
,
precomputed_features
=
precomputed_features
,
modality
=
Modality
.
IMAGE
,
image_offsets
=
image_offsets
,
)
],
"im_token_id"
:
self
.
im_token_id
,
}
}
python/sglang/srt/managers/multimodal_processors/qwen_vl.py
View file @
cf9815ba
...
@@ -32,8 +32,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
...
@@ -32,8 +32,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
)
)
self
.
IM_START_TOKEN_ID
=
hf_config
.
vision_start_token_id
self
.
IM_START_TOKEN_ID
=
hf_config
.
vision_start_token_id
self
.
IM_END_TOKEN_ID
=
hf_config
.
vision_end_token_id
self
.
IM_END_TOKEN_ID
=
hf_config
.
vision_end_token_id
self
.
image_token_id
=
hf_config
.
image_token_id
self
.
IM_TOKEN_ID
=
hf_config
.
image_token_id
self
.
video_token_id
=
hf_config
.
video_token_id
self
.
VIDEO_TOKEN_ID
=
hf_config
.
video_token_id
self
.
vision_start_token_id
=
hf_config
.
vision_start_token_id
self
.
vision_start_token_id
=
hf_config
.
vision_start_token_id
self
.
vision_end_token_id
=
hf_config
.
vision_end_token_id
self
.
vision_end_token_id
=
hf_config
.
vision_end_token_id
self
.
NUM_TOKEN_PER_FRAME
=
770
self
.
NUM_TOKEN_PER_FRAME
=
770
...
@@ -125,72 +125,45 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
...
@@ -125,72 +125,45 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
async
def
resize_image_async
(
image
):
async
def
resize_image_async
(
image
):
return
resize_image
(
image
)
return
resize_image
(
image
)
images_are_preprocessed
=
self
.
mm_inputs_are_preprocessed
(
base_output
.
images
)
# Qwen-specific: resize images if they are raw Image objects
if
base_output
.
images
and
not
images_are_preprocessed
:
if
base_output
.
images
and
isinstance
(
base_output
.
images
[
0
],
Image
.
Image
)
:
resize_tasks
=
[
resize_image_async
(
image
)
for
image
in
base_output
.
images
]
resize_tasks
=
[
resize_image_async
(
image
)
for
image
in
base_output
.
images
]
base_output
.
images
=
await
asyncio
.
gather
(
*
resize_tasks
)
base_output
.
images
=
await
asyncio
.
gather
(
*
resize_tasks
)
ret
=
self
.
process_mm_data
(
input_text
=
base_output
.
input_text
,
images
=
None
if
images_are_preprocessed
else
base_output
.
images
,
)
input_ids
=
ret
[
"input_ids"
].
flatten
().
tolist
()
image_offsets
=
self
.
get_mm_items_offset
(
input_ids
=
ret
[
"input_ids"
].
flatten
(),
mm_token_id
=
self
.
image_token_id
)
image_grid_thw
=
None
video_grid_thw
=
None
# TODO
video_grid_thw
=
None
# TODO
items
=
[]
if
base_output
.
images
:
combined_mm_item
,
input_ids
=
self
.
process_and_combine_mm_data
(
base_output
)
if
images_are_preprocessed
:
image_grid_thw
=
self
.
_extract_processor_features
(
if
combined_mm_item
is
None
:
base_output
.
images
,
"image_grid_thws"
# Note(Xinyuan): This is the case where image loading fails.
)
return
None
precomputed_features
=
self
.
_extract_processor_features
(
base_output
.
images
,
"precomputed_features"
video_grid_thw
=
None
# TODO
)
second_per_grid_ts
=
getattr
(
combined_mm_item
,
"second_per_grid_ts"
,
None
)
pixel_values
=
self
.
_extract_processor_features
(
base_output
.
images
,
"pixel_values"
)
else
:
image_grid_thw
=
ret
[
"image_grid_thw"
]
pixel_values
=
ret
[
"pixel_values"
]
precomputed_features
=
None
items
+=
[
MultimodalDataItem
(
pixel_values
=
pixel_values
,
image_grid_thws
=
image_grid_thw
,
video_grid_thws
=
video_grid_thw
,
precomputed_features
=
precomputed_features
,
image_offsets
=
image_offsets
,
modality
=
Modality
.
IMAGE
,
)
]
mrope_positions
,
mrope_position_delta
=
MRotaryEmbedding
.
get_rope_index
(
mrope_positions
,
mrope_position_delta
=
MRotaryEmbedding
.
get_rope_index
(
spatial_merge_size
=
self
.
hf_config
.
vision_config
.
spatial_merge_size
,
spatial_merge_size
=
self
.
hf_config
.
vision_config
.
spatial_merge_size
,
image_token_id
=
self
.
image_token_id
,
image_token_id
=
self
.
IM_TOKEN_ID
,
video_token_id
=
self
.
video_token_id
,
video_token_id
=
self
.
VIDEO_TOKEN_ID
,
vision_start_token_id
=
self
.
vision_start_token_id
,
vision_start_token_id
=
self
.
vision_start_token_id
,
model_type
=
self
.
hf_config
.
model_type
,
model_type
=
self
.
hf_config
.
model_type
,
tokens_per_second
=
getattr
(
tokens_per_second
=
getattr
(
self
.
hf_config
.
vision_config
,
"tokens_per_second"
,
None
self
.
hf_config
.
vision_config
,
"tokens_per_second"
,
None
),
),
input_ids
=
torch
.
tensor
(
input_ids
)
.
unsqueeze
(
0
),
input_ids
=
input_ids
.
unsqueeze
(
0
),
image_grid_thw
=
image_grid_thw
,
image_grid_thw
=
combined_mm_item
.
image_grid_thw
,
video_grid_thw
=
video_grid_thw
,
video_grid_thw
=
video_grid_thw
,
second_per_grid_ts
=
ret
.
get
(
"
second_per_grid_ts
"
,
None
)
,
second_per_grid_ts
=
second_per_grid_ts
,
)
)
mrope_positions
=
mrope_positions
.
squeeze
(
1
)
mrope_positions
=
mrope_positions
.
squeeze
(
1
)
return
{
return
{
"input_ids"
:
input_ids
,
"input_ids"
:
input_ids
.
tolist
()
,
"mm_items"
:
item
s
,
"mm_items"
:
[
combined_mm_
item
]
,
"im_start_id"
:
self
.
IM_START_TOKEN_ID
,
"im_start_id"
:
self
.
IM_START_TOKEN_ID
,
"im_end_id"
:
self
.
IM_END_TOKEN_ID
,
"im_end_id"
:
self
.
IM_END_TOKEN_ID
,
"im_token_id"
:
self
.
image_token_id
,
"im_token_id"
:
self
.
IM_TOKEN_ID
,
"video_token_id"
:
self
.
video_token_id
,
"video_token_id"
:
self
.
VIDEO_TOKEN_ID
,
"mrope_positions"
:
mrope_positions
,
"mrope_positions"
:
mrope_positions
,
"mrope_position_delta"
:
mrope_position_delta
,
"mrope_position_delta"
:
mrope_position_delta
,
}
}
python/sglang/srt/managers/schedule_batch.py
View file @
cf9815ba
...
@@ -188,7 +188,7 @@ class MultimodalDataItem:
...
@@ -188,7 +188,7 @@ class MultimodalDataItem:
# the real data, pixel_values or audio_features
# the real data, pixel_values or audio_features
# data: Union[List[torch.Tensor], List[np.ndarray]]
# data: Union[List[torch.Tensor], List[np.ndarray]]
pixel_values
:
Union
[
torch
.
Tensor
,
np
.
ndarray
]
=
None
pixel_values
:
Union
[
torch
.
Tensor
,
np
.
ndarray
]
=
None
image_grid_thw
s
:
Union
[
torch
.
Tensor
,
np
.
ndarray
]
=
None
image_grid_thw
:
Union
[
torch
.
Tensor
,
np
.
ndarray
]
=
None
video_grid_thws
:
Union
[
torch
.
Tensor
,
np
.
ndarray
]
=
None
video_grid_thws
:
Union
[
torch
.
Tensor
,
np
.
ndarray
]
=
None
image_emb_mask
:
Optional
[
torch
.
Tensor
]
=
None
image_emb_mask
:
Optional
[
torch
.
Tensor
]
=
None
...
@@ -198,6 +198,9 @@ class MultimodalDataItem:
...
@@ -198,6 +198,9 @@ class MultimodalDataItem:
# [num_images, (n, w, h)]
# [num_images, (n, w, h)]
tgt_size
:
Tuple
[
int
,
int
]
=
None
tgt_size
:
Tuple
[
int
,
int
]
=
None
# kimi-vl related
image_grid_hws
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
audio_features
:
Union
[
torch
.
Tensor
,
np
.
ndarray
]
=
None
audio_features
:
Union
[
torch
.
Tensor
,
np
.
ndarray
]
=
None
audio_feature_lens
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
audio_feature_lens
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
audio_offsets
:
Optional
[
List
[
Tuple
[
int
,
int
]]]
=
None
audio_offsets
:
Optional
[
List
[
Tuple
[
int
,
int
]]]
=
None
...
...
python/sglang/srt/models/gemma3_mm.py
View file @
cf9815ba
...
@@ -286,14 +286,26 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
...
@@ -286,14 +286,26 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
all_pixel_values
=
flatten_nested_list
([
item
.
pixel_values
for
item
in
items
])
all_pixel_values
=
flatten_nested_list
([
item
.
pixel_values
for
item
in
items
])
vision_outputs_list
=
[]
vision_outputs_list
=
[]
for
pixel_value
in
all_pixel_values
:
for
pixel_values_batch
in
all_pixel_values
:
# Add batch dimension for single image processing
# Normalize input shape to [batch_size, channels, height, width]
pixel_value_batch
=
pixel_value
.
unsqueeze
(
0
)
if
pixel_values_batch
.
dim
()
==
5
:
pixel_value_batch
=
pixel_value_batch
.
to
(
device
=
self
.
vision_tower
.
device
)
pixel_values_batch
=
pixel_values_batch
.
squeeze
(
0
)
pixel_value_batch
=
pixel_value_batch
.
to
(
dtype
=
self
.
language_model
.
dtype
())
elif
pixel_values_batch
.
dim
()
==
3
:
pixel_values_batch
=
pixel_values_batch
.
unsqueeze
(
0
)
vision_output
=
self
.
vision_tower
(
pixel_values
=
pixel_value_batch
)
elif
pixel_values_batch
.
dim
()
!=
4
:
vision_outputs_list
.
append
(
vision_output
)
raise
ValueError
(
f
"Unexpected pixel_values shape:
{
pixel_values_batch
.
shape
}
"
)
# Process each image in the batch
batch_size
=
pixel_values_batch
.
shape
[
0
]
for
i
in
range
(
batch_size
):
pixel_value
=
pixel_values_batch
[
i
:
i
+
1
]
# Keep batch dimension as 1
pixel_value
=
pixel_value
.
to
(
device
=
self
.
vision_tower
.
device
,
dtype
=
self
.
language_model
.
dtype
()
)
vision_output
=
self
.
vision_tower
(
pixel_values
=
pixel_value
)
vision_outputs_list
.
append
(
vision_output
)
# Concatenate all vision outputs
# Concatenate all vision outputs
vision_outputs
=
torch
.
cat
(
vision_outputs_list
,
dim
=
0
)
vision_outputs
=
torch
.
cat
(
vision_outputs_list
,
dim
=
0
)
...
...
python/sglang/srt/models/kimi_vl.py
View file @
cf9815ba
...
@@ -144,10 +144,10 @@ class KimiVLForConditionalGeneration(nn.Module):
...
@@ -144,10 +144,10 @@ class KimiVLForConditionalGeneration(nn.Module):
.
type
(
self
.
vision_tower
.
dtype
)
.
type
(
self
.
vision_tower
.
dtype
)
.
to
(
self
.
vision_tower
.
device
)
.
to
(
self
.
vision_tower
.
device
)
)
)
image_grid_
t
hws
=
torch
.
c
oncat
(
image_grid_hws
=
torch
.
c
at
([
item
.
image_grid_hws
for
item
in
items
],
dim
=
0
).
to
(
[
item
.
image_grid_thws
for
item
in
items
],
dim
=
0
self
.
vision_tower
.
device
)
.
to
(
self
.
vision_tower
.
device
)
)
image_features
=
self
.
vision_tower
(
pixel_values
,
image_grid_
t
hws
)
image_features
=
self
.
vision_tower
(
pixel_values
,
image_grid_hws
)
assert
isinstance
(
image_features
,
list
)
assert
isinstance
(
image_features
,
list
)
# lengths = [x.shape[0] for x in image_features]
# lengths = [x.shape[0] for x in image_features]
res
=
self
.
multi_modal_projector
(
torch
.
cat
(
image_features
))
# .split(lengths)
res
=
self
.
multi_modal_projector
(
torch
.
cat
(
image_features
))
# .split(lengths)
...
...
python/sglang/srt/models/qwen2_5_vl.py
View file @
cf9815ba
...
@@ -503,10 +503,10 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
...
@@ -503,10 +503,10 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
pixel_values
=
torch
.
cat
([
item
.
pixel_values
for
item
in
items
],
dim
=
0
).
type
(
pixel_values
=
torch
.
cat
([
item
.
pixel_values
for
item
in
items
],
dim
=
0
).
type
(
self
.
visual
.
dtype
self
.
visual
.
dtype
)
)
image_grid_thw
s
=
torch
.
concat
([
item
.
image_grid_thw
s
for
item
in
items
],
dim
=
0
)
image_grid_thw
=
torch
.
concat
([
item
.
image_grid_thw
for
item
in
items
],
dim
=
0
)
assert
pixel_values
.
dim
()
==
2
,
pixel_values
.
dim
()
assert
pixel_values
.
dim
()
==
2
,
pixel_values
.
dim
()
assert
image_grid_thw
s
.
dim
()
==
2
,
image_grid_thw
s
.
dim
()
assert
image_grid_thw
.
dim
()
==
2
,
image_grid_thw
.
dim
()
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_grid_thw
s
)
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_grid_thw
)
return
image_embeds
return
image_embeds
def
_process_video_input
(
self
,
video_input
:
Qwen2VLVideoInputs
)
->
torch
.
Tensor
:
def
_process_video_input
(
self
,
video_input
:
Qwen2VLVideoInputs
)
->
torch
.
Tensor
:
...
...
python/sglang/srt/models/qwen2_vl.py
View file @
cf9815ba
...
@@ -490,10 +490,10 @@ class Qwen2VLForConditionalGeneration(nn.Module):
...
@@ -490,10 +490,10 @@ class Qwen2VLForConditionalGeneration(nn.Module):
pixel_values
=
torch
.
cat
([
item
.
pixel_values
for
item
in
items
],
dim
=
0
).
type
(
pixel_values
=
torch
.
cat
([
item
.
pixel_values
for
item
in
items
],
dim
=
0
).
type
(
self
.
visual
.
dtype
self
.
visual
.
dtype
)
)
image_grid_thw
s
=
torch
.
concat
([
item
.
image_grid_thw
s
for
item
in
items
],
dim
=
0
)
image_grid_thw
=
torch
.
concat
([
item
.
image_grid_thw
for
item
in
items
],
dim
=
0
)
assert
pixel_values
.
dim
()
==
2
,
pixel_values
.
dim
()
assert
pixel_values
.
dim
()
==
2
,
pixel_values
.
dim
()
assert
image_grid_thw
s
.
dim
()
==
2
,
image_grid_thw
s
.
dim
()
assert
image_grid_thw
.
dim
()
==
2
,
image_grid_thw
.
dim
()
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_grid_thw
s
)
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_grid_thw
)
return
image_embeds
return
image_embeds
def
_process_video_input
(
self
,
video_input
:
Qwen2VLVideoInputs
)
->
torch
.
Tensor
:
def
_process_video_input
(
self
,
video_input
:
Qwen2VLVideoInputs
)
->
torch
.
Tensor
:
...
...
test/srt/test_vlm_input_format.py
View file @
cf9815ba
...
@@ -156,7 +156,7 @@ class TestQwenVLUnderstandsImage(VLMInputTestBase, unittest.IsolatedAsyncioTestC
...
@@ -156,7 +156,7 @@ class TestQwenVLUnderstandsImage(VLMInputTestBase, unittest.IsolatedAsyncioTestC
def
_pixel_values_image_data
(
self
,
processor_output
):
def
_pixel_values_image_data
(
self
,
processor_output
):
return
dict
(
return
dict
(
modality
=
"IMAGE"
,
modality
=
"IMAGE"
,
image_grid_thw
s
=
processor_output
[
"image_grid_thw"
],
image_grid_thw
=
processor_output
[
"image_grid_thw"
],
pixel_values
=
processor_output
[
"pixel_values"
],
pixel_values
=
processor_output
[
"pixel_values"
],
)
)
...
@@ -207,8 +207,8 @@ class TestKimiVLImageUnderstandsImage(
...
@@ -207,8 +207,8 @@ class TestKimiVLImageUnderstandsImage(
def
_pixel_values_image_data
(
self
,
processor_output
):
def
_pixel_values_image_data
(
self
,
processor_output
):
return
dict
(
return
dict
(
modality
=
"IMAGE"
,
modality
=
"IMAGE"
,
image_grid_thws
=
processor_output
[
"image_grid_hws"
],
pixel_values
=
processor_output
[
"pixel_values"
],
pixel_values
=
processor_output
[
"pixel_values"
],
image_grid_hws
=
processor_output
[
"image_grid_hws"
],
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment