"...resnet50_tensorflow.git" did not exist on "c3fe0550a70a807ffef5c0c49573624abd52d813"
Unverified Commit 5a73316b authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`FA-2`] Final fix for FA2 dtype (#26846)



* final fix for FA2 dtype

* try

* oops

* Update src/transformers/models/falcon/modeling_falcon.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* apply fix everywhere

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 732d2a8a
...@@ -613,15 +613,18 @@ class FalconFlashAttention2(FalconAttention): ...@@ -613,15 +613,18 @@ class FalconFlashAttention2(FalconAttention):
# cast them back in float16 just to be sure everything works as expected. # cast them back in float16 just to be sure everything works as expected.
input_dtype = query_layer.dtype input_dtype = query_layer.dtype
if input_dtype == torch.float32: if input_dtype == torch.float32:
# Handle the case where the model is quantized
target_dtype = getattr(self.config, "_pre_quantization_dtype", self.query_key_value.weight.dtype)
logger.warning_once( logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to" f"The input hidden states seems to be silently casted in float32, this might be related to"
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
" float16." f" {target_dtype}."
) )
query_layer = query_layer.to(torch.float16) query_layer = query_layer.to(target_dtype)
key_layer = key_layer.to(torch.float16) key_layer = key_layer.to(target_dtype)
value_layer = value_layer.to(torch.float16) value_layer = value_layer.to(target_dtype)
attn_output = self._flash_attention_forward( attn_output = self._flash_attention_forward(
query_layer, key_layer, value_layer, padding_mask, query_length, dropout=attn_dropout query_layer, key_layer, value_layer, padding_mask, query_length, dropout=attn_dropout
......
...@@ -469,20 +469,24 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -469,20 +469,24 @@ class LlamaFlashAttention2(LlamaAttention):
# In PEFT, usually we cast the layer norms in float32 for training stability reasons # In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need # therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected. # cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms # This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly) # in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype input_dtype = query_states.dtype
if input_dtype == torch.float32: if input_dtype == torch.float32:
# Handle the case where the model is quantized
target_dtype = getattr(self.config, "_pre_quantization_dtype", self.q_proj.weight.dtype)
logger.warning_once( logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to" f"The input hidden states seems to be silently casted in float32, this might be related to"
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
" float16." f" {target_dtype}."
) )
query_states = query_states.to(torch.float16) query_states = query_states.to(target_dtype)
key_states = key_states.to(torch.float16) key_states = key_states.to(target_dtype)
value_states = value_states.to(torch.float16) value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward( attn_output = self._flash_attention_forward(
query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate
......
...@@ -408,15 +408,18 @@ class MistralFlashAttention2(MistralAttention): ...@@ -408,15 +408,18 @@ class MistralFlashAttention2(MistralAttention):
# cast them back in float16 just to be sure everything works as expected. # cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype input_dtype = query_states.dtype
if input_dtype == torch.float32: if input_dtype == torch.float32:
# Handle the case where the model is quantized
target_dtype = getattr(self.config, "_pre_quantization_dtype", self.q_proj.weight.dtype)
logger.warning_once( logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to" f"The input hidden states seems to be silently casted in float32, this might be related to"
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
" float16." f" {target_dtype}."
) )
query_states = query_states.to(torch.float16) query_states = query_states.to(target_dtype)
key_states = key_states.to(torch.float16) key_states = key_states.to(target_dtype)
value_states = value_states.to(torch.float16) value_states = value_states.to(target_dtype)
# Reashape to the expected shape for Flash Attention # Reashape to the expected shape for Flash Attention
query_states = query_states.transpose(1, 2) query_states = query_states.transpose(1, 2)
......
...@@ -64,6 +64,7 @@ from transformers.testing_utils import ( ...@@ -64,6 +64,7 @@ from transformers.testing_utils import (
is_pt_flax_cross_test, is_pt_flax_cross_test,
is_pt_tf_cross_test, is_pt_tf_cross_test,
require_accelerate, require_accelerate,
require_bitsandbytes,
require_flash_attn, require_flash_attn,
require_safetensors, require_safetensors,
require_torch, require_torch,
...@@ -2959,6 +2960,45 @@ class ModelTesterMixin: ...@@ -2959,6 +2960,45 @@ class ModelTesterMixin:
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=30, do_sample=False dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=30, do_sample=False
) )
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
@mark.flash_attn_test
@slow
def test_flash_attn_2_fp32_ln(self):
import torch
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
return
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
use_flash_attention_2=True,
low_cpu_mem_usage=True,
load_in_4bit=True,
)
for _, param in model.named_parameters():
# upcast only layer norms
if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
param.data = param.data.to(torch.float32)
_ = model(input_ids=dummy_input)
# with attention mask
_ = model(input_ids=dummy_input, attention_mask=dummy_attention_mask)
global_rng = random.Random() global_rng = random.Random()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment