Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5cb552b1
Unverified
Commit
5cb552b1
authored
Apr 01, 2025
by
Mick
Committed by
GitHub
Mar 31, 2025
Browse files
refactor: multimodal data (#4754)
parent
c7457191
Changes
36
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
564 additions
and
662 deletions
+564
-662
benchmark/mmmu/bench_hf.py
benchmark/mmmu/bench_hf.py
+32
-11
benchmark/mmmu/eval_utils.py
benchmark/mmmu/eval_utils.py
+2
-0
python/sglang/srt/managers/mm_utils.py
python/sglang/srt/managers/mm_utils.py
+202
-156
python/sglang/srt/managers/multimodal_processor.py
python/sglang/srt/managers/multimodal_processor.py
+0
-2
python/sglang/srt/managers/multimodal_processors/base_processor.py
...lang/srt/managers/multimodal_processors/base_processor.py
+36
-72
python/sglang/srt/managers/multimodal_processors/clip.py
python/sglang/srt/managers/multimodal_processors/clip.py
+7
-26
python/sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py
...lang/srt/managers/multimodal_processors/deepseek_vl_v2.py
+17
-58
python/sglang/srt/managers/multimodal_processors/gemma3.py
python/sglang/srt/managers/multimodal_processors/gemma3.py
+12
-27
python/sglang/srt/managers/multimodal_processors/janus_pro.py
...on/sglang/srt/managers/multimodal_processors/janus_pro.py
+21
-47
python/sglang/srt/managers/multimodal_processors/llava.py
python/sglang/srt/managers/multimodal_processors/llava.py
+23
-12
python/sglang/srt/managers/multimodal_processors/minicpm.py
python/sglang/srt/managers/multimodal_processors/minicpm.py
+35
-38
python/sglang/srt/managers/multimodal_processors/mlama.py
python/sglang/srt/managers/multimodal_processors/mlama.py
+10
-23
python/sglang/srt/managers/multimodal_processors/qwen_vl.py
python/sglang/srt/managers/multimodal_processors/qwen_vl.py
+21
-44
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+116
-94
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+1
-1
python/sglang/srt/managers/utils.py
python/sglang/srt/managers/utils.py
+1
-6
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+0
-5
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+6
-18
python/sglang/srt/models/clip.py
python/sglang/srt/models/clip.py
+12
-7
python/sglang/srt/models/deepseek_janus_pro.py
python/sglang/srt/models/deepseek_janus_pro.py
+10
-15
No files found.
benchmark/mmmu/bench_hf.py
View file @
5cb552b1
...
@@ -72,7 +72,8 @@ def eval_mmmu(args):
...
@@ -72,7 +72,8 @@ def eval_mmmu(args):
if
suffix
:
if
suffix
:
contents
+=
[{
"type"
:
"text"
,
"text"
:
suffix
}]
contents
+=
[{
"type"
:
"text"
,
"text"
:
suffix
}]
messages
=
[{
"role"
:
"user"
,
"content"
:
contents
}]
messages
=
[{
"role"
:
"user"
,
"content"
:
contents
}]
model_inputs
=
processor
.
apply_chat_template
(
try
:
model_inputs
=
processor
.
tokenizer
.
apply_chat_template
(
messages
,
messages
,
tokenize
=
True
,
tokenize
=
True
,
return_dict
=
True
,
return_dict
=
True
,
...
@@ -80,9 +81,29 @@ def eval_mmmu(args):
...
@@ -80,9 +81,29 @@ def eval_mmmu(args):
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
).
to
(
model
.
device
)
).
to
(
model
.
device
)
input_len
=
model_inputs
[
"input_ids"
].
shape
[
-
1
]
input_len
=
model_inputs
[
"input_ids"
].
shape
[
-
1
]
generation
=
model
.
generate
(
**
model_inputs
,
generation_config
=
generation_config
)
generation
=
model
.
generate
(
**
model_inputs
,
generation_config
=
generation_config
)
generation
=
generation
[
0
][
input_len
:]
generation
=
generation
[
0
][
input_len
:]
response
=
processor
.
decode
(
generation
,
skip_special_tokens
=
True
)
response
=
processor
.
decode
(
generation
,
skip_special_tokens
=
True
)
except
:
contents
=
[]
if
prefix
:
contents
+=
[
prefix
]
image
=
PIL
.
Image
.
open
(
sample
[
"image_path"
])
contents
+=
[
image
]
if
suffix
:
contents
+=
[
suffix
]
messages
=
[{
"role"
:
"user"
,
"content"
:
contents
}]
response
=
model
.
chat
(
msgs
=
messages
,
tokenizer
=
processor
.
tokenizer
,
sampling
=
False
,
max_new_tokens
=
sampling_params
[
"max_new_tokens"
],
use_tts_template
=
False
,
generate_audio
=
False
,
temperature
=
0.0
,
)
print
(
f
"response:
{
response
}
"
)
print
(
f
"response:
{
response
}
"
)
process_result
(
response
,
sample
,
answer_dict
,
out_samples
)
process_result
(
response
,
sample
,
answer_dict
,
out_samples
)
...
...
benchmark/mmmu/eval_utils.py
View file @
5cb552b1
...
@@ -442,6 +442,8 @@ def calculate_ins_level_acc(results: Dict):
...
@@ -442,6 +442,8 @@ def calculate_ins_level_acc(results: Dict):
def
process_result
(
response
,
sample
,
answer_dict
,
out_samples
):
def
process_result
(
response
,
sample
,
answer_dict
,
out_samples
):
if
response
is
None
:
return
if
sample
[
"question_type"
]
==
"multiple-choice"
:
if
sample
[
"question_type"
]
==
"multiple-choice"
:
pred_ans
=
parse_multi_choice_response
(
pred_ans
=
parse_multi_choice_response
(
response
,
sample
[
"all_choices"
],
sample
[
"index2ans"
]
response
,
sample
[
"all_choices"
],
sample
[
"index2ans"
]
...
...
python/sglang/srt/managers/mm_utils.py
View file @
5cb552b1
"""
"""
Multimodality utils
Multi
-
modality utils
"""
"""
from
abc
import
abstractmethod
from
abc
import
abstractmethod
...
@@ -9,11 +9,13 @@ import torch
...
@@ -9,11 +9,13 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
sglang.srt.managers.schedule_batch
import
(
from
sglang.srt.managers.schedule_batch
import
(
MultimodalDataItem
,
MultimodalInputs
,
MultimodalInputs
,
global_server_args_dict
,
global_server_args_dict
,
logger
,
logger
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
print_warning_once
from
sglang.utils
import
logger
from
sglang.utils
import
logger
...
@@ -26,7 +28,7 @@ class MultiModalityDataPaddingPattern:
...
@@ -26,7 +28,7 @@ class MultiModalityDataPaddingPattern:
@
abstractmethod
@
abstractmethod
def
pad_input_tokens
(
def
pad_input_tokens
(
self
,
input_ids
:
List
[
int
],
image
_inputs
:
MultimodalInputs
self
,
input_ids
:
List
[
int
],
mm
_inputs
:
MultimodalInputs
)
->
List
[
int
]:
)
->
List
[
int
]:
"""
"""
Pad the input ids sequence containing data tokens, and replace them with pad_values
Pad the input ids sequence containing data tokens, and replace them with pad_values
...
@@ -49,13 +51,13 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
...
@@ -49,13 +51,13 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
"""
"""
This function will replace the data-tokens inbetween with pad_values accordingly
This function will replace the data-tokens inbetween with pad_values accordingly
"""
"""
pad_values
=
mm_inputs
.
pad_values
pad_values
=
[
item
.
pad_value
for
item
in
mm_inputs
.
mm_items
]
data_token_pairs
=
self
.
data_token_id_pairs
data_token_pairs
=
self
.
data_token_id_pairs
mm_inputs
.
image
_offsets
=
[]
mm_inputs
.
data
_offsets
=
[]
if
data_token_pairs
is
None
:
if
data_token_pairs
is
None
:
data_token_pairs
=
[
mm_inputs
.
im_start_id
,
mm_inputs
.
im_end_id
]
data_token_pairs
=
[
mm_inputs
.
im_start_id
,
mm_inputs
.
im_end_id
]
if
data_token_pairs
is
None
:
if
data_token_pairs
is
None
:
logger
.
warning
(
print_
warning
_once
(
"No data_token_pairs provided, RadixAttention might be influenced."
"No data_token_pairs provided, RadixAttention might be influenced."
)
)
return
input_ids
return
input_ids
...
@@ -77,10 +79,10 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
...
@@ -77,10 +79,10 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
if
input_ids
[
start_idx
]
in
start_token_ids
:
if
input_ids
[
start_idx
]
in
start_token_ids
:
data_idx
+=
1
data_idx
+=
1
mm_inputs
.
image
_offsets
+=
[
start_idx
]
mm_inputs
.
data
_offsets
+=
[
start_idx
]
if
data_idx
>=
len
(
mm_inputs
.
pad_values
):
if
data_idx
>=
len
(
pad_values
):
data_idx
=
len
(
mm_inputs
.
pad_values
)
-
1
data_idx
=
len
(
pad_values
)
-
1
num_tokens
=
end_idx
-
start_idx
-
1
num_tokens
=
end_idx
-
start_idx
-
1
pad_value
=
pad_values
[
data_idx
]
pad_value
=
pad_values
[
data_idx
]
...
@@ -94,68 +96,19 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
...
@@ -94,68 +96,19 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
return
padded_ids
return
padded_ids
class
MultModalityDataPaddingPatternSingleToken
(
MultiModalityDataPaddingPattern
):
"""In this pattern, data is represented with a special token_id ( image_inputs.im_token_id ),
which needs first to be expanded to multiple tokens, then replaced with their padding values
This strategy should be used when a single data token represents content that should
be expanded to multiple tokens during processing.
"""
def
__init__
(
self
,
num_data_token_calc_func
:
Callable
[[
Tuple
[
int
,
int
,
int
]],
int
]
)
->
None
:
self
.
num_data_token_calc_func
=
num_data_token_calc_func
def
pad_input_tokens
(
self
,
input_ids
:
List
[
int
],
mm_inputs
:
MultimodalInputs
)
->
List
[
int
]:
"""
This function will follow the procedure of:
1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
2. the padded data tokens will be replaced with their pad_values
"""
image_grid_thws
=
mm_inputs
.
image_grid_thws
pad_values
=
mm_inputs
.
pad_values
image_indices
=
[
idx
for
idx
,
token
in
enumerate
(
input_ids
)
if
token
==
mm_inputs
.
im_token_id
]
mm_inputs
.
image_offsets
=
[]
input_ids_with_image
=
[]
for
image_cnt
,
_
in
enumerate
(
image_grid_thws
):
# print(f"image_cnt {image_cnt}")
num_image_tokens
=
self
.
num_data_token_calc_func
(
image_grid_thws
[
image_cnt
])
if
image_cnt
==
0
:
non_image_tokens
=
input_ids
[:
image_indices
[
image_cnt
]]
else
:
non_image_tokens
=
input_ids
[
image_indices
[
image_cnt
-
1
]
+
1
:
image_indices
[
image_cnt
]
]
input_ids_with_image
.
extend
(
non_image_tokens
)
mm_inputs
.
image_offsets
.
append
(
len
(
input_ids_with_image
))
pad_ids
=
pad_values
*
(
(
num_image_tokens
+
len
(
pad_values
))
//
len
(
pad_values
)
)
input_ids_with_image
.
extend
(
pad_ids
[:
num_image_tokens
])
input_ids_with_image
.
extend
(
input_ids
[
image_indices
[
-
1
]
+
1
:])
return
input_ids_with_image
class
MultiModalityDataPaddingPatternImageTokens
(
MultiModalityDataPaddingPattern
):
class
MultiModalityDataPaddingPatternImageTokens
(
MultiModalityDataPaddingPattern
):
"""In this pattern, data tokens should be represented as image tokens (e.g. <image><image>....<image>)"""
"""In this pattern, data tokens should be represented as repetitions of a single token
e.g. <image><image>....<image>, or <audio><audio>...<audio>
"""
def
__init__
(
self
,
image_token_id
:
torch
.
Tensor
)
->
None
:
def
__init__
(
self
,
image_token_id
:
torch
.
Tensor
)
->
None
:
self
.
image_token_id
=
image_token_id
self
.
image_token_id
=
image_token_id
def
pad_input_tokens
(
self
,
input_ids
:
List
[
int
],
image
_inputs
)
->
List
[
int
]:
def
pad_input_tokens
(
self
,
input_ids
:
List
[
int
],
mm
_inputs
)
->
List
[
int
]:
"""
"""
This function will replace the data-tokens in between with pad_values accordingly
This function will replace the data-tokens in between with pad_values accordingly
"""
"""
pad_values
=
image_inputs
.
pad_values
pad_values
=
[
item
.
pad_value
for
item
in
mm_inputs
.
mm_items
]
assert
len
(
pad_values
)
!=
0
assert
len
(
pad_values
)
!=
0
input_ids_tensor
=
torch
.
tensor
(
input_ids
)
input_ids_tensor
=
torch
.
tensor
(
input_ids
)
...
@@ -170,109 +123,183 @@ class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern
...
@@ -170,109 +123,183 @@ class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern
return
input_ids_tensor
.
tolist
()
return
input_ids_tensor
.
tolist
()
def
get_embedding_and_mask
(
data_embedding_func
:
Callable
[[
List
[
MultimodalDataItem
]],
torch
.
Tensor
],
embedding_items
:
List
[
MultimodalDataItem
],
placeholder_tensor
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
):
"""
Get the multimodal embedding and its mask from input_ids
"""
# 1. Get the embedding
embedding
=
data_embedding_func
(
embedding_items
)
# 2. Check the embedding
if
embedding
.
dim
()
==
2
:
num_mm_tokens_in_embedding
=
embedding
.
shape
[
0
]
else
:
num_mm_tokens_in_embedding
=
embedding
.
shape
[
0
]
*
embedding
.
shape
[
1
]
# the mask of multimodal tokens from input_ids
special_multimodal_mask
=
torch
.
isin
(
input_ids
,
placeholder_tensor
,
).
unsqueeze
(
-
1
)
num_mm_tokens_in_input_ids
=
special_multimodal_mask
.
sum
()
if
num_mm_tokens_in_input_ids
!=
num_mm_tokens_in_embedding
:
logger
.
warning
(
f
"Number of tokens in multimodal embedding does not match those in the input text."
f
"Got
{
num_mm_tokens_in_input_ids
}
tokens in the text but
{
num_mm_tokens_in_embedding
}
"
"tokens from multimodal embeddings."
)
if
num_mm_tokens_in_input_ids
<
num_mm_tokens_in_embedding
:
# TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding
# a fix may be cache the unfinished multimodal embedding for future reuse, determine the tokens to embed with
# extend_start_loc and extend_seq_lens
chunked_prefill_size
=
global_server_args_dict
[
"chunked_prefill_size"
]
if
chunked_prefill_size
!=
-
1
:
logger
.
warning
(
"You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked prefill"
)
# extract from the end: this is a compromise
if
embedding
.
dim
()
==
2
:
embedding
=
embedding
[
-
num_mm_tokens_in_input_ids
:,
:]
else
:
num_multimodal
=
num_mm_tokens_in_input_ids
//
embedding
.
shape
[
0
]
embedding
=
embedding
[
-
num_multimodal
:,
:]
else
:
raise
RuntimeError
(
"Insufficient multimodal embedding length. This is an internal error"
)
return
embedding
,
special_multimodal_mask
def
embed_mm_inputs
(
def
embed_mm_inputs
(
mm_input
:
MultimodalInputs
,
mm_input
s
:
MultimodalInputs
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
input_embedding
:
nn
.
Embedding
,
input_embedding
:
nn
.
Embedding
,
mm_data_embedding_func
:
Callable
[[
MultimodalInputs
],
torch
.
Tensor
],
image_data_embedding_func
:
Callable
[
[
List
[
MultimodalDataItem
]],
torch
.
Tensor
]
=
None
,
audio_data_embedding_func
:
Callable
[
[
List
[
MultimodalDataItem
]],
torch
.
Tensor
]
=
None
,
placeholder_token_ids
:
List
[
int
]
=
None
,
placeholder_token_ids
:
List
[
int
]
=
None
,
)
->
Optional
[
torch
.
Tensor
]:
)
->
Optional
[
torch
.
Tensor
]:
"""
"""
Calculate the image embeddings if necessary, then scatter the result with
Calculate the multimodal embeddings if necessary, then scatter the result with the help of a boolean mask denoting the embed locations
the help of a boolean mask denoting the embed locations
Args:
placeholder_token_ids: denoting the token of multimodal data in input_ids.
If none, the pad_values of multimodal items are used
Returns:
Returns:
final embedding: Optional[torch.Tensor]
final embedding: Optional[torch.Tensor]
"""
"""
if
mm_input
is
None
:
if
mm_inputs
is
None
:
return
None
return
None
placeholder_token_ids
=
placeholder_token_ids
or
mm_input
.
pad_values
# 1. Calculate the multimodal data which exists in input_ids, with the help of pad_values
# we assume that multimodal data are represented with its pad_values in input_ids
placeholder_token_ids
=
placeholder_token_ids
or
[
item
.
pad_value
for
item
in
mm_inputs
.
mm_items
]
# boolean masking the special tokens
placeholder_tensor
=
torch
.
tensor
(
placeholder_token_ids
,
device
=
input_ids
.
device
)
special_image_mask
=
torch
.
isin
(
input_ids
,
placeholder_masks
=
torch
.
isin
(
input_ids
,
placeholder_tensor
)
torch
.
tensor
(
placeholder_token_ids
,
device
=
input_ids
.
device
),
).
unsqueeze
(
-
1
)
num_image_tokens_in_input_ids
=
special_image_mask
.
sum
()
appearing_pad_values
=
torch
.
unique
(
# print(f"{num_image_tokens_in_input_ids}")
input_ids
[
placeholder_masks
],
return_counts
=
False
# print(f"{input_ids}"
)
)
# return
if
appearing_pad_values
.
numel
()
==
0
:
if
num_image_tokens_in_input_ids
==
0
:
# all been prefixed
# unexpected
inputs_embeds
=
input_embedding
(
input_ids
)
inputs_embeds
=
input_embedding
(
input_ids
)
else
:
else
:
# print(f"Getting image feature")
appearing_items
=
[
image_embedding
=
mm_data_embedding_func
(
mm_input
)
item
for
item
in
mm_inputs
.
mm_items
if
item
.
pad_value
is
not
None
and
item
.
pad_value
in
appearing_pad_values
]
# print(f"image_embedding: {image_embedding.shape}")
using_all_items
=
False
if
len
(
appearing_items
)
==
0
:
# This happens mostly when arg placeholder_token_ids is passed
logger
.
warning_once
(
"No multimodal data item's pad value exist in placeholder ids. Using all items"
)
using_all_items
=
True
appearing_items
=
mm_inputs
.
mm_items
if
image_embedding
.
dim
()
==
2
:
embeddings
,
masks
=
[],
[]
num_image_tokens_in_embedding
=
image_embedding
.
shape
[
0
]
else
:
# 2. Get multimodal embedding separately
num_image_tokens_in_embedding
=
(
# TODO: make this more generic
image_embedding
.
shape
[
0
]
*
image_embedding
.
shape
[
1
]
# Try get image embedding if any
if
(
any
(
True
for
item
in
appearing_items
if
item
.
is_image
())
and
image_data_embedding_func
):
items
=
[
item
for
item
in
appearing_items
if
item
.
is_image
()]
embedding
,
mask
=
get_embedding_and_mask
(
data_embedding_func
=
image_data_embedding_func
,
embedding_items
=
items
,
placeholder_tensor
=
(
placeholder_tensor
if
using_all_items
else
torch
.
tensor
(
[
item
.
pad_value
for
item
in
items
],
device
=
input_ids
.
device
,
)
)
if
num_image_tokens_in_input_ids
!=
num_image_tokens_in_embedding
:
),
num_image
=
num_image_tokens_in_input_ids
//
image_embedding
.
shape
[
1
]
input_ids
=
input_ids
,
image_embedding
=
image_embedding
[:
num_image
,
:]
logger
.
warning
(
f
"Number of images does not match number of special image tokens in the input text. "
f
"Got
{
num_image_tokens_in_input_ids
}
image tokens in the text but
{
num_image_tokens_in_embedding
}
"
"tokens from image embeddings."
)
)
embeddings
+=
[
embedding
]
masks
+=
[
mask
]
# TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding
# Try get audio embedding if any
# a fix may be cache the unfinished image embedding for future reuse, determine the tokens to embed with
if
(
# extend_start_loc and extend_seq_lens
any
(
True
for
item
in
appearing_items
if
item
.
is_audio
())
if
num_image_tokens_in_input_ids
>
num_image_tokens_in_embedding
:
and
audio_data_embedding_func
chunked_prefill_size
=
global_server_args_dict
[
"chunked_prefill_size"
]
):
if
chunked_prefill_size
!=
-
1
:
items
=
[
item
for
item
in
appearing_items
if
item
.
is_audio
()]
logger
.
warning
(
embedding
,
mask
=
get_embedding_and_mask
(
"You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked_prefill"
data_embedding_func
=
audio_data_embedding_func
,
embedding_items
=
items
,
placeholder_tensor
=
(
placeholder_tensor
if
using_all_items
else
torch
.
tensor
(
[
item
.
pad_value
for
item
in
items
],
device
=
input_ids
.
device
,
)
),
input_ids
=
input_ids
,
)
)
embeddings
+=
[
embedding
]
masks
+=
[
mask
]
# 3. Get input embeddings
vocab_size
=
input_embedding
.
num_embeddings
vocab_size
=
input_embedding
.
num_embeddings
# Important: clamp after getting original
image
regions
# Important: clamp after getting original
multimodal
regions
# Clamp input ids. This is because the input_ids for the
image
tokens are
# Clamp input ids. This is because the input_ids for the
multimodal
tokens are
# filled with the hash values of the
image
for the prefix matching in the radix attention.
# filled with the hash values of the
multimodal
for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids
.
clamp_
(
min
=
0
,
max
=
vocab_size
-
1
)
input_ids
.
clamp_
(
min
=
0
,
max
=
vocab_size
-
1
)
inputs_embeds
=
input_embedding
(
input_ids
)
inputs_embeds
=
input_embedding
(
input_ids
)
special_image_mask
=
special_image_mask
.
expand_as
(
inputs_embeds
).
to
(
# 4. scatter embeddings into input embedding
inputs_embeds
.
device
for
embedding
,
mask
in
zip
(
embeddings
,
masks
):
)
mask
=
mask
.
expand_as
(
inputs_embeds
).
to
(
inputs_embeds
.
device
)
inputs_embeds
=
inputs_embeds
.
masked_scatter
(
inputs_embeds
=
inputs_embeds
.
masked_scatter
(
special_image_mask
,
mask
,
image_embedding
.
to
(
inputs_embeds
.
device
,
inputs_embeds
.
dtype
),
embedding
.
to
(
inputs_embeds
.
device
,
inputs_embeds
.
dtype
),
)
return
inputs_embeds
def
embed_image_embedding
(
inputs_embeds
:
torch
.
Tensor
,
image_embedding
:
torch
.
Tensor
,
image_bounds
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""
scatter image_embedding into inputs_embeds according to image_bounds
"""
if
len
(
image_bounds
)
>
0
:
image_indices
=
torch
.
stack
(
[
torch
.
arange
(
start
,
end
,
dtype
=
torch
.
long
)
for
start
,
end
in
image_bounds
.
tolist
()
]
).
to
(
inputs_embeds
.
device
)
inputs_embeds
.
scatter_
(
0
,
image_indices
.
view
(
-
1
,
1
).
repeat
(
1
,
inputs_embeds
.
shape
[
-
1
]),
image_embedding
.
view
(
-
1
,
image_embedding
.
shape
[
-
1
]),
)
)
return
inputs_embeds
return
inputs_embeds
...
@@ -280,28 +307,43 @@ def embed_image_embedding(
...
@@ -280,28 +307,43 @@ def embed_image_embedding(
def
general_mm_embed_routine
(
def
general_mm_embed_routine
(
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
embed_tokens
:
nn
.
Embedding
,
language_model
:
nn
.
Module
,
mm_data_embedding_func
:
Callable
[[
MultimodalInputs
],
torch
.
Tensor
],
image_data_embedding_func
:
Callable
[
[
List
[
MultimodalDataItem
]],
torch
.
Tensor
]
=
None
,
audio_data_embedding_func
:
Callable
[
[
List
[
MultimodalDataItem
]],
torch
.
Tensor
]
=
None
,
placeholder_token_ids
:
List
[
int
]
=
None
,
placeholder_token_ids
:
List
[
int
]
=
None
,
):
**
kwargs
,
)
->
torch
.
Tensor
:
"""
"""
a general wrapper function to get final input embeds from multimodal models
A general wrapper function to get final input embeds from multimodal models with a language model as causal model
with a language model as causal model
Args:
Args:
placeholder_token_ids (List[int]): the ids of mm data placeholder tokens
placeholder_token_ids (List[int]): the ids of mm data placeholder tokens
image_data_embedding_func : the function returning the image embedding
audio_data_embedding_func : the function returning the image embedding
Returns:
inputs_embedding
forwarded hidden states
"""
"""
assert
hasattr
(
language_model
,
"get_input_embeddings"
)
embed_tokens
=
language_model
.
get_input_embeddings
()
if
(
if
(
not
forward_batch
.
forward_mode
.
is_decode
()
not
forward_batch
.
forward_mode
.
is_decode
()
and
forward_batch
.
contains_mm_inputs
()
and
forward_batch
.
contains_mm_inputs
()
):
):
image
=
forward_batch
.
merge_mm_inputs
()
mm_input
=
forward_batch
.
merge_mm_inputs
()
inputs_embeds
=
embed_mm_inputs
(
inputs_embeds
=
embed_mm_inputs
(
mm_input
=
image
,
mm_input
s
=
mm_input
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
input_embedding
=
embed_tokens
,
input_embedding
=
embed_tokens
,
mm_data_embedding_func
=
mm_data_embedding_func
,
image_data_embedding_func
=
image_data_embedding_func
,
audio_data_embedding_func
=
audio_data_embedding_func
,
placeholder_token_ids
=
placeholder_token_ids
,
placeholder_token_ids
=
placeholder_token_ids
,
)
)
# once used, mm_inputs is useless
# once used, mm_inputs is useless
...
@@ -310,7 +352,13 @@ def general_mm_embed_routine(
...
@@ -310,7 +352,13 @@ def general_mm_embed_routine(
else
:
else
:
inputs_embeds
=
embed_tokens
(
input_ids
)
inputs_embeds
=
embed_tokens
(
input_ids
)
return
inputs_embeds
hidden_states
=
language_model
(
input_ids
=
None
,
forward_batch
=
forward_batch
,
input_embeds
=
inputs_embeds
,
**
kwargs
,
)
return
hidden_states
def
get_multimodal_data_bounds
(
def
get_multimodal_data_bounds
(
...
@@ -322,15 +370,13 @@ def get_multimodal_data_bounds(
...
@@ -322,15 +370,13 @@ def get_multimodal_data_bounds(
Returns:
Returns:
[bounds_count, 2]
[bounds_count, 2]
"""
"""
# All the images in the batch should share the same special image
# All the multimodal data in the batch should share the same special bound token ids.
# bound token ids.
start_tokens
=
[
s
for
s
,
_e
in
token_pairs
]
start_tokens
=
[
s
for
s
,
_e
in
token_pairs
]
end_tokens
=
[
e
for
_s
,
e
in
token_pairs
]
end_tokens
=
[
e
for
_s
,
e
in
token_pairs
]
assert
all
(
isinstance
(
t
,
int
)
for
t
in
start_tokens
)
assert
all
(
isinstance
(
t
,
int
)
for
t
in
start_tokens
)
assert
all
(
isinstance
(
t
,
int
)
for
t
in
end_tokens
)
assert
all
(
isinstance
(
t
,
int
)
for
t
in
end_tokens
)
# print(input_ids)
start_cond
=
torch
.
isin
(
start_cond
=
torch
.
isin
(
input_ids
,
torch
.
tensor
(
start_tokens
,
device
=
input_ids
.
device
)
input_ids
,
torch
.
tensor
(
start_tokens
,
device
=
input_ids
.
device
)
)
)
...
@@ -339,7 +385,7 @@ def get_multimodal_data_bounds(
...
@@ -339,7 +385,7 @@ def get_multimodal_data_bounds(
(
data_start_tokens
,)
=
torch
.
where
(
start_cond
)
(
data_start_tokens
,)
=
torch
.
where
(
start_cond
)
(
data_end_tokens
,)
=
torch
.
where
(
end_cond
)
(
data_end_tokens
,)
=
torch
.
where
(
end_cond
)
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the
images
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the
multimodal data
if
len
(
data_start_tokens
)
!=
len
(
data_end_tokens
):
if
len
(
data_start_tokens
)
!=
len
(
data_end_tokens
):
if
(
if
(
len
(
data_start_tokens
)
+
1
==
len
(
data_end_tokens
)
len
(
data_start_tokens
)
+
1
==
len
(
data_end_tokens
)
...
@@ -352,14 +398,14 @@ def get_multimodal_data_bounds(
...
@@ -352,14 +398,14 @@ def get_multimodal_data_bounds(
data_start_tokens
,
data_start_tokens
,
]
]
)
)
valid_
image
_nums
=
min
(
len
(
data_start_tokens
),
len
(
data_end_tokens
))
valid_
mm_data
_nums
=
min
(
len
(
data_start_tokens
),
len
(
data_end_tokens
))
if
valid_
image
_nums
==
0
:
if
valid_
mm_data
_nums
==
0
:
return
torch
.
zeros
((
0
,
2
),
device
=
input_ids
.
device
)
return
torch
.
zeros
((
0
,
2
),
device
=
input_ids
.
device
)
# Filter out pairs where start_token >= end_token
# Filter out pairs where start_token >= end_token
valid_pairs
=
[]
valid_pairs
=
[]
for
i
in
range
(
valid_
image
_nums
):
for
i
in
range
(
valid_
mm_data
_nums
):
start_token
=
data_start_tokens
[
i
]
start_token
=
data_start_tokens
[
i
]
end_token
=
data_end_tokens
[
i
]
end_token
=
data_end_tokens
[
i
]
if
start_token
<
end_token
:
if
start_token
<
end_token
:
...
...
python/sglang/srt/managers/multimodal_processor.py
View file @
5cb552b1
...
@@ -64,5 +64,3 @@ def get_mm_processor(
...
@@ -64,5 +64,3 @@ def get_mm_processor(
f
"No processor registered for architecture:
{
hf_config
.
architectures
}
.
\n
"
f
"No processor registered for architecture:
{
hf_config
.
architectures
}
.
\n
"
f
"Registered architectures:
{
[
model_cls
.
__name__
for
model_cls
in
PROCESSOR_MAPPING
.
keys
()]
}
"
f
"Registered architectures:
{
[
model_cls
.
__name__
for
model_cls
in
PROCESSOR_MAPPING
.
keys
()]
}
"
)
)
self
.
image_proce
python/sglang/srt/managers/multimodal_processors/base_processor.py
View file @
5cb552b1
...
@@ -8,18 +8,10 @@ from typing import Optional
...
@@ -8,18 +8,10 @@ from typing import Optional
import
numpy
as
np
import
numpy
as
np
import
PIL
import
PIL
import
transformers
from
decord
import
VideoReader
,
cpu
from
decord
import
VideoReader
,
cpu
from
PIL
import
Image
from
PIL
import
Image
from
sglang.srt.utils
import
load_audio
,
load_image
,
logger
from
sglang.srt.utils
import
encode_video
,
load_audio
,
load_image
,
logger
global
global_processor
def
get_global_processor
():
global
global_processor
return
global_processor
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -27,9 +19,6 @@ class BaseMultiModalProcessorOutput:
...
@@ -27,9 +19,6 @@ class BaseMultiModalProcessorOutput:
# input_text, with each frame of video/image represented with a image_token
# input_text, with each frame of video/image represented with a image_token
input_text
:
str
input_text
:
str
mm_data_hashes
:
Optional
[
list
[
int
]]
# images
image_sizes
:
Optional
[
list
[
int
]]
# frames loaded from image and video, in given order
# frames loaded from image and video, in given order
images
:
Optional
[
list
[
PIL
.
Image
]]
=
None
images
:
Optional
[
list
[
PIL
.
Image
]]
=
None
...
@@ -37,7 +26,7 @@ class BaseMultiModalProcessorOutput:
...
@@ -37,7 +26,7 @@ class BaseMultiModalProcessorOutput:
audios
:
Optional
[
list
[
np
.
ndarray
]]
=
None
audios
:
Optional
[
list
[
np
.
ndarray
]]
=
None
def
normalize
(
self
):
def
normalize
(
self
):
for
field_name
in
[
"data_hashes"
,
"image_sizes"
,
"images"
,
"audios"
]:
for
field_name
in
[
"image_sizes"
,
"images"
,
"audios"
]:
field
=
getattr
(
self
,
field_name
,
None
)
field
=
getattr
(
self
,
field_name
,
None
)
if
field
is
not
None
and
isinstance
(
field
,
list
)
and
len
(
field
)
==
0
:
if
field
is
not
None
and
isinstance
(
field
,
list
)
and
len
(
field
)
==
0
:
setattr
(
self
,
field_name
,
None
)
setattr
(
self
,
field_name
,
None
)
...
@@ -67,28 +56,35 @@ class BaseMultimodalProcessor(ABC):
...
@@ -67,28 +56,35 @@ class BaseMultimodalProcessor(ABC):
# FIXME: not accurate, model and image specific
# FIXME: not accurate, model and image specific
self
.
NUM_TOKEN_PER_FRAME
=
330
self
.
NUM_TOKEN_PER_FRAME
=
330
# Initialize global processor first
self
.
io_executor
=
concurrent
.
futures
.
ThreadPoolExecutor
(
init_global_processor
(
self
,
server_args
)
max_workers
=
int
(
os
.
environ
.
get
(
"SGLANG_IO_WORKERS"
,
4
))
)
self
.
executor
=
concurrent
.
futures
.
ProcessPoolExecutor
(
self
.
cpu_executor
=
concurrent
.
futures
.
ProcessPoolExecutor
(
initializer
=
init_global_processor
,
mp_context
=
mp
.
get_context
(
"fork"
),
mp_context
=
mp
.
get_context
(
"fork"
),
initargs
=
(
max_workers
=
int
(
os
.
environ
.
get
(
"SGLANG_CPU_WORKERS"
,
os
.
cpu_count
())),
self
,
server_args
,
),
max_workers
=
int
(
os
.
environ
.
get
(
"SGLANG_CPU_COUNT"
,
os
.
cpu_count
())),
)
)
def
_build_processor
(
self
,
server_args
):
def
process_mm_data
(
"""Init the global processor for multi modal models."""
self
,
input_text
,
images
=
None
,
videos
=
None
,
audios
=
None
,
**
kwargs
from
sglang.srt.hf_transformers_utils
import
get_processor
):
"""
return
get_processor
(
process multimodal data with transformers AutoProcessor
server_args
.
tokenizer_path
,
"""
tokenizer_mode
=
server_args
.
tokenizer_mode
,
if
images
is
not
None
:
trust_remote_code
=
server_args
.
trust_remote_code
,
kwargs
[
"images"
]
=
images
if
videos
is
not
None
:
kwargs
[
"videos"
]
=
videos
if
audios
is
not
None
:
kwargs
[
"audios"
]
=
audios
processor
=
self
.
_processor
result
=
processor
.
__call__
(
text
=
[
input_text
],
padding
=
True
,
return_tensors
=
"pt"
,
**
kwargs
,
)
)
return
result
@
abstractmethod
@
abstractmethod
async
def
process_mm_data_async
(
async
def
process_mm_data_async
(
...
@@ -115,33 +111,9 @@ class BaseMultimodalProcessor(ABC):
...
@@ -115,33 +111,9 @@ class BaseMultimodalProcessor(ABC):
return
estimated_frames_list
return
estimated_frames_list
@
staticmethod
def
encode_video
(
video_path
,
frame_count_limit
=
None
):
if
not
os
.
path
.
exists
(
video_path
):
logger
.
error
(
f
"Video
{
video_path
}
does not exist"
)
return
[]
if
frame_count_limit
==
0
:
return
[]
def
uniform_sample
(
l
,
n
):
gap
=
len
(
l
)
/
n
idxs
=
[
int
(
i
*
gap
+
gap
/
2
)
for
i
in
range
(
n
)]
return
[
l
[
i
]
for
i
in
idxs
]
vr
=
VideoReader
(
video_path
,
ctx
=
cpu
(
0
))
sample_fps
=
round
(
vr
.
get_avg_fps
()
/
1
)
# FPS
frame_indices
=
[
i
for
i
in
range
(
0
,
len
(
vr
),
sample_fps
)]
if
frame_count_limit
is
not
None
and
len
(
frame_indices
)
>
frame_count_limit
:
frame_indices
=
uniform_sample
(
frame_indices
,
frame_count_limit
)
frames
=
vr
.
get_batch
(
frame_indices
).
asnumpy
()
frames
=
[
Image
.
fromarray
(
v
.
astype
(
"uint8"
))
for
v
in
frames
]
return
frames
def
load_mm_data
(
def
load_mm_data
(
self
,
self
,
input_ids
:
list
[
int
]
,
prompt
:
str
,
multimodal_tokens
:
MultimodalSpecialTokens
,
multimodal_tokens
:
MultimodalSpecialTokens
,
max_req_input_len
:
int
,
max_req_input_len
:
int
,
image_data
:
Optional
[
list
]
=
None
,
image_data
:
Optional
[
list
]
=
None
,
...
@@ -167,11 +139,13 @@ class BaseMultimodalProcessor(ABC):
...
@@ -167,11 +139,13 @@ class BaseMultimodalProcessor(ABC):
else
:
else
:
multimodal_tokens
.
image_token
=
multimodal_tokens
.
image_token
multimodal_tokens
.
image_token
=
multimodal_tokens
.
image_token
if
isinstance
(
input_ids
,
list
)
and
return_text
:
assert
isinstance
(
prompt
,
str
)
assert
len
(
input_ids
)
and
isinstance
(
input_ids
[
0
],
int
)
input_text
=
self
.
_processor
.
tokenizer
.
decode
(
input_ids
)
if
isinstance
(
prompt
,
list
)
and
return_text
:
assert
len
(
prompt
)
and
isinstance
(
prompt
[
0
],
int
)
prompt
=
self
.
_processor
.
tokenizer
.
decode
(
prompt
)
else
:
else
:
input_text
=
input_ids
prompt
=
prompt
if
return_text
:
if
return_text
:
import
re
import
re
...
@@ -181,7 +155,7 @@ class BaseMultimodalProcessor(ABC):
...
@@ -181,7 +155,7 @@ class BaseMultimodalProcessor(ABC):
+
")"
+
")"
)
)
# split text into list of normal text and special tokens
# split text into list of normal text and special tokens
text_parts
=
re
.
split
(
pattern
,
input_tex
t
)
text_parts
=
re
.
split
(
pattern
,
promp
t
)
# TODO(mick): load from server_args, env, or sampling_params
# TODO(mick): load from server_args, env, or sampling_params
MAX_NUM_FRAMES
=
30
MAX_NUM_FRAMES
=
30
...
@@ -217,7 +191,7 @@ class BaseMultimodalProcessor(ABC):
...
@@ -217,7 +191,7 @@ class BaseMultimodalProcessor(ABC):
):
):
# video
# video
path
=
image_file
[
len
(
"video:"
)
:]
path
=
image_file
[
len
(
"video:"
)
:]
frames
=
BaseMultimodalProcessor
.
encode_video
(
frames
=
encode_video
(
path
,
frame_count_limit
=
frames_to_process
path
,
frame_count_limit
=
frames_to_process
)
)
else
:
else
:
...
@@ -254,19 +228,9 @@ class BaseMultimodalProcessor(ABC):
...
@@ -254,19 +228,9 @@ class BaseMultimodalProcessor(ABC):
raise
RuntimeError
(
f
"An exception occurred while loading images:
{
e
}
"
)
raise
RuntimeError
(
f
"An exception occurred while loading images:
{
e
}
"
)
out
=
BaseMultiModalProcessorOutput
(
out
=
BaseMultiModalProcessorOutput
(
mm_data_hashes
=
hashes
,
image_sizes
=
image_sizes
,
images
=
images
,
images
=
images
,
audios
=
audios
,
audios
=
audios
,
input_text
=
new_text
,
input_text
=
new_text
,
)
)
out
.
normalize
()
out
.
normalize
()
return
out
return
out
def
init_global_processor
(
sglang_processor
:
BaseMultimodalProcessor
,
server_args
):
"""
Init the global processor for multimodal models."""
global
global_processor
transformers
.
logging
.
set_verbosity_error
()
global_processor
=
sglang_processor
.
_build_processor
(
server_args
=
server_args
)
python/sglang/srt/managers/multimodal_processors/clip.py
View file @
5cb552b1
import
asyncio
from
typing
import
List
,
Union
from
typing
import
List
,
Union
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
BaseMultimodalProcessor
,
BaseMultimodalProcessor
,
get_global_processor
,
)
)
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.models.clip
import
CLIPModel
from
sglang.srt.models.clip
import
CLIPModel
from
sglang.srt.utils
import
load_image
from
sglang.srt.utils
import
load_image
...
@@ -15,29 +14,6 @@ class ClipImageProcessor(BaseMultimodalProcessor):
...
@@ -15,29 +14,6 @@ class ClipImageProcessor(BaseMultimodalProcessor):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
@
staticmethod
def
_process_single_image_task
(
images
,
input_text
):
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
return
get_global_processor
()(
images
=
images
,
text
=
input_text
,
return_tensors
=
"pt"
)
async
def
_process_single_image
(
self
,
images
,
input_text
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
image_inputs
=
await
loop
.
run_in_executor
(
self
.
executor
,
ClipImageProcessor
.
_process_single_image_task
,
images
,
input_text
,
)
else
:
image_inputs
=
self
.
_processor
(
images
=
images
,
text
=
[
input_text
],
return_tensors
=
"pt"
)
return
image_inputs
async
def
process_mm_data_async
(
async
def
process_mm_data_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
*
args
,
**
kwargs
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
*
args
,
**
kwargs
):
):
...
@@ -56,8 +32,13 @@ class ClipImageProcessor(BaseMultimodalProcessor):
...
@@ -56,8 +32,13 @@ class ClipImageProcessor(BaseMultimodalProcessor):
else
:
else
:
images
=
load_image
(
image_data
[
0
])[
0
]
images
=
load_image
(
image_data
[
0
])[
0
]
image_inputs
=
await
self
.
_
process_
single_image
(
images
,
input_text
)
image_inputs
=
self
.
process_
mm_data
(
input_text
=
input_text
,
images
=
images
)
image_inputs
[
"data_hashes"
]
=
[
hash
(
str
(
image_data
))]
image_inputs
[
"data_hashes"
]
=
[
hash
(
str
(
image_data
))]
image_inputs
[
"input_ids"
]
=
image_inputs
[
"input_ids"
].
tolist
()[
0
]
image_inputs
[
"input_ids"
]
=
image_inputs
[
"input_ids"
].
tolist
()[
0
]
image_inputs
[
"mm_items"
]
=
[
MultimodalDataItem
(
pixel_values
=
image_inputs
[
"pixel_values"
],
modality
=
Modality
.
IMAGE
)
]
return
image_inputs
return
image_inputs
python/sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py
View file @
5cb552b1
...
@@ -16,15 +16,14 @@
...
@@ -16,15 +16,14 @@
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import
asyncio
import
torch
import
torch
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
BaseMultimodalProcessor
,
BaseMultimodalProcessor
,
MultimodalSpecialTokens
,
MultimodalSpecialTokens
,
get_global_processor
,
)
)
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.models.deepseek_vl2
import
DeepseekVL2ForCausalLM
from
sglang.srt.models.deepseek_vl2
import
DeepseekVL2ForCausalLM
...
@@ -35,51 +34,6 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
...
@@ -35,51 +34,6 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
self
.
IMAGE_TOKEN
=
"<image>"
self
.
IMAGE_TOKEN
=
"<image>"
@
staticmethod
def
_process_images_task
(
image
,
input_text
,
max_req_input_len
):
processor
=
get_global_processor
()
res
=
processor
.
__call__
(
conversations
=
input_text
,
images
=
image
,
max_req_input_len
=
max_req_input_len
)
image_token_id
=
processor
.
image_token_id
res
[
"im_token_id"
]
=
image_token_id
return
res
async
def
_process_images
(
self
,
image_data
,
input_text
,
max_req_input_len
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
image_inputs
=
await
loop
.
run_in_executor
(
self
.
executor
,
DeepseekVL2ImageProcessor
.
_process_images_task
,
image_data
,
input_text
,
max_req_input_len
,
)
else
:
image_inputs
=
self
.
_process_images_task
(
image_data
,
input_text
,
max_req_input_len
)
return
image_inputs
async
def
_process_images
(
self
,
image_data
,
input_text
,
max_req_input_len
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
image_inputs
=
await
loop
.
run_in_executor
(
self
.
executor
,
DeepseekVL2ImageProcessor
.
_process_images_task
,
image_data
,
input_text
,
max_req_input_len
,
)
else
:
image_inputs
=
self
.
_process_images_task
(
image_data
,
input_text
,
max_req_input_len
)
return
image_inputs
async
def
process_mm_data_async
(
async
def
process_mm_data_async
(
self
,
image_data
,
input_ids
,
request_obj
,
max_req_input_len
,
*
args
,
**
kwargs
self
,
image_data
,
input_ids
,
request_obj
,
max_req_input_len
,
*
args
,
**
kwargs
):
):
...
@@ -89,8 +43,6 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
...
@@ -89,8 +43,6 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
if
not
isinstance
(
image_data
,
list
):
if
not
isinstance
(
image_data
,
list
):
image_data
=
[
image_data
]
image_data
=
[
image_data
]
images
,
image_sizes
=
[],
[]
image_token
=
self
.
IMAGE_TOKEN
image_token
=
self
.
IMAGE_TOKEN
base_output
=
self
.
load_mm_data
(
base_output
=
self
.
load_mm_data
(
input_ids
,
input_ids
,
...
@@ -98,8 +50,11 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
...
@@ -98,8 +50,11 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
image_token
),
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
image_token
),
max_req_input_len
=
max_req_input_len
,
max_req_input_len
=
max_req_input_len
,
)
)
res
=
await
self
.
_process_images
(
res
=
self
.
process_mm_data
(
base_output
.
images
,
base_output
.
input_text
,
max_req_input_len
input_text
=
base_output
.
input_text
,
images
=
base_output
.
images
,
max_req_input_len
=
max_req_input_len
,
conversations
=
base_output
.
input_text
,
)
)
images_seq_mask
=
res
[
"images_seq_mask"
]
images_seq_mask
=
res
[
"images_seq_mask"
]
images_spatial_crop
=
res
[
"images_spatial_crop"
]
images_spatial_crop
=
res
[
"images_spatial_crop"
]
...
@@ -107,13 +62,17 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
...
@@ -107,13 +62,17 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
batched_images_spatial_crop
.
append
(
images_spatial_crop
)
batched_images_spatial_crop
.
append
(
images_spatial_crop
)
batched_images_spatial_crop
=
torch
.
stack
(
batched_images_spatial_crop
,
dim
=
0
)
batched_images_spatial_crop
=
torch
.
stack
(
batched_images_spatial_crop
,
dim
=
0
)
items
=
[]
item
=
MultimodalDataItem
(
pixel_values
=
res
[
"images"
],
modality
=
Modality
.
IMAGE
,
image_emb_mask
=
images_seq_mask
,
image_spatial_crop
=
batched_images_spatial_crop
,
)
items
+=
[
item
]
return
{
return
{
"mm_items"
:
items
,
"input_ids"
:
res
[
"input_ids"
].
tolist
(),
"input_ids"
:
res
[
"input_ids"
].
tolist
(),
"pixel_values"
:
res
[
"images"
],
"im_token_id"
:
self
.
_processor
.
image_token_id
,
"im_token_id"
:
res
[
"im_token_id"
],
"data_hashes"
:
base_output
.
mm_data_hashes
,
"image_sizes"
:
image_sizes
,
"images_emb_mask"
:
images_seq_mask
,
"image_spatial_crop"
:
batched_images_spatial_crop
,
"modalities"
:
request_obj
.
modalities
or
[
"image"
],
}
}
python/sglang/srt/managers/multimodal_processors/gemma3.py
View file @
5cb552b1
...
@@ -7,8 +7,8 @@ from sglang.srt.managers.multimodal_processor import (
...
@@ -7,8 +7,8 @@ from sglang.srt.managers.multimodal_processor import (
)
)
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
MultimodalSpecialTokens
,
MultimodalSpecialTokens
,
get_global_processor
,
)
)
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.models.gemma3_mm
import
Gemma3ForConditionalGeneration
from
sglang.srt.models.gemma3_mm
import
Gemma3ForConditionalGeneration
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
...
@@ -25,28 +25,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
...
@@ -25,28 +25,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
self
.
IM_START_TOKEN_ID
=
hf_config
.
boi_token_index
self
.
IM_START_TOKEN_ID
=
hf_config
.
boi_token_index
self
.
IM_END_TOKEN_ID
=
hf_config
.
eoi_token_index
self
.
IM_END_TOKEN_ID
=
hf_config
.
eoi_token_index
async
def
_process_single_image
(
self
,
images
,
input_text
)
->
dict
:
if
isinstance
(
images
,
list
)
and
len
(
images
)
==
0
:
images
=
None
processor
=
get_global_processor
()
result
=
processor
.
__call__
(
text
=
[
input_text
],
images
=
images
,
padding
=
True
,
return_tensors
=
"pt"
,
# if RGBA, this needs to be set
# images_kwargs={
# "input_data_format": ChannelDimension.FIRST
# }
)
pixel_values
=
getattr
(
result
,
"pixel_values"
,
None
)
return
{
"input_ids"
:
result
.
input_ids
,
"pixel_values"
:
pixel_values
,
}
async
def
process_mm_data_async
(
async
def
process_mm_data_async
(
self
,
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
image_data
:
List
[
Union
[
str
,
bytes
]],
...
@@ -63,21 +41,28 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
...
@@ -63,21 +41,28 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
image_token
=
self
.
IMAGE_TOKEN
image_token
=
self
.
IMAGE_TOKEN
base_output
=
self
.
load_mm_data
(
base_output
=
self
.
load_mm_data
(
input_ids
=
input_ids
,
prompt
=
input_ids
,
image_data
=
image_data
,
image_data
=
image_data
,
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
image_token
),
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
image_token
),
max_req_input_len
=
max_req_input_len
,
max_req_input_len
=
max_req_input_len
,
discard_alpha_channel
=
True
,
discard_alpha_channel
=
True
,
)
)
ret
=
await
self
.
_
process_
single_image
(
ret
=
self
.
process_
mm_data
(
input_text
=
base_output
.
input_text
,
images
=
base_output
.
images
input_text
=
base_output
.
input_text
,
images
=
base_output
.
images
)
)
items
=
[]
for
i
,
image
in
enumerate
(
base_output
.
images
):
item
=
MultimodalDataItem
(
pixel_values
=
ret
[
"pixel_values"
][
i
],
modality
=
Modality
.
IMAGE
,
)
items
+=
[
item
]
return
{
return
{
"mm_items"
:
items
,
"input_ids"
:
ret
[
"input_ids"
].
flatten
().
tolist
(),
"input_ids"
:
ret
[
"input_ids"
].
flatten
().
tolist
(),
"pixel_values"
:
ret
[
"pixel_values"
],
"data_hashes"
:
base_output
.
mm_data_hashes
,
"im_start_id"
:
self
.
IM_START_TOKEN_ID
,
"im_start_id"
:
self
.
IM_START_TOKEN_ID
,
"im_end_id"
:
self
.
IM_END_TOKEN_ID
,
"im_end_id"
:
self
.
IM_END_TOKEN_ID
,
}
}
python/sglang/srt/managers/multimodal_processors/janus_pro.py
View file @
5cb552b1
import
asyncio
from
typing
import
List
,
Union
from
typing
import
List
,
Union
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
BaseMultimodalProcessor
,
BaseMultimodalProcessor
,
MultimodalSpecialTokens
,
MultimodalSpecialTokens
,
get_global_processor
,
)
)
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.models.deepseek_janus_pro
import
MultiModalityCausalLM
from
sglang.srt.models.deepseek_janus_pro
import
MultiModalityCausalLM
...
@@ -15,37 +14,6 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
...
@@ -15,37 +14,6 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
@
staticmethod
def
_process_images_task
(
images
,
input_text
):
processor
=
get_global_processor
()
result
=
processor
.
__call__
(
prompt
=
input_text
,
images
=
images
,
return_tensors
=
"pt"
)
return
{
"input_ids"
:
result
[
"input_ids"
],
"pixel_values"
:
result
[
"pixel_values"
],
"images_emb_mask"
:
result
[
"images_emb_mask"
],
"im_start_id"
:
processor
.
image_start_id
,
"im_end_id"
:
processor
.
image_end_id
,
"im_token_id"
:
processor
.
image_id
,
}
async
def
_process_images
(
self
,
images
,
input_text
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
image_inputs
=
await
loop
.
run_in_executor
(
self
.
executor
,
JanusProImageProcessor
.
_process_images_task
,
images
,
input_text
,
)
else
:
image_inputs
=
self
.
_processor
(
images
=
images
,
text
=
input_text
,
return_tensors
=
"pt"
)
return
image_inputs
async
def
process_mm_data_async
(
async
def
process_mm_data_async
(
self
,
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
image_data
:
List
[
Union
[
str
,
bytes
]],
...
@@ -60,25 +28,31 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
...
@@ -60,25 +28,31 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
if
not
isinstance
(
image_data
,
list
):
if
not
isinstance
(
image_data
,
list
):
image_data
=
[
image_data
]
image_data
=
[
image_data
]
processor
=
self
.
_processor
base_out
=
self
.
load_mm_data
(
base_out
=
self
.
load_mm_data
(
input_ids
=
input_ids
,
prompt
=
input_ids
,
image_data
=
image_data
,
image_data
=
image_data
,
multimodal_tokens
=
MultimodalSpecialTokens
(
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
processor
.
image_tag
),
image_token
=
"<image_placeholder>"
),
max_req_input_len
=
max_req_input_len
,
max_req_input_len
=
max_req_input_len
,
)
)
images
=
base_out
.
images
images
=
base_out
.
images
res
=
await
self
.
_process_images
(
images
=
images
,
input_text
=
base_out
.
input_text
)
res
=
self
.
process_mm_data
(
# print(res)
input_text
=
base_out
.
input_text
,
# print(base_out)
prompt
=
base_out
.
input_text
,
# print("", res["images_emb_mask"].shape)
images
=
images
,
)
return
{
return
{
"mm_items"
:
[
MultimodalDataItem
(
pixel_values
=
res
[
"pixel_values"
],
image_emb_mask
=
res
[
"images_emb_mask"
],
modality
=
Modality
.
IMAGE
,
)
],
"input_ids"
:
res
[
"input_ids"
].
flatten
().
tolist
(),
"input_ids"
:
res
[
"input_ids"
].
flatten
().
tolist
(),
"pixel_values"
:
res
[
"pixel_values"
],
"im_start_id"
:
processor
.
image_start_id
,
"images_emb_mask"
:
res
[
"images_emb_mask"
],
"im_end_id"
:
processor
.
image_end_id
,
"data_hashes"
:
base_out
.
mm_data_hashes
,
"im_token_id"
:
processor
.
image_id
,
"im_start_id"
:
res
[
"im_start_id"
],
"im_end_id"
:
res
[
"im_end_id"
],
"im_token_id"
:
res
[
"im_token_id"
],
}
}
python/sglang/srt/managers/multimodal_processors/llava.py
View file @
5cb552b1
...
@@ -5,8 +5,8 @@ import numpy as np
...
@@ -5,8 +5,8 @@ import numpy as np
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
BaseMultimodalProcessor
,
BaseMultimodalProcessor
,
get_global_processor
,
)
)
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.mm_utils
import
expand2square
,
process_anyres_image
from
sglang.srt.mm_utils
import
expand2square
,
process_anyres_image
from
sglang.srt.models.llava
import
LlavaMistralForCausalLM
,
LlavaQwenForCausalLM
from
sglang.srt.models.llava
import
LlavaMistralForCausalLM
,
LlavaQwenForCausalLM
from
sglang.srt.models.llavavid
import
LlavaVidForCausalLM
from
sglang.srt.models.llavavid
import
LlavaVidForCausalLM
...
@@ -25,11 +25,10 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
...
@@ -25,11 +25,10 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
image_data
:
Union
[
str
,
bytes
],
image_data
:
Union
[
str
,
bytes
],
image_aspect_ratio
:
Optional
[
str
]
=
None
,
image_aspect_ratio
:
Optional
[
str
]
=
None
,
image_grid_pinpoints
:
Optional
[
str
]
=
None
,
image_grid_pinpoints
:
Optional
[
str
]
=
None
,
image_
processor
=
None
,
processor
=
None
,
):
):
processor
=
get_global_processor
()
image_processor
=
image_processor
or
processor
.
image_processor
image_processor
=
processor
.
image_processor
try
:
try
:
image
,
image_size
=
load_image
(
image_data
)
image
,
image_size
=
load_image
(
image_data
)
...
@@ -72,18 +71,22 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
...
@@ -72,18 +71,22 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
async
def
_process_single_image
(
async
def
_process_single_image
(
self
,
image_data
:
Union
[
bytes
,
str
],
aspect_ratio
:
str
,
grid_pinpoints
:
str
self
,
image_data
:
Union
[
bytes
,
str
],
aspect_ratio
:
str
,
grid_pinpoints
:
str
):
):
if
self
.
executor
is
not
None
:
if
self
.
cpu_
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
return
await
loop
.
run_in_executor
(
return
await
loop
.
run_in_executor
(
self
.
executor
,
self
.
cpu_
executor
,
LlavaImageProcessor
.
_process_single_image_task
,
LlavaImageProcessor
.
_process_single_image_task
,
image_data
,
image_data
,
aspect_ratio
,
aspect_ratio
,
grid_pinpoints
,
grid_pinpoints
,
self
.
_processor
,
)
)
else
:
else
:
return
self
.
_process_single_image_task
(
return
self
.
_process_single_image_task
(
image_data
,
aspect_ratio
,
grid_pinpoints
image_data
,
aspect_ratio
,
grid_pinpoints
,
self
.
_processor
.
image_processor
,
)
)
async
def
process_mm_data_async
(
async
def
process_mm_data_async
(
...
@@ -134,14 +137,22 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
...
@@ -134,14 +137,22 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
pixel_values
,
image_hash
,
image_size
=
await
self
.
_process_single_image
(
pixel_values
,
image_hash
,
image_size
=
await
self
.
_process_single_image
(
image_data
[
0
],
aspect_ratio
,
grid_pinpoints
image_data
[
0
],
aspect_ratio
,
grid_pinpoints
)
)
data_hashes
=
[
image_hash
]
image_sizes
=
[
image_size
]
image_sizes
=
[
image_size
]
else
:
else
:
raise
ValueError
(
f
"Invalid image data:
{
image_data
}
"
)
raise
ValueError
(
f
"Invalid image data:
{
image_data
}
"
)
modality
=
Modality
.
IMAGE
if
isinstance
(
request_obj
.
modalities
,
list
):
if
request_obj
.
modalities
[
0
]
==
"multi-images"
:
modality
=
Modality
.
MULTI_IMAGES
elif
request_obj
.
modalities
[
0
]
==
"video"
:
modality
=
Modality
.
VIDEO
return
{
return
{
"pixel_values"
:
pixel_values
,
"mm_items"
:
[
"data_hashes"
:
data_hashes
,
MultimodalDataItem
(
"image_sizes"
:
image_sizes
,
pixel_values
=
pixel_values
,
"modalities"
:
request_obj
.
modalities
or
[
"image"
],
image_sizes
=
image_sizes
,
modality
=
modality
,
)
],
}
}
python/sglang/srt/managers/multimodal_processors/minicpm.py
View file @
5cb552b1
import
asyncio
from
typing
import
List
,
Union
from
typing
import
List
,
Union
import
torch
import
torch
from
transformers
import
BaseImageProcessorFast
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
BaseMultimodalProcessor
,
BaseMultimodalProcessor
,
MultimodalSpecialTokens
,
MultimodalSpecialTokens
,
get_global_processor
,
)
)
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.models.minicpmo
import
MiniCPMO
from
sglang.srt.models.minicpmo
import
MiniCPMO
from
sglang.srt.models.minicpmv
import
MiniCPMV
from
sglang.srt.models.minicpmv
import
MiniCPMV
...
@@ -21,19 +21,23 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
...
@@ -21,19 +21,23 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
self
.
image_token
=
"(<image>./</image>)"
self
.
image_token
=
"(<image>./</image>)"
self
.
audio_token
=
"(<audio>./</audio>)"
self
.
audio_token
=
"(<audio>./</audio>)"
@
staticmethod
def
process_data_task
(
self
,
input_text
,
images
=
None
,
audios
=
None
):
def
_process_data_task
(
input_text
,
images
=
None
,
audios
=
None
):
if
isinstance
(
images
,
list
)
and
len
(
images
)
==
0
:
if
isinstance
(
images
,
list
)
and
len
(
images
)
==
0
:
images
=
None
images
=
None
if
isinstance
(
audios
,
list
)
and
len
(
audios
)
==
0
:
if
isinstance
(
audios
,
list
)
and
len
(
audios
)
==
0
:
audios
=
None
audios
=
None
result
=
get_global_processor
().
__call__
(
processor
=
self
.
_processor
args
=
{}
if
isinstance
(
processor
,
BaseImageProcessorFast
):
args
[
"device"
]
=
"cuda"
result
=
self
.
_processor
.
__call__
(
text
=
input_text
,
text
=
input_text
,
images
=
images
,
images
=
images
,
audios
=
audios
,
audios
=
audios
,
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
chunk_input
=
True
,
chunk_input
=
True
,
**
args
,
)
)
return
{
return
{
"input_ids"
:
result
.
input_ids
,
"input_ids"
:
result
.
input_ids
,
...
@@ -44,23 +48,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
...
@@ -44,23 +48,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
"audio_bounds"
:
getattr
(
result
,
"audio_bounds"
,
None
),
"audio_bounds"
:
getattr
(
result
,
"audio_bounds"
,
None
),
}
}
async
def
_process_data
(
self
,
images
,
input_text
,
audios
=
None
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
multimodal_data_inputs
=
await
loop
.
run_in_executor
(
self
.
executor
,
MiniCPMMultimodalProcessor
.
_process_data_task
,
input_text
,
images
,
audios
,
)
else
:
multimodal_data_inputs
=
self
.
_processor
(
images
=
images
,
text
=
input_text
,
audios
=
audios
,
return_tensors
=
"pt"
)
return
multimodal_data_inputs
async
def
process_mm_data_async
(
async
def
process_mm_data_async
(
self
,
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
image_data
:
List
[
Union
[
str
,
bytes
]],
...
@@ -77,7 +64,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
...
@@ -77,7 +64,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
audio_data
=
[
audio_data
]
audio_data
=
[
audio_data
]
base_output
=
self
.
load_mm_data
(
base_output
=
self
.
load_mm_data
(
input_ids
=
input_ids
,
prompt
=
input_ids
,
max_req_input_len
=
max_req_input_len
,
max_req_input_len
=
max_req_input_len
,
audio_data
=
audio_data
,
audio_data
=
audio_data
,
image_data
=
image_data
,
image_data
=
image_data
,
...
@@ -88,9 +75,9 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
...
@@ -88,9 +75,9 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
if
base_output
is
None
:
if
base_output
is
None
:
return
None
return
None
res
=
await
self
.
_process_data
(
res
=
self
.
process_mm_data
(
images
=
base_output
.
images
,
input_text
=
base_output
.
input_text
,
input_text
=
base_output
.
input_text
,
images
=
base_output
.
images
,
audios
=
base_output
.
audios
,
audios
=
base_output
.
audios
,
)
)
...
@@ -142,23 +129,33 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
...
@@ -142,23 +129,33 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
tgt_sizes_flat
+=
[
tgt_n
]
tgt_sizes_flat
+=
[
tgt_n
]
pixel_values
=
pixel_values_flat
pixel_values
=
pixel_values_flat
if
len
(
tgt_sizes_flat
)
==
0
:
tgt_sizes
=
None
items
=
[]
else
:
if
len
(
pixel_values
)
!=
0
:
tgt_sizes
=
torch
.
stack
(
tgt_sizes_flat
)
item
=
MultimodalDataItem
(
if
not
isinstance
(
res
[
"audio_features"
],
list
):
pixel_values
=
pixel_values
,
res
[
"audio_features"
]
=
[
res
[
"audio_features"
]]
tgt_size
=
tgt_sizes_flat
,
modality
=
Modality
.
IMAGE
,
)
items
+=
[
item
]
if
(
"audio_features"
in
res
and
res
[
"audio_features"
]
is
not
None
and
len
(
res
[
"audio_features"
])
!=
0
):
item
=
MultimodalDataItem
(
audio_features
=
[
res
[
"audio_features"
]],
audio_feature_lens
=
res
[
"audio_feature_lens"
],
modality
=
Modality
.
AUDIO
,
)
items
+=
[
item
]
return
{
return
{
"mm_items"
:
items
,
"input_ids"
:
res
[
"input_ids"
].
flatten
().
tolist
(),
"input_ids"
:
res
[
"input_ids"
].
flatten
().
tolist
(),
"pixel_values"
:
pixel_values
,
"tgt_sizes"
:
tgt_sizes
,
"data_hashes"
:
base_output
.
mm_data_hashes
,
"modalities"
:
request_obj
.
modalities
or
[
"image"
],
"audio_start_id"
:
audio_start_id
,
"audio_start_id"
:
audio_start_id
,
"audio_end_id"
:
audio_end_id
,
"audio_end_id"
:
audio_end_id
,
"audio_features"
:
res
[
"audio_features"
],
"audio_bounds"
:
res
[
"audio_bounds"
],
"audio_feature_lens"
:
res
[
"audio_feature_lens"
],
"im_token_id"
:
im_token_id
,
"im_token_id"
:
im_token_id
,
"im_start_id"
:
tokenizer
.
im_start_id
,
"im_start_id"
:
tokenizer
.
im_start_id
,
"im_end_id"
:
tokenizer
.
im_end_id
,
"im_end_id"
:
tokenizer
.
im_end_id
,
...
...
python/sglang/srt/managers/multimodal_processors/mlama.py
View file @
5cb552b1
import
asyncio
from
typing
import
List
,
Union
from
typing
import
List
,
Union
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
BaseMultimodalProcessor
,
BaseMultimodalProcessor
,
get_global_processor
,
)
)
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.models.mllama
import
MllamaForConditionalGeneration
from
sglang.srt.models.mllama
import
MllamaForConditionalGeneration
from
sglang.srt.utils
import
load_image
from
sglang.srt.utils
import
load_image
...
@@ -15,25 +14,6 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
...
@@ -15,25 +14,6 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
@
staticmethod
def
_process_single_image_task
(
images
,
input_text
):
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
return
get_global_processor
()(
images
,
input_text
,
return_tensors
=
"pt"
)
async
def
_process_single_image
(
self
,
images
,
input_text
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
image_inputs
=
await
loop
.
run_in_executor
(
self
.
executor
,
MllamaImageProcessor
.
_process_single_image_task
,
images
,
input_text
,
)
else
:
image_inputs
=
self
.
_processor
(
images
,
input_text
,
return_tensors
=
"pt"
)
return
image_inputs
async
def
process_mm_data_async
(
async
def
process_mm_data_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
*
args
,
**
kwargs
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
*
args
,
**
kwargs
):
):
...
@@ -52,8 +32,15 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
...
@@ -52,8 +32,15 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
else
:
else
:
images
=
load_image
(
image_data
[
0
])[
0
]
images
=
load_image
(
image_data
[
0
])[
0
]
image_inputs
=
await
self
.
_process_single_image
(
images
,
input_text
)
image_inputs
=
self
.
process_mm_data
(
input_text
=
input_text
,
images
=
images
)
image_inputs
[
"data_hashes"
]
=
[
hash
(
str
(
image_data
))]
image_inputs
[
"input_ids"
]
=
image_inputs
[
"input_ids"
].
tolist
()[
0
]
image_inputs
[
"input_ids"
]
=
image_inputs
[
"input_ids"
].
tolist
()[
0
]
image_inputs
[
"mm_items"
]
=
[
MultimodalDataItem
(
pixel_values
=
image_inputs
[
"pixel_values"
],
aspect_ratio_id
=
image_inputs
[
"aspect_ratio_ids"
],
aspect_ratio_mask
=
image_inputs
[
"aspect_ratio_mask"
],
modality
=
Modality
.
IMAGE
,
)
]
return
image_inputs
return
image_inputs
python/sglang/srt/managers/multimodal_processors/qwen_vl.py
View file @
5cb552b1
import
asyncio
import
asyncio
import
math
import
math
import
time
from
typing
import
List
,
Union
from
typing
import
List
,
Union
import
torch
import
torch
...
@@ -11,8 +10,8 @@ from sglang.srt.managers.multimodal_processor import (
...
@@ -11,8 +10,8 @@ from sglang.srt.managers.multimodal_processor import (
)
)
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
MultimodalSpecialTokens
,
MultimodalSpecialTokens
,
get_global_processor
,
)
)
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.models.qwen2_5_vl
import
Qwen2_5_VLForConditionalGeneration
from
sglang.srt.models.qwen2_5_vl
import
Qwen2_5_VLForConditionalGeneration
from
sglang.srt.models.qwen2_vl
import
Qwen2VLForConditionalGeneration
from
sglang.srt.models.qwen2_vl
import
Qwen2VLForConditionalGeneration
...
@@ -34,45 +33,15 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
...
@@ -34,45 +33,15 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
self
.
MAX_PIXELS
=
16384
*
28
*
28
self
.
MAX_PIXELS
=
16384
*
28
*
28
self
.
MAX_RATIO
=
200
self
.
MAX_RATIO
=
200
@
staticmethod
def
_process_images_task
(
images
,
input_text
,
_hf_config
):
if
isinstance
(
images
,
list
)
and
len
(
images
)
==
0
:
images
=
None
result
=
get_global_processor
().
__call__
(
text
=
[
input_text
],
images
=
images
,
padding
=
True
,
return_tensors
=
"pt"
)
return
{
"input_ids"
:
result
.
input_ids
,
"pixel_values"
:
getattr
(
result
,
"pixel_values"
,
None
),
"image_grid_thw"
:
getattr
(
result
,
"image_grid_thw"
,
None
),
"second_per_grid_ts"
:
getattr
(
result
,
"second_per_grid_ts"
,
None
),
"video_grid_thws"
:
getattr
(
result
,
"video_grid_thws"
,
None
),
}
async
def
_process_single_image
(
self
,
images
,
input_text
)
->
dict
:
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
return
await
loop
.
run_in_executor
(
self
.
executor
,
Qwen2_5VLImageProcessor
.
_process_images_task
,
images
,
input_text
,
self
.
hf_config
,
)
else
:
return
self
.
_process_images_task
(
images
,
input_text
,
self
.
hf_config
)
async
def
process_mm_data_async
(
async
def
process_mm_data_async
(
self
,
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
image_data
:
List
[
Union
[
str
,
bytes
]],
input_ids
,
prompt
,
request_obj
,
request_obj
,
max_req_input_len
,
max_req_input_len
,
*
args
,
*
args
,
**
kwargs
,
**
kwargs
,
):
):
start
=
time
.
time
()
if
not
image_data
:
if
not
image_data
:
return
None
return
None
if
isinstance
(
image_data
,
str
):
if
isinstance
(
image_data
,
str
):
...
@@ -80,7 +49,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
...
@@ -80,7 +49,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
image_token
=
self
.
IMAGE_TOKEN
image_token
=
self
.
IMAGE_TOKEN
base_output
=
self
.
load_mm_data
(
base_output
=
self
.
load_mm_data
(
input_ids
=
input_ids
,
prompt
=
prompt
,
image_data
=
image_data
,
image_data
=
image_data
,
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
image_token
),
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
image_token
),
max_req_input_len
=
max_req_input_len
,
max_req_input_len
=
max_req_input_len
,
...
@@ -144,24 +113,32 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
...
@@ -144,24 +113,32 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return
math
.
floor
(
number
/
factor
)
*
factor
return
math
.
floor
(
number
/
factor
)
*
factor
images
=
[
resize_image
(
image
)
for
image
in
base_output
.
images
]
async
def
resize_image_async
(
image
):
return
resize_image
(
image
)
ret
=
await
self
.
_process_single_image
(
resize_tasks
=
[
resize_image_async
(
image
)
for
image
in
base_output
.
images
]
images
=
images
,
input_text
=
base_output
.
input_text
resized_images
=
await
asyncio
.
gather
(
*
resize_tasks
)
ret
=
self
.
process_mm_data
(
input_text
=
base_output
.
input_text
,
images
=
resized_images
,
)
)
image_grid_thws
=
torch
.
concat
([
ret
[
"image_grid_thw"
]])
image_grid_thws
=
torch
.
concat
([
ret
[
"image_grid_thw"
]])
video_grid_thws
=
None
return
{
return
{
"input_ids"
:
ret
[
"input_ids"
].
flatten
().
tolist
(),
"input_ids"
:
ret
[
"input_ids"
].
flatten
().
tolist
(),
"pixel_values"
:
ret
[
"pixel_values"
],
"mm_items"
:
[
"data_hashes"
:
base_output
.
mm_data_hashes
,
MultimodalDataItem
(
"modalities"
:
request_obj
.
modalities
or
[
"image"
],
pixel_values
=
ret
[
"pixel_values"
],
"image_grid_thws"
:
image_grid_thws
,
image_grid_thws
=
image_grid_thws
,
"video_grid_thws"
:
video_grid_thws
,
# TODO
video_grid_thws
=
None
,
second_per_grid_ts
=
ret
.
get
(
"second_per_grid_ts"
,
None
),
modality
=
Modality
.
IMAGE
,
)
],
"im_start_id"
:
self
.
IM_START_TOKEN_ID
,
"im_start_id"
:
self
.
IM_START_TOKEN_ID
,
"im_end_id"
:
self
.
IM_END_TOKEN_ID
,
"im_end_id"
:
self
.
IM_END_TOKEN_ID
,
"im_token_id"
:
self
.
image_token_id
,
"im_token_id"
:
self
.
image_token_id
,
"video_token_id"
:
self
.
video_token_id
,
"video_token_id"
:
self
.
video_token_id
,
"second_per_grid_ts"
:
ret
[
"second_per_grid_ts"
],
}
}
python/sglang/srt/managers/schedule_batch.py
View file @
5cb552b1
from
__future__
import
annotations
from
__future__
import
annotations
from
enum
import
Enum
,
auto
# Copyright 2023-2024 SGLang Team
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -51,7 +53,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
...
@@ -51,7 +53,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_compiler_backend
from
sglang.srt.utils
import
flatten_nested_list
,
get_compiler_backend
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
...
@@ -143,165 +145,185 @@ class FINISH_ABORT(BaseFinishReason):
...
@@ -143,165 +145,185 @@ class FINISH_ABORT(BaseFinishReason):
}
}
class
Modality
(
Enum
):
IMAGE
=
auto
()
MULTI_IMAGES
=
auto
()
VIDEO
=
auto
()
AUDIO
=
auto
()
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
MultimodalInputs
:
class
MultimodalDataItem
:
"""The image related inputs."""
"""
A single multimodal data, from a single image/video/audio or other
"""
modality
:
Modality
hash
:
int
=
None
pad_value
:
int
=
None
pixel_values
:
Union
[
torch
.
Tensor
,
np
.
array
]
aspect_ratio_id
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
data_hashes
:
Optional
[
list
]
=
None
aspect_ratio_mask
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
image_sizes
:
Optional
[
list
]
=
None
image_sizes
:
Tuple
[
int
,
int
]
=
None
image_offsets
:
Optional
[
list
]
=
None
image_offsets
:
Optional
[
list
]
=
None
# the real data, pixel_values or audio_features
# data: Union[List[torch.Tensor], List[np.array]]
pixel_values
:
Union
[
torch
.
Tensor
,
np
.
array
]
=
None
image_grid_thws
:
Union
[
torch
.
Tensor
,
np
.
array
]
=
None
video_grid_thws
:
Union
[
torch
.
Tensor
,
np
.
array
]
=
None
image_emb_mask
:
Optional
[
torch
.
Tensor
]
=
None
image_spatial_crop
:
Optional
[
torch
.
Tensor
]
=
None
second_per_grid_ts
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
# [num_images, (n, w, h)]
tgt_size
:
Tuple
[
int
,
int
]
=
None
audio_features
:
Union
[
torch
.
Tensor
,
np
.
array
]
=
None
audio_feature_lens
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
@
staticmethod
def
is_empty_list
(
l
):
if
l
is
None
:
return
True
return
len
([
item
for
item
in
flatten_nested_list
(
l
)
if
item
is
not
None
])
==
0
def
set_pad_value
(
self
):
"""
Set the pad value after first hashign the data
"""
def
hash_feature
(
f
):
if
isinstance
(
f
,
list
):
return
hash
(
tuple
(
flatten_nested_list
(
f
)))
elif
isinstance
(
f
,
np
.
ndarray
):
arr
=
np
.
ascontiguousarray
(
f
)
arr_bytes
=
arr
.
tobytes
()
return
hash
(
arr_bytes
)
return
hash
(
f
)
if
self
.
is_audio
():
self
.
hash
=
hash_feature
(
self
.
audio_features
)
else
:
self
.
hash
=
hash_feature
(
self
.
pixel_values
)
assert
self
.
hash
is
not
None
self
.
pad_value
=
self
.
hash
%
(
1
<<
30
)
def
is_audio
(
self
):
return
(
self
.
modality
==
Modality
.
AUDIO
)
and
not
MultimodalDataItem
.
is_empty_list
(
self
.
audio_features
)
def
is_image
(
self
):
return
(
self
.
modality
==
Modality
.
IMAGE
or
self
.
modality
==
Modality
.
MULTI_IMAGES
)
and
not
MultimodalDataItem
.
is_empty_list
(
self
.
pixel_values
)
def
is_video
(
self
):
return
(
self
.
modality
==
Modality
.
VIDEO
)
and
not
MultimodalDataItem
.
is_empty_list
(
self
.
pixel_values
)
def
validate
(
self
):
...
# TODO
@
dataclasses
.
dataclass
class
MultimodalInputs
:
"""The multimodal data related inputs."""
# items of data
mm_items
:
List
[
MultimodalDataItem
]
image_pad_len
:
Optional
[
list
]
=
None
image_pad_len
:
Optional
[
list
]
=
None
pad_values
:
Optional
[
list
]
=
None
modalities
:
Optional
[
list
]
=
None
num_image_tokens
:
Optional
[
int
]
=
None
num_image_tokens
:
Optional
[
int
]
=
None
# Llava related
aspect_ratio_ids
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
aspect_ratio_mask
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
# QWen2-VL related
# QWen2-VL related
# [num_of_images, t, h, w]
image_grid_thws
:
torch
.
Tensor
=
None
mrope_position_delta
:
Optional
[
torch
.
Tensor
]
=
None
mrope_position_delta
:
Optional
[
torch
.
Tensor
]
=
None
# Qwen2-VL video related
video_token_id
:
Optional
[
int
]
=
None
video_grid_thws
:
List
[
Tuple
[
int
,
int
,
int
]]
=
None
second_per_grid_ts
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
# deepseek vl2 related
# image
images_emb_mask
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
image_spatial_crop
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
# The id of the single-image placeholder token
im_token_id
:
Optional
[
torch
.
Tensor
]
=
None
im_token_id
:
Optional
[
torch
.
Tensor
]
=
None
# All the images in the batch should share the same special image
# bound token ids.
im_start_id
:
Optional
[
int
]
=
None
im_start_id
:
Optional
[
int
]
=
None
im_end_id
:
Optional
[
int
]
=
None
im_end_id
:
Optional
[
int
]
=
None
slice_start_id
:
Optional
[
int
]
=
None
slice_start_id
:
Optional
[
int
]
=
None
slice_end_id
:
Optional
[
int
]
=
None
slice_end_id
:
Optional
[
int
]
=
None
# [num_images, 2 (w, h)]
tgt_sizes
:
Optional
[
list
]
=
None
# video
video_token_id
:
Optional
[
int
]
=
None
# audio
# audio
audio_start_id
:
Optional
[
torch
.
Tensor
]
=
None
audio_start_id
:
Optional
[
torch
.
Tensor
]
=
None
audio_end_id
:
Optional
[
torch
.
Tensor
]
=
None
audio_end_id
:
Optional
[
torch
.
Tensor
]
=
None
audio_features
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
audio_feature_lens
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
@
staticmethod
@
staticmethod
def
from_dict
(
obj
:
dict
):
def
from_dict
(
obj
:
dict
):
ret
=
MultimodalInputs
(
ret
=
MultimodalInputs
(
pixel_values
=
obj
[
"pixel_values"
],
mm_items
=
obj
[
"mm_items"
],
data_hashes
=
obj
[
"data_hashes"
],
)
)
assert
isinstance
(
ret
.
mm_items
,
list
)
ret
.
mm_items
=
[
item
for
item
in
ret
.
mm_items
if
item
.
is_audio
()
or
item
.
is_image
()
or
item
.
is_video
()
]
assert
len
(
ret
.
mm_items
)
!=
0
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward,
# Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# errors in cuda kernels. See also llava.py for example.
# errors in cuda kernels. See also llava.py for example.
ret
.
pad_values
=
[
x
%
(
1
<<
30
)
for
x
in
ret
.
data_hashes
]
for
item
in
ret
.
mm_items
:
item
.
set_pad_value
()
optional_args
=
[
optional_args
=
[
"image_sizes"
,
"modalities"
,
"modalities"
,
"aspect_ratio_ids"
,
"aspect_ratio_mask"
,
"image_grid_thws"
,
"images_emb_mask"
,
"image_spatial_crop"
,
"im_token_id"
,
"im_token_id"
,
"im_start_id"
,
"im_start_id"
,
"im_end_id"
,
"im_end_id"
,
"slice_start_id"
,
"slice_start_id"
,
"slice_end_id"
,
"slice_end_id"
,
"tgt_sizes"
,
"audio_start_id"
,
"audio_start_id"
,
"audio_end_id"
,
"audio_end_id"
,
"audio_features"
,
"audio_feature_lens"
,
]
]
for
arg
in
optional_args
:
for
arg
in
optional_args
:
if
arg
in
obj
:
if
arg
in
obj
:
setattr
(
ret
,
arg
,
obj
[
arg
])
setattr
(
ret
,
arg
,
obj
[
arg
])
# validate
assert
(
isinstance
(
ret
.
pixel_values
,
torch
.
Tensor
)
or
isinstance
(
ret
.
pixel_values
,
np
.
ndarray
)
or
isinstance
(
ret
.
pixel_values
,
list
)
)
assert
ret
.
audio_features
is
None
or
isinstance
(
ret
.
audio_features
,
list
)
return
ret
return
ret
def
contains_image_inputs
(
self
)
->
bool
:
def
contains_image_inputs
(
self
)
->
bool
:
""" """
""" """
return
self
.
pixel_values
is
not
None
and
self
.
pixel_values
!=
[]
return
any
(
item
.
is_image
()
for
item
in
self
.
mm_items
)
def
contains_audio_inputs
(
self
)
->
bool
:
def
contains_audio_inputs
(
self
)
->
bool
:
""" """
""" """
return
self
.
audio_features
is
not
None
and
self
.
audio_features
!=
[]
return
any
(
item
.
is_audio
()
for
item
in
self
.
mm_items
)
def
collect_image_inputs
(
self
)
->
List
[
torch
.
Tensor
]:
return
[
item
.
pixel_values
for
item
in
self
.
mm_items
if
item
.
is_image
()]
def
merge
(
self
,
other
:
MultimodalInputs
):
def
merge
(
self
,
other
:
MultimodalInputs
):
"""
"""
merge image inputs when requests are being merged
merge image inputs when requests are being merged
"""
"""
if
isinstance
(
self
.
pixel_values
,
list
):
# in some rare cases, pixel values are list of patches with different shapes
# e.g. minicpm
self
.
pixel_values
+=
other
.
pixel_values
else
:
assert
(
self
.
pixel_values
.
shape
[
1
:]
==
other
.
pixel_values
.
shape
[
1
:]
),
f
"
{
self
.
pixel_values
.
shape
[
1
:]
}
vs
{
other
.
pixel_values
.
shape
[
1
:]
}
"
self
.
pixel_values
=
np
.
concatenate
([
self
.
pixel_values
,
other
.
pixel_values
])
# args would be stacked along first dim
# usually these are already tensors
stack_args
=
[
# TODO: merge with image_grid_thws, basically the same thing
"tgt_sizes"
,
"image_spatial_crop"
,
]
for
arg
in
stack_args
:
if
getattr
(
self
,
arg
,
None
)
is
None
:
setattr
(
self
,
arg
,
getattr
(
other
,
arg
,
None
))
elif
getattr
(
other
,
arg
,
None
)
is
not
None
:
# self and other both not None
setattr
(
self
,
arg
,
torch
.
cat
([
getattr
(
self
,
arg
),
getattr
(
other
,
arg
)],
dim
=
0
),
)
if
self
.
image_grid_thws
is
None
:
self
.
image_grid_thws
=
other
.
image_grid_thws
elif
other
.
image_grid_thws
is
not
None
:
self
.
image_grid_thws
=
torch
.
concat
(
[
self
.
image_grid_thws
,
other
.
image_grid_thws
]
)
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward,
# Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# errors in cuda kernels. See also llava.py for example.
# errors in cuda kernels. See also llava.py for example.
self
.
data_hashes
+=
other
.
data_hashes
self
.
pad_values
=
[
x
%
(
1
<<
30
)
for
x
in
self
.
data_hashes
]
# args needed to be merged
# args needed to be merged
optional_args
=
[
optional_args
=
[
"audio_features"
,
"items"
,
"image_sizes"
,
"image_offsets"
,
"image_offsets"
,
"image_pad_len"
,
"image_pad_len"
,
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
"aspect_ratio_ids"
,
"aspect_ratio_mask"
,
"images_emb_mask"
,
]
]
for
arg
in
optional_args
:
for
arg
in
optional_args
:
self_arg
=
getattr
(
self
,
arg
,
None
)
self_arg
=
getattr
(
self
,
arg
,
None
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
5cb552b1
...
@@ -112,7 +112,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
...
@@ -112,7 +112,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
from
sglang.srt.mem_cache.hiradix_cache
import
HiRadixCache
from
sglang.srt.mem_cache.hiradix_cache
import
HiRadixCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.metrics.collector
import
SchedulerMetricsCollector
,
SchedulerStats
from
sglang.srt.metrics.collector
import
SchedulerMetricsCollector
,
SchedulerStats
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.torch_memory_saver_adapter
import
TorchMemorySaverAdapter
from
sglang.srt.torch_memory_saver_adapter
import
TorchMemorySaverAdapter
...
...
python/sglang/srt/managers/utils.py
View file @
5cb552b1
import
json
import
logging
import
logging
import
time
from
collections
import
defaultdict
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Optional
import
torch
from
sglang.srt.managers.schedule_batch
import
FINISH_ABORT
,
Req
from
sglang.srt.managers.schedule_batch
import
FINISH_ABORT
,
Req
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
5cb552b1
...
@@ -355,11 +355,6 @@ class ForwardBatch:
...
@@ -355,11 +355,6 @@ class ForwardBatch:
for
mm_input
in
valid_inputs
[
1
:]:
for
mm_input
in
valid_inputs
[
1
:]:
merged
.
merge
(
mm_input
)
merged
.
merge
(
mm_input
)
if
isinstance
(
merged
.
pixel_values
,
np
.
ndarray
):
merged
.
pixel_values
=
torch
.
from_numpy
(
merged
.
pixel_values
)
if
isinstance
(
merged
.
audio_features
,
np
.
ndarray
):
merged
.
audio_features
=
torch
.
from_numpy
(
merged
.
audio_features
)
return
merged
return
merged
def
contains_image_inputs
(
self
)
->
bool
:
def
contains_image_inputs
(
self
)
->
bool
:
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
5cb552b1
...
@@ -251,16 +251,15 @@ class ModelRunner:
...
@@ -251,16 +251,15 @@ class ModelRunner:
self
.
init_double_sparsity_channel_config
(
server_args
.
ds_heavy_channel_type
)
self
.
init_double_sparsity_channel_config
(
server_args
.
ds_heavy_channel_type
)
if
self
.
is_multimodal
:
if
self
.
is_multimodal
:
self
.
mem_fraction_static
*=
0.9
5
self
.
mem_fraction_static
*=
0.9
0
logger
.
info
(
logger
.
info
(
f
"Automatically reduce --mem-fraction-static to
{
self
.
mem_fraction_static
:.
3
f
}
"
f
"Automatically reduce --mem-fraction-static to
{
self
.
mem_fraction_static
:.
3
f
}
"
f
"because this is a multimodal model."
f
"because this is a multimodal model."
)
)
if
self
.
model_config
.
hf_config
.
architectures
==
[
logger
.
info
(
"MllamaForConditionalGeneration"
"Automatically turn off --chunked-prefill-size for multimodal model."
]:
)
logger
.
info
(
"Automatically turn off --chunked-prefill-size for mllama."
)
server_args
.
chunked_prefill_size
=
-
1
server_args
.
chunked_prefill_size
=
-
1
if
self
.
model_config
.
hf_config
.
architectures
==
[
if
self
.
model_config
.
hf_config
.
architectures
==
[
...
@@ -269,18 +268,7 @@ class ModelRunner:
...
@@ -269,18 +268,7 @@ class ModelRunner:
"Qwen2_5_VLForConditionalGeneration"
"Qwen2_5_VLForConditionalGeneration"
]:
]:
# TODO: qwen2-vl series does not support radix cache now, set disable_radix_cache=True automatically
# TODO: qwen2-vl series does not support radix cache now, set disable_radix_cache=True automatically
logger
.
info
(
logger
.
info
(
"Automatically disable radix cache for qwen-vl series."
)
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen-vl series."
)
server_args
.
chunked_prefill_size
=
-
1
server_args
.
disable_radix_cache
=
True
if
self
.
model_config
.
hf_config
.
architectures
==
[
"DeepseekVL2ForCausalLM"
]:
# TODO: deepseek-vl2 does not support radix cache now, set disable_radix_cache=True automatically
logger
.
info
(
"Automatically turn off --chunked-prefill-size and disable radix cache for deepseek-vl2."
)
server_args
.
chunked_prefill_size
=
-
1
server_args
.
disable_radix_cache
=
True
server_args
.
disable_radix_cache
=
True
if
server_args
.
enable_deepep_moe
:
if
server_args
.
enable_deepep_moe
:
...
...
python/sglang/srt/models/clip.py
View file @
5cb552b1
...
@@ -17,7 +17,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
...
@@ -17,7 +17,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from
sglang.srt.managers.schedule_batch
import
MultimodalInputs
from
sglang.srt.managers.schedule_batch
import
MultimodalInputs
from
sglang.srt.model_executor.model_runner
import
ForwardBatch
from
sglang.srt.model_executor.model_runner
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
add_prefix
from
sglang.srt.utils
import
add_prefix
,
flatten_nested_list
class
CLIPVisionEmbeddings
(
nn
.
Module
):
class
CLIPVisionEmbeddings
(
nn
.
Module
):
...
@@ -368,7 +368,6 @@ class CLIPVisionTransformer(nn.Module):
...
@@ -368,7 +368,6 @@ class CLIPVisionTransformer(nn.Module):
self
,
self
,
pixel_values
:
torch
.
Tensor
,
pixel_values
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embeddings
(
pixel_values
.
to
(
self
.
device
))
hidden_states
=
self
.
embeddings
(
pixel_values
.
to
(
self
.
device
))
hidden_states
=
self
.
pre_layrnorm
(
hidden_states
)
hidden_states
=
self
.
pre_layrnorm
(
hidden_states
)
...
@@ -456,12 +455,18 @@ class CLIPModel(nn.Module):
...
@@ -456,12 +455,18 @@ class CLIPModel(nn.Module):
get_embedding
:
bool
=
True
,
get_embedding
:
bool
=
True
,
):
):
assert
get_embedding
,
"CLIPEmbeddingModel is only used for embedding"
assert
get_embedding
,
"CLIPEmbeddingModel is only used for embedding"
image
_inputs
=
None
mm
_inputs
=
[]
if
forward_batch
.
mm_inputs
is
not
None
:
if
forward_batch
.
mm_inputs
is
not
None
:
image_inputs
=
forward_batch
.
mm_inputs
mm_inputs
=
forward_batch
.
mm_inputs
pixel_values_list
=
[
if
image_inputs
is
not
None
and
image_inputs
[
0
]
is
not
None
:
item
.
pixel_values
vision_outputs
=
self
.
vision_model
(
image_inputs
[
0
].
pixel_values
)
for
item
in
flatten_nested_list
(
[
mm_input
.
mm_items
for
mm_input
in
mm_inputs
if
mm_input
is
not
None
]
)
]
if
len
(
pixel_values_list
)
!=
0
:
pixel_values
=
torch
.
concat
(
pixel_values_list
)
vision_outputs
=
self
.
vision_model
(
pixel_values
)
pooled_output
=
vision_outputs
[:,
0
,
:]
pooled_output
=
vision_outputs
[:,
0
,
:]
image_embeds
=
self
.
visual_projection
(
pooled_output
)
image_embeds
=
self
.
visual_projection
(
pooled_output
)
image_embeds
=
nn
.
functional
.
normalize
(
image_embeds
,
p
=
2
,
dim
=
1
)
image_embeds
=
nn
.
functional
.
normalize
(
image_embeds
,
p
=
2
,
dim
=
1
)
...
...
python/sglang/srt/models/deepseek_janus_pro.py
View file @
5cb552b1
...
@@ -51,7 +51,7 @@ from sglang.srt.managers.mm_utils import (
...
@@ -51,7 +51,7 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs
,
MultiModalityDataPaddingPatternTokenPairs
,
general_mm_embed_routine
,
general_mm_embed_routine
,
)
)
from
sglang.srt.managers.schedule_batch
import
Multimodal
Inputs
,
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
Multimodal
DataItem
,
MultimodalInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.llama
import
LlamaForCausalLM
from
sglang.srt.models.llama
import
LlamaForCausalLM
...
@@ -1959,8 +1959,8 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
...
@@ -1959,8 +1959,8 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
def
get_image_feature
(
self
,
i
mage_input
:
Multimodal
Inputs
)
->
torch
.
Tensor
:
def
get_image_feature
(
self
,
i
tems
:
List
[
Multimodal
DataItem
]
)
->
torch
.
Tensor
:
pixel_values
=
image_input
.
pixel_values
pixel_values
=
torch
.
concat
([
item
.
pixel_values
for
item
in
items
],
dim
=
0
)
bs
,
n
=
pixel_values
.
shape
[
0
:
2
]
bs
,
n
=
pixel_values
.
shape
[
0
:
2
]
pixel_values
=
pixel_values
.
to
(
pixel_values
=
pixel_values
.
to
(
device
=
self
.
vision_model
.
device
,
dtype
=
self
.
vision_model
.
dtype
device
=
self
.
vision_model
.
device
,
dtype
=
self
.
vision_model
.
dtype
...
@@ -1976,7 +1976,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
...
@@ -1976,7 +1976,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
return
images_embeds
return
images_embeds
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
return
self
.
language_model
.
model
.
embed_tokens
return
self
.
language_model
.
get_input_embeddings
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -1984,23 +1984,18 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
...
@@ -1984,23 +1984,18 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
input_ids
:
torch
.
LongTensor
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
get_embedding
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
general_mm_embed_routine
(
inputs_embeds
=
general_mm_embed_routine
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
embed_tokens
=
self
.
get_input_embeddings
(),
image_data_embedding_func
=
self
.
get_image_feature
,
mm_data_embedding_func
=
self
.
get_image_feature
,
language_model
=
self
.
language_model
,
)
return
self
.
language_model
(
input_ids
=
None
,
positions
=
positions
,
positions
=
positions
,
forward_batch
=
forward_batch
,
input_embeds
=
inputs_embeds
,
get_embedding
=
False
,
)
)
return
hidden_states
def
prepare_gen_img_embeds
(
self
,
image_ids
:
torch
.
LongTensor
):
def
prepare_gen_img_embeds
(
self
,
image_ids
:
torch
.
LongTensor
):
return
self
.
gen_aligner
(
self
.
gen_embed
(
image_ids
))
return
self
.
gen_aligner
(
self
.
gen_embed
(
image_ids
))
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment