Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cc253b73
Unverified
Commit
cc253b73
authored
Oct 02, 2025
by
Cyrus Leung
Committed by
GitHub
Oct 02, 2025
Browse files
[Model] Use `merge_by_field_config` for MM models (D-F) (#26076)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
7d6fb905
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
99 additions
and
177 deletions
+99
-177
vllm/model_executor/models/deepseek_vl2.py
vllm/model_executor/models/deepseek_vl2.py
+25
-33
vllm/model_executor/models/dots_ocr.py
vllm/model_executor/models/dots_ocr.py
+24
-51
vllm/model_executor/models/ernie45_vl.py
vllm/model_executor/models/ernie45_vl.py
+24
-52
vllm/model_executor/models/fuyu.py
vllm/model_executor/models/fuyu.py
+26
-41
No files found.
vllm/model_executor/models/deepseek_vl2.py
View file @
cc253b73
...
@@ -20,8 +20,7 @@ from vllm.model_executor.model_loader.utils import set_default_torch_dtype
...
@@ -20,8 +20,7 @@ from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from
vllm.model_executor.models.transformers
import
replace_linear_class
from
vllm.model_executor.models.transformers
import
replace_linear_class
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargsItems
,
MultiModalUUIDDict
,
MultiModalKwargsItems
,
MultiModalUUIDDict
)
NestedTensors
)
from
vllm.multimodal.parse
import
(
ImageEmbeddingItems
,
ImageProcessorItems
,
from
vllm.multimodal.parse
import
(
ImageEmbeddingItems
,
ImageProcessorItems
,
ImageSize
,
MultiModalDataItems
)
ImageSize
,
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
...
@@ -40,7 +39,7 @@ from vllm.utils import is_list_of
...
@@ -40,7 +39,7 @@ from vllm.utils import is_list_of
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
flatten_bn
,
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
init_vllm_registered_model
,
maybe_prefix
)
init_vllm_registered_model
,
maybe_prefix
)
# The image token id may be various
# The image token id may be various
...
@@ -50,15 +49,15 @@ _IMAGE_TOKEN = "<image>"
...
@@ -50,15 +49,15 @@ _IMAGE_TOKEN = "<image>"
class
DeepseekVL2ImagePixelInputs
(
TensorSchema
):
class
DeepseekVL2ImagePixelInputs
(
TensorSchema
):
"""
"""
Dimensions:
Dimensions:
- bn: Batch size * number of images
- bn
p
: Batch size * number of images
* number of patches
- p: Number of patches
- p: Number of patches
- c: Number of channels (3)
- c: Number of channels (3)
- h: Height of each image
- h: Height of each image
- w: Width of each image
- w: Width of each image
"""
"""
type
:
Literal
[
"pixel_values"
]
type
:
Literal
[
"pixel_values"
]
data
:
Annotated
[
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
data
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn
"
,
"
p"
,
3
,
"h"
,
"w"
,
dynamic_dims
=
{
"p"
})]
TensorShape
(
"bnp"
,
3
,
"h"
,
"w"
,
dynamic_dims
=
{
"
bn
p"
})]
images_spatial_crop
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn"
,
2
)]
images_spatial_crop
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn"
,
2
)]
...
@@ -228,12 +227,8 @@ class DeepseekVL2MultiModalProcessor(
...
@@ -228,12 +227,8 @@ class DeepseekVL2MultiModalProcessor(
tok_kwargs
=
tok_kwargs
,
tok_kwargs
=
tok_kwargs
,
)
)
pixel_values
=
processed_outputs
[
"pixel_values"
]
processed_outputs
[
"num_patches"
]
=
(
# split pixel values into patches corresponding to each image
processed_outputs
[
"images_spatial_crop"
].
prod
(
-
1
)
+
1
)
images_spatial_crop
=
processed_outputs
[
"images_spatial_crop"
]
patches_per_image
=
[
x
.
prod
().
item
()
+
1
for
x
in
images_spatial_crop
]
pixel_values
=
pixel_values
.
split
(
patches_per_image
)
processed_outputs
[
"pixel_values"
]
=
pixel_values
return
processed_outputs
return
processed_outputs
...
@@ -242,8 +237,11 @@ class DeepseekVL2MultiModalProcessor(
...
@@ -242,8 +237,11 @@ class DeepseekVL2MultiModalProcessor(
hf_inputs
:
BatchFeature
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
num_patches
=
hf_inputs
.
get
(
"num_patches"
,
torch
.
empty
(
0
))
return
dict
(
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
),
pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
num_patches
),
images_spatial_crop
=
MultiModalFieldConfig
.
batched
(
"image"
),
images_spatial_crop
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
)
...
@@ -318,6 +316,7 @@ class DeepseekVL2MultiModalProcessor(
...
@@ -318,6 +316,7 @@ class DeepseekVL2MultiModalProcessor(
info
=
DeepseekVL2ProcessingInfo
,
info
=
DeepseekVL2ProcessingInfo
,
dummy_inputs
=
DeepseekVL2DummyInputsBuilder
)
dummy_inputs
=
DeepseekVL2DummyInputsBuilder
)
class
DeepseekVLV2ForCausalLM
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
class
DeepseekVLV2ForCausalLM
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
merge_by_field_config
=
True
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"language."
:
"language_model."
,
"language."
:
"language_model."
,
...
@@ -460,37 +459,30 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -460,37 +459,30 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
if
pixel_values
is
not
None
:
if
pixel_values
is
not
None
:
expected_h
=
expected_w
=
self
.
vision_config
.
image_size
expected_h
=
expected_w
=
self
.
vision_config
.
image_size
return
DeepseekVL2ImagePixelInputs
(
type
=
"pixel_values"
,
return
DeepseekVL2ImagePixelInputs
(
data
=
flatten_bn
(
pixel_values
),
type
=
"pixel_values"
,
images_spatial_crop
=
flatten_bn
(
data
=
pixel_values
,
images_spatial_crop
,
images_spatial_crop
=
images_spatial_crop
,
concat
=
True
),
resolve_bindings
=
{
resolve_bindings
=
{
"h"
:
expected_h
,
"h"
:
expected_h
,
"w"
:
expected_w
,
"w"
:
expected_w
,
})
})
if
image_embeds
is
not
None
:
if
image_embeds
is
not
None
:
return
DeepseekVL2VImageEmbeddingInputs
(
return
DeepseekVL2VImageEmbeddingInputs
(
type
=
"image_embeds"
,
type
=
"image_embeds"
,
data
=
flatten_bn
(
image_embeds
)
,
data
=
image_embeds
,
)
)
raise
AssertionError
(
"This line should be unreachable."
)
raise
AssertionError
(
"This line should be unreachable."
)
def
_pixel_values_to_embedding
(
def
_pixel_values_to_embedding
(
self
,
self
,
pixel_values
:
Nested
Tensor
s
,
pixel_values
:
torch
.
Tensor
,
images_spatial_crop
:
torch
.
Tensor
,
images_spatial_crop
:
torch
.
Tensor
,
)
->
NestedTensors
:
)
->
list
[
torch
.
Tensor
]:
# Pixel_values: n_image * batch_size * [patch_per_img, 3, height, width]
total_tiles
=
[
x
for
x
in
pixel_values
]
# [batch_all_tiles, 3, height, width]
total_tiles
=
torch
.
cat
(
total_tiles
,
dim
=
0
)
# [batch_all_tiles, vit_seq_len, c]
# [batch_all_tiles, vit_seq_len, c]
images_feature
=
self
.
vision
.
forward_features
(
total_til
es
)
images_feature
=
self
.
vision
.
forward_features
(
pixel_valu
es
)
# [batch_all_tiles, hw, D]
# [batch_all_tiles, hw, D]
images_embeds
=
self
.
projector
(
images_feature
)
images_embeds
=
self
.
projector
(
images_feature
)
...
@@ -573,7 +565,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -573,7 +565,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return
vision_embeddings
return
vision_embeddings
def
_process_image_input
(
def
_process_image_input
(
self
,
image_input
:
DeepseekVL2ImageInputs
)
->
torch
.
Tensor
:
self
,
image_input
:
DeepseekVL2ImageInputs
)
->
list
[
torch
.
Tensor
]
:
if
image_input
[
"type"
]
==
"image_embeds"
:
if
image_input
[
"type"
]
==
"image_embeds"
:
image_data
=
image_input
[
"data"
]
image_data
=
image_input
[
"data"
]
if
is_list_of
(
image_data
,
torch
.
Tensor
):
if
is_list_of
(
image_data
,
torch
.
Tensor
):
...
...
vllm/model_executor/models/dots_ocr.py
View file @
cc253b73
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
,
Mapping
from
collections.abc
import
Iterable
,
Mapping
from
typing
import
Literal
,
Optional
,
TypedDict
,
Union
from
typing
import
Annotated
,
Literal
,
Optional
,
Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -42,34 +42,38 @@ from vllm.platforms import _Backend
...
@@ -42,34 +42,38 @@ from vllm.platforms import _Backend
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs.dotsocr
import
(
DotsOCRConfig
,
from
vllm.transformers_utils.configs.dotsocr
import
(
DotsOCRConfig
,
DotsVisionConfig
)
DotsVisionConfig
)
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.vision
import
run_dp_sharded_mrope_vision_model
from
.vision
import
run_dp_sharded_mrope_vision_model
IMAGE_TOKEN
=
"<|imgpad|>"
IMAGE_TOKEN
=
"<|imgpad|>"
class
DotsOCRImagePixelInputs
(
TypedDict
):
class
DotsOCRImagePixelInputs
(
TensorSchema
):
type
:
Literal
[
"pixel_values"
,
"image_grid_thw"
]
"""
Dimensions:
- np: The total number of patches over each image over each prompt in
the batch
- ni: Number of images
- cps: Number of channels * patch_size * patch_size
"""
type
:
Literal
[
"pixel_values"
]
pixel_values
:
torch
.
Tensor
pixel_values
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"np"
,
"cps"
)]
image_grid_thw
:
torch
.
Tensor
image_grid_thw
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"ni"
,
3
)]
class
DotsOCRImageEmbeddingInputs
(
TypedDict
):
class
DotsOCRImageEmbeddingInputs
(
TensorSchema
):
type
:
Literal
[
"image_embeds"
,
"image_grid_thw"
]
image_embeds
:
torch
.
Tensor
"""Supported types:
- List[`torch.Tensor`]: A list of tensors holding all images' features.
Each tensor holds an image's features.
- `torch.Tensor`: A tensor holding all images' features
(concatenation of all images' feature tensors).
Tensor shape: `(num_image_features, hidden_size)`
- `num_image_features` varies based on
the number and resolution of the images.
- `hidden_size` must match the hidden size of language model backbone.
"""
"""
Dimensions:
- nf: Number of image features
- hs: Hidden size
- ni: Number of images
"""
type
:
Literal
[
"image_embeds"
]
image_grid_thw
:
torch
.
Tensor
image_embeds
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"nf"
,
"hs"
)]
image_grid_thw
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"ni"
,
3
)]
DotsOCRImageInputs
=
Union
[
DotsOCRImagePixelInputs
,
DotsOCRImageInputs
=
Union
[
DotsOCRImagePixelInputs
,
...
@@ -654,6 +658,8 @@ class DotsVisionTransformer(nn.Module):
...
@@ -654,6 +658,8 @@ class DotsVisionTransformer(nn.Module):
)
)
class
DotsOCRForCausalLM
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
,
class
DotsOCRForCausalLM
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
,
SupportsLoRA
):
SupportsLoRA
):
merge_by_field_config
=
True
hf_to_vllm_mapper
=
WeightsMapper
(
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_substr
=
{
orig_to_new_substr
=
{
".attn.qkv_proj."
:
".attn.qkv."
,
".attn.qkv_proj."
:
".attn.qkv."
,
...
@@ -709,22 +715,6 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -709,22 +715,6 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
architectures
=
[
"Qwen2ForCausalLM"
],
architectures
=
[
"Qwen2ForCausalLM"
],
)
)
def
_validate_and_reshape_mm_tensor
(
self
,
mm_input
:
object
,
name
:
str
)
->
torch
.
Tensor
:
if
not
isinstance
(
mm_input
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
f
"Incorrect type of
{
name
}
. "
f
"Got type:
{
type
(
mm_input
)
}
"
)
if
isinstance
(
mm_input
,
torch
.
Tensor
):
if
mm_input
.
ndim
==
2
:
return
mm_input
if
mm_input
.
ndim
!=
3
:
raise
ValueError
(
f
"
{
name
}
should be 2D or batched 3D tensor. "
f
"Got ndim:
{
mm_input
.
ndim
}
"
f
"(shape=
{
mm_input
.
shape
}
)"
)
return
torch
.
concat
(
list
(
mm_input
))
else
:
return
torch
.
concat
(
mm_input
)
def
_parse_and_validate_image_input
(
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
DotsOCRImageInputs
]:
self
,
**
kwargs
:
object
)
->
Optional
[
DotsOCRImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
...
@@ -735,28 +725,11 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -735,28 +725,11 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
return
None
return
None
if
pixel_values
is
not
None
:
if
pixel_values
is
not
None
:
pixel_values
=
self
.
_validate_and_reshape_mm_tensor
(
pixel_values
,
"image pixel values"
)
image_grid_thw
=
self
.
_validate_and_reshape_mm_tensor
(
image_grid_thw
,
"image grid_thw"
)
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of image pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
return
DotsOCRImagePixelInputs
(
type
=
"pixel_values"
,
return
DotsOCRImagePixelInputs
(
type
=
"pixel_values"
,
pixel_values
=
pixel_values
,
pixel_values
=
pixel_values
,
image_grid_thw
=
image_grid_thw
)
image_grid_thw
=
image_grid_thw
)
if
image_embeds
is
not
None
:
if
image_embeds
is
not
None
:
image_embeds
=
self
.
_validate_and_reshape_mm_tensor
(
image_embeds
,
"image embeds"
)
image_grid_thw
=
self
.
_validate_and_reshape_mm_tensor
(
image_grid_thw
,
"image grid_thw"
)
if
not
isinstance
(
image_embeds
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of image embeddings. "
f
"Got type:
{
type
(
image_embeds
)
}
"
)
return
DotsOCRImageEmbeddingInputs
(
type
=
"image_embeds"
,
return
DotsOCRImageEmbeddingInputs
(
type
=
"image_embeds"
,
image_embeds
=
image_embeds
,
image_embeds
=
image_embeds
,
image_grid_thw
=
image_grid_thw
)
image_grid_thw
=
image_grid_thw
)
...
...
vllm/model_executor/models/ernie45_vl.py
View file @
cc253b73
...
@@ -25,7 +25,7 @@
...
@@ -25,7 +25,7 @@
import
math
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
functools
import
partial
from
functools
import
partial
from
typing
import
Any
,
Callable
,
Literal
,
Optional
,
TypedDict
,
Union
from
typing
import
Annotated
,
Any
,
Callable
,
Literal
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -56,6 +56,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
...
@@ -56,6 +56,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.ernie45_vl_moe
import
Ernie4_5_VLMoeForCausalLM
from
.ernie45_vl_moe
import
Ernie4_5_VLMoeForCausalLM
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
...
@@ -579,38 +580,38 @@ class Ernie4_5_VisionTransformer(nn.Module):
...
@@ -579,38 +580,38 @@ class Ernie4_5_VisionTransformer(nn.Module):
# === Vision Inputs === #
# === Vision Inputs === #
class
Ernie4_5_VLImagePixelInputs
(
TypedDict
):
class
Ernie4_5_VLImagePixelInputs
(
TensorSchema
):
type
:
Literal
[
"pixel_values"
]
pixel_values
:
torch
.
Tensor
"""Shape:
`(num_patches, num_channels * patch_size * patch_size)`
"""
"""
Dimensions:
grid_thw
:
torch
.
Tensor
- np: The total number of patches over each image over each prompt in
"""Shape: `(num_images, 3)`
the batch
This should be in `(grid_t, grid_h, grid_w)` format.
- ni: Number of images
- cps: Number of channels * patch_size * patch_size
"""
"""
type
:
Literal
[
"pixel_values"
]
pixel_values
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"np"
,
"cps"
)]
image_grid_thw
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"ni"
,
3
)]
Ernie4_5_VLImageInputs
=
Ernie4_5_VLImagePixelInputs
Ernie4_5_VLImageInputs
=
Ernie4_5_VLImagePixelInputs
class
Ernie4_5_VLVideoPixelInputs
(
TypedDict
):
class
Ernie4_5_VLVideoPixelInputs
(
TensorSchema
):
type
:
Literal
[
"pixel_values_videos"
]
pixel_values_videos
:
torch
.
Tensor
"""Shape:
`(num_patches,
num_channels * temporal_patch_size * patch_size * patch_size)`
"""
"""
Dimensions:
video_grid_thw
:
torch
.
Tensor
- np: The total number of patches over each image over each prompt in
"""Shape: `(num_videos, 3)`
the batch
- ni: Number of images
This should be in `(grid_t, grid_h, grid_w)` format.
- cps: Number of channels * temporal_patch_size * patch_size *
patch_size
"""
"""
type
:
Literal
[
"pixel_values_videos"
]
pixel_values_videos
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"np"
,
"cps"
)]
video_grid_thw
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"ni"
,
3
)]
Ernie4_5_VLVideoInputs
=
Ernie4_5_VL
Image
PixelInputs
Ernie4_5_VLVideoInputs
=
Ernie4_5_VL
Video
PixelInputs
# === Vision Processor === #
# === Vision Processor === #
...
@@ -1213,6 +1214,7 @@ class Ernie4_5_VLDummyInputsBuilder(
...
@@ -1213,6 +1214,7 @@ class Ernie4_5_VLDummyInputsBuilder(
dummy_inputs
=
Ernie4_5_VLDummyInputsBuilder
)
dummy_inputs
=
Ernie4_5_VLDummyInputsBuilder
)
class
Ernie4_5_VLMoeForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
class
Ernie4_5_VLMoeForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsLoRA
,
SupportsPP
):
SupportsLoRA
,
SupportsPP
):
merge_by_field_config
=
True
packed_modules_mapping
=
{
packed_modules_mapping
=
{
"qkv_proj"
:
[
"qkv_proj"
:
[
...
@@ -1325,22 +1327,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1325,22 +1327,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
return
self
.
language_model
def
_validate_and_reshape_mm_tensor
(
self
,
mm_input
:
object
,
name
:
str
)
->
torch
.
Tensor
:
if
not
isinstance
(
mm_input
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
f
"Incorrect type of
{
name
}
. "
f
"Got type:
{
type
(
mm_input
)
}
"
)
if
isinstance
(
mm_input
,
torch
.
Tensor
):
if
mm_input
.
ndim
==
2
:
return
mm_input
if
mm_input
.
ndim
!=
3
:
raise
ValueError
(
f
"
{
name
}
should be 2D or batched 3D tensor. "
f
"Got ndim:
{
mm_input
.
ndim
}
"
f
"(shape=
{
mm_input
.
shape
}
)"
)
return
mm_input
.
reshape
(
-
1
,
mm_input
.
shape
[
-
1
])
else
:
return
torch
.
concat
(
mm_input
)
def
_parse_and_validate_image_input
(
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
Ernie4_5_VLImageInputs
]:
self
,
**
kwargs
:
object
)
->
Optional
[
Ernie4_5_VLImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
...
@@ -1350,15 +1336,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1350,15 +1336,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
return
None
return
None
if
pixel_values
is
not
None
:
if
pixel_values
is
not
None
:
pixel_values
=
self
.
_validate_and_reshape_mm_tensor
(
pixel_values
,
"image pixel values"
)
image_grid_thw
=
self
.
_validate_and_reshape_mm_tensor
(
image_grid_thw
,
"image grid_thw"
)
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of image pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
return
Ernie4_5_VLImagePixelInputs
(
type
=
"pixel_values"
,
return
Ernie4_5_VLImagePixelInputs
(
type
=
"pixel_values"
,
pixel_values
=
pixel_values
,
pixel_values
=
pixel_values
,
image_grid_thw
=
image_grid_thw
)
image_grid_thw
=
image_grid_thw
)
...
@@ -1372,11 +1349,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1372,11 +1349,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
return
None
return
None
if
pixel_values_videos
is
not
None
:
if
pixel_values_videos
is
not
None
:
pixel_values_videos
=
self
.
_validate_and_reshape_mm_tensor
(
pixel_values_videos
,
"video pixel values"
)
video_grid_thw
=
self
.
_validate_and_reshape_mm_tensor
(
video_grid_thw
,
"video grid_thw"
)
return
Ernie4_5_VLVideoPixelInputs
(
return
Ernie4_5_VLVideoPixelInputs
(
type
=
"pixel_values_videos"
,
type
=
"pixel_values_videos"
,
pixel_values_videos
=
pixel_values_videos
,
pixel_values_videos
=
pixel_values_videos
,
...
...
vllm/model_executor/models/fuyu.py
View file @
cc253b73
...
@@ -59,17 +59,14 @@ class FuyuImagePatchInputs(TensorSchema):
...
@@ -59,17 +59,14 @@ class FuyuImagePatchInputs(TensorSchema):
type
:
Literal
[
"image_patches"
]
=
"image_patches"
type
:
Literal
[
"image_patches"
]
=
"image_patches"
flat_data
:
Annotated
[
image_patches_flat
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bnp"
,
"fn"
)]
torch
.
Tensor
,
TensorShape
(
"bnp"
,
"fn"
),
]
patches_per_image
:
Annotated
[
list
[
int
],
TensorShape
(
"bn"
)]
patches_per_image
:
Annotated
[
list
[
int
],
TensorShape
(
"bn"
)]
"""
"""
The number of total patches for each image in the batch.
The number of total patches for each image in the batch.
This is used to split the embeddings which has the first two dimensions
This is used to split the embeddings which has the first two dimensions
flattened just like `
flat_d
at
a
`.
flattened just like `
image_patches_fl
at`.
"""
"""
...
@@ -174,28 +171,10 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
...
@@ -174,28 +171,10 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
tok_kwargs
=
tok_kwargs
,
tok_kwargs
=
tok_kwargs
,
)
)
image_patches
=
processed_outputs
.
get
(
"image_patches"
)
image_patches
=
processed_outputs
[
"image_patches"
]
if
image_patches
is
not
None
:
processed_outputs
[
"image_patches"
]
=
flatten_bn
(
image_patches
)
images
=
mm_data
[
"images"
]
processed_outputs
[
"patches_per_image"
]
=
torch
.
tensor
(
assert
isinstance
(
images
,
list
)
[
len
(
p
)
for
p
in
image_patches
])
# Original output: (1, num_images, Pn, Px * Py * C)
# New output: (num_images, Pn, Px * Py * C)
# image_patches is a list with shape:
# (1, num_images, Pn, Px * Py * C)
# before Transformers 4.53
if
isinstance
(
image_patches
,
list
):
assert
len
(
image_patches
)
==
1
assert
(
isinstance
(
image_patches
[
0
],
torch
.
Tensor
)
and
len
(
image_patches
[
0
])
==
len
(
images
))
processed_outputs
[
"image_patches"
]
=
image_patches
[
0
]
# image_patches is a tensor with shape:
# (num_images, Pn, Px * Py * C)
# after Transformers 4.53
elif
isinstance
(
image_patches
,
torch
.
Tensor
):
assert
len
(
image_patches
)
==
len
(
images
)
else
:
raise
AssertionError
(
"This line should be unreachable."
)
return
processed_outputs
return
processed_outputs
...
@@ -218,7 +197,13 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
...
@@ -218,7 +197,13 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
hf_inputs
:
BatchFeature
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
dict
(
image_patches
=
MultiModalFieldConfig
.
batched
(
"image"
))
patches_per_image
=
hf_inputs
.
get
(
"patches_per_image"
,
torch
.
empty
(
0
))
return
dict
(
image_patches
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
patches_per_image
),
patches_per_image
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
def
_get_prompt_updates
(
def
_get_prompt_updates
(
self
,
self
,
...
@@ -263,6 +248,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
...
@@ -263,6 +248,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
info
=
FuyuProcessingInfo
,
info
=
FuyuProcessingInfo
,
dummy_inputs
=
FuyuDummyInputsBuilder
)
dummy_inputs
=
FuyuDummyInputsBuilder
)
class
FuyuForCausalLM
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
class
FuyuForCausalLM
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
merge_by_field_config
=
True
hf_to_vllm_mapper
=
WeightsMapper
(
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
orig_to_new_prefix
=
{
...
@@ -306,29 +292,28 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -306,29 +292,28 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def
_parse_and_validate_image_input
(
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
FuyuImagePatchInputs
]:
self
,
**
kwargs
:
object
)
->
Optional
[
FuyuImagePatchInputs
]:
image_patches
=
kwargs
.
pop
(
"image_patches"
,
None
)
image_patches
=
kwargs
.
pop
(
"image_patches"
,
None
)
if
image_patches
is
not
None
:
patches_per_image
=
kwargs
.
pop
(
"patches_per_image"
,
None
)
image_patches_flat
=
flatten_bn
(
image_patches
)
flat_data
=
flatten_bn
(
image_patches_flat
,
concat
=
True
)
return
FuyuImagePatchInputs
(
type
=
"image_patches"
,
flat_data
=
flat_data
,
patches_per_image
=
[
x
.
size
(
0
)
for
x
in
image_patches_flat
],
resolve_bindings
=
{
"fn"
:
self
.
image_feature_size
},
)
return
None
if
image_patches
is
None
:
return
None
return
FuyuImagePatchInputs
(
type
=
"image_patches"
,
image_patches_flat
=
image_patches
,
patches_per_image
=
patches_per_image
,
resolve_bindings
=
{
"fn"
:
self
.
image_feature_size
},
)
def
_process_image_input
(
def
_process_image_input
(
self
,
image_input
:
FuyuImagePatchInputs
)
->
MultiModalEmbeddings
:
self
,
image_input
:
FuyuImagePatchInputs
)
->
MultiModalEmbeddings
:
image_patches_flat
=
image_input
[
"
flat_d
at
a
"
]
image_patches_flat
=
image_input
[
"
image_patches_fl
at"
]
patches_per_image
=
image_input
[
"patches_per_image"
]
patches_per_image
=
image_input
[
"patches_per_image"
]
assert
self
.
vision_embed_tokens
is
not
None
assert
self
.
vision_embed_tokens
is
not
None
vision_embeddings_flat
,
_
=
self
.
vision_embed_tokens
(
vision_embeddings_flat
,
_
=
self
.
vision_embed_tokens
(
image_patches_flat
)
image_patches_flat
)
return
vision_embeddings_flat
.
split
(
patches_per_image
,
dim
=
0
)
return
vision_embeddings_flat
.
split
(
patches_per_image
.
tolist
()
,
dim
=
0
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
return
self
.
language_model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment