Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b5e3d603
"tests/test_structures/vscode:/vscode.git/clone" did not exist on "583c4accbbf8e37b15638820b7b781f4475c6bde"
Unverified
Commit
b5e3d603
authored
Jul 10, 2025
by
Mick
Committed by
GitHub
Jul 09, 2025
Browse files
vlm: support video as an input modality (#5888)
parent
4ed57807
Changes
42
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
522 additions
and
280 deletions
+522
-280
python/sglang/srt/models/qwen2_5_vl.py
python/sglang/srt/models/qwen2_5_vl.py
+10
-7
python/sglang/srt/models/qwen2_vl.py
python/sglang/srt/models/qwen2_vl.py
+12
-1
python/sglang/srt/models/vila.py
python/sglang/srt/models/vila.py
+8
-2
python/sglang/srt/multimodal/processors/base_processor.py
python/sglang/srt/multimodal/processors/base_processor.py
+197
-137
python/sglang/srt/multimodal/processors/deepseek_vl_v2.py
python/sglang/srt/multimodal/processors/deepseek_vl_v2.py
+1
-1
python/sglang/srt/multimodal/processors/gemma3.py
python/sglang/srt/multimodal/processors/gemma3.py
+4
-2
python/sglang/srt/multimodal/processors/gemma3n.py
python/sglang/srt/multimodal/processors/gemma3n.py
+1
-1
python/sglang/srt/multimodal/processors/internvl.py
python/sglang/srt/multimodal/processors/internvl.py
+1
-1
python/sglang/srt/multimodal/processors/janus_pro.py
python/sglang/srt/multimodal/processors/janus_pro.py
+1
-1
python/sglang/srt/multimodal/processors/kimi_vl.py
python/sglang/srt/multimodal/processors/kimi_vl.py
+1
-1
python/sglang/srt/multimodal/processors/minicpm.py
python/sglang/srt/multimodal/processors/minicpm.py
+4
-3
python/sglang/srt/multimodal/processors/mllama4.py
python/sglang/srt/multimodal/processors/mllama4.py
+1
-1
python/sglang/srt/multimodal/processors/phi4mm.py
python/sglang/srt/multimodal/processors/phi4mm.py
+1
-1
python/sglang/srt/multimodal/processors/pixtral.py
python/sglang/srt/multimodal/processors/pixtral.py
+1
-1
python/sglang/srt/multimodal/processors/qwen_vl.py
python/sglang/srt/multimodal/processors/qwen_vl.py
+203
-80
python/sglang/srt/multimodal/processors/vila.py
python/sglang/srt/multimodal/processors/vila.py
+1
-1
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+55
-30
test/srt/test_jinja_template_utils.py
test/srt/test_jinja_template_utils.py
+12
-7
test/srt/test_vision_openai_server_a.py
test/srt/test_vision_openai_server_a.py
+6
-0
test/srt/test_vision_openai_server_b.py
test/srt/test_vision_openai_server_b.py
+2
-2
No files found.
python/sglang/srt/models/qwen2_5_vl.py
View file @
b5e3d603
...
@@ -56,7 +56,6 @@ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInp
...
@@ -56,7 +56,6 @@ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInp
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2
import
Qwen2Model
from
sglang.srt.models.qwen2
import
Qwen2Model
from
sglang.srt.models.qwen2_vl
import
Qwen2VLVideoInputs
from
sglang.srt.utils
import
add_prefix
from
sglang.srt.utils
import
add_prefix
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -507,11 +506,15 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
...
@@ -507,11 +506,15 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_grid_thw
)
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_grid_thw
)
return
image_embeds
return
image_embeds
def
_process_video_input
(
self
,
video_input
:
Qwen2VLVideoInputs
)
->
torch
.
Tensor
:
def
get_video_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
pixel_values_videos
=
video_input
[
"pixel_values_videos"
].
type
(
self
.
visual
.
dtype
)
# in qwen-vl, last dim is the same
video_embeds
=
self
.
visual
(
pixel_values
=
torch
.
cat
(
pixel_values_videos
,
grid_thw
=
video_input
[
"video_grid_thw"
]
[
getattr
(
item
,
"pixel_values_videos"
)
for
item
in
items
],
dim
=
0
)
).
type
(
self
.
visual
.
dtype
)
video_grid_thw
=
torch
.
concat
([
item
.
video_grid_thw
for
item
in
items
],
dim
=
0
)
assert
pixel_values
.
dim
()
==
2
,
pixel_values
.
dim
()
assert
video_grid_thw
.
dim
()
==
2
,
video_grid_thw
.
dim
()
video_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
video_grid_thw
)
return
video_embeds
return
video_embeds
def
get_input_embeddings
(
self
):
def
get_input_embeddings
(
self
):
...
@@ -553,7 +556,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
...
@@ -553,7 +556,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
input_ids
=
input_ids
,
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
language_model
=
self
.
model
,
language_model
=
self
.
model
,
image_data_embedding_func
=
self
.
get_image_feature
,
multimodal_model
=
self
,
positions
=
positions
,
positions
=
positions
,
)
)
...
...
python/sglang/srt/models/qwen2_vl.py
View file @
b5e3d603
...
@@ -493,6 +493,17 @@ class Qwen2VLForConditionalGeneration(nn.Module):
...
@@ -493,6 +493,17 @@ class Qwen2VLForConditionalGeneration(nn.Module):
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_grid_thw
)
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_grid_thw
)
return
image_embeds
return
image_embeds
def
get_video_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
# in qwen-vl, last dim is the same
pixel_values
=
torch
.
cat
(
[
item
.
pixel_values_videos
for
item
in
items
],
dim
=
0
).
type
(
self
.
visual
.
dtype
)
video_grid_thw
=
torch
.
concat
([
item
.
video_grid_thw
for
item
in
items
],
dim
=
0
)
assert
pixel_values
.
dim
()
==
2
,
pixel_values
.
dim
()
assert
video_grid_thw
.
dim
()
==
2
,
video_grid_thw
.
dim
()
video_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
video_grid_thw
)
return
video_embeds
def
_process_video_input
(
self
,
video_input
:
Qwen2VLVideoInputs
)
->
torch
.
Tensor
:
def
_process_video_input
(
self
,
video_input
:
Qwen2VLVideoInputs
)
->
torch
.
Tensor
:
pixel_values_videos
=
video_input
[
"pixel_values_videos"
].
type
(
self
.
visual
.
dtype
)
pixel_values_videos
=
video_input
[
"pixel_values_videos"
].
type
(
self
.
visual
.
dtype
)
video_embeds
=
self
.
visual
(
video_embeds
=
self
.
visual
(
...
@@ -538,7 +549,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
...
@@ -538,7 +549,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
input_ids
=
input_ids
,
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
language_model
=
self
.
model
,
language_model
=
self
.
model
,
image_data_embedding_func
=
self
.
get_image_feature
,
multimodal_model
=
self
,
positions
=
positions
,
positions
=
positions
,
)
)
...
...
python/sglang/srt/models/vila.py
View file @
b5e3d603
...
@@ -17,7 +17,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
...
@@ -17,7 +17,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.mm_utils
import
MultiModalityDataPaddingPatternMultimodalTokens
from
sglang.srt.managers.mm_utils
import
MultiModalityDataPaddingPatternMultimodalTokens
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
from
sglang.srt.managers.schedule_batch
import
(
Modality
,
MultimodalDataItem
,
MultimodalInputs
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.models.qwen2
import
Qwen2ForCausalLM
from
sglang.srt.models.qwen2
import
Qwen2ForCausalLM
...
@@ -223,7 +227,9 @@ class VILAForConditionalGeneration(nn.Module):
...
@@ -223,7 +227,9 @@ class VILAForConditionalGeneration(nn.Module):
input_ids
=
input_ids
,
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
language_model
=
self
.
llm
,
language_model
=
self
.
llm
,
image_data_embedding_func
=
self
.
get_image_feature
,
data_embedding_funcs
=
{
Modality
.
IMAGE
:
self
.
get_image_feature
,
},
get_embedding
=
get_embedding
,
get_embedding
=
get_embedding
,
positions
=
positions
,
positions
=
positions
,
)
)
...
...
python/sglang/srt/multimodal/processors/base_processor.py
View file @
b5e3d603
This diff is collapsed.
Click to expand it.
python/sglang/srt/multimodal/processors/deepseek_vl_v2.py
View file @
b5e3d603
...
@@ -69,7 +69,7 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
...
@@ -69,7 +69,7 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
)
)
item
=
MultimodalDataItem
(
item
=
MultimodalDataItem
(
pixel_values
=
res
[
"images"
],
pixel_values
=
res
[
"images"
],
image_
offsets
=
image_offsets
,
offsets
=
image_offsets
,
modality
=
Modality
.
IMAGE
,
modality
=
Modality
.
IMAGE
,
image_emb_mask
=
images_seq_mask
,
image_emb_mask
=
images_seq_mask
,
image_spatial_crop
=
batched_images_spatial_crop
,
image_spatial_crop
=
batched_images_spatial_crop
,
...
...
python/sglang/srt/multimodal/processors/gemma3.py
View file @
b5e3d603
...
@@ -36,6 +36,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
...
@@ -36,6 +36,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
*
args
,
*
args
,
**
kwargs
,
**
kwargs
,
):
):
print
(
f
"
{
image_data
=
}
"
)
base_output
=
self
.
load_mm_data
(
base_output
=
self
.
load_mm_data
(
prompt
=
input_text
,
prompt
=
input_text
,
image_data
=
image_data
,
image_data
=
image_data
,
...
@@ -46,8 +47,9 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
...
@@ -46,8 +47,9 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
discard_alpha_channel
=
True
,
discard_alpha_channel
=
True
,
)
)
mm_items
,
input_ids
=
self
.
process_and_combine_mm_data
(
base_output
)
mm_items
,
input_ids
,
_
=
self
.
process_and_combine_mm_data
(
base_output
)
print
(
f
"
{
base_output
=
}
"
)
print
(
f
"
{
mm_items
=
}
"
)
return
{
return
{
"input_ids"
:
input_ids
.
tolist
(),
"input_ids"
:
input_ids
.
tolist
(),
"mm_items"
:
mm_items
,
"mm_items"
:
mm_items
,
...
...
python/sglang/srt/multimodal/processors/gemma3n.py
View file @
b5e3d603
...
@@ -72,7 +72,7 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
...
@@ -72,7 +72,7 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
),
),
)
)
mm_items
,
input_ids
=
self
.
process_and_combine_mm_data
(
base_output
)
mm_items
,
input_ids
,
_
=
self
.
process_and_combine_mm_data
(
base_output
)
return
{
return
{
"input_ids"
:
input_ids
.
tolist
(),
"input_ids"
:
input_ids
.
tolist
(),
...
...
python/sglang/srt/multimodal/processors/internvl.py
View file @
b5e3d603
...
@@ -225,7 +225,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
...
@@ -225,7 +225,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
MultimodalDataItem
(
MultimodalDataItem
(
pixel_values
=
pixel_values
,
pixel_values
=
pixel_values
,
modality
=
Modality
.
IMAGE
,
modality
=
Modality
.
IMAGE
,
image_
offsets
=
image_offsets
,
offsets
=
image_offsets
,
)
)
]
]
...
...
python/sglang/srt/multimodal/processors/janus_pro.py
View file @
b5e3d603
...
@@ -49,7 +49,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
...
@@ -49,7 +49,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
MultimodalDataItem
(
MultimodalDataItem
(
pixel_values
=
res
[
"pixel_values"
],
pixel_values
=
res
[
"pixel_values"
],
image_emb_mask
=
res
[
"images_emb_mask"
],
image_emb_mask
=
res
[
"images_emb_mask"
],
image_
offsets
=
image_offsets
,
offsets
=
image_offsets
,
modality
=
Modality
.
IMAGE
,
modality
=
Modality
.
IMAGE
,
)
)
],
],
...
...
python/sglang/srt/multimodal/processors/kimi_vl.py
View file @
b5e3d603
...
@@ -39,7 +39,7 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
...
@@ -39,7 +39,7 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
max_req_input_len
=
max_req_input_len
,
max_req_input_len
=
max_req_input_len
,
)
)
mm_items
,
input_ids
=
self
.
process_and_combine_mm_data
(
base_output
)
mm_items
,
input_ids
,
_
=
self
.
process_and_combine_mm_data
(
base_output
)
return
{
return
{
"input_ids"
:
input_ids
.
tolist
(),
"input_ids"
:
input_ids
.
tolist
(),
...
...
python/sglang/srt/multimodal/processors/minicpm.py
View file @
b5e3d603
...
@@ -19,6 +19,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
...
@@ -19,6 +19,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
self
.
image_token
=
"(<image>./</image>)"
self
.
image_token
=
"(<image>./</image>)"
self
.
audio_token
=
"(<audio>./</audio>)"
self
.
audio_token
=
"(<audio>./</audio>)"
self
.
video_token
=
"(<video>./</video>)"
async
def
process_mm_data_async
(
async
def
process_mm_data_async
(
self
,
self
,
...
@@ -36,6 +37,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
...
@@ -36,6 +37,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
image_data
=
image_data
,
image_data
=
image_data
,
multimodal_tokens
=
MultimodalSpecialTokens
(
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
self
.
image_token
,
image_token
=
self
.
image_token
,
video_token
=
self
.
video_token
,
audio_token
=
self
.
audio_token
,
audio_token
=
self
.
audio_token
,
),
),
)
)
...
@@ -113,7 +115,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
...
@@ -113,7 +115,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
if
len
(
pixel_values
)
!=
0
:
if
len
(
pixel_values
)
!=
0
:
item
=
MultimodalDataItem
(
item
=
MultimodalDataItem
(
pixel_values
=
pixel_values
,
pixel_values
=
pixel_values
,
image_
offsets
=
image_offsets
,
offsets
=
image_offsets
,
tgt_size
=
tgt_sizes_flat
,
tgt_size
=
tgt_sizes_flat
,
modality
=
Modality
.
IMAGE
,
modality
=
Modality
.
IMAGE
,
)
)
...
@@ -135,11 +137,10 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
...
@@ -135,11 +137,10 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
item
=
MultimodalDataItem
(
item
=
MultimodalDataItem
(
audio_features
=
[
res
[
"audio_features"
]],
audio_features
=
[
res
[
"audio_features"
]],
audio_feature_lens
=
res
[
"audio_feature_lens"
],
audio_feature_lens
=
res
[
"audio_feature_lens"
],
audio_
offsets
=
audio_offsets
,
offsets
=
audio_offsets
,
modality
=
Modality
.
AUDIO
,
modality
=
Modality
.
AUDIO
,
)
)
items
+=
[
item
]
items
+=
[
item
]
return
{
return
{
"mm_items"
:
items
,
"mm_items"
:
items
,
"input_ids"
:
input_ids
.
tolist
(),
"input_ids"
:
input_ids
.
tolist
(),
...
...
python/sglang/srt/multimodal/processors/mllama4.py
View file @
b5e3d603
...
@@ -144,7 +144,7 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
...
@@ -144,7 +144,7 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
MultimodalDataItem
(
MultimodalDataItem
(
pixel_values
=
processor_output
[
"pixel_values"
],
pixel_values
=
processor_output
[
"pixel_values"
],
modality
=
Modality
.
IMAGE
,
modality
=
Modality
.
IMAGE
,
image_
offsets
=
image_offsets
,
offsets
=
image_offsets
,
)
)
]
]
...
...
python/sglang/srt/multimodal/processors/phi4mm.py
View file @
b5e3d603
...
@@ -65,7 +65,7 @@ class Phi4MMImageProcessor(BaseMultimodalProcessor):
...
@@ -65,7 +65,7 @@ class Phi4MMImageProcessor(BaseMultimodalProcessor):
pixel_values
=
res
[
"input_image_embeds"
],
pixel_values
=
res
[
"input_image_embeds"
],
image_sizes
=
res
[
"image_sizes"
],
image_sizes
=
res
[
"image_sizes"
],
image_emb_mask
=
res
[
"image_attention_mask"
],
image_emb_mask
=
res
[
"image_attention_mask"
],
image_
offsets
=
image_offsets
,
offsets
=
image_offsets
,
modality
=
Modality
.
IMAGE
,
modality
=
Modality
.
IMAGE
,
)
)
]
]
...
...
python/sglang/srt/multimodal/processors/pixtral.py
View file @
b5e3d603
...
@@ -106,7 +106,7 @@ class PixtralProcessor(BaseMultimodalProcessor):
...
@@ -106,7 +106,7 @@ class PixtralProcessor(BaseMultimodalProcessor):
pixel_values
=
processor_output
[
"pixel_values"
],
pixel_values
=
processor_output
[
"pixel_values"
],
image_sizes
=
processor_output
[
"image_sizes"
],
image_sizes
=
processor_output
[
"image_sizes"
],
modality
=
Modality
.
IMAGE
,
modality
=
Modality
.
IMAGE
,
image_
offsets
=
image_offsets
,
offsets
=
image_offsets
,
)
)
]
]
...
...
python/sglang/srt/multimodal/processors/qwen_vl.py
View file @
b5e3d603
import
asyncio
import
asyncio
import
math
import
math
import
os
import
re
import
re
from
typing
import
Dict
,
List
,
Union
from
typing
import
List
,
Union
import
torch
import
torchvision
from
PIL
import
Image
from
PIL
import
Image
from
torchvision.transforms
import
InterpolationMode
from
sglang.srt.layers.rotary_embedding
import
MRotaryEmbedding
from
sglang.srt.layers.rotary_embedding
import
MRotaryEmbedding
from
sglang.srt.models.qwen2_5_vl
import
Qwen2_5_VLForConditionalGeneration
from
sglang.srt.models.qwen2_5_vl
import
Qwen2_5_VLForConditionalGeneration
...
@@ -12,58 +16,31 @@ from sglang.srt.multimodal.processors.base_processor import (
...
@@ -12,58 +16,31 @@ from sglang.srt.multimodal.processors.base_processor import (
BaseMultimodalProcessor
as
SGLangBaseProcessor
,
BaseMultimodalProcessor
as
SGLangBaseProcessor
,
)
)
from
sglang.srt.multimodal.processors.base_processor
import
MultimodalSpecialTokens
from
sglang.srt.multimodal.processors.base_processor
import
MultimodalSpecialTokens
from
sglang.utils
import
logger
IMAGE_FACTOR
=
28
MIN_PIXELS
=
4
*
28
*
28
MAX_PIXELS
=
16384
*
28
*
28
MAX_RATIO
=
200
VIDEO_TOTAL_PIXELS
=
int
(
float
(
os
.
environ
.
get
(
"VIDEO_MAX_PIXELS"
,
128000
*
28
*
28
*
0.9
))
)
# Compatible with Qwen2VL and Qwen2_5VL
VIDEO_MIN_PIXELS
=
128
*
28
*
28
class
Qwen2_5VLImageProcessor
(
SGLangBaseProcessor
):
VIDEO_MAX_PIXELS
=
768
*
28
*
28
models
=
[
Qwen2VLForConditionalGeneration
,
Qwen2_5_VLForConditionalGeneration
]
FRAME_FACTOR
=
2
FPS
=
2.0
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
FPS_MIN_FRAMES
=
4
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
FPS_MAX_FRAMES
=
768
# The single, pre-expanded image token.
self
.
IMAGE_TOKEN
=
"<|vision_start|><|image_pad|><|vision_end|>"
# The regex that matches expanded image tokens.
self
.
IMAGE_TOKEN_REGEX
=
re
.
compile
(
r
"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>"
)
self
.
IM_START_TOKEN_ID
=
hf_config
.
vision_start_token_id
self
.
IM_END_TOKEN_ID
=
hf_config
.
vision_end_token_id
self
.
IM_TOKEN_ID
=
hf_config
.
image_token_id
self
.
VIDEO_TOKEN_ID
=
hf_config
.
video_token_id
self
.
vision_start_token_id
=
hf_config
.
vision_start_token_id
self
.
vision_end_token_id
=
hf_config
.
vision_end_token_id
self
.
NUM_TOKEN_PER_FRAME
=
770
self
.
IMAGE_FACTOR
=
28
self
.
MIN_PIXELS
=
4
*
28
*
28
self
.
MAX_PIXELS
=
16384
*
28
*
28
self
.
MAX_RATIO
=
200
async
def
process_mm_data_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
,
Dict
]],
input_text
,
request_obj
,
max_req_input_len
,
*
args
,
**
kwargs
,
):
base_output
=
self
.
load_mm_data
(
prompt
=
input_text
,
image_data
=
image_data
,
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
self
.
IMAGE_TOKEN
,
image_token_regex
=
self
.
IMAGE_TOKEN_REGEX
,
),
max_req_input_len
=
max_req_input_len
,
)
def
smart_resize
(
def
smart_resize
(
height
:
int
,
height
:
int
,
width
:
int
,
width
:
int
,
factor
:
int
=
self
.
IMAGE_FACTOR
,
factor
:
int
=
IMAGE_FACTOR
,
min_pixels
:
int
=
self
.
MIN_PIXELS
,
min_pixels
:
int
=
MIN_PIXELS
,
max_pixels
:
int
=
self
.
MAX_PIXELS
,
max_pixels
:
int
=
MAX_PIXELS
,
)
->
tuple
[
int
,
int
]:
)
->
tuple
[
int
,
int
]:
"""
"""
Rescales the image so that the following conditions are met:
Rescales the image so that the following conditions are met:
...
@@ -73,9 +50,9 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
...
@@ -73,9 +50,9 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
3. The aspect ratio of the image is maintained as closely as possible.
3. The aspect ratio of the image is maintained as closely as possible.
"""
"""
if
max
(
height
,
width
)
/
min
(
height
,
width
)
>
self
.
MAX_RATIO
:
if
max
(
height
,
width
)
/
min
(
height
,
width
)
>
MAX_RATIO
:
raise
ValueError
(
raise
ValueError
(
f
"absolute aspect ratio must be smaller than
{
self
.
MAX_RATIO
}
, got
{
max
(
height
,
width
)
/
min
(
height
,
width
)
}
"
f
"absolute aspect ratio must be smaller than
{
MAX_RATIO
}
, got
{
max
(
height
,
width
)
/
min
(
height
,
width
)
}
"
)
)
h_bar
=
max
(
factor
,
round_by_factor
(
height
,
factor
))
h_bar
=
max
(
factor
,
round_by_factor
(
height
,
factor
))
w_bar
=
max
(
factor
,
round_by_factor
(
width
,
factor
))
w_bar
=
max
(
factor
,
round_by_factor
(
width
,
factor
))
...
@@ -89,10 +66,11 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
...
@@ -89,10 +66,11 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
w_bar
=
ceil_by_factor
(
width
*
beta
,
factor
)
w_bar
=
ceil_by_factor
(
width
*
beta
,
factor
)
return
h_bar
,
w_bar
return
h_bar
,
w_bar
def
resize_image
(
image
,
size_factor
:
int
=
self
.
IMAGE_FACTOR
)
->
Image
.
Image
:
def
resize_image
(
image
,
size_factor
:
int
=
IMAGE_FACTOR
)
->
Image
.
Image
:
width
,
height
=
image
.
size
width
,
height
=
image
.
size
min_pixels
=
self
.
MIN_PIXELS
min_pixels
=
MIN_PIXELS
max_pixels
=
self
.
MAX_PIXELS
max_pixels
=
MAX_PIXELS
resized_height
,
resized_width
=
smart_resize
(
resized_height
,
resized_width
=
smart_resize
(
height
,
height
,
width
,
width
,
...
@@ -103,38 +81,183 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
...
@@ -103,38 +81,183 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
image
=
image
.
resize
((
resized_width
,
resized_height
))
image
=
image
.
resize
((
resized_width
,
resized_height
))
return
image
return
image
def
round_by_factor
(
number
:
int
,
factor
:
int
)
->
int
:
def
round_by_factor
(
number
:
int
,
factor
:
int
)
->
int
:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return
round
(
number
/
factor
)
*
factor
return
round
(
number
/
factor
)
*
factor
def
ceil_by_factor
(
number
:
int
,
factor
:
int
)
->
int
:
def
ceil_by_factor
(
number
:
int
,
factor
:
int
)
->
int
:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return
math
.
ceil
(
number
/
factor
)
*
factor
return
math
.
ceil
(
number
/
factor
)
*
factor
def
floor_by_factor
(
number
:
int
,
factor
:
int
)
->
int
:
def
floor_by_factor
(
number
:
int
,
factor
:
int
)
->
int
:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return
math
.
floor
(
number
/
factor
)
*
factor
return
math
.
floor
(
number
/
factor
)
*
factor
async
def
resize_image_async
(
image
):
async
def
resize_image_async
(
image
):
return
resize_image
(
image
)
return
resize_image
(
image
)
def
smart_nframes
(
ele
:
dict
,
total_frames
:
int
,
video_fps
:
int
|
float
,
)
->
int
:
"""calculate the number of frames for video used for model inputs.
Args:
ele (dict): a dict contains the configuration of video.
support either `fps` or `nframes`:
- nframes: the number of frames to extract for model inputs.
- fps: the fps to extract frames for model inputs.
- min_frames: the minimum number of frames of the video, only used when fps is provided.
- max_frames: the maximum number of frames of the video, only used when fps is provided.
total_frames (int): the original total number of frames of the video.
video_fps (int | float): the original fps of the video.
Raises:
ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
Returns:
int: the number of frames for video used for model inputs.
"""
assert
not
(
"fps"
in
ele
and
"nframes"
in
ele
),
"Only accept either `fps` or `nframes`"
if
"nframes"
in
ele
:
nframes
=
round_by_factor
(
ele
[
"nframes"
],
FRAME_FACTOR
)
else
:
fps
=
ele
.
get
(
"fps"
,
FPS
)
min_frames
=
ceil_by_factor
(
ele
.
get
(
"min_frames"
,
FPS_MIN_FRAMES
),
FRAME_FACTOR
)
max_frames
=
floor_by_factor
(
ele
.
get
(
"max_frames"
,
min
(
FPS_MAX_FRAMES
,
total_frames
)),
FRAME_FACTOR
)
nframes
=
total_frames
/
video_fps
*
fps
if
nframes
>
total_frames
:
logger
.
warning
(
f
"smart_nframes: nframes[
{
nframes
}
] > total_frames[
{
total_frames
}
]"
)
nframes
=
min
(
min
(
max
(
nframes
,
min_frames
),
max_frames
),
total_frames
)
nframes
=
floor_by_factor
(
nframes
,
FRAME_FACTOR
)
if
not
(
FRAME_FACTOR
<=
nframes
and
nframes
<=
total_frames
):
raise
ValueError
(
f
"nframes should in interval [
{
FRAME_FACTOR
}
,
{
total_frames
}
], but got
{
nframes
}
."
)
return
nframes
# process video, qwen-specific
async
def
preprocess_video
(
vr
,
image_factor
:
int
=
IMAGE_FACTOR
,
# vr: VideoReader, image_factor: int = IMAGE_FACTOR
)
->
torch
.
Tensor
:
ele
=
{}
total_frames
,
video_fps
=
len
(
vr
),
vr
.
get_avg_fps
()
nframes
=
smart_nframes
({},
total_frames
=
total_frames
,
video_fps
=
video_fps
)
idx
=
torch
.
linspace
(
0
,
total_frames
-
1
,
nframes
).
round
().
long
().
tolist
()
video
=
vr
.
get_batch
(
idx
).
asnumpy
()
video
=
torch
.
tensor
(
video
).
permute
(
0
,
3
,
1
,
2
)
# Convert to TCHW format
nframes
,
_
,
height
,
width
=
video
.
shape
min_pixels
=
ele
.
get
(
"min_pixels"
,
VIDEO_MIN_PIXELS
)
total_pixels
=
ele
.
get
(
"total_pixels"
,
VIDEO_TOTAL_PIXELS
)
max_pixels
=
max
(
min
(
VIDEO_MAX_PIXELS
,
total_pixels
/
nframes
*
FRAME_FACTOR
),
int
(
min_pixels
*
1.05
),
)
max_pixels_supposed
=
ele
.
get
(
"max_pixels"
,
max_pixels
)
if
max_pixels_supposed
>
max_pixels
:
logger
.
warning
(
f
"The given max_pixels[
{
max_pixels_supposed
}
] exceeds limit[
{
max_pixels
}
]."
)
max_pixels
=
min
(
max_pixels_supposed
,
max_pixels
)
if
"resized_height"
in
ele
and
"resized_width"
in
ele
:
resized_height
,
resized_width
=
smart_resize
(
ele
[
"resized_height"
],
ele
[
"resized_width"
],
factor
=
image_factor
,
)
else
:
resized_height
,
resized_width
=
smart_resize
(
height
,
width
,
factor
=
image_factor
,
min_pixels
=
min_pixels
,
max_pixels
=
max_pixels
,
)
video
=
torchvision
.
transforms
.
functional
.
resize
(
video
,
[
resized_height
,
resized_width
],
interpolation
=
InterpolationMode
.
BICUBIC
,
antialias
=
True
,
).
float
()
return
video
# Compatible with Qwen2VL and Qwen2_5VL
class
Qwen2_5VLImageProcessor
(
SGLangBaseProcessor
):
models
=
[
Qwen2VLForConditionalGeneration
,
Qwen2_5_VLForConditionalGeneration
]
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
# The single, pre-expanded image token.
self
.
IMAGE_TOKEN
=
"<|vision_start|><|image_pad|><|vision_end|>"
# The regex that matches expanded image tokens.
self
.
IMAGE_TOKEN_REGEX
=
re
.
compile
(
r
"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>"
)
self
.
IM_START_TOKEN_ID
=
hf_config
.
vision_start_token_id
self
.
IM_END_TOKEN_ID
=
hf_config
.
vision_end_token_id
self
.
IM_TOKEN_ID
=
hf_config
.
image_token_id
self
.
VIDEO_TOKEN_ID
=
hf_config
.
video_token_id
self
.
vision_start_token_id
=
hf_config
.
vision_start_token_id
self
.
vision_end_token_id
=
hf_config
.
vision_end_token_id
self
.
NUM_TOKEN_PER_FRAME
=
770
self
.
IMAGE_FACTOR
=
28
self
.
MIN_PIXELS
=
4
*
28
*
28
self
.
MAX_PIXELS
=
16384
*
28
*
28
self
.
MAX_RATIO
=
200
# TODO(mick): move all MultimodalSpecialTokens initializations into processor init
self
.
mm_special_tokens
=
MultimodalSpecialTokens
(
image_token
=
self
.
IMAGE_TOKEN
,
image_token_regex
=
self
.
IMAGE_TOKEN_REGEX
,
video_token
=
self
.
VIDEO_TOKEN_ID
,
)
async
def
process_mm_data_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
request_obj
,
max_req_input_len
,
*
args
,
**
kwargs
,
):
base_output
=
self
.
load_mm_data
(
prompt
=
input_text
,
image_data
=
image_data
,
video_data
=
request_obj
.
video_data
,
multimodal_tokens
=
self
.
mm_special_tokens
,
max_req_input_len
=
max_req_input_len
,
)
# Qwen-specific: resize images if they are raw Image objects
# Qwen-specific: resize images if they are raw Image objects
if
base_output
.
images
and
isinstance
(
base_output
.
images
[
0
],
Image
.
Image
):
if
base_output
.
images
and
isinstance
(
base_output
.
images
[
0
],
Image
.
Image
):
resize_tasks
=
[
resize_image_async
(
image
)
for
image
in
base_output
.
images
]
resize_tasks
=
[
resize_image_async
(
image
)
for
image
in
base_output
.
images
]
base_output
.
images
=
await
asyncio
.
gather
(
*
resize_tasks
)
base_output
.
images
=
await
asyncio
.
gather
(
*
resize_tasks
)
video_grid_thw
=
None
# TODO
if
base_output
.
videos
:
base_output
.
videos
=
[
mm_items
,
input_ids
=
self
.
process_and_combine_mm_data
(
base_output
)
await
preprocess_video
(
video
)
for
video
in
base_output
.
videos
]
if
not
mm_items
:
# Note(Xinyuan): This is the case where image loading fails.
return
None
combined_mm_item
=
mm_items
[
0
]
# only image is supported for now
mm_items
,
input_ids
,
ret
=
self
.
process_and_combine_mm_data
(
base_output
)
video_grid_thw
=
None
# TODO
second_per_grid_ts
=
getattr
(
combined_mm_item
,
"second_per_grid_ts"
,
None
)
input_ids
=
input_ids
.
flatten
()
mrope_positions
,
mrope_position_delta
=
MRotaryEmbedding
.
get_rope_index
(
mrope_positions
,
mrope_position_delta
=
MRotaryEmbedding
.
get_rope_index
(
spatial_merge_size
=
self
.
hf_config
.
vision_config
.
spatial_merge_size
,
spatial_merge_size
=
self
.
hf_config
.
vision_config
.
spatial_merge_size
,
image_token_id
=
self
.
IM_TOKEN_ID
,
image_token_id
=
self
.
IM_TOKEN_ID
,
...
@@ -145,9 +268,9 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
...
@@ -145,9 +268,9 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
self
.
hf_config
.
vision_config
,
"tokens_per_second"
,
None
self
.
hf_config
.
vision_config
,
"tokens_per_second"
,
None
),
),
input_ids
=
input_ids
.
unsqueeze
(
0
),
input_ids
=
input_ids
.
unsqueeze
(
0
),
image_grid_thw
=
combined_mm_item
.
image_grid_thw
,
image_grid_thw
=
getattr
(
ret
,
"
image_grid_thw
"
,
None
)
,
video_grid_thw
=
video_grid_thw
,
video_grid_thw
=
getattr
(
ret
,
"
video_grid_thw
"
,
None
)
,
second_per_grid_ts
=
second_per_grid_ts
,
second_per_grid_ts
=
getattr
(
ret
,
"
second_per_grid_ts
"
,
None
)
,
)
)
mrope_positions
=
mrope_positions
.
squeeze
(
1
)
mrope_positions
=
mrope_positions
.
squeeze
(
1
)
...
...
python/sglang/srt/multimodal/processors/vila.py
View file @
b5e3d603
...
@@ -57,7 +57,7 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
...
@@ -57,7 +57,7 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
image_data
=
image_data
,
image_data
=
image_data
,
)
)
mm_items
,
input_ids
=
self
.
process_and_combine_mm_data
(
base_output
)
mm_items
,
input_ids
,
_
=
self
.
process_and_combine_mm_data
(
base_output
)
return
{
return
{
"input_ids"
:
input_ids
.
tolist
(),
"input_ids"
:
input_ids
.
tolist
(),
...
...
python/sglang/srt/utils.py
View file @
b5e3d603
...
@@ -728,33 +728,6 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
...
@@ -728,33 +728,6 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
return
audio
return
audio
def
encode_video
(
video_path
,
frame_count_limit
=
None
):
# Lazy import because decord is not available on some arm platforms.
from
decord
import
VideoReader
,
cpu
if
not
os
.
path
.
exists
(
video_path
):
logger
.
error
(
f
"Video
{
video_path
}
does not exist"
)
return
[]
if
frame_count_limit
==
0
:
return
[]
def
uniform_sample
(
l
,
n
):
gap
=
len
(
l
)
/
n
idxs
=
[
int
(
i
*
gap
+
gap
/
2
)
for
i
in
range
(
n
)]
return
[
l
[
i
]
for
i
in
idxs
]
vr
=
VideoReader
(
video_path
,
ctx
=
cpu
(
0
))
sample_fps
=
round
(
vr
.
get_avg_fps
()
/
1
)
# FPS
frame_indices
=
[
i
for
i
in
range
(
0
,
len
(
vr
),
sample_fps
)]
if
frame_count_limit
is
not
None
and
len
(
frame_indices
)
>
frame_count_limit
:
frame_indices
=
uniform_sample
(
frame_indices
,
frame_count_limit
)
frames
=
vr
.
get_batch
(
frame_indices
).
asnumpy
()
frames
=
[
Image
.
fromarray
(
v
.
astype
(
"uint8"
))
for
v
in
frames
]
return
frames
def
load_image
(
def
load_image
(
image_file
:
Union
[
Image
.
Image
,
str
,
bytes
],
image_file
:
Union
[
Image
.
Image
,
str
,
bytes
],
)
->
tuple
[
Image
.
Image
,
tuple
[
int
,
int
]]:
)
->
tuple
[
Image
.
Image
,
tuple
[
int
,
int
]]:
...
@@ -774,9 +747,6 @@ def load_image(
...
@@ -774,9 +747,6 @@ def load_image(
elif
image_file
.
startswith
(
"data:"
):
elif
image_file
.
startswith
(
"data:"
):
image_file
=
image_file
.
split
(
","
)[
1
]
image_file
=
image_file
.
split
(
","
)[
1
]
image
=
Image
.
open
(
BytesIO
(
pybase64
.
b64decode
(
image_file
,
validate
=
True
)))
image
=
Image
.
open
(
BytesIO
(
pybase64
.
b64decode
(
image_file
,
validate
=
True
)))
elif
image_file
.
startswith
(
"video:"
):
image_file
=
image_file
.
replace
(
"video:"
,
""
)
image
,
image_size
=
decode_video_base64
(
image_file
)
elif
isinstance
(
image_file
,
str
):
elif
isinstance
(
image_file
,
str
):
image
=
Image
.
open
(
BytesIO
(
pybase64
.
b64decode
(
image_file
,
validate
=
True
)))
image
=
Image
.
open
(
BytesIO
(
pybase64
.
b64decode
(
image_file
,
validate
=
True
)))
else
:
else
:
...
@@ -785,6 +755,61 @@ def load_image(
...
@@ -785,6 +755,61 @@ def load_image(
return
image
,
image_size
return
image
,
image_size
def
load_video
(
video_file
:
Union
[
str
,
bytes
],
use_gpu
:
bool
=
True
):
# We import decord here to avoid a strange Segmentation fault (core dumped) issue.
from
decord
import
VideoReader
,
cpu
,
gpu
try
:
from
decord.bridge
import
decord_bridge
ctx
=
gpu
(
0
)
_
=
decord_bridge
.
get_ctx_device
(
ctx
)
except
Exception
:
ctx
=
cpu
(
0
)
tmp_file
=
None
vr
=
None
try
:
if
isinstance
(
video_file
,
bytes
):
tmp_file
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
,
suffix
=
".mp4"
)
tmp_file
.
write
(
video_file
)
tmp_file
.
close
()
vr
=
VideoReader
(
tmp_file
.
name
,
ctx
=
ctx
)
elif
isinstance
(
video_file
,
str
):
if
video_file
.
startswith
((
"http://"
,
"https://"
)):
timeout
=
int
(
os
.
getenv
(
"REQUEST_TIMEOUT"
,
"10"
))
response
=
requests
.
get
(
video_file
,
stream
=
True
,
timeout
=
timeout
)
response
.
raise_for_status
()
tmp_file
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
,
suffix
=
".mp4"
)
for
chunk
in
response
.
iter_content
(
chunk_size
=
8192
):
tmp_file
.
write
(
chunk
)
tmp_file
.
close
()
vr
=
VideoReader
(
tmp_file
.
name
,
ctx
=
ctx
)
elif
video_file
.
startswith
(
"data:"
):
_
,
encoded
=
video_file
.
split
(
","
,
1
)
video_bytes
=
base64
.
b64decode
(
encoded
)
tmp_file
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
,
suffix
=
".mp4"
)
tmp_file
.
write
(
video_bytes
)
tmp_file
.
close
()
vr
=
VideoReader
(
tmp_file
.
name
,
ctx
=
ctx
)
elif
os
.
path
.
isfile
(
video_file
):
vr
=
VideoReader
(
video_file
,
ctx
=
ctx
)
else
:
video_bytes
=
base64
.
b64decode
(
video_file
)
tmp_file
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
,
suffix
=
".mp4"
)
tmp_file
.
write
(
video_bytes
)
tmp_file
.
close
()
vr
=
VideoReader
(
tmp_file
.
name
,
ctx
=
ctx
)
else
:
raise
ValueError
(
f
"Unsupported video input type:
{
type
(
video_file
)
}
"
)
return
vr
finally
:
if
tmp_file
and
os
.
path
.
exists
(
tmp_file
.
name
):
os
.
unlink
(
tmp_file
.
name
)
def
suppress_other_loggers
():
def
suppress_other_loggers
():
warnings
.
filterwarnings
(
warnings
.
filterwarnings
(
"ignore"
,
category
=
UserWarning
,
message
=
"The given NumPy array is not writable"
"ignore"
,
category
=
UserWarning
,
message
=
"The given NumPy array is not writable"
...
...
test/srt/test_jinja_template_utils.py
View file @
b5e3d603
...
@@ -3,7 +3,6 @@ Unit tests for Jinja chat template utils.
...
@@ -3,7 +3,6 @@ Unit tests for Jinja chat template utils.
"""
"""
import
unittest
import
unittest
from
unittest.mock
import
patch
from
sglang.srt.jinja_template_utils
import
(
from
sglang.srt.jinja_template_utils
import
(
detect_jinja_template_content_format
,
detect_jinja_template_content_format
,
...
@@ -76,11 +75,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
...
@@ -76,11 +75,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
}
}
image_data
=
[]
image_data
=
[]
video_data
=
[]
audio_data
=
[]
audio_data
=
[]
modalities
=
[]
modalities
=
[]
result
=
process_content_for_template_format
(
result
=
process_content_for_template_format
(
msg_dict
,
"openai"
,
image_data
,
audio_data
,
modalities
msg_dict
,
"openai"
,
image_data
,
video_data
,
audio_data
,
modalities
)
)
# Check that image_data was extracted
# Check that image_data was extracted
...
@@ -111,11 +111,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
...
@@ -111,11 +111,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
}
}
image_data
=
[]
image_data
=
[]
video_data
=
[]
audio_data
=
[]
audio_data
=
[]
modalities
=
[]
modalities
=
[]
result
=
process_content_for_template_format
(
result
=
process_content_for_template_format
(
msg_dict
,
"string"
,
image_data
,
audio_data
,
modalities
msg_dict
,
"string"
,
image_data
,
video_data
,
audio_data
,
modalities
)
)
# For string format, should flatten to text only
# For string format, should flatten to text only
...
@@ -139,11 +140,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
...
@@ -139,11 +140,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
}
}
image_data
=
[]
image_data
=
[]
video_data
=
[]
audio_data
=
[]
audio_data
=
[]
modalities
=
[]
modalities
=
[]
result
=
process_content_for_template_format
(
result
=
process_content_for_template_format
(
msg_dict
,
"openai"
,
image_data
,
audio_data
,
modalities
msg_dict
,
"openai"
,
image_data
,
video_data
,
audio_data
,
modalities
)
)
# Check that audio_data was extracted
# Check that audio_data was extracted
...
@@ -162,11 +164,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
...
@@ -162,11 +164,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
msg_dict
=
{
"role"
:
"user"
,
"content"
:
"Hello world"
}
msg_dict
=
{
"role"
:
"user"
,
"content"
:
"Hello world"
}
image_data
=
[]
image_data
=
[]
video_data
=
[]
audio_data
=
[]
audio_data
=
[]
modalities
=
[]
modalities
=
[]
result
=
process_content_for_template_format
(
result
=
process_content_for_template_format
(
msg_dict
,
"openai"
,
image_data
,
audio_data
,
modalities
msg_dict
,
"openai"
,
image_data
,
video_data
,
audio_data
,
modalities
)
)
# Should pass through unchanged
# Should pass through unchanged
...
@@ -188,11 +191,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
...
@@ -188,11 +191,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
}
}
image_data
=
[]
image_data
=
[]
video_data
=
[]
audio_data
=
[]
audio_data
=
[]
modalities
=
[]
modalities
=
[]
result
=
process_content_for_template_format
(
result
=
process_content_for_template_format
(
msg_dict
,
"openai"
,
image_data
,
audio_data
,
modalities
msg_dict
,
"openai"
,
image_data
,
video_data
,
audio_data
,
modalities
)
)
# Check that modalities was extracted
# Check that modalities was extracted
...
@@ -209,11 +213,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
...
@@ -209,11 +213,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
}
}
image_data
=
[]
image_data
=
[]
video_data
=
[]
audio_data
=
[]
audio_data
=
[]
modalities
=
[]
modalities
=
[]
result
=
process_content_for_template_format
(
result
=
process_content_for_template_format
(
msg_dict
,
"string"
,
image_data
,
audio_data
,
modalities
msg_dict
,
"string"
,
image_data
,
video_data
,
audio_data
,
modalities
)
)
# None values should be filtered out
# None values should be filtered out
...
...
test/srt/test_vision_openai_server_a.py
View file @
b5e3d603
...
@@ -35,6 +35,9 @@ class TestQwen2VLServer(TestOpenAIVisionServer):
...
@@ -35,6 +35,9 @@ class TestQwen2VLServer(TestOpenAIVisionServer):
)
)
cls
.
base_url
+=
"/v1"
cls
.
base_url
+=
"/v1"
def
test_video_chat_completion
(
self
):
self
.
_test_video_chat_completion
()
class
TestQwen2_5_VLServer
(
TestOpenAIVisionServer
):
class
TestQwen2_5_VLServer
(
TestOpenAIVisionServer
):
@
classmethod
@
classmethod
...
@@ -54,6 +57,9 @@ class TestQwen2_5_VLServer(TestOpenAIVisionServer):
...
@@ -54,6 +57,9 @@ class TestQwen2_5_VLServer(TestOpenAIVisionServer):
)
)
cls
.
base_url
+=
"/v1"
cls
.
base_url
+=
"/v1"
def
test_video_chat_completion
(
self
):
self
.
_test_video_chat_completion
()
class
TestVLMContextLengthIssue
(
CustomTestCase
):
class
TestVLMContextLengthIssue
(
CustomTestCase
):
@
classmethod
@
classmethod
...
...
test/srt/test_vision_openai_server_b.py
View file @
b5e3d603
...
@@ -93,7 +93,7 @@ class TestJanusProServer(TestOpenAIVisionServer):
...
@@ -93,7 +93,7 @@ class TestJanusProServer(TestOpenAIVisionServer):
)
)
cls
.
base_url
+=
"/v1"
cls
.
base_url
+=
"/v1"
def
test_video_chat_completion
(
self
):
def
test_video_
images_
chat_completion
(
self
):
pass
pass
def
test_single_image_chat_completion
(
self
):
def
test_single_image_chat_completion
(
self
):
...
@@ -170,7 +170,7 @@ class TestKimiVLServer(TestOpenAIVisionServer):
...
@@ -170,7 +170,7 @@ class TestKimiVLServer(TestOpenAIVisionServer):
)
)
cls
.
base_url
+=
"/v1"
cls
.
base_url
+=
"/v1"
def
test_video_chat_completion
(
self
):
def
test_video_
images_
chat_completion
(
self
):
pass
pass
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment