Unverified Commit fdcde144 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Add XLA test (#9848)

parent 99b9affa
...@@ -281,6 +281,10 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -281,6 +281,10 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make BART float16 compliant # TODO JP: Make BART float16 compliant
pass pass
def test_xla_mode(self):
# TODO JP: Make BART XLA compliant
pass
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""): def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error.""" """If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
......
...@@ -217,6 +217,10 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -217,6 +217,10 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make Blenderbot float16 compliant # TODO JP: Make Blenderbot float16 compliant
pass pass
def test_xla_mode(self):
# TODO JP: Make Blenderbot XLA compliant
pass
def test_resize_token_embeddings(self): def test_resize_token_embeddings(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -282,6 +282,10 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -282,6 +282,10 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make Blenderbot Small float16 compliant # TODO JP: Make Blenderbot Small float16 compliant
pass pass
def test_xla_mode(self):
# TODO JP: Make Blenderbot Small XLA compliant
pass
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""): def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error.""" """If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
......
...@@ -141,6 +141,19 @@ class TFModelTesterMixin: ...@@ -141,6 +141,19 @@ class TFModelTesterMixin:
outputs = run_in_graph_mode() outputs = run_in_graph_mode()
self.assertIsNotNone(outputs) self.assertIsNotNone(outputs)
def test_xla_mode(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
@tf.function(experimental_compile=True)
def run_in_graph_mode():
return model(inputs)
outputs = run_in_graph_mode()
self.assertIsNotNone(outputs)
def test_forward_signature(self): def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -301,6 +301,10 @@ class TFConvBertModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -301,6 +301,10 @@ class TFConvBertModelTest(TFModelTesterMixin, unittest.TestCase):
[self.model_tester.num_attention_heads / 2, encoder_seq_length, encoder_key_length], [self.model_tester.num_attention_heads / 2, encoder_seq_length, encoder_key_length],
) )
def test_xla_mode(self):
# TODO JP: Make ConvBert XLA compliant
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
model = TFConvBertModel.from_pretrained("YituTech/conv-bert-base") model = TFConvBertModel.from_pretrained("YituTech/conv-bert-base")
......
...@@ -225,6 +225,10 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -225,6 +225,10 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make CTRL float16 compliant # TODO JP: Make CTRL float16 compliant
pass pass
def test_xla_mode(self):
# TODO JP: Make CTRL XLA compliant
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
...@@ -334,6 +334,10 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -334,6 +334,10 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make Flaubert float16 compliant # TODO JP: Make Flaubert float16 compliant
pass pass
def test_xla_mode(self):
# TODO JP: Make Flaubert XLA compliant
pass
@require_tf @require_tf
@require_sentencepiece @require_sentencepiece
......
...@@ -391,6 +391,10 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -391,6 +391,10 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make GPT2 float16 compliant # TODO JP: Make GPT2 float16 compliant
pass pass
def test_xla_mode(self):
# TODO JP: Make GPT2 XLA compliant
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
...@@ -361,6 +361,10 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -361,6 +361,10 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make LED float16 compliant # TODO JP: Make LED float16 compliant
pass pass
def test_xla_mode(self):
# TODO JP: Make LED XLA compliant
pass
def test_saved_model_with_attentions_output(self): def test_saved_model_with_attentions_output(self):
# This test don't pass because of the error: # This test don't pass because of the error:
# condition [13,8,4,5], then [13,8,4,5], and else [13,8,4,6] must be broadcastable # condition [13,8,4,5], then [13,8,4,5], and else [13,8,4,6] must be broadcastable
......
...@@ -359,6 +359,10 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -359,6 +359,10 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make Longformer float16 compliant # TODO JP: Make Longformer float16 compliant
pass pass
def test_xla_mode(self):
# TODO JP: Make Blenderbot XLA compliant
pass
@require_tf @require_tf
@require_sentencepiece @require_sentencepiece
......
...@@ -250,6 +250,10 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -250,6 +250,10 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make Marian float16 compliant # TODO JP: Make Marian float16 compliant
pass pass
def test_xla_mode(self):
# TODO JP: Make Marian XLA compliant
pass
def test_resize_token_embeddings(self): def test_resize_token_embeddings(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -221,6 +221,10 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -221,6 +221,10 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make MBart float16 compliant # TODO JP: Make MBart float16 compliant
pass pass
def test_xla_mode(self):
# TODO JP: Make MBart XLA compliant
pass
def test_resize_token_embeddings(self): def test_resize_token_embeddings(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -231,6 +231,10 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -231,6 +231,10 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_mpnet_for_token_classification(*config_and_inputs) self.model_tester.create_and_check_mpnet_for_token_classification(*config_and_inputs)
def test_xla_mode(self):
# TODO JP: Make MPNet XLA compliant
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in ["microsoft/mpnet-base"]: for model_name in ["microsoft/mpnet-base"]:
......
...@@ -249,6 +249,10 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -249,6 +249,10 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make OpenAIGPT float16 compliant # TODO JP: Make OpenAIGPT float16 compliant
pass pass
def test_xla_mode(self):
# TODO JP: Make OpenAIGPT XLA compliant
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
...@@ -248,6 +248,10 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -248,6 +248,10 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make Pegasus float16 compliant # TODO JP: Make Pegasus float16 compliant
pass pass
def test_xla_mode(self):
# TODO JP: Make Pegasus XLA compliant
pass
def test_resize_token_embeddings(self): def test_resize_token_embeddings(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -310,6 +310,10 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -310,6 +310,10 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make T5 float16 compliant # TODO JP: Make T5 float16 compliant
pass pass
def test_xla_mode(self):
# TODO JP: Make T5 XLA compliant
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
model = TFT5Model.from_pretrained("t5-small") model = TFT5Model.from_pretrained("t5-small")
...@@ -443,6 +447,10 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -443,6 +447,10 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make T5 float16 compliant # TODO JP: Make T5 float16 compliant
pass pass
def test_xla_mode(self):
# TODO JP: Make T5 XLA compliant
pass
@require_tf @require_tf
@require_sentencepiece @require_sentencepiece
......
...@@ -208,6 +208,10 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -208,6 +208,10 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make TransfoXL float16 compliant # TODO JP: Make TransfoXL float16 compliant
pass pass
def test_xla_mode(self):
# TODO JP: Make TransfoXL XLA compliant
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
...@@ -330,6 +330,10 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -330,6 +330,10 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make XLM float16 compliant # TODO JP: Make XLM float16 compliant
pass pass
def test_xla_mode(self):
# TODO JP: Make XLM XLA compliant
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment