"ml/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "95e744beeb82f725579932336eeabc0de019cbf4"
Unverified Commit 569f6c7d authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Fix FA2 tests (#29909)

* fix FA2 tests

* refactor inference test name
parent 3b8e2932
...@@ -879,7 +879,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -879,7 +879,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
@require_torch_gpu @require_torch_gpu
@pytest.mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_inference(self): def test_flash_attn_2_inference_equivalence(self):
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2: if not model_class._supports_flash_attn_2:
return return
...@@ -936,7 +936,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -936,7 +936,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
@require_torch_gpu @require_torch_gpu
@pytest.mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_inference_padding_right(self): def test_flash_attn_2_inference_equivalence_right_padding(self):
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2: if not model_class._supports_flash_attn_2:
return return
......
...@@ -301,7 +301,7 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa ...@@ -301,7 +301,7 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
@require_torch_accelerator @require_torch_accelerator
@pytest.mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_inference(self): def test_flash_attn_2_inference_equivalence(self):
import torch import torch
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
...@@ -353,7 +353,7 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa ...@@ -353,7 +353,7 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
@require_torch_accelerator @require_torch_accelerator
@pytest.mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_inference_padding_right(self): def test_flash_attn_2_inference_equivalence_right_padding(self):
import torch import torch
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
......
...@@ -462,7 +462,7 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -462,7 +462,7 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
@require_torch_gpu @require_torch_gpu
@pytest.mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_inference_padding_right(self): def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest("Gemma flash attention does not support right padding") self.skipTest("Gemma flash attention does not support right padding")
@require_torch_sdpa @require_torch_sdpa
......
...@@ -466,7 +466,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -466,7 +466,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
@require_torch_gpu @require_torch_gpu
@pytest.mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_inference_padding_right(self): def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest("Mistral flash attention does not support right padding") self.skipTest("Mistral flash attention does not support right padding")
......
...@@ -465,7 +465,7 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -465,7 +465,7 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
@require_torch_gpu @require_torch_gpu
@pytest.mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_inference_padding_right(self): def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest("Mixtral flash attention does not support right padding") self.skipTest("Mixtral flash attention does not support right padding")
# Ignore copy # Ignore copy
......
...@@ -477,7 +477,7 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -477,7 +477,7 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
@require_torch_gpu @require_torch_gpu
@pytest.mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_inference_padding_right(self): def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest("Qwen2 flash attention does not support right padding") self.skipTest("Qwen2 flash attention does not support right padding")
......
...@@ -461,7 +461,7 @@ class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste ...@@ -461,7 +461,7 @@ class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
@require_torch_gpu @require_torch_gpu
@pytest.mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_inference_padding_right(self): def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest("Starcoder2 flash attention does not support right padding") self.skipTest("Starcoder2 flash attention does not support right padding")
......
...@@ -888,7 +888,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -888,7 +888,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
@require_torch_gpu @require_torch_gpu
@pytest.mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_inference(self): def test_flash_attn_2_inference_equivalence(self):
import torch import torch
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
...@@ -934,7 +934,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -934,7 +934,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
@require_torch_gpu @require_torch_gpu
@pytest.mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_inference_padding_right(self): def test_flash_attn_2_inference_equivalence_right_padding(self):
import torch import torch
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
......
...@@ -3245,7 +3245,7 @@ class ModelTesterMixin: ...@@ -3245,7 +3245,7 @@ class ModelTesterMixin:
@require_torch_gpu @require_torch_gpu
@mark.flash_attn_test @mark.flash_attn_test
@slow @slow
def test_flash_attn_2_inference(self): def test_flash_attn_2_inference_equivalence(self):
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2: if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
...@@ -3260,9 +3260,7 @@ class ModelTesterMixin: ...@@ -3260,9 +3260,7 @@ class ModelTesterMixin:
) )
model_fa.to(torch_device) model_fa.to(torch_device)
model = model_class.from_pretrained( model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model.to(torch_device) model.to(torch_device)
dummy_input = inputs_dict[model.main_input_name][:1] dummy_input = inputs_dict[model.main_input_name][:1]
...@@ -3340,7 +3338,7 @@ class ModelTesterMixin: ...@@ -3340,7 +3338,7 @@ class ModelTesterMixin:
@require_torch_gpu @require_torch_gpu
@mark.flash_attn_test @mark.flash_attn_test
@slow @slow
def test_flash_attn_2_inference_padding_right(self): def test_flash_attn_2_inference_equivalence_right_padding(self):
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2: if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
...@@ -3355,9 +3353,7 @@ class ModelTesterMixin: ...@@ -3355,9 +3353,7 @@ class ModelTesterMixin:
) )
model_fa.to(torch_device) model_fa.to(torch_device)
model = model_class.from_pretrained( model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model.to(torch_device) model.to(torch_device)
dummy_input = inputs_dict[model.main_input_name][:1] dummy_input = inputs_dict[model.main_input_name][:1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment