Unverified Commit 3b7675b2 authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

fix FA2 when using quantization (#28203)

parent fa21ead7
...@@ -617,11 +617,11 @@ class FalconFlashAttention2(FalconAttention): ...@@ -617,11 +617,11 @@ class FalconFlashAttention2(FalconAttention):
# cast them back in float16 just to be sure everything works as expected. # cast them back in float16 just to be sure everything works as expected.
input_dtype = query_layer.dtype input_dtype = query_layer.dtype
if input_dtype == torch.float32: if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized # Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"): elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype target_dtype = self.config._pre_quantization_dtype
elif torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
else: else:
target_dtype = self.query_key_value.weight.dtype target_dtype = self.query_key_value.weight.dtype
......
...@@ -375,11 +375,11 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention): ...@@ -375,11 +375,11 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
# cast them back in float16 just to be sure everything works as expected. # cast them back in float16 just to be sure everything works as expected.
input_dtype = query.dtype input_dtype = query.dtype
if input_dtype == torch.float32: if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized # Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"): elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype target_dtype = self.config._pre_quantization_dtype
elif torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
else: else:
target_dtype = self.c_attn.weight.dtype target_dtype = self.c_attn.weight.dtype
......
...@@ -528,11 +528,11 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -528,11 +528,11 @@ class LlamaFlashAttention2(LlamaAttention):
input_dtype = query_states.dtype input_dtype = query_states.dtype
if input_dtype == torch.float32: if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized # Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"): elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype target_dtype = self.config._pre_quantization_dtype
elif torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
else: else:
target_dtype = self.q_proj.weight.dtype target_dtype = self.q_proj.weight.dtype
......
...@@ -428,11 +428,11 @@ class MistralFlashAttention2(MistralAttention): ...@@ -428,11 +428,11 @@ class MistralFlashAttention2(MistralAttention):
# cast them back in float16 just to be sure everything works as expected. # cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype input_dtype = query_states.dtype
if input_dtype == torch.float32: if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized # Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"): elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype target_dtype = self.config._pre_quantization_dtype
elif torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
else: else:
target_dtype = self.q_proj.weight.dtype target_dtype = self.q_proj.weight.dtype
......
...@@ -477,11 +477,11 @@ class MixtralFlashAttention2(MixtralAttention): ...@@ -477,11 +477,11 @@ class MixtralFlashAttention2(MixtralAttention):
# cast them back in float16 just to be sure everything works as expected. # cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype input_dtype = query_states.dtype
if input_dtype == torch.float32: if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized # Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"): elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype target_dtype = self.config._pre_quantization_dtype
elif torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
else: else:
target_dtype = self.q_proj.weight.dtype target_dtype = self.q_proj.weight.dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment