Unverified Commit e3143952 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

Refactor flash attention implementation in transformers (#31446)



* dumb commit

* nit

* update

* something like this

* unpack in modeling utils

* safe import

* oups

* update

* nits

* diff convert gemma

* update

* start propagating

* udpate other modeling code as well

* update for sliding window models

* nits

* more init cleanups

* styling

* fixup

* noice

* pass fixup

* typo typing_extension -> typing_extensions

* torch.nn.functionnal -> torch.nn.functional

* add to import structure

* unpack

* simplify a bit more for this first version

* nut

* update

* update

* nit

* ease the import of `Unpack`

* remove useless `use_sliding_window`

* no qua please

* protect import?

* style

* [run-slow]

* [run slow] llama,gemma,mistral,mixtral

* remove extra kwargs

* fix llama

* address review comments

* apply diff_model_converter to modeling_gemma.py

* remove cache_position 1

* remove cache_position 2

* some cleaning

* refactor gemma2 as well

* apply review comments

* rename file to modeling_flash_attention_utils.py

* siglip refactor

* remove dead code

* is the hub down?

* still down?

* fix siglip

* fix gemma2

* fatal: Could not read from remote repository.

* fix typo in softcap implem

* flacky

* Failed: Timeout >120.0s

---------
Co-authored-by: default avatarfxmarty <9808326+fxmarty@users.noreply.github.com>
parent ad4ef3a2
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
# limitations under the License. # limitations under the License.
"""PyTorch Jamba model.""" """PyTorch Jamba model."""
import inspect
import math import math
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
...@@ -43,23 +42,20 @@ from ...modeling_utils import PreTrainedModel ...@@ -43,23 +42,20 @@ from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_greater_or_equal_2_10,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from ...utils.import_utils import ( from ...utils.import_utils import (
is_causal_conv1d_available, is_causal_conv1d_available,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_mamba_ssm_available, is_mamba_ssm_available,
) )
from .configuration_jamba import JambaConfig from .configuration_jamba import JambaConfig
if is_flash_attn_2_available(): if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func from ...modeling_flash_attention_utils import _flash_attention_forward
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
if is_mamba_ssm_available(): if is_mamba_ssm_available():
...@@ -165,19 +161,6 @@ def load_balancing_loss_func( ...@@ -165,19 +161,6 @@ def load_balancing_loss_func(
return overall_loss * num_experts return overall_loss * num_experts
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Jamba # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Jamba
class JambaRMSNorm(nn.Module): class JambaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
...@@ -429,18 +412,6 @@ class JambaFlashAttention2(JambaAttention): ...@@ -429,18 +412,6 @@ class JambaFlashAttention2(JambaAttention):
kv_seq_len = cache_position[-1] kv_seq_len = cache_position[-1]
use_sliding_windows = (
_flash_supports_window_size
and getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window
)
if not _flash_supports_window_size:
logger.warning_once(
"The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
" make sure to upgrade flash-attn library."
)
if past_key_value is not None: if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute # Activate slicing cache only if the config has a value `sliding_windows` attribute
cache_has_contents = cache_position[0] > 0 cache_has_contents = cache_position[0] > 0
...@@ -502,14 +473,16 @@ class JambaFlashAttention2(JambaAttention): ...@@ -502,14 +473,16 @@ class JambaFlashAttention2(JambaAttention):
key_states = key_states.transpose(1, 2) key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2) value_states = value_states.transpose(1, 2)
attn_output = self._flash_attention_forward( attn_output = _flash_attention_forward(
query_states, query_states,
key_states, key_states,
value_states, value_states,
attention_mask, attention_mask,
q_len, q_len,
dropout=dropout_rate, dropout=dropout_rate,
use_sliding_windows=use_sliding_windows, sliding_window=getattr(self.config, "sliding_window", None),
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
) )
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
...@@ -520,149 +493,6 @@ class JambaFlashAttention2(JambaAttention): ...@@ -520,149 +493,6 @@ class JambaFlashAttention2(JambaAttention):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
def _flash_attention_forward(
self,
query_states,
key_states,
value_states,
attention_mask,
query_length,
dropout=0.0,
softmax_scale=None,
use_sliding_windows=False,
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`, *optional*):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
use_sliding_windows (`bool`, *optional*):
Whether to activate sliding window attention.
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
if not use_sliding_windows:
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
else:
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
window_size=(self.config.sliding_window, self.config.sliding_window),
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
if not use_sliding_windows:
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
dropout,
softmax_scale=softmax_scale,
causal=causal,
)
else:
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
dropout,
softmax_scale=softmax_scale,
causal=causal,
window_size=(self.config.sliding_window, self.config.sliding_window),
)
return attn_output
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
# On the first iteration we need to properly re-create the padding mask
# by slicing it on the proper place
if kv_seq_len != attention_mask.shape[-1]:
attention_mask_num_tokens = attention_mask.shape[-1]
attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
# Adapted from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Jamba # Adapted from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Jamba
class JambaSdpaAttention(JambaAttention): class JambaSdpaAttention(JambaAttention):
......
...@@ -46,8 +46,7 @@ from .configuration_jetmoe import JetMoeConfig ...@@ -46,8 +46,7 @@ from .configuration_jetmoe import JetMoeConfig
if is_flash_attn_2_available(): if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func from ...modeling_flash_attention_utils import _flash_attention_forward
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -358,19 +357,6 @@ class JetMoeMoA(nn.Module): ...@@ -358,19 +357,6 @@ class JetMoeMoA(nn.Module):
raise NotImplementedError("This module doesn't support call and forward.") raise NotImplementedError("This module doesn't support call and forward.")
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->JetMoe # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->JetMoe
class JetMoeRMSNorm(nn.Module): class JetMoeRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
...@@ -647,6 +633,7 @@ class JetMoeSdpaAttention(JetMoeAttention): ...@@ -647,6 +633,7 @@ class JetMoeSdpaAttention(JetMoeAttention):
class JetMoeFlashAttention2(JetMoeAttention): class JetMoeFlashAttention2(JetMoeAttention):
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
...@@ -739,8 +726,15 @@ class JetMoeFlashAttention2(JetMoeAttention): ...@@ -739,8 +726,15 @@ class JetMoeFlashAttention2(JetMoeAttention):
key_states = key_states.to(target_dtype) key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype) value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward( attn_output = _flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
).to(input_dtype) ).to(input_dtype)
# output projection # output projection
...@@ -753,105 +747,6 @@ class JetMoeFlashAttention2(JetMoeAttention): ...@@ -753,105 +747,6 @@ class JetMoeFlashAttention2(JetMoeAttention):
return attn_output, attn_weights, past_key_value, router_logits return attn_output, attn_weights, past_key_value, router_logits
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)
return attn_output
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
JETMOE_ATTENTION_CLASSES = { JETMOE_ATTENTION_CLASSES = {
"eager": JetMoeAttention, "eager": JetMoeAttention,
......
...@@ -29,6 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss ...@@ -29,6 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast, CausalLMOutputWithPast,
...@@ -41,7 +42,6 @@ from ...pytorch_utils import ALL_LAYERNORM_LAYERS ...@@ -41,7 +42,6 @@ from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import ( from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10, is_flash_attn_greater_or_equal_2_10,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
...@@ -49,28 +49,11 @@ from ...utils import ( ...@@ -49,28 +49,11 @@ from ...utils import (
from .configuration_llama import LlamaConfig from .configuration_llama import LlamaConfig
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LlamaConfig" _CONFIG_FOR_DOC = "LlamaConfig"
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
class LlamaRMSNorm(nn.Module): class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """
...@@ -464,8 +447,16 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -464,8 +447,16 @@ class LlamaFlashAttention2(LlamaAttention):
key_states = key_states.to(target_dtype) key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype) value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward( attn_output = _flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
) )
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
...@@ -476,103 +467,6 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -476,103 +467,6 @@ class LlamaFlashAttention2(LlamaAttention):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)
return attn_output
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
class LlamaSdpaAttention(LlamaAttention): class LlamaSdpaAttention(LlamaAttention):
""" """
...@@ -723,6 +617,7 @@ class LlamaDecoderLayer(nn.Module): ...@@ -723,6 +617,7 @@ class LlamaDecoderLayer(nn.Module):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
**kwargs,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
......
...@@ -389,8 +389,13 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): ...@@ -389,8 +389,13 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
image_newline=self.image_newline, image_newline=self.image_newline,
) )
inputs_embeds = inputs_embeds.to(image_features.dtype) inputs_embeds = inputs_embeds.to(image_features.dtype)
inputs_embeds, attention_mask, position_ids, labels, input_ids = ( (
self._merge_input_ids_with_image_features( inputs_embeds,
attention_mask,
position_ids,
labels,
input_ids,
) = self._merge_input_ids_with_image_features(
image_features, image_features,
feature_lens, feature_lens,
inputs_embeds, inputs_embeds,
...@@ -400,7 +405,6 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): ...@@ -400,7 +405,6 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
labels=labels, labels=labels,
image_token_index=self.config.image_token_index, image_token_index=self.config.image_token_index,
) )
)
# Then merge video tokens if there are any # Then merge video tokens if there are any
if pixel_values_videos is not None and pixel_values_videos.size(0) > 0: if pixel_values_videos is not None and pixel_values_videos.size(0) > 0:
video_features = self._get_video_features(pixel_values_videos) video_features = self._get_video_features(pixel_values_videos)
...@@ -408,8 +412,13 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): ...@@ -408,8 +412,13 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
feature_lens = [feature.size(0) for feature in video_features] feature_lens = [feature.size(0) for feature in video_features]
video_features = torch.cat(video_features, dim=0) video_features = torch.cat(video_features, dim=0)
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=video_features.device) feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=video_features.device)
inputs_embeds, attention_mask, position_ids, labels, input_ids = ( (
self._merge_input_ids_with_image_features( inputs_embeds,
attention_mask,
position_ids,
labels,
input_ids,
) = self._merge_input_ids_with_image_features(
video_features, video_features,
feature_lens, feature_lens,
inputs_embeds, inputs_embeds,
...@@ -419,7 +428,6 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): ...@@ -419,7 +428,6 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
labels=labels, labels=labels,
image_token_index=self.config.video_token_index, image_token_index=self.config.video_token_index,
) )
)
# pixel_values is not None but is empty ---> text only cases # pixel_values is not None but is empty ---> text only cases
elif (pixel_values is not None and pixel_values.size(0) == 0) or ( elif (pixel_values is not None and pixel_values.size(0) == 0) or (
......
...@@ -862,8 +862,13 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel): ...@@ -862,8 +862,13 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel):
image_newline=self.image_newline, image_newline=self.image_newline,
) )
inputs_embeds = inputs_embeds.to(image_features.dtype) inputs_embeds = inputs_embeds.to(image_features.dtype)
inputs_embeds, attention_mask, position_ids, labels, input_ids = ( (
self._merge_input_ids_with_image_features( inputs_embeds,
attention_mask,
position_ids,
labels,
input_ids,
) = self._merge_input_ids_with_image_features(
image_features, image_features,
feature_lens, feature_lens,
inputs_embeds, inputs_embeds,
...@@ -873,7 +878,6 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel): ...@@ -873,7 +878,6 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel):
labels=labels, labels=labels,
image_token_index=self.config.image_token_index, image_token_index=self.config.image_token_index,
) )
)
# Then merge video tokens if there are any # Then merge video tokens if there are any
if pixel_values_videos is not None and pixel_values_videos.size(0) > 0: if pixel_values_videos is not None and pixel_values_videos.size(0) > 0:
video_features = self._get_video_features(pixel_values_videos) video_features = self._get_video_features(pixel_values_videos)
...@@ -881,8 +885,13 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel): ...@@ -881,8 +885,13 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel):
feature_lens = [feature.size(0) for feature in video_features] feature_lens = [feature.size(0) for feature in video_features]
video_features = torch.cat(video_features, dim=0) video_features = torch.cat(video_features, dim=0)
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=video_features.device) feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=video_features.device)
inputs_embeds, attention_mask, position_ids, labels, input_ids = ( (
self._merge_input_ids_with_image_features( inputs_embeds,
attention_mask,
position_ids,
labels,
input_ids,
) = self._merge_input_ids_with_image_features(
video_features, video_features,
feature_lens, feature_lens,
inputs_embeds, inputs_embeds,
...@@ -892,7 +901,6 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel): ...@@ -892,7 +901,6 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel):
labels=labels, labels=labels,
image_token_index=self.config.video_token_index, image_token_index=self.config.video_token_index,
) )
)
# pixel_values is not None but is empty ---> text only cases # pixel_values is not None but is empty ---> text only cases
elif (pixel_values is not None and pixel_values.size(0) == 0) or ( elif (pixel_values is not None and pixel_values.size(0) == 0) or (
......
...@@ -18,7 +18,6 @@ import math ...@@ -18,7 +18,6 @@ import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
...@@ -46,8 +45,7 @@ from .configuration_m2m_100 import M2M100Config ...@@ -46,8 +45,7 @@ from .configuration_m2m_100 import M2M100Config
if is_flash_attn_2_available(): if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func from ...modeling_flash_attention_utils import _flash_attention_forward
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -335,31 +333,14 @@ class M2M100Attention(nn.Module): ...@@ -335,31 +333,14 @@ class M2M100Attention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
class M2M100FlashAttention2(M2M100Attention): class M2M100FlashAttention2(M2M100Attention):
def __init__( # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
self, def __init__(self, *args, **kwargs):
embed_dim: int, super().__init__(*args, **kwargs)
num_heads: int,
dropout: float = 0.0, # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
is_decoder: bool = False, # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
bias: bool = True, # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
is_causal: bool = False,
config: Optional[M2M100Config] = None,
):
super().__init__(embed_dim, num_heads, dropout, is_decoder, bias, is_causal, config)
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
...@@ -425,8 +406,16 @@ class M2M100FlashAttention2(M2M100Attention): ...@@ -425,8 +406,16 @@ class M2M100FlashAttention2(M2M100Attention):
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
attn_output = self._flash_attention_forward( attn_output = _flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout, softmax_scale=None query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=self.dropout,
softmax_scale=None,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
) )
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
...@@ -437,105 +426,6 @@ class M2M100FlashAttention2(M2M100Attention): ...@@ -437,105 +426,6 @@ class M2M100FlashAttention2(M2M100Attention):
return attn_output, None, past_key_value return attn_output, None, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)
return attn_output
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->M2M100, MBART->M2M100 # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->M2M100, MBART->M2M100
class M2M100EncoderLayer(nn.Module): class M2M100EncoderLayer(nn.Module):
......
...@@ -19,7 +19,6 @@ import math ...@@ -19,7 +19,6 @@ import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
...@@ -50,8 +49,7 @@ from .configuration_mbart import MBartConfig ...@@ -50,8 +49,7 @@ from .configuration_mbart import MBartConfig
if is_flash_attn_2_available(): if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func from ...modeling_flash_attention_utils import _flash_attention_forward
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -63,19 +61,6 @@ _CONFIG_FOR_DOC = "MBartConfig" ...@@ -63,19 +61,6 @@ _CONFIG_FOR_DOC = "MBartConfig"
_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024] _EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int): def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
""" """
Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not
...@@ -400,8 +385,15 @@ class MBartFlashAttention2(MBartAttention): ...@@ -400,8 +385,15 @@ class MBartFlashAttention2(MBartAttention):
key_states = key_states.to(target_dtype) key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype) value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward( attn_output = _flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=self.dropout,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
) )
attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = attn_output.reshape(bsz, q_len, -1)
...@@ -412,105 +404,6 @@ class MBartFlashAttention2(MBartAttention): ...@@ -412,105 +404,6 @@ class MBartFlashAttention2(MBartAttention):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)
return attn_output
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
MBART_ATTENTION_CLASSES = { MBART_ATTENTION_CLASSES = {
"eager": MBartAttention, "eager": MBartAttention,
......
...@@ -19,12 +19,10 @@ ...@@ -19,12 +19,10 @@
# limitations under the License. # limitations under the License.
"""PyTorch Mistral model.""" """PyTorch Mistral model."""
import inspect
import math import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
...@@ -51,30 +49,13 @@ from .configuration_mistral import MistralConfig ...@@ -51,30 +49,13 @@ from .configuration_mistral import MistralConfig
if is_flash_attn_2_available(): if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func from ...modeling_flash_attention_utils import _flash_attention_forward
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MistralConfig" _CONFIG_FOR_DOC = "MistralConfig"
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral
class MistralRMSNorm(nn.Module): class MistralRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
...@@ -337,18 +318,6 @@ class MistralFlashAttention2(MistralAttention): ...@@ -337,18 +318,6 @@ class MistralFlashAttention2(MistralAttention):
cos, sin = self.rotary_emb(value_states, position_ids) cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
use_sliding_windows = (
_flash_supports_window_size
and getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window
)
if not _flash_supports_window_size:
logger.warning_once(
"The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
" make sure to upgrade flash-attn library."
)
if past_key_value is not None: if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute # Activate slicing cache only if the config has a value `sliding_windows` attribute
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
...@@ -411,14 +380,16 @@ class MistralFlashAttention2(MistralAttention): ...@@ -411,14 +380,16 @@ class MistralFlashAttention2(MistralAttention):
key_states = key_states.transpose(1, 2) key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2) value_states = value_states.transpose(1, 2)
attn_output = self._flash_attention_forward( attn_output = _flash_attention_forward(
query_states, query_states,
key_states, key_states,
value_states, value_states,
attention_mask, attention_mask,
q_len, q_len,
dropout=dropout_rate, dropout=dropout_rate,
use_sliding_windows=use_sliding_windows, sliding_window=getattr(self.config, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
) )
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
...@@ -429,148 +400,6 @@ class MistralFlashAttention2(MistralAttention): ...@@ -429,148 +400,6 @@ class MistralFlashAttention2(MistralAttention):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
def _flash_attention_forward(
self,
query_states,
key_states,
value_states,
attention_mask,
query_length,
dropout=0.0,
softmax_scale=None,
use_sliding_windows=False,
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
use_sliding_windows (`bool`, *optional*):
Whether to activate sliding window attention.
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
if not use_sliding_windows:
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
else:
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
window_size=(self.config.sliding_window, self.config.sliding_window),
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
if not use_sliding_windows:
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
dropout,
softmax_scale=softmax_scale,
causal=causal,
)
else:
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
dropout,
softmax_scale=softmax_scale,
causal=causal,
window_size=(self.config.sliding_window, self.config.sliding_window),
)
return attn_output
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
# On the first iteration we need to properly re-create the padding mask
# by slicing it on the proper place
if kv_seq_len != attention_mask.shape[-1]:
attention_mask_num_tokens = attention_mask.shape[-1]
attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral
class MistralSdpaAttention(MistralAttention): class MistralSdpaAttention(MistralAttention):
...@@ -723,6 +552,7 @@ class MistralDecoderLayer(nn.Module): ...@@ -723,6 +552,7 @@ class MistralDecoderLayer(nn.Module):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
**kwargs,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
# limitations under the License. # limitations under the License.
"""PyTorch Mixtral model.""" """PyTorch Mixtral model."""
import inspect
import math import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
...@@ -47,7 +46,6 @@ from ...utils import ( ...@@ -47,7 +46,6 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
...@@ -56,10 +54,7 @@ from .configuration_mixtral import MixtralConfig ...@@ -56,10 +54,7 @@ from .configuration_mixtral import MixtralConfig
if is_flash_attn_2_available(): if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func from ...modeling_flash_attention_utils import _flash_attention_forward
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
# It means that the function will not be traced through and simply appear as a node in the graph. # It means that the function will not be traced through and simply appear as a node in the graph.
...@@ -151,19 +146,6 @@ def load_balancing_loss_func( ...@@ -151,19 +146,6 @@ def load_balancing_loss_func(
return overall_loss * num_experts return overall_loss * num_experts
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral
class MixtralRMSNorm(nn.Module): class MixtralRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
...@@ -402,15 +384,6 @@ class MixtralFlashAttention2(MixtralAttention): ...@@ -402,15 +384,6 @@ class MixtralFlashAttention2(MixtralAttention):
flash attention and deal with padding tokens in case the input contains any of them. flash attention and deal with padding tokens in case the input contains any of them.
""" """
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -447,18 +420,6 @@ class MixtralFlashAttention2(MixtralAttention): ...@@ -447,18 +420,6 @@ class MixtralFlashAttention2(MixtralAttention):
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
use_sliding_windows = (
_flash_supports_window_size
and getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window
)
if not _flash_supports_window_size:
logger.warning_once(
"The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
" make sure to upgrade flash-attn library."
)
if past_key_value is not None: if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute # Activate slicing cache only if the config has a value `sliding_windows` attribute
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
...@@ -521,14 +482,15 @@ class MixtralFlashAttention2(MixtralAttention): ...@@ -521,14 +482,15 @@ class MixtralFlashAttention2(MixtralAttention):
key_states = key_states.transpose(1, 2) key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2) value_states = value_states.transpose(1, 2)
attn_output = self._flash_attention_forward( attn_output = _flash_attention_forward(
query_states, query_states,
key_states, key_states,
value_states, value_states,
attention_mask, attention_mask,
q_len, q_len,
dropout=dropout_rate, dropout=dropout_rate,
use_sliding_windows=use_sliding_windows, sliding_window=getattr(self.config, "sliding_window", None),
is_causal=self.is_causal,
) )
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
...@@ -539,148 +501,6 @@ class MixtralFlashAttention2(MixtralAttention): ...@@ -539,148 +501,6 @@ class MixtralFlashAttention2(MixtralAttention):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
def _flash_attention_forward(
self,
query_states,
key_states,
value_states,
attention_mask,
query_length,
dropout=0.0,
softmax_scale=None,
use_sliding_windows=False,
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
use_sliding_windows (`bool`, *optional*):
Whether to activate sliding window attention.
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
if not use_sliding_windows:
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
else:
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
window_size=(self.config.sliding_window, self.config.sliding_window),
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
if not use_sliding_windows:
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
dropout,
softmax_scale=softmax_scale,
causal=causal,
)
else:
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
dropout,
softmax_scale=softmax_scale,
causal=causal,
window_size=(self.config.sliding_window, self.config.sliding_window),
)
return attn_output
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
# On the first iteration we need to properly re-create the padding mask
# by slicing it on the proper place
if kv_seq_len != attention_mask.shape[-1]:
attention_mask_num_tokens = attention_mask.shape[-1]
attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral # copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral
# TODO @longjie no longer copied from Mistral after static cache # TODO @longjie no longer copied from Mistral after static cache
......
...@@ -23,7 +23,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union ...@@ -23,7 +23,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
...@@ -58,8 +57,7 @@ from .configuration_musicgen import MusicgenConfig, MusicgenDecoderConfig ...@@ -58,8 +57,7 @@ from .configuration_musicgen import MusicgenConfig, MusicgenDecoderConfig
if is_flash_attn_2_available(): if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func from ...modeling_flash_attention_utils import _flash_attention_forward
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
if TYPE_CHECKING: if TYPE_CHECKING:
from ...generation.streamers import BaseStreamer from ...generation.streamers import BaseStreamer
...@@ -70,19 +68,6 @@ _CONFIG_FOR_DOC = "MusicgenConfig" ...@@ -70,19 +68,6 @@ _CONFIG_FOR_DOC = "MusicgenConfig"
_CHECKPOINT_FOR_DOC = "facebook/musicgen-small" _CHECKPOINT_FOR_DOC = "facebook/musicgen-small"
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
@dataclass @dataclass
class MusicgenUnconditionalInput(ModelOutput): class MusicgenUnconditionalInput(ModelOutput):
""" """
...@@ -434,8 +419,15 @@ class MusicgenFlashAttention2(MusicgenAttention): ...@@ -434,8 +419,15 @@ class MusicgenFlashAttention2(MusicgenAttention):
key_states = key_states.to(target_dtype) key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype) value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward( attn_output = _flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=self.dropout,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
) )
attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = attn_output.reshape(bsz, q_len, -1)
...@@ -446,104 +438,6 @@ class MusicgenFlashAttention2(MusicgenAttention): ...@@ -446,104 +438,6 @@ class MusicgenFlashAttention2(MusicgenAttention):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)
return attn_output
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
class MusicgenSdpaAttention(MusicgenAttention): class MusicgenSdpaAttention(MusicgenAttention):
def forward( def forward(
......
This diff is collapsed.
This diff is collapsed.
...@@ -119,7 +119,7 @@ class Qwen2Config(PretrainedConfig): ...@@ -119,7 +119,7 @@ class Qwen2Config(PretrainedConfig):
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window self.sliding_window = sliding_window if use_sliding_window else None
self.max_window_layers = max_window_layers self.max_window_layers = max_window_layers
# for backward compatibility # for backward compatibility
......
...@@ -149,7 +149,7 @@ class Qwen2MoeConfig(PretrainedConfig): ...@@ -149,7 +149,7 @@ class Qwen2MoeConfig(PretrainedConfig):
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window self.sliding_window = sliding_window if use_sliding_window else None
self.max_window_layers = max_window_layers self.max_window_layers = max_window_layers
self.num_key_value_heads = num_key_value_heads self.num_key_value_heads = num_key_value_heads
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment