Unverified Commit af3de8d8 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Whisper, Bart, MBart] Add Flash Attention 2 (#27203)



* add whisper fa2

* correct

* change all

* correct

* correct

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fix more

* fix more

* fix more

* fix more

* fix more

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 3520e37e
...@@ -19,6 +19,7 @@ import warnings ...@@ -19,6 +19,7 @@ import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
...@@ -40,12 +41,18 @@ from ...utils import ( ...@@ -40,12 +41,18 @@ from ...utils import (
add_end_docstrings, add_end_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from .configuration_bart import BartConfig from .configuration_bart import BartConfig
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "facebook/bart-base" _CHECKPOINT_FOR_DOC = "facebook/bart-base"
...@@ -71,6 +78,19 @@ BART_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -71,6 +78,19 @@ BART_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
""" """
Shift input ids one token to the right. Shift input ids one token to the right.
...@@ -119,12 +139,15 @@ class BartAttention(nn.Module): ...@@ -119,12 +139,15 @@ class BartAttention(nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
is_decoder: bool = False, is_decoder: bool = False,
bias: bool = True, bias: bool = True,
is_causal: bool = False,
config: Optional[BartConfig] = None,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim: if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError( raise ValueError(
...@@ -133,6 +156,7 @@ class BartAttention(nn.Module): ...@@ -133,6 +156,7 @@ class BartAttention(nn.Module):
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
...@@ -263,14 +287,225 @@ class BartAttention(nn.Module): ...@@ -263,14 +287,225 @@ class BartAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value return attn_output, attn_weights_reshaped, past_key_value
class BartFlashAttention2(BartAttention):
"""
Bart flash attention module. This module inherits from `BartAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# BartFlashAttention2 attention does not support output_attentions
if output_attentions:
raise ValueError("BartFlashAttention2 attention does not support output_attentions")
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, q_len, _ = hidden_states.size()
# get query proj
query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0].transpose(1, 2)
value_states = past_key_value[1].transpose(1, 2)
elif is_cross_attention:
# cross_attentions
key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
else:
# self_attention
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
)
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`int`, *optional*):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=self.is_causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal
)
return attn_output
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
BART_ATTENTION_CLASSES = {
"default": BartAttention,
"flash_attention_2": BartFlashAttention2,
}
class BartEncoderLayer(nn.Module): class BartEncoderLayer(nn.Module):
def __init__(self, config: BartConfig): def __init__(self, config: BartConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
self.self_attn = BartAttention( attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = BART_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads, num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
config=config,
) )
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout self.dropout = config.dropout
...@@ -336,22 +571,26 @@ class BartDecoderLayer(nn.Module): ...@@ -336,22 +571,26 @@ class BartDecoderLayer(nn.Module):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
self.self_attn = BartAttention( attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = BART_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads, num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
is_decoder=True, is_decoder=True,
is_causal=True,
config=config,
) )
self.dropout = config.dropout self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function] self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = BartAttention( self.encoder_attn = BART_ATTENTION_CLASSES[attn_type](
self.embed_dim, self.embed_dim,
config.decoder_attention_heads, config.decoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
is_decoder=True, is_decoder=True,
config=config,
) )
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
...@@ -479,6 +718,7 @@ class BartPreTrainedModel(PreTrainedModel): ...@@ -479,6 +718,7 @@ class BartPreTrainedModel(PreTrainedModel):
_keys_to_ignore_on_load_unexpected = ["encoder.version", "decoder.version"] _keys_to_ignore_on_load_unexpected = ["encoder.version", "decoder.version"]
_no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"] _no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"]
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
...@@ -792,8 +1032,11 @@ class BartEncoder(BartPreTrainedModel): ...@@ -792,8 +1032,11 @@ class BartEncoder(BartPreTrainedModel):
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] if getattr(self.config, "_flash_attn_2_enabled", False):
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) attention_mask = attention_mask if 0 in attention_mask else None
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
encoder_states = () if output_hidden_states else None encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
...@@ -995,16 +1238,24 @@ class BartDecoder(BartPreTrainedModel): ...@@ -995,16 +1238,24 @@ class BartDecoder(BartPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input) * self.embed_scale inputs_embeds = self.embed_tokens(input) * self.embed_scale
attention_mask = _prepare_4d_causal_attention_mask( if getattr(self.config, "_flash_attn_2_enabled", False):
attention_mask, input_shape, inputs_embeds, past_key_values_length # 2d mask is passed through the layers
) attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] if getattr(self.config, "_flash_attn_2_enabled", False):
encoder_attention_mask = _prepare_4d_attention_mask( encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] else:
) # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# embed positions # embed positions
positions = self.embed_positions(input, past_key_values_length) positions = self.embed_positions(input, past_key_values_length)
......
...@@ -1174,7 +1174,7 @@ class BigBirdPegasusEncoderAttention(nn.Module): ...@@ -1174,7 +1174,7 @@ class BigBirdPegasusEncoderAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->BigBirdPegasusDecoder # Copied from transformers.models.bart.modeling_bart.BartAttention with BartConfig->BigBirdPegasusConfig, Bart->BigBirdPegasusDecoder
class BigBirdPegasusDecoderAttention(nn.Module): class BigBirdPegasusDecoderAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
...@@ -1185,12 +1185,15 @@ class BigBirdPegasusDecoderAttention(nn.Module): ...@@ -1185,12 +1185,15 @@ class BigBirdPegasusDecoderAttention(nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
is_decoder: bool = False, is_decoder: bool = False,
bias: bool = True, bias: bool = True,
is_causal: bool = False,
config: Optional[BigBirdPegasusConfig] = None,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim: if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError( raise ValueError(
...@@ -1199,6 +1202,7 @@ class BigBirdPegasusDecoderAttention(nn.Module): ...@@ -1199,6 +1202,7 @@ class BigBirdPegasusDecoderAttention(nn.Module):
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
......
...@@ -90,12 +90,15 @@ class BioGptAttention(nn.Module): ...@@ -90,12 +90,15 @@ class BioGptAttention(nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
is_decoder: bool = False, is_decoder: bool = False,
bias: bool = True, bias: bool = True,
is_causal: bool = False,
config: Optional[BioGptConfig] = None,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim: if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError( raise ValueError(
...@@ -104,6 +107,7 @@ class BioGptAttention(nn.Module): ...@@ -104,6 +107,7 @@ class BioGptAttention(nn.Module):
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
......
...@@ -104,12 +104,15 @@ class BlenderbotAttention(nn.Module): ...@@ -104,12 +104,15 @@ class BlenderbotAttention(nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
is_decoder: bool = False, is_decoder: bool = False,
bias: bool = True, bias: bool = True,
is_causal: bool = False,
config: Optional[BlenderbotConfig] = None,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim: if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError( raise ValueError(
...@@ -118,6 +121,7 @@ class BlenderbotAttention(nn.Module): ...@@ -118,6 +121,7 @@ class BlenderbotAttention(nn.Module):
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
...@@ -248,15 +252,21 @@ class BlenderbotAttention(nn.Module): ...@@ -248,15 +252,21 @@ class BlenderbotAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot BLENDERBOT_ATTENTION_CLASSES = {"default": BlenderbotAttention}
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot, MBART->BLENDERBOT
class BlenderbotEncoderLayer(nn.Module): class BlenderbotEncoderLayer(nn.Module):
def __init__(self, config: BlenderbotConfig): def __init__(self, config: BlenderbotConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
self.self_attn = BlenderbotAttention( attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = BLENDERBOT_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads, num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
config=config,
) )
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout self.dropout = config.dropout
...@@ -317,28 +327,32 @@ class BlenderbotEncoderLayer(nn.Module): ...@@ -317,28 +327,32 @@ class BlenderbotEncoderLayer(nn.Module):
return outputs return outputs
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Blenderbot # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Blenderbot, MBART->BLENDERBOT
class BlenderbotDecoderLayer(nn.Module): class BlenderbotDecoderLayer(nn.Module):
def __init__(self, config: BlenderbotConfig): def __init__(self, config: BlenderbotConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = BlenderbotAttention( self.self_attn = BLENDERBOT_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads, num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
is_decoder=True, is_decoder=True,
is_causal=True,
config=config,
) )
self.dropout = config.dropout self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function] self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = BlenderbotAttention( self.encoder_attn = BLENDERBOT_ATTENTION_CLASSES[attn_type](
self.embed_dim, self.embed_dim,
config.decoder_attention_heads, config.decoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
is_decoder=True, is_decoder=True,
config=config,
) )
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
......
...@@ -101,12 +101,15 @@ class BlenderbotSmallAttention(nn.Module): ...@@ -101,12 +101,15 @@ class BlenderbotSmallAttention(nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
is_decoder: bool = False, is_decoder: bool = False,
bias: bool = True, bias: bool = True,
is_causal: bool = False,
config: Optional[BlenderbotSmallConfig] = None,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim: if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError( raise ValueError(
...@@ -115,6 +118,7 @@ class BlenderbotSmallAttention(nn.Module): ...@@ -115,6 +118,7 @@ class BlenderbotSmallAttention(nn.Module):
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
...@@ -245,15 +249,18 @@ class BlenderbotSmallAttention(nn.Module): ...@@ -245,15 +249,18 @@ class BlenderbotSmallAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->BlenderbotSmall # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL
class BlenderbotSmallEncoderLayer(nn.Module): class BlenderbotSmallEncoderLayer(nn.Module):
def __init__(self, config: BlenderbotSmallConfig): def __init__(self, config: BlenderbotSmallConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
self.self_attn = BlenderbotSmallAttention( attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads, num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
config=config,
) )
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout self.dropout = config.dropout
...@@ -314,28 +321,35 @@ class BlenderbotSmallEncoderLayer(nn.Module): ...@@ -314,28 +321,35 @@ class BlenderbotSmallEncoderLayer(nn.Module):
return outputs return outputs
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->BlenderbotSmall BLENDERBOT_SMALL_ATTENTION_CLASSES = {"default": BlenderbotSmallAttention}
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL
class BlenderbotSmallDecoderLayer(nn.Module): class BlenderbotSmallDecoderLayer(nn.Module):
def __init__(self, config: BlenderbotSmallConfig): def __init__(self, config: BlenderbotSmallConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
self.self_attn = BlenderbotSmallAttention( attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads, num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
is_decoder=True, is_decoder=True,
is_causal=True,
config=config,
) )
self.dropout = config.dropout self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function] self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = BlenderbotSmallAttention( self.encoder_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[attn_type](
self.embed_dim, self.embed_dim,
config.decoder_attention_heads, config.decoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
is_decoder=True, is_decoder=True,
config=config,
) )
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
......
...@@ -330,12 +330,15 @@ class Data2VecAudioAttention(nn.Module): ...@@ -330,12 +330,15 @@ class Data2VecAudioAttention(nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
is_decoder: bool = False, is_decoder: bool = False,
bias: bool = True, bias: bool = True,
is_causal: bool = False,
config: Optional[Data2VecAudioConfig] = None,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim: if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError( raise ValueError(
...@@ -344,6 +347,7 @@ class Data2VecAudioAttention(nn.Module): ...@@ -344,6 +347,7 @@ class Data2VecAudioAttention(nn.Module):
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
......
...@@ -370,12 +370,15 @@ class GPTSanJapaneseAttention(nn.Module): ...@@ -370,12 +370,15 @@ class GPTSanJapaneseAttention(nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
is_decoder: bool = False, is_decoder: bool = False,
bias: bool = True, bias: bool = True,
is_causal: bool = False,
config: Optional[GPTSanJapaneseConfig] = None,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim: if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError( raise ValueError(
...@@ -384,6 +387,7 @@ class GPTSanJapaneseAttention(nn.Module): ...@@ -384,6 +387,7 @@ class GPTSanJapaneseAttention(nn.Module):
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
......
...@@ -396,12 +396,15 @@ class HubertAttention(nn.Module): ...@@ -396,12 +396,15 @@ class HubertAttention(nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
is_decoder: bool = False, is_decoder: bool = False,
bias: bool = True, bias: bool = True,
is_causal: bool = False,
config: Optional[HubertConfig] = None,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim: if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError( raise ValueError(
...@@ -410,6 +413,7 @@ class HubertAttention(nn.Module): ...@@ -410,6 +413,7 @@ class HubertAttention(nn.Module):
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
......
...@@ -287,12 +287,15 @@ class InformerAttention(nn.Module): ...@@ -287,12 +287,15 @@ class InformerAttention(nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
is_decoder: bool = False, is_decoder: bool = False,
bias: bool = True, bias: bool = True,
is_causal: bool = False,
config: Optional[InformerConfig] = None,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim: if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError( raise ValueError(
...@@ -301,6 +304,7 @@ class InformerAttention(nn.Module): ...@@ -301,6 +304,7 @@ class InformerAttention(nn.Module):
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
......
...@@ -172,12 +172,15 @@ class M2M100Attention(nn.Module): ...@@ -172,12 +172,15 @@ class M2M100Attention(nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
is_decoder: bool = False, is_decoder: bool = False,
bias: bool = True, bias: bool = True,
is_causal: bool = False,
config: Optional[M2M100Config] = None,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim: if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError( raise ValueError(
...@@ -186,6 +189,7 @@ class M2M100Attention(nn.Module): ...@@ -186,6 +189,7 @@ class M2M100Attention(nn.Module):
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
...@@ -316,15 +320,18 @@ class M2M100Attention(nn.Module): ...@@ -316,15 +320,18 @@ class M2M100Attention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->M2M100 # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->M2M100, MBART->M2M100
class M2M100EncoderLayer(nn.Module): class M2M100EncoderLayer(nn.Module):
def __init__(self, config: M2M100Config): def __init__(self, config: M2M100Config):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
self.self_attn = M2M100Attention( attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = M2M100_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads, num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
config=config,
) )
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout self.dropout = config.dropout
...@@ -385,28 +392,35 @@ class M2M100EncoderLayer(nn.Module): ...@@ -385,28 +392,35 @@ class M2M100EncoderLayer(nn.Module):
return outputs return outputs
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->M2M100 M2M100_ATTENTION_CLASSES = {"default": M2M100Attention}
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->M2M100, MBART->M2M100
class M2M100DecoderLayer(nn.Module): class M2M100DecoderLayer(nn.Module):
def __init__(self, config: M2M100Config): def __init__(self, config: M2M100Config):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = M2M100Attention( self.self_attn = M2M100_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads, num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
is_decoder=True, is_decoder=True,
is_causal=True,
config=config,
) )
self.dropout = config.dropout self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function] self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = M2M100Attention( self.encoder_attn = M2M100_ATTENTION_CLASSES[attn_type](
self.embed_dim, self.embed_dim,
config.decoder_attention_heads, config.decoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
is_decoder=True, is_decoder=True,
config=config,
) )
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
......
...@@ -119,12 +119,15 @@ class MarianAttention(nn.Module): ...@@ -119,12 +119,15 @@ class MarianAttention(nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
is_decoder: bool = False, is_decoder: bool = False,
bias: bool = True, bias: bool = True,
is_causal: bool = False,
config: Optional[MarianConfig] = None,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim: if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError( raise ValueError(
...@@ -133,6 +136,7 @@ class MarianAttention(nn.Module): ...@@ -133,6 +136,7 @@ class MarianAttention(nn.Module):
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
...@@ -263,15 +267,18 @@ class MarianAttention(nn.Module): ...@@ -263,15 +267,18 @@ class MarianAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->Marian # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->Marian, BART->MARIAN
class MarianEncoderLayer(nn.Module): class MarianEncoderLayer(nn.Module):
def __init__(self, config: MarianConfig): def __init__(self, config: MarianConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
self.self_attn = MarianAttention( attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = MARIAN_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads, num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
config=config,
) )
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout self.dropout = config.dropout
...@@ -332,28 +339,35 @@ class MarianEncoderLayer(nn.Module): ...@@ -332,28 +339,35 @@ class MarianEncoderLayer(nn.Module):
return outputs return outputs
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->Marian MARIAN_ATTENTION_CLASSES = {"default": MarianAttention}
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->Marian, BART->MARIAN
class MarianDecoderLayer(nn.Module): class MarianDecoderLayer(nn.Module):
def __init__(self, config: MarianConfig): def __init__(self, config: MarianConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
self.self_attn = MarianAttention( attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = MARIAN_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads, num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
is_decoder=True, is_decoder=True,
is_causal=True,
config=config,
) )
self.dropout = config.dropout self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function] self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = MarianAttention( self.encoder_attn = MARIAN_ATTENTION_CLASSES[attn_type](
self.embed_dim, self.embed_dim,
config.decoder_attention_heads, config.decoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
is_decoder=True, is_decoder=True,
config=config,
) )
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
......
...@@ -18,6 +18,7 @@ import math ...@@ -18,6 +18,7 @@ import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
...@@ -39,12 +40,18 @@ from ...utils import ( ...@@ -39,12 +40,18 @@ from ...utils import (
add_end_docstrings, add_end_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from .configuration_mbart import MBartConfig from .configuration_mbart import MBartConfig
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "facebook/mbart-large-cc25" _CHECKPOINT_FOR_DOC = "facebook/mbart-large-cc25"
...@@ -59,6 +66,19 @@ MBART_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -59,6 +66,19 @@ MBART_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int): def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
""" """
Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not
...@@ -113,12 +133,15 @@ class MBartAttention(nn.Module): ...@@ -113,12 +133,15 @@ class MBartAttention(nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
is_decoder: bool = False, is_decoder: bool = False,
bias: bool = True, bias: bool = True,
is_causal: bool = False,
config: Optional[MBartConfig] = None,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim: if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError( raise ValueError(
...@@ -127,6 +150,7 @@ class MBartAttention(nn.Module): ...@@ -127,6 +150,7 @@ class MBartAttention(nn.Module):
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
...@@ -257,14 +281,226 @@ class MBartAttention(nn.Module): ...@@ -257,14 +281,226 @@ class MBartAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->MBart
class MBartFlashAttention2(MBartAttention):
"""
MBart flash attention module. This module inherits from `MBartAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# MBartFlashAttention2 attention does not support output_attentions
if output_attentions:
raise ValueError("MBartFlashAttention2 attention does not support output_attentions")
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, q_len, _ = hidden_states.size()
# get query proj
query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0].transpose(1, 2)
value_states = past_key_value[1].transpose(1, 2)
elif is_cross_attention:
# cross_attentions
key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
else:
# self_attention
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
)
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`int`, *optional*):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=self.is_causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal
)
return attn_output
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
MBART_ATTENTION_CLASSES = {
"default": MBartAttention,
"flash_attention_2": MBartFlashAttention2,
}
class MBartEncoderLayer(nn.Module): class MBartEncoderLayer(nn.Module):
def __init__(self, config: MBartConfig): def __init__(self, config: MBartConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
self.self_attn = MBartAttention( attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = MBART_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads, num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
config=config,
) )
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout self.dropout = config.dropout
...@@ -329,23 +565,27 @@ class MBartDecoderLayer(nn.Module): ...@@ -329,23 +565,27 @@ class MBartDecoderLayer(nn.Module):
def __init__(self, config: MBartConfig): def __init__(self, config: MBartConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = MBartAttention( self.self_attn = MBART_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads, num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
is_decoder=True, is_decoder=True,
is_causal=True,
config=config,
) )
self.dropout = config.dropout self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function] self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = MBartAttention( self.encoder_attn = MBART_ATTENTION_CLASSES[attn_type](
self.embed_dim, self.embed_dim,
config.decoder_attention_heads, config.decoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
is_decoder=True, is_decoder=True,
config=config,
) )
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
...@@ -472,6 +712,7 @@ class MBartPreTrainedModel(PreTrainedModel): ...@@ -472,6 +712,7 @@ class MBartPreTrainedModel(PreTrainedModel):
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["MBartDecoderLayer", "MBartAttention"] _no_split_modules = ["MBartDecoderLayer", "MBartAttention"]
_supports_flash_attn_2 = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
...@@ -766,7 +1007,11 @@ class MBartEncoder(MBartPreTrainedModel): ...@@ -766,7 +1007,11 @@ class MBartEncoder(MBartPreTrainedModel):
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) if getattr(self.config, "_flash_attn_2_enabled", False):
attention_mask = attention_mask if 0 in attention_mask else None
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
encoder_states = () if output_hidden_states else None encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
...@@ -970,16 +1215,24 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -970,16 +1215,24 @@ class MBartDecoder(MBartPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
attention_mask = _prepare_4d_causal_attention_mask( if getattr(self.config, "_flash_attn_2_enabled", False):
attention_mask, input_shape, inputs_embeds, past_key_values_length # 2d mask is passed through the layers
) attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] if getattr(self.config, "_flash_attn_2_enabled", False):
encoder_attention_mask = _prepare_4d_attention_mask( encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] else:
) # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# embed positions # embed positions
positions = self.embed_positions(input, past_key_values_length) positions = self.embed_positions(input, past_key_values_length)
......
...@@ -145,7 +145,7 @@ class MusicgenSinusoidalPositionalEmbedding(nn.Module): ...@@ -145,7 +145,7 @@ class MusicgenSinusoidalPositionalEmbedding(nn.Module):
return self.weights.index_select(0, position_ids.view(-1)).detach() return self.weights.index_select(0, position_ids.view(-1)).detach()
# Copied from transformers.models.bart.modeling_bart.BartAttention # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Musicgen
class MusicgenAttention(nn.Module): class MusicgenAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
...@@ -156,12 +156,15 @@ class MusicgenAttention(nn.Module): ...@@ -156,12 +156,15 @@ class MusicgenAttention(nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
is_decoder: bool = False, is_decoder: bool = False,
bias: bool = True, bias: bool = True,
is_causal: bool = False,
config: Optional[MusicgenConfig] = None,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim: if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError( raise ValueError(
...@@ -170,6 +173,7 @@ class MusicgenAttention(nn.Module): ...@@ -170,6 +173,7 @@ class MusicgenAttention(nn.Module):
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
......
...@@ -467,12 +467,15 @@ class NllbMoeAttention(nn.Module): ...@@ -467,12 +467,15 @@ class NllbMoeAttention(nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
is_decoder: bool = False, is_decoder: bool = False,
bias: bool = True, bias: bool = True,
is_causal: bool = False,
config: Optional[NllbMoeConfig] = None,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim: if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError( raise ValueError(
...@@ -481,6 +484,7 @@ class NllbMoeAttention(nn.Module): ...@@ -481,6 +484,7 @@ class NllbMoeAttention(nn.Module):
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
......
...@@ -119,12 +119,15 @@ class PegasusAttention(nn.Module): ...@@ -119,12 +119,15 @@ class PegasusAttention(nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
is_decoder: bool = False, is_decoder: bool = False,
bias: bool = True, bias: bool = True,
is_causal: bool = False,
config: Optional[PegasusConfig] = None,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim: if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError( raise ValueError(
...@@ -133,6 +136,7 @@ class PegasusAttention(nn.Module): ...@@ -133,6 +136,7 @@ class PegasusAttention(nn.Module):
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
...@@ -263,15 +267,21 @@ class PegasusAttention(nn.Module): ...@@ -263,15 +267,21 @@ class PegasusAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Pegasus PEGASUS_ATTENTION_CLASSES = {"default": PegasusAttention}
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Pegasus, MBART->PEGASUS
class PegasusEncoderLayer(nn.Module): class PegasusEncoderLayer(nn.Module):
def __init__(self, config: PegasusConfig): def __init__(self, config: PegasusConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
self.self_attn = PegasusAttention( attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = PEGASUS_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads, num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
config=config,
) )
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout self.dropout = config.dropout
...@@ -332,28 +342,32 @@ class PegasusEncoderLayer(nn.Module): ...@@ -332,28 +342,32 @@ class PegasusEncoderLayer(nn.Module):
return outputs return outputs
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Pegasus # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Pegasus, MBART->PEGASUS
class PegasusDecoderLayer(nn.Module): class PegasusDecoderLayer(nn.Module):
def __init__(self, config: PegasusConfig): def __init__(self, config: PegasusConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = PegasusAttention( self.self_attn = PEGASUS_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads, num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
is_decoder=True, is_decoder=True,
is_causal=True,
config=config,
) )
self.dropout = config.dropout self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function] self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = PegasusAttention( self.encoder_attn = PEGASUS_ATTENTION_CLASSES[attn_type](
self.embed_dim, self.embed_dim,
config.decoder_attention_heads, config.decoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
is_decoder=True, is_decoder=True,
config=config,
) )
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
......
...@@ -128,12 +128,15 @@ class PegasusXAttention(nn.Module): ...@@ -128,12 +128,15 @@ class PegasusXAttention(nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
is_decoder: bool = False, is_decoder: bool = False,
bias: bool = True, bias: bool = True,
is_causal: bool = False,
config: Optional[PegasusXConfig] = None,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim: if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError( raise ValueError(
...@@ -142,6 +145,7 @@ class PegasusXAttention(nn.Module): ...@@ -142,6 +145,7 @@ class PegasusXAttention(nn.Module):
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
......
...@@ -112,12 +112,15 @@ class PLBartAttention(nn.Module): ...@@ -112,12 +112,15 @@ class PLBartAttention(nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
is_decoder: bool = False, is_decoder: bool = False,
bias: bool = True, bias: bool = True,
is_causal: bool = False,
config: Optional[PLBartConfig] = None,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim: if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError( raise ValueError(
...@@ -126,6 +129,7 @@ class PLBartAttention(nn.Module): ...@@ -126,6 +129,7 @@ class PLBartAttention(nn.Module):
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
...@@ -256,15 +260,18 @@ class PLBartAttention(nn.Module): ...@@ -256,15 +260,18 @@ class PLBartAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->PLBart # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->PLBart, BART->PLBART
class PLBartEncoderLayer(nn.Module): class PLBartEncoderLayer(nn.Module):
def __init__(self, config: PLBartConfig): def __init__(self, config: PLBartConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
self.self_attn = PLBartAttention( attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = PLBART_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads, num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
config=config,
) )
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout self.dropout = config.dropout
...@@ -325,28 +332,35 @@ class PLBartEncoderLayer(nn.Module): ...@@ -325,28 +332,35 @@ class PLBartEncoderLayer(nn.Module):
return outputs return outputs
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->PLBart PLBART_ATTENTION_CLASSES = {"default": PLBartAttention}
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->PLBart, BART->PLBART
class PLBartDecoderLayer(nn.Module): class PLBartDecoderLayer(nn.Module):
def __init__(self, config: PLBartConfig): def __init__(self, config: PLBartConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
self.self_attn = PLBartAttention( attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = PLBART_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads, num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
is_decoder=True, is_decoder=True,
is_causal=True,
config=config,
) )
self.dropout = config.dropout self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function] self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = PLBartAttention( self.encoder_attn = PLBART_ATTENTION_CLASSES[attn_type](
self.embed_dim, self.embed_dim,
config.decoder_attention_heads, config.decoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
is_decoder=True, is_decoder=True,
config=config,
) )
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
...@@ -743,8 +757,11 @@ class PLBartEncoder(PLBartPreTrainedModel): ...@@ -743,8 +757,11 @@ class PLBartEncoder(PLBartPreTrainedModel):
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] if getattr(self.config, "_flash_attn_2_enabled", False):
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) attention_mask = attention_mask if 0 in attention_mask else None
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
encoder_states = () if output_hidden_states else None encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
...@@ -947,16 +964,24 @@ class PLBartDecoder(PLBartPreTrainedModel): ...@@ -947,16 +964,24 @@ class PLBartDecoder(PLBartPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input) * self.embed_scale inputs_embeds = self.embed_tokens(input) * self.embed_scale
attention_mask = _prepare_4d_causal_attention_mask( if getattr(self.config, "_flash_attn_2_enabled", False):
attention_mask, input_shape, inputs_embeds, past_key_values_length # 2d mask is passed through the layers
) attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] if getattr(self.config, "_flash_attn_2_enabled", False):
encoder_attention_mask = _prepare_4d_attention_mask( encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] else:
) # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# embed positions # embed positions
positions = self.embed_positions(input, past_key_values_length) positions = self.embed_positions(input, past_key_values_length)
......
...@@ -1101,12 +1101,15 @@ class SeamlessM4TAttention(nn.Module): ...@@ -1101,12 +1101,15 @@ class SeamlessM4TAttention(nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
is_decoder: bool = False, is_decoder: bool = False,
bias: bool = True, bias: bool = True,
is_causal: bool = False,
config: Optional[SeamlessM4TConfig] = None,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim: if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError( raise ValueError(
...@@ -1115,6 +1118,7 @@ class SeamlessM4TAttention(nn.Module): ...@@ -1115,6 +1118,7 @@ class SeamlessM4TAttention(nn.Module):
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
......
...@@ -392,12 +392,15 @@ class SEWAttention(nn.Module): ...@@ -392,12 +392,15 @@ class SEWAttention(nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
is_decoder: bool = False, is_decoder: bool = False,
bias: bool = True, bias: bool = True,
is_causal: bool = False,
config: Optional[SEWConfig] = None,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim: if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError( raise ValueError(
...@@ -406,6 +409,7 @@ class SEWAttention(nn.Module): ...@@ -406,6 +409,7 @@ class SEWAttention(nn.Module):
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
......
...@@ -178,12 +178,15 @@ class Speech2TextAttention(nn.Module): ...@@ -178,12 +178,15 @@ class Speech2TextAttention(nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
is_decoder: bool = False, is_decoder: bool = False,
bias: bool = True, bias: bool = True,
is_causal: bool = False,
config: Optional[Speech2TextConfig] = None,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim: if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError( raise ValueError(
...@@ -192,6 +195,7 @@ class Speech2TextAttention(nn.Module): ...@@ -192,6 +195,7 @@ class Speech2TextAttention(nn.Module):
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
...@@ -322,15 +326,21 @@ class Speech2TextAttention(nn.Module): ...@@ -322,15 +326,21 @@ class Speech2TextAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Speech2Text SPEECH_TO_TEXT_ATTENTION_CLASSES = {"default": Speech2TextAttention}
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT
class Speech2TextEncoderLayer(nn.Module): class Speech2TextEncoderLayer(nn.Module):
def __init__(self, config: Speech2TextConfig): def __init__(self, config: Speech2TextConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
self.self_attn = Speech2TextAttention( attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads, num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
config=config,
) )
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout self.dropout = config.dropout
...@@ -391,28 +401,32 @@ class Speech2TextEncoderLayer(nn.Module): ...@@ -391,28 +401,32 @@ class Speech2TextEncoderLayer(nn.Module):
return outputs return outputs
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Speech2Text # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT
class Speech2TextDecoderLayer(nn.Module): class Speech2TextDecoderLayer(nn.Module):
def __init__(self, config: Speech2TextConfig): def __init__(self, config: Speech2TextConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = Speech2TextAttention( self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads, num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
is_decoder=True, is_decoder=True,
is_causal=True,
config=config,
) )
self.dropout = config.dropout self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function] self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = Speech2TextAttention( self.encoder_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[attn_type](
self.embed_dim, self.embed_dim,
config.decoder_attention_heads, config.decoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
is_decoder=True, is_decoder=True,
config=config,
) )
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment