Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d62856b9
Unverified
Commit
d62856b9
authored
Mar 09, 2026
by
Cyrus Leung
Committed by
GitHub
Mar 09, 2026
Browse files
[Misc] Move processors to `transformers_utils` (#35953)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
bd2659a5
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
509 additions
and
597 deletions
+509
-597
vllm/model_executor/models/glm4v.py
vllm/model_executor/models/glm4v.py
+8
-73
vllm/model_executor/models/molmo.py
vllm/model_executor/models/molmo.py
+100
-180
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+12
-111
vllm/model_executor/models/qwen_vl.py
vllm/model_executor/models/qwen_vl.py
+6
-89
vllm/model_executor/models/voxtral.py
vllm/model_executor/models/voxtral.py
+40
-135
vllm/model_executor/models/voxtral_realtime.py
vllm/model_executor/models/voxtral_realtime.py
+2
-1
vllm/multimodal/processing/context.py
vllm/multimodal/processing/context.py
+5
-5
vllm/transformers_utils/processor.py
vllm/transformers_utils/processor.py
+10
-3
vllm/transformers_utils/processors/__init__.py
vllm/transformers_utils/processors/__init__.py
+8
-0
vllm/transformers_utils/processors/glm4v.py
vllm/transformers_utils/processors/glm4v.py
+35
-0
vllm/transformers_utils/processors/pixtral.py
vllm/transformers_utils/processors/pixtral.py
+116
-0
vllm/transformers_utils/processors/qwen_vl.py
vllm/transformers_utils/processors/qwen_vl.py
+48
-0
vllm/transformers_utils/processors/voxtral.py
vllm/transformers_utils/processors/voxtral.py
+119
-0
No files found.
vllm/model_executor/models/glm4v.py
View file @
d62856b9
...
@@ -13,11 +13,7 @@ import numpy as np
...
@@ -13,11 +13,7 @@ import numpy as np
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
torch.nn
import
LayerNorm
from
torch.nn
import
LayerNorm
from
torchvision
import
transforms
from
transformers
import
BatchFeature
from
torchvision.transforms
import
InterpolationMode
from
transformers
import
BatchFeature
,
PreTrainedTokenizer
,
TensorType
from
transformers.image_utils
import
ImageInput
from
transformers.tokenization_utils_base
import
TextInput
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.config.multimodal
import
BaseDummyOptions
from
vllm.config.multimodal
import
BaseDummyOptions
...
@@ -50,7 +46,8 @@ from vllm.multimodal.processing import (
...
@@ -50,7 +46,8 @@ from vllm.multimodal.processing import (
PromptUpdate
,
PromptUpdate
,
)
)
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs
import
ChatGLMConfig
from
vllm.transformers_utils.configs.chatglm
import
ChatGLMConfig
from
vllm.transformers_utils.processors.glm4v
import
GLM4VProcessor
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.chatglm
import
ChatGLMBaseModel
,
ChatGLMModel
,
GLMTransformer
from
.chatglm
import
ChatGLMBaseModel
,
ChatGLMModel
,
GLMTransformer
...
@@ -386,81 +383,19 @@ class GLM4VModel(ChatGLMModel):
...
@@ -386,81 +383,19 @@ class GLM4VModel(ChatGLMModel):
)
)
class
GLM4VProcessor
:
"""
This model doesn't define its own HF processor,
so we implement our own one here.
"""
def
__init__
(
self
,
config
:
ChatGLMConfig
,
tokenizer
:
PreTrainedTokenizer
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
tokenizer
=
tokenizer
vision_config
=
config
.
vision_config
image_size
=
vision_config
[
"image_size"
]
self
.
image_transform
=
transforms
.
Compose
(
[
transforms
.
Resize
(
(
image_size
,
image_size
),
interpolation
=
InterpolationMode
.
BICUBIC
,
),
transforms
.
ToTensor
(),
transforms
.
Normalize
(
mean
=
(
0.48145466
,
0.4578275
,
0.40821073
),
std
=
(
0.26862954
,
0.26130258
,
0.27577711
),
),
]
)
def
__call__
(
self
,
text
:
TextInput
|
list
[
TextInput
]
|
None
=
None
,
images
:
ImageInput
|
list
[
ImageInput
]
|
None
=
None
,
return_tensors
:
str
|
TensorType
|
None
=
None
,
)
->
BatchFeature
:
if
text
is
None
:
text
=
[]
if
not
isinstance
(
text
,
list
):
text
=
[
text
]
if
images
is
None
:
images
=
[]
if
not
isinstance
(
images
,
list
):
images
=
[
images
]
text_inputs
=
self
.
tokenizer
(
text
)
if
len
(
images
)
==
0
:
image_inputs
=
{}
else
:
pixel_values
=
[
self
.
image_transform
(
image
)
for
image
in
images
]
image_inputs
=
{
"pixel_values"
:
torch
.
stack
(
pixel_values
)}
return
BatchFeature
(
{
**
text_inputs
,
**
image_inputs
,
},
tensor_type
=
return_tensors
,
)
class
GLM4VProcessingInfo
(
BaseProcessingInfo
):
class
GLM4VProcessingInfo
(
BaseProcessingInfo
):
def
get_hf_config
(
self
):
def
get_hf_config
(
self
):
return
self
.
ctx
.
get_hf_config
(
ChatGLMConfig
)
return
self
.
ctx
.
get_hf_config
(
ChatGLMConfig
)
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
GLM4VProcessor
:
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
GLM4VProcessor
:
config
=
self
.
get_hf_config
()
vision_config
=
config
.
vision_config
image_size
=
vision_config
[
"image_size"
]
return
self
.
ctx
.
init_processor
(
return
self
.
ctx
.
init_processor
(
GLM4VProcessor
,
GLM4VProcessor
,
config
=
self
.
get_hf_config
(),
tokenizer
=
self
.
get_tokenizer
(),
tokenizer
=
self
.
get_tokenizer
(),
**
kwargs
,
**
{
**
kwargs
,
"image_size"
:
image_size
},
)
)
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
int
|
None
]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
int
|
None
]:
...
...
vllm/model_executor/models/molmo.py
View file @
d62856b9
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
import
math
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
cached_property
,
partial
from
functools
import
partial
from
itertools
import
islice
from
itertools
import
islice
from
typing
import
Annotated
from
typing
import
Annotated
...
@@ -13,9 +13,11 @@ import torch
...
@@ -13,9 +13,11 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
einops
import
rearrange
from
transformers
import
BatchFeature
,
PretrainedConfig
,
ProcessorMixin
,
TensorType
from
transformers
import
(
from
transformers.image_utils
import
ImageInput
BaseImageProcessor
,
from
transformers.tokenization_utils_base
import
TextInput
BatchFeature
,
PretrainedConfig
,
)
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
...
@@ -1017,117 +1019,28 @@ def select_tiling(
...
@@ -1017,117 +1019,28 @@ def select_tiling(
return
candidate_tilings
[
ix
]
return
candidate_tilings
[
ix
]
class
MolmoProcessorWrapper
:
def
_as_2tuple
(
x
:
int
|
tuple
[
int
,
int
])
->
tuple
[
int
,
int
]:
"""
if
isinstance
(
x
,
int
):
Wraps `MolmoProcessor` so that it can be called directly.
return
x
,
x
The original definition can be found here:
https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/preprocessing_molmo.py
"""
def
__init__
(
self
,
processor
:
ProcessorMixin
):
super
().
__init__
()
self
.
processor
=
processor
@
cached_property
def
vocab
(
self
)
->
dict
[
str
,
int
]:
return
self
.
processor
.
tokenizer
.
vocab
# type: ignore
@
cached_property
def
max_crops
(
self
)
->
int
:
image_processor
=
self
.
processor
.
image_processor
# type: ignore
max_crops
=
image_processor
.
max_crops
assert
isinstance
(
max_crops
,
int
)
return
max_crops
@
cached_property
def
base_image_input_size
(
self
)
->
tuple
[
int
,
int
]:
image_processor
=
self
.
processor
.
image_processor
# type: ignore
base_image_input_size
=
image_processor
.
base_image_input_size
if
isinstance
(
base_image_input_size
,
int
):
return
base_image_input_size
,
base_image_input_size
return
tuple
(
base_image_input_size
)
@
cached_property
def
image_patch_size
(
self
)
->
int
:
image_processor
=
self
.
processor
.
image_processor
# type: ignore
image_patch_size
=
image_processor
.
image_patch_size
assert
isinstance
(
image_patch_size
,
int
)
return
image_patch_size
@
cached_property
def
overlap_margins
(
self
)
->
tuple
[
int
,
int
]:
image_processor
=
self
.
processor
.
image_processor
# type: ignore
left_margin
,
right_margin
=
image_processor
.
overlap_margins
assert
isinstance
(
left_margin
,
int
)
assert
isinstance
(
right_margin
,
int
)
return
left_margin
,
right_margin
@
cached_property
def
image_token_length_w
(
self
)
->
int
:
image_processor
=
self
.
processor
.
image_processor
# type: ignore
image_token_length_w
=
image_processor
.
image_token_length_w
assert
isinstance
(
image_token_length_w
,
int
)
return
image_token_length_w
@
cached_property
def
image_token_length_h
(
self
)
->
int
:
image_processor
=
self
.
processor
.
image_processor
# type: ignore
image_token_length_h
=
image_processor
.
image_token_length_h
assert
isinstance
(
image_token_length_h
,
int
)
return
image_token_length_h
@
property
def
message_format
(
self
)
->
str
|
None
:
return
"role"
@
property
def
always_start_with_space
(
self
)
->
bool
:
return
True
@
cached_property
def
image_patch_id
(
self
)
->
int
:
return
self
.
vocab
[
IMAGE_PATCH_TOKEN
]
@
cached_property
def
im_col_id
(
self
)
->
int
:
return
self
.
vocab
[
IM_COL_TOKEN
]
@
cached_property
return
x
def
im_start_id
(
self
)
->
int
:
return
self
.
vocab
[
IM_START_TOKEN
]
@
cached_property
def
im_end_id
(
self
)
->
int
:
return
self
.
vocab
[
IM_END_TOKEN
]
@
property
class
MolmoProcessingInfo
(
BaseProcessingInfo
):
def
pooling_size
(
self
)
->
int
:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
int
|
None
]
:
return
POOLING_SIZE
return
{
"image"
:
None
}
def
select_tiling
(
def
select_tiling
(
self
,
self
,
*
,
*
,
image_width
:
int
,
image_width
:
int
,
image_height
:
int
,
image_height
:
int
,
image_processor
:
BaseImageProcessor
,
)
->
tuple
[
int
,
int
]:
)
->
tuple
[
int
,
int
]:
max_crops
=
self
.
max_crops
max_crops
=
image_processor
.
max_crops
left_margin
,
right_margin
=
self
.
overlap_margins
left_margin
,
right_margin
=
image_processor
.
overlap_margins
base_image_input_size
=
self
.
base_image_input_size
base_image_input_size
=
_as_2tuple
(
image_processor
.
base_image_input_size
)
base_image_input_d
=
self
.
image_patch_size
base_image_input_d
=
image_processor
.
image_patch_size
total_margin_pixels
=
base_image_input_d
*
(
right_margin
+
left_margin
)
total_margin_pixels
=
base_image_input_d
*
(
right_margin
+
left_margin
)
crop_patches
=
base_image_input_size
[
0
]
//
base_image_input_d
crop_patches
=
base_image_input_size
[
0
]
//
base_image_input_d
...
@@ -1147,16 +1060,18 @@ class MolmoProcessorWrapper:
...
@@ -1147,16 +1060,18 @@ class MolmoProcessorWrapper:
*
,
*
,
image_width
:
int
,
image_width
:
int
,
image_height
:
int
,
image_height
:
int
,
image_processor
:
BaseImageProcessor
,
)
->
tuple
[
int
,
int
]:
)
->
tuple
[
int
,
int
]:
left_margin
,
right_margin
=
self
.
overlap_margins
left_margin
,
right_margin
=
image_processor
.
overlap_margins
base_image_input_size
=
self
.
base_image_input_size
base_image_input_size
=
_as_2tuple
(
image_processor
.
base_image_input_size
)
base_image_input_d
=
self
.
image_patch_size
base_image_input_d
=
image_processor
.
image_patch_size
pooling_size
=
self
.
pooling_size
pooling_size
=
POOLING_SIZE
crop_patches
=
base_image_input_size
[
0
]
//
base_image_input_d
crop_patches
=
base_image_input_size
[
0
]
//
base_image_input_d
tiling_w
,
tiling_h
=
self
.
select_tiling
(
tiling_w
,
tiling_h
=
self
.
select_tiling
(
image_height
=
image_height
,
image_height
=
image_height
,
image_width
=
image_width
,
image_width
=
image_width
,
image_processor
=
image_processor
,
)
)
nrows
,
ncols
=
get_patches_grid_size
(
nrows
,
ncols
=
get_patches_grid_size
(
...
@@ -1170,70 +1085,22 @@ class MolmoProcessorWrapper:
...
@@ -1170,70 +1085,22 @@ class MolmoProcessorWrapper:
return
ncols
,
nrows
return
ncols
,
nrows
def
__call__
(
self
,
text
:
TextInput
|
list
[
TextInput
]
|
None
=
None
,
images
:
ImageInput
|
list
[
ImageInput
]
|
None
=
None
,
return_tensors
:
str
|
TensorType
|
None
=
None
,
**
kwargs
,
)
->
BatchFeature
:
outputs
=
self
.
processor
.
process
(
# type: ignore
text
,
images
,
**
kwargs
)
if
images
is
None
:
images
=
[]
if
not
isinstance
(
images
,
list
):
images
=
[
images
]
input_ids
:
torch
.
Tensor
=
outputs
.
pop
(
"input_ids"
)
outputs
[
"input_ids"
]
=
input_ids
.
unsqueeze
(
0
)
image_input_idx
=
outputs
.
pop
(
"image_input_idx"
,
None
)
if
image_input_idx
is
not
None
:
feat_is_patch
=
image_input_idx
>=
0
tilings
=
[
self
.
select_tiling
(
image_width
=
image
.
size
[
0
],
image_height
=
image
.
size
[
1
],
)
for
image
in
images
]
# For each image: tiling_h * tiling_w + extra
num_crops
=
torch
.
tensor
(
tilings
).
prod
(
-
1
)
+
1
assert
num_crops
.
sum
()
==
len
(
feat_is_patch
)
outputs
[
"image_input_idx"
]
=
image_input_idx
outputs
[
"num_crops"
]
=
num_crops
outputs
[
"img_patch_id"
]
=
self
.
image_patch_id
return
BatchFeature
(
outputs
)
class
MolmoProcessingInfo
(
BaseProcessingInfo
):
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
MolmoProcessorWrapper
:
processor
=
self
.
ctx
.
get_hf_processor
(
**
kwargs
)
return
MolmoProcessorWrapper
(
processor
)
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
int
|
None
]:
return
{
"image"
:
None
}
def
get_num_image_tokens
(
def
get_num_image_tokens
(
self
,
self
,
*
,
*
,
image_width
:
int
,
image_width
:
int
,
image_height
:
int
,
image_height
:
int
,
processor
:
Molmo
Processor
Wrapper
,
image_
processor
:
BaseImage
Processor
,
)
->
int
:
)
->
int
:
ncols
,
nrows
=
processor
.
get_patches_grid_size
(
ncols
,
nrows
=
self
.
get_patches_grid_size
(
image_width
=
image_width
,
image_width
=
image_width
,
image_height
=
image_height
,
image_height
=
image_height
,
image_processor
=
image_processor
,
)
)
pooling_size
=
processor
.
pooling_size
pooling_size
=
POOLING_SIZE
image_token_length_w
=
processor
.
image_token_length_w
image_token_length_w
=
image_
processor
.
image_token_length_w
image_token_length_h
=
processor
.
image_token_length_h
image_token_length_h
=
image_
processor
.
image_token_length_h
# Calculate total tokens: 2 for start/end + (w+1)*h for column separators
# Calculate total tokens: 2 for start/end + (w+1)*h for column separators
extra
=
2
+
(
image_token_length_w
+
1
)
*
image_token_length_h
extra
=
2
+
(
image_token_length_w
+
1
)
*
image_token_length_h
...
@@ -1243,9 +1110,10 @@ class MolmoProcessingInfo(BaseProcessingInfo):
...
@@ -1243,9 +1110,10 @@ class MolmoProcessingInfo(BaseProcessingInfo):
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
processor
=
self
.
get_hf_processor
()
processor
=
self
.
get_hf_processor
()
image_processor
=
processor
.
image_processor
tilings
=
get_candidate_tilings
(
processor
.
max_crops
)
tilings
=
get_candidate_tilings
(
image_
processor
.
max_crops
)
base_h
,
base_w
=
processor
.
base_image_input_size
base_h
,
base_w
=
_as_2tuple
(
image_
processor
.
base_image_input_size
)
largest_feature_size
,
largest_feature_pinpoint
=
0
,
None
largest_feature_size
,
largest_feature_pinpoint
=
0
,
None
for
wr
,
hr
in
tilings
:
for
wr
,
hr
in
tilings
:
...
@@ -1254,7 +1122,7 @@ class MolmoProcessingInfo(BaseProcessingInfo):
...
@@ -1254,7 +1122,7 @@ class MolmoProcessingInfo(BaseProcessingInfo):
feat_size
=
self
.
get_num_image_tokens
(
feat_size
=
self
.
get_num_image_tokens
(
image_width
=
width
,
image_width
=
width
,
image_height
=
height
,
image_height
=
height
,
processor
=
processor
,
image_
processor
=
image_
processor
,
)
)
if
feat_size
>
largest_feature_size
:
if
feat_size
>
largest_feature_size
:
largest_feature_size
=
feat_size
largest_feature_size
=
feat_size
...
@@ -1292,6 +1160,54 @@ class MolmoDummyInputsBuilder(BaseDummyInputsBuilder[MolmoProcessingInfo]):
...
@@ -1292,6 +1160,54 @@ class MolmoDummyInputsBuilder(BaseDummyInputsBuilder[MolmoProcessingInfo]):
class
MolmoMultiModalProcessor
(
BaseMultiModalProcessor
[
MolmoProcessingInfo
]):
class
MolmoMultiModalProcessor
(
BaseMultiModalProcessor
[
MolmoProcessingInfo
]):
def
_call_hf_processor
(
self
,
prompt
:
str
,
mm_data
:
Mapping
[
str
,
object
],
mm_kwargs
:
Mapping
[
str
,
object
],
tok_kwargs
:
Mapping
[
str
,
object
],
)
->
BatchFeature
:
hf_processor
=
self
.
info
.
get_hf_processor
(
**
mm_kwargs
)
processed_outputs
=
self
.
info
.
ctx
.
call_hf_processor
(
hf_processor
.
process
,
dict
(
text
=
prompt
,
**
mm_data
),
dict
(
**
mm_kwargs
,
**
tok_kwargs
),
)
tokenizer
=
hf_processor
.
tokenizer
image_patch_id
=
tokenizer
.
vocab
[
IMAGE_PATCH_TOKEN
]
image_processor
=
hf_processor
.
image_processor
input_ids
:
torch
.
Tensor
=
processed_outputs
.
pop
(
"input_ids"
)
processed_outputs
[
"input_ids"
]
=
input_ids
.
unsqueeze
(
0
)
if
(
images
:
=
mm_data
.
get
(
"images"
))
is
not
None
:
mm_items
=
self
.
info
.
parse_mm_data
({
"image"
:
images
},
validate
=
False
)
parsed_images
=
mm_items
.
get_items
(
"image"
,
ImageProcessorItems
)
image_sizes
=
[
parsed_images
.
get_image_size
(
i
)
for
i
in
range
(
len
(
parsed_images
))
]
feat_is_patch
=
processed_outputs
[
"image_input_idx"
]
>=
0
tilings
=
[
self
.
info
.
select_tiling
(
image_width
=
image_size
.
width
,
image_height
=
image_size
.
height
,
image_processor
=
image_processor
,
)
for
image_size
in
image_sizes
]
# For each image: tiling_h * tiling_w + extra
num_crops
=
torch
.
tensor
(
tilings
).
prod
(
-
1
)
+
1
assert
num_crops
.
sum
()
==
len
(
feat_is_patch
)
processed_outputs
[
"num_crops"
]
=
num_crops
processed_outputs
[
"img_patch_id"
]
=
image_patch_id
return
processed_outputs
def
_apply_hf_processor_tokens_only
(
def
_apply_hf_processor_tokens_only
(
self
,
self
,
prompt_tokens
:
list
[
int
],
prompt_tokens
:
list
[
int
],
...
@@ -1301,18 +1217,19 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
...
@@ -1301,18 +1217,19 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
# The chat template is already applied to the prompt tokens
# The chat template is already applied to the prompt tokens
# Use message_format="none" to avoid applying it again
# Use message_format="none" to avoid applying it again
# Prepend an empty space if `always_start_with_space` is True
# Prepend an empty space if `always_start_with_space` is True
tokens
=
processor
.
processor
.
get_tokens_input
(
# type: ignore
tokens
=
processor
.
get_tokens_input
(
self
.
info
.
get_tokenizer
().
decode
(
prompt_tokens
),
self
.
info
.
get_tokenizer
().
decode
(
prompt_tokens
),
message_format
=
"none"
,
message_format
=
"none"
,
always_start_with_space
=
processor
.
always_start_with_spac
e
,
always_start_with_space
=
Tru
e
,
)
)
# Prepend a BOS token id to the tokens
# Prepend a BOS token id to the tokens
processed_data
=
self
.
info
.
ctx
.
call_hf_processor
(
processed_data
=
self
.
info
.
ctx
.
call_hf_processor
(
processor
,
# type: ignore
processor
.
process
,
dict
(
tokens
=
tokens
),
dict
(
tokens
=
tokens
),
)
)
(
prompt_ids
,)
=
processed_data
.
pop
(
"input_ids"
).
tolist
()
prompt_ids
=
processed_data
.
pop
(
"input_ids"
).
tolist
()
print
(
prompt_ids
,
len
(
prompt_ids
))
return
prompt_ids
return
prompt_ids
...
@@ -1338,16 +1255,18 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
...
@@ -1338,16 +1255,18 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
out_mm_kwargs
:
MultiModalKwargsItems
,
out_mm_kwargs
:
MultiModalKwargsItems
,
)
->
Sequence
[
PromptUpdate
]:
)
->
Sequence
[
PromptUpdate
]:
processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
tokenizer
=
self
.
info
.
get_tokenizer
()
vocab
=
tokenizer
.
get_vocab
()
image_token_length_w
=
processor
.
image_token_length_w
img_patch_id
=
vocab
[
IMAGE_PATCH_TOKEN
]
image_token_length_h
=
processor
.
image_token_length_h
img_col_id
=
vocab
[
IM_COL_TOKEN
]
pooling_size
=
processor
.
pooling_size
img_start_id
=
vocab
[
IM_START_TOKEN
]
img_end_id
=
vocab
[
IM_END_TOKEN
]
img_patch_id
=
processor
.
image_patch_id
processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
img_col_id
=
processor
.
im_col_id
image_processor
=
processor
.
image_processor
img_start_id
=
processor
.
im_start_id
image_token_length_w
=
image_processor
.
image_token_length_w
img_end_id
=
processor
.
im_end_id
image_token_length_h
=
image_processor
.
image_token_length_h
pooling_size
=
POOLING_SIZE
extra_row
=
[
img_patch_id
]
*
image_token_length_w
+
[
img_col_id
]
extra_row
=
[
img_patch_id
]
*
image_token_length_w
+
[
img_col_id
]
extra_joint
=
[
img_start_id
]
+
extra_row
*
image_token_length_h
+
[
img_end_id
]
extra_joint
=
[
img_start_id
]
+
extra_row
*
image_token_length_h
+
[
img_end_id
]
...
@@ -1356,9 +1275,10 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
...
@@ -1356,9 +1275,10 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
images
=
mm_items
.
get_items
(
"image"
,
ImageProcessorItems
)
images
=
mm_items
.
get_items
(
"image"
,
ImageProcessorItems
)
image_size
=
images
.
get_image_size
(
item_idx
)
image_size
=
images
.
get_image_size
(
item_idx
)
ncols
,
nrows
=
processor
.
get_patches_grid_size
(
ncols
,
nrows
=
self
.
info
.
get_patches_grid_size
(
image_width
=
image_size
.
width
,
image_width
=
image_size
.
width
,
image_height
=
image_size
.
height
,
image_height
=
image_size
.
height
,
image_processor
=
image_processor
,
)
)
joint_row
=
[
img_patch_id
]
*
((
ncols
+
1
)
//
pooling_size
)
+
[
img_col_id
]
joint_row
=
[
img_patch_id
]
*
((
ncols
+
1
)
//
pooling_size
)
+
[
img_col_id
]
...
...
vllm/model_executor/models/pixtral.py
View file @
d62856b9
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
import
math
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
dataclasses
import
dataclass
,
fields
from
dataclasses
import
dataclass
,
fields
from
functools
import
cached_property
from
typing
import
Annotated
,
Literal
from
typing
import
Annotated
,
Literal
import
torch
import
torch
...
@@ -13,10 +12,7 @@ import torch.nn.functional as F
...
@@ -13,10 +12,7 @@ import torch.nn.functional as F
from
mistral_common.protocol.instruct.chunk
import
ImageChunk
,
TextChunk
from
mistral_common.protocol.instruct.chunk
import
ImageChunk
,
TextChunk
from
mistral_common.protocol.instruct.messages
import
UserMessage
from
mistral_common.protocol.instruct.messages
import
UserMessage
from
mistral_common.protocol.instruct.request
import
ChatCompletionRequest
from
mistral_common.protocol.instruct.request
import
ChatCompletionRequest
from
mistral_common.tokens.tokenizers.multimodal
import
ImageEncoder
from
transformers
import
PixtralVisionConfig
from
PIL
import
Image
from
transformers
import
BatchFeature
,
PixtralVisionConfig
,
TensorType
from
transformers.image_utils
import
ImageInput
from
transformers.models.pixtral.image_processing_pixtral
import
(
from
transformers.models.pixtral.image_processing_pixtral
import
(
_num_image_tokens
as
_get_pixtral_hf_num_image_tokens
,
_num_image_tokens
as
_get_pixtral_hf_num_image_tokens
,
)
)
...
@@ -25,7 +21,6 @@ from transformers.models.pixtral.modeling_pixtral import (
...
@@ -25,7 +21,6 @@ from transformers.models.pixtral.modeling_pixtral import (
apply_rotary_pos_emb
,
apply_rotary_pos_emb
,
position_ids_in_meshgrid
,
position_ids_in_meshgrid
,
)
)
from
transformers.tokenization_utils_base
import
TextInput
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.config.multimodal
import
BaseDummyOptions
from
vllm.config.multimodal
import
BaseDummyOptions
...
@@ -66,6 +61,7 @@ from vllm.platforms import current_platform
...
@@ -66,6 +61,7 @@ from vllm.platforms import current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.tokenizers
import
cached_tokenizer_from_config
from
vllm.tokenizers
import
cached_tokenizer_from_config
from
vllm.tokenizers.mistral
import
MistralTokenizer
from
vllm.tokenizers.mistral
import
MistralTokenizer
from
vllm.transformers_utils.processors.pixtral
import
MistralCommonPixtralProcessor
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.interfaces
import
(
from
.interfaces
import
(
...
@@ -121,93 +117,6 @@ class PixtralImagePixelInputs(TensorSchema):
...
@@ -121,93 +117,6 @@ class PixtralImagePixelInputs(TensorSchema):
]
]
class
PixtralProcessorAdapter
:
"""
Provide a HF-compatible interface for
`mistral_common.tokens.tokenizers.multimodal.ImageEncoder`.
"""
def
__init__
(
self
,
tokenizer
:
MistralTokenizer
)
->
None
:
super
().
__init__
()
self
.
tokenizer
=
tokenizer
@
property
def
image_processor
(
self
)
->
ImageEncoder
:
image_encoder
=
self
.
tokenizer
.
instruct
.
mm_encoder
assert
isinstance
(
image_encoder
,
ImageEncoder
)
return
image_encoder
@
cached_property
def
image_break_id
(
self
)
->
int
:
return
self
.
image_processor
.
special_ids
.
img_break
@
cached_property
def
image_token_id
(
self
)
->
int
:
return
self
.
image_processor
.
special_ids
.
img
@
cached_property
def
image_end_id
(
self
)
->
int
:
return
self
.
image_processor
.
special_ids
.
img_end
@
cached_property
def
image_size
(
self
)
->
int
:
return
self
.
image_processor
.
mm_config
.
max_image_size
@
cached_property
def
patch_size
(
self
)
->
int
:
return
self
.
image_processor
.
mm_config
.
image_patch_size
def
__call__
(
self
,
text
:
TextInput
|
list
[
TextInput
]
|
None
=
None
,
images
:
ImageInput
|
list
[
ImageInput
]
|
None
=
None
,
return_tensors
:
str
|
TensorType
|
None
=
None
,
**
kwargs
,
)
->
Mapping
[
str
,
NestedTensors
]:
if
text
is
None
:
text
=
[]
if
not
isinstance
(
text
,
list
):
text
=
[
text
]
if
images
is
None
:
images
=
[]
if
not
isinstance
(
images
,
list
):
images
=
[
images
]
if
not
images
:
input_ids
=
self
.
tokenizer
(
text
).
input_ids
return
{
"input_ids"
:
torch
.
tensor
(
input_ids
)}
# Allow dummy text, which is used for profiling as well as token inputs
if
any
(
len
(
t
)
>
0
for
t
in
text
):
raise
ValueError
(
"You've passed text inputs instead of token inputs. "
"Make sure to process your input via `mistral_common`'s "
"tokenizer or pass a chat completion request. "
"For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411."
)
images_processed
=
list
[
torch
.
Tensor
]()
images_tokens
=
list
[
torch
.
Tensor
]()
for
image
in
images
:
image_inputs
=
self
.
image_processor
(
ImageChunk
(
image
=
image
))
image_processed
=
torch
.
tensor
(
image_inputs
.
image
)
image_tokens
=
torch
.
tensor
(
image_inputs
.
tokens
)
images_processed
.
append
(
image_processed
)
images_tokens
.
append
(
image_tokens
)
return
BatchFeature
(
{
"input_ids"
:
torch
.
cat
(
images_tokens
)[
None
].
expand
(
len
(
text
),
-
1
),
"images"
:
images_processed
,
}
)
class
PixtralProcessingInfo
(
BaseProcessingInfo
):
class
PixtralProcessingInfo
(
BaseProcessingInfo
):
def
get_tokenizer
(
self
)
->
MistralTokenizer
:
def
get_tokenizer
(
self
)
->
MistralTokenizer
:
tokenizer
=
cached_tokenizer_from_config
(
self
.
ctx
.
model_config
)
tokenizer
=
cached_tokenizer_from_config
(
self
.
ctx
.
model_config
)
...
@@ -216,28 +125,19 @@ class PixtralProcessingInfo(BaseProcessingInfo):
...
@@ -216,28 +125,19 @@ class PixtralProcessingInfo(BaseProcessingInfo):
return
tokenizer
return
tokenizer
def
get_hf_processor
(
self
)
->
PixtralProcessorAdapter
:
def
get_hf_processor
(
self
,
**
kwargs
)
->
MistralCommonPixtralProcessor
:
return
PixtralProcessorAdapter
(
self
.
get_tokenizer
())
return
self
.
ctx
.
init_processor
(
MistralCommonPixtralProcessor
,
tokenizer
=
self
.
get_tokenizer
(),
**
kwargs
,
)
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
int
|
None
]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
int
|
None
]:
return
{
"image"
:
None
}
return
{
"image"
:
None
}
def
get_num_image_tokens
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
processor
:
PixtralProcessorAdapter
,
)
->
int
:
ncols
,
nrows
=
processor
.
image_processor
.
_image_to_num_tokens
(
Image
.
new
(
"RGB"
,
(
image_width
,
image_height
))
)
return
ncols
*
nrows
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
image_processor
=
self
.
get_hf_processor
().
image_processor
image_processor
=
self
.
get_hf_processor
().
image_processor
max_image_size
=
image_processor
.
mm_config
.
max_image_size
max_image_size
=
image_processor
.
mm_
encoder
.
mm_
config
.
max_image_size
return
ImageSize
(
width
=
max_image_size
,
height
=
max_image_size
)
return
ImageSize
(
width
=
max_image_size
,
height
=
max_image_size
)
...
@@ -321,8 +221,9 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo])
...
@@ -321,8 +221,9 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo])
images
=
mm_items
.
get_items
(
"image"
,
ImageProcessorItems
)
images
=
mm_items
.
get_items
(
"image"
,
ImageProcessorItems
)
image_size
=
images
.
get_image_size
(
item_idx
)
image_size
=
images
.
get_image_size
(
item_idx
)
ncols
,
nrows
=
processor
.
image_processor
.
_image_to_num_tokens
(
_
,
nrows
,
ncols
=
processor
.
image_processor
.
get_number_of_image_patches
(
Image
.
new
(
"RGB"
,
(
image_size
.
width
,
image_size
.
height
))
image_size
.
height
,
image_size
.
width
,
)
)
tokens
=
([
image_token_id
]
*
ncols
+
[
image_break_id
])
*
nrows
tokens
=
([
image_token_id
]
*
ncols
+
[
image_break_id
])
*
nrows
...
...
vllm/model_executor/models/qwen_vl.py
View file @
d62856b9
...
@@ -14,11 +14,7 @@ from typing import Annotated, Literal, TypeAlias
...
@@ -14,11 +14,7 @@ from typing import Annotated, Literal, TypeAlias
import
regex
as
re
import
regex
as
re
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
torchvision
import
transforms
from
transformers
import
BatchFeature
from
torchvision.transforms
import
InterpolationMode
from
transformers
import
BatchFeature
,
PretrainedConfig
,
PreTrainedTokenizer
,
TensorType
from
transformers.image_utils
import
ImageInput
from
transformers.tokenization_utils_base
import
TextInput
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.config.multimodal
import
BaseDummyOptions
from
vllm.config.multimodal
import
BaseDummyOptions
...
@@ -48,6 +44,7 @@ from vllm.multimodal.processing import (
...
@@ -48,6 +44,7 @@ from vllm.multimodal.processing import (
PromptUpdateDetails
,
PromptUpdateDetails
,
)
)
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.processors.qwen_vl
import
QwenVLProcessor
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.interfaces
import
(
from
.interfaces
import
(
...
@@ -434,96 +431,16 @@ class QwenVLModel(QWenModel):
...
@@ -434,96 +431,16 @@ class QwenVLModel(QWenModel):
)
)
class
QwenVLProcessor
:
class
QwenVLProcessingInfo
(
BaseProcessingInfo
):
"""
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
QwenVLProcessor
:
This model doesn't define its own HF processor,
config
=
self
.
get_hf_config
()
so we implement our own one here.
We call the wrapped tokenizer to automatically insert image pad tokens:
https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py#L245
The image processor is defined here:
https://huggingface.co/Qwen/Qwen-VL/blob/main/visual.py#L354
"""
def
__init__
(
self
,
config
:
PretrainedConfig
,
tokenizer
:
PreTrainedTokenizer
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
tokenizer
=
tokenizer
vision_config
=
config
.
visual
vision_config
=
config
.
visual
image_size
=
vision_config
[
"image_size"
]
image_size
=
vision_config
[
"image_size"
]
self
.
image_transform
=
transforms
.
Compose
(
[
transforms
.
Resize
(
(
image_size
,
image_size
),
interpolation
=
InterpolationMode
.
BICUBIC
,
),
transforms
.
ToTensor
(),
transforms
.
Normalize
(
mean
=
(
0.48145466
,
0.4578275
,
0.40821073
),
std
=
(
0.26862954
,
0.26130258
,
0.27577711
),
),
]
)
@
property
def
image_start_tag
(
self
)
->
str
:
return
self
.
tokenizer
.
image_start_tag
# type: ignore
@
property
def
image_end_tag
(
self
)
->
str
:
return
self
.
tokenizer
.
image_end_tag
# type: ignore
@
property
def
image_pad_tag
(
self
)
->
str
:
return
self
.
tokenizer
.
image_pad_tag
# type: ignore
def
__call__
(
self
,
text
:
TextInput
|
list
[
TextInput
]
|
None
=
None
,
images
:
ImageInput
|
list
[
ImageInput
]
|
None
=
None
,
return_tensors
:
str
|
TensorType
|
None
=
None
,
)
->
BatchFeature
:
if
text
is
None
:
text
=
[]
if
not
isinstance
(
text
,
list
):
text
=
[
text
]
if
images
is
None
:
images
=
[]
if
not
isinstance
(
images
,
list
):
images
=
[
images
]
text_inputs
=
self
.
tokenizer
(
text
)
if
len
(
images
)
==
0
:
image_inputs
=
{}
else
:
pixel_values
=
[
self
.
image_transform
(
image
)
for
image
in
images
]
image_inputs
=
{
"pixel_values"
:
torch
.
stack
(
pixel_values
)}
return
BatchFeature
(
{
**
text_inputs
,
**
image_inputs
,
},
tensor_type
=
return_tensors
,
)
class
QwenVLProcessingInfo
(
BaseProcessingInfo
):
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
QwenVLProcessor
:
return
self
.
ctx
.
init_processor
(
return
self
.
ctx
.
init_processor
(
QwenVLProcessor
,
QwenVLProcessor
,
config
=
self
.
get_hf_config
(),
tokenizer
=
self
.
get_tokenizer
(),
tokenizer
=
self
.
get_tokenizer
(),
**
kwargs
,
**
{
**
kwargs
,
"image_size"
:
image_size
},
)
)
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
int
|
None
]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
int
|
None
]:
...
...
vllm/model_executor/models/voxtral.py
View file @
d62856b9
...
@@ -3,25 +3,19 @@
...
@@ -3,25 +3,19 @@
import
math
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
functools
import
cached_property
,
partial
from
functools
import
partial
from
math
import
ceil
from
typing
import
Literal
,
cast
from
typing
import
Literal
,
cast
import
numpy
as
np
import
numpy
as
np
import
regex
as
re
import
regex
as
re
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
mistral_common.audio
import
mel_filter_bank
from
mistral_common.audio
import
Audio
,
mel_filter_bank
from
mistral_common.protocol.instruct.chunk
import
AudioChunk
,
RawAudio
,
TextChunk
from
mistral_common.protocol.instruct.chunk
import
AudioChunk
,
RawAudio
,
TextChunk
from
mistral_common.protocol.instruct.messages
import
UserMessage
from
mistral_common.protocol.instruct.messages
import
UserMessage
from
mistral_common.protocol.instruct.request
import
ChatCompletionRequest
from
mistral_common.protocol.instruct.request
import
ChatCompletionRequest
from
mistral_common.protocol.transcription.request
import
TranscriptionRequest
from
mistral_common.protocol.transcription.request
import
TranscriptionRequest
from
mistral_common.tokens.tokenizers.audio
import
(
from
transformers
import
BatchFeature
,
WhisperConfig
Audio
,
AudioEncoder
,
)
from
transformers
import
BatchFeature
,
TensorType
,
WhisperConfig
from
transformers.tokenization_utils_base
import
TextInput
from
vllm.config
import
ModelConfig
,
SpeechToTextConfig
,
VllmConfig
from
vllm.config
import
ModelConfig
,
SpeechToTextConfig
,
VllmConfig
from
vllm.config.multimodal
import
BaseDummyOptions
from
vllm.config.multimodal
import
BaseDummyOptions
...
@@ -62,6 +56,7 @@ from vllm.multimodal.processing.processor import (
...
@@ -62,6 +56,7 @@ from vllm.multimodal.processing.processor import (
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.tokenizers
import
cached_tokenizer_from_config
from
vllm.tokenizers
import
cached_tokenizer_from_config
from
vllm.tokenizers.mistral
import
MistralTokenizer
from
vllm.tokenizers.mistral
import
MistralTokenizer
from
vllm.transformers_utils.processors.voxtral
import
MistralCommonVoxtralProcessor
from
.interfaces
import
SupportsLoRA
,
SupportsMultiModal
,
SupportsTranscription
from
.interfaces
import
SupportsLoRA
,
SupportsMultiModal
,
SupportsTranscription
from
.utils
import
init_vllm_registered_model
,
maybe_prefix
from
.utils
import
init_vllm_registered_model
,
maybe_prefix
...
@@ -81,98 +76,6 @@ ISO639_1_SUPPORTED_LANGS = {
...
@@ -81,98 +76,6 @@ ISO639_1_SUPPORTED_LANGS = {
}
}
class
VoxtralProcessorAdapter
:
"""
Provide a HF-compatible interface for
:class:`mistral_common.tokens.tokenizers.multimodal.AudioEncoder`.
"""
def
__init__
(
self
,
tokenizer
:
MistralTokenizer
)
->
None
:
super
().
__init__
()
self
.
tokenizer
=
tokenizer
@
cached_property
def
_audio_processor
(
self
)
->
AudioEncoder
:
audio_encoder
=
self
.
tokenizer
.
instruct
.
audio_encoder
assert
isinstance
(
audio_encoder
,
AudioEncoder
)
return
audio_encoder
@
cached_property
def
audio_token_id
(
self
)
->
int
:
return
self
.
_audio_processor
.
special_ids
.
audio
@
cached_property
def
begin_audio_token_id
(
self
)
->
int
:
return
self
.
_audio_processor
.
special_ids
.
begin_audio
@
cached_property
def
sampling_rate
(
self
)
->
int
:
return
self
.
_audio_processor
.
audio_config
.
sampling_rate
@
cached_property
def
frame_rate
(
self
)
->
float
:
return
self
.
_audio_processor
.
audio_config
.
frame_rate
def
get_num_audio_tokens
(
self
,
audio_length
:
int
,
)
->
int
:
return
ceil
(
audio_length
/
(
self
.
sampling_rate
//
self
.
frame_rate
))
def
__call__
(
self
,
text
:
TextInput
|
list
[
TextInput
]
|
None
=
None
,
audios
:
np
.
ndarray
|
list
[
np
.
ndarray
]
|
None
=
None
,
return_tensors
:
str
|
TensorType
|
None
=
None
,
**
kwargs
,
)
->
Mapping
[
str
,
NestedTensors
]:
if
text
is
None
:
text
=
[]
if
not
isinstance
(
text
,
list
):
text
=
[
text
]
if
audios
is
None
:
audios
=
[]
if
not
isinstance
(
audios
,
list
):
audios
=
[
audios
]
if
not
audios
:
input_ids
=
self
.
tokenizer
(
text
).
input_ids
return
{
"input_ids"
:
torch
.
tensor
(
input_ids
)}
# Allow dummy text, which is used for profiling as well as token inputs
if
any
(
len
(
t
)
>
0
for
t
in
text
):
raise
ValueError
(
"You've passed text inputs instead of token inputs. "
"Make sure to process your input via `mistral_common`'s "
"tokenizer or pass a chat completion request. "
"For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411."
)
audios_tokens
=
list
[
torch
.
Tensor
]()
audios_processed
=
list
[
torch
.
Tensor
]()
for
audio
in
audios
:
assert
isinstance
(
audio
,
np
.
ndarray
)
assert
audio
.
ndim
==
1
if
not
self
.
_audio_processor
.
audio_config
.
is_streaming
:
audio
=
self
.
_audio_processor
.
pad
(
audio
,
self
.
sampling_rate
)
audio_tokens
=
[
self
.
begin_audio_token_id
]
+
[
self
.
audio_token_id
]
*
self
.
get_num_audio_tokens
(
len
(
audio
))
audios_tokens
.
append
(
torch
.
tensor
(
audio_tokens
))
audios_processed
.
append
(
torch
.
tensor
(
audio
))
return
BatchFeature
(
{
"input_ids"
:
torch
.
cat
(
audios_tokens
)[
None
].
expand
(
len
(
text
),
-
1
),
"audio_arrays"
:
audios_processed
,
}
)
class
VoxtralProcessingInfo
(
BaseProcessingInfo
):
class
VoxtralProcessingInfo
(
BaseProcessingInfo
):
def
get_tokenizer
(
self
)
->
MistralTokenizer
:
def
get_tokenizer
(
self
)
->
MistralTokenizer
:
tokenizer
=
cached_tokenizer_from_config
(
self
.
ctx
.
model_config
)
tokenizer
=
cached_tokenizer_from_config
(
self
.
ctx
.
model_config
)
...
@@ -181,12 +84,18 @@ class VoxtralProcessingInfo(BaseProcessingInfo):
...
@@ -181,12 +84,18 @@ class VoxtralProcessingInfo(BaseProcessingInfo):
return
tokenizer
return
tokenizer
def
get_hf_processor
(
self
)
->
VoxtralProcessorAdapter
:
def
get_hf_processor
(
self
,
**
kwargs
)
->
MistralCommonVoxtralProcessor
:
return
VoxtralProcessorAdapter
(
self
.
get_tokenizer
())
return
self
.
ctx
.
init_processor
(
MistralCommonVoxtralProcessor
,
tokenizer
=
self
.
get_tokenizer
(),
**
kwargs
,
)
def
get_data_parser
(
self
):
def
get_data_parser
(
self
):
feature_extractor
=
self
.
get_hf_processor
().
feature_extractor
return
MultiModalDataParser
(
return
MultiModalDataParser
(
target_sr
=
self
.
get_hf_processor
()
.
sampling_rate
,
target_sr
=
feature_extractor
.
sampling_rate
,
target_channels
=
1
,
target_channels
=
1
,
expected_hidden_size
=
self
.
_get_expected_hidden_size
(),
expected_hidden_size
=
self
.
_get_expected_hidden_size
(),
)
)
...
@@ -205,9 +114,10 @@ class VoxtralProcessingInfo(BaseProcessingInfo):
...
@@ -205,9 +114,10 @@ class VoxtralProcessingInfo(BaseProcessingInfo):
return
self
.
ctx
.
model_config
.
max_model_len
return
self
.
ctx
.
model_config
.
max_model_len
def
get_max_audio_array_len
(
self
)
->
int
:
def
get_max_audio_array_len
(
self
)
->
int
:
processor
=
self
.
get_hf_processor
()
feature_extractor
=
self
.
get_hf_processor
().
feature_extractor
return
self
.
get_max_audio_tokens
()
*
int
(
return
self
.
get_max_audio_tokens
()
*
int
(
process
or
.
sampling_rate
//
process
or
.
frame_rate
feature_extract
or
.
sampling_rate
//
feature_extract
or
.
frame_rate
)
)
...
@@ -242,6 +152,7 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
...
@@ -242,6 +152,7 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
mm_options
:
Mapping
[
str
,
BaseDummyOptions
],
mm_options
:
Mapping
[
str
,
BaseDummyOptions
],
)
->
ProcessorInputs
:
)
->
ProcessorInputs
:
tokenizer
=
self
.
info
.
get_tokenizer
()
tokenizer
=
self
.
info
.
get_tokenizer
()
feature_extractor
=
self
.
info
.
get_hf_processor
().
feature_extractor
dummy_text
=
self
.
get_dummy_text
(
mm_counts
)
dummy_text
=
self
.
get_dummy_text
(
mm_counts
)
dummy_mm_data
=
self
.
get_dummy_mm_data
(
seq_len
,
mm_counts
,
mm_options
)
dummy_mm_data
=
self
.
get_dummy_mm_data
(
seq_len
,
mm_counts
,
mm_options
)
...
@@ -252,7 +163,7 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
...
@@ -252,7 +163,7 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
for
audio
in
dummy_audios
:
for
audio
in
dummy_audios
:
audio_item
=
Audio
(
audio_item
=
Audio
(
audio_array
=
audio
,
audio_array
=
audio
,
sampling_rate
=
self
.
info
.
get_hf_processor
()
.
sampling_rate
,
sampling_rate
=
feature_extractor
.
sampling_rate
,
format
=
format
,
format
=
format
,
)
)
chunk
=
AudioChunk
(
input_audio
=
RawAudio
.
from_audio
(
audio_item
))
chunk
=
AudioChunk
(
input_audio
=
RawAudio
.
from_audio
(
audio_item
))
...
@@ -292,33 +203,26 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
...
@@ -292,33 +203,26 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
# skip validation here
# skip validation here
...
...
def
_
apply
_hf_processor
_mm_only
(
def
_
call
_hf_processor
(
self
,
self
,
mm_items
:
MultiModalDataItems
,
prompt
:
str
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
mm_data
:
Mapping
[
str
,
object
],
tokenization_kwargs
:
Mapping
[
str
,
object
],
mm_kwargs
:
Mapping
[
str
,
object
],
tok_kwargs
:
Mapping
[
str
,
object
],
)
->
BatchFeature
:
)
->
BatchFeature
:
processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
mm_data
=
dict
(
mm_data
)
processor_data
,
passthrough_data
=
self
.
_get_hf_mm_data
(
mm_items
)
audios
=
mm_data
.
pop
(
"audios"
,
[])
audios
=
processor_data
.
get
(
"audios"
,
[])
if
not
isinstance
(
audios
,
list
):
if
audios
:
audios
=
[
audios
]
# MistralCommonVoxtralProcessor accepts "audio"
mm_data
[
"audio"
]
=
audios
audio_config
=
processor
.
_audio_processor
.
audio_config
audio_tensors
:
list
[
torch
.
Tensor
]
=
[]
return
super
().
_call_hf_processor
(
for
audio
in
audios
:
prompt
=
prompt
,
audio
=
np
.
asarray
(
audio
,
dtype
=
np
.
float32
).
ravel
()
mm_data
=
mm_data
,
if
not
audio_config
.
is_streaming
:
mm_kwargs
=
mm_kwargs
,
audio
=
processor
.
_audio_processor
.
pad
(
tok_kwargs
=
tok_kwargs
,
audio
,
)
processor
.
sampling_rate
,
audio_config
.
is_streaming
,
)
audio_tensors
.
append
(
torch
.
tensor
(
audio
))
result
=
BatchFeature
({
"audio_arrays"
:
audio_tensors
}
if
audio_tensors
else
{})
result
.
update
(
passthrough_data
)
return
result
def
_get_prompt_updates
(
def
_get_prompt_updates
(
self
,
self
,
...
@@ -327,6 +231,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
...
@@ -327,6 +231,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
out_mm_kwargs
:
MultiModalKwargsItems
,
out_mm_kwargs
:
MultiModalKwargsItems
,
)
->
Sequence
[
PromptUpdate
]:
)
->
Sequence
[
PromptUpdate
]:
processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
feature_extractor
=
processor
.
feature_extractor
audio_id
=
processor
.
audio_token_id
audio_id
=
processor
.
audio_token_id
out_mm_data
=
out_mm_kwargs
.
require_data
()
out_mm_data
=
out_mm_kwargs
.
require_data
()
...
@@ -348,7 +253,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
...
@@ -348,7 +253,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
audios
=
mm_items
.
get_items
(
"audio"
,
AudioProcessorItems
)
audios
=
mm_items
.
get_items
(
"audio"
,
AudioProcessorItems
)
audio_len
=
audios
.
get_audio_length
(
item_idx
)
audio_len
=
audios
.
get_audio_length
(
item_idx
)
nb_audio_tokens
=
process
or
.
get_num_audio_tokens
(
audio_len
)
nb_audio_tokens
=
feature_extract
or
.
get_num_audio_tokens
(
audio_len
)
return
[
audio_id
]
*
nb_audio_tokens
return
[
audio_id
]
*
nb_audio_tokens
...
@@ -560,8 +465,8 @@ class VoxtralForConditionalGeneration(
...
@@ -560,8 +465,8 @@ class VoxtralForConditionalGeneration(
This is used for estimating the amount of processing for this audio.
This is used for estimating the amount of processing for this audio.
"""
"""
tokenizer
=
cached_tokenizer_from_config
(
model_config
)
tokenizer
=
cached_tokenizer_from_config
(
model_config
)
adapter
=
VoxtralProcessor
Adapter
(
tokenizer
)
adapter
=
MistralCommon
VoxtralProcessor
(
tokenizer
)
return
adapter
.
get_num_audio_tokens
(
return
adapter
.
feature_extractor
.
get_num_audio_tokens
(
int
(
audio_duration_s
*
stt_config
.
sample_rate
)
int
(
audio_duration_s
*
stt_config
.
sample_rate
)
)
)
...
...
vllm/model_executor/models/voxtral_realtime.py
View file @
d62856b9
...
@@ -8,12 +8,13 @@ from typing import Literal
...
@@ -8,12 +8,13 @@ from typing import Literal
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
mistral_common.audio
import
Audio
from
mistral_common.protocol.instruct.chunk
import
RawAudio
from
mistral_common.protocol.instruct.chunk
import
RawAudio
from
mistral_common.protocol.transcription.request
import
(
from
mistral_common.protocol.transcription.request
import
(
StreamingMode
,
StreamingMode
,
TranscriptionRequest
,
TranscriptionRequest
,
)
)
from
mistral_common.tokens.tokenizers.audio
import
Audio
,
AudioConfig
from
mistral_common.tokens.tokenizers.audio
import
AudioConfig
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
ModelConfig
,
SpeechToTextConfig
,
VllmConfig
from
vllm.config
import
ModelConfig
,
SpeechToTextConfig
,
VllmConfig
...
...
vllm/multimodal/processing/context.py
View file @
d62856b9
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
time
import
time
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
collections.abc
import
Mapping
from
collections.abc
import
Callable
,
Mapping
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
functools
import
cached_property
from
functools
import
cached_property
...
@@ -241,13 +241,13 @@ class InputProcessingContext:
...
@@ -241,13 +241,13 @@ class InputProcessingContext:
def
call_hf_processor
(
def
call_hf_processor
(
self
,
self
,
hf_processor
:
ProcessorMixin
,
hf_processor
:
Callable
[...,
BatchFeature
]
|
ProcessorMixin
,
data
:
Mapping
[
str
,
object
],
data
:
Mapping
[
str
,
object
],
kwargs
:
Mapping
[
str
,
object
]
=
{},
kwargs
:
Mapping
[
str
,
object
]
=
{},
*
,
*
,
num_tries
:
int
=
1
,
num_tries
:
int
=
1
,
max_tries
:
int
=
5
,
max_tries
:
int
=
5
,
)
->
BatchFeature
|
JSONTree
:
)
->
BatchFeature
:
"""
"""
Call `hf_processor` on the prompt `data`
Call `hf_processor` on the prompt `data`
(text, image, audio...) with configurable options `kwargs`.
(text, image, audio...) with configurable options `kwargs`.
...
@@ -300,7 +300,7 @@ class InputProcessingContext:
...
@@ -300,7 +300,7 @@ class InputProcessingContext:
if
isinstance
(
output
,
BatchFeature
):
if
isinstance
(
output
,
BatchFeature
):
output_
=
self
.
_postprocess_output
(
output
.
data
)
output_
=
self
.
_postprocess_output
(
output
.
data
)
return
BatchFeature
(
output_
)
return
BatchFeature
(
output_
)
# type: ignore
logger
.
warning_once
(
logger
.
warning_once
(
"%s did not return `BatchFeature`. "
"%s did not return `BatchFeature`. "
...
@@ -309,7 +309,7 @@ class InputProcessingContext:
...
@@ -309,7 +309,7 @@ class InputProcessingContext:
type
(
hf_processor
).
__name__
,
type
(
hf_processor
).
__name__
,
)
)
return
self
.
_postprocess_output
(
output
)
return
self
.
_postprocess_output
(
output
)
# type: ignore
class
BaseProcessingInfo
:
class
BaseProcessingInfo
:
...
...
vllm/transformers_utils/processor.py
View file @
d62856b9
...
@@ -241,12 +241,13 @@ def get_processor_kwargs_type(
...
@@ -241,12 +241,13 @@ def get_processor_kwargs_type(
call_kwargs_annotations
=
call_kwargs
.
annotation
if
call_kwargs
else
None
call_kwargs_annotations
=
call_kwargs
.
annotation
if
call_kwargs
else
None
# if the processor has explicit kwargs annotation, use it
# if the processor has explicit kwargs annotation, use it
if
call_kwargs_annotations
not
in
(
None
,
inspect
.
_empty
):
if
call_kwargs_annotations
not
in
(
None
,
inspect
.
_empty
):
# noqa: SIM102
# get_type_hints will parse all type annotations at runtime,
# get_type_hints will parse all type annotations at runtime,
# and if an annotation refers to a type or
# and if an annotation refers to a type or
# name that hasn’t been imported or defined, it will raise an error.
# name that hasn’t been imported or defined, it will raise an error.
# So we use __annotations__ to get the raw annotations directly.
# So we use __annotations__ to get the raw annotations directly.
return
get_args
(
call_kwargs_annotations
)[
0
]
if
anno_args
:
=
get_args
(
call_kwargs_annotations
):
return
anno_args
[
0
]
# otherwise, try to get from ProcessorKwargs
# otherwise, try to get from ProcessorKwargs
module_name
=
type
(
processor
).
__module__
module_name
=
type
(
processor
).
__module__
...
@@ -266,7 +267,13 @@ def get_processor_kwargs_keys(
...
@@ -266,7 +267,13 @@ def get_processor_kwargs_keys(
kwargs_cls
:
type
[
processing_utils
.
ProcessingKwargs
],
kwargs_cls
:
type
[
processing_utils
.
ProcessingKwargs
],
)
->
set
[
str
]:
)
->
set
[
str
]:
dynamic_kwargs
:
set
[
str
]
=
set
()
dynamic_kwargs
:
set
[
str
]
=
set
()
modality_kwargs
=
{
"text_kwargs"
,
"images_kwargs"
,
"videos_kwargs"
,
"audio_kwargs"
}
modality_kwargs
=
{
"text_kwargs"
,
"images_kwargs"
,
"videos_kwargs"
,
"audio_kwargs"
,
"common_kwargs"
,
}
try
:
try
:
# get kwargs annotations in processor
# get kwargs annotations in processor
...
...
vllm/transformers_utils/processors/__init__.py
View file @
d62856b9
...
@@ -15,10 +15,14 @@ _CLASS_TO_MODULE: dict[str, str] = {
...
@@ -15,10 +15,14 @@ _CLASS_TO_MODULE: dict[str, str] = {
"DeepseekVLV2Processor"
:
"vllm.transformers_utils.processors.deepseek_vl2"
,
"DeepseekVLV2Processor"
:
"vllm.transformers_utils.processors.deepseek_vl2"
,
"FireRedASR2Processor"
:
"vllm.transformers_utils.processors.fireredasr2"
,
"FireRedASR2Processor"
:
"vllm.transformers_utils.processors.fireredasr2"
,
"FunASRProcessor"
:
"vllm.transformers_utils.processors.funasr"
,
"FunASRProcessor"
:
"vllm.transformers_utils.processors.funasr"
,
"GLM4VProcessor"
:
"vllm.transformers_utils.processors.glm4v"
,
"HunYuanVLProcessor"
:
"vllm.transformers_utils.processors.hunyuan_vl"
,
"HunYuanVLProcessor"
:
"vllm.transformers_utils.processors.hunyuan_vl"
,
"HunYuanVLImageProcessor"
:
"vllm.transformers_utils.processors.hunyuan_vl_image"
,
"HunYuanVLImageProcessor"
:
"vllm.transformers_utils.processors.hunyuan_vl_image"
,
"MistralCommonPixtralProcessor"
:
"vllm.transformers_utils.processors.pixtral"
,
"MistralCommonVoxtralProcessor"
:
"vllm.transformers_utils.processors.voxtral"
,
"OvisProcessor"
:
"vllm.transformers_utils.processors.ovis"
,
"OvisProcessor"
:
"vllm.transformers_utils.processors.ovis"
,
"Ovis2_5Processor"
:
"vllm.transformers_utils.processors.ovis2_5"
,
"Ovis2_5Processor"
:
"vllm.transformers_utils.processors.ovis2_5"
,
"QwenVLProcessor"
:
"vllm.transformers_utils.processors.qwen_vl"
,
"Qwen3ASRProcessor"
:
"vllm.transformers_utils.processors.qwen3_asr"
,
"Qwen3ASRProcessor"
:
"vllm.transformers_utils.processors.qwen3_asr"
,
}
}
...
@@ -28,10 +32,14 @@ __all__ = [
...
@@ -28,10 +32,14 @@ __all__ = [
"DeepseekVLV2Processor"
,
"DeepseekVLV2Processor"
,
"FireRedASR2Processor"
,
"FireRedASR2Processor"
,
"FunASRProcessor"
,
"FunASRProcessor"
,
"GLM4VProcessor"
,
"HunYuanVLProcessor"
,
"HunYuanVLProcessor"
,
"HunYuanVLImageProcessor"
,
"HunYuanVLImageProcessor"
,
"MistralCommonPixtralProcessor"
,
"MistralCommonVoxtralProcessor"
,
"OvisProcessor"
,
"OvisProcessor"
,
"Ovis2_5Processor"
,
"Ovis2_5Processor"
,
"QwenVLProcessor"
,
"Qwen3ASRProcessor"
,
"Qwen3ASRProcessor"
,
]
]
...
...
vllm/transformers_utils/processors/glm4v.py
0 → 100644
View file @
d62856b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
transformers
import
PreTrainedTokenizer
from
transformers.image_processing_utils_fast
import
BaseImageProcessorFast
from
transformers.image_utils
import
PILImageResampling
from
transformers.processing_utils
import
ProcessorMixin
class
GLM4VImageProcessorFast
(
BaseImageProcessorFast
):
"""
Port of https://huggingface.co/zai-org/glm-4v-9b/blob/main/tokenization_chatglm.py#L177
to HF Transformers.
"""
resample
=
PILImageResampling
.
BICUBIC
image_mean
=
[
0.48145466
,
0.4578275
,
0.40821073
]
image_std
=
[
0.26862954
,
0.26130258
,
0.27577711
]
size
=
{
"height"
:
1120
,
"width"
:
1120
}
do_resize
=
True
do_rescale
=
True
do_normalize
=
True
class
GLM4VProcessor
(
ProcessorMixin
):
attributes
=
[
"image_processor"
,
"tokenizer"
]
def
__init__
(
self
,
tokenizer
:
PreTrainedTokenizer
,
image_size
:
int
,
)
->
None
:
self
.
tokenizer
=
tokenizer
self
.
image_processor
=
GLM4VImageProcessorFast
(
size
=
{
"width"
:
image_size
,
"height"
:
image_size
}
)
vllm/transformers_utils/processors/pixtral.py
0 → 100644
View file @
d62856b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
mistral_common.protocol.instruct.chunk
import
ImageChunk
from
mistral_common.tokens.tokenizers.multimodal
import
ImageEncoder
from
PIL
import
Image
from
transformers
import
BatchFeature
,
ProcessorMixin
,
TensorType
from
transformers.audio_utils
import
AudioInput
from
transformers.image_utils
import
ImageInput
from
transformers.tokenization_utils_base
import
PreTokenizedInput
,
TextInput
from
transformers.video_utils
import
VideoInput
from
vllm.tokenizers.mistral
import
MistralTokenizer
class
MistralCommonImageProcessor
:
"""
Provide a HF-compatible interface for
`mistral_common.tokens.tokenizers.multimodal.ImageEncoder`.
"""
def
__init__
(
self
,
mm_encoder
:
ImageEncoder
)
->
None
:
self
.
mm_encoder
=
mm_encoder
def
__call__
(
self
,
images
:
ImageInput
,
return_tensors
:
str
|
TensorType
|
None
=
None
,
**
kwargs
,
)
->
BatchFeature
:
images_lst
=
[
images
]
if
not
isinstance
(
images
,
list
)
else
images
images_processed
=
list
[
torch
.
Tensor
]()
for
image
in
images_lst
:
image_inputs
=
self
.
mm_encoder
(
ImageChunk
(
image
=
image
))
image_processed
=
torch
.
tensor
(
image_inputs
.
image
)
images_processed
.
append
(
image_processed
)
return
BatchFeature
({
"images"
:
images_processed
},
tensor_type
=
return_tensors
)
def
get_number_of_image_patches
(
self
,
height
:
int
,
width
:
int
,
)
->
tuple
[
int
,
int
,
int
]:
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
))
ncols
,
nrows
=
self
.
mm_encoder
.
_image_to_num_tokens
(
image
)
return
ncols
*
nrows
,
nrows
,
ncols
class
MistralCommonPixtralProcessor
(
ProcessorMixin
):
attributes
=
[
"image_processor"
,
"tokenizer"
]
def
__init__
(
self
,
tokenizer
:
MistralTokenizer
)
->
None
:
self
.
tokenizer
=
tokenizer
.
transformers_tokenizer
self
.
image_processor
=
MistralCommonImageProcessor
(
tokenizer
.
instruct
.
mm_encoder
)
self
.
_image_special_ids
=
self
.
image_processor
.
mm_encoder
.
special_ids
@
property
def
image_break_id
(
self
)
->
int
:
return
self
.
_image_special_ids
.
img_break
@
property
def
image_token_id
(
self
)
->
int
:
return
self
.
_image_special_ids
.
img
@
property
def
image_end_id
(
self
)
->
int
:
return
self
.
_image_special_ids
.
img_end
def
__call__
(
self
,
images
:
ImageInput
|
None
=
None
,
text
:
TextInput
|
PreTokenizedInput
|
list
[
TextInput
]
|
list
[
PreTokenizedInput
]
|
None
=
None
,
videos
:
VideoInput
|
None
=
None
,
audio
:
AudioInput
|
None
=
None
,
**
kwargs
,
):
if
images
is
None
and
text
is
None
and
videos
is
None
and
audio
is
None
:
raise
ValueError
(
f
"You need to provide at least one input to "
f
"call
{
self
.
__class__
.
__name__
}
"
)
kwargs
=
self
.
_merge_kwargs
(
self
.
valid_processor_kwargs
,
tokenizer_init_kwargs
=
{},
**
kwargs
,
)
kwargs
[
"text_kwargs"
][
"return_tensors"
]
=
"pt"
kwargs
[
"images_kwargs"
][
"return_tensors"
]
=
None
# Avoid padding issue
attribute_to_kwargs
=
{
"tokenizer"
:
(
text
,
"text_kwargs"
),
"image_processor"
:
(
images
,
"images_kwargs"
),
"video_processor"
:
(
videos
,
"videos_kwargs"
),
"feature_extractor"
:
(
audio
,
"audio_kwargs"
),
}
outputs
=
{}
for
attribute_name
in
self
.
attributes
:
attribute
=
getattr
(
self
,
attribute_name
,
None
)
input_data
,
input_kwargs
=
attribute_to_kwargs
[
attribute_name
]
if
input_data
is
not
None
and
attribute
is
not
None
:
attribute_output
=
attribute
(
input_data
,
**
kwargs
[
input_kwargs
])
outputs
.
update
(
attribute_output
)
return
BatchFeature
(
outputs
)
vllm/transformers_utils/processors/qwen_vl.py
0 → 100644
View file @
d62856b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
transformers.image_processing_utils_fast
import
BaseImageProcessorFast
from
transformers.image_utils
import
PILImageResampling
from
transformers.processing_utils
import
ProcessorMixin
from
vllm.tokenizers.qwen_vl
import
QwenVLTokenizer
class
QwenVLImageProcessorFast
(
BaseImageProcessorFast
):
"""
Port of https://huggingface.co/Qwen/Qwen-VL/blob/main/visual.py#L354
to HF Transformers.
"""
resample
=
PILImageResampling
.
BICUBIC
image_mean
=
[
0.48145466
,
0.4578275
,
0.40821073
]
image_std
=
[
0.26862954
,
0.26130258
,
0.27577711
]
size
=
{
"height"
:
448
,
"width"
:
448
}
do_resize
=
True
do_rescale
=
True
do_normalize
=
True
class
QwenVLProcessor
(
ProcessorMixin
):
attributes
=
[
"image_processor"
,
"tokenizer"
]
def
__init__
(
self
,
tokenizer
:
QwenVLTokenizer
,
image_size
:
int
,
)
->
None
:
self
.
tokenizer
=
tokenizer
self
.
image_processor
=
QwenVLImageProcessorFast
(
size
=
{
"width"
:
image_size
,
"height"
:
image_size
}
)
@
property
def
image_start_tag
(
self
)
->
str
:
return
self
.
tokenizer
.
image_start_tag
# type: ignore[attr-defined]
@
property
def
image_end_tag
(
self
)
->
str
:
return
self
.
tokenizer
.
image_end_tag
# type: ignore[attr-defined]
@
property
def
image_pad_tag
(
self
)
->
str
:
return
self
.
tokenizer
.
image_pad_tag
# type: ignore[attr-defined]
vllm/transformers_utils/processors/voxtral.py
0 → 100644
View file @
d62856b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
math
import
ceil
import
numpy
as
np
import
torch
from
mistral_common.tokens.tokenizers.audio
import
AudioEncoder
from
transformers
import
BatchFeature
,
ProcessorMixin
,
TensorType
from
transformers.audio_utils
import
AudioInput
from
transformers.image_utils
import
ImageInput
from
transformers.tokenization_utils_base
import
PreTokenizedInput
,
TextInput
from
transformers.video_utils
import
VideoInput
from
vllm.tokenizers.mistral
import
MistralTokenizer
class
MistralCommonFeatureExtractor
:
"""
Provide a HF-compatible interface for
`mistral_common.tokens.tokenizers.multimodal.AudioEncoder`.
"""
def
__init__
(
self
,
audio_encoder
:
AudioEncoder
)
->
None
:
self
.
audio_encoder
=
audio_encoder
@
property
def
sampling_rate
(
self
):
return
self
.
audio_encoder
.
audio_config
.
sampling_rate
@
property
def
frame_rate
(
self
):
return
self
.
audio_encoder
.
audio_config
.
frame_rate
def
__call__
(
self
,
audios
:
AudioInput
,
return_tensors
:
str
|
TensorType
|
None
=
None
,
**
kwargs
,
)
->
BatchFeature
:
audios_lst
=
[
audios
]
if
not
isinstance
(
audios
,
list
)
else
audios
audios_processed
=
list
[
torch
.
Tensor
]()
for
audio
in
audios_lst
:
audio
=
np
.
asarray
(
audio
,
dtype
=
np
.
float32
).
ravel
()
if
not
self
.
audio_encoder
.
audio_config
.
is_streaming
:
audio
=
self
.
audio_encoder
.
pad
(
audio
,
self
.
sampling_rate
)
audios_processed
.
append
(
torch
.
tensor
(
audio
))
return
BatchFeature
(
{
"audio_arrays"
:
audios_processed
},
tensor_type
=
return_tensors
)
def
get_num_audio_tokens
(
self
,
audio_length
:
int
)
->
int
:
return
ceil
(
audio_length
/
(
self
.
sampling_rate
//
self
.
frame_rate
))
class
MistralCommonVoxtralProcessor
(
ProcessorMixin
):
attributes
=
[
"feature_extractor"
,
"tokenizer"
]
def
__init__
(
self
,
tokenizer
:
MistralTokenizer
)
->
None
:
self
.
tokenizer
=
tokenizer
.
transformers_tokenizer
self
.
feature_extractor
=
MistralCommonFeatureExtractor
(
tokenizer
.
instruct
.
audio_encoder
)
self
.
_audio_special_ids
=
self
.
feature_extractor
.
audio_encoder
.
special_ids
@
property
def
audio_token_id
(
self
)
->
int
:
return
self
.
_audio_special_ids
.
audio
@
property
def
begin_audio_token_id
(
self
)
->
int
:
return
self
.
_audio_special_ids
.
begin_audio
def
__call__
(
self
,
images
:
ImageInput
|
None
=
None
,
text
:
TextInput
|
PreTokenizedInput
|
list
[
TextInput
]
|
list
[
PreTokenizedInput
]
|
None
=
None
,
videos
:
VideoInput
|
None
=
None
,
audio
:
AudioInput
|
None
=
None
,
**
kwargs
,
):
if
images
is
None
and
text
is
None
and
videos
is
None
and
audio
is
None
:
raise
ValueError
(
f
"You need to provide at least one input to "
f
"call
{
self
.
__class__
.
__name__
}
"
)
kwargs
=
self
.
_merge_kwargs
(
self
.
valid_processor_kwargs
,
tokenizer_init_kwargs
=
{},
**
kwargs
,
)
kwargs
[
"text_kwargs"
][
"return_tensors"
]
=
"pt"
kwargs
[
"audio_kwargs"
][
"return_tensors"
]
=
None
# Avoid padding issue
attribute_to_kwargs
=
{
"tokenizer"
:
(
text
,
"text_kwargs"
),
"image_processor"
:
(
images
,
"images_kwargs"
),
"video_processor"
:
(
videos
,
"videos_kwargs"
),
"feature_extractor"
:
(
audio
,
"audio_kwargs"
),
}
outputs
=
{}
for
attribute_name
in
self
.
attributes
:
attribute
=
getattr
(
self
,
attribute_name
,
None
)
input_data
,
input_kwargs
=
attribute_to_kwargs
[
attribute_name
]
if
input_data
is
not
None
and
attribute
is
not
None
:
attribute_output
=
attribute
(
input_data
,
**
kwargs
[
input_kwargs
])
outputs
.
update
(
attribute_output
)
return
BatchFeature
(
outputs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment