Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
569f6c7d
Unverified
Commit
569f6c7d
authored
Apr 01, 2024
by
Yoach Lacombe
Committed by
GitHub
Apr 01, 2024
Browse files
Fix FA2 tests (#29909)
* fix FA2 tests * refactor inference test name
parent
3b8e2932
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
15 additions
and
19 deletions
+15
-19
tests/models/bark/test_modeling_bark.py
tests/models/bark/test_modeling_bark.py
+2
-2
tests/models/distilbert/test_modeling_distilbert.py
tests/models/distilbert/test_modeling_distilbert.py
+2
-2
tests/models/gemma/test_modeling_gemma.py
tests/models/gemma/test_modeling_gemma.py
+1
-1
tests/models/mistral/test_modeling_mistral.py
tests/models/mistral/test_modeling_mistral.py
+1
-1
tests/models/mixtral/test_modeling_mixtral.py
tests/models/mixtral/test_modeling_mixtral.py
+1
-1
tests/models/qwen2/test_modeling_qwen2.py
tests/models/qwen2/test_modeling_qwen2.py
+1
-1
tests/models/starcoder2/test_modeling_starcoder2.py
tests/models/starcoder2/test_modeling_starcoder2.py
+1
-1
tests/models/whisper/test_modeling_whisper.py
tests/models/whisper/test_modeling_whisper.py
+2
-2
tests/test_modeling_common.py
tests/test_modeling_common.py
+4
-8
No files found.
tests/models/bark/test_modeling_bark.py
View file @
569f6c7d
...
...
@@ -879,7 +879,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
@
require_torch_gpu
@
pytest
.
mark
.
flash_attn_test
@
slow
def
test_flash_attn_2_inference
(
self
):
def
test_flash_attn_2_inference
_equivalence
(
self
):
for
model_class
in
self
.
all_model_classes
:
if
not
model_class
.
_supports_flash_attn_2
:
return
...
...
@@ -936,7 +936,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
@
require_torch_gpu
@
pytest
.
mark
.
flash_attn_test
@
slow
def
test_flash_attn_2_inference_padding
_right
(
self
):
def
test_flash_attn_2_inference_
equivalence_right_
padding
(
self
):
for
model_class
in
self
.
all_model_classes
:
if
not
model_class
.
_supports_flash_attn_2
:
return
...
...
tests/models/distilbert/test_modeling_distilbert.py
View file @
569f6c7d
...
...
@@ -301,7 +301,7 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
@
require_torch_accelerator
@
pytest
.
mark
.
flash_attn_test
@
slow
def
test_flash_attn_2_inference
(
self
):
def
test_flash_attn_2_inference
_equivalence
(
self
):
import
torch
for
model_class
in
self
.
all_model_classes
:
...
...
@@ -353,7 +353,7 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
@
require_torch_accelerator
@
pytest
.
mark
.
flash_attn_test
@
slow
def
test_flash_attn_2_inference_padding
_right
(
self
):
def
test_flash_attn_2_inference_
equivalence_right_
padding
(
self
):
import
torch
for
model_class
in
self
.
all_model_classes
:
...
...
tests/models/gemma/test_modeling_gemma.py
View file @
569f6c7d
...
...
@@ -462,7 +462,7 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
@
require_torch_gpu
@
pytest
.
mark
.
flash_attn_test
@
slow
def
test_flash_attn_2_inference_padding
_right
(
self
):
def
test_flash_attn_2_inference_
equivalence_right_
padding
(
self
):
self
.
skipTest
(
"Gemma flash attention does not support right padding"
)
@
require_torch_sdpa
...
...
tests/models/mistral/test_modeling_mistral.py
View file @
569f6c7d
...
...
@@ -466,7 +466,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
@
require_torch_gpu
@
pytest
.
mark
.
flash_attn_test
@
slow
def
test_flash_attn_2_inference_padding
_right
(
self
):
def
test_flash_attn_2_inference_
equivalence_right_
padding
(
self
):
self
.
skipTest
(
"Mistral flash attention does not support right padding"
)
...
...
tests/models/mixtral/test_modeling_mixtral.py
View file @
569f6c7d
...
...
@@ -465,7 +465,7 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
@
require_torch_gpu
@
pytest
.
mark
.
flash_attn_test
@
slow
def
test_flash_attn_2_inference_padding
_right
(
self
):
def
test_flash_attn_2_inference_
equivalence_right_
padding
(
self
):
self
.
skipTest
(
"Mixtral flash attention does not support right padding"
)
# Ignore copy
...
...
tests/models/qwen2/test_modeling_qwen2.py
View file @
569f6c7d
...
...
@@ -477,7 +477,7 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
@
require_torch_gpu
@
pytest
.
mark
.
flash_attn_test
@
slow
def
test_flash_attn_2_inference_padding
_right
(
self
):
def
test_flash_attn_2_inference_
equivalence_right_
padding
(
self
):
self
.
skipTest
(
"Qwen2 flash attention does not support right padding"
)
...
...
tests/models/starcoder2/test_modeling_starcoder2.py
View file @
569f6c7d
...
...
@@ -461,7 +461,7 @@ class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
@
require_torch_gpu
@
pytest
.
mark
.
flash_attn_test
@
slow
def
test_flash_attn_2_inference_padding
_right
(
self
):
def
test_flash_attn_2_inference_
equivalence_right_
padding
(
self
):
self
.
skipTest
(
"Starcoder2 flash attention does not support right padding"
)
...
...
tests/models/whisper/test_modeling_whisper.py
View file @
569f6c7d
...
...
@@ -888,7 +888,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
@
require_torch_gpu
@
pytest
.
mark
.
flash_attn_test
@
slow
def
test_flash_attn_2_inference
(
self
):
def
test_flash_attn_2_inference
_equivalence
(
self
):
import
torch
for
model_class
in
self
.
all_model_classes
:
...
...
@@ -934,7 +934,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
@
require_torch_gpu
@
pytest
.
mark
.
flash_attn_test
@
slow
def
test_flash_attn_2_inference_padding
_right
(
self
):
def
test_flash_attn_2_inference_
equivalence_right_
padding
(
self
):
import
torch
for
model_class
in
self
.
all_model_classes
:
...
...
tests/test_modeling_common.py
View file @
569f6c7d
...
...
@@ -3245,7 +3245,7 @@ class ModelTesterMixin:
@
require_torch_gpu
@
mark
.
flash_attn_test
@
slow
def
test_flash_attn_2_inference
(
self
):
def
test_flash_attn_2_inference
_equivalence
(
self
):
for
model_class
in
self
.
all_model_classes
:
if
not
model_class
.
_supports_flash_attn_2
:
self
.
skipTest
(
f
"
{
model_class
.
__name__
}
does not support Flash Attention 2"
)
...
...
@@ -3260,9 +3260,7 @@ class ModelTesterMixin:
)
model_fa
.
to
(
torch_device
)
model
=
model_class
.
from_pretrained
(
tmpdirname
,
torch_dtype
=
torch
.
bfloat16
,
attn_implementation
=
"flash_attention_2"
)
model
=
model_class
.
from_pretrained
(
tmpdirname
,
torch_dtype
=
torch
.
bfloat16
)
model
.
to
(
torch_device
)
dummy_input
=
inputs_dict
[
model
.
main_input_name
][:
1
]
...
...
@@ -3340,7 +3338,7 @@ class ModelTesterMixin:
@
require_torch_gpu
@
mark
.
flash_attn_test
@
slow
def
test_flash_attn_2_inference_padding
_right
(
self
):
def
test_flash_attn_2_inference_
equivalence_right_
padding
(
self
):
for
model_class
in
self
.
all_model_classes
:
if
not
model_class
.
_supports_flash_attn_2
:
self
.
skipTest
(
f
"
{
model_class
.
__name__
}
does not support Flash Attention 2"
)
...
...
@@ -3355,9 +3353,7 @@ class ModelTesterMixin:
)
model_fa
.
to
(
torch_device
)
model
=
model_class
.
from_pretrained
(
tmpdirname
,
torch_dtype
=
torch
.
bfloat16
,
attn_implementation
=
"flash_attention_2"
)
model
=
model_class
.
from_pretrained
(
tmpdirname
,
torch_dtype
=
torch
.
bfloat16
)
model
.
to
(
torch_device
)
dummy_input
=
inputs_dict
[
model
.
main_input_name
][:
1
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment