Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
528dbcac
Unverified
Commit
528dbcac
authored
Jan 22, 2025
by
zhou fan
Committed by
GitHub
Jan 22, 2025
Browse files
[Model][Bugfix]: correct Aria model output (#12309)
Signed-off-by:
xffxff
<
1247714429@qq.com
>
parent
cd7b6f08
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
54 additions
and
3 deletions
+54
-3
examples/offline_inference/vision_language.py
examples/offline_inference/vision_language.py
+2
-1
vllm/model_executor/models/aria.py
vllm/model_executor/models/aria.py
+52
-2
No files found.
examples/offline_inference/vision_language.py
View file @
528dbcac
...
@@ -28,9 +28,10 @@ def run_aria(question: str, modality: str):
...
@@ -28,9 +28,10 @@ def run_aria(question: str, modality: str):
llm
=
LLM
(
model
=
model_name
,
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
4096
,
max_model_len
=
4096
,
max_num_seqs
=
2
,
max_num_seqs
=
2
,
dtype
=
"bfloat16"
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
)
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
)
prompt
=
(
f
"<|im_start|>user
\n
<fim_prefix><|img|><fim_suffix>
\n
{
question
}
"
prompt
=
(
f
"<|im_start|>user
\n
<fim_prefix><|img|><fim_suffix>
{
question
}
"
"<|im_end|>
\n
<|im_start|>assistant
\n
"
)
"<|im_end|>
\n
<|im_start|>assistant
\n
"
)
stop_token_ids
=
[
93532
,
93653
,
944
,
93421
,
1019
,
93653
,
93519
]
stop_token_ids
=
[
93532
,
93653
,
944
,
93421
,
1019
,
93653
,
93519
]
...
...
vllm/model_executor/models/aria.py
View file @
528dbcac
...
@@ -30,6 +30,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
...
@@ -30,6 +30,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
# yapf: disable
# yapf: disable
from
.idefics2_vision_model
import
Idefics2VisionConfig
from
.idefics2_vision_model
import
(
from
.idefics2_vision_model
import
(
Idefics2VisionTransformer
as
Idefics3VisionTransformer
)
Idefics2VisionTransformer
as
Idefics3VisionTransformer
)
# yapf: enable
# yapf: enable
...
@@ -50,6 +51,53 @@ class AriaImagePixelInputs(TypedDict):
...
@@ -50,6 +51,53 @@ class AriaImagePixelInputs(TypedDict):
"""
"""
class
AriaVisionTransformer
(
Idefics3VisionTransformer
):
def
__init__
(
self
,
config
:
Idefics2VisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
(
config
,
quant_config
,
prefix
)
# Unlike Idefics3VisionTransformer which uses LayerNorm after the
# final layer, Aria omits this normalization, so we replace it with an
# Identity layer
self
.
post_layernorm
=
nn
.
Identity
()
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
# NOTE: post_layernorm is not used in Aria
if
"post_layernorm"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
AriaProjectorMLP
(
nn
.
Module
):
class
AriaProjectorMLP
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -228,8 +276,10 @@ class AriaTextMoELayer(nn.Module):
...
@@ -228,8 +276,10 @@ class AriaTextMoELayer(nn.Module):
router_output
=
torch
.
nn
.
functional
.
linear
(
hidden_states
,
router_output
=
torch
.
nn
.
functional
.
linear
(
hidden_states
,
self
.
router_weight
)
self
.
router_weight
)
hidden_states_copy
=
hidden_states
.
clone
()
# NOTE: hidden_states will be modified inplace by `FusedMoE`
sparse_expert_output
=
self
.
experts
(
hidden_states
,
router_output
)
sparse_expert_output
=
self
.
experts
(
hidden_states
,
router_output
)
shared_expert_output
=
self
.
shared_experts
(
hidden_states
)
shared_expert_output
=
self
.
shared_experts
(
hidden_states
_copy
)
return
sparse_expert_output
+
shared_expert_output
return
sparse_expert_output
+
shared_expert_output
...
@@ -445,7 +495,7 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -445,7 +495,7 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
config
=
config
self
.
vision_tower
=
Idefics3
VisionTransformer
(
self
.
vision_tower
=
Aria
VisionTransformer
(
config
.
vision_config
,
config
.
vision_config
,
quant_config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.vision_tower"
,
prefix
=
f
"
{
prefix
}
.vision_tower"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment