Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8e60afa1
Unverified
Commit
8e60afa1
authored
Sep 30, 2024
by
Jee Jee Li
Committed by
GitHub
Sep 30, 2024
Browse files
[Model][LoRA]LoRA support added for MiniCPMV2.6 (#8943)
Co-authored-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
b6d73925
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
49 additions
and
880 deletions
+49
-880
vllm/model_executor/models/idefics2_vision_model.py
vllm/model_executor/models/idefics2_vision_model.py
+15
-9
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+34
-67
vllm/model_executor/models/na_vit.py
vllm/model_executor/models/na_vit.py
+0
-804
No files found.
vllm/model_executor/models/idefics2_vision_model.py
View file @
8e60afa1
...
@@ -65,11 +65,10 @@ class Idefics2VisionEmbeddings(nn.Module):
...
@@ -65,11 +65,10 @@ class Idefics2VisionEmbeddings(nn.Module):
self
.
position_embedding
=
nn
.
Embedding
(
self
.
num_positions
,
self
.
position_embedding
=
nn
.
Embedding
(
self
.
num_positions
,
self
.
embed_dim
)
self
.
embed_dim
)
def
forward
(
def
forward
(
self
,
self
,
pixel_values
:
torch
.
FloatTensor
,
pixel_values
:
torch
.
FloatTensor
,
patch_attention_mask
:
torch
.
BoolTensor
,
patch_attention_mask
:
torch
.
BoolTensor
,
tgt_sizes
:
Optional
[
torch
.
IntTensor
]
=
None
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
batch_size
,
_
,
max_im_h
,
max_im_w
=
pixel_values
.
shape
batch_size
,
_
,
max_im_h
,
max_im_w
=
pixel_values
.
shape
patch_embeds
=
self
.
patch_embedding
(
pixel_values
)
patch_embeds
=
self
.
patch_embedding
(
pixel_values
)
embeddings
=
patch_embeds
.
flatten
(
2
).
transpose
(
1
,
2
)
embeddings
=
patch_embeds
.
flatten
(
2
).
transpose
(
1
,
2
)
...
@@ -84,8 +83,13 @@ class Idefics2VisionEmbeddings(nn.Module):
...
@@ -84,8 +83,13 @@ class Idefics2VisionEmbeddings(nn.Module):
fill_value
=
0
)
fill_value
=
0
)
for
batch_idx
,
p_attn_mask
in
enumerate
(
patch_attention_mask
):
for
batch_idx
,
p_attn_mask
in
enumerate
(
patch_attention_mask
):
nb_patches_h
=
p_attn_mask
[:,
0
].
sum
()
nb_patches_w
=
p_attn_mask
[
0
].
sum
()
if
tgt_sizes
is
not
None
:
nb_patches_h
=
tgt_sizes
[
batch_idx
][
0
]
nb_patches_w
=
tgt_sizes
[
batch_idx
][
1
]
else
:
nb_patches_h
=
p_attn_mask
[:,
0
].
sum
()
nb_patches_w
=
p_attn_mask
[
0
].
sum
()
fractional_coords_h
=
torch
.
arange
(
0
,
1
-
1e-6
,
1
/
nb_patches_h
)
fractional_coords_h
=
torch
.
arange
(
0
,
1
-
1e-6
,
1
/
nb_patches_h
)
fractional_coords_w
=
torch
.
arange
(
0
,
1
-
1e-6
,
1
/
nb_patches_w
)
fractional_coords_w
=
torch
.
arange
(
0
,
1
-
1e-6
,
1
/
nb_patches_w
)
bucket_coords_h
=
torch
.
bucketize
(
fractional_coords_h
,
bucket_coords_h
=
torch
.
bucketize
(
fractional_coords_h
,
...
@@ -287,10 +291,12 @@ class Idefics2VisionTransformer(nn.Module):
...
@@ -287,10 +291,12 @@ class Idefics2VisionTransformer(nn.Module):
self
,
self
,
pixel_values
,
pixel_values
,
patch_attention_mask
:
Optional
[
torch
.
BoolTensor
]
=
None
,
patch_attention_mask
:
Optional
[
torch
.
BoolTensor
]
=
None
,
)
->
torch
.
tensor
:
tgt_sizes
:
Optional
[
torch
.
IntTensor
]
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embeddings
(
hidden_states
=
self
.
embeddings
(
pixel_values
=
pixel_values
,
pixel_values
=
pixel_values
,
patch_attention_mask
=
patch_attention_mask
)
patch_attention_mask
=
patch_attention_mask
,
tgt_sizes
=
tgt_sizes
)
encoder_outputs
=
self
.
encoder
(
hidden_states
)
encoder_outputs
=
self
.
encoder
(
hidden_states
)
last_hidden_state
=
self
.
post_layernorm
(
encoder_outputs
)
last_hidden_state
=
self
.
post_layernorm
(
encoder_outputs
)
return
last_hidden_state
return
last_hidden_state
vllm/model_executor/models/minicpmv.py
View file @
8e60afa1
...
@@ -31,17 +31,15 @@ import torch
...
@@ -31,17 +31,15 @@ import torch
import
torch.types
import
torch.types
from
PIL
import
Image
from
PIL
import
Image
from
torch
import
nn
from
torch
import
nn
from
torch.nn.init
import
trunc_normal_
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
typing_extensions
import
NotRequired
from
typing_extensions
import
NotRequired
from
vllm.attention
import
AttentionMetadata
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
MultiModalConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.resampler
import
(
Resampler2
,
from
vllm.model_executor.layers.resampler
import
(
BaseResampler
,
Resampler2
,
get_2d_sincos_pos_embed
)
get_2d_sincos_pos_embed
)
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
...
@@ -106,58 +104,6 @@ class MiniCPMVImagePixelInputs(TypedDict):
...
@@ -106,58 +104,6 @@ class MiniCPMVImagePixelInputs(TypedDict):
DEFAULT_LN
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
DEFAULT_LN
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
class
BaseResampler
(
nn
.
Module
):
"""
A 2D perceiver-resampler network with one cross attention layers by
(grid_size**2) learnable queries and 2d sincos pos_emb
Outputs:
A tensor with the shape of (grid_size**2, embed_dim)
"""
def
__init__
(
self
,
num_queries
:
int
,
embed_dim
:
int
,
num_heads
:
int
,
kv_dim
:
Optional
[
int
]
=
None
,
norm_layer
:
Callable
[[
int
],
nn
.
LayerNorm
]
=
DEFAULT_LN
,
)
->
None
:
super
().
__init__
()
self
.
num_queries
=
num_queries
self
.
embed_dim
=
embed_dim
self
.
num_heads
=
num_heads
self
.
query
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_queries
,
embed_dim
))
trunc_normal_
(
self
.
query
,
std
=
0.02
)
if
kv_dim
is
not
None
and
kv_dim
!=
embed_dim
:
self
.
kv_proj
=
ReplicatedLinear
(
kv_dim
,
embed_dim
,
bias
=
False
)
else
:
# Maintain the same return value with ReplicatedLinear.forward
self
.
kv_proj
=
lambda
*
args
,
**
kwargs
:
(
nn
.
Identity
()(
*
args
,
**
kwargs
),
None
,
)
self
.
attn
=
nn
.
MultiheadAttention
(
embed_dim
,
num_heads
)
self
.
ln_q
=
norm_layer
(
embed_dim
)
self
.
ln_kv
=
norm_layer
(
embed_dim
)
self
.
ln_post
=
norm_layer
(
embed_dim
)
self
.
proj
=
nn
.
Parameter
(
(
embed_dim
**-
0.5
)
*
torch
.
randn
(
embed_dim
,
embed_dim
))
def
_init_weights
(
self
,
m
:
nn
.
Module
)
->
None
:
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
0.02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
_repeat
(
self
,
query
,
N
:
int
):
return
query
.
unsqueeze
(
1
).
repeat
(
1
,
N
,
1
)
class
Resampler2_5
(
BaseResampler
):
class
Resampler2_5
(
BaseResampler
):
def
__init__
(
def
__init__
(
...
@@ -869,7 +815,35 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
...
@@ -869,7 +815,35 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
return
"resampler"
in
name
return
"resampler"
in
name
class
MiniCPMV2_6
(
MiniCPMVBaseModel
):
class
MiniCPMV2_6
(
MiniCPMVBaseModel
,
SupportsLoRA
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
# LoRA specific attributes
supported_lora_modules
=
[
# vision encoder
"fc1"
,
"fc2"
,
"out_proj"
,
# language model
"qkv_proj"
,
# same name with vision encoder
"o_proj"
,
"gate_up_proj"
,
"down_proj"
,
# resampler
"kv_proj"
,
]
embedding_modules
=
{}
embedding_padding_modules
=
[]
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -894,15 +868,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
...
@@ -894,15 +868,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
name
=
"model"
)
name
=
"model"
)
def
init_vision_module
(
self
)
->
nn
.
Module
:
def
init_vision_module
(
self
)
->
nn
.
Module
:
# A custom version of SiglipVisionTransformer, won't work with TP
from
vllm.model_executor.models.na_vit
import
SiglipVisionTransformer
if
self
.
config
.
_attn_implementation
==
"flash_attention_2"
:
model
=
Idefics2VisionTransformer
(
self
.
config
.
vision_config
)
self
.
config
.
vision_config
.
_attn_implementation
=
"flash_attention_2"
else
:
# not support sdpa
self
.
config
.
vision_config
.
_attn_implementation
=
"eager"
model
=
SiglipVisionTransformer
(
self
.
config
.
vision_config
)
if
self
.
config
.
drop_vision_last_layer
:
if
self
.
config
.
drop_vision_last_layer
:
model
.
encoder
.
layers
=
model
.
encoder
.
layers
[:
-
1
]
model
.
encoder
.
layers
=
model
.
encoder
.
layers
[:
-
1
]
return
model
return
model
...
@@ -928,7 +895,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
...
@@ -928,7 +895,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
pixel_values
,
pixel_values
,
patch_attention_mask
=
patch_attn_mask
,
patch_attention_mask
=
patch_attn_mask
,
tgt_sizes
=
tgt_sizes
,
tgt_sizes
=
tgt_sizes
,
)
.
last_hidden_state
)
return
vision_embedding
return
vision_embedding
def
get_vision_hidden_states
(
def
get_vision_hidden_states
(
...
@@ -960,12 +927,12 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
...
@@ -960,12 +927,12 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
all_pixel_values
.
type
(
dtype
),
all_pixel_values
.
type
(
dtype
),
patch_attention_mask
=
patch_attn_mask
,
patch_attention_mask
=
patch_attn_mask
,
tgt_sizes
=
tgt_sizes
,
tgt_sizes
=
tgt_sizes
,
)
.
last_hidden_state
)
return
self
.
resampler
(
vision_embedding
,
tgt_sizes
)
return
self
.
resampler
(
vision_embedding
,
tgt_sizes
)
def
is_default_weight_loading
(
self
,
name
:
str
)
->
bool
:
def
is_default_weight_loading
(
self
,
name
:
str
)
->
bool
:
return
"resampler"
in
name
or
"vpm"
in
name
return
"resampler"
in
name
_SUPPORT_VERSION
=
{
_SUPPORT_VERSION
=
{
...
...
vllm/model_executor/models/na_vit.py
deleted
100644 → 0
View file @
b6d73925
import
logging
import
math
import
os
import
warnings
from
typing
import
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch.nn.init
import
_calculate_fan_in_and_fan_out
from
transformers.activations
import
ACT2FN
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.modeling_attn_mask_utils
import
_prepare_4d_attention_mask
from
transformers.modeling_outputs
import
(
BaseModelOutput
,
BaseModelOutputWithPooling
)
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.utils
import
(
ModelOutput
,
is_flash_attn_2_available
,
replace_return_docstrings
)
logger
=
logging
.
getLogger
(
"vllm"
)
# For Siglip: copied from
# HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes
# Remove hints as there's little possibility to change these code.
class
SiglipVisionConfig
(
PretrainedConfig
):
model_type
=
"siglip_vision_model"
def
__init__
(
self
,
hidden_size
=
768
,
intermediate_size
=
3072
,
num_hidden_layers
=
12
,
num_attention_heads
=
12
,
num_channels
=
3
,
image_size
=
224
,
patch_size
=
16
,
hidden_act
=
"gelu_pytorch_tanh"
,
layer_norm_eps
=
1e-6
,
attention_dropout
=
0.0
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
num_channels
=
num_channels
self
.
patch_size
=
patch_size
self
.
image_size
=
image_size
self
.
attention_dropout
=
attention_dropout
self
.
layer_norm_eps
=
layer_norm_eps
self
.
hidden_act
=
hidden_act
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
**
kwargs
)
->
"PretrainedConfig"
:
cls
.
_set_token_in_kwargs
(
kwargs
)
config_dict
,
kwargs
=
cls
.
get_config_dict
(
pretrained_model_name_or_path
,
**
kwargs
)
# get the vision config dict if we are loading from SiglipConfig
if
config_dict
.
get
(
"model_type"
)
==
"siglip"
:
config_dict
=
config_dict
[
"vision_config"
]
if
"model_type"
in
config_dict
and
hasattr
(
cls
,
"model_type"
)
and
config_dict
[
"model_type"
]
!=
cls
.
model_type
:
logger
.
warning
(
"You are using a model of type %s to "
"instantiate a model of type %s. "
"This is not supported for all configurations"
"of models and can yield errors."
,
config_dict
[
'model_type'
],
cls
.
model_type
)
return
cls
.
from_dict
(
config_dict
,
**
kwargs
)
_CHECKPOINT_FOR_DOC
=
"google/siglip-base-patch16-224"
SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST
=
[
"google/siglip-base-patch16-224"
,
# See all SigLIP models at https://huggingface.co/models?filter=siglip
]
if
is_flash_attn_2_available
():
from
flash_attn
import
flash_attn_func
,
flash_attn_varlen_func
from
flash_attn.bert_padding
import
pad_input
# noqa
from
flash_attn.bert_padding
import
index_first_axis
,
unpad_input
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def
_get_unpad_data
(
attention_mask
):
seqlens_in_batch
=
attention_mask
.
sum
(
dim
=-
1
,
dtype
=
torch
.
int32
)
indices
=
torch
.
nonzero
(
attention_mask
.
flatten
(),
as_tuple
=
False
).
flatten
()
max_seqlen_in_batch
=
seqlens_in_batch
.
max
().
item
()
cu_seqlens
=
F
.
pad
(
torch
.
cumsum
(
seqlens_in_batch
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
))
return
(
indices
,
cu_seqlens
,
max_seqlen_in_batch
,
)
def
_trunc_normal_
(
tensor
,
mean
,
std
,
a
,
b
):
def
norm_cdf
(
x
):
# Computes standard normal cumulative distribution function
return
(
1.0
+
math
.
erf
(
x
/
math
.
sqrt
(
2.0
)))
/
2.0
if
(
mean
<
a
-
2
*
std
)
or
(
mean
>
b
+
2
*
std
):
warnings
.
warn
(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect."
,
stacklevel
=
2
,
)
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l_
=
norm_cdf
((
a
-
mean
)
/
std
)
u
=
norm_cdf
((
b
-
mean
)
/
std
)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor
.
uniform_
(
2
*
l_
-
1
,
2
*
u
-
1
)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
if
tensor
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]:
# The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu
og_dtype
=
tensor
.
dtype
tensor
=
tensor
.
to
(
torch
.
float32
)
tensor
.
erfinv_
()
tensor
=
tensor
.
to
(
og_dtype
)
else
:
tensor
.
erfinv_
()
# Transform to proper mean, std
tensor
.
mul_
(
std
*
math
.
sqrt
(
2.0
))
tensor
.
add_
(
mean
)
# Clamp to ensure it's in the proper range
if
tensor
.
dtype
==
torch
.
float16
:
# The `clamp_` op is not (yet?) defined in float16+cpu
tensor
=
tensor
.
to
(
torch
.
float32
)
tensor
.
clamp_
(
min
=
a
,
max
=
b
)
tensor
=
tensor
.
to
(
torch
.
float16
)
else
:
tensor
.
clamp_
(
min
=
a
,
max
=
b
)
def
trunc_normal_tf_
(
tensor
:
torch
.
Tensor
,
mean
:
float
=
0.0
,
std
:
float
=
1.0
,
a
:
float
=
-
2.0
,
b
:
float
=
2.0
)
->
torch
.
Tensor
:
with
torch
.
no_grad
():
_trunc_normal_
(
tensor
,
0
,
1.0
,
a
,
b
)
tensor
.
mul_
(
std
).
add_
(
mean
)
def
variance_scaling_
(
tensor
,
scale
=
1.0
,
mode
=
"fan_in"
,
distribution
=
"normal"
):
fan_in
,
fan_out
=
_calculate_fan_in_and_fan_out
(
tensor
)
if
mode
==
"fan_in"
:
denom
=
fan_in
elif
mode
==
"fan_out"
:
denom
=
fan_out
elif
mode
==
"fan_avg"
:
denom
=
(
fan_in
+
fan_out
)
/
2
variance
=
scale
/
denom
if
distribution
==
"truncated_normal"
:
# constant is stddev of standard normal truncated to (-2, 2)
trunc_normal_tf_
(
tensor
,
std
=
math
.
sqrt
(
variance
)
/
0.87962566103423978
)
elif
distribution
==
"normal"
:
with
torch
.
no_grad
():
tensor
.
normal_
(
std
=
math
.
sqrt
(
variance
))
elif
distribution
==
"uniform"
:
bound
=
math
.
sqrt
(
3
*
variance
)
with
torch
.
no_grad
():
tensor
.
uniform_
(
-
bound
,
bound
)
else
:
raise
ValueError
(
f
"invalid distribution
{
distribution
}
"
)
def
lecun_normal_
(
tensor
):
variance_scaling_
(
tensor
,
mode
=
"fan_in"
,
distribution
=
"truncated_normal"
)
def
default_flax_embed_init
(
tensor
):
variance_scaling_
(
tensor
,
mode
=
"fan_in"
,
distribution
=
"normal"
)
class
SiglipVisionModelOutput
(
ModelOutput
):
image_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
last_hidden_state
:
torch
.
FloatTensor
=
None
hidden_states
:
Optional
[
Tuple
[
torch
.
FloatTensor
]]
=
None
attentions
:
Optional
[
Tuple
[
torch
.
FloatTensor
]]
=
None
class
SiglipVisionEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
config
:
SiglipVisionConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
image_size
=
config
.
image_size
self
.
patch_size
=
config
.
patch_size
self
.
patch_embedding
=
nn
.
Conv2d
(
in_channels
=
config
.
num_channels
,
out_channels
=
self
.
embed_dim
,
kernel_size
=
self
.
patch_size
,
stride
=
self
.
patch_size
,
padding
=
"valid"
,
)
self
.
num_patches_per_side
=
self
.
image_size
//
self
.
patch_size
self
.
num_patches
=
self
.
num_patches_per_side
**
2
self
.
num_positions
=
self
.
num_patches
self
.
position_embedding
=
nn
.
Embedding
(
self
.
num_positions
,
self
.
embed_dim
)
def
forward
(
self
,
pixel_values
:
torch
.
FloatTensor
,
patch_attention_mask
:
torch
.
BoolTensor
,
tgt_sizes
:
Optional
[
torch
.
IntTensor
]
=
None
)
->
torch
.
Tensor
:
batch_size
=
pixel_values
.
size
(
0
)
patch_embeds
=
self
.
patch_embedding
(
pixel_values
)
embeddings
=
patch_embeds
.
flatten
(
2
).
transpose
(
1
,
2
)
max_im_h
,
max_im_w
=
pixel_values
.
size
(
2
),
pixel_values
.
size
(
3
)
max_nb_patches_h
,
max_nb_patches_w
=
(
max_im_h
//
self
.
patch_size
,
max_im_w
//
self
.
patch_size
)
boundaries
=
torch
.
arange
(
1
/
self
.
num_patches_per_side
,
1.0
,
1
/
self
.
num_patches_per_side
)
position_ids
=
torch
.
full
(
size
=
(
batch_size
,
max_nb_patches_h
*
max_nb_patches_w
,
),
fill_value
=
0
,
)
for
batch_idx
,
p_attn_mask
in
enumerate
(
patch_attention_mask
):
if
tgt_sizes
is
not
None
:
nb_patches_h
=
tgt_sizes
[
batch_idx
][
0
]
nb_patches_w
=
tgt_sizes
[
batch_idx
][
1
]
else
:
nb_patches_h
=
p_attn_mask
[:,
0
].
sum
()
nb_patches_w
=
p_attn_mask
[
0
].
sum
()
fractional_coords_h
=
torch
.
arange
(
0
,
1
-
1e-6
,
1
/
nb_patches_h
)
fractional_coords_w
=
torch
.
arange
(
0
,
1
-
1e-6
,
1
/
nb_patches_w
)
bucket_coords_h
=
torch
.
bucketize
(
fractional_coords_h
,
boundaries
,
right
=
True
)
bucket_coords_w
=
torch
.
bucketize
(
fractional_coords_w
,
boundaries
,
right
=
True
)
pos_ids
=
(
bucket_coords_h
[:,
None
]
*
self
.
num_patches_per_side
+
bucket_coords_w
).
flatten
()
position_ids
[
batch_idx
][
p_attn_mask
.
view
(
-
1
).
cpu
()]
=
pos_ids
position_ids
=
position_ids
.
to
(
self
.
position_embedding
.
weight
.
device
)
embeddings
=
embeddings
+
self
.
position_embedding
(
position_ids
)
return
embeddings
class
SiglipAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
# Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
head_dim
=
self
.
embed_dim
//
self
.
num_heads
if
self
.
head_dim
*
self
.
num_heads
!=
self
.
embed_dim
:
raise
ValueError
(
"embed_dim must be divisible by num_heads (got `embed_dim`: "
f
"
{
self
.
embed_dim
}
and `num_heads`:"
f
"
{
self
.
num_heads
}
)."
)
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
dropout
=
config
.
attention_dropout
self
.
k_proj
=
nn
.
Linear
(
self
.
embed_dim
,
self
.
embed_dim
)
self
.
v_proj
=
nn
.
Linear
(
self
.
embed_dim
,
self
.
embed_dim
)
self
.
q_proj
=
nn
.
Linear
(
self
.
embed_dim
,
self
.
embed_dim
)
self
.
out_proj
=
nn
.
Linear
(
self
.
embed_dim
,
self
.
embed_dim
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""Input shape: Batch x Time x Channel"""
batch_size
,
q_len
,
_
=
hidden_states
.
size
()
query_states
=
self
.
q_proj
(
hidden_states
)
key_states
=
self
.
k_proj
(
hidden_states
)
value_states
=
self
.
v_proj
(
hidden_states
)
query_states
=
query_states
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
k_v_seq_len
=
key_states
.
shape
[
-
2
]
attn_weights
=
torch
.
matmul
(
query_states
,
key_states
.
transpose
(
2
,
3
))
*
self
.
scale
if
attn_weights
.
size
()
!=
(
batch_size
,
self
.
num_heads
,
q_len
,
k_v_seq_len
):
raise
ValueError
(
"Attention weights should be of size "
f
"
{
(
batch_size
,
self
.
num_heads
,
q_len
,
k_v_seq_len
)
}
, but is"
f
"
{
attn_weights
.
size
()
}
"
)
if
attention_mask
is
not
None
:
if
attention_mask
.
size
()
!=
(
batch_size
,
1
,
q_len
,
k_v_seq_len
):
raise
ValueError
(
"Attention mask should be of size "
f
"
{
(
batch_size
,
1
,
q_len
,
k_v_seq_len
)
}
"
,
f
"but is
{
attention_mask
.
size
()
}
"
)
attn_weights
=
attn_weights
+
attention_mask
# upcast attention to fp32
attn_weights
=
nn
.
functional
.
softmax
(
attn_weights
,
dim
=-
1
,
dtype
=
torch
.
float32
).
to
(
query_states
.
dtype
)
attn_weights
=
nn
.
functional
.
dropout
(
attn_weights
,
p
=
self
.
dropout
,
training
=
self
.
training
)
attn_output
=
torch
.
matmul
(
attn_weights
,
value_states
)
if
attn_output
.
size
()
!=
(
batch_size
,
self
.
num_heads
,
q_len
,
self
.
head_dim
):
raise
ValueError
(
"`attn_output` should be of size "
f
"
{
(
batch_size
,
self
.
num_heads
,
q_len
,
self
.
head_dim
)
}
, "
"but is"
f
"
{
attn_output
.
size
()
}
"
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
reshape
(
batch_size
,
q_len
,
self
.
embed_dim
)
attn_output
=
self
.
out_proj
(
attn_output
)
return
attn_output
,
attn_weights
class
SiglipFlashAttention2
(
SiglipAttention
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
is_causal
=
False
# Hack to make sure we don't use a causal mask
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
output_attentions
=
False
bsz
,
q_len
,
_
=
hidden_states
.
size
()
query_states
=
self
.
q_proj
(
hidden_states
)
key_states
=
self
.
k_proj
(
hidden_states
)
value_states
=
self
.
v_proj
(
hidden_states
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
kv_seq_len
=
key_states
.
shape
[
-
2
]
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
query_states
=
query_states
.
transpose
(
1
,
2
)
key_states
=
key_states
.
transpose
(
1
,
2
)
value_states
=
value_states
.
transpose
(
1
,
2
)
dropout_rate
=
self
.
dropout
if
self
.
training
else
0.0
input_dtype
=
query_states
.
dtype
if
input_dtype
==
torch
.
float32
:
if
torch
.
is_autocast_enabled
():
target_dtype
=
torch
.
get_autocast_gpu_dtype
()
# Handle the case where the model is quantized
elif
hasattr
(
self
.
config
,
"_pre_quantization_dtype"
):
target_dtype
=
self
.
config
.
_pre_quantization_dtype
else
:
target_dtype
=
self
.
q_proj
.
weight
.
dtype
logger
.
warning
(
"The input hidden states seems to be "
"silently casted in float32, "
"this might be related to the fact "
"you have upcasted embedding or layer norm layers in float32. "
"We will cast back the input in"
" %s."
,
target_dtype
)
query_states
=
query_states
.
to
(
target_dtype
)
key_states
=
key_states
.
to
(
target_dtype
)
value_states
=
value_states
.
to
(
target_dtype
)
attn_output
=
self
.
_flash_attention_forward
(
query_states
,
key_states
,
value_states
,
attention_mask
,
q_len
,
dropout
=
dropout_rate
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
embed_dim
).
contiguous
()
attn_output
=
self
.
out_proj
(
attn_output
)
if
not
output_attentions
:
attn_weights
=
None
return
attn_output
,
attn_weights
def
_flash_attention_forward
(
self
,
query_states
,
key_states
,
value_states
,
attention_mask
,
query_length
,
dropout
=
0.0
,
softmax_scale
=
None
):
causal
=
self
.
is_causal
and
query_length
!=
1
# Contains at least one padding token in the sequence
if
attention_mask
is
not
None
:
batch_size
=
query_states
.
shape
[
0
]
(
query_states
,
key_states
,
value_states
,
indices_q
,
cu_seq_lens
,
max_seq_lens
)
=
self
.
_upad_input
(
query_states
,
key_states
,
value_states
,
attention_mask
,
query_length
)
cu_seqlens_q
,
cu_seqlens_k
=
cu_seq_lens
max_seqlen_in_batch_q
,
max_seqlen_in_batch_k
=
max_seq_lens
attn_output_unpad
=
flash_attn_varlen_func
(
query_states
,
key_states
,
value_states
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
max_seqlen_in_batch_q
,
max_seqlen_k
=
max_seqlen_in_batch_k
,
dropout_p
=
dropout
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
)
attn_output
=
pad_input
(
attn_output_unpad
,
indices_q
,
batch_size
,
query_length
)
else
:
attn_output
=
flash_attn_func
(
query_states
,
key_states
,
value_states
,
dropout
,
softmax_scale
=
softmax_scale
,
causal
=
causal
)
return
attn_output
def
_upad_input
(
self
,
query_layer
,
key_layer
,
value_layer
,
attention_mask
,
query_length
):
indices_k
,
cu_seqlens_k
,
max_seqlen_in_batch_k
=
_get_unpad_data
(
attention_mask
)
batch_size
,
kv_seq_len
,
num_key_value_heads
,
head_dim
=
key_layer
.
shape
key_layer
=
index_first_axis
(
key_layer
.
reshape
(
batch_size
*
kv_seq_len
,
num_key_value_heads
,
head_dim
),
indices_k
)
value_layer
=
index_first_axis
(
value_layer
.
reshape
(
batch_size
*
kv_seq_len
,
num_key_value_heads
,
head_dim
),
indices_k
)
if
query_length
==
kv_seq_len
:
query_layer
=
index_first_axis
(
query_layer
.
reshape
(
batch_size
*
kv_seq_len
,
self
.
num_heads
,
head_dim
),
indices_k
)
cu_seqlens_q
=
cu_seqlens_k
max_seqlen_in_batch_q
=
max_seqlen_in_batch_k
indices_q
=
indices_k
elif
query_length
==
1
:
max_seqlen_in_batch_q
=
1
cu_seqlens_q
=
torch
.
arange
(
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
query_layer
.
device
)
# There is a memcpy here, that is very bad.
indices_q
=
cu_seqlens_q
[:
-
1
]
query_layer
=
query_layer
.
squeeze
(
1
)
else
:
# The -q_len: slice assumes left padding.
attention_mask
=
attention_mask
[:,
-
query_length
:]
(
query_layer
,
indices_q
,
cu_seqlens_q
,
max_seqlen_in_batch_q
)
=
unpad_input
(
query_layer
,
attention_mask
)
return
(
query_layer
,
key_layer
,
value_layer
,
indices_q
,
(
cu_seqlens_q
,
cu_seqlens_k
),
(
max_seqlen_in_batch_q
,
max_seqlen_in_batch_k
),
)
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
class
SiglipMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
activation_fn
=
ACT2FN
[
config
.
hidden_act
]
self
.
fc1
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
intermediate_size
)
self
.
fc2
=
nn
.
Linear
(
config
.
intermediate_size
,
config
.
hidden_size
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
fc1
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
hidden_states
)
hidden_states
=
self
.
fc2
(
hidden_states
)
return
hidden_states
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer
# with CLIP->Siglip
class
SiglipEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
SiglipVisionConfig
):
super
().
__init__
()
self
.
embed_dim
=
config
.
hidden_size
self
.
_use_flash_attention_2
=
(
config
.
_attn_implementation
==
"flash_attention_2"
)
self
.
self_attn
=
(
SiglipAttention
(
config
)
if
not
self
.
_use_flash_attention_2
else
SiglipFlashAttention2
(
config
))
self
.
layer_norm1
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
mlp
=
SiglipMLP
(
config
)
self
.
layer_norm2
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
,
output_attentions
:
Optional
[
bool
]
=
False
,
)
->
Tuple
[
torch
.
FloatTensor
]:
residual
=
hidden_states
hidden_states
=
self
.
layer_norm1
(
hidden_states
)
hidden_states
,
attn_weights
=
self
.
self_attn
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
output_attentions
=
output_attentions
,
)
hidden_states
=
residual
+
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
layer_norm2
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
outputs
=
(
hidden_states
,
)
if
output_attentions
:
outputs
+=
(
attn_weights
,
)
return
outputs
class
SiglipPreTrainedModel
(
PreTrainedModel
):
config_class
=
SiglipVisionConfig
base_model_prefix
=
"siglip"
supports_gradient_checkpointing
=
True
def
_init_weights
(
self
,
module
):
"""Initialize the weights"""
if
isinstance
(
module
,
SiglipVisionEmbeddings
):
width
=
self
.
config
.
hidden_size
nn
.
init
.
normal_
(
module
.
position_embedding
.
weight
,
std
=
1
/
np
.
sqrt
(
width
))
elif
isinstance
(
module
,
nn
.
Embedding
):
default_flax_embed_init
(
module
.
weight
)
elif
isinstance
(
module
,
SiglipAttention
):
nn
.
init
.
normal_
(
module
.
q_proj
.
weight
)
nn
.
init
.
normal_
(
module
.
k_proj
.
weight
)
nn
.
init
.
normal_
(
module
.
v_proj
.
weight
)
nn
.
init
.
normal_
(
module
.
out_proj
.
weight
)
nn
.
init
.
zeros_
(
module
.
q_proj
.
bias
)
nn
.
init
.
zeros_
(
module
.
k_proj
.
bias
)
nn
.
init
.
zeros_
(
module
.
v_proj
.
bias
)
nn
.
init
.
zeros_
(
module
.
out_proj
.
bias
)
elif
isinstance
(
module
,
SiglipMLP
):
nn
.
init
.
normal_
(
module
.
fc1
.
weight
)
nn
.
init
.
normal_
(
module
.
fc2
.
weight
)
nn
.
init
.
normal_
(
module
.
fc1
.
bias
,
std
=
1e-6
)
nn
.
init
.
normal_
(
module
.
fc2
.
bias
,
std
=
1e-6
)
elif
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Conv2d
)):
lecun_normal_
(
module
.
weight
)
if
module
.
bias
is
not
None
:
nn
.
init
.
zeros_
(
module
.
bias
)
elif
isinstance
(
module
,
nn
.
LayerNorm
):
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
# Copied from transformers.models.clip.modeling_clip.CLIPEncoder
# with CLIP->Siglip
class
SiglipEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
:
SiglipVisionConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
layers
=
nn
.
ModuleList
([
SiglipEncoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
gradient_checkpointing
=
False
# Ignore copy
def
forward
(
self
,
inputs_embeds
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutput
]:
output_attentions
=
output_attentions
if
output_attentions
is
not
None
\
else
self
.
config
.
output_attentions
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
return_dict
=
return_dict
if
return_dict
is
not
None
\
else
self
.
config
.
use_return_dict
encoder_states
=
()
if
output_hidden_states
else
None
all_attentions
=
()
if
output_attentions
else
None
hidden_states
=
inputs_embeds
for
encoder_layer
in
self
.
layers
:
if
output_hidden_states
:
encoder_states
=
encoder_states
+
(
hidden_states
,
)
if
self
.
gradient_checkpointing
and
self
.
training
:
layer_outputs
=
self
.
_gradient_checkpointing_func
(
encoder_layer
.
__call__
,
hidden_states
,
attention_mask
,
output_attentions
,
)
else
:
layer_outputs
=
encoder_layer
(
hidden_states
,
attention_mask
,
output_attentions
=
output_attentions
,
)
hidden_states
=
layer_outputs
[
0
]
if
output_attentions
:
all_attentions
=
all_attentions
+
(
layer_outputs
[
1
],
)
if
output_hidden_states
:
encoder_states
=
encoder_states
+
(
hidden_states
,
)
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
encoder_states
,
all_attentions
]
if
v
is
not
None
)
return
BaseModelOutput
(
last_hidden_state
=
hidden_states
,
hidden_states
=
encoder_states
,
attentions
=
all_attentions
)
class
SiglipVisionTransformer
(
SiglipPreTrainedModel
):
config_class
=
SiglipVisionConfig
main_input_name
=
"pixel_values"
_supports_flash_attn_2
=
True
def
__init__
(
self
,
config
:
SiglipVisionConfig
):
super
().
__init__
(
config
)
self
.
config
=
config
embed_dim
=
config
.
hidden_size
self
.
embeddings
=
SiglipVisionEmbeddings
(
config
)
self
.
encoder
=
SiglipEncoder
(
config
)
self
.
post_layernorm
=
nn
.
LayerNorm
(
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
_use_flash_attention_2
=
(
config
.
_attn_implementation
==
"flash_attention_2"
)
# Initialize weights and apply final processing
self
.
post_init
()
def
get_input_embeddings
(
self
)
->
nn
.
Module
:
return
self
.
embeddings
.
patch_embedding
@
replace_return_docstrings
(
output_type
=
BaseModelOutputWithPooling
,
config_class
=
SiglipVisionConfig
)
def
forward
(
self
,
pixel_values
,
patch_attention_mask
:
Optional
[
torch
.
BoolTensor
]
=
None
,
tgt_sizes
:
Optional
[
torch
.
IntTensor
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutputWithPooling
]:
r
"""
Returns:
"""
output_attentions
=
output_attentions
if
output_attentions
is
not
None
\
else
self
.
config
.
output_attentions
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
return_dict
=
return_dict
if
return_dict
is
not
None
\
else
self
.
config
.
use_return_dict
batch_size
=
pixel_values
.
size
(
0
)
if
patch_attention_mask
is
None
:
patch_attention_mask
=
torch
.
ones
(
size
=
(
batch_size
,
pixel_values
.
size
(
2
)
//
self
.
config
.
patch_size
,
pixel_values
.
size
(
3
)
//
self
.
config
.
patch_size
,
),
dtype
=
torch
.
bool
,
device
=
pixel_values
.
device
,
)
hidden_states
=
self
.
embeddings
(
pixel_values
=
pixel_values
,
patch_attention_mask
=
patch_attention_mask
,
tgt_sizes
=
tgt_sizes
)
patch_attention_mask
=
patch_attention_mask
.
view
(
batch_size
,
-
1
)
# The call to `_upad_input` in `_flash_attention_forward` is expensive
# So when the `patch_attention_mask` is full of 1s
# (i.e. attending to the whole sequence),
# avoiding passing the attention_mask,
# which is equivalent to attending to the full sequence
if
not
torch
.
any
(
~
patch_attention_mask
):
attention_mask
=
None
else
:
attention_mask
=
(
_prepare_4d_attention_mask
(
patch_attention_mask
,
hidden_states
.
dtype
)
if
not
self
.
_use_flash_attention_2
else
patch_attention_mask
)
encoder_outputs
=
self
.
encoder
(
inputs_embeds
=
hidden_states
,
attention_mask
=
attention_mask
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
)
last_hidden_state
=
encoder_outputs
[
0
]
last_hidden_state
=
self
.
post_layernorm
(
last_hidden_state
)
if
not
return_dict
:
return
(
last_hidden_state
,
None
)
+
encoder_outputs
[
1
:]
return
BaseModelOutputWithPooling
(
last_hidden_state
=
last_hidden_state
,
pooler_output
=
None
,
hidden_states
=
encoder_outputs
.
hidden_states
,
attentions
=
encoder_outputs
.
attentions
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment