Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bcc213df
Unverified
Commit
bcc213df
authored
Feb 16, 2025
by
Mick
Committed by
GitHub
Feb 16, 2025
Browse files
Model: Support Qwen 2.5 vl (#3258)
parent
39416e39
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1999 additions
and
261 deletions
+1999
-261
docs/references/supported_models.md
docs/references/supported_models.md
+2
-2
python/sglang/lang/chat_template.py
python/sglang/lang/chat_template.py
+8
-0
python/sglang/srt/configs/__init__.py
python/sglang/srt/configs/__init__.py
+6
-3
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+1
-0
python/sglang/srt/configs/qwen2_5_vl_config.py
python/sglang/srt/configs/qwen2_5_vl_config.py
+1003
-0
python/sglang/srt/configs/qwen2vl.py
python/sglang/srt/configs/qwen2vl.py
+0
-130
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+2
-3
python/sglang/srt/managers/image_processor.py
python/sglang/srt/managers/image_processor.py
+217
-122
python/sglang/srt/models/qwen2_5_vl.py
python/sglang/srt/models/qwen2_5_vl.py
+722
-0
python/sglang/srt/models/qwen2_vl.py
python/sglang/srt/models/qwen2_vl.py
+2
-1
test/srt/test_vision_openai_server.py
test/srt/test_vision_openai_server.py
+36
-0
No files found.
docs/references/supported_models.md
View file @
bcc213df
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
-
Llama / Llama 2 / Llama 3 / Llama 3.1 / Llama 3.2
-
Llama / Llama 2 / Llama 3 / Llama 3.1 / Llama 3.2
-
Mistral / Mixtral / Mistral NeMo / Mistral Small 3
-
Mistral / Mixtral / Mistral NeMo / Mistral Small 3
-
Gemma / Gemma 2
-
Gemma / Gemma 2
-
Qwen / Qwen 2 / Qwen 2 MoE / Qwen 2 VL
-
Qwen / Qwen 2 / Qwen 2 MoE / Qwen 2 VL
/ Qwen 2.5 VL
-
DeepSeek / DeepSeek 2 /
[
DeepSeek 3
](
https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3
)
-
DeepSeek / DeepSeek 2 /
[
DeepSeek 3
](
https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3
)
-
OLMoE
-
OLMoE
-
[
LLaVA-OneVision
](
https://llava-vl.github.io/blog/2024-08-05-llava-onevision/
)
-
[
LLaVA-OneVision
](
https://llava-vl.github.io/blog/2024-08-05-llava-onevision/
)
...
@@ -54,7 +54,7 @@ To support a new model in SGLang, you only need to add a single file under [SGLa
...
@@ -54,7 +54,7 @@ To support a new model in SGLang, you only need to add a single file under [SGLa
You can learn from existing model implementations and create new files for the new models.
You can learn from existing model implementations and create new files for the new models.
For most models, you should be able to find a similar model to start with (e.g., starting from Llama).
For most models, you should be able to find a similar model to start with (e.g., starting from Llama).
## How to Support a New v
ision L
LM
## How to Support a New vLM
To support a new vision-language model (vLM) in SGLang, there are several key components in addition to the standard LLM.
To support a new vision-language model (vLM) in SGLang, there are several key components in addition to the standard LLM.
...
...
python/sglang/lang/chat_template.py
View file @
bcc213df
...
@@ -427,6 +427,8 @@ def match_chat_ml(model_path: str):
...
@@ -427,6 +427,8 @@ def match_chat_ml(model_path: str):
if
"tinyllama"
in
model_path
:
if
"tinyllama"
in
model_path
:
return
get_chat_template
(
"chatml"
)
return
get_chat_template
(
"chatml"
)
# Now the suffix for qwen2 chat model is "instruct"
# Now the suffix for qwen2 chat model is "instruct"
if
"qwen"
in
model_path
and
"vl"
in
model_path
:
return
get_chat_template
(
"qwen2-vl"
)
if
"qwen"
in
model_path
:
if
"qwen"
in
model_path
:
if
"vl"
in
model_path
:
if
"vl"
in
model_path
:
return
get_chat_template
(
"qwen2-vl"
)
return
get_chat_template
(
"qwen2-vl"
)
...
@@ -443,6 +445,12 @@ def match_chat_ml(model_path: str):
...
@@ -443,6 +445,12 @@ def match_chat_ml(model_path: str):
return
get_chat_template
(
"chatml-llava"
)
return
get_chat_template
(
"chatml-llava"
)
@
register_chat_template_matching_function
def
match_chat_minicpm
(
model_path
:
str
):
if
"minicpm"
in
model_path
:
return
get_chat_template
(
"minicpmv"
)
@
register_chat_template_matching_function
@
register_chat_template_matching_function
def
match_chat_yi
(
model_path
:
str
):
def
match_chat_yi
(
model_path
:
str
):
model_path
=
model_path
.
lower
()
model_path
=
model_path
.
lower
()
...
...
python/sglang/srt/configs/__init__.py
View file @
bcc213df
from
sglang.srt.configs.chatglm
import
ChatGLMConfig
from
sglang.srt.configs.chatglm
import
ChatGLMConfig
from
sglang.srt.configs.dbrx
import
DbrxConfig
from
sglang.srt.configs.dbrx
import
DbrxConfig
from
sglang.srt.configs.exaone
import
ExaoneConfig
from
sglang.srt.configs.exaone
import
ExaoneConfig
from
sglang.srt.configs.qwen2vl
import
Qwen2VLConfig
,
Qwen2VLVisionConfig
from
sglang.srt.configs.qwen2_5_vl_config
import
(
Qwen2_5_VLConfig
,
Qwen2_5_VLVisionConfig
,
)
__all__
=
[
__all__
=
[
"ExaoneConfig"
,
"ExaoneConfig"
,
"Qwen2VLConfig"
,
"Qwen2VLVisionConfig"
,
"ChatGLMConfig"
,
"ChatGLMConfig"
,
"DbrxConfig"
,
"DbrxConfig"
,
"Qwen2_5_VLConfig"
,
"Qwen2_5_VLVisionConfig"
,
]
]
python/sglang/srt/configs/model_config.py
View file @
bcc213df
...
@@ -403,6 +403,7 @@ def is_multimodal_model(model_architectures: List[str]):
...
@@ -403,6 +403,7 @@ def is_multimodal_model(model_architectures: List[str]):
or
"LlavaVidForCausalLM"
in
model_architectures
or
"LlavaVidForCausalLM"
in
model_architectures
or
"MllamaForConditionalGeneration"
in
model_architectures
or
"MllamaForConditionalGeneration"
in
model_architectures
or
"Qwen2VLForConditionalGeneration"
in
model_architectures
or
"Qwen2VLForConditionalGeneration"
in
model_architectures
or
"Qwen2_5_VLForConditionalGeneration"
in
model_architectures
or
"MiniCPMV"
in
model_architectures
or
"MiniCPMV"
in
model_architectures
):
):
return
True
return
True
...
...
python/sglang/srt/configs/qwen2_5_vl_config.py
0 → 100644
View file @
bcc213df
# coding=utf-8
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Qwen2VL model configuration"""
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Union
import
numpy
as
np
from
transformers
import
(
AutoImageProcessor
,
AutoProcessor
,
BaseImageProcessor
,
BatchFeature
,
PretrainedConfig
,
ProcessorMixin
,
TensorType
,
)
from
transformers.image_transforms
import
(
convert_to_rgb
,
normalize
,
rescale
,
resize
,
to_channel_dimension_format
,
)
from
transformers.image_utils
import
(
ChannelDimension
,
ImageInput
,
PILImageResampling
,
VideoInput
,
get_image_size
,
infer_channel_dimension_format
,
is_pil_image
,
is_valid_image
,
make_list_of_images
,
to_numpy_array
,
valid_images
,
validate_preprocess_arguments
,
)
from
transformers.modeling_rope_utils
import
rope_config_validation
from
transformers.models.mllama.image_processing_mllama
import
is_valid_list_of_images
from
transformers.models.qwen2_vl.image_processing_qwen2_vl
import
smart_resize
from
transformers.processing_utils
import
ProcessingKwargs
,
Unpack
,
VideosKwargs
from
transformers.tokenization_utils_base
import
PreTokenizedInput
,
TextInput
from
transformers.utils.constants
import
OPENAI_CLIP_MEAN
,
OPENAI_CLIP_STD
class
Qwen2_5_VLVisionConfig
(
PretrainedConfig
):
model_type
=
"qwen2_5_vl"
base_config_key
=
"vision_config"
def
__init__
(
self
,
depth
=
32
,
hidden_size
=
3584
,
hidden_act
=
"silu"
,
intermediate_size
=
3420
,
num_heads
=
16
,
in_channels
=
3
,
patch_size
=
14
,
spatial_merge_size
=
2
,
temporal_patch_size
=
2
,
tokens_per_second
=
4
,
window_size
=
112
,
out_hidden_size
=
3584
,
fullatt_block_indexes
=
[
7
,
15
,
23
,
31
],
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
depth
=
depth
self
.
hidden_size
=
hidden_size
self
.
hidden_act
=
hidden_act
self
.
intermediate_size
=
intermediate_size
self
.
num_heads
=
num_heads
self
.
in_channels
=
in_channels
self
.
patch_size
=
patch_size
self
.
spatial_merge_size
=
spatial_merge_size
self
.
temporal_patch_size
=
temporal_patch_size
self
.
tokens_per_second
=
tokens_per_second
self
.
window_size
=
window_size
self
.
fullatt_block_indexes
=
fullatt_block_indexes
self
.
out_hidden_size
=
out_hidden_size
class
Qwen2_5_VLConfig
(
PretrainedConfig
):
r
"""
This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 152064):
Vocabulary size of the Qwen2_5_VL model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Qwen2_5_VLModel`]
hidden_size (`int`, *optional*, defaults to 8192):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 29568):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 80):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 64):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 1000000.0):
The base period of the RoPE embeddings.
use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
max_window_layers (`int`, *optional*, defaults to 80):
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
vision_config (`Dict`, *optional*):
The config for the visual encoder initialization.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
```python
>>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig
>>> # Initializing a Qwen2_5_VL style configuration
>>> configuration = Qwen2_5_VLConfig()
>>> # Initializing a model from the Qwen2-VL-7B style configuration
>>> model = Qwen2_5_VLForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type
=
"qwen2_5_vl"
sub_configs
=
{
"vision_config"
:
Qwen2_5_VLVisionConfig
}
keys_to_ignore_at_inference
=
[
"past_key_values"
]
# Default tensor parallel plan for base model `Qwen2_5_VL`
base_model_tp_plan
=
{
"layers.*.self_attn.q_proj"
:
"colwise"
,
"layers.*.self_attn.k_proj"
:
"colwise"
,
"layers.*.self_attn.v_proj"
:
"colwise"
,
"layers.*.self_attn.o_proj"
:
"rowwise"
,
"layers.*.mlp.gate_proj"
:
"colwise"
,
"layers.*.mlp.up_proj"
:
"colwise"
,
"layers.*.mlp.down_proj"
:
"rowwise"
,
}
def
__init__
(
self
,
vocab_size
=
152064
,
hidden_size
=
8192
,
intermediate_size
=
29568
,
num_hidden_layers
=
80
,
num_attention_heads
=
64
,
num_key_value_heads
=
8
,
hidden_act
=
"silu"
,
max_position_embeddings
=
32768
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-05
,
use_cache
=
True
,
tie_word_embeddings
=
False
,
rope_theta
=
1000000.0
,
use_sliding_window
=
False
,
sliding_window
=
4096
,
max_window_layers
=
80
,
attention_dropout
=
0.0
,
vision_config
=
None
,
rope_scaling
=
None
,
**
kwargs
,
):
if
isinstance
(
vision_config
,
dict
):
self
.
vision_config
=
self
.
sub_configs
[
"vision_config"
](
**
vision_config
)
elif
vision_config
is
None
:
self
.
vision_config
=
self
.
sub_configs
[
"vision_config"
]()
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
use_sliding_window
=
use_sliding_window
self
.
sliding_window
=
sliding_window
self
.
max_window_layers
=
max_window_layers
# for backward compatibility
if
num_key_value_heads
is
None
:
num_key_value_heads
=
num_attention_heads
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
self
.
attention_dropout
=
attention_dropout
self
.
rope_scaling
=
rope_scaling
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
# and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations
# one can set it to "linear"/"dynamic" etc. to have scaled RoPE
# TODO: @raushan update config in the hub
if
self
.
rope_scaling
is
not
None
and
"type"
in
self
.
rope_scaling
:
if
self
.
rope_scaling
[
"type"
]
==
"mrope"
:
self
.
rope_scaling
[
"type"
]
=
"default"
self
.
rope_scaling
[
"rope_type"
]
=
self
.
rope_scaling
[
"type"
]
rope_config_validation
(
self
,
ignore_keys
=
{
"mrope_section"
})
super
().
__init__
(
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
)
# FIXME: workaround of obsolete transformers version
class
Qwen2_5_VLVideosProcessorKwargs
(
VideosKwargs
,
total
=
False
):
fps
:
Union
[
List
[
float
],
float
]
class
Qwen2_5_VLProcessorKwargs
(
ProcessingKwargs
,
total
=
False
):
videos_kwargs
:
Qwen2_5_VLVideosProcessorKwargs
_defaults
=
{
"text_kwargs"
:
{
"padding"
:
False
,
},
"videos_kwargs"
:
{
"fps"
:
2.0
},
}
class
Qwen2_5_VLProcessor
(
ProcessorMixin
):
r
"""
Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor.
[`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the
[`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information.
Args:
image_processor ([`Qwen2VLImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`Qwen2TokenizerFast`], *optional*):
The tokenizer is a required input.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
"""
attributes
=
[
"image_processor"
,
"tokenizer"
]
valid_kwargs
=
[
"chat_template"
]
image_processor_class
=
"AutoImageProcessor"
tokenizer_class
=
(
"Qwen2Tokenizer"
,
"Qwen2TokenizerFast"
)
def
__init__
(
self
,
image_processor
=
None
,
tokenizer
=
None
,
chat_template
=
None
,
**
kwargs
):
self
.
image_token
=
(
"<|image_pad|>"
if
not
hasattr
(
tokenizer
,
"image_token"
)
else
tokenizer
.
image_token
)
self
.
video_token
=
(
"<|video_pad|>"
if
not
hasattr
(
tokenizer
,
"video_token"
)
else
tokenizer
.
video_token
)
super
().
__init__
(
image_processor
,
tokenizer
,
chat_template
=
chat_template
)
def
__call__
(
self
,
images
:
ImageInput
=
None
,
text
:
Union
[
TextInput
,
PreTokenizedInput
,
List
[
TextInput
],
List
[
PreTokenizedInput
]
]
=
None
,
videos
:
VideoInput
=
None
,
**
kwargs
:
Unpack
[
Qwen2_5_VLProcessorKwargs
],
)
->
BatchFeature
:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
- **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
- **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
- **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`.
"""
output_kwargs
=
self
.
_merge_kwargs
(
Qwen2_5_VLProcessorKwargs
,
tokenizer_init_kwargs
=
self
.
tokenizer
.
init_kwargs
,
**
kwargs
,
)
if
images
is
not
None
:
image_inputs
=
self
.
image_processor
(
images
=
images
,
videos
=
None
,
**
output_kwargs
[
"images_kwargs"
]
)
image_grid_thw
=
image_inputs
[
"image_grid_thw"
]
else
:
image_inputs
=
{}
image_grid_thw
=
None
if
videos
is
not
None
:
videos_inputs
=
self
.
image_processor
(
images
=
None
,
videos
=
videos
,
**
output_kwargs
[
"images_kwargs"
]
)
video_grid_thw
=
videos_inputs
[
"video_grid_thw"
]
fps
=
output_kwargs
[
"videos_kwargs"
].
pop
(
"fps"
,
2.0
)
if
isinstance
(
fps
,
(
int
,
float
)):
second_per_grid_ts
=
[
self
.
image_processor
.
temporal_patch_size
/
fps
]
*
len
(
video_grid_thw
)
elif
hasattr
(
fps
,
"__len__"
)
and
len
(
fps
)
==
len
(
video_grid_thw
):
second_per_grid_ts
=
[
self
.
image_processor
.
temporal_patch_size
/
tmp
for
tmp
in
fps
]
else
:
raise
ValueError
(
f
"The length of fps (
{
len
(
fps
)
if
hasattr
(
fps
,
'__len__'
)
else
fps
}
) must be equal to the length of video_grid_thw (
{
len
(
video_grid_thw
)
}
) or fps should be a single number."
)
videos_inputs
.
update
({
"second_per_grid_ts"
:
second_per_grid_ts
})
else
:
videos_inputs
=
{}
video_grid_thw
=
None
if
not
isinstance
(
text
,
list
):
text
=
[
text
]
if
image_grid_thw
is
not
None
:
merge_length
=
self
.
image_processor
.
merge_size
**
2
index
=
0
for
i
in
range
(
len
(
text
)):
while
self
.
image_token
in
text
[
i
]:
text
[
i
]
=
text
[
i
].
replace
(
self
.
image_token
,
"<|placeholder|>"
*
(
image_grid_thw
[
index
].
prod
()
//
merge_length
),
1
,
)
index
+=
1
text
[
i
]
=
text
[
i
].
replace
(
"<|placeholder|>"
,
self
.
image_token
)
if
video_grid_thw
is
not
None
:
merge_length
=
self
.
image_processor
.
merge_size
**
2
index
=
0
for
i
in
range
(
len
(
text
)):
while
self
.
video_token
in
text
[
i
]:
text
[
i
]
=
text
[
i
].
replace
(
self
.
video_token
,
"<|placeholder|>"
*
(
video_grid_thw
[
index
].
prod
()
//
merge_length
),
1
,
)
index
+=
1
text
[
i
]
=
text
[
i
].
replace
(
"<|placeholder|>"
,
self
.
video_token
)
text_inputs
=
self
.
tokenizer
(
text
,
**
output_kwargs
[
"text_kwargs"
])
return
BatchFeature
(
data
=
{
**
text_inputs
,
**
image_inputs
,
**
videos_inputs
})
def
batch_decode
(
self
,
*
args
,
**
kwargs
):
"""
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return
self
.
tokenizer
.
batch_decode
(
*
args
,
**
kwargs
)
def
decode
(
self
,
*
args
,
**
kwargs
):
"""
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return
self
.
tokenizer
.
decode
(
*
args
,
**
kwargs
)
def
post_process_image_text_to_text
(
self
,
generated_outputs
):
"""
Post-process the output of the model to decode the text.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
Returns:
`List[str]`: The decoded text.
"""
return
self
.
tokenizer
.
batch_decode
(
generated_outputs
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
,
)
@
property
def
model_input_names
(
self
):
tokenizer_input_names
=
self
.
tokenizer
.
model_input_names
image_processor_input_names
=
self
.
image_processor
.
model_input_names
names_from_processor
=
list
(
dict
.
fromkeys
(
tokenizer_input_names
+
image_processor_input_names
)
)
return
names_from_processor
+
[
"second_per_grid_ts"
]
class
Qwen2_5_VLImageProcessor
(
BaseImageProcessor
):
r
"""
Constructs a Qwen2.5-VL image processor that dynamically resizes images based on the original images.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
Resampling filter to use when resizing the image.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
min_pixels (`int`, *optional*, defaults to `56 * 56`):
The min pixels of the image to resize the image.
max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
The max pixels of the image to resize the image.
patch_size (`int`, *optional*, defaults to 14):
The spacial patch size of the vision encoder.
temporal_patch_size (`int`, *optional*, defaults to 2):
The temporal patch size of the vision encoder.
merge_size (`int`, *optional*, defaults to 2):
The merge size of the vision encoder to llm encoder.
"""
model_input_names
=
[
"pixel_values"
,
"image_grid_thw"
,
"pixel_values_videos"
,
"video_grid_thw"
,
"second_per_grid_ts"
,
]
def
__init__
(
self
,
do_resize
:
bool
=
True
,
resample
:
PILImageResampling
=
PILImageResampling
.
BICUBIC
,
do_rescale
:
bool
=
True
,
rescale_factor
:
Union
[
int
,
float
]
=
1
/
255
,
do_normalize
:
bool
=
True
,
image_mean
:
Optional
[
Union
[
float
,
List
[
float
]]]
=
None
,
image_std
:
Optional
[
Union
[
float
,
List
[
float
]]]
=
None
,
do_convert_rgb
:
bool
=
True
,
min_pixels
:
int
=
56
*
56
,
max_pixels
:
int
=
28
*
28
*
1280
,
patch_size
:
int
=
14
,
temporal_patch_size
:
int
=
2
,
merge_size
:
int
=
2
,
**
kwargs
,
)
->
None
:
super
().
__init__
(
**
kwargs
)
self
.
do_resize
=
do_resize
self
.
resample
=
resample
self
.
do_rescale
=
do_rescale
self
.
rescale_factor
=
rescale_factor
self
.
do_normalize
=
do_normalize
self
.
image_mean
=
image_mean
if
image_mean
is
not
None
else
OPENAI_CLIP_MEAN
self
.
image_std
=
image_std
if
image_std
is
not
None
else
OPENAI_CLIP_STD
self
.
min_pixels
=
min_pixels
self
.
max_pixels
=
max_pixels
self
.
patch_size
=
patch_size
self
.
temporal_patch_size
=
temporal_patch_size
self
.
merge_size
=
merge_size
self
.
size
=
{
"min_pixels"
:
min_pixels
,
"max_pixels"
:
max_pixels
}
self
.
do_convert_rgb
=
do_convert_rgb
def
rescale
(
self
,
image
:
np
.
ndarray
,
scale
:
float
,
data_format
:
Optional
[
Union
[
str
,
ChannelDimension
]]
=
None
,
input_data_format
:
Optional
[
Union
[
str
,
ChannelDimension
]]
=
None
,
**
kwargs
,
)
->
np
.
ndarray
:
"""
Rescale an image by a scale factor. image = image * scale.
Args:
image (`np.ndarray`):
Image to rescale.
scale (`float`):
The scaling factor to rescale pixel values by.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns:
`np.ndarray`: The rescaled image.
"""
return
rescale
(
image
,
scale
=
scale
,
data_format
=
data_format
,
input_data_format
=
input_data_format
,
**
kwargs
,
)
def
normalize
(
self
,
image
:
np
.
ndarray
,
mean
:
Union
[
float
,
Iterable
[
float
]],
std
:
Union
[
float
,
Iterable
[
float
]],
data_format
:
Optional
[
Union
[
str
,
ChannelDimension
]]
=
None
,
input_data_format
:
Optional
[
Union
[
str
,
ChannelDimension
]]
=
None
,
**
kwargs
,
)
->
np
.
ndarray
:
"""
Normalize an image. image = (image - image_mean) / image_std.
Args:
image (`np.ndarray`):
Image to normalize.
mean (`float` or `Iterable[float]`):
Image mean to use for normalization.
std (`float` or `Iterable[float]`):
Image standard deviation to use for normalization.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns:
`np.ndarray`: The normalized image.
"""
return
normalize
(
image
,
mean
=
mean
,
std
=
std
,
data_format
=
data_format
,
input_data_format
=
input_data_format
,
**
kwargs
,
)
def
_preprocess
(
self
,
images
:
Union
[
ImageInput
,
VideoInput
],
do_resize
:
bool
=
None
,
resample
:
PILImageResampling
=
None
,
do_rescale
:
bool
=
None
,
rescale_factor
:
float
=
None
,
do_normalize
:
bool
=
None
,
image_mean
:
Optional
[
Union
[
float
,
List
[
float
]]]
=
None
,
image_std
:
Optional
[
Union
[
float
,
List
[
float
]]]
=
None
,
do_convert_rgb
:
bool
=
None
,
data_format
:
Optional
[
ChannelDimension
]
=
ChannelDimension
.
FIRST
,
input_data_format
:
Optional
[
Union
[
str
,
ChannelDimension
]]
=
None
,
):
"""
Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
Args:
images (`ImageInput`):
Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
vision_info (`List[Dict]`, *optional*):
Optional list of dictionaries containing additional information about vision inputs.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Scale factor to use if rescaling the image.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
images
=
make_list_of_images
(
images
)
if
do_convert_rgb
:
images
=
[
convert_to_rgb
(
image
)
for
image
in
images
]
# All transformations expect numpy arrays.
images
=
[
to_numpy_array
(
image
)
for
image
in
images
]
if
input_data_format
is
None
:
# We assume that all images have the same channel dimension format.
input_data_format
=
infer_channel_dimension_format
(
images
[
0
])
height
,
width
=
get_image_size
(
images
[
0
],
channel_dim
=
input_data_format
)
resized_height
,
resized_width
=
height
,
width
processed_images
=
[]
for
image
in
images
:
if
do_resize
:
resized_height
,
resized_width
=
smart_resize
(
height
,
width
,
factor
=
self
.
patch_size
*
self
.
merge_size
,
min_pixels
=
self
.
min_pixels
,
max_pixels
=
self
.
max_pixels
,
)
image
=
resize
(
image
,
size
=
(
resized_height
,
resized_width
),
resample
=
resample
,
input_data_format
=
input_data_format
,
)
if
do_rescale
:
image
=
self
.
rescale
(
image
,
scale
=
rescale_factor
,
input_data_format
=
input_data_format
)
if
do_normalize
:
image
=
self
.
normalize
(
image
=
image
,
mean
=
image_mean
,
std
=
image_std
,
input_data_format
=
input_data_format
,
)
image
=
to_channel_dimension_format
(
image
,
data_format
,
input_channel_dim
=
input_data_format
)
processed_images
.
append
(
image
)
patches
=
np
.
array
(
processed_images
)
if
data_format
==
ChannelDimension
.
LAST
:
patches
=
patches
.
transpose
(
0
,
3
,
1
,
2
)
if
patches
.
shape
[
0
]
%
self
.
temporal_patch_size
!=
0
:
repeats
=
np
.
repeat
(
patches
[
-
1
][
np
.
newaxis
],
self
.
temporal_patch_size
-
1
,
axis
=
0
)
patches
=
np
.
concatenate
([
patches
,
repeats
],
axis
=
0
)
channel
=
patches
.
shape
[
1
]
grid_t
=
patches
.
shape
[
0
]
//
self
.
temporal_patch_size
grid_h
,
grid_w
=
(
resized_height
//
self
.
patch_size
,
resized_width
//
self
.
patch_size
,
)
patches
=
patches
.
reshape
(
grid_t
,
self
.
temporal_patch_size
,
channel
,
grid_h
//
self
.
merge_size
,
self
.
merge_size
,
self
.
patch_size
,
grid_w
//
self
.
merge_size
,
self
.
merge_size
,
self
.
patch_size
,
)
patches
=
patches
.
transpose
(
0
,
3
,
6
,
4
,
7
,
2
,
1
,
5
,
8
)
flatten_patches
=
patches
.
reshape
(
grid_t
*
grid_h
*
grid_w
,
channel
*
self
.
temporal_patch_size
*
self
.
patch_size
*
self
.
patch_size
,
)
return
flatten_patches
,
(
grid_t
,
grid_h
,
grid_w
)
def
preprocess
(
self
,
images
:
ImageInput
,
videos
:
VideoInput
=
None
,
do_resize
:
bool
=
None
,
size
:
Dict
[
str
,
int
]
=
None
,
resample
:
PILImageResampling
=
None
,
do_rescale
:
bool
=
None
,
rescale_factor
:
float
=
None
,
do_normalize
:
bool
=
None
,
image_mean
:
Optional
[
Union
[
float
,
List
[
float
]]]
=
None
,
image_std
:
Optional
[
Union
[
float
,
List
[
float
]]]
=
None
,
do_convert_rgb
:
bool
=
None
,
return_tensors
:
Optional
[
Union
[
str
,
TensorType
]]
=
None
,
data_format
:
Optional
[
ChannelDimension
]
=
ChannelDimension
.
FIRST
,
input_data_format
:
Optional
[
Union
[
str
,
ChannelDimension
]]
=
None
,
):
"""
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
videos (`VideoInput`):
Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If
passing in videos with pixel values between 0 and 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
the longest edge resized to keep the input aspect ratio.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
do_resize
=
do_resize
if
do_resize
is
not
None
else
self
.
do_resize
size
=
size
if
size
is
not
None
else
self
.
size
resample
=
resample
if
resample
is
not
None
else
self
.
resample
do_rescale
=
do_rescale
if
do_rescale
is
not
None
else
self
.
do_rescale
rescale_factor
=
(
rescale_factor
if
rescale_factor
is
not
None
else
self
.
rescale_factor
)
do_normalize
=
do_normalize
if
do_normalize
is
not
None
else
self
.
do_normalize
image_mean
=
image_mean
if
image_mean
is
not
None
else
self
.
image_mean
image_std
=
image_std
if
image_std
is
not
None
else
self
.
image_std
do_convert_rgb
=
(
do_convert_rgb
if
do_convert_rgb
is
not
None
else
self
.
do_convert_rgb
)
def
make_flat_list_of_images
(
images
:
Union
[
List
[
ImageInput
],
ImageInput
],
)
->
ImageInput
:
"""
Ensure that the output is a flat list of images. If the input is a single image, it is converted to a list of length 1.
If the input is a nested list of images, it is converted to a flat list of images.
Args:
images (`Union[List[ImageInput], ImageInput]`):
The input image.
Returns:
list: A list of images or a 4d array of images.
"""
# If the input is a nested list of images, we flatten it
if
(
isinstance
(
images
,
(
list
,
tuple
))
and
all
(
isinstance
(
images_i
,
(
list
,
tuple
))
for
images_i
in
images
)
and
all
(
is_valid_list_of_images
(
images_i
)
for
images_i
in
images
)
):
return
[
img
for
img_list
in
images
for
img
in
img_list
]
if
isinstance
(
images
,
(
list
,
tuple
))
and
is_valid_list_of_images
(
images
):
if
is_pil_image
(
images
[
0
])
or
images
[
0
].
ndim
==
3
:
return
images
if
images
[
0
].
ndim
==
4
:
return
[
img
for
img_list
in
images
for
img
in
img_list
]
if
is_valid_image
(
images
):
if
is_pil_image
(
images
)
or
images
.
ndim
==
3
:
return
[
images
]
if
images
.
ndim
==
4
:
return
list
(
images
)
raise
ValueError
(
f
"Could not make a flat list of images from
{
images
}
"
)
def
make_batched_videos
(
videos
)
->
VideoInput
:
"""
Ensure that the input is a list of videos.
Args:
videos (`VideoInput`):
Video or videos to turn into a list of videos.
Returns:
list: A list of videos.
"""
if
(
isinstance
(
videos
,
(
list
,
tuple
))
and
isinstance
(
videos
[
0
],
(
list
,
tuple
))
and
is_valid_image
(
videos
[
0
][
0
])
):
# case 1: nested batch of videos so we flatten it
if
not
is_pil_image
(
videos
[
0
][
0
])
and
videos
[
0
][
0
].
ndim
==
4
:
videos
=
[
[
video
for
batch_list
in
batched_videos
for
video
in
batch_list
]
for
batched_videos
in
videos
]
# case 2: list of videos represented as list of video frames
return
videos
elif
isinstance
(
videos
,
(
list
,
tuple
))
and
is_valid_image
(
videos
[
0
]):
if
is_pil_image
(
videos
[
0
])
or
videos
[
0
].
ndim
==
3
:
return
[
videos
]
elif
videos
[
0
].
ndim
==
4
:
return
[
list
(
video
)
for
video
in
videos
]
elif
is_valid_image
(
videos
):
if
is_pil_image
(
videos
)
or
videos
.
ndim
==
3
:
return
[[
videos
]]
elif
videos
.
ndim
==
4
:
return
[
list
(
videos
)]
raise
ValueError
(
f
"Could not make batched video from
{
videos
}
"
)
if
images
is
not
None
:
images
=
make_flat_list_of_images
(
images
)
if
videos
is
not
None
:
videos
=
make_batched_videos
(
videos
)
if
images
is
not
None
and
not
valid_images
(
images
):
raise
ValueError
(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
validate_preprocess_arguments
(
rescale_factor
=
rescale_factor
,
do_normalize
=
do_normalize
,
image_mean
=
image_mean
,
image_std
=
image_std
,
do_resize
=
do_resize
,
size
=
size
,
resample
=
resample
,
)
if
images
is
not
None
:
pixel_values
,
vision_grid_thws
=
[],
[]
for
image
in
images
:
patches
,
image_grid_thw
=
self
.
_preprocess
(
image
,
do_resize
=
do_resize
,
resample
=
resample
,
do_rescale
=
do_rescale
,
rescale_factor
=
rescale_factor
,
do_normalize
=
do_normalize
,
image_mean
=
image_mean
,
image_std
=
image_std
,
data_format
=
data_format
,
do_convert_rgb
=
do_convert_rgb
,
input_data_format
=
input_data_format
,
)
pixel_values
.
extend
(
patches
)
vision_grid_thws
.
append
(
image_grid_thw
)
pixel_values
=
np
.
array
(
pixel_values
)
vision_grid_thws
=
np
.
array
(
vision_grid_thws
)
data
=
{
"pixel_values"
:
pixel_values
,
"image_grid_thw"
:
vision_grid_thws
}
if
videos
is
not
None
:
pixel_values
,
vision_grid_thws
=
[],
[]
for
images
in
videos
:
patches
,
video_grid_thw
=
self
.
_preprocess
(
images
,
do_resize
=
do_resize
,
resample
=
resample
,
do_rescale
=
do_rescale
,
rescale_factor
=
rescale_factor
,
do_normalize
=
do_normalize
,
image_mean
=
image_mean
,
image_std
=
image_std
,
data_format
=
data_format
,
do_convert_rgb
=
do_convert_rgb
,
input_data_format
=
input_data_format
,
)
pixel_values
.
extend
(
patches
)
vision_grid_thws
.
append
(
video_grid_thw
)
pixel_values
=
np
.
array
(
pixel_values
)
vision_grid_thws
=
np
.
array
(
vision_grid_thws
)
data
=
{
"pixel_values_videos"
:
pixel_values
,
"video_grid_thw"
:
vision_grid_thws
,
}
return
BatchFeature
(
data
=
data
,
tensor_type
=
return_tensors
)
AutoImageProcessor
.
register
(
Qwen2_5_VLConfig
,
Qwen2_5_VLImageProcessor
)
AutoProcessor
.
register
(
Qwen2_5_VLConfig
,
Qwen2_5_VLProcessor
)
python/sglang/srt/configs/qwen2vl.py
deleted
100644 → 0
View file @
39416e39
# coding=utf-8
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Qwen2VL model configuration"""
import
os
from
typing
import
Union
from
transformers
import
PretrainedConfig
class
Qwen2VLVisionConfig
(
PretrainedConfig
):
model_type
=
"qwen2_vl"
def
__init__
(
self
,
depth
=
32
,
embed_dim
=
1280
,
hidden_size
=
3584
,
hidden_act
=
"quick_gelu"
,
mlp_ratio
=
4
,
num_heads
=
16
,
in_channels
=
3
,
patch_size
=
14
,
spatial_merge_size
=
2
,
temporal_patch_size
=
2
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
depth
=
depth
self
.
embed_dim
=
embed_dim
self
.
hidden_size
=
hidden_size
self
.
hidden_act
=
hidden_act
self
.
mlp_ratio
=
mlp_ratio
self
.
num_heads
=
num_heads
self
.
in_channels
=
in_channels
self
.
patch_size
=
patch_size
self
.
spatial_merge_size
=
spatial_merge_size
self
.
temporal_patch_size
=
temporal_patch_size
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
**
kwargs
)
->
"PretrainedConfig"
:
cls
.
_set_token_in_kwargs
(
kwargs
)
config_dict
,
kwargs
=
cls
.
get_config_dict
(
pretrained_model_name_or_path
,
**
kwargs
)
if
config_dict
.
get
(
"model_type"
)
==
"qwen2_vl"
:
config_dict
=
config_dict
[
"vision_config"
]
return
cls
.
from_dict
(
config_dict
,
**
kwargs
)
class
Qwen2VLConfig
(
PretrainedConfig
):
model_type
=
"qwen2_vl"
def
__init__
(
self
,
vocab_size
=
152064
,
hidden_size
=
8192
,
intermediate_size
=
29568
,
num_hidden_layers
=
80
,
num_attention_heads
=
64
,
num_key_value_heads
=
8
,
hidden_act
=
"silu"
,
max_position_embeddings
=
32768
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-05
,
use_cache
=
True
,
tie_word_embeddings
=
False
,
rope_theta
=
1000000.0
,
use_sliding_window
=
False
,
sliding_window
=
4096
,
max_window_layers
=
80
,
attention_dropout
=
0.0
,
vision_config
=
None
,
rope_scaling
=
None
,
**
kwargs
,
):
if
isinstance
(
vision_config
,
dict
):
self
.
vision_config
=
Qwen2VLVisionConfig
(
**
vision_config
)
elif
vision_config
is
None
:
self
.
vision_config
=
Qwen2VLVisionConfig
()
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
use_sliding_window
=
use_sliding_window
self
.
sliding_window
=
sliding_window
self
.
max_window_layers
=
max_window_layers
# for backward compatibility
if
num_key_value_heads
is
None
:
num_key_value_heads
=
num_attention_heads
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
self
.
attention_dropout
=
attention_dropout
self
.
rope_scaling
=
rope_scaling
# NOTE(HandH1998): This is necessary for configuring the `rope_type`` of qwen2vl models after removing dependencies on vllm.
if
self
.
rope_scaling
is
not
None
and
"type"
in
self
.
rope_scaling
:
if
self
.
rope_scaling
[
"type"
]
==
"mrope"
:
self
.
rope_scaling
[
"type"
]
=
"default"
self
.
rope_scaling
[
"rope_type"
]
=
self
.
rope_scaling
[
"type"
]
super
().
__init__
(
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
)
python/sglang/srt/hf_transformers_utils.py
View file @
bcc213df
...
@@ -30,16 +30,15 @@ from transformers import (
...
@@ -30,16 +30,15 @@ from transformers import (
)
)
from
transformers.models.auto.modeling_auto
import
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from
transformers.models.auto.modeling_auto
import
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from
sglang.srt.configs
import
ChatGLMConfig
,
DbrxConfig
,
ExaoneConfig
,
Qwen2VLConfig
from
sglang.srt.configs
import
ChatGLMConfig
,
DbrxConfig
,
ExaoneConfig
,
Qwen2
_5_
VLConfig
_CONFIG_REGISTRY
:
Dict
[
str
,
Type
[
PretrainedConfig
]]
=
{
_CONFIG_REGISTRY
:
Dict
[
str
,
Type
[
PretrainedConfig
]]
=
{
ChatGLMConfig
.
model_type
:
ChatGLMConfig
,
ChatGLMConfig
.
model_type
:
ChatGLMConfig
,
DbrxConfig
.
model_type
:
DbrxConfig
,
DbrxConfig
.
model_type
:
DbrxConfig
,
ExaoneConfig
.
model_type
:
ExaoneConfig
,
ExaoneConfig
.
model_type
:
ExaoneConfig
,
Qwen2VLConfig
.
model_type
:
Qwen2VLConfig
,
Qwen2
_5_
VLConfig
.
model_type
:
Qwen2
_5_
VLConfig
,
}
}
for
name
,
cls
in
_CONFIG_REGISTRY
.
items
():
for
name
,
cls
in
_CONFIG_REGISTRY
.
items
():
with
contextlib
.
suppress
(
ValueError
):
with
contextlib
.
suppress
(
ValueError
):
AutoConfig
.
register
(
name
,
cls
)
AutoConfig
.
register
(
name
,
cls
)
...
...
python/sglang/srt/managers/image_processor.py
View file @
bcc213df
# TODO: also move pad_input_ids into this module
# TODO: also move pad_input_ids into this module
import
asyncio
import
asyncio
import
concurrent.futures
import
concurrent.futures
import
dataclasses
import
logging
import
logging
import
multiprocessing
as
mp
import
multiprocessing
as
mp
import
os
import
os
...
@@ -8,6 +9,7 @@ from abc import ABC, abstractmethod
...
@@ -8,6 +9,7 @@ from abc import ABC, abstractmethod
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
import
PIL
import
transformers
import
transformers
from
decord
import
VideoReader
,
cpu
from
decord
import
VideoReader
,
cpu
from
PIL
import
Image
from
PIL
import
Image
...
@@ -34,11 +36,22 @@ def init_global_processor(server_args: ServerArgs):
...
@@ -34,11 +36,22 @@ def init_global_processor(server_args: ServerArgs):
)
)
@
dataclasses
.
dataclass
class
BaseImageProcessorOutput
:
image_hashes
:
list
[
int
]
image_sizes
:
list
[
int
]
all_frames
:
[
PIL
.
Image
]
# input_text, with each frame of video/image represented with a image_token
input_text
:
str
class
BaseImageProcessor
(
ABC
):
class
BaseImageProcessor
(
ABC
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
self
.
hf_config
=
hf_config
self
.
hf_config
=
hf_config
self
.
_processor
=
_processor
self
.
_processor
=
_processor
self
.
server_args
=
server_args
self
.
server_args
=
server_args
# FIXME: not accurate, model and image specific
self
.
NUM_TOKEN_PER_FRAME
=
330
self
.
executor
=
concurrent
.
futures
.
ProcessPoolExecutor
(
self
.
executor
=
concurrent
.
futures
.
ProcessPoolExecutor
(
initializer
=
init_global_processor
,
initializer
=
init_global_processor
,
...
@@ -48,9 +61,128 @@ class BaseImageProcessor(ABC):
...
@@ -48,9 +61,128 @@ class BaseImageProcessor(ABC):
)
)
@
abstractmethod
@
abstractmethod
async
def
process_images_async
(
self
,
image_data
,
input_text
,
**
kwargs
):
async
def
process_images_async
(
self
,
image_data
,
input_text
,
max_req_input_len
,
**
kwargs
):
pass
pass
def
get_estimated_frames_list
(
self
,
image_data
):
"""
estimate the total frame count from all visual input
"""
# Before processing inputs
estimated_frames_list
=
[]
for
image
in
image_data
:
if
isinstance
(
image
,
str
)
and
image
.
startswith
(
"video:"
):
path
=
image
[
len
(
"video:"
)
:]
# Estimate frames for the video
vr
=
VideoReader
(
path
,
ctx
=
cpu
(
0
))
num_frames
=
len
(
vr
)
else
:
# For images, each contributes one frame
num_frames
=
1
estimated_frames_list
.
append
(
num_frames
)
return
estimated_frames_list
def
encode_video
(
self
,
video_path
,
frame_count_limit
=
None
):
if
not
os
.
path
.
exists
(
video_path
):
logger
.
error
(
f
"Video
{
video_path
}
does not exist"
)
return
[]
if
frame_count_limit
==
0
:
return
[]
def
uniform_sample
(
l
,
n
):
gap
=
len
(
l
)
/
n
idxs
=
[
int
(
i
*
gap
+
gap
/
2
)
for
i
in
range
(
n
)]
return
[
l
[
i
]
for
i
in
idxs
]
vr
=
VideoReader
(
video_path
,
ctx
=
cpu
(
0
))
sample_fps
=
round
(
vr
.
get_avg_fps
()
/
1
)
# FPS
frame_idx
=
[
i
for
i
in
range
(
0
,
len
(
vr
),
sample_fps
)]
if
frame_count_limit
is
not
None
and
len
(
frame_idx
)
>
frame_count_limit
:
frame_idx
=
uniform_sample
(
frame_idx
,
frame_count_limit
)
frames
=
vr
.
get_batch
(
frame_idx
).
asnumpy
()
frames
=
[
Image
.
fromarray
(
v
.
astype
(
"uint8"
))
for
v
in
frames
]
return
frames
def
load_images
(
self
,
max_req_input_len
:
int
,
input_ids
:
list
,
image_data
,
image_token
:
str
,
)
->
BaseImageProcessorOutput
:
"""
Each frame of video/image will be replaced by a single image token
"""
image_hashes
,
image_sizes
=
[],
[]
all_frames
=
[]
new_text_parts
=
[]
if
isinstance
(
input_ids
,
list
):
assert
len
(
input_ids
)
and
isinstance
(
input_ids
[
0
],
int
)
input_text
=
self
.
_processor
.
tokenizer
.
decode
(
input_ids
)
else
:
input_text
=
input_ids
text_parts
=
input_text
.
split
(
image_token
)
# roughly calculate the max number of frames under the max_req_input_len limit
def
calculate_max_num_frames
()
->
int
:
ret
=
(
max_req_input_len
-
len
(
input_ids
))
//
self
.
NUM_TOKEN_PER_FRAME
return
min
(
ret
,
100
)
MAX_NUM_FRAMES
=
calculate_max_num_frames
()
estimated_frames_list
=
self
.
get_estimated_frames_list
(
image_data
=
image_data
)
total_frame_count
=
sum
(
estimated_frames_list
)
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
scaling_factor
=
min
(
1.0
,
MAX_NUM_FRAMES
/
total_frame_count
)
# Process each input with allocated frames
for
image_index
,
(
image
,
estimated_frames
)
in
enumerate
(
zip
(
image_data
,
estimated_frames_list
)
):
if
len
(
all_frames
)
>=
MAX_NUM_FRAMES
:
frames_to_process
=
0
else
:
frames_to_process
=
max
(
1
,
int
(
estimated_frames
*
scaling_factor
))
if
frames_to_process
==
0
:
frames
=
[]
else
:
try
:
if
isinstance
(
image
,
str
)
and
image
.
startswith
(
"video:"
):
path
=
image
[
len
(
"video:"
)
:]
frames
=
self
.
encode_video
(
path
,
frame_count_limit
=
frames_to_process
)
else
:
raw_image
,
_size
=
load_image
(
image
)
frames
=
[
raw_image
]
if
len
(
frames
)
==
0
:
continue
except
FileNotFoundError
as
e
:
print
(
e
)
return
None
image_sizes
+=
frames
[
0
].
size
*
len
(
frames
)
image_hashes
+=
[
hash
(
image
)]
*
len
(
frames
)
all_frames
+=
frames
new_text_parts
.
append
(
text_parts
[
image_index
])
if
frames_to_process
!=
0
:
new_text_parts
.
append
(
image_token
*
len
(
frames
))
assert
frames_to_process
==
len
(
frames
)
new_text_parts
.
append
(
text_parts
[
-
1
])
input_text
=
""
.
join
(
new_text_parts
)
return
BaseImageProcessorOutput
(
image_hashes
,
image_sizes
,
all_frames
,
input_text
)
class
DummyImageProcessor
(
BaseImageProcessor
):
class
DummyImageProcessor
(
BaseImageProcessor
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -248,9 +380,9 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
...
@@ -248,9 +380,9 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
text
=
input_text
,
images
=
images
,
return_tensors
=
"pt"
text
=
input_text
,
images
=
images
,
return_tensors
=
"pt"
)
)
return
{
return
{
"input_ids"
:
result
[
"
input_ids
"
]
,
"input_ids"
:
result
.
input_ids
,
"pixel_values"
:
result
[
"
pixel_values
"
]
,
"pixel_values"
:
result
.
pixel_values
,
"tgt_sizes"
:
result
[
"
tgt_sizes
"
]
,
"tgt_sizes"
:
result
.
tgt_sizes
,
}
}
async
def
_process_images
(
self
,
images
,
input_text
):
async
def
_process_images
(
self
,
images
,
input_text
):
...
@@ -278,124 +410,20 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
...
@@ -278,124 +410,20 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
):
):
if
not
image_data
:
if
not
image_data
:
return
None
return
None
if
not
isinstance
(
image_data
,
list
):
if
not
isinstance
(
image_data
,
list
):
image_data
=
[
image_data
]
image_data
=
[
image_data
]
image_hashes
,
image_sizes
=
[],
[]
base_output
=
self
.
load_images
(
all_frames
=
[]
max_req_input_len
,
input_ids
,
image_data
,
self
.
IMAGE_TOKEN
)
# roughly calculate the max number of frames under the max_req_input_len limit
if
base_output
is
None
:
def
calculate_max_num_frames
()
->
int
:
return
None
# Model-specific
NUM_TOKEN_PER_FRAME
=
330
ret
=
(
max_req_input_len
-
len
(
input_ids
))
//
NUM_TOKEN_PER_FRAME
return
min
(
ret
,
100
)
MAX_NUM_FRAMES
=
calculate_max_num_frames
()
# print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
def
get_estimated_frames_list
():
"""
estimate the total frame count from all visual input
"""
# Before processing inputs
estimated_frames_list
=
[]
for
image
in
image_data
:
if
isinstance
(
image
,
str
)
and
image
.
startswith
(
"video:"
):
path
=
image
[
len
(
"video:"
)
:]
# Estimate frames for the video
vr
=
VideoReader
(
path
,
ctx
=
cpu
(
0
))
num_frames
=
len
(
vr
)
else
:
# For images, each contributes one frame
num_frames
=
1
estimated_frames_list
.
append
(
num_frames
)
return
estimated_frames_list
estimated_frames_list
=
get_estimated_frames_list
()
total_frame_count
=
sum
(
estimated_frames_list
)
scaling_factor
=
min
(
1.0
,
MAX_NUM_FRAMES
/
total_frame_count
)
def
encode_video
(
video_path
,
frame_count_limit
=
None
):
if
not
os
.
path
.
exists
(
video_path
):
logger
.
error
(
f
"Video
{
video_path
}
does not exist"
)
return
[]
if
frame_count_limit
==
0
:
return
[]
def
uniform_sample
(
l
,
n
):
gap
=
len
(
l
)
/
n
idxs
=
[
int
(
i
*
gap
+
gap
/
2
)
for
i
in
range
(
n
)]
return
[
l
[
i
]
for
i
in
idxs
]
vr
=
VideoReader
(
video_path
,
ctx
=
cpu
(
0
))
sample_fps
=
round
(
vr
.
get_avg_fps
()
/
1
)
# FPS
frame_idx
=
[
i
for
i
in
range
(
0
,
len
(
vr
),
sample_fps
)]
if
frame_count_limit
is
not
None
and
len
(
frame_idx
)
>
frame_count_limit
:
frame_idx
=
uniform_sample
(
frame_idx
,
frame_count_limit
)
frames
=
vr
.
get_batch
(
frame_idx
).
asnumpy
()
frames
=
[
Image
.
fromarray
(
v
.
astype
(
"uint8"
))
for
v
in
frames
]
return
frames
if
isinstance
(
input_ids
,
list
):
assert
len
(
input_ids
)
and
isinstance
(
input_ids
[
0
],
int
)
input_text
=
self
.
_processor
.
tokenizer
.
decode
(
input_ids
)
else
:
input_text
=
input_ids
# MiniCPMV requires each frame of video as a single image token
text_parts
=
input_text
.
split
(
self
.
IMAGE_TOKEN
)
new_text_parts
=
[]
# Process each input with allocated frames
for
image_index
,
(
image
,
estimated_frames
)
in
enumerate
(
zip
(
image_data
,
estimated_frames_list
)
):
if
len
(
all_frames
)
>=
MAX_NUM_FRAMES
:
frames_to_process
=
0
else
:
frames_to_process
=
max
(
1
,
int
(
estimated_frames
*
scaling_factor
))
if
frames_to_process
==
0
:
frames
=
[]
else
:
try
:
if
isinstance
(
image
,
str
)
and
image
.
startswith
(
"video:"
):
path
=
image
[
len
(
"video:"
)
:]
frames
=
encode_video
(
path
,
frame_count_limit
=
frames_to_process
)
else
:
raw_image
,
_size
=
load_image
(
image
)
frames
=
[
raw_image
]
if
len
(
frames
)
==
0
:
continue
except
FileNotFoundError
as
e
:
print
(
e
)
return
None
image_sizes
+=
frames
[
0
].
size
*
len
(
frames
)
image_hashes
+=
[
hash
(
image
)]
*
len
(
frames
)
all_frames
+=
frames
assert
frames_to_process
==
len
(
frames
)
new_text_parts
.
append
(
text_parts
[
image_index
])
if
frames_to_process
!=
0
:
new_text_parts
.
append
(
self
.
IMAGE_TOKEN
*
len
(
frames
))
new_text_parts
.
append
(
text_parts
[
-
1
])
input_text
=
""
.
join
(
new_text_parts
)
if
len
(
all_frames
)
==
0
:
if
len
(
base_output
.
all_frames
)
==
0
:
return
None
return
None
res
=
await
self
.
_process_images
(
images
=
all_frames
,
input_text
=
input_text
)
res
=
await
self
.
_process_images
(
pixel_values
=
res
[
"pixel_values"
]
images
=
base_output
.
all_frames
,
input_text
=
base_output
.
input_text
tgt_sizes
=
res
[
"tgt_sizes"
]
)
input_ids
=
res
[
"input_ids"
]
# Collect special token ids
# Collect special token ids
tokenizer
=
self
.
_processor
.
tokenizer
tokenizer
=
self
.
_processor
.
tokenizer
...
@@ -405,10 +433,10 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
...
@@ -405,10 +433,10 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
slice_start_id
=
[
tokenizer
.
slice_start_id
]
slice_start_id
=
[
tokenizer
.
slice_start_id
]
slice_end_id
=
[
tokenizer
.
slice_end_id
]
slice_end_id
=
[
tokenizer
.
slice_end_id
]
return
{
return
{
"input_ids"
:
input_ids
.
flatten
().
tolist
(),
"input_ids"
:
res
[
"
input_ids
"
]
.
flatten
().
tolist
(),
"pixel_values"
:
pixel_values
,
"pixel_values"
:
res
[
"
pixel_values
"
]
,
"tgt_sizes"
:
tgt_sizes
,
"tgt_sizes"
:
res
[
"
tgt_sizes
"
]
,
"image_hashes"
:
image_hashes
,
"image_hashes"
:
base_output
.
image_hashes
,
"modalities"
:
request_obj
.
modalities
or
[
"image"
],
"modalities"
:
request_obj
.
modalities
or
[
"image"
],
"im_start_id"
:
im_start_id
,
"im_start_id"
:
im_start_id
,
"im_end_id"
:
im_end_id
,
"im_end_id"
:
im_end_id
,
...
@@ -536,13 +564,80 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
...
@@ -536,13 +564,80 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
}
}
class
Qwen2_5VLImageProcessor
(
BaseImageProcessor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
self
.
IMAGE_TOKEN
=
"<|vision_start|><|image_pad|><|vision_end|>"
self
.
IM_START_TOKEN_ID
=
hf_config
.
vision_start_token_id
self
.
IM_END_TOKEN_ID
=
hf_config
.
vision_end_token_id
self
.
NUM_TOKEN_PER_FRAME
=
770
@
staticmethod
def
_process_images_task
(
images
,
input_text
):
result
=
global_processor
.
__call__
(
text
=
input_text
,
images
=
images
,
return_tensors
=
"pt"
)
return
{
"input_ids"
:
result
.
input_ids
,
"pixel_values"
:
result
.
pixel_values
,
"image_grid_thws"
:
result
.
image_grid_thw
,
}
async
def
_process_images
(
self
,
images
,
input_text
)
->
dict
:
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
return
await
loop
.
run_in_executor
(
self
.
executor
,
Qwen2_5VLImageProcessor
.
_process_images_task
,
images
,
input_text
,
)
else
:
return
self
.
_process_images_task
(
images
,
input_text
)
async
def
process_images_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_ids
,
request_obj
,
max_req_input_len
,
*
args
,
**
kwargs
,
):
if
not
image_data
:
return
None
if
isinstance
(
image_data
,
str
):
image_data
=
[
image_data
]
image_token
=
self
.
IMAGE_TOKEN
base_output
=
self
.
load_images
(
max_req_input_len
,
input_ids
,
image_data
,
image_token
)
ret
=
await
self
.
_process_images
(
base_output
.
all_frames
,
base_output
.
input_text
)
return
{
"input_ids"
:
ret
[
"input_ids"
].
flatten
().
tolist
(),
"pixel_values"
:
ret
[
"pixel_values"
],
"image_hashes"
:
base_output
.
image_hashes
,
"modalities"
:
request_obj
.
modalities
or
[
"image"
],
"image_grid_thws"
:
ret
[
"image_grid_thws"
],
"im_start_id"
:
self
.
IM_START_TOKEN_ID
,
"im_end_id"
:
self
.
IM_END_TOKEN_ID
,
}
def
get_image_processor
(
def
get_image_processor
(
hf_config
,
server_args
:
ServerArgs
,
processor
hf_config
,
server_args
:
ServerArgs
,
processor
)
->
BaseImageProcessor
:
)
->
BaseImageProcessor
:
if
"MllamaForConditionalGeneration"
in
hf_config
.
architectures
:
if
"MllamaForConditionalGeneration"
in
hf_config
.
architectures
:
return
MllamaImageProcessor
(
hf_config
,
server_args
,
processor
)
return
MllamaImageProcessor
(
hf_config
,
server_args
,
processor
)
elif
"Qwen2VLForConditionalGeneration"
in
hf_config
.
architectures
:
elif
"Qwen2VLForConditionalGeneration"
in
hf_config
.
architectures
:
return
Qwen2VLImageProcessor
(
hf_config
,
server_args
,
processor
.
image_processor
)
return
Qwen2VLImageProcessor
(
hf_config
,
server_args
,
processor
)
elif
"Qwen2_5_VLForConditionalGeneration"
in
hf_config
.
architectures
:
return
Qwen2_5VLImageProcessor
(
hf_config
,
server_args
,
processor
)
elif
"MiniCPMV"
in
hf_config
.
architectures
:
elif
"MiniCPMV"
in
hf_config
.
architectures
:
return
MiniCPMVImageProcessor
(
hf_config
,
server_args
,
processor
)
return
MiniCPMVImageProcessor
(
hf_config
,
server_args
,
processor
)
else
:
else
:
...
...
python/sglang/srt/models/qwen2_5_vl.py
0 → 100644
View file @
bcc213df
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
import
logging
from
functools
import
lru_cache
,
partial
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Type
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
transformers
import
AutoModel
,
Qwen2VLConfig
from
transformers.activations
import
ACT2FN
from
transformers.models.qwen2.modeling_qwen2
import
Qwen2RMSNorm
from
sglang.srt.configs
import
Qwen2_5_VLConfig
,
Qwen2_5_VLVisionConfig
from
sglang.srt.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
sglang.srt.hf_transformers_utils
import
get_processor
from
sglang.srt.layers.attention.vision
import
VisionAttention
from
sglang.srt.layers.linear
import
ColumnParallelLinear
,
RowParallelLinear
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2
import
Qwen2Model
from
sglang.srt.models.qwen2_vl
import
Qwen2VLImageInputs
,
Qwen2VLVideoInputs
logger
=
logging
.
getLogger
(
__name__
)
class
Qwen2_5_VLMLP
(
nn
.
Module
):
def
__init__
(
self
,
in_features
:
int
,
hidden_features
:
int
=
None
,
bias
:
bool
=
True
,
hidden_act
=
"silu"
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
gate_proj
=
ColumnParallelLinear
(
in_features
,
hidden_features
,
bias
=
bias
,
quant_config
=
quant_config
)
self
.
up_proj
=
ColumnParallelLinear
(
in_features
,
hidden_features
,
bias
=
bias
,
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
hidden_features
,
in_features
,
bias
=
bias
,
quant_config
=
quant_config
)
self
.
act
=
ACT2FN
[
hidden_act
]
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x_parallel_gate
,
_
=
self
.
gate_proj
(
x
)
x_parallel_gate
=
self
.
act
(
x_parallel_gate
)
x_parallel_up
,
_
=
self
.
up_proj
(
x
)
x_parallel
=
x_parallel_gate
*
x_parallel_up
x
,
_
=
self
.
down_proj
(
x_parallel
)
return
x
class
Qwen2_5_VisionBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
intermediate_dim
:
int
,
num_heads
:
int
,
hidden_act
=
"silu"
,
norm_layer
:
Type
[
nn
.
Module
]
=
None
,
attn_implementation
:
Optional
[
str
]
=
"sdpa"
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
if
norm_layer
is
None
:
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
self
.
norm1
=
Qwen2RMSNorm
(
dim
,
eps
=
1e-6
)
self
.
norm2
=
Qwen2RMSNorm
(
dim
,
eps
=
1e-6
)
if
attn_implementation
==
"sdpa"
:
use_context_forward
=
False
use_full_precision_softmax
=
False
elif
attn_implementation
==
"flash_attention_2"
:
use_full_precision_softmax
=
False
use_context_forward
=
True
elif
attn_implementation
==
"eager"
:
use_full_precision_softmax
=
True
use_context_forward
=
False
self
.
attn
=
VisionAttention
(
embed_dim
=
dim
,
num_heads
=
num_heads
,
projection_size
=
dim
,
use_qkv_parallel
=
False
,
use_context_forward
=
use_context_forward
,
use_full_precision_softmax
=
use_full_precision_softmax
,
flatten_batch
=
True
,
quant_config
=
quant_config
,
)
self
.
mlp
=
Qwen2_5_VLMLP
(
dim
,
intermediate_dim
,
hidden_act
=
hidden_act
,
quant_config
=
quant_config
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
rotary_pos_emb
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
norm1
(
x
)
hidden_states
=
rearrange
(
hidden_states
,
"s b ... -> b s ..."
)
attn
=
self
.
attn
(
hidden_states
,
cu_seqlens
=
cu_seqlens
,
rotary_pos_emb
=
rotary_pos_emb
)
attn
=
rearrange
(
attn
,
"b s ... -> s b ..."
)
x
=
x
+
attn
norm2
=
self
.
norm2
(
x
)
mlp
=
self
.
mlp
(
norm2
)
x
=
x
+
mlp
return
x
class
Qwen2_5_VisionPatchEmbed
(
nn
.
Module
):
def
__init__
(
self
,
patch_size
:
int
=
14
,
temporal_patch_size
:
int
=
2
,
in_chans
:
int
=
3
,
embed_dim
:
int
=
1152
,
)
->
None
:
super
().
__init__
()
self
.
patch_size
=
patch_size
self
.
temporal_patch_size
=
temporal_patch_size
self
.
embed_dim
=
embed_dim
kernel_size
=
[
temporal_patch_size
,
patch_size
,
patch_size
]
self
.
proj
=
nn
.
Conv3d
(
in_chans
,
embed_dim
,
kernel_size
=
kernel_size
,
stride
=
kernel_size
,
bias
=
False
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
L
,
C
=
x
.
shape
x
=
x
.
view
(
L
,
-
1
,
self
.
temporal_patch_size
,
self
.
patch_size
,
self
.
patch_size
)
x
=
self
.
proj
(
x
).
view
(
L
,
self
.
embed_dim
)
return
x
class
Qwen2_5_VisionPatchMerger
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
context_dim
:
int
,
spatial_merge_size
:
int
=
2
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
context_dim
*
(
spatial_merge_size
**
2
)
self
.
ln_q
=
Qwen2RMSNorm
(
context_dim
,
eps
=
1e-6
)
self
.
mlp
=
nn
.
ModuleList
(
[
ColumnParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
),
nn
.
GELU
(),
RowParallelLinear
(
self
.
hidden_size
,
dim
,
bias
=
True
,
quant_config
=
quant_config
),
]
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
self
.
ln_q
(
x
)
x
=
x
.
view
(
-
1
,
self
.
hidden_size
)
mlp_fc1
,
mlp_act
,
mlp_fc2
=
self
.
mlp
x_parallel
,
_
=
mlp_fc1
(
x
)
x_parallel
=
mlp_act
(
x_parallel
)
out
,
_
=
mlp_fc2
(
x_parallel
)
return
out
class
Qwen2_5_VisionRotaryEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
theta
:
float
=
10000.0
)
->
None
:
super
().
__init__
()
self
.
dim
=
dim
self
.
theta
=
theta
inv_freq
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
,
dtype
=
torch
.
float
)
/
dim
))
self
.
register_buffer
(
"inv_freq"
,
inv_freq
,
persistent
=
False
)
self
.
_seq_len_cached
=
0
self
.
_freqs_cached
=
None
def
update_freqs_cache
(
self
,
seqlen
:
int
)
->
None
:
if
seqlen
>
self
.
_seq_len_cached
:
seqlen
*=
2
self
.
_seq_len_cached
=
seqlen
self
.
inv_freq
=
1.0
/
(
self
.
theta
**
(
torch
.
arange
(
0
,
self
.
dim
,
2
,
dtype
=
torch
.
float
,
device
=
self
.
inv_freq
.
device
)
/
self
.
dim
)
)
seq
=
torch
.
arange
(
seqlen
,
device
=
self
.
inv_freq
.
device
,
dtype
=
self
.
inv_freq
.
dtype
)
freqs
=
torch
.
outer
(
seq
,
self
.
inv_freq
)
self
.
_freqs_cached
=
freqs
def
forward
(
self
,
seqlen
:
int
)
->
torch
.
Tensor
:
self
.
update_freqs_cache
(
seqlen
)
return
self
.
_freqs_cached
[:
seqlen
]
class
Qwen2_5_VisionTransformer
(
nn
.
Module
):
def
__init__
(
self
,
vision_config
:
Qwen2_5_VLVisionConfig
,
norm_eps
:
float
=
1e-6
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
patch_size
:
int
=
vision_config
.
patch_size
temporal_patch_size
:
int
=
vision_config
.
temporal_patch_size
spatial_merge_size
:
int
=
vision_config
.
spatial_merge_size
self
.
spatial_merge_size
=
spatial_merge_size
self
.
spatial_merge_unit
:
int
=
spatial_merge_size
*
spatial_merge_size
in_chans
:
int
=
vision_config
.
in_chans
hidden_size
:
int
=
vision_config
.
hidden_size
depth
:
int
=
vision_config
.
depth
num_heads
:
int
=
vision_config
.
num_heads
self
.
fullatt_block_indexes
=
vision_config
.
fullatt_block_indexes
self
.
window_size
=
vision_config
.
window_size
self
.
patch_size
=
vision_config
.
patch_size
mlp_hidden_size
:
int
=
vision_config
.
intermediate_size
self
.
patch_embed
=
Qwen2_5_VisionPatchEmbed
(
patch_size
=
patch_size
,
temporal_patch_size
=
temporal_patch_size
,
in_chans
=
in_chans
,
embed_dim
=
hidden_size
,
)
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
norm_eps
)
head_dim
=
hidden_size
//
num_heads
self
.
rotary_pos_emb
=
Qwen2_5_VisionRotaryEmbedding
(
head_dim
//
2
)
self
.
blocks
=
nn
.
ModuleList
(
[
Qwen2_5_VisionBlock
(
dim
=
hidden_size
,
intermediate_dim
=
mlp_hidden_size
,
num_heads
=
num_heads
,
hidden_act
=
vision_config
.
hidden_act
,
norm_layer
=
norm_layer
,
attn_implementation
=
"sdpa"
,
quant_config
=
quant_config
,
)
for
_
in
range
(
depth
)
]
)
self
.
merger
=
Qwen2_5_VisionPatchMerger
(
dim
=
vision_config
.
out_hidden_size
,
context_dim
=
hidden_size
,
spatial_merge_size
=
spatial_merge_size
,
quant_config
=
quant_config
,
)
def
get_window_index
(
self
,
grid_thw
):
window_index
:
list
=
[]
cu_window_seqlens
:
list
=
[
0
]
window_index_id
=
0
vit_merger_window_size
=
(
self
.
window_size
//
self
.
spatial_merge_size
//
self
.
patch_size
)
for
grid_t
,
grid_h
,
grid_w
in
grid_thw
:
llm_grid_h
,
llm_grid_w
=
(
grid_h
//
self
.
spatial_merge_size
,
grid_w
//
self
.
spatial_merge_size
,
)
index
=
torch
.
arange
(
grid_t
*
llm_grid_h
*
llm_grid_w
).
reshape
(
grid_t
,
llm_grid_h
,
llm_grid_w
)
pad_h
=
vit_merger_window_size
-
llm_grid_h
%
vit_merger_window_size
pad_w
=
vit_merger_window_size
-
llm_grid_w
%
vit_merger_window_size
num_windows_h
=
(
llm_grid_h
+
pad_h
)
//
vit_merger_window_size
num_windows_w
=
(
llm_grid_w
+
pad_w
)
//
vit_merger_window_size
index_padded
=
F
.
pad
(
index
,
(
0
,
pad_w
,
0
,
pad_h
),
"constant"
,
-
100
)
index_padded
=
index_padded
.
reshape
(
grid_t
,
num_windows_h
,
vit_merger_window_size
,
num_windows_w
,
vit_merger_window_size
,
)
index_padded
=
index_padded
.
permute
(
0
,
1
,
3
,
2
,
4
).
reshape
(
grid_t
,
num_windows_h
*
num_windows_w
,
vit_merger_window_size
,
vit_merger_window_size
,
)
seqlens
=
(
index_padded
!=
-
100
).
sum
([
2
,
3
]).
reshape
(
-
1
)
index_padded
=
index_padded
.
reshape
(
-
1
)
index_new
=
index_padded
[
index_padded
!=
-
100
]
window_index
.
append
(
index_new
+
window_index_id
)
cu_seqlens_tmp
=
(
seqlens
.
cumsum
(
0
)
*
self
.
spatial_merge_unit
+
cu_window_seqlens
[
-
1
]
)
cu_window_seqlens
.
extend
(
cu_seqlens_tmp
.
tolist
())
window_index_id
+=
(
grid_t
*
llm_grid_h
*
llm_grid_w
).
item
()
window_index
=
torch
.
cat
(
window_index
,
dim
=
0
)
return
window_index
,
cu_window_seqlens
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
return
self
.
blocks
[
0
].
mlp
.
gate_proj
.
weight
.
dtype
@
property
def
device
(
self
)
->
torch
.
device
:
return
self
.
blocks
[
0
].
mlp
.
gate_proj
.
weight
.
device
def
rot_pos_emb
(
self
,
grid_thw
:
torch
.
Tensor
)
->
torch
.
Tensor
:
pos_ids
=
[]
for
t
,
h
,
w
in
grid_thw
:
hpos_ids
=
torch
.
arange
(
h
).
unsqueeze
(
1
).
expand
(
-
1
,
w
)
wpos_ids
=
torch
.
arange
(
w
).
unsqueeze
(
0
).
expand
(
h
,
-
1
)
hpos_ids
=
(
hpos_ids
.
reshape
(
h
//
self
.
spatial_merge_size
,
self
.
spatial_merge_size
,
w
//
self
.
spatial_merge_size
,
self
.
spatial_merge_size
,
)
.
permute
(
0
,
2
,
1
,
3
)
.
flatten
()
)
wpos_ids
=
(
wpos_ids
.
reshape
(
h
//
self
.
spatial_merge_size
,
self
.
spatial_merge_size
,
w
//
self
.
spatial_merge_size
,
self
.
spatial_merge_size
,
)
.
permute
(
0
,
2
,
1
,
3
)
.
flatten
()
)
pos_ids
.
append
(
torch
.
stack
([
hpos_ids
,
wpos_ids
],
dim
=-
1
).
repeat
(
t
,
1
))
pos_ids
=
torch
.
cat
(
pos_ids
,
dim
=
0
)
max_grid_size
=
grid_thw
[:,
1
:].
max
()
rotary_pos_emb_full
=
self
.
rotary_pos_emb
(
max_grid_size
)
rotary_pos_emb
=
rotary_pos_emb_full
[
pos_ids
].
flatten
(
1
)
return
rotary_pos_emb
def
forward
(
self
,
x
:
torch
.
Tensor
,
grid_thw
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# patchify
x
=
x
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
x
=
self
.
patch_embed
(
x
)
# compute position embedding
rotary_pos_emb
=
self
.
rot_pos_emb
(
grid_thw
)
window_index
,
cu_window_seqlens
=
self
.
get_window_index
(
grid_thw
)
cu_window_seqlens
=
torch
.
tensor
(
cu_window_seqlens
,
device
=
x
.
device
,
dtype
=
grid_thw
.
dtype
if
torch
.
jit
.
is_tracing
()
else
torch
.
int32
,
)
cu_window_seqlens
=
torch
.
unique_consecutive
(
cu_window_seqlens
)
seq_len
,
_
=
x
.
size
()
x
=
x
.
reshape
(
seq_len
//
self
.
spatial_merge_unit
,
self
.
spatial_merge_unit
,
-
1
)
x
=
x
[
window_index
,
:,
:]
x
=
x
.
reshape
(
seq_len
,
-
1
)
rotary_pos_emb
=
rotary_pos_emb
.
reshape
(
seq_len
//
self
.
spatial_merge_unit
,
self
.
spatial_merge_unit
,
-
1
)
rotary_pos_emb
=
rotary_pos_emb
[
window_index
,
:,
:]
rotary_pos_emb
=
rotary_pos_emb
.
reshape
(
seq_len
,
-
1
)
# compute cu_seqlens
cu_seqlens
=
torch
.
repeat_interleave
(
grid_thw
[:,
1
]
*
grid_thw
[:,
2
],
grid_thw
[:,
0
]
).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
cu_seqlens
=
F
.
pad
(
cu_seqlens
,
(
1
,
0
),
"constant"
,
0
)
# transformers
x
=
x
.
unsqueeze
(
1
)
for
layer_num
,
blk
in
enumerate
(
self
.
blocks
):
if
layer_num
in
self
.
fullatt_block_indexes
:
cu_seqlens_now
=
cu_seqlens
else
:
cu_seqlens_now
=
cu_window_seqlens
x
=
blk
(
x
,
cu_seqlens
=
cu_seqlens_now
,
rotary_pos_emb
=
rotary_pos_emb
)
# adapter
x
=
self
.
merger
(
x
)
reverse_indices
=
torch
.
argsort
(
window_index
)
x
=
x
[
reverse_indices
,
:]
return
x
cached_get_processor
=
lru_cache
(
get_processor
)
class
Qwen2_5_VLForConditionalGeneration
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Qwen2VLConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
visual
=
Qwen2_5_VisionTransformer
(
config
.
vision_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-6
),
# NOTE: Qwen2-VL vision encoder does not support any
# quantization method now.
quant_config
=
None
,
)
self
.
model
=
Qwen2Model
(
config
,
quant_config
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
def
calculate_num_image_tokens
(
self
,
image_grid_thw
:
Tuple
[
int
,
int
,
int
]):
processor
=
cached_get_processor
(
self
.
config
.
_name_or_path
)
grid_t
,
grid_h
,
grid_w
=
image_grid_thw
num_image_tokens
=
(
grid_t
*
grid_h
*
grid_w
//
processor
.
image_processor
.
merge_size
//
processor
.
image_processor
.
merge_size
)
return
num_image_tokens
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
ImageInputs
):
new_input_ids
=
[]
last_idx
=
0
image_idx
=
-
1
image_inputs
.
image_offsets
=
[]
# Get all special token IDs
im_start_id
=
image_inputs
.
im_start_id
im_end_id
=
image_inputs
.
im_end_id
# Find all start and end positions for both types
start_indices
=
[
i
for
i
,
x
in
enumerate
(
input_ids
)
if
x
==
im_start_id
]
end_indices
=
[
i
for
i
,
x
in
enumerate
(
input_ids
)
if
x
==
im_end_id
]
if
len
(
start_indices
)
!=
len
(
end_indices
):
return
input_ids
# Process each region (both image and slice)
for
start_idx
,
end_idx
in
zip
(
start_indices
,
end_indices
):
# Add non-image tokens before this region
new_input_ids
.
extend
(
input_ids
[
last_idx
:
start_idx
+
1
])
is_image_start
=
input_ids
[
start_idx
]
==
im_start_id
if
is_image_start
:
image_inputs
.
image_offsets
+=
[
start_idx
]
image_idx
+=
1
num_tokens
=
end_idx
-
start_idx
-
1
# exclude start and end tokens
# Generate pad_ids
pad_values
=
[
image_inputs
.
pad_values
[
image_idx
]]
pad_ids
=
pad_values
*
((
num_tokens
+
len
(
pad_values
))
//
len
(
pad_values
))
pad_ids
=
pad_ids
[:
num_tokens
]
# Add pad_ids
new_input_ids
.
extend
(
pad_ids
)
# Update last_idx to after end token
last_idx
=
end_idx
# Add remaining tokens after last region
new_input_ids
.
extend
(
input_ids
[
last_idx
:])
assert
len
(
input_ids
)
==
len
(
new_input_ids
)
return
new_input_ids
def
_process_image_input
(
self
,
image_input
:
Qwen2VLImageInputs
)
->
torch
.
Tensor
:
pixel_values
=
image_input
[
"pixel_values"
].
type
(
self
.
visual
.
dtype
)
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_input
[
"image_grid_thw"
])
return
image_embeds
def
_process_video_input
(
self
,
video_input
:
Qwen2VLVideoInputs
)
->
torch
.
Tensor
:
pixel_values_videos
=
video_input
[
"pixel_values_videos"
].
type
(
self
.
visual
.
dtype
)
video_embeds
=
self
.
visual
(
pixel_values_videos
,
grid_thw
=
video_input
[
"video_grid_thw"
]
)
return
video_embeds
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
get_embedding
:
bool
=
False
,
):
"""Run forward pass for Qwen2_5-VL.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it)
"""
if
getattr
(
self
.
config
,
"rope_scaling"
,
{}).
get
(
"type"
,
None
)
==
"mrope"
:
positions
=
forward_batch
.
mrope_positions
image_inputs
=
None
if
forward_batch
.
image_inputs
is
not
None
:
image_inputs
=
[
img
for
img
in
forward_batch
.
image_inputs
if
img
is
not
None
]
if
(
forward_batch
.
forward_mode
.
is_decode
()
or
image_inputs
is
None
or
len
(
image_inputs
)
==
0
):
inputs_embeds
=
self
.
model
.
embed_tokens
(
input_ids
)
else
:
if
getattr
(
self
.
config
,
"rope_scaling"
,
{}).
get
(
"type"
,
None
)
==
"mrope"
:
assert
positions
.
ndim
==
2
and
positions
.
size
(
0
)
==
3
,
(
"multimodal section rotary embedding requires "
f
"(3, seq_len) positions, but got
{
positions
.
size
()
}
"
)
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids
.
clamp_
(
min
=
0
,
max
=
self
.
config
.
vocab_size
-
1
)
# [B, s, hidden_size]
inputs_embeds
=
self
.
model
.
embed_tokens
(
input_ids
)
extend_start_loc_cpu
=
forward_batch
.
extend_start_loc
.
cpu
().
numpy
()
prefix_lens_cpu
=
forward_batch
.
extend_prefix_lens_cpu
for
i
,
image
in
enumerate
(
forward_batch
.
image_inputs
):
if
image
is
None
:
continue
start_idx
=
extend_start_loc_cpu
[
i
]
prefix_len
=
prefix_lens_cpu
[
i
]
pixel_values
=
image
.
pixel_values
.
clone
().
detach
().
requires_grad_
(
False
)
image_grid_thws
=
torch
.
tensor
(
np
.
array
(
image
.
image_grid_thws
),
device
=
"cuda"
)
image_offsets
=
image
.
image_offsets
image_input
=
Qwen2VLImageInputs
(
pixel_values
=
pixel_values
,
image_grid_thw
=
image_grid_thws
)
image_embeds
=
self
.
_process_image_input
(
image_input
)
image_embeds_offset
=
0
for
idx
,
image_offset
in
enumerate
(
image_offsets
):
if
image_offset
<
prefix_len
:
continue
num_image_tokens
=
self
.
calculate_num_image_tokens
(
image_grid_thws
[
idx
]
)
left_idx
=
start_idx
+
(
image_offset
-
prefix_len
)
right_idx
=
left_idx
+
num_image_tokens
tp_size
=
get_tensor_model_parallel_world_size
()
hidden_size
=
image_embeds
.
shape
[
-
1
]
if
hidden_size
%
tp_size
!=
0
:
padding_size
=
tp_size
-
(
hidden_size
%
tp_size
)
image_embeds
=
F
.
pad
(
image_embeds
,
(
0
,
padding_size
))
inputs_embeds
=
F
.
pad
(
inputs_embeds
,
(
0
,
padding_size
))
hidden_chunk_size
=
image_embeds
.
shape
[
-
1
]
//
tp_size
rank
=
get_tensor_model_parallel_rank
()
start_dim
=
rank
*
hidden_chunk_size
end_dim
=
(
rank
+
1
)
*
hidden_chunk_size
inputs_embeds
[
left_idx
:
right_idx
,
...,
start_dim
:
end_dim
]
=
(
image_embeds
[
image_embeds_offset
:
image_embeds_offset
+
num_image_tokens
,
...,
start_dim
:
end_dim
,
]
)
image_embeds_offset
+=
num_image_tokens
input_ids
=
None
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
forward_batch
=
forward_batch
,
input_embeds
=
inputs_embeds
,
)
if
not
get_embedding
:
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
else
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
if
"visual"
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
if
"visual"
in
name
and
"qkv.weight"
in
name
:
visual_num_heads
=
self
.
config
.
vision_config
.
num_heads
visual_embed_dim
=
self
.
config
.
vision_config
.
hidden_size
head_size
=
visual_embed_dim
//
visual_num_heads
loaded_weight
=
loaded_weight
.
view
(
3
,
visual_num_heads
,
head_size
,
visual_embed_dim
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
visual_embed_dim
)
elif
"visual"
in
name
and
"qkv.bias"
in
name
:
visual_num_heads
=
self
.
config
.
vision_config
.
num_heads
visual_embed_dim
=
self
.
config
.
vision_config
.
hidden_size
head_size
=
visual_embed_dim
//
visual_num_heads
loaded_weight
=
loaded_weight
.
view
(
3
,
visual_num_heads
,
head_size
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
)
if
"visual"
in
name
:
# adapt to VisionAttention
name
=
name
.
replace
(
r
"attn.qkv."
,
r
"attn.qkv_proj."
)
try
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
except
KeyError
:
print
(
params_dict
.
keys
())
raise
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
EntryClass
=
[
Qwen2_5_VLForConditionalGeneration
]
AutoModel
.
register
(
Qwen2_5_VLConfig
,
Qwen2_5_VLForConditionalGeneration
)
python/sglang/srt/models/qwen2_vl.py
View file @
bcc213df
...
@@ -31,8 +31,9 @@ import torch
...
@@ -31,8 +31,9 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
einops
import
rearrange
from
transformers
import
Qwen2VLConfig
from
transformers.models.qwen2_vl.configuration_qwen2_vl
import
Qwen2VLVisionConfig
from
sglang.srt.configs
import
Qwen2VLConfig
,
Qwen2VLVisionConfig
from
sglang.srt.hf_transformers_utils
import
get_processor
from
sglang.srt.hf_transformers_utils
import
get_processor
from
sglang.srt.layers.activation
import
QuickGELU
from
sglang.srt.layers.activation
import
QuickGELU
from
sglang.srt.layers.attention.vision
import
VisionAttention
from
sglang.srt.layers.attention.vision
import
VisionAttention
...
...
test/srt/test_vision_openai_server.py
View file @
bcc213df
...
@@ -252,6 +252,18 @@ class TestOpenAIVisionServer(unittest.TestCase):
...
@@ -252,6 +252,18 @@ class TestOpenAIVisionServer(unittest.TestCase):
print
(
"-"
*
30
)
print
(
"-"
*
30
)
# Add assertions to validate the video response
# Add assertions to validate the video response
assert
"iPod"
in
video_response
or
"device"
in
video_response
,
video_response
assert
(
"man"
in
video_response
or
"person"
in
video_response
or
"individual"
in
video_response
),
video_response
assert
(
"present"
in
video_response
or
"examine"
in
video_response
or
"display"
in
video_response
)
assert
"black"
in
video_response
or
"dark"
in
video_response
self
.
assertIsNotNone
(
video_response
)
self
.
assertIsNotNone
(
video_response
)
self
.
assertGreater
(
len
(
video_response
),
0
)
self
.
assertGreater
(
len
(
video_response
),
0
)
...
@@ -366,6 +378,30 @@ class TestQWen2VLServer(TestOpenAIVisionServer):
...
@@ -366,6 +378,30 @@ class TestQWen2VLServer(TestOpenAIVisionServer):
cls
.
base_url
+=
"/v1"
cls
.
base_url
+=
"/v1"
class
TestQWen2_5_VLServer
(
TestOpenAIVisionServer
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
"Qwen/Qwen2.5-VL-7B-Instruct"
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
api_key
=
"sk-123456"
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
api_key
=
cls
.
api_key
,
other_args
=
[
"--chat-template"
,
"qwen2-vl"
,
# FIXME: workaround to chunked prefill within image embeds
"--chunked-prefill-size"
,
"10000"
,
"--mem-fraction-static"
,
"0.4"
,
],
)
cls
.
base_url
+=
"/v1"
class
TestQWen2VLServerContextLengthIssue
(
unittest
.
TestCase
):
class
TestQWen2VLServerContextLengthIssue
(
unittest
.
TestCase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment