Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f21af262
Unverified
Commit
f21af262
authored
Feb 03, 2023
by
Joao Gante
Committed by
GitHub
Feb 03, 2023
Browse files
🚨
🚨
Generate: standardize beam search behavior across frameworks (#21368)
parent
ea55bd86
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
121 additions
and
117 deletions
+121
-117
src/transformers/generation/beam_search.py
src/transformers/generation/beam_search.py
+50
-31
src/transformers/generation/configuration_utils.py
src/transformers/generation/configuration_utils.py
+17
-2
src/transformers/generation/flax_utils.py
src/transformers/generation/flax_utils.py
+21
-13
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
+24
-14
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+5
-1
src/transformers/models/rag/modeling_rag.py
src/transformers/models/rag/modeling_rag.py
+1
-0
tests/generation/test_utils.py
tests/generation/test_utils.py
+0
-53
tests/models/bart/test_modeling_flax_bart.py
tests/models/bart/test_modeling_flax_bart.py
+1
-1
tests/models/gpt2/test_modeling_flax_gpt2.py
tests/models/gpt2/test_modeling_flax_gpt2.py
+1
-1
tests/models/t5/test_modeling_flax_t5.py
tests/models/t5/test_modeling_flax_t5.py
+1
-1
No files found.
src/transformers/generation/beam_search.py
View file @
f21af262
...
...
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
warnings
from
abc
import
ABC
,
abstractmethod
from
collections
import
UserDict
from
typing
import
List
,
Optional
,
Tuple
,
Union
...
...
@@ -130,8 +129,6 @@ class BeamSearchScorer(BeamScorer):
Args:
batch_size (`int`):
Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
max_length (`int`):
The maximum length of the sequence to be generated.
num_beams (`int`):
Number of beams for beam search.
device (`torch.device`):
...
...
@@ -142,14 +139,20 @@ class BeamSearchScorer(BeamScorer):
the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
`length_penalty` < 0.0 encourages shorter sequences.
do_early_stopping (`bool`, *optional*, defaults to `False`):
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
`True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
heuristic is applied and the generation stops when is it very unlikely to find better candidates;
`"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
beam search algorithm).
num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
The number of beam hypotheses that shall be returned upon calling
[`~transformer.BeamSearchScorer.finalize`].
num_beam_groups (`int`):
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
max_length (`int`, *optional*):
The maximum length of the sequence to be generated.
"""
def
__init__
(
...
...
@@ -158,10 +161,10 @@ class BeamSearchScorer(BeamScorer):
num_beams
:
int
,
device
:
torch
.
device
,
length_penalty
:
Optional
[
float
]
=
1.0
,
do_early_stopping
:
Optional
[
bool
]
=
False
,
do_early_stopping
:
Optional
[
Union
[
bool
,
str
]
]
=
False
,
num_beam_hyps_to_keep
:
Optional
[
int
]
=
1
,
num_beam_groups
:
Optional
[
int
]
=
1
,
**
kwargs
,
max_length
:
Optional
[
int
]
=
None
,
):
self
.
num_beams
=
num_beams
self
.
device
=
device
...
...
@@ -177,6 +180,7 @@ class BeamSearchScorer(BeamScorer):
num_beams
=
self
.
num_beams
,
length_penalty
=
self
.
length_penalty
,
early_stopping
=
self
.
do_early_stopping
,
max_length
=
max_length
,
)
for
_
in
range
(
batch_size
)
]
...
...
@@ -194,13 +198,6 @@ class BeamSearchScorer(BeamScorer):
f
" divisible by `num_beam_groups`, but is
{
num_beam_groups
}
with `num_beams` being
{
num_beams
}
."
)
if
"max_length"
in
kwargs
:
warnings
.
warn
(
"Passing `max_length` to BeamSearchScorer is deprecated and has no effect. "
"`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`"
", or `group_beam_search(...)`."
)
@
property
def
is_done
(
self
)
->
bool
:
return
self
.
_done
.
all
()
...
...
@@ -402,8 +399,6 @@ class ConstrainedBeamSearchScorer(BeamScorer):
Args:
batch_size (`int`):
Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
max_length (`int`):
The maximum length of the sequence to be generated.
num_beams (`int`):
Number of beams for beam search.
constraints (`List[Constraint]`):
...
...
@@ -417,14 +412,20 @@ class ConstrainedBeamSearchScorer(BeamScorer):
the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
`length_penalty` < 0.0 encourages shorter sequences.
do_early_stopping (`bool`, *optional*, defaults to `False`):
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
`True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
heuristic is applied and the generation stops when is it very unlikely to find better candidates;
`"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
beam search algorithm).
num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
The number of beam hypotheses that shall be returned upon calling
[`~transformer.BeamSearchScorer.finalize`].
num_beam_groups (`int`):
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
max_length (`int`, *optional*):
The maximum length of the sequence to be generated.
"""
def
__init__
(
...
...
@@ -434,10 +435,10 @@ class ConstrainedBeamSearchScorer(BeamScorer):
constraints
:
List
[
Constraint
],
device
:
torch
.
device
,
length_penalty
:
Optional
[
float
]
=
1.0
,
do_early_stopping
:
Optional
[
bool
]
=
False
,
do_early_stopping
:
Optional
[
Union
[
bool
,
str
]
]
=
False
,
num_beam_hyps_to_keep
:
Optional
[
int
]
=
1
,
num_beam_groups
:
Optional
[
int
]
=
1
,
**
kwargs
,
max_length
:
Optional
[
int
]
=
None
,
):
self
.
num_beams
=
num_beams
self
.
device
=
device
...
...
@@ -454,6 +455,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
num_beams
=
self
.
num_beams
,
length_penalty
=
self
.
length_penalty
,
early_stopping
=
self
.
do_early_stopping
,
max_length
=
max_length
,
)
for
_
in
range
(
batch_size
)
]
...
...
@@ -471,13 +473,6 @@ class ConstrainedBeamSearchScorer(BeamScorer):
f
" divisible by `num_beam_groups`, but is
{
num_beam_groups
}
with `num_beams` being
{
num_beams
}
."
)
if
"max_length"
in
kwargs
:
warnings
.
warn
(
"Passing `max_length` to ConstrainedBeamSearchScorer is deprecated and has no effect. "
"`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`"
", or `group_beam_search(...)`."
)
@
property
def
is_done
(
self
)
->
bool
:
return
self
.
_done
.
all
()
...
...
@@ -865,16 +860,23 @@ class ConstrainedBeamSearchScorer(BeamScorer):
class
BeamHypotheses
:
def
__init__
(
self
,
num_beams
:
int
,
length_penalty
:
float
,
early_stopping
:
bool
):
def
__init__
(
self
,
num_beams
:
int
,
length_penalty
:
float
,
early_stopping
:
bool
,
max_length
:
Optional
[
int
]
=
None
):
"""
Initialize n-best list of hypotheses.
"""
self
.
length_penalty
=
length_penalty
self
.
early_stopping
=
early_stopping
self
.
max_length
=
max_length
self
.
num_beams
=
num_beams
self
.
beams
=
[]
self
.
worst_score
=
1e9
if
not
isinstance
(
self
.
early_stopping
,
bool
)
and
self
.
max_length
is
None
:
raise
ValueError
(
"When `do_early_stopping` is set to a string, `max_length` must be defined. Ensure it is passed to the"
" BeamScorer class instance at initialization time."
)
def
__len__
(
self
):
"""
Number of hypotheses in the list.
...
...
@@ -903,9 +905,26 @@ class BeamHypotheses:
if
len
(
self
)
<
self
.
num_beams
:
return
False
elif
self
.
early_stopping
:
# `True`: stop as soon as at least `num_beams` hypotheses are finished
if
self
.
early_stopping
is
True
:
return
True
# `False`: heuristic -- compute best possible score from `cur_len`, even though it is not entirely accurate
# when `length_penalty` is positive. See the discussion below for more details.
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
elif
self
.
early_stopping
is
False
:
highest_attainable_score
=
best_sum_logprobs
/
cur_len
**
self
.
length_penalty
ret
=
self
.
worst_score
>=
highest_attainable_score
return
ret
# `"never"`: compute the best possible score, depending on the signal of `length_penalty`
else
:
# `length_penalty` > 0.0 -> max denominator is obtaned from `max_length`, not from `cur_len` -> min
# abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain
# its max this way
if
self
.
length_penalty
>
0.0
:
highest_attainable_score
=
best_sum_logprobs
/
self
.
max_length
**
self
.
length_penalty
# the opposite logic applies here (max `highest_attainable_score` from `cur_len`)
else
:
cur
_score
=
best_sum_logprobs
/
cur_len
**
self
.
length_penalty
ret
=
self
.
worst_score
>=
cur
_score
highest_attainable
_score
=
best_sum_logprobs
/
cur_len
**
self
.
length_penalty
ret
=
self
.
worst_score
>=
highest_attainable
_score
return
ret
src/transformers/generation/configuration_utils.py
View file @
f21af262
...
...
@@ -71,8 +71,12 @@ class GenerationConfig(PushToHubMixin):
`min_new_tokens`. Its effect is overridden by `min_new_tokens`, if also set.
min_new_tokens (`int`, *optional*):
The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt.
early_stopping (`bool`, *optional*, defaults to `False`):
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
early_stopping (`bool` or `str`, *optional*, defaults to `False`):
Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
`True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
heuristic is applied and the generation stops when is it very unlikely to find better candidates;
`"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
beam search algorithm).
max_time(`float`, *optional*):
The maximum amount of time you allow the computation to run for in seconds. generation will still finish
the current pass after allocated time has been passed.
...
...
@@ -290,6 +294,9 @@ class GenerationConfig(PushToHubMixin):
logger
.
error
(
f
"Can't set
{
key
}
with value
{
value
}
for
{
self
}
"
)
raise
err
# Validate the values of the attributes
self
.
validate
()
def
__eq__
(
self
,
other
):
self_dict
=
self
.
__dict__
.
copy
()
other_dict
=
other
.
__dict__
.
copy
()
...
...
@@ -302,6 +309,14 @@ class GenerationConfig(PushToHubMixin):
def
__repr__
(
self
):
return
f
"
{
self
.
__class__
.
__name__
}
{
self
.
to_json_string
()
}
"
def
validate
(
self
):
"""
Validates the values of the attributes of the GenerationConfig instance, and raises a `ValueError` if any of
the values are invalid.
"""
if
self
.
early_stopping
not
in
{
True
,
False
,
"never"
}:
raise
ValueError
(
f
"`early_stopping` must be a boolean or 'never', but is
{
self
.
early_stopping
}
."
)
def
save_pretrained
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
],
...
...
src/transformers/generation/flax_utils.py
View file @
f21af262
...
...
@@ -19,7 +19,7 @@ import copy
import
inspect
import
warnings
from
functools
import
partial
from
typing
import
Any
,
Dict
,
Optional
from
typing
import
Any
,
Dict
,
Optional
,
Union
import
numpy
as
np
...
...
@@ -275,6 +275,7 @@ class FlaxGenerationMixin:
generation_config
=
copy
.
deepcopy
(
generation_config
)
model_kwargs
=
generation_config
.
update
(
**
kwargs
)
# All unused kwargs must be model kwargs
generation_config
.
validate
()
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
# set init values
...
...
@@ -633,7 +634,7 @@ class FlaxGenerationMixin:
pad_token_id
:
Optional
[
int
]
=
None
,
eos_token_id
:
Optional
[
int
]
=
None
,
length_penalty
:
Optional
[
float
]
=
None
,
early_stopping
:
Optional
[
bool
]
=
None
,
early_stopping
:
Optional
[
Union
[
bool
,
str
]
]
=
None
,
logits_processor
:
Optional
[
FlaxLogitsProcessorList
]
=
None
,
trace
:
bool
=
True
,
params
:
Optional
[
Dict
[
str
,
jnp
.
ndarray
]]
=
None
,
...
...
@@ -733,14 +734,22 @@ class FlaxGenerationMixin:
not_max_length_yet
=
state
.
cur_len
<
max_length
# 2. can the new beams still improve?
best_running_score
=
state
.
running_scores
[:,
-
1
:]
/
(
max_length
**
length_penalty
)
# early_stopping == False -> apply heuristic = always get the best score from `cur_len`. See the discussion
# below for more details.
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
# early_stopping == "never" -> compute the best score from max_length or cur_len, depending on the sign of
# length_penalty. Positive length_penalty favors longer sequences, thus we use max_length there.
if
early_stopping
==
"never"
and
length_penalty
>
0.0
:
best_running_score
=
state
.
running_scores
[:,
:
1
]
/
(
max_length
**
length_penalty
)
else
:
best_running_score
=
state
.
running_scores
[:,
:
1
]
/
(
state
.
cur_len
**
length_penalty
)
worst_finished_score
=
jnp
.
where
(
state
.
is_sent_finished
,
jnp
.
min
(
state
.
scores
,
axis
=
1
,
keepdims
=
True
),
np
.
array
(
-
1.0e7
)
)
improvement_still_possible
=
jnp
.
a
ll
(
worst_finished_score
<
best_running
_score
)
improvement_still_possible
=
jnp
.
a
ny
(
best_running_score
>
worst_finished
_score
)
# 3. is there still a beam that has not finished?
still_open_beam
=
~
(
jnp
.
all
(
state
.
is_sent_finished
)
&
early_stopping
)
still_open_beam
=
~
(
jnp
.
all
(
state
.
is_sent_finished
)
&
(
early_stopping
is
True
)
)
return
not_max_length_yet
&
still_open_beam
&
improvement_still_possible
...
...
@@ -813,7 +822,7 @@ class FlaxGenerationMixin:
# 5. Get running sequences scores for next
# Determine the top k beam indices (from top 2*k beams) from log probs
# and gather top k beams (from top 2*k beams).
next_topk_indices
=
jnp
.
flip
(
lax
.
top_k
(
running_topk_log_probs
,
k
=
num_beams
)[
1
]
,
axis
=
1
)
next_topk_indices
=
lax
.
top_k
(
running_topk_log_probs
,
k
=
num_beams
)[
1
]
next_running_sequences
,
next_running_scores
=
gather_beams
(
[
topk_sequences
,
running_topk_log_probs
],
next_topk_indices
,
batch_size
,
num_beams
)
...
...
@@ -824,10 +833,9 @@ class FlaxGenerationMixin:
# - make sure no scores can be added anymore if beam is full
# - make sure still running sequences cannot be chosen as finalized beam
topk_log_probs
=
topk_log_probs
/
(
state
.
cur_len
**
length_penalty
)
beams_in_batch_are_full
=
(
jnp
.
broadcast_to
(
state
.
is_sent_finished
.
all
(
axis
=-
1
,
keepdims
=
True
),
did_topk_just_finished
.
shape
)
&
early_stopping
)
beams_in_batch_are_full
=
jnp
.
broadcast_to
(
state
.
is_sent_finished
.
all
(
axis
=-
1
,
keepdims
=
True
),
did_topk_just_finished
.
shape
)
&
(
early_stopping
is
True
)
add_penalty
=
~
did_topk_just_finished
|
beams_in_batch_are_full
topk_log_probs
+=
add_penalty
*
np
.
array
(
-
1.0e7
)
...
...
@@ -838,7 +846,7 @@ class FlaxGenerationMixin:
merged_sequences
=
jnp
.
concatenate
([
state
.
sequences
,
topk_sequences
],
axis
=
1
)
merged_scores
=
jnp
.
concatenate
([
state
.
scores
,
topk_log_probs
],
axis
=
1
)
merged_is_sent_finished
=
jnp
.
concatenate
([
state
.
is_sent_finished
,
did_topk_just_finished
],
axis
=
1
)
topk_merged_indices
=
jnp
.
flip
(
lax
.
top_k
(
merged_scores
,
k
=
num_beams
)[
1
]
,
axis
=
1
)
topk_merged_indices
=
lax
.
top_k
(
merged_scores
,
k
=
num_beams
)[
1
]
next_sequences
,
next_scores
,
next_is_sent_finished
=
gather_beams
(
[
merged_sequences
,
merged_scores
,
merged_is_sent_finished
],
topk_merged_indices
,
batch_size
,
num_beams
)
...
...
@@ -877,7 +885,7 @@ class FlaxGenerationMixin:
scores
=
jnp
.
where
(
none_finished
[:,
None
],
state
.
scores
,
state
.
running_scores
)
# take best beam for each batch
sequences
=
sequences
[:,
-
1
]
scores
=
scores
[:,
-
1
]
sequences
=
sequences
[:,
0
]
scores
=
scores
[:,
0
]
return
FlaxBeamSearchOutput
(
sequences
=
sequences
,
scores
=
scores
)
src/transformers/generation/tf_utils.py
View file @
f21af262
...
...
@@ -611,6 +611,7 @@ class TFGenerationMixin:
generation_config
=
copy
.
deepcopy
(
generation_config
)
model_kwargs
=
generation_config
.
update
(
**
kwargs
)
# All unused kwargs must be model kwargs
generation_config
.
validate
()
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
# 2. Cast input dtypes to tf.int32 unless they're floats (which happens for some image models)
...
...
@@ -1808,7 +1809,7 @@ class TFGenerationMixin:
pad_token_id
:
Optional
[
int
]
=
None
,
eos_token_id
:
Optional
[
int
]
=
None
,
length_penalty
:
Optional
[
float
]
=
None
,
early_stopping
:
Optional
[
bool
]
=
None
,
early_stopping
:
Optional
[
Union
[
bool
,
str
]
]
=
None
,
logits_processor
:
Optional
[
TFLogitsProcessorList
]
=
None
,
logits_warper
:
Optional
[
TFLogitsProcessorList
]
=
None
,
num_return_sequences
:
Optional
[
int
]
=
None
,
...
...
@@ -1838,8 +1839,12 @@ class TFGenerationMixin:
to the sequence length, which in turn is used to divide the score of the sequence. Since the score is
the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences,
while `length_penalty` < 0.0 encourages shorter sequences.
early_stopping (`bool`, *optional*, defaults to `False`):
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
early_stopping (`bool` or `str`, *optional*, defaults to `False`):
Controls the stopping condition for beam-based methods, like beam-search. It accepts the following
values: `True`, where the generation stops as soon as there are `num_beams` complete candidates;
`False`, where an heuristic is applied and the generation stops when is it very unlikely to find better
candidates; `"never"`, where the beam search procedure only stops when there cannot be better
candidates (canonical beam search algorithm).
logits_processor (`[TFLogitsProcessorList]`, *optional*):
An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
...
...
@@ -2009,16 +2014,24 @@ class TFGenerationMixin:
not_max_length_yet
=
cur_len
<
max_length
# 2. can the new beams still improve?
# early_stopping == False -> apply heuristic = always get the best score from `cur_len`. See the discussion
# below for more details.
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
# early_stopping == "never" -> compute the best score from max_length or cur_len, depending on the sign of
# length_penalty. Positive length_penalty favors longer sequences, thus we use max_length there.
if
early_stopping
==
"never"
and
length_penalty
>
0.0
:
best_running_score
=
running_scores
[:,
:
1
]
/
(
max_length
**
length_penalty
)
else
:
best_running_score
=
running_scores
[:,
:
1
]
/
(
tf
.
cast
(
cur_len
,
dtype
=
tf
.
float32
)
**
length_penalty
)
worst_finished_score
=
tf
.
where
(
is_sent_finished
,
tf
.
math
.
reduce_min
(
scores
,
axis
=
1
,
keepdims
=
True
),
-
1.0e9
)
improvement_still_possible
=
tf
.
math
.
reduce_a
ll
(
worst_finished_score
<
best_running
_score
)
improvement_still_possible
=
tf
.
math
.
reduce_a
ny
(
best_running_score
>
worst_finished
_score
)
# 3. is there still a beam that has not finished?
still_open_beam
=
~
(
tf
.
math
.
reduce_all
(
is_sent_finished
)
&
early_stopping
)
still_open_beam
=
~
(
tf
.
math
.
reduce_all
(
is_sent_finished
)
&
(
early_stopping
is
True
)
)
return
not_max_length_yet
&
(
still_open_beam
|
improvement_still_possible
)
return
not_max_length_yet
&
still_open_beam
&
improvement_still_possible
def
beam_search_body_fn
(
cur_len
,
...
...
@@ -2140,12 +2153,9 @@ class TFGenerationMixin:
# - make sure no scores can be added anymore if beam is full
# - make sure still running sequences cannot be chosen as finalized beam
topk_log_probs
=
topk_log_probs
/
(
tf
.
cast
(
cur_len
,
dtype
=
tf
.
float32
)
**
length_penalty
)
beams_in_batch_are_full
=
(
tf
.
broadcast_to
(
beams_in_batch_are_full
=
tf
.
broadcast_to
(
tf
.
math
.
reduce_all
(
is_sent_finished
,
axis
=-
1
,
keepdims
=
True
),
shape_list
(
did_topk_just_finished
)
)
&
early_stopping
)
)
&
(
early_stopping
is
True
)
add_penalty
=
~
did_topk_just_finished
|
beams_in_batch_are_full
topk_log_probs
+=
tf
.
cast
(
add_penalty
,
tf
.
float32
)
*
-
1.0e9
...
...
@@ -2239,7 +2249,7 @@ class TFGenerationMixin:
sequences
=
tf
.
where
(
none_finished
[:,
None
,
None
],
sequences
,
running_sequences
)
scores
=
tf
.
where
(
none_finished
[:,
None
],
scores
,
running_scores
)
# Take best beams for each batch (the score is sorted in
a
scending order)
# Take best beams for each batch (the score is sorted in
de
scending order)
sequences
=
flatten_beam_dim
(
sequences
[:,
:
num_return_sequences
,
:])
scores
=
flatten_beam_dim
(
scores
[:,
:
num_return_sequences
])
...
...
src/transformers/generation/utils.py
View file @
f21af262
...
...
@@ -1190,6 +1190,7 @@ class GenerationMixin:
generation_config
=
copy
.
deepcopy
(
generation_config
)
model_kwargs
=
generation_config
.
update
(
**
kwargs
)
# All unused kwargs must be model kwargs
generation_config
.
validate
()
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
# 2. Set generation parameters if not already defined
...
...
@@ -1458,6 +1459,7 @@ class GenerationMixin:
length_penalty
=
generation_config
.
length_penalty
,
do_early_stopping
=
generation_config
.
early_stopping
,
num_beam_hyps_to_keep
=
generation_config
.
num_return_sequences
,
max_length
=
generation_config
.
max_length
,
)
# 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids
,
model_kwargs
=
self
.
_expand_inputs_for_generation
(
...
...
@@ -1493,6 +1495,7 @@ class GenerationMixin:
device
=
inputs_tensor
.
device
,
length_penalty
=
generation_config
.
length_penalty
,
do_early_stopping
=
generation_config
.
early_stopping
,
max_length
=
generation_config
.
max_length
,
)
# 13. interleave input_ids with `num_beams` additional sequences per batch
...
...
@@ -1536,12 +1539,12 @@ class GenerationMixin:
beam_scorer
=
BeamSearchScorer
(
batch_size
=
batch_size
,
num_beams
=
generation_config
.
num_beams
,
max_length
=
stopping_criteria
.
max_length
,
device
=
inputs_tensor
.
device
,
length_penalty
=
generation_config
.
length_penalty
,
do_early_stopping
=
generation_config
.
early_stopping
,
num_beam_hyps_to_keep
=
generation_config
.
num_return_sequences
,
num_beam_groups
=
generation_config
.
num_beam_groups
,
max_length
=
generation_config
.
max_length
,
)
# 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids
,
model_kwargs
=
self
.
_expand_inputs_for_generation
(
...
...
@@ -1629,6 +1632,7 @@ class GenerationMixin:
length_penalty
=
generation_config
.
length_penalty
,
do_early_stopping
=
generation_config
.
early_stopping
,
num_beam_hyps_to_keep
=
generation_config
.
num_return_sequences
,
max_length
=
generation_config
.
max_length
,
)
# 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids
,
model_kwargs
=
self
.
_expand_inputs_for_generation
(
...
...
src/transformers/models/rag/modeling_rag.py
View file @
f21af262
...
...
@@ -1566,6 +1566,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
length_penalty
=
generation_config
.
length_penalty
,
do_early_stopping
=
generation_config
.
early_stopping
,
num_beam_hyps_to_keep
=
generation_config
.
num_return_sequences
,
max_length
=
generation_config
.
max_length
,
)
return
self
.
beam_search
(
input_ids
,
...
...
tests/generation/test_utils.py
View file @
f21af262
...
...
@@ -2034,59 +2034,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
**
model_kwargs
,
)
def
test_beam_search_warning_if_max_length_is_passed
(
self
):
article
=
"""Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer
=
BartTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
bart_model
=
BartForConditionalGeneration
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
).
to
(
torch_device
)
batch_size
=
1
num_beams
=
3
input_ids
=
bart_tokenizer
(
article
,
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
input_ids
=
input_ids
.
expand
(
num_beams
,
-
1
)
model_kwargs
=
bart_model
.
_prepare_encoder_decoder_kwargs_for_generation
(
input_ids
,
{})
# pretend decoder_input_ids correspond to first encoder input id
decoder_input_ids
=
input_ids
[:,
:
1
]
stopping_criteria_max_length
=
18
stopping_criteria
=
StoppingCriteriaList
([
MaxLengthCriteria
(
max_length
=
stopping_criteria_max_length
)])
with
self
.
assertWarns
(
UserWarning
):
beam_scorer
=
BeamSearchScorer
(
batch_size
=
batch_size
,
num_beams
=
num_beams
,
device
=
torch_device
,
max_length
=
10
,
)
generated_ids
=
bart_model
.
beam_search
(
decoder_input_ids
,
num_beams
=
num_beams
,
stopping_criteria
=
stopping_criteria
,
beam_scorer
=
beam_scorer
,
**
model_kwargs
,
)
beam_scorer_no_max_len
=
BeamSearchScorer
(
batch_size
=
batch_size
,
num_beams
=
num_beams
,
device
=
torch_device
,
)
generated_ids_no_max_len
=
bart_model
.
beam_search
(
decoder_input_ids
,
num_beams
=
num_beams
,
stopping_criteria
=
stopping_criteria
,
beam_scorer
=
beam_scorer_no_max_len
,
**
model_kwargs
,
)
# BeamSearchScorer max_length should not influence "real" max_length
self
.
assertEqual
(
generated_ids
.
tolist
(),
generated_ids_no_max_len
.
tolist
())
def
test_custom_stopping_criteria_overload_error
(
self
):
article
=
"""Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer
=
BartTokenizer
.
from_pretrained
(
"sshleifer/bart-tiny-random"
)
...
...
tests/models/bart/test_modeling_flax_bart.py
View file @
f21af262
...
...
@@ -426,7 +426,7 @@ class FlaxBartModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationT
)
input_ids
=
tokenizer
(
input_str
,
return_tensors
=
"np"
).
input_ids
sequences
=
model
.
generate
(
input_ids
,
num_beams
=
2
,
max_length
=
20
).
sequences
sequences
=
model
.
generate
(
input_ids
,
num_beams
=
2
,
min_length
=
None
,
max_length
=
20
).
sequences
output_str
=
tokenizer
.
batch_decode
(
sequences
)[
0
]
...
...
tests/models/gpt2/test_modeling_flax_gpt2.py
View file @
f21af262
...
...
@@ -224,7 +224,7 @@ class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittes
output_string
=
tokenizer
.
batch_decode
(
output_sequences
,
skip_special_tokens
=
True
)
expected_string
=
[
"Hello this is a long string of words. I'm going to
try to explain what I mean.
"
,
"Hello this is a long string of words. I'm going to
start with the first one.
\n
"
,
"Hey, I'm not sure if I'm going to be able to do"
,
]
...
...
tests/models/t5/test_modeling_flax_t5.py
View file @
f21af262
...
...
@@ -1076,7 +1076,7 @@ class FlaxT5ModelIntegrationTests(unittest.TestCase):
expected_summaries
=
[
'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a'
" cell phone video of the final seconds .
\"
one can hear cries of 'My God' in several languages,
\"
one"
" magazine says . all 150 on board were killed
when germanwings flight 9525
crash
ed
."
,
" magazine says . all 150 on board were killed
in the
crash ."
,
"the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a"
" preliminary examination into the situation in the occupied Palestinian territory . as members of the"
" court, Palestinians may be subject to counter-charges as well ."
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment