Unverified Commit 0d04b1e2 authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Add Flash Attention 2 support to Musicgen and Musicgen Melody (#29939)

* add FA2 to o.g Musicgen

* make style

* add FA2 support to Musicgen Melody

* add generation FA2 tests to o.g Musicgen

* make style and fix copies

* add Musicgen to FA2 docs + deprecate list

* add sdpa supports to Musicgen's

* make style and fix copies

* refactor attention implementation arguments

* add Copied from to sdpa tests

* add copied form in sdpa tests melody

* add copied for FA2 generation tests

* add FA2 inference copied from

* make style
parent fed27ffc
......@@ -55,6 +55,8 @@ FlashAttention-2 is currently supported for the following architectures:
* [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
......@@ -190,6 +192,8 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model)
* [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model)
* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
<Tip>
......
......@@ -1470,6 +1470,12 @@ MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP = DeprecatedDict(
MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = DeprecatedList(["facebook/musicgen-small"])
MUSICGEN_MELODY_PRETRAINED_CONFIG_ARCHIVE_MAP = DeprecatedDict(
{"facebook/musicgen-melody": "https://huggingface.co/facebook/musicgen-melody/resolve/main/config.json"}
)
MUSICGEN_MELODY_PRETRAINED_MODEL_ARCHIVE_LIST = DeprecatedList(["facebook/musicgen-melody"])
MVP_PRETRAINED_MODEL_ARCHIVE_LIST = DeprecatedList(
[
"RUCAIBox/mvp",
......
......@@ -239,3 +239,20 @@ class MusicgenConfig(PretrainedConfig):
# This is a property because you might want to change the codec model on the fly
def sampling_rate(self):
return self.audio_encoder.sampling_rate
@property
def _attn_implementation(self):
# This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
if hasattr(self, "_attn_implementation_internal"):
if self._attn_implementation_internal is None:
# `config.attn_implementation` should never be None, for backward compatibility.
return "eager"
else:
return self._attn_implementation_internal
else:
return "eager"
@_attn_implementation.setter
def _attn_implementation(self, value):
self._attn_implementation_internal = value
self.decoder._attn_implementation = value
......@@ -22,13 +22,19 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...generation.configuration_utils import GenerationConfig
from ...generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList
from ...generation.stopping_criteria import StoppingCriteriaList
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
......@@ -40,6 +46,8 @@ from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
......@@ -48,6 +56,10 @@ from ..auto.modeling_auto import AutoModel
from .configuration_musicgen import MusicgenConfig, MusicgenDecoderConfig
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
if TYPE_CHECKING:
from ...generation.streamers import BaseStreamer
......@@ -60,6 +72,19 @@ _CHECKPOINT_FOR_DOC = "facebook/musicgen-small"
from ..deprecated._archive_maps import MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
@dataclass
class MusicgenUnconditionalInput(ModelOutput):
"""
......@@ -302,29 +327,361 @@ class MusicgenAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Musicgen
class MusicgenFlashAttention2(MusicgenAttention):
"""
Musicgen flash attention module. This module inherits from `MusicgenAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# MusicgenFlashAttention2 attention does not support output_attentions
if output_attentions:
raise ValueError("MusicgenFlashAttention2 attention does not support output_attentions")
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, q_len, _ = hidden_states.size()
# get query proj
query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0].transpose(1, 2)
value_states = past_key_value[1].transpose(1, 2)
elif is_cross_attention:
# cross_attentions
key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
else:
# self_attention
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
)
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)
return attn_output
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->Musicgen
class MusicgenSdpaAttention(MusicgenAttention):
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
if output_attentions or layer_head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"MusicgenModel is using MusicgenSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states,
key_value_states=key_value_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
query_states = self._shape(query_states, tgt_len, bsz)
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
)
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None, past_key_value
MUSICGEN_ATTENTION_CLASSES = {
"eager": MusicgenAttention,
"sdpa": MusicgenSdpaAttention,
"flash_attention_2": MusicgenFlashAttention2,
}
class MusicgenDecoderLayer(nn.Module):
def __init__(self, config: MusicgenDecoderConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = MusicgenAttention(
self.self_attn = MUSICGEN_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
bias=False,
is_causal=True,
config=config,
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = MusicgenAttention(
self.encoder_attn = MUSICGEN_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
bias=False,
config=config,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=False)
......@@ -432,6 +789,8 @@ class MusicgenPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["MusicgenDecoderLayer", "MusicgenAttention"]
_supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module):
std = self.config.initializer_factor
......@@ -667,6 +1026,7 @@ class MusicgenDecoder(MusicgenPreTrainedModel):
self.layers = nn.ModuleList([MusicgenDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.layer_norm = nn.LayerNorm(config.hidden_size)
self.attn_implementation = config._attn_implementation
self.gradient_checkpointing = False
# Initialize weights and apply final processing
......@@ -721,12 +1081,36 @@ class MusicgenDecoder(MusicgenPreTrainedModel):
if inputs_embeds is None:
inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)])
if self.attn_implementation == "flash_attention_2":
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self.attn_implementation == "sdpa" and head_mask is None and not output_attentions:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
)
else:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
if self.attn_implementation == "flash_attention_2":
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
elif self.attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
encoder_attention_mask,
inputs_embeds.dtype,
tgt_len=input_shape[-1],
)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
......@@ -1409,6 +1793,8 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
base_model_prefix = "encoder_decoder"
main_input_name = "input_ids"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(
self,
......
......@@ -21,9 +21,7 @@ from ..auto.configuration_auto import AutoConfig
logger = logging.get_logger(__name__)
MUSICGEN_MELODY_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/musicgen-melody": "https://huggingface.co/facebook/musicgen-melody/resolve/main/config.json",
}
from ..deprecated._archive_maps import MUSICGEN_MELODY_PRETRAINED_CONFIG_ARCHIVE_MAP # noqa: F401, E402
class MusicgenMelodyDecoderConfig(PretrainedConfig):
......@@ -254,3 +252,20 @@ class MusicgenMelodyConfig(PretrainedConfig):
# This is a property because you might want to change the codec model on the fly
def sampling_rate(self):
return self.audio_encoder.sampling_rate
@property
def _attn_implementation(self):
# This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
if hasattr(self, "_attn_implementation_internal"):
if self._attn_implementation_internal is None:
# `config.attn_implementation` should never be None, for backward compatibility.
return "eager"
else:
return self._attn_implementation_internal
else:
return "eager"
@_attn_implementation.setter
def _attn_implementation(self, value):
self._attn_implementation_internal = value
self.decoder._attn_implementation = value
......@@ -22,13 +22,14 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...generation.configuration_utils import GenerationConfig
from ...generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList
from ...generation.stopping_criteria import StoppingCriteriaList
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_outputs import (
BaseModelOutputWithPast,
ModelOutput,
......@@ -37,6 +38,8 @@ from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
......@@ -45,6 +48,10 @@ from ..auto.modeling_auto import AutoModel, AutoModelForTextEncoding
from .configuration_musicgen_melody import MusicgenMelodyConfig, MusicgenMelodyDecoderConfig
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
if TYPE_CHECKING:
from ...generation.streamers import BaseStreamer
......@@ -53,10 +60,20 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MusicgenMelodyConfig"
_CHECKPOINT_FOR_DOC = "facebook/musicgen-melody"
MUSICGEN_MELODY_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/musicgen-melody",
# See all Musicgen Melody models at https://huggingface.co/models?filter=musicgen_melody
]
from ..deprecated._archive_maps import MUSICGEN_MELODY_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
@dataclass
......@@ -324,17 +341,348 @@ class MusicgenMelodyAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->MusicgenMelody
class MusicgenMelodyFlashAttention2(MusicgenMelodyAttention):
"""
MusicgenMelody flash attention module. This module inherits from `MusicgenMelodyAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# MusicgenMelodyFlashAttention2 attention does not support output_attentions
if output_attentions:
raise ValueError("MusicgenMelodyFlashAttention2 attention does not support output_attentions")
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, q_len, _ = hidden_states.size()
# get query proj
query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0].transpose(1, 2)
value_states = past_key_value[1].transpose(1, 2)
elif is_cross_attention:
# cross_attentions
key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
else:
# self_attention
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
)
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)
return attn_output
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->MusicgenMelody
class MusicgenMelodySdpaAttention(MusicgenMelodyAttention):
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
if output_attentions or layer_head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"MusicgenMelodyModel is using MusicgenMelodySdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states,
key_value_states=key_value_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
query_states = self._shape(query_states, tgt_len, bsz)
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
)
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None, past_key_value
MUSICGEN_MELODY_ATTENTION_CLASSES = {
"eager": MusicgenMelodyAttention,
"sdpa": MusicgenMelodySdpaAttention,
"flash_attention_2": MusicgenMelodyFlashAttention2,
}
class MusicgenMelodyDecoderLayer(nn.Module):
def __init__(self, config: MusicgenMelodyDecoderConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = MusicgenMelodyAttention(
self.self_attn = MUSICGEN_MELODY_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
bias=False,
is_causal=True,
config=config,
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
......@@ -414,6 +762,8 @@ class MusicgenMelodyPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["MusicgenMelodyDecoderLayer", "MusicgenMelodyAttention"]
_supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module):
std = self.config.initializer_factor
......@@ -626,6 +976,7 @@ class MusicgenMelodyDecoder(MusicgenMelodyPreTrainedModel):
self.layers = nn.ModuleList([MusicgenMelodyDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.layer_norm = nn.LayerNorm(config.hidden_size)
self.attn_implementation = config._attn_implementation
self.gradient_checkpointing = False
# Initialize weights and apply final processing
......@@ -695,6 +1046,18 @@ class MusicgenMelodyDecoder(MusicgenMelodyPreTrainedModel):
input_shape = inputs_embeds.size()[:-1]
if self.attn_implementation == "flash_attention_2":
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self.attn_implementation == "sdpa" and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
)
else:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
......@@ -1373,6 +1736,8 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
config_class = MusicgenMelodyConfig
main_input_name = "input_ids"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(
self,
......
......@@ -16,9 +16,12 @@
import copy
import inspect
import math
import tempfile
import unittest
import numpy as np
from parameterized import parameterized
from pytest import mark
from transformers import (
EncodecConfig,
......@@ -30,12 +33,15 @@ from transformers import (
)
from transformers.testing_utils import (
is_torch_available,
require_flash_attn,
require_torch,
require_torch_fp16,
require_torch_gpu,
require_torch_sdpa,
slow,
torch_device,
)
from transformers.utils import cached_property
from transformers.utils import cached_property, is_torch_bf16_available_on_device, is_torch_fp16_available_on_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
......@@ -277,6 +283,615 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
self.assertNotIn(config.pad_token_id, output_generate)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence
def test_flash_attn_2_inference_equivalence(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
# Ignore copy
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is not None:
# Ignore copy
dummy_attention_mask[:, 1:] = 1
dummy_attention_mask[:, :1] = 0
# Ignore copy
outputs = model(dummy_input, output_hidden_states=True)
# Ignore copy
outputs_fa = model_fa(dummy_input, output_hidden_states=True)
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs_fa.decoder_hidden_states[-1]
)
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
# Ignore copy
other_inputs = {
"output_hidden_states": True,
}
if dummy_attention_mask is not None:
other_inputs["attention_mask"] = dummy_attention_mask
outputs = model(dummy_input, **other_inputs)
outputs_fa = model_fa(dummy_input, **other_inputs)
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs_fa.decoder_hidden_states[-1]
)
assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
# check with inference + dropout
model.train()
_ = model_fa(dummy_input, **other_inputs)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence_right_padding
def test_flash_attn_2_inference_equivalence_right_padding(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
# Ignore copy
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is not None:
# Ignore copy
dummy_attention_mask[:, :-1] = 1
dummy_attention_mask[:, -1:] = 0
if model.config.is_encoder_decoder:
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)
outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
else:
outputs = model(dummy_input, output_hidden_states=True)
outputs_fa = model_fa(dummy_input, output_hidden_states=True)
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs_fa.decoder_hidden_states[-1]
)
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
# Ignore copy
other_inputs = {
"output_hidden_states": True,
}
if dummy_attention_mask is not None:
other_inputs["attention_mask"] = dummy_attention_mask
outputs = model(dummy_input, **other_inputs)
outputs_fa = model_fa(dummy_input, **other_inputs)
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs_fa.decoder_hidden_states[-1]
)
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_left_padding
def test_flash_attn_2_generate_left_padding(self):
# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
# make sure we do left padding
dummy_attention_mask[:, :-1] = 0
dummy_attention_mask[:, -1:] = 1
out = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
out_fa = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
self.assertTrue(torch.allclose(out, out_fa))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_padding_right
def test_flash_attn_2_generate_padding_right(self):
# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
# make sure we do right padding
dummy_attention_mask[:, :-1] = 1
dummy_attention_mask[:, -1:] = 0
out = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
out_fa = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
self.assertTrue(torch.allclose(out, out_fa))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_use_cache
def test_flash_attn_2_generate_use_cache(self):
max_new_tokens = 30
# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
# Just test that a large cache works as expected
_ = model.generate(
dummy_input,
attention_mask=dummy_attention_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
)
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_eager_matches_sdpa_inference
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
if not self.all_model_classes[0]._supports_sdpa:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device):
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device):
self.skipTest(
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
)
# Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead.
if torch_dtype == "float16":
torch_dtype = torch.float16
elif torch_dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif torch_dtype == "float32":
torch_dtype = torch.float32
atols = {
("cpu", False, torch.float32): 1e-6,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-6,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-6,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-6,
("cuda", True, torch.bfloat16): 1e-2,
("cuda", True, torch.float16): 5e-3,
}
rtols = {
("cpu", False, torch.float32): 1e-4,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-4,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-4,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-4,
("cuda", True, torch.bfloat16): 3e-2,
("cuda", True, torch.float16): 5e-3,
}
def get_mean_reldiff(failcase, x, ref, atol, rtol):
return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
is_encoder_decoder = model.config.is_encoder_decoder
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
model_sdpa = model_sdpa.eval().to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = []
for padding_side in ["left", "right"]:
for use_mask in [False, True]:
for batch_size in [1, 5]:
# Ignore copy
batch_size_input_ids = self.model_tester.num_codebooks * batch_size
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
dummy_input = dummy_input.to(torch_dtype)
# Ignore copy
dummy_input = dummy_input[:batch_size_input_ids]
# Ignore copy
if dummy_input.shape[0] != batch_size_input_ids:
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
# Ignore copy
extension = torch.rand(
batch_size_input_ids - dummy_input.shape[0],
*dummy_input.shape[1:],
dtype=torch_dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
else:
# Ignore copy
extension = torch.randint(
high=5,
size=(batch_size_input_ids - dummy_input.shape[0], *dummy_input.shape[1:]),
dtype=dummy_input.dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
if not use_mask:
dummy_attention_mask = None
else:
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is None:
if is_encoder_decoder:
seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1]
else:
seqlen = dummy_input.shape[-1]
dummy_attention_mask = (
torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device)
)
dummy_attention_mask = dummy_attention_mask[:batch_size]
if dummy_attention_mask.shape[0] != batch_size:
extension = torch.ones(
batch_size - dummy_attention_mask.shape[0],
*dummy_attention_mask.shape[1:],
dtype=dummy_attention_mask.dtype,
device=torch_device,
)
dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0)
dummy_attention_mask = dummy_attention_mask.to(torch_device)
dummy_attention_mask[:] = 1
if padding_side == "left":
dummy_attention_mask[-1, :-1] = 1
dummy_attention_mask[-1, -4:] = 0
elif padding_side == "right":
dummy_attention_mask[-1, 1:] = 1
dummy_attention_mask[-1, :3] = 0
for enable_kernels in [False, True]:
failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}"
other_inputs = {
"output_hidden_states": True,
}
# Otherwise fails for e.g. WhisperEncoderModel
if "attention_mask" in inspect.signature(model_eager.forward).parameters:
other_inputs["attention_mask"] = dummy_attention_mask
# TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad():
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
):
outputs_eager = model_eager(dummy_input, **other_inputs)
outputs_sdpa = model_sdpa(dummy_input, **other_inputs)
logits_eager = (
outputs_eager.hidden_states[-1]
if not is_encoder_decoder
else outputs_eager.decoder_hidden_states[-1]
)
logits_sdpa = (
outputs_sdpa.hidden_states[-1]
if not is_encoder_decoder
else outputs_sdpa.decoder_hidden_states[-1]
)
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
# Masked tokens output slightly deviates - we don't mind that.
if use_mask:
if padding_side == "left":
sub_sdpa = logits_sdpa[:-1]
sub_eager = logits_eager[:-1]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
sub_sdpa = logits_sdpa[-1, :-4]
sub_eager = logits_eager[-1, :-4]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
# Testing the padding tokens is not really meaningful but anyway
# sub_sdpa = logits_sdpa[-1, -4:]
# sub_eager = logits_eager[-1, -4:]
# if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
# fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
elif padding_side == "right":
sub_sdpa = logits_sdpa[:-1]
sub_eager = logits_eager[:-1]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
sub_sdpa = logits_sdpa[-1, 3:]
sub_eager = logits_eager[-1, 3:]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
# Testing the padding tokens is not really meaningful but anyway
# sub_sdpa = logits_sdpa[-1, :3]
# sub_eager = logits_eager[-1, :3]
# if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
# fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
else:
if not torch.allclose(logits_sdpa, logits_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
)
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
@require_torch_sdpa
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_eager_matches_sdpa_generate
def test_eager_matches_sdpa_generate(self):
max_new_tokens = 30
# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_sdpa:
self.skipTest(f"{model_class.__name__} does not support SDPA")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
model_sdpa = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation="eager",
).to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa:
raise ValueError("The SDPA model should have SDPA attention layers")
# Just test that a large cache works as expected
res_eager = model_eager.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
res_sdpa = model_sdpa.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
self.assertTrue(torch.allclose(res_eager, res_sdpa))
def prepare_musicgen_inputs_dict(
config,
......@@ -941,6 +1556,639 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
self.assertNotIn(config.pad_token_id, output_generate)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence
def test_flash_attn_2_inference_equivalence(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
# Ignore copy
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is not None:
# Ignore copy
dummy_attention_mask[:, 1:] = 1
dummy_attention_mask[:, :1] = 0
# Ignore copy
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)
# Ignore copy
outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
# Ignore copy
outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs_fa.decoder_hidden_states[-1]
)
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
# Ignore copy
other_inputs = {
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True,
}
# Ignore copy
if dummy_attention_mask is not None:
other_inputs["attention_mask"] = dummy_attention_mask
# Ignore copy
outputs = model(dummy_input, **other_inputs)
# Ignore copy
outputs_fa = model_fa(dummy_input, **other_inputs)
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs_fa.decoder_hidden_states[-1]
)
assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
# check with inference + dropout
model.train()
_ = model_fa(dummy_input, **other_inputs)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence_right_padding
def test_flash_attn_2_inference_equivalence_right_padding(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
# Ignore copy
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is not None:
# Ignore copy
dummy_attention_mask[:, :-1] = 1
dummy_attention_mask[:, -1:] = 0
# Ignore copy
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)
# Ignore copy
outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
# Ignore copy
outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs_fa.decoder_hidden_states[-1]
)
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
# Ignore copy
other_inputs = {
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True,
}
# Ignore copy
if dummy_attention_mask is not None:
other_inputs["attention_mask"] = dummy_attention_mask
# Ignore copy
outputs = model(dummy_input, **other_inputs)
# Ignore copy
outputs_fa = model_fa(dummy_input, **other_inputs)
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs_fa.decoder_hidden_states[-1]
)
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_left_padding
def test_flash_attn_2_generate_left_padding(self):
# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
dummy_attention_mask = inputs_dict.get("attention_mask")
if dummy_attention_mask is None:
dummy_attention_mask = torch.ones_like(dummy_input)
# make sure we do left padding
dummy_attention_mask[:, :-1] = 0
dummy_attention_mask[:, -1:] = 1
out = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
out_fa = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
self.assertTrue(torch.allclose(out, out_fa))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_padding_right
def test_flash_attn_2_generate_padding_right(self):
# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
dummy_attention_mask = inputs_dict.get("attention_mask")
if dummy_attention_mask is None:
dummy_attention_mask = torch.ones_like(dummy_input)
# make sure we do right padding
dummy_attention_mask[:, :-1] = 1
dummy_attention_mask[:, -1:] = 0
out = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
out_fa = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
self.assertTrue(torch.allclose(out, out_fa))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_use_cache
def test_flash_attn_2_generate_use_cache(self):
max_new_tokens = 30
# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
# Just test that a large cache works as expected
_ = model.generate(
dummy_input,
attention_mask=dummy_attention_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
)
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_eager_matches_sdpa_inference
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
if not self.all_model_classes[0]._supports_sdpa:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device):
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device):
self.skipTest(
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
)
# Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead.
if torch_dtype == "float16":
torch_dtype = torch.float16
elif torch_dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif torch_dtype == "float32":
torch_dtype = torch.float32
atols = {
("cpu", False, torch.float32): 1e-6,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-6,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-6,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-6,
("cuda", True, torch.bfloat16): 1e-2,
("cuda", True, torch.float16): 5e-3,
}
rtols = {
("cpu", False, torch.float32): 1e-4,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-4,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-4,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-4,
("cuda", True, torch.bfloat16): 3e-2,
("cuda", True, torch.float16): 5e-3,
}
def get_mean_reldiff(failcase, x, ref, atol, rtol):
return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
is_encoder_decoder = model.config.is_encoder_decoder
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
model_sdpa = model_sdpa.eval().to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = []
for padding_side in ["left", "right"]:
for use_mask in [False, True]:
for batch_size in [1, 5]:
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
dummy_input = dummy_input.to(torch_dtype)
dummy_input = dummy_input[:batch_size]
if dummy_input.shape[0] != batch_size:
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
extension = torch.rand(
batch_size - dummy_input.shape[0],
*dummy_input.shape[1:],
dtype=torch_dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
else:
extension = torch.randint(
high=5,
size=(batch_size - dummy_input.shape[0], *dummy_input.shape[1:]),
dtype=dummy_input.dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
if not use_mask:
dummy_attention_mask = None
else:
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is None:
# Ignore copy
seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1]
# Ignore copy
dummy_attention_mask = (
torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device)
)
dummy_attention_mask = dummy_attention_mask[:batch_size]
if dummy_attention_mask.shape[0] != batch_size:
extension = torch.ones(
batch_size - dummy_attention_mask.shape[0],
*dummy_attention_mask.shape[1:],
dtype=dummy_attention_mask.dtype,
device=torch_device,
)
dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0)
dummy_attention_mask = dummy_attention_mask.to(torch_device)
dummy_attention_mask[:] = 1
if padding_side == "left":
dummy_attention_mask[-1, :-1] = 1
dummy_attention_mask[-1, -4:] = 0
elif padding_side == "right":
dummy_attention_mask[-1, 1:] = 1
dummy_attention_mask[-1, :3] = 0
for enable_kernels in [False, True]:
failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}"
# Ignore copy
batch_size_input_ids = self.model_tester.num_codebooks * batch_size
# Ignore copy
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[
:batch_size_input_ids
]
# Ignore copy
if decoder_input_ids.shape[0] != batch_size_input_ids:
# Ignore copy
extension = torch.ones(
batch_size_input_ids - decoder_input_ids.shape[0],
*decoder_input_ids.shape[1:],
dtype=decoder_input_ids.dtype,
device=torch_device,
)
decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0)
decoder_input_ids = decoder_input_ids.to(torch_device)
# TODO: never an `attention_mask` arg here?
# Ignore copy
other_inputs = {
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True,
}
# TODO: test gradients as well (& for FA2 as well!)
# Ignore copy
with torch.no_grad():
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
):
outputs_eager = model_eager(dummy_input, **other_inputs)
outputs_sdpa = model_sdpa(dummy_input, **other_inputs)
logits_eager = (
outputs_eager.hidden_states[-1]
if not is_encoder_decoder
else outputs_eager.decoder_hidden_states[-1]
)
logits_sdpa = (
outputs_sdpa.hidden_states[-1]
if not is_encoder_decoder
else outputs_sdpa.decoder_hidden_states[-1]
)
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
# Masked tokens output slightly deviates - we don't mind that.
if use_mask:
if padding_side == "left":
sub_sdpa = logits_sdpa[:-1]
sub_eager = logits_eager[:-1]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
sub_sdpa = logits_sdpa[-1, :-4]
sub_eager = logits_eager[-1, :-4]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
# Testing the padding tokens is not really meaningful but anyway
# sub_sdpa = logits_sdpa[-1, -4:]
# sub_eager = logits_eager[-1, -4:]
# if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
# fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
elif padding_side == "right":
sub_sdpa = logits_sdpa[:-1]
sub_eager = logits_eager[:-1]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
sub_sdpa = logits_sdpa[-1, 3:]
sub_eager = logits_eager[-1, 3:]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
# Testing the padding tokens is not really meaningful but anyway
# sub_sdpa = logits_sdpa[-1, :3]
# sub_eager = logits_eager[-1, :3]
# if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
# fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
else:
if not torch.allclose(logits_sdpa, logits_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
)
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
@require_torch_sdpa
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_eager_matches_sdpa_generate
def test_eager_matches_sdpa_generate(self):
max_new_tokens = 30
# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_sdpa:
self.skipTest(f"{model_class.__name__} does not support SDPA")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
model_sdpa = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation="eager",
).to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa:
raise ValueError("The SDPA model should have SDPA attention layers")
# Just test that a large cache works as expected
res_eager = model_eager.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
res_sdpa = model_sdpa.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
self.assertTrue(torch.allclose(res_eager, res_sdpa))
def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):
"""Produces a series of 'bip bip' sounds at a given frequency."""
......
......@@ -16,9 +16,12 @@
import copy
import inspect
import math
import tempfile
import unittest
import numpy as np
from parameterized import parameterized
from pytest import mark
from transformers import (
EncodecConfig,
......@@ -30,13 +33,16 @@ from transformers import (
from transformers.testing_utils import (
is_torch_available,
is_torchaudio_available,
require_flash_attn,
require_torch,
require_torch_fp16,
require_torch_gpu,
require_torch_sdpa,
require_torchaudio,
slow,
torch_device,
)
from transformers.utils import cached_property
from transformers.utils import cached_property, is_torch_bf16_available_on_device, is_torch_fp16_available_on_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
......@@ -277,6 +283,615 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
self.assertNotIn(config.pad_token_id, output_generate)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.test_flash_attn_2_inference_equivalence
def test_flash_attn_2_inference_equivalence(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
# Ignore copy
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is not None:
# Ignore copy
dummy_attention_mask[:, 1:] = 1
dummy_attention_mask[:, :1] = 0
# Ignore copy
outputs = model(dummy_input, output_hidden_states=True)
# Ignore copy
outputs_fa = model_fa(dummy_input, output_hidden_states=True)
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs_fa.decoder_hidden_states[-1]
)
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
# Ignore copy
other_inputs = {
"output_hidden_states": True,
}
if dummy_attention_mask is not None:
other_inputs["attention_mask"] = dummy_attention_mask
outputs = model(dummy_input, **other_inputs)
outputs_fa = model_fa(dummy_input, **other_inputs)
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs_fa.decoder_hidden_states[-1]
)
assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
# check with inference + dropout
model.train()
_ = model_fa(dummy_input, **other_inputs)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.test_flash_attn_2_inference_equivalence_right_padding
def test_flash_attn_2_inference_equivalence_right_padding(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
# Ignore copy
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is not None:
# Ignore copy
dummy_attention_mask[:, :-1] = 1
dummy_attention_mask[:, -1:] = 0
if model.config.is_encoder_decoder:
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)
outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
else:
outputs = model(dummy_input, output_hidden_states=True)
outputs_fa = model_fa(dummy_input, output_hidden_states=True)
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs_fa.decoder_hidden_states[-1]
)
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
# Ignore copy
other_inputs = {
"output_hidden_states": True,
}
if dummy_attention_mask is not None:
other_inputs["attention_mask"] = dummy_attention_mask
outputs = model(dummy_input, **other_inputs)
outputs_fa = model_fa(dummy_input, **other_inputs)
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs_fa.decoder_hidden_states[-1]
)
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_left_padding
def test_flash_attn_2_generate_left_padding(self):
# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
# make sure we do left padding
dummy_attention_mask[:, :-1] = 0
dummy_attention_mask[:, -1:] = 1
out = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
out_fa = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
self.assertTrue(torch.allclose(out, out_fa))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_padding_right
def test_flash_attn_2_generate_padding_right(self):
# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
# make sure we do right padding
dummy_attention_mask[:, :-1] = 1
dummy_attention_mask[:, -1:] = 0
out = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
out_fa = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
self.assertTrue(torch.allclose(out, out_fa))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.test_flash_attn_2_generate_use_cache
def test_flash_attn_2_generate_use_cache(self):
max_new_tokens = 30
# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
# Just test that a large cache works as expected
_ = model.generate(
dummy_input,
attention_mask=dummy_attention_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
)
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
# Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.test_eager_matches_sdpa_inference
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
if not self.all_model_classes[0]._supports_sdpa:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device):
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device):
self.skipTest(
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
)
# Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead.
if torch_dtype == "float16":
torch_dtype = torch.float16
elif torch_dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif torch_dtype == "float32":
torch_dtype = torch.float32
atols = {
("cpu", False, torch.float32): 1e-6,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-6,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-6,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-6,
("cuda", True, torch.bfloat16): 1e-2,
("cuda", True, torch.float16): 5e-3,
}
rtols = {
("cpu", False, torch.float32): 1e-4,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-4,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-4,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-4,
("cuda", True, torch.bfloat16): 3e-2,
("cuda", True, torch.float16): 5e-3,
}
def get_mean_reldiff(failcase, x, ref, atol, rtol):
return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
is_encoder_decoder = model.config.is_encoder_decoder
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
model_sdpa = model_sdpa.eval().to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = []
for padding_side in ["left", "right"]:
for use_mask in [False, True]:
for batch_size in [1, 5]:
# Ignore copy
batch_size_input_ids = self.model_tester.num_codebooks * batch_size
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
dummy_input = dummy_input.to(torch_dtype)
# Ignore copy
dummy_input = dummy_input[:batch_size_input_ids]
# Ignore copy
if dummy_input.shape[0] != batch_size_input_ids:
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
# Ignore copy
extension = torch.rand(
batch_size_input_ids - dummy_input.shape[0],
*dummy_input.shape[1:],
dtype=torch_dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
else:
# Ignore copy
extension = torch.randint(
high=5,
size=(batch_size_input_ids - dummy_input.shape[0], *dummy_input.shape[1:]),
dtype=dummy_input.dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
if not use_mask:
dummy_attention_mask = None
else:
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is None:
if is_encoder_decoder:
seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1]
else:
seqlen = dummy_input.shape[-1]
dummy_attention_mask = (
torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device)
)
dummy_attention_mask = dummy_attention_mask[:batch_size]
if dummy_attention_mask.shape[0] != batch_size:
extension = torch.ones(
batch_size - dummy_attention_mask.shape[0],
*dummy_attention_mask.shape[1:],
dtype=dummy_attention_mask.dtype,
device=torch_device,
)
dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0)
dummy_attention_mask = dummy_attention_mask.to(torch_device)
dummy_attention_mask[:] = 1
if padding_side == "left":
dummy_attention_mask[-1, :-1] = 1
dummy_attention_mask[-1, -4:] = 0
elif padding_side == "right":
dummy_attention_mask[-1, 1:] = 1
dummy_attention_mask[-1, :3] = 0
for enable_kernels in [False, True]:
failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}"
other_inputs = {
"output_hidden_states": True,
}
# Otherwise fails for e.g. WhisperEncoderModel
if "attention_mask" in inspect.signature(model_eager.forward).parameters:
other_inputs["attention_mask"] = dummy_attention_mask
# TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad():
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
):
outputs_eager = model_eager(dummy_input, **other_inputs)
outputs_sdpa = model_sdpa(dummy_input, **other_inputs)
logits_eager = (
outputs_eager.hidden_states[-1]
if not is_encoder_decoder
else outputs_eager.decoder_hidden_states[-1]
)
logits_sdpa = (
outputs_sdpa.hidden_states[-1]
if not is_encoder_decoder
else outputs_sdpa.decoder_hidden_states[-1]
)
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
# Masked tokens output slightly deviates - we don't mind that.
if use_mask:
if padding_side == "left":
sub_sdpa = logits_sdpa[:-1]
sub_eager = logits_eager[:-1]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
sub_sdpa = logits_sdpa[-1, :-4]
sub_eager = logits_eager[-1, :-4]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
# Testing the padding tokens is not really meaningful but anyway
# sub_sdpa = logits_sdpa[-1, -4:]
# sub_eager = logits_eager[-1, -4:]
# if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
# fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
elif padding_side == "right":
sub_sdpa = logits_sdpa[:-1]
sub_eager = logits_eager[:-1]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
sub_sdpa = logits_sdpa[-1, 3:]
sub_eager = logits_eager[-1, 3:]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
# Testing the padding tokens is not really meaningful but anyway
# sub_sdpa = logits_sdpa[-1, :3]
# sub_eager = logits_eager[-1, :3]
# if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
# fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
else:
if not torch.allclose(logits_sdpa, logits_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
)
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
@require_torch_sdpa
@slow
# Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.test_eager_matches_sdpa_generate
def test_eager_matches_sdpa_generate(self):
max_new_tokens = 30
# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_sdpa:
self.skipTest(f"{model_class.__name__} does not support SDPA")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
model_sdpa = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation="eager",
).to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa:
raise ValueError("The SDPA model should have SDPA attention layers")
# Just test that a large cache works as expected
res_eager = model_eager.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
res_sdpa = model_sdpa.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
self.assertTrue(torch.allclose(res_eager, res_sdpa))
def prepare_musicgen_melody_inputs_dict(
config,
......@@ -923,6 +1538,639 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
self.assertNotIn(config.pad_token_id, output_generate)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence
def test_flash_attn_2_inference_equivalence(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
# Ignore copy
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is not None:
# Ignore copy
dummy_attention_mask[:, 1:] = 1
dummy_attention_mask[:, :1] = 0
# Ignore copy
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)
# Ignore copy
outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
# Ignore copy
outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs_fa.decoder_hidden_states[-1]
)
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
# Ignore copy
other_inputs = {
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True,
}
# Ignore copy
if dummy_attention_mask is not None:
other_inputs["attention_mask"] = dummy_attention_mask
# Ignore copy
outputs = model(dummy_input, **other_inputs)
# Ignore copy
outputs_fa = model_fa(dummy_input, **other_inputs)
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs_fa.decoder_hidden_states[-1]
)
assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
# check with inference + dropout
model.train()
_ = model_fa(dummy_input, **other_inputs)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence_right_padding
def test_flash_attn_2_inference_equivalence_right_padding(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
# Ignore copy
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is not None:
# Ignore copy
dummy_attention_mask[:, :-1] = 1
dummy_attention_mask[:, -1:] = 0
# Ignore copy
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)
# Ignore copy
outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
# Ignore copy
outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs_fa.decoder_hidden_states[-1]
)
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
# Ignore copy
other_inputs = {
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True,
}
# Ignore copy
if dummy_attention_mask is not None:
other_inputs["attention_mask"] = dummy_attention_mask
# Ignore copy
outputs = model(dummy_input, **other_inputs)
# Ignore copy
outputs_fa = model_fa(dummy_input, **other_inputs)
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs_fa.decoder_hidden_states[-1]
)
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_left_padding
def test_flash_attn_2_generate_left_padding(self):
# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
dummy_attention_mask = inputs_dict.get("attention_mask")
if dummy_attention_mask is None:
dummy_attention_mask = torch.ones_like(dummy_input)
# make sure we do left padding
dummy_attention_mask[:, :-1] = 0
dummy_attention_mask[:, -1:] = 1
out = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
out_fa = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
self.assertTrue(torch.allclose(out, out_fa))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_padding_right
def test_flash_attn_2_generate_padding_right(self):
# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
dummy_attention_mask = inputs_dict.get("attention_mask")
if dummy_attention_mask is None:
dummy_attention_mask = torch.ones_like(dummy_input)
# make sure we do right padding
dummy_attention_mask[:, :-1] = 1
dummy_attention_mask[:, -1:] = 0
out = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
out_fa = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
self.assertTrue(torch.allclose(out, out_fa))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_use_cache
def test_flash_attn_2_generate_use_cache(self):
max_new_tokens = 30
# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
# Just test that a large cache works as expected
_ = model.generate(
dummy_input,
attention_mask=dummy_attention_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
)
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_eager_matches_sdpa_inference
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
if not self.all_model_classes[0]._supports_sdpa:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device):
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device):
self.skipTest(
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
)
# Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead.
if torch_dtype == "float16":
torch_dtype = torch.float16
elif torch_dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif torch_dtype == "float32":
torch_dtype = torch.float32
atols = {
("cpu", False, torch.float32): 1e-6,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-6,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-6,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-6,
("cuda", True, torch.bfloat16): 1e-2,
("cuda", True, torch.float16): 5e-3,
}
rtols = {
("cpu", False, torch.float32): 1e-4,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-4,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-4,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-4,
("cuda", True, torch.bfloat16): 3e-2,
("cuda", True, torch.float16): 5e-3,
}
def get_mean_reldiff(failcase, x, ref, atol, rtol):
return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
is_encoder_decoder = model.config.is_encoder_decoder
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
model_sdpa = model_sdpa.eval().to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = []
for padding_side in ["left", "right"]:
for use_mask in [False, True]:
for batch_size in [1, 5]:
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
dummy_input = dummy_input.to(torch_dtype)
dummy_input = dummy_input[:batch_size]
if dummy_input.shape[0] != batch_size:
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
extension = torch.rand(
batch_size - dummy_input.shape[0],
*dummy_input.shape[1:],
dtype=torch_dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
else:
extension = torch.randint(
high=5,
size=(batch_size - dummy_input.shape[0], *dummy_input.shape[1:]),
dtype=dummy_input.dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
if not use_mask:
dummy_attention_mask = None
else:
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is None:
# Ignore copy
seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1]
# Ignore copy
dummy_attention_mask = (
torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device)
)
dummy_attention_mask = dummy_attention_mask[:batch_size]
if dummy_attention_mask.shape[0] != batch_size:
extension = torch.ones(
batch_size - dummy_attention_mask.shape[0],
*dummy_attention_mask.shape[1:],
dtype=dummy_attention_mask.dtype,
device=torch_device,
)
dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0)
dummy_attention_mask = dummy_attention_mask.to(torch_device)
dummy_attention_mask[:] = 1
if padding_side == "left":
dummy_attention_mask[-1, :-1] = 1
dummy_attention_mask[-1, -4:] = 0
elif padding_side == "right":
dummy_attention_mask[-1, 1:] = 1
dummy_attention_mask[-1, :3] = 0
for enable_kernels in [False, True]:
failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}"
# Ignore copy
batch_size_input_ids = self.model_tester.num_codebooks * batch_size
# Ignore copy
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[
:batch_size_input_ids
]
# Ignore copy
if decoder_input_ids.shape[0] != batch_size_input_ids:
# Ignore copy
extension = torch.ones(
batch_size_input_ids - decoder_input_ids.shape[0],
*decoder_input_ids.shape[1:],
dtype=decoder_input_ids.dtype,
device=torch_device,
)
decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0)
decoder_input_ids = decoder_input_ids.to(torch_device)
# TODO: never an `attention_mask` arg here?
# Ignore copy
other_inputs = {
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True,
}
# TODO: test gradients as well (& for FA2 as well!)
# Ignore copy
with torch.no_grad():
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
):
outputs_eager = model_eager(dummy_input, **other_inputs)
outputs_sdpa = model_sdpa(dummy_input, **other_inputs)
logits_eager = (
outputs_eager.hidden_states[-1]
if not is_encoder_decoder
else outputs_eager.decoder_hidden_states[-1]
)
logits_sdpa = (
outputs_sdpa.hidden_states[-1]
if not is_encoder_decoder
else outputs_sdpa.decoder_hidden_states[-1]
)
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
# Masked tokens output slightly deviates - we don't mind that.
if use_mask:
if padding_side == "left":
sub_sdpa = logits_sdpa[:-1]
sub_eager = logits_eager[:-1]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
sub_sdpa = logits_sdpa[-1, :-4]
sub_eager = logits_eager[-1, :-4]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
# Testing the padding tokens is not really meaningful but anyway
# sub_sdpa = logits_sdpa[-1, -4:]
# sub_eager = logits_eager[-1, -4:]
# if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
# fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
elif padding_side == "right":
sub_sdpa = logits_sdpa[:-1]
sub_eager = logits_eager[:-1]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
sub_sdpa = logits_sdpa[-1, 3:]
sub_eager = logits_eager[-1, 3:]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
# Testing the padding tokens is not really meaningful but anyway
# sub_sdpa = logits_sdpa[-1, :3]
# sub_eager = logits_eager[-1, :3]
# if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
# fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
else:
if not torch.allclose(logits_sdpa, logits_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
)
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
@require_torch_sdpa
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_eager_matches_sdpa_generate
def test_eager_matches_sdpa_generate(self):
max_new_tokens = 30
# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_sdpa:
self.skipTest(f"{model_class.__name__} does not support SDPA")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
model_sdpa = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation="eager",
).to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa:
raise ValueError("The SDPA model should have SDPA attention layers")
# Just test that a large cache works as expected
res_eager = model_eager.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
res_sdpa = model_sdpa.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
self.assertTrue(torch.allclose(res_eager, res_sdpa))
# Copied from tests.models.musicgen.test_modeling_musicgen.get_bip_bip
def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment