Unverified Commit b65df514 authored by Alexander Visheratin's avatar Alexander Visheratin Committed by GitHub
Browse files

Add Flash Attention 2 to M2M100 model (#30256)



* Added flash attention 2.

* Fixes.

* Fix inheritance.

* Fixed init.

* Remove stuff.

* Added documentation.

* Add FA2 to M2M100 documentation.

* Add test.

* Fixed documentation.

* Update src/transformers/models/m2m_100/modeling_m2m_100.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update docs/source/en/model_doc/nllb.md
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Fixed variable name.

---------
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent ec92f983
...@@ -121,3 +121,45 @@ Hindi to French and Chinese to English using the *facebook/m2m100_418M* checkpoi ...@@ -121,3 +121,45 @@ Hindi to French and Chinese to English using the *facebook/m2m100_418M* checkpoi
[[autodoc]] M2M100ForConditionalGeneration [[autodoc]] M2M100ForConditionalGeneration
- forward - forward
## Using Flash Attention 2
Flash Attention 2 is a faster, optimized version of the attention scores computation which relies on `cuda` kernels.
### Installation
First, check whether your hardware is compatible with Flash Attention 2. The latest list of compatible hardware can be found in the [official documentation](https://github.com/Dao-AILab/flash-attention#installation-and-features).
Next, [install](https://github.com/Dao-AILab/flash-attention#installation-and-features) the latest version of Flash Attention 2:
```bash
pip install -U flash-attn --no-build-isolation
```
### Usage
To load a model using Flash Attention 2, we can pass the argument `attn_implementation="flash_attention_2"` to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). You can use either `torch.float16` or `torch.bfloat16` precision.
```python
>>> import torch
>>> from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
>>> model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to("cuda").eval()
>>> tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
>>> # translate Hindi to French
>>> hi_text = "जीवन एक चॉकलेट बॉक्स की तरह है।"
>>> tokenizer.src_lang = "hi"
>>> encoded_hi = tokenizer(hi_text, return_tensors="pt").to("cuda")
>>> generated_tokens = model.generate(**encoded_hi, forced_bos_token_id=tokenizer.get_lang_id("fr"))
>>> tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
"La vie est comme une boîte de chocolat."
```
### Expected speedups
Below is an expected speedup diagram that compares pure inference time between the native implementation and the Flash Attention 2.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/visheratin/documentation-images/resolve/main/nllb-speedup.webp">
</div>
...@@ -145,3 +145,46 @@ UN-Chef sagt, es gibt keine militärische Lösung in Syrien ...@@ -145,3 +145,46 @@ UN-Chef sagt, es gibt keine militärische Lösung in Syrien
## NllbTokenizerFast ## NllbTokenizerFast
[[autodoc]] NllbTokenizerFast [[autodoc]] NllbTokenizerFast
## Using Flash Attention 2
Flash Attention 2 is a faster, optimized version of the attention scores computation which relies on `cuda` kernels.
### Installation
First, check whether your hardware is compatible with Flash Attention 2. The latest list of compatible hardware can be found in the [official documentation](https://github.com/Dao-AILab/flash-attention#installation-and-features).
Next, [install](https://github.com/Dao-AILab/flash-attention#installation-and-features) the latest version of Flash Attention 2:
```bash
pip install -U flash-attn --no-build-isolation
```
### Usage
To load a model using Flash Attention 2, we can pass the argument `attn_implementation="flash_attention_2"` to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). You can use either `torch.float16` or `torch.bfloat16` precision.
```python
>>> import torch
>>> from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
>>> model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to("cuda").eval()
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
>>> article = "Şeful ONU spune că nu există o soluţie militară în Siria"
>>> inputs = tokenizer(article, return_tensors="pt").to("cuda")
>>> translated_tokens = model.generate(
... **inputs, forced_bos_token_id=tokenizer.lang_code_to_id["deu_Latn"], max_length=30
... )
>>> tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
"UN-Chef sagt, es gibt keine militärische Lösung in Syrien"
```
### Expected speedups
Below is an expected speedup diagram that compares pure inference time between the native implementation and the Flash Attention 2.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/visheratin/documentation-images/resolve/main/nllb-speedup.webp">
</div>
\ No newline at end of file
...@@ -53,11 +53,13 @@ FlashAttention-2 is currently supported for the following architectures: ...@@ -53,11 +53,13 @@ FlashAttention-2 is currently supported for the following architectures:
* [Llava](https://huggingface.co/docs/transformers/model_doc/llava) * [Llava](https://huggingface.co/docs/transformers/model_doc/llava)
* [Llava-NeXT](https://huggingface.co/docs/transformers/model_doc/llava_next) * [Llava-NeXT](https://huggingface.co/docs/transformers/model_doc/llava_next)
* [VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava) * [VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)
* [M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)
* [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel) * [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel) * [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel) * [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) * [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel) * [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
* [NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel) * [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel) * [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel) * [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
......
...@@ -12,13 +12,13 @@ ...@@ -12,13 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" PyTorch M2M100 model.""" """PyTorch M2M100 model."""
import math import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
...@@ -37,12 +37,19 @@ from ...utils import ( ...@@ -37,12 +37,19 @@ from ...utils import (
add_end_docstrings, add_end_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from .configuration_m2m_100 import M2M100Config from .configuration_m2m_100 import M2M100Config
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "M2M100Config" _CONFIG_FOR_DOC = "M2M100Config"
...@@ -317,6 +324,208 @@ class M2M100Attention(nn.Module): ...@@ -317,6 +324,208 @@ class M2M100Attention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
class M2M100FlashAttention2(M2M100Attention):
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[M2M100Config] = None,
):
super().__init__(embed_dim, num_heads, dropout, is_decoder, bias, is_causal, config)
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, q_len, _ = hidden_states.size()
# get query proj
query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0].transpose(1, 2)
value_states = past_key_value[1].transpose(1, 2)
elif is_cross_attention:
# cross_attentions
key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
else:
# self_attention
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout, softmax_scale=None
)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)
return attn_output
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->M2M100, MBART->M2M100 # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->M2M100, MBART->M2M100
class M2M100EncoderLayer(nn.Module): class M2M100EncoderLayer(nn.Module):
def __init__(self, config: M2M100Config): def __init__(self, config: M2M100Config):
...@@ -388,7 +597,10 @@ class M2M100EncoderLayer(nn.Module): ...@@ -388,7 +597,10 @@ class M2M100EncoderLayer(nn.Module):
return outputs return outputs
M2M100_ATTENTION_CLASSES = {"eager": M2M100Attention} M2M100_ATTENTION_CLASSES = {
"eager": M2M100Attention,
"flash_attention_2": M2M100FlashAttention2,
}
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->M2M100, MBART->M2M100 # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->M2M100, MBART->M2M100
...@@ -517,6 +729,7 @@ class M2M100PreTrainedModel(PreTrainedModel): ...@@ -517,6 +729,7 @@ class M2M100PreTrainedModel(PreTrainedModel):
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["M2M100Attention"] _no_split_modules = ["M2M100Attention"]
_supports_flash_attn_2 = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
...@@ -687,6 +900,7 @@ class M2M100Encoder(M2M100PreTrainedModel): ...@@ -687,6 +900,7 @@ class M2M100Encoder(M2M100PreTrainedModel):
) )
self.layers = nn.ModuleList([M2M100EncoderLayer(config) for _ in range(config.encoder_layers)]) self.layers = nn.ModuleList([M2M100EncoderLayer(config) for _ in range(config.encoder_layers)])
self.layer_norm = nn.LayerNorm(config.d_model) self.layer_norm = nn.LayerNorm(config.d_model)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.gradient_checkpointing = False self.gradient_checkpointing = False
# Initialize weights and apply final processing # Initialize weights and apply final processing
...@@ -767,8 +981,11 @@ class M2M100Encoder(M2M100PreTrainedModel): ...@@ -767,8 +981,11 @@ class M2M100Encoder(M2M100PreTrainedModel):
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] if self._use_flash_attention_2:
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) attention_mask = attention_mask if 0 in attention_mask else None
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
encoder_states = () if output_hidden_states else None encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
...@@ -857,6 +1074,7 @@ class M2M100Decoder(M2M100PreTrainedModel): ...@@ -857,6 +1074,7 @@ class M2M100Decoder(M2M100PreTrainedModel):
self.padding_idx, self.padding_idx,
) )
self.layers = nn.ModuleList([M2M100DecoderLayer(config) for _ in range(config.decoder_layers)]) self.layers = nn.ModuleList([M2M100DecoderLayer(config) for _ in range(config.decoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.layer_norm = nn.LayerNorm(config.d_model) self.layer_norm = nn.LayerNorm(config.d_model)
self.gradient_checkpointing = False self.gradient_checkpointing = False
...@@ -967,18 +1185,24 @@ class M2M100Decoder(M2M100PreTrainedModel): ...@@ -967,18 +1185,24 @@ class M2M100Decoder(M2M100PreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
# create causal mask if self._use_flash_attention_2:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # 2d mask is passed through the layers
combined_attention_mask = _prepare_4d_causal_attention_mask( combined_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
attention_mask, input_shape, inputs_embeds, past_key_values_length else:
) # 4d mask is passed through the layers
combined_attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] if self._use_flash_attention_2:
encoder_attention_mask = _prepare_4d_attention_mask( encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] else:
) # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# embed positions # embed positions
positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length) positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length)
...@@ -1102,6 +1326,11 @@ class M2M100Model(M2M100PreTrainedModel): ...@@ -1102,6 +1326,11 @@ class M2M100Model(M2M100PreTrainedModel):
self.encoder = M2M100Encoder(config, self.shared) self.encoder = M2M100Encoder(config, self.shared)
self.decoder = M2M100Decoder(config, self.shared) self.decoder = M2M100Decoder(config, self.shared)
if config._attn_implementation == "flash_attention_2":
logger.warning_once(
"Attention with Flash Attention 2 does not support `layer_head_mask`. If you need this feature, please use standard attention."
)
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
......
...@@ -19,12 +19,16 @@ import copy ...@@ -19,12 +19,16 @@ import copy
import tempfile import tempfile
import unittest import unittest
import pytest
from transformers import M2M100Config, is_torch_available from transformers import M2M100Config, is_torch_available
from transformers.testing_utils import ( from transformers.testing_utils import (
require_flash_attn,
require_sentencepiece, require_sentencepiece,
require_tokenizers, require_tokenizers,
require_torch, require_torch,
require_torch_fp16, require_torch_fp16,
require_torch_gpu,
slow, slow,
torch_device, torch_device,
) )
...@@ -412,3 +416,48 @@ class M2M100ModelIntegrationTests(unittest.TestCase): ...@@ -412,3 +416,48 @@ class M2M100ModelIntegrationTests(unittest.TestCase):
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
) )
assert generated == expected_en assert generated == expected_en
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_seq_to_seq_generation(self):
"""
Overwritting the common test as the test is flaky on tiny models
"""
model = M2M100ForConditionalGeneration.from_pretrained(
"facebook/m2m100_418M", attn_implementation="flash_attention_2"
).to(torch_device)
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", src_lang="fr", tgt_lang="en")
src_fr = [
"L'affaire NSA souligne l'absence totale de débat sur le renseignement",
"Selon moi, il y a deux niveaux de réponse de la part du gouvernement français.",
"Lorsque François Hollande téléphone à Barack Obama ou quand le ministre des affaires étrangères Laurent"
" Fabius convoque l'ambassadeur des Etats-Unis, ils réagissent à une vraie découverte, qui est celle de"
" l'ampleur de la surveillance américaine sur l'ensemble des communications en France.",
]
# The below article tests that we don't add any hypotheses outside of the top n_beams
dct = tokenizer(src_fr, padding=True, return_tensors="pt")
hypotheses_batch = model.generate(
input_ids=dct["input_ids"].to(torch_device),
attention_mask=dct["attention_mask"].to(torch_device),
num_beams=5,
forced_bos_token_id=tokenizer.get_lang_id("en"),
)
expected_en = [
"The NSA case highlights the total absence of intelligence debate",
"I think there are two levels of response from the French government.",
"When François Hollande calls Barack Obama or when Foreign Minister Laurent Fabius calls the U.S."
" Ambassador, they respond to a real discovery, which is that of the scale of U.S. surveillance on all"
" communications in France.",
]
generated = tokenizer.batch_decode(
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
)
assert generated == expected_en
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment