Unverified Commit 12febc20 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: Export TF generate with a TF tokenizer (#22310)

* Export TF generate with a TF tokenizer

* remove unused lines
parent 5fd4e3c8
...@@ -1725,14 +1725,13 @@ class TFGenerationMixin: ...@@ -1725,14 +1725,13 @@ class TFGenerationMixin:
# 2-to-n generation steps can then be run in autoregressive fashion # 2-to-n generation steps can then be run in autoregressive fashion
# only in case 1st generation step does NOT yield EOS token though # only in case 1st generation step does NOT yield EOS token though
if greedy_search_cond_fn(generated, finished_sequences, cur_len, model_kwargs): maximum_iterations = max_length - cur_len
maximum_iterations = max_length - cur_len generated, _, cur_len, _ = tf.while_loop(
generated, _, cur_len, _ = tf.while_loop( greedy_search_cond_fn,
greedy_search_cond_fn, greedy_search_body_fn,
greedy_search_body_fn, (generated, finished_sequences, cur_len, model_kwargs),
(generated, finished_sequences, cur_len, model_kwargs), maximum_iterations=maximum_iterations,
maximum_iterations=maximum_iterations, )
)
# 6. prepare outputs # 6. prepare outputs
if not use_xla: if not use_xla:
...@@ -2016,14 +2015,13 @@ class TFGenerationMixin: ...@@ -2016,14 +2015,13 @@ class TFGenerationMixin:
# 2-to-n generation steps can then be run in autoregressive fashion # 2-to-n generation steps can then be run in autoregressive fashion
# only in case 1st generation step does NOT yield EOS token though # only in case 1st generation step does NOT yield EOS token though
if sample_cond_fn(generated, finished_sequences, cur_len, model_kwargs): maximum_iterations = max_length - cur_len
maximum_iterations = max_length - cur_len generated, _, cur_len, _ = tf.while_loop(
generated, _, cur_len, _ = tf.while_loop( sample_cond_fn,
sample_cond_fn, sample_body_fn,
sample_body_fn, (generated, finished_sequences, cur_len, model_kwargs),
(generated, finished_sequences, cur_len, model_kwargs), maximum_iterations=maximum_iterations,
maximum_iterations=maximum_iterations, )
)
# 6. prepare outputs # 6. prepare outputs
if not use_xla: if not use_xla:
...@@ -2565,7 +2563,8 @@ class TFGenerationMixin: ...@@ -2565,7 +2563,8 @@ class TFGenerationMixin:
# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does # 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
# NOT yield EOS token though) # NOT yield EOS token though)
if beam_search_cond_fn( maximum_iterations = max_length - cur_len
(
cur_len, cur_len,
running_sequences, running_sequences,
running_scores, running_scores,
...@@ -2574,9 +2573,10 @@ class TFGenerationMixin: ...@@ -2574,9 +2573,10 @@ class TFGenerationMixin:
scores, scores,
beam_indices, beam_indices,
is_sent_finished, is_sent_finished,
model_kwargs, _,
): ) = tf.while_loop(
maximum_iterations = max_length - cur_len beam_search_cond_fn,
beam_search_body_fn,
( (
cur_len, cur_len,
running_sequences, running_sequences,
...@@ -2586,23 +2586,10 @@ class TFGenerationMixin: ...@@ -2586,23 +2586,10 @@ class TFGenerationMixin:
scores, scores,
beam_indices, beam_indices,
is_sent_finished, is_sent_finished,
_, model_kwargs,
) = tf.while_loop( ),
beam_search_cond_fn, maximum_iterations=maximum_iterations,
beam_search_body_fn, )
(
cur_len,
running_sequences,
running_scores,
running_beam_indices,
sequences,
scores,
beam_indices,
is_sent_finished,
model_kwargs,
),
maximum_iterations=maximum_iterations,
)
# 6. prepare outputs # 6. prepare outputs
# Account for the edge-case where there are no finished sequences for a particular batch item. If so, return # Account for the edge-case where there are no finished sequences for a particular batch item. If so, return
...@@ -3019,22 +3006,13 @@ class TFGenerationMixin: ...@@ -3019,22 +3006,13 @@ class TFGenerationMixin:
# 2-to-n generation steps can then be run in autoregressive fashion # 2-to-n generation steps can then be run in autoregressive fashion
# only in case 1st generation step does NOT yield EOS token though # only in case 1st generation step does NOT yield EOS token though
if contrastive_search_cond_fn( maximum_iterations = max_length - cur_len
generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables generated, _, cur_len, _, _ = tf.while_loop(
): contrastive_search_cond_fn,
maximum_iterations = max_length - cur_len contrastive_search_body_fn,
( (generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables),
generated, maximum_iterations=maximum_iterations,
_, )
cur_len,
_,
_,
) = tf.while_loop(
contrastive_search_cond_fn,
contrastive_search_body_fn,
(generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables),
maximum_iterations=maximum_iterations,
)
# 6. prepare outputs # 6. prepare outputs
if not use_xla: if not use_xla:
......
...@@ -13,13 +13,15 @@ ...@@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import tempfile import tempfile
import unittest import unittest
import numpy as np import numpy as np
from huggingface_hub import hf_hub_download
from transformers import is_tf_available from transformers import is_tensorflow_text_available, is_tf_available
from transformers.testing_utils import require_tf, slow from transformers.testing_utils import require_tensorflow_text, require_tf, slow
from ..test_modeling_tf_common import floats_tensor from ..test_modeling_tf_common import floats_tensor
from .test_framework_agnostic import GenerationIntegrationTestsMixin from .test_framework_agnostic import GenerationIntegrationTestsMixin
...@@ -40,6 +42,9 @@ if is_tf_available(): ...@@ -40,6 +42,9 @@ if is_tf_available():
tf_top_k_top_p_filtering, tf_top_k_top_p_filtering,
) )
if is_tensorflow_text_available():
import tensorflow_text as text
@require_tf @require_tf
class UtilsFunctionsTest(unittest.TestCase): class UtilsFunctionsTest(unittest.TestCase):
...@@ -239,6 +244,36 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests ...@@ -239,6 +244,36 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_new_tokens) tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_new_tokens)
tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs) tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)
@slow
@require_tensorflow_text
def test_generate_tf_function_export_with_tf_tokenizer(self):
# TF-only test: tf.saved_model export
with tempfile.TemporaryDirectory() as tmp_dir:
# file needed to load the TF tokenizer
hf_hub_download(repo_id="google/flan-t5-small", filename="spiece.model", local_dir=tmp_dir)
class CompleteSentenceTransformer(tf.keras.layers.Layer):
def __init__(self):
super().__init__()
self.tokenizer = text.SentencepieceTokenizer(
model=tf.io.gfile.GFile(os.path.join(tmp_dir, "spiece.model"), "rb").read()
)
self.model = TFAutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
def call(self, inputs, *args, **kwargs):
tokens = self.tokenizer.tokenize(inputs)
input_ids, attention_mask = text.pad_model_inputs(
tokens, max_seq_length=64, pad_value=self.model.config.pad_token_id
)
outputs = self.model.generate(input_ids=input_ids, attention_mask=attention_mask)
return self.tokenizer.detokenize(outputs)
complete_model = CompleteSentenceTransformer()
inputs = tf.keras.layers.Input(shape=(1,), dtype=tf.string, name="inputs")
outputs = complete_model(inputs)
keras_model = tf.keras.Model(inputs, outputs)
keras_model.save(tmp_dir)
def test_eos_token_id_int_and_list_top_k_top_sampling(self): def test_eos_token_id_int_and_list_top_k_top_sampling(self):
# Has PT equivalent: this test relies on random sampling # Has PT equivalent: this test relies on random sampling
generation_kwargs = { generation_kwargs = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment