Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c998d04b
Unverified
Commit
c998d04b
authored
Apr 24, 2025
by
Mick
Committed by
GitHub
Apr 23, 2025
Browse files
vlm: enable radix cache for qwen-vl models (#5349)
Co-authored-by:
Xinyuan Tong
<
justinning0323@outlook.com
>
parent
7d0edf3c
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
413 additions
and
301 deletions
+413
-301
benchmark/mmmu/eval_utils.py
benchmark/mmmu/eval_utils.py
+45
-13
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+6
-7
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+150
-114
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+2
-0
python/sglang/srt/managers/mm_utils.py
python/sglang/srt/managers/mm_utils.py
+85
-28
python/sglang/srt/managers/multimodal_processors/base_processor.py
...lang/srt/managers/multimodal_processors/base_processor.py
+14
-1
python/sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py
...lang/srt/managers/multimodal_processors/deepseek_vl_v2.py
+9
-2
python/sglang/srt/managers/multimodal_processors/gemma3.py
python/sglang/srt/managers/multimodal_processors/gemma3.py
+2
-5
python/sglang/srt/managers/multimodal_processors/janus_pro.py
...on/sglang/srt/managers/multimodal_processors/janus_pro.py
+2
-2
python/sglang/srt/managers/multimodal_processors/minicpm.py
python/sglang/srt/managers/multimodal_processors/minicpm.py
+3
-2
python/sglang/srt/managers/multimodal_processors/qwen_vl.py
python/sglang/srt/managers/multimodal_processors/qwen_vl.py
+38
-13
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+14
-11
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+4
-1
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+23
-74
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+0
-9
python/sglang/srt/models/deepseek_vl2.py
python/sglang/srt/models/deepseek_vl2.py
+3
-3
python/sglang/srt/models/minicpmo.py
python/sglang/srt/models/minicpmo.py
+5
-1
python/sglang/srt/models/mllama4.py
python/sglang/srt/models/mllama4.py
+2
-2
python/sglang/srt/models/qwen2_5_vl.py
python/sglang/srt/models/qwen2_5_vl.py
+3
-6
python/sglang/srt/models/qwen2_vl.py
python/sglang/srt/models/qwen2_vl.py
+3
-7
No files found.
benchmark/mmmu/eval_utils.py
View file @
c998d04b
...
...
@@ -89,7 +89,7 @@ def set_seed(seed_value):
def
prepare_samples
(
eval_args
:
EvalArgs
):
print
(
"
p
reparing samples..."
)
print
(
"
P
reparing samples..."
)
# Build prompts
set_seed
(
eval_args
.
seed
)
...
...
@@ -105,15 +105,40 @@ def prepare_samples(eval_args: EvalArgs):
assert
len
(
value
)
==
1
,
"key {} has more than one value"
.
format
(
key
)
eval_args
.
config
[
key
]
=
value
[
0
]
# run for each subject
# run for each subject
in parallel
sub_dataset_list
=
[]
subjects
=
list
(
CAT_SHORT2LONG
.
values
())
# Get a fixed list of subjects
for
subject
in
tqdm
(
CAT_SHORT2LONG
.
values
()):
sub_dataset
=
load_dataset
(
eval_args
.
dataset_path
,
subject
,
split
=
eval_args
.
split
)
sub_dataset_list
.
append
(
sub_dataset
)
# break
print
(
f
"Loading datasets for
{
len
(
subjects
)
}
subjects..."
)
with
ThreadPoolExecutor
()
as
executor
:
# Submit all load_dataset tasks
future_to_subject
=
{
executor
.
submit
(
load_dataset
,
eval_args
.
dataset_path
,
subject
,
split
=
eval_args
.
split
):
subject
for
subject
in
subjects
}
# Collect results as they complete
results
=
{}
for
future
in
tqdm
(
as_completed
(
future_to_subject
),
total
=
len
(
subjects
),
desc
=
"Loading datasets"
,
):
subject
=
future_to_subject
[
future
]
try
:
results
[
subject
]
=
future
.
result
()
except
Exception
as
exc
:
print
(
f
"
{
subject
}
generated an exception:
{
exc
}
"
)
# Ensure datasets are added in the original order for consistency
for
subject
in
subjects
:
if
subject
in
results
:
sub_dataset_list
.
append
(
results
[
subject
])
else
:
# Handle cases where a dataset failed to load (optional, depends on desired behavior)
print
(
f
"Warning: Dataset for subject '
{
subject
}
' could not be loaded."
)
# merge all dataset
dataset
=
concatenate_datasets
(
sub_dataset_list
)
...
...
@@ -133,18 +158,25 @@ def prepare_samples(eval_args: EvalArgs):
width
,
height
=
image
.
size
if
width
*
height
>=
eval_args
.
image_pixels_limit
:
return
None
,
True
image_path
=
f
"
{
images_path
}
/image_
{
i
}
.png"
# Use a unique identifier for the image path to avoid potential collisions if indices reset
image_path
=
f
"
{
images_path
}
/image_
{
sample
[
'id'
]
}
.png"
if
not
os
.
path
.
exists
(
image_path
):
image
.
save
(
image_path
)
sample
[
"image_path"
]
=
image_path
return
sample
,
False
print
(
"Processing samples..."
)
with
ThreadPoolExecutor
()
as
executor
:
# Pass the sample itself to process_sample, index is less reliable now
futures
=
[
executor
.
submit
(
process_sample
,
i
,
sample
)
executor
.
submit
(
process_sample
,
i
,
sample
)
# Keep index i for tqdm maybe? Or remove it. Let's keep it for now.
for
i
,
sample
in
enumerate
(
dataset
)
]
for
future
in
tqdm
(
as_completed
(
futures
),
total
=
len
(
futures
)):
for
future
in
tqdm
(
as_completed
(
futures
),
total
=
len
(
dataset
),
desc
=
"Processing samples"
):
sample
,
skipped
=
future
.
result
()
if
skipped
:
skip_count
+=
1
...
...
@@ -152,9 +184,9 @@ def prepare_samples(eval_args: EvalArgs):
samples
.
append
(
sample
)
print
(
f
"
s
kipping
{
skip_count
}
samples with large images,
{
round
((
float
(
skip_count
)
/
len
(
dataset
))
*
100
,
2
)
}
% of dataset"
f
"
S
kipping
{
skip_count
}
samples with large images,
{
round
((
float
(
skip_count
)
/
len
(
dataset
))
*
100
,
2
)
}
% of dataset"
)
print
(
"
s
amples have been prepared"
)
print
(
"
S
amples have been prepared"
)
return
samples
...
...
python/sglang/srt/configs/model_config.py
View file @
c998d04b
...
...
@@ -73,15 +73,14 @@ class ModelConfig:
)
if
enable_multimodal
is
None
:
if
self
.
hf_config
.
architectures
[
0
]
==
"Llama4ForConditionalGeneration"
:
mm_disabled_models
=
[
"Gemma3ForConditionalGeneration"
,
"Llama4ForConditionalGeneration"
,
]
if
self
.
hf_config
.
architectures
[
0
]
in
mm_disabled_models
:
enable_multimodal
=
False
logger
.
info
(
"Multimodal is disabled for Llama4. To enable it, set --enable-llama4-multimodal."
)
elif
self
.
hf_config
.
architectures
[
0
]
==
"Gemma3ForConditionalGeneration"
:
enable_multimodal
=
False
logger
.
info
(
"Multimodal is disabled for Gemma3. To enable it, set --enable-gemma3-multimodal."
f
"Multimodal is disabled for
{
self
.
hf_config
.
model_type
}
. To enable it, set --enable-multimodal."
)
else
:
enable_multimodal
=
True
...
...
python/sglang/srt/layers/rotary_embedding.py
View file @
c998d04b
...
...
@@ -877,127 +877,163 @@ class MRotaryEmbedding(RotaryEmbedding):
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
# Copied from https://github.com/huggingface/transformers/blob/c8e0e603de9b3d49161a15fe6e8ea84badfb5d02/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1439
@
staticmethod
def
get_input_positions
(
input_tokens
:
List
[
int
],
image_grid_thw
:
Union
[
List
[
List
[
int
]],
torch
.
Tensor
],
video_grid_thw
:
Union
[
List
[
List
[
int
]],
torch
.
Tensor
],
def
get_rope_index
(
spatial_merge_size
:
int
,
image_token_id
:
int
,
video_token_id
:
int
,
vision_start_token_id
:
int
,
vision_end_token_id
:
int
,
spatial_merge_size
:
int
,
context_len
:
int
=
0
,
seq_len
:
Optional
[
int
]
=
None
,
second_per_grid_ts
:
Optional
[
torch
.
Tensor
]
=
None
,
model_type
:
str
,
tokens_per_second
:
Optional
[
int
]
=
None
,
)
->
Tuple
[
List
[
List
[
int
]],
int
]:
"""
Get mrope input positions and delta value.
:arg
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
"""
if
isinstance
(
image_grid_thw
,
torch
.
Tensor
):
image_grid_thw
=
image_grid_thw
.
tolist
()
if
isinstance
(
video_grid_thw
,
torch
.
Tensor
):
video_grid_thw
=
video_grid_thw
.
tolist
()
input_tokens_tensor
=
torch
.
tensor
(
input_tokens
)
vision_start_indices
=
torch
.
argwhere
(
input_tokens_tensor
==
vision_start_token_id
).
squeeze
(
1
)
vision_tokens
=
input_tokens_tensor
[
vision_start_indices
+
1
]
image_nums
=
(
vision_tokens
==
image_token_id
).
sum
()
video_nums
=
(
vision_tokens
==
video_token_id
).
sum
()
llm_pos_ids_list
:
list
=
[]
st
=
0
remain_images
,
remain_videos
=
image_nums
,
video_nums
image_index
,
video_index
=
0
,
0
for
_
in
range
(
image_nums
+
video_nums
):
if
image_token_id
in
input_tokens
and
remain_images
>
0
:
ed_image
=
input_tokens
.
index
(
image_token_id
,
st
)
else
:
ed_image
=
len
(
input_tokens
)
+
1
if
video_token_id
in
input_tokens
and
remain_videos
>
0
:
ed_video
=
input_tokens
.
index
(
video_token_id
,
st
)
else
:
ed_video
=
len
(
input_tokens
)
+
1
if
ed_image
<
ed_video
:
t
,
h
,
w
=
(
image_grid_thw
[
image_index
][
0
],
image_grid_thw
[
image_index
][
1
],
image_grid_thw
[
image_index
][
2
],
)
image_index
+=
1
remain_images
-=
1
second_per_grid_t
=
0
ed
=
ed_image
else
:
t
,
h
,
w
=
(
video_grid_thw
[
video_index
][
0
],
video_grid_thw
[
video_index
][
1
],
video_grid_thw
[
video_index
][
2
],
)
if
second_per_grid_ts
is
not
None
:
second_per_grid_t
=
second_per_grid_ts
[
video_index
]
else
:
second_per_grid_t
=
1.0
video_index
+=
1
remain_videos
-=
1
ed
=
ed_video
llm_grid_t
,
llm_grid_h
,
llm_grid_w
=
(
t
,
h
//
spatial_merge_size
,
w
//
spatial_merge_size
,
)
text_len
=
ed
-
st
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
t_index
=
(
torch
.
arange
(
llm_grid_t
).
view
(
-
1
,
1
).
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
)
*
second_per_grid_t
*
tokens_per_second
).
flatten
()
h_index
=
(
torch
.
arange
(
llm_grid_h
)
.
view
(
1
,
-
1
,
1
)
.
expand
(
llm_grid_t
,
-
1
,
llm_grid_w
)
.
flatten
()
)
w_index
=
(
torch
.
arange
(
llm_grid_w
)
.
view
(
1
,
1
,
-
1
)
.
expand
(
llm_grid_t
,
llm_grid_h
,
-
1
)
.
flatten
()
)
llm_pos_ids_list
.
append
(
torch
.
stack
([
t_index
,
h_index
,
w_index
])
+
text_len
+
st_idx
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
image_grid_thw
:
Optional
[
torch
.
LongTensor
]
=
None
,
video_grid_thw
:
Optional
[
torch
.
LongTensor
]
=
None
,
second_per_grid_ts
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
mrope_position_deltas
=
[]
if
input_ids
is
not
None
and
(
image_grid_thw
is
not
None
or
video_grid_thw
is
not
None
):
total_input_ids
=
input_ids
position_ids
=
torch
.
ones
(
3
,
input_ids
.
shape
[
0
],
input_ids
.
shape
[
1
],
dtype
=
input_ids
.
dtype
,
device
=
input_ids
.
device
,
)
st
=
ed
+
llm_grid_t
*
llm_grid_h
*
llm_grid_w
if
st
<
len
(
input_tokens
):
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
text_len
=
len
(
input_tokens
)
-
st
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
image_index
,
video_index
=
0
,
0
for
i
,
input_ids
in
enumerate
(
total_input_ids
):
image_nums
,
video_nums
=
0
,
0
vision_start_indices
=
torch
.
argwhere
(
input_ids
==
vision_start_token_id
).
squeeze
(
1
)
vision_tokens
=
input_ids
[
vision_start_indices
+
1
]
image_nums
=
(
vision_tokens
==
image_token_id
).
sum
()
video_nums
=
(
vision_tokens
==
video_token_id
).
sum
()
input_tokens
=
input_ids
.
tolist
()
llm_pos_ids_list
:
list
=
[]
st
=
0
remain_images
,
remain_videos
=
image_nums
,
video_nums
for
_
in
range
(
image_nums
+
video_nums
):
if
image_token_id
in
input_tokens
and
remain_images
>
0
:
ed_image
=
input_tokens
.
index
(
image_token_id
,
st
)
else
:
ed_image
=
len
(
input_tokens
)
+
1
if
video_token_id
in
input_tokens
and
remain_videos
>
0
:
ed_video
=
input_tokens
.
index
(
video_token_id
,
st
)
else
:
ed_video
=
len
(
input_tokens
)
+
1
if
ed_image
<
ed_video
:
t
,
h
,
w
=
(
image_grid_thw
[
image_index
][
0
],
image_grid_thw
[
image_index
][
1
],
image_grid_thw
[
image_index
][
2
],
)
second_per_grid_t
=
0
image_index
+=
1
remain_images
-=
1
ed
=
ed_image
else
:
t
,
h
,
w
=
(
video_grid_thw
[
video_index
][
0
],
video_grid_thw
[
video_index
][
1
],
video_grid_thw
[
video_index
][
2
],
)
if
second_per_grid_ts
is
not
None
:
second_per_grid_t
=
second_per_grid_ts
[
video_index
]
else
:
second_per_grid_t
=
1.0
video_index
+=
1
remain_videos
-=
1
ed
=
ed_video
llm_grid_t
,
llm_grid_h
,
llm_grid_w
=
(
t
.
item
(),
h
.
item
()
//
spatial_merge_size
,
w
.
item
()
//
spatial_merge_size
,
)
text_len
=
ed
-
st
st_idx
=
(
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
)
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
if
model_type
==
"qwen2_5_vl"
:
range_tensor
=
torch
.
arange
(
llm_grid_t
).
view
(
-
1
,
1
)
expanded_range
=
range_tensor
.
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
)
time_tensor
=
(
expanded_range
*
second_per_grid_t
*
tokens_per_second
)
time_tensor_long
=
time_tensor
.
long
()
t_index
=
time_tensor_long
.
flatten
()
elif
model_type
==
"qwen2_vl"
:
t_index
=
(
torch
.
arange
(
llm_grid_t
)
.
view
(
-
1
,
1
)
.
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
)
.
flatten
()
)
else
:
raise
RuntimeError
(
"Unimplemented"
)
h_index
=
(
torch
.
arange
(
llm_grid_h
)
.
view
(
1
,
-
1
,
1
)
.
expand
(
llm_grid_t
,
-
1
,
llm_grid_w
)
.
flatten
()
)
w_index
=
(
torch
.
arange
(
llm_grid_w
)
.
view
(
1
,
1
,
-
1
)
.
expand
(
llm_grid_t
,
llm_grid_h
,
-
1
)
.
flatten
()
)
llm_pos_ids_list
.
append
(
torch
.
stack
([
t_index
,
h_index
,
w_index
])
+
text_len
+
st_idx
)
st
=
ed
+
llm_grid_t
*
llm_grid_h
*
llm_grid_w
if
st
<
len
(
input_tokens
):
st_idx
=
(
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
)
text_len
=
len
(
input_tokens
)
-
st
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
llm_positions
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
).
reshape
(
3
,
-
1
)
position_ids
[...,
i
,
:]
=
llm_positions
.
to
(
position_ids
.
device
)
mrope_position_deltas
.
append
(
llm_positions
.
max
()
+
1
-
len
(
total_input_ids
[
i
])
)
mrope_position_deltas
=
torch
.
tensor
(
mrope_position_deltas
,
device
=
input_ids
.
device
).
unsqueeze
(
1
)
return
position_ids
,
mrope_position_deltas
else
:
s
=
input_ids
.
shape
[
1
]
position_ids
=
torch
.
arange
(
s
)
position_ids
=
(
position_ids
.
unsqueeze
(
0
).
expand
(
3
,
-
1
,
-
1
).
to
(
input_ids
.
device
)
)
llm_positions
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
).
reshape
(
3
,
-
1
)
mrope_position_delta
=
(
llm_positions
.
max
()
+
1
-
len
(
input_tokens
)).
item
()
llm_positions
=
llm_positions
[:,
context_len
:
seq_len
]
return
llm_positions
.
tolist
(),
mrope_position_delta
max_position_ids
=
position_ids
.
max
(
0
,
keepdim
=
False
)[
0
].
max
(
-
1
,
keepdim
=
True
)[
0
]
mrope_position_deltas
=
max_position_ids
+
1
-
s
return
position_ids
,
mrope_position_deltas
@
staticmethod
def
get_next_input_positions
(
...
...
python/sglang/srt/managers/io_struct.py
View file @
c998d04b
...
...
@@ -463,6 +463,8 @@ class EmbeddingReqInput:
image_data
:
Optional
[
Union
[
List
[
List
[
Union
[
Image
,
str
]]],
List
[
Union
[
Image
,
str
]],
Union
[
Image
,
str
]]
]
=
None
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
audio_data
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# The token ids for text; one can either specify text or input_ids.
input_ids
:
Optional
[
Union
[
List
[
List
[
int
]],
List
[
int
]]]
=
None
# The request id.
...
...
python/sglang/srt/managers/mm_utils.py
View file @
c998d04b
...
...
@@ -10,12 +10,13 @@ import torch
from
torch
import
nn
from
sglang.srt.managers.schedule_batch
import
(
Modality
,
MultimodalDataItem
,
MultimodalInputs
,
global_server_args_dict
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
print_warning_once
from
sglang.srt.utils
import
flatten_nested_list
,
print_warning_once
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -97,31 +98,80 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
return
padded_ids
class
MultiModalityDataPaddingPattern
Image
Tokens
(
MultiModalityDataPaddingPattern
):
class
MultiModalityDataPaddingPattern
Multimodal
Tokens
(
MultiModalityDataPaddingPattern
):
"""In this pattern, data tokens should be represented as repetitions of a single token
e.g. <image><image>....<image>, or <audio><audio>...<audio>
"""
def
__init__
(
self
,
image_
token_id
:
torch
.
Tensor
)
->
None
:
self
.
image_
token_id
=
image_
token_id
def
__init__
(
self
,
token_id
s
:
List
[
int
]
)
->
None
:
self
.
token_id
s
=
token_id
s
def
pad_input_tokens
(
self
,
input_ids
:
List
[
int
],
mm_inputs
)
->
List
[
int
]:
def
pad_input_tokens
(
self
,
input_ids
:
List
[
int
],
mm_inputs
:
MultimodalInputs
)
->
List
[
int
]:
"""
This function will replace the data-tokens in between with pad_values accordingly
Finds contiguous regions of tokens matching `self.token_ids` in `input_ids`
and replaces each region with the corresponding `pad_value` from `mm_inputs.mm_items`.
"""
pad_values
=
[
item
.
pad_value
for
item
in
mm_inputs
.
mm_items
]
assert
len
(
pad_values
)
!=
0
if
not
pad_values
:
# No multimodal items, return original input_ids
return
input_ids
if
not
input_ids
:
return
[]
input_ids_tensor
=
torch
.
tensor
(
input_ids
)
mask
=
torch
.
isin
(
input_ids_tensor
,
self
.
image_token_id
)
device
=
input_ids_tensor
.
device
token_ids_tensor
=
torch
.
tensor
(
self
.
token_ids
,
device
=
device
)
mask
=
torch
.
isin
(
input_ids_tensor
,
token_ids_tensor
)
num_image_tokens
=
mask
.
sum
().
item
()
repeated_pad_values
=
torch
.
tensor
(
pad_values
).
repeat
(
num_image_tokens
//
len
(
pad_values
)
+
1
)[:
num_image_tokens
]
if
not
mask
.
any
():
# No tokens match token_ids, return original input_ids
return
input_ids
# Find contiguous regions
padded_mask
=
torch
.
cat
(
(
torch
.
tensor
([
False
],
device
=
device
),
mask
,
torch
.
tensor
([
False
],
device
=
device
),
)
)
# Find indices where the mask value changes
diff_indices
=
torch
.
where
(
padded_mask
[
1
:]
!=
padded_mask
[:
-
1
])[
0
]
# Start indices are where False changes to True
starts
=
diff_indices
[::
2
]
# End indices are where True changes to False (exclusive index)
ends
=
diff_indices
[
1
::
2
]
# Check if the number of regions matches the number of pad values
if
len
(
starts
)
!=
len
(
pad_values
):
# Maybe log a warning here?
num_regions
=
len
(
starts
)
num_pad_values
=
len
(
pad_values
)
if
num_regions
>
0
and
num_pad_values
>
0
:
pad_values
=
(
pad_values
*
(
num_regions
//
num_pad_values
+
1
))[
:
num_regions
]
else
:
# If no regions or no pad_values, this loop won't run anyway.
pad_values
=
[]
# Ensure pad_values is empty if starts is empty
# Create a copy to modify
output_ids_tensor
=
input_ids_tensor
.
clone
()
# Replace tokens in each region with the corresponding pad value
# Ensure we don't iterate if pad_values became empty due to mismatch and num_regions=0
for
i
in
range
(
min
(
len
(
starts
),
len
(
pad_values
))):
start_idx
=
starts
[
i
]
end_idx
=
ends
[
i
]
pad_value
=
pad_values
[
i
]
if
pad_value
is
not
None
:
# Ensure pad_value is not None before assignment
output_ids_tensor
[
start_idx
:
end_idx
]
=
pad_value
else
:
logger
.
warning
(
f
"Skipping region
{
i
}
due to None pad_value."
)
input_ids_tensor
[
mask
]
=
repeated_pad_values
return
input_ids_tensor
.
tolist
()
return
output_ids_tensor
.
tolist
()
def
get_embedding_and_mask
(
...
...
@@ -150,7 +200,6 @@ def get_embedding_and_mask(
).
unsqueeze
(
-
1
)
num_mm_tokens_in_input_ids
=
special_multimodal_mask
.
sum
().
item
()
if
num_mm_tokens_in_input_ids
!=
num_mm_tokens_in_embedding
:
logger
.
warning
(
f
"Number of tokens in multimodal embedding does not match those in the input text."
...
...
@@ -190,13 +239,13 @@ def embed_mm_inputs(
audio_data_embedding_func
:
Callable
[
[
List
[
MultimodalDataItem
]],
torch
.
Tensor
]
=
None
,
placeholder_token
_ids
:
List
[
int
]
=
None
,
placeholder_token
s
:
dict
[
Modality
,
List
[
int
]
]
=
None
,
)
->
Optional
[
torch
.
Tensor
]:
"""
Calculate the multimodal embeddings if necessary, then scatter the result with the help of a boolean mask denoting the embed locations
Args:
placeholder_token
_id
s: denoting the token of multimodal data in input_ids.
placeholder_tokens: denoting the token of multimodal data in input_ids.
If none, the pad_values of multimodal items are used
Returns:
...
...
@@ -208,9 +257,17 @@ def embed_mm_inputs(
# 1. Calculate the multimodal data which exists in input_ids, with the help of pad_values
# we assume that multimodal data are represented with its pad_values in input_ids
placeholder_token_ids
=
placeholder_token_ids
or
[
item
.
pad_value
for
item
in
mm_inputs
.
mm_items
]
# See `pad_input_ids` for more detail
# if placeholder_tokens is specified
if
placeholder_tokens
is
not
None
:
placeholder_token_ids
=
flatten_nested_list
(
[
placeholder_token
for
placeholder_token
in
placeholder_tokens
.
values
()]
)
else
:
placeholder_token_ids
=
[
item
.
pad_value
for
item
in
mm_inputs
.
mm_items
]
assert
isinstance
(
placeholder_token_ids
[
0
],
int
)
placeholder_tensor
=
torch
.
tensor
(
placeholder_token_ids
,
device
=
input_ids
.
device
)
...
...
@@ -233,7 +290,7 @@ def embed_mm_inputs(
using_all_items
=
False
if
len
(
appearing_items
)
==
0
:
# This happens mostly when arg placeholder_token_ids is passed
logger
.
warning
_once
(
logger
.
warning
(
"No multimodal data item's pad value exist in placeholder ids. Using all items"
)
using_all_items
=
True
...
...
@@ -253,7 +310,8 @@ def embed_mm_inputs(
data_embedding_func
=
image_data_embedding_func
,
embedding_items
=
items
,
placeholder_tensor
=
(
placeholder_tensor
# use the specified modality token to identify the location to embed
placeholder_tokens
[
Modality
.
IMAGE
]
if
using_all_items
else
torch
.
tensor
(
[
item
.
pad_value
for
item
in
items
],
...
...
@@ -275,7 +333,7 @@ def embed_mm_inputs(
data_embedding_func
=
audio_data_embedding_func
,
embedding_items
=
items
,
placeholder_tensor
=
(
placeholder_tens
or
placeholder_t
ok
ens
[
Modality
.
AUDIO
]
if
using_all_items
else
torch
.
tensor
(
[
item
.
pad_value
for
item
in
items
],
...
...
@@ -296,7 +354,7 @@ def embed_mm_inputs(
input_ids
.
clamp_
(
min
=
0
,
max
=
vocab_size
-
1
)
inputs_embeds
=
input_embedding
(
input_ids
)
# 4.
s
catter embeddings into input embedding
# 4.
S
catter embeddings into input embedding
for
embedding
,
mask
in
zip
(
embeddings
,
masks
):
mask
=
mask
.
expand_as
(
inputs_embeds
).
to
(
inputs_embeds
.
device
)
inputs_embeds
=
inputs_embeds
.
masked_scatter
(
...
...
@@ -316,7 +374,7 @@ def general_mm_embed_routine(
audio_data_embedding_func
:
Callable
[
[
List
[
MultimodalDataItem
]],
torch
.
Tensor
]
=
None
,
placeholder_token
_ids
:
List
[
int
]
=
None
,
placeholder_token
s
:
dict
[
Modality
,
List
[
int
]
]
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
"""
...
...
@@ -328,7 +386,6 @@ def general_mm_embed_routine(
audio_data_embedding_func : the function returning the image embedding
Returns:
inputs_embedding
forwarded hidden states
"""
...
...
@@ -346,9 +403,9 @@ def general_mm_embed_routine(
input_embedding
=
embed_tokens
,
image_data_embedding_func
=
image_data_embedding_func
,
audio_data_embedding_func
=
audio_data_embedding_func
,
placeholder_token
_id
s
=
placeholder_token
_id
s
,
placeholder_tokens
=
placeholder_tokens
,
)
# once used, mm_inputs is useless
# once used, mm_inputs is useless
, considering chunked-prefill is disabled for multimodal models
# just being defensive here
forward_batch
.
mm_inputs
=
None
else
:
...
...
python/sglang/srt/managers/multimodal_processors/base_processor.py
View file @
c998d04b
...
...
@@ -8,6 +8,7 @@ from typing import List, Optional
import
numpy
as
np
import
PIL
from
PIL
import
Image
from
transformers
import
BaseImageProcessorFast
from
sglang.srt.managers.schedule_batch
import
Modality
...
...
@@ -92,7 +93,12 @@ class BaseMultimodalProcessor(ABC):
@
abstractmethod
async
def
process_mm_data_async
(
self
,
image_data
,
input_text
,
max_req_input_len
,
**
kwargs
self
,
image_data
,
input_text
,
request_obj
,
max_req_input_len
,
**
kwargs
,
):
pass
...
...
@@ -104,6 +110,8 @@ class BaseMultimodalProcessor(ABC):
from
decord
import
VideoReader
,
cpu
# Before processing inputs
if
not
image_data
or
len
(
image_data
)
==
0
:
return
[]
estimated_frames_list
=
[]
for
image
in
image_data
:
if
isinstance
(
image
,
str
)
and
image
.
startswith
(
"video:"
):
...
...
@@ -215,6 +223,9 @@ class BaseMultimodalProcessor(ABC):
discard_alpha_channel: if True, discards the alpha channel in the returned images
"""
if
image_data
is
None
:
image_data
=
[]
if
isinstance
(
multimodal_tokens
.
image_token
,
int
):
multimodal_tokens
.
image_token
=
(
self
.
_processor
.
tokenizer
.
convert_ids_to_tokens
(
...
...
@@ -229,6 +240,8 @@ class BaseMultimodalProcessor(ABC):
prompt
=
self
.
_processor
.
tokenizer
.
decode
(
prompt
)
else
:
prompt
=
prompt
assert
isinstance
(
prompt
,
str
)
if
return_text
:
import
re
...
...
python/sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py
View file @
c998d04b
...
...
@@ -16,6 +16,7 @@
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
from
typing
import
List
,
Union
import
torch
...
...
@@ -35,7 +36,13 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
self
.
IMAGE_TOKEN
=
"<image>"
async
def
process_mm_data_async
(
self
,
image_data
,
input_ids
,
request_obj
,
max_req_input_len
,
*
args
,
**
kwargs
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
request_obj
,
max_req_input_len
,
*
args
,
**
kwargs
):
if
not
image_data
:
return
None
...
...
@@ -45,7 +52,7 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
image_token
=
self
.
IMAGE_TOKEN
base_output
=
self
.
load_mm_data
(
input_
ids
,
input_
text
,
image_data
=
image_data
,
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
image_token
),
max_req_input_len
=
max_req_input_len
,
...
...
python/sglang/srt/managers/multimodal_processors/gemma3.py
View file @
c998d04b
from
typing
import
List
,
Union
from
transformers.utils
import
logging
from
sglang.srt.managers.multimodal_processor
import
(
BaseMultimodalProcessor
as
SGLangBaseProcessor
,
)
...
...
@@ -13,7 +11,6 @@ from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
# will be removed in the future
logger
=
logging
.
get_logger
(
__name__
)
class
Gemma3SGLangImageProcessor
(
SGLangBaseProcessor
):
...
...
@@ -28,7 +25,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
async
def
process_mm_data_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_
ids
,
input_
text
,
request_obj
,
max_req_input_len
,
*
args
,
...
...
@@ -41,7 +38,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
image_token
=
self
.
IMAGE_TOKEN
base_output
=
self
.
load_mm_data
(
prompt
=
input_
ids
,
prompt
=
input_
text
,
image_data
=
image_data
,
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
image_token
),
max_req_input_len
=
max_req_input_len
,
...
...
python/sglang/srt/managers/multimodal_processors/janus_pro.py
View file @
c998d04b
...
...
@@ -17,7 +17,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
async
def
process_mm_data_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_
ids
,
input_
text
,
request_obj
,
max_req_input_len
,
**
kwargs
,
...
...
@@ -31,7 +31,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
processor
=
self
.
_processor
base_out
=
self
.
load_mm_data
(
prompt
=
input_
ids
,
prompt
=
input_
text
,
image_data
=
image_data
,
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
processor
.
image_token
...
...
python/sglang/srt/managers/multimodal_processors/minicpm.py
View file @
c998d04b
...
...
@@ -51,9 +51,10 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
async
def
process_mm_data_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_
ids
,
input_
text
,
request_obj
,
max_req_input_len
,
**
kwargs
,
):
audio_data
=
request_obj
.
audio_data
if
not
image_data
and
not
audio_data
:
...
...
@@ -64,7 +65,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
audio_data
=
[
audio_data
]
base_output
=
self
.
load_mm_data
(
prompt
=
input_
ids
,
prompt
=
input_
text
,
max_req_input_len
=
max_req_input_len
,
audio_data
=
audio_data
,
image_data
=
image_data
,
...
...
python/sglang/srt/managers/multimodal_processors/qwen_vl.py
View file @
c998d04b
...
...
@@ -5,6 +5,7 @@ from typing import List, Union
import
torch
from
PIL
import
Image
from
sglang.srt.layers.rotary_embedding
import
MRotaryEmbedding
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
BaseMultimodalProcessor
as
SGLangBaseProcessor
,
)
...
...
@@ -27,6 +28,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
self
.
IM_END_TOKEN_ID
=
hf_config
.
vision_end_token_id
self
.
image_token_id
=
hf_config
.
image_token_id
self
.
video_token_id
=
hf_config
.
video_token_id
self
.
vision_start_token_id
=
hf_config
.
vision_start_token_id
self
.
vision_end_token_id
=
hf_config
.
vision_end_token_id
self
.
NUM_TOKEN_PER_FRAME
=
770
self
.
IMAGE_FACTOR
=
28
self
.
MIN_PIXELS
=
4
*
28
*
28
...
...
@@ -36,20 +39,18 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
async
def
process_mm_data_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
promp
t
,
input_tex
t
,
request_obj
,
max_req_input_len
,
*
args
,
**
kwargs
,
):
if
not
image_data
:
return
None
if
isinstance
(
image_data
,
str
):
image_data
=
[
image_data
]
image_token
=
self
.
IMAGE_TOKEN
base_output
=
self
.
load_mm_data
(
prompt
=
promp
t
,
prompt
=
input_tex
t
,
image_data
=
image_data
,
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
image_token
),
max_req_input_len
=
max_req_input_len
,
...
...
@@ -116,29 +117,53 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
async
def
resize_image_async
(
image
):
return
resize_image
(
image
)
resize_tasks
=
[
resize_image_async
(
image
)
for
image
in
base_output
.
images
]
resized_images
=
await
asyncio
.
gather
(
*
resize_tasks
)
if
base_output
.
images
:
resize_tasks
=
[
resize_image_async
(
image
)
for
image
in
base_output
.
images
]
base_output
.
images
=
await
asyncio
.
gather
(
*
resize_tasks
)
ret
=
self
.
process_mm_data
(
input_text
=
base_output
.
input_text
,
images
=
resized_
images
,
images
=
base_output
.
images
,
)
image_grid_thws
=
torch
.
concat
([
ret
[
"image_grid_thw"
]])
return
{
"input_ids"
:
ret
[
"input_ids"
].
flatten
().
tolist
(),
"mm_items"
:
[
items
=
[]
input_ids
=
ret
[
"input_ids"
].
flatten
().
tolist
()
if
"pixel_values"
in
ret
:
items
+=
[
MultimodalDataItem
(
pixel_values
=
ret
[
"pixel_values"
],
image_grid_thws
=
image_grid_thw
s
,
image_grid_thws
=
torch
.
concat
([
ret
[
"
image_grid_thw
"
]])
,
# TODO
video_grid_thws
=
None
,
second_per_grid_ts
=
ret
.
get
(
"second_per_grid_ts"
,
None
),
modality
=
Modality
.
IMAGE
,
)
],
]
mrope_positions
,
mrope_position_delta
=
MRotaryEmbedding
.
get_rope_index
(
spatial_merge_size
=
self
.
hf_config
.
vision_config
.
spatial_merge_size
,
image_token_id
=
self
.
image_token_id
,
video_token_id
=
self
.
video_token_id
,
vision_start_token_id
=
self
.
vision_start_token_id
,
model_type
=
self
.
hf_config
.
model_type
,
tokens_per_second
=
getattr
(
self
.
hf_config
.
vision_config
,
"tokens_per_second"
,
None
),
input_ids
=
torch
.
tensor
(
input_ids
).
unsqueeze
(
0
),
image_grid_thw
=
ret
.
get
(
"image_grid_thw"
,
None
),
video_grid_thw
=
ret
.
get
(
"video_grid_thw"
,
None
),
second_per_grid_ts
=
ret
.
get
(
"second_per_grid_ts"
,
None
),
)
mrope_positions
=
mrope_positions
.
squeeze
(
1
)
return
{
"input_ids"
:
input_ids
,
"mm_items"
:
items
,
"im_start_id"
:
self
.
IM_START_TOKEN_ID
,
"im_end_id"
:
self
.
IM_END_TOKEN_ID
,
"im_token_id"
:
self
.
image_token_id
,
"video_token_id"
:
self
.
video_token_id
,
"mrope_positions"
:
mrope_positions
,
"mrope_position_delta"
:
mrope_position_delta
,
}
python/sglang/srt/managers/schedule_batch.py
View file @
c998d04b
...
...
@@ -285,6 +285,7 @@ class MultimodalInputs:
num_image_tokens
:
Optional
[
int
]
=
None
# QWen2-VL related
mrope_positions
:
Optional
[
torch
.
Tensor
]
=
None
mrope_position_delta
:
Optional
[
torch
.
Tensor
]
=
None
# image
...
...
@@ -310,16 +311,12 @@ class MultimodalInputs:
assert
isinstance
(
ret
.
mm_items
,
list
)
ret
.
mm_items
=
[
item
for
item
in
ret
.
mm_items
if
item
.
is_valid
()]
assert
len
(
ret
.
mm_items
)
!=
0
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# errors in cuda kernels. See also llava.py for example.
for
item
in
ret
.
mm_items
:
item
.
set_pad_value
()
optional_args
=
[
"mrope_positions"
,
"mrope_position_delta"
,
"im_token_id"
,
"im_start_id"
,
"im_end_id"
,
...
...
@@ -350,20 +347,26 @@ class MultimodalInputs:
merge image inputs when requests are being merged
"""
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# errors in cuda kernels. See also llava.py for example.
# args needed to be merged
optional_args
=
[
"mm_items"
,
"image_pad_len"
,
"mrope_position_delta"
,
]
for
arg
in
optional_args
:
self_arg
=
getattr
(
self
,
arg
,
None
)
if
self_arg
is
not
None
:
setattr
(
self
,
arg
,
self_arg
+
getattr
(
other
,
arg
))
mrope_positions
=
self
.
mrope_positions
if
mrope_positions
is
not
None
:
if
other
.
mrope_positions
is
None
:
self
.
mrope_positions
=
mrope_positions
else
:
self
.
mrope_positions
=
torch
.
cat
(
[
self
.
mrope_positions
,
other
.
mrope_positions
],
dim
=
1
)
# other args would be kept intact
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
c998d04b
...
...
@@ -419,7 +419,10 @@ class TokenizerManager:
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
image_inputs
:
Dict
=
await
self
.
mm_processor
.
process_mm_data_async
(
obj
.
image_data
,
input_text
or
input_ids
,
obj
,
self
.
max_req_input_len
image_data
=
obj
.
image_data
,
input_text
=
input_text
or
input_ids
,
request_obj
=
obj
,
max_req_input_len
=
self
.
max_req_input_len
,
)
if
image_inputs
and
"input_ids"
in
image_inputs
:
input_ids
=
image_inputs
[
"input_ids"
]
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
c998d04b
...
...
@@ -407,8 +407,6 @@ class ForwardBatch:
def
_compute_mrope_positions
(
self
,
model_runner
:
ModelRunner
,
batch
:
ModelWorkerBatch
):
device
=
model_runner
.
device
hf_config
=
model_runner
.
model_config
.
hf_config
mrope_positions_list
=
[
None
]
*
self
.
seq_lens
.
shape
[
0
]
if
self
.
forward_mode
.
is_decode
():
for
i
,
_
in
enumerate
(
mrope_positions_list
):
...
...
@@ -417,93 +415,44 @@ class ForwardBatch:
if
batch
.
multimodal_inputs
[
i
]
is
None
else
batch
.
multimodal_inputs
[
i
].
mrope_position_delta
)
mrope_positions_list
[
i
]
=
MRotaryEmbedding
.
get_next_input_positions
(
mrope_position_delta
,
int
(
self
.
seq_lens
[
i
])
-
1
,
int
(
self
.
seq_lens
[
i
]),
mrope_positions_list
[
i
]
=
torch
.
tensor
(
MRotaryEmbedding
.
get_next_input_positions
(
mrope_position_delta
,
int
(
self
.
seq_lens
[
i
])
-
1
,
int
(
self
.
seq_lens
[
i
]),
)
)
elif
self
.
forward_mode
.
is_extend
():
extend_start_loc_cpu
=
self
.
extend_start_loc
.
cpu
().
numpy
()
for
i
,
mm_input
in
enumerate
(
batch
.
multimodal_inputs
):
extend_start_loc
,
extend_seq_len
,
extend_prefix_len
=
(
extend_start_loc_cpu
[
i
],
extend_seq_len
,
extend_prefix_len
=
(
batch
.
extend_seq_lens
[
i
],
batch
.
extend_prefix_lens
[
i
],
)
if
mm_input
is
None
:
# text only
mrope_positions
=
[
mrope_positions
=
torch
.
tensor
(
[
pos
for
pos
in
range
(
extend_prefix_len
,
extend_prefix_len
+
extend_seq_len
)
[
pos
for
pos
in
range
(
extend_prefix_len
,
extend_prefix_len
+
extend_seq_len
,
)
]
]
]
*
3
else
:
image_grid_thws_list
=
[
item
.
image_grid_thws
for
item
in
mm_input
.
mm_items
if
item
.
image_grid_thws
is
not
None
]
image_grid_thw
=
(
None
if
len
(
image_grid_thws_list
)
==
0
else
torch
.
cat
(
image_grid_thws_list
,
dim
=
0
)
)
video_grid_thws_list
=
[
item
.
video_grid_thws
for
item
in
mm_input
.
mm_items
if
item
.
video_grid_thws
is
not
None
]
video_grid_thw
=
(
None
if
len
(
video_grid_thws_list
)
==
0
else
torch
.
cat
(
video_grid_thws_list
,
dim
=
0
)
*
3
)
second_per_grid_ts_list
=
[
item
.
second_per_grid_ts
for
item
in
mm_input
.
mm_items
if
item
.
second_per_grid_ts
is
not
None
else
:
mrope_positions
=
mm_input
.
mrope_positions
[
:,
extend_prefix_len
:
extend_prefix_len
+
extend_seq_len
,
]
second_per_grid_ts
=
(
None
if
len
(
second_per_grid_ts_list
)
==
0
else
torch
.
cat
(
second_per_grid_ts_list
,
dim
=
0
)
)
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
mrope_positions
,
mrope_position_delta
=
(
MRotaryEmbedding
.
get_input_positions
(
input_tokens
=
self
.
input_ids
[
extend_start_loc
:
extend_start_loc
+
extend_seq_len
].
tolist
(),
image_grid_thw
=
image_grid_thw
,
video_grid_thw
=
video_grid_thw
,
image_token_id
=
hf_config
.
image_token_id
,
video_token_id
=
hf_config
.
video_token_id
,
vision_start_token_id
=
hf_config
.
vision_start_token_id
,
vision_end_token_id
=
hf_config
.
vision_end_token_id
,
spatial_merge_size
=
hf_config
.
vision_config
.
spatial_merge_size
,
context_len
=
0
,
seq_len
=
len
(
self
.
input_ids
),
second_per_grid_ts
=
second_per_grid_ts
,
tokens_per_second
=
getattr
(
hf_config
.
vision_config
,
"tokens_per_second"
,
None
),
)
)
batch
.
multimodal_inputs
[
i
].
mrope_position_delta
=
(
mrope_position_delta
)
mrope_positions_list
[
i
]
=
mrope_positions
self
.
mrope_positions
=
torch
.
cat
(
[
torch
.
tensor
(
pos
,
device
=
device
)
for
pos
in
mrope_positions_list
],
axis
=
1
,
)
[
pos
.
to
(
device
=
model_runner
.
device
)
for
pos
in
mrope_positions_list
],
dim
=
1
,
)
.
to
(
device
=
model_runner
.
device
)
self
.
mrope_positions
=
self
.
mrope_positions
.
to
(
torch
.
int64
)
def
get_max_chunk_capacity
(
self
):
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
c998d04b
...
...
@@ -310,15 +310,6 @@ class ModelRunner:
)
server_args
.
chunked_prefill_size
=
-
1
if
self
.
model_config
.
hf_config
.
architectures
==
[
"Qwen2VLForConditionalGeneration"
]
or
self
.
model_config
.
hf_config
.
architectures
==
[
"Qwen2_5_VLForConditionalGeneration"
]:
# TODO: qwen2-vl series does not support radix cache now, set disable_radix_cache=True automatically
logger
.
info
(
"Automatically disable radix cache for qwen-vl series."
)
server_args
.
disable_radix_cache
=
True
if
server_args
.
enable_deepep_moe
:
logger
.
info
(
f
"DeepEP is turned on. DeepEP mode:
{
server_args
.
deepep_mode
}
"
)
...
...
python/sglang/srt/models/deepseek_vl2.py
View file @
c998d04b
...
...
@@ -12,7 +12,7 @@ from sglang.srt.configs.deepseekvl2 import (
from
sglang.srt.layers.linear
import
ReplicatedLinear
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.mm_utils
import
(
MultiModalityDataPaddingPattern
Image
Tokens
,
MultiModalityDataPaddingPattern
Multimodal
Tokens
,
general_mm_embed_routine
,
)
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
...
...
@@ -249,8 +249,8 @@ class DeepseekVL2ForCausalLM(nn.Module):
weights_loader
(
param
,
loaded_weight
)
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
MultimodalInputs
):
helper
=
MultiModalityDataPaddingPattern
Image
Tokens
(
image_token_id
=
image_inputs
.
im_token_id
helper
=
MultiModalityDataPaddingPattern
Multimodal
Tokens
(
[
image_inputs
.
im_token_id
]
)
return
helper
.
pad_input_tokens
(
input_ids
,
image_inputs
)
...
...
python/sglang/srt/models/minicpmo.py
View file @
c998d04b
...
...
@@ -43,6 +43,7 @@ from sglang.srt.managers.mm_utils import (
general_mm_embed_routine
,
)
from
sglang.srt.managers.schedule_batch
import
(
Modality
,
MultimodalDataItem
,
MultimodalInputs
,
flatten_nested_list
,
...
...
@@ -1834,7 +1835,10 @@ class MiniCPMO(MiniCPMBaseModel):
language_model
=
self
.
llm
,
image_data_embedding_func
=
self
.
get_image_feature
,
audio_data_embedding_func
=
self
.
get_audio_feature
,
placeholder_token_ids
=
placeholder_token_ids
,
placeholder_tokens
=
{
Modality
.
IMAGE
:
placeholder_token_ids
,
Modality
.
AUDIO
:
placeholder_token_ids
,
},
positions
=
positions
,
)
return
hidden_states
...
...
python/sglang/srt/models/mllama4.py
View file @
c998d04b
...
...
@@ -10,7 +10,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.quantization
import
QuantizationConfig
from
sglang.srt.managers.mm_utils
import
(
MultiModalityDataPaddingPattern
Image
Tokens
,
MultiModalityDataPaddingPattern
Multimodal
Tokens
,
general_mm_embed_routine
,
)
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
...
...
@@ -53,7 +53,7 @@ class Llama4ForConditionalGeneration(nn.Module):
# Get all special token IDs
im_token_id
:
int
=
mm_inputs
.
im_token_id
pattern
=
MultiModalityDataPaddingPattern
ImageTokens
(
torch
.
tensor
(
im_token_id
)
)
pattern
=
MultiModalityDataPaddingPattern
MultimodalTokens
([
im_token_id
]
)
return
pattern
.
pad_input_tokens
(
input_ids
,
mm_inputs
)
def
get_image_feature
(
...
...
python/sglang/srt/models/qwen2_5_vl.py
View file @
c998d04b
...
...
@@ -49,7 +49,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.managers.mm_utils
import
(
MultiModalityDataPaddingPatternToken
Pair
s
,
MultiModalityDataPaddingPattern
Multimodal
Tokens
,
general_mm_embed_routine
,
)
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
...
...
@@ -488,11 +488,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
mm_inputs
:
MultimodalInputs
):
# Get all special token IDs
im_start_id
:
int
=
mm_inputs
.
im_start_id
im_end_id
:
int
=
mm_inputs
.
im_end_id
media_token_pairs
=
[(
im_start_id
,
im_end_id
)]
pattern
=
MultiModalityDataPaddingPatternTokenPairs
(
media_token_pairs
)
im_token_id
:
int
=
mm_inputs
.
im_token_id
pattern
=
MultiModalityDataPaddingPatternMultimodalTokens
([
im_token_id
])
return
pattern
.
pad_input_tokens
(
input_ids
,
mm_inputs
)
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
...
...
python/sglang/srt/models/qwen2_vl.py
View file @
c998d04b
...
...
@@ -42,7 +42,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.managers.mm_utils
import
(
MultiModalityDataPaddingPatternToken
Pair
s
,
MultiModalityDataPaddingPattern
Multimodal
Tokens
,
general_mm_embed_routine
,
)
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
...
...
@@ -490,15 +490,11 @@ class Qwen2VLForConditionalGeneration(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
# Use grid_t * grid_w * grid_h to pad tokens for each image
# add replaced padding by unique image hash
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
mm_inputs
:
MultimodalInputs
):
# Get all special token IDs
im_start_id
:
int
=
mm_inputs
.
im_start_id
im_end_id
:
int
=
mm_inputs
.
im_end_id
im_token_id
:
int
=
mm_inputs
.
im_token_id
media_token_pairs
=
[(
im_start_id
,
im_end_id
)]
pattern
=
MultiModalityDataPaddingPatternTokenPairs
(
media_token_pairs
)
pattern
=
MultiModalityDataPaddingPatternMultimodalTokens
([
im_token_id
])
return
pattern
.
pad_input_tokens
(
input_ids
,
mm_inputs
)
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment