Unverified Commit 3258ff93 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

use `pytest.mark` directly (#27390)



fix
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 791ec370
...@@ -20,7 +20,7 @@ import inspect ...@@ -20,7 +20,7 @@ import inspect
import tempfile import tempfile
import unittest import unittest
from pytest import mark import pytest
from transformers import ( from transformers import (
BarkCoarseConfig, BarkCoarseConfig,
...@@ -877,7 +877,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -877,7 +877,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_inference(self): def test_flash_attn_2_inference(self):
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
...@@ -936,7 +936,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -936,7 +936,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_inference_padding_right(self): def test_flash_attn_2_inference_padding_right(self):
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
......
...@@ -16,7 +16,7 @@ import os ...@@ -16,7 +16,7 @@ import os
import tempfile import tempfile
import unittest import unittest
from pytest import mark import pytest
from transformers import DistilBertConfig, is_torch_available from transformers import DistilBertConfig, is_torch_available
from transformers.testing_utils import require_flash_attn, require_torch, require_torch_accelerator, slow, torch_device from transformers.testing_utils import require_flash_attn, require_torch, require_torch_accelerator, slow, torch_device
...@@ -290,7 +290,7 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa ...@@ -290,7 +290,7 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
# Because DistilBertForMultipleChoice requires inputs with different shapes we need to override this test. # Because DistilBertForMultipleChoice requires inputs with different shapes we need to override this test.
@require_flash_attn @require_flash_attn
@require_torch_accelerator @require_torch_accelerator
@mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_inference(self): def test_flash_attn_2_inference(self):
import torch import torch
...@@ -344,7 +344,7 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa ...@@ -344,7 +344,7 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
# Because DistilBertForMultipleChoice requires inputs with different shapes we need to override this test. # Because DistilBertForMultipleChoice requires inputs with different shapes we need to override this test.
@require_flash_attn @require_flash_attn
@require_torch_accelerator @require_torch_accelerator
@mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_inference_padding_right(self): def test_flash_attn_2_inference_padding_right(self):
import torch import torch
......
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
import unittest import unittest
import pytest
from parameterized import parameterized from parameterized import parameterized
from pytest import mark
from transformers import LlamaConfig, is_torch_available, set_seed from transformers import LlamaConfig, is_torch_available, set_seed
from transformers.testing_utils import ( from transformers.testing_utils import (
...@@ -385,7 +385,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -385,7 +385,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_generate_padding_right(self): def test_flash_attn_2_generate_padding_right(self):
""" """
......
...@@ -19,7 +19,7 @@ import gc ...@@ -19,7 +19,7 @@ import gc
import tempfile import tempfile
import unittest import unittest
from pytest import mark import pytest
from transformers import AutoTokenizer, MistralConfig, is_torch_available from transformers import AutoTokenizer, MistralConfig, is_torch_available
from transformers.testing_utils import ( from transformers.testing_utils import (
...@@ -369,7 +369,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -369,7 +369,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_generate_padding_right(self): def test_flash_attn_2_generate_padding_right(self):
import torch import torch
...@@ -403,7 +403,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -403,7 +403,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_inference_padding_right(self): def test_flash_attn_2_inference_padding_right(self):
import torch import torch
......
...@@ -21,7 +21,7 @@ import tempfile ...@@ -21,7 +21,7 @@ import tempfile
import unittest import unittest
import numpy as np import numpy as np
from pytest import mark import pytest
import transformers import transformers
from transformers import WhisperConfig from transformers import WhisperConfig
...@@ -800,7 +800,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -800,7 +800,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_inference(self): def test_flash_attn_2_inference(self):
import torch import torch
...@@ -845,7 +845,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -845,7 +845,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_inference_padding_right(self): def test_flash_attn_2_inference_padding_right(self):
import torch import torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment