Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cc253b73
Unverified
Commit
cc253b73
authored
Oct 02, 2025
by
Cyrus Leung
Committed by
GitHub
Oct 02, 2025
Browse files
[Model] Use `merge_by_field_config` for MM models (D-F) (#26076)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
7d6fb905
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
99 additions
and
177 deletions
+99
-177
vllm/model_executor/models/deepseek_vl2.py
vllm/model_executor/models/deepseek_vl2.py
+25
-33
vllm/model_executor/models/dots_ocr.py
vllm/model_executor/models/dots_ocr.py
+24
-51
vllm/model_executor/models/ernie45_vl.py
vllm/model_executor/models/ernie45_vl.py
+24
-52
vllm/model_executor/models/fuyu.py
vllm/model_executor/models/fuyu.py
+26
-41
No files found.
vllm/model_executor/models/deepseek_vl2.py
View file @
cc253b73
...
...
@@ -20,8 +20,7 @@ from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from
vllm.model_executor.models.transformers
import
replace_linear_class
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargsItems
,
MultiModalUUIDDict
,
NestedTensors
)
MultiModalKwargsItems
,
MultiModalUUIDDict
)
from
vllm.multimodal.parse
import
(
ImageEmbeddingItems
,
ImageProcessorItems
,
ImageSize
,
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
...
...
@@ -40,7 +39,7 @@ from vllm.utils import is_list_of
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
flatten_bn
,
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
init_vllm_registered_model
,
maybe_prefix
)
# The image token id may be various
...
...
@@ -50,15 +49,15 @@ _IMAGE_TOKEN = "<image>"
class
DeepseekVL2ImagePixelInputs
(
TensorSchema
):
"""
Dimensions:
- bn: Batch size * number of images
- bn
p
: Batch size * number of images
* number of patches
- p: Number of patches
- c: Number of channels (3)
- h: Height of each image
- w: Width of each image
"""
type
:
Literal
[
"pixel_values"
]
data
:
Annotated
[
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
TensorShape
(
"bn
"
,
"
p"
,
3
,
"h"
,
"w"
,
dynamic_dims
=
{
"p"
})]
data
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bnp"
,
3
,
"h"
,
"w"
,
dynamic_dims
=
{
"
bn
p"
})]
images_spatial_crop
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn"
,
2
)]
...
...
@@ -228,12 +227,8 @@ class DeepseekVL2MultiModalProcessor(
tok_kwargs
=
tok_kwargs
,
)
pixel_values
=
processed_outputs
[
"pixel_values"
]
# split pixel values into patches corresponding to each image
images_spatial_crop
=
processed_outputs
[
"images_spatial_crop"
]
patches_per_image
=
[
x
.
prod
().
item
()
+
1
for
x
in
images_spatial_crop
]
pixel_values
=
pixel_values
.
split
(
patches_per_image
)
processed_outputs
[
"pixel_values"
]
=
pixel_values
processed_outputs
[
"num_patches"
]
=
(
processed_outputs
[
"images_spatial_crop"
].
prod
(
-
1
)
+
1
)
return
processed_outputs
...
...
@@ -242,8 +237,11 @@ class DeepseekVL2MultiModalProcessor(
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
num_patches
=
hf_inputs
.
get
(
"num_patches"
,
torch
.
empty
(
0
))
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
),
pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
num_patches
),
images_spatial_crop
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
...
...
@@ -318,6 +316,7 @@ class DeepseekVL2MultiModalProcessor(
info
=
DeepseekVL2ProcessingInfo
,
dummy_inputs
=
DeepseekVL2DummyInputsBuilder
)
class
DeepseekVLV2ForCausalLM
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
merge_by_field_config
=
True
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"language."
:
"language_model."
,
...
...
@@ -460,37 +459,30 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
if
pixel_values
is
not
None
:
expected_h
=
expected_w
=
self
.
vision_config
.
image_size
return
DeepseekVL2ImagePixelInputs
(
type
=
"pixel_values"
,
data
=
flatten_bn
(
pixel_values
),
images_spatial_crop
=
flatten_bn
(
images_spatial_crop
,
concat
=
True
),
resolve_bindings
=
{
"h"
:
expected_h
,
"w"
:
expected_w
,
})
return
DeepseekVL2ImagePixelInputs
(
type
=
"pixel_values"
,
data
=
pixel_values
,
images_spatial_crop
=
images_spatial_crop
,
resolve_bindings
=
{
"h"
:
expected_h
,
"w"
:
expected_w
,
})
if
image_embeds
is
not
None
:
return
DeepseekVL2VImageEmbeddingInputs
(
type
=
"image_embeds"
,
data
=
flatten_bn
(
image_embeds
)
,
data
=
image_embeds
,
)
raise
AssertionError
(
"This line should be unreachable."
)
def
_pixel_values_to_embedding
(
self
,
pixel_values
:
Nested
Tensor
s
,
pixel_values
:
torch
.
Tensor
,
images_spatial_crop
:
torch
.
Tensor
,
)
->
NestedTensors
:
# Pixel_values: n_image * batch_size * [patch_per_img, 3, height, width]
total_tiles
=
[
x
for
x
in
pixel_values
]
# [batch_all_tiles, 3, height, width]
total_tiles
=
torch
.
cat
(
total_tiles
,
dim
=
0
)
)
->
list
[
torch
.
Tensor
]:
# [batch_all_tiles, vit_seq_len, c]
images_feature
=
self
.
vision
.
forward_features
(
total_til
es
)
images_feature
=
self
.
vision
.
forward_features
(
pixel_valu
es
)
# [batch_all_tiles, hw, D]
images_embeds
=
self
.
projector
(
images_feature
)
...
...
@@ -573,7 +565,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return
vision_embeddings
def
_process_image_input
(
self
,
image_input
:
DeepseekVL2ImageInputs
)
->
torch
.
Tensor
:
self
,
image_input
:
DeepseekVL2ImageInputs
)
->
list
[
torch
.
Tensor
]
:
if
image_input
[
"type"
]
==
"image_embeds"
:
image_data
=
image_input
[
"data"
]
if
is_list_of
(
image_data
,
torch
.
Tensor
):
...
...
vllm/model_executor/models/dots_ocr.py
View file @
cc253b73
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
,
Mapping
from
typing
import
Literal
,
Optional
,
TypedDict
,
Union
from
typing
import
Annotated
,
Literal
,
Optional
,
Union
import
torch
import
torch.nn
as
nn
...
...
@@ -42,34 +42,38 @@ from vllm.platforms import _Backend
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs.dotsocr
import
(
DotsOCRConfig
,
DotsVisionConfig
)
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.vision
import
run_dp_sharded_mrope_vision_model
IMAGE_TOKEN
=
"<|imgpad|>"
class
DotsOCRImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
,
"image_grid_thw"
]
class
DotsOCRImagePixelInputs
(
TensorSchema
):
"""
Dimensions:
- np: The total number of patches over each image over each prompt in
the batch
- ni: Number of images
- cps: Number of channels * patch_size * patch_size
"""
type
:
Literal
[
"pixel_values"
]
pixel_values
:
torch
.
Tensor
image_grid_thw
:
torch
.
Tensor
pixel_values
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"np"
,
"cps"
)]
image_grid_thw
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"ni"
,
3
)]
class
DotsOCRImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
,
"image_grid_thw"
]
image_embeds
:
torch
.
Tensor
"""Supported types:
- List[`torch.Tensor`]: A list of tensors holding all images' features.
Each tensor holds an image's features.
- `torch.Tensor`: A tensor holding all images' features
(concatenation of all images' feature tensors).
Tensor shape: `(num_image_features, hidden_size)`
- `num_image_features` varies based on
the number and resolution of the images.
- `hidden_size` must match the hidden size of language model backbone.
class
DotsOCRImageEmbeddingInputs
(
TensorSchema
):
"""
Dimensions:
- nf: Number of image features
- hs: Hidden size
- ni: Number of images
"""
type
:
Literal
[
"image_embeds"
]
image_grid_thw
:
torch
.
Tensor
image_embeds
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"nf"
,
"hs"
)]
image_grid_thw
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"ni"
,
3
)]
DotsOCRImageInputs
=
Union
[
DotsOCRImagePixelInputs
,
...
...
@@ -654,6 +658,8 @@ class DotsVisionTransformer(nn.Module):
)
class
DotsOCRForCausalLM
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
,
SupportsLoRA
):
merge_by_field_config
=
True
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_substr
=
{
".attn.qkv_proj."
:
".attn.qkv."
,
...
...
@@ -709,22 +715,6 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
architectures
=
[
"Qwen2ForCausalLM"
],
)
def
_validate_and_reshape_mm_tensor
(
self
,
mm_input
:
object
,
name
:
str
)
->
torch
.
Tensor
:
if
not
isinstance
(
mm_input
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
f
"Incorrect type of
{
name
}
. "
f
"Got type:
{
type
(
mm_input
)
}
"
)
if
isinstance
(
mm_input
,
torch
.
Tensor
):
if
mm_input
.
ndim
==
2
:
return
mm_input
if
mm_input
.
ndim
!=
3
:
raise
ValueError
(
f
"
{
name
}
should be 2D or batched 3D tensor. "
f
"Got ndim:
{
mm_input
.
ndim
}
"
f
"(shape=
{
mm_input
.
shape
}
)"
)
return
torch
.
concat
(
list
(
mm_input
))
else
:
return
torch
.
concat
(
mm_input
)
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
DotsOCRImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
...
...
@@ -735,28 +725,11 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
return
None
if
pixel_values
is
not
None
:
pixel_values
=
self
.
_validate_and_reshape_mm_tensor
(
pixel_values
,
"image pixel values"
)
image_grid_thw
=
self
.
_validate_and_reshape_mm_tensor
(
image_grid_thw
,
"image grid_thw"
)
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of image pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
return
DotsOCRImagePixelInputs
(
type
=
"pixel_values"
,
pixel_values
=
pixel_values
,
image_grid_thw
=
image_grid_thw
)
if
image_embeds
is
not
None
:
image_embeds
=
self
.
_validate_and_reshape_mm_tensor
(
image_embeds
,
"image embeds"
)
image_grid_thw
=
self
.
_validate_and_reshape_mm_tensor
(
image_grid_thw
,
"image grid_thw"
)
if
not
isinstance
(
image_embeds
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of image embeddings. "
f
"Got type:
{
type
(
image_embeds
)
}
"
)
return
DotsOCRImageEmbeddingInputs
(
type
=
"image_embeds"
,
image_embeds
=
image_embeds
,
image_grid_thw
=
image_grid_thw
)
...
...
vllm/model_executor/models/ernie45_vl.py
View file @
cc253b73
...
...
@@ -25,7 +25,7 @@
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
functools
import
partial
from
typing
import
Any
,
Callable
,
Literal
,
Optional
,
TypedDict
,
Union
from
typing
import
Annotated
,
Any
,
Callable
,
Literal
,
Optional
,
Union
import
numpy
as
np
import
torch
...
...
@@ -56,6 +56,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.ernie45_vl_moe
import
Ernie4_5_VLMoeForCausalLM
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
...
...
@@ -579,38 +580,38 @@ class Ernie4_5_VisionTransformer(nn.Module):
# === Vision Inputs === #
class
Ernie4_5_VLImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
pixel_values
:
torch
.
Tensor
"""Shape:
`(num_patches, num_channels * patch_size * patch_size)`
class
Ernie4_5_VLImagePixelInputs
(
TensorSchema
):
"""
grid_thw
:
torch
.
Tensor
"""Shape: `(num_images, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
Dimensions:
- np: The total number of patches over each image over each prompt in
the batch
- ni: Number of images
- cps: Number of channels * patch_size * patch_size
"""
type
:
Literal
[
"pixel_values"
]
pixel_values
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"np"
,
"cps"
)]
image_grid_thw
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"ni"
,
3
)]
Ernie4_5_VLImageInputs
=
Ernie4_5_VLImagePixelInputs
class
Ernie4_5_VLVideoPixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values_videos"
]
pixel_values_videos
:
torch
.
Tensor
"""Shape:
`(num_patches,
num_channels * temporal_patch_size * patch_size * patch_size)`
class
Ernie4_5_VLVideoPixelInputs
(
TensorSchema
):
"""
video_grid_thw
:
torch
.
Tensor
"""Shape: `(num_videos, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
Dimensions:
- np: The total number of patches over each image over each prompt in
the batch
- ni: Number of images
- cps: Number of channels * temporal_patch_size * patch_size *
patch_size
"""
type
:
Literal
[
"pixel_values_videos"
]
pixel_values_videos
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"np"
,
"cps"
)]
video_grid_thw
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"ni"
,
3
)]
Ernie4_5_VLVideoInputs
=
Ernie4_5_VL
Image
PixelInputs
Ernie4_5_VLVideoInputs
=
Ernie4_5_VL
Video
PixelInputs
# === Vision Processor === #
...
...
@@ -1213,6 +1214,7 @@ class Ernie4_5_VLDummyInputsBuilder(
dummy_inputs
=
Ernie4_5_VLDummyInputsBuilder
)
class
Ernie4_5_VLMoeForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsLoRA
,
SupportsPP
):
merge_by_field_config
=
True
packed_modules_mapping
=
{
"qkv_proj"
:
[
...
...
@@ -1325,22 +1327,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
_validate_and_reshape_mm_tensor
(
self
,
mm_input
:
object
,
name
:
str
)
->
torch
.
Tensor
:
if
not
isinstance
(
mm_input
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
f
"Incorrect type of
{
name
}
. "
f
"Got type:
{
type
(
mm_input
)
}
"
)
if
isinstance
(
mm_input
,
torch
.
Tensor
):
if
mm_input
.
ndim
==
2
:
return
mm_input
if
mm_input
.
ndim
!=
3
:
raise
ValueError
(
f
"
{
name
}
should be 2D or batched 3D tensor. "
f
"Got ndim:
{
mm_input
.
ndim
}
"
f
"(shape=
{
mm_input
.
shape
}
)"
)
return
mm_input
.
reshape
(
-
1
,
mm_input
.
shape
[
-
1
])
else
:
return
torch
.
concat
(
mm_input
)
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
Ernie4_5_VLImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
...
...
@@ -1350,15 +1336,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
return
None
if
pixel_values
is
not
None
:
pixel_values
=
self
.
_validate_and_reshape_mm_tensor
(
pixel_values
,
"image pixel values"
)
image_grid_thw
=
self
.
_validate_and_reshape_mm_tensor
(
image_grid_thw
,
"image grid_thw"
)
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of image pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
return
Ernie4_5_VLImagePixelInputs
(
type
=
"pixel_values"
,
pixel_values
=
pixel_values
,
image_grid_thw
=
image_grid_thw
)
...
...
@@ -1372,11 +1349,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
return
None
if
pixel_values_videos
is
not
None
:
pixel_values_videos
=
self
.
_validate_and_reshape_mm_tensor
(
pixel_values_videos
,
"video pixel values"
)
video_grid_thw
=
self
.
_validate_and_reshape_mm_tensor
(
video_grid_thw
,
"video grid_thw"
)
return
Ernie4_5_VLVideoPixelInputs
(
type
=
"pixel_values_videos"
,
pixel_values_videos
=
pixel_values_videos
,
...
...
vllm/model_executor/models/fuyu.py
View file @
cc253b73
...
...
@@ -59,17 +59,14 @@ class FuyuImagePatchInputs(TensorSchema):
type
:
Literal
[
"image_patches"
]
=
"image_patches"
flat_data
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bnp"
,
"fn"
),
]
image_patches_flat
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bnp"
,
"fn"
)]
patches_per_image
:
Annotated
[
list
[
int
],
TensorShape
(
"bn"
)]
"""
The number of total patches for each image in the batch.
This is used to split the embeddings which has the first two dimensions
flattened just like `
flat_d
at
a
`.
flattened just like `
image_patches_fl
at`.
"""
...
...
@@ -174,28 +171,10 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
tok_kwargs
=
tok_kwargs
,
)
image_patches
=
processed_outputs
.
get
(
"image_patches"
)
if
image_patches
is
not
None
:
images
=
mm_data
[
"images"
]
assert
isinstance
(
images
,
list
)
# Original output: (1, num_images, Pn, Px * Py * C)
# New output: (num_images, Pn, Px * Py * C)
# image_patches is a list with shape:
# (1, num_images, Pn, Px * Py * C)
# before Transformers 4.53
if
isinstance
(
image_patches
,
list
):
assert
len
(
image_patches
)
==
1
assert
(
isinstance
(
image_patches
[
0
],
torch
.
Tensor
)
and
len
(
image_patches
[
0
])
==
len
(
images
))
processed_outputs
[
"image_patches"
]
=
image_patches
[
0
]
# image_patches is a tensor with shape:
# (num_images, Pn, Px * Py * C)
# after Transformers 4.53
elif
isinstance
(
image_patches
,
torch
.
Tensor
):
assert
len
(
image_patches
)
==
len
(
images
)
else
:
raise
AssertionError
(
"This line should be unreachable."
)
image_patches
=
processed_outputs
[
"image_patches"
]
processed_outputs
[
"image_patches"
]
=
flatten_bn
(
image_patches
)
processed_outputs
[
"patches_per_image"
]
=
torch
.
tensor
(
[
len
(
p
)
for
p
in
image_patches
])
return
processed_outputs
...
...
@@ -218,7 +197,13 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
dict
(
image_patches
=
MultiModalFieldConfig
.
batched
(
"image"
))
patches_per_image
=
hf_inputs
.
get
(
"patches_per_image"
,
torch
.
empty
(
0
))
return
dict
(
image_patches
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
patches_per_image
),
patches_per_image
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
def
_get_prompt_updates
(
self
,
...
...
@@ -263,6 +248,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
info
=
FuyuProcessingInfo
,
dummy_inputs
=
FuyuDummyInputsBuilder
)
class
FuyuForCausalLM
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
merge_by_field_config
=
True
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
...
...
@@ -306,29 +292,28 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
FuyuImagePatchInputs
]:
image_patches
=
kwargs
.
pop
(
"image_patches"
,
None
)
if
image_patches
is
not
None
:
image_patches_flat
=
flatten_bn
(
image_patches
)
flat_data
=
flatten_bn
(
image_patches_flat
,
concat
=
True
)
return
FuyuImagePatchInputs
(
type
=
"image_patches"
,
flat_data
=
flat_data
,
patches_per_image
=
[
x
.
size
(
0
)
for
x
in
image_patches_flat
],
resolve_bindings
=
{
"fn"
:
self
.
image_feature_size
},
)
patches_per_image
=
kwargs
.
pop
(
"patches_per_image"
,
None
)
return
None
if
image_patches
is
None
:
return
None
return
FuyuImagePatchInputs
(
type
=
"image_patches"
,
image_patches_flat
=
image_patches
,
patches_per_image
=
patches_per_image
,
resolve_bindings
=
{
"fn"
:
self
.
image_feature_size
},
)
def
_process_image_input
(
self
,
image_input
:
FuyuImagePatchInputs
)
->
MultiModalEmbeddings
:
image_patches_flat
=
image_input
[
"
flat_d
at
a
"
]
image_patches_flat
=
image_input
[
"
image_patches_fl
at"
]
patches_per_image
=
image_input
[
"patches_per_image"
]
assert
self
.
vision_embed_tokens
is
not
None
vision_embeddings_flat
,
_
=
self
.
vision_embed_tokens
(
image_patches_flat
)
return
vision_embeddings_flat
.
split
(
patches_per_image
,
dim
=
0
)
return
vision_embeddings_flat
.
split
(
patches_per_image
.
tolist
()
,
dim
=
0
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment