Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0a89b8a3
Commit
0a89b8a3
authored
Feb 22, 2025
by
zhuwenwen
Browse files
support qwen2_5-vl
parent
47bd229c
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
425 additions
and
267 deletions
+425
-267
requirements-common.txt
requirements-common.txt
+1
-1
vllm/model_executor/models/qwen2_5_vl.py
vllm/model_executor/models/qwen2_5_vl.py
+79
-106
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+130
-111
vllm/multimodal/inputs.py
vllm/multimodal/inputs.py
+57
-41
vllm/multimodal/parse.py
vllm/multimodal/parse.py
+62
-2
vllm/transformers_utils/processor.py
vllm/transformers_utils/processor.py
+96
-6
No files found.
requirements-common.txt
View file @
0a89b8a3
...
...
@@ -5,7 +5,7 @@ requests >= 2.26.0
tqdm
blake3
py-cpuinfo
transformers >= 4.4
8.2
# Required for Bamba model and Transformers backend.
transformers >= 4.4
9.0
# Required for Bamba model and Transformers backend.
tokenizers >= 0.19.1 # Required for Llama 3.
protobuf # Required by LlamaTokenizer.
fastapi >= 0.107.0, < 0.113.0; python_version < '3.9'
...
...
vllm/model_executor/models/qwen2_5_vl.py
View file @
0a89b8a3
...
...
@@ -33,18 +33,18 @@ import torch.nn as nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
transformers
import
BatchFeature
from
transformers.models.qwen2_5_vl
import
(
Qwen2_5_VLImageProcessor
,
Qwen2_5_VLProcessor
)
from
transformers.models.qwen2_5_vl
import
Qwen2_5_VLProcessor
from
transformers.models.qwen2_5_vl.configuration_qwen2_5_vl
import
(
Qwen2_5_VLConfig
,
Qwen2_5_VLVisionConfig
)
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
parallel_state
from
vllm.distributed
import
parallel_state
,
tensor_model_parallel_all_gather
from
vllm.distributed
import
utils
as
dist_utils
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.activation
import
_ACTIVATION_REGISTRY
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
...
...
@@ -207,11 +207,12 @@ class Qwen2_5_VisionAttention(nn.Module):
)
->
None
:
super
().
__init__
()
# Per attention head and per partition values.
world_size
=
parallel_state
.
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
parallel_state
.
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
self
.
hidden_size_per_attention_head
=
dist_utils
.
divide
(
projection_size
,
num_heads
)
self
.
num_attention_heads_per_partition
=
dist_utils
.
divide
(
num_heads
,
world
_size
)
num_heads
,
self
.
tp
_size
)
self
.
qkv
=
ColumnParallelLinear
(
input_size
=
embed_dim
,
output_size
=
3
*
projection_size
,
...
...
@@ -231,6 +232,29 @@ class Qwen2_5_VisionAttention(nn.Module):
f
"Qwen2.5-VL does not support
{
self
.
attn_backend
}
backend now."
)
def
split_qkv
(
self
,
qkv
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
...]:
# [s, b, 3 * head * head_dim]
seq_len
,
bs
,
_
=
qkv
.
shape
if
self
.
tp_size
>
1
:
qkv
=
tensor_model_parallel_all_gather
(
qkv
)
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=
2
)
# 3 * [s, b, head * head_dim]
if
self
.
tp_size
>
1
:
splitter
=
partial
(
dist_utils
.
split_tensor_along_last_dim
,
num_partitions
=
self
.
tp_size
)
q
=
splitter
(
q
)[
self
.
tp_rank
]
k
=
splitter
(
k
)[
self
.
tp_rank
]
v
=
splitter
(
v
)[
self
.
tp_rank
]
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
new_shape
=
(
seq_len
,
bs
,
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
)
q
,
k
,
v
=
(
x
.
view
(
*
new_shape
)
for
x
in
(
q
,
k
,
v
))
return
q
,
k
,
v
def
forward
(
self
,
x
:
torch
.
Tensor
,
...
...
@@ -240,22 +264,20 @@ class Qwen2_5_VisionAttention(nn.Module):
# [s, b, c] --> [s, b, head * 3 * head_dim]
x
,
_
=
self
.
qkv
(
x
)
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
new_x_shape
=
x
.
size
()[:
-
1
]
+
(
self
.
num_attention_heads_per_partition
,
3
*
self
.
hidden_size_per_attention_head
,
)
x
=
x
.
view
(
*
new_x_shape
)
# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
q
,
k
,
v
=
dist_utils
.
split_tensor_along_last_dim
(
x
,
3
)
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
q
,
k
,
v
=
self
.
split_qkv
(
x
)
batch_size
=
q
.
shape
[
1
]
q
,
k
,
v
=
(
rearrange
(
x
,
"s b ... -> b s ..."
).
contiguous
()
for
x
in
(
q
,
k
,
v
))
if
rotary_pos_emb
is
not
None
:
q
=
apply_rotary_pos_emb_vision
(
q
,
rotary_pos_emb
)
k
=
apply_rotary_pos_emb_vision
(
k
,
rotary_pos_emb
)
use_flash_attn
=
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
q
=
apply_rotary_pos_emb_vision
(
q
,
rotary_pos_emb
,
use_flash_attn
=
use_flash_attn
)
k
=
apply_rotary_pos_emb_vision
(
k
,
rotary_pos_emb
,
use_flash_attn
=
use_flash_attn
)
if
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
:
# from vllm_flash_attn.flash_attn_interface import (
...
...
@@ -279,20 +301,23 @@ class Qwen2_5_VisionAttention(nn.Module):
"(b s) ... -> b s ..."
,
b
=
batch_size
)
elif
self
.
attn_backend
==
_Backend
.
TORCH_SDPA
:
seq_length
=
q
.
size
(
1
)
q
,
k
,
v
=
(
rearrange
(
x
,
"b s h d -> b h s d"
)
for
x
in
[
q
,
k
,
v
])
attention_mask
=
torch
.
zeros
([
1
,
seq_length
,
seq_length
],
device
=
q
.
device
,
dtype
=
torch
.
bool
)
# Execute attention entry by entry for speed & less VRAM.
outputs
=
[]
for
i
in
range
(
1
,
len
(
cu_seqlens
)):
attention_mask
[...,
cu_seqlens
[
i
-
1
]:
cu_seqlens
[
i
],
cu_seqlens
[
i
-
1
]:
cu_seqlens
[
i
]]
=
True
output
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attention_mask
,
start_idx
=
cu_seqlens
[
i
-
1
]
end_idx
=
cu_seqlens
[
i
]
q_i
=
q
[:,
start_idx
:
end_idx
]
k_i
=
k
[:,
start_idx
:
end_idx
]
v_i
=
v
[:,
start_idx
:
end_idx
]
q_i
,
k_i
,
v_i
=
(
rearrange
(
x
,
"b s h d -> b h s d"
)
for
x
in
[
q_i
,
k_i
,
v_i
])
output_i
=
F
.
scaled_dot_product_attention
(
q_i
,
k_i
,
v_i
,
dropout_p
=
0.0
)
context_layer
=
rearrange
(
output
,
"b h s d -> b s h d "
)
output_i
=
rearrange
(
output_i
,
"b h s d -> b s h d "
)
outputs
.
append
(
output_i
)
context_layer
=
torch
.
cat
(
outputs
,
dim
=
1
)
elif
self
.
attn_backend
==
_Backend
.
XFORMERS
:
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalMask
...
...
@@ -310,25 +335,6 @@ class Qwen2_5_VisionAttention(nn.Module):
return
output
class
Qwen2RMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
):
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
forward
(
self
,
hidden_states
):
input_dtype
=
hidden_states
.
dtype
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
variance
=
hidden_states
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
return
self
.
weight
*
hidden_states
.
to
(
input_dtype
)
def
extra_repr
(
self
):
return
f
"
{
tuple
(
self
.
weight
.
shape
)
}
, eps=
{
self
.
variance_epsilon
}
"
class
Qwen2_5_VisionBlock
(
nn
.
Module
):
def
__init__
(
...
...
@@ -499,8 +505,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
hidden_size
=
self
.
hidden_size
,
)
# NOTE: We use torch native RMSNorm here for precision purposes.
norm_layer
=
partial
(
Qwen2RMSNorm
,
eps
=
norm_eps
)
norm_layer
=
partial
(
RMSNorm
,
eps
=
norm_eps
)
head_dim
=
self
.
hidden_size
//
self
.
num_heads
self
.
rotary_pos_emb
=
Qwen2_5_VisionRotaryEmbedding
(
head_dim
//
2
)
...
...
@@ -665,24 +670,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
if
name
.
endswith
(
"qkv.weight"
):
visual_num_heads
=
self
.
num_heads
visual_embed_dim
=
self
.
hidden_size
head_size
=
visual_embed_dim
//
visual_num_heads
loaded_weight
=
loaded_weight
.
view
(
3
,
visual_num_heads
,
head_size
,
visual_embed_dim
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
visual_embed_dim
)
elif
name
.
endswith
(
"qkv.bias"
):
visual_num_heads
=
self
.
num_heads
visual_embed_dim
=
self
.
hidden_size
head_size
=
visual_embed_dim
//
visual_num_heads
loaded_weight
=
loaded_weight
.
view
(
3
,
visual_num_heads
,
head_size
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
)
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
...
...
@@ -701,39 +688,20 @@ class Qwen2_5_VLProcessingInfo(Qwen2VLProcessingInfo):
*
,
min_pixels
:
Optional
[
int
]
=
None
,
max_pixels
:
Optional
[
int
]
=
None
,
fps
:
Optional
[
float
]
=
2.0
,
size
:
Optional
[
dict
[
str
,
int
]]
=
None
,
fps
:
Optional
[
Union
[
float
,
List
[
float
]]]
=
None
,
**
kwargs
:
object
,
)
->
Qwen2_5_VLProcessor
:
hf_processor
=
self
.
ctx
.
get_hf_processor
(
Qwen2_5_VLProcessor
)
image_processor
=
hf_processor
.
image_processor
# type: ignore
assert
isinstance
(
image_processor
,
Qwen2_5_VLImageProcessor
)
if
min_pixels
:
image_processor
.
min_pixels
=
min_pixels
if
max_pixels
:
image_processor
.
max_pixels
=
max_pixels
if
max_pixels
or
min_pixels
:
image_processor
.
size
=
{
"min_pixels"
:
image_processor
.
min_pixels
,
"max_pixels"
:
image_processor
.
max_pixels
,
}
return
hf_processor
if
fps
is
not
None
:
kwargs
[
"fps"
]
=
fps
def
get_image_processor
(
self
,
*
,
min_pixels
:
Optional
[
int
]
=
None
,
max_pixels
:
Optional
[
int
]
=
None
,
fps
:
Optional
[
float
]
=
2.0
,
)
->
Qwen2_5_VLImageProcessor
:
hf_processor
=
self
.
get_hf_processor
(
min_pixels
=
min_pixels
,
return
self
.
ctx
.
get_hf_processor
(
Qwen2_5_VLProcessor
,
image_processor
=
self
.
get_image_processor
(
min_pixels
=
min_pixels
,
max_pixels
=
max_pixels
,
fps
=
fps
,
size
=
size
),
**
kwargs
,
)
image_processor
=
hf_processor
.
image_processor
# type: ignore
assert
isinstance
(
image_processor
,
Qwen2_5_VLImageProcessor
)
return
image_processor
class
Qwen2_5_VLMultiModalProcessor
(
Qwen2VLMultiModalProcessor
):
...
...
@@ -760,19 +728,23 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
"q_proj"
,
"k_proj"
,
"v_proj"
,
]
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
# LoRA specific attributes, TODO: double check
# LoRA specific attributes
supported_lora_modules
=
[
# language model
"qkv_proj"
,
"o_proj"
,
"gate_up_proj"
,
"down_proj"
,
"gate_proj"
"up_proj"
,
"down_proj"
,
# Same name with vision encoder
# vision tower
"qkv"
,
"gate_proj"
,
"up_proj"
,
"attn.proj"
,
# Distinguish patch_embed.proj
"fc1"
,
"fc2"
,
...
...
@@ -780,6 +752,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
"mlp.0"
,
"mlp.2"
]
embedding_modules
=
{}
embedding_padding_modules
=
[]
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
0a89b8a3
...
...
@@ -58,14 +58,17 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
ImageItem
,
ModalityData
,
MultiModalFieldConfig
,
MultiModalKwargs
,
VideoItem
)
from
vllm.multimodal.parse
import
(
ImageSize
,
ModalityDataItems
,
MultiModalDataItems
,
MultiModalDataParser
)
from
vllm.multimodal.parse
import
(
DictEmbeddingItems
,
ImageSize
,
ModalityDataItems
,
MultiModalDataItems
,
MultiModalDataParser
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.platforms
import
_Backend
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.config
import
uses_mrope
from
vllm.transformers_utils.processor
import
(
cached_image_processor_from_config
)
from
.interfaces
import
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
...
...
@@ -231,11 +234,15 @@ def apply_rotary_emb_torch(x: torch.Tensor,
def
apply_rotary_pos_emb_vision
(
t
:
torch
.
Tensor
,
freqs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
freqs
:
torch
.
Tensor
,
use_flash_attn
=
False
)
->
torch
.
Tensor
:
t_
=
t
.
float
()
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
output
=
apply_rotary_emb_torch
(
t_
,
cos
,
sin
).
type_as
(
t
)
apply_rotary_emb
=
apply_rotary_emb_torch
if
use_flash_attn
:
from
flash_attn.layers.rotary
import
apply_rotary_emb
output
=
apply_rotary_emb
(
t_
,
cos
,
sin
).
type_as
(
t
)
return
output
...
...
@@ -341,20 +348,23 @@ class Qwen2VisionAttention(nn.Module):
"(b s) ... -> b s ..."
,
b
=
batch_size
)
elif
self
.
attn_backend
==
_Backend
.
TORCH_SDPA
:
seq_length
=
q
.
size
(
1
)
q
,
k
,
v
=
(
rearrange
(
x
,
"b s h d -> b h s d"
)
for
x
in
[
q
,
k
,
v
])
attention_mask
=
torch
.
zeros
([
1
,
seq_length
,
seq_length
],
device
=
q
.
device
,
dtype
=
torch
.
bool
)
# Execute attention entry by entry for speed & less VRAM.
outputs
=
[]
for
i
in
range
(
1
,
len
(
cu_seqlens
)):
attention_mask
[...,
cu_seqlens
[
i
-
1
]:
cu_seqlens
[
i
],
cu_seqlens
[
i
-
1
]:
cu_seqlens
[
i
]]
=
True
output
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attention_mask
,
start_idx
=
cu_seqlens
[
i
-
1
]
end_idx
=
cu_seqlens
[
i
]
q_i
=
q
[:,
start_idx
:
end_idx
]
k_i
=
k
[:,
start_idx
:
end_idx
]
v_i
=
v
[:,
start_idx
:
end_idx
]
q_i
,
k_i
,
v_i
=
(
rearrange
(
x
,
"b s h d -> b h s d"
)
for
x
in
[
q_i
,
k_i
,
v_i
])
output_i
=
F
.
scaled_dot_product_attention
(
q_i
,
k_i
,
v_i
,
dropout_p
=
0.0
)
context_layer
=
rearrange
(
output
,
"b h s d -> b s h d "
)
output_i
=
rearrange
(
output_i
,
"b h s d -> b s h d "
)
outputs
.
append
(
output_i
)
context_layer
=
torch
.
cat
(
outputs
,
dim
=
1
)
elif
self
.
attn_backend
==
_Backend
.
XFORMERS
:
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalMask
...
...
@@ -710,49 +720,25 @@ class Qwen2VisionTransformer(nn.Module):
return
loaded_params
class
Qwen2VLEmbeddingItems
(
ModalityDataItems
[
dict
[
str
,
torch
.
Tensor
],
dict
[
str
,
torch
.
Tensor
]]):
def
__init__
(
self
,
data
:
dict
,
modality
:
str
)
->
None
:
super
().
__init__
(
data
,
modality
)
grid_thw
=
data
[
f
"
{
modality
}
_grid_thw"
]
slice_idxs
=
[
0
]
+
grid_thw
.
prod
(
-
1
).
cumsum_
(
0
).
tolist
()
self
.
_slices
=
[
slice
(
slice_idxs
[
i
],
slice_idxs
[
i
+
1
])
for
i
in
range
(
len
(
grid_thw
))
]
def
get_count
(
self
)
->
int
:
return
len
(
self
.
data
[
f
"
{
self
.
modality
}
_grid_thw"
])
def
get
(
self
,
index
:
int
)
->
dict
[
str
,
torch
.
Tensor
]:
out
=
{}
for
k
,
v
in
self
.
data
.
items
():
if
v
!=
f
"
{
self
.
modality
}
_grid_thw"
:
v
=
v
[
self
.
_slices
[
index
]]
out
[
k
]
=
v
return
out
def
get_processor_data
(
self
)
->
Mapping
[
str
,
object
]:
return
{}
def
get_passthrough_data
(
self
)
->
Mapping
[
str
,
object
]:
return
self
.
data
class
Qwen2VLImageEmbeddingItems
(
Qwen2VLEmbeddingItems
):
def
__init__
(
self
,
data
:
dict
)
->
None
:
super
().
__init__
(
data
,
"image"
)
def
_qwen2vl_field_config
(
hf_inputs
:
Mapping
[
str
,
torch
.
Tensor
]):
image_grid_thw
=
hf_inputs
.
get
(
"image_grid_thw"
,
torch
.
empty
((
0
,
3
)))
image_grid_sizes
=
image_grid_thw
.
prod
(
-
1
)
class
Qwen2VLVideoEmbeddingItems
(
Qwen2VLEmbeddingItems
):
video_grid_thw
=
hf_inputs
.
get
(
"video_grid_thw"
,
torch
.
empty
((
0
,
3
)))
video_grid_sizes
=
video_grid_thw
.
prod
(
-
1
)
def
__init__
(
self
,
data
:
dict
)
->
None
:
super
().
__init__
(
data
,
"video"
)
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
image_grid_sizes
),
image_embeds
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
image_grid_sizes
),
image_grid_thw
=
MultiModalFieldConfig
.
batched
(
"image"
),
pixel_values_videos
=
MultiModalFieldConfig
.
flat_from_sizes
(
"video"
,
video_grid_sizes
),
video_embeds
=
MultiModalFieldConfig
.
flat_from_sizes
(
"video"
,
video_grid_sizes
),
video_grid_thw
=
MultiModalFieldConfig
.
batched
(
"video"
),
)
class
Qwen2VLMultiModalDataParser
(
MultiModalDataParser
):
...
...
@@ -762,7 +748,12 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
data
:
Union
[
dict
[
str
,
torch
.
Tensor
],
ModalityData
[
ImageItem
]],
)
->
ModalityDataItems
[
Any
,
Any
]:
if
isinstance
(
data
,
dict
):
return
Qwen2VLEmbeddingItems
(
data
,
modality
=
"image"
)
return
DictEmbeddingItems
(
data
,
modality
=
"image"
,
required_fields
=
{
"image_embeds"
,
"image_grid_thw"
},
fields_factory
=
_qwen2vl_field_config
,
)
return
super
().
_parse_image_data
(
data
)
...
...
@@ -771,7 +762,12 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
data
:
Union
[
dict
[
str
,
torch
.
Tensor
],
ModalityData
[
VideoItem
]],
)
->
ModalityDataItems
[
Any
,
Any
]:
if
isinstance
(
data
,
dict
):
return
Qwen2VLEmbeddingItems
(
data
,
modality
=
"video"
)
return
DictEmbeddingItems
(
data
,
modality
=
"video"
,
required_fields
=
{
"video_embeds"
,
"video_grid_thw"
},
fields_factory
=
_qwen2vl_field_config
,
)
return
super
().
_parse_video_data
(
data
)
...
...
@@ -786,34 +782,64 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
*
,
min_pixels
:
Optional
[
int
]
=
None
,
max_pixels
:
Optional
[
int
]
=
None
,
size
:
Optional
[
dict
[
str
,
int
]]
=
None
,
**
kwargs
:
object
,
)
->
Qwen2VLProcessor
:
hf_processor
=
self
.
ctx
.
get_hf_processor
(
Qwen2VLProcessor
)
image_processor
=
hf_processor
.
image_processor
# type: ignore
assert
isinstance
(
image_processor
,
Qwen2VLImageProcessor
)
if
min_pixels
:
image_processor
.
min_pixels
=
min_pixels
if
max_pixels
:
image_processor
.
max_pixels
=
max_pixels
if
max_pixels
or
min_pixels
:
image_processor
.
size
=
{
"min_pixels"
:
image_processor
.
min_pixels
,
"max_pixels"
:
image_processor
.
max_pixels
,
}
return
self
.
ctx
.
get_hf_processor
(
Qwen2VLProcessor
,
image_processor
=
self
.
get_image_processor
(
min_pixels
=
min_pixels
,
max_pixels
=
max_pixels
,
size
=
size
),
**
kwargs
,
)
return
hf_processor
def
_get_image_processor_kwargs
(
self
,
*
,
min_pixels
:
Optional
[
int
]
=
None
,
max_pixels
:
Optional
[
int
]
=
None
,
size
:
Optional
[
dict
[
str
,
int
]]
=
None
,
**
kwargs
:
object
,
):
if
self
.
ctx
.
model_config
.
mm_processor_kwargs
:
kwargs
.
update
(
self
.
ctx
.
model_config
.
mm_processor_kwargs
)
if
min_pixels
is
not
None
:
kwargs
[
"min_pixels"
]
=
min_pixels
if
size
is
None
:
size
=
{
"shortest_edge"
:
min_pixels
}
else
:
size
[
"shortest_edge"
]
=
min_pixels
if
max_pixels
is
not
None
:
kwargs
[
"max_pixels"
]
=
max_pixels
if
size
is
None
:
size
=
{
"longest_edge"
:
max_pixels
}
else
:
size
[
"longest_edge"
]
=
max_pixels
if
size
is
not
None
:
kwargs
[
"size"
]
=
size
return
kwargs
def
get_image_processor
(
self
,
*
,
min_pixels
:
Optional
[
int
]
=
None
,
max_pixels
:
Optional
[
int
]
=
None
,
size
:
Optional
[
dict
[
str
,
int
]]
=
None
,
**
kwargs
:
object
,
):
hf_processor
=
self
.
get_hf_processor
(
min_pixels
=
min_pixels
,
max_pixels
=
max_pixels
)
image_processor
=
hf_processor
.
image_processor
# type: ignore
assert
isinstance
(
image_processor
,
Qwen2VLImageProcessor
)
return
image_processor
return
cached_image_processor_from_config
(
self
.
ctx
.
model_config
,
**
self
.
_get_image_processor_kwargs
(
min_pixels
=
min_pixels
,
max_pixels
=
max_pixels
,
size
=
size
,
**
kwargs
),
)
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
None
,
"video"
:
None
}
...
...
@@ -860,7 +886,11 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
preprocessed_size
=
ImageSize
(
width
=
image_width
,
height
=
image_height
)
grid_t
=
max
(
num_frames
//
temporal_patch_size
,
1
)
# NOTE: Frames are padded to be divisible by `temporal_patch_size`
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
padded_num_frames
=
num_frames
+
num_frames
%
temporal_patch_size
grid_t
=
max
(
padded_num_frames
//
temporal_patch_size
,
1
)
grid_h
=
preprocessed_size
.
height
//
patch_size
grid_w
=
preprocessed_size
.
width
//
patch_size
...
...
@@ -945,14 +975,10 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
max_image_tokens
=
self
.
get_max_image_tokens
()
*
max_images
max_total_frames
=
self
.
_get_max_video_frames
(
seq_len
-
max_image_tokens
)
num
_frames
=
min
(
max
(
max
_total_frames
//
max
(
max_videos
,
1
),
1
),
max
_frames
_per_video
=
min
(
max_total_frames
//
max
(
max_videos
,
1
),
_MAX_FRAMES_PER_VIDEO
)
# Temporary workaround for https://github.com/huggingface/transformers/issues/35412
if
num_frames
>
1
and
num_frames
%
2
==
1
:
num_frames
+=
1
return
num_frames
return
max
(
max_frames_per_video
,
1
)
def
get_max_video_tokens
(
self
,
seq_len
:
int
)
->
int
:
target_width
,
target_height
=
self
.
get_image_size_with_most_features
()
...
...
@@ -1010,6 +1036,18 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
def
_get_data_parser
(
self
)
->
MultiModalDataParser
:
return
Qwen2VLMultiModalDataParser
()
def
_call_hf_processor
(
self
,
prompt
:
str
,
mm_data
:
Mapping
[
str
,
object
],
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
BatchFeature
:
return
self
.
info
.
ctx
.
call_hf_processor
(
self
.
info
.
get_hf_processor
(
**
mm_kwargs
),
dict
(
text
=
prompt
,
**
mm_data
),
self
.
info
.
_get_image_processor_kwargs
(
**
mm_kwargs
),
)
def
_get_prompt_replacements
(
self
,
mm_items
:
MultiModalDataItems
,
...
...
@@ -1022,8 +1060,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
tokenizer
=
self
.
info
.
get_tokenizer
()
vocab
=
tokenizer
.
get_vocab
()
# NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has
# image_token and video_token registered
placeholder
=
{
"image"
:
vocab
[
hf_processor
.
image_token
],
"video"
:
vocab
[
hf_processor
.
video_token
],
...
...
@@ -1052,24 +1088,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
image_grid_thw
=
hf_inputs
.
get
(
"image_grid_thw"
,
torch
.
empty
((
0
,
3
)))
image_grid_sizes
=
image_grid_thw
.
prod
(
-
1
)
video_grid_thw
=
hf_inputs
.
get
(
"video_grid_thw"
,
torch
.
empty
((
0
,
3
)))
video_grid_sizes
=
video_grid_thw
.
prod
(
-
1
)
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
image_grid_sizes
),
image_embeds
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
image_grid_sizes
),
image_grid_thw
=
MultiModalFieldConfig
.
batched
(
"image"
),
pixel_values_videos
=
MultiModalFieldConfig
.
flat_from_sizes
(
"video"
,
video_grid_sizes
),
video_embeds
=
MultiModalFieldConfig
.
flat_from_sizes
(
"video"
,
video_grid_sizes
),
video_grid_thw
=
MultiModalFieldConfig
.
batched
(
"video"
),
)
return
_qwen2vl_field_config
(
hf_inputs
)
@
MULTIMODAL_REGISTRY
.
register_processor
(
Qwen2VLMultiModalProcessor
,
...
...
vllm/multimodal/inputs.py
View file @
0a89b8a3
...
...
@@ -739,3 +739,19 @@ class MultiModalInputs(TypedDict):
For each modality, information about the placeholder tokens in
:code:`prompt_token_ids`.
"""
class
MultiModalEncDecInputs
(
MultiModalInputs
):
"""
Represents the outputs of :class:`vllm.multimodal.EncDecMultiModalProcessor`
ready to be passed to vLLM internals.
"""
encoder_prompt
:
str
"""The processed encoder prompt text."""
encoder_prompt_token_ids
:
list
[
int
]
"""The processed token IDs of the encoder prompt."""
encoder_token_type_ids
:
NotRequired
[
list
[
int
]]
"""The token type IDs of the encoder prompt."""
\ No newline at end of file
vllm/multimodal/parse.py
View file @
0a89b8a3
...
...
@@ -9,13 +9,15 @@ from typing import (TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar,
import
numpy
as
np
import
torch
from
PIL.Image
import
Image
from
transformers
import
BatchFeature
from
typing_extensions
import
TypeAlias
,
TypeGuard
,
assert_never
from
vllm.utils
import
is_list_of
from
.audio
import
resample_audio
from
.inputs
import
(
AudioItem
,
HfAudioItem
,
HfImageItem
,
HfVideoItem
,
ImageItem
,
ModalityData
,
MultiModalDataDict
,
VideoItem
)
ImageItem
,
ModalityData
,
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargs
,
VideoItem
)
_T
=
TypeVar
(
"_T"
)
_I
=
TypeVar
(
"_I"
)
...
...
@@ -111,6 +113,64 @@ class EmbeddingItems(ModalityDataItems[Union[torch.Tensor, list[torch.Tensor]],
return
len
(
self
.
get
(
item_idx
))
class
DictEmbeddingItems
(
ModalityDataItems
[
Mapping
[
str
,
torch
.
Tensor
],
Mapping
[
str
,
torch
.
Tensor
]]):
"""
Base class for data items that are expressed as a dictionary of tensors.
Usually, the dictionary keys correspond to the outputs of HF processor.
"""
def
__init__
(
self
,
data
:
Mapping
[
str
,
torch
.
Tensor
],
modality
:
str
,
required_fields
:
set
[
str
],
fields_factory
:
Callable
[
[
Mapping
[
str
,
torch
.
Tensor
]],
Mapping
[
str
,
MultiModalFieldConfig
],
],
)
->
None
:
super
().
__init__
(
data
,
modality
)
missing_required_data_keys
=
required_fields
-
data
.
keys
()
if
missing_required_data_keys
:
data_keys
=
set
(
data
.
keys
())
msg
=
(
f
"The data should contain the fields:
{
required_fields
}
, "
f
"but only found the following keys:
{
data_keys
}
"
)
raise
ValueError
(
msg
)
fields_config
=
fields_factory
(
data
)
missing_required_fields
=
required_fields
-
fields_config
.
keys
()
if
missing_required_fields
:
fields
=
set
(
fields_config
.
keys
())
msg
=
f
"
{
required_fields
=
}
should be a subset of
{
fields
=
}
"
raise
ValueError
(
msg
)
self
.
fields_config
=
fields_config
self
.
required_fields
=
required_fields
self
.
_kwargs
=
MultiModalKwargs
.
from_hf_inputs
(
BatchFeature
(
dict
(
data
)),
fields_config
,
)
def
get_count
(
self
)
->
int
:
return
self
.
_kwargs
.
get_item_count
(
self
.
modality
)
def
get
(
self
,
index
:
int
)
->
Mapping
[
str
,
torch
.
Tensor
]:
return
{
k
:
v
.
data
for
k
,
v
in
self
.
_kwargs
.
get_item
(
self
.
modality
,
index
).
items
()
}
def
get_processor_data
(
self
)
->
Mapping
[
str
,
object
]:
return
{}
def
get_passthrough_data
(
self
)
->
Mapping
[
str
,
object
]:
return
self
.
data
class
AudioProcessorItems
(
ProcessorBatchItems
[
HfAudioItem
]):
def
__init__
(
self
,
data
:
Sequence
[
HfAudioItem
])
->
None
:
...
...
vllm/transformers_utils/processor.py
View file @
0a89b8a3
# SPDX-License-Identifier: Apache-2.0
from
functools
import
lru_cache
from
typing
import
Any
,
cast
from
typing
import
TYPE_CHECKING
,
Any
,
Union
,
cast
from
transformers.processing_utils
import
ProcessorMixin
from
typing_extensions
import
TypeVar
if
TYPE_CHECKING
:
from
vllm.config
import
ModelConfig
_P
=
TypeVar
(
"_P"
,
bound
=
ProcessorMixin
,
default
=
ProcessorMixin
)
class
HashableDict
(
dict
):
"""
A dictionary that can be hashed by lru_cache.
"""
# NOTE: pythonic dict is not hashable,
# we override on it directly for simplicity
def
__hash__
(
self
)
->
int
:
# type: ignore[override]
return
hash
(
frozenset
(
self
.
items
()))
class
HashableList
(
list
):
"""
A list that can be hashed by lru_cache.
"""
def
__hash__
(
self
)
->
int
:
# type: ignore[override]
return
hash
(
tuple
(
self
))
def
_merge_mm_kwargs
(
model_config
:
"ModelConfig"
,
**
kwargs
):
base_kwargs
=
model_config
.
mm_processor_kwargs
if
base_kwargs
is
None
:
base_kwargs
=
{}
merged_kwargs
=
{
**
base_kwargs
,
**
kwargs
}
# NOTE: Pythonic dict is not hashable and will raise unhashable type
# error when calling `cached_get_processor`, therefore we need to
# wrap it to a hashable dict.
for
key
,
value
in
merged_kwargs
.
items
():
if
isinstance
(
value
,
dict
):
merged_kwargs
[
key
]
=
HashableDict
(
value
)
if
isinstance
(
value
,
list
):
merged_kwargs
[
key
]
=
HashableList
(
value
)
return
merged_kwargs
def
get_processor
(
processor_name
:
str
,
*
args
:
Any
,
trust_remote_code
:
bool
=
False
,
processor_cls
:
type
[
ProcessorMixin
]
=
ProcessorMixin
,
processor_cls
:
Union
[
type
[
_P
],
tuple
[
type
[
_P
],
...]
]
=
ProcessorMixin
,
**
kwargs
:
Any
,
):
)
->
_P
:
"""Load a processor for the given model name via HuggingFace."""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from
transformers
import
AutoProcessor
processor_factory
=
(
AutoProcessor
i
f
processor_cls
==
ProcessorMixin
else
processor_cls
)
processor_factory
=
(
AutoProcessor
if
processor_cls
==
ProcessorMixin
or
i
sinstance
(
processor_cls
,
tuple
)
else
processor_cls
)
try
:
processor
=
processor_factory
.
from_pretrained
(
...
...
@@ -43,12 +87,30 @@ def get_processor(
else
:
raise
e
return
cast
(
ProcessorMixin
,
processor
)
if
not
isinstance
(
processor
,
processor_cls
):
raise
TypeError
(
"Invalid type of HuggingFace processor. "
f
"Expected type:
{
processor_cls
}
, but "
f
"found type:
{
type
(
processor
)
}
"
)
return
processor
cached_get_processor
=
lru_cache
(
get_processor
)
def
cached_processor_from_config
(
model_config
:
"ModelConfig"
,
processor_cls
:
Union
[
type
[
_P
],
tuple
[
type
[
_P
],
...]]
=
ProcessorMixin
,
**
kwargs
:
Any
,
)
->
_P
:
return
cached_get_processor
(
model_config
.
model
,
trust_remote_code
=
model_config
.
trust_remote_code
,
processor_cls
=
processor_cls
,
# type: ignore[arg-type]
**
_merge_mm_kwargs
(
model_config
,
**
kwargs
),
)
def
get_image_processor
(
processor_name
:
str
,
*
args
:
Any
,
...
...
@@ -85,6 +147,20 @@ def get_image_processor(
return
cast
(
BaseImageProcessor
,
processor
)
cached_get_image_processor
=
lru_cache
(
get_image_processor
)
def
cached_image_processor_from_config
(
model_config
:
"ModelConfig"
,
**
kwargs
:
Any
,
):
return
cached_get_image_processor
(
model_config
.
model
,
trust_remote_code
=
model_config
.
trust_remote_code
,
**
_merge_mm_kwargs
(
model_config
,
**
kwargs
),
)
def
get_video_processor
(
processor_name
:
str
,
*
args
:
Any
,
...
...
@@ -104,3 +180,17 @@ def get_video_processor(
)
return
cast
(
BaseImageProcessor
,
processor
.
video_processor
)
cached_get_video_processor
=
lru_cache
(
get_video_processor
)
def
cached_video_processor_from_config
(
model_config
:
"ModelConfig"
,
**
kwargs
:
Any
,
):
return
cached_get_video_processor
(
model_config
.
model
,
trust_remote_code
=
model_config
.
trust_remote_code
,
**
_merge_mm_kwargs
(
model_config
,
**
kwargs
),
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment