Unverified Commit 12febc20 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: Export TF generate with a TF tokenizer (#22310)

* Export TF generate with a TF tokenizer

* remove unused lines
parent 5fd4e3c8
......@@ -1725,7 +1725,6 @@ class TFGenerationMixin:
# 2-to-n generation steps can then be run in autoregressive fashion
# only in case 1st generation step does NOT yield EOS token though
if greedy_search_cond_fn(generated, finished_sequences, cur_len, model_kwargs):
maximum_iterations = max_length - cur_len
generated, _, cur_len, _ = tf.while_loop(
greedy_search_cond_fn,
......@@ -2016,7 +2015,6 @@ class TFGenerationMixin:
# 2-to-n generation steps can then be run in autoregressive fashion
# only in case 1st generation step does NOT yield EOS token though
if sample_cond_fn(generated, finished_sequences, cur_len, model_kwargs):
maximum_iterations = max_length - cur_len
generated, _, cur_len, _ = tf.while_loop(
sample_cond_fn,
......@@ -2565,17 +2563,6 @@ class TFGenerationMixin:
# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
# NOT yield EOS token though)
if beam_search_cond_fn(
cur_len,
running_sequences,
running_scores,
running_beam_indices,
sequences,
scores,
beam_indices,
is_sent_finished,
model_kwargs,
):
maximum_iterations = max_length - cur_len
(
cur_len,
......@@ -3019,17 +3006,8 @@ class TFGenerationMixin:
# 2-to-n generation steps can then be run in autoregressive fashion
# only in case 1st generation step does NOT yield EOS token though
if contrastive_search_cond_fn(
generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables
):
maximum_iterations = max_length - cur_len
(
generated,
_,
cur_len,
_,
_,
) = tf.while_loop(
generated, _, cur_len, _, _ = tf.while_loop(
contrastive_search_cond_fn,
contrastive_search_body_fn,
(generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables),
......
......@@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import unittest
import numpy as np
from huggingface_hub import hf_hub_download
from transformers import is_tf_available
from transformers.testing_utils import require_tf, slow
from transformers import is_tensorflow_text_available, is_tf_available
from transformers.testing_utils import require_tensorflow_text, require_tf, slow
from ..test_modeling_tf_common import floats_tensor
from .test_framework_agnostic import GenerationIntegrationTestsMixin
......@@ -40,6 +42,9 @@ if is_tf_available():
tf_top_k_top_p_filtering,
)
if is_tensorflow_text_available():
import tensorflow_text as text
@require_tf
class UtilsFunctionsTest(unittest.TestCase):
......@@ -239,6 +244,36 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_new_tokens)
tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)
@slow
@require_tensorflow_text
def test_generate_tf_function_export_with_tf_tokenizer(self):
# TF-only test: tf.saved_model export
with tempfile.TemporaryDirectory() as tmp_dir:
# file needed to load the TF tokenizer
hf_hub_download(repo_id="google/flan-t5-small", filename="spiece.model", local_dir=tmp_dir)
class CompleteSentenceTransformer(tf.keras.layers.Layer):
def __init__(self):
super().__init__()
self.tokenizer = text.SentencepieceTokenizer(
model=tf.io.gfile.GFile(os.path.join(tmp_dir, "spiece.model"), "rb").read()
)
self.model = TFAutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
def call(self, inputs, *args, **kwargs):
tokens = self.tokenizer.tokenize(inputs)
input_ids, attention_mask = text.pad_model_inputs(
tokens, max_seq_length=64, pad_value=self.model.config.pad_token_id
)
outputs = self.model.generate(input_ids=input_ids, attention_mask=attention_mask)
return self.tokenizer.detokenize(outputs)
complete_model = CompleteSentenceTransformer()
inputs = tf.keras.layers.Input(shape=(1,), dtype=tf.string, name="inputs")
outputs = complete_model(inputs)
keras_model = tf.keras.Model(inputs, outputs)
keras_model.save(tmp_dir)
def test_eos_token_id_int_and_list_top_k_top_sampling(self):
# Has PT equivalent: this test relies on random sampling
generation_kwargs = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment