Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e66984f9
Unverified
Commit
e66984f9
authored
Nov 20, 2023
by
Younes Belkada
Committed by
GitHub
Nov 20, 2023
Browse files
[`FA-2`] Add fa2 support for `from_config` (#26914)
* add fa2 support for from_config * Update test_modeling_common.py
parent
f31af392
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
54 additions
and
0 deletions
+54
-0
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+6
-0
tests/test_modeling_common.py
tests/test_modeling_common.py
+48
-0
No files found.
src/transformers/modeling_utils.py
View file @
e66984f9
...
...
@@ -1173,14 +1173,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
Args:
torch_dtype (`torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype.
use_flash_attention_2 (`bool`, *optional*):
Whether to load the model with Flash Attention 2 modules.
"""
torch_dtype
=
kwargs
.
pop
(
"torch_dtype"
,
None
)
use_flash_attention_2
=
kwargs
.
pop
(
"use_flash_attention_2"
,
False
)
# override default dtype if needed
dtype_orig
=
None
if
torch_dtype
is
not
None
:
dtype_orig
=
cls
.
_set_default_torch_dtype
(
torch_dtype
)
if
use_flash_attention_2
:
config
=
cls
.
_check_and_enable_flash_attn_2
(
config
,
torch_dtype
)
if
is_deepspeed_zero3_enabled
():
import
deepspeed
...
...
tests/test_modeling_common.py
View file @
e66984f9
...
...
@@ -33,6 +33,7 @@ from pytest import mark
import
transformers
from
transformers
import
(
AutoModel
,
AutoModelForCausalLM
,
AutoModelForSequenceClassification
,
PretrainedConfig
,
is_torch_available
,
...
...
@@ -3269,6 +3270,53 @@ class ModelTesterMixin:
# Check models are equal
self
.
assertTrue
(
check_models_equal
(
flax_model_1
,
flax_model_2
))
@
require_flash_attn
@
require_torch_gpu
@
mark
.
flash_attn_test
@
slow
def
test_flash_attn_2_from_config
(
self
):
import
torch
for
model_class
in
self
.
all_generative_model_classes
:
if
not
model_class
.
_supports_flash_attn_2
:
return
config
,
_
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
# TODO: to change it in the future with other relevant auto classes
fa2_model
=
AutoModelForCausalLM
.
from_config
(
config
,
use_flash_attention_2
=
True
,
torch_dtype
=
torch
.
bfloat16
).
to
(
torch_device
)
dummy_input
=
torch
.
LongTensor
([[
0
,
2
,
3
,
4
],
[
0
,
2
,
3
,
4
]]).
to
(
torch_device
)
dummy_attention_mask
=
torch
.
LongTensor
([[
1
,
1
,
1
,
1
],
[
0
,
1
,
1
,
1
]]).
to
(
torch_device
)
fa2_correctly_converted
=
False
for
_
,
module
in
fa2_model
.
named_modules
():
if
"FlashAttention"
in
module
.
__class__
.
__name__
:
fa2_correctly_converted
=
True
break
self
.
assertTrue
(
fa2_correctly_converted
)
_
=
fa2_model
(
input_ids
=
dummy_input
,
attention_mask
=
dummy_attention_mask
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
fa2_model
.
save_pretrained
(
tmpdirname
)
model_from_pretrained
=
AutoModelForCausalLM
.
from_pretrained
(
tmpdirname
)
self
.
assertFalse
(
getattr
(
model_from_pretrained
.
config
,
"_flash_attn_2_enabled"
,
False
))
fa2_correctly_converted
=
False
for
_
,
module
in
model_from_pretrained
.
named_modules
():
if
"FlashAttention"
in
module
.
__class__
.
__name__
:
fa2_correctly_converted
=
True
break
self
.
assertFalse
(
fa2_correctly_converted
)
global_rng
=
random
.
Random
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment