Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
204c54d4
"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "f62f992cf7aa7f1e4eb0d1ef912bd06d26c4dd8c"
Unverified
Commit
204c54d4
authored
Mar 16, 2022
by
Joao Gante
Committed by
GitHub
Mar 16, 2022
Browse files
TF: add beam search tests (#16202)
parent
19099457
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
51 additions
and
0 deletions
+51
-0
tests/gpt2/test_modeling_tf_gpt2.py
tests/gpt2/test_modeling_tf_gpt2.py
+28
-0
tests/t5/test_modeling_tf_t5.py
tests/t5/test_modeling_tf_t5.py
+23
-0
No files found.
tests/gpt2/test_modeling_tf_gpt2.py
View file @
204c54d4
...
@@ -521,6 +521,34 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
...
@@ -521,6 +521,34 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
]
]
self
.
assertListEqual
(
output_strings
,
expected_output_string
)
self
.
assertListEqual
(
output_strings
,
expected_output_string
)
@
slow
def
test_lm_generate_greedy_distilgpt2_beam_search_special
(
self
):
model
=
TFGPT2LMHeadModel
.
from_pretrained
(
"distilgpt2"
)
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"distilgpt2"
)
tokenizer
.
pad_token
=
tokenizer
.
eos_token
tokenizer
.
padding_side
=
"left"
sentences
=
[
"Today is a beautiful day and"
,
"Yesterday was"
]
input_ids
=
tokenizer
(
sentences
,
return_tensors
=
"tf"
,
padding
=
True
).
input_ids
generation_kwargs
=
{
"bad_words_ids"
:
[
tokenizer
(
"is"
).
input_ids
,
tokenizer
(
"angry about"
).
input_ids
],
"no_repeat_ngram_size"
:
2
,
"do_sample"
:
False
,
"repetition_penalty"
:
1.3
,
"num_beams"
:
2
,
}
output_ids
=
model
.
generate
(
input_ids
,
**
generation_kwargs
)
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
expected_output_string
=
[
"Today is a beautiful day and I hope you enjoy it.
\n
I am very happy to announce that"
,
"Yesterday was the first time I've ever seen a game where you can play with"
,
]
self
.
assertListEqual
(
output_strings
,
expected_output_string
)
@
slow
@
slow
def
test_lm_generate_gpt2
(
self
):
def
test_lm_generate_gpt2
(
self
):
model
=
TFGPT2LMHeadModel
.
from_pretrained
(
"gpt2"
)
model
=
TFGPT2LMHeadModel
.
from_pretrained
(
"gpt2"
)
...
...
tests/t5/test_modeling_tf_t5.py
View file @
204c54d4
...
@@ -548,6 +548,29 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
...
@@ -548,6 +548,29 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
self
.
assertListEqual
(
expected_output_string
,
output_strings
)
self
.
assertListEqual
(
expected_output_string
,
output_strings
)
@
slow
def
test_beam_search_generate
(
self
):
model
=
TFT5ForConditionalGeneration
.
from_pretrained
(
"t5-small"
)
tokenizer
=
T5Tokenizer
.
from_pretrained
(
"t5-small"
)
sentences
=
[
"I really love my"
,
"Translate English to German: the transformers are truly amazing"
]
input_ids
=
tokenizer
(
sentences
,
return_tensors
=
"tf"
,
padding
=
True
).
input_ids
generation_kwargs
=
{
"bad_words_ids"
:
[
tokenizer
(
"my"
).
input_ids
,
tokenizer
(
"ein schöner"
).
input_ids
],
"no_repeat_ngram_size"
:
3
,
"do_sample"
:
False
,
"repetition_penalty"
:
2.2
,
"num_beams"
:
4
,
}
output_ids
=
model
.
generate
(
input_ids
,
**
generation_kwargs
)
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
expected_output_string
=
[
"Ich liebe es so sehr!"
,
"die Transformatoren sind wirklich erstaunlich"
]
self
.
assertListEqual
(
expected_output_string
,
output_strings
)
@
require_tf
@
require_tf
@
require_sentencepiece
@
require_sentencepiece
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment