Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
630eb5b5
Unverified
Commit
630eb5b5
authored
Jan 19, 2025
by
Cyrus Leung
Committed by
GitHub
Jan 18, 2025
Browse files
[Bugfix] Fix multi-modal processors for transformers 4.48 (#12187)
parent
4e94951b
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
199 additions
and
36 deletions
+199
-36
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+24
-1
vllm/model_executor/models/qwen2_audio.py
vllm/model_executor/models/qwen2_audio.py
+49
-23
vllm/model_executor/models/ultravox.py
vllm/model_executor/models/ultravox.py
+1
-8
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+5
-4
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+2
-0
vllm/transformers_utils/configs/aria.py
vllm/transformers_utils/configs/aria.py
+118
-0
No files found.
vllm/model_executor/models/llava.py
View file @
630eb5b5
...
@@ -5,9 +5,11 @@ from typing import (Final, Iterable, List, Literal, Mapping, Optional,
...
@@ -5,9 +5,11 @@ from typing import (Final, Iterable, List, Literal, Mapping, Optional,
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
packaging.version
import
Version
from
transformers
import
(
BatchFeature
,
CLIPVisionConfig
,
LlavaConfig
,
from
transformers
import
(
BatchFeature
,
CLIPVisionConfig
,
LlavaConfig
,
PixtralVisionConfig
,
PretrainedConfig
,
PixtralVisionConfig
,
PretrainedConfig
,
SiglipVisionConfig
)
SiglipVisionConfig
)
from
transformers
import
__version__
as
TRANSFORMERS_VERSION
from
transformers.models.llava
import
LlavaProcessor
from
transformers.models.llava
import
LlavaProcessor
from
transformers.models.pixtral
import
PixtralProcessor
from
transformers.models.pixtral
import
PixtralProcessor
...
@@ -716,6 +718,27 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -716,6 +718,27 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return
loader
.
load_weights
(
weights
)
return
loader
.
load_weights
(
weights
)
class
MantisProcessingInfo
(
LlavaProcessingInfo
):
def
get_hf_processor
(
self
):
hf_config
=
self
.
get_hf_config
()
vision_info
=
self
.
get_vision_encoder_info
()
if
Version
(
TRANSFORMERS_VERSION
)
<
Version
(
"4.48"
):
# BUG: num_additional_image_tokens = 0 but treated as 1,
# so we set vision_feature_select_strategy to None to offset this
vision_feature_select_strategy
=
None
else
:
# FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150
vision_feature_select_strategy
=
hf_config
.
vision_feature_select_strategy
# noqa: E501
return
self
.
ctx
.
get_hf_processor
(
LlavaProcessor
,
patch_size
=
vision_info
.
get_patch_size
(),
vision_feature_select_strategy
=
vision_feature_select_strategy
,
)
class
MantisMultiModalProcessor
(
LlavaMultiModalProcessor
):
class
MantisMultiModalProcessor
(
LlavaMultiModalProcessor
):
def
apply
(
def
apply
(
...
@@ -794,7 +817,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
...
@@ -794,7 +817,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
# To use this model, please use
# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
@
MULTIMODAL_REGISTRY
.
register_processor
(
MantisMultiModalProcessor
,
@
MULTIMODAL_REGISTRY
.
register_processor
(
MantisMultiModalProcessor
,
info
=
Llava
ProcessingInfo
,
info
=
Mantis
ProcessingInfo
,
dummy_inputs
=
LlavaDummyInputsBuilder
)
dummy_inputs
=
LlavaDummyInputsBuilder
)
class
MantisForConditionalGeneration
(
LlavaForConditionalGeneration
):
class
MantisForConditionalGeneration
(
LlavaForConditionalGeneration
):
pass
pass
vllm/model_executor/models/qwen2_audio.py
View file @
630eb5b5
...
@@ -36,8 +36,9 @@ from vllm.config import VllmConfig
...
@@ -36,8 +36,9 @@ from vllm.config import VllmConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalFieldConfig
,
MultiModalKwargs
,
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
NestedTensors
)
MultiModalInputsV2
,
MultiModalKwargs
,
NestedTensors
,
PlaceholderRange
)
from
vllm.multimodal.parse
import
(
AudioProcessorItems
,
MultiModalDataItems
,
from
vllm.multimodal.parse
import
(
AudioProcessorItems
,
MultiModalDataItems
,
MultiModalDataParser
)
MultiModalDataParser
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
...
@@ -153,29 +154,24 @@ class Qwen2AudioMultiModalProcessor(
...
@@ -153,29 +154,24 @@ class Qwen2AudioMultiModalProcessor(
mm_data
:
Mapping
[
str
,
object
],
mm_data
:
Mapping
[
str
,
object
],
mm_kwargs
:
Mapping
[
str
,
Any
],
mm_kwargs
:
Mapping
[
str
,
Any
],
)
->
BatchFeature
:
)
->
BatchFeature
:
mm_data
=
dict
(
mm_data
)
# Text-only input not supported in composite processor
audios
=
mm_data
.
pop
(
"audios"
,
[])
if
not
mm_data
or
not
mm_data
.
get
(
"audios"
,
[]):
prompt_ids
=
self
.
info
.
get_tokenizer
().
encode
(
prompt
)
if
audios
:
prompt_ids
=
self
.
_apply_hf_processor_tokens_only
(
prompt_ids
)
mm_data
[
"audios"
]
=
audios
return
BatchFeature
(
dict
(
input_ids
=
[
prompt_ids
]),
tensor_type
=
"pt"
)
feature_extractor
=
self
.
info
.
get_feature_extractor
(
**
mm_kwargs
)
feature_extractor
=
self
.
info
.
get_feature_extractor
(
**
mm_kwargs
)
mm_kwargs
=
dict
(
mm_kwargs
=
dict
(
**
mm_kwargs
,
**
mm_kwargs
,
sampling_rate
=
feature_extractor
.
sampling_rate
,
sampling_rate
=
feature_extractor
.
sampling_rate
,
)
)
else
:
# NOTE: WhisperFeatureExtractor cannot handle empty list of audios
pass
processed_outputs
=
super
().
_call_hf_processor
(
return
super
().
_call_hf_processor
(
prompt
=
prompt
,
prompt
=
prompt
,
mm_data
=
mm_data
,
mm_data
=
mm_data
,
mm_kwargs
=
mm_kwargs
,
mm_kwargs
=
mm_kwargs
,
)
)
return
processed_outputs
def
_get_mm_fields_config
(
def
_get_mm_fields_config
(
self
,
self
,
hf_inputs
:
BatchFeature
,
hf_inputs
:
BatchFeature
,
...
@@ -192,8 +188,14 @@ class Qwen2AudioMultiModalProcessor(
...
@@ -192,8 +188,14 @@ class Qwen2AudioMultiModalProcessor(
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
out_mm_kwargs
:
MultiModalKwargs
,
out_mm_kwargs
:
MultiModalKwargs
,
)
->
list
[
PromptReplacement
]:
)
->
list
[
PromptReplacement
]:
hf_config
=
self
.
info
.
get_hf_config
()
processor
=
self
.
info
.
get_hf_processor
()
placeholder
=
hf_config
.
audio_token_index
# Use getattr with default to be compatible with transformers<4.48
audio_token
=
getattr
(
processor
,
"audio_token"
,
"<|AUDIO|>"
)
audio_bos_token
=
getattr
(
processor
,
"audio_bos_token"
,
"<|audio_bos|>"
)
audio_eos_token
=
getattr
(
processor
,
"audio_eos_token"
,
"<|audio_eos|>"
)
feature_attention_mask
=
out_mm_kwargs
.
get
(
"feature_attention_mask"
)
feature_attention_mask
=
out_mm_kwargs
.
get
(
"feature_attention_mask"
)
if
feature_attention_mask
is
None
:
if
feature_attention_mask
is
None
:
...
@@ -214,12 +216,16 @@ class Qwen2AudioMultiModalProcessor(
...
@@ -214,12 +216,16 @@ class Qwen2AudioMultiModalProcessor(
f
"The audio
{
audio
}
(len=
{
len
(
audio
)
}
) is too short "
f
"The audio
{
audio
}
(len=
{
len
(
audio
)
}
) is too short "
"to be represented inside the model"
)
"to be represented inside the model"
)
return
[
placeholder
]
*
num_placeholders
return
""
.
join
([
audio_bos_token
,
audio_token
*
num_placeholders
,
audio_eos_token
,
])
return
[
return
[
PromptReplacement
(
PromptReplacement
(
modality
=
"audio"
,
modality
=
"audio"
,
target
=
[
placeholder
]
,
target
=
audio_token
,
replacement
=
get_replacement_qwen2_audio
,
replacement
=
get_replacement_qwen2_audio
,
)
)
]
]
...
@@ -234,6 +240,26 @@ class Qwen2AudioMultiModalProcessor(
...
@@ -234,6 +240,26 @@ class Qwen2AudioMultiModalProcessor(
# tokens than the number of audio items)
# tokens than the number of audio items)
return
not
hasattr
(
self
.
info
.
get_hf_processor
(),
"audio_token"
)
return
not
hasattr
(
self
.
info
.
get_hf_processor
(),
"audio_token"
)
def
apply
(
self
,
prompt
:
Union
[
str
,
list
[
int
]],
mm_data
:
MultiModalDataDict
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
MultiModalInputsV2
:
result
=
super
().
apply
(
prompt
,
mm_data
,
hf_processor_mm_kwargs
)
# Only <|AUDIO|> tokens should be considered as placeholders,
# so we ignore the audio_bos_token and audio_eos_token
result
[
"mm_placeholders"
]
=
{
modality
:
[
PlaceholderRange
(
offset
=
p
[
"offset"
]
+
1
,
length
=
p
[
"length"
]
-
2
)
for
p
in
ps
]
for
modality
,
ps
in
result
[
"mm_placeholders"
].
items
()
}
return
result
@
MULTIMODAL_REGISTRY
.
register_processor
(
@
MULTIMODAL_REGISTRY
.
register_processor
(
Qwen2AudioMultiModalProcessor
,
Qwen2AudioMultiModalProcessor
,
...
...
vllm/model_executor/models/ultravox.py
View file @
630eb5b5
...
@@ -137,7 +137,7 @@ class UltravoxMultiModalProcessor(
...
@@ -137,7 +137,7 @@ class UltravoxMultiModalProcessor(
mm_kwargs
:
Mapping
[
str
,
object
],
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
BatchFeature
:
)
->
BatchFeature
:
# Text-only input not supported in composite processor
# Text-only input not supported in composite processor
if
not
mm_data
:
if
not
mm_data
or
not
mm_data
.
get
(
"audios"
,
[])
:
prompt_ids
=
self
.
info
.
get_tokenizer
().
encode
(
prompt
)
prompt_ids
=
self
.
info
.
get_tokenizer
().
encode
(
prompt
)
prompt_ids
=
self
.
_apply_hf_processor_tokens_only
(
prompt_ids
)
prompt_ids
=
self
.
_apply_hf_processor_tokens_only
(
prompt_ids
)
return
BatchFeature
(
dict
(
input_ids
=
[
prompt_ids
]),
tensor_type
=
"pt"
)
return
BatchFeature
(
dict
(
input_ids
=
[
prompt_ids
]),
tensor_type
=
"pt"
)
...
@@ -146,13 +146,6 @@ class UltravoxMultiModalProcessor(
...
@@ -146,13 +146,6 @@ class UltravoxMultiModalProcessor(
audios
=
mm_data
.
pop
(
"audios"
,
[])
audios
=
mm_data
.
pop
(
"audios"
,
[])
assert
isinstance
(
audios
,
list
)
assert
isinstance
(
audios
,
list
)
if
not
audios
:
return
super
().
_call_hf_processor
(
prompt
=
prompt
,
mm_data
=
mm_data
,
mm_kwargs
=
mm_kwargs
,
)
feature_extractor
=
self
.
info
.
get_feature_extractor
()
feature_extractor
=
self
.
info
.
get_feature_extractor
()
mm_kwargs
=
dict
(
mm_kwargs
=
dict
(
**
mm_kwargs
,
**
mm_kwargs
,
...
...
vllm/transformers_utils/config.py
View file @
630eb5b5
...
@@ -22,10 +22,10 @@ from vllm.envs import VLLM_USE_MODELSCOPE
...
@@ -22,10 +22,10 @@ from vllm.envs import VLLM_USE_MODELSCOPE
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
# yapf conflicts with isort for this block
# yapf conflicts with isort for this block
# yapf: disable
# yapf: disable
from
vllm.transformers_utils.configs
import
(
ChatGLM
Config
,
C
ohere2
Config
,
from
vllm.transformers_utils.configs
import
(
Aria
Config
,
C
hatGLM
Config
,
Dbrx
Config
,
D
eepseekVLV2
Config
,
Cohere2
Config
,
D
brx
Config
,
EAGLE
Config
,
E
xaone
Config
,
DeepseekVLV2
Config
,
E
AGLE
Config
,
H2OVLChatConfig
,
ExaoneConfig
,
H2OVLChatConfig
,
InternVLChatConfig
,
JAISConfig
,
InternVLChatConfig
,
JAISConfig
,
MedusaConfig
,
MllamaConfig
,
MedusaConfig
,
MllamaConfig
,
MLPSpeculatorConfig
,
MPTConfig
,
MLPSpeculatorConfig
,
MPTConfig
,
...
@@ -52,6 +52,7 @@ _CONFIG_REGISTRY_OVERRIDE_HF: Dict[str, Type[PretrainedConfig]] = {
...
@@ -52,6 +52,7 @@ _CONFIG_REGISTRY_OVERRIDE_HF: Dict[str, Type[PretrainedConfig]] = {
}
}
_CONFIG_REGISTRY
:
Dict
[
str
,
Type
[
PretrainedConfig
]]
=
{
_CONFIG_REGISTRY
:
Dict
[
str
,
Type
[
PretrainedConfig
]]
=
{
"aria"
:
AriaConfig
,
"chatglm"
:
ChatGLMConfig
,
"chatglm"
:
ChatGLMConfig
,
"cohere2"
:
Cohere2Config
,
"cohere2"
:
Cohere2Config
,
"dbrx"
:
DbrxConfig
,
"dbrx"
:
DbrxConfig
,
...
...
vllm/transformers_utils/configs/__init__.py
View file @
630eb5b5
from
vllm.transformers_utils.configs.aria
import
AriaConfig
from
vllm.transformers_utils.configs.chatglm
import
ChatGLMConfig
from
vllm.transformers_utils.configs.chatglm
import
ChatGLMConfig
from
vllm.transformers_utils.configs.cohere2
import
Cohere2Config
from
vllm.transformers_utils.configs.cohere2
import
Cohere2Config
from
vllm.transformers_utils.configs.dbrx
import
DbrxConfig
from
vllm.transformers_utils.configs.dbrx
import
DbrxConfig
...
@@ -23,6 +24,7 @@ from vllm.transformers_utils.configs.telechat2 import Telechat2Config
...
@@ -23,6 +24,7 @@ from vllm.transformers_utils.configs.telechat2 import Telechat2Config
from
vllm.transformers_utils.configs.ultravox
import
UltravoxConfig
from
vllm.transformers_utils.configs.ultravox
import
UltravoxConfig
__all__
=
[
__all__
=
[
"AriaConfig"
,
"ChatGLMConfig"
,
"ChatGLMConfig"
,
"Cohere2Config"
,
"Cohere2Config"
,
"DbrxConfig"
,
"DbrxConfig"
,
...
...
vllm/transformers_utils/configs/aria.py
View file @
630eb5b5
# Copyright 2024 Rhymes AI. All rights reserved.
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from
typing
import
Mapping
from
transformers
import
PretrainedConfig
from
transformers.models.idefics2.configuration_idefics2
import
(
from
transformers.models.idefics2.configuration_idefics2
import
(
Idefics2VisionConfig
)
Idefics2VisionConfig
)
from
transformers.models.llama.configuration_llama
import
LlamaConfig
from
transformers.models.llama.configuration_llama
import
LlamaConfig
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
class
AriaVisionConfig
(
Idefics2VisionConfig
):
class
AriaVisionConfig
(
Idefics2VisionConfig
):
model_type
=
"aria_vision_model"
model_type
=
"aria_vision_model"
...
@@ -45,3 +70,96 @@ class AriaMoELMConfig(LlamaConfig):
...
@@ -45,3 +70,96 @@ class AriaMoELMConfig(LlamaConfig):
self
.
moe_num_experts
=
moe_num_experts
self
.
moe_num_experts
=
moe_num_experts
self
.
moe_topk
=
moe_topk
self
.
moe_topk
=
moe_topk
self
.
moe_num_shared_experts
=
moe_num_shared_experts
self
.
moe_num_shared_experts
=
moe_num_shared_experts
class
AriaConfig
(
PretrainedConfig
):
"""
Configuration class for Aria model.
This class handles the configuration for both vision and text components of
the Aria model,
as well as additional parameters for image token handling and projector
mapping.
Args:
vision_config (AriaVisionConfig or dict): Configuration for the vision
component.
text_config (AriaMoELMConfig or dict): Configuration for the text
component.
projector_patch_to_query_dict (dict): Mapping of patch sizes to query
dimensions.
ignore_index (int): Index to ignore in loss calculation.
image_token_index (int): Index used to represent image tokens.
**kwargs: Additional keyword arguments passed to the parent class.
Attributes:
model_type (str): Type of the model, set to "aria".
is_composition (bool): Whether the model is a composition of multiple
components.
ignore_index (int): Index to ignore in loss calculation.
image_token_index (int): Index used to represent image tokens.
projector_patch_to_query_dict (dict): Mapping of patch sizes to query
dimensions.
vision_config (AriaVisionConfig): Configuration for the vision
component.
text_config (AriaMoELMConfig): Configuration for the text component.
"""
model_type
=
"aria"
is_composition
=
False
def
__init__
(
self
,
vision_config
:
AriaVisionConfig
=
AriaVisionConfig
(),
# noqa: B008
text_config
:
AriaMoELMConfig
=
AriaMoELMConfig
(),
# noqa: B008
projector_patch_to_query_dict
:
Mapping
[
int
,
int
]
=
{
1225
:
128
,
4900
:
256
,
},
ignore_index
=-
100
,
image_token_index
=
32000
,
tie_word_embeddings
=
False
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
ignore_index
=
ignore_index
self
.
image_token_index
=
image_token_index
self
.
tie_word_embeddings
=
tie_word_embeddings
attn_implementation
=
kwargs
.
pop
(
"attn_implementation"
,
None
)
# Set the default attention implementation to flash_attention_2 if not
# specified
self
.
_attn_implementation
=
(
"flash_attention_2"
if
attn_implementation
is
None
else
attn_implementation
)
# Convert the keys and values of projector_patch_to_query_dict to
# integers
# This ensures consistency even if they were provided as strings
self
.
projector_patch_to_query_dict
=
{
int
(
k
):
int
(
v
)
for
k
,
v
in
projector_patch_to_query_dict
.
items
()
}
if
isinstance
(
vision_config
,
dict
)
and
"model_type"
in
vision_config
:
vision_config
=
AriaVisionConfig
(
**
vision_config
)
if
attn_implementation
is
None
:
vision_attn_implementation
=
"flash_attention_2"
elif
attn_implementation
==
"sdpa"
:
logger
.
warning
(
"SDPA is not supported for vit, using "
"flash_attention_2 instead"
)
vision_attn_implementation
=
"flash_attention_2"
else
:
vision_attn_implementation
=
attn_implementation
vision_config
.
_attn_implementation
=
vision_attn_implementation
self
.
vision_config
=
vision_config
if
isinstance
(
text_config
,
dict
)
and
"model_type"
in
text_config
:
text_attn_implementation
=
(
"sdpa"
if
attn_implementation
is
None
else
attn_implementation
)
text_config
=
AriaMoELMConfig
(
**
text_config
)
text_config
.
_attn_implementation
=
text_attn_implementation
self
.
text_config
=
text_config
# This is needed for the static kv cache
self
.
num_hidden_layers
=
self
.
text_config
.
num_hidden_layers
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment