Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b5e3d603
Unverified
Commit
b5e3d603
authored
Jul 10, 2025
by
Mick
Committed by
GitHub
Jul 09, 2025
Browse files
vlm: support video as an input modality (#5888)
parent
4ed57807
Changes
42
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
522 additions
and
280 deletions
+522
-280
python/sglang/srt/models/qwen2_5_vl.py
python/sglang/srt/models/qwen2_5_vl.py
+10
-7
python/sglang/srt/models/qwen2_vl.py
python/sglang/srt/models/qwen2_vl.py
+12
-1
python/sglang/srt/models/vila.py
python/sglang/srt/models/vila.py
+8
-2
python/sglang/srt/multimodal/processors/base_processor.py
python/sglang/srt/multimodal/processors/base_processor.py
+197
-137
python/sglang/srt/multimodal/processors/deepseek_vl_v2.py
python/sglang/srt/multimodal/processors/deepseek_vl_v2.py
+1
-1
python/sglang/srt/multimodal/processors/gemma3.py
python/sglang/srt/multimodal/processors/gemma3.py
+4
-2
python/sglang/srt/multimodal/processors/gemma3n.py
python/sglang/srt/multimodal/processors/gemma3n.py
+1
-1
python/sglang/srt/multimodal/processors/internvl.py
python/sglang/srt/multimodal/processors/internvl.py
+1
-1
python/sglang/srt/multimodal/processors/janus_pro.py
python/sglang/srt/multimodal/processors/janus_pro.py
+1
-1
python/sglang/srt/multimodal/processors/kimi_vl.py
python/sglang/srt/multimodal/processors/kimi_vl.py
+1
-1
python/sglang/srt/multimodal/processors/minicpm.py
python/sglang/srt/multimodal/processors/minicpm.py
+4
-3
python/sglang/srt/multimodal/processors/mllama4.py
python/sglang/srt/multimodal/processors/mllama4.py
+1
-1
python/sglang/srt/multimodal/processors/phi4mm.py
python/sglang/srt/multimodal/processors/phi4mm.py
+1
-1
python/sglang/srt/multimodal/processors/pixtral.py
python/sglang/srt/multimodal/processors/pixtral.py
+1
-1
python/sglang/srt/multimodal/processors/qwen_vl.py
python/sglang/srt/multimodal/processors/qwen_vl.py
+203
-80
python/sglang/srt/multimodal/processors/vila.py
python/sglang/srt/multimodal/processors/vila.py
+1
-1
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+55
-30
test/srt/test_jinja_template_utils.py
test/srt/test_jinja_template_utils.py
+12
-7
test/srt/test_vision_openai_server_a.py
test/srt/test_vision_openai_server_a.py
+6
-0
test/srt/test_vision_openai_server_b.py
test/srt/test_vision_openai_server_b.py
+2
-2
No files found.
python/sglang/srt/models/qwen2_5_vl.py
View file @
b5e3d603
...
...
@@ -56,7 +56,6 @@ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInp
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2
import
Qwen2Model
from
sglang.srt.models.qwen2_vl
import
Qwen2VLVideoInputs
from
sglang.srt.utils
import
add_prefix
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -507,11 +506,15 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_grid_thw
)
return
image_embeds
def
_process_video_input
(
self
,
video_input
:
Qwen2VLVideoInputs
)
->
torch
.
Tensor
:
pixel_values_videos
=
video_input
[
"pixel_values_videos"
].
type
(
self
.
visual
.
dtype
)
video_embeds
=
self
.
visual
(
pixel_values_videos
,
grid_thw
=
video_input
[
"video_grid_thw"
]
)
def
get_video_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
# in qwen-vl, last dim is the same
pixel_values
=
torch
.
cat
(
[
getattr
(
item
,
"pixel_values_videos"
)
for
item
in
items
],
dim
=
0
).
type
(
self
.
visual
.
dtype
)
video_grid_thw
=
torch
.
concat
([
item
.
video_grid_thw
for
item
in
items
],
dim
=
0
)
assert
pixel_values
.
dim
()
==
2
,
pixel_values
.
dim
()
assert
video_grid_thw
.
dim
()
==
2
,
video_grid_thw
.
dim
()
video_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
video_grid_thw
)
return
video_embeds
def
get_input_embeddings
(
self
):
...
...
@@ -553,7 +556,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
language_model
=
self
.
model
,
image_data_embedding_func
=
self
.
get_image_feature
,
multimodal_model
=
self
,
positions
=
positions
,
)
...
...
python/sglang/srt/models/qwen2_vl.py
View file @
b5e3d603
...
...
@@ -493,6 +493,17 @@ class Qwen2VLForConditionalGeneration(nn.Module):
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_grid_thw
)
return
image_embeds
def
get_video_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
# in qwen-vl, last dim is the same
pixel_values
=
torch
.
cat
(
[
item
.
pixel_values_videos
for
item
in
items
],
dim
=
0
).
type
(
self
.
visual
.
dtype
)
video_grid_thw
=
torch
.
concat
([
item
.
video_grid_thw
for
item
in
items
],
dim
=
0
)
assert
pixel_values
.
dim
()
==
2
,
pixel_values
.
dim
()
assert
video_grid_thw
.
dim
()
==
2
,
video_grid_thw
.
dim
()
video_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
video_grid_thw
)
return
video_embeds
def
_process_video_input
(
self
,
video_input
:
Qwen2VLVideoInputs
)
->
torch
.
Tensor
:
pixel_values_videos
=
video_input
[
"pixel_values_videos"
].
type
(
self
.
visual
.
dtype
)
video_embeds
=
self
.
visual
(
...
...
@@ -538,7 +549,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
language_model
=
self
.
model
,
image_data_embedding_func
=
self
.
get_image_feature
,
multimodal_model
=
self
,
positions
=
positions
,
)
...
...
python/sglang/srt/models/vila.py
View file @
b5e3d603
...
...
@@ -17,7 +17,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.mm_utils
import
MultiModalityDataPaddingPatternMultimodalTokens
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
from
sglang.srt.managers.schedule_batch
import
(
Modality
,
MultimodalDataItem
,
MultimodalInputs
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.models.qwen2
import
Qwen2ForCausalLM
...
...
@@ -223,7 +227,9 @@ class VILAForConditionalGeneration(nn.Module):
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
language_model
=
self
.
llm
,
image_data_embedding_func
=
self
.
get_image_feature
,
data_embedding_funcs
=
{
Modality
.
IMAGE
:
self
.
get_image_feature
,
},
get_embedding
=
get_embedding
,
positions
=
positions
,
)
...
...
python/sglang/srt/multimodal/processors/base_processor.py
View file @
b5e3d603
...
...
@@ -5,7 +5,7 @@ import multiprocessing as mp
import
os
import
re
from
abc
import
ABC
,
abstractmethod
from
enum
import
Enum
from
functools
import
lru_cache
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
...
...
@@ -14,7 +14,7 @@ from PIL import Image
from
transformers
import
BaseImageProcessorFast
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.utils
import
encode_video
,
load_audio
,
load_image
from
sglang.srt.utils
import
load_audio
,
load_image
,
load_video
,
logger
@
dataclasses
.
dataclass
...
...
@@ -25,14 +25,22 @@ class BaseMultiModalProcessorOutput:
# frames loaded from image and video, in given order
images
:
Optional
[
list
[
Union
[
Image
.
Image
,
dict
]]]
=
None
# videos
videos
:
Optional
[
list
[
Union
[
torch
.
Tensor
,
dict
]]]
=
None
# audios
audios
:
Optional
[
list
[
Union
[
np
.
ndarray
,
dict
]]]
=
None
def
normalize
(
self
):
for
field_name
in
[
"images"
,
"audios"
]:
field
=
getattr
(
self
,
field_name
,
None
)
if
field
is
not
None
and
isinstance
(
field
,
list
)
and
len
(
field
)
==
0
:
setattr
(
self
,
field_name
,
None
)
def
organize_results
(
self
)
->
List
[
Tuple
[
Modality
,
Any
]]:
"""
:return: a list of results, with their corresponding modalities
"""
return
(
[(
Modality
.
IMAGE
,
data
)
for
data
in
self
.
images
]
+
[(
Modality
.
VIDEO
,
data
)
for
data
in
self
.
videos
]
+
[(
Modality
.
AUDIO
,
data
)
for
data
in
self
.
audios
]
)
@
dataclasses
.
dataclass
...
...
@@ -41,6 +49,10 @@ class MultimodalSpecialTokens:
video_token
:
Optional
[
Union
[
int
,
str
,
List
[
str
]]]
=
None
audio_token
:
Optional
[
Union
[
int
,
str
,
List
[
str
]]]
=
None
image_token_regex
:
Optional
[
re
.
Pattern
]
=
None
video_token_regex
:
Optional
[
re
.
Pattern
]
=
None
audio_token_regex
:
Optional
[
re
.
Pattern
]
=
None
def
convert_to_str
(
self
,
token
:
Union
[
str
,
int
],
processor
)
->
str
:
if
token
is
None
:
return
token
...
...
@@ -53,11 +65,29 @@ class MultimodalSpecialTokens:
self
.
video_token
=
self
.
convert_to_str
(
self
.
video_token
,
processor
)
self
.
audio_token
=
self
.
convert_to_str
(
self
.
audio_token
,
processor
)
image_token_regex
:
Optional
[
re
.
Pattern
]
=
None
video_token_regex
:
Optional
[
re
.
Pattern
]
=
None
audio_token_regex
:
Optional
[
re
.
Pattern
]
=
None
def
__post_init__
(
self
):
def
get_modality_of_token
(
self
,
token
)
->
Optional
[
Modality
]:
"""
:return: the modality associated with the given token, if the token is a special_token or matches with the multimodal token regex
"""
modality
=
{
self
.
image_token
:
Modality
.
IMAGE
,
self
.
video_token
:
Modality
.
VIDEO
,
self
.
audio_token
:
Modality
.
AUDIO
,
}.
get
(
token
)
if
modality
:
return
modality
for
regex
,
modality
in
[
(
self
.
image_token_regex
,
Modality
.
IMAGE
),
(
self
.
video_token_regex
,
Modality
.
VIDEO
),
(
self
.
audio_token_regex
,
Modality
.
AUDIO
),
]:
if
regex
and
regex
.
match
(
token
):
return
modality
return
None
def
parse_regex
(
self
):
if
self
.
image_token_regex
is
None
and
self
.
image_token
is
not
None
:
self
.
image_token_regex
=
re
.
compile
(
re
.
escape
(
self
.
image_token
))
if
self
.
video_token_regex
is
None
and
self
.
video_token
is
not
None
:
...
...
@@ -65,7 +95,7 @@ class MultimodalSpecialTokens:
if
self
.
audio_token_regex
is
None
and
self
.
audio_token
is
not
None
:
self
.
audio_token_regex
=
re
.
compile
(
re
.
escape
(
self
.
audio_token
))
def
co
llect
(
self
)
->
re
.
Pattern
:
def
co
mbine_regex
(
self
)
->
re
.
Pattern
:
tokens
=
[
self
.
image_token_regex
,
self
.
video_token_regex
,
...
...
@@ -105,6 +135,7 @@ class BaseMultimodalProcessor(ABC):
self
.
ATTR_NAME_TO_MODALITY
=
{
# Image-related attributes
"pixel_values"
:
Modality
.
IMAGE
,
"pixel_values_videos"
:
Modality
.
VIDEO
,
"image_sizes"
:
Modality
.
IMAGE
,
"image_grid_thw"
:
Modality
.
IMAGE
,
"image_emb_mask"
:
Modality
.
IMAGE
,
...
...
@@ -120,7 +151,7 @@ class BaseMultimodalProcessor(ABC):
"input_features"
:
Modality
.
AUDIO
,
"input_features_mask"
:
Modality
.
AUDIO
,
# Video-related attributes
"video_grid_thw
s
"
:
Modality
.
VIDEO
,
"video_grid_thw"
:
Modality
.
VIDEO
,
# Generic attributes that could apply to multiple modalities
# "precomputed_features" - handled specially as it can be any modality
}
...
...
@@ -196,20 +227,25 @@ class BaseMultimodalProcessor(ABC):
@
staticmethod
def
_load_single_item
(
data
,
is_video
,
is_audio
,
frame_count_limit
=
None
,
discard_alpha_channel
=
True
data
,
modality
:
Modality
,
frame_count_limit
=
None
,
discard_alpha_channel
=
True
):
"""Static method that can be pickled for multiprocessing"""
"""
Load a single multimodal data.
If data is precomputed, returns directly.
Static method that can be pickled for multiprocessing"""
if
isinstance
(
data
,
dict
):
return
data
try
:
if
is_audio
:
return
load_audio
(
data
)
elif
is_video
:
path
=
data
[
len
(
"video:"
)
:]
return
encode_video
(
path
,
frame_count_limit
)
else
:
if
modality
==
Modality
.
IMAGE
:
img
,
_
=
load_image
(
data
)
return
img
.
convert
(
"RGB"
)
if
discard_alpha_channel
else
img
elif
modality
==
Modality
.
VIDEO
:
return
load_video
(
data
,
frame_count_limit
)
elif
modality
==
Modality
.
AUDIO
:
return
load_audio
(
data
)
except
Exception
as
e
:
raise
RuntimeError
(
f
"Error while loading data
{
data
}
:
{
e
}
"
)
...
...
@@ -217,75 +253,78 @@ class BaseMultimodalProcessor(ABC):
self
,
text_parts
:
List
[
str
],
multimodal_tokens
:
MultimodalSpecialTokens
,
image_data
:
Optional
[
list
]
=
None
,
audio_data
:
Optional
[
list
]
=
None
,
data_iterators
:
dict
,
discard_alpha_channel
:
bool
=
True
,
):
image_estimated_frames_iter
:
Optional
[
iter
]
=
None
,
image_scaling_factor
:
float
=
1.0
,
max_image_frames
:
int
=
30
,
)
->
Tuple
[
List
,
List
]:
"""
load multimodal data parallelly
load multimodal data parallelly
using iterators.
"""
# TODO(mick): load from server_args, env, or sampling_params
MAX_NUM_FRAMES
=
30
estimated_frames_list
=
self
.
get_estimated_frames_list
(
image_data
=
image_data
)
total_frame_count
=
sum
(
estimated_frames_list
)
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
scaling_factor
=
min
(
1.0
,
MAX_NUM_FRAMES
/
max
(
1
,
total_frame_count
))
assert
len
(
image_data
)
==
len
(
estimated_frames_list
)
# Submit all tasks
futures
=
[]
task_info
=
[]
image_index
,
audio_index
=
0
,
0
for
text_part
in
text_parts
:
if
(
multimodal_tokens
.
image_token_regex
and
multimodal_tokens
.
image_token_regex
.
match
(
text_part
)
):
data
=
image_data
[
image_index
]
is_video
=
isinstance
(
data
,
str
)
and
data
.
startswith
(
"video:"
)
estimated_frames
=
estimated_frames_list
[
image_index
]
frame_count_limit
=
max
(
1
,
int
(
estimated_frames
*
scaling_factor
))
modality
=
multimodal_tokens
.
get_modality_of_token
(
text_part
)
if
modality
is
not
None
:
data_iterator
=
data_iterators
.
get
(
modality
)
if
data_iterator
is
None
:
raise
ValueError
(
f
"No data iterator found for token:
{
text_part
}
"
)
try
:
data
=
next
(
data_iterator
)
except
StopIteration
:
raise
ValueError
(
f
"Mismatch: More '
{
text_part
}
' tokens found than corresponding data items provided."
)
frame_count_limit
=
None
if
modality
==
Modality
.
IMAGE
and
image_estimated_frames_iter
:
try
:
estimated_frames
=
next
(
image_estimated_frames_iter
)
# Use the pre-calculated scaling factor and max frames
frame_count_limit
=
max
(
1
,
int
(
estimated_frames
*
image_scaling_factor
)
)
# Ensure we don't exceed the absolute max (redundant if scaling_factor handles it)
# frame_count_limit = min(frame_count_limit, max_image_frames)
except
StopIteration
:
raise
ValueError
(
"Mismatch between image tokens and estimated frame counts."
)
futures
.
append
(
self
.
io_executor
.
submit
(
BaseMultimodalProcessor
.
_load_single_item
,
data
,
is_video
,
False
,
modality
,
frame_count_limit
,
discard_alpha_channel
,
)
)
task_info
.
append
((
Modality
.
IMAGE
,
data
,
frame_count_limit
))
image_index
+=
1
elif
(
multimodal_tokens
.
audio_token_regex
and
multimodal_tokens
.
audio_token_regex
.
match
(
text_part
)
):
data
=
audio_data
[
audio_index
]
futures
.
append
(
self
.
io_executor
.
submit
(
BaseMultimodalProcessor
.
_load_single_item
,
data
,
False
,
True
,
None
,
discard_alpha_channel
,
)
task_info
.
append
((
modality
,
data
,
frame_count_limit
))
for
modality
,
iterator
in
data_iterators
.
items
():
try
:
next
(
iterator
)
logger
.
warning
(
f
"Warning: More
{
modality
.
name
.
lower
()
}
data items provided than corresponding tokens found in the prompt."
)
task_info
.
append
((
Modality
.
AUDIO
,
data
,
None
))
audio_index
+=
1
except
StopIteration
:
pass
except
Exception
:
pass
return
futures
,
task_info
def
load_mm_data
(
self
,
prompt
:
str
|
List
[
int
]
,
prompt
:
str
,
multimodal_tokens
:
MultimodalSpecialTokens
,
max_req_input_len
:
int
,
image_data
:
Optional
[
list
]
=
None
,
video_data
:
Optional
[
list
]
=
None
,
audio_data
:
Optional
[
list
]
=
None
,
return_text
:
Optional
[
bool
]
=
True
,
discard_alpha_channel
:
bool
=
True
,
...
...
@@ -299,14 +338,9 @@ class BaseMultimodalProcessor(ABC):
discard_alpha_channel: if True, discards the alpha channel in the returned images
"""
if
not
return_text
:
raise
NotImplementedError
()
if
image_data
is
None
:
image_data
=
[]
multimodal_tokens
.
convert_to_strs
(
self
.
_processor
)
multimodal_tokens
_
pa
ttern
=
multimodal_tokens
.
collect
()
multimodal_tokens
.
pa
rse_regex
()
multimodal_tokens_pattern
=
multimodal_tokens
.
combine_regex
()
if
isinstance
(
prompt
,
list
)
and
return_text
:
assert
len
(
prompt
)
and
isinstance
(
prompt
[
0
],
int
)
prompt
=
self
.
_processor
.
tokenizer
.
decode
(
prompt
)
...
...
@@ -317,59 +351,84 @@ class BaseMultimodalProcessor(ABC):
# split text into list of normal text and special tokens
text_parts
=
re
.
split
(
multimodal_tokens_pattern
,
prompt
)
# collect all data
data_iterators
=
{}
if
multimodal_tokens
.
image_token
and
image_data
:
data_iterators
[
Modality
.
IMAGE
]
=
iter
(
image_data
)
if
multimodal_tokens
.
video_token
and
video_data
:
data_iterators
[
Modality
.
VIDEO
]
=
iter
(
video_data
)
if
multimodal_tokens
.
audio_token
and
audio_data
:
data_iterators
[
Modality
.
AUDIO
]
=
iter
(
audio_data
)
# futures: the futures of loaded data
# task_info: modality, raw_data, and other metadata of each data
futures
,
task_info
=
self
.
submit_data_loading_tasks
(
text_parts
=
text_parts
,
multimodal_tokens
=
multimodal_tokens
,
image_data
=
image_data
,
audio_data
=
audio_data
,
data_iterators
=
data_iterators
,
discard_alpha_channel
=
discard_alpha_channel
,
)
# Process results
images
,
audios
=
[],
[]
new_text
=
""
task_ptr
=
0
task_info_iter
=
iter
(
task_info
)
futures_iter
=
iter
(
futures
)
# Process results
images
,
videos
,
audios
=
[],
[],
[]
new_text_parts
=
[]
for
text_part
in
text_parts
:
if
multimodal_tokens_pattern
.
match
(
text_part
):
task_type
,
data
,
frame_limit
=
task_info
[
task_ptr
]
result
=
futures
[
task_ptr
].
result
()
task_ptr
+=
1
if
task_type
==
Modality
.
IMAGE
:
# If data is already processed it will be a
# dictionary. In this case we want to keep the
# expanded tokens in text_part. Otherwise, we will
# call the processor code, so keep only a single image
# token.
mm_tokens
=
(
text_part
if
isinstance
(
data
,
dict
)
else
multimodal_tokens
.
image_token
)
frames
=
[
result
]
if
not
isinstance
(
result
,
list
)
else
result
if
frames
:
images
+=
frames
new_text
+=
mm_tokens
*
len
(
frames
)
elif
task_type
==
Modality
.
AUDIO
:
# audio
mm_tokens
=
(
text_part
if
isinstance
(
data
,
dict
)
else
multimodal_tokens
.
audio_token
)
audios
.
append
(
result
)
new_text
+=
mm_tokens
# TODO: handle video
else
:
new_text
+=
text_part
out
=
BaseMultiModalProcessorOutput
(
input_text
=
new_text
,
try
:
if
multimodal_tokens_pattern
.
match
(
text_part
):
modality
,
raw_data
,
frame_limit
=
next
(
task_info_iter
)
is_precomputed
=
isinstance
(
raw_data
,
dict
)
result
=
next
(
futures_iter
).
result
()
if
modality
==
Modality
.
IMAGE
:
# If data is already processed it will be a
# dictionary(precomputed). In this case we want to keep the
# expanded tokens in text_part. Otherwise, we will
# call the processor code, so keep only a single image
# token.
mm_tokens
=
(
text_part
if
is_precomputed
else
multimodal_tokens
.
image_token
)
frames
=
[
result
]
if
not
isinstance
(
result
,
list
)
else
result
if
frames
:
# only for minicpmv
images
+=
frames
new_text_parts
+=
mm_tokens
*
len
(
frames
)
elif
modality
==
Modality
.
VIDEO
:
# load as video
mm_tokens
=
(
text_part
if
is_precomputed
else
multimodal_tokens
.
video_token
)
videos
+=
[
result
]
new_text_parts
+=
mm_tokens
elif
modality
==
Modality
.
AUDIO
:
# audio
mm_tokens
=
(
text_part
if
is_precomputed
else
multimodal_tokens
.
audio_token
)
audios
+=
[
result
]
new_text_parts
+=
mm_tokens
else
:
# normal text
new_text_parts
+=
[
text_part
]
except
Exception
as
e
:
raise
RuntimeError
(
f
"An exception occurred while loading multimodal data:
{
e
}
"
)
return
BaseMultiModalProcessorOutput
(
images
=
images
,
audios
=
audios
,
videos
=
videos
,
input_text
=
""
.
join
(
new_text_parts
),
)
out
.
normalize
()
return
out
@
staticmethod
def
get_mm_items_offset
(
...
...
@@ -460,21 +519,19 @@ class BaseMultimodalProcessor(ABC):
)
except
ValueError
:
modality
=
Modality
.
IMAGE
if
modality
:
# Create item if needed
if
modality
not
in
items
:
items
[
modality
]
=
MultimodalDataItem
(
modality
=
modality
)
# Set attribute
if
hasattr
(
items
[
modality
],
attr_name
):
setattr
(
items
[
modality
],
attr_name
,
value
)
setattr
(
items
[
modality
],
attr_name
,
value
)
return
list
(
items
.
values
())
def
_process_and_collect_mm_items
(
self
,
input_text
:
str
,
images
=
None
,
audios
=
None
,
videos
=
None
,
**
kwargs
)
->
Tuple
[
List
[
MultimodalDataItem
],
torch
.
Tensor
]:
)
->
Tuple
[
List
[
MultimodalDataItem
],
torch
.
Tensor
,
dict
]:
"""
Helper method to process multimodal data and create mm_items in one step.
...
...
@@ -488,11 +545,11 @@ class BaseMultimodalProcessor(ABC):
input_ids
=
ret
[
"input_ids"
].
flatten
()
collected_items
=
self
.
collect_mm_items_from_processor_output
(
ret
)
return
collected_items
,
input_ids
return
collected_items
,
input_ids
,
ret
def
process_and_combine_mm_data
(
self
,
base_output
:
BaseMultiModalProcessorOutput
)
->
Tuple
[
List
[
MultimodalDataItem
],
torch
.
Tensor
]:
)
->
Tuple
[
List
[
MultimodalDataItem
],
torch
.
Tensor
,
dict
]:
"""
Process multimodal data and return the combined multimodal items and input_ids.
Supports mixed modalities (images and audio in the same request).
...
...
@@ -501,8 +558,7 @@ class BaseMultimodalProcessor(ABC):
Tuple of (list of mm_items, input_ids)
"""
# Collect all items and categorize them
all_items
=
(
base_output
.
images
or
[])
+
(
base_output
.
audios
or
[])
all_items
=
base_output
.
organize_results
()
# Handle text-only case
if
not
all_items
:
input_ids
=
self
.
_processor
.
tokenizer
(
...
...
@@ -510,19 +566,20 @@ class BaseMultimodalProcessor(ABC):
return_tensors
=
"pt"
,
add_special_tokens
=
True
,
).
input_ids
.
flatten
()
return
[],
input_ids
return
[],
input_ids
,
{}
dict_items
,
raw_images
,
raw_audios
=
[],
[],
[]
for
item
in
all_items
:
dict_items
,
raw_images
,
raw_audios
,
raw_videos
=
[],
[],
[],
[]
for
modality
,
item
in
all_items
:
if
isinstance
(
item
,
dict
):
dict_items
.
append
(
item
)
elif
isinstance
(
item
,
Image
.
Image
)
:
elif
modality
==
Modality
.
IMAGE
:
raw_images
.
append
(
item
)
elif
isinstance
(
item
,
np
.
ndarray
)
:
elif
modality
==
Modality
.
AUDIO
:
raw_audios
.
append
(
item
)
elif
modality
==
Modality
.
VIDEO
:
raw_videos
.
append
(
item
)
else
:
raise
ValueError
(
f
"Unknown multimodal item type:
{
type
(
item
)
}
"
)
# Process items and get input_ids
all_collected_items
=
[]
input_ids
=
None
...
...
@@ -534,13 +591,16 @@ class BaseMultimodalProcessor(ABC):
)
# Handle raw items (need processing)
if
raw_images
or
raw_audios
:
collected_items
,
input_ids
=
self
.
_process_and_collect_mm_items
(
if
raw_images
or
raw_audios
or
raw_videos
:
collected_items
,
input_ids
,
ret
=
self
.
_process_and_collect_mm_items
(
input_text
=
base_output
.
input_text
,
images
=
raw_images
,
audios
=
raw_audios
,
videos
=
raw_videos
,
)
all_collected_items
.
extend
(
collected_items
)
else
:
ret
=
None
# Fallback tokenization if no raw items were processed
if
input_ids
is
None
:
...
...
@@ -553,21 +613,21 @@ class BaseMultimodalProcessor(ABC):
# Add offsets to all items
for
mm_item
in
all_collected_items
:
if
mm_item
.
modality
in
[
Modality
.
IMAGE
,
Modality
.
MULTI_IMAGES
]:
mm_item
.
image_
offsets
=
self
.
get_mm_items_offset
(
mm_item
.
offsets
=
self
.
get_mm_items_offset
(
input_ids
=
input_ids
,
mm_token_id
=
self
.
IM_TOKEN_ID
,
)
elif
mm_item
.
modality
==
Modality
.
AUDIO
:
mm_item
.
audio_
offsets
=
self
.
get_mm_items_offset
(
mm_item
.
offsets
=
self
.
get_mm_items_offset
(
input_ids
=
input_ids
,
mm_token_id
=
self
.
AUDIO_TOKEN_ID
,
)
elif
mm_item
.
modality
==
Modality
.
VIDEO
:
mm_item
.
video_
offsets
=
self
.
get_mm_items_offset
(
mm_item
.
offsets
=
self
.
get_mm_items_offset
(
input_ids
=
input_ids
,
mm_token_id
=
self
.
VIDEO_TOKEN_ID
,
)
else
:
raise
ValueError
(
f
"Unknown modality:
{
mm_item
.
modality
}
"
)
return
all_collected_items
,
input_ids
return
all_collected_items
,
input_ids
,
ret
python/sglang/srt/multimodal/processors/deepseek_vl_v2.py
View file @
b5e3d603
...
...
@@ -69,7 +69,7 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
)
item
=
MultimodalDataItem
(
pixel_values
=
res
[
"images"
],
image_
offsets
=
image_offsets
,
offsets
=
image_offsets
,
modality
=
Modality
.
IMAGE
,
image_emb_mask
=
images_seq_mask
,
image_spatial_crop
=
batched_images_spatial_crop
,
...
...
python/sglang/srt/multimodal/processors/gemma3.py
View file @
b5e3d603
...
...
@@ -36,6 +36,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
*
args
,
**
kwargs
,
):
print
(
f
"
{
image_data
=
}
"
)
base_output
=
self
.
load_mm_data
(
prompt
=
input_text
,
image_data
=
image_data
,
...
...
@@ -46,8 +47,9 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
discard_alpha_channel
=
True
,
)
mm_items
,
input_ids
=
self
.
process_and_combine_mm_data
(
base_output
)
mm_items
,
input_ids
,
_
=
self
.
process_and_combine_mm_data
(
base_output
)
print
(
f
"
{
base_output
=
}
"
)
print
(
f
"
{
mm_items
=
}
"
)
return
{
"input_ids"
:
input_ids
.
tolist
(),
"mm_items"
:
mm_items
,
...
...
python/sglang/srt/multimodal/processors/gemma3n.py
View file @
b5e3d603
...
...
@@ -72,7 +72,7 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
),
)
mm_items
,
input_ids
=
self
.
process_and_combine_mm_data
(
base_output
)
mm_items
,
input_ids
,
_
=
self
.
process_and_combine_mm_data
(
base_output
)
return
{
"input_ids"
:
input_ids
.
tolist
(),
...
...
python/sglang/srt/multimodal/processors/internvl.py
View file @
b5e3d603
...
...
@@ -225,7 +225,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
MultimodalDataItem
(
pixel_values
=
pixel_values
,
modality
=
Modality
.
IMAGE
,
image_
offsets
=
image_offsets
,
offsets
=
image_offsets
,
)
]
...
...
python/sglang/srt/multimodal/processors/janus_pro.py
View file @
b5e3d603
...
...
@@ -49,7 +49,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
MultimodalDataItem
(
pixel_values
=
res
[
"pixel_values"
],
image_emb_mask
=
res
[
"images_emb_mask"
],
image_
offsets
=
image_offsets
,
offsets
=
image_offsets
,
modality
=
Modality
.
IMAGE
,
)
],
...
...
python/sglang/srt/multimodal/processors/kimi_vl.py
View file @
b5e3d603
...
...
@@ -39,7 +39,7 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
max_req_input_len
=
max_req_input_len
,
)
mm_items
,
input_ids
=
self
.
process_and_combine_mm_data
(
base_output
)
mm_items
,
input_ids
,
_
=
self
.
process_and_combine_mm_data
(
base_output
)
return
{
"input_ids"
:
input_ids
.
tolist
(),
...
...
python/sglang/srt/multimodal/processors/minicpm.py
View file @
b5e3d603
...
...
@@ -19,6 +19,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
self
.
image_token
=
"(<image>./</image>)"
self
.
audio_token
=
"(<audio>./</audio>)"
self
.
video_token
=
"(<video>./</video>)"
async
def
process_mm_data_async
(
self
,
...
...
@@ -36,6 +37,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
image_data
=
image_data
,
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
self
.
image_token
,
video_token
=
self
.
video_token
,
audio_token
=
self
.
audio_token
,
),
)
...
...
@@ -113,7 +115,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
if
len
(
pixel_values
)
!=
0
:
item
=
MultimodalDataItem
(
pixel_values
=
pixel_values
,
image_
offsets
=
image_offsets
,
offsets
=
image_offsets
,
tgt_size
=
tgt_sizes_flat
,
modality
=
Modality
.
IMAGE
,
)
...
...
@@ -135,11 +137,10 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
item
=
MultimodalDataItem
(
audio_features
=
[
res
[
"audio_features"
]],
audio_feature_lens
=
res
[
"audio_feature_lens"
],
audio_
offsets
=
audio_offsets
,
offsets
=
audio_offsets
,
modality
=
Modality
.
AUDIO
,
)
items
+=
[
item
]
return
{
"mm_items"
:
items
,
"input_ids"
:
input_ids
.
tolist
(),
...
...
python/sglang/srt/multimodal/processors/mllama4.py
View file @
b5e3d603
...
...
@@ -144,7 +144,7 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
MultimodalDataItem
(
pixel_values
=
processor_output
[
"pixel_values"
],
modality
=
Modality
.
IMAGE
,
image_
offsets
=
image_offsets
,
offsets
=
image_offsets
,
)
]
...
...
python/sglang/srt/multimodal/processors/phi4mm.py
View file @
b5e3d603
...
...
@@ -65,7 +65,7 @@ class Phi4MMImageProcessor(BaseMultimodalProcessor):
pixel_values
=
res
[
"input_image_embeds"
],
image_sizes
=
res
[
"image_sizes"
],
image_emb_mask
=
res
[
"image_attention_mask"
],
image_
offsets
=
image_offsets
,
offsets
=
image_offsets
,
modality
=
Modality
.
IMAGE
,
)
]
...
...
python/sglang/srt/multimodal/processors/pixtral.py
View file @
b5e3d603
...
...
@@ -106,7 +106,7 @@ class PixtralProcessor(BaseMultimodalProcessor):
pixel_values
=
processor_output
[
"pixel_values"
],
image_sizes
=
processor_output
[
"image_sizes"
],
modality
=
Modality
.
IMAGE
,
image_
offsets
=
image_offsets
,
offsets
=
image_offsets
,
)
]
...
...
python/sglang/srt/multimodal/processors/qwen_vl.py
View file @
b5e3d603
import
asyncio
import
math
import
os
import
re
from
typing
import
Dict
,
List
,
Union
from
typing
import
List
,
Union
import
torch
import
torchvision
from
PIL
import
Image
from
torchvision.transforms
import
InterpolationMode
from
sglang.srt.layers.rotary_embedding
import
MRotaryEmbedding
from
sglang.srt.models.qwen2_5_vl
import
Qwen2_5_VLForConditionalGeneration
...
...
@@ -12,6 +16,185 @@ from sglang.srt.multimodal.processors.base_processor import (
BaseMultimodalProcessor
as
SGLangBaseProcessor
,
)
from
sglang.srt.multimodal.processors.base_processor
import
MultimodalSpecialTokens
from
sglang.utils
import
logger
IMAGE_FACTOR
=
28
MIN_PIXELS
=
4
*
28
*
28
MAX_PIXELS
=
16384
*
28
*
28
MAX_RATIO
=
200
VIDEO_TOTAL_PIXELS
=
int
(
float
(
os
.
environ
.
get
(
"VIDEO_MAX_PIXELS"
,
128000
*
28
*
28
*
0.9
))
)
VIDEO_MIN_PIXELS
=
128
*
28
*
28
VIDEO_MAX_PIXELS
=
768
*
28
*
28
FRAME_FACTOR
=
2
FPS
=
2.0
FPS_MIN_FRAMES
=
4
FPS_MAX_FRAMES
=
768
def
smart_resize
(
height
:
int
,
width
:
int
,
factor
:
int
=
IMAGE_FACTOR
,
min_pixels
:
int
=
MIN_PIXELS
,
max_pixels
:
int
=
MAX_PIXELS
,
)
->
tuple
[
int
,
int
]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if
max
(
height
,
width
)
/
min
(
height
,
width
)
>
MAX_RATIO
:
raise
ValueError
(
f
"absolute aspect ratio must be smaller than
{
MAX_RATIO
}
, got
{
max
(
height
,
width
)
/
min
(
height
,
width
)
}
"
)
h_bar
=
max
(
factor
,
round_by_factor
(
height
,
factor
))
w_bar
=
max
(
factor
,
round_by_factor
(
width
,
factor
))
if
h_bar
*
w_bar
>
max_pixels
:
beta
=
math
.
sqrt
((
height
*
width
)
/
max_pixels
)
h_bar
=
floor_by_factor
(
height
/
beta
,
factor
)
w_bar
=
floor_by_factor
(
width
/
beta
,
factor
)
elif
h_bar
*
w_bar
<
min_pixels
:
beta
=
math
.
sqrt
(
min_pixels
/
(
height
*
width
))
h_bar
=
ceil_by_factor
(
height
*
beta
,
factor
)
w_bar
=
ceil_by_factor
(
width
*
beta
,
factor
)
return
h_bar
,
w_bar
def
resize_image
(
image
,
size_factor
:
int
=
IMAGE_FACTOR
)
->
Image
.
Image
:
width
,
height
=
image
.
size
min_pixels
=
MIN_PIXELS
max_pixels
=
MAX_PIXELS
resized_height
,
resized_width
=
smart_resize
(
height
,
width
,
factor
=
size_factor
,
min_pixels
=
min_pixels
,
max_pixels
=
max_pixels
,
)
image
=
image
.
resize
((
resized_width
,
resized_height
))
return
image
def
round_by_factor
(
number
:
int
,
factor
:
int
)
->
int
:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return
round
(
number
/
factor
)
*
factor
def
ceil_by_factor
(
number
:
int
,
factor
:
int
)
->
int
:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return
math
.
ceil
(
number
/
factor
)
*
factor
def
floor_by_factor
(
number
:
int
,
factor
:
int
)
->
int
:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return
math
.
floor
(
number
/
factor
)
*
factor
async
def
resize_image_async
(
image
):
return
resize_image
(
image
)
def
smart_nframes
(
ele
:
dict
,
total_frames
:
int
,
video_fps
:
int
|
float
,
)
->
int
:
"""calculate the number of frames for video used for model inputs.
Args:
ele (dict): a dict contains the configuration of video.
support either `fps` or `nframes`:
- nframes: the number of frames to extract for model inputs.
- fps: the fps to extract frames for model inputs.
- min_frames: the minimum number of frames of the video, only used when fps is provided.
- max_frames: the maximum number of frames of the video, only used when fps is provided.
total_frames (int): the original total number of frames of the video.
video_fps (int | float): the original fps of the video.
Raises:
ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
Returns:
int: the number of frames for video used for model inputs.
"""
assert
not
(
"fps"
in
ele
and
"nframes"
in
ele
),
"Only accept either `fps` or `nframes`"
if
"nframes"
in
ele
:
nframes
=
round_by_factor
(
ele
[
"nframes"
],
FRAME_FACTOR
)
else
:
fps
=
ele
.
get
(
"fps"
,
FPS
)
min_frames
=
ceil_by_factor
(
ele
.
get
(
"min_frames"
,
FPS_MIN_FRAMES
),
FRAME_FACTOR
)
max_frames
=
floor_by_factor
(
ele
.
get
(
"max_frames"
,
min
(
FPS_MAX_FRAMES
,
total_frames
)),
FRAME_FACTOR
)
nframes
=
total_frames
/
video_fps
*
fps
if
nframes
>
total_frames
:
logger
.
warning
(
f
"smart_nframes: nframes[
{
nframes
}
] > total_frames[
{
total_frames
}
]"
)
nframes
=
min
(
min
(
max
(
nframes
,
min_frames
),
max_frames
),
total_frames
)
nframes
=
floor_by_factor
(
nframes
,
FRAME_FACTOR
)
if
not
(
FRAME_FACTOR
<=
nframes
and
nframes
<=
total_frames
):
raise
ValueError
(
f
"nframes should in interval [
{
FRAME_FACTOR
}
,
{
total_frames
}
], but got
{
nframes
}
."
)
return
nframes
# process video, qwen-specific
async
def
preprocess_video
(
vr
,
image_factor
:
int
=
IMAGE_FACTOR
,
# vr: VideoReader, image_factor: int = IMAGE_FACTOR
)
->
torch
.
Tensor
:
ele
=
{}
total_frames
,
video_fps
=
len
(
vr
),
vr
.
get_avg_fps
()
nframes
=
smart_nframes
({},
total_frames
=
total_frames
,
video_fps
=
video_fps
)
idx
=
torch
.
linspace
(
0
,
total_frames
-
1
,
nframes
).
round
().
long
().
tolist
()
video
=
vr
.
get_batch
(
idx
).
asnumpy
()
video
=
torch
.
tensor
(
video
).
permute
(
0
,
3
,
1
,
2
)
# Convert to TCHW format
nframes
,
_
,
height
,
width
=
video
.
shape
min_pixels
=
ele
.
get
(
"min_pixels"
,
VIDEO_MIN_PIXELS
)
total_pixels
=
ele
.
get
(
"total_pixels"
,
VIDEO_TOTAL_PIXELS
)
max_pixels
=
max
(
min
(
VIDEO_MAX_PIXELS
,
total_pixels
/
nframes
*
FRAME_FACTOR
),
int
(
min_pixels
*
1.05
),
)
max_pixels_supposed
=
ele
.
get
(
"max_pixels"
,
max_pixels
)
if
max_pixels_supposed
>
max_pixels
:
logger
.
warning
(
f
"The given max_pixels[
{
max_pixels_supposed
}
] exceeds limit[
{
max_pixels
}
]."
)
max_pixels
=
min
(
max_pixels_supposed
,
max_pixels
)
if
"resized_height"
in
ele
and
"resized_width"
in
ele
:
resized_height
,
resized_width
=
smart_resize
(
ele
[
"resized_height"
],
ele
[
"resized_width"
],
factor
=
image_factor
,
)
else
:
resized_height
,
resized_width
=
smart_resize
(
height
,
width
,
factor
=
image_factor
,
min_pixels
=
min_pixels
,
max_pixels
=
max_pixels
,
)
video
=
torchvision
.
transforms
.
functional
.
resize
(
video
,
[
resized_height
,
resized_width
],
interpolation
=
InterpolationMode
.
BICUBIC
,
antialias
=
True
,
).
float
()
return
video
# Compatible with Qwen2VL and Qwen2_5VL
...
...
@@ -37,104 +220,44 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
self
.
MIN_PIXELS
=
4
*
28
*
28
self
.
MAX_PIXELS
=
16384
*
28
*
28
self
.
MAX_RATIO
=
200
# TODO(mick): move all MultimodalSpecialTokens initializations into processor init
self
.
mm_special_tokens
=
MultimodalSpecialTokens
(
image_token
=
self
.
IMAGE_TOKEN
,
image_token_regex
=
self
.
IMAGE_TOKEN_REGEX
,
video_token
=
self
.
VIDEO_TOKEN_ID
,
)
async
def
process_mm_data_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
,
Dict
]],
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
request_obj
,
max_req_input_len
,
*
args
,
**
kwargs
,
):
base_output
=
self
.
load_mm_data
(
prompt
=
input_text
,
image_data
=
image_data
,
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
self
.
IMAGE_TOKEN
,
image_token_regex
=
self
.
IMAGE_TOKEN_REGEX
,
),
video_data
=
request_obj
.
video_data
,
multimodal_tokens
=
self
.
mm_special_tokens
,
max_req_input_len
=
max_req_input_len
,
)
def
smart_resize
(
height
:
int
,
width
:
int
,
factor
:
int
=
self
.
IMAGE_FACTOR
,
min_pixels
:
int
=
self
.
MIN_PIXELS
,
max_pixels
:
int
=
self
.
MAX_PIXELS
,
)
->
tuple
[
int
,
int
]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if
max
(
height
,
width
)
/
min
(
height
,
width
)
>
self
.
MAX_RATIO
:
raise
ValueError
(
f
"absolute aspect ratio must be smaller than
{
self
.
MAX_RATIO
}
, got
{
max
(
height
,
width
)
/
min
(
height
,
width
)
}
"
)
h_bar
=
max
(
factor
,
round_by_factor
(
height
,
factor
))
w_bar
=
max
(
factor
,
round_by_factor
(
width
,
factor
))
if
h_bar
*
w_bar
>
max_pixels
:
beta
=
math
.
sqrt
((
height
*
width
)
/
max_pixels
)
h_bar
=
floor_by_factor
(
height
/
beta
,
factor
)
w_bar
=
floor_by_factor
(
width
/
beta
,
factor
)
elif
h_bar
*
w_bar
<
min_pixels
:
beta
=
math
.
sqrt
(
min_pixels
/
(
height
*
width
))
h_bar
=
ceil_by_factor
(
height
*
beta
,
factor
)
w_bar
=
ceil_by_factor
(
width
*
beta
,
factor
)
return
h_bar
,
w_bar
def
resize_image
(
image
,
size_factor
:
int
=
self
.
IMAGE_FACTOR
)
->
Image
.
Image
:
width
,
height
=
image
.
size
min_pixels
=
self
.
MIN_PIXELS
max_pixels
=
self
.
MAX_PIXELS
resized_height
,
resized_width
=
smart_resize
(
height
,
width
,
factor
=
size_factor
,
min_pixels
=
min_pixels
,
max_pixels
=
max_pixels
,
)
image
=
image
.
resize
((
resized_width
,
resized_height
))
return
image
def
round_by_factor
(
number
:
int
,
factor
:
int
)
->
int
:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return
round
(
number
/
factor
)
*
factor
def
ceil_by_factor
(
number
:
int
,
factor
:
int
)
->
int
:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return
math
.
ceil
(
number
/
factor
)
*
factor
def
floor_by_factor
(
number
:
int
,
factor
:
int
)
->
int
:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return
math
.
floor
(
number
/
factor
)
*
factor
async
def
resize_image_async
(
image
):
return
resize_image
(
image
)
# Qwen-specific: resize images if they are raw Image objects
if
base_output
.
images
and
isinstance
(
base_output
.
images
[
0
],
Image
.
Image
):
resize_tasks
=
[
resize_image_async
(
image
)
for
image
in
base_output
.
images
]
base_output
.
images
=
await
asyncio
.
gather
(
*
resize_tasks
)
video_grid_thw
=
None
# TODO
mm_items
,
input_ids
=
self
.
process_and_combine_mm_data
(
base_output
)
if
not
mm_items
:
# Note(Xinyuan): This is the case where image loading fails.
return
None
if
base_output
.
videos
:
base_output
.
videos
=
[
await
preprocess_video
(
video
)
for
video
in
base_output
.
videos
]
combined_mm_item
=
mm_items
[
0
]
# only image is supported for now
video_grid_thw
=
None
# TODO
second_per_grid_ts
=
getattr
(
combined_mm_item
,
"second_per_grid_ts"
,
None
)
mm_items
,
input_ids
,
ret
=
self
.
process_and_combine_mm_data
(
base_output
)
input_ids
=
input_ids
.
flatten
()
mrope_positions
,
mrope_position_delta
=
MRotaryEmbedding
.
get_rope_index
(
spatial_merge_size
=
self
.
hf_config
.
vision_config
.
spatial_merge_size
,
image_token_id
=
self
.
IM_TOKEN_ID
,
...
...
@@ -145,9 +268,9 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
self
.
hf_config
.
vision_config
,
"tokens_per_second"
,
None
),
input_ids
=
input_ids
.
unsqueeze
(
0
),
image_grid_thw
=
combined_mm_item
.
image_grid_thw
,
video_grid_thw
=
video_grid_thw
,
second_per_grid_ts
=
second_per_grid_ts
,
image_grid_thw
=
getattr
(
ret
,
"
image_grid_thw
"
,
None
)
,
video_grid_thw
=
getattr
(
ret
,
"
video_grid_thw
"
,
None
)
,
second_per_grid_ts
=
getattr
(
ret
,
"
second_per_grid_ts
"
,
None
)
,
)
mrope_positions
=
mrope_positions
.
squeeze
(
1
)
...
...
python/sglang/srt/multimodal/processors/vila.py
View file @
b5e3d603
...
...
@@ -57,7 +57,7 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
image_data
=
image_data
,
)
mm_items
,
input_ids
=
self
.
process_and_combine_mm_data
(
base_output
)
mm_items
,
input_ids
,
_
=
self
.
process_and_combine_mm_data
(
base_output
)
return
{
"input_ids"
:
input_ids
.
tolist
(),
...
...
python/sglang/srt/utils.py
View file @
b5e3d603
...
...
@@ -728,33 +728,6 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
return
audio
def
encode_video
(
video_path
,
frame_count_limit
=
None
):
# Lazy import because decord is not available on some arm platforms.
from
decord
import
VideoReader
,
cpu
if
not
os
.
path
.
exists
(
video_path
):
logger
.
error
(
f
"Video
{
video_path
}
does not exist"
)
return
[]
if
frame_count_limit
==
0
:
return
[]
def
uniform_sample
(
l
,
n
):
gap
=
len
(
l
)
/
n
idxs
=
[
int
(
i
*
gap
+
gap
/
2
)
for
i
in
range
(
n
)]
return
[
l
[
i
]
for
i
in
idxs
]
vr
=
VideoReader
(
video_path
,
ctx
=
cpu
(
0
))
sample_fps
=
round
(
vr
.
get_avg_fps
()
/
1
)
# FPS
frame_indices
=
[
i
for
i
in
range
(
0
,
len
(
vr
),
sample_fps
)]
if
frame_count_limit
is
not
None
and
len
(
frame_indices
)
>
frame_count_limit
:
frame_indices
=
uniform_sample
(
frame_indices
,
frame_count_limit
)
frames
=
vr
.
get_batch
(
frame_indices
).
asnumpy
()
frames
=
[
Image
.
fromarray
(
v
.
astype
(
"uint8"
))
for
v
in
frames
]
return
frames
def
load_image
(
image_file
:
Union
[
Image
.
Image
,
str
,
bytes
],
)
->
tuple
[
Image
.
Image
,
tuple
[
int
,
int
]]:
...
...
@@ -774,9 +747,6 @@ def load_image(
elif
image_file
.
startswith
(
"data:"
):
image_file
=
image_file
.
split
(
","
)[
1
]
image
=
Image
.
open
(
BytesIO
(
pybase64
.
b64decode
(
image_file
,
validate
=
True
)))
elif
image_file
.
startswith
(
"video:"
):
image_file
=
image_file
.
replace
(
"video:"
,
""
)
image
,
image_size
=
decode_video_base64
(
image_file
)
elif
isinstance
(
image_file
,
str
):
image
=
Image
.
open
(
BytesIO
(
pybase64
.
b64decode
(
image_file
,
validate
=
True
)))
else
:
...
...
@@ -785,6 +755,61 @@ def load_image(
return
image
,
image_size
def
load_video
(
video_file
:
Union
[
str
,
bytes
],
use_gpu
:
bool
=
True
):
# We import decord here to avoid a strange Segmentation fault (core dumped) issue.
from
decord
import
VideoReader
,
cpu
,
gpu
try
:
from
decord.bridge
import
decord_bridge
ctx
=
gpu
(
0
)
_
=
decord_bridge
.
get_ctx_device
(
ctx
)
except
Exception
:
ctx
=
cpu
(
0
)
tmp_file
=
None
vr
=
None
try
:
if
isinstance
(
video_file
,
bytes
):
tmp_file
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
,
suffix
=
".mp4"
)
tmp_file
.
write
(
video_file
)
tmp_file
.
close
()
vr
=
VideoReader
(
tmp_file
.
name
,
ctx
=
ctx
)
elif
isinstance
(
video_file
,
str
):
if
video_file
.
startswith
((
"http://"
,
"https://"
)):
timeout
=
int
(
os
.
getenv
(
"REQUEST_TIMEOUT"
,
"10"
))
response
=
requests
.
get
(
video_file
,
stream
=
True
,
timeout
=
timeout
)
response
.
raise_for_status
()
tmp_file
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
,
suffix
=
".mp4"
)
for
chunk
in
response
.
iter_content
(
chunk_size
=
8192
):
tmp_file
.
write
(
chunk
)
tmp_file
.
close
()
vr
=
VideoReader
(
tmp_file
.
name
,
ctx
=
ctx
)
elif
video_file
.
startswith
(
"data:"
):
_
,
encoded
=
video_file
.
split
(
","
,
1
)
video_bytes
=
base64
.
b64decode
(
encoded
)
tmp_file
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
,
suffix
=
".mp4"
)
tmp_file
.
write
(
video_bytes
)
tmp_file
.
close
()
vr
=
VideoReader
(
tmp_file
.
name
,
ctx
=
ctx
)
elif
os
.
path
.
isfile
(
video_file
):
vr
=
VideoReader
(
video_file
,
ctx
=
ctx
)
else
:
video_bytes
=
base64
.
b64decode
(
video_file
)
tmp_file
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
,
suffix
=
".mp4"
)
tmp_file
.
write
(
video_bytes
)
tmp_file
.
close
()
vr
=
VideoReader
(
tmp_file
.
name
,
ctx
=
ctx
)
else
:
raise
ValueError
(
f
"Unsupported video input type:
{
type
(
video_file
)
}
"
)
return
vr
finally
:
if
tmp_file
and
os
.
path
.
exists
(
tmp_file
.
name
):
os
.
unlink
(
tmp_file
.
name
)
def
suppress_other_loggers
():
warnings
.
filterwarnings
(
"ignore"
,
category
=
UserWarning
,
message
=
"The given NumPy array is not writable"
...
...
test/srt/test_jinja_template_utils.py
View file @
b5e3d603
...
...
@@ -3,7 +3,6 @@ Unit tests for Jinja chat template utils.
"""
import
unittest
from
unittest.mock
import
patch
from
sglang.srt.jinja_template_utils
import
(
detect_jinja_template_content_format
,
...
...
@@ -76,11 +75,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
}
image_data
=
[]
video_data
=
[]
audio_data
=
[]
modalities
=
[]
result
=
process_content_for_template_format
(
msg_dict
,
"openai"
,
image_data
,
audio_data
,
modalities
msg_dict
,
"openai"
,
image_data
,
video_data
,
audio_data
,
modalities
)
# Check that image_data was extracted
...
...
@@ -111,11 +111,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
}
image_data
=
[]
video_data
=
[]
audio_data
=
[]
modalities
=
[]
result
=
process_content_for_template_format
(
msg_dict
,
"string"
,
image_data
,
audio_data
,
modalities
msg_dict
,
"string"
,
image_data
,
video_data
,
audio_data
,
modalities
)
# For string format, should flatten to text only
...
...
@@ -139,11 +140,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
}
image_data
=
[]
video_data
=
[]
audio_data
=
[]
modalities
=
[]
result
=
process_content_for_template_format
(
msg_dict
,
"openai"
,
image_data
,
audio_data
,
modalities
msg_dict
,
"openai"
,
image_data
,
video_data
,
audio_data
,
modalities
)
# Check that audio_data was extracted
...
...
@@ -162,11 +164,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
msg_dict
=
{
"role"
:
"user"
,
"content"
:
"Hello world"
}
image_data
=
[]
video_data
=
[]
audio_data
=
[]
modalities
=
[]
result
=
process_content_for_template_format
(
msg_dict
,
"openai"
,
image_data
,
audio_data
,
modalities
msg_dict
,
"openai"
,
image_data
,
video_data
,
audio_data
,
modalities
)
# Should pass through unchanged
...
...
@@ -188,11 +191,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
}
image_data
=
[]
video_data
=
[]
audio_data
=
[]
modalities
=
[]
result
=
process_content_for_template_format
(
msg_dict
,
"openai"
,
image_data
,
audio_data
,
modalities
msg_dict
,
"openai"
,
image_data
,
video_data
,
audio_data
,
modalities
)
# Check that modalities was extracted
...
...
@@ -209,11 +213,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
}
image_data
=
[]
video_data
=
[]
audio_data
=
[]
modalities
=
[]
result
=
process_content_for_template_format
(
msg_dict
,
"string"
,
image_data
,
audio_data
,
modalities
msg_dict
,
"string"
,
image_data
,
video_data
,
audio_data
,
modalities
)
# None values should be filtered out
...
...
test/srt/test_vision_openai_server_a.py
View file @
b5e3d603
...
...
@@ -35,6 +35,9 @@ class TestQwen2VLServer(TestOpenAIVisionServer):
)
cls
.
base_url
+=
"/v1"
def
test_video_chat_completion
(
self
):
self
.
_test_video_chat_completion
()
class
TestQwen2_5_VLServer
(
TestOpenAIVisionServer
):
@
classmethod
...
...
@@ -54,6 +57,9 @@ class TestQwen2_5_VLServer(TestOpenAIVisionServer):
)
cls
.
base_url
+=
"/v1"
def
test_video_chat_completion
(
self
):
self
.
_test_video_chat_completion
()
class
TestVLMContextLengthIssue
(
CustomTestCase
):
@
classmethod
...
...
test/srt/test_vision_openai_server_b.py
View file @
b5e3d603
...
...
@@ -93,7 +93,7 @@ class TestJanusProServer(TestOpenAIVisionServer):
)
cls
.
base_url
+=
"/v1"
def
test_video_chat_completion
(
self
):
def
test_video_
images_
chat_completion
(
self
):
pass
def
test_single_image_chat_completion
(
self
):
...
...
@@ -170,7 +170,7 @@ class TestKimiVLServer(TestOpenAIVisionServer):
)
cls
.
base_url
+=
"/v1"
def
test_video_chat_completion
(
self
):
def
test_video_
images_
chat_completion
(
self
):
pass
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment