Unverified Commit f9a98c47 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`Mixtral` & `Mistral`] Add support for sdpa (#28133)



* some nits

* update test

* add support d\sd[a

* remove some dummy inputs

* all good

* style

* nits

* fixes

* fix more copies

* nits

* styling

* fix

* Update src/transformers/models/mistral/modeling_mistral.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* add a slow test just to be sure

* fixup

---------
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
parent 814619f5
...@@ -31,7 +31,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss ...@@ -31,7 +31,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
...@@ -612,9 +612,98 @@ class MistralFlashAttention2(MistralAttention): ...@@ -612,9 +612,98 @@ class MistralFlashAttention2(MistralAttention):
) )
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral
class MistralSdpaAttention(MistralAttention):
"""
Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`MistralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from MistralAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"MistralModel is using MistralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
MISTRAL_ATTENTION_CLASSES = { MISTRAL_ATTENTION_CLASSES = {
"eager": MistralAttention, "eager": MistralAttention,
"flash_attention_2": MistralFlashAttention2, "flash_attention_2": MistralFlashAttention2,
"sdpa": MistralSdpaAttention,
} }
...@@ -717,6 +806,7 @@ class MistralPreTrainedModel(PreTrainedModel): ...@@ -717,6 +806,7 @@ class MistralPreTrainedModel(PreTrainedModel):
_no_split_modules = ["MistralDecoderLayer"] _no_split_modules = ["MistralDecoderLayer"]
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True _supports_cache_class = True
def _init_weights(self, module): def _init_weights(self, module):
...@@ -822,7 +912,7 @@ class MistralModel(MistralPreTrainedModel): ...@@ -822,7 +912,7 @@ class MistralModel(MistralPreTrainedModel):
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
) )
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self._attn_implementation = config._attn_implementation
self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False self.gradient_checkpointing = False
...@@ -893,7 +983,7 @@ class MistralModel(MistralPreTrainedModel): ...@@ -893,7 +983,7 @@ class MistralModel(MistralPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
if attention_mask is not None and self._use_flash_attention_2 and use_cache: if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != batch_size is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right: if is_padding_right:
raise ValueError( raise ValueError(
...@@ -902,9 +992,18 @@ class MistralModel(MistralPreTrainedModel): ...@@ -902,9 +992,18 @@ class MistralModel(MistralPreTrainedModel):
" call `tokenizer.padding_side = 'left'` before tokenizing the input. " " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
) )
if self._use_flash_attention_2: if self._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers # 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._attn_implementation == "sdpa" and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else: else:
# 4d mask is passed through the layers # 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
......
...@@ -33,6 +33,7 @@ from ...activations import ACT2FN ...@@ -33,6 +33,7 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...modeling_attn_mask_utils import ( from ...modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
) )
from ...modeling_outputs import ( from ...modeling_outputs import (
MoeCausalLMOutputWithPast, MoeCausalLMOutputWithPast,
...@@ -660,6 +661,101 @@ class MixtralFlashAttention2(MixtralAttention): ...@@ -660,6 +661,101 @@ class MixtralFlashAttention2(MixtralAttention):
) )
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mixtral
class MixtralSdpaAttention(MixtralAttention):
"""
Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`MixtralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from MixtralAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"MixtralModel is using MixtralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
MIXTRAL_ATTENTION_CLASSES = {
"eager": MixtralAttention,
"flash_attention_2": MixtralFlashAttention2,
"sdpa": MixtralSdpaAttention,
}
class MixtralBLockSparseTop2MLP(nn.Module): class MixtralBLockSparseTop2MLP(nn.Module):
def __init__(self, config: MixtralConfig): def __init__(self, config: MixtralConfig):
super().__init__() super().__init__()
...@@ -678,12 +774,6 @@ class MixtralBLockSparseTop2MLP(nn.Module): ...@@ -678,12 +774,6 @@ class MixtralBLockSparseTop2MLP(nn.Module):
return current_hidden_states return current_hidden_states
MISTRAL_ATTENTION_CLASSES = {
"eager": MixtralAttention,
"flash_attention_2": MixtralFlashAttention2,
}
class MixtralSparseMoeBlock(nn.Module): class MixtralSparseMoeBlock(nn.Module):
""" """
This implementation is This implementation is
...@@ -759,7 +849,7 @@ class MixtralDecoderLayer(nn.Module): ...@@ -759,7 +849,7 @@ class MixtralDecoderLayer(nn.Module):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
self.block_sparse_moe = MixtralSparseMoeBlock(config) self.block_sparse_moe = MixtralSparseMoeBlock(config)
self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -861,6 +951,7 @@ class MixtralPreTrainedModel(PreTrainedModel): ...@@ -861,6 +951,7 @@ class MixtralPreTrainedModel(PreTrainedModel):
_no_split_modules = ["MixtralDecoderLayer"] _no_split_modules = ["MixtralDecoderLayer"]
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True _supports_cache_class = True
def _init_weights(self, module): def _init_weights(self, module):
...@@ -964,7 +1055,7 @@ class MixtralModel(MixtralPreTrainedModel): ...@@ -964,7 +1055,7 @@ class MixtralModel(MixtralPreTrainedModel):
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
) )
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self._attn_implementation = config._attn_implementation
self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False self.gradient_checkpointing = False
...@@ -1040,7 +1131,7 @@ class MixtralModel(MixtralPreTrainedModel): ...@@ -1040,7 +1131,7 @@ class MixtralModel(MixtralPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
if attention_mask is not None and self._use_flash_attention_2 and use_cache: if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != batch_size is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right: if is_padding_right:
raise ValueError( raise ValueError(
...@@ -1049,9 +1140,18 @@ class MixtralModel(MixtralPreTrainedModel): ...@@ -1049,9 +1140,18 @@ class MixtralModel(MixtralPreTrainedModel):
" call `tokenizer.padding_side = 'left'` before tokenizing the input. " " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
) )
if self._use_flash_attention_2: if self._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers # 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._attn_implementation == "sdpa" and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else: else:
# 4d mask is passed through the layers # 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
......
...@@ -28,6 +28,7 @@ from transformers.testing_utils import ( ...@@ -28,6 +28,7 @@ from transformers.testing_utils import (
require_flash_attn, require_flash_attn,
require_torch, require_torch,
require_torch_gpu, require_torch_gpu,
require_torch_sdpa,
slow, slow,
torch_device, torch_device,
) )
...@@ -528,6 +529,44 @@ class MistralIntegrationTest(unittest.TestCase): ...@@ -528,6 +529,44 @@ class MistralIntegrationTest(unittest.TestCase):
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
gc.collect() gc.collect()
@slow
@require_torch_sdpa
def test_model_7b_long_prompt_sdpa(self):
EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
# An input with 4097 tokens that is above the size of the sliding window
input_ids = [1] + [306, 338] * 2048
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1",
device_map="auto",
attn_implementation="sdpa",
)
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
# Assisted generation
assistant_model = model
assistant_model.generation_config.num_assistant_tokens = 2
assistant_model.generation_config.num_assistant_tokens_schedule = "constant"
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
del assistant_model
backend_empty_cache(torch_device)
gc.collect()
EXPECTED_TEXT_COMPLETION = """My favourite condiment is 100% ketchup. I love it on everything. I’m not a big"""
prompt = "My favourite condiment is "
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
# greedy generation outputs
generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0)
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
@slow @slow
def test_speculative_generation(self): def test_speculative_generation(self):
EXPECTED_TEXT_COMPLETION = ( EXPECTED_TEXT_COMPLETION = (
......
...@@ -38,11 +38,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin ...@@ -38,11 +38,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import ( from transformers import MixtralForCausalLM, MixtralForSequenceClassification, MixtralModel
MixtralForCausalLM,
MixtralForSequenceClassification,
MixtralModel,
)
class MixtralModelTester: class MixtralModelTester:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment