Unverified Commit e85d8639 authored by Fanli Lin's avatar Fanli Lin Committed by GitHub
Browse files

add the missing flash attention test marker (#32419)

* add flash attention check

* fix

* fix

* add the missing marker

* bug fix

* add one more

* remove order

* add one more
parent 0aa83282
...@@ -628,9 +628,9 @@ class GemmaIntegrationTest(unittest.TestCase): ...@@ -628,9 +628,9 @@ class GemmaIntegrationTest(unittest.TestCase):
self.assertEqual(output_text, EXPECTED_TEXTS) self.assertEqual(output_text, EXPECTED_TEXTS)
@pytest.mark.flash_attn_test
@require_flash_attn @require_flash_attn
@require_read_token @require_read_token
@pytest.mark.flash_attn_test
def test_model_2b_flash_attn(self): def test_model_2b_flash_attn(self):
model_id = "google/gemma-2b" model_id = "google/gemma-2b"
EXPECTED_TEXTS = [ EXPECTED_TEXTS = [
......
...@@ -620,6 +620,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -620,6 +620,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@slow @slow
@pytest.mark.flash_attn_test
def test_use_flash_attention_2_true(self): def test_use_flash_attention_2_true(self):
""" """
NOTE: this is the only test testing that the legacy `use_flash_attention=2` argument still works as intended. NOTE: this is the only test testing that the legacy `use_flash_attention=2` argument still works as intended.
......
...@@ -576,9 +576,10 @@ class MistralIntegrationTest(unittest.TestCase): ...@@ -576,9 +576,10 @@ class MistralIntegrationTest(unittest.TestCase):
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
gc.collect() gc.collect()
@require_flash_attn
@require_bitsandbytes @require_bitsandbytes
@slow @slow
@require_flash_attn @pytest.mark.flash_attn_test
def test_model_7b_long_prompt(self): def test_model_7b_long_prompt(self):
EXPECTED_OUTPUT_TOKEN_IDS = [306, 338] EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
# An input with 4097 tokens that is above the size of the sliding window # An input with 4097 tokens that is above the size of the sliding window
......
...@@ -544,6 +544,7 @@ class Qwen2IntegrationTest(unittest.TestCase): ...@@ -544,6 +544,7 @@ class Qwen2IntegrationTest(unittest.TestCase):
@require_bitsandbytes @require_bitsandbytes
@slow @slow
@require_flash_attn @require_flash_attn
@pytest.mark.flash_attn_test
def test_model_450m_long_prompt(self): def test_model_450m_long_prompt(self):
EXPECTED_OUTPUT_TOKEN_IDS = [306, 338] EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
# An input with 4097 tokens that is above the size of the sliding window # An input with 4097 tokens that is above the size of the sliding window
......
...@@ -606,6 +606,7 @@ class Qwen2MoeIntegrationTest(unittest.TestCase): ...@@ -606,6 +606,7 @@ class Qwen2MoeIntegrationTest(unittest.TestCase):
@require_bitsandbytes @require_bitsandbytes
@slow @slow
@require_flash_attn @require_flash_attn
@pytest.mark.flash_attn_test
def test_model_a2_7b_long_prompt(self): def test_model_a2_7b_long_prompt(self):
EXPECTED_OUTPUT_TOKEN_IDS = [306, 338] EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
# An input with 4097 tokens that is above the size of the sliding window # An input with 4097 tokens that is above the size of the sliding window
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import unittest import unittest
import pytest
from parameterized import parameterized from parameterized import parameterized
from transformers import StableLmConfig, is_torch_available, set_seed from transformers import StableLmConfig, is_torch_available, set_seed
...@@ -539,6 +540,7 @@ class StableLmModelIntegrationTest(unittest.TestCase): ...@@ -539,6 +540,7 @@ class StableLmModelIntegrationTest(unittest.TestCase):
@require_bitsandbytes @require_bitsandbytes
@slow @slow
@require_flash_attn @require_flash_attn
@pytest.mark.flash_attn_test
def test_model_3b_long_prompt(self): def test_model_3b_long_prompt(self):
EXPECTED_OUTPUT_TOKEN_IDS = [3, 3, 3] EXPECTED_OUTPUT_TOKEN_IDS = [3, 3, 3]
input_ids = [306, 338] * 2047 input_ids = [306, 338] * 2047
......
...@@ -528,6 +528,7 @@ class Starcoder2IntegrationTest(unittest.TestCase): ...@@ -528,6 +528,7 @@ class Starcoder2IntegrationTest(unittest.TestCase):
self.assertEqual(EXPECTED_TEXT, output_text) self.assertEqual(EXPECTED_TEXT, output_text)
@require_flash_attn @require_flash_attn
@pytest.mark.flash_attn_test
def test_starcoder2_batched_generation_fa2(self): def test_starcoder2_batched_generation_fa2(self):
EXPECTED_TEXT = [ EXPECTED_TEXT = [
"Hello my name is Younes and I am a student at the University of Liverpool. I am currently studying for my MSc in Computer Science. I am interested in the field of Machine Learning and I am currently working on", "Hello my name is Younes and I am a student at the University of Liverpool. I am currently studying for my MSc in Computer Science. I am interested in the field of Machine Learning and I am currently working on",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment