Unverified Commit def581ef authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

Fix FA2 integration (#28142)



* fix fa2

* fix FA2 for popular models

* improve warning and add Younes as co-author
Co-Authored-By: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* fix the warning

* Add Tip

* typo fix

* nit

---------
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent b134f685
...@@ -67,6 +67,8 @@ come in several checkpoints they each contain a part of each weight of the model ...@@ -67,6 +67,8 @@ come in several checkpoints they each contain a part of each weight of the model
- The LLaMA tokenizer is a BPE model based on [sentencepiece](https://github.com/google/sentencepiece). One quirk of sentencepiece is that when decoding a sequence, if the first token is the start of the word (e.g. "Banana"), the tokenizer does not prepend the prefix space to the string. - The LLaMA tokenizer is a BPE model based on [sentencepiece](https://github.com/google/sentencepiece). One quirk of sentencepiece is that when decoding a sequence, if the first token is the start of the word (e.g. "Banana"), the tokenizer does not prepend the prefix space to the string.
- When using Flash Attention 2 via `attn_implementation="flash_attention_2"`, don't pass `torch_dtype` to the `from_pretrained` class method and use Automatic Mixed-Precision training. When using `Trainer`, it is simply specifying either `fp16` or `bf16` to `True`. Otherwise, make sure you are using `torch.autocast`. This is required because the Flash Attention only support `fp16` and `bf16` data type.
## Resources ## Resources
......
...@@ -1419,9 +1419,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1419,9 +1419,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour" "You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour"
) )
elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]: elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]:
raise ValueError( logger.warning(
f"Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes. You passed {torch_dtype}, this might lead to" "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes. "
" unexpected behaviour." "No dtype was provided, you should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator."
) )
# The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called, # The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
......
...@@ -620,6 +620,8 @@ class FalconFlashAttention2(FalconAttention): ...@@ -620,6 +620,8 @@ class FalconFlashAttention2(FalconAttention):
# Handle the case where the model is quantized # Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"): if hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype target_dtype = self.config._pre_quantization_dtype
elif torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
else: else:
target_dtype = self.query_key_value.weight.dtype target_dtype = self.query_key_value.weight.dtype
......
...@@ -378,6 +378,8 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention): ...@@ -378,6 +378,8 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
# Handle the case where the model is quantized # Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"): if hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype target_dtype = self.config._pre_quantization_dtype
elif torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
else: else:
target_dtype = self.c_attn.weight.dtype target_dtype = self.c_attn.weight.dtype
......
...@@ -531,6 +531,8 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -531,6 +531,8 @@ class LlamaFlashAttention2(LlamaAttention):
# Handle the case where the model is quantized # Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"): if hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype target_dtype = self.config._pre_quantization_dtype
elif torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
else: else:
target_dtype = self.q_proj.weight.dtype target_dtype = self.q_proj.weight.dtype
......
...@@ -431,6 +431,8 @@ class MistralFlashAttention2(MistralAttention): ...@@ -431,6 +431,8 @@ class MistralFlashAttention2(MistralAttention):
# Handle the case where the model is quantized # Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"): if hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype target_dtype = self.config._pre_quantization_dtype
elif torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
else: else:
target_dtype = self.q_proj.weight.dtype target_dtype = self.q_proj.weight.dtype
......
...@@ -479,6 +479,8 @@ class MixtralFlashAttention2(MixtralAttention): ...@@ -479,6 +479,8 @@ class MixtralFlashAttention2(MixtralAttention):
# Handle the case where the model is quantized # Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"): if hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype target_dtype = self.config._pre_quantization_dtype
elif torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
else: else:
target_dtype = self.q_proj.weight.dtype target_dtype = self.q_proj.weight.dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment