Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c998d04b
Unverified
Commit
c998d04b
authored
Apr 24, 2025
by
Mick
Committed by
GitHub
Apr 23, 2025
Browse files
vlm: enable radix cache for qwen-vl models (#5349)
Co-authored-by:
Xinyuan Tong
<
justinning0323@outlook.com
>
parent
7d0edf3c
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
413 additions
and
301 deletions
+413
-301
benchmark/mmmu/eval_utils.py
benchmark/mmmu/eval_utils.py
+45
-13
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+6
-7
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+150
-114
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+2
-0
python/sglang/srt/managers/mm_utils.py
python/sglang/srt/managers/mm_utils.py
+85
-28
python/sglang/srt/managers/multimodal_processors/base_processor.py
...lang/srt/managers/multimodal_processors/base_processor.py
+14
-1
python/sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py
...lang/srt/managers/multimodal_processors/deepseek_vl_v2.py
+9
-2
python/sglang/srt/managers/multimodal_processors/gemma3.py
python/sglang/srt/managers/multimodal_processors/gemma3.py
+2
-5
python/sglang/srt/managers/multimodal_processors/janus_pro.py
...on/sglang/srt/managers/multimodal_processors/janus_pro.py
+2
-2
python/sglang/srt/managers/multimodal_processors/minicpm.py
python/sglang/srt/managers/multimodal_processors/minicpm.py
+3
-2
python/sglang/srt/managers/multimodal_processors/qwen_vl.py
python/sglang/srt/managers/multimodal_processors/qwen_vl.py
+38
-13
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+14
-11
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+4
-1
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+23
-74
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+0
-9
python/sglang/srt/models/deepseek_vl2.py
python/sglang/srt/models/deepseek_vl2.py
+3
-3
python/sglang/srt/models/minicpmo.py
python/sglang/srt/models/minicpmo.py
+5
-1
python/sglang/srt/models/mllama4.py
python/sglang/srt/models/mllama4.py
+2
-2
python/sglang/srt/models/qwen2_5_vl.py
python/sglang/srt/models/qwen2_5_vl.py
+3
-6
python/sglang/srt/models/qwen2_vl.py
python/sglang/srt/models/qwen2_vl.py
+3
-7
No files found.
benchmark/mmmu/eval_utils.py
View file @
c998d04b
...
@@ -89,7 +89,7 @@ def set_seed(seed_value):
...
@@ -89,7 +89,7 @@ def set_seed(seed_value):
def
prepare_samples
(
eval_args
:
EvalArgs
):
def
prepare_samples
(
eval_args
:
EvalArgs
):
print
(
"
p
reparing samples..."
)
print
(
"
P
reparing samples..."
)
# Build prompts
# Build prompts
set_seed
(
eval_args
.
seed
)
set_seed
(
eval_args
.
seed
)
...
@@ -105,15 +105,40 @@ def prepare_samples(eval_args: EvalArgs):
...
@@ -105,15 +105,40 @@ def prepare_samples(eval_args: EvalArgs):
assert
len
(
value
)
==
1
,
"key {} has more than one value"
.
format
(
key
)
assert
len
(
value
)
==
1
,
"key {} has more than one value"
.
format
(
key
)
eval_args
.
config
[
key
]
=
value
[
0
]
eval_args
.
config
[
key
]
=
value
[
0
]
# run for each subject
# run for each subject
in parallel
sub_dataset_list
=
[]
sub_dataset_list
=
[]
subjects
=
list
(
CAT_SHORT2LONG
.
values
())
# Get a fixed list of subjects
for
subject
in
tqdm
(
CAT_SHORT2LONG
.
values
()):
print
(
f
"Loading datasets for
{
len
(
subjects
)
}
subjects..."
)
sub_dataset
=
load_dataset
(
with
ThreadPoolExecutor
()
as
executor
:
eval_args
.
dataset_path
,
subject
,
split
=
eval_args
.
split
# Submit all load_dataset tasks
)
future_to_subject
=
{
sub_dataset_list
.
append
(
sub_dataset
)
executor
.
submit
(
# break
load_dataset
,
eval_args
.
dataset_path
,
subject
,
split
=
eval_args
.
split
):
subject
for
subject
in
subjects
}
# Collect results as they complete
results
=
{}
for
future
in
tqdm
(
as_completed
(
future_to_subject
),
total
=
len
(
subjects
),
desc
=
"Loading datasets"
,
):
subject
=
future_to_subject
[
future
]
try
:
results
[
subject
]
=
future
.
result
()
except
Exception
as
exc
:
print
(
f
"
{
subject
}
generated an exception:
{
exc
}
"
)
# Ensure datasets are added in the original order for consistency
for
subject
in
subjects
:
if
subject
in
results
:
sub_dataset_list
.
append
(
results
[
subject
])
else
:
# Handle cases where a dataset failed to load (optional, depends on desired behavior)
print
(
f
"Warning: Dataset for subject '
{
subject
}
' could not be loaded."
)
# merge all dataset
# merge all dataset
dataset
=
concatenate_datasets
(
sub_dataset_list
)
dataset
=
concatenate_datasets
(
sub_dataset_list
)
...
@@ -133,18 +158,25 @@ def prepare_samples(eval_args: EvalArgs):
...
@@ -133,18 +158,25 @@ def prepare_samples(eval_args: EvalArgs):
width
,
height
=
image
.
size
width
,
height
=
image
.
size
if
width
*
height
>=
eval_args
.
image_pixels_limit
:
if
width
*
height
>=
eval_args
.
image_pixels_limit
:
return
None
,
True
return
None
,
True
image_path
=
f
"
{
images_path
}
/image_
{
i
}
.png"
# Use a unique identifier for the image path to avoid potential collisions if indices reset
image_path
=
f
"
{
images_path
}
/image_
{
sample
[
'id'
]
}
.png"
if
not
os
.
path
.
exists
(
image_path
):
if
not
os
.
path
.
exists
(
image_path
):
image
.
save
(
image_path
)
image
.
save
(
image_path
)
sample
[
"image_path"
]
=
image_path
sample
[
"image_path"
]
=
image_path
return
sample
,
False
return
sample
,
False
print
(
"Processing samples..."
)
with
ThreadPoolExecutor
()
as
executor
:
with
ThreadPoolExecutor
()
as
executor
:
# Pass the sample itself to process_sample, index is less reliable now
futures
=
[
futures
=
[
executor
.
submit
(
process_sample
,
i
,
sample
)
executor
.
submit
(
process_sample
,
i
,
sample
)
# Keep index i for tqdm maybe? Or remove it. Let's keep it for now.
for
i
,
sample
in
enumerate
(
dataset
)
for
i
,
sample
in
enumerate
(
dataset
)
]
]
for
future
in
tqdm
(
as_completed
(
futures
),
total
=
len
(
futures
)):
for
future
in
tqdm
(
as_completed
(
futures
),
total
=
len
(
dataset
),
desc
=
"Processing samples"
):
sample
,
skipped
=
future
.
result
()
sample
,
skipped
=
future
.
result
()
if
skipped
:
if
skipped
:
skip_count
+=
1
skip_count
+=
1
...
@@ -152,9 +184,9 @@ def prepare_samples(eval_args: EvalArgs):
...
@@ -152,9 +184,9 @@ def prepare_samples(eval_args: EvalArgs):
samples
.
append
(
sample
)
samples
.
append
(
sample
)
print
(
print
(
f
"
s
kipping
{
skip_count
}
samples with large images,
{
round
((
float
(
skip_count
)
/
len
(
dataset
))
*
100
,
2
)
}
% of dataset"
f
"
S
kipping
{
skip_count
}
samples with large images,
{
round
((
float
(
skip_count
)
/
len
(
dataset
))
*
100
,
2
)
}
% of dataset"
)
)
print
(
"
s
amples have been prepared"
)
print
(
"
S
amples have been prepared"
)
return
samples
return
samples
...
...
python/sglang/srt/configs/model_config.py
View file @
c998d04b
...
@@ -73,15 +73,14 @@ class ModelConfig:
...
@@ -73,15 +73,14 @@ class ModelConfig:
)
)
if
enable_multimodal
is
None
:
if
enable_multimodal
is
None
:
if
self
.
hf_config
.
architectures
[
0
]
==
"Llama4ForConditionalGeneration"
:
mm_disabled_models
=
[
"Gemma3ForConditionalGeneration"
,
"Llama4ForConditionalGeneration"
,
]
if
self
.
hf_config
.
architectures
[
0
]
in
mm_disabled_models
:
enable_multimodal
=
False
enable_multimodal
=
False
logger
.
info
(
logger
.
info
(
"Multimodal is disabled for Llama4. To enable it, set --enable-llama4-multimodal."
f
"Multimodal is disabled for
{
self
.
hf_config
.
model_type
}
. To enable it, set --enable-multimodal."
)
elif
self
.
hf_config
.
architectures
[
0
]
==
"Gemma3ForConditionalGeneration"
:
enable_multimodal
=
False
logger
.
info
(
"Multimodal is disabled for Gemma3. To enable it, set --enable-gemma3-multimodal."
)
)
else
:
else
:
enable_multimodal
=
True
enable_multimodal
=
True
...
...
python/sglang/srt/layers/rotary_embedding.py
View file @
c998d04b
...
@@ -877,127 +877,163 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -877,127 +877,163 @@ class MRotaryEmbedding(RotaryEmbedding):
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
return
query
,
key
# Copied from https://github.com/huggingface/transformers/blob/c8e0e603de9b3d49161a15fe6e8ea84badfb5d02/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1439
@
staticmethod
@
staticmethod
def
get_input_positions
(
def
get_rope_index
(
input_tokens
:
List
[
int
],
spatial_merge_size
:
int
,
image_grid_thw
:
Union
[
List
[
List
[
int
]],
torch
.
Tensor
],
video_grid_thw
:
Union
[
List
[
List
[
int
]],
torch
.
Tensor
],
image_token_id
:
int
,
image_token_id
:
int
,
video_token_id
:
int
,
video_token_id
:
int
,
vision_start_token_id
:
int
,
vision_start_token_id
:
int
,
vision_end_token_id
:
int
,
model_type
:
str
,
spatial_merge_size
:
int
,
context_len
:
int
=
0
,
seq_len
:
Optional
[
int
]
=
None
,
second_per_grid_ts
:
Optional
[
torch
.
Tensor
]
=
None
,
tokens_per_second
:
Optional
[
int
]
=
None
,
tokens_per_second
:
Optional
[
int
]
=
None
,
)
->
Tuple
[
List
[
List
[
int
]],
int
]:
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
"""
image_grid_thw
:
Optional
[
torch
.
LongTensor
]
=
None
,
Get mrope input positions and delta value.
video_grid_thw
:
Optional
[
torch
.
LongTensor
]
=
None
,
second_per_grid_ts
:
Optional
[
torch
.
Tensor
]
=
None
,
:arg
**
kwargs
,
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
mrope_position_deltas
=
[]
if
input_ids
is
not
None
and
(
"""
image_grid_thw
is
not
None
or
video_grid_thw
is
not
None
):
if
isinstance
(
image_grid_thw
,
torch
.
Tensor
):
total_input_ids
=
input_ids
image_grid_thw
=
image_grid_thw
.
tolist
()
position_ids
=
torch
.
ones
(
if
isinstance
(
video_grid_thw
,
torch
.
Tensor
):
3
,
video_grid_thw
=
video_grid_thw
.
tolist
()
input_ids
.
shape
[
0
],
input_ids
.
shape
[
1
],
input_tokens_tensor
=
torch
.
tensor
(
input_tokens
)
dtype
=
input_ids
.
dtype
,
vision_start_indices
=
torch
.
argwhere
(
device
=
input_ids
.
device
,
input_tokens_tensor
==
vision_start_token_id
).
squeeze
(
1
)
vision_tokens
=
input_tokens_tensor
[
vision_start_indices
+
1
]
image_nums
=
(
vision_tokens
==
image_token_id
).
sum
()
video_nums
=
(
vision_tokens
==
video_token_id
).
sum
()
llm_pos_ids_list
:
list
=
[]
st
=
0
remain_images
,
remain_videos
=
image_nums
,
video_nums
image_index
,
video_index
=
0
,
0
for
_
in
range
(
image_nums
+
video_nums
):
if
image_token_id
in
input_tokens
and
remain_images
>
0
:
ed_image
=
input_tokens
.
index
(
image_token_id
,
st
)
else
:
ed_image
=
len
(
input_tokens
)
+
1
if
video_token_id
in
input_tokens
and
remain_videos
>
0
:
ed_video
=
input_tokens
.
index
(
video_token_id
,
st
)
else
:
ed_video
=
len
(
input_tokens
)
+
1
if
ed_image
<
ed_video
:
t
,
h
,
w
=
(
image_grid_thw
[
image_index
][
0
],
image_grid_thw
[
image_index
][
1
],
image_grid_thw
[
image_index
][
2
],
)
image_index
+=
1
remain_images
-=
1
second_per_grid_t
=
0
ed
=
ed_image
else
:
t
,
h
,
w
=
(
video_grid_thw
[
video_index
][
0
],
video_grid_thw
[
video_index
][
1
],
video_grid_thw
[
video_index
][
2
],
)
if
second_per_grid_ts
is
not
None
:
second_per_grid_t
=
second_per_grid_ts
[
video_index
]
else
:
second_per_grid_t
=
1.0
video_index
+=
1
remain_videos
-=
1
ed
=
ed_video
llm_grid_t
,
llm_grid_h
,
llm_grid_w
=
(
t
,
h
//
spatial_merge_size
,
w
//
spatial_merge_size
,
)
text_len
=
ed
-
st
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
t_index
=
(
torch
.
arange
(
llm_grid_t
).
view
(
-
1
,
1
).
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
)
*
second_per_grid_t
*
tokens_per_second
).
flatten
()
h_index
=
(
torch
.
arange
(
llm_grid_h
)
.
view
(
1
,
-
1
,
1
)
.
expand
(
llm_grid_t
,
-
1
,
llm_grid_w
)
.
flatten
()
)
w_index
=
(
torch
.
arange
(
llm_grid_w
)
.
view
(
1
,
1
,
-
1
)
.
expand
(
llm_grid_t
,
llm_grid_h
,
-
1
)
.
flatten
()
)
llm_pos_ids_list
.
append
(
torch
.
stack
([
t_index
,
h_index
,
w_index
])
+
text_len
+
st_idx
)
)
st
=
ed
+
llm_grid_t
*
llm_grid_h
*
llm_grid_w
image_index
,
video_index
=
0
,
0
for
i
,
input_ids
in
enumerate
(
total_input_ids
):
if
st
<
len
(
input_tokens
):
image_nums
,
video_nums
=
0
,
0
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
vision_start_indices
=
torch
.
argwhere
(
text_len
=
len
(
input_tokens
)
-
st
input_ids
==
vision_start_token_id
llm_pos_ids_list
.
append
(
).
squeeze
(
1
)
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
vision_tokens
=
input_ids
[
vision_start_indices
+
1
]
image_nums
=
(
vision_tokens
==
image_token_id
).
sum
()
video_nums
=
(
vision_tokens
==
video_token_id
).
sum
()
input_tokens
=
input_ids
.
tolist
()
llm_pos_ids_list
:
list
=
[]
st
=
0
remain_images
,
remain_videos
=
image_nums
,
video_nums
for
_
in
range
(
image_nums
+
video_nums
):
if
image_token_id
in
input_tokens
and
remain_images
>
0
:
ed_image
=
input_tokens
.
index
(
image_token_id
,
st
)
else
:
ed_image
=
len
(
input_tokens
)
+
1
if
video_token_id
in
input_tokens
and
remain_videos
>
0
:
ed_video
=
input_tokens
.
index
(
video_token_id
,
st
)
else
:
ed_video
=
len
(
input_tokens
)
+
1
if
ed_image
<
ed_video
:
t
,
h
,
w
=
(
image_grid_thw
[
image_index
][
0
],
image_grid_thw
[
image_index
][
1
],
image_grid_thw
[
image_index
][
2
],
)
second_per_grid_t
=
0
image_index
+=
1
remain_images
-=
1
ed
=
ed_image
else
:
t
,
h
,
w
=
(
video_grid_thw
[
video_index
][
0
],
video_grid_thw
[
video_index
][
1
],
video_grid_thw
[
video_index
][
2
],
)
if
second_per_grid_ts
is
not
None
:
second_per_grid_t
=
second_per_grid_ts
[
video_index
]
else
:
second_per_grid_t
=
1.0
video_index
+=
1
remain_videos
-=
1
ed
=
ed_video
llm_grid_t
,
llm_grid_h
,
llm_grid_w
=
(
t
.
item
(),
h
.
item
()
//
spatial_merge_size
,
w
.
item
()
//
spatial_merge_size
,
)
text_len
=
ed
-
st
st_idx
=
(
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
)
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
if
model_type
==
"qwen2_5_vl"
:
range_tensor
=
torch
.
arange
(
llm_grid_t
).
view
(
-
1
,
1
)
expanded_range
=
range_tensor
.
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
)
time_tensor
=
(
expanded_range
*
second_per_grid_t
*
tokens_per_second
)
time_tensor_long
=
time_tensor
.
long
()
t_index
=
time_tensor_long
.
flatten
()
elif
model_type
==
"qwen2_vl"
:
t_index
=
(
torch
.
arange
(
llm_grid_t
)
.
view
(
-
1
,
1
)
.
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
)
.
flatten
()
)
else
:
raise
RuntimeError
(
"Unimplemented"
)
h_index
=
(
torch
.
arange
(
llm_grid_h
)
.
view
(
1
,
-
1
,
1
)
.
expand
(
llm_grid_t
,
-
1
,
llm_grid_w
)
.
flatten
()
)
w_index
=
(
torch
.
arange
(
llm_grid_w
)
.
view
(
1
,
1
,
-
1
)
.
expand
(
llm_grid_t
,
llm_grid_h
,
-
1
)
.
flatten
()
)
llm_pos_ids_list
.
append
(
torch
.
stack
([
t_index
,
h_index
,
w_index
])
+
text_len
+
st_idx
)
st
=
ed
+
llm_grid_t
*
llm_grid_h
*
llm_grid_w
if
st
<
len
(
input_tokens
):
st_idx
=
(
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
)
text_len
=
len
(
input_tokens
)
-
st
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
llm_positions
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
).
reshape
(
3
,
-
1
)
position_ids
[...,
i
,
:]
=
llm_positions
.
to
(
position_ids
.
device
)
mrope_position_deltas
.
append
(
llm_positions
.
max
()
+
1
-
len
(
total_input_ids
[
i
])
)
mrope_position_deltas
=
torch
.
tensor
(
mrope_position_deltas
,
device
=
input_ids
.
device
).
unsqueeze
(
1
)
return
position_ids
,
mrope_position_deltas
else
:
s
=
input_ids
.
shape
[
1
]
position_ids
=
torch
.
arange
(
s
)
position_ids
=
(
position_ids
.
unsqueeze
(
0
).
expand
(
3
,
-
1
,
-
1
).
to
(
input_ids
.
device
)
)
)
max_position_ids
=
position_ids
.
max
(
0
,
keepdim
=
False
)[
0
].
max
(
llm_positions
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
).
reshape
(
3
,
-
1
)
-
1
,
keepdim
=
True
mrope_position_delta
=
(
llm_positions
.
max
()
+
1
-
len
(
input_tokens
)).
item
()
)[
0
]
llm_positions
=
llm_positions
[:,
context_len
:
seq_len
]
mrope_position_deltas
=
max_position_ids
+
1
-
s
return
position_ids
,
mrope_position_deltas
return
llm_positions
.
tolist
(),
mrope_position_delta
@
staticmethod
@
staticmethod
def
get_next_input_positions
(
def
get_next_input_positions
(
...
...
python/sglang/srt/managers/io_struct.py
View file @
c998d04b
...
@@ -463,6 +463,8 @@ class EmbeddingReqInput:
...
@@ -463,6 +463,8 @@ class EmbeddingReqInput:
image_data
:
Optional
[
image_data
:
Optional
[
Union
[
List
[
List
[
Union
[
Image
,
str
]]],
List
[
Union
[
Image
,
str
]],
Union
[
Image
,
str
]]
Union
[
List
[
List
[
Union
[
Image
,
str
]]],
List
[
Union
[
Image
,
str
]],
Union
[
Image
,
str
]]
]
=
None
]
=
None
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
audio_data
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# The token ids for text; one can either specify text or input_ids.
# The token ids for text; one can either specify text or input_ids.
input_ids
:
Optional
[
Union
[
List
[
List
[
int
]],
List
[
int
]]]
=
None
input_ids
:
Optional
[
Union
[
List
[
List
[
int
]],
List
[
int
]]]
=
None
# The request id.
# The request id.
...
...
python/sglang/srt/managers/mm_utils.py
View file @
c998d04b
...
@@ -10,12 +10,13 @@ import torch
...
@@ -10,12 +10,13 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
sglang.srt.managers.schedule_batch
import
(
from
sglang.srt.managers.schedule_batch
import
(
Modality
,
MultimodalDataItem
,
MultimodalDataItem
,
MultimodalInputs
,
MultimodalInputs
,
global_server_args_dict
,
global_server_args_dict
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
print_warning_once
from
sglang.srt.utils
import
flatten_nested_list
,
print_warning_once
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -97,31 +98,80 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
...
@@ -97,31 +98,80 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
return
padded_ids
return
padded_ids
class
MultiModalityDataPaddingPattern
Image
Tokens
(
MultiModalityDataPaddingPattern
):
class
MultiModalityDataPaddingPattern
Multimodal
Tokens
(
MultiModalityDataPaddingPattern
):
"""In this pattern, data tokens should be represented as repetitions of a single token
"""In this pattern, data tokens should be represented as repetitions of a single token
e.g. <image><image>....<image>, or <audio><audio>...<audio>
e.g. <image><image>....<image>, or <audio><audio>...<audio>
"""
"""
def
__init__
(
self
,
image_
token_id
:
torch
.
Tensor
)
->
None
:
def
__init__
(
self
,
token_id
s
:
List
[
int
]
)
->
None
:
self
.
image_
token_id
=
image_
token_id
self
.
token_id
s
=
token_id
s
def
pad_input_tokens
(
self
,
input_ids
:
List
[
int
],
mm_inputs
)
->
List
[
int
]:
def
pad_input_tokens
(
self
,
input_ids
:
List
[
int
],
mm_inputs
:
MultimodalInputs
)
->
List
[
int
]:
"""
"""
This function will replace the data-tokens in between with pad_values accordingly
Finds contiguous regions of tokens matching `self.token_ids` in `input_ids`
and replaces each region with the corresponding `pad_value` from `mm_inputs.mm_items`.
"""
"""
pad_values
=
[
item
.
pad_value
for
item
in
mm_inputs
.
mm_items
]
pad_values
=
[
item
.
pad_value
for
item
in
mm_inputs
.
mm_items
]
assert
len
(
pad_values
)
!=
0
if
not
pad_values
:
# No multimodal items, return original input_ids
return
input_ids
if
not
input_ids
:
return
[]
input_ids_tensor
=
torch
.
tensor
(
input_ids
)
input_ids_tensor
=
torch
.
tensor
(
input_ids
)
mask
=
torch
.
isin
(
input_ids_tensor
,
self
.
image_token_id
)
device
=
input_ids_tensor
.
device
token_ids_tensor
=
torch
.
tensor
(
self
.
token_ids
,
device
=
device
)
mask
=
torch
.
isin
(
input_ids_tensor
,
token_ids_tensor
)
num_image_tokens
=
mask
.
sum
().
item
()
if
not
mask
.
any
():
repeated_pad_values
=
torch
.
tensor
(
pad_values
).
repeat
(
# No tokens match token_ids, return original input_ids
num_image_tokens
//
len
(
pad_values
)
+
1
return
input_ids
)[:
num_image_tokens
]
# Find contiguous regions
padded_mask
=
torch
.
cat
(
(
torch
.
tensor
([
False
],
device
=
device
),
mask
,
torch
.
tensor
([
False
],
device
=
device
),
)
)
# Find indices where the mask value changes
diff_indices
=
torch
.
where
(
padded_mask
[
1
:]
!=
padded_mask
[:
-
1
])[
0
]
# Start indices are where False changes to True
starts
=
diff_indices
[::
2
]
# End indices are where True changes to False (exclusive index)
ends
=
diff_indices
[
1
::
2
]
# Check if the number of regions matches the number of pad values
if
len
(
starts
)
!=
len
(
pad_values
):
# Maybe log a warning here?
num_regions
=
len
(
starts
)
num_pad_values
=
len
(
pad_values
)
if
num_regions
>
0
and
num_pad_values
>
0
:
pad_values
=
(
pad_values
*
(
num_regions
//
num_pad_values
+
1
))[
:
num_regions
]
else
:
# If no regions or no pad_values, this loop won't run anyway.
pad_values
=
[]
# Ensure pad_values is empty if starts is empty
# Create a copy to modify
output_ids_tensor
=
input_ids_tensor
.
clone
()
# Replace tokens in each region with the corresponding pad value
# Ensure we don't iterate if pad_values became empty due to mismatch and num_regions=0
for
i
in
range
(
min
(
len
(
starts
),
len
(
pad_values
))):
start_idx
=
starts
[
i
]
end_idx
=
ends
[
i
]
pad_value
=
pad_values
[
i
]
if
pad_value
is
not
None
:
# Ensure pad_value is not None before assignment
output_ids_tensor
[
start_idx
:
end_idx
]
=
pad_value
else
:
logger
.
warning
(
f
"Skipping region
{
i
}
due to None pad_value."
)
input_ids_tensor
[
mask
]
=
repeated_pad_values
return
output_ids_tensor
.
tolist
()
return
input_ids_tensor
.
tolist
()
def
get_embedding_and_mask
(
def
get_embedding_and_mask
(
...
@@ -150,7 +200,6 @@ def get_embedding_and_mask(
...
@@ -150,7 +200,6 @@ def get_embedding_and_mask(
).
unsqueeze
(
-
1
)
).
unsqueeze
(
-
1
)
num_mm_tokens_in_input_ids
=
special_multimodal_mask
.
sum
().
item
()
num_mm_tokens_in_input_ids
=
special_multimodal_mask
.
sum
().
item
()
if
num_mm_tokens_in_input_ids
!=
num_mm_tokens_in_embedding
:
if
num_mm_tokens_in_input_ids
!=
num_mm_tokens_in_embedding
:
logger
.
warning
(
logger
.
warning
(
f
"Number of tokens in multimodal embedding does not match those in the input text."
f
"Number of tokens in multimodal embedding does not match those in the input text."
...
@@ -190,13 +239,13 @@ def embed_mm_inputs(
...
@@ -190,13 +239,13 @@ def embed_mm_inputs(
audio_data_embedding_func
:
Callable
[
audio_data_embedding_func
:
Callable
[
[
List
[
MultimodalDataItem
]],
torch
.
Tensor
[
List
[
MultimodalDataItem
]],
torch
.
Tensor
]
=
None
,
]
=
None
,
placeholder_token
_ids
:
List
[
int
]
=
None
,
placeholder_token
s
:
dict
[
Modality
,
List
[
int
]
]
=
None
,
)
->
Optional
[
torch
.
Tensor
]:
)
->
Optional
[
torch
.
Tensor
]:
"""
"""
Calculate the multimodal embeddings if necessary, then scatter the result with the help of a boolean mask denoting the embed locations
Calculate the multimodal embeddings if necessary, then scatter the result with the help of a boolean mask denoting the embed locations
Args:
Args:
placeholder_token
_id
s: denoting the token of multimodal data in input_ids.
placeholder_tokens: denoting the token of multimodal data in input_ids.
If none, the pad_values of multimodal items are used
If none, the pad_values of multimodal items are used
Returns:
Returns:
...
@@ -208,9 +257,17 @@ def embed_mm_inputs(
...
@@ -208,9 +257,17 @@ def embed_mm_inputs(
# 1. Calculate the multimodal data which exists in input_ids, with the help of pad_values
# 1. Calculate the multimodal data which exists in input_ids, with the help of pad_values
# we assume that multimodal data are represented with its pad_values in input_ids
# we assume that multimodal data are represented with its pad_values in input_ids
placeholder_token_ids
=
placeholder_token_ids
or
[
# See `pad_input_ids` for more detail
item
.
pad_value
for
item
in
mm_inputs
.
mm_items
]
# if placeholder_tokens is specified
if
placeholder_tokens
is
not
None
:
placeholder_token_ids
=
flatten_nested_list
(
[
placeholder_token
for
placeholder_token
in
placeholder_tokens
.
values
()]
)
else
:
placeholder_token_ids
=
[
item
.
pad_value
for
item
in
mm_inputs
.
mm_items
]
assert
isinstance
(
placeholder_token_ids
[
0
],
int
)
placeholder_tensor
=
torch
.
tensor
(
placeholder_token_ids
,
device
=
input_ids
.
device
)
placeholder_tensor
=
torch
.
tensor
(
placeholder_token_ids
,
device
=
input_ids
.
device
)
...
@@ -233,7 +290,7 @@ def embed_mm_inputs(
...
@@ -233,7 +290,7 @@ def embed_mm_inputs(
using_all_items
=
False
using_all_items
=
False
if
len
(
appearing_items
)
==
0
:
if
len
(
appearing_items
)
==
0
:
# This happens mostly when arg placeholder_token_ids is passed
# This happens mostly when arg placeholder_token_ids is passed
logger
.
warning
_once
(
logger
.
warning
(
"No multimodal data item's pad value exist in placeholder ids. Using all items"
"No multimodal data item's pad value exist in placeholder ids. Using all items"
)
)
using_all_items
=
True
using_all_items
=
True
...
@@ -253,7 +310,8 @@ def embed_mm_inputs(
...
@@ -253,7 +310,8 @@ def embed_mm_inputs(
data_embedding_func
=
image_data_embedding_func
,
data_embedding_func
=
image_data_embedding_func
,
embedding_items
=
items
,
embedding_items
=
items
,
placeholder_tensor
=
(
placeholder_tensor
=
(
placeholder_tensor
# use the specified modality token to identify the location to embed
placeholder_tokens
[
Modality
.
IMAGE
]
if
using_all_items
if
using_all_items
else
torch
.
tensor
(
else
torch
.
tensor
(
[
item
.
pad_value
for
item
in
items
],
[
item
.
pad_value
for
item
in
items
],
...
@@ -275,7 +333,7 @@ def embed_mm_inputs(
...
@@ -275,7 +333,7 @@ def embed_mm_inputs(
data_embedding_func
=
audio_data_embedding_func
,
data_embedding_func
=
audio_data_embedding_func
,
embedding_items
=
items
,
embedding_items
=
items
,
placeholder_tensor
=
(
placeholder_tensor
=
(
placeholder_tens
or
placeholder_t
ok
ens
[
Modality
.
AUDIO
]
if
using_all_items
if
using_all_items
else
torch
.
tensor
(
else
torch
.
tensor
(
[
item
.
pad_value
for
item
in
items
],
[
item
.
pad_value
for
item
in
items
],
...
@@ -296,7 +354,7 @@ def embed_mm_inputs(
...
@@ -296,7 +354,7 @@ def embed_mm_inputs(
input_ids
.
clamp_
(
min
=
0
,
max
=
vocab_size
-
1
)
input_ids
.
clamp_
(
min
=
0
,
max
=
vocab_size
-
1
)
inputs_embeds
=
input_embedding
(
input_ids
)
inputs_embeds
=
input_embedding
(
input_ids
)
# 4.
s
catter embeddings into input embedding
# 4.
S
catter embeddings into input embedding
for
embedding
,
mask
in
zip
(
embeddings
,
masks
):
for
embedding
,
mask
in
zip
(
embeddings
,
masks
):
mask
=
mask
.
expand_as
(
inputs_embeds
).
to
(
inputs_embeds
.
device
)
mask
=
mask
.
expand_as
(
inputs_embeds
).
to
(
inputs_embeds
.
device
)
inputs_embeds
=
inputs_embeds
.
masked_scatter
(
inputs_embeds
=
inputs_embeds
.
masked_scatter
(
...
@@ -316,7 +374,7 @@ def general_mm_embed_routine(
...
@@ -316,7 +374,7 @@ def general_mm_embed_routine(
audio_data_embedding_func
:
Callable
[
audio_data_embedding_func
:
Callable
[
[
List
[
MultimodalDataItem
]],
torch
.
Tensor
[
List
[
MultimodalDataItem
]],
torch
.
Tensor
]
=
None
,
]
=
None
,
placeholder_token
_ids
:
List
[
int
]
=
None
,
placeholder_token
s
:
dict
[
Modality
,
List
[
int
]
]
=
None
,
**
kwargs
,
**
kwargs
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
...
@@ -328,7 +386,6 @@ def general_mm_embed_routine(
...
@@ -328,7 +386,6 @@ def general_mm_embed_routine(
audio_data_embedding_func : the function returning the image embedding
audio_data_embedding_func : the function returning the image embedding
Returns:
Returns:
inputs_embedding
forwarded hidden states
forwarded hidden states
"""
"""
...
@@ -346,9 +403,9 @@ def general_mm_embed_routine(
...
@@ -346,9 +403,9 @@ def general_mm_embed_routine(
input_embedding
=
embed_tokens
,
input_embedding
=
embed_tokens
,
image_data_embedding_func
=
image_data_embedding_func
,
image_data_embedding_func
=
image_data_embedding_func
,
audio_data_embedding_func
=
audio_data_embedding_func
,
audio_data_embedding_func
=
audio_data_embedding_func
,
placeholder_token
_id
s
=
placeholder_token
_id
s
,
placeholder_tokens
=
placeholder_tokens
,
)
)
# once used, mm_inputs is useless
# once used, mm_inputs is useless
, considering chunked-prefill is disabled for multimodal models
# just being defensive here
# just being defensive here
forward_batch
.
mm_inputs
=
None
forward_batch
.
mm_inputs
=
None
else
:
else
:
...
...
python/sglang/srt/managers/multimodal_processors/base_processor.py
View file @
c998d04b
...
@@ -8,6 +8,7 @@ from typing import List, Optional
...
@@ -8,6 +8,7 @@ from typing import List, Optional
import
numpy
as
np
import
numpy
as
np
import
PIL
import
PIL
from
PIL
import
Image
from
transformers
import
BaseImageProcessorFast
from
transformers
import
BaseImageProcessorFast
from
sglang.srt.managers.schedule_batch
import
Modality
from
sglang.srt.managers.schedule_batch
import
Modality
...
@@ -92,7 +93,12 @@ class BaseMultimodalProcessor(ABC):
...
@@ -92,7 +93,12 @@ class BaseMultimodalProcessor(ABC):
@
abstractmethod
@
abstractmethod
async
def
process_mm_data_async
(
async
def
process_mm_data_async
(
self
,
image_data
,
input_text
,
max_req_input_len
,
**
kwargs
self
,
image_data
,
input_text
,
request_obj
,
max_req_input_len
,
**
kwargs
,
):
):
pass
pass
...
@@ -104,6 +110,8 @@ class BaseMultimodalProcessor(ABC):
...
@@ -104,6 +110,8 @@ class BaseMultimodalProcessor(ABC):
from
decord
import
VideoReader
,
cpu
from
decord
import
VideoReader
,
cpu
# Before processing inputs
# Before processing inputs
if
not
image_data
or
len
(
image_data
)
==
0
:
return
[]
estimated_frames_list
=
[]
estimated_frames_list
=
[]
for
image
in
image_data
:
for
image
in
image_data
:
if
isinstance
(
image
,
str
)
and
image
.
startswith
(
"video:"
):
if
isinstance
(
image
,
str
)
and
image
.
startswith
(
"video:"
):
...
@@ -215,6 +223,9 @@ class BaseMultimodalProcessor(ABC):
...
@@ -215,6 +223,9 @@ class BaseMultimodalProcessor(ABC):
discard_alpha_channel: if True, discards the alpha channel in the returned images
discard_alpha_channel: if True, discards the alpha channel in the returned images
"""
"""
if
image_data
is
None
:
image_data
=
[]
if
isinstance
(
multimodal_tokens
.
image_token
,
int
):
if
isinstance
(
multimodal_tokens
.
image_token
,
int
):
multimodal_tokens
.
image_token
=
(
multimodal_tokens
.
image_token
=
(
self
.
_processor
.
tokenizer
.
convert_ids_to_tokens
(
self
.
_processor
.
tokenizer
.
convert_ids_to_tokens
(
...
@@ -229,6 +240,8 @@ class BaseMultimodalProcessor(ABC):
...
@@ -229,6 +240,8 @@ class BaseMultimodalProcessor(ABC):
prompt
=
self
.
_processor
.
tokenizer
.
decode
(
prompt
)
prompt
=
self
.
_processor
.
tokenizer
.
decode
(
prompt
)
else
:
else
:
prompt
=
prompt
prompt
=
prompt
assert
isinstance
(
prompt
,
str
)
if
return_text
:
if
return_text
:
import
re
import
re
...
...
python/sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py
View file @
c998d04b
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
from
typing
import
List
,
Union
import
torch
import
torch
...
@@ -35,7 +36,13 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
...
@@ -35,7 +36,13 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
self
.
IMAGE_TOKEN
=
"<image>"
self
.
IMAGE_TOKEN
=
"<image>"
async
def
process_mm_data_async
(
async
def
process_mm_data_async
(
self
,
image_data
,
input_ids
,
request_obj
,
max_req_input_len
,
*
args
,
**
kwargs
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
request_obj
,
max_req_input_len
,
*
args
,
**
kwargs
):
):
if
not
image_data
:
if
not
image_data
:
return
None
return
None
...
@@ -45,7 +52,7 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
...
@@ -45,7 +52,7 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
image_token
=
self
.
IMAGE_TOKEN
image_token
=
self
.
IMAGE_TOKEN
base_output
=
self
.
load_mm_data
(
base_output
=
self
.
load_mm_data
(
input_
ids
,
input_
text
,
image_data
=
image_data
,
image_data
=
image_data
,
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
image_token
),
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
image_token
),
max_req_input_len
=
max_req_input_len
,
max_req_input_len
=
max_req_input_len
,
...
...
python/sglang/srt/managers/multimodal_processors/gemma3.py
View file @
c998d04b
from
typing
import
List
,
Union
from
typing
import
List
,
Union
from
transformers.utils
import
logging
from
sglang.srt.managers.multimodal_processor
import
(
from
sglang.srt.managers.multimodal_processor
import
(
BaseMultimodalProcessor
as
SGLangBaseProcessor
,
BaseMultimodalProcessor
as
SGLangBaseProcessor
,
)
)
...
@@ -13,7 +11,6 @@ from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
...
@@ -13,7 +11,6 @@ from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
# will be removed in the future
# will be removed in the future
logger
=
logging
.
get_logger
(
__name__
)
class
Gemma3SGLangImageProcessor
(
SGLangBaseProcessor
):
class
Gemma3SGLangImageProcessor
(
SGLangBaseProcessor
):
...
@@ -28,7 +25,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
...
@@ -28,7 +25,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
async
def
process_mm_data_async
(
async
def
process_mm_data_async
(
self
,
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
image_data
:
List
[
Union
[
str
,
bytes
]],
input_
ids
,
input_
text
,
request_obj
,
request_obj
,
max_req_input_len
,
max_req_input_len
,
*
args
,
*
args
,
...
@@ -41,7 +38,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
...
@@ -41,7 +38,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
image_token
=
self
.
IMAGE_TOKEN
image_token
=
self
.
IMAGE_TOKEN
base_output
=
self
.
load_mm_data
(
base_output
=
self
.
load_mm_data
(
prompt
=
input_
ids
,
prompt
=
input_
text
,
image_data
=
image_data
,
image_data
=
image_data
,
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
image_token
),
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
image_token
),
max_req_input_len
=
max_req_input_len
,
max_req_input_len
=
max_req_input_len
,
...
...
python/sglang/srt/managers/multimodal_processors/janus_pro.py
View file @
c998d04b
...
@@ -17,7 +17,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
...
@@ -17,7 +17,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
async
def
process_mm_data_async
(
async
def
process_mm_data_async
(
self
,
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
image_data
:
List
[
Union
[
str
,
bytes
]],
input_
ids
,
input_
text
,
request_obj
,
request_obj
,
max_req_input_len
,
max_req_input_len
,
**
kwargs
,
**
kwargs
,
...
@@ -31,7 +31,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
...
@@ -31,7 +31,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
processor
=
self
.
_processor
processor
=
self
.
_processor
base_out
=
self
.
load_mm_data
(
base_out
=
self
.
load_mm_data
(
prompt
=
input_
ids
,
prompt
=
input_
text
,
image_data
=
image_data
,
image_data
=
image_data
,
multimodal_tokens
=
MultimodalSpecialTokens
(
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
processor
.
image_token
image_token
=
processor
.
image_token
...
...
python/sglang/srt/managers/multimodal_processors/minicpm.py
View file @
c998d04b
...
@@ -51,9 +51,10 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
...
@@ -51,9 +51,10 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
async
def
process_mm_data_async
(
async
def
process_mm_data_async
(
self
,
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
image_data
:
List
[
Union
[
str
,
bytes
]],
input_
ids
,
input_
text
,
request_obj
,
request_obj
,
max_req_input_len
,
max_req_input_len
,
**
kwargs
,
):
):
audio_data
=
request_obj
.
audio_data
audio_data
=
request_obj
.
audio_data
if
not
image_data
and
not
audio_data
:
if
not
image_data
and
not
audio_data
:
...
@@ -64,7 +65,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
...
@@ -64,7 +65,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
audio_data
=
[
audio_data
]
audio_data
=
[
audio_data
]
base_output
=
self
.
load_mm_data
(
base_output
=
self
.
load_mm_data
(
prompt
=
input_
ids
,
prompt
=
input_
text
,
max_req_input_len
=
max_req_input_len
,
max_req_input_len
=
max_req_input_len
,
audio_data
=
audio_data
,
audio_data
=
audio_data
,
image_data
=
image_data
,
image_data
=
image_data
,
...
...
python/sglang/srt/managers/multimodal_processors/qwen_vl.py
View file @
c998d04b
...
@@ -5,6 +5,7 @@ from typing import List, Union
...
@@ -5,6 +5,7 @@ from typing import List, Union
import
torch
import
torch
from
PIL
import
Image
from
PIL
import
Image
from
sglang.srt.layers.rotary_embedding
import
MRotaryEmbedding
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
BaseMultimodalProcessor
as
SGLangBaseProcessor
,
BaseMultimodalProcessor
as
SGLangBaseProcessor
,
)
)
...
@@ -27,6 +28,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
...
@@ -27,6 +28,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
self
.
IM_END_TOKEN_ID
=
hf_config
.
vision_end_token_id
self
.
IM_END_TOKEN_ID
=
hf_config
.
vision_end_token_id
self
.
image_token_id
=
hf_config
.
image_token_id
self
.
image_token_id
=
hf_config
.
image_token_id
self
.
video_token_id
=
hf_config
.
video_token_id
self
.
video_token_id
=
hf_config
.
video_token_id
self
.
vision_start_token_id
=
hf_config
.
vision_start_token_id
self
.
vision_end_token_id
=
hf_config
.
vision_end_token_id
self
.
NUM_TOKEN_PER_FRAME
=
770
self
.
NUM_TOKEN_PER_FRAME
=
770
self
.
IMAGE_FACTOR
=
28
self
.
IMAGE_FACTOR
=
28
self
.
MIN_PIXELS
=
4
*
28
*
28
self
.
MIN_PIXELS
=
4
*
28
*
28
...
@@ -36,20 +39,18 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
...
@@ -36,20 +39,18 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
async
def
process_mm_data_async
(
async
def
process_mm_data_async
(
self
,
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
image_data
:
List
[
Union
[
str
,
bytes
]],
promp
t
,
input_tex
t
,
request_obj
,
request_obj
,
max_req_input_len
,
max_req_input_len
,
*
args
,
*
args
,
**
kwargs
,
**
kwargs
,
):
):
if
not
image_data
:
return
None
if
isinstance
(
image_data
,
str
):
if
isinstance
(
image_data
,
str
):
image_data
=
[
image_data
]
image_data
=
[
image_data
]
image_token
=
self
.
IMAGE_TOKEN
image_token
=
self
.
IMAGE_TOKEN
base_output
=
self
.
load_mm_data
(
base_output
=
self
.
load_mm_data
(
prompt
=
promp
t
,
prompt
=
input_tex
t
,
image_data
=
image_data
,
image_data
=
image_data
,
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
image_token
),
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
image_token
),
max_req_input_len
=
max_req_input_len
,
max_req_input_len
=
max_req_input_len
,
...
@@ -116,29 +117,53 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
...
@@ -116,29 +117,53 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
async
def
resize_image_async
(
image
):
async
def
resize_image_async
(
image
):
return
resize_image
(
image
)
return
resize_image
(
image
)
resize_tasks
=
[
resize_image_async
(
image
)
for
image
in
base_output
.
images
]
if
base_output
.
images
:
resized_images
=
await
asyncio
.
gather
(
*
resize_tasks
)
resize_tasks
=
[
resize_image_async
(
image
)
for
image
in
base_output
.
images
]
base_output
.
images
=
await
asyncio
.
gather
(
*
resize_tasks
)
ret
=
self
.
process_mm_data
(
ret
=
self
.
process_mm_data
(
input_text
=
base_output
.
input_text
,
input_text
=
base_output
.
input_text
,
images
=
resized_
images
,
images
=
base_output
.
images
,
)
)
image_grid_thws
=
torch
.
concat
([
ret
[
"image_grid_thw"
]])
items
=
[]
return
{
"input_ids"
:
ret
[
"input_ids"
].
flatten
().
tolist
(),
input_ids
=
ret
[
"input_ids"
].
flatten
().
tolist
()
"mm_items"
:
[
if
"pixel_values"
in
ret
:
items
+=
[
MultimodalDataItem
(
MultimodalDataItem
(
pixel_values
=
ret
[
"pixel_values"
],
pixel_values
=
ret
[
"pixel_values"
],
image_grid_thws
=
image_grid_thw
s
,
image_grid_thws
=
torch
.
concat
([
ret
[
"
image_grid_thw
"
]])
,
# TODO
# TODO
video_grid_thws
=
None
,
video_grid_thws
=
None
,
second_per_grid_ts
=
ret
.
get
(
"second_per_grid_ts"
,
None
),
second_per_grid_ts
=
ret
.
get
(
"second_per_grid_ts"
,
None
),
modality
=
Modality
.
IMAGE
,
modality
=
Modality
.
IMAGE
,
)
)
],
]
mrope_positions
,
mrope_position_delta
=
MRotaryEmbedding
.
get_rope_index
(
spatial_merge_size
=
self
.
hf_config
.
vision_config
.
spatial_merge_size
,
image_token_id
=
self
.
image_token_id
,
video_token_id
=
self
.
video_token_id
,
vision_start_token_id
=
self
.
vision_start_token_id
,
model_type
=
self
.
hf_config
.
model_type
,
tokens_per_second
=
getattr
(
self
.
hf_config
.
vision_config
,
"tokens_per_second"
,
None
),
input_ids
=
torch
.
tensor
(
input_ids
).
unsqueeze
(
0
),
image_grid_thw
=
ret
.
get
(
"image_grid_thw"
,
None
),
video_grid_thw
=
ret
.
get
(
"video_grid_thw"
,
None
),
second_per_grid_ts
=
ret
.
get
(
"second_per_grid_ts"
,
None
),
)
mrope_positions
=
mrope_positions
.
squeeze
(
1
)
return
{
"input_ids"
:
input_ids
,
"mm_items"
:
items
,
"im_start_id"
:
self
.
IM_START_TOKEN_ID
,
"im_start_id"
:
self
.
IM_START_TOKEN_ID
,
"im_end_id"
:
self
.
IM_END_TOKEN_ID
,
"im_end_id"
:
self
.
IM_END_TOKEN_ID
,
"im_token_id"
:
self
.
image_token_id
,
"im_token_id"
:
self
.
image_token_id
,
"video_token_id"
:
self
.
video_token_id
,
"video_token_id"
:
self
.
video_token_id
,
"mrope_positions"
:
mrope_positions
,
"mrope_position_delta"
:
mrope_position_delta
,
}
}
python/sglang/srt/managers/schedule_batch.py
View file @
c998d04b
...
@@ -285,6 +285,7 @@ class MultimodalInputs:
...
@@ -285,6 +285,7 @@ class MultimodalInputs:
num_image_tokens
:
Optional
[
int
]
=
None
num_image_tokens
:
Optional
[
int
]
=
None
# QWen2-VL related
# QWen2-VL related
mrope_positions
:
Optional
[
torch
.
Tensor
]
=
None
mrope_position_delta
:
Optional
[
torch
.
Tensor
]
=
None
mrope_position_delta
:
Optional
[
torch
.
Tensor
]
=
None
# image
# image
...
@@ -310,16 +311,12 @@ class MultimodalInputs:
...
@@ -310,16 +311,12 @@ class MultimodalInputs:
assert
isinstance
(
ret
.
mm_items
,
list
)
assert
isinstance
(
ret
.
mm_items
,
list
)
ret
.
mm_items
=
[
item
for
item
in
ret
.
mm_items
if
item
.
is_valid
()]
ret
.
mm_items
=
[
item
for
item
in
ret
.
mm_items
if
item
.
is_valid
()]
assert
len
(
ret
.
mm_items
)
!=
0
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# errors in cuda kernels. See also llava.py for example.
for
item
in
ret
.
mm_items
:
for
item
in
ret
.
mm_items
:
item
.
set_pad_value
()
item
.
set_pad_value
()
optional_args
=
[
optional_args
=
[
"mrope_positions"
,
"mrope_position_delta"
,
"im_token_id"
,
"im_token_id"
,
"im_start_id"
,
"im_start_id"
,
"im_end_id"
,
"im_end_id"
,
...
@@ -350,20 +347,26 @@ class MultimodalInputs:
...
@@ -350,20 +347,26 @@ class MultimodalInputs:
merge image inputs when requests are being merged
merge image inputs when requests are being merged
"""
"""
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# errors in cuda kernels. See also llava.py for example.
# args needed to be merged
# args needed to be merged
optional_args
=
[
optional_args
=
[
"mm_items"
,
"mm_items"
,
"image_pad_len"
,
"image_pad_len"
,
"mrope_position_delta"
,
]
]
for
arg
in
optional_args
:
for
arg
in
optional_args
:
self_arg
=
getattr
(
self
,
arg
,
None
)
self_arg
=
getattr
(
self
,
arg
,
None
)
if
self_arg
is
not
None
:
if
self_arg
is
not
None
:
setattr
(
self
,
arg
,
self_arg
+
getattr
(
other
,
arg
))
setattr
(
self
,
arg
,
self_arg
+
getattr
(
other
,
arg
))
mrope_positions
=
self
.
mrope_positions
if
mrope_positions
is
not
None
:
if
other
.
mrope_positions
is
None
:
self
.
mrope_positions
=
mrope_positions
else
:
self
.
mrope_positions
=
torch
.
cat
(
[
self
.
mrope_positions
,
other
.
mrope_positions
],
dim
=
1
)
# other args would be kept intact
# other args would be kept intact
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
c998d04b
...
@@ -419,7 +419,10 @@ class TokenizerManager:
...
@@ -419,7 +419,10 @@ class TokenizerManager:
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
image_inputs
:
Dict
=
await
self
.
mm_processor
.
process_mm_data_async
(
image_inputs
:
Dict
=
await
self
.
mm_processor
.
process_mm_data_async
(
obj
.
image_data
,
input_text
or
input_ids
,
obj
,
self
.
max_req_input_len
image_data
=
obj
.
image_data
,
input_text
=
input_text
or
input_ids
,
request_obj
=
obj
,
max_req_input_len
=
self
.
max_req_input_len
,
)
)
if
image_inputs
and
"input_ids"
in
image_inputs
:
if
image_inputs
and
"input_ids"
in
image_inputs
:
input_ids
=
image_inputs
[
"input_ids"
]
input_ids
=
image_inputs
[
"input_ids"
]
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
c998d04b
...
@@ -407,8 +407,6 @@ class ForwardBatch:
...
@@ -407,8 +407,6 @@ class ForwardBatch:
def
_compute_mrope_positions
(
def
_compute_mrope_positions
(
self
,
model_runner
:
ModelRunner
,
batch
:
ModelWorkerBatch
self
,
model_runner
:
ModelRunner
,
batch
:
ModelWorkerBatch
):
):
device
=
model_runner
.
device
hf_config
=
model_runner
.
model_config
.
hf_config
mrope_positions_list
=
[
None
]
*
self
.
seq_lens
.
shape
[
0
]
mrope_positions_list
=
[
None
]
*
self
.
seq_lens
.
shape
[
0
]
if
self
.
forward_mode
.
is_decode
():
if
self
.
forward_mode
.
is_decode
():
for
i
,
_
in
enumerate
(
mrope_positions_list
):
for
i
,
_
in
enumerate
(
mrope_positions_list
):
...
@@ -417,93 +415,44 @@ class ForwardBatch:
...
@@ -417,93 +415,44 @@ class ForwardBatch:
if
batch
.
multimodal_inputs
[
i
]
is
None
if
batch
.
multimodal_inputs
[
i
]
is
None
else
batch
.
multimodal_inputs
[
i
].
mrope_position_delta
else
batch
.
multimodal_inputs
[
i
].
mrope_position_delta
)
)
mrope_positions_list
[
i
]
=
MRotaryEmbedding
.
get_next_input_positions
(
mrope_positions_list
[
i
]
=
torch
.
tensor
(
mrope_position_delta
,
MRotaryEmbedding
.
get_next_input_positions
(
int
(
self
.
seq_lens
[
i
])
-
1
,
mrope_position_delta
,
int
(
self
.
seq_lens
[
i
]),
int
(
self
.
seq_lens
[
i
])
-
1
,
int
(
self
.
seq_lens
[
i
]),
)
)
)
elif
self
.
forward_mode
.
is_extend
():
elif
self
.
forward_mode
.
is_extend
():
extend_start_loc_cpu
=
self
.
extend_start_loc
.
cpu
().
numpy
()
for
i
,
mm_input
in
enumerate
(
batch
.
multimodal_inputs
):
for
i
,
mm_input
in
enumerate
(
batch
.
multimodal_inputs
):
extend_start_loc
,
extend_seq_len
,
extend_prefix_len
=
(
extend_seq_len
,
extend_prefix_len
=
(
extend_start_loc_cpu
[
i
],
batch
.
extend_seq_lens
[
i
],
batch
.
extend_seq_lens
[
i
],
batch
.
extend_prefix_lens
[
i
],
batch
.
extend_prefix_lens
[
i
],
)
)
if
mm_input
is
None
:
if
mm_input
is
None
:
# text only
# text only
mrope_positions
=
[
mrope_positions
=
torch
.
tensor
(
[
[
pos
[
for
pos
in
range
(
pos
extend_prefix_len
,
extend_prefix_len
+
extend_seq_len
for
pos
in
range
(
)
extend_prefix_len
,
extend_prefix_len
+
extend_seq_len
,
)
]
]
]
]
*
3
*
3
else
:
image_grid_thws_list
=
[
item
.
image_grid_thws
for
item
in
mm_input
.
mm_items
if
item
.
image_grid_thws
is
not
None
]
image_grid_thw
=
(
None
if
len
(
image_grid_thws_list
)
==
0
else
torch
.
cat
(
image_grid_thws_list
,
dim
=
0
)
)
video_grid_thws_list
=
[
item
.
video_grid_thws
for
item
in
mm_input
.
mm_items
if
item
.
video_grid_thws
is
not
None
]
video_grid_thw
=
(
None
if
len
(
video_grid_thws_list
)
==
0
else
torch
.
cat
(
video_grid_thws_list
,
dim
=
0
)
)
)
else
:
second_per_grid_ts_list
=
[
mrope_positions
=
mm_input
.
mrope_positions
[
item
.
second_per_grid_ts
:,
for
item
in
mm_input
.
mm_items
extend_prefix_len
:
extend_prefix_len
+
extend_seq_len
,
if
item
.
second_per_grid_ts
is
not
None
]
]
second_per_grid_ts
=
(
None
if
len
(
second_per_grid_ts_list
)
==
0
else
torch
.
cat
(
second_per_grid_ts_list
,
dim
=
0
)
)
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
mrope_positions
,
mrope_position_delta
=
(
MRotaryEmbedding
.
get_input_positions
(
input_tokens
=
self
.
input_ids
[
extend_start_loc
:
extend_start_loc
+
extend_seq_len
].
tolist
(),
image_grid_thw
=
image_grid_thw
,
video_grid_thw
=
video_grid_thw
,
image_token_id
=
hf_config
.
image_token_id
,
video_token_id
=
hf_config
.
video_token_id
,
vision_start_token_id
=
hf_config
.
vision_start_token_id
,
vision_end_token_id
=
hf_config
.
vision_end_token_id
,
spatial_merge_size
=
hf_config
.
vision_config
.
spatial_merge_size
,
context_len
=
0
,
seq_len
=
len
(
self
.
input_ids
),
second_per_grid_ts
=
second_per_grid_ts
,
tokens_per_second
=
getattr
(
hf_config
.
vision_config
,
"tokens_per_second"
,
None
),
)
)
batch
.
multimodal_inputs
[
i
].
mrope_position_delta
=
(
mrope_position_delta
)
mrope_positions_list
[
i
]
=
mrope_positions
mrope_positions_list
[
i
]
=
mrope_positions
self
.
mrope_positions
=
torch
.
cat
(
self
.
mrope_positions
=
torch
.
cat
(
[
torch
.
tensor
(
pos
,
device
=
device
)
for
pos
in
mrope_positions_list
],
[
pos
.
to
(
device
=
model_runner
.
device
)
for
pos
in
mrope_positions_list
],
axis
=
1
,
dim
=
1
,
)
)
.
to
(
device
=
model_runner
.
device
)
self
.
mrope_positions
=
self
.
mrope_positions
.
to
(
torch
.
int64
)
self
.
mrope_positions
=
self
.
mrope_positions
.
to
(
torch
.
int64
)
def
get_max_chunk_capacity
(
self
):
def
get_max_chunk_capacity
(
self
):
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
c998d04b
...
@@ -310,15 +310,6 @@ class ModelRunner:
...
@@ -310,15 +310,6 @@ class ModelRunner:
)
)
server_args
.
chunked_prefill_size
=
-
1
server_args
.
chunked_prefill_size
=
-
1
if
self
.
model_config
.
hf_config
.
architectures
==
[
"Qwen2VLForConditionalGeneration"
]
or
self
.
model_config
.
hf_config
.
architectures
==
[
"Qwen2_5_VLForConditionalGeneration"
]:
# TODO: qwen2-vl series does not support radix cache now, set disable_radix_cache=True automatically
logger
.
info
(
"Automatically disable radix cache for qwen-vl series."
)
server_args
.
disable_radix_cache
=
True
if
server_args
.
enable_deepep_moe
:
if
server_args
.
enable_deepep_moe
:
logger
.
info
(
f
"DeepEP is turned on. DeepEP mode:
{
server_args
.
deepep_mode
}
"
)
logger
.
info
(
f
"DeepEP is turned on. DeepEP mode:
{
server_args
.
deepep_mode
}
"
)
...
...
python/sglang/srt/models/deepseek_vl2.py
View file @
c998d04b
...
@@ -12,7 +12,7 @@ from sglang.srt.configs.deepseekvl2 import (
...
@@ -12,7 +12,7 @@ from sglang.srt.configs.deepseekvl2 import (
from
sglang.srt.layers.linear
import
ReplicatedLinear
from
sglang.srt.layers.linear
import
ReplicatedLinear
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.mm_utils
import
(
from
sglang.srt.managers.mm_utils
import
(
MultiModalityDataPaddingPattern
Image
Tokens
,
MultiModalityDataPaddingPattern
Multimodal
Tokens
,
general_mm_embed_routine
,
general_mm_embed_routine
,
)
)
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
...
@@ -249,8 +249,8 @@ class DeepseekVL2ForCausalLM(nn.Module):
...
@@ -249,8 +249,8 @@ class DeepseekVL2ForCausalLM(nn.Module):
weights_loader
(
param
,
loaded_weight
)
weights_loader
(
param
,
loaded_weight
)
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
MultimodalInputs
):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
MultimodalInputs
):
helper
=
MultiModalityDataPaddingPattern
Image
Tokens
(
helper
=
MultiModalityDataPaddingPattern
Multimodal
Tokens
(
image_token_id
=
image_inputs
.
im_token_id
[
image_inputs
.
im_token_id
]
)
)
return
helper
.
pad_input_tokens
(
input_ids
,
image_inputs
)
return
helper
.
pad_input_tokens
(
input_ids
,
image_inputs
)
...
...
python/sglang/srt/models/minicpmo.py
View file @
c998d04b
...
@@ -43,6 +43,7 @@ from sglang.srt.managers.mm_utils import (
...
@@ -43,6 +43,7 @@ from sglang.srt.managers.mm_utils import (
general_mm_embed_routine
,
general_mm_embed_routine
,
)
)
from
sglang.srt.managers.schedule_batch
import
(
from
sglang.srt.managers.schedule_batch
import
(
Modality
,
MultimodalDataItem
,
MultimodalDataItem
,
MultimodalInputs
,
MultimodalInputs
,
flatten_nested_list
,
flatten_nested_list
,
...
@@ -1834,7 +1835,10 @@ class MiniCPMO(MiniCPMBaseModel):
...
@@ -1834,7 +1835,10 @@ class MiniCPMO(MiniCPMBaseModel):
language_model
=
self
.
llm
,
language_model
=
self
.
llm
,
image_data_embedding_func
=
self
.
get_image_feature
,
image_data_embedding_func
=
self
.
get_image_feature
,
audio_data_embedding_func
=
self
.
get_audio_feature
,
audio_data_embedding_func
=
self
.
get_audio_feature
,
placeholder_token_ids
=
placeholder_token_ids
,
placeholder_tokens
=
{
Modality
.
IMAGE
:
placeholder_token_ids
,
Modality
.
AUDIO
:
placeholder_token_ids
,
},
positions
=
positions
,
positions
=
positions
,
)
)
return
hidden_states
return
hidden_states
...
...
python/sglang/srt/models/mllama4.py
View file @
c998d04b
...
@@ -10,7 +10,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
...
@@ -10,7 +10,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.quantization
import
QuantizationConfig
from
sglang.srt.layers.quantization
import
QuantizationConfig
from
sglang.srt.managers.mm_utils
import
(
from
sglang.srt.managers.mm_utils
import
(
MultiModalityDataPaddingPattern
Image
Tokens
,
MultiModalityDataPaddingPattern
Multimodal
Tokens
,
general_mm_embed_routine
,
general_mm_embed_routine
,
)
)
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
...
@@ -53,7 +53,7 @@ class Llama4ForConditionalGeneration(nn.Module):
...
@@ -53,7 +53,7 @@ class Llama4ForConditionalGeneration(nn.Module):
# Get all special token IDs
# Get all special token IDs
im_token_id
:
int
=
mm_inputs
.
im_token_id
im_token_id
:
int
=
mm_inputs
.
im_token_id
pattern
=
MultiModalityDataPaddingPattern
ImageTokens
(
torch
.
tensor
(
im_token_id
)
)
pattern
=
MultiModalityDataPaddingPattern
MultimodalTokens
([
im_token_id
]
)
return
pattern
.
pad_input_tokens
(
input_ids
,
mm_inputs
)
return
pattern
.
pad_input_tokens
(
input_ids
,
mm_inputs
)
def
get_image_feature
(
def
get_image_feature
(
...
...
python/sglang/srt/models/qwen2_5_vl.py
View file @
c998d04b
...
@@ -49,7 +49,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
...
@@ -49,7 +49,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.managers.mm_utils
import
(
from
sglang.srt.managers.mm_utils
import
(
MultiModalityDataPaddingPatternToken
Pair
s
,
MultiModalityDataPaddingPattern
Multimodal
Tokens
,
general_mm_embed_routine
,
general_mm_embed_routine
,
)
)
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
...
@@ -488,11 +488,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
...
@@ -488,11 +488,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
mm_inputs
:
MultimodalInputs
):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
mm_inputs
:
MultimodalInputs
):
# Get all special token IDs
# Get all special token IDs
im_start_id
:
int
=
mm_inputs
.
im_start_id
im_token_id
:
int
=
mm_inputs
.
im_token_id
im_end_id
:
int
=
mm_inputs
.
im_end_id
pattern
=
MultiModalityDataPaddingPatternMultimodalTokens
([
im_token_id
])
media_token_pairs
=
[(
im_start_id
,
im_end_id
)]
pattern
=
MultiModalityDataPaddingPatternTokenPairs
(
media_token_pairs
)
return
pattern
.
pad_input_tokens
(
input_ids
,
mm_inputs
)
return
pattern
.
pad_input_tokens
(
input_ids
,
mm_inputs
)
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
...
...
python/sglang/srt/models/qwen2_vl.py
View file @
c998d04b
...
@@ -42,7 +42,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
...
@@ -42,7 +42,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.managers.mm_utils
import
(
from
sglang.srt.managers.mm_utils
import
(
MultiModalityDataPaddingPatternToken
Pair
s
,
MultiModalityDataPaddingPattern
Multimodal
Tokens
,
general_mm_embed_routine
,
general_mm_embed_routine
,
)
)
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
...
@@ -490,15 +490,11 @@ class Qwen2VLForConditionalGeneration(nn.Module):
...
@@ -490,15 +490,11 @@ class Qwen2VLForConditionalGeneration(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
# Use grid_t * grid_w * grid_h to pad tokens for each image
# add replaced padding by unique image hash
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
mm_inputs
:
MultimodalInputs
):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
mm_inputs
:
MultimodalInputs
):
# Get all special token IDs
# Get all special token IDs
im_start_id
:
int
=
mm_inputs
.
im_start_id
im_token_id
:
int
=
mm_inputs
.
im_token_id
im_end_id
:
int
=
mm_inputs
.
im_end_id
media_token_pairs
=
[(
im_start_id
,
im_end_id
)]
pattern
=
MultiModalityDataPaddingPatternMultimodalTokens
([
im_token_id
])
pattern
=
MultiModalityDataPaddingPatternTokenPairs
(
media_token_pairs
)
return
pattern
.
pad_input_tokens
(
input_ids
,
mm_inputs
)
return
pattern
.
pad_input_tokens
(
input_ids
,
mm_inputs
)
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment