Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7d6fb905
Unverified
Commit
7d6fb905
authored
Oct 02, 2025
by
Cyrus Leung
Committed by
GitHub
Oct 02, 2025
Browse files
[Model] Use `merge_by_field_config` for MM models (A-C) (#26073)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
418d111f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
29 additions
and
24 deletions
+29
-24
vllm/model_executor/models/aria.py
vllm/model_executor/models/aria.py
+13
-6
vllm/model_executor/models/aya_vision.py
vllm/model_executor/models/aya_vision.py
+4
-3
vllm/model_executor/models/blip2.py
vllm/model_executor/models/blip2.py
+4
-9
vllm/model_executor/models/chameleon.py
vllm/model_executor/models/chameleon.py
+4
-3
vllm/model_executor/models/cohere2_vision.py
vllm/model_executor/models/cohere2_vision.py
+4
-3
No files found.
vllm/model_executor/models/aria.py
View file @
7d6fb905
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Annotated
,
Optional
,
Union
from
typing
import
Annotated
,
Literal
,
Optional
,
Union
import
torch
import
torch.nn
as
nn
...
...
@@ -38,8 +38,8 @@ from .idefics2_vision_model import (
# yapf: enable
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsQuant
from
.llama
import
LlamaDecoderLayer
,
LlamaMLP
,
LlamaModel
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
flatten_bn
,
is_pp_missing_parameter
,
maybe_prefix
)
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
is_pp_missing_parameter
,
maybe_prefix
)
class
AriaImagePixelInputs
(
TensorSchema
):
...
...
@@ -52,6 +52,8 @@ class AriaImagePixelInputs(TensorSchema):
- w: Width of each image
"""
type
:
Literal
[
"pixel_values"
]
pixel_values
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn"
,
3
,
"h"
,
"w"
),
...
...
@@ -485,6 +487,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
This model combines a vision tower, a multi-modal projector, and a language
model to perform tasks that involve both image and text inputs.
"""
merge_by_field_config
=
True
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
# mapping for new names in checkpoint saved after transformers v4.52
...
...
@@ -551,12 +555,15 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
return
None
return
AriaImagePixelInputs
(
pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
),
pixel_mask
=
flatten_bn
(
pixel_mask
,
concat
=
True
),
type
=
"pixel_values"
,
pixel_values
=
pixel_values
,
pixel_mask
=
pixel_mask
,
)
def
_create_patch_attention_mask
(
self
,
pixel_mask
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
self
,
pixel_mask
:
Optional
[
torch
.
Tensor
],
)
->
Optional
[
torch
.
Tensor
]:
if
pixel_mask
is
None
:
return
None
...
...
vllm/model_executor/models/aya_vision.py
View file @
7d6fb905
...
...
@@ -31,7 +31,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.siglip
import
SiglipVisionModel
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
flatten_bn
,
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
init_vllm_registered_model
,
maybe_prefix
)
...
...
@@ -295,6 +295,7 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
dummy_inputs
=
AyaVisionDummyInputsBuilder
)
class
AyaVisionForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
merge_by_field_config
=
True
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
...
...
@@ -379,8 +380,8 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return
AyaVisionImagePixelInputs
(
type
=
"pixel_values"
,
pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
),
num_patches
=
flatten_bn
(
num_patches
,
concat
=
True
),
pixel_values
=
pixel_values
,
num_patches
=
num_patches
,
resolve_bindings
=
{
"h"
:
self
.
config
.
vision_config
.
image_size
,
"w"
:
self
.
config
.
vision_config
.
image_size
,
...
...
vllm/model_executor/models/blip2.py
View file @
7d6fb905
...
...
@@ -26,12 +26,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from
.blip
import
BlipVisionModel
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
,
SupportsQuant
)
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
)
# We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo
_IMAGE_TOKEN_ID
=
50265
from
.utils
import
AutoWeightsLoader
,
init_vllm_registered_model
,
maybe_prefix
class
Blip2ImagePixelInputs
(
TensorSchema
):
...
...
@@ -514,6 +509,7 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
dummy_inputs
=
Blip2DummyInputsBuilder
)
class
Blip2ForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
,
SupportsQuant
):
merge_by_field_config
=
True
@
classmethod
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
Optional
[
str
]:
...
...
@@ -570,8 +566,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
if
pixel_values
is
not
None
:
expected_h
=
expected_w
=
self
.
config
.
vision_config
.
image_size
return
Blip2ImagePixelInputs
(
type
=
"pixel_values"
,
data
=
flatten_bn
(
pixel_values
,
concat
=
True
),
data
=
pixel_values
,
resolve_bindings
=
{
"h"
:
expected_h
,
"w"
:
expected_w
...
...
@@ -580,7 +575,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
if
image_embeds
is
not
None
:
return
Blip2ImageEmbeddingInputs
(
type
=
"image_embeds"
,
data
=
flatten_bn
(
image_embeds
,
concat
=
True
),
data
=
image_embeds
,
)
raise
AssertionError
(
"This line should be unreachable."
)
...
...
vllm/model_executor/models/chameleon.py
View file @
7d6fb905
...
...
@@ -42,7 +42,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
,
SupportsQuant
)
from
.utils
import
(
flatten_bn
,
is_pp_missing_parameter
,
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
...
...
@@ -935,6 +935,8 @@ class ChameleonModel(nn.Module):
dummy_inputs
=
ChameleonDummyInputsBuilder
)
class
ChameleonForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
,
SupportsQuant
):
merge_by_field_config
=
True
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
]
...
...
@@ -981,8 +983,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
expected_h
=
expected_w
=
vq_config
.
resolution
return
ChameleonImagePixelInputs
(
type
=
"pixel_values"
,
data
=
flatten_bn
(
pixel_values
,
concat
=
True
),
data
=
pixel_values
,
resolve_bindings
=
{
"h"
:
expected_h
,
"w"
:
expected_w
...
...
vllm/model_executor/models/cohere2_vision.py
View file @
7d6fb905
...
...
@@ -36,7 +36,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.siglip
import
SiglipVisionModel
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
flatten_bn
,
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
init_vllm_registered_model
,
maybe_prefix
)
...
...
@@ -317,6 +317,7 @@ class Cohere2VisionMultiModalProcessor(
dummy_inputs
=
Cohere2VisionDummyInputsBuilder
)
class
Cohere2VisionForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
merge_by_field_config
=
True
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
...
...
@@ -399,8 +400,8 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return
Cohere2VisionImagePixelInputs
(
type
=
"pixel_values"
,
pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
),
num_patches
=
flatten_bn
(
num_patches
,
concat
=
True
),
pixel_values
=
pixel_values
,
num_patches
=
num_patches
,
resolve_bindings
=
{
"h"
:
self
.
config
.
vision_config
.
image_size
,
"w"
:
self
.
config
.
vision_config
.
image_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment