Unverified Commit ee0d001d authored by Matt's avatar Matt Committed by GitHub
Browse files

Add a TF in-graph tokenizer for BERT (#17701)

* Add a TF in-graph tokenizer for BERT

* Add from_pretrained

* Add proper truncation, option handling to match other tokenizers

* Add proper imports and guards

* Add test, fix all the bugs exposed by said test

* Fix truncation of paired texts in graph mode, more test updates

* Small fixes, add a (very careful) test for savedmodel

* Add tensorflow-text dependency, make fixup

* Update documentation

* Update documentation

* make fixup

* Slight changes to tests

* Add some docstring examples

* Update tests

* Update tests and add proper lowercasing/normalization

* make fixup

* Add docstring for padding!

* Mark slow tests

* make fixup

* Fall back to BertTokenizerFast if BertTokenizer is unavailable

* Fall back to BertTokenizerFast if BertTokenizer is unavailable

* make fixup

* Properly handle tensorflow-text dummies
parent 401fcca6
...@@ -58,6 +58,10 @@ This model was contributed by [thomwolf](https://huggingface.co/thomwolf). The o ...@@ -58,6 +58,10 @@ This model was contributed by [thomwolf](https://huggingface.co/thomwolf). The o
[[autodoc]] BertTokenizerFast [[autodoc]] BertTokenizerFast
## TFBertTokenizer
[[autodoc]] TFBertTokenizer
## Bert specific outputs ## Bert specific outputs
[[autodoc]] models.bert.modeling_bert.BertForPreTrainingOutput [[autodoc]] models.bert.modeling_bert.BertForPreTrainingOutput
......
...@@ -155,6 +155,7 @@ _deps = [ ...@@ -155,6 +155,7 @@ _deps = [
"starlette", "starlette",
"tensorflow-cpu>=2.3", "tensorflow-cpu>=2.3",
"tensorflow>=2.3", "tensorflow>=2.3",
"tensorflow-text",
"tf2onnx", "tf2onnx",
"timeout-decorator", "timeout-decorator",
"timm", "timm",
...@@ -238,8 +239,8 @@ extras = {} ...@@ -238,8 +239,8 @@ extras = {}
extras["ja"] = deps_list("fugashi", "ipadic", "unidic_lite", "unidic") extras["ja"] = deps_list("fugashi", "ipadic", "unidic_lite", "unidic")
extras["sklearn"] = deps_list("scikit-learn") extras["sklearn"] = deps_list("scikit-learn")
extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx") extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx", "tensorflow-text")
extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "tf2onnx") extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "tf2onnx", "tensorflow-text")
extras["torch"] = deps_list("torch") extras["torch"] = deps_list("torch")
extras["accelerate"] = deps_list("accelerate") extras["accelerate"] = deps_list("accelerate")
......
...@@ -35,6 +35,7 @@ from .utils import ( ...@@ -35,6 +35,7 @@ from .utils import (
is_scatter_available, is_scatter_available,
is_sentencepiece_available, is_sentencepiece_available,
is_speech_available, is_speech_available,
is_tensorflow_text_available,
is_tf_available, is_tf_available,
is_timm_available, is_timm_available,
is_tokenizers_available, is_tokenizers_available,
...@@ -435,6 +436,7 @@ _import_structure = { ...@@ -435,6 +436,7 @@ _import_structure = {
"is_sentencepiece_available", "is_sentencepiece_available",
"is_sklearn_available", "is_sklearn_available",
"is_speech_available", "is_speech_available",
"is_tensorflow_text_available",
"is_tf_available", "is_tf_available",
"is_timm_available", "is_timm_available",
"is_tokenizers_available", "is_tokenizers_available",
...@@ -575,6 +577,19 @@ else: ...@@ -575,6 +577,19 @@ else:
_import_structure["models.mctct"].append("MCTCTFeatureExtractor") _import_structure["models.mctct"].append("MCTCTFeatureExtractor")
_import_structure["models.speech_to_text"].append("Speech2TextFeatureExtractor") _import_structure["models.speech_to_text"].append("Speech2TextFeatureExtractor")
# Tensorflow-text-specific objects
try:
if not is_tensorflow_text_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_tensorflow_text_objects
_import_structure["utils.dummy_tensorflow_text_objects"] = [
name for name in dir(dummy_tensorflow_text_objects) if not name.startswith("_")
]
else:
_import_structure["models.bert"].append("TFBertTokenizer")
try: try:
if not (is_sentencepiece_available() and is_speech_available()): if not (is_sentencepiece_available() and is_speech_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
...@@ -3067,6 +3082,7 @@ if TYPE_CHECKING: ...@@ -3067,6 +3082,7 @@ if TYPE_CHECKING:
is_sentencepiece_available, is_sentencepiece_available,
is_sklearn_available, is_sklearn_available,
is_speech_available, is_speech_available,
is_tensorflow_text_available,
is_tf_available, is_tf_available,
is_timm_available, is_timm_available,
is_tokenizers_available, is_tokenizers_available,
...@@ -3183,6 +3199,14 @@ if TYPE_CHECKING: ...@@ -3183,6 +3199,14 @@ if TYPE_CHECKING:
from .models.mctct import MCTCTFeatureExtractor from .models.mctct import MCTCTFeatureExtractor
from .models.speech_to_text import Speech2TextFeatureExtractor from .models.speech_to_text import Speech2TextFeatureExtractor
try:
if not is_tensorflow_text_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_tensorflow_text_objects import *
else:
from .models.bert import TFBertTokenizer
try: try:
if not (is_speech_available() and is_sentencepiece_available()): if not (is_speech_available() and is_sentencepiece_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
......
...@@ -61,6 +61,7 @@ deps = { ...@@ -61,6 +61,7 @@ deps = {
"starlette": "starlette", "starlette": "starlette",
"tensorflow-cpu": "tensorflow-cpu>=2.3", "tensorflow-cpu": "tensorflow-cpu>=2.3",
"tensorflow": "tensorflow>=2.3", "tensorflow": "tensorflow>=2.3",
"tensorflow-text": "tensorflow-text",
"tf2onnx": "tf2onnx", "tf2onnx": "tf2onnx",
"timeout-decorator": "timeout-decorator", "timeout-decorator": "timeout-decorator",
"timm": "timm", "timm": "timm",
......
...@@ -22,6 +22,7 @@ from ...utils import ( ...@@ -22,6 +22,7 @@ from ...utils import (
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule, _LazyModule,
is_flax_available, is_flax_available,
is_tensorflow_text_available,
is_tf_available, is_tf_available,
is_tokenizers_available, is_tokenizers_available,
is_torch_available, is_torch_available,
...@@ -84,6 +85,13 @@ else: ...@@ -84,6 +85,13 @@ else:
"TFBertModel", "TFBertModel",
"TFBertPreTrainedModel", "TFBertPreTrainedModel",
] ]
try:
if not is_tensorflow_text_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_bert_tf"] = ["TFBertTokenizer"]
try: try:
if not is_flax_available(): if not is_flax_available():
...@@ -160,6 +168,14 @@ if TYPE_CHECKING: ...@@ -160,6 +168,14 @@ if TYPE_CHECKING:
TFBertPreTrainedModel, TFBertPreTrainedModel,
) )
try:
if not is_tensorflow_text_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_bert_tf import TFBertTokenizer
try: try:
if not is_flax_available(): if not is_flax_available():
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
......
import os
from typing import List, Union
import tensorflow as tf
from tensorflow_text import FastBertTokenizer, ShrinkLongestTrimmer, case_fold_utf8, combine_segments, pad_model_inputs
from .tokenization_bert import BertTokenizer
class TFBertTokenizer(tf.keras.layers.Layer):
"""
This is an in-graph tokenizer for BERT. It should be initialized similarly to other tokenizers, using the
`from_pretrained()` method. It can also be initialized with the `from_tokenizer()` method, which imports settings
from an existing standard tokenizer object.
In-graph tokenizers, unlike other Hugging Face tokenizers, are actually Keras layers and are designed to be run
when the model is called, rather than during preprocessing. As a result, they have somewhat more limited options
than standard tokenizer classes. They are most useful when you want to create an end-to-end model that goes
straight from `tf.string` inputs to outputs.
Args:
vocab_list (`list`):
List containing the vocabulary.
do_lower_case (`bool`, *optional*, defaults to `True`):
Whether or not to lowercase the input when tokenizing.
cls_token_id (`str`, *optional*, defaults to `"[CLS]"`):
The classifier token which is used when doing sequence classification (classification of the whole sequence
instead of per-token classification). It is the first token of the sequence when built with special tokens.
sep_token_id (`str`, *optional*, defaults to `"[SEP]"`):
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
sequence classification or for a text and a question for question answering. It is also used as the last
token of a sequence built with special tokens.
pad_token_id (`str`, *optional*, defaults to `"[PAD]"`):
The token used for padding, for example when batching sequences of different lengths.
padding (`str`, defaults to `"longest"`):
The type of padding to use. Can be either `"longest"`, to pad only up to the longest sample in the batch,
or `"max_length", to pad all inputs to the maximum length supported by the tokenizer.
truncation (`bool`, *optional*, defaults to `True`):
Whether to truncate the sequence to the maximum length.
max_length (`int`, *optional*, defaults to `512`):
The maximum length of the sequence, used for padding (if `padding` is "max_length") and/or truncation (if
`truncation` is `True`).
pad_to_multiple_of (`int`, *optional*, defaults to `None`):
If set, the sequence will be padded to a multiple of this value.
return_token_type_ids (`bool`, *optional*, defaults to `True`):
Whether to return token_type_ids.
return_attention_mask (`bool`, *optional*, defaults to `True`):
Whether to return the attention_mask.
"""
def __init__(
self,
vocab_list: List,
do_lower_case: bool,
cls_token_id: int = None,
sep_token_id: int = None,
pad_token_id: int = None,
padding: str = "longest",
truncation: bool = True,
max_length: int = 512,
pad_to_multiple_of: int = None,
return_token_type_ids: bool = True,
return_attention_mask: bool = True,
):
super().__init__()
self.tf_tokenizer = FastBertTokenizer(
vocab_list, token_out_type=tf.int64, lower_case_nfd_strip_accents=do_lower_case
)
self.vocab_list = vocab_list
self.do_lower_case = do_lower_case
self.cls_token_id = cls_token_id or vocab_list.index("[CLS]")
self.sep_token_id = sep_token_id or vocab_list.index("[SEP]")
self.pad_token_id = pad_token_id or vocab_list.index("[PAD]")
self.paired_trimmer = ShrinkLongestTrimmer(max_length - 3, axis=1) # Allow room for special tokens
self.max_length = max_length
self.padding = padding
self.truncation = truncation
self.pad_to_multiple_of = pad_to_multiple_of
self.return_token_type_ids = return_token_type_ids
self.return_attention_mask = return_attention_mask
@classmethod
def from_tokenizer(cls, tokenizer: "PreTrainedTokenizerBase", **kwargs): # noqa: F821
"""
Initialize a `TFBertTokenizer` from an existing `Tokenizer`.
Args:
tokenizer (`PreTrainedTokenizerBase`):
The tokenizer to use to initialize the `TFBertTokenizer`.
Examples:
```python
from transformers import AutoTokenizer, TFBertTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
tf_tokenizer = TFBertTokenizer.from_tokenizer(tokenizer)
```
"""
vocab = tokenizer.get_vocab()
vocab = sorted([(wordpiece, idx) for wordpiece, idx in vocab.items()], key=lambda x: x[1])
vocab_list = [entry[0] for entry in vocab]
return cls(
vocab_list=vocab_list,
do_lower_case=tokenizer.do_lower_case,
cls_token_id=tokenizer.cls_token_id,
sep_token_id=tokenizer.sep_token_id,
pad_token_id=tokenizer.pad_token_id,
**kwargs,
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], *init_inputs, **kwargs):
"""
Instantiate a `TFBertTokenizer` from a pre-trained tokenizer.
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
The name or path to the pre-trained tokenizer.
Examples:
```python
from transformers import TFBertTokenizer
tf_tokenizer = TFBertTokenizer.from_pretrained("bert-base-uncased")
```
"""
try:
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs)
except: # noqa: E722
from .tokenization_bert_fast import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs)
return cls.from_tokenizer(tokenizer, **kwargs)
def unpaired_tokenize(self, texts):
if self.do_lower_case:
texts = case_fold_utf8(texts)
return self.tf_tokenizer.tokenize(texts)
def call(
self,
text,
text_pair=None,
padding=None,
truncation=None,
max_length=None,
pad_to_multiple_of=None,
return_token_type_ids=None,
return_attention_mask=None,
):
if padding is None:
padding = self.padding
if padding not in ("longest", "max_length"):
raise ValueError("Padding must be either 'longest' or 'max_length'!")
if max_length is not None and text_pair is not None:
# Because we have to instantiate a Trimmer to do it properly
raise ValueError("max_length cannot be overridden at call time when truncating paired texts!")
if max_length is None:
max_length = self.max_length
if truncation is None:
truncation = self.truncation
if pad_to_multiple_of is None:
pad_to_multiple_of = self.pad_to_multiple_of
if return_token_type_ids is None:
return_token_type_ids = self.return_token_type_ids
if return_attention_mask is None:
return_attention_mask = self.return_attention_mask
if not isinstance(text, tf.Tensor):
text = tf.convert_to_tensor(text)
if text_pair is not None and not isinstance(text_pair, tf.Tensor):
text_pair = tf.convert_to_tensor(text_pair)
if text_pair is not None:
if text.shape.rank > 1:
raise ValueError("text argument should not be multidimensional when a text pair is supplied!")
if text_pair.shape.rank > 1:
raise ValueError("text_pair should not be multidimensional!")
if text.shape.rank == 2:
text, text_pair = text[:, 0], text[:, 1]
text = self.unpaired_tokenize(text)
if text_pair is None: # Unpaired text
if truncation:
text = text[:, : max_length - 2] # Allow room for special tokens
input_ids, token_type_ids = combine_segments(
(text,), start_of_sequence_id=self.cls_token_id, end_of_segment_id=self.sep_token_id
)
else: # Paired text
text_pair = self.unpaired_tokenize(text_pair)
if truncation:
text, text_pair = self.paired_trimmer.trim([text, text_pair])
input_ids, token_type_ids = combine_segments(
(text, text_pair), start_of_sequence_id=self.cls_token_id, end_of_segment_id=self.sep_token_id
)
if padding == "longest":
pad_length = input_ids.bounding_shape(axis=1)
if pad_to_multiple_of is not None:
# No ceiling division in tensorflow, so we negate floordiv instead
pad_length = pad_to_multiple_of * (-tf.math.floordiv(-pad_length, pad_to_multiple_of))
else:
pad_length = max_length
input_ids, attention_mask = pad_model_inputs(input_ids, max_seq_length=pad_length, pad_value=self.pad_token_id)
output = {"input_ids": input_ids}
if return_attention_mask:
output["attention_mask"] = attention_mask
if return_token_type_ids:
token_type_ids, _ = pad_model_inputs(
token_type_ids, max_seq_length=pad_length, pad_value=self.pad_token_id
)
output["token_type_ids"] = token_type_ids
return output
def get_config(self):
return {
"vocab_list": self.vocab_list,
"do_lower_case": self.do_lower_case,
"cls_token_id": self.cls_token_id,
"sep_token_id": self.sep_token_id,
"pad_token_id": self.pad_token_id,
}
...@@ -63,6 +63,7 @@ from .utils import ( ...@@ -63,6 +63,7 @@ from .utils import (
is_soundfile_availble, is_soundfile_availble,
is_spacy_available, is_spacy_available,
is_tensorflow_probability_available, is_tensorflow_probability_available,
is_tensorflow_text_available,
is_tf2onnx_available, is_tf2onnx_available,
is_tf_available, is_tf_available,
is_timm_available, is_timm_available,
...@@ -361,6 +362,14 @@ def require_tokenizers(test_case): ...@@ -361,6 +362,14 @@ def require_tokenizers(test_case):
return unittest.skipUnless(is_tokenizers_available(), "test requires tokenizers")(test_case) return unittest.skipUnless(is_tokenizers_available(), "test requires tokenizers")(test_case)
def require_tensorflow_text(test_case):
"""
Decorator marking a test that requires tensorflow_text. These tests are skipped when tensroflow_text isn't
installed.
"""
return unittest.skipUnless(is_tensorflow_text_available(), "test requires tensorflow_text")(test_case)
def require_pandas(test_case): def require_pandas(test_case):
""" """
Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed. Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed.
......
...@@ -119,6 +119,7 @@ from .import_utils import ( ...@@ -119,6 +119,7 @@ from .import_utils import (
is_spacy_available, is_spacy_available,
is_speech_available, is_speech_available,
is_tensorflow_probability_available, is_tensorflow_probability_available,
is_tensorflow_text_available,
is_tf2onnx_available, is_tf2onnx_available,
is_tf_available, is_tf_available,
is_timm_available, is_timm_available,
......
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from ..utils import DummyObject, requires_backends
class TFBertTokenizer(metaclass=DummyObject):
_backends = ["tensorflow_text"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tensorflow_text"])
...@@ -491,6 +491,10 @@ def is_spacy_available(): ...@@ -491,6 +491,10 @@ def is_spacy_available():
return importlib.util.find_spec("spacy") is not None return importlib.util.find_spec("spacy") is not None
def is_tensorflow_text_available():
return importlib.util.find_spec("tensorflow_text") is not None
def is_in_notebook(): def is_in_notebook():
try: try:
# Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py # Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
...@@ -721,6 +725,12 @@ TENSORFLOW_PROBABILITY_IMPORT_ERROR = """ ...@@ -721,6 +725,12 @@ TENSORFLOW_PROBABILITY_IMPORT_ERROR = """
explained here: https://github.com/tensorflow/probability. explained here: https://github.com/tensorflow/probability.
""" """
# docstyle-ignore
TENSORFLOW_TEXT_IMPORT_ERROR = """
{0} requires the tensorflow_text library but it was not found in your environment. You can install it with pip as
explained here: https://www.tensorflow.org/text/guide/tf_text_intro.
"""
# docstyle-ignore # docstyle-ignore
PANDAS_IMPORT_ERROR = """ PANDAS_IMPORT_ERROR = """
...@@ -800,6 +810,7 @@ BACKENDS_MAPPING = OrderedDict( ...@@ -800,6 +810,7 @@ BACKENDS_MAPPING = OrderedDict(
("speech", (is_speech_available, SPEECH_IMPORT_ERROR)), ("speech", (is_speech_available, SPEECH_IMPORT_ERROR)),
("tensorflow_probability", (is_tensorflow_probability_available, TENSORFLOW_PROBABILITY_IMPORT_ERROR)), ("tensorflow_probability", (is_tensorflow_probability_available, TENSORFLOW_PROBABILITY_IMPORT_ERROR)),
("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
("tensorflow_text", (is_tensorflow_text_available, TENSORFLOW_TEXT_IMPORT_ERROR)),
("timm", (is_timm_available, TIMM_IMPORT_ERROR)), ("timm", (is_timm_available, TIMM_IMPORT_ERROR)),
("tokenizers", (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)), ("tokenizers", (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)),
("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
......
import unittest
from pathlib import Path
from tempfile import TemporaryDirectory
from transformers import AutoConfig, TFAutoModel, is_tensorflow_text_available, is_tf_available
from transformers.models.bert.tokenization_bert import BertTokenizer
from transformers.testing_utils import require_tensorflow_text, slow
if is_tensorflow_text_available():
from transformers.models.bert import TFBertTokenizer
if is_tf_available():
import tensorflow as tf
TOKENIZER_CHECKPOINTS = ["bert-base-uncased", "bert-base-cased"]
TINY_MODEL_CHECKPOINT = "hf-internal-testing/tiny-bert-tf-only"
if is_tf_available():
class ModelToSave(tf.keras.Model):
def __init__(self, tokenizer):
super().__init__()
self.tokenizer = tokenizer
config = AutoConfig.from_pretrained(TINY_MODEL_CHECKPOINT)
self.bert = TFAutoModel.from_config(config)
def call(self, inputs):
tokenized = self.tokenizer(inputs)
out = self.bert(**tokenized)
return out["pooler_output"]
@require_tensorflow_text
class BertTokenizationTest(unittest.TestCase):
# The TF tokenizers are usually going to be used as pretrained tokenizers from existing model checkpoints,
# so that's what we focus on here.
def setUp(self):
super().setUp()
self.tokenizers = [BertTokenizer.from_pretrained(checkpoint) for checkpoint in TOKENIZER_CHECKPOINTS]
self.tf_tokenizers = [TFBertTokenizer.from_pretrained(checkpoint) for checkpoint in TOKENIZER_CHECKPOINTS]
self.test_sentences = [
"This is a straightforward English test sentence.",
"This one has some weird characters\rto\nsee\r\nif those\u00E9break things.",
"Now we're going to add some Chinese: 一 二 三 一二三",
"And some much more rare Chinese: 齉 堃 齉堃",
"Je vais aussi écrire en français pour tester les accents",
"Classical Irish also has some unusual characters, so in they go: Gaelaċ, ꝼ",
]
self.paired_sentences = list(zip(self.test_sentences, self.test_sentences[::-1]))
def test_output_equivalence(self):
for tokenizer, tf_tokenizer in zip(self.tokenizers, self.tf_tokenizers):
for test_inputs in (self.test_sentences, self.paired_sentences):
python_outputs = tokenizer(test_inputs, return_tensors="tf", padding="longest")
tf_outputs = tf_tokenizer(test_inputs)
for key in python_outputs.keys():
self.assertTrue(tf.reduce_all(python_outputs[key].shape == tf_outputs[key].shape))
self.assertTrue(tf.reduce_all(tf.cast(python_outputs[key], tf.int64) == tf_outputs[key]))
@slow
def test_different_pairing_styles(self):
for tf_tokenizer in self.tf_tokenizers:
merged_outputs = tf_tokenizer(self.paired_sentences)
separated_outputs = tf_tokenizer(
text=[sentence[0] for sentence in self.paired_sentences],
text_pair=[sentence[1] for sentence in self.paired_sentences],
)
for key in merged_outputs.keys():
self.assertTrue(tf.reduce_all(tf.cast(merged_outputs[key], tf.int64) == separated_outputs[key]))
@slow
def test_graph_mode(self):
for tf_tokenizer in self.tf_tokenizers:
compiled_tokenizer = tf.function(tf_tokenizer)
for test_inputs in (self.test_sentences, self.paired_sentences):
test_inputs = tf.constant(test_inputs)
compiled_outputs = compiled_tokenizer(test_inputs)
eager_outputs = tf_tokenizer(test_inputs)
for key in eager_outputs.keys():
self.assertTrue(tf.reduce_all(eager_outputs[key] == compiled_outputs[key]))
@slow
def test_saved_model(self):
for tf_tokenizer in self.tf_tokenizers:
model = ModelToSave(tokenizer=tf_tokenizer)
test_inputs = tf.convert_to_tensor(self.test_sentences)
out = model(test_inputs) # Build model with some sample inputs
with TemporaryDirectory() as tempdir:
save_path = Path(tempdir) / "saved.model"
model.save(save_path)
loaded_model = tf.keras.models.load_model(save_path)
loaded_output = loaded_model(test_inputs)
# We may see small differences because the loaded model is compiled, so we need an epsilon for the test
self.assertLessEqual(tf.reduce_max(tf.abs(out - loaded_output)), 1e-5)
...@@ -26,7 +26,7 @@ PATH_TO_TRANSFORMERS = "src/transformers" ...@@ -26,7 +26,7 @@ PATH_TO_TRANSFORMERS = "src/transformers"
_re_backend = re.compile(r"is\_([a-z_]*)_available()") _re_backend = re.compile(r"is\_([a-z_]*)_available()")
# Matches from xxx import bla # Matches from xxx import bla
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n") _re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
_re_test_backend = re.compile(r"^\s+if\s+not\s+is\_[a-z]*\_available\(\)") _re_test_backend = re.compile(r"^\s+if\s+not\s+is\_[a-z_]*\_available\(\)")
DUMMY_CONSTANT = """ DUMMY_CONSTANT = """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment