Unverified Commit e3143952 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

Refactor flash attention implementation in transformers (#31446)



* dumb commit

* nit

* update

* something like this

* unpack in modeling utils

* safe import

* oups

* update

* nits

* diff convert gemma

* update

* start propagating

* udpate other modeling code as well

* update for sliding window models

* nits

* more init cleanups

* styling

* fixup

* noice

* pass fixup

* typo typing_extension -> typing_extensions

* torch.nn.functionnal -> torch.nn.functional

* add to import structure

* unpack

* simplify a bit more for this first version

* nut

* update

* update

* nit

* ease the import of `Unpack`

* remove useless `use_sliding_window`

* no qua please

* protect import?

* style

* [run-slow]

* [run slow] llama,gemma,mistral,mixtral

* remove extra kwargs

* fix llama

* address review comments

* apply diff_model_converter to modeling_gemma.py

* remove cache_position 1

* remove cache_position 2

* some cleaning

* refactor gemma2 as well

* apply review comments

* rename file to modeling_flash_attention_utils.py

* siglip refactor

* remove dead code

* is the hub down?

* still down?

* fix siglip

* fix gemma2

* fatal: Could not read from remote repository.

* fix typo in softcap implem

* flacky

* Failed: Timeout >120.0s

---------
Co-authored-by: default avatarfxmarty <9808326+fxmarty@users.noreply.github.com>
parent ad4ef3a2
...@@ -21,7 +21,6 @@ from typing import Any, Optional, Tuple, Union ...@@ -21,7 +21,6 @@ from typing import Any, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
...@@ -44,8 +43,7 @@ from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionCo ...@@ -44,8 +43,7 @@ from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionCo
if is_flash_attn_2_available(): if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func from ...modeling_flash_attention_utils import _flash_attention_forward
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -55,19 +53,6 @@ _CONFIG_FOR_DOC = "SiglipConfig" ...@@ -55,19 +53,6 @@ _CONFIG_FOR_DOC = "SiglipConfig"
_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224" _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def _trunc_normal_(tensor, mean, std, a, b): def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW # Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
...@@ -516,8 +501,15 @@ class SiglipFlashAttention2(SiglipAttention): ...@@ -516,8 +501,15 @@ class SiglipFlashAttention2(SiglipAttention):
key_states = key_states.to(target_dtype) key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype) value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward( attn_output = _flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
) )
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous() attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()
...@@ -528,106 +520,6 @@ class SiglipFlashAttention2(SiglipAttention): ...@@ -528,106 +520,6 @@ class SiglipFlashAttention2(SiglipAttention):
return attn_output, attn_weights return attn_output, attn_weights
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)
return attn_output
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
class SiglipSdpaAttention(SiglipAttention): class SiglipSdpaAttention(SiglipAttention):
""" """
......
...@@ -23,7 +23,6 @@ import math ...@@ -23,7 +23,6 @@ import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
...@@ -52,8 +51,7 @@ from .configuration_stablelm import StableLmConfig ...@@ -52,8 +51,7 @@ from .configuration_stablelm import StableLmConfig
if is_flash_attn_2_available(): if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func from ...modeling_flash_attention_utils import _flash_attention_forward
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -61,19 +59,6 @@ logger = logging.get_logger(__name__) ...@@ -61,19 +59,6 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "StableLmConfig" _CONFIG_FOR_DOC = "StableLmConfig"
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->StableLm # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->StableLm
class StableLmRotaryEmbedding(nn.Module): class StableLmRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
...@@ -615,13 +600,15 @@ class StableLmFlashAttention2(StableLmAttention): ...@@ -615,13 +600,15 @@ class StableLmFlashAttention2(StableLmAttention):
dropout_rate = self.attention_dropout.p if self.training else 0.0 dropout_rate = self.attention_dropout.p if self.training else 0.0
attn_output = self._flash_attention_forward( attn_output = _flash_attention_forward(
query_states, query_states,
key_states, key_states,
value_states, value_states,
attention_mask, attention_mask,
q_len, q_len,
dropout=dropout_rate, dropout=dropout_rate,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
) )
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
...@@ -632,105 +619,6 @@ class StableLmFlashAttention2(StableLmAttention): ...@@ -632,105 +619,6 @@ class StableLmFlashAttention2(StableLmAttention):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)
return attn_output
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
ATTENTION_CLASSES = { ATTENTION_CLASSES = {
"eager": StableLmAttention, "eager": StableLmAttention,
......
...@@ -19,12 +19,10 @@ ...@@ -19,12 +19,10 @@
# limitations under the License. # limitations under the License.
"""PyTorch Starcoder2 model.""" """PyTorch Starcoder2 model."""
import inspect
import math import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
...@@ -53,10 +51,7 @@ from .configuration_starcoder2 import Starcoder2Config ...@@ -53,10 +51,7 @@ from .configuration_starcoder2 import Starcoder2Config
if is_flash_attn_2_available(): if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func from ...modeling_flash_attention_utils import _flash_attention_forward
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -64,19 +59,6 @@ logger = logging.get_logger(__name__) ...@@ -64,19 +59,6 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Starcoder2Config" _CONFIG_FOR_DOC = "Starcoder2Config"
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Starcoder2 # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Starcoder2
class Starcoder2RotaryEmbedding(nn.Module): class Starcoder2RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
...@@ -301,7 +283,6 @@ class Starcoder2Attention(nn.Module): ...@@ -301,7 +283,6 @@ class Starcoder2Attention(nn.Module):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Starcoder2
class Starcoder2FlashAttention2(Starcoder2Attention): class Starcoder2FlashAttention2(Starcoder2Attention):
""" """
Starcoder2 flash attention module. This module inherits from `Starcoder2Attention` as the weights of the module stays Starcoder2 flash attention module. This module inherits from `Starcoder2Attention` as the weights of the module stays
...@@ -355,18 +336,6 @@ class Starcoder2FlashAttention2(Starcoder2Attention): ...@@ -355,18 +336,6 @@ class Starcoder2FlashAttention2(Starcoder2Attention):
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
use_sliding_windows = (
_flash_supports_window_size
and getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window
)
if not _flash_supports_window_size:
logger.warning_once(
"The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
" make sure to upgrade flash-attn library."
)
if past_key_value is not None: if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute # Activate slicing cache only if the config has a value `sliding_windows` attribute
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
...@@ -429,14 +398,16 @@ class Starcoder2FlashAttention2(Starcoder2Attention): ...@@ -429,14 +398,16 @@ class Starcoder2FlashAttention2(Starcoder2Attention):
key_states = key_states.transpose(1, 2) key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2) value_states = value_states.transpose(1, 2)
attn_output = self._flash_attention_forward( attn_output = _flash_attention_forward(
query_states, query_states,
key_states, key_states,
value_states, value_states,
attention_mask, attention_mask,
q_len, q_len,
dropout=dropout_rate, dropout=dropout_rate,
use_sliding_windows=use_sliding_windows, sliding_window=getattr(self.config, "sliding_window", None),
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
) )
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
...@@ -448,148 +419,6 @@ class Starcoder2FlashAttention2(Starcoder2Attention): ...@@ -448,148 +419,6 @@ class Starcoder2FlashAttention2(Starcoder2Attention):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
def _flash_attention_forward(
self,
query_states,
key_states,
value_states,
attention_mask,
query_length,
dropout=0.0,
softmax_scale=None,
use_sliding_windows=False,
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
use_sliding_windows (`bool`, *optional*):
Whether to activate sliding window attention.
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
if not use_sliding_windows:
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
else:
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
window_size=(self.config.sliding_window, self.config.sliding_window),
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
if not use_sliding_windows:
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
dropout,
softmax_scale=softmax_scale,
causal=causal,
)
else:
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
dropout,
softmax_scale=softmax_scale,
causal=causal,
window_size=(self.config.sliding_window, self.config.sliding_window),
)
return attn_output
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
# On the first iteration we need to properly re-create the padding mask
# by slicing it on the proper place
if kv_seq_len != attention_mask.shape[-1]:
attention_mask_num_tokens = attention_mask.shape[-1]
attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Starcoder2 # Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Starcoder2
class Starcoder2SdpaAttention(Starcoder2Attention): class Starcoder2SdpaAttention(Starcoder2Attention):
......
...@@ -21,7 +21,6 @@ from typing import Optional, Tuple, Union ...@@ -21,7 +21,6 @@ from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
...@@ -44,8 +43,7 @@ from .configuration_unispeech import UniSpeechConfig ...@@ -44,8 +43,7 @@ from .configuration_unispeech import UniSpeechConfig
if is_flash_attn_2_available(): if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func from ...modeling_flash_attention_utils import _flash_attention_forward
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -65,19 +63,6 @@ _CTC_EXPECTED_OUTPUT = "'mister quilter is the apposl of the midle classes and w ...@@ -65,19 +63,6 @@ _CTC_EXPECTED_OUTPUT = "'mister quilter is the apposl of the midle classes and w
_CTC_EXPECTED_LOSS = 17.17 _CTC_EXPECTED_LOSS = 17.17
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
@dataclass @dataclass
class UniSpeechForPreTrainingOutput(ModelOutput): class UniSpeechForPreTrainingOutput(ModelOutput):
""" """
...@@ -709,8 +694,15 @@ class UniSpeechFlashAttention2(UniSpeechAttention): ...@@ -709,8 +694,15 @@ class UniSpeechFlashAttention2(UniSpeechAttention):
key_states = key_states.to(target_dtype) key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype) value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward( attn_output = _flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=self.dropout,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
) )
attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = attn_output.reshape(bsz, q_len, -1)
...@@ -721,104 +713,6 @@ class UniSpeechFlashAttention2(UniSpeechAttention): ...@@ -721,104 +713,6 @@ class UniSpeechFlashAttention2(UniSpeechAttention):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)
return attn_output
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
class UniSpeechSdpaAttention(UniSpeechAttention): class UniSpeechSdpaAttention(UniSpeechAttention):
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->UniSpeech # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->UniSpeech
......
...@@ -21,7 +21,6 @@ from typing import Optional, Tuple, Union ...@@ -21,7 +21,6 @@ from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
...@@ -52,8 +51,7 @@ from .configuration_unispeech_sat import UniSpeechSatConfig ...@@ -52,8 +51,7 @@ from .configuration_unispeech_sat import UniSpeechSatConfig
if is_flash_attn_2_available(): if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func from ...modeling_flash_attention_utils import _flash_attention_forward
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -81,19 +79,6 @@ _XVECTOR_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sv" ...@@ -81,19 +79,6 @@ _XVECTOR_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sv"
_XVECTOR_EXPECTED_OUTPUT = 0.97 _XVECTOR_EXPECTED_OUTPUT = 0.97
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
@dataclass @dataclass
class UniSpeechSatForPreTrainingOutput(ModelOutput): class UniSpeechSatForPreTrainingOutput(ModelOutput):
""" """
...@@ -726,8 +711,15 @@ class UniSpeechSatFlashAttention2(UniSpeechSatAttention): ...@@ -726,8 +711,15 @@ class UniSpeechSatFlashAttention2(UniSpeechSatAttention):
key_states = key_states.to(target_dtype) key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype) value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward( attn_output = _flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=self.dropout,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
) )
attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = attn_output.reshape(bsz, q_len, -1)
...@@ -738,104 +730,6 @@ class UniSpeechSatFlashAttention2(UniSpeechSatAttention): ...@@ -738,104 +730,6 @@ class UniSpeechSatFlashAttention2(UniSpeechSatAttention):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)
return attn_output
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
class UniSpeechSatSdpaAttention(UniSpeechSatAttention): class UniSpeechSatSdpaAttention(UniSpeechSatAttention):
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->UniSpeechSat # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->UniSpeechSat
......
...@@ -21,7 +21,6 @@ from typing import Optional, Tuple, Union ...@@ -21,7 +21,6 @@ from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
...@@ -63,8 +62,7 @@ if is_safetensors_available(): ...@@ -63,8 +62,7 @@ if is_safetensors_available():
if is_flash_attn_2_available(): if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func from ...modeling_flash_attention_utils import _flash_attention_forward
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -96,19 +94,6 @@ _XVECTOR_CHECKPOINT = "anton-l/wav2vec2-base-superb-sv" ...@@ -96,19 +94,6 @@ _XVECTOR_CHECKPOINT = "anton-l/wav2vec2-base-superb-sv"
_XVECTOR_EXPECTED_OUTPUT = 0.98 _XVECTOR_EXPECTED_OUTPUT = 0.98
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
@dataclass @dataclass
class Wav2Vec2ForPreTrainingOutput(ModelOutput): class Wav2Vec2ForPreTrainingOutput(ModelOutput):
""" """
...@@ -773,8 +758,15 @@ class Wav2Vec2FlashAttention2(Wav2Vec2Attention): ...@@ -773,8 +758,15 @@ class Wav2Vec2FlashAttention2(Wav2Vec2Attention):
key_states = key_states.to(target_dtype) key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype) value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward( attn_output = _flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=self.dropout,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
) )
attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = attn_output.reshape(bsz, q_len, -1)
...@@ -785,105 +777,6 @@ class Wav2Vec2FlashAttention2(Wav2Vec2Attention): ...@@ -785,105 +777,6 @@ class Wav2Vec2FlashAttention2(Wav2Vec2Attention):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)
return attn_output
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
class Wav2Vec2SdpaAttention(Wav2Vec2Attention): class Wav2Vec2SdpaAttention(Wav2Vec2Attention):
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->Wav2Vec2 # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->Wav2Vec2
......
...@@ -19,7 +19,6 @@ from typing import Optional, Tuple, Union ...@@ -19,7 +19,6 @@ from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
...@@ -49,8 +48,7 @@ from .generation_whisper import WhisperGenerationMixin ...@@ -49,8 +48,7 @@ from .generation_whisper import WhisperGenerationMixin
if is_flash_attn_2_available(): if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func from ...modeling_flash_attention_utils import _flash_attention_forward
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -61,19 +59,6 @@ _CONFIG_FOR_DOC = "WhisperConfig" ...@@ -61,19 +59,6 @@ _CONFIG_FOR_DOC = "WhisperConfig"
_CHECKPOINT_FOR_DOC = "openai/whisper-tiny" _CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor: def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor:
"""Returns sinusoids for positional embedding""" """Returns sinusoids for positional embedding"""
if channels % 2 != 0: if channels % 2 != 0:
...@@ -465,8 +450,15 @@ class WhisperFlashAttention2(WhisperAttention): ...@@ -465,8 +450,15 @@ class WhisperFlashAttention2(WhisperAttention):
key_states = key_states.to(target_dtype) key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype) value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward( attn_output = _flash_attention_forward(
query_states, key_states, value_states, causal_mask, tgt_len, dropout=self.dropout query_states,
key_states,
value_states,
causal_mask,
tgt_len,
dropout=self.dropout,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
) )
attn_output = attn_output.reshape(bsz, tgt_len, -1) attn_output = attn_output.reshape(bsz, tgt_len, -1)
...@@ -477,105 +469,6 @@ class WhisperFlashAttention2(WhisperAttention): ...@@ -477,105 +469,6 @@ class WhisperFlashAttention2(WhisperAttention):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)
return attn_output
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
class WhisperSdpaAttention(WhisperAttention): class WhisperSdpaAttention(WhisperAttention):
def forward( def forward(
......
...@@ -812,11 +812,11 @@ def is_flash_attn_greater_or_equal_2_10(): ...@@ -812,11 +812,11 @@ def is_flash_attn_greater_or_equal_2_10():
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0") return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
def is_flash_attn_greater_or_equal(version): def is_flash_attn_greater_or_equal(library_version: str):
if not _is_package_available("flash_attn"): if not _is_package_available("flash_attn"):
return False return False
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(version) return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version)
def is_torchdistx_available(): def is_torchdistx_available():
......
...@@ -41,6 +41,8 @@ SPECIAL_CASES_TO_ALLOW = { ...@@ -41,6 +41,8 @@ SPECIAL_CASES_TO_ALLOW = {
"expert_layer_offset", "expert_layer_offset",
"expert_layer_period", "expert_layer_period",
], ],
"Qwen2Config": ["use_sliding_window"],
"Qwen2MoeConfig": ["use_sliding_window"],
"Gemma2Config": ["tie_word_embeddings"], "Gemma2Config": ["tie_word_embeddings"],
# used to compute the property `self.chunk_length` # used to compute the property `self.chunk_length`
"EncodecConfig": ["overlap"], "EncodecConfig": ["overlap"],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment