Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1e86457c
Unverified
Commit
1e86457c
authored
Mar 25, 2025
by
Mick
Committed by
GitHub
Mar 24, 2025
Browse files
model: Minicpmo (#3023)
parent
64129fa6
Changes
38
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
641 additions
and
178 deletions
+641
-178
benchmark/mmmu/bench_hf.py
benchmark/mmmu/bench_hf.py
+24
-7
docs/references/supported_models.md
docs/references/supported_models.md
+4
-4
python/pyproject.toml
python/pyproject.toml
+1
-0
python/sglang/lang/chat_template.py
python/sglang/lang/chat_template.py
+20
-7
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+7
-6
python/sglang/srt/conversation.py
python/sglang/srt/conversation.py
+30
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+12
-2
python/sglang/srt/managers/mm_utils.py
python/sglang/srt/managers/mm_utils.py
+107
-37
python/sglang/srt/managers/multimodal_processor.py
python/sglang/srt/managers/multimodal_processor.py
+68
-0
python/sglang/srt/managers/multimodal_processors/base_processor.py
...lang/srt/managers/multimodal_processors/base_processor.py
+77
-41
python/sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py
...lang/srt/managers/multimodal_processors/deepseek_vl_v2.py
+28
-8
python/sglang/srt/managers/multimodal_processors/gemma3.py
python/sglang/srt/managers/multimodal_processors/gemma3.py
+10
-10
python/sglang/srt/managers/multimodal_processors/janus_pro.py
...on/sglang/srt/managers/multimodal_processors/janus_pro.py
+15
-12
python/sglang/srt/managers/multimodal_processors/llava.py
python/sglang/srt/managers/multimodal_processors/llava.py
+8
-8
python/sglang/srt/managers/multimodal_processors/minicpm.py
python/sglang/srt/managers/multimodal_processors/minicpm.py
+167
-0
python/sglang/srt/managers/multimodal_processors/mlama.py
python/sglang/srt/managers/multimodal_processors/mlama.py
+5
-5
python/sglang/srt/managers/multimodal_processors/qwen_vl.py
python/sglang/srt/managers/multimodal_processors/qwen_vl.py
+13
-9
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+38
-15
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+6
-6
python/sglang/srt/managers/session_controller.py
python/sglang/srt/managers/session_controller.py
+1
-1
No files found.
benchmark/mmmu/bench_hf.py
View file @
1e86457c
import
argparse
import
argparse
import
PIL.Image
import
torch
import
torch
from
data_utils
import
save_json
from
data_utils
import
save_json
from
eval_utils
import
(
from
eval_utils
import
(
...
@@ -10,22 +11,38 @@ from eval_utils import (
...
@@ -10,22 +11,38 @@ from eval_utils import (
process_result
,
process_result
,
)
)
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
transformers
import
AutoModel
ForImageTextToText
,
AutoProcessor
,
GenerationConfig
from
transformers
import
AutoModel
,
AutoProcessor
,
GenerationConfig
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
eval_mmmu
(
args
):
def
eval_mmmu
(
args
):
eval_args
=
EvalArgs
.
from_cli_args
(
args
)
eval_args
=
EvalArgs
.
from_cli_args
(
args
)
try
:
from
transformers
import
AutoModelForImageTextToText
model
=
AutoModelForImageTextToText
.
from_pretrained
(
args
.
model_path
,
torch_dtype
=
"auto"
,
trust_remote_code
=
True
,
)
except
Exception
as
first_exception
:
try
:
model
=
AutoModel
.
from_pretrained
(
args
.
model_path
,
torch_dtype
=
"auto"
,
trust_remote_code
=
True
,
init_tts
=
False
,
)
except
Exception
as
second_exception
:
raise
RuntimeError
(
f
"Failed to load model: First attempt failed with
{
first_exception
}
, "
f
"second attempt failed with
{
second_exception
}
"
)
from
second_exception
model
=
AutoModelForImageTextToText
.
from_pretrained
(
args
.
model_path
,
torch_dtype
=
"auto"
,
trust_remote_code
=
True
,
)
model
=
model
.
eval
().
cuda
()
model
=
model
.
eval
().
cuda
()
processor
=
AutoProcessor
.
from_pretrained
(
processor
=
AutoProcessor
.
from_pretrained
(
args
.
model_path
,
torch_dtype
=
"auto"
,
device_map
=
"auto"
args
.
model_path
,
torch_dtype
=
"auto"
,
device_map
=
"auto"
,
trust_remote_code
=
True
)
)
samples
=
prepare_samples
(
eval_args
)
samples
=
prepare_samples
(
eval_args
)
...
...
docs/references/supported_models.md
View file @
1e86457c
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
-
InternLM 2
-
InternLM 2
-
Exaone 3
-
Exaone 3
-
BaiChuan2
-
BaiChuan2
-
MiniCPM / MiniCPM 3 / MiniCPM
V
-
MiniCPM / MiniCPM 3 / MiniCPM
-v / MiniCPM-o
-
XVERSE / XVERSE MoE
-
XVERSE / XVERSE MoE
-
SmolLM
-
SmolLM
-
GLM-4
-
GLM-4
...
@@ -70,9 +70,9 @@ LLM.
...
@@ -70,9 +70,9 @@ LLM.
1.
**Register your new model as multimodal**
: Extend
`is_multimodal_model`
in
[
1.
**Register your new model as multimodal**
: Extend
`is_multimodal_model`
in
[
`model_config.py`
](
https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/configs/model_config.py
)
to
`model_config.py`
](
https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/configs/model_config.py
)
to
return True for your model.
return True for your model.
2.
**Process Images**
:
Creat
e a new
`
Image
Processor`
class that inherits from
`Base
Image
Processor`
and register this
2.
**Process Images**
:
Defin
e a new
`Processor`
class that inherits from
`BaseProcessor`
and register this
processor as your model's dedicated processor. See
[
processor as your model's dedicated processor. See
[
`
image
_processor.py`
](
https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/managers/
image
_processor.py
)
`
multimodal
_processor.py`
](
https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/managers/
multimodal
_processor.py
)
for more details.
for more details.
3.
**Handle Image Tokens**
: Implement a
`pad_input_ids`
function for your new model, in which image tokens in the prompt
3.
**Handle Image Tokens**
: Implement a
`pad_input_ids`
function for your new model, in which image tokens in the prompt
should be expanded and replaced with image-hashes, so that SGLang can recognize different images for
should be expanded and replaced with image-hashes, so that SGLang can recognize different images for
...
@@ -80,7 +80,7 @@ LLM.
...
@@ -80,7 +80,7 @@ LLM.
4.
Replace Multi-headed
`Attention`
of ViT with SGLang's
`VisionAttention`
.
4.
Replace Multi-headed
`Attention`
of ViT with SGLang's
`VisionAttention`
.
You can refer
[
Qwen2VL
](
https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/qwen2_vl.py
)
or other
You can refer
[
Qwen2VL
](
https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/qwen2_vl.py
)
or other
vLMs. These models demonstrate how to properly handle both
visu
al and textual inputs.
vLMs. These models demonstrate how to properly handle both
multimod
al and textual inputs.
You should test the new vLM locally against hf models. See
[
`mmmu`
](
https://github.com/sgl-project/sglang/tree/main/benchmark/mmmu
)
for an example.
You should test the new vLM locally against hf models. See
[
`mmmu`
](
https://github.com/sgl-project/sglang/tree/main/benchmark/mmmu
)
for an example.
...
...
python/pyproject.toml
View file @
1e86457c
...
@@ -34,6 +34,7 @@ runtime_common = [
...
@@ -34,6 +34,7 @@ runtime_common = [
"pydantic"
,
"pydantic"
,
"python-multipart"
,
"python-multipart"
,
"pyzmq>=25.1.2"
,
"pyzmq>=25.1.2"
,
"soundfile==0.13.1"
,
"torchao>=0.7.0"
,
"torchao>=0.7.0"
,
"transformers==4.50.0"
,
"transformers==4.50.0"
,
"uvicorn"
,
"uvicorn"
,
...
...
python/sglang/lang/chat_template.py
View file @
1e86457c
...
@@ -15,6 +15,7 @@ class ChatTemplate:
...
@@ -15,6 +15,7 @@ class ChatTemplate:
role_prefix_and_suffix
:
Dict
[
str
,
Tuple
[
str
,
str
]]
role_prefix_and_suffix
:
Dict
[
str
,
Tuple
[
str
,
str
]]
stop_str
:
List
[
str
]
=
()
stop_str
:
List
[
str
]
=
()
image_token
:
str
=
"<image>"
image_token
:
str
=
"<image>"
audio_token
:
str
=
"<audio>"
style
:
ChatTemplateStyle
=
ChatTemplateStyle
.
PLAIN
style
:
ChatTemplateStyle
=
ChatTemplateStyle
.
PLAIN
def
get_prefix_and_suffix
(
def
get_prefix_and_suffix
(
...
@@ -253,6 +254,22 @@ register_chat_template(
...
@@ -253,6 +254,22 @@ register_chat_template(
)
)
)
)
# https://huggingface.co/openbmb/MiniCPM-o-2_6
register_chat_template
(
ChatTemplate
(
name
=
"minicpmo"
,
default_system_prompt
=
None
,
role_prefix_and_suffix
=
{
"system"
:
(
""
,
" "
),
"user"
:
(
"user:"
,
" "
),
"assistant"
:
(
"assistant:"
,
"</s>"
),
},
stop_str
=
(
"<|im_end|>"
,
"<|endoftext|>"
),
image_token
=
"(<image>./</image>)"
,
audio_token
=
"(<audio>./</audio>)"
,
)
)
# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
register_chat_template
(
register_chat_template
(
ChatTemplate
(
ChatTemplate
(
...
@@ -474,12 +491,6 @@ def match_chat_ml(model_path: str):
...
@@ -474,12 +491,6 @@ def match_chat_ml(model_path: str):
return
get_chat_template
(
"chatml-llava"
)
return
get_chat_template
(
"chatml-llava"
)
@
register_chat_template_matching_function
def
match_chat_minicpm
(
model_path
:
str
):
if
"minicpm"
in
model_path
:
return
get_chat_template
(
"minicpmv"
)
@
register_chat_template_matching_function
@
register_chat_template_matching_function
def
match_chat_yi
(
model_path
:
str
):
def
match_chat_yi
(
model_path
:
str
):
model_path
=
model_path
.
lower
()
model_path
=
model_path
.
lower
()
...
@@ -499,8 +510,10 @@ def match_gemma_it(model_path: str):
...
@@ -499,8 +510,10 @@ def match_gemma_it(model_path: str):
@
register_chat_template_matching_function
@
register_chat_template_matching_function
def
match_openbmb_minicpm
(
model_path
:
str
):
def
match_openbmb_minicpm
(
model_path
:
str
):
model_path
=
model_path
.
lower
()
model_path
=
model_path
.
lower
()
if
"minicpm"
in
model_path
:
if
"minicpm
-v
"
in
model_path
:
return
get_chat_template
(
"minicpmv"
)
return
get_chat_template
(
"minicpmv"
)
elif
"minicpm-o"
in
model_path
:
return
get_chat_template
(
"minicpmo"
)
@
register_chat_template_matching_function
@
register_chat_template_matching_function
...
...
python/sglang/srt/configs/model_config.py
View file @
1e86457c
...
@@ -462,18 +462,19 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
...
@@ -462,18 +462,19 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
multimodal_model_archs
=
[
multimodal_model_archs
=
[
"DeepseekVL2ForCausalLM"
,
"DeepseekVL2ForCausalLM"
,
"LlavaLlamaForCausalLM"
,
"LlavaQwenForCausalLM"
,
"LlavaMistralForCausalLM"
,
"LlavaVidForCausalLM"
,
"Gemma3ForConditionalGeneration"
,
"Gemma3ForConditionalGeneration"
,
"Grok1VForCausalLM"
,
"Grok1VForCausalLM"
,
"Grok1AForCausalLM"
,
"Grok1AForCausalLM"
,
"LlavaLlamaForCausalLM"
,
"LlavaMistralForCausalLM"
,
"LlavaQwenForCausalLM"
,
"LlavaVidForCausalLM"
,
"MiniCPMO"
,
"MiniCPMV"
,
"MultiModalityCausalLM"
,
"MllamaForConditionalGeneration"
,
"MllamaForConditionalGeneration"
,
"Qwen2VLForConditionalGeneration"
,
"Qwen2VLForConditionalGeneration"
,
"Qwen2_5_VLForConditionalGeneration"
,
"Qwen2_5_VLForConditionalGeneration"
,
"MiniCPMV"
,
"MultiModalityCausalLM"
,
]
]
...
...
python/sglang/srt/conversation.py
View file @
1e86457c
...
@@ -73,11 +73,14 @@ class Conversation:
...
@@ -73,11 +73,14 @@ class Conversation:
stop_str
:
Union
[
str
,
List
[
str
]]
=
None
stop_str
:
Union
[
str
,
List
[
str
]]
=
None
# The string that represents an image token in the prompt
# The string that represents an image token in the prompt
image_token
:
str
=
"<image>"
image_token
:
str
=
"<image>"
audio_token
:
str
=
"<audio>"
image_data
:
Optional
[
List
[
str
]]
=
None
image_data
:
Optional
[
List
[
str
]]
=
None
modalities
:
Optional
[
List
[
str
]]
=
None
modalities
:
Optional
[
List
[
str
]]
=
None
stop_token_ids
:
Optional
[
int
]
=
None
stop_token_ids
:
Optional
[
int
]
=
None
audio_data
:
Optional
[
List
[
str
]]
=
None
def
get_prompt
(
self
)
->
str
:
def
get_prompt
(
self
)
->
str
:
"""Get the prompt for generation."""
"""Get the prompt for generation."""
system_prompt
=
self
.
system_template
.
format
(
system_message
=
self
.
system_message
)
system_prompt
=
self
.
system_template
.
format
(
system_message
=
self
.
system_message
)
...
@@ -327,6 +330,10 @@ class Conversation:
...
@@ -327,6 +330,10 @@ class Conversation:
"""Append a new message."""
"""Append a new message."""
self
.
image_data
.
append
(
image
)
self
.
image_data
.
append
(
image
)
def
append_audio
(
self
,
audio
:
str
):
"""Append a new message."""
self
.
audio_data
.
append
(
audio
)
def
update_last_message
(
self
,
message
:
str
):
def
update_last_message
(
self
,
message
:
str
):
"""Update the last output.
"""Update the last output.
...
@@ -373,6 +380,7 @@ class Conversation:
...
@@ -373,6 +380,7 @@ class Conversation:
sep2
=
self
.
sep2
,
sep2
=
self
.
sep2
,
stop_str
=
self
.
stop_str
,
stop_str
=
self
.
stop_str
,
image_token
=
self
.
image_token
,
image_token
=
self
.
image_token
,
audio_token
=
self
.
audio_token
,
)
)
def
dict
(
self
):
def
dict
(
self
):
...
@@ -459,8 +467,10 @@ def generate_chat_conv(
...
@@ -459,8 +467,10 @@ def generate_chat_conv(
sep2
=
conv
.
sep2
,
sep2
=
conv
.
sep2
,
stop_str
=
conv
.
stop_str
,
stop_str
=
conv
.
stop_str
,
image_data
=
[],
image_data
=
[],
audio_data
=
[],
modalities
=
[],
modalities
=
[],
image_token
=
conv
.
image_token
,
image_token
=
conv
.
image_token
,
audio_token
=
conv
.
audio_token
,
)
)
if
isinstance
(
request
.
messages
,
str
):
if
isinstance
(
request
.
messages
,
str
):
...
@@ -498,6 +508,7 @@ def generate_chat_conv(
...
@@ -498,6 +508,7 @@ def generate_chat_conv(
if
conv
.
name
!=
"qwen2-vl"
if
conv
.
name
!=
"qwen2-vl"
else
conv
.
image_token
else
conv
.
image_token
)
)
audio_token
=
conv
.
audio_token
for
content
in
message
.
content
:
for
content
in
message
.
content
:
if
content
.
type
==
"text"
:
if
content
.
type
==
"text"
:
if
num_image_url
>
16
:
if
num_image_url
>
16
:
...
@@ -507,6 +518,10 @@ def generate_chat_conv(
...
@@ -507,6 +518,10 @@ def generate_chat_conv(
# NOTE: Only works for llava
# NOTE: Only works for llava
real_content
+=
image_token
real_content
+=
image_token
conv
.
append_image
(
content
.
image_url
.
url
)
conv
.
append_image
(
content
.
image_url
.
url
)
elif
content
.
type
==
"audio_url"
:
real_content
+=
audio_token
conv
.
append_audio
(
content
.
audio_url
.
url
)
conv
.
append_message
(
conv
.
roles
[
0
],
real_content
)
conv
.
append_message
(
conv
.
roles
[
0
],
real_content
)
elif
msg_role
==
"assistant"
:
elif
msg_role
==
"assistant"
:
parsed_content
=
""
parsed_content
=
""
...
@@ -704,3 +719,18 @@ register_conv_template(
...
@@ -704,3 +719,18 @@ register_conv_template(
image_token
=
"<image_placeholder>"
,
image_token
=
"<image_placeholder>"
,
)
)
)
)
# Reference: https://huggingface.co/openbmb/MiniCPM-o-2_6#usage
register_conv_template
(
Conversation
(
name
=
"minicpmo"
,
system_message
=
"You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
,
system_template
=
"<|im_start|>system
\n
{system_message}"
,
roles
=
(
"<|im_start|>user"
,
"<|im_start|>assistant"
),
sep
=
"<|im_end|>
\n
"
,
sep_style
=
SeparatorStyle
.
ADD_NEW_LINE_SINGLE
,
stop_str
=
(
"<|im_end|>"
,
"<|endoftext|>"
),
image_token
=
"(<image>./</image>)"
,
audio_token
=
"(<audio>./</audio>)"
,
)
)
python/sglang/srt/managers/io_struct.py
View file @
1e86457c
...
@@ -45,6 +45,8 @@ class GenerateReqInput:
...
@@ -45,6 +45,8 @@ class GenerateReqInput:
# The image input. It can be a file name, a url, or base64 encoded string.
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
# See also python/sglang/srt/utils.py:load_image.
image_data
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
image_data
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# The audio input. Like image data, tt can be a file name, a url, or base64 encoded string.
audio_data
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# The sampling_params. See descriptions below.
# The sampling_params. See descriptions below.
sampling_params
:
Optional
[
Union
[
List
[
Dict
],
Dict
]]
=
None
sampling_params
:
Optional
[
Union
[
List
[
Dict
],
Dict
]]
=
None
# The request id.
# The request id.
...
@@ -167,6 +169,13 @@ class GenerateReqInput:
...
@@ -167,6 +169,13 @@ class GenerateReqInput:
elif
isinstance
(
self
.
image_data
,
list
):
elif
isinstance
(
self
.
image_data
,
list
):
pass
pass
if
self
.
audio_data
is
None
:
self
.
audio_data
=
[
None
]
*
num
elif
not
isinstance
(
self
.
audio_data
,
list
):
self
.
audio_data
=
[
self
.
audio_data
]
*
num
elif
isinstance
(
self
.
audio_data
,
list
):
pass
if
self
.
sampling_params
is
None
:
if
self
.
sampling_params
is
None
:
self
.
sampling_params
=
[{}]
*
num
self
.
sampling_params
=
[{}]
*
num
elif
not
isinstance
(
self
.
sampling_params
,
list
):
elif
not
isinstance
(
self
.
sampling_params
,
list
):
...
@@ -231,6 +240,7 @@ class GenerateReqInput:
...
@@ -231,6 +240,7 @@ class GenerateReqInput:
text
=
self
.
text
[
i
]
if
self
.
text
is
not
None
else
None
,
text
=
self
.
text
[
i
]
if
self
.
text
is
not
None
else
None
,
input_ids
=
self
.
input_ids
[
i
]
if
self
.
input_ids
is
not
None
else
None
,
input_ids
=
self
.
input_ids
[
i
]
if
self
.
input_ids
is
not
None
else
None
,
image_data
=
self
.
image_data
[
i
],
image_data
=
self
.
image_data
[
i
],
audio_data
=
self
.
audio_data
[
i
],
sampling_params
=
self
.
sampling_params
[
i
],
sampling_params
=
self
.
sampling_params
[
i
],
rid
=
self
.
rid
[
i
],
rid
=
self
.
rid
[
i
],
return_logprob
=
self
.
return_logprob
[
i
],
return_logprob
=
self
.
return_logprob
[
i
],
...
@@ -259,8 +269,8 @@ class TokenizedGenerateReqInput:
...
@@ -259,8 +269,8 @@ class TokenizedGenerateReqInput:
input_text
:
str
input_text
:
str
# The input token ids
# The input token ids
input_ids
:
List
[
int
]
input_ids
:
List
[
int
]
# The
image
inputs
# The
multimodal
inputs
image
_inputs
:
dict
mm
_inputs
:
dict
# The sampling parameters
# The sampling parameters
sampling_params
:
SamplingParams
sampling_params
:
SamplingParams
# Whether to return the logprobs
# Whether to return the logprobs
...
...
python/sglang/srt/managers/mm_utils.py
View file @
1e86457c
...
@@ -9,7 +9,7 @@ import torch
...
@@ -9,7 +9,7 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
sglang.srt.managers.schedule_batch
import
(
from
sglang.srt.managers.schedule_batch
import
(
Image
Inputs
,
Multimodal
Inputs
,
global_server_args_dict
,
global_server_args_dict
,
logger
,
logger
,
)
)
...
@@ -26,7 +26,7 @@ class MultiModalityDataPaddingPattern:
...
@@ -26,7 +26,7 @@ class MultiModalityDataPaddingPattern:
@
abstractmethod
@
abstractmethod
def
pad_input_tokens
(
def
pad_input_tokens
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
Image
Inputs
self
,
input_ids
:
List
[
int
],
image_inputs
:
Multimodal
Inputs
)
->
List
[
int
]:
)
->
List
[
int
]:
"""
"""
Pad the input ids sequence containing data tokens, and replace them with pad_values
Pad the input ids sequence containing data tokens, and replace them with pad_values
...
@@ -44,16 +44,16 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
...
@@ -44,16 +44,16 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
self
.
data_token_id_pairs
=
data_token_pairs
self
.
data_token_id_pairs
=
data_token_pairs
def
pad_input_tokens
(
def
pad_input_tokens
(
self
,
input_ids
:
List
[
int
],
image
_inputs
:
Image
Inputs
self
,
input_ids
:
List
[
int
],
mm
_inputs
:
Multimodal
Inputs
)
->
List
[
int
]:
)
->
List
[
int
]:
"""
"""
This function will replace the data-tokens inbetween with pad_values accordingly
This function will replace the data-tokens inbetween with pad_values accordingly
"""
"""
pad_values
=
image
_inputs
.
pad_values
pad_values
=
mm
_inputs
.
pad_values
data_token_pairs
=
self
.
data_token_id_pairs
data_token_pairs
=
self
.
data_token_id_pairs
image
_inputs
.
image_offsets
=
[]
mm
_inputs
.
image_offsets
=
[]
if
data_token_pairs
is
None
:
if
data_token_pairs
is
None
:
data_token_pairs
=
[
image
_inputs
.
im_start_id
,
image
_inputs
.
im_end_id
]
data_token_pairs
=
[
mm
_inputs
.
im_start_id
,
mm
_inputs
.
im_end_id
]
if
data_token_pairs
is
None
:
if
data_token_pairs
is
None
:
logger
.
warning
(
logger
.
warning
(
"No data_token_pairs provided, RadixAttention might be influenced."
"No data_token_pairs provided, RadixAttention might be influenced."
...
@@ -61,8 +61,6 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
...
@@ -61,8 +61,6 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
return
input_ids
return
input_ids
start_token_ids
=
[
s
for
s
,
_e
in
data_token_pairs
]
start_token_ids
=
[
s
for
s
,
_e
in
data_token_pairs
]
end_tokens_ids
=
[
e
for
_s
,
e
in
data_token_pairs
]
end_tokens_ids
=
[
e
for
_s
,
e
in
data_token_pairs
]
# First start token marks new data
data_start_token
=
start_token_ids
[
0
]
padded_ids
=
[]
padded_ids
=
[]
last_idx
=
0
last_idx
=
0
...
@@ -77,9 +75,12 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
...
@@ -77,9 +75,12 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
for
start_idx
,
end_idx
in
zip
(
start_indices
,
end_indices
):
for
start_idx
,
end_idx
in
zip
(
start_indices
,
end_indices
):
padded_ids
.
extend
(
input_ids
[
last_idx
:
start_idx
+
1
])
padded_ids
.
extend
(
input_ids
[
last_idx
:
start_idx
+
1
])
if
input_ids
[
start_idx
]
==
data_
start_token
:
if
input_ids
[
start_idx
]
in
start_token
_ids
:
data_idx
+=
1
data_idx
+=
1
image_inputs
.
image_offsets
+=
[
start_idx
]
mm_inputs
.
image_offsets
+=
[
start_idx
]
if
data_idx
>=
len
(
mm_inputs
.
pad_values
):
data_idx
=
len
(
mm_inputs
.
pad_values
)
-
1
num_tokens
=
end_idx
-
start_idx
-
1
num_tokens
=
end_idx
-
start_idx
-
1
pad_value
=
pad_values
[
data_idx
]
pad_value
=
pad_values
[
data_idx
]
...
@@ -89,7 +90,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
...
@@ -89,7 +90,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
padded_ids
.
extend
(
input_ids
[
last_idx
:])
padded_ids
.
extend
(
input_ids
[
last_idx
:])
assert
len
(
input_ids
)
==
len
(
padded_ids
)
assert
len
(
input_ids
)
==
len
(
padded_ids
)
,
"Length validation fails"
return
padded_ids
return
padded_ids
...
@@ -107,26 +108,25 @@ class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern)
...
@@ -107,26 +108,25 @@ class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern)
self
.
num_data_token_calc_func
=
num_data_token_calc_func
self
.
num_data_token_calc_func
=
num_data_token_calc_func
def
pad_input_tokens
(
def
pad_input_tokens
(
self
,
input_ids
:
List
[
int
],
image
_inputs
:
Image
Inputs
self
,
input_ids
:
List
[
int
],
mm
_inputs
:
Multimodal
Inputs
)
->
List
[
int
]:
)
->
List
[
int
]:
"""
"""
This function will follow the procedure of:
This function will follow the procedure of:
1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
2. the padded data tokens will be replaced with their pad_values
2. the padded data tokens will be replaced with their pad_values
"""
"""
image_grid_thws
=
image
_inputs
.
image_grid_thws
image_grid_thws
=
mm
_inputs
.
image_grid_thws
pad_values
=
image
_inputs
.
pad_values
pad_values
=
mm
_inputs
.
pad_values
image_indices
=
[
image_indices
=
[
idx
idx
for
idx
,
token
in
enumerate
(
input_ids
)
if
token
==
mm_inputs
.
im_token_id
for
idx
,
token
in
enumerate
(
input_ids
)
if
token
==
image_inputs
.
im_token_id
]
]
image
_inputs
.
image_offsets
=
[]
mm
_inputs
.
image_offsets
=
[]
input_ids_with_image
=
[]
input_ids_with_image
=
[]
for
image_cnt
,
_
in
enumerate
(
image_grid_thws
):
for
image_cnt
,
_
in
enumerate
(
image_grid_thws
):
# print(f"image_cnt {image_cnt}")
num_image_tokens
=
self
.
num_data_token_calc_func
(
image_grid_thws
[
image_cnt
])
num_image_tokens
=
self
.
num_data_token_calc_func
(
image_grid_thws
[
image_cnt
])
if
image_cnt
==
0
:
if
image_cnt
==
0
:
non_image_tokens
=
input_ids
[:
image_indices
[
image_cnt
]]
non_image_tokens
=
input_ids
[:
image_indices
[
image_cnt
]]
...
@@ -135,7 +135,7 @@ class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern)
...
@@ -135,7 +135,7 @@ class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern)
image_indices
[
image_cnt
-
1
]
+
1
:
image_indices
[
image_cnt
]
image_indices
[
image_cnt
-
1
]
+
1
:
image_indices
[
image_cnt
]
]
]
input_ids_with_image
.
extend
(
non_image_tokens
)
input_ids_with_image
.
extend
(
non_image_tokens
)
image
_inputs
.
image_offsets
.
append
(
len
(
input_ids_with_image
))
mm
_inputs
.
image_offsets
.
append
(
len
(
input_ids_with_image
))
pad_ids
=
pad_values
*
(
pad_ids
=
pad_values
*
(
(
num_image_tokens
+
len
(
pad_values
))
//
len
(
pad_values
)
(
num_image_tokens
+
len
(
pad_values
))
//
len
(
pad_values
)
)
)
...
@@ -170,11 +170,11 @@ class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern
...
@@ -170,11 +170,11 @@ class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern
return
input_ids_tensor
.
tolist
()
return
input_ids_tensor
.
tolist
()
def
embed_
image
_inputs
(
def
embed_
mm
_inputs
(
image
_input
:
Image
Inputs
,
mm
_input
:
Multimodal
Inputs
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
input_embedding
:
nn
.
Embedding
,
input_embedding
:
nn
.
Embedding
,
image
_embedding_func
,
mm_data
_embedding_func
:
Callable
[[
MultimodalInputs
],
torch
.
Tensor
]
,
placeholder_token_ids
:
List
[
int
]
=
None
,
placeholder_token_ids
:
List
[
int
]
=
None
,
)
->
Optional
[
torch
.
Tensor
]:
)
->
Optional
[
torch
.
Tensor
]:
"""
"""
...
@@ -184,10 +184,10 @@ def embed_image_inputs(
...
@@ -184,10 +184,10 @@ def embed_image_inputs(
Returns:
Returns:
final embedding: Optional[torch.Tensor]
final embedding: Optional[torch.Tensor]
"""
"""
if
image
_input
is
None
:
if
mm
_input
is
None
:
return
None
return
None
placeholder_token_ids
=
placeholder_token_ids
or
image
_input
.
pad_values
placeholder_token_ids
=
placeholder_token_ids
or
mm
_input
.
pad_values
# boolean masking the special tokens
# boolean masking the special tokens
special_image_mask
=
torch
.
isin
(
special_image_mask
=
torch
.
isin
(
...
@@ -196,12 +196,18 @@ def embed_image_inputs(
...
@@ -196,12 +196,18 @@ def embed_image_inputs(
).
unsqueeze
(
-
1
)
).
unsqueeze
(
-
1
)
num_image_tokens_in_input_ids
=
special_image_mask
.
sum
()
num_image_tokens_in_input_ids
=
special_image_mask
.
sum
()
# print(f"{num_image_tokens_in_input_ids}")
# print(f"{input_ids}")
# return
if
num_image_tokens_in_input_ids
==
0
:
if
num_image_tokens_in_input_ids
==
0
:
# unexpected
# unexpected
inputs_embeds
=
input_embedding
(
input_ids
)
inputs_embeds
=
input_embedding
(
input_ids
)
else
:
else
:
image_embedding
=
image_embedding_func
(
image_input
)
# print(f"Getting image feature")
image_embedding
=
mm_data_embedding_func
(
mm_input
)
# print(f"image_embedding: {image_embedding.shape}")
if
image_embedding
.
dim
()
==
2
:
if
image_embedding
.
dim
()
==
2
:
num_image_tokens_in_embedding
=
image_embedding
.
shape
[
0
]
num_image_tokens_in_embedding
=
image_embedding
.
shape
[
0
]
...
@@ -273,31 +279,95 @@ def embed_image_embedding(
...
@@ -273,31 +279,95 @@ def embed_image_embedding(
def
general_mm_embed_routine
(
def
general_mm_embed_routine
(
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
embed_tokens
:
nn
.
Embedding
,
embed_tokens
:
nn
.
Embedding
,
image
_embedding_func
:
Callable
[[
Image
Inputs
],
torch
.
Tensor
],
mm_data
_embedding_func
:
Callable
[[
Multimodal
Inputs
],
torch
.
Tensor
],
placeholder_token_ids
:
List
[
int
]
=
None
,
placeholder_token_ids
:
List
[
int
]
=
None
,
):
):
"""
"""
a general wrapper function to get final input embeds from multimodal models
a general wrapper function to get final input embeds from multimodal models
with a language model as causal model
with a language model as causal model
Args:
placeholder_token_ids (List[int]): the ids of mm data placeholder tokens
"""
"""
if
(
if
(
forward_batch
.
forward_mode
.
is_decode
()
not
forward_batch
.
forward_mode
.
is_decode
()
or
not
forward_batch
.
contains_
image
_inputs
()
and
forward_batch
.
contains_
mm
_inputs
()
):
):
inputs_embeds
=
embed_tokens
(
input_ids
)
image
=
forward_batch
.
merge_mm_inputs
()
else
:
inputs_embeds
=
embed_mm_inputs
(
image
=
forward_batch
.
merge_image_inputs
()
mm_input
=
image
,
inputs_embeds
=
embed_image_inputs
(
image_input
=
image
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
input_embedding
=
embed_tokens
,
input_embedding
=
embed_tokens
,
image
_embedding_func
=
image
_embedding_func
,
mm_data
_embedding_func
=
mm_data
_embedding_func
,
placeholder_token_ids
=
placeholder_token_ids
,
placeholder_token_ids
=
placeholder_token_ids
,
)
)
# once used,
image
_inputs is useless
# once used,
mm
_inputs is useless
# just being defensive here
# just being defensive here
forward_batch
.
image_inputs
=
None
forward_batch
.
mm_inputs
=
None
else
:
inputs_embeds
=
embed_tokens
(
input_ids
)
return
inputs_embeds
return
inputs_embeds
def
get_multimodal_data_bounds
(
input_ids
:
torch
.
Tensor
,
pad_values
:
List
[
int
],
token_pairs
:
List
[
Tuple
[
int
,
int
]]
)
->
torch
.
Tensor
:
"""
Returns a tensor indicating the bounds of multimodal data (images, video, audio, etc.)
Returns:
[bounds_count, 2]
"""
# All the images in the batch should share the same special image
# bound token ids.
start_tokens
=
[
s
for
s
,
_e
in
token_pairs
]
end_tokens
=
[
e
for
_s
,
e
in
token_pairs
]
assert
all
(
isinstance
(
t
,
int
)
for
t
in
start_tokens
)
assert
all
(
isinstance
(
t
,
int
)
for
t
in
end_tokens
)
# print(input_ids)
start_cond
=
torch
.
isin
(
input_ids
,
torch
.
tensor
(
start_tokens
,
device
=
input_ids
.
device
)
)
end_cond
=
torch
.
isin
(
input_ids
,
torch
.
tensor
(
end_tokens
,
device
=
input_ids
.
device
))
(
data_start_tokens
,)
=
torch
.
where
(
start_cond
)
(
data_end_tokens
,)
=
torch
.
where
(
end_cond
)
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the images
if
len
(
data_start_tokens
)
!=
len
(
data_end_tokens
):
if
(
len
(
data_start_tokens
)
+
1
==
len
(
data_end_tokens
)
and
input_ids
[
0
]
in
pad_values
and
data_end_tokens
[
0
]
<
data_start_tokens
[
0
]
):
data_start_tokens
=
torch
.
cat
(
[
torch
.
tensor
([
0
],
device
=
data_start_tokens
.
device
),
data_start_tokens
,
]
)
valid_image_nums
=
min
(
len
(
data_start_tokens
),
len
(
data_end_tokens
))
if
valid_image_nums
==
0
:
return
torch
.
zeros
((
0
,
2
),
device
=
input_ids
.
device
)
# Filter out pairs where start_token >= end_token
valid_pairs
=
[]
for
i
in
range
(
valid_image_nums
):
start_token
=
data_start_tokens
[
i
]
end_token
=
data_end_tokens
[
i
]
if
start_token
<
end_token
:
valid_pairs
.
append
((
start_token
+
1
,
end_token
-
1
))
if
not
valid_pairs
:
return
torch
.
zeros
((
0
,
2
),
device
=
input_ids
.
device
)
# Convert valid pairs to tensor
valid_pairs_tensor
=
torch
.
tensor
(
valid_pairs
,
device
=
input_ids
.
device
)
return
valid_pairs_tensor
python/sglang/srt/managers/
image
_processor.py
→
python/sglang/srt/managers/
multimodal
_processor.py
View file @
1e86457c
...
@@ -4,46 +4,41 @@ import inspect
...
@@ -4,46 +4,41 @@ import inspect
import
logging
import
logging
import
pkgutil
import
pkgutil
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
Union
from
torch
import
Tensor
from
transformers
import
PROCESSOR_MAPPING
from
transformers
import
IMAGE_PROCESSOR_MAPPING
from
sglang.srt.managers.image_processors.base_image_processor
import
(
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
BaseImageProcessor
,
BaseMultimodalProcessor
,
DummyImageProcessor
,
)
)
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
PROCESSOR_MAPPING
=
{}
IMAGE_PROCESSOR_MAPPING
=
{}
class
DummyMultimodalProcessor
(
BaseMultimodalProcessor
):
def
__init__
(
self
):
pass
def
get_image_processor
(
hf_config
,
server_args
,
processor
)
->
BaseImageProcessor
:
async
def
process_mm_data_async
(
self
,
*
args
,
**
kwargs
):
for
model_cls
,
processor_cls
in
IMAGE_PROCESSOR_MAPPING
.
items
():
return
None
if
model_cls
.
__name__
in
hf_config
.
architectures
:
return
processor_cls
(
hf_config
,
server_args
,
processor
)
raise
ValueError
(
f
"No image processor found for architecture:
{
hf_config
.
architectures
}
"
)
def
get_dummy_
image_
processor
():
def
get_dummy_processor
():
return
Dummy
Image
Processor
()
return
Dummy
Multimodal
Processor
()
@
lru_cache
()
@
lru_cache
()
def
import_
image_
processors
():
def
import_processors
():
package_name
=
"sglang.srt.managers.
image
_processors"
package_name
=
"sglang.srt.managers.
multimodal
_processors"
package
=
importlib
.
import_module
(
package_name
)
package
=
importlib
.
import_module
(
package_name
)
for
_
,
name
,
ispkg
in
pkgutil
.
iter_modules
(
package
.
__path__
,
package_name
+
"."
):
for
_
,
name
,
ispkg
in
pkgutil
.
iter_modules
(
package
.
__path__
,
package_name
+
"."
):
if
not
ispkg
:
if
not
ispkg
:
try
:
try
:
module
=
importlib
.
import_module
(
name
)
module
=
importlib
.
import_module
(
name
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
warning
(
f
"
Ignore import error when loading
{
name
}
: "
f
"
{
e
}
"
)
logger
.
warning
(
f
"Ignore import error when loading
{
name
}
: "
f
"
{
e
}
"
)
continue
continue
all_members
=
inspect
.
getmembers
(
module
,
inspect
.
isclass
)
all_members
=
inspect
.
getmembers
(
module
,
inspect
.
isclass
)
classes
=
[
classes
=
[
...
@@ -51,11 +46,23 @@ def import_image_processors():
...
@@ -51,11 +46,23 @@ def import_image_processors():
for
name
,
member
in
all_members
for
name
,
member
in
all_members
if
member
.
__module__
==
module
.
__name__
if
member
.
__module__
==
module
.
__name__
]
]
for
cls
in
classes
:
for
cls
in
(
if
issubclass
(
cls
,
BaseImageProcessor
):
cls
for
cls
in
classes
if
issubclass
(
cls
,
BaseMultimodalProcessor
)
for
arch
in
getattr
(
cls
,
"models"
):
):
IMAGE_PROCESSOR_MAPPING
[
arch
]
=
cls
assert
hasattr
(
cls
,
"models"
)
for
arch
in
getattr
(
cls
,
"models"
):
PROCESSOR_MAPPING
[
arch
]
=
cls
def
get_mm_processor
(
hf_config
,
server_args
:
ServerArgs
,
processor
)
->
BaseMultimodalProcessor
:
for
model_cls
,
processor_cls
in
PROCESSOR_MAPPING
.
items
():
if
model_cls
.
__name__
in
hf_config
.
architectures
:
return
processor_cls
(
hf_config
,
server_args
,
processor
)
raise
ValueError
(
f
"No processor registered for architecture:
{
hf_config
.
architectures
}
.
\n
"
f
"Registered architectures:
{
[
model_cls
.
__name__
for
model_cls
in
PROCESSOR_MAPPING
.
keys
()]
}
"
)
# also register processors
self
.
image_proce
import_image_processors
()
python/sglang/srt/managers/
image
_processors/base_
image_
processor.py
→
python/sglang/srt/managers/
multimodal
_processors/base_processor.py
View file @
1e86457c
...
@@ -4,16 +4,16 @@ import dataclasses
...
@@ -4,16 +4,16 @@ import dataclasses
import
multiprocessing
as
mp
import
multiprocessing
as
mp
import
os
import
os
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
,
Union
from
typing
import
Optional
import
numpy
as
np
import
PIL
import
PIL
import
transformers
import
transformers
from
decord
import
VideoReader
,
cpu
from
decord
import
VideoReader
,
cpu
from
openai
import
BadRequestError
from
openai
import
BadRequestError
from
PIL
import
Image
from
PIL
import
Image
from
sglang.srt.utils
import
load_image
from
sglang.srt.utils
import
load_audio
,
load_image
,
logger
from
sglang.utils
import
logger
global
global_processor
global
global_processor
...
@@ -24,21 +24,41 @@ def get_global_processor():
...
@@ -24,21 +24,41 @@ def get_global_processor():
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
BaseImageProcessorOutput
:
class
BaseMultiModalProcessorOutput
:
image_hashes
:
list
[
int
]
# input_text, with each frame of video/image represented with a image_token
image_sizes
:
list
[
tuple
[
int
,
int
]]
all_frames
:
[
PIL
.
Image
]
# input_text, with each frame of video/image represented as an image_token
input_text
:
str
input_text
:
str
mm_data_hashes
:
Optional
[
list
[
int
]]
# images
image_sizes
:
Optional
[
list
[
int
]]
# frames loaded from image and video, in given order
images
:
Optional
[
list
[
PIL
.
Image
]]
=
None
# audios
audios
:
Optional
[
list
[
np
.
ndarray
]]
=
None
def
normalize
(
self
):
def
normalize
(
self
):
for
field_name
in
[
"data_hashes"
,
"image_sizes"
,
"
all_frame
s"
]:
for
field_name
in
[
"data_hashes"
,
"image_sizes"
,
"
images"
,
"audio
s"
]:
field
=
getattr
(
self
,
field_name
,
None
)
field
=
getattr
(
self
,
field_name
,
None
)
if
field
is
not
None
and
isinstance
(
field
,
list
)
and
len
(
field
)
==
0
:
if
field
is
not
None
and
isinstance
(
field
,
list
)
and
len
(
field
)
==
0
:
setattr
(
self
,
field_name
,
None
)
setattr
(
self
,
field_name
,
None
)
class
BaseImageProcessor
(
ABC
):
@
dataclasses
.
dataclass
class
MultimodalSpecialTokens
:
image_token
:
Optional
[
str
]
=
None
video_token
:
Optional
[
str
]
=
None
audio_token
:
Optional
[
str
]
=
None
def
collect
(
self
)
->
list
[
str
]:
return
[
token
for
token
in
[
self
.
image_token
,
self
.
video_token
,
self
.
audio_token
]
if
token
]
class
BaseMultimodalProcessor
(
ABC
):
models
=
[]
models
=
[]
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
...
@@ -72,7 +92,7 @@ class BaseImageProcessor(ABC):
...
@@ -72,7 +92,7 @@ class BaseImageProcessor(ABC):
)
)
@
abstractmethod
@
abstractmethod
async
def
process_
images
_async
(
async
def
process_
mm_data
_async
(
self
,
image_data
,
input_text
,
max_req_input_len
,
**
kwargs
self
,
image_data
,
input_text
,
max_req_input_len
,
**
kwargs
):
):
pass
pass
...
@@ -120,29 +140,33 @@ class BaseImageProcessor(ABC):
...
@@ -120,29 +140,33 @@ class BaseImageProcessor(ABC):
frames
=
[
Image
.
fromarray
(
v
.
astype
(
"uint8"
))
for
v
in
frames
]
frames
=
[
Image
.
fromarray
(
v
.
astype
(
"uint8"
))
for
v
in
frames
]
return
frames
return
frames
def
load_
images
(
def
load_
mm_data
(
self
,
self
,
input_ids
:
list
[
int
],
input_ids
:
list
[
int
],
image_data
,
multimodal_tokens
:
MultimodalSpecialTokens
,
image_token
:
Union
[
int
,
str
],
max_req_input_len
:
int
,
max_req_input_len
:
int
,
image_data
:
Optional
[
list
]
=
None
,
audio_data
:
Optional
[
list
]
=
None
,
return_text
:
Optional
[
bool
]
=
True
,
return_text
:
Optional
[
bool
]
=
True
,
discard_alpha_channel
:
bool
=
True
,
discard_alpha_channel
:
bool
=
True
,
)
->
Base
Image
ProcessorOutput
:
)
->
Base
MultiModal
ProcessorOutput
:
"""
"""
Each frame of video/image will be replaced by a single image token
Each frame of video/image will be replaced by a single image token
Args:
Args:
image_token: The token ID representing the image placeholder.
multimodal_tokens (list[str]): list of special token which denoting a single multimodal data
e.g. image token or audio token
discard_alpha_channel: if True, discards the alpha channel in the returned images
discard_alpha_channel: if True, discards the alpha channel in the returned images
"""
"""
if
isinstance
(
image_token
,
int
):
if
isinstance
(
multimodal_tokens
.
image_token
,
int
):
image_token_str
=
self
.
_processor
.
tokenizer
.
convert_ids_to_tokens
(
multimodal_tokens
.
image_token
=
(
image_token
self
.
_processor
.
tokenizer
.
convert_ids_to_tokens
(
multimodal_tokens
.
image_token
)
)
)
else
:
else
:
image_token_str
=
image_token
multimodal_tokens
.
image_token
=
multimodal_tokens
.
image_token
if
isinstance
(
input_ids
,
list
)
and
return_text
:
if
isinstance
(
input_ids
,
list
)
and
return_text
:
assert
len
(
input_ids
)
and
isinstance
(
input_ids
[
0
],
int
)
assert
len
(
input_ids
)
and
isinstance
(
input_ids
[
0
],
int
)
...
@@ -152,7 +176,11 @@ class BaseImageProcessor(ABC):
...
@@ -152,7 +176,11 @@ class BaseImageProcessor(ABC):
if
return_text
:
if
return_text
:
import
re
import
re
pattern
=
"("
+
"|"
.
join
(
re
.
escape
(
sep
)
for
sep
in
[
image_token
])
+
")"
pattern
=
(
"("
+
"|"
.
join
(
re
.
escape
(
sep
)
for
sep
in
multimodal_tokens
.
collect
())
+
")"
)
# split text into list of normal text and special tokens
# split text into list of normal text and special tokens
text_parts
=
re
.
split
(
pattern
,
input_text
)
text_parts
=
re
.
split
(
pattern
,
input_text
)
...
@@ -162,7 +190,7 @@ class BaseImageProcessor(ABC):
...
@@ -162,7 +190,7 @@ class BaseImageProcessor(ABC):
total_frame_count
=
sum
(
estimated_frames_list
)
total_frame_count
=
sum
(
estimated_frames_list
)
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
_
scaling_factor
=
min
(
1.0
,
MAX_NUM_FRAMES
/
max
(
1
,
total_frame_count
))
scaling_factor
=
min
(
1.0
,
MAX_NUM_FRAMES
/
max
(
1
,
total_frame_count
))
assert
len
(
image_data
)
==
len
(
estimated_frames_list
)
assert
len
(
image_data
)
==
len
(
estimated_frames_list
)
...
@@ -171,9 +199,16 @@ class BaseImageProcessor(ABC):
...
@@ -171,9 +199,16 @@ class BaseImageProcessor(ABC):
new_text
=
""
new_text
=
""
for
index
,
text_part
in
enumerate
(
text_parts
):
for
index
,
text_part
in
enumerate
(
text_parts
):
try
:
try
:
if
text_part
==
image_token
:
if
text_part
==
multimodal_tokens
.
image_token
:
# load as image
# load as image
frames_to_process
=
estimated_frames_list
[
image_index
]
if
len
(
images
)
>=
MAX_NUM_FRAMES
:
frames_to_process
=
0
else
:
estimated_frames
=
estimated_frames_list
[
image_index
]
frames_to_process
=
max
(
1
,
int
(
estimated_frames
*
scaling_factor
)
)
if
frames_to_process
==
0
:
if
frames_to_process
==
0
:
frames
=
[]
frames
=
[]
else
:
else
:
...
@@ -183,7 +218,7 @@ class BaseImageProcessor(ABC):
...
@@ -183,7 +218,7 @@ class BaseImageProcessor(ABC):
):
):
# video
# video
path
=
image_file
[
len
(
"video:"
)
:]
path
=
image_file
[
len
(
"video:"
)
:]
frames
=
self
.
encode_video
(
frames
=
BaseMultimodalProcessor
.
encode_video
(
path
,
frame_count_limit
=
frames_to_process
path
,
frame_count_limit
=
frames_to_process
)
)
else
:
else
:
...
@@ -200,40 +235,41 @@ class BaseImageProcessor(ABC):
...
@@ -200,40 +235,41 @@ class BaseImageProcessor(ABC):
images
+=
frames
images
+=
frames
image_index
+=
1
image_index
+=
1
if
frames_to_process
!=
0
:
if
frames_to_process
!=
0
:
new_text
+=
image_token
*
len
(
frames
)
new_text
+=
multimodal_tokens
.
image_token
*
len
(
frames
)
assert
frames_to_process
==
len
(
frames
)
assert
frames_to_process
==
len
(
frames
)
elif
text_part
==
multimodal_tokens
.
audio_token
:
# load as audio
audio_file
=
audio_data
[
audio_index
]
audio
=
load_audio
(
audio_file
)
hashes
+=
[
hash
(
audio_file
)]
audios
+=
[
audio
]
audio_index
+=
1
new_text
+=
multimodal_tokens
.
audio_token
else
:
else
:
# TODO(mick): handle video
# TODO(mick): handle video
# normal text
# normal text
new_text
+=
text_part
new_text
+=
text_part
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"An exception occurred while loading images:
{
e
}
"
)
logger
.
error
(
f
"An exception occurred while loading images:
{
e
}
"
)
raise
BadRequestError
(
raise
BadRequestError
(
f
"An exception occurred while loading images:
{
e
}
"
f
"An exception occurred while loading images:
{
e
}
"
)
)
return
BaseImage
ProcessorOutput
(
out
=
BaseMultiModal
ProcessorOutput
(
image
_hashes
=
hashes
,
mm_data
_hashes
=
hashes
,
image_sizes
=
image_sizes
,
image_sizes
=
image_sizes
,
all_frames
=
images
,
images
=
images
,
audios
=
audios
,
input_text
=
new_text
,
input_text
=
new_text
,
)
)
out
.
normalize
()
out
.
normalize
()
return
out
return
out
class
DummyImageProcessor
(
BaseImageProcessor
):
def
init_global_processor
(
sglang_processor
:
BaseMultimodalProcessor
,
server_args
):
def
__init__
(
self
):
"""
pass
Init the global processor for multimodal models."""
async
def
process_images_async
(
self
,
*
args
,
**
kwargs
):
return
None
def
init_global_processor
(
sglang_image_processor
:
BaseImageProcessor
,
server_args
):
"""Init the global processor for multi-modal models."""
global
global_processor
global
global_processor
transformers
.
logging
.
set_verbosity_error
()
transformers
.
logging
.
set_verbosity_error
()
global_processor
=
sglang_
image_
processor
.
_build_processor
(
server_args
=
server_args
)
global_processor
=
sglang_processor
.
_build_processor
(
server_args
=
server_args
)
python/sglang/srt/managers/
image
_processors/deepseek_vl_v2.py
→
python/sglang/srt/managers/
multimodal
_processors/deepseek_vl_v2.py
View file @
1e86457c
...
@@ -20,14 +20,15 @@ import asyncio
...
@@ -20,14 +20,15 @@ import asyncio
import
torch
import
torch
from
sglang.srt.managers.image_processor
import
BaseImageProcessor
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
from
sglang.srt.managers.image_processors.base_image_processor
import
(
BaseMultimodalProcessor
,
MultimodalSpecialTokens
,
get_global_processor
,
get_global_processor
,
)
)
from
sglang.srt.models.deepseek_vl2
import
DeepseekVL2ForCausalLM
from
sglang.srt.models.deepseek_vl2
import
DeepseekVL2ForCausalLM
class
DeepseekVL2ImageProcessor
(
Base
Image
Processor
):
class
DeepseekVL2ImageProcessor
(
Base
Multimodal
Processor
):
models
=
[
DeepseekVL2ForCausalLM
]
models
=
[
DeepseekVL2ForCausalLM
]
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
...
@@ -63,7 +64,23 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor):
...
@@ -63,7 +64,23 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor):
return
image_inputs
return
image_inputs
async
def
process_images_async
(
async
def
_process_images
(
self
,
image_data
,
input_text
,
max_req_input_len
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
image_inputs
=
await
loop
.
run_in_executor
(
self
.
executor
,
DeepseekVL2ImageProcessor
.
_process_images_task
,
image_data
,
input_text
,
max_req_input_len
,
)
else
:
image_inputs
=
self
.
_process_images_task
(
image_data
,
input_text
,
max_req_input_len
)
return
image_inputs
async
def
process_mm_data_async
(
self
,
image_data
,
input_ids
,
request_obj
,
max_req_input_len
,
*
args
,
**
kwargs
self
,
image_data
,
input_ids
,
request_obj
,
max_req_input_len
,
*
args
,
**
kwargs
):
):
if
not
image_data
:
if
not
image_data
:
...
@@ -75,11 +92,14 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor):
...
@@ -75,11 +92,14 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor):
images
,
image_sizes
=
[],
[]
images
,
image_sizes
=
[],
[]
image_token
=
self
.
IMAGE_TOKEN
image_token
=
self
.
IMAGE_TOKEN
base_output
=
self
.
load_images
(
base_output
=
self
.
load_mm_data
(
input_ids
,
image_data
,
image_token
,
max_req_input_len
input_ids
,
image_data
=
image_data
,
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
image_token
),
max_req_input_len
=
max_req_input_len
,
)
)
res
=
await
self
.
_process_images
(
res
=
await
self
.
_process_images
(
base_output
.
all_fram
es
,
base_output
.
input_text
,
max_req_input_len
base_output
.
imag
es
,
base_output
.
input_text
,
max_req_input_len
)
)
images_seq_mask
=
res
[
"images_seq_mask"
]
images_seq_mask
=
res
[
"images_seq_mask"
]
images_spatial_crop
=
res
[
"images_spatial_crop"
]
images_spatial_crop
=
res
[
"images_spatial_crop"
]
...
@@ -91,7 +111,7 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor):
...
@@ -91,7 +111,7 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor):
"input_ids"
:
res
[
"input_ids"
].
tolist
(),
"input_ids"
:
res
[
"input_ids"
].
tolist
(),
"pixel_values"
:
res
[
"images"
],
"pixel_values"
:
res
[
"images"
],
"im_token_id"
:
res
[
"im_token_id"
],
"im_token_id"
:
res
[
"im_token_id"
],
"
image
_hashes"
:
base_output
.
image
_hashes
,
"
data
_hashes"
:
base_output
.
mm_data
_hashes
,
"image_sizes"
:
image_sizes
,
"image_sizes"
:
image_sizes
,
"images_emb_mask"
:
images_seq_mask
,
"images_emb_mask"
:
images_seq_mask
,
"image_spatial_crop"
:
batched_images_spatial_crop
,
"image_spatial_crop"
:
batched_images_spatial_crop
,
...
...
python/sglang/srt/managers/
image
_processors/gemma3.py
→
python/sglang/srt/managers/
multimodal
_processors/gemma3.py
View file @
1e86457c
import
asyncio
from
typing
import
List
,
Union
from
typing
import
List
,
Union
from
transformers.utils
import
logging
from
transformers.utils
import
logging
from
sglang.srt.managers.
image
_processor
import
(
from
sglang.srt.managers.
multimodal
_processor
import
(
Base
Image
Processor
as
SGLangBase
Image
Processor
,
Base
Multimodal
Processor
as
SGLangBaseProcessor
,
)
)
from
sglang.srt.managers.image_processors.base_image_processor
import
(
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
MultimodalSpecialTokens
,
get_global_processor
,
get_global_processor
,
)
)
from
sglang.srt.models.gemma3_mm
import
Gemma3ForConditionalGeneration
from
sglang.srt.models.gemma3_mm
import
Gemma3ForConditionalGeneration
...
@@ -16,7 +16,7 @@ from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
...
@@ -16,7 +16,7 @@ from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
class
Gemma3SGLangImageProcessor
(
SGLangBase
Image
Processor
):
class
Gemma3SGLangImageProcessor
(
SGLangBaseProcessor
):
models
=
[
Gemma3ForConditionalGeneration
]
models
=
[
Gemma3ForConditionalGeneration
]
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
...
@@ -47,7 +47,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
...
@@ -47,7 +47,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
"pixel_values"
:
pixel_values
,
"pixel_values"
:
pixel_values
,
}
}
async
def
process_
images
_async
(
async
def
process_
mm_data
_async
(
self
,
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
image_data
:
List
[
Union
[
str
,
bytes
]],
input_ids
,
input_ids
,
...
@@ -62,22 +62,22 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
...
@@ -62,22 +62,22 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
image_data
=
[
image_data
]
image_data
=
[
image_data
]
image_token
=
self
.
IMAGE_TOKEN
image_token
=
self
.
IMAGE_TOKEN
base_output
=
self
.
load_
images
(
base_output
=
self
.
load_
mm_data
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
image_data
=
image_data
,
image_data
=
image_data
,
image_token
=
image_token
,
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
image_token
)
,
max_req_input_len
=
max_req_input_len
,
max_req_input_len
=
max_req_input_len
,
discard_alpha_channel
=
True
,
discard_alpha_channel
=
True
,
)
)
ret
=
await
self
.
_process_single_image
(
ret
=
await
self
.
_process_single_image
(
input_text
=
base_output
.
input_text
,
images
=
base_output
.
all_fram
es
input_text
=
base_output
.
input_text
,
images
=
base_output
.
imag
es
)
)
return
{
return
{
"input_ids"
:
ret
[
"input_ids"
].
flatten
().
tolist
(),
"input_ids"
:
ret
[
"input_ids"
].
flatten
().
tolist
(),
"pixel_values"
:
ret
[
"pixel_values"
],
"pixel_values"
:
ret
[
"pixel_values"
],
"
image
_hashes"
:
base_output
.
image
_hashes
,
"
data
_hashes"
:
base_output
.
mm_data
_hashes
,
"im_start_id"
:
self
.
IM_START_TOKEN_ID
,
"im_start_id"
:
self
.
IM_START_TOKEN_ID
,
"im_end_id"
:
self
.
IM_END_TOKEN_ID
,
"im_end_id"
:
self
.
IM_END_TOKEN_ID
,
}
}
python/sglang/srt/managers/
image
_processors/janus_pro.py
→
python/sglang/srt/managers/
multimodal
_processors/janus_pro.py
View file @
1e86457c
import
asyncio
import
asyncio
from
typing
import
List
,
Union
from
typing
import
List
,
Union
from
sglang.srt.managers.image_processors.base_image_processor
import
(
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
BaseImageProcessor
as
SGLangBaseImageProcessor
,
BaseMultimodalProcessor
,
)
MultimodalSpecialTokens
,
from
sglang.srt.managers.image_processors.base_image_processor
import
(
get_global_processor
,
get_global_processor
,
)
)
from
sglang.srt.models.deepseek_janus_pro
import
MultiModalityCausalLM
from
sglang.srt.models.deepseek_janus_pro
import
MultiModalityCausalLM
class
JanusProProcessor
(
SGLangBaseImage
Processor
):
class
JanusPro
Image
Processor
(
BaseMultimodal
Processor
):
models
=
[
MultiModalityCausalLM
]
models
=
[
MultiModalityCausalLM
]
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
...
@@ -36,7 +35,7 @@ class JanusProProcessor(SGLangBaseImageProcessor):
...
@@ -36,7 +35,7 @@ class JanusProProcessor(SGLangBaseImageProcessor):
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
image_inputs
=
await
loop
.
run_in_executor
(
image_inputs
=
await
loop
.
run_in_executor
(
self
.
executor
,
self
.
executor
,
JanusProProcessor
.
_process_images_task
,
JanusPro
Image
Processor
.
_process_images_task
,
images
,
images
,
input_text
,
input_text
,
)
)
...
@@ -47,7 +46,7 @@ class JanusProProcessor(SGLangBaseImageProcessor):
...
@@ -47,7 +46,7 @@ class JanusProProcessor(SGLangBaseImageProcessor):
return
image_inputs
return
image_inputs
async
def
process_
images
_async
(
async
def
process_
mm_data
_async
(
self
,
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
image_data
:
List
[
Union
[
str
,
bytes
]],
input_ids
,
input_ids
,
...
@@ -61,20 +60,24 @@ class JanusProProcessor(SGLangBaseImageProcessor):
...
@@ -61,20 +60,24 @@ class JanusProProcessor(SGLangBaseImageProcessor):
if
not
isinstance
(
image_data
,
list
):
if
not
isinstance
(
image_data
,
list
):
image_data
=
[
image_data
]
image_data
=
[
image_data
]
base_out
=
self
.
load_
images
(
base_out
=
self
.
load_
mm_data
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
image_data
=
image_data
,
image_data
=
image_data
,
image_token
=
"<image_placeholder>"
,
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
"<image_placeholder>"
),
max_req_input_len
=
max_req_input_len
,
max_req_input_len
=
max_req_input_len
,
)
)
images
=
base_out
.
all_fram
es
images
=
base_out
.
imag
es
res
=
await
self
.
_process_images
(
images
=
images
,
input_text
=
base_out
.
input_text
)
res
=
await
self
.
_process_images
(
images
=
images
,
input_text
=
base_out
.
input_text
)
# print(res)
# print(base_out)
# print("", res["images_emb_mask"].shape)
return
{
return
{
"input_ids"
:
res
[
"input_ids"
].
flatten
().
tolist
(),
"input_ids"
:
res
[
"input_ids"
].
flatten
().
tolist
(),
"pixel_values"
:
res
[
"pixel_values"
],
"pixel_values"
:
res
[
"pixel_values"
],
"images_emb_mask"
:
res
[
"images_emb_mask"
],
"images_emb_mask"
:
res
[
"images_emb_mask"
],
"
image
_hashes"
:
base_out
.
image
_hashes
,
"
data
_hashes"
:
base_out
.
mm_data
_hashes
,
"im_start_id"
:
res
[
"im_start_id"
],
"im_start_id"
:
res
[
"im_start_id"
],
"im_end_id"
:
res
[
"im_end_id"
],
"im_end_id"
:
res
[
"im_end_id"
],
"im_token_id"
:
res
[
"im_token_id"
],
"im_token_id"
:
res
[
"im_token_id"
],
...
...
python/sglang/srt/managers/
image
_processors/llava.py
→
python/sglang/srt/managers/
multimodal
_processors/llava.py
View file @
1e86457c
...
@@ -3,8 +3,8 @@ from typing import List, Optional, Union
...
@@ -3,8 +3,8 @@ from typing import List, Optional, Union
import
numpy
as
np
import
numpy
as
np
from
sglang.srt.managers.
image_processor
import
BaseImageProcessor
from
sglang.srt.managers.
multimodal_processors.base_processor
import
(
from
sglang.srt.managers.image_processors.base_image_processor
import
(
BaseMultimodalProcessor
,
get_global_processor
,
get_global_processor
,
)
)
from
sglang.srt.mm_utils
import
expand2square
,
process_anyres_image
from
sglang.srt.mm_utils
import
expand2square
,
process_anyres_image
...
@@ -14,7 +14,7 @@ from sglang.srt.utils import load_image, logger
...
@@ -14,7 +14,7 @@ from sglang.srt.utils import load_image, logger
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
class
LlavaImageProcessor
(
Base
Image
Processor
):
class
LlavaImageProcessor
(
Base
Multimodal
Processor
):
models
=
[
LlavaVidForCausalLM
,
LlavaQwenForCausalLM
,
LlavaMistralForCausalLM
]
models
=
[
LlavaVidForCausalLM
,
LlavaQwenForCausalLM
,
LlavaMistralForCausalLM
]
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
...
@@ -86,7 +86,7 @@ class LlavaImageProcessor(BaseImageProcessor):
...
@@ -86,7 +86,7 @@ class LlavaImageProcessor(BaseImageProcessor):
image_data
,
aspect_ratio
,
grid_pinpoints
image_data
,
aspect_ratio
,
grid_pinpoints
)
)
async
def
process_
images
_async
(
async
def
process_
mm_data
_async
(
self
,
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
input_text
,
...
@@ -113,7 +113,7 @@ class LlavaImageProcessor(BaseImageProcessor):
...
@@ -113,7 +113,7 @@ class LlavaImageProcessor(BaseImageProcessor):
if
"multi-images"
in
modalities
or
"video"
in
modalities
:
if
"multi-images"
in
modalities
or
"video"
in
modalities
:
# Multiple images
# Multiple images
aspect_ratio
=
"pad"
# LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
aspect_ratio
=
"pad"
# LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
pixel_values
,
image
_hashes
,
image_sizes
=
[],
[],
[]
pixel_values
,
data
_hashes
,
image_sizes
=
[],
[],
[]
res
=
[]
res
=
[]
for
img_data
in
image_data
:
for
img_data
in
image_data
:
res
.
append
(
res
.
append
(
...
@@ -124,7 +124,7 @@ class LlavaImageProcessor(BaseImageProcessor):
...
@@ -124,7 +124,7 @@ class LlavaImageProcessor(BaseImageProcessor):
res
=
await
asyncio
.
gather
(
*
res
)
res
=
await
asyncio
.
gather
(
*
res
)
for
pixel_v
,
image_h
,
image_s
in
res
:
for
pixel_v
,
image_h
,
image_s
in
res
:
pixel_values
.
append
(
pixel_v
)
pixel_values
.
append
(
pixel_v
)
image
_hashes
.
append
(
image_h
)
data
_hashes
.
append
(
image_h
)
image_sizes
.
append
(
image_s
)
image_sizes
.
append
(
image_s
)
if
isinstance
(
pixel_values
[
0
],
np
.
ndarray
):
if
isinstance
(
pixel_values
[
0
],
np
.
ndarray
):
...
@@ -134,14 +134,14 @@ class LlavaImageProcessor(BaseImageProcessor):
...
@@ -134,14 +134,14 @@ class LlavaImageProcessor(BaseImageProcessor):
pixel_values
,
image_hash
,
image_size
=
await
self
.
_process_single_image
(
pixel_values
,
image_hash
,
image_size
=
await
self
.
_process_single_image
(
image_data
[
0
],
aspect_ratio
,
grid_pinpoints
image_data
[
0
],
aspect_ratio
,
grid_pinpoints
)
)
image
_hashes
=
[
image_hash
]
data
_hashes
=
[
image_hash
]
image_sizes
=
[
image_size
]
image_sizes
=
[
image_size
]
else
:
else
:
raise
ValueError
(
f
"Invalid image data:
{
image_data
}
"
)
raise
ValueError
(
f
"Invalid image data:
{
image_data
}
"
)
return
{
return
{
"pixel_values"
:
pixel_values
,
"pixel_values"
:
pixel_values
,
"
image
_hashes"
:
image
_hashes
,
"
data
_hashes"
:
data
_hashes
,
"image_sizes"
:
image_sizes
,
"image_sizes"
:
image_sizes
,
"modalities"
:
request_obj
.
modalities
or
[
"image"
],
"modalities"
:
request_obj
.
modalities
or
[
"image"
],
}
}
python/sglang/srt/managers/
image
_processors/minicpm
v
.py
→
python/sglang/srt/managers/
multimodal
_processors/minicpm.py
View file @
1e86457c
...
@@ -3,82 +3,113 @@ from typing import List, Union
...
@@ -3,82 +3,113 @@ from typing import List, Union
import
torch
import
torch
from
sglang.srt.managers.image_processor
import
BaseImageProcessor
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
from
sglang.srt.managers.image_processors.base_image_processor
import
(
BaseMultimodalProcessor
,
MultimodalSpecialTokens
,
get_global_processor
,
get_global_processor
,
)
)
from
sglang.srt.models.minicpmo
import
MiniCPMO
from
sglang.srt.models.minicpmv
import
MiniCPMV
from
sglang.srt.models.minicpmv
import
MiniCPMV
class
MiniCPMVImageProcessor
(
BaseImageProcessor
):
# Compatible with both 'O' and 'V'
models
=
[
MiniCPMV
]
class
MiniCPMMultimodalProcessor
(
BaseMultimodalProcessor
):
models
=
[
MiniCPMV
,
MiniCPMO
]
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
self
.
IMAGE_TOKEN
=
"(<image>./</image>)"
self
.
image_token
=
"(<image>./</image>)"
self
.
audio_token
=
"(<audio>./</audio>)"
@
staticmethod
@
staticmethod
def
_process_images_task
(
images
,
input_text
):
def
_process_data_task
(
input_text
,
images
=
None
,
audios
=
None
):
processor
=
get_global_processor
()
result
=
processor
.
__call__
(
text
=
input_text
,
images
=
images
,
return_tensors
=
"pt"
)
if
isinstance
(
images
,
list
)
and
len
(
images
)
==
0
:
images
=
None
if
isinstance
(
audios
,
list
)
and
len
(
audios
)
==
0
:
audios
=
None
result
=
get_global_processor
().
__call__
(
text
=
input_text
,
images
=
images
,
audios
=
audios
,
return_tensors
=
"pt"
,
chunk_input
=
True
,
)
return
{
return
{
"input_ids"
:
result
.
input_ids
,
"input_ids"
:
result
.
input_ids
,
"pixel_values"
:
result
.
pixel_values
,
"pixel_values"
:
getattr
(
result
,
"pixel_values"
,
None
),
"tgt_sizes"
:
result
.
tgt_sizes
,
"tgt_sizes"
:
getattr
(
result
,
"tgt_sizes"
,
None
),
"audio_features"
:
getattr
(
result
,
"audio_features"
,
None
),
"audio_feature_lens"
:
getattr
(
result
,
"audio_feature_lens"
,
None
),
"audio_bounds"
:
getattr
(
result
,
"audio_bounds"
,
None
),
}
}
async
def
_process_
images
(
self
,
images
,
input_text
):
async
def
_process_
data
(
self
,
images
,
input_text
,
audios
=
None
):
if
self
.
executor
is
not
None
:
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
image
_inputs
=
await
loop
.
run_in_executor
(
multimodal_data
_inputs
=
await
loop
.
run_in_executor
(
self
.
executor
,
self
.
executor
,
MiniCPMVImageProcessor
.
_process_images_task
,
MiniCPMMultimodalProcessor
.
_process_data_task
,
images
,
input_text
,
input_text
,
images
,
audios
,
)
)
else
:
else
:
image
_inputs
=
self
.
_processor
(
multimodal_data
_inputs
=
self
.
_processor
(
images
=
images
,
text
=
input_text
,
return_tensors
=
"pt"
images
=
images
,
text
=
input_text
,
audios
=
audios
,
return_tensors
=
"pt"
)
)
return
image
_inputs
return
multimodal_data
_inputs
async
def
process_
images
_async
(
async
def
process_
mm_data
_async
(
self
,
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
image_data
:
List
[
Union
[
str
,
bytes
]],
input_ids
,
input_ids
,
request_obj
,
request_obj
,
max_req_input_len
,
max_req_input_len
,
):
):
if
not
image_data
:
audio_data
=
request_obj
.
audio_data
if
not
image_data
and
not
audio_data
:
return
None
return
None
if
not
isinstance
(
image_data
,
list
):
if
not
isinstance
(
image_data
,
list
):
image_data
=
[
image_data
]
image_data
=
[
image_data
]
if
not
isinstance
(
audio_data
,
list
):
audio_data
=
[
audio_data
]
base_output
=
self
.
load_
images
(
base_output
=
self
.
load_
mm_data
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
image_data
=
image_data
,
image_token
=
self
.
IMAGE_TOKEN
,
max_req_input_len
=
max_req_input_len
,
max_req_input_len
=
max_req_input_len
,
audio_data
=
audio_data
,
image_data
=
image_data
,
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
self
.
image_token
,
audio_token
=
self
.
audio_token
),
)
)
if
base_output
is
None
:
if
base_output
is
None
:
return
None
return
None
if
len
(
base_output
.
all_frames
)
==
0
:
res
=
await
self
.
_process_data
(
return
None
images
=
base_output
.
images
,
res
=
await
self
.
_process_images
(
input_text
=
base_output
.
input_text
,
image
s
=
base_output
.
a
ll_frames
,
input_text
=
base_output
.
input_text
audio
s
=
base_output
.
a
udios
,
)
)
# Collect special token ids
# Collect special token ids
tokenizer
=
self
.
_processor
.
tokenizer
tokenizer
=
self
.
_processor
.
tokenizer
im_start_id
=
tokenizer
.
im_start_id
slice_start_id
,
slice_end_id
,
audio_start_id
,
audio_end_id
=
(
im_token_id
=
tokenizer
.
unk_token_id
None
,
im_end_id
=
tokenizer
.
im_end_id
None
,
None
,
None
,
)
if
tokenizer
.
slice_start_id
:
if
tokenizer
.
slice_start_id
:
slice_start_id
=
tokenizer
.
slice_start_id
slice_start_id
=
tokenizer
.
slice_start_id
slice_end_id
=
tokenizer
.
slice_end_id
slice_end_id
=
tokenizer
.
slice_end_id
if
hasattr
(
tokenizer
,
"audio_start_id"
):
audio_start_id
=
tokenizer
.
audio_start_id
audio_end_id
=
tokenizer
.
audio_end_id
im_token_id
=
tokenizer
.
unk_token_id
pixel_values
=
res
[
"pixel_values"
]
pixel_values
=
res
[
"pixel_values"
]
tgt_sizes
=
res
[
"tgt_sizes"
]
tgt_sizes
=
res
[
"tgt_sizes"
]
...
@@ -98,8 +129,6 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
...
@@ -98,8 +129,6 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
f
"
{
len
(
pixel_values
)
}
vs.
{
len
(
tgt_sizes
)
}
"
f
"
{
len
(
pixel_values
)
}
vs.
{
len
(
tgt_sizes
)
}
"
)
)
# tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)]
# tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
pixel_values_flat
:
List
[
torch
.
Tensor
]
=
[]
pixel_values_flat
:
List
[
torch
.
Tensor
]
=
[]
tgt_sizes_flat
:
List
[
torch
.
Tensor
]
=
[]
tgt_sizes_flat
:
List
[
torch
.
Tensor
]
=
[]
for
pixel_b
,
tgt_b
in
zip
(
pixel_values
,
tgt_sizes
):
for
pixel_b
,
tgt_b
in
zip
(
pixel_values
,
tgt_sizes
):
...
@@ -109,21 +138,30 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
...
@@ -109,21 +138,30 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
"Inconsistent N lengths, found: "
f
"
{
len
(
pixel_b
)
}
vs
{
len
(
tgt_b
)
}
"
"Inconsistent N lengths, found: "
f
"
{
len
(
pixel_b
)
}
vs
{
len
(
tgt_b
)
}
"
)
)
for
pixel_n
,
tgt_n
in
zip
(
pixel_b
,
tgt_b
):
for
pixel_n
,
tgt_n
in
zip
(
pixel_b
,
tgt_b
):
# per patch
pixel_values_flat
+=
[
pixel_n
]
pixel_values_flat
+=
[
pixel_n
]
tgt_sizes_flat
+=
[
tgt_n
]
tgt_sizes_flat
+=
[
tgt_n
]
pixel_values
=
pixel_values_flat
pixel_values
=
pixel_values_flat
tgt_sizes
=
torch
.
stack
(
tgt_sizes_flat
)
if
len
(
tgt_sizes_flat
)
==
0
:
tgt_sizes
=
None
else
:
tgt_sizes
=
torch
.
stack
(
tgt_sizes_flat
)
if
not
isinstance
(
res
[
"audio_features"
],
list
):
res
[
"audio_features"
]
=
[
res
[
"audio_features"
]]
return
{
return
{
"input_ids"
:
res
[
"input_ids"
].
flatten
().
tolist
(),
"input_ids"
:
res
[
"input_ids"
].
flatten
().
tolist
(),
"pixel_values"
:
pixel_values
,
"pixel_values"
:
pixel_values
,
"tgt_sizes"
:
tgt_sizes
,
"tgt_sizes"
:
tgt_sizes
,
"
image
_hashes"
:
base_output
.
image
_hashes
,
"
data
_hashes"
:
base_output
.
mm_data
_hashes
,
"modalities"
:
request_obj
.
modalities
or
[
"image"
],
"modalities"
:
request_obj
.
modalities
or
[
"image"
],
"im_start_id"
:
im_start_id
,
"audio_start_id"
:
audio_start_id
,
"audio_end_id"
:
audio_end_id
,
"audio_features"
:
res
[
"audio_features"
],
"audio_bounds"
:
res
[
"audio_bounds"
],
"audio_feature_lens"
:
res
[
"audio_feature_lens"
],
"im_token_id"
:
im_token_id
,
"im_token_id"
:
im_token_id
,
"im_end_id"
:
im_end_id
,
"im_start_id"
:
tokenizer
.
im_start_id
,
"im_end_id"
:
tokenizer
.
im_end_id
,
"slice_start_id"
:
slice_start_id
,
"slice_start_id"
:
slice_start_id
,
"slice_end_id"
:
slice_end_id
,
"slice_end_id"
:
slice_end_id
,
}
}
python/sglang/srt/managers/
image
_processors/mlama.py
→
python/sglang/srt/managers/
multimodal
_processors/mlama.py
View file @
1e86457c
import
asyncio
import
asyncio
from
typing
import
List
,
Union
from
typing
import
List
,
Union
from
sglang.srt.managers.
image_processor
import
BaseImageProcessor
from
sglang.srt.managers.
multimodal_processors.base_processor
import
(
from
sglang.srt.managers.image_processors.base_image_processor
import
(
BaseMultimodalProcessor
,
get_global_processor
,
get_global_processor
,
)
)
from
sglang.srt.models.mllama
import
MllamaForConditionalGeneration
from
sglang.srt.models.mllama
import
MllamaForConditionalGeneration
from
sglang.srt.utils
import
load_image
from
sglang.srt.utils
import
load_image
class
MllamaImageProcessor
(
Base
Image
Processor
):
class
MllamaImageProcessor
(
Base
Multimodal
Processor
):
models
=
[
MllamaForConditionalGeneration
]
models
=
[
MllamaForConditionalGeneration
]
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
...
@@ -34,7 +34,7 @@ class MllamaImageProcessor(BaseImageProcessor):
...
@@ -34,7 +34,7 @@ class MllamaImageProcessor(BaseImageProcessor):
return
image_inputs
return
image_inputs
async
def
process_
images
_async
(
async
def
process_
mm_data
_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
*
args
,
**
kwargs
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
*
args
,
**
kwargs
):
):
if
not
image_data
:
if
not
image_data
:
...
@@ -53,7 +53,7 @@ class MllamaImageProcessor(BaseImageProcessor):
...
@@ -53,7 +53,7 @@ class MllamaImageProcessor(BaseImageProcessor):
images
=
load_image
(
image_data
[
0
])[
0
]
images
=
load_image
(
image_data
[
0
])[
0
]
image_inputs
=
await
self
.
_process_single_image
(
images
,
input_text
)
image_inputs
=
await
self
.
_process_single_image
(
images
,
input_text
)
image_inputs
[
"
image
_hashes"
]
=
[
hash
(
str
(
image_data
))]
image_inputs
[
"
data
_hashes"
]
=
[
hash
(
str
(
image_data
))]
image_inputs
[
"input_ids"
]
=
image_inputs
[
"input_ids"
].
tolist
()[
0
]
image_inputs
[
"input_ids"
]
=
image_inputs
[
"input_ids"
].
tolist
()[
0
]
return
image_inputs
return
image_inputs
python/sglang/srt/managers/
image
_processors/qwen_vl.py
→
python/sglang/srt/managers/
multimodal
_processors/qwen_vl.py
View file @
1e86457c
import
asyncio
import
asyncio
import
math
import
math
import
time
from
typing
import
List
,
Union
from
typing
import
List
,
Union
import
torch
import
torch
from
PIL
import
Image
from
PIL
import
Image
from
sglang.srt.managers.image_processor
import
BaseImageProcessor
from
sglang.srt.managers.multimodal_processor
import
(
from
sglang.srt.managers.image_processors.base_image_processor
import
(
BaseMultimodalProcessor
as
SGLangBaseProcessor
,
)
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
MultimodalSpecialTokens
,
get_global_processor
,
get_global_processor
,
)
)
from
sglang.srt.models.qwen2_5_vl
import
Qwen2_5_VLForConditionalGeneration
from
sglang.srt.models.qwen2_5_vl
import
Qwen2_5_VLForConditionalGeneration
...
@@ -14,7 +18,7 @@ from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
...
@@ -14,7 +18,7 @@ from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
# Compatible with Qwen2VL and Qwen2_5VL
# Compatible with Qwen2VL and Qwen2_5VL
class
Qwen2_5VLImageProcessor
(
BaseImag
eProcessor
):
class
Qwen2_5VLImageProcessor
(
SGLangBas
eProcessor
):
models
=
[
Qwen2VLForConditionalGeneration
,
Qwen2_5_VLForConditionalGeneration
]
models
=
[
Qwen2VLForConditionalGeneration
,
Qwen2_5_VLForConditionalGeneration
]
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
...
@@ -59,7 +63,7 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
...
@@ -59,7 +63,7 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
else
:
else
:
return
self
.
_process_images_task
(
images
,
input_text
,
self
.
hf_config
)
return
self
.
_process_images_task
(
images
,
input_text
,
self
.
hf_config
)
async
def
process_
images
_async
(
async
def
process_
mm_data
_async
(
self
,
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
image_data
:
List
[
Union
[
str
,
bytes
]],
input_ids
,
input_ids
,
...
@@ -68,16 +72,17 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
...
@@ -68,16 +72,17 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
*
args
,
*
args
,
**
kwargs
,
**
kwargs
,
):
):
start
=
time
.
time
()
if
not
image_data
:
if
not
image_data
:
return
None
return
None
if
isinstance
(
image_data
,
str
):
if
isinstance
(
image_data
,
str
):
image_data
=
[
image_data
]
image_data
=
[
image_data
]
image_token
=
self
.
IMAGE_TOKEN
image_token
=
self
.
IMAGE_TOKEN
base_output
=
self
.
load_
images
(
base_output
=
self
.
load_
mm_data
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
image_data
=
image_data
,
image_data
=
image_data
,
image_token
=
image_token
,
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
image_token
)
,
max_req_input_len
=
max_req_input_len
,
max_req_input_len
=
max_req_input_len
,
)
)
...
@@ -139,7 +144,7 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
...
@@ -139,7 +144,7 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return
math
.
floor
(
number
/
factor
)
*
factor
return
math
.
floor
(
number
/
factor
)
*
factor
images
=
[
resize_image
(
image
)
for
image
in
base_output
.
all_fram
es
]
images
=
[
resize_image
(
image
)
for
image
in
base_output
.
imag
es
]
ret
=
await
self
.
_process_single_image
(
ret
=
await
self
.
_process_single_image
(
images
=
images
,
input_text
=
base_output
.
input_text
images
=
images
,
input_text
=
base_output
.
input_text
...
@@ -147,11 +152,10 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
...
@@ -147,11 +152,10 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
image_grid_thws
=
torch
.
concat
([
ret
[
"image_grid_thw"
]])
image_grid_thws
=
torch
.
concat
([
ret
[
"image_grid_thw"
]])
video_grid_thws
=
None
video_grid_thws
=
None
return
{
return
{
"input_ids"
:
ret
[
"input_ids"
].
flatten
().
tolist
(),
"input_ids"
:
ret
[
"input_ids"
].
flatten
().
tolist
(),
"pixel_values"
:
ret
[
"pixel_values"
],
"pixel_values"
:
ret
[
"pixel_values"
],
"
image
_hashes"
:
base_output
.
image
_hashes
,
"
data
_hashes"
:
base_output
.
mm_data
_hashes
,
"modalities"
:
request_obj
.
modalities
or
[
"image"
],
"modalities"
:
request_obj
.
modalities
or
[
"image"
],
"image_grid_thws"
:
image_grid_thws
,
"image_grid_thws"
:
image_grid_thws
,
"video_grid_thws"
:
video_grid_thws
,
"video_grid_thws"
:
video_grid_thws
,
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
1e86457c
...
@@ -144,11 +144,11 @@ class FINISH_ABORT(BaseFinishReason):
...
@@ -144,11 +144,11 @@ class FINISH_ABORT(BaseFinishReason):
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
Image
Inputs
:
class
Multimodal
Inputs
:
"""The image related inputs."""
"""The image related inputs."""
pixel_values
:
Union
[
torch
.
Tensor
,
np
.
array
]
pixel_values
:
Union
[
torch
.
Tensor
,
np
.
array
]
image
_hashes
:
Optional
[
list
]
=
None
data
_hashes
:
Optional
[
list
]
=
None
image_sizes
:
Optional
[
list
]
=
None
image_sizes
:
Optional
[
list
]
=
None
image_offsets
:
Optional
[
list
]
=
None
image_offsets
:
Optional
[
list
]
=
None
image_pad_len
:
Optional
[
list
]
=
None
image_pad_len
:
Optional
[
list
]
=
None
...
@@ -182,20 +182,27 @@ class ImageInputs:
...
@@ -182,20 +182,27 @@ class ImageInputs:
im_end_id
:
Optional
[
int
]
=
None
im_end_id
:
Optional
[
int
]
=
None
slice_start_id
:
Optional
[
int
]
=
None
slice_start_id
:
Optional
[
int
]
=
None
slice_end_id
:
Optional
[
int
]
=
None
slice_end_id
:
Optional
[
int
]
=
None
# [num_images, 2 (w, h)]
tgt_sizes
:
Optional
[
list
]
=
None
tgt_sizes
:
Optional
[
list
]
=
None
# audio
audio_start_id
:
Optional
[
torch
.
Tensor
]
=
None
audio_end_id
:
Optional
[
torch
.
Tensor
]
=
None
audio_features
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
audio_feature_lens
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
@
staticmethod
@
staticmethod
def
from_dict
(
obj
:
dict
):
def
from_dict
(
obj
:
dict
):
ret
=
Image
Inputs
(
ret
=
Multimodal
Inputs
(
pixel_values
=
obj
[
"pixel_values"
],
pixel_values
=
obj
[
"pixel_values"
],
image
_hashes
=
obj
[
"
image
_hashes"
],
data
_hashes
=
obj
[
"
data
_hashes"
],
)
)
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward,
# Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# errors in cuda kernels. See also llava.py for example.
# errors in cuda kernels. See also llava.py for example.
ret
.
pad_values
=
[
x
%
(
1
<<
30
)
for
x
in
ret
.
image
_hashes
]
ret
.
pad_values
=
[
x
%
(
1
<<
30
)
for
x
in
ret
.
data
_hashes
]
optional_args
=
[
optional_args
=
[
"image_sizes"
,
"image_sizes"
,
...
@@ -211,6 +218,10 @@ class ImageInputs:
...
@@ -211,6 +218,10 @@ class ImageInputs:
"slice_start_id"
,
"slice_start_id"
,
"slice_end_id"
,
"slice_end_id"
,
"tgt_sizes"
,
"tgt_sizes"
,
"audio_start_id"
,
"audio_end_id"
,
"audio_features"
,
"audio_feature_lens"
,
]
]
for
arg
in
optional_args
:
for
arg
in
optional_args
:
if
arg
in
obj
:
if
arg
in
obj
:
...
@@ -223,9 +234,19 @@ class ImageInputs:
...
@@ -223,9 +234,19 @@ class ImageInputs:
or
isinstance
(
ret
.
pixel_values
,
list
)
or
isinstance
(
ret
.
pixel_values
,
list
)
)
)
assert
ret
.
audio_features
is
None
or
isinstance
(
ret
.
audio_features
,
list
)
return
ret
return
ret
def
merge
(
self
,
other
:
ImageInputs
):
def
contains_image_inputs
(
self
)
->
bool
:
""" """
return
self
.
pixel_values
is
not
None
and
self
.
pixel_values
!=
[]
def
contains_audio_inputs
(
self
)
->
bool
:
""" """
return
self
.
audio_features
is
not
None
and
self
.
audio_features
!=
[]
def
merge
(
self
,
other
:
MultimodalInputs
):
"""
"""
merge image inputs when requests are being merged
merge image inputs when requests are being merged
"""
"""
...
@@ -268,10 +289,12 @@ class ImageInputs:
...
@@ -268,10 +289,12 @@ class ImageInputs:
# Please note that if the `input_ids` is later used in the model forward,
# Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# errors in cuda kernels. See also llava.py for example.
# errors in cuda kernels. See also llava.py for example.
self
.
image_hashes
+=
other
.
image_hashes
self
.
data_hashes
+=
other
.
data_hashes
self
.
pad_values
=
[
x
%
(
1
<<
30
)
for
x
in
self
.
image_hashes
]
self
.
pad_values
=
[
x
%
(
1
<<
30
)
for
x
in
self
.
data_hashes
]
# args needed to be merged
# args needed to be merged
optional_args
=
[
optional_args
=
[
"audio_features"
,
"image_sizes"
,
"image_sizes"
,
"image_offsets"
,
"image_offsets"
,
"image_pad_len"
,
"image_pad_len"
,
...
@@ -362,7 +385,7 @@ class Req:
...
@@ -362,7 +385,7 @@ class Req:
self
.
decoded_text
=
""
self
.
decoded_text
=
""
# For multimodal inputs
# For multimodal inputs
self
.
image
_inputs
:
Optional
[
Image
Inputs
]
=
None
self
.
multimodal
_inputs
:
Optional
[
Multimodal
Inputs
]
=
None
# Prefix info
# Prefix info
# The indices to kv cache for the shared prefix.
# The indices to kv cache for the shared prefix.
...
@@ -458,10 +481,10 @@ class Req:
...
@@ -458,10 +481,10 @@ class Req:
return
len
(
self
.
origin_input_ids
)
+
len
(
self
.
output_ids
)
return
len
(
self
.
origin_input_ids
)
+
len
(
self
.
output_ids
)
def
extend_image_inputs
(
self
,
image_inputs
):
def
extend_image_inputs
(
self
,
image_inputs
):
if
self
.
image
_inputs
is
None
:
if
self
.
multimodal
_inputs
is
None
:
self
.
image
_inputs
=
image_inputs
self
.
multimodal
_inputs
=
image_inputs
else
:
else
:
self
.
image
_inputs
.
merge
(
image_inputs
)
self
.
multimodal
_inputs
.
merge
(
image_inputs
)
def
finished
(
self
)
->
bool
:
def
finished
(
self
)
->
bool
:
# Whether request reached finished condition
# Whether request reached finished condition
...
@@ -802,7 +825,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -802,7 +825,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
encoder_cached
=
[]
self
.
encoder_cached
=
[]
for
req
in
self
.
reqs
:
for
req
in
self
.
reqs
:
im
=
req
.
image
_inputs
im
=
req
.
multimodal
_inputs
if
im
is
None
or
im
.
num_image_tokens
is
None
:
if
im
is
None
or
im
.
num_image_tokens
is
None
:
# No image input
# No image input
self
.
encoder_lens_cpu
.
append
(
0
)
self
.
encoder_lens_cpu
.
append
(
0
)
...
@@ -1391,7 +1414,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1391,7 +1414,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
extend_seq_lens
=
extend_seq_lens
,
extend_seq_lens
=
extend_seq_lens
,
extend_prefix_lens
=
extend_prefix_lens
,
extend_prefix_lens
=
extend_prefix_lens
,
extend_logprob_start_lens
=
extend_logprob_start_lens
,
extend_logprob_start_lens
=
extend_logprob_start_lens
,
image
_inputs
=
[
r
.
image
_inputs
for
r
in
self
.
reqs
],
multimodal
_inputs
=
[
r
.
multimodal
_inputs
for
r
in
self
.
reqs
],
encoder_cached
=
self
.
encoder_cached
,
encoder_cached
=
self
.
encoder_cached
,
encoder_lens
=
self
.
encoder_lens
,
encoder_lens
=
self
.
encoder_lens
,
encoder_lens_cpu
=
self
.
encoder_lens_cpu
,
encoder_lens_cpu
=
self
.
encoder_lens_cpu
,
...
@@ -1474,7 +1497,7 @@ class ModelWorkerBatch:
...
@@ -1474,7 +1497,7 @@ class ModelWorkerBatch:
extend_input_logprob_token_ids
:
Optional
[
torch
.
Tensor
]
extend_input_logprob_token_ids
:
Optional
[
torch
.
Tensor
]
# For multimodal
# For multimodal
image
_inputs
:
Optional
[
List
[
Image
Inputs
]]
multimodal
_inputs
:
Optional
[
List
[
Multimodal
Inputs
]]
# For encoder-decoder
# For encoder-decoder
encoder_cached
:
Optional
[
List
[
bool
]]
encoder_cached
:
Optional
[
List
[
bool
]]
...
...
python/sglang/srt/managers/scheduler.py
View file @
1e86457c
...
@@ -88,7 +88,7 @@ from sglang.srt.managers.io_struct import (
...
@@ -88,7 +88,7 @@ from sglang.srt.managers.io_struct import (
)
)
from
sglang.srt.managers.schedule_batch
import
(
from
sglang.srt.managers.schedule_batch
import
(
FINISH_ABORT
,
FINISH_ABORT
,
Image
Inputs
,
Multimodal
Inputs
,
Req
,
Req
,
ScheduleBatch
,
ScheduleBatch
,
global_server_args_dict
,
global_server_args_dict
,
...
@@ -841,8 +841,8 @@ class Scheduler(
...
@@ -841,8 +841,8 @@ class Scheduler(
return
return
# Handle multimodal inputs
# Handle multimodal inputs
if
recv_req
.
image
_inputs
is
not
None
:
if
recv_req
.
mm
_inputs
is
not
None
:
image_inputs
=
Image
Inputs
.
from_dict
(
recv_req
.
image
_inputs
)
image_inputs
=
Multimodal
Inputs
.
from_dict
(
recv_req
.
mm
_inputs
)
# Expand a single image token into multiple dummy tokens for receiving image embeddings
# Expand a single image token into multiple dummy tokens for receiving image embeddings
req
.
origin_input_ids
=
self
.
pad_input_ids_func
(
req
.
origin_input_ids
=
self
.
pad_input_ids_func
(
req
.
origin_input_ids
,
image_inputs
req
.
origin_input_ids
,
image_inputs
...
@@ -856,7 +856,7 @@ class Scheduler(
...
@@ -856,7 +856,7 @@ class Scheduler(
)
)
logger
.
error
(
error_msg
)
logger
.
error
(
error_msg
)
req
.
origin_input_ids
=
[
0
]
req
.
origin_input_ids
=
[
0
]
req
.
image
_inputs
=
None
req
.
multimodal
_inputs
=
None
req
.
sampling_params
.
max_new_tokens
=
0
req
.
sampling_params
.
max_new_tokens
=
0
req
.
finished_reason
=
FINISH_ABORT
(
req
.
finished_reason
=
FINISH_ABORT
(
error_msg
,
HTTPStatus
.
BAD_REQUEST
,
"BadRequestError"
error_msg
,
HTTPStatus
.
BAD_REQUEST
,
"BadRequestError"
...
@@ -960,7 +960,7 @@ class Scheduler(
...
@@ -960,7 +960,7 @@ class Scheduler(
# Handle multimodal inputs
# Handle multimodal inputs
if
recv_req
.
image_inputs
is
not
None
:
if
recv_req
.
image_inputs
is
not
None
:
image_inputs
=
Image
Inputs
.
from_dict
(
recv_req
.
image_inputs
)
image_inputs
=
Multimodal
Inputs
.
from_dict
(
recv_req
.
image_inputs
)
# Expand a single image token into multiple dummy tokens for receiving image embeddings
# Expand a single image token into multiple dummy tokens for receiving image embeddings
req
.
origin_input_ids
=
self
.
pad_input_ids_func
(
req
.
origin_input_ids
=
self
.
pad_input_ids_func
(
req
.
origin_input_ids
,
image_inputs
req
.
origin_input_ids
,
image_inputs
...
@@ -974,7 +974,7 @@ class Scheduler(
...
@@ -974,7 +974,7 @@ class Scheduler(
)
)
logger
.
error
(
error_msg
)
logger
.
error
(
error_msg
)
req
.
origin_input_ids
=
[
0
]
req
.
origin_input_ids
=
[
0
]
req
.
image
_inputs
=
None
req
.
multimodal
_inputs
=
None
req
.
sampling_params
.
max_new_tokens
=
0
req
.
sampling_params
.
max_new_tokens
=
0
req
.
finished_reason
=
FINISH_ABORT
(
req
.
finished_reason
=
FINISH_ABORT
(
error_msg
,
HTTPStatus
.
BAD_REQUEST
,
"BadRequestError"
error_msg
,
HTTPStatus
.
BAD_REQUEST
,
"BadRequestError"
...
...
python/sglang/srt/managers/session_controller.py
View file @
1e86457c
...
@@ -138,7 +138,7 @@ class Session:
...
@@ -138,7 +138,7 @@ class Session:
token_ids_logprob
=
req
.
token_ids_logprob
,
token_ids_logprob
=
req
.
token_ids_logprob
,
)
)
if
last_req
is
not
None
:
if
last_req
is
not
None
:
new_req
.
image
_inputs
=
last_req
.
image
_inputs
new_req
.
multimodal
_inputs
=
last_req
.
mm
_inputs
new_req
.
tokenizer
=
tokenizer
new_req
.
tokenizer
=
tokenizer
if
abort
:
if
abort
:
new_req
.
to_abort
=
True
new_req
.
to_abort
=
True
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment