Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b5e3d603
Unverified
Commit
b5e3d603
authored
Jul 10, 2025
by
Mick
Committed by
GitHub
Jul 09, 2025
Browse files
vlm: support video as an input modality (#5888)
parent
4ed57807
Changes
42
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
522 additions
and
280 deletions
+522
-280
python/sglang/srt/models/qwen2_5_vl.py
python/sglang/srt/models/qwen2_5_vl.py
+10
-7
python/sglang/srt/models/qwen2_vl.py
python/sglang/srt/models/qwen2_vl.py
+12
-1
python/sglang/srt/models/vila.py
python/sglang/srt/models/vila.py
+8
-2
python/sglang/srt/multimodal/processors/base_processor.py
python/sglang/srt/multimodal/processors/base_processor.py
+197
-137
python/sglang/srt/multimodal/processors/deepseek_vl_v2.py
python/sglang/srt/multimodal/processors/deepseek_vl_v2.py
+1
-1
python/sglang/srt/multimodal/processors/gemma3.py
python/sglang/srt/multimodal/processors/gemma3.py
+4
-2
python/sglang/srt/multimodal/processors/gemma3n.py
python/sglang/srt/multimodal/processors/gemma3n.py
+1
-1
python/sglang/srt/multimodal/processors/internvl.py
python/sglang/srt/multimodal/processors/internvl.py
+1
-1
python/sglang/srt/multimodal/processors/janus_pro.py
python/sglang/srt/multimodal/processors/janus_pro.py
+1
-1
python/sglang/srt/multimodal/processors/kimi_vl.py
python/sglang/srt/multimodal/processors/kimi_vl.py
+1
-1
python/sglang/srt/multimodal/processors/minicpm.py
python/sglang/srt/multimodal/processors/minicpm.py
+4
-3
python/sglang/srt/multimodal/processors/mllama4.py
python/sglang/srt/multimodal/processors/mllama4.py
+1
-1
python/sglang/srt/multimodal/processors/phi4mm.py
python/sglang/srt/multimodal/processors/phi4mm.py
+1
-1
python/sglang/srt/multimodal/processors/pixtral.py
python/sglang/srt/multimodal/processors/pixtral.py
+1
-1
python/sglang/srt/multimodal/processors/qwen_vl.py
python/sglang/srt/multimodal/processors/qwen_vl.py
+203
-80
python/sglang/srt/multimodal/processors/vila.py
python/sglang/srt/multimodal/processors/vila.py
+1
-1
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+55
-30
test/srt/test_jinja_template_utils.py
test/srt/test_jinja_template_utils.py
+12
-7
test/srt/test_vision_openai_server_a.py
test/srt/test_vision_openai_server_a.py
+6
-0
test/srt/test_vision_openai_server_b.py
test/srt/test_vision_openai_server_b.py
+2
-2
No files found.
python/sglang/srt/models/qwen2_5_vl.py
View file @
b5e3d603
...
@@ -56,7 +56,6 @@ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInp
...
@@ -56,7 +56,6 @@ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInp
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2
import
Qwen2Model
from
sglang.srt.models.qwen2
import
Qwen2Model
from
sglang.srt.models.qwen2_vl
import
Qwen2VLVideoInputs
from
sglang.srt.utils
import
add_prefix
from
sglang.srt.utils
import
add_prefix
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -507,11 +506,15 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
...
@@ -507,11 +506,15 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_grid_thw
)
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_grid_thw
)
return
image_embeds
return
image_embeds
def
_process_video_input
(
self
,
video_input
:
Qwen2VLVideoInputs
)
->
torch
.
Tensor
:
def
get_video_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
pixel_values_videos
=
video_input
[
"pixel_values_videos"
].
type
(
self
.
visual
.
dtype
)
# in qwen-vl, last dim is the same
video_embeds
=
self
.
visual
(
pixel_values
=
torch
.
cat
(
pixel_values_videos
,
grid_thw
=
video_input
[
"video_grid_thw"
]
[
getattr
(
item
,
"pixel_values_videos"
)
for
item
in
items
],
dim
=
0
)
).
type
(
self
.
visual
.
dtype
)
video_grid_thw
=
torch
.
concat
([
item
.
video_grid_thw
for
item
in
items
],
dim
=
0
)
assert
pixel_values
.
dim
()
==
2
,
pixel_values
.
dim
()
assert
video_grid_thw
.
dim
()
==
2
,
video_grid_thw
.
dim
()
video_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
video_grid_thw
)
return
video_embeds
return
video_embeds
def
get_input_embeddings
(
self
):
def
get_input_embeddings
(
self
):
...
@@ -553,7 +556,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
...
@@ -553,7 +556,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
input_ids
=
input_ids
,
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
language_model
=
self
.
model
,
language_model
=
self
.
model
,
image_data_embedding_func
=
self
.
get_image_feature
,
multimodal_model
=
self
,
positions
=
positions
,
positions
=
positions
,
)
)
...
...
python/sglang/srt/models/qwen2_vl.py
View file @
b5e3d603
...
@@ -493,6 +493,17 @@ class Qwen2VLForConditionalGeneration(nn.Module):
...
@@ -493,6 +493,17 @@ class Qwen2VLForConditionalGeneration(nn.Module):
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_grid_thw
)
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_grid_thw
)
return
image_embeds
return
image_embeds
def
get_video_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
# in qwen-vl, last dim is the same
pixel_values
=
torch
.
cat
(
[
item
.
pixel_values_videos
for
item
in
items
],
dim
=
0
).
type
(
self
.
visual
.
dtype
)
video_grid_thw
=
torch
.
concat
([
item
.
video_grid_thw
for
item
in
items
],
dim
=
0
)
assert
pixel_values
.
dim
()
==
2
,
pixel_values
.
dim
()
assert
video_grid_thw
.
dim
()
==
2
,
video_grid_thw
.
dim
()
video_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
video_grid_thw
)
return
video_embeds
def
_process_video_input
(
self
,
video_input
:
Qwen2VLVideoInputs
)
->
torch
.
Tensor
:
def
_process_video_input
(
self
,
video_input
:
Qwen2VLVideoInputs
)
->
torch
.
Tensor
:
pixel_values_videos
=
video_input
[
"pixel_values_videos"
].
type
(
self
.
visual
.
dtype
)
pixel_values_videos
=
video_input
[
"pixel_values_videos"
].
type
(
self
.
visual
.
dtype
)
video_embeds
=
self
.
visual
(
video_embeds
=
self
.
visual
(
...
@@ -538,7 +549,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
...
@@ -538,7 +549,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
input_ids
=
input_ids
,
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
language_model
=
self
.
model
,
language_model
=
self
.
model
,
image_data_embedding_func
=
self
.
get_image_feature
,
multimodal_model
=
self
,
positions
=
positions
,
positions
=
positions
,
)
)
...
...
python/sglang/srt/models/vila.py
View file @
b5e3d603
...
@@ -17,7 +17,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
...
@@ -17,7 +17,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.mm_utils
import
MultiModalityDataPaddingPatternMultimodalTokens
from
sglang.srt.managers.mm_utils
import
MultiModalityDataPaddingPatternMultimodalTokens
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
from
sglang.srt.managers.schedule_batch
import
(
Modality
,
MultimodalDataItem
,
MultimodalInputs
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.models.qwen2
import
Qwen2ForCausalLM
from
sglang.srt.models.qwen2
import
Qwen2ForCausalLM
...
@@ -223,7 +227,9 @@ class VILAForConditionalGeneration(nn.Module):
...
@@ -223,7 +227,9 @@ class VILAForConditionalGeneration(nn.Module):
input_ids
=
input_ids
,
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
language_model
=
self
.
llm
,
language_model
=
self
.
llm
,
image_data_embedding_func
=
self
.
get_image_feature
,
data_embedding_funcs
=
{
Modality
.
IMAGE
:
self
.
get_image_feature
,
},
get_embedding
=
get_embedding
,
get_embedding
=
get_embedding
,
positions
=
positions
,
positions
=
positions
,
)
)
...
...
python/sglang/srt/multimodal/processors/base_processor.py
View file @
b5e3d603
This diff is collapsed.
Click to expand it.
python/sglang/srt/multimodal/processors/deepseek_vl_v2.py
View file @
b5e3d603
...
@@ -69,7 +69,7 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
...
@@ -69,7 +69,7 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
)
)
item
=
MultimodalDataItem
(
item
=
MultimodalDataItem
(
pixel_values
=
res
[
"images"
],
pixel_values
=
res
[
"images"
],
image_
offsets
=
image_offsets
,
offsets
=
image_offsets
,
modality
=
Modality
.
IMAGE
,
modality
=
Modality
.
IMAGE
,
image_emb_mask
=
images_seq_mask
,
image_emb_mask
=
images_seq_mask
,
image_spatial_crop
=
batched_images_spatial_crop
,
image_spatial_crop
=
batched_images_spatial_crop
,
...
...
python/sglang/srt/multimodal/processors/gemma3.py
View file @
b5e3d603
...
@@ -36,6 +36,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
...
@@ -36,6 +36,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
*
args
,
*
args
,
**
kwargs
,
**
kwargs
,
):
):
print
(
f
"
{
image_data
=
}
"
)
base_output
=
self
.
load_mm_data
(
base_output
=
self
.
load_mm_data
(
prompt
=
input_text
,
prompt
=
input_text
,
image_data
=
image_data
,
image_data
=
image_data
,
...
@@ -46,8 +47,9 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
...
@@ -46,8 +47,9 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
discard_alpha_channel
=
True
,
discard_alpha_channel
=
True
,
)
)
mm_items
,
input_ids
=
self
.
process_and_combine_mm_data
(
base_output
)
mm_items
,
input_ids
,
_
=
self
.
process_and_combine_mm_data
(
base_output
)
print
(
f
"
{
base_output
=
}
"
)
print
(
f
"
{
mm_items
=
}
"
)
return
{
return
{
"input_ids"
:
input_ids
.
tolist
(),
"input_ids"
:
input_ids
.
tolist
(),
"mm_items"
:
mm_items
,
"mm_items"
:
mm_items
,
...
...
python/sglang/srt/multimodal/processors/gemma3n.py
View file @
b5e3d603
...
@@ -72,7 +72,7 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
...
@@ -72,7 +72,7 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
),
),
)
)
mm_items
,
input_ids
=
self
.
process_and_combine_mm_data
(
base_output
)
mm_items
,
input_ids
,
_
=
self
.
process_and_combine_mm_data
(
base_output
)
return
{
return
{
"input_ids"
:
input_ids
.
tolist
(),
"input_ids"
:
input_ids
.
tolist
(),
...
...
python/sglang/srt/multimodal/processors/internvl.py
View file @
b5e3d603
...
@@ -225,7 +225,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
...
@@ -225,7 +225,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
MultimodalDataItem
(
MultimodalDataItem
(
pixel_values
=
pixel_values
,
pixel_values
=
pixel_values
,
modality
=
Modality
.
IMAGE
,
modality
=
Modality
.
IMAGE
,
image_
offsets
=
image_offsets
,
offsets
=
image_offsets
,
)
)
]
]
...
...
python/sglang/srt/multimodal/processors/janus_pro.py
View file @
b5e3d603
...
@@ -49,7 +49,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
...
@@ -49,7 +49,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
MultimodalDataItem
(
MultimodalDataItem
(
pixel_values
=
res
[
"pixel_values"
],
pixel_values
=
res
[
"pixel_values"
],
image_emb_mask
=
res
[
"images_emb_mask"
],
image_emb_mask
=
res
[
"images_emb_mask"
],
image_
offsets
=
image_offsets
,
offsets
=
image_offsets
,
modality
=
Modality
.
IMAGE
,
modality
=
Modality
.
IMAGE
,
)
)
],
],
...
...
python/sglang/srt/multimodal/processors/kimi_vl.py
View file @
b5e3d603
...
@@ -39,7 +39,7 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
...
@@ -39,7 +39,7 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
max_req_input_len
=
max_req_input_len
,
max_req_input_len
=
max_req_input_len
,
)
)
mm_items
,
input_ids
=
self
.
process_and_combine_mm_data
(
base_output
)
mm_items
,
input_ids
,
_
=
self
.
process_and_combine_mm_data
(
base_output
)
return
{
return
{
"input_ids"
:
input_ids
.
tolist
(),
"input_ids"
:
input_ids
.
tolist
(),
...
...
python/sglang/srt/multimodal/processors/minicpm.py
View file @
b5e3d603
...
@@ -19,6 +19,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
...
@@ -19,6 +19,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
self
.
image_token
=
"(<image>./</image>)"
self
.
image_token
=
"(<image>./</image>)"
self
.
audio_token
=
"(<audio>./</audio>)"
self
.
audio_token
=
"(<audio>./</audio>)"
self
.
video_token
=
"(<video>./</video>)"
async
def
process_mm_data_async
(
async
def
process_mm_data_async
(
self
,
self
,
...
@@ -36,6 +37,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
...
@@ -36,6 +37,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
image_data
=
image_data
,
image_data
=
image_data
,
multimodal_tokens
=
MultimodalSpecialTokens
(
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
self
.
image_token
,
image_token
=
self
.
image_token
,
video_token
=
self
.
video_token
,
audio_token
=
self
.
audio_token
,
audio_token
=
self
.
audio_token
,
),
),
)
)
...
@@ -113,7 +115,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
...
@@ -113,7 +115,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
if
len
(
pixel_values
)
!=
0
:
if
len
(
pixel_values
)
!=
0
:
item
=
MultimodalDataItem
(
item
=
MultimodalDataItem
(
pixel_values
=
pixel_values
,
pixel_values
=
pixel_values
,
image_
offsets
=
image_offsets
,
offsets
=
image_offsets
,
tgt_size
=
tgt_sizes_flat
,
tgt_size
=
tgt_sizes_flat
,
modality
=
Modality
.
IMAGE
,
modality
=
Modality
.
IMAGE
,
)
)
...
@@ -135,11 +137,10 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
...
@@ -135,11 +137,10 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
item
=
MultimodalDataItem
(
item
=
MultimodalDataItem
(
audio_features
=
[
res
[
"audio_features"
]],
audio_features
=
[
res
[
"audio_features"
]],
audio_feature_lens
=
res
[
"audio_feature_lens"
],
audio_feature_lens
=
res
[
"audio_feature_lens"
],
audio_
offsets
=
audio_offsets
,
offsets
=
audio_offsets
,
modality
=
Modality
.
AUDIO
,
modality
=
Modality
.
AUDIO
,
)
)
items
+=
[
item
]
items
+=
[
item
]
return
{
return
{
"mm_items"
:
items
,
"mm_items"
:
items
,
"input_ids"
:
input_ids
.
tolist
(),
"input_ids"
:
input_ids
.
tolist
(),
...
...
python/sglang/srt/multimodal/processors/mllama4.py
View file @
b5e3d603
...
@@ -144,7 +144,7 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
...
@@ -144,7 +144,7 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
MultimodalDataItem
(
MultimodalDataItem
(
pixel_values
=
processor_output
[
"pixel_values"
],
pixel_values
=
processor_output
[
"pixel_values"
],
modality
=
Modality
.
IMAGE
,
modality
=
Modality
.
IMAGE
,
image_
offsets
=
image_offsets
,
offsets
=
image_offsets
,
)
)
]
]
...
...
python/sglang/srt/multimodal/processors/phi4mm.py
View file @
b5e3d603
...
@@ -65,7 +65,7 @@ class Phi4MMImageProcessor(BaseMultimodalProcessor):
...
@@ -65,7 +65,7 @@ class Phi4MMImageProcessor(BaseMultimodalProcessor):
pixel_values
=
res
[
"input_image_embeds"
],
pixel_values
=
res
[
"input_image_embeds"
],
image_sizes
=
res
[
"image_sizes"
],
image_sizes
=
res
[
"image_sizes"
],
image_emb_mask
=
res
[
"image_attention_mask"
],
image_emb_mask
=
res
[
"image_attention_mask"
],
image_
offsets
=
image_offsets
,
offsets
=
image_offsets
,
modality
=
Modality
.
IMAGE
,
modality
=
Modality
.
IMAGE
,
)
)
]
]
...
...
python/sglang/srt/multimodal/processors/pixtral.py
View file @
b5e3d603
...
@@ -106,7 +106,7 @@ class PixtralProcessor(BaseMultimodalProcessor):
...
@@ -106,7 +106,7 @@ class PixtralProcessor(BaseMultimodalProcessor):
pixel_values
=
processor_output
[
"pixel_values"
],
pixel_values
=
processor_output
[
"pixel_values"
],
image_sizes
=
processor_output
[
"image_sizes"
],
image_sizes
=
processor_output
[
"image_sizes"
],
modality
=
Modality
.
IMAGE
,
modality
=
Modality
.
IMAGE
,
image_
offsets
=
image_offsets
,
offsets
=
image_offsets
,
)
)
]
]
...
...
python/sglang/srt/multimodal/processors/qwen_vl.py
View file @
b5e3d603
import
asyncio
import
asyncio
import
math
import
math
import
os
import
re
import
re
from
typing
import
Dict
,
List
,
Union
from
typing
import
List
,
Union
import
torch
import
torchvision
from
PIL
import
Image
from
PIL
import
Image
from
torchvision.transforms
import
InterpolationMode
from
sglang.srt.layers.rotary_embedding
import
MRotaryEmbedding
from
sglang.srt.layers.rotary_embedding
import
MRotaryEmbedding
from
sglang.srt.models.qwen2_5_vl
import
Qwen2_5_VLForConditionalGeneration
from
sglang.srt.models.qwen2_5_vl
import
Qwen2_5_VLForConditionalGeneration
...
@@ -12,6 +16,185 @@ from sglang.srt.multimodal.processors.base_processor import (
...
@@ -12,6 +16,185 @@ from sglang.srt.multimodal.processors.base_processor import (
BaseMultimodalProcessor
as
SGLangBaseProcessor
,
BaseMultimodalProcessor
as
SGLangBaseProcessor
,
)
)
from
sglang.srt.multimodal.processors.base_processor
import
MultimodalSpecialTokens
from
sglang.srt.multimodal.processors.base_processor
import
MultimodalSpecialTokens
from
sglang.utils
import
logger
IMAGE_FACTOR
=
28
MIN_PIXELS
=
4
*
28
*
28
MAX_PIXELS
=
16384
*
28
*
28
MAX_RATIO
=
200
VIDEO_TOTAL_PIXELS
=
int
(
float
(
os
.
environ
.
get
(
"VIDEO_MAX_PIXELS"
,
128000
*
28
*
28
*
0.9
))
)
VIDEO_MIN_PIXELS
=
128
*
28
*
28
VIDEO_MAX_PIXELS
=
768
*
28
*
28
FRAME_FACTOR
=
2
FPS
=
2.0
FPS_MIN_FRAMES
=
4
FPS_MAX_FRAMES
=
768
def
smart_resize
(
height
:
int
,
width
:
int
,
factor
:
int
=
IMAGE_FACTOR
,
min_pixels
:
int
=
MIN_PIXELS
,
max_pixels
:
int
=
MAX_PIXELS
,
)
->
tuple
[
int
,
int
]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if
max
(
height
,
width
)
/
min
(
height
,
width
)
>
MAX_RATIO
:
raise
ValueError
(
f
"absolute aspect ratio must be smaller than
{
MAX_RATIO
}
, got
{
max
(
height
,
width
)
/
min
(
height
,
width
)
}
"
)
h_bar
=
max
(
factor
,
round_by_factor
(
height
,
factor
))
w_bar
=
max
(
factor
,
round_by_factor
(
width
,
factor
))
if
h_bar
*
w_bar
>
max_pixels
:
beta
=
math
.
sqrt
((
height
*
width
)
/
max_pixels
)
h_bar
=
floor_by_factor
(
height
/
beta
,
factor
)
w_bar
=
floor_by_factor
(
width
/
beta
,
factor
)
elif
h_bar
*
w_bar
<
min_pixels
:
beta
=
math
.
sqrt
(
min_pixels
/
(
height
*
width
))
h_bar
=
ceil_by_factor
(
height
*
beta
,
factor
)
w_bar
=
ceil_by_factor
(
width
*
beta
,
factor
)
return
h_bar
,
w_bar
def
resize_image
(
image
,
size_factor
:
int
=
IMAGE_FACTOR
)
->
Image
.
Image
:
width
,
height
=
image
.
size
min_pixels
=
MIN_PIXELS
max_pixels
=
MAX_PIXELS
resized_height
,
resized_width
=
smart_resize
(
height
,
width
,
factor
=
size_factor
,
min_pixels
=
min_pixels
,
max_pixels
=
max_pixels
,
)
image
=
image
.
resize
((
resized_width
,
resized_height
))
return
image
def
round_by_factor
(
number
:
int
,
factor
:
int
)
->
int
:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return
round
(
number
/
factor
)
*
factor
def
ceil_by_factor
(
number
:
int
,
factor
:
int
)
->
int
:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return
math
.
ceil
(
number
/
factor
)
*
factor
def
floor_by_factor
(
number
:
int
,
factor
:
int
)
->
int
:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return
math
.
floor
(
number
/
factor
)
*
factor
async
def
resize_image_async
(
image
):
return
resize_image
(
image
)
def
smart_nframes
(
ele
:
dict
,
total_frames
:
int
,
video_fps
:
int
|
float
,
)
->
int
:
"""calculate the number of frames for video used for model inputs.
Args:
ele (dict): a dict contains the configuration of video.
support either `fps` or `nframes`:
- nframes: the number of frames to extract for model inputs.
- fps: the fps to extract frames for model inputs.
- min_frames: the minimum number of frames of the video, only used when fps is provided.
- max_frames: the maximum number of frames of the video, only used when fps is provided.
total_frames (int): the original total number of frames of the video.
video_fps (int | float): the original fps of the video.
Raises:
ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
Returns:
int: the number of frames for video used for model inputs.
"""
assert
not
(
"fps"
in
ele
and
"nframes"
in
ele
),
"Only accept either `fps` or `nframes`"
if
"nframes"
in
ele
:
nframes
=
round_by_factor
(
ele
[
"nframes"
],
FRAME_FACTOR
)
else
:
fps
=
ele
.
get
(
"fps"
,
FPS
)
min_frames
=
ceil_by_factor
(
ele
.
get
(
"min_frames"
,
FPS_MIN_FRAMES
),
FRAME_FACTOR
)
max_frames
=
floor_by_factor
(
ele
.
get
(
"max_frames"
,
min
(
FPS_MAX_FRAMES
,
total_frames
)),
FRAME_FACTOR
)
nframes
=
total_frames
/
video_fps
*
fps
if
nframes
>
total_frames
:
logger
.
warning
(
f
"smart_nframes: nframes[
{
nframes
}
] > total_frames[
{
total_frames
}
]"
)
nframes
=
min
(
min
(
max
(
nframes
,
min_frames
),
max_frames
),
total_frames
)
nframes
=
floor_by_factor
(
nframes
,
FRAME_FACTOR
)
if
not
(
FRAME_FACTOR
<=
nframes
and
nframes
<=
total_frames
):
raise
ValueError
(
f
"nframes should in interval [
{
FRAME_FACTOR
}
,
{
total_frames
}
], but got
{
nframes
}
."
)
return
nframes
# process video, qwen-specific
async
def
preprocess_video
(
vr
,
image_factor
:
int
=
IMAGE_FACTOR
,
# vr: VideoReader, image_factor: int = IMAGE_FACTOR
)
->
torch
.
Tensor
:
ele
=
{}
total_frames
,
video_fps
=
len
(
vr
),
vr
.
get_avg_fps
()
nframes
=
smart_nframes
({},
total_frames
=
total_frames
,
video_fps
=
video_fps
)
idx
=
torch
.
linspace
(
0
,
total_frames
-
1
,
nframes
).
round
().
long
().
tolist
()
video
=
vr
.
get_batch
(
idx
).
asnumpy
()
video
=
torch
.
tensor
(
video
).
permute
(
0
,
3
,
1
,
2
)
# Convert to TCHW format
nframes
,
_
,
height
,
width
=
video
.
shape
min_pixels
=
ele
.
get
(
"min_pixels"
,
VIDEO_MIN_PIXELS
)
total_pixels
=
ele
.
get
(
"total_pixels"
,
VIDEO_TOTAL_PIXELS
)
max_pixels
=
max
(
min
(
VIDEO_MAX_PIXELS
,
total_pixels
/
nframes
*
FRAME_FACTOR
),
int
(
min_pixels
*
1.05
),
)
max_pixels_supposed
=
ele
.
get
(
"max_pixels"
,
max_pixels
)
if
max_pixels_supposed
>
max_pixels
:
logger
.
warning
(
f
"The given max_pixels[
{
max_pixels_supposed
}
] exceeds limit[
{
max_pixels
}
]."
)
max_pixels
=
min
(
max_pixels_supposed
,
max_pixels
)
if
"resized_height"
in
ele
and
"resized_width"
in
ele
:
resized_height
,
resized_width
=
smart_resize
(
ele
[
"resized_height"
],
ele
[
"resized_width"
],
factor
=
image_factor
,
)
else
:
resized_height
,
resized_width
=
smart_resize
(
height
,
width
,
factor
=
image_factor
,
min_pixels
=
min_pixels
,
max_pixels
=
max_pixels
,
)
video
=
torchvision
.
transforms
.
functional
.
resize
(
video
,
[
resized_height
,
resized_width
],
interpolation
=
InterpolationMode
.
BICUBIC
,
antialias
=
True
,
).
float
()
return
video
# Compatible with Qwen2VL and Qwen2_5VL
# Compatible with Qwen2VL and Qwen2_5VL
...
@@ -37,104 +220,44 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
...
@@ -37,104 +220,44 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
self
.
MIN_PIXELS
=
4
*
28
*
28
self
.
MIN_PIXELS
=
4
*
28
*
28
self
.
MAX_PIXELS
=
16384
*
28
*
28
self
.
MAX_PIXELS
=
16384
*
28
*
28
self
.
MAX_RATIO
=
200
self
.
MAX_RATIO
=
200
# TODO(mick): move all MultimodalSpecialTokens initializations into processor init
self
.
mm_special_tokens
=
MultimodalSpecialTokens
(
image_token
=
self
.
IMAGE_TOKEN
,
image_token_regex
=
self
.
IMAGE_TOKEN_REGEX
,
video_token
=
self
.
VIDEO_TOKEN_ID
,
)
async
def
process_mm_data_async
(
async
def
process_mm_data_async
(
self
,
self
,
image_data
:
List
[
Union
[
str
,
bytes
,
Dict
]],
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
input_text
,
request_obj
,
request_obj
,
max_req_input_len
,
max_req_input_len
,
*
args
,
*
args
,
**
kwargs
,
**
kwargs
,
):
):
base_output
=
self
.
load_mm_data
(
base_output
=
self
.
load_mm_data
(
prompt
=
input_text
,
prompt
=
input_text
,
image_data
=
image_data
,
image_data
=
image_data
,
multimodal_tokens
=
MultimodalSpecialTokens
(
video_data
=
request_obj
.
video_data
,
image_token
=
self
.
IMAGE_TOKEN
,
multimodal_tokens
=
self
.
mm_special_tokens
,
image_token_regex
=
self
.
IMAGE_TOKEN_REGEX
,
),
max_req_input_len
=
max_req_input_len
,
max_req_input_len
=
max_req_input_len
,
)
)
def
smart_resize
(
height
:
int
,
width
:
int
,
factor
:
int
=
self
.
IMAGE_FACTOR
,
min_pixels
:
int
=
self
.
MIN_PIXELS
,
max_pixels
:
int
=
self
.
MAX_PIXELS
,
)
->
tuple
[
int
,
int
]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if
max
(
height
,
width
)
/
min
(
height
,
width
)
>
self
.
MAX_RATIO
:
raise
ValueError
(
f
"absolute aspect ratio must be smaller than
{
self
.
MAX_RATIO
}
, got
{
max
(
height
,
width
)
/
min
(
height
,
width
)
}
"
)
h_bar
=
max
(
factor
,
round_by_factor
(
height
,
factor
))
w_bar
=
max
(
factor
,
round_by_factor
(
width
,
factor
))
if
h_bar
*
w_bar
>
max_pixels
:
beta
=
math
.
sqrt
((
height
*
width
)
/
max_pixels
)
h_bar
=
floor_by_factor
(
height
/
beta
,
factor
)
w_bar
=
floor_by_factor
(
width
/
beta
,
factor
)
elif
h_bar
*
w_bar
<
min_pixels
:
beta
=
math
.
sqrt
(
min_pixels
/
(
height
*
width
))
h_bar
=
ceil_by_factor
(
height
*
beta
,
factor
)
w_bar
=
ceil_by_factor
(
width
*
beta
,
factor
)
return
h_bar
,
w_bar
def
resize_image
(
image
,
size_factor
:
int
=
self
.
IMAGE_FACTOR
)
->
Image
.
Image
:
width
,
height
=
image
.
size
min_pixels
=
self
.
MIN_PIXELS
max_pixels
=
self
.
MAX_PIXELS
resized_height
,
resized_width
=
smart_resize
(
height
,
width
,
factor
=
size_factor
,
min_pixels
=
min_pixels
,
max_pixels
=
max_pixels
,
)
image
=
image
.
resize
((
resized_width
,
resized_height
))
return
image
def
round_by_factor
(
number
:
int
,
factor
:
int
)
->
int
:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return
round
(
number
/
factor
)
*
factor
def
ceil_by_factor
(
number
:
int
,
factor
:
int
)
->
int
:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return
math
.
ceil
(
number
/
factor
)
*
factor
def
floor_by_factor
(
number
:
int
,
factor
:
int
)
->
int
:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return
math
.
floor
(
number
/
factor
)
*
factor
async
def
resize_image_async
(
image
):
return
resize_image
(
image
)
# Qwen-specific: resize images if they are raw Image objects
# Qwen-specific: resize images if they are raw Image objects
if
base_output
.
images
and
isinstance
(
base_output
.
images
[
0
],
Image
.
Image
):
if
base_output
.
images
and
isinstance
(
base_output
.
images
[
0
],
Image
.
Image
):
resize_tasks
=
[
resize_image_async
(
image
)
for
image
in
base_output
.
images
]
resize_tasks
=
[
resize_image_async
(
image
)
for
image
in
base_output
.
images
]
base_output
.
images
=
await
asyncio
.
gather
(
*
resize_tasks
)
base_output
.
images
=
await
asyncio
.
gather
(
*
resize_tasks
)
video_grid_thw
=
None
# TODO
if
base_output
.
videos
:
base_output
.
videos
=
[
mm_items
,
input_ids
=
self
.
process_and_combine_mm_data
(
base_output
)
await
preprocess_video
(
video
)
for
video
in
base_output
.
videos
]
if
not
mm_items
:
# Note(Xinyuan): This is the case where image loading fails.
return
None
combined_mm_item
=
mm_items
[
0
]
# only image is supported for now
mm_items
,
input_ids
,
ret
=
self
.
process_and_combine_mm_data
(
base_output
)
video_grid_thw
=
None
# TODO
second_per_grid_ts
=
getattr
(
combined_mm_item
,
"second_per_grid_ts"
,
None
)
input_ids
=
input_ids
.
flatten
()
mrope_positions
,
mrope_position_delta
=
MRotaryEmbedding
.
get_rope_index
(
mrope_positions
,
mrope_position_delta
=
MRotaryEmbedding
.
get_rope_index
(
spatial_merge_size
=
self
.
hf_config
.
vision_config
.
spatial_merge_size
,
spatial_merge_size
=
self
.
hf_config
.
vision_config
.
spatial_merge_size
,
image_token_id
=
self
.
IM_TOKEN_ID
,
image_token_id
=
self
.
IM_TOKEN_ID
,
...
@@ -145,9 +268,9 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
...
@@ -145,9 +268,9 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
self
.
hf_config
.
vision_config
,
"tokens_per_second"
,
None
self
.
hf_config
.
vision_config
,
"tokens_per_second"
,
None
),
),
input_ids
=
input_ids
.
unsqueeze
(
0
),
input_ids
=
input_ids
.
unsqueeze
(
0
),
image_grid_thw
=
combined_mm_item
.
image_grid_thw
,
image_grid_thw
=
getattr
(
ret
,
"
image_grid_thw
"
,
None
)
,
video_grid_thw
=
video_grid_thw
,
video_grid_thw
=
getattr
(
ret
,
"
video_grid_thw
"
,
None
)
,
second_per_grid_ts
=
second_per_grid_ts
,
second_per_grid_ts
=
getattr
(
ret
,
"
second_per_grid_ts
"
,
None
)
,
)
)
mrope_positions
=
mrope_positions
.
squeeze
(
1
)
mrope_positions
=
mrope_positions
.
squeeze
(
1
)
...
...
python/sglang/srt/multimodal/processors/vila.py
View file @
b5e3d603
...
@@ -57,7 +57,7 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
...
@@ -57,7 +57,7 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
image_data
=
image_data
,
image_data
=
image_data
,
)
)
mm_items
,
input_ids
=
self
.
process_and_combine_mm_data
(
base_output
)
mm_items
,
input_ids
,
_
=
self
.
process_and_combine_mm_data
(
base_output
)
return
{
return
{
"input_ids"
:
input_ids
.
tolist
(),
"input_ids"
:
input_ids
.
tolist
(),
...
...
python/sglang/srt/utils.py
View file @
b5e3d603
...
@@ -728,33 +728,6 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
...
@@ -728,33 +728,6 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
return
audio
return
audio
def
encode_video
(
video_path
,
frame_count_limit
=
None
):
# Lazy import because decord is not available on some arm platforms.
from
decord
import
VideoReader
,
cpu
if
not
os
.
path
.
exists
(
video_path
):
logger
.
error
(
f
"Video
{
video_path
}
does not exist"
)
return
[]
if
frame_count_limit
==
0
:
return
[]
def
uniform_sample
(
l
,
n
):
gap
=
len
(
l
)
/
n
idxs
=
[
int
(
i
*
gap
+
gap
/
2
)
for
i
in
range
(
n
)]
return
[
l
[
i
]
for
i
in
idxs
]
vr
=
VideoReader
(
video_path
,
ctx
=
cpu
(
0
))
sample_fps
=
round
(
vr
.
get_avg_fps
()
/
1
)
# FPS
frame_indices
=
[
i
for
i
in
range
(
0
,
len
(
vr
),
sample_fps
)]
if
frame_count_limit
is
not
None
and
len
(
frame_indices
)
>
frame_count_limit
:
frame_indices
=
uniform_sample
(
frame_indices
,
frame_count_limit
)
frames
=
vr
.
get_batch
(
frame_indices
).
asnumpy
()
frames
=
[
Image
.
fromarray
(
v
.
astype
(
"uint8"
))
for
v
in
frames
]
return
frames
def
load_image
(
def
load_image
(
image_file
:
Union
[
Image
.
Image
,
str
,
bytes
],
image_file
:
Union
[
Image
.
Image
,
str
,
bytes
],
)
->
tuple
[
Image
.
Image
,
tuple
[
int
,
int
]]:
)
->
tuple
[
Image
.
Image
,
tuple
[
int
,
int
]]:
...
@@ -774,9 +747,6 @@ def load_image(
...
@@ -774,9 +747,6 @@ def load_image(
elif
image_file
.
startswith
(
"data:"
):
elif
image_file
.
startswith
(
"data:"
):
image_file
=
image_file
.
split
(
","
)[
1
]
image_file
=
image_file
.
split
(
","
)[
1
]
image
=
Image
.
open
(
BytesIO
(
pybase64
.
b64decode
(
image_file
,
validate
=
True
)))
image
=
Image
.
open
(
BytesIO
(
pybase64
.
b64decode
(
image_file
,
validate
=
True
)))
elif
image_file
.
startswith
(
"video:"
):
image_file
=
image_file
.
replace
(
"video:"
,
""
)
image
,
image_size
=
decode_video_base64
(
image_file
)
elif
isinstance
(
image_file
,
str
):
elif
isinstance
(
image_file
,
str
):
image
=
Image
.
open
(
BytesIO
(
pybase64
.
b64decode
(
image_file
,
validate
=
True
)))
image
=
Image
.
open
(
BytesIO
(
pybase64
.
b64decode
(
image_file
,
validate
=
True
)))
else
:
else
:
...
@@ -785,6 +755,61 @@ def load_image(
...
@@ -785,6 +755,61 @@ def load_image(
return
image
,
image_size
return
image
,
image_size
def
load_video
(
video_file
:
Union
[
str
,
bytes
],
use_gpu
:
bool
=
True
):
# We import decord here to avoid a strange Segmentation fault (core dumped) issue.
from
decord
import
VideoReader
,
cpu
,
gpu
try
:
from
decord.bridge
import
decord_bridge
ctx
=
gpu
(
0
)
_
=
decord_bridge
.
get_ctx_device
(
ctx
)
except
Exception
:
ctx
=
cpu
(
0
)
tmp_file
=
None
vr
=
None
try
:
if
isinstance
(
video_file
,
bytes
):
tmp_file
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
,
suffix
=
".mp4"
)
tmp_file
.
write
(
video_file
)
tmp_file
.
close
()
vr
=
VideoReader
(
tmp_file
.
name
,
ctx
=
ctx
)
elif
isinstance
(
video_file
,
str
):
if
video_file
.
startswith
((
"http://"
,
"https://"
)):
timeout
=
int
(
os
.
getenv
(
"REQUEST_TIMEOUT"
,
"10"
))
response
=
requests
.
get
(
video_file
,
stream
=
True
,
timeout
=
timeout
)
response
.
raise_for_status
()
tmp_file
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
,
suffix
=
".mp4"
)
for
chunk
in
response
.
iter_content
(
chunk_size
=
8192
):
tmp_file
.
write
(
chunk
)
tmp_file
.
close
()
vr
=
VideoReader
(
tmp_file
.
name
,
ctx
=
ctx
)
elif
video_file
.
startswith
(
"data:"
):
_
,
encoded
=
video_file
.
split
(
","
,
1
)
video_bytes
=
base64
.
b64decode
(
encoded
)
tmp_file
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
,
suffix
=
".mp4"
)
tmp_file
.
write
(
video_bytes
)
tmp_file
.
close
()
vr
=
VideoReader
(
tmp_file
.
name
,
ctx
=
ctx
)
elif
os
.
path
.
isfile
(
video_file
):
vr
=
VideoReader
(
video_file
,
ctx
=
ctx
)
else
:
video_bytes
=
base64
.
b64decode
(
video_file
)
tmp_file
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
,
suffix
=
".mp4"
)
tmp_file
.
write
(
video_bytes
)
tmp_file
.
close
()
vr
=
VideoReader
(
tmp_file
.
name
,
ctx
=
ctx
)
else
:
raise
ValueError
(
f
"Unsupported video input type:
{
type
(
video_file
)
}
"
)
return
vr
finally
:
if
tmp_file
and
os
.
path
.
exists
(
tmp_file
.
name
):
os
.
unlink
(
tmp_file
.
name
)
def
suppress_other_loggers
():
def
suppress_other_loggers
():
warnings
.
filterwarnings
(
warnings
.
filterwarnings
(
"ignore"
,
category
=
UserWarning
,
message
=
"The given NumPy array is not writable"
"ignore"
,
category
=
UserWarning
,
message
=
"The given NumPy array is not writable"
...
...
test/srt/test_jinja_template_utils.py
View file @
b5e3d603
...
@@ -3,7 +3,6 @@ Unit tests for Jinja chat template utils.
...
@@ -3,7 +3,6 @@ Unit tests for Jinja chat template utils.
"""
"""
import
unittest
import
unittest
from
unittest.mock
import
patch
from
sglang.srt.jinja_template_utils
import
(
from
sglang.srt.jinja_template_utils
import
(
detect_jinja_template_content_format
,
detect_jinja_template_content_format
,
...
@@ -76,11 +75,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
...
@@ -76,11 +75,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
}
}
image_data
=
[]
image_data
=
[]
video_data
=
[]
audio_data
=
[]
audio_data
=
[]
modalities
=
[]
modalities
=
[]
result
=
process_content_for_template_format
(
result
=
process_content_for_template_format
(
msg_dict
,
"openai"
,
image_data
,
audio_data
,
modalities
msg_dict
,
"openai"
,
image_data
,
video_data
,
audio_data
,
modalities
)
)
# Check that image_data was extracted
# Check that image_data was extracted
...
@@ -111,11 +111,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
...
@@ -111,11 +111,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
}
}
image_data
=
[]
image_data
=
[]
video_data
=
[]
audio_data
=
[]
audio_data
=
[]
modalities
=
[]
modalities
=
[]
result
=
process_content_for_template_format
(
result
=
process_content_for_template_format
(
msg_dict
,
"string"
,
image_data
,
audio_data
,
modalities
msg_dict
,
"string"
,
image_data
,
video_data
,
audio_data
,
modalities
)
)
# For string format, should flatten to text only
# For string format, should flatten to text only
...
@@ -139,11 +140,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
...
@@ -139,11 +140,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
}
}
image_data
=
[]
image_data
=
[]
video_data
=
[]
audio_data
=
[]
audio_data
=
[]
modalities
=
[]
modalities
=
[]
result
=
process_content_for_template_format
(
result
=
process_content_for_template_format
(
msg_dict
,
"openai"
,
image_data
,
audio_data
,
modalities
msg_dict
,
"openai"
,
image_data
,
video_data
,
audio_data
,
modalities
)
)
# Check that audio_data was extracted
# Check that audio_data was extracted
...
@@ -162,11 +164,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
...
@@ -162,11 +164,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
msg_dict
=
{
"role"
:
"user"
,
"content"
:
"Hello world"
}
msg_dict
=
{
"role"
:
"user"
,
"content"
:
"Hello world"
}
image_data
=
[]
image_data
=
[]
video_data
=
[]
audio_data
=
[]
audio_data
=
[]
modalities
=
[]
modalities
=
[]
result
=
process_content_for_template_format
(
result
=
process_content_for_template_format
(
msg_dict
,
"openai"
,
image_data
,
audio_data
,
modalities
msg_dict
,
"openai"
,
image_data
,
video_data
,
audio_data
,
modalities
)
)
# Should pass through unchanged
# Should pass through unchanged
...
@@ -188,11 +191,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
...
@@ -188,11 +191,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
}
}
image_data
=
[]
image_data
=
[]
video_data
=
[]
audio_data
=
[]
audio_data
=
[]
modalities
=
[]
modalities
=
[]
result
=
process_content_for_template_format
(
result
=
process_content_for_template_format
(
msg_dict
,
"openai"
,
image_data
,
audio_data
,
modalities
msg_dict
,
"openai"
,
image_data
,
video_data
,
audio_data
,
modalities
)
)
# Check that modalities was extracted
# Check that modalities was extracted
...
@@ -209,11 +213,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
...
@@ -209,11 +213,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
}
}
image_data
=
[]
image_data
=
[]
video_data
=
[]
audio_data
=
[]
audio_data
=
[]
modalities
=
[]
modalities
=
[]
result
=
process_content_for_template_format
(
result
=
process_content_for_template_format
(
msg_dict
,
"string"
,
image_data
,
audio_data
,
modalities
msg_dict
,
"string"
,
image_data
,
video_data
,
audio_data
,
modalities
)
)
# None values should be filtered out
# None values should be filtered out
...
...
test/srt/test_vision_openai_server_a.py
View file @
b5e3d603
...
@@ -35,6 +35,9 @@ class TestQwen2VLServer(TestOpenAIVisionServer):
...
@@ -35,6 +35,9 @@ class TestQwen2VLServer(TestOpenAIVisionServer):
)
)
cls
.
base_url
+=
"/v1"
cls
.
base_url
+=
"/v1"
def
test_video_chat_completion
(
self
):
self
.
_test_video_chat_completion
()
class
TestQwen2_5_VLServer
(
TestOpenAIVisionServer
):
class
TestQwen2_5_VLServer
(
TestOpenAIVisionServer
):
@
classmethod
@
classmethod
...
@@ -54,6 +57,9 @@ class TestQwen2_5_VLServer(TestOpenAIVisionServer):
...
@@ -54,6 +57,9 @@ class TestQwen2_5_VLServer(TestOpenAIVisionServer):
)
)
cls
.
base_url
+=
"/v1"
cls
.
base_url
+=
"/v1"
def
test_video_chat_completion
(
self
):
self
.
_test_video_chat_completion
()
class
TestVLMContextLengthIssue
(
CustomTestCase
):
class
TestVLMContextLengthIssue
(
CustomTestCase
):
@
classmethod
@
classmethod
...
...
test/srt/test_vision_openai_server_b.py
View file @
b5e3d603
...
@@ -93,7 +93,7 @@ class TestJanusProServer(TestOpenAIVisionServer):
...
@@ -93,7 +93,7 @@ class TestJanusProServer(TestOpenAIVisionServer):
)
)
cls
.
base_url
+=
"/v1"
cls
.
base_url
+=
"/v1"
def
test_video_chat_completion
(
self
):
def
test_video_
images_
chat_completion
(
self
):
pass
pass
def
test_single_image_chat_completion
(
self
):
def
test_single_image_chat_completion
(
self
):
...
@@ -170,7 +170,7 @@ class TestKimiVLServer(TestOpenAIVisionServer):
...
@@ -170,7 +170,7 @@ class TestKimiVLServer(TestOpenAIVisionServer):
)
)
cls
.
base_url
+=
"/v1"
cls
.
base_url
+=
"/v1"
def
test_video_chat_completion
(
self
):
def
test_video_
images_
chat_completion
(
self
):
pass
pass
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment