Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b37d8279
Unverified
Commit
b37d8279
authored
Jan 20, 2025
by
Cyrus Leung
Committed by
GitHub
Jan 20, 2025
Browse files
[Model] Upgrade Aria to transformers 4.48 (#12203)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
3127e975
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
178 additions
and
379 deletions
+178
-379
examples/offline_inference/vision_language.py
examples/offline_inference/vision_language.py
+0
-3
tests/models/decoder_only/vision_language/test_models.py
tests/models/decoder_only/vision_language/test_models.py
+2
-5
tests/models/multimodal/processing/test_common.py
tests/models/multimodal/processing/test_common.py
+5
-7
tests/models/registry.py
tests/models/registry.py
+62
-5
tests/models/test_initialization.py
tests/models/test_initialization.py
+2
-12
tests/models/test_registry.py
tests/models/test_registry.py
+3
-0
vllm/model_executor/models/aria.py
vllm/model_executor/models/aria.py
+100
-175
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+4
-5
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+0
-2
vllm/transformers_utils/configs/aria.py
vllm/transformers_utils/configs/aria.py
+0
-165
No files found.
examples/offline_inference/vision_language.py
View file @
b37d8279
...
...
@@ -26,11 +26,8 @@ def run_aria(question: str, modality: str):
# NOTE: Need L40 (or equivalent) to avoid OOM
llm
=
LLM
(
model
=
model_name
,
tokenizer_mode
=
"slow"
,
dtype
=
"bfloat16"
,
max_model_len
=
4096
,
max_num_seqs
=
2
,
trust_remote_code
=
True
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
)
prompt
=
(
f
"<|im_start|>user
\n
<fim_prefix><|img|><fim_suffix>
\n
{
question
}
"
...
...
tests/models/decoder_only/vision_language/test_models.py
View file @
b37d8279
...
...
@@ -10,7 +10,6 @@ from typing import Type
import
pytest
from
transformers
import
AutoModelForVision2Seq
from
transformers
import
__version__
as
TRANSFORMERS_VERSION
from
transformers.utils
import
is_flash_attn_2_available
from
vllm.platforms
import
current_platform
from
vllm.utils
import
identity
...
...
@@ -140,9 +139,7 @@ VLM_TEST_SETTINGS = {
#### Extended model tests
"aria"
:
VLMTestInfo
(
models
=
[
"rhymes-ai/Aria"
],
tokenizer_mode
=
"slow"
,
test_type
=
(
VLMTestType
.
IMAGE
,
VLMTestType
.
MULTI_IMAGE
),
dtype
=
"bfloat16"
,
prompt_formatter
=
lambda
img_prompt
:
f
"<|im_start|>user
\n
{
img_prompt
}
<|im_end|>
\n
<|im_start|>assistant
\n
"
,
# noqa: E501
img_idx_to_prompt
=
lambda
idx
:
"<fim_prefix><|img|><fim_suffix>
\n
"
,
max_model_len
=
4096
,
...
...
@@ -158,8 +155,8 @@ VLM_TEST_SETTINGS = {
max_tokens
=
64
,
marks
=
[
pytest
.
mark
.
skipif
(
not
is_flash_attn_2_available
()
,
reason
=
"
M
odel
needs flash-attn for numeric convergence.
"
,
TRANSFORMERS_VERSION
<
"4.48.0"
,
reason
=
"
HF m
odel
requires transformers>=4.48.0
"
,
),
large_gpu_mark
(
min_gb
=
64
),
],
...
...
tests/models/multimodal/processing/test_common.py
View file @
b37d8279
...
...
@@ -11,6 +11,7 @@ from vllm.multimodal.processing import ProcessingCache
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
....multimodal.utils
import
random_audio
,
random_image
,
random_video
from
...registry
import
HF_EXAMPLE_MODELS
def
_test_processing_correctness
(
...
...
@@ -20,12 +21,9 @@ def _test_processing_correctness(
num_batches
:
int
,
simplify_rate
:
float
,
):
if
model_id
==
"TIGER-Lab/Mantis-8B-siglip-llama3"
:
hf_overrides
=
{
"architectures"
:
[
"MantisForConditionalGeneration"
]}
elif
model_id
==
"deepseek-ai/deepseek-vl2-tiny"
:
hf_overrides
=
{
"architectures"
:
[
"DeepseekVLV2ForCausalLM"
]}
else
:
hf_overrides
=
{}
model_info
=
HF_EXAMPLE_MODELS
.
find_hf_info
(
model_id
)
model_info
.
check_available_online
(
on_fail
=
"skip"
)
model_info
.
check_transformers_version
(
on_fail
=
"skip"
)
limit_mm_per_prompt
=
{
modality
:
3
if
supports_multi
else
1
...
...
@@ -41,7 +39,7 @@ def _test_processing_correctness(
seed
=
0
,
dtype
=
"float16"
,
revision
=
None
,
hf_overrides
=
hf_overrides
,
hf_overrides
=
model_info
.
hf_overrides
,
limit_mm_per_prompt
=
limit_mm_per_prompt
,
)
...
...
tests/models/registry.py
View file @
b37d8279
from
dataclasses
import
dataclass
,
field
from
typing
import
AbstractSet
,
Mapping
,
Optional
from
typing
import
AbstractSet
,
Any
,
Literal
,
Mapping
,
Optional
import
pytest
from
packaging.version
import
Version
from
transformers
import
__version__
as
TRANSFORMERS_VERSION
@
dataclass
(
frozen
=
True
)
...
...
@@ -38,6 +42,50 @@ class _HfExamplesInfo:
trust_remote_code
:
bool
=
False
"""The ``trust_remote_code`` level required to load the model."""
hf_overrides
:
dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
"""The ``hf_overrides`` required to load the model."""
def
check_transformers_version
(
self
,
*
,
on_fail
:
Literal
[
"error"
,
"skip"
],
)
->
None
:
"""
If the installed transformers version does not meet the requirements,
perform the given action.
"""
if
self
.
min_transformers_version
is
None
:
return
current_version
=
TRANSFORMERS_VERSION
required_version
=
self
.
min_transformers_version
if
Version
(
current_version
)
<
Version
(
required_version
):
msg
=
(
f
"You have `transformers==
{
current_version
}
` installed, but "
f
"`transformers>=
{
required_version
}
` is required to run this "
"model"
)
if
on_fail
==
"error"
:
raise
RuntimeError
(
msg
)
else
:
pytest
.
skip
(
msg
)
def
check_available_online
(
self
,
*
,
on_fail
:
Literal
[
"error"
,
"skip"
],
)
->
None
:
"""
If the model is not available online, perform the given action.
"""
if
not
self
.
is_available_online
:
msg
=
"Model is not available online"
if
on_fail
==
"error"
:
raise
RuntimeError
(
msg
)
else
:
pytest
.
skip
(
msg
)
# yapf: disable
_TEXT_GENERATION_EXAMPLE_MODELS
=
{
...
...
@@ -48,8 +96,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code
=
True
),
"ArcticForCausalLM"
:
_HfExamplesInfo
(
"Snowflake/snowflake-arctic-instruct"
,
trust_remote_code
=
True
),
"AriaForConditionalGeneration"
:
_HfExamplesInfo
(
"rhymes-ai/Aria"
,
trust_remote_code
=
True
),
"BaiChuanForCausalLM"
:
_HfExamplesInfo
(
"baichuan-inc/Baichuan-7B"
,
trust_remote_code
=
True
),
"BaichuanForCausalLM"
:
_HfExamplesInfo
(
"baichuan-inc/Baichuan2-7B-chat"
,
...
...
@@ -176,6 +222,8 @@ _CROSS_ENCODER_EXAMPLE_MODELS = {
_MULTIMODAL_EXAMPLE_MODELS
=
{
# [Decoder-only]
"AriaForConditionalGeneration"
:
_HfExamplesInfo
(
"rhymes-ai/Aria"
,
min_transformers_version
=
"4.48"
),
"Blip2ForConditionalGeneration"
:
_HfExamplesInfo
(
"Salesforce/blip2-opt-2.7b"
),
# noqa: E501
"ChameleonForConditionalGeneration"
:
_HfExamplesInfo
(
"facebook/chameleon-7b"
),
# noqa: E501
"ChatGLMModel"
:
_HfExamplesInfo
(
"THUDM/glm-4v-9b"
,
...
...
@@ -183,7 +231,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code
=
True
),
"ChatGLMForConditionalGeneration"
:
_HfExamplesInfo
(
"chatglm2-6b"
,
is_available_online
=
False
),
"DeepseekVLV2ForCausalLM"
:
_HfExamplesInfo
(
"deepseek-ai/deepseek-vl2-tiny"
),
# noqa: E501
"DeepseekVLV2ForCausalLM"
:
_HfExamplesInfo
(
"deepseek-ai/deepseek-vl2-tiny"
,
# noqa: E501
hf_overrides
=
{
"architectures"
:
[
"DeepseekVLV2ForCausalLM"
]}),
# noqa: E501
"FuyuForCausalLM"
:
_HfExamplesInfo
(
"adept/fuyu-8b"
),
"H2OVLChatModel"
:
_HfExamplesInfo
(
"h2oai/h2ovl-mississippi-800m"
),
"InternVLChatModel"
:
_HfExamplesInfo
(
"OpenGVLab/InternVL2-1B"
,
...
...
@@ -194,7 +243,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"LlavaNextForConditionalGeneration"
:
_HfExamplesInfo
(
"llava-hf/llava-v1.6-mistral-7b-hf"
),
# noqa: E501
"LlavaNextVideoForConditionalGeneration"
:
_HfExamplesInfo
(
"llava-hf/LLaVA-NeXT-Video-7B-hf"
),
# noqa: E501
"LlavaOnevisionForConditionalGeneration"
:
_HfExamplesInfo
(
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf"
),
# noqa: E501
"MantisForConditionalGeneration"
:
_HfExamplesInfo
(
"TIGER-Lab/Mantis-8B-siglip-llama3"
),
# noqa: E501
"MantisForConditionalGeneration"
:
_HfExamplesInfo
(
"TIGER-Lab/Mantis-8B-siglip-llama3"
,
# noqa: E501
hf_overrides
=
{
"architectures"
:
[
"MantisForConditionalGeneration"
]}),
# noqa: E501
"MiniCPMV"
:
_HfExamplesInfo
(
"openbmb/MiniCPM-Llama3-V-2_5"
,
trust_remote_code
=
True
),
"MolmoForCausalLM"
:
_HfExamplesInfo
(
"allenai/Molmo-7B-D-0924"
,
...
...
@@ -247,5 +297,12 @@ class HfExampleModels:
def
get_hf_info
(
self
,
model_arch
:
str
)
->
_HfExamplesInfo
:
return
self
.
hf_models
[
model_arch
]
def
find_hf_info
(
self
,
model_id
:
str
)
->
_HfExamplesInfo
:
for
info
in
self
.
hf_models
.
values
():
if
info
.
default
==
model_id
:
return
info
raise
ValueError
(
f
"No example model defined for
{
model_id
}
"
)
HF_EXAMPLE_MODELS
=
HfExampleModels
(
_EXAMPLE_MODELS
)
tests/models/test_initialization.py
View file @
b37d8279
from
unittest.mock
import
patch
import
pytest
from
packaging.version
import
Version
from
transformers
import
PretrainedConfig
from
transformers
import
__version__
as
TRANSFORMERS_VERSION
from
vllm
import
LLM
...
...
@@ -13,16 +11,8 @@ from .registry import HF_EXAMPLE_MODELS
@
pytest
.
mark
.
parametrize
(
"model_arch"
,
HF_EXAMPLE_MODELS
.
get_supported_archs
())
def
test_can_initialize
(
model_arch
):
model_info
=
HF_EXAMPLE_MODELS
.
get_hf_info
(
model_arch
)
if
not
model_info
.
is_available_online
:
pytest
.
skip
(
"Model is not available online"
)
if
model_info
.
min_transformers_version
is
not
None
:
current_version
=
TRANSFORMERS_VERSION
required_version
=
model_info
.
min_transformers_version
if
Version
(
current_version
)
<
Version
(
required_version
):
pytest
.
skip
(
f
"You have `transformers==
{
current_version
}
` installed, but "
f
"`transformers>=
{
required_version
}
` is required to run this "
"model"
)
model_info
.
check_available_online
(
on_fail
=
"skip"
)
model_info
.
check_transformers_version
(
on_fail
=
"skip"
)
# Avoid OOM
def
hf_overrides
(
hf_config
:
PretrainedConfig
)
->
PretrainedConfig
:
...
...
tests/models/test_registry.py
View file @
b37d8279
...
...
@@ -21,6 +21,9 @@ from .registry import HF_EXAMPLE_MODELS
@
pytest
.
mark
.
parametrize
(
"model_arch"
,
ModelRegistry
.
get_supported_archs
())
def
test_registry_imports
(
model_arch
):
model_info
=
HF_EXAMPLE_MODELS
.
get_hf_info
(
model_arch
)
model_info
.
check_transformers_version
(
on_fail
=
"skip"
)
# Ensure all model classes can be imported successfully
model_cls
,
_
=
ModelRegistry
.
resolve_model_cls
(
model_arch
)
...
...
vllm/model_executor/models/aria.py
View file @
b37d8279
from
typing
import
(
Callable
,
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
)
from
typing
import
(
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
)
import
torch
import
torch.nn
as
nn
from
transformers
import
BatchFeature
,
PretrainedConfig
from
transformers
import
AriaConfig
,
AriaTextConfig
,
BatchFeature
from
transformers.models.aria.modeling_aria
import
AriaCrossAttention
from
transformers.models.aria.processing_aria
import
AriaProcessor
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
QuantizationConfig
,
VllmConfig
...
...
@@ -26,10 +28,11 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo
,
PromptReplacement
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs.aria
import
(
AriaMoELMConfig
,
AriaVisionConfig
)
from
.idefics2_vision_model
import
Idefics2VisionTransformer
# yapf: disable
from
.idefics2_vision_model
import
(
Idefics2VisionTransformer
as
Idefics3VisionTransformer
)
# yapf: enable
from
.interfaces
import
SupportsMultiModal
from
.llama
import
LlamaDecoderLayer
,
LlamaMLP
,
LlamaModel
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
flatten_bn
,
...
...
@@ -47,87 +50,22 @@ class AriaImagePixelInputs(TypedDict):
"""
class
AriaVisionTransformer
(
Idefics2VisionTransformer
):
"""
AriaVisionTransformer is a modified version of Idefics2VisionTransformer
that replaces the post-layernorm with an identity layer.
"""
class
AriaProjectorMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
:
AriaVisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
(
config
,
quant_config
,
prefix
)
self
.
post_layernorm
=
nn
.
Identity
()
class
AriaVisionModel
(
nn
.
Module
):
config_class
=
AriaVisionConfig
def
__init__
(
self
,
config
:
AriaVisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
*
,
prefix
:
str
=
""
,
in_features
:
int
,
hidden_features
:
int
,
output_dim
:
int
,
)
->
None
:
super
().
__init__
()
self
.
vision_model
=
AriaVisionTransformer
(
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.vision_model"
,
)
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
,
pixel_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
patch_attention_mask
=
self
.
_create_patch_attention_mask
(
pixel_mask
)
vit_oup
=
self
.
vision_model
(
pixel_values
=
pixel_values
,
patch_attention_mask
=
patch_attention_mask
,
)
image_atts
=
self
.
_create_image_attention_mask
(
patch_attention_mask
)
return
vit_oup
,
image_atts
def
_create_patch_attention_mask
(
self
,
pixel_mask
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
if
pixel_mask
is
None
:
return
None
patches_subgrid
=
pixel_mask
.
unfold
(
dimension
=
1
,
size
=
self
.
vision_model
.
config
.
patch_size
,
step
=
self
.
vision_model
.
config
.
patch_size
,
).
unfold
(
dimension
=
2
,
size
=
self
.
vision_model
.
config
.
patch_size
,
step
=
self
.
vision_model
.
config
.
patch_size
,
)
return
(
patches_subgrid
.
sum
(
dim
=
(
-
1
,
-
2
))
>
0
).
bool
()
def
_create_image_attention_mask
(
self
,
patch_attention_mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
patch_attention_mask
is
None
:
return
None
flattened_mask
=
patch_attention_mask
.
flatten
(
1
)
return
torch
.
logical_not
(
flattened_mask
)
class
FFN
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
:
int
,
ff_dim
:
int
,
output_dim
:
int
)
->
None
:
super
().
__init__
()
self
.
linear_in
=
ColumnParallelLinear
(
embed_dim
,
ff_dim
,
bias
=
False
)
self
.
linear_out
=
RowParallelLinear
(
ff_dim
,
output_dim
,
bias
=
False
)
self
.
linear_in
=
ColumnParallelLinear
(
in_features
,
hidden_features
,
bias
=
False
)
self
.
linear_out
=
RowParallelLinear
(
hidden_features
,
output_dim
,
bias
=
False
)
self
.
act
=
get_act_fn
(
"gelu_new"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -137,46 +75,6 @@ class FFN(nn.Module):
return
hidden_states
class
CrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
kv_dim
:
int
,
embed_dim
:
int
,
num_heads
:
int
)
->
None
:
super
().
__init__
()
self
.
num_heads
=
num_heads
self
.
q_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
False
)
self
.
k_proj
=
nn
.
Linear
(
kv_dim
,
embed_dim
,
bias
=
False
)
self
.
v_proj
=
nn
.
Linear
(
kv_dim
,
embed_dim
,
bias
=
False
)
self
.
multihead_attn
=
nn
.
MultiheadAttention
(
embed_dim
,
num_heads
)
self
.
linear
=
nn
.
Linear
(
embed_dim
,
embed_dim
)
self
.
layer_norm
=
nn
.
LayerNorm
(
embed_dim
)
self
.
ln_kv
=
nn
.
LayerNorm
(
kv_dim
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
normed_hidden_states
=
self
.
layer_norm
(
hidden_states
)
query
=
self
.
q_proj
(
normed_hidden_states
).
permute
(
1
,
0
,
2
)
x
=
self
.
ln_kv
(
x
)
key
=
self
.
k_proj
(
x
).
permute
(
1
,
0
,
2
)
value
=
self
.
v_proj
(
x
).
permute
(
1
,
0
,
2
)
attn_output
,
_
=
self
.
multihead_attn
(
query
,
key
,
value
,
attn_mask
=
attn_mask
)
attn_output
=
attn_output
.
permute
(
1
,
0
,
2
)
attn_output
=
self
.
linear
(
attn_output
)
return
attn_output
class
AriaProjector
(
nn
.
Module
):
"""
A projection module with one cross attention layer and one FFN layer, which
...
...
@@ -198,42 +96,42 @@ class AriaProjector(nn.Module):
A tensor with the shape of (batch_size, query_number, output_dim)
"""
def
__init__
(
self
,
patch_to_query_dict
:
dict
[
int
,
int
],
embed_dim
:
int
,
num_heads
:
int
,
kv_dim
:
int
,
ff_dim
:
int
,
output_dim
:
int
,
norm_layer
:
Callable
[[
int
],
nn
.
Module
]
=
nn
.
LayerNorm
,
)
->
None
:
def
__init__
(
self
,
config
:
AriaConfig
)
->
None
:
super
().
__init__
()
self
.
patch_to_query_dict
=
patch_to_query_dict
self
.
embed_dim
=
embed_dim
self
.
num_heads
=
num_heads
self
.
patch_to_query_dict
=
config
.
projector_patch_to_query_dict
self
.
in_features
=
config
.
vision_config
.
hidden_size
self
.
num_heads
=
config
.
vision_config
.
num_attention_heads
self
.
kv_dim
=
config
.
vision_config
.
hidden_size
self
.
hidden_features
=
config
.
text_config
.
hidden_size
self
.
output_dim
=
config
.
text_config
.
hidden_size
self
.
query
=
nn
.
Parameter
(
torch
.
empty
(
max
(
patch_to_query_dict
.
values
()),
self
.
embed_dim
))
torch
.
empty
(
config
.
max_value_projector_patch_to_query_dict
,
self
.
in_features
))
self
.
cross_attn
=
CrossAttention
(
kv_dim
,
embed_dim
,
num_heads
)
self
.
cross_attn
=
Aria
CrossAttention
(
config
)
self
.
ln_ffn
=
norm_layer
(
embed_dim
)
self
.
ffn
=
FFN
(
embed_dim
,
ff_dim
,
output_dim
)
self
.
layer_norm
=
nn
.
LayerNorm
(
self
.
in_features
)
self
.
feed_forward
=
AriaProjectorMLP
(
self
.
in_features
,
self
.
hidden_features
,
self
.
output_dim
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
bs
=
x
.
shape
[
0
]
queries
=
self
.
query
.
unsqueeze
(
0
).
repeat
(
bs
,
1
,
1
)
batch_size
,
num_patches
=
x
.
shape
[
0
],
x
.
shape
[
1
]
if
num_patches
not
in
self
.
patch_to_query_dict
:
raise
KeyError
(
f
"Number of patches
{
num_patches
}
not found in "
"patch_to_query_dict amongst possible values "
f
"
{
self
.
patch_to_query_dict
.
keys
()
}
."
)
query_num
=
self
.
patch_to_query_dict
.
get
(
x
.
shape
[
1
],
None
)
assert
(
query_num
is
not
None
),
f
"Query number for
{
x
.
shape
[
1
]
}
patches is not provided"
query_num
=
self
.
patch_to_query_dict
[
num_patches
]
queries
=
queries
[:,
:
query_num
,
:]
queries
=
self
.
query
[:
query_num
].
unsqueeze
(
0
).
repeat
(
batch_size
,
1
,
1
)
if
attn_mask
is
not
None
:
attn_mask
=
attn_mask
.
repeat_interleave
(
self
.
num_heads
,
0
)
...
...
@@ -241,7 +139,7 @@ class AriaProjector(nn.Module):
attention_out
=
self
.
cross_attn
(
x
,
queries
,
attn_mask
=
attn_mask
)
out
=
self
.
f
fn
(
self
.
ln_ffn
(
attention_out
))
out
=
self
.
f
eed_forward
(
self
.
layer_norm
(
attention_out
))
return
out
...
...
@@ -278,7 +176,7 @@ class AriaFusedMoE(FusedMoE):
param
.
data
.
copy_
(
loaded_weight
.
transpose
(
1
,
2
))
class
MoELayer
(
nn
.
Module
):
class
AriaText
MoELayer
(
nn
.
Module
):
"""
Mixture of Experts (MoE) Layer for the AriaMoE model.
...
...
@@ -289,7 +187,7 @@ class MoELayer(nn.Module):
def
__init__
(
self
,
config
:
Aria
MoELM
Config
,
config
:
Aria
Text
Config
,
quant_config
:
Optional
[
QuantizationConfig
],
)
->
None
:
super
().
__init__
()
...
...
@@ -303,15 +201,16 @@ class MoELayer(nn.Module):
num_experts
=
config
.
moe_num_experts
,
top_k
=
config
.
moe_topk
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
quant_config
=
quant_config
,
reduce_results
=
True
,
)
self
.
shared_experts
=
LlamaMLP
(
config
.
hidden_size
,
config
.
moe_
intermediate_size
*
config
.
moe_num_shared_experts
,
config
.
intermediate_size
*
config
.
moe_num_shared_experts
,
"silu"
,
quant_config
=
quant_config
,
bias
=
config
.
mlp_bias
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -329,13 +228,13 @@ class MoELayer(nn.Module):
router_output
=
torch
.
nn
.
functional
.
linear
(
hidden_states
,
self
.
router_weight
)
shared_expert_output
=
self
.
shared_experts
(
hidden_states
)
sparse_expert_output
=
self
.
experts
(
hidden_states
,
router_output
)
shared_expert_output
=
self
.
shared_experts
(
hidden_states
)
return
sparse_expert_output
+
shared_expert_output
class
MoE
DecoderLayer
(
LlamaDecoderLayer
):
class
AriaText
DecoderLayer
(
LlamaDecoderLayer
):
"""
Custom Decoder Layer for the AriaMoE model which modifies the standard
`LlamaDecoderLayer` by replacing the traditional MLP with a Mixture of
...
...
@@ -344,16 +243,16 @@ class MoEDecoderLayer(LlamaDecoderLayer):
def
__init__
(
self
,
config
:
Aria
MoELM
Config
,
config
:
Aria
Text
Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
(
config
,
cache_config
,
quant_config
,
prefix
)
self
.
mlp
=
MoELayer
(
config
,
quant_config
=
quant_config
)
self
.
mlp
=
AriaText
MoELayer
(
config
,
quant_config
=
quant_config
)
class
Aria
MoELM
Model
(
LlamaModel
):
class
Aria
Text
Model
(
LlamaModel
):
"""
Custom LlamaModel for the AriaMoE model which modifies the standard
LlamaModel by replacing the `LlamaDecoderLayer` with `MoEDecoderLayer`.
...
...
@@ -362,7 +261,7 @@ class AriaMoELMModel(LlamaModel):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
layer_type
=
MoE
DecoderLayer
)
layer_type
=
AriaText
DecoderLayer
)
# Adapted from LlamaModel.load_weights with the modification of adding
# the expert weights mapping to `stacked_params_mapping`
...
...
@@ -434,25 +333,23 @@ class AriaMoELMModel(LlamaModel):
return
loaded_params
def
build_mm_projector
(
config
:
PretrainedConfig
):
return
AriaProjector
(
patch_to_query_dict
=
config
.
projector_patch_to_query_dict
,
embed_dim
=
config
.
vision_config
.
hidden_size
,
num_heads
=
config
.
vision_config
.
num_attention_heads
,
kv_dim
=
config
.
vision_config
.
hidden_size
,
ff_dim
=
config
.
text_config
.
hidden_size
,
output_dim
=
config
.
text_config
.
hidden_size
,
)
class
AriaProcessingInfo
(
BaseProcessingInfo
):
def
get_hf_config
(
self
):
return
self
.
ctx
.
get_hf_config
()
return
self
.
ctx
.
get_hf_config
(
AriaConfig
)
def
get_vision_config
(
self
)
->
AriaVisionConfig
:
def
get_vision_config
(
self
):
return
self
.
get_hf_config
().
vision_config
def
get_hf_processor
(
self
):
processor
=
self
.
ctx
.
get_hf_processor
(
AriaProcessor
)
# Patch for https://github.com/huggingface/transformers/issues/35768
processor
.
tokenizer
.
image_token
=
"<|img|>"
processor
.
image_token
=
"<|img|>"
return
processor
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
None
}
...
...
@@ -554,10 +451,14 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
vision_tower
=
AriaVisionModel
(
config
.
vision_config
)
self
.
multi_modal_projector
=
build_mm_projector
(
config
)
self
.
vision_tower
=
Idefics3VisionTransformer
(
config
.
vision_config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.vision_tower"
,
)
self
.
multi_modal_projector
=
AriaProjector
(
config
)
self
.
vocab_size
=
config
.
text_config
.
vocab_size
self
.
language_model
=
Aria
MoELM
Model
(
self
.
language_model
=
Aria
Text
Model
(
vllm_config
=
vllm_config
.
with_hf_config
(
config
.
text_config
),
prefix
=
maybe_prefix
(
prefix
,
"language_model.model"
),
)
...
...
@@ -608,6 +509,22 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
pixel_mask
=
pixel_mask
,
)
def
_create_patch_attention_mask
(
self
,
pixel_mask
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
if
pixel_mask
is
None
:
return
None
patches_subgrid
=
pixel_mask
.
unfold
(
dimension
=
1
,
size
=
self
.
vision_tower
.
config
.
patch_size
,
step
=
self
.
vision_tower
.
config
.
patch_size
,
).
unfold
(
dimension
=
2
,
size
=
self
.
vision_tower
.
config
.
patch_size
,
step
=
self
.
vision_tower
.
config
.
patch_size
,
)
return
(
patches_subgrid
.
sum
(
dim
=
(
-
1
,
-
2
))
>
0
).
bool
()
def
_process_image_input
(
self
,
image_input
:
AriaImagePixelInputs
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
@@ -616,9 +533,18 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
pixel_values
=
image_input
[
'pixel_values'
]
pixel_mask
=
image_input
[
'pixel_mask'
]
image_feature
,
image_attn_mask
=
self
.
vision_tower
(
pixel_values
,
pixel_mask
=
pixel_mask
)
return
self
.
multi_modal_projector
(
image_feature
,
image_attn_mask
)
patch_attention_mask
=
self
.
_create_patch_attention_mask
(
pixel_mask
)
image_outputs
=
self
.
vision_tower
(
pixel_values
=
pixel_values
,
patch_attention_mask
=
patch_attention_mask
,
)
image_attn_mask
=
None
if
patch_attention_mask
is
not
None
:
flattened_mask
=
patch_attention_mask
.
flatten
(
1
)
image_attn_mask
=
torch
.
logical_not
(
flattened_mask
)
return
self
.
multi_modal_projector
(
image_outputs
,
image_attn_mask
)
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
NestedTensors
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
...
...
@@ -683,6 +609,5 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
loader
=
AutoWeightsLoader
(
self
)
loader
.
load_weights
(
weights
,
mapper
=
self
.
hf_to_vllm_mapper
)
vllm/transformers_utils/config.py
View file @
b37d8279
...
...
@@ -22,10 +22,10 @@ from vllm.envs import VLLM_USE_MODELSCOPE
from
vllm.logger
import
init_logger
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.transformers_utils.configs
import
(
Aria
Config
,
C
hatGLM
Config
,
Cohere2
Config
,
D
brx
Config
,
DeepseekVLV2
Config
,
E
AGLE
Config
,
ExaoneConfig
,
H2OVLChatConfig
,
from
vllm.transformers_utils.configs
import
(
ChatGLM
Config
,
C
ohere2
Config
,
Dbrx
Config
,
D
eepseekVLV2
Config
,
EAGLE
Config
,
E
xaone
Config
,
H2OVLChatConfig
,
InternVLChatConfig
,
JAISConfig
,
MedusaConfig
,
MllamaConfig
,
MLPSpeculatorConfig
,
MPTConfig
,
...
...
@@ -52,7 +52,6 @@ _CONFIG_REGISTRY_OVERRIDE_HF: Dict[str, Type[PretrainedConfig]] = {
}
_CONFIG_REGISTRY
:
Dict
[
str
,
Type
[
PretrainedConfig
]]
=
{
"aria"
:
AriaConfig
,
"chatglm"
:
ChatGLMConfig
,
"cohere2"
:
Cohere2Config
,
"dbrx"
:
DbrxConfig
,
...
...
vllm/transformers_utils/configs/__init__.py
View file @
b37d8279
from
vllm.transformers_utils.configs.aria
import
AriaConfig
from
vllm.transformers_utils.configs.chatglm
import
ChatGLMConfig
from
vllm.transformers_utils.configs.cohere2
import
Cohere2Config
from
vllm.transformers_utils.configs.dbrx
import
DbrxConfig
...
...
@@ -24,7 +23,6 @@ from vllm.transformers_utils.configs.telechat2 import Telechat2Config
from
vllm.transformers_utils.configs.ultravox
import
UltravoxConfig
__all__
=
[
"AriaConfig"
,
"ChatGLMConfig"
,
"Cohere2Config"
,
"DbrxConfig"
,
...
...
vllm/transformers_utils/configs/aria.py
deleted
100644 → 0
View file @
3127e975
# Copyright 2024 Rhymes AI. All rights reserved.
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from
typing
import
Mapping
from
transformers
import
PretrainedConfig
from
transformers.models.idefics2.configuration_idefics2
import
(
Idefics2VisionConfig
)
from
transformers.models.llama.configuration_llama
import
LlamaConfig
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
class
AriaVisionConfig
(
Idefics2VisionConfig
):
model_type
=
"aria_vision_model"
class
AriaMoELMConfig
(
LlamaConfig
):
"""
Configuration class for AriaMoE language model.
This class extends the LlamaConfig to include additional parameters specific
to the Mixture of Experts (MoE) architecture.
"""
model_type
=
"aria_moe_lm"
def
__init__
(
self
,
moe_intermediate_size
:
int
=
4096
,
moe_num_experts
:
int
=
8
,
moe_topk
:
int
=
2
,
moe_num_shared_experts
:
int
=
2
,
**
kwargs
,
):
"""
Initialize the AriaMoELMConfig.
Args:
moe_intermediate_size (int): The intermediate size for MoE layers.
Default is 4096.
moe_num_experts (int): The number of experts in the MoE layer.
Default is 8.
moe_topk (int): The number of top experts to route to for each
token. Default is 2.
moe_num_shared_experts (int): The number of shared experts. Default
is 2.
**kwargs: Additional keyword arguments to be passed to the parent
LlamaConfig.
"""
super
().
__init__
(
**
kwargs
)
self
.
moe_intermediate_size
=
moe_intermediate_size
self
.
moe_num_experts
=
moe_num_experts
self
.
moe_topk
=
moe_topk
self
.
moe_num_shared_experts
=
moe_num_shared_experts
class
AriaConfig
(
PretrainedConfig
):
"""
Configuration class for Aria model.
This class handles the configuration for both vision and text components of
the Aria model,
as well as additional parameters for image token handling and projector
mapping.
Args:
vision_config (AriaVisionConfig or dict): Configuration for the vision
component.
text_config (AriaMoELMConfig or dict): Configuration for the text
component.
projector_patch_to_query_dict (dict): Mapping of patch sizes to query
dimensions.
ignore_index (int): Index to ignore in loss calculation.
image_token_index (int): Index used to represent image tokens.
**kwargs: Additional keyword arguments passed to the parent class.
Attributes:
model_type (str): Type of the model, set to "aria".
is_composition (bool): Whether the model is a composition of multiple
components.
ignore_index (int): Index to ignore in loss calculation.
image_token_index (int): Index used to represent image tokens.
projector_patch_to_query_dict (dict): Mapping of patch sizes to query
dimensions.
vision_config (AriaVisionConfig): Configuration for the vision
component.
text_config (AriaMoELMConfig): Configuration for the text component.
"""
model_type
=
"aria"
is_composition
=
False
def
__init__
(
self
,
vision_config
:
AriaVisionConfig
=
AriaVisionConfig
(),
# noqa: B008
text_config
:
AriaMoELMConfig
=
AriaMoELMConfig
(),
# noqa: B008
projector_patch_to_query_dict
:
Mapping
[
int
,
int
]
=
{
1225
:
128
,
4900
:
256
,
},
ignore_index
=-
100
,
image_token_index
=
32000
,
tie_word_embeddings
=
False
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
ignore_index
=
ignore_index
self
.
image_token_index
=
image_token_index
self
.
tie_word_embeddings
=
tie_word_embeddings
attn_implementation
=
kwargs
.
pop
(
"attn_implementation"
,
None
)
# Set the default attention implementation to flash_attention_2 if not
# specified
self
.
_attn_implementation
=
(
"flash_attention_2"
if
attn_implementation
is
None
else
attn_implementation
)
# Convert the keys and values of projector_patch_to_query_dict to
# integers
# This ensures consistency even if they were provided as strings
self
.
projector_patch_to_query_dict
=
{
int
(
k
):
int
(
v
)
for
k
,
v
in
projector_patch_to_query_dict
.
items
()
}
if
isinstance
(
vision_config
,
dict
)
and
"model_type"
in
vision_config
:
vision_config
=
AriaVisionConfig
(
**
vision_config
)
if
attn_implementation
is
None
:
vision_attn_implementation
=
"flash_attention_2"
elif
attn_implementation
==
"sdpa"
:
logger
.
warning
(
"SDPA is not supported for vit, using "
"flash_attention_2 instead"
)
vision_attn_implementation
=
"flash_attention_2"
else
:
vision_attn_implementation
=
attn_implementation
vision_config
.
_attn_implementation
=
vision_attn_implementation
self
.
vision_config
=
vision_config
if
isinstance
(
text_config
,
dict
)
and
"model_type"
in
text_config
:
text_attn_implementation
=
(
"sdpa"
if
attn_implementation
is
None
else
attn_implementation
)
text_config
=
AriaMoELMConfig
(
**
text_config
)
text_config
.
_attn_implementation
=
text_attn_implementation
self
.
text_config
=
text_config
# This is needed for the static kv cache
self
.
num_hidden_layers
=
self
.
text_config
.
num_hidden_layers
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment