Unverified Commit 546cbe7e authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Speedup tf tests (#10601)

* Pipeline tests should be slow

* Temporarily mark some tests as slow

* Temporarily mark Barthez tests as slow
parent 696e8a43
...@@ -129,6 +129,7 @@ class TFModelTesterMixin: ...@@ -129,6 +129,7 @@ class TFModelTesterMixin:
self.assert_outputs_same(after_outputs, outputs) self.assert_outputs_same(after_outputs, outputs)
@slow
def test_graph_mode(self): def test_graph_mode(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
...@@ -142,6 +143,7 @@ class TFModelTesterMixin: ...@@ -142,6 +143,7 @@ class TFModelTesterMixin:
outputs = run_in_graph_mode() outputs = run_in_graph_mode()
self.assertIsNotNone(outputs) self.assertIsNotNone(outputs)
@slow
def test_xla_mode(self): def test_xla_mode(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
...@@ -182,6 +184,7 @@ class TFModelTesterMixin: ...@@ -182,6 +184,7 @@ class TFModelTesterMixin:
expected_arg_names = ["input_ids"] expected_arg_names = ["input_ids"]
self.assertListEqual(arg_names[:1], expected_arg_names) self.assertListEqual(arg_names[:1], expected_arg_names)
@slow
def test_saved_model_creation(self): def test_saved_model_creation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = False config.output_hidden_states = False
...@@ -311,6 +314,7 @@ class TFModelTesterMixin: ...@@ -311,6 +314,7 @@ class TFModelTesterMixin:
onnxruntime.InferenceSession(onnx_model.SerializeToString()) onnxruntime.InferenceSession(onnx_model.SerializeToString())
@slow
def test_mixed_precision(self): def test_mixed_precision(self):
tf.keras.mixed_precision.experimental.set_policy("mixed_float16") tf.keras.mixed_precision.experimental.set_policy("mixed_float16")
...@@ -484,6 +488,7 @@ class TFModelTesterMixin: ...@@ -484,6 +488,7 @@ class TFModelTesterMixin:
max_diff = np.amax(np.abs(tfo - pto)) max_diff = np.amax(np.abs(tfo - pto))
self.assertLessEqual(max_diff, 4e-2) self.assertLessEqual(max_diff, 4e-2)
@slow
def test_train_pipeline_custom_model(self): def test_train_pipeline_custom_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# head_mask and decoder_head_mask has different shapes than other input args # head_mask and decoder_head_mask has different shapes than other input args
...@@ -904,6 +909,7 @@ class TFModelTesterMixin: ...@@ -904,6 +909,7 @@ class TFModelTesterMixin:
model(inputs) model(inputs)
@slow
def test_graph_mode_with_inputs_embeds(self): def test_graph_mode_with_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -17,13 +17,14 @@ ...@@ -17,13 +17,14 @@
import unittest import unittest
from transformers import BarthezTokenizer, BarthezTokenizerFast, BatchEncoding from transformers import BarthezTokenizer, BarthezTokenizerFast, BatchEncoding
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow
from .test_tokenization_common import TokenizerTesterMixin from .test_tokenization_common import TokenizerTesterMixin
@require_tokenizers @require_tokenizers
@require_sentencepiece @require_sentencepiece
@slow
class BarthezTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class BarthezTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = BarthezTokenizer tokenizer_class = BarthezTokenizer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment