Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f614b6e3
Unverified
Commit
f614b6e3
authored
Jul 07, 2023
by
Joao Gante
Committed by
GitHub
Jul 07, 2023
Browse files
Whisper: fix prompted max length (#24666)
parent
49572942
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
27 additions
and
9 deletions
+27
-9
src/transformers/generation/stopping_criteria.py
src/transformers/generation/stopping_criteria.py
+17
-3
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+7
-1
src/transformers/models/whisper/modeling_whisper.py
src/transformers/models/whisper/modeling_whisper.py
+3
-5
No files found.
src/transformers/generation/stopping_criteria.py
View file @
f614b6e3
...
...
@@ -6,7 +6,10 @@ from typing import Optional
import
torch
from
..utils
import
add_start_docstrings
from
..utils
import
add_start_docstrings
,
logging
logger
=
logging
.
get_logger
(
__name__
)
STOPPING_CRITERIA_INPUTS_DOCSTRING
=
r
"""
...
...
@@ -46,14 +49,25 @@ class MaxLengthCriteria(StoppingCriteria):
Args:
max_length (`int`):
The maximum length that the output sequence can have in number of tokens.
max_position_embeddings (`int`, `optional`):
The maximum model length, as defined by the model's `config.max_position_embeddings` attribute.
"""
def
__init__
(
self
,
max_length
:
int
):
def
__init__
(
self
,
max_length
:
int
,
max_position_embeddings
:
Optional
[
int
]
=
None
):
self
.
max_length
=
max_length
self
.
max_position_embeddings
=
max_position_embeddings
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
bool
:
return
input_ids
.
shape
[
-
1
]
>=
self
.
max_length
cur_len
=
input_ids
.
shape
[
-
1
]
is_done
=
cur_len
>=
self
.
max_length
if
self
.
max_position_embeddings
is
not
None
and
not
is_done
and
cur_len
>=
self
.
max_position_embeddings
:
logger
.
warning_once
(
"This is a friendly reminder - the current text generation call will exceed the model's predefined "
f
"maximum length (
{
self
.
max_position_embeddings
}
). Depending on the model, you may observe "
"exceptions, performance degradation, or nothing at all."
)
return
is_done
class
MaxNewTokensCriteria
(
StoppingCriteria
):
...
...
src/transformers/generation/utils.py
View file @
f614b6e3
...
...
@@ -954,7 +954,13 @@ class GenerationMixin:
)
->
StoppingCriteriaList
:
criteria
=
StoppingCriteriaList
()
if
generation_config
.
max_length
is
not
None
:
criteria
.
append
(
MaxLengthCriteria
(
max_length
=
generation_config
.
max_length
))
max_position_embeddings
=
getattr
(
self
.
config
,
"max_position_embeddings"
,
None
)
criteria
.
append
(
MaxLengthCriteria
(
max_length
=
generation_config
.
max_length
,
max_position_embeddings
=
max_position_embeddings
,
)
)
if
generation_config
.
max_time
is
not
None
:
criteria
.
append
(
MaxTimeCriteria
(
max_time
=
generation_config
.
max_time
))
criteria
=
self
.
_merge_criteria_processor_list
(
criteria
,
stopping_criteria
)
...
...
src/transformers/models/whisper/modeling_whisper.py
View file @
f614b6e3
...
...
@@ -1715,11 +1715,9 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
# Set the decoder_start_token_id to <|startofprev|>
kwargs
.
update
({
"decoder_start_token_id"
:
decoder_start_token_id
})
# Update the max generation length to include the prompt
specified_max_length
=
kwargs
.
pop
(
"max_new_tokens"
,
None
)
or
kwargs
.
pop
(
"max_length"
,
None
)
default_max_length
=
generation_config
.
max_new_tokens
or
generation_config
.
max_length
non_prompt_max_length
=
specified_max_length
or
default_max_length
kwargs
[
"max_new_tokens"
]
=
non_prompt_max_length
+
len
(
text_prompt_ids
)
# If the user passes `max_new_tokens`, increase its number to account for the prompt
if
kwargs
.
get
(
"max_new_tokens"
,
None
)
is
not
None
:
kwargs
[
"max_new_tokens"
]
+=
len
(
text_prompt_ids
)
# Reformat the forced_decoder_ids to incorporate the prompt
non_prompt_forced_decoder_ids
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment