Unverified Commit af3de8d8 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Whisper, Bart, MBart] Add Flash Attention 2 (#27203)



* add whisper fa2

* correct

* change all

* correct

* correct

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fix more

* fix more

* fix more

* fix more

* fix more

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 3520e37e
......@@ -125,12 +125,15 @@ class Speech2Text2Attention(nn.Module):
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[Speech2Text2Config] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
......@@ -139,6 +142,7 @@ class Speech2Text2Attention(nn.Module):
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
......
......@@ -281,12 +281,15 @@ class TimeSeriesTransformerAttention(nn.Module):
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[TimeSeriesTransformerConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
......@@ -295,6 +298,7 @@ class TimeSeriesTransformerAttention(nn.Module):
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
......@@ -425,15 +429,18 @@ class TimeSeriesTransformerAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->TimeSeriesTransformer
# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->TimeSeriesTransformer, BART->TIME_SERIES_TRANSFORMER
class TimeSeriesTransformerEncoderLayer(nn.Module):
def __init__(self, config: TimeSeriesTransformerConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = TimeSeriesTransformerAttention(
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
config=config,
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
......@@ -494,28 +501,35 @@ class TimeSeriesTransformerEncoderLayer(nn.Module):
return outputs
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->TimeSeriesTransformer
TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES = {"default": TimeSeriesTransformerAttention}
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->TimeSeriesTransformer, with BART->TIME_SERIES_TRANSFORMER
class TimeSeriesTransformerDecoderLayer(nn.Module):
def __init__(self, config: TimeSeriesTransformerConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = TimeSeriesTransformerAttention(
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
is_causal=True,
config=config,
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = TimeSeriesTransformerAttention(
self.encoder_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[attn_type](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
config=config,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
......
......@@ -432,12 +432,15 @@ class UniSpeechAttention(nn.Module):
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[UniSpeechConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
......@@ -446,6 +449,7 @@ class UniSpeechAttention(nn.Module):
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
......
......@@ -446,12 +446,15 @@ class UniSpeechSatAttention(nn.Module):
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[UniSpeechSatConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
......@@ -460,6 +463,7 @@ class UniSpeechSatAttention(nn.Module):
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
......
......@@ -498,12 +498,15 @@ class Wav2Vec2Attention(nn.Module):
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[Wav2Vec2Config] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
......@@ -512,6 +515,7 @@ class Wav2Vec2Attention(nn.Module):
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
......
......@@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
......@@ -38,6 +39,7 @@ from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
logging,
replace_return_docstrings,
)
......@@ -45,6 +47,11 @@ from .configuration_whisper import WhisperConfig
from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "WhisperConfig"
......@@ -57,6 +64,19 @@ WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor:
"""Returns sinusoids for positional embedding"""
if channels % 2 != 0:
......@@ -299,12 +319,15 @@ class WhisperAttention(nn.Module):
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[WhisperConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
......@@ -313,6 +336,7 @@ class WhisperAttention(nn.Module):
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
......@@ -445,15 +469,227 @@ class WhisperAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper
# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Whisper
class WhisperFlashAttention2(WhisperAttention):
"""
Whisper flash attention module. This module inherits from `WhisperAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# WhisperFlashAttention2 attention does not support output_attentions
if output_attentions:
raise ValueError("WhisperFlashAttention2 attention does not support output_attentions")
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, q_len, _ = hidden_states.size()
# get query proj
query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0].transpose(1, 2)
value_states = past_key_value[1].transpose(1, 2)
elif is_cross_attention:
# cross_attentions
key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
else:
# self_attention
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
)
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`int`, *optional*):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=self.is_causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal
)
return attn_output
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
WHISPER_ATTENTION_CLASSES = {
"default": WhisperAttention,
"flash_attention_2": WhisperFlashAttention2,
}
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper, MBART->WHISPER
class WhisperEncoderLayer(nn.Module):
def __init__(self, config: WhisperConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = WhisperAttention(
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = WHISPER_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
config=config,
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
......@@ -514,28 +750,32 @@ class WhisperEncoderLayer(nn.Module):
return outputs
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Whisper
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Whisper, MBART->WHISPER
class WhisperDecoderLayer(nn.Module):
def __init__(self, config: WhisperConfig):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = WhisperAttention(
self.self_attn = WHISPER_ATTENTION_CLASSES[attn_type](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
is_causal=True,
config=config,
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = WhisperAttention(
self.encoder_attn = WHISPER_ATTENTION_CLASSES[attn_type](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
config=config,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
......@@ -638,6 +878,7 @@ class WhisperPreTrainedModel(PreTrainedModel):
main_input_name = "input_features"
supports_gradient_checkpointing = True
_no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"]
_supports_flash_attn_2 = True
def _init_weights(self, module):
std = self.config.init_std
......@@ -1070,6 +1311,11 @@ class WhisperDecoder(WhisperPreTrainedModel):
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if getattr(self.config, "_flash_attn_2_enabled", False):
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
......
......@@ -21,13 +21,16 @@ import tempfile
import unittest
import numpy as np
from pytest import mark
import transformers
from transformers import WhisperConfig
from transformers.testing_utils import (
is_pt_flax_cross_test,
require_flash_attn,
require_torch,
require_torch_fp16,
require_torch_gpu,
require_torchaudio,
slow,
torch_device,
......@@ -795,6 +798,107 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
use_cache=use_cache,
)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_inference(self):
import torch
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
)
model.to(torch_device)
dummy_input = inputs_dict[model.main_input_name][:1]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
logits = outputs.decoder_hidden_states[-1]
logits_fa = outputs_fa.decoder_hidden_states[-1]
# whisper FA2 needs very high tolerance
assert torch.allclose(logits_fa, logits, atol=4e-1)
# check with inference + dropout
model.train()
_ = model_fa(dummy_input, decoder_input_ids=decoder_input_ids)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_inference_padding_right(self):
import torch
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False)
model.to(torch_device)
dummy_input = inputs_dict[model.main_input_name][:1]
dummy_input = dummy_input.to(torch.float16)
decoder_input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=dummy_input.device, dtype=torch.long)
decoder_attention_mask = torch.tensor(
[[0, 0, 0, 1, 1, 1]], device=dummy_input.device, dtype=torch.long
)
outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
logits = outputs.decoder_hidden_states[-1]
logits_fa = outputs_fa.decoder_hidden_states[-1]
# whisper FA2 needs very high tolerance
assert torch.allclose(logits_fa, logits, atol=4e-1)
other_inputs = {
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"output_hidden_states": True,
}
outputs = model(dummy_input, **other_inputs)
outputs_fa = model_fa(dummy_input, **other_inputs)
logits = outputs.decoder_hidden_states[-1]
logits_fa = outputs_fa.decoder_hidden_states[-1]
# whisper FA2 needs very high tolerance
assert torch.allclose(logits_fa[:, -2:], logits[:, -2:], atol=4e-1)
def _create_and_check_torchscript(self, config, inputs_dict):
if not self.test_torchscript:
return
......
......@@ -2856,7 +2856,7 @@ class ModelTesterMixin:
if not model_class._supports_flash_attn_2:
return
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
......@@ -2871,25 +2871,76 @@ class ModelTesterMixin:
)
model.to(torch_device)
dummy_input = torch.LongTensor([[1, 2, 3, 4, 5]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[0, 1, 1, 1, 1]]).to(torch_device)
dummy_input = inputs_dict[model.main_input_name][:1]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is not None:
dummy_attention_mask = dummy_attention_mask[:1]
dummy_attention_mask[:, 1:] = 1
dummy_attention_mask[:, :1] = 0
if model.config.is_encoder_decoder:
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
else:
outputs = model(dummy_input, output_hidden_states=True)
outputs_fa = model_fa(dummy_input, output_hidden_states=True)
logits = model(dummy_input, output_hidden_states=True).hidden_states[-1]
logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1]
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs_fa.decoder_hidden_states[-1]
)
self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2))
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
logits_fa = output_fa.hidden_states[-1]
if model.config.is_encoder_decoder:
other_inputs = {
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True,
}
if dummy_attention_mask is not None:
other_inputs["attention_mask"] = dummy_attention_mask
output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
logits = output.hidden_states[-1]
outputs = model(dummy_input, **other_inputs)
outputs_fa = model_fa(dummy_input, **other_inputs)
else:
other_inputs = {
"output_hidden_states": True,
}
if dummy_attention_mask is not None:
other_inputs["attention_mask"] = dummy_attention_mask
self.assertTrue(torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2))
outputs = model(dummy_input, **other_inputs)
outputs_fa = model_fa(dummy_input, **other_inputs)
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs_fa.decoder_hidden_states[-1]
)
assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
# check with inference + dropout
model.train()
_ = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
_ = model_fa(dummy_input, **other_inputs)
@require_flash_attn
@require_torch_gpu
......@@ -2902,7 +2953,7 @@ class ModelTesterMixin:
if not model_class._supports_flash_attn_2:
return
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
......@@ -2917,21 +2968,72 @@ class ModelTesterMixin:
)
model.to(torch_device)
dummy_input = torch.LongTensor([[1, 2, 3, 4, 5]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1, 0]]).to(torch_device)
dummy_input = inputs_dict[model.main_input_name][:1]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is not None:
dummy_attention_mask = dummy_attention_mask[:1]
dummy_attention_mask[:, :-1] = 1
dummy_attention_mask[:, -1:] = 0
if model.config.is_encoder_decoder:
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
logits = model(dummy_input, output_hidden_states=True).hidden_states[-1]
logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1]
outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
else:
outputs = model(dummy_input, output_hidden_states=True)
outputs_fa = model_fa(dummy_input, output_hidden_states=True)
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs_fa.decoder_hidden_states[-1]
)
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
if model.config.is_encoder_decoder:
other_inputs = {
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True,
}
if dummy_attention_mask is not None:
other_inputs["attention_mask"] = dummy_attention_mask
self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2))
outputs = model(dummy_input, **other_inputs)
outputs_fa = model_fa(dummy_input, **other_inputs)
else:
other_inputs = {
"output_hidden_states": True,
}
if dummy_attention_mask is not None:
other_inputs["attention_mask"] = dummy_attention_mask
output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
logits_fa = output_fa.hidden_states[-1]
outputs = model(dummy_input, **other_inputs)
outputs_fa = model_fa(dummy_input, **other_inputs)
output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
logits = output.hidden_states[-1]
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs_fa.decoder_hidden_states[-1]
)
self.assertTrue(torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2))
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
@require_flash_attn
@require_torch_gpu
......@@ -2944,7 +3046,7 @@ class ModelTesterMixin:
if not model_class._supports_flash_attn_2:
return
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
......@@ -2953,8 +3055,14 @@ class ModelTesterMixin:
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True
).to(torch_device)
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device)
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
# make sure we do left padding
dummy_attention_mask[:, :-1] = 0
dummy_attention_mask[:, -1:] = 1
out = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
......@@ -2981,7 +3089,7 @@ class ModelTesterMixin:
if not model_class._supports_flash_attn_2:
return
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
......@@ -2990,8 +3098,14 @@ class ModelTesterMixin:
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True
).to(torch_device)
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
# make sure we do left padding
dummy_attention_mask[:, :-1] = 1
dummy_attention_mask[:, -1:] = 0
out = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
......@@ -3014,26 +3128,39 @@ class ModelTesterMixin:
def test_flash_attn_2_generate_use_cache(self):
import torch
max_new_tokens = 30
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
return
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True
tmpdirname,
torch_dtype=torch.float16,
use_flash_attention_2=True,
low_cpu_mem_usage=True,
).to(torch_device)
# Just test that a large cache works as expected
_ = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=30, do_sample=False
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
@require_flash_attn
......@@ -3048,14 +3175,18 @@ class ModelTesterMixin:
if not model_class._supports_flash_attn_2:
return
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device)
dummy_input = inputs_dict[model.main_input_name]
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
if model.config.is_encoder_decoder:
dummy_decoder_input_ids = inputs_dict["decoder_input_ids"]
dummy_decoder_attention_mask = inputs_dict["decoder_attention_mask"]
model = model_class.from_pretrained(
tmpdirname,
......@@ -3070,10 +3201,19 @@ class ModelTesterMixin:
if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
param.data = param.data.to(torch.float32)
_ = model(input_ids=dummy_input)
if model.config.is_encoder_decoder:
_ = model(dummy_input, decoder_input_ids=dummy_decoder_input_ids)
# with attention mask
_ = model(
dummy_input,
attention_mask=dummy_attention_mask,
decoder_input_ids=dummy_decoder_input_ids,
decoder_attention_mask=dummy_decoder_attention_mask,
)
else:
_ = model(dummy_input)
# with attention mask
_ = model(input_ids=dummy_input, attention_mask=dummy_attention_mask)
_ = model(dummy_input, attention_mask=dummy_attention_mask)
global_rng = random.Random()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment