Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3b7675b2
"...git@developer.sourcefind.cn:modelzoo/qwen_lmdeploy.git" did not exist on "2067862d1b874704ff5e88e65c515a7ff062f85e"
Unverified
Commit
3b7675b2
authored
Dec 26, 2023
by
Sourab Mangrulkar
Committed by
GitHub
Dec 26, 2023
Browse files
fix FA2 when using quantization (#28203)
parent
fa21ead7
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
15 additions
and
15 deletions
+15
-15
src/transformers/models/falcon/modeling_falcon.py
src/transformers/models/falcon/modeling_falcon.py
+3
-3
src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
+3
-3
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+3
-3
src/transformers/models/mistral/modeling_mistral.py
src/transformers/models/mistral/modeling_mistral.py
+3
-3
src/transformers/models/mixtral/modeling_mixtral.py
src/transformers/models/mixtral/modeling_mixtral.py
+3
-3
No files found.
src/transformers/models/falcon/modeling_falcon.py
View file @
3b7675b2
...
...
@@ -617,11 +617,11 @@ class FalconFlashAttention2(FalconAttention):
# cast them back in float16 just to be sure everything works as expected.
input_dtype
=
query_layer
.
dtype
if
input_dtype
==
torch
.
float32
:
if
torch
.
is_autocast_enabled
():
target_dtype
=
torch
.
get_autocast_gpu_dtype
()
# Handle the case where the model is quantized
if
hasattr
(
self
.
config
,
"_pre_quantization_dtype"
):
el
if
hasattr
(
self
.
config
,
"_pre_quantization_dtype"
):
target_dtype
=
self
.
config
.
_pre_quantization_dtype
elif
torch
.
is_autocast_enabled
():
target_dtype
=
torch
.
get_autocast_gpu_dtype
()
else
:
target_dtype
=
self
.
query_key_value
.
weight
.
dtype
...
...
src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
View file @
3b7675b2
...
...
@@ -375,11 +375,11 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
# cast them back in float16 just to be sure everything works as expected.
input_dtype
=
query
.
dtype
if
input_dtype
==
torch
.
float32
:
if
torch
.
is_autocast_enabled
():
target_dtype
=
torch
.
get_autocast_gpu_dtype
()
# Handle the case where the model is quantized
if
hasattr
(
self
.
config
,
"_pre_quantization_dtype"
):
el
if
hasattr
(
self
.
config
,
"_pre_quantization_dtype"
):
target_dtype
=
self
.
config
.
_pre_quantization_dtype
elif
torch
.
is_autocast_enabled
():
target_dtype
=
torch
.
get_autocast_gpu_dtype
()
else
:
target_dtype
=
self
.
c_attn
.
weight
.
dtype
...
...
src/transformers/models/llama/modeling_llama.py
View file @
3b7675b2
...
...
@@ -528,11 +528,11 @@ class LlamaFlashAttention2(LlamaAttention):
input_dtype
=
query_states
.
dtype
if
input_dtype
==
torch
.
float32
:
if
torch
.
is_autocast_enabled
():
target_dtype
=
torch
.
get_autocast_gpu_dtype
()
# Handle the case where the model is quantized
if
hasattr
(
self
.
config
,
"_pre_quantization_dtype"
):
el
if
hasattr
(
self
.
config
,
"_pre_quantization_dtype"
):
target_dtype
=
self
.
config
.
_pre_quantization_dtype
elif
torch
.
is_autocast_enabled
():
target_dtype
=
torch
.
get_autocast_gpu_dtype
()
else
:
target_dtype
=
self
.
q_proj
.
weight
.
dtype
...
...
src/transformers/models/mistral/modeling_mistral.py
View file @
3b7675b2
...
...
@@ -428,11 +428,11 @@ class MistralFlashAttention2(MistralAttention):
# cast them back in float16 just to be sure everything works as expected.
input_dtype
=
query_states
.
dtype
if
input_dtype
==
torch
.
float32
:
if
torch
.
is_autocast_enabled
():
target_dtype
=
torch
.
get_autocast_gpu_dtype
()
# Handle the case where the model is quantized
if
hasattr
(
self
.
config
,
"_pre_quantization_dtype"
):
el
if
hasattr
(
self
.
config
,
"_pre_quantization_dtype"
):
target_dtype
=
self
.
config
.
_pre_quantization_dtype
elif
torch
.
is_autocast_enabled
():
target_dtype
=
torch
.
get_autocast_gpu_dtype
()
else
:
target_dtype
=
self
.
q_proj
.
weight
.
dtype
...
...
src/transformers/models/mixtral/modeling_mixtral.py
View file @
3b7675b2
...
...
@@ -477,11 +477,11 @@ class MixtralFlashAttention2(MixtralAttention):
# cast them back in float16 just to be sure everything works as expected.
input_dtype
=
query_states
.
dtype
if
input_dtype
==
torch
.
float32
:
if
torch
.
is_autocast_enabled
():
target_dtype
=
torch
.
get_autocast_gpu_dtype
()
# Handle the case where the model is quantized
if
hasattr
(
self
.
config
,
"_pre_quantization_dtype"
):
el
if
hasattr
(
self
.
config
,
"_pre_quantization_dtype"
):
target_dtype
=
self
.
config
.
_pre_quantization_dtype
elif
torch
.
is_autocast_enabled
():
target_dtype
=
torch
.
get_autocast_gpu_dtype
()
else
:
target_dtype
=
self
.
q_proj
.
weight
.
dtype
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment