Unverified Commit cadf93a6 authored by Susnato Dhar's avatar Susnato Dhar Committed by GitHub
Browse files

fix FA2 when using quantization for remaining models (#28341)



* fix fa2 autocasting when using quantization

* Update src/transformers/models/distilbert/modeling_distilbert.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/distilbert/modeling_distilbert.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 899d8351
......@@ -382,8 +382,10 @@ class BartFlashAttention2(BartAttention):
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
......
......@@ -322,8 +322,10 @@ class DistilBertFlashAttention2(MultiHeadSelfAttention):
# in fp32. (LlamaRMSNorm handles it correctly)
if query_states.dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_lin.weight.dtype
......
......@@ -357,8 +357,10 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention):
# in fp32. (LlamaRMSNorm handles it correctly)
if query.dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
......
......@@ -384,8 +384,10 @@ class GPTNeoXFlashAttention2(GPTNeoXAttention):
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
input_dtype = query.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
......
......@@ -372,8 +372,10 @@ class MBartFlashAttention2(MBartAttention):
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
......
......@@ -363,8 +363,10 @@ class OptFlashAttention2(OPTAttention):
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
......
......@@ -484,8 +484,10 @@ class PhiFlashAttention2(PhiAttention):
# in fp32.
if query_states.dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
......
......@@ -562,8 +562,10 @@ class WhisperFlashAttention2(WhisperAttention):
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment