Unverified Commit afad0c18 authored by Matt's avatar Matt Committed by GitHub
Browse files

Fix TF nightly tests (#20507)



* Fixed test_saved_model_extended

* Fix TFGPT2 tests

* make fixup

* Make sure keras-nlp utils are available for type hinting too

* Update src/transformers/testing_utils.py
Co-authored-by: default avatarYih-Dar <2521628+ydshieh@users.noreply.github.com>

* make fixup
Co-authored-by: default avatarYih-Dar <2521628+ydshieh@users.noreply.github.com>
parent 761b3fad
...@@ -525,6 +525,7 @@ _import_structure = { ...@@ -525,6 +525,7 @@ _import_structure = {
"is_datasets_available", "is_datasets_available",
"is_faiss_available", "is_faiss_available",
"is_flax_available", "is_flax_available",
"is_keras_nlp_available",
"is_phonemizer_available", "is_phonemizer_available",
"is_psutil_available", "is_psutil_available",
"is_py3nvml_available", "is_py3nvml_available",
...@@ -3706,6 +3707,7 @@ if TYPE_CHECKING: ...@@ -3706,6 +3707,7 @@ if TYPE_CHECKING:
is_datasets_available, is_datasets_available,
is_faiss_available, is_faiss_available,
is_flax_available, is_flax_available,
is_keras_nlp_available,
is_phonemizer_available, is_phonemizer_available,
is_psutil_available, is_psutil_available,
is_py3nvml_available, is_py3nvml_available,
......
...@@ -57,6 +57,7 @@ from .utils import ( ...@@ -57,6 +57,7 @@ from .utils import (
is_ftfy_available, is_ftfy_available,
is_ipex_available, is_ipex_available,
is_jumanpp_available, is_jumanpp_available,
is_keras_nlp_available,
is_librosa_available, is_librosa_available,
is_natten_available, is_natten_available,
is_onnx_available, is_onnx_available,
...@@ -392,6 +393,13 @@ def require_tensorflow_text(test_case): ...@@ -392,6 +393,13 @@ def require_tensorflow_text(test_case):
return unittest.skipUnless(is_tensorflow_text_available(), "test requires tensorflow_text")(test_case) return unittest.skipUnless(is_tensorflow_text_available(), "test requires tensorflow_text")(test_case)
def require_keras_nlp(test_case):
"""
Decorator marking a test that requires keras_nlp. These tests are skipped when keras_nlp isn't installed.
"""
return unittest.skipUnless(is_keras_nlp_available(), "test requires keras_nlp")(test_case)
def require_pandas(test_case): def require_pandas(test_case):
""" """
Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed. Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed.
......
...@@ -2,12 +2,12 @@ import unittest ...@@ -2,12 +2,12 @@ import unittest
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from transformers import AutoConfig, TFGPT2LMHeadModel, is_tensorflow_text_available, is_tf_available from transformers import AutoConfig, TFGPT2LMHeadModel, is_keras_nlp_available, is_tf_available
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from transformers.testing_utils import require_tensorflow_text, slow from transformers.testing_utils import require_keras_nlp, slow
if is_tensorflow_text_available(): if is_keras_nlp_available():
from transformers.models.gpt2 import TFGPT2Tokenizer from transformers.models.gpt2 import TFGPT2Tokenizer
if is_tf_available(): if is_tf_available():
...@@ -40,7 +40,7 @@ if is_tf_available(): ...@@ -40,7 +40,7 @@ if is_tf_available():
return outputs return outputs
@require_tensorflow_text @require_keras_nlp
class GPTTokenizationTest(unittest.TestCase): class GPTTokenizationTest(unittest.TestCase):
# The TF tokenizers are usually going to be used as pretrained tokenizers from existing model checkpoints, # The TF tokenizers are usually going to be used as pretrained tokenizers from existing model checkpoints,
# so that's what we focus on here. # so that's what we focus on here.
......
...@@ -218,11 +218,17 @@ class TFCoreModelTesterMixin: ...@@ -218,11 +218,17 @@ class TFCoreModelTesterMixin:
model = model_class(config) model = model_class(config)
num_out = len(model(class_inputs_dict)) num_out = len(model(class_inputs_dict))
for key in class_inputs_dict.keys(): for key in list(class_inputs_dict.keys()):
# Remove keys not in the serving signature, as the SavedModel will not be compiled to deal with them
if key not in model.serving.input_signature[0]:
del class_inputs_dict[key]
# Check it's a tensor, in case the inputs dict has some bools in it too # Check it's a tensor, in case the inputs dict has some bools in it too
if isinstance(class_inputs_dict[key], tf.Tensor) and class_inputs_dict[key].dtype.is_integer: elif isinstance(class_inputs_dict[key], tf.Tensor) and class_inputs_dict[key].dtype.is_integer:
class_inputs_dict[key] = tf.cast(class_inputs_dict[key], tf.int32) class_inputs_dict[key] = tf.cast(class_inputs_dict[key], tf.int32)
if set(class_inputs_dict.keys()) != set(model.serving.input_signature[0].keys()):
continue # Some models have inputs that the preparation functions don't create, we skip those
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, saved_model=True) model.save_pretrained(tmpdirname, saved_model=True)
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1") saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment