Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
12febc20
Unverified
Commit
12febc20
authored
Mar 22, 2023
by
Joao Gante
Committed by
GitHub
Mar 22, 2023
Browse files
Generate: Export TF generate with a TF tokenizer (#22310)
* Export TF generate with a TF tokenizer * remove unused lines
parent
5fd4e3c8
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
68 additions
and
55 deletions
+68
-55
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
+31
-53
tests/generation/test_tf_utils.py
tests/generation/test_tf_utils.py
+37
-2
No files found.
src/transformers/generation/tf_utils.py
View file @
12febc20
...
...
@@ -1725,7 +1725,6 @@ class TFGenerationMixin:
# 2-to-n generation steps can then be run in autoregressive fashion
# only in case 1st generation step does NOT yield EOS token though
if
greedy_search_cond_fn
(
generated
,
finished_sequences
,
cur_len
,
model_kwargs
):
maximum_iterations
=
max_length
-
cur_len
generated
,
_
,
cur_len
,
_
=
tf
.
while_loop
(
greedy_search_cond_fn
,
...
...
@@ -2016,7 +2015,6 @@ class TFGenerationMixin:
# 2-to-n generation steps can then be run in autoregressive fashion
# only in case 1st generation step does NOT yield EOS token though
if
sample_cond_fn
(
generated
,
finished_sequences
,
cur_len
,
model_kwargs
):
maximum_iterations
=
max_length
-
cur_len
generated
,
_
,
cur_len
,
_
=
tf
.
while_loop
(
sample_cond_fn
,
...
...
@@ -2565,17 +2563,6 @@ class TFGenerationMixin:
# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
# NOT yield EOS token though)
if
beam_search_cond_fn
(
cur_len
,
running_sequences
,
running_scores
,
running_beam_indices
,
sequences
,
scores
,
beam_indices
,
is_sent_finished
,
model_kwargs
,
):
maximum_iterations
=
max_length
-
cur_len
(
cur_len
,
...
...
@@ -3019,17 +3006,8 @@ class TFGenerationMixin:
# 2-to-n generation steps can then be run in autoregressive fashion
# only in case 1st generation step does NOT yield EOS token though
if
contrastive_search_cond_fn
(
generated
,
finished_sequences
,
cur_len
,
model_kwargs
,
next_step_cached_variables
):
maximum_iterations
=
max_length
-
cur_len
(
generated
,
_
,
cur_len
,
_
,
_
,
)
=
tf
.
while_loop
(
generated
,
_
,
cur_len
,
_
,
_
=
tf
.
while_loop
(
contrastive_search_cond_fn
,
contrastive_search_body_fn
,
(
generated
,
finished_sequences
,
cur_len
,
model_kwargs
,
next_step_cached_variables
),
...
...
tests/generation/test_tf_utils.py
View file @
12febc20
...
...
@@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
tempfile
import
unittest
import
numpy
as
np
from
huggingface_hub
import
hf_hub_download
from
transformers
import
is_tf_available
from
transformers.testing_utils
import
require_tf
,
slow
from
transformers
import
is_tensorflow_text_available
,
is_tf_available
from
transformers.testing_utils
import
require_tensorflow_text
,
require_tf
,
slow
from
..test_modeling_tf_common
import
floats_tensor
from
.test_framework_agnostic
import
GenerationIntegrationTestsMixin
...
...
@@ -40,6 +42,9 @@ if is_tf_available():
tf_top_k_top_p_filtering
,
)
if
is_tensorflow_text_available
():
import
tensorflow_text
as
text
@
require_tf
class
UtilsFunctionsTest
(
unittest
.
TestCase
):
...
...
@@ -239,6 +244,36 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
tf_model_outputs
=
test_model
.
generate
(
**
inputs
,
max_new_tokens
=
max_new_tokens
)
tf
.
debugging
.
assert_equal
(
tf_func_outputs
,
tf_model_outputs
)
@
slow
@
require_tensorflow_text
def
test_generate_tf_function_export_with_tf_tokenizer
(
self
):
# TF-only test: tf.saved_model export
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
# file needed to load the TF tokenizer
hf_hub_download
(
repo_id
=
"google/flan-t5-small"
,
filename
=
"spiece.model"
,
local_dir
=
tmp_dir
)
class
CompleteSentenceTransformer
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
):
super
().
__init__
()
self
.
tokenizer
=
text
.
SentencepieceTokenizer
(
model
=
tf
.
io
.
gfile
.
GFile
(
os
.
path
.
join
(
tmp_dir
,
"spiece.model"
),
"rb"
).
read
()
)
self
.
model
=
TFAutoModelForSeq2SeqLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-t5"
)
def
call
(
self
,
inputs
,
*
args
,
**
kwargs
):
tokens
=
self
.
tokenizer
.
tokenize
(
inputs
)
input_ids
,
attention_mask
=
text
.
pad_model_inputs
(
tokens
,
max_seq_length
=
64
,
pad_value
=
self
.
model
.
config
.
pad_token_id
)
outputs
=
self
.
model
.
generate
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
)
return
self
.
tokenizer
.
detokenize
(
outputs
)
complete_model
=
CompleteSentenceTransformer
()
inputs
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
1
,),
dtype
=
tf
.
string
,
name
=
"inputs"
)
outputs
=
complete_model
(
inputs
)
keras_model
=
tf
.
keras
.
Model
(
inputs
,
outputs
)
keras_model
.
save
(
tmp_dir
)
def
test_eos_token_id_int_and_list_top_k_top_sampling
(
self
):
# Has PT equivalent: this test relies on random sampling
generation_kwargs
=
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment