Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f21af262
Unverified
Commit
f21af262
authored
Feb 03, 2023
by
Joao Gante
Committed by
GitHub
Feb 03, 2023
Browse files
🚨
🚨
Generate: standardize beam search behavior across frameworks (#21368)
parent
ea55bd86
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
121 additions
and
117 deletions
+121
-117
src/transformers/generation/beam_search.py
src/transformers/generation/beam_search.py
+50
-31
src/transformers/generation/configuration_utils.py
src/transformers/generation/configuration_utils.py
+17
-2
src/transformers/generation/flax_utils.py
src/transformers/generation/flax_utils.py
+21
-13
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
+24
-14
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+5
-1
src/transformers/models/rag/modeling_rag.py
src/transformers/models/rag/modeling_rag.py
+1
-0
tests/generation/test_utils.py
tests/generation/test_utils.py
+0
-53
tests/models/bart/test_modeling_flax_bart.py
tests/models/bart/test_modeling_flax_bart.py
+1
-1
tests/models/gpt2/test_modeling_flax_gpt2.py
tests/models/gpt2/test_modeling_flax_gpt2.py
+1
-1
tests/models/t5/test_modeling_flax_t5.py
tests/models/t5/test_modeling_flax_t5.py
+1
-1
No files found.
src/transformers/generation/beam_search.py
View file @
f21af262
...
...
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
warnings
from
abc
import
ABC
,
abstractmethod
from
collections
import
UserDict
from
typing
import
List
,
Optional
,
Tuple
,
Union
...
...
@@ -130,8 +129,6 @@ class BeamSearchScorer(BeamScorer):
Args:
batch_size (`int`):
Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
max_length (`int`):
The maximum length of the sequence to be generated.
num_beams (`int`):
Number of beams for beam search.
device (`torch.device`):
...
...
@@ -142,14 +139,20 @@ class BeamSearchScorer(BeamScorer):
the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
`length_penalty` < 0.0 encourages shorter sequences.
do_early_stopping (`bool`, *optional*, defaults to `False`):
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
`True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
heuristic is applied and the generation stops when is it very unlikely to find better candidates;
`"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
beam search algorithm).
num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
The number of beam hypotheses that shall be returned upon calling
[`~transformer.BeamSearchScorer.finalize`].
num_beam_groups (`int`):
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
max_length (`int`, *optional*):
The maximum length of the sequence to be generated.
"""
def
__init__
(
...
...
@@ -158,10 +161,10 @@ class BeamSearchScorer(BeamScorer):
num_beams
:
int
,
device
:
torch
.
device
,
length_penalty
:
Optional
[
float
]
=
1.0
,
do_early_stopping
:
Optional
[
bool
]
=
False
,
do_early_stopping
:
Optional
[
Union
[
bool
,
str
]
]
=
False
,
num_beam_hyps_to_keep
:
Optional
[
int
]
=
1
,
num_beam_groups
:
Optional
[
int
]
=
1
,
**
kwargs
,
max_length
:
Optional
[
int
]
=
None
,
):
self
.
num_beams
=
num_beams
self
.
device
=
device
...
...
@@ -177,6 +180,7 @@ class BeamSearchScorer(BeamScorer):
num_beams
=
self
.
num_beams
,
length_penalty
=
self
.
length_penalty
,
early_stopping
=
self
.
do_early_stopping
,
max_length
=
max_length
,
)
for
_
in
range
(
batch_size
)
]
...
...
@@ -194,13 +198,6 @@ class BeamSearchScorer(BeamScorer):
f
" divisible by `num_beam_groups`, but is
{
num_beam_groups
}
with `num_beams` being
{
num_beams
}
."
)
if
"max_length"
in
kwargs
:
warnings
.
warn
(
"Passing `max_length` to BeamSearchScorer is deprecated and has no effect. "
"`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`"
", or `group_beam_search(...)`."
)
@
property
def
is_done
(
self
)
->
bool
:
return
self
.
_done
.
all
()
...
...
@@ -402,8 +399,6 @@ class ConstrainedBeamSearchScorer(BeamScorer):
Args:
batch_size (`int`):
Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
max_length (`int`):
The maximum length of the sequence to be generated.
num_beams (`int`):
Number of beams for beam search.
constraints (`List[Constraint]`):
...
...
@@ -417,14 +412,20 @@ class ConstrainedBeamSearchScorer(BeamScorer):
the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
`length_penalty` < 0.0 encourages shorter sequences.
do_early_stopping (`bool`, *optional*, defaults to `False`):
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
`True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
heuristic is applied and the generation stops when is it very unlikely to find better candidates;
`"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
beam search algorithm).
num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
The number of beam hypotheses that shall be returned upon calling
[`~transformer.BeamSearchScorer.finalize`].
num_beam_groups (`int`):
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
max_length (`int`, *optional*):
The maximum length of the sequence to be generated.
"""
def
__init__
(
...
...
@@ -434,10 +435,10 @@ class ConstrainedBeamSearchScorer(BeamScorer):
constraints
:
List
[
Constraint
],
device
:
torch
.
device
,
length_penalty
:
Optional
[
float
]
=
1.0
,
do_early_stopping
:
Optional
[
bool
]
=
False
,
do_early_stopping
:
Optional
[
Union
[
bool
,
str
]
]
=
False
,
num_beam_hyps_to_keep
:
Optional
[
int
]
=
1
,
num_beam_groups
:
Optional
[
int
]
=
1
,
**
kwargs
,
max_length
:
Optional
[
int
]
=
None
,
):
self
.
num_beams
=
num_beams
self
.
device
=
device
...
...
@@ -454,6 +455,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
num_beams
=
self
.
num_beams
,
length_penalty
=
self
.
length_penalty
,
early_stopping
=
self
.
do_early_stopping
,
max_length
=
max_length
,
)
for
_
in
range
(
batch_size
)
]
...
...
@@ -471,13 +473,6 @@ class ConstrainedBeamSearchScorer(BeamScorer):
f
" divisible by `num_beam_groups`, but is
{
num_beam_groups
}
with `num_beams` being
{
num_beams
}
."
)
if
"max_length"
in
kwargs
:
warnings
.
warn
(
"Passing `max_length` to ConstrainedBeamSearchScorer is deprecated and has no effect. "
"`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`"
", or `group_beam_search(...)`."
)
@
property
def
is_done
(
self
)
->
bool
:
return
self
.
_done
.
all
()
...
...
@@ -865,16 +860,23 @@ class ConstrainedBeamSearchScorer(BeamScorer):
class
BeamHypotheses
:
def
__init__
(
self
,
num_beams
:
int
,
length_penalty
:
float
,
early_stopping
:
bool
):
def
__init__
(
self
,
num_beams
:
int
,
length_penalty
:
float
,
early_stopping
:
bool
,
max_length
:
Optional
[
int
]
=
None
):
"""
Initialize n-best list of hypotheses.
"""
self
.
length_penalty
=
length_penalty
self
.
early_stopping
=
early_stopping
self
.
max_length
=
max_length
self
.
num_beams
=
num_beams
self
.
beams
=
[]
self
.
worst_score
=
1e9
if
not
isinstance
(
self
.
early_stopping
,
bool
)
and
self
.
max_length
is
None
:
raise
ValueError
(
"When `do_early_stopping` is set to a string, `max_length` must be defined. Ensure it is passed to the"
" BeamScorer class instance at initialization time."
)
def
__len__
(
self
):
"""
Number of hypotheses in the list.
...
...
@@ -903,9 +905,26 @@ class BeamHypotheses:
if
len
(
self
)
<
self
.
num_beams
:
return
False
elif
self
.
early_stopping
:
# `True`: stop as soon as at least `num_beams` hypotheses are finished
if
self
.
early_stopping
is
True
:
return
True
# `False`: heuristic -- compute best possible score from `cur_len`, even though it is not entirely accurate
# when `length_penalty` is positive. See the discussion below for more details.
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
elif
self
.
early_stopping
is
False
:
highest_attainable_score
=
best_sum_logprobs
/
cur_len
**
self
.
length_penalty
ret
=
self
.
worst_score
>=
highest_attainable_score
return
ret
# `"never"`: compute the best possible score, depending on the signal of `length_penalty`
else
:
cur_score
=
best_sum_logprobs
/
cur_len
**
self
.
length_penalty
ret
=
self
.
worst_score
>=
cur_score
# `length_penalty` > 0.0 -> max denominator is obtaned from `max_length`, not from `cur_len` -> min
# abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain
# its max this way
if
self
.
length_penalty
>
0.0
:
highest_attainable_score
=
best_sum_logprobs
/
self
.
max_length
**
self
.
length_penalty
# the opposite logic applies here (max `highest_attainable_score` from `cur_len`)
else
:
highest_attainable_score
=
best_sum_logprobs
/
cur_len
**
self
.
length_penalty
ret
=
self
.
worst_score
>=
highest_attainable_score
return
ret
src/transformers/generation/configuration_utils.py
View file @
f21af262
...
...
@@ -71,8 +71,12 @@ class GenerationConfig(PushToHubMixin):
`min_new_tokens`. Its effect is overridden by `min_new_tokens`, if also set.
min_new_tokens (`int`, *optional*):
The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt.
early_stopping (`bool`, *optional*, defaults to `False`):
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
early_stopping (`bool` or `str`, *optional*, defaults to `False`):
Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
`True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
heuristic is applied and the generation stops when is it very unlikely to find better candidates;
`"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
beam search algorithm).
max_time(`float`, *optional*):
The maximum amount of time you allow the computation to run for in seconds. generation will still finish
the current pass after allocated time has been passed.
...
...
@@ -290,6 +294,9 @@ class GenerationConfig(PushToHubMixin):
logger
.
error
(
f
"Can't set
{
key
}
with value
{
value
}
for
{
self
}
"
)
raise
err
# Validate the values of the attributes
self
.
validate
()
def
__eq__
(
self
,
other
):
self_dict
=
self
.
__dict__
.
copy
()
other_dict
=
other
.
__dict__
.
copy
()
...
...
@@ -302,6 +309,14 @@ class GenerationConfig(PushToHubMixin):
def
__repr__
(
self
):
return
f
"
{
self
.
__class__
.
__name__
}
{
self
.
to_json_string
()
}
"
def
validate
(
self
):
"""
Validates the values of the attributes of the GenerationConfig instance, and raises a `ValueError` if any of
the values are invalid.
"""
if
self
.
early_stopping
not
in
{
True
,
False
,
"never"
}:
raise
ValueError
(
f
"`early_stopping` must be a boolean or 'never', but is
{
self
.
early_stopping
}
."
)
def
save_pretrained
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
],
...
...
src/transformers/generation/flax_utils.py
View file @
f21af262
...
...
@@ -19,7 +19,7 @@ import copy
import
inspect
import
warnings
from
functools
import
partial
from
typing
import
Any
,
Dict
,
Optional
from
typing
import
Any
,
Dict
,
Optional
,
Union
import
numpy
as
np
...
...
@@ -275,6 +275,7 @@ class FlaxGenerationMixin:
generation_config
=
copy
.
deepcopy
(
generation_config
)
model_kwargs
=
generation_config
.
update
(
**
kwargs
)
# All unused kwargs must be model kwargs
generation_config
.
validate
()
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
# set init values
...
...
@@ -633,7 +634,7 @@ class FlaxGenerationMixin:
pad_token_id
:
Optional
[
int
]
=
None
,
eos_token_id
:
Optional
[
int
]
=
None
,
length_penalty
:
Optional
[
float
]
=
None
,
early_stopping
:
Optional
[
bool
]
=
None
,
early_stopping
:
Optional
[
Union
[
bool
,
str
]
]
=
None
,
logits_processor
:
Optional
[
FlaxLogitsProcessorList
]
=
None
,
trace
:
bool
=
True
,
params
:
Optional
[
Dict
[
str
,
jnp
.
ndarray
]]
=
None
,
...
...
@@ -733,14 +734,22 @@ class FlaxGenerationMixin:
not_max_length_yet
=
state
.
cur_len
<
max_length
# 2. can the new beams still improve?
best_running_score
=
state
.
running_scores
[:,
-
1
:]
/
(
max_length
**
length_penalty
)
# early_stopping == False -> apply heuristic = always get the best score from `cur_len`. See the discussion
# below for more details.
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
# early_stopping == "never" -> compute the best score from max_length or cur_len, depending on the sign of
# length_penalty. Positive length_penalty favors longer sequences, thus we use max_length there.
if
early_stopping
==
"never"
and
length_penalty
>
0.0
:
best_running_score
=
state
.
running_scores
[:,
:
1
]
/
(
max_length
**
length_penalty
)
else
:
best_running_score
=
state
.
running_scores
[:,
:
1
]
/
(
state
.
cur_len
**
length_penalty
)
worst_finished_score
=
jnp
.
where
(
state
.
is_sent_finished
,
jnp
.
min
(
state
.
scores
,
axis
=
1
,
keepdims
=
True
),
np
.
array
(
-
1.0e7
)
)
improvement_still_possible
=
jnp
.
a
ll
(
worst_finished_score
<
best_running
_score
)
improvement_still_possible
=
jnp
.
a
ny
(
best_running_score
>
worst_finished
_score
)
# 3. is there still a beam that has not finished?
still_open_beam
=
~
(
jnp
.
all
(
state
.
is_sent_finished
)
&
early_stopping
)
still_open_beam
=
~
(
jnp
.
all
(
state
.
is_sent_finished
)
&
(
early_stopping
is
True
)
)
return
not_max_length_yet
&
still_open_beam
&
improvement_still_possible
...
...
@@ -813,7 +822,7 @@ class FlaxGenerationMixin:
# 5. Get running sequences scores for next
# Determine the top k beam indices (from top 2*k beams) from log probs
# and gather top k beams (from top 2*k beams).
next_topk_indices
=
jnp
.
flip
(
lax
.
top_k
(
running_topk_log_probs
,
k
=
num_beams
)[
1
]
,
axis
=
1
)
next_topk_indices
=
lax
.
top_k
(
running_topk_log_probs
,
k
=
num_beams
)[
1
]
next_running_sequences
,
next_running_scores
=
gather_beams
(
[
topk_sequences
,
running_topk_log_probs
],
next_topk_indices
,
batch_size
,
num_beams
)
...
...
@@ -824,10 +833,9 @@ class FlaxGenerationMixin:
# - make sure no scores can be added anymore if beam is full
# - make sure still running sequences cannot be chosen as finalized beam
topk_log_probs
=
topk_log_probs
/
(
state
.
cur_len
**
length_penalty
)
beams_in_batch_are_full
=
(
jnp
.
broadcast_to
(
state
.
is_sent_finished
.
all
(
axis
=-
1
,
keepdims
=
True
),
did_topk_just_finished
.
shape
)
&
early_stopping
)
beams_in_batch_are_full
=
jnp
.
broadcast_to
(
state
.
is_sent_finished
.
all
(
axis
=-
1
,
keepdims
=
True
),
did_topk_just_finished
.
shape
)
&
(
early_stopping
is
True
)
add_penalty
=
~
did_topk_just_finished
|
beams_in_batch_are_full
topk_log_probs
+=
add_penalty
*
np
.
array
(
-
1.0e7
)
...
...
@@ -838,7 +846,7 @@ class FlaxGenerationMixin:
merged_sequences
=
jnp
.
concatenate
([
state
.
sequences
,
topk_sequences
],
axis
=
1
)
merged_scores
=
jnp
.
concatenate
([
state
.
scores
,
topk_log_probs
],
axis
=
1
)
merged_is_sent_finished
=
jnp
.
concatenate
([
state
.
is_sent_finished
,
did_topk_just_finished
],
axis
=
1
)
topk_merged_indices
=
jnp
.
flip
(
lax
.
top_k
(
merged_scores
,
k
=
num_beams
)[
1
]
,
axis
=
1
)
topk_merged_indices
=
lax
.
top_k
(
merged_scores
,
k
=
num_beams
)[
1
]
next_sequences
,
next_scores
,
next_is_sent_finished
=
gather_beams
(
[
merged_sequences
,
merged_scores
,
merged_is_sent_finished
],
topk_merged_indices
,
batch_size
,
num_beams
)
...
...
@@ -877,7 +885,7 @@ class FlaxGenerationMixin:
scores
=
jnp
.
where
(
none_finished
[:,
None
],
state
.
scores
,
state
.
running_scores
)
# take best beam for each batch
sequences
=
sequences
[:,
-
1
]
scores
=
scores
[:,
-
1
]
sequences
=
sequences
[:,
0
]
scores
=
scores
[:,
0
]
return
FlaxBeamSearchOutput
(
sequences
=
sequences
,
scores
=
scores
)
src/transformers/generation/tf_utils.py
View file @
f21af262
...
...
@@ -611,6 +611,7 @@ class TFGenerationMixin:
generation_config
=
copy
.
deepcopy
(
generation_config
)
model_kwargs
=
generation_config
.
update
(
**
kwargs
)
# All unused kwargs must be model kwargs
generation_config
.
validate
()
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
# 2. Cast input dtypes to tf.int32 unless they're floats (which happens for some image models)
...
...
@@ -1808,7 +1809,7 @@ class TFGenerationMixin:
pad_token_id
:
Optional
[
int
]
=
None
,
eos_token_id
:
Optional
[
int
]
=
None
,
length_penalty
:
Optional
[
float
]
=
None
,
early_stopping
:
Optional
[
bool
]
=
None
,
early_stopping
:
Optional
[
Union
[
bool
,
str
]
]
=
None
,
logits_processor
:
Optional
[
TFLogitsProcessorList
]
=
None
,
logits_warper
:
Optional
[
TFLogitsProcessorList
]
=
None
,
num_return_sequences
:
Optional
[
int
]
=
None
,
...
...
@@ -1838,8 +1839,12 @@ class TFGenerationMixin:
to the sequence length, which in turn is used to divide the score of the sequence. Since the score is
the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences,
while `length_penalty` < 0.0 encourages shorter sequences.
early_stopping (`bool`, *optional*, defaults to `False`):
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
early_stopping (`bool` or `str`, *optional*, defaults to `False`):
Controls the stopping condition for beam-based methods, like beam-search. It accepts the following
values: `True`, where the generation stops as soon as there are `num_beams` complete candidates;
`False`, where an heuristic is applied and the generation stops when is it very unlikely to find better
candidates; `"never"`, where the beam search procedure only stops when there cannot be better
candidates (canonical beam search algorithm).
logits_processor (`[TFLogitsProcessorList]`, *optional*):
An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
...
...
@@ -2009,16 +2014,24 @@ class TFGenerationMixin:
not_max_length_yet
=
cur_len
<
max_length
# 2. can the new beams still improve?
best_running_score
=
running_scores
[:,
:
1
]
/
(
max_length
**
length_penalty
)
# early_stopping == False -> apply heuristic = always get the best score from `cur_len`. See the discussion
# below for more details.
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
# early_stopping == "never" -> compute the best score from max_length or cur_len, depending on the sign of
# length_penalty. Positive length_penalty favors longer sequences, thus we use max_length there.
if
early_stopping
==
"never"
and
length_penalty
>
0.0
:
best_running_score
=
running_scores
[:,
:
1
]
/
(
max_length
**
length_penalty
)
else
:
best_running_score
=
running_scores
[:,
:
1
]
/
(
tf
.
cast
(
cur_len
,
dtype
=
tf
.
float32
)
**
length_penalty
)
worst_finished_score
=
tf
.
where
(
is_sent_finished
,
tf
.
math
.
reduce_min
(
scores
,
axis
=
1
,
keepdims
=
True
),
-
1.0e9
)
improvement_still_possible
=
tf
.
math
.
reduce_a
ll
(
worst_finished_score
<
best_running
_score
)
improvement_still_possible
=
tf
.
math
.
reduce_a
ny
(
best_running_score
>
worst_finished
_score
)
# 3. is there still a beam that has not finished?
still_open_beam
=
~
(
tf
.
math
.
reduce_all
(
is_sent_finished
)
&
early_stopping
)
still_open_beam
=
~
(
tf
.
math
.
reduce_all
(
is_sent_finished
)
&
(
early_stopping
is
True
)
)
return
not_max_length_yet
&
(
still_open_beam
|
improvement_still_possible
)
return
not_max_length_yet
&
still_open_beam
&
improvement_still_possible
def
beam_search_body_fn
(
cur_len
,
...
...
@@ -2140,12 +2153,9 @@ class TFGenerationMixin:
# - make sure no scores can be added anymore if beam is full
# - make sure still running sequences cannot be chosen as finalized beam
topk_log_probs
=
topk_log_probs
/
(
tf
.
cast
(
cur_len
,
dtype
=
tf
.
float32
)
**
length_penalty
)
beams_in_batch_are_full
=
(
tf
.
broadcast_to
(
tf
.
math
.
reduce_all
(
is_sent_finished
,
axis
=-
1
,
keepdims
=
True
),
shape_list
(
did_topk_just_finished
)
)
&
early_stopping
)
beams_in_batch_are_full
=
tf
.
broadcast_to
(
tf
.
math
.
reduce_all
(
is_sent_finished
,
axis
=-
1
,
keepdims
=
True
),
shape_list
(
did_topk_just_finished
)
)
&
(
early_stopping
is
True
)
add_penalty
=
~
did_topk_just_finished
|
beams_in_batch_are_full
topk_log_probs
+=
tf
.
cast
(
add_penalty
,
tf
.
float32
)
*
-
1.0e9
...
...
@@ -2239,7 +2249,7 @@ class TFGenerationMixin:
sequences
=
tf
.
where
(
none_finished
[:,
None
,
None
],
sequences
,
running_sequences
)
scores
=
tf
.
where
(
none_finished
[:,
None
],
scores
,
running_scores
)
# Take best beams for each batch (the score is sorted in
a
scending order)
# Take best beams for each batch (the score is sorted in
de
scending order)
sequences
=
flatten_beam_dim
(
sequences
[:,
:
num_return_sequences
,
:])
scores
=
flatten_beam_dim
(
scores
[:,
:
num_return_sequences
])
...
...
src/transformers/generation/utils.py
View file @
f21af262
...
...
@@ -1190,6 +1190,7 @@ class GenerationMixin:
generation_config
=
copy
.
deepcopy
(
generation_config
)
model_kwargs
=
generation_config
.
update
(
**
kwargs
)
# All unused kwargs must be model kwargs
generation_config
.
validate
()
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
# 2. Set generation parameters if not already defined
...
...
@@ -1458,6 +1459,7 @@ class GenerationMixin:
length_penalty
=
generation_config
.
length_penalty
,
do_early_stopping
=
generation_config
.
early_stopping
,
num_beam_hyps_to_keep
=
generation_config
.
num_return_sequences
,
max_length
=
generation_config
.
max_length
,
)
# 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids
,
model_kwargs
=
self
.
_expand_inputs_for_generation
(
...
...
@@ -1493,6 +1495,7 @@ class GenerationMixin:
device
=
inputs_tensor
.
device
,
length_penalty
=
generation_config
.
length_penalty
,
do_early_stopping
=
generation_config
.
early_stopping
,
max_length
=
generation_config
.
max_length
,
)
# 13. interleave input_ids with `num_beams` additional sequences per batch
...
...
@@ -1536,12 +1539,12 @@ class GenerationMixin:
beam_scorer
=
BeamSearchScorer
(
batch_size
=
batch_size
,
num_beams
=
generation_config
.
num_beams
,
max_length
=
stopping_criteria
.
max_length
,
device
=
inputs_tensor
.
device
,
length_penalty
=
generation_config
.
length_penalty
,
do_early_stopping
=
generation_config
.
early_stopping
,
num_beam_hyps_to_keep
=
generation_config
.
num_return_sequences
,
num_beam_groups
=
generation_config
.
num_beam_groups
,
max_length
=
generation_config
.
max_length
,
)
# 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids
,
model_kwargs
=
self
.
_expand_inputs_for_generation
(
...
...
@@ -1629,6 +1632,7 @@ class GenerationMixin:
length_penalty
=
generation_config
.
length_penalty
,
do_early_stopping
=
generation_config
.
early_stopping
,
num_beam_hyps_to_keep
=
generation_config
.
num_return_sequences
,
max_length
=
generation_config
.
max_length
,
)
# 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids
,
model_kwargs
=
self
.
_expand_inputs_for_generation
(
...
...
src/transformers/models/rag/modeling_rag.py
View file @
f21af262
...
...
@@ -1566,6 +1566,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
length_penalty
=
generation_config
.
length_penalty
,
do_early_stopping
=
generation_config
.
early_stopping
,
num_beam_hyps_to_keep
=
generation_config
.
num_return_sequences
,
max_length
=
generation_config
.
max_length
,
)
return
self
.
beam_search
(
input_ids
,
...
...
tests/generation/test_utils.py
View file @
f21af262
...
...
@@ -2034,59 +2034,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
**
model_kwargs
,
)
def
test_beam_search_warning_if_max_length_is_passed
(
self
):
article
=
"""Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer
=
BartTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
bart_model
=
BartForConditionalGeneration
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
).
to
(
torch_device
)
batch_size
=
1
num_beams
=
3
input_ids
=
bart_tokenizer
(
article
,
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
input_ids
=
input_ids
.
expand
(
num_beams
,
-
1
)
model_kwargs
=
bart_model
.
_prepare_encoder_decoder_kwargs_for_generation
(
input_ids
,
{})
# pretend decoder_input_ids correspond to first encoder input id
decoder_input_ids
=
input_ids
[:,
:
1
]
stopping_criteria_max_length
=
18
stopping_criteria
=
StoppingCriteriaList
([
MaxLengthCriteria
(
max_length
=
stopping_criteria_max_length
)])
with
self
.
assertWarns
(
UserWarning
):
beam_scorer
=
BeamSearchScorer
(
batch_size
=
batch_size
,
num_beams
=
num_beams
,
device
=
torch_device
,
max_length
=
10
,
)
generated_ids
=
bart_model
.
beam_search
(
decoder_input_ids
,
num_beams
=
num_beams
,
stopping_criteria
=
stopping_criteria
,
beam_scorer
=
beam_scorer
,
**
model_kwargs
,
)
beam_scorer_no_max_len
=
BeamSearchScorer
(
batch_size
=
batch_size
,
num_beams
=
num_beams
,
device
=
torch_device
,
)
generated_ids_no_max_len
=
bart_model
.
beam_search
(
decoder_input_ids
,
num_beams
=
num_beams
,
stopping_criteria
=
stopping_criteria
,
beam_scorer
=
beam_scorer_no_max_len
,
**
model_kwargs
,
)
# BeamSearchScorer max_length should not influence "real" max_length
self
.
assertEqual
(
generated_ids
.
tolist
(),
generated_ids_no_max_len
.
tolist
())
def
test_custom_stopping_criteria_overload_error
(
self
):
article
=
"""Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer
=
BartTokenizer
.
from_pretrained
(
"sshleifer/bart-tiny-random"
)
...
...
tests/models/bart/test_modeling_flax_bart.py
View file @
f21af262
...
...
@@ -426,7 +426,7 @@ class FlaxBartModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationT
)
input_ids
=
tokenizer
(
input_str
,
return_tensors
=
"np"
).
input_ids
sequences
=
model
.
generate
(
input_ids
,
num_beams
=
2
,
max_length
=
20
).
sequences
sequences
=
model
.
generate
(
input_ids
,
num_beams
=
2
,
min_length
=
None
,
max_length
=
20
).
sequences
output_str
=
tokenizer
.
batch_decode
(
sequences
)[
0
]
...
...
tests/models/gpt2/test_modeling_flax_gpt2.py
View file @
f21af262
...
...
@@ -224,7 +224,7 @@ class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittes
output_string
=
tokenizer
.
batch_decode
(
output_sequences
,
skip_special_tokens
=
True
)
expected_string
=
[
"Hello this is a long string of words. I'm going to
try to explain what I mean.
"
,
"Hello this is a long string of words. I'm going to
start with the first one.
\n
"
,
"Hey, I'm not sure if I'm going to be able to do"
,
]
...
...
tests/models/t5/test_modeling_flax_t5.py
View file @
f21af262
...
...
@@ -1076,7 +1076,7 @@ class FlaxT5ModelIntegrationTests(unittest.TestCase):
expected_summaries
=
[
'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a'
" cell phone video of the final seconds .
\"
one can hear cries of 'My God' in several languages,
\"
one"
" magazine says . all 150 on board were killed
when germanwings flight 9525
crash
ed
."
,
" magazine says . all 150 on board were killed
in the
crash ."
,
"the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a"
" preliminary examination into the situation in the occupied Palestinian territory . as members of the"
" court, Palestinians may be subject to counter-charges as well ."
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment