Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2410d0f8
Unverified
Commit
2410d0f8
authored
Mar 16, 2022
by
Patrick von Platen
Committed by
GitHub
Mar 16, 2022
Browse files
Fix generation min length (#16206)
* up * fix min lengths
parent
667b823b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
2 deletions
+5
-2
src/transformers/generation_utils.py
src/transformers/generation_utils.py
+1
-1
tests/generation/test_generation_utils.py
tests/generation/test_generation_utils.py
+4
-1
No files found.
src/transformers/generation_utils.py
View file @
2410d0f8
...
@@ -741,7 +741,7 @@ class GenerationMixin:
...
@@ -741,7 +741,7 @@ class GenerationMixin:
)
)
if
bad_words_ids
is
not
None
:
if
bad_words_ids
is
not
None
:
processors
.
append
(
NoBadWordsLogitsProcessor
(
bad_words_ids
,
eos_token_id
))
processors
.
append
(
NoBadWordsLogitsProcessor
(
bad_words_ids
,
eos_token_id
))
if
min_length
is
not
None
and
eos_token_id
is
not
None
and
min_length
>
-
1
:
if
min_length
is
not
None
and
eos_token_id
is
not
None
and
min_length
>
0
:
processors
.
append
(
MinLengthLogitsProcessor
(
min_length
,
eos_token_id
))
processors
.
append
(
MinLengthLogitsProcessor
(
min_length
,
eos_token_id
))
if
prefix_allowed_tokens_fn
is
not
None
:
if
prefix_allowed_tokens_fn
is
not
None
:
processors
.
append
(
PrefixConstrainedLogitsProcessor
(
prefix_allowed_tokens_fn
,
num_beams
//
num_beam_groups
))
processors
.
append
(
PrefixConstrainedLogitsProcessor
(
prefix_allowed_tokens_fn
,
num_beams
//
num_beam_groups
))
...
...
tests/generation/test_generation_utils.py
View file @
2410d0f8
...
@@ -1949,11 +1949,14 @@ class GenerationIntegrationTests(unittest.TestCase):
...
@@ -1949,11 +1949,14 @@ class GenerationIntegrationTests(unittest.TestCase):
def
test_custom_logits_processor
(
self
):
def
test_custom_logits_processor
(
self
):
bart_tokenizer
=
BartTokenizer
.
from_pretrained
(
"sshleifer/bart-tiny-random"
)
bart_tokenizer
=
BartTokenizer
.
from_pretrained
(
"sshleifer/bart-tiny-random"
)
article
=
"""Justin Timberlake and Jessica Biel, welcome to parenthood."""
article
=
"""Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_model
=
BartForConditionalGeneration
.
from_pretrained
(
"sshleifer/bart-tiny-random"
).
to
(
torch_device
)
bart_model
=
BartForConditionalGeneration
.
from_pretrained
(
"sshleifer/bart-tiny-random"
,
min_length
=
1
).
to
(
torch_device
)
input_ids
=
bart_tokenizer
(
article
,
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
input_ids
=
bart_tokenizer
(
article
,
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
logits_processor
=
LogitsProcessorList
()
logits_processor
=
LogitsProcessorList
()
logits_processor
.
append
(
MinLengthLogitsProcessor
(
min_length
=
10
,
eos_token_id
=
0
))
logits_processor
.
append
(
MinLengthLogitsProcessor
(
min_length
=
10
,
eos_token_id
=
0
))
# it should not be allowed to both define `min_length` via config and `logits_processor` list
with
self
.
assertRaises
(
ValueError
):
with
self
.
assertRaises
(
ValueError
):
bart_model
.
generate
(
input_ids
,
logits_processor
=
logits_processor
)
bart_model
.
generate
(
input_ids
,
logits_processor
=
logits_processor
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment