Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2410d0f8
"docs/source/vscode:/vscode.git/clone" did not exist on "c2c0d9db5f9fb7050857053f21a17604f798c4dd"
Unverified
Commit
2410d0f8
authored
Mar 16, 2022
by
Patrick von Platen
Committed by
GitHub
Mar 16, 2022
Browse files
Fix generation min length (#16206)
* up * fix min lengths
parent
667b823b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
2 deletions
+5
-2
src/transformers/generation_utils.py
src/transformers/generation_utils.py
+1
-1
tests/generation/test_generation_utils.py
tests/generation/test_generation_utils.py
+4
-1
No files found.
src/transformers/generation_utils.py
View file @
2410d0f8
...
@@ -741,7 +741,7 @@ class GenerationMixin:
...
@@ -741,7 +741,7 @@ class GenerationMixin:
)
)
if
bad_words_ids
is
not
None
:
if
bad_words_ids
is
not
None
:
processors
.
append
(
NoBadWordsLogitsProcessor
(
bad_words_ids
,
eos_token_id
))
processors
.
append
(
NoBadWordsLogitsProcessor
(
bad_words_ids
,
eos_token_id
))
if
min_length
is
not
None
and
eos_token_id
is
not
None
and
min_length
>
-
1
:
if
min_length
is
not
None
and
eos_token_id
is
not
None
and
min_length
>
0
:
processors
.
append
(
MinLengthLogitsProcessor
(
min_length
,
eos_token_id
))
processors
.
append
(
MinLengthLogitsProcessor
(
min_length
,
eos_token_id
))
if
prefix_allowed_tokens_fn
is
not
None
:
if
prefix_allowed_tokens_fn
is
not
None
:
processors
.
append
(
PrefixConstrainedLogitsProcessor
(
prefix_allowed_tokens_fn
,
num_beams
//
num_beam_groups
))
processors
.
append
(
PrefixConstrainedLogitsProcessor
(
prefix_allowed_tokens_fn
,
num_beams
//
num_beam_groups
))
...
...
tests/generation/test_generation_utils.py
View file @
2410d0f8
...
@@ -1949,11 +1949,14 @@ class GenerationIntegrationTests(unittest.TestCase):
...
@@ -1949,11 +1949,14 @@ class GenerationIntegrationTests(unittest.TestCase):
def
test_custom_logits_processor
(
self
):
def
test_custom_logits_processor
(
self
):
bart_tokenizer
=
BartTokenizer
.
from_pretrained
(
"sshleifer/bart-tiny-random"
)
bart_tokenizer
=
BartTokenizer
.
from_pretrained
(
"sshleifer/bart-tiny-random"
)
article
=
"""Justin Timberlake and Jessica Biel, welcome to parenthood."""
article
=
"""Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_model
=
BartForConditionalGeneration
.
from_pretrained
(
"sshleifer/bart-tiny-random"
).
to
(
torch_device
)
bart_model
=
BartForConditionalGeneration
.
from_pretrained
(
"sshleifer/bart-tiny-random"
,
min_length
=
1
).
to
(
torch_device
)
input_ids
=
bart_tokenizer
(
article
,
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
input_ids
=
bart_tokenizer
(
article
,
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
logits_processor
=
LogitsProcessorList
()
logits_processor
=
LogitsProcessorList
()
logits_processor
.
append
(
MinLengthLogitsProcessor
(
min_length
=
10
,
eos_token_id
=
0
))
logits_processor
.
append
(
MinLengthLogitsProcessor
(
min_length
=
10
,
eos_token_id
=
0
))
# it should not be allowed to both define `min_length` via config and `logits_processor` list
with
self
.
assertRaises
(
ValueError
):
with
self
.
assertRaises
(
ValueError
):
bart_model
.
generate
(
input_ids
,
logits_processor
=
logits_processor
)
bart_model
.
generate
(
input_ids
,
logits_processor
=
logits_processor
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment