Unverified Commit 80377eb0 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

F.scaled_dot_product_attention support (#26572)



* add sdpa

* wip

* cleaning

* add ref

* yet more cleaning

* and more :)

* wip llama

* working llama

* add output_attentions=True support

* bigcode sdpa support

* fixes

* gpt-bigcode support, require torch>=2.1.1

* add falcon support

* fix conflicts falcon

* style

* fix attention_mask definition

* remove output_attentions from attnmaskconverter

* support whisper without removing any Copied from statement

* fix mbart default to eager renaming

* fix typo in falcon

* fix is_causal in SDPA

* check is_flash_attn_2_available in the models init as well in case the model is not initialized through from_pretrained

* add warnings when falling back on the manual implementation

* precise doc

* wip replace _flash_attn_enabled by config.attn_implementation

* fix typo

* add tests

* style

* add a copy.deepcopy on the config in from_pretrained, as we do not want to modify it inplace

* obey to config.attn_implementation if a config is passed in from_pretrained

* fix is_torch_sdpa_available when torch is not installed

* remove dead code

* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/bart/modeling_bart.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* remove duplicate pretraining_tp code

* add dropout in llama

* precise comment on attn_mask

* add fmt: off for _unmask_unattended docstring

* precise num_masks comment

* nuke pretraining_tp in LlamaSDPAAttention following Arthur's suggestion

* cleanup modeling_utils

* backward compatibility

* fix style as requested

* style

* improve documentation

* test pass

* style

* add _unmask_unattended tests

* skip meaningless tests for idefics

* hard_check SDPA requirements when specifically requested

* standardize the use if XXX_ATTENTION_CLASSES

* fix SDPA bug with mem-efficient backend on CUDA when using fp32

* fix test

* rely on SDPA is_causal parameter to handle the causal mask in some cases

* fix FALCON_ATTENTION_CLASSES

* remove _flash_attn_2_enabled occurences

* fix test

* add OPT to the list of supported flash models

* improve test

* properly test on different SDPA backends, on different dtypes & properly handle separately the pad tokens in the test

* remove remaining _flash_attn_2_enabled occurence

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update docs/source/en/perf_infer_gpu_one.md
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* remove use_attn_implementation

* fix docstring & slight bug

* make attn_implementation internal (_attn_implementation)

* typos

* fix tests

* deprecate use_flash_attention_2=True

* fix test

* add back llama that was removed by mistake

* fix tests

* remove _flash_attn_2_enabled occurences bis

* add check & test that passed attn_implementation is valid

* fix falcon torchscript export

* fix device of mask in tests

* add tip about torch.jit.trace and move bt doc below sdpa

* fix parameterized.expand order

* move tests from test_modeling_attn_mask_utils to test_modeling_utils as a relevant test class is already there

* update sdpaattention class with the new cache

* Update src/transformers/configuration_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/bark/modeling_bark.py

* address review comments

* WIP torch.jit.trace fix. left: test both eager & sdpa

* add test for torch.jit.trace for both eager/sdpa

* fix falcon with torch==2.0 that needs to use sdpa

* fix doc

* hopefully last fix

* fix key_value_length that has no default now in mask converter

* is it flacky?

* fix speculative decoding bug

* tests do pass

* fix following #27907

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent ce0bbd51
......@@ -471,6 +471,12 @@ class FFN(nn.Module):
return x
DISTILBERT_ATTENTION_CLASSES = {
"eager": MultiHeadSelfAttention,
"flash_attention_2": DistilBertFlashAttention2,
}
class TransformerBlock(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
......@@ -479,11 +485,7 @@ class TransformerBlock(nn.Module):
if config.dim % config.n_heads != 0:
raise ValueError(f"config.n_heads {config.n_heads} must divide config.dim {config.dim} evenly")
self.attention = (
MultiHeadSelfAttention(config)
if not getattr(config, "_flash_attn_2_enabled", False)
else DistilBertFlashAttention2(config)
)
self.attention = DISTILBERT_ATTENTION_CLASSES[config._attn_implementation](config)
self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
self.ffn = FFN(config)
......@@ -703,6 +705,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
self.embeddings = Embeddings(config) # Embeddings
self.transformer = Transformer(config) # Encoder
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
# Initialize weights and apply final processing
self.post_init()
......@@ -808,7 +811,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
embeddings = self.embeddings(input_ids, inputs_embeds) # (bs, seq_length, dim)
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
if attention_mask is None:
......
......@@ -16,7 +16,7 @@
import math
import warnings
from typing import Optional, Tuple, Union
from typing import TYPE_CHECKING, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
......@@ -24,7 +24,11 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
from torch.nn import functional as F
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
......@@ -33,6 +37,7 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import is_torch_greater_or_equal_than_2_0
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
......@@ -44,6 +49,9 @@ from ...utils import (
from .configuration_falcon import FalconConfig
if TYPE_CHECKING:
from ...configuration_utils import PretrainedConfig
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
......@@ -278,6 +286,7 @@ class FalconAttention(nn.Module):
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self._use_sdpa = config._attn_implementation == "sdpa"
if self.head_dim * self.num_heads != self.hidden_size:
raise ValueError(
......@@ -439,16 +448,15 @@ class FalconAttention(nn.Module):
present = None
if alibi is None:
if hasattr(F, "scaled_dot_product_attention") and not output_attentions:
# TODO: deprecate this once we add FA2 support in Falcon
logger.warning_once(
"The current implementation of Falcon calls `torch.scaled_dot_product_attention` directly, this will be deprecated in the"
" future in favor of the `BetterTransformer` API. Please install the latest optimum library with `pip install -U optimum` and call "
"`model.to_bettertransformer()` to benefit from `torch.scaled_dot_product_attention` and future performance optimizations."
)
if self._use_sdpa and not output_attentions:
attn_output = F.scaled_dot_product_attention(
query_layer, key_layer, value_layer, attention_mask, 0.0, is_causal=False
query_layer,
key_layer,
value_layer,
attention_mask,
0.0,
# The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1.
is_causal=self.is_causal and attention_mask is None and query_length > 1,
)
attention_scores = None
else:
......@@ -456,58 +464,70 @@ class FalconAttention(nn.Module):
attention_scores /= math.sqrt(self.head_dim)
attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype)
# It is unclear why neither dropout nor head_mask is applied here (while it is with alibi).
attn_output = attention_scores @ value_layer
attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
attn_output = attn_output.permute(0, 2, 1, 3)
attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
output_tensor = self.dense(attn_output)
attn_output = self.dense(attn_output)
if output_attentions:
return output_tensor, present, attention_scores
return attn_output, present, attention_scores
else:
return output_tensor, present
return attn_output, present
else:
matmul_result = query_layer @ key_layer.transpose(-1, -2)
if self._use_sdpa and not output_attentions and head_mask is None:
attn_output = F.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=attention_mask,
dropout_p=self.attention_dropout.p if self.training else 0.0,
is_causal=self.is_causal and attention_mask is None and query_length > 1,
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
attn_output = self.dense(attn_output)
else:
matmul_result = query_layer @ key_layer.transpose(-1, -2)
# change view to [batch_size, num_heads, q_length, kv_length]
attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
# change view to [batch_size, num_heads, q_length, kv_length]
attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
input_dtype = attention_scores.dtype
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
attention_scores = attention_scores.to(torch.float32)
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
input_dtype = attention_scores.dtype
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
attention_scores = attention_scores.to(torch.float32)
# Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by
# adding (alibi * self.inv_norm_factor) to attention_mask. I think this would be mathematically
# equivalent and more performant, but there might be a numerical difference. If you're reading this
# and you'd like to experiment and maybe file a PR, feel free!
attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
attention_logits *= self.inv_norm_factor
attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype)
# [batch_size, num_heads, q_length, kv_length]
attention_probs = self.attention_dropout(attention_probs)
attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
attention_logits *= self.inv_norm_factor
attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype)
# [batch_size, num_heads, q_length, kv_length]
attention_probs = self.attention_dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
if head_mask is not None:
attention_probs = attention_probs * head_mask
# change view [batch_size, num_heads, q_length, kv_length]
attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
# change view [batch_size, num_heads, q_length, kv_length]
attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = (attention_probs_reshaped @ value_layer).flatten(0, 1)
# matmul: [batch_size * num_heads, q_length, head_dim]
attn_output = (attention_probs_reshaped @ value_layer).flatten(0, 1)
# change view [batch_size, q_length, num_heads * head_dim]
context_layer = self._merge_heads(context_layer)
# change view [batch_size, q_length, num_heads * head_dim]
attn_output = self._merge_heads(attn_output)
output_tensor = self.dense(context_layer)
attn_output = self.dense(attn_output)
if output_attentions:
return output_tensor, present, attention_probs
return attn_output, present, attention_probs
else:
return output_tensor, present
return attn_output, present
class FalconFlashAttention2(FalconAttention):
......@@ -734,17 +754,20 @@ class FalconMLP(nn.Module):
return x
FALCON_ATTENTION_CLASSES = {
"eager": FalconAttention,
"sdpa": FalconAttention, # FalconAttention originally implemented both a forward with & without SDPA
"flash_attention_2": FalconFlashAttention2,
}
class FalconDecoderLayer(nn.Module):
def __init__(self, config: FalconConfig):
super().__init__()
hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.self_attention = (
FalconAttention(config)
if not getattr(config, "_flash_attn_2_enabled", False)
else FalconFlashAttention2(config)
)
self.self_attention = FALCON_ATTENTION_CLASSES[config._attn_implementation](config)
self.mlp = FalconMLP(config)
self.hidden_dropout = config.hidden_dropout
self.config = config
......@@ -912,6 +935,7 @@ class FalconPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["FalconDecoderLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
......@@ -932,6 +956,25 @@ class FalconPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
# Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa
@classmethod
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> "PretrainedConfig":
# NOTE: Falcon supported SDPA from PyTorch 2.0. We keep it like that for backward compatibility (automatically use SDPA for torch>=2.0).
if hard_check_only:
if not is_torch_greater_or_equal_than_2_0:
raise ImportError("PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.0.")
if not is_torch_greater_or_equal_than_2_0:
return config
_is_bettertransformer = getattr(cls, "use_bettertransformer", False)
if _is_bettertransformer:
return config
if not hard_check_only:
config._attn_implementation = "sdpa"
return config
@add_start_docstrings(
"The bare Falcon Model transformer outputting raw hidden-states without any specific head on top.",
......@@ -950,6 +993,8 @@ class FalconModel(FalconPreTrainedModel):
# Transformer blocks
self.h = nn.ModuleList([FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_sdpa = config._attn_implementation == "sdpa"
# Final Layer Norm
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
......@@ -1003,12 +1048,6 @@ class FalconModel(FalconPreTrainedModel):
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
......@@ -1047,15 +1086,61 @@ class FalconModel(FalconPreTrainedModel):
)
position_ids = position_ids.unsqueeze(0)
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
if alibi is None:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
elif head_mask is None:
alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
attention_mask_2d = attention_mask
# We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched.
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
# We take care to integrate alibi bias in the attention_mask here.
if attention_mask_2d is None:
attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads)
else:
attention_mask = torch.masked_fill(
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
attention_mask < -1,
torch.finfo(alibi.dtype).min,
)
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
if seq_length > 1:
attention_mask = AttentionMaskConverter._unmask_unattended(
attention_mask, attention_mask_2d, unmasked_value=0.0
)
else:
# PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
......
......@@ -22,6 +22,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
......@@ -128,6 +129,7 @@ class GPTBigCodeAttention(nn.Module):
self.scale_attention_softmax_in_fp32 = (
config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32
)
self.attn_pdrop = config.attn_pdrop
if self.is_cross_attention:
if self.multi_query:
......@@ -359,7 +361,7 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
key = key.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim)
value = value.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim)
attn_dropout = self.config.attn_pdrop if self.training else 0.0
attn_dropout = self.attn_pdrop if self.training else 0.0
softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else query.dtype
upcast = query.dtype != softmax_dtype
......@@ -509,6 +511,137 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
)
class GPTBigCodeSdpaAttention(GPTBigCodeAttention):
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
if head_mask is not None:
# The super dispatch is done in the forward.
raise ValueError(
"PyTorch SDPA does not support head_mask. Please open an issue in Transformers repository."
)
scale = None
if not self.scale_attn_weights:
scale = 1
# MQA models: (batch_size, query_length, num_heads * head_dim)
# MHA models: (batch_size, num_heads, query_length, head_dim)
query_shape = query.shape
batch_size = query_shape[0]
key.shape[-2]
if self.multi_query:
query_length = query_shape[1]
# NOTE: Maybe there is better than this?
query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
# Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions.
key = key.unsqueeze(1)
value = value.unsqueeze(1)
# Although these expand are not numerically useful, PyTorch 2.1 can not dispatch to mem-efficient attention
# and flash attention (No available kernel. Aborting execution.) from the shapes
# query = [batch_size, num_heads, query_length, head_dim]
# key = [batch_size, 1, past_length, head_dim]
# value = [batch_size, 1, past_length, head_dim]
# which is unfortunate. Hopefully can be improved in the future. These expand should not be too expansive as they do not do memory copy.
key = key.expand(-1, self.num_heads, -1, -1)
value = value.expand(-1, self.num_heads, -1, -1)
else:
query_length = query_shape[-1]
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=self.attn_pdrop if self.training else 0.0,
# The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1.
is_causal=self.is_causal and attention_mask is None and query_length > 1,
scale=scale,
)
if self.multi_query:
# (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim)
sdpa_result = sdpa_result.transpose(1, 2)
# Reshape is kind of expensive here, as it does a memory copy,
# but I did not manage to make away without it (logits do not match when using view)
# (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim)
sdpa_result = sdpa_result.reshape(query_shape)
return sdpa_result, None
def forward(
self,
hidden_states: torch.Tensor,
layer_past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Union[
Tuple[torch.Tensor, Optional[torch.Tensor]],
Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
]:
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn") or not self.is_cross_attention:
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`."
)
query = self.q_attn(hidden_states)
key_value = self.c_attn(encoder_hidden_states)
attention_mask = encoder_attention_mask
elif self.multi_query:
query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2)
else:
# Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim),
# i.e., the memory layout is not the same as GPT2.
# This makes the concatenation with past_key_value more efficient.
query, key_value = (
self.c_attn(hidden_states)
.view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
.transpose(1, 2)
.split((self.head_dim, 2 * self.head_dim), dim=3)
)
if layer_past is not None:
key_value = torch.cat((layer_past, key_value), dim=-2)
present = key_value if use_cache else None
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
if not output_attentions and head_mask is None:
# Difference with the original implementation: there is no need to transpose the key here,
# as SDPA expects seq_length to be at index -2 for the key as well
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
else:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"GPTBigCodeModel is using GPTBigCodeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` and `head_mask` not None."
' Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
attn_output, attn_weights = super()._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask)
if not self.multi_query:
attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
if output_attentions:
if self.multi_query:
# Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length)
attn_weights = attn_weights.transpose(1, 2)
outputs += (attn_weights,)
return outputs
class GPTBigCodeMLP(nn.Module):
def __init__(self, intermediate_size, config):
super().__init__()
......@@ -527,6 +660,13 @@ class GPTBigCodeMLP(nn.Module):
return hidden_states
GPTBIGCODE_ATTENTION_CLASSES = {
"eager": GPTBigCodeAttention,
"flash_attention_2": GPTBigCodeFlashAttention2,
"sdpa": GPTBigCodeSdpaAttention,
}
class GPTBigCodeBlock(nn.Module):
def __init__(self, config, layer_idx=None):
super().__init__()
......@@ -534,21 +674,19 @@ class GPTBigCodeBlock(nn.Module):
self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = (
GPTBigCodeAttention(config, layer_idx=layer_idx)
if not getattr(config, "_flash_attn_2_enabled", False)
else GPTBigCodeFlashAttention2(config, layer_idx=layer_idx)
)
self.attn = GPTBIGCODE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
if config.add_cross_attention:
if config.multi_query:
raise NotImplementedError("Cross-attention not implemented for MQA")
self.crossattention = (
GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx)
if not getattr(config, "_flash_attn_2_enabled", False)
else GPTBigCodeFlashAttention2(config, is_cross_attention=True, layer_idx=layer_idx)
self.crossattention = GPTBIGCODE_ATTENTION_CLASSES[config._attn_implementation](
config, is_cross_attention=True, layer_idx=layer_idx
)
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPTBigCodeMLP(self.inner_dim, config)
......@@ -629,6 +767,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
_no_split_modules = ["GPTBigCodeBlock"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
......@@ -770,6 +909,9 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
self.gradient_checkpointing = False
self._use_sdpa = config._attn_implementation == "sdpa"
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
# Initialize weights and apply final processing
self.post_init()
......@@ -850,7 +992,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
key_length = past_length + query_length
self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length]
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask.bool() if (attention_mask is not None and 0 in attention_mask) else None
encoder_attention_mask = (
......@@ -867,7 +1009,34 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
# MQA models: (batch_size, query_length, n_heads, key_length)
# MHA models: (batch_size, n_heads, query_length, key_length)
attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1)
self_attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1)
if self._use_sdpa and head_mask is None and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
if self.multi_query:
# gpt_bigcode using MQA has the bad taste to use a causal mask with shape
# [batch_size, target_length, 1, source_length], not compatible with SDPA, hence this transpose.
self_attention_mask = self_attention_mask.transpose(1, 2)
if query_length > 1 and attention_mask is not None:
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
self_attention_mask = AttentionMaskConverter._unmask_unattended(
self_attention_mask, attention_mask, unmasked_value=True
)
# SDPA with a custom mask is much faster in fp16/fp32 dtype rather than bool. Cast here to floating point instead of at every layer.
dtype = self.wte.weight.dtype
self_attention_mask = torch.where(
self_attention_mask,
torch.full([], 0.0, dtype=dtype, device=self_attention_mask.device),
torch.full(
[], torch.finfo(self.wte.weight.dtype).min, dtype=dtype, device=self_attention_mask.device
),
)
attention_mask = self_attention_mask
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
......
......@@ -487,6 +487,12 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention):
)
GPT_NEO_ATTENTION_CLASSES = {
"eager": GPTNeoSelfAttention,
"flash_attention_2": GPTNeoFlashAttention2,
}
class GPTNeoAttention(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
......@@ -495,11 +501,7 @@ class GPTNeoAttention(nn.Module):
self.attention_type = self.attention_layers[layer_id]
if self.attention_type in ["global", "local"]:
self.attention = (
GPTNeoSelfAttention(config, self.attention_type)
if not getattr(config, "_flash_attn_2_enabled", False)
else GPTNeoFlashAttention2(config, self.attention_type)
)
self.attention = GPT_NEO_ATTENTION_CLASSES[config._attn_implementation](config, self.attention_type)
else:
raise NotImplementedError(
"Only attn layer types 'global' and 'local' exist, but got `config.attention_layers`: "
......@@ -718,6 +720,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.drop = nn.Dropout(float(config.embed_dropout))
self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.gradient_checkpointing = False
......@@ -795,7 +798,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
hidden_states = inputs_embeds + position_embeds
# Attention mask.
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
......
......@@ -658,6 +658,12 @@ class GPTNeoXMLP(nn.Module):
return hidden_states
GPT_NEOX_ATTENTION_CLASSES = {
"eager": GPTNeoXAttention,
"flash_attention_2": GPTNeoXFlashAttention2,
}
class GPTNeoXLayer(nn.Module):
def __init__(self, config):
super().__init__()
......@@ -666,11 +672,7 @@ class GPTNeoXLayer(nn.Module):
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_dropout = nn.Dropout(config.hidden_dropout)
self.post_mlp_dropout = nn.Dropout(config.hidden_dropout)
self.attention = (
GPTNeoXAttention(config)
if not getattr(config, "_flash_attn_2_enabled", False)
else GPTNeoXFlashAttention2(config)
)
self.attention = GPT_NEOX_ATTENTION_CLASSES[config._attn_implementation](config)
self.mlp = GPTNeoXMLP(config)
def forward(
......@@ -785,6 +787,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
self.emb_dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.gradient_checkpointing = False
......@@ -861,7 +864,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
if attention_mask is not None:
assert batch_size > 0, "batch_size has to be defined and > 0"
attention_mask = attention_mask.view(batch_size, -1)
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
attention_mask = attention_mask if 0 in attention_mask else None
else:
# We create a 3D attention mask from a 2D tensor mask.
......
......@@ -29,7 +29,7 @@ from torch.nn import CrossEntropyLoss
from ... import PreTrainedModel
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_outputs import ModelOutput
from ...modeling_utils import PretrainedConfig
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
......@@ -578,6 +578,7 @@ class IdeficsAttention(nn.Module):
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.dropout = dropout
self.is_causal = True
if (self.head_dim * num_heads) != self.hidden_size:
raise ValueError(
......@@ -693,6 +694,8 @@ class IdeficsAttention(nn.Module):
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
......@@ -960,6 +963,7 @@ class IdeficsPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"]
_supports_sdpa = True
def _init_weights(self, module):
# important: this ported version of Idefics isn't meant for training from scratch - only
......@@ -975,6 +979,18 @@ class IdeficsPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
# Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa
@classmethod
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
# We remove the checks on `is_torch_sdpa_available()` and `cls._supports_sdpa` as Falcon supports SDPA from torch==2.0.0 (no requirement on 2.1).
_is_bettertransformer = getattr(cls, "use_bettertransformer", False)
if _is_bettertransformer:
return config
if not hard_check_only:
config._attn_implementation = "sdpa"
return config
LLAMA_INPUTS_DOCSTRING = r"""
Args:
......@@ -1240,7 +1256,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
......
......@@ -34,6 +34,7 @@ from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_attention_mask,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_utils import PreTrainedModel
......@@ -518,7 +519,7 @@ class LlamaFlashAttention2(LlamaAttention):
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = 0.0 if not self.training else self.attention_dropout
dropout_rate = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
......@@ -654,15 +655,99 @@ class LlamaFlashAttention2(LlamaAttention):
)
class LlamaSdpaAttention(LlamaAttention):
"""
Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from LlamaAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
LLAMA_ATTENTION_CLASSES = {
"eager": LlamaAttention,
"flash_attention_2": LlamaFlashAttention2,
"sdpa": LlamaSdpaAttention,
}
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = (
LlamaAttention(config=config, layer_idx=layer_idx)
if not getattr(config, "_flash_attn_2_enabled", False)
else LlamaFlashAttention2(config=config, layer_idx=layer_idx)
)
self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -757,6 +842,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module):
......@@ -862,6 +948,8 @@ class LlamaModel(LlamaPreTrainedModel):
self.layers = nn.ModuleList(
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self._use_sdpa = config._attn_implementation == "sdpa"
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
......@@ -922,9 +1010,18 @@ class LlamaModel(LlamaPreTrainedModel):
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
......
......@@ -232,9 +232,8 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
self.multi_modal_projector = LlavaMultiModalProjector(config)
self.vocab_size = config.vocab_size
use_flash_attention_2 = getattr(config, "_flash_attn_2_enabled", False)
self.language_model = AutoModelForCausalLM.from_config(
config.text_config, use_flash_attention_2=use_flash_attention_2
config.text_config, attn_implementation=config._attn_implementation
)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init()
......
......@@ -325,9 +325,8 @@ class M2M100EncoderLayer(nn.Module):
def __init__(self, config: M2M100Config):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = M2M100_ATTENTION_CLASSES[attn_type](
self.self_attn = M2M100_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
......@@ -392,7 +391,7 @@ class M2M100EncoderLayer(nn.Module):
return outputs
M2M100_ATTENTION_CLASSES = {"default": M2M100Attention}
M2M100_ATTENTION_CLASSES = {"eager": M2M100Attention}
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->M2M100, MBART->M2M100
......@@ -400,9 +399,8 @@ class M2M100DecoderLayer(nn.Module):
def __init__(self, config: M2M100Config):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = M2M100_ATTENTION_CLASSES[attn_type](
self.self_attn = M2M100_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
......@@ -415,7 +413,7 @@ class M2M100DecoderLayer(nn.Module):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = M2M100_ATTENTION_CLASSES[attn_type](
self.encoder_attn = M2M100_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
......
......@@ -272,9 +272,8 @@ class MarianEncoderLayer(nn.Module):
def __init__(self, config: MarianConfig):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = MARIAN_ATTENTION_CLASSES[attn_type](
self.self_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
......@@ -339,7 +338,7 @@ class MarianEncoderLayer(nn.Module):
return outputs
MARIAN_ATTENTION_CLASSES = {"default": MarianAttention}
MARIAN_ATTENTION_CLASSES = {"eager": MarianAttention}
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->Marian, BART->MARIAN
......@@ -348,8 +347,7 @@ class MarianDecoderLayer(nn.Module):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = MARIAN_ATTENTION_CLASSES[attn_type](
self.self_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
......@@ -362,7 +360,7 @@ class MarianDecoderLayer(nn.Module):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = MARIAN_ATTENTION_CLASSES[attn_type](
self.encoder_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
......
......@@ -501,7 +501,7 @@ class MBartFlashAttention2(MBartAttention):
MBART_ATTENTION_CLASSES = {
"default": MBartAttention,
"eager": MBartAttention,
"flash_attention_2": MBartFlashAttention2,
}
......@@ -510,9 +510,8 @@ class MBartEncoderLayer(nn.Module):
def __init__(self, config: MBartConfig):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = MBART_ATTENTION_CLASSES[attn_type](
self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
......@@ -581,9 +580,8 @@ class MBartDecoderLayer(nn.Module):
def __init__(self, config: MBartConfig):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = MBART_ATTENTION_CLASSES[attn_type](
self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
......@@ -596,7 +594,7 @@ class MBartDecoderLayer(nn.Module):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = MBART_ATTENTION_CLASSES[attn_type](
self.encoder_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
......@@ -935,6 +933,7 @@ class MBartEncoder(MBartPreTrainedModel):
embed_dim,
)
self.layers = nn.ModuleList([MBartEncoderLayer(config) for _ in range(config.encoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.layer_norm = nn.LayerNorm(config.d_model)
......@@ -1023,7 +1022,7 @@ class MBartEncoder(MBartPreTrainedModel):
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
attention_mask = attention_mask if 0 in attention_mask else None
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......@@ -1112,6 +1111,7 @@ class MBartDecoder(MBartPreTrainedModel):
config.d_model,
)
self.layers = nn.ModuleList([MBartDecoderLayer(config) for _ in range(config.decoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.layer_norm = nn.LayerNorm(config.d_model)
......@@ -1231,7 +1231,7 @@ class MBartDecoder(MBartPreTrainedModel):
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
......@@ -1242,7 +1242,7 @@ class MBartDecoder(MBartPreTrainedModel):
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......
......@@ -601,15 +601,19 @@ class MistralFlashAttention2(MistralAttention):
)
MISTRAL_ATTENTION_CLASSES = {
"eager": MistralAttention,
"flash_attention_2": MistralFlashAttention2,
}
class MistralDecoderLayer(nn.Module):
def __init__(self, config: MistralConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = (
MistralAttention(config=config, layer_idx=layer_idx)
if not getattr(config, "_flash_attn_2_enabled", False)
else MistralFlashAttention2(config, layer_idx=layer_idx)
)
self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
self.mlp = MistralMLP(config)
self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -807,6 +811,7 @@ class MistralModel(MistralPreTrainedModel):
self.layers = nn.ModuleList(
[MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
......@@ -870,12 +875,7 @@ class MistralModel(MistralPreTrainedModel):
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if (
attention_mask is not None
and hasattr(self.config, "_flash_attn_2_enabled")
and self.config._flash_attn_2_enabled
and use_cache
):
if attention_mask is not None and self._use_flash_attention_2 and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right:
raise ValueError(
......@@ -884,7 +884,7 @@ class MistralModel(MistralPreTrainedModel):
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
......
......@@ -491,15 +491,18 @@ class OptFlashAttention2(OPTAttention):
)
OPT_ATTENTION_CLASSES = {
"eager": OPTAttention,
"flash_attention_2": OptFlashAttention2,
}
class OPTDecoderLayer(nn.Module):
def __init__(self, config: OPTConfig):
super().__init__()
self.embed_dim = config.hidden_size
if not getattr(config, "_flash_attn_2_enabled", False):
self.self_attn = OPTAttention(config=config, is_decoder=True)
else:
self.self_attn = OptFlashAttention2(config=config, is_decoder=True)
self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](config=config, is_decoder=True)
self.do_layer_norm_before = config.do_layer_norm_before
self.dropout = config.dropout
......@@ -732,6 +735,7 @@ class OPTDecoder(OPTPreTrainedModel):
self.final_layer_norm = None
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.gradient_checkpointing = False
# Initialize weights and apply final processing
......@@ -830,7 +834,7 @@ class OPTDecoder(OPTPreTrainedModel):
mask_seq_length = past_key_values_length + seq_length
# embed positions
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
# 2d mask is passed through the layers
causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
attention_mask = (
......
......@@ -267,7 +267,7 @@ class PegasusAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value
PEGASUS_ATTENTION_CLASSES = {"default": PegasusAttention}
PEGASUS_ATTENTION_CLASSES = {"eager": PegasusAttention}
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Pegasus, MBART->PEGASUS
......@@ -275,9 +275,8 @@ class PegasusEncoderLayer(nn.Module):
def __init__(self, config: PegasusConfig):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = PEGASUS_ATTENTION_CLASSES[attn_type](
self.self_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
......@@ -347,9 +346,8 @@ class PegasusDecoderLayer(nn.Module):
def __init__(self, config: PegasusConfig):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = PEGASUS_ATTENTION_CLASSES[attn_type](
self.self_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
......@@ -362,7 +360,7 @@ class PegasusDecoderLayer(nn.Module):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = PEGASUS_ATTENTION_CLASSES[attn_type](
self.encoder_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
......
......@@ -612,14 +612,16 @@ class PhiFlashAttention2(PhiAttention):
)
PHI_ATTENTION_CLASSES = {
"eager": PhiAttention,
"flash_attention_2": PhiFlashAttention2,
}
class PhiDecoderLayer(nn.Module):
def __init__(self, config: PhiConfig, layer_idx: int):
super().__init__()
self.self_attn = (
PhiAttention(config=config, layer_idx=layer_idx)
if not getattr(config, "_flash_attn_2_enabled", False)
else PhiFlashAttention2(config=config, layer_idx=layer_idx)
)
self.self_attn = PHI_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
self.mlp = PhiMLP(config)
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
......@@ -813,6 +815,7 @@ class PhiModel(PhiPreTrainedModel):
[PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.gradient_checkpointing = False
# Initialize weights and apply final processing
......@@ -876,7 +879,7 @@ class PhiModel(PhiPreTrainedModel):
inputs_embeds = self.embed_dropout(inputs_embeds)
# Attention mask.
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
......
......@@ -23,7 +23,12 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
......@@ -265,9 +270,8 @@ class PLBartEncoderLayer(nn.Module):
def __init__(self, config: PLBartConfig):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = PLBART_ATTENTION_CLASSES[attn_type](
self.self_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
......@@ -332,7 +336,8 @@ class PLBartEncoderLayer(nn.Module):
return outputs
PLBART_ATTENTION_CLASSES = {"default": PLBartAttention}
# TODO: Implement attention with SDPA for PLBart.
PLBART_ATTENTION_CLASSES = {"eager": PLBartAttention}
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->PLBart, BART->PLBART
......@@ -341,8 +346,7 @@ class PLBartDecoderLayer(nn.Module):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = PLBART_ATTENTION_CLASSES[attn_type](
self.self_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
......@@ -355,7 +359,7 @@ class PLBartDecoderLayer(nn.Module):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = PLBART_ATTENTION_CLASSES[attn_type](
self.encoder_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
......@@ -670,6 +674,8 @@ class PLBartEncoder(PLBartPreTrainedModel):
embed_dim,
)
self.layers = nn.ModuleList([PLBartEncoderLayer(config) for _ in range(config.encoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_sdpa = config._attn_implementation == "sdpa"
self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.gradient_checkpointing = False
......@@ -757,8 +763,13 @@ class PLBartEncoder(PLBartPreTrainedModel):
# expand attention_mask
if attention_mask is not None:
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
attention_mask = attention_mask if 0 in attention_mask else None
elif self._use_sdpa and head_mask is None and not output_attentions:
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
# the manual implementation that requires a 4D causal mask in all cases.
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
......@@ -846,6 +857,9 @@ class PLBartDecoder(PLBartPreTrainedModel):
config.d_model,
)
self.layers = nn.ModuleList([PLBartDecoderLayer(config) for _ in range(config.decoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_sdpa = config._attn_implementation == "sdpa"
self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.gradient_checkpointing = False
......@@ -964,9 +978,18 @@ class PLBartDecoder(PLBartPreTrainedModel):
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input) * self.embed_scale
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
......@@ -975,8 +998,17 @@ class PLBartDecoder(PLBartPreTrainedModel):
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
encoder_attention_mask,
inputs_embeds.dtype,
tgt_len=input_shape[-1],
)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
......
......@@ -30,7 +30,12 @@ from ...modeling_outputs import (
Seq2SeqModelOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_speech_to_text import Speech2TextConfig
......@@ -326,7 +331,7 @@ class Speech2TextAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value
SPEECH_TO_TEXT_ATTENTION_CLASSES = {"default": Speech2TextAttention}
SPEECH_TO_TEXT_ATTENTION_CLASSES = {"eager": Speech2TextAttention}
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT
......@@ -334,9 +339,8 @@ class Speech2TextEncoderLayer(nn.Module):
def __init__(self, config: Speech2TextConfig):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[attn_type](
self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
......@@ -406,9 +410,8 @@ class Speech2TextDecoderLayer(nn.Module):
def __init__(self, config: Speech2TextConfig):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[attn_type](
self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
......@@ -421,7 +424,7 @@ class Speech2TextDecoderLayer(nn.Module):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[attn_type](
self.encoder_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
......
......@@ -32,7 +32,12 @@ from ...modeling_outputs import (
)
from ...modeling_utils import PreTrainedModel
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_time_series_transformer import TimeSeriesTransformerConfig
......@@ -436,9 +441,8 @@ class TimeSeriesTransformerEncoderLayer(nn.Module):
def __init__(self, config: TimeSeriesTransformerConfig):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[attn_type](
self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
......@@ -503,7 +507,10 @@ class TimeSeriesTransformerEncoderLayer(nn.Module):
return outputs
TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES = {"default": TimeSeriesTransformerAttention}
# TODO: Implement attention with SDPA for TimeSeriesTransformer.
TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES = {
"eager": TimeSeriesTransformerAttention,
}
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->TimeSeriesTransformer, with BART->TIME_SERIES_TRANSFORMER
......@@ -512,8 +519,7 @@ class TimeSeriesTransformerDecoderLayer(nn.Module):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[attn_type](
self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
......@@ -526,7 +532,7 @@ class TimeSeriesTransformerDecoderLayer(nn.Module):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[attn_type](
self.encoder_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
......
......@@ -28,7 +28,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...generation.logits_process import WhisperTimeStampLogitsProcessor
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
......@@ -690,9 +690,111 @@ class WhisperFlashAttention2(WhisperAttention):
)
class WhisperSdpaAttention(WhisperAttention):
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with BART->whisper, Bart->Whisper
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
if output_attentions or layer_head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"WhisperModel is using WhisperSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states,
key_value_states=key_value_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
query_states = self._shape(query_states, tgt_len, bsz)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
)
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None, past_key_value
WHISPER_ATTENTION_CLASSES = {
"default": WhisperAttention,
"eager": WhisperAttention,
"flash_attention_2": WhisperFlashAttention2,
"sdpa": WhisperSdpaAttention,
}
......@@ -701,9 +803,8 @@ class WhisperEncoderLayer(nn.Module):
def __init__(self, config: WhisperConfig):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = WHISPER_ATTENTION_CLASSES[attn_type](
self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
......@@ -773,9 +874,8 @@ class WhisperDecoderLayer(nn.Module):
def __init__(self, config: WhisperConfig):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = WHISPER_ATTENTION_CLASSES[attn_type](
self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
......@@ -788,7 +888,7 @@ class WhisperDecoderLayer(nn.Module):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = WHISPER_ATTENTION_CLASSES[attn_type](
self.encoder_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
......@@ -897,6 +997,7 @@ class WhisperPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module):
std = self.config.init_std
......@@ -1227,6 +1328,8 @@ class WhisperDecoder(WhisperPreTrainedModel):
self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model)
self.layers = nn.ModuleList([WhisperDecoderLayer(config) for _ in range(config.decoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_sdpa = config._attn_implementation == "sdpa"
self.layer_norm = nn.LayerNorm(config.d_model)
......@@ -1336,9 +1439,14 @@ class WhisperDecoder(WhisperPreTrainedModel):
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and head_mask is None and not output_attentions:
# output_attentions=True & head_mask can not be supported when using SDPA.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
......
......@@ -107,6 +107,7 @@ from .utils import (
is_torch_fp16_available_on_device,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_sdpa_available,
is_torch_tensorrt_fx_available,
is_torch_tf32_available,
is_torch_tpu_available,
......@@ -440,6 +441,15 @@ def require_flash_attn(test_case):
return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case)
def require_torch_sdpa(test_case):
"""
Decorator marking a test that requires PyTorch's SDPA.
These tests are skipped when requirements are not met (torch version).
"""
return unittest.skipUnless(is_torch_sdpa_available(), "test requires PyTorch SDPA")(test_case)
def require_peft(test_case):
"""
Decorator marking a test that requires PEFT.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment