Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
179a6a36
Unverified
Commit
179a6a36
authored
Aug 04, 2024
by
Jee Jee Li
Committed by
GitHub
Aug 04, 2024
Browse files
[Model]Refactor MiniCPMV (#7020)
Co-authored-by:
Cyrus Leung
<
cyrus.tl.leung@gmail.com
>
parent
83c644fe
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
937 additions
and
386 deletions
+937
-386
docs/source/models/supported_models.rst
docs/source/models/supported_models.rst
+1
-1
vllm/model_executor/models/idefics2_vision_model.py
vllm/model_executor/models/idefics2_vision_model.py
+296
-0
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+639
-384
vllm/model_executor/models/na_vit.py
vllm/model_executor/models/na_vit.py
+1
-1
No files found.
docs/source/models/supported_models.rst
View file @
179a6a36
...
@@ -220,7 +220,7 @@ Vision Language Models
...
@@ -220,7 +220,7 @@ Vision Language Models
- Phi-3-Vision
- Phi-3-Vision
- :code:`microsoft/Phi-3-vision-128k-instruct`, etc.
- :code:`microsoft/Phi-3-vision-128k-instruct`, etc.
-
-
* - :code:`MiniCPM
-
V`
* - :code:`MiniCPMV`
- MiniCPM-V
- MiniCPM-V
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc.
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc.
-
-
...
...
vllm/model_executor/models/idefics2_vision_model.py
0 → 100644
View file @
179a6a36
# coding=utf-8
# adapted from https://github.com/huggingface/transformers/blob/v4.43.2/src/transformers/models/idefics2/modeling_idefics2.py
# Copyright 2024 The vLLM team.
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Idefics2 model."""
from
typing
import
Optional
import
torch
from
torch
import
nn
from
transformers.models.idefics2.configuration_idefics2
import
(
Idefics2Config
,
Idefics2VisionConfig
)
from
xformers
import
ops
as
xops
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
class
Idefics2VisionEmbeddings
(
nn
.
Module
):
"""
This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings
` to enable images of variable
resolution.
The modifications are adapted from [Patch n' Pack: NaViT, a Vision
Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
which allows treating images in their native aspect ratio and without the
need to resize them to the same fixed size. In particular, we start from the
original pre-trained SigLIP model(which uses images of fixed-size square
images) and adapt it by training on images of variable resolutions.
"""
def
__init__
(
self
,
config
:
Idefics2VisionConfig
):
super
().
__init__
()
self
.
embed_dim
=
config
.
hidden_size
self
.
image_size
=
config
.
image_size
self
.
patch_size
=
config
.
patch_size
self
.
patch_embedding
=
nn
.
Conv2d
(
in_channels
=
config
.
num_channels
,
out_channels
=
self
.
embed_dim
,
kernel_size
=
self
.
patch_size
,
stride
=
self
.
patch_size
,
padding
=
"valid"
,
)
self
.
num_patches_per_side
=
self
.
image_size
//
self
.
patch_size
self
.
num_patches
=
self
.
num_patches_per_side
**
2
self
.
num_positions
=
self
.
num_patches
self
.
position_embedding
=
nn
.
Embedding
(
self
.
num_positions
,
self
.
embed_dim
)
def
forward
(
self
,
pixel_values
:
torch
.
FloatTensor
,
patch_attention_mask
:
torch
.
BoolTensor
,
)
->
torch
.
Tensor
:
batch_size
,
_
,
max_im_h
,
max_im_w
=
pixel_values
.
shape
patch_embeds
=
self
.
patch_embedding
(
pixel_values
)
embeddings
=
patch_embeds
.
flatten
(
2
).
transpose
(
1
,
2
)
max_nb_patches_h
,
max_nb_patches_w
=
(
max_im_h
//
self
.
patch_size
,
max_im_w
//
self
.
patch_size
,
)
boundaries
=
torch
.
arange
(
1
/
self
.
num_patches_per_side
,
1.0
,
1
/
self
.
num_patches_per_side
)
position_ids
=
torch
.
full
(
size
=
(
batch_size
,
max_nb_patches_h
*
max_nb_patches_w
),
fill_value
=
0
)
for
batch_idx
,
p_attn_mask
in
enumerate
(
patch_attention_mask
):
nb_patches_h
=
p_attn_mask
[:,
0
].
sum
()
nb_patches_w
=
p_attn_mask
[
0
].
sum
()
fractional_coords_h
=
torch
.
arange
(
0
,
1
-
1e-6
,
1
/
nb_patches_h
)
fractional_coords_w
=
torch
.
arange
(
0
,
1
-
1e-6
,
1
/
nb_patches_w
)
bucket_coords_h
=
torch
.
bucketize
(
fractional_coords_h
,
boundaries
,
right
=
True
)
bucket_coords_w
=
torch
.
bucketize
(
fractional_coords_w
,
boundaries
,
right
=
True
)
pos_ids
=
(
bucket_coords_h
[:,
None
]
*
self
.
num_patches_per_side
+
bucket_coords_w
).
flatten
()
position_ids
[
batch_idx
][
p_attn_mask
.
view
(
-
1
).
cpu
()]
=
pos_ids
position_ids
=
position_ids
.
to
(
self
.
position_embedding
.
weight
.
device
)
embeddings
=
embeddings
+
self
.
position_embedding
(
position_ids
)
return
embeddings
class
Idefics2VisionAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
self
,
config
:
Idefics2Config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
head_dim
=
self
.
embed_dim
//
self
.
num_heads
if
self
.
head_dim
*
self
.
num_heads
!=
self
.
embed_dim
:
raise
ValueError
(
f
"embed_dim must be divisible by num_heads (got `embed_dim`:
{
self
.
embed_dim
}
and `num_heads`:"
# noqa: E501
f
"
{
self
.
num_heads
}
)."
)
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
dropout
=
config
.
attention_dropout
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
embed_dim
,
self
.
head_dim
,
self
.
num_heads
,
quant_config
=
quant_config
,
)
self
.
out_proj
=
RowParallelLinear
(
self
.
embed_dim
,
self
.
embed_dim
,
bias
=
True
,
quant_config
=
quant_config
,
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_heads_per_partition
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
self
.
is_causal
=
False
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
batch_size
,
q_len
,
_
=
hidden_states
.
size
()
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
# batch_size, q_len, 3 * num_heads_per_partition * head_dim
query_states
,
key_states
,
value_states
=
qkv
.
chunk
(
3
,
dim
=-
1
)
query_states
=
query_states
.
view
(
batch_size
,
q_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
key_states
=
key_states
.
view
(
batch_size
,
q_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
value_states
=
value_states
.
view
(
batch_size
,
q_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
# see: https://facebookresearch.github.io/xformers/components/ops.html
out
=
xops
.
memory_efficient_attention_forward
(
query_states
,
key_states
,
value_states
,
p
=
self
.
dropout
,
scale
=
self
.
scale
,
)
out
=
out
.
view
(
batch_size
,
q_len
,
-
1
)
attn_output
,
_
=
self
.
out_proj
(
out
)
return
attn_output
class
Idefics2VisionMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Idefics2Config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
activation_fn
=
get_act_fn
(
config
.
hidden_act
)
self
.
fc1
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
intermediate_size
,
bias
=
True
,
quant_config
=
quant_config
,
)
self
.
fc2
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
fc1
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
hidden_states
)
hidden_states
,
_
=
self
.
fc2
(
hidden_states
)
return
hidden_states
class
Idefics2EncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Idefics2Config
):
super
().
__init__
()
self
.
embed_dim
=
config
.
hidden_size
self
.
self_attn
=
Idefics2VisionAttention
(
config
)
self
.
layer_norm1
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
mlp
=
Idefics2VisionMLP
(
config
)
self
.
layer_norm2
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""
Args:
hidden_states (`torch.FloatTensor`):
Input to the layer of shape `(batch, seq_len, embed_dim)`.
"""
residual
=
hidden_states
hidden_states
=
self
.
layer_norm1
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
layer_norm2
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
class
Idefics2Encoder
(
nn
.
Module
):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention
layers. Each layer is a
[`Idefics2EncoderLayer`].
Args:
config: Idefics2Config
"""
def
__init__
(
self
,
config
:
Idefics2Config
):
super
().
__init__
()
self
.
config
=
config
self
.
layers
=
nn
.
ModuleList
([
Idefics2EncoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
def
forward
(
self
,
inputs_embeds
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
r
"""
Args:
inputs_embeds (torch.Tensor):
Optionally, instead of passing `input_ids` you can choose to
directly pass an embedded representation.
This is useful if you want more control over how to convert
`input_ids` indices into associated vectorsthan the model's
internal embedding lookup matrix.
"""
hidden_states
=
inputs_embeds
for
encoder_layer
in
self
.
layers
:
layer_outputs
=
encoder_layer
(
hidden_states
)
hidden_states
=
layer_outputs
return
hidden_states
class
Idefics2VisionTransformer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Idefics2VisionConfig
):
super
().
__init__
()
embed_dim
=
config
.
hidden_size
self
.
config
=
config
self
.
embeddings
=
Idefics2VisionEmbeddings
(
config
)
self
.
encoder
=
Idefics2Encoder
(
config
)
self
.
post_layernorm
=
nn
.
LayerNorm
(
embed_dim
,
eps
=
config
.
layer_norm_eps
)
def
get_input_embeddings
(
self
):
return
self
.
embeddings
def
forward
(
self
,
pixel_values
,
patch_attention_mask
:
Optional
[
torch
.
BoolTensor
]
=
None
,
)
->
torch
.
tensor
:
hidden_states
=
self
.
embeddings
(
pixel_values
=
pixel_values
,
patch_attention_mask
=
patch_attention_mask
)
encoder_outputs
=
self
.
encoder
(
hidden_states
)
last_hidden_state
=
self
.
post_layernorm
(
encoder_outputs
)
return
last_hidden_state
vllm/model_executor/models/minicpmv.py
View file @
179a6a36
...
@@ -24,7 +24,8 @@
...
@@ -24,7 +24,8 @@
import
math
import
math
import
re
import
re
from
functools
import
partial
from
functools
import
partial
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
(
Any
,
Callable
,
Iterable
,
List
,
Optional
,
Tuple
,
TypedDict
,
Union
)
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -38,11 +39,14 @@ from transformers.configuration_utils import PretrainedConfig
...
@@ -38,11 +39,14 @@ from transformers.configuration_utils import PretrainedConfig
from
vllm.attention
import
AttentionMetadata
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
SupportsVision
from
vllm.model_executor.models.interfaces
import
SupportsVision
from
vllm.model_executor.models.llama
import
LlamaModel
from
vllm.model_executor.models.llama
import
LlamaModel
...
@@ -54,12 +58,45 @@ from vllm.multimodal.image import (cached_get_image_processor,
...
@@ -54,12 +58,45 @@ from vllm.multimodal.image import (cached_get_image_processor,
cached_get_tokenizer
)
cached_get_tokenizer
)
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
,
SequenceData
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
,
SequenceData
from
.idefics2_vision_model
import
Idefics2VisionTransformer
logger
=
init_logger
(
__name__
)
_KEYS_TO_MODIFY_MAPPING
=
{
_KEYS_TO_MODIFY_MAPPING
=
{
"llm.lm_head"
:
"lm_head"
,
"llm.lm_head"
:
"lm_head"
,
"llm.model"
:
"llm"
,
"llm.model"
:
"llm"
,
}
}
class
MiniCPMVImagePixelInputs
(
TypedDict
):
pixel_values
:
List
[
torch
.
Tensor
]
"""
Shape: `(batch_size * num_images, num_channels, height, width)`
Note that the image size may vary, so we pass it as a list
instead of a batched tensor.
"""
image_bounds
:
torch
.
Tensor
"""
Shape: `(batch_size * num_images, 2)`
This should be in `(start, stop)` format.
"""
tgt_sizes
:
torch
.
Tensor
"""
Shape: `(batch_size * num_images, 2)`
This should be in `(height, width)` format.
"""
MiniCPMVImageInputs
=
MiniCPMVImagePixelInputs
DEFAULT_LN
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
def
get_abs_pos
(
abs_pos
:
torch
.
Tensor
,
tgt_size
:
torch
.
Tensor
):
def
get_abs_pos
(
abs_pos
:
torch
.
Tensor
,
tgt_size
:
torch
.
Tensor
):
# abs_pos: L, C
# abs_pos: L, C
# tgt_size: (H, W)
# tgt_size: (H, W)
...
@@ -68,23 +105,25 @@ def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor):
...
@@ -68,23 +105,25 @@ def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor):
# tgt_size = int(math.sqrt(tgt_size))
# tgt_size = int(math.sqrt(tgt_size))
dtype
=
abs_pos
.
dtype
dtype
=
abs_pos
.
dtype
return
F
.
interpolate
(
return
(
F
.
interpolate
(
abs_pos
.
float
().
reshape
(
1
,
src_size
,
src_size
,
-
1
).
permute
(
0
,
3
,
1
,
2
),
abs_pos
.
float
().
reshape
(
1
,
src_size
,
src_size
,
-
1
).
permute
(
0
,
3
,
1
,
2
),
size
=
(
tgt_size
[
0
],
tgt_size
[
1
]),
size
=
(
tgt_size
[
0
],
tgt_size
[
1
]),
mode
=
"bicubic"
,
mode
=
"bicubic"
,
align_corners
=
False
,
align_corners
=
False
,
).
permute
(
0
,
2
,
3
,
1
).
flatten
(
0
,
2
).
to
(
dtype
=
dtype
)
).
permute
(
0
,
2
,
3
,
1
).
flatten
(
0
,
2
).
to
(
dtype
=
dtype
)
)
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def
get_2d_sincos_pos_embed
(
embed_dim
:
int
,
def
get_2d_sincos_pos_embed
(
grid_size
:
Union
[
int
,
Tuple
[
int
,
int
]],
embed_dim
:
int
,
cls_token
:
bool
=
False
,
grid_size
:
Union
[
int
,
Tuple
[
int
,
int
]],
version
:
Tuple
[
int
,
int
]
=
(
2
,
0
)):
cls_token
:
bool
=
False
,
version
:
Tuple
[
int
,
int
]
=
(
2
,
0
),
):
"""
"""
grid_size: int of the grid height and width
grid_size: int of the grid height and width
return:
return:
pos_embed: [grid_size*grid_size, embed_dim] or
pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
"""
if
isinstance
(
grid_size
,
int
):
if
isinstance
(
grid_size
,
int
):
...
@@ -109,7 +148,7 @@ def get_2d_sincos_pos_embed(embed_dim: int,
...
@@ -109,7 +148,7 @@ def get_2d_sincos_pos_embed(embed_dim: int,
def
get_2d_sincos_pos_embed_from_grid
(
embed_dim
:
int
,
def
get_2d_sincos_pos_embed_from_grid
(
embed_dim
:
int
,
grid
:
Union
[
int
,
Tuple
[
int
,
int
]]
,
grid
:
np
.
ndarray
,
version
:
Tuple
[
int
,
int
]
=
(
2
,
0
)):
version
:
Tuple
[
int
,
int
]
=
(
2
,
0
)):
assert
embed_dim
%
2
==
0
assert
embed_dim
%
2
==
0
...
@@ -127,7 +166,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim: int,
...
@@ -127,7 +166,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim: int,
def
get_1d_sincos_pos_embed_from_grid
(
embed_dim
:
int
,
def
get_1d_sincos_pos_embed_from_grid
(
embed_dim
:
int
,
pos
:
int
,
pos
:
np
.
ndarray
,
version
:
Tuple
[
int
,
int
]
=
(
2
,
0
)):
version
:
Tuple
[
int
,
int
]
=
(
2
,
0
)):
"""
"""
embed_dim: output dimension for each position
embed_dim: output dimension for each position
...
@@ -136,24 +175,24 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim: int,
...
@@ -136,24 +175,24 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim: int,
"""
"""
assert
embed_dim
%
2
==
0
assert
embed_dim
%
2
==
0
omega
=
np
.
arange
(
embed_dim
//
2
,
dtype
=
np
.
float32
)
omega
=
np
.
arange
(
embed_dim
//
2
,
dtype
=
np
.
float32
)
omega
/=
embed_dim
/
2.
omega
/=
embed_dim
/
2.
0
omega
=
1.
/
10000
**
omega
# (D/2,)
omega
=
1.
0
/
10000
**
omega
# (D/2,)
if
version
==
(
2
,
0
):
if
version
==
(
2
,
0
):
pos
=
pos
.
reshape
(
-
1
)
# (M,)
pos
=
pos
.
reshape
(
-
1
)
# (M,)
out
=
np
.
einsum
(
'
m,d->md
'
,
pos
,
omega
)
# (M, D/2), outer product
out
=
np
.
einsum
(
"
m,d->md
"
,
pos
,
omega
)
# (M, D/2), outer product
emb_sin
=
np
.
sin
(
out
)
# (M, D/2)
emb_sin
=
np
.
sin
(
out
)
# (M, D/2)
emb_cos
=
np
.
cos
(
out
)
# (M, D/2)
emb_cos
=
np
.
cos
(
out
)
# (M, D/2)
emb
=
np
.
concatenate
([
emb_sin
,
emb_cos
],
axis
=
1
)
# (M, D)
emb
=
np
.
concatenate
([
emb_sin
,
emb_cos
],
axis
=
1
)
# (M, D)
else
:
else
:
out
=
np
.
einsum
(
'
hw,d->hwd
'
,
pos
,
omega
)
# (H, W, D/2), outer product
out
=
np
.
einsum
(
"
hw,d->hwd
"
,
pos
,
omega
)
# (H, W, D/2), outer product
emb_sin
=
np
.
sin
(
out
)
# (H, W, D/2)
emb_sin
=
np
.
sin
(
out
)
# (H, W, D/2)
emb_cos
=
np
.
cos
(
out
)
# (H, W, D/2)
emb_cos
=
np
.
cos
(
out
)
# (H, W, D/2)
emb
=
np
.
concatenate
([
emb_sin
,
emb_cos
],
axis
=-
1
)
# (H, W, D)
emb
=
np
.
concatenate
([
emb_sin
,
emb_cos
],
axis
=-
1
)
# (H, W, D)
return
emb
return
emb
class
Resampler
(
nn
.
Module
):
class
Base
Resampler
(
nn
.
Module
):
"""
"""
A 2D perceiver-resampler network with one cross attention layers by
A 2D perceiver-resampler network with one cross attention layers by
(grid_size**2) learnable queries and 2d sincos pos_emb
(grid_size**2) learnable queries and 2d sincos pos_emb
...
@@ -161,89 +200,151 @@ class Resampler(nn.Module):
...
@@ -161,89 +200,151 @@ class Resampler(nn.Module):
A tensor with the shape of (grid_size**2, embed_dim)
A tensor with the shape of (grid_size**2, embed_dim)
"""
"""
default_norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
def
__init__
(
self
,
def
__init__
(
self
,
num_queries
:
int
,
num_queries
:
int
,
embed_dim
:
int
,
grid_size
:
int
,
num_heads
:
int
,
embed_dim
:
int
,
kv_dim
:
Optional
[
int
]
=
None
,
num_heads
:
int
,
norm_layer
:
Callable
[[
int
],
nn
.
LayerNorm
]
=
DEFAULT_LN
,
kv_dim
:
Optional
[
int
]
=
None
,
)
->
None
:
norm_layer
:
nn
.
Module
=
default_norm_layer
,
adaptive
:
bool
=
False
,
max_size
:
Tuple
[
int
,
int
]
=
(
70
,
70
),
version
:
Tuple
[
int
,
int
]
=
(
2
,
0
)):
super
().
__init__
()
super
().
__init__
()
self
.
version
=
version
self
.
num_queries
=
num_queries
if
self
.
version
==
(
2
,
0
):
self
.
num_queries
=
grid_size
**
2
else
:
self
.
num_queries
=
num_queries
self
.
max_size
=
max_size
self
.
embed_dim
=
embed_dim
self
.
embed_dim
=
embed_dim
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
adaptive
=
adaptive
self
.
query
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_queries
,
embed_dim
))
self
.
query
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_queries
,
embed_dim
))
trunc_normal_
(
self
.
query
,
std
=
.
02
)
trunc_normal_
(
self
.
query
,
std
=
0
.02
)
if
kv_dim
is
not
None
and
kv_dim
!=
embed_dim
:
if
kv_dim
is
not
None
and
kv_dim
!=
embed_dim
:
self
.
kv_proj
=
nn
.
Linear
(
kv_dim
,
embed_dim
,
bias
=
False
)
self
.
kv_proj
=
Replicated
Linear
(
kv_dim
,
embed_dim
,
bias
=
False
)
else
:
else
:
self
.
kv_proj
=
nn
.
Identity
()
# Maintain the same return value with ReplicatedLinear.forward
self
.
kv_proj
=
lambda
*
args
,
**
kwargs
:
(
nn
.
Identity
()(
*
args
,
**
kwargs
),
None
,
)
self
.
attn
=
nn
.
MultiheadAttention
(
embed_dim
,
num_heads
)
self
.
attn
=
nn
.
MultiheadAttention
(
embed_dim
,
num_heads
)
self
.
ln_q
=
norm_layer
(
embed_dim
)
self
.
ln_q
=
norm_layer
(
embed_dim
)
self
.
ln_kv
=
norm_layer
(
embed_dim
)
self
.
ln_kv
=
norm_layer
(
embed_dim
)
self
.
ln_post
=
norm_layer
(
embed_dim
)
self
.
ln_post
=
norm_layer
(
embed_dim
)
self
.
proj
=
nn
.
Parameter
(
self
.
proj
=
nn
.
Parameter
(
(
embed_dim
**-
0.5
)
*
torch
.
randn
(
embed_dim
,
embed_dim
))
(
embed_dim
**-
0.5
)
*
torch
.
randn
(
embed_dim
,
embed_dim
))
if
self
.
version
==
(
2
,
0
):
def
_init_weights
(
self
,
m
:
nn
.
Module
)
->
None
:
self
.
pos_embed
=
nn
.
Parameter
(
if
isinstance
(
m
,
nn
.
Linear
):
torch
.
from_numpy
(
trunc_normal_
(
m
.
weight
,
std
=
0.02
)
get_2d_sincos_pos_embed
(
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
embed_dim
,
grid_size
,
nn
.
init
.
constant_
(
m
.
bias
,
0
)
version
=
self
.
version
)).
float
()).
requires_grad_
(
False
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
_repeat
(
self
,
query
,
N
:
int
):
return
query
.
unsqueeze
(
1
).
repeat
(
1
,
N
,
1
)
class
Resampler2
(
BaseResampler
):
def
__init__
(
self
,
grid_size
:
int
,
embed_dim
:
int
,
num_heads
:
int
,
kv_dim
:
Optional
[
int
]
=
None
,
norm_layer
:
Callable
[[
int
],
nn
.
LayerNorm
]
=
DEFAULT_LN
,
adaptive
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
grid_size
**
2
,
embed_dim
,
num_heads
,
kv_dim
,
norm_layer
)
self
.
adaptive
=
adaptive
pos_embed_arr
=
get_2d_sincos_pos_embed
(
embed_dim
,
grid_size
,
version
=
(
2
,
0
))
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
from_numpy
(
pos_embed_arr
).
float
()).
requires_grad_
(
False
)
self
.
apply
(
self
.
_init_weights
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
tgt_sizes
:
torch
.
Tensor
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
):
if
self
.
adaptive
:
pos_embed_arr
=
get_2d_sincos_pos_embed
(
self
.
embed_dim
,
tgt_sizes
,
version
=
(
2
,
0
))
pos_embed
=
torch
.
from_numpy
(
pos_embed_arr
).
to
(
device
=
x
.
device
,
dtype
=
x
.
dtype
)
else
:
else
:
self
.
_set_2d_pos_cache
(
self
.
max_size
)
pos_embed
=
get_abs_pos
(
self
.
pos_embed
,
tgt_sizes
)
x
,
_
=
self
.
kv_proj
(
x
)
x
=
self
.
ln_kv
(
x
).
permute
(
1
,
0
,
2
)
N
=
x
.
shape
[
1
]
q
=
self
.
ln_q
(
self
.
query
)
out
=
self
.
attn
(
self
.
_repeat
(
q
,
N
)
+
self
.
pos_embed
.
unsqueeze
(
1
),
x
+
pos_embed
.
unsqueeze
(
1
),
x
,
attn_mask
=
attn_mask
,
)[
0
]
x
=
out
.
permute
(
1
,
0
,
2
)
x
=
self
.
ln_post
(
x
)
x
=
x
@
self
.
proj
return
x
class
Resampler2_5
(
BaseResampler
):
def
__init__
(
self
,
num_queries
:
int
,
embed_dim
:
int
,
num_heads
:
int
,
kv_dim
:
Optional
[
int
]
=
None
,
norm_layer
:
Callable
[[
int
],
nn
.
LayerNorm
]
=
DEFAULT_LN
,
max_size
:
Tuple
[
int
,
int
]
=
(
70
,
70
),
)
->
None
:
super
().
__init__
(
num_queries
,
embed_dim
,
num_heads
,
kv_dim
,
norm_layer
)
self
.
max_size
=
max_size
self
.
_set_2d_pos_cache
(
self
.
max_size
)
self
.
apply
(
self
.
_init_weights
)
self
.
apply
(
self
.
_init_weights
)
def
_set_2d_pos_cache
(
self
,
def
_set_2d_pos_cache
(
self
,
max_size
:
Tuple
[
int
,
int
],
max_size
:
Tuple
[
int
,
int
],
device
:
torch
.
types
.
Device
=
'
cpu
'
)
:
device
:
torch
.
types
.
Device
=
"
cpu
"
)
->
None
:
pos_embed
=
torch
.
from_numpy
(
pos_embed
_arr
=
get_2d_sincos_pos_embed
(
self
.
embed_dim
,
get_2d_sincos_pos_embed
(
self
.
embed_dim
,
max_size
,
max_size
,
version
=
(
2
,
5
))
version
=
self
.
version
)
).
float
().
to
(
device
)
pos_embed
=
torch
.
from_numpy
(
pos_embed_arr
).
float
().
to
(
device
)
self
.
register_buffer
(
"pos_embed"
,
pos_embed
,
persistent
=
False
)
self
.
register_buffer
(
"pos_embed"
,
pos_embed
,
persistent
=
False
)
def
_adjust_pos_cache
(
self
,
tgt_sizes
:
torch
.
Tensor
,
def
_adjust_pos_cache
(
self
,
tgt_sizes
:
torch
.
Tensor
,
device
:
torch
.
types
.
Device
):
device
:
torch
.
types
.
Device
)
->
None
:
max_h
=
torch
.
max
(
tgt_sizes
[:,
0
])
max_h
=
tgt_sizes
[:,
0
].
max
().
item
()
max_w
=
torch
.
max
(
tgt_sizes
[:,
1
])
max_w
=
tgt_sizes
[:,
1
].
max
().
item
()
assert
isinstance
(
max_h
,
int
)
and
isinstance
(
max_w
,
int
)
if
max_h
>
self
.
max_size
[
0
]
or
max_w
>
self
.
max_size
[
1
]:
if
max_h
>
self
.
max_size
[
0
]
or
max_w
>
self
.
max_size
[
1
]:
self
.
max_size
=
[
self
.
max_size
=
(
max
(
max_h
,
self
.
max_size
[
0
]),
max
(
max_h
,
self
.
max_size
[
0
]),
max
(
max_w
,
self
.
max_size
[
1
])
max
(
max_w
,
self
.
max_size
[
1
])
,
]
)
self
.
_set_2d_pos_cache
(
self
.
max_size
,
device
)
self
.
_set_2d_pos_cache
(
self
.
max_size
,
device
)
def
_init_weights
(
self
,
m
:
nn
.
Module
):
def
forward
(
self
,
x
:
torch
.
Tensor
,
if
isinstance
(
m
,
nn
.
Linear
):
tgt_sizes
:
torch
.
Tensor
)
->
torch
.
Tensor
:
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward_2_5
(
self
,
x
:
torch
.
Tensor
,
tgt_sizes
:
Optional
[
torch
.
Tensor
]
=
None
):
assert
x
.
shape
[
0
]
==
tgt_sizes
.
shape
[
0
]
assert
x
.
shape
[
0
]
==
tgt_sizes
.
shape
[
0
]
bs
=
x
.
shape
[
0
]
bs
=
x
.
shape
[
0
]
...
@@ -254,25 +355,25 @@ class Resampler(nn.Module):
...
@@ -254,25 +355,25 @@ class Resampler(nn.Module):
self
.
_adjust_pos_cache
(
tgt_sizes
,
device
=
device
)
self
.
_adjust_pos_cache
(
tgt_sizes
,
device
=
device
)
max_patch_len
=
torch
.
max
(
patch_len
)
max_patch_len
=
patch_len
.
max
().
item
()
assert
isinstance
(
max_patch_len
,
int
)
key_padding_mask
=
torch
.
zeros
((
bs
,
max_patch_len
),
key_padding_mask
=
torch
.
zeros
((
bs
,
max_patch_len
),
dtype
=
torch
.
bool
,
dtype
=
torch
.
bool
,
device
=
device
)
device
=
device
)
pos_embed
=
[]
pos_embed
=
[]
for
i
in
range
(
bs
):
for
i
in
range
(
bs
):
tgt_h
,
tgt_w
=
tgt_sizes
[
i
]
tgt_h
,
tgt_w
=
tgt_sizes
[
i
]
.
tolist
()
pos_embed
.
append
(
self
.
pos_embed
[:
tgt_h
,
:
tgt_w
,
:].
reshape
(
pos_embed
.
append
(
self
.
pos_embed
[:
tgt_h
,
:
tgt_w
,
:].
reshape
(
(
tgt_h
*
tgt_w
,
-
1
)).
to
(
dtype
))
# patches * D
(
tgt_h
*
tgt_w
,
-
1
)).
to
(
dtype
))
# patches * D
key_padding_mask
[
i
,
patch_len
[
i
]:]
=
True
key_padding_mask
[
i
,
patch_len
[
i
]:]
=
True
pos_embed
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
pos_embed
,
pos_embed
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
pos_embed
,
batch_first
=
True
,
batch_first
=
True
,
padding_value
=
0.0
).
permute
(
padding_value
=
0.0
).
permute
(
1
,
0
,
1
,
0
,
2
)
# BLD => L * B * D
2
)
# BLD => L * B * D
x
,
_
=
self
.
kv_proj
(
x
)
# B * L * D
x
=
self
.
kv_proj
(
x
)
# B * L * D
x
=
self
.
ln_kv
(
x
).
permute
(
1
,
0
,
2
)
# L * B * D
x
=
self
.
ln_kv
(
x
).
permute
(
1
,
0
,
2
)
# L * B * D
q
=
self
.
ln_q
(
self
.
query
)
# Q * D
q
=
self
.
ln_q
(
self
.
query
)
# Q * D
...
@@ -281,7 +382,8 @@ class Resampler(nn.Module):
...
@@ -281,7 +382,8 @@ class Resampler(nn.Module):
self
.
_repeat
(
q
,
bs
),
# Q * B * D
self
.
_repeat
(
q
,
bs
),
# Q * B * D
x
+
pos_embed
,
# L * B * D + L * B * D
x
+
pos_embed
,
# L * B * D + L * B * D
x
,
x
,
key_padding_mask
=
key_padding_mask
)[
0
]
key_padding_mask
=
key_padding_mask
,
)[
0
]
# out: Q * B * D
# out: Q * B * D
x
=
out
.
permute
(
1
,
0
,
2
)
# B * Q * D
x
=
out
.
permute
(
1
,
0
,
2
)
# B * Q * D
...
@@ -289,45 +391,6 @@ class Resampler(nn.Module):
...
@@ -289,45 +391,6 @@ class Resampler(nn.Module):
x
=
x
@
self
.
proj
x
=
x
@
self
.
proj
return
x
return
x
def
forward_2
(
self
,
x
:
torch
.
Tensor
,
tgt_sizes
:
Optional
[
torch
.
Tensor
]
=
None
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
):
if
self
.
adaptive
:
pos_embed
=
torch
.
Tensor
(
get_2d_sincos_pos_embed
(
self
.
embed_dim
,
tgt_sizes
)).
float
().
to
(
device
=
x
.
device
,
dtype
=
x
.
dtype
)
else
:
pos_embed
=
get_abs_pos
(
self
.
pos_embed
,
tgt_sizes
)
x
=
self
.
kv_proj
(
x
)
x
=
self
.
ln_kv
(
x
).
permute
(
1
,
0
,
2
)
N
=
x
.
shape
[
1
]
q
=
self
.
ln_q
(
self
.
query
)
out
=
self
.
attn
(
self
.
_repeat
(
q
,
N
)
+
self
.
pos_embed
.
unsqueeze
(
1
),
x
+
pos_embed
.
unsqueeze
(
1
),
x
,
attn_mask
=
attn_mask
)[
0
]
x
=
out
.
permute
(
1
,
0
,
2
)
x
=
self
.
ln_post
(
x
)
x
=
x
@
self
.
proj
return
x
def
forward
(
self
,
x
:
torch
.
Tensor
,
tgt_sizes
:
Optional
[
torch
.
Tensor
]
=
None
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
):
if
self
.
version
==
(
2
,
0
):
return
self
.
forward_2
(
x
,
tgt_sizes
=
tgt_sizes
,
attn_mask
=
attn_mask
)
else
:
return
self
.
forward_2_5
(
x
,
tgt_sizes
=
tgt_sizes
)
def
_repeat
(
self
,
query
,
N
:
int
):
return
query
.
unsqueeze
(
1
).
repeat
(
1
,
N
,
1
)
def
get_max_minicpmv_image_tokens
(
ctx
:
InputContext
):
def
get_max_minicpmv_image_tokens
(
ctx
:
InputContext
):
hf_config
=
ctx
.
get_hf_config
(
PretrainedConfig
)
hf_config
=
ctx
.
get_hf_config
(
PretrainedConfig
)
...
@@ -348,10 +411,7 @@ def dummy_image_for_minicpmv(hf_config: PretrainedConfig):
...
@@ -348,10 +411,7 @@ def dummy_image_for_minicpmv(hf_config: PretrainedConfig):
def
dummy_data_for_minicpmv
(
ctx
:
InputContext
,
seq_len
:
int
):
def
dummy_data_for_minicpmv
(
ctx
:
InputContext
,
seq_len
:
int
):
hf_config
=
ctx
.
get_hf_config
(
PretrainedConfig
)
hf_config
=
ctx
.
get_hf_config
(
PretrainedConfig
)
# image_feature_size = get_max_minicpmv_image_tokens(ctx)
seq_data
=
dummy_seq_data_for_minicpmv
(
seq_len
)
seq_data
=
dummy_seq_data_for_minicpmv
(
seq_len
)
mm_data
=
dummy_image_for_minicpmv
(
hf_config
)
mm_data
=
dummy_image_for_minicpmv
(
hf_config
)
return
seq_data
,
mm_data
return
seq_data
,
mm_data
...
@@ -376,25 +436,36 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
...
@@ -376,25 +436,36 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
pattern
=
"(<image>./</image>)"
pattern
=
"(<image>./</image>)"
image
=
multi_modal_data
[
"image"
]
image
=
multi_modal_data
[
"image"
]
image_tags
=
re
.
findall
(
pattern
,
prompt
)
image_tags
=
re
.
findall
(
pattern
,
prompt
)
assert
len
(
image_tags
)
<=
1
text_chunks
=
prompt
.
split
(
pattern
)
new_prompt
=
text_chunks
[
0
]
\
+
image_processor
.
get_slice_image_placeholder
(
image
.
size
)
\
+
text_chunks
[
1
]
new_token_ids
=
tokenizer
.
encode
(
new_prompt
)
if
len
(
image_tags
)
==
0
:
new_token_ids
=
token_ids
llm_inputs
=
LLMInputs
(
prompt_token_ids
=
new_token_ids
,
new_prompt
=
prompt
prompt
=
new_prompt
,
else
:
multi_modal_data
=
multi_modal_data
)
if
len
(
image_tags
)
>
1
:
logger
.
warning
(
"Multiple image input is not supported yet, "
"so any extra image tokens will be treated "
"as plain text."
)
text_chunks
=
prompt
.
split
(
pattern
)
new_prompt
=
(
text_chunks
[
0
]
+
image_processor
.
get_slice_image_placeholder
(
image
.
size
)
+
""
.
join
(
text_chunks
[
1
:]))
new_token_ids
=
tokenizer
.
encode
(
new_prompt
)
llm_inputs
=
LLMInputs
(
prompt_token_ids
=
new_token_ids
,
prompt
=
new_prompt
,
multi_modal_data
=
multi_modal_data
,
)
return
llm_inputs
return
llm_inputs
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
()
class
MiniCPMVBaseModel
(
nn
.
Module
,
SupportsVision
):
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_minicpmv_image_tokens
)
"""
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_minicpmv
)
The abstract class of MiniCPMV can only be inherited, but cannot be
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_minicpmv
)
instantiated.
class
MiniCPMV
(
nn
.
Module
,
SupportsVision
):
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -419,8 +490,8 @@ class MiniCPMV(nn.Module, SupportsVision):
...
@@ -419,8 +490,8 @@ class MiniCPMV(nn.Module, SupportsVision):
self
.
vpm
=
self
.
init_vision_module
()
self
.
vpm
=
self
.
init_vision_module
()
param_dtype
=
torch
.
get_default_dtype
()
param_dtype
=
torch
.
get_default_dtype
()
self
.
vpm
.
to
(
dtype
=
param_dtype
)
self
.
vpm
.
to
(
dtype
=
param_dtype
)
self
.
vision_dim
=
self
.
vpm
.
embed_dim
if
self
.
version
==
(
2
,
0
)
\
self
.
vision_dim
=
(
self
.
vpm
.
embed_dim
if
self
.
version
==
(
2
,
0
)
else
else
self
.
vpm
.
embeddings
.
embed_dim
self
.
vpm
.
embeddings
.
embed_dim
)
self
.
embed_dim
=
self
.
config
.
hidden_size
self
.
embed_dim
=
self
.
config
.
hidden_size
self
.
resampler
=
self
.
init_resampler
(
self
.
embed_dim
,
self
.
vision_dim
)
self
.
resampler
=
self
.
init_resampler
(
self
.
embed_dim
,
self
.
vision_dim
)
self
.
resampler
.
to
(
device
=
"cuda"
,
dtype
=
param_dtype
)
self
.
resampler
.
to
(
device
=
"cuda"
,
dtype
=
param_dtype
)
...
@@ -430,248 +501,100 @@ class MiniCPMV(nn.Module, SupportsVision):
...
@@ -430,248 +501,100 @@ class MiniCPMV(nn.Module, SupportsVision):
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
def
init_llm
(
self
,
def
get_embedding
(
config
:
PretrainedConfig
,
self
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
input_ids
:
torch
.
Tensor
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
image_inputs
:
Optional
[
MiniCPMVImageInputs
],
if
self
.
version
==
(
2
,
0
):
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
MiniCPMModel
(
config
,
vlm_embedding
:
torch
.
Tensor
=
self
.
llm
.
embed_tokens
(
input_ids
)
cache_config
=
cache_config
,
if
hasattr
(
self
.
config
,
"scale_emb"
):
quant_config
=
quant_config
)
vlm_embedding
*=
self
.
config
.
scale_emb
elif
self
.
version
==
(
2
,
5
):
return
LlamaModel
(
config
,
if
image_inputs
is
None
:
# No image
cache_config
=
cache_config
,
vision_hidden_states
=
torch
.
tensor
([],
device
=
input_ids
.
device
)
quant_config
=
quant_config
)
else
:
return
Qwen2Model
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
def
init_vision_module
(
self
):
if
self
.
version
==
(
2
,
0
):
try
:
import
timm
except
ImportError
:
raise
ImportError
(
'Please install timm==0.9.10'
)
from
ImportError
default_dtype
=
torch
.
get_default_dtype
()
torch
.
set_default_dtype
(
torch
.
float16
)
model
=
timm
.
create_model
(
'vit_so400m_patch14_siglip_384.webli'
,
pretrained
=
False
,
num_classes
=
0
,
dynamic_img_size
=
True
,
dynamic_img_pad
=
True
)
torch
.
set_default_dtype
(
default_dtype
)
if
isinstance
(
model
,
timm
.
models
.
VisionTransformer
)
and
model
.
attn_pool
is
not
None
:
model
.
attn_pool
=
torch
.
nn
.
Identity
()
if
self
.
config
.
drop_vision_last_layer
:
model
.
blocks
=
model
.
blocks
[:
-
1
]
elif
self
.
version
==
(
2
,
5
):
from
transformers.models.idefics2.modeling_idefics2
import
(
Idefics2VisionTransformer
)
model
=
Idefics2VisionTransformer
(
self
.
config
.
vision_config
)
if
self
.
config
.
drop_vision_last_layer
:
model
.
encoder
.
layers
=
model
.
encoder
.
layers
[:
-
1
]
else
:
from
vllm.model_executor.models.na_vit
import
(
SiglipVisionTransformer
)
if
self
.
config
.
_attn_implementation
==
'flash_attention_2'
:
self
.
config
.
vision_config
.
_attn_implementation
\
=
'flash_attention_2'
else
:
# not support sdpa
self
.
config
.
vision_config
.
_attn_implementation
=
'eager'
model
=
SiglipVisionTransformer
(
self
.
config
.
vision_config
)
if
self
.
config
.
drop_vision_last_layer
:
model
.
encoder
.
layers
=
model
.
encoder
.
layers
[:
-
1
]
return
model
def
init_resampler
(
self
,
embed_dim
:
int
,
vision_dim
:
int
):
default_dtype
=
torch
.
get_default_dtype
()
torch
.
set_default_dtype
(
torch
.
float16
)
if
self
.
version
==
(
2
,
0
):
resampler
=
Resampler
(
grid_size
=
int
(
math
.
sqrt
(
self
.
config
.
query_num
)),
num_queries
=
None
,
embed_dim
=
embed_dim
,
num_heads
=
embed_dim
//
128
,
kv_dim
=
vision_dim
,
adaptive
=
True
,
version
=
self
.
version
)
else
:
else
:
resampler
=
Resampler
(
num_queries
=
self
.
config
.
query_num
,
vision_hidden_states
=
self
.
get_vision_hidden_states
(
image_inputs
)
grid_size
=
None
,
embed_dim
=
embed_dim
,
# See NOTE in _parse_and_validate_inputs
num_heads
=
embed_dim
//
128
,
image_bounds
=
image_inputs
[
"image_bounds"
]
kv_dim
=
vision_dim
,
if
len
(
image_bounds
)
>
0
:
adaptive
=
True
,
image_indices
=
torch
.
stack
([
version
=
self
.
version
)
torch
.
arange
(
start
,
end
,
dtype
=
torch
.
long
)
torch
.
set_default_dtype
(
default_dtype
)
for
start
,
end
in
image_bounds
.
tolist
()
return
resampler
]).
to
(
vlm_embedding
.
device
)
vlm_embedding
.
scatter_
(
0
,
image_indices
.
view
(
-
1
,
1
).
repeat
(
1
,
vlm_embedding
.
shape
[
-
1
]),
vision_hidden_states
.
view
(
-
1
,
vision_hidden_states
.
shape
[
-
1
]),
)
def
get_vision_embedding
(
self
,
return
vlm_embedding
,
vision_hidden_states
pixel_values
:
List
[
List
[
torch
.
Tensor
]],
patch_attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
tgt_sizes
:
Optional
[
torch
.
Tensor
]
=
None
,
version
:
Tuple
[
int
,
int
]
=
(
2
,
0
)):
if
version
==
(
2
,
0
):
res
=
[]
dtype
=
self
.
vpm
.
pos_embed
.
data
.
dtype
for
pixel_value
in
pixel_values
:
# V2.0 start
H
,
W
=
pixel_value
[
0
].
shape
[
-
2
:]
tgt_size
=
(
math
.
ceil
(
H
/
self
.
vpm
.
patch_embed
.
patch_size
[
0
]),
math
.
ceil
(
W
/
self
.
vpm
.
patch_embed
.
patch_size
[
0
]))
# V2.0 end
vision_embedding
=
self
.
vpm
.
forward_features
(
pixel_value
.
unsqueeze
(
0
).
type
(
dtype
))
if
hasattr
(
self
.
vpm
,
'num_prefix_tokens'
)
and
self
.
vpm
.
num_prefix_tokens
>
0
:
vision_embedding
=
vision_embedding
[:,
self
.
vpm
.
num_prefix_tokens
:]
res
.
append
(
self
.
resampler
(
vision_embedding
,
tgt_size
))
return
torch
.
vstack
(
res
)
elif
version
==
(
2
,
5
):
vision_embedding
=
self
.
vpm
(
pixel_values
.
type
(
dtype
),
patch_attention_mask
=
patch_attn_mask
).
last_hidden_state
vision_embedding
=
self
.
resampler
(
vision_embedding
,
tgt_sizes
)
else
:
vision_embedding
=
self
.
vpm
(
pixel_values
.
type
(
dtype
),
patch_attention_mask
=
patch_attn_mask
,
tgt_sizes
=
tgt_sizes
).
last_hidden_state
def
get_image_bounds
(
self
,
input_ids
:
torch
.
Tensor
):
def
_
get_image_bounds
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tokenizer
=
cached_get_tokenizer
(
self
.
config
.
_name_or_path
,
tokenizer
=
cached_get_tokenizer
(
self
.
config
.
_name_or_path
,
trust_remote_code
=
True
)
trust_remote_code
=
True
)
if
not
hasattr
(
tokenizer
,
"slice_start_id"
):
start_cond
=
input_ids
==
tokenizer
.
im_start_id
start_cond
=
input_ids
==
tokenizer
.
im_start_id
end_cond
=
input_ids
==
tokenizer
.
im_end_id
end_cond
=
input_ids
==
tokenizer
.
im_end_id
if
hasattr
(
tokenizer
,
"slice_start_id"
):
else
:
start_cond
|=
(
input_ids
==
tokenizer
.
slice_start_id
)
start_cond
=
(
input_ids
==
tokenizer
.
im_start_id
)
|
(
end_cond
|=
(
input_ids
==
tokenizer
.
slice_end_id
)
input_ids
==
tokenizer
.
slice_start_id
)
end_cond
=
(
input_ids
==
tokenizer
.
im_end_id
)
|
(
input_ids
==
tokenizer
.
slice_end_id
)
image_start_tokens
=
torch
.
where
(
start_cond
)
[
0
]
image_start_tokens
,
=
torch
.
where
(
start_cond
)
image_start_tokens
+=
1
image_start_tokens
+=
1
image_end_tokens
=
torch
.
where
(
end_cond
)
[
0
]
image_end_tokens
,
=
torch
.
where
(
end_cond
)
valid_image_nums
=
max
(
len
(
image_start_tokens
),
len
(
image_end_tokens
))
valid_image_nums
=
max
(
len
(
image_start_tokens
),
len
(
image_end_tokens
))
if
valid_image_nums
==
0
:
if
valid_image_nums
==
0
:
return
[]
return
torch
.
zeros
((
0
,
2
),
device
=
input_ids
.
device
)
image_bound
=
torch
.
hstack
([
return
torch
.
hstack
([
image_start_tokens
[:
valid_image_nums
].
unsqueeze
(
-
1
),
image_start_tokens
[:
valid_image_nums
].
unsqueeze
(
-
1
),
image_end_tokens
[:
valid_image_nums
].
unsqueeze
(
-
1
),
image_end_tokens
[:
valid_image_nums
].
unsqueeze
(
-
1
),
])
])
return
image_bound
def
_parse_and_validate_inputs
(
self
,
def
get_vision_hidden_states
(
self
,
data
:
Dict
[
str
,
input_ids
:
torch
.
Tensor
,
Union
[
List
[
torch
.
Tensor
],
**
kwargs
:
object
,
torch
.
Tensor
]]):
)
->
Optional
[
MiniCPMVImageInputs
]:
if
"vision_hidden_states"
not
in
data
:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
[])
pixel_values
=
data
[
"pixel_values"
]
tgt_sizes
=
kwargs
.
pop
(
"tgt_sizes"
,
[])
tgt_sizes
=
data
[
"tgt_sizes"
]
vision_hidden_states
=
[]
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
if
self
.
version
==
(
2
,
0
):
raise
ValueError
(
"Incorrect type of pixel values. "
if
pixel_values
is
not
None
and
len
(
pixel_values
)
>
0
:
f
"Got type:
{
type
(
pixel_values
)
}
"
)
vision_hidden_states
=
self
.
get_vision_embedding
(
pixel_values
)
if
not
isinstance
(
tgt_sizes
,
(
torch
.
Tensor
,
list
)):
else
:
raise
ValueError
(
"Incorrect type of target sizes. "
vision_hidden_states
=
torch
.
tensor
([]).
to
(
f
"Got type:
{
type
(
tgt_sizes
)
}
"
)
data
[
"input_ids"
].
device
)
else
:
if
len
(
pixel_values
)
!=
len
(
tgt_sizes
):
device
=
self
.
vpm
.
embeddings
.
position_embedding
.
weight
.
device
raise
ValueError
(
"Inconsistent batch lengths, found: "
dtype
=
self
.
vpm
.
embeddings
.
position_embedding
.
weight
.
dtype
f
"
{
len
(
pixel_values
)
}
vs.
{
len
(
tgt_sizes
)
}
"
)
all_pixel_values
=
[
i
.
flatten
(
end_dim
=
1
).
permute
(
1
,
0
)
for
i
in
pixel_values
pixel_values_flat
:
List
[
torch
.
Tensor
]
=
[]
]
tgt_sizes_flat
:
List
[
torch
.
Tensor
]
=
[]
if
all_pixel_values
:
for
b
in
range
(
len
(
pixel_values
)):
tgt_sizes
=
torch
.
vstack
(
tgt_sizes
).
type
(
torch
.
int32
)
pixel_values_flat
+=
pixel_values
[
b
]
max_patches
=
torch
.
max
(
tgt_sizes
[:,
0
]
*
tgt_sizes
[:,
1
])
tgt_sizes_flat
+=
tgt_sizes
[
b
]
all_pixel_values
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
all_pixel_values
,
batch_first
=
True
,
padding_value
=
0.0
)
# NOTE: Input IDs does not contain image tokens during memory profiling,
B
,
L
,
_
=
all_pixel_values
.
shape
# so we allow it to be empty
all_pixel_values
=
all_pixel_values
.
permute
(
if
len
(
pixel_values_flat
)
!=
len
(
tgt_sizes_flat
):
0
,
2
,
1
).
reshape
(
B
,
3
,
-
1
,
L
)
raise
ValueError
(
"Inconsistent flattened lengths, found: "
patch_attn_mask
=
torch
.
zeros
((
B
,
1
,
max_patches
),
f
"
{
len
(
pixel_values_flat
)
}
vs. "
dtype
=
torch
.
bool
,
f
"
{
len
(
tgt_sizes_flat
)
}
"
)
device
=
device
)
if
self
.
version
==
(
2
,
5
):
if
len
(
pixel_values_flat
)
==
0
:
for
i
in
range
(
B
):
return
None
patch_attn_mask
[
i
,
:
tgt_sizes
[
i
][
0
]
*
tgt_sizes
[
i
][
1
]]
=
True
return
MiniCPMVImageInputs
(
vision_embedding
=
self
.
vpm
(
image_bounds
=
self
.
_get_image_bounds
(
input_ids
),
all_pixel_values
.
type
(
dtype
),
pixel_values
=
pixel_values_flat
,
patch_attention_mask
=
patch_attn_mask
tgt_sizes
=
torch
.
stack
(
tgt_sizes_flat
),
).
last_hidden_state
)
else
:
for
i
in
range
(
B
):
patch_attn_mask
[
i
,
0
,
:
tgt_sizes
[
i
][
0
]
*
tgt_sizes
[
i
][
1
]]
=
True
vision_embedding
=
self
.
vpm
(
all_pixel_values
.
type
(
dtype
),
patch_attention_mask
=
patch_attn_mask
,
tgt_sizes
=
tgt_sizes
).
last_hidden_state
vision_hidden_states
=
self
.
resampler
(
vision_embedding
,
tgt_sizes
)
else
:
# no image
dummy_feature
=
[]
vision_hidden_states
=
dummy_feature
else
:
vision_hidden_states
=
data
[
"vision_hidden_states"
]
return
vision_hidden_states
def
get_embedding
(
self
,
data
:
Dict
[
str
,
Union
[
List
[
torch
.
Tensor
],
torch
.
Tensor
]]):
input_ids
=
data
[
"input_ids"
]
vision_hidden_states
=
self
.
get_vision_hidden_states
(
data
)
if
vision_hidden_states
is
not
None
and
len
(
vision_hidden_states
)
>
0
:
image_bounds
=
self
.
get_image_bounds
(
input_ids
)
else
:
image_bounds
=
[]
if
hasattr
(
self
.
config
,
'scale_emb'
):
vlm_embedding
=
self
.
llm
.
embed_tokens
(
input_ids
)
*
self
.
config
.
scale_emb
else
:
vlm_embedding
=
self
.
llm
.
embed_tokens
(
input_ids
)
vision_hidden_states
=
[
i
.
type
(
vlm_embedding
.
dtype
)
if
isinstance
(
i
,
torch
.
Tensor
)
else
i
for
i
in
vision_hidden_states
]
if
len
(
vision_hidden_states
)
>
0
and
len
(
image_bounds
)
>
0
:
vision_hidden_states
=
torch
.
cat
(
vision_hidden_states
,
dim
=
0
)
image_indices
=
torch
.
stack
([
torch
.
arange
(
r
[
0
],
r
[
1
],
dtype
=
torch
.
long
)
for
r
in
image_bounds
]).
to
(
vlm_embedding
.
device
)
vlm_embedding
.
scatter_
(
0
,
image_indices
.
view
(
-
1
,
1
).
repeat
(
1
,
vlm_embedding
.
shape
[
-
1
]),
vision_hidden_states
.
view
(
-
1
,
vision_hidden_states
.
shape
[
-
1
]))
return
vlm_embedding
,
vision_hidden_states
def
process_multimodal_inputs
(
self
,
inputs
:
Dict
[
str
,
Union
[
List
[
torch
.
Tensor
],
torch
.
Tensor
]]):
pixel_values
=
[]
tgt_sizes
=
[]
for
b
in
range
(
len
(
inputs
[
"pixel_values"
])):
pixel_values
+=
inputs
[
"pixel_values"
][
b
]
tgt_sizes
+=
inputs
[
"tgt_sizes"
][
b
]
return
{
"pixel_values"
:
pixel_values
,
"input_ids"
:
inputs
[
"input_ids"
],
"tgt_sizes"
:
tgt_sizes
}
def
forward
(
def
forward
(
self
,
self
,
...
@@ -680,23 +603,20 @@ class MiniCPMV(nn.Module, SupportsVision):
...
@@ -680,23 +603,20 @@ class MiniCPMV(nn.Module, SupportsVision):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
**
kwargs
:
object
,
**
kwargs
:
Any
,
):
)
->
torch
.
Tensor
:
inputs
=
{
image_inputs
=
self
.
_parse_and_validate_inputs
(
input_ids
,
**
kwargs
)
"pixel_values"
:
kwargs
.
pop
(
"pixel_values"
,
[]),
"input_ids"
:
input_ids
,
vlm_embeddings
,
_
=
self
.
get_embedding
(
input_ids
,
image_inputs
)
"tgt_sizes"
:
kwargs
.
pop
(
"tgt_sizes"
,
None
),
}
output
=
self
.
llm
(
inputs
=
self
.
process_multimodal_inputs
(
inputs
)
input_ids
=
None
,
positions
=
positions
,
vlm_embeddings
,
vision_hidden_states
=
self
.
get_embedding
(
inputs
)
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
output
=
self
.
llm
(
input_ids
=
None
,
intermediate_tensors
=
intermediate_tensors
,
positions
=
positions
,
inputs_embeds
=
vlm_embeddings
,
kv_caches
=
kv_caches
,
)
attn_metadata
=
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
vlm_embeddings
)
return
output
return
output
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
@@ -735,13 +655,10 @@ class MiniCPMV(nn.Module, SupportsVision):
...
@@ -735,13 +655,10 @@ class MiniCPMV(nn.Module, SupportsVision):
# the checkpoint. Skip them.
# the checkpoint. Skip them.
continue
continue
use_default_weight_loading
=
False
use_default_weight_loading
=
False
if
"vpm"
in
name
or
'resampler'
in
name
:
if
self
.
is_default_weight_loading
(
name
):
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading
=
True
use_default_weight_loading
=
True
else
:
else
:
for
(
param_name
,
weight_name
,
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
...
@@ -755,3 +672,341 @@ class MiniCPMV(nn.Module, SupportsVision):
...
@@ -755,3 +672,341 @@ class MiniCPMV(nn.Module, SupportsVision):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
def
init_llm
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
nn
.
Module
:
raise
NotImplementedError
def
init_vision_module
(
self
)
->
nn
.
Module
:
raise
NotImplementedError
def
init_resampler
(
self
,
embed_dim
:
int
,
vision_dim
:
int
)
->
nn
.
Module
:
raise
NotImplementedError
def
get_vision_embedding
(
self
,
pixel_values
:
List
[
torch
.
Tensor
],
patch_attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
tgt_sizes
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
get_vision_hidden_states
(
self
,
data
:
MiniCPMVImageInputs
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
is_default_weight_loading
(
self
,
name
:
str
)
->
bool
:
raise
NotImplementedError
class
MiniCPMV2
(
MiniCPMVBaseModel
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
multimodal_config
:
MultiModalConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
(
config
,
multimodal_config
,
cache_config
,
quant_config
)
assert
self
.
version
==
(
2
,
0
)
def
init_llm
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
nn
.
Module
:
return
MiniCPMModel
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
def
init_vision_module
(
self
)
->
nn
.
Module
:
# TODO :refactor this vision model
try
:
import
timm
except
ImportError
:
raise
ImportError
(
"Please install timm==0.9.10"
)
from
ImportError
with
set_default_torch_dtype
(
torch
.
float16
):
model
=
timm
.
create_model
(
"vit_so400m_patch14_siglip_384.webli"
,
pretrained
=
False
,
num_classes
=
0
,
dynamic_img_size
=
True
,
dynamic_img_pad
=
True
,
)
if
(
isinstance
(
model
,
timm
.
models
.
VisionTransformer
)
and
model
.
attn_pool
is
not
None
):
model
.
attn_pool
=
torch
.
nn
.
Identity
()
if
self
.
config
.
drop_vision_last_layer
:
model
.
blocks
=
model
.
blocks
[:
-
1
]
return
model
def
init_resampler
(
self
,
embed_dim
:
int
,
vision_dim
:
int
)
->
nn
.
Module
:
with
set_default_torch_dtype
(
torch
.
float16
):
resampler
=
Resampler2
(
embed_dim
=
embed_dim
,
num_heads
=
embed_dim
//
128
,
grid_size
=
int
(
math
.
sqrt
(
self
.
config
.
query_num
)),
kv_dim
=
vision_dim
,
adaptive
=
True
,
)
return
resampler
def
get_vision_embedding
(
self
,
pixel_values
:
List
[
torch
.
Tensor
],
patch_attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
tgt_sizes
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
res
=
[]
dtype
=
self
.
vpm
.
pos_embed
.
data
.
dtype
for
pixel_value
in
pixel_values
:
H
,
W
=
pixel_value
[
0
].
shape
[
-
2
:]
tgt_size
=
(
math
.
ceil
(
H
/
self
.
vpm
.
patch_embed
.
patch_size
[
0
]),
math
.
ceil
(
W
/
self
.
vpm
.
patch_embed
.
patch_size
[
0
]),
)
vision_embedding
=
self
.
vpm
.
forward_features
(
pixel_value
.
unsqueeze
(
0
).
type
(
dtype
))
if
(
hasattr
(
self
.
vpm
,
"num_prefix_tokens"
)
and
self
.
vpm
.
num_prefix_tokens
>
0
):
vision_embedding
=
vision_embedding
[:,
self
.
vpm
.
num_prefix_tokens
:]
res
.
append
(
self
.
resampler
(
vision_embedding
,
tgt_size
))
return
torch
.
vstack
(
res
)
def
get_vision_hidden_states
(
self
,
data
:
MiniCPMVImageInputs
)
->
torch
.
Tensor
:
pixel_values
=
data
[
"pixel_values"
]
return
self
.
get_vision_embedding
(
pixel_values
)
def
is_default_weight_loading
(
self
,
name
:
str
)
->
bool
:
return
"resampler"
in
name
or
"vpm"
in
name
class
MiniCPMV2_5
(
MiniCPMVBaseModel
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
multimodal_config
:
MultiModalConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
(
config
,
multimodal_config
,
cache_config
,
quant_config
)
assert
self
.
version
==
(
2
,
5
)
def
init_llm
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
nn
.
Module
:
return
LlamaModel
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
def
init_vision_module
(
self
)
->
nn
.
Module
:
model
=
Idefics2VisionTransformer
(
self
.
config
.
vision_config
)
if
self
.
config
.
drop_vision_last_layer
:
model
.
encoder
.
layers
=
model
.
encoder
.
layers
[:
-
1
]
return
model
def
init_resampler
(
self
,
embed_dim
:
int
,
vision_dim
:
int
)
->
nn
.
Module
:
with
set_default_torch_dtype
(
torch
.
float16
):
resampler
=
Resampler2_5
(
num_queries
=
self
.
config
.
query_num
,
embed_dim
=
embed_dim
,
num_heads
=
embed_dim
//
128
,
kv_dim
=
vision_dim
,
)
return
resampler
def
get_vision_embedding
(
self
,
pixel_values
:
List
[
torch
.
Tensor
],
patch_attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
tgt_sizes
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
vision_embedding
=
self
.
vpm
(
pixel_values
,
patch_attention_mask
=
patch_attn_mask
)
vision_embedding
=
self
.
resampler
(
vision_embedding
,
tgt_sizes
)
return
vision_embedding
def
get_vision_hidden_states
(
self
,
data
:
MiniCPMVImageInputs
)
->
torch
.
Tensor
:
pixel_values
=
data
[
"pixel_values"
]
tgt_sizes
=
data
[
"tgt_sizes"
]
device
=
self
.
vpm
.
embeddings
.
position_embedding
.
weight
.
device
dtype
=
self
.
vpm
.
embeddings
.
position_embedding
.
weight
.
dtype
all_pixel_values_lst
=
[
i
.
flatten
(
end_dim
=
1
).
permute
(
1
,
0
)
for
i
in
pixel_values
]
max_patches
=
(
tgt_sizes
[:,
0
]
*
tgt_sizes
[:,
1
]).
max
().
item
()
assert
isinstance
(
max_patches
,
int
)
all_pixel_values
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
all_pixel_values_lst
,
batch_first
=
True
,
padding_value
=
0.0
)
B
,
L
,
_
=
all_pixel_values
.
shape
all_pixel_values
=
all_pixel_values
.
permute
(
0
,
2
,
1
).
reshape
(
B
,
3
,
-
1
,
L
)
patch_attn_mask
=
torch
.
zeros
((
B
,
1
,
max_patches
),
dtype
=
torch
.
bool
,
device
=
device
)
for
i
in
range
(
B
):
patch_attn_mask
[
i
,
:
tgt_sizes
[
i
][
0
]
*
tgt_sizes
[
i
][
1
]]
=
True
return
self
.
get_vision_embedding
(
all_pixel_values
.
type
(
dtype
),
patch_attn_mask
,
tgt_sizes
)
def
is_default_weight_loading
(
self
,
name
:
str
)
->
bool
:
return
"resampler"
in
name
# NOTE: Currently, information about this model is unavailable. We are
# temporarily using `MiniCPMVQwen2` as it's name. The name may need
# to be modified in the future.
class
MiniCPMVQwen2
(
MiniCPMVBaseModel
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
multimodal_config
:
MultiModalConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
(
config
,
multimodal_config
,
cache_config
,
quant_config
)
def
init_llm
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
nn
.
Module
:
return
Qwen2Model
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
def
init_vision_module
(
self
)
->
nn
.
Module
:
# A custom version of SiglipVisionTransformer, won't work with TP
from
vllm.model_executor.models.na_vit
import
SiglipVisionTransformer
if
self
.
config
.
_attn_implementation
==
"flash_attention_2"
:
self
.
config
.
vision_config
.
_attn_implementation
=
"flash_attention_2"
else
:
# not support sdpa
self
.
config
.
vision_config
.
_attn_implementation
=
"eager"
model
=
SiglipVisionTransformer
(
self
.
config
.
vision_config
)
if
self
.
config
.
drop_vision_last_layer
:
model
.
encoder
.
layers
=
model
.
encoder
.
layers
[:
-
1
]
return
model
def
init_resampler
(
self
,
embed_dim
:
int
,
vision_dim
:
int
)
->
nn
.
Module
:
with
set_default_torch_dtype
(
torch
.
float16
):
resampler
=
Resampler2_5
(
num_queries
=
self
.
config
.
query_num
,
embed_dim
=
embed_dim
,
num_heads
=
embed_dim
//
128
,
kv_dim
=
vision_dim
,
)
return
resampler
def
get_vision_embedding
(
self
,
pixel_values
:
List
[
torch
.
Tensor
],
patch_attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
tgt_sizes
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
vision_embedding
=
self
.
vpm
(
pixel_values
,
patch_attention_mask
=
patch_attn_mask
,
tgt_sizes
=
tgt_sizes
,
).
last_hidden_state
return
vision_embedding
def
get_vision_hidden_states
(
self
,
data
:
MiniCPMVImageInputs
)
->
torch
.
Tensor
:
pixel_values
=
data
[
"pixel_values"
]
tgt_sizes
=
data
[
"tgt_sizes"
]
device
=
self
.
vpm
.
embeddings
.
position_embedding
.
weight
.
device
dtype
=
self
.
vpm
.
embeddings
.
position_embedding
.
weight
.
dtype
all_pixel_values_lst
=
[
i
.
flatten
(
end_dim
=
1
).
permute
(
1
,
0
)
for
i
in
pixel_values
]
max_patches
=
(
tgt_sizes
[:,
0
]
*
tgt_sizes
[:,
1
]).
max
().
item
()
assert
isinstance
(
max_patches
,
int
)
all_pixel_values
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
all_pixel_values_lst
,
batch_first
=
True
,
padding_value
=
0.0
)
B
,
L
,
_
=
all_pixel_values
.
shape
all_pixel_values
=
all_pixel_values
.
permute
(
0
,
2
,
1
).
reshape
(
B
,
3
,
-
1
,
L
)
patch_attn_mask
=
torch
.
zeros
((
B
,
1
,
max_patches
),
dtype
=
torch
.
bool
,
device
=
device
)
for
i
in
range
(
B
):
patch_attn_mask
[
i
,
0
,
:
tgt_sizes
[
i
][
0
]
*
tgt_sizes
[
i
][
1
]]
=
True
vision_embedding
=
self
.
vpm
(
all_pixel_values
.
type
(
dtype
),
patch_attention_mask
=
patch_attn_mask
,
tgt_sizes
=
tgt_sizes
,
).
last_hidden_state
return
self
.
resampler
(
vision_embedding
,
tgt_sizes
)
def
is_default_weight_loading
(
self
,
name
:
str
)
->
bool
:
return
"resampler"
in
name
or
"vpm"
in
name
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
()
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_minicpmv_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_minicpmv
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_minicpmv
)
class
MiniCPMV
(
MiniCPMVBaseModel
):
"""
Different versions of MiniCPMV use different visual encoders and LLMs,
which is not conducive to the current integration logic of LoRA and
bitsandbytes in vLLM. Therefore, it is necessary to separate them.
"""
def
__new__
(
cls
,
config
:
PretrainedConfig
,
multimodal_config
:
MultiModalConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
if
not
hasattr
(
config
,
"version"
):
if
config
.
hidden_size
==
2304
and
config
.
query_num
==
64
:
version
=
(
2
,
0
)
else
:
version
=
(
2
,
5
)
else
:
version
=
str
(
config
.
version
).
split
(
"."
)
version
=
tuple
([
int
(
x
)
for
x
in
version
])
# Dispatch class based on version
if
version
==
(
2
,
0
):
instance_class
=
MiniCPMV2
elif
version
==
(
2
,
5
):
instance_class
=
MiniCPMV2_5
else
:
instance_class
=
MiniCPMVQwen2
return
instance_class
(
config
,
multimodal_config
,
cache_config
,
quant_config
)
vllm/model_executor/models/na_vit.py
View file @
179a6a36
...
@@ -100,7 +100,7 @@ def _get_unpad_data(attention_mask):
...
@@ -100,7 +100,7 @@ def _get_unpad_data(attention_mask):
indices
=
torch
.
nonzero
(
attention_mask
.
flatten
(),
as_tuple
=
False
).
flatten
()
indices
=
torch
.
nonzero
(
attention_mask
.
flatten
(),
as_tuple
=
False
).
flatten
()
max_seqlen_in_batch
=
seqlens_in_batch
.
max
().
item
()
max_seqlen_in_batch
=
seqlens_in_batch
.
max
().
item
()
cu_seqlens
=
F
.
pad
(
cu_seqlens
=
F
.
pad
(
torch
.
cumsum
(
seqlens_in_batch
,
dim
=
0
,
dtype
=
torch
.
torch
.
int32
),
(
1
,
0
))
torch
.
cumsum
(
seqlens_in_batch
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
))
return
(
return
(
indices
,
indices
,
cu_seqlens
,
cu_seqlens
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment