"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "90cb55bf773d6879441616e6378d16971b557868"
Unverified Commit fdcde144 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Add XLA test (#9848)

parent 99b9affa
......@@ -281,6 +281,10 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make BART float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make BART XLA compliant
pass
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
......
......@@ -217,6 +217,10 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make Blenderbot float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make Blenderbot XLA compliant
pass
def test_resize_token_embeddings(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
......@@ -282,6 +282,10 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make Blenderbot Small float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make Blenderbot Small XLA compliant
pass
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
......
......@@ -141,6 +141,19 @@ class TFModelTesterMixin:
outputs = run_in_graph_mode()
self.assertIsNotNone(outputs)
def test_xla_mode(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
@tf.function(experimental_compile=True)
def run_in_graph_mode():
return model(inputs)
outputs = run_in_graph_mode()
self.assertIsNotNone(outputs)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
......
......@@ -301,6 +301,10 @@ class TFConvBertModelTest(TFModelTesterMixin, unittest.TestCase):
[self.model_tester.num_attention_heads / 2, encoder_seq_length, encoder_key_length],
)
def test_xla_mode(self):
# TODO JP: Make ConvBert XLA compliant
pass
@slow
def test_model_from_pretrained(self):
model = TFConvBertModel.from_pretrained("YituTech/conv-bert-base")
......
......@@ -225,6 +225,10 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make CTRL float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make CTRL XLA compliant
pass
@slow
def test_model_from_pretrained(self):
for model_name in TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
......@@ -334,6 +334,10 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make Flaubert float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make Flaubert XLA compliant
pass
@require_tf
@require_sentencepiece
......
......@@ -391,6 +391,10 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make GPT2 float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make GPT2 XLA compliant
pass
@slow
def test_model_from_pretrained(self):
for model_name in TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
......@@ -361,6 +361,10 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make LED float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make LED XLA compliant
pass
def test_saved_model_with_attentions_output(self):
# This test don't pass because of the error:
# condition [13,8,4,5], then [13,8,4,5], and else [13,8,4,6] must be broadcastable
......
......@@ -359,6 +359,10 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make Longformer float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make Blenderbot XLA compliant
pass
@require_tf
@require_sentencepiece
......
......@@ -250,6 +250,10 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make Marian float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make Marian XLA compliant
pass
def test_resize_token_embeddings(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
......@@ -221,6 +221,10 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make MBart float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make MBart XLA compliant
pass
def test_resize_token_embeddings(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
......@@ -231,6 +231,10 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_mpnet_for_token_classification(*config_and_inputs)
def test_xla_mode(self):
# TODO JP: Make MPNet XLA compliant
pass
@slow
def test_model_from_pretrained(self):
for model_name in ["microsoft/mpnet-base"]:
......
......@@ -249,6 +249,10 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make OpenAIGPT float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make OpenAIGPT XLA compliant
pass
@slow
def test_model_from_pretrained(self):
for model_name in TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
......@@ -248,6 +248,10 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make Pegasus float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make Pegasus XLA compliant
pass
def test_resize_token_embeddings(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
......@@ -310,6 +310,10 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make T5 float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make T5 XLA compliant
pass
@slow
def test_model_from_pretrained(self):
model = TFT5Model.from_pretrained("t5-small")
......@@ -443,6 +447,10 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make T5 float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make T5 XLA compliant
pass
@require_tf
@require_sentencepiece
......
......@@ -208,6 +208,10 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make TransfoXL float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make TransfoXL XLA compliant
pass
@slow
def test_model_from_pretrained(self):
for model_name in TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
......@@ -330,6 +330,10 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make XLM float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make XLM XLA compliant
pass
@slow
def test_model_from_pretrained(self):
for model_name in TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment