Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
as the weights of the module stays untouched. The only required change would be on the forward pass
where it needs to correctly call the public API of flash attention and deal with padding tokens
in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
config.max_window_layers layers.
"""
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def__init__(self,*args,**kwargs):
super().__init__(*args,**kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
classSiglipVisionModelOutput(ModelOutput):
"""
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
Args:
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
The image embeddings obtained by applying the projection layer to the pooler_output.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip
classSiglipTextModelOutput(ModelOutput):
"""
Base class for text model's outputs that also contains a pooling of the last hidden states.
Args:
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
The text embeddings obtained by applying the projection layer to the pooler_output.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
SiglipAttention flash attention module. This module inherits from `SiglipAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
is_causal=False
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def__init__(self,*args,**kwargs):
super().__init__(*args,**kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states=query_states.view(
batch_size,q_len,self.num_heads,self.head_dim
).transpose(1,2)
key_states=key_states.view(
batch_size,q_len,self.num_heads,self.head_dim
).transpose(1,2)
value_states=value_states.view(
batch_size,q_len,self.num_heads,self.head_dim
).transpose(1,2)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states=query_states.transpose(1,2)
key_states=key_states.transpose(1,2)
value_states=value_states.transpose(1,2)
dropout_rate=self.dropoutifself.trainingelse0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
Siglip attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`SiglipAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
is_causal=False
# Adapted from SiglipAttention.forward and transformers.models.llama.modeling_llama.LlamaSdpaAttention.forward
defforward(
self,
hidden_states:torch.Tensor,
attention_mask:Optional[torch.Tensor]=None,
output_attentions:Optional[bool]=False,
)->Tuple[torch.Tensor,Optional[torch.Tensor]]:
ifoutput_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"SiglipModel is using SiglipSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
returnsuper().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
batch_size,q_len,_=hidden_states.size()
query_states=self.q_proj(hidden_states)
key_states=self.k_proj(hidden_states)
value_states=self.v_proj(hidden_states)
query_states=query_states.view(
batch_size,q_len,self.num_heads,self.head_dim
).transpose(1,2)
key_states=key_states.view(
batch_size,q_len,self.num_heads,self.head_dim
).transpose(1,2)
value_states=value_states.view(
batch_size,q_len,self.num_heads,self.head_dim
).transpose(1,2)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip.py#L287