"...git@developer.sourcefind.cn:modelzoo/qwen_lmdeploy.git" did not exist on "2067862d1b874704ff5e88e65c515a7ff062f85e"
Unverified Commit 3b7675b2 authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

fix FA2 when using quantization (#28203)

parent fa21ead7
......@@ -617,11 +617,11 @@ class FalconFlashAttention2(FalconAttention):
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_layer.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
elif torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
else:
target_dtype = self.query_key_value.weight.dtype
......
......@@ -375,11 +375,11 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
elif torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
else:
target_dtype = self.c_attn.weight.dtype
......
......@@ -528,11 +528,11 @@ class LlamaFlashAttention2(LlamaAttention):
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
elif torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
else:
target_dtype = self.q_proj.weight.dtype
......
......@@ -428,11 +428,11 @@ class MistralFlashAttention2(MistralAttention):
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
elif torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
else:
target_dtype = self.q_proj.weight.dtype
......
......@@ -477,11 +477,11 @@ class MixtralFlashAttention2(MixtralAttention):
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
elif torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
else:
target_dtype = self.q_proj.weight.dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment