Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
de2f7221
Unverified
Commit
de2f7221
authored
May 13, 2024
by
Joao Gante
Committed by
GitHub
May 13, 2024
Browse files
Generate: remove near-duplicate sample/greedy copy (#30773)
parent
ce87dca1
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
92 additions
and
418 deletions
+92
-418
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+87
-413
src/transformers/models/musicgen/modeling_musicgen.py
src/transformers/models/musicgen/modeling_musicgen.py
+2
-2
src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
...ormers/models/musicgen_melody/modeling_musicgen_melody.py
+2
-2
src/transformers/models/rag/modeling_rag.py
src/transformers/models/rag/modeling_rag.py
+1
-1
No files found.
src/transformers/generation/utils.py
View file @
de2f7221
...
@@ -1683,17 +1683,6 @@ class GenerationMixin:
...
@@ -1683,17 +1683,6 @@ class GenerationMixin:
streamer
=
streamer
,
streamer
=
streamer
,
**
model_kwargs
,
**
model_kwargs
,
)
)
if
generation_mode
==
GenerationMode
.
GREEDY_SEARCH
:
# 11. run greedy search
result
=
self
.
_greedy_search
(
input_ids
,
logits_processor
=
prepared_logits_processor
,
stopping_criteria
=
prepared_stopping_criteria
,
generation_config
=
generation_config
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
**
model_kwargs
,
)
elif
generation_mode
==
GenerationMode
.
CONTRASTIVE_SEARCH
:
elif
generation_mode
==
GenerationMode
.
CONTRASTIVE_SEARCH
:
if
not
model_kwargs
[
"use_cache"
]:
if
not
model_kwargs
[
"use_cache"
]:
...
@@ -1709,9 +1698,11 @@ class GenerationMixin:
...
@@ -1709,9 +1698,11 @@ class GenerationMixin:
**
model_kwargs
,
**
model_kwargs
,
)
)
elif
generation_mode
==
GenerationMode
.
SAMPLE
:
elif
generation_mode
in
(
GenerationMode
.
SAMPLE
,
GenerationMode
.
GREEDY_SEARCH
)
:
# 11. prepare logits warper
# 11. prepare logits warper
logits_warper
=
self
.
_get_logits_warper
(
generation_config
)
prepared_logits_warper
=
(
self
.
_get_logits_warper
(
generation_config
)
if
generation_config
.
do_sample
else
None
)
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids
,
model_kwargs
=
self
.
_expand_inputs_for_generation
(
input_ids
,
model_kwargs
=
self
.
_expand_inputs_for_generation
(
...
@@ -1721,11 +1712,11 @@ class GenerationMixin:
...
@@ -1721,11 +1712,11 @@ class GenerationMixin:
**
model_kwargs
,
**
model_kwargs
,
)
)
# 13. run sample
# 13. run sample
(it degenerates to greedy search when `generation_config.do_sample=False`)
result
=
self
.
_sample
(
result
=
self
.
_sample
(
input_ids
,
input_ids
,
logits_processor
=
prepared_logits_processor
,
logits_processor
=
prepared_logits_processor
,
logits_warper
=
logits_warper
,
logits_warper
=
prepared_
logits_warper
,
stopping_criteria
=
prepared_stopping_criteria
,
stopping_criteria
=
prepared_stopping_criteria
,
generation_config
=
generation_config
,
generation_config
=
generation_config
,
synced_gpus
=
synced_gpus
,
synced_gpus
=
synced_gpus
,
...
@@ -1733,38 +1724,11 @@ class GenerationMixin:
...
@@ -1733,38 +1724,11 @@ class GenerationMixin:
**
model_kwargs
,
**
model_kwargs
,
)
)
elif
generation_mode
==
GenerationMode
.
BEAM_SEARCH
:
elif
generation_mode
in
(
GenerationMode
.
BEAM_SAMPLE
,
GenerationMode
.
BEAM_SEARCH
):
# 11. prepare beam search scorer
beam_scorer
=
BeamSearchScorer
(
batch_size
=
batch_size
,
num_beams
=
generation_config
.
num_beams
,
device
=
inputs_tensor
.
device
,
length_penalty
=
generation_config
.
length_penalty
,
do_early_stopping
=
generation_config
.
early_stopping
,
num_beam_hyps_to_keep
=
generation_config
.
num_return_sequences
,
max_length
=
generation_config
.
max_length
,
)
# 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids
,
model_kwargs
=
self
.
_expand_inputs_for_generation
(
input_ids
=
input_ids
,
expand_size
=
generation_config
.
num_beams
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
**
model_kwargs
,
)
# 13. run beam search
result
=
self
.
_beam_search
(
input_ids
,
beam_scorer
,
logits_processor
=
prepared_logits_processor
,
stopping_criteria
=
prepared_stopping_criteria
,
generation_config
=
generation_config
,
synced_gpus
=
synced_gpus
,
**
model_kwargs
,
)
elif
generation_mode
==
GenerationMode
.
BEAM_SAMPLE
:
# 11. prepare logits warper
# 11. prepare logits warper
logits_warper
=
self
.
_get_logits_warper
(
generation_config
)
prepared_logits_warper
=
(
self
.
_get_logits_warper
(
generation_config
)
if
generation_config
.
do_sample
else
None
)
# 12. prepare beam search scorer
# 12. prepare beam search scorer
beam_scorer
=
BeamSearchScorer
(
beam_scorer
=
BeamSearchScorer
(
...
@@ -1786,11 +1750,11 @@ class GenerationMixin:
...
@@ -1786,11 +1750,11 @@ class GenerationMixin:
)
)
# 14. run beam sample
# 14. run beam sample
result
=
self
.
_beam_s
ample
(
result
=
self
.
_beam_s
earch
(
input_ids
,
input_ids
,
beam_scorer
,
beam_scorer
,
logits_processor
=
prepared_logits_processor
,
logits_processor
=
prepared_logits_processor
,
logits_warper
=
logits_warper
,
logits_warper
=
prepared_
logits_warper
,
stopping_criteria
=
prepared_stopping_criteria
,
stopping_criteria
=
prepared_stopping_criteria
,
generation_config
=
generation_config
,
generation_config
=
generation_config
,
synced_gpus
=
synced_gpus
,
synced_gpus
=
synced_gpus
,
...
@@ -2284,162 +2248,32 @@ class GenerationMixin:
...
@@ -2284,162 +2248,32 @@ class GenerationMixin:
**
model_kwargs
,
**
model_kwargs
,
)
->
Union
[
GenerateNonBeamOutput
,
torch
.
LongTensor
]:
)
->
Union
[
GenerateNonBeamOutput
,
torch
.
LongTensor
]:
r
"""
r
"""
Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be
Deprecated. Use `._sample()` instead, passing the same arguments.
used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
logits_processor (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
stopping_criteria (`StoppingCriteriaList`):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
model_kwargs:
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
Return:
[`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
[`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
`model.config.is_encoder_decoder=True`.
"""
"""
# init values
pad_token_id
=
generation_config
.
pad_token_id
output_attentions
=
generation_config
.
output_attentions
output_hidden_states
=
generation_config
.
output_hidden_states
output_scores
=
generation_config
.
output_scores
output_logits
=
generation_config
.
output_logits
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
has_eos_stopping_criteria
=
any
(
hasattr
(
criteria
,
"eos_token_id"
)
for
criteria
in
stopping_criteria
)
# init attention / hidden states / scores tuples
logger
.
warning_once
(
raw_logits
=
()
if
(
return_dict_in_generate
and
output_logits
)
else
None
"Calling `._greedy_search()` directly is deprecated and will be removed in v4.42. Use `._sample()` "
scores
=
()
if
(
return_dict_in_generate
and
output_scores
)
else
None
"instead, passing the same arguments."
decoder_attentions
=
()
if
(
return_dict_in_generate
and
output_attentions
)
else
None
)
cross_attentions
=
()
if
(
return_dict_in_generate
and
output_attentions
)
else
None
return
self
.
_sample
(
decoder_hidden_states
=
()
if
(
return_dict_in_generate
and
output_hidden_states
)
else
None
input_ids
=
input_ids
,
logits_processor
=
logits_processor
,
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
stopping_criteria
=
stopping_criteria
,
if
return_dict_in_generate
and
self
.
config
.
is_encoder_decoder
:
generation_config
=
generation_config
,
encoder_attentions
=
model_kwargs
[
"encoder_outputs"
].
get
(
"attentions"
)
if
output_attentions
else
None
synced_gpus
=
synced_gpus
,
encoder_hidden_states
=
(
streamer
=
streamer
,
model_kwargs
[
"encoder_outputs"
].
get
(
"hidden_states"
)
if
output_hidden_states
else
None
**
model_kwargs
,
)
)
# keep track of which sequences are already finished
batch_size
=
input_ids
.
shape
[
0
]
this_peer_finished
=
False
unfinished_sequences
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
model_kwargs
=
self
.
_get_initial_cache_position
(
input_ids
,
model_kwargs
)
while
self
.
_has_unfinished_sequences
(
this_peer_finished
,
synced_gpus
,
device
=
input_ids
.
device
):
# prepare model inputs
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
**
model_kwargs
)
# forward pass to get next token
outputs
=
self
(
**
model_inputs
,
return_dict
=
True
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
)
if
synced_gpus
and
this_peer_finished
:
continue
# don't waste resources running the code we don't need
next_token_logits
=
outputs
.
logits
[:,
-
1
,
:]
# pre-process distribution
next_tokens_scores
=
logits_processor
(
input_ids
,
next_token_logits
)
# Store scores, attentions and hidden_states when required
if
return_dict_in_generate
:
if
output_scores
:
scores
+=
(
next_tokens_scores
,)
if
output_logits
:
raw_logits
+=
(
next_token_logits
,)
if
output_attentions
:
decoder_attentions
+=
(
(
outputs
.
decoder_attentions
,)
if
self
.
config
.
is_encoder_decoder
else
(
outputs
.
attentions
,)
)
if
self
.
config
.
is_encoder_decoder
:
cross_attentions
+=
(
outputs
.
cross_attentions
,)
if
output_hidden_states
:
decoder_hidden_states
+=
(
(
outputs
.
decoder_hidden_states
,)
if
self
.
config
.
is_encoder_decoder
else
(
outputs
.
hidden_states
,)
)
# argmax
next_tokens
=
torch
.
argmax
(
next_tokens_scores
,
dim
=-
1
)
# finished sentences should have their next token be a padding token
if
has_eos_stopping_criteria
:
next_tokens
=
next_tokens
*
unfinished_sequences
+
pad_token_id
*
(
1
-
unfinished_sequences
)
# update generated ids, model inputs, and length for next step
input_ids
=
torch
.
cat
([
input_ids
,
next_tokens
[:,
None
]],
dim
=-
1
)
if
streamer
is
not
None
:
streamer
.
put
(
next_tokens
.
cpu
())
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
)
unfinished_sequences
=
unfinished_sequences
&
~
stopping_criteria
(
input_ids
,
scores
)
this_peer_finished
=
unfinished_sequences
.
max
()
==
0
if
streamer
is
not
None
:
streamer
.
end
()
if
return_dict_in_generate
:
if
self
.
config
.
is_encoder_decoder
:
return
GenerateEncoderDecoderOutput
(
sequences
=
input_ids
,
scores
=
scores
,
logits
=
raw_logits
,
encoder_attentions
=
encoder_attentions
,
encoder_hidden_states
=
encoder_hidden_states
,
decoder_attentions
=
decoder_attentions
,
cross_attentions
=
cross_attentions
,
decoder_hidden_states
=
decoder_hidden_states
,
past_key_values
=
model_kwargs
.
get
(
"past_key_values"
),
)
else
:
return
GenerateDecoderOnlyOutput
(
sequences
=
input_ids
,
scores
=
scores
,
logits
=
raw_logits
,
attentions
=
decoder_attentions
,
hidden_states
=
decoder_hidden_states
,
past_key_values
=
model_kwargs
.
get
(
"past_key_values"
),
)
else
:
return
input_ids
def
_sample
(
def
_sample
(
self
,
self
,
input_ids
:
torch
.
LongTensor
,
input_ids
:
torch
.
LongTensor
,
logits_processor
:
LogitsProcessorList
,
logits_processor
:
LogitsProcessorList
,
stopping_criteria
:
StoppingCriteriaList
,
stopping_criteria
:
StoppingCriteriaList
,
logits_warper
:
LogitsProcessorList
,
generation_config
:
GenerationConfig
,
generation_config
:
GenerationConfig
,
synced_gpus
:
bool
,
synced_gpus
:
bool
,
streamer
:
Optional
[
"BaseStreamer"
],
streamer
:
Optional
[
"BaseStreamer"
],
logits_warper
:
Optional
[
LogitsProcessorList
]
=
None
,
**
model_kwargs
,
**
model_kwargs
,
)
->
Union
[
GenerateNonBeamOutput
,
torch
.
LongTensor
]:
)
->
Union
[
GenerateNonBeamOutput
,
torch
.
LongTensor
]:
r
"""
r
"""
...
@@ -2455,10 +2289,6 @@ class GenerationMixin:
...
@@ -2455,10 +2289,6 @@ class GenerationMixin:
stopping_criteria (`StoppingCriteriaList`):
stopping_criteria (`StoppingCriteriaList`):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
used to tell if the generation loop should stop.
logits_warper (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
to warp the prediction score distribution of the language modeling head applied before multinomial
sampling at each generation step.
generation_config ([`~generation.GenerationConfig`]):
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
synced_gpus (`bool`):
...
@@ -2466,6 +2296,11 @@ class GenerationMixin:
...
@@ -2466,6 +2296,11 @@ class GenerationMixin:
streamer (`BaseStreamer`, *optional*):
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
logits_warper (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
to warp the prediction score distribution of the language modeling head applied before multinomial
sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in
`generation_config`)
model_kwargs:
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
an encoder-decoder model the kwargs should include `encoder_outputs`.
...
@@ -2485,6 +2320,12 @@ class GenerationMixin:
...
@@ -2485,6 +2320,12 @@ class GenerationMixin:
output_logits
=
generation_config
.
output_logits
output_logits
=
generation_config
.
output_logits
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
has_eos_stopping_criteria
=
any
(
hasattr
(
criteria
,
"eos_token_id"
)
for
criteria
in
stopping_criteria
)
has_eos_stopping_criteria
=
any
(
hasattr
(
criteria
,
"eos_token_id"
)
for
criteria
in
stopping_criteria
)
do_sample
=
generation_config
.
do_sample
if
do_sample
is
True
and
not
isinstance
(
logits_warper
,
LogitsProcessorList
):
raise
ValueError
(
"`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
f
"
{
logits_warper
}
)."
)
# init attention / hidden states / scores tuples
# init attention / hidden states / scores tuples
scores
=
()
if
(
return_dict_in_generate
and
output_scores
)
else
None
scores
=
()
if
(
return_dict_in_generate
and
output_scores
)
else
None
...
@@ -2525,7 +2366,8 @@ class GenerationMixin:
...
@@ -2525,7 +2366,8 @@ class GenerationMixin:
# pre-process distribution
# pre-process distribution
next_token_scores
=
logits_processor
(
input_ids
,
next_token_logits
)
next_token_scores
=
logits_processor
(
input_ids
,
next_token_logits
)
next_token_scores
=
logits_warper
(
input_ids
,
next_token_scores
)
if
do_sample
:
next_token_scores
=
logits_warper
(
input_ids
,
next_token_scores
)
# Store scores, attentions and hidden_states when required
# Store scores, attentions and hidden_states when required
if
return_dict_in_generate
:
if
return_dict_in_generate
:
...
@@ -2547,9 +2389,12 @@ class GenerationMixin:
...
@@ -2547,9 +2389,12 @@ class GenerationMixin:
else
(
outputs
.
hidden_states
,)
else
(
outputs
.
hidden_states
,)
)
)
# sample
# token selection
probs
=
nn
.
functional
.
softmax
(
next_token_scores
,
dim
=-
1
)
if
do_sample
:
next_tokens
=
torch
.
multinomial
(
probs
,
num_samples
=
1
).
squeeze
(
1
)
probs
=
nn
.
functional
.
softmax
(
next_token_scores
,
dim
=-
1
)
next_tokens
=
torch
.
multinomial
(
probs
,
num_samples
=
1
).
squeeze
(
1
)
else
:
next_tokens
=
torch
.
argmax
(
next_token_scores
,
dim
=-
1
)
# finished sentences should have their next token be a padding token
# finished sentences should have their next token be a padding token
if
has_eos_stopping_criteria
:
if
has_eos_stopping_criteria
:
...
@@ -2622,6 +2467,7 @@ class GenerationMixin:
...
@@ -2622,6 +2467,7 @@ class GenerationMixin:
past_key_values
.
reorder_cache
(
beam_idx
)
past_key_values
.
reorder_cache
(
beam_idx
)
return
past_key_values
return
past_key_values
# TODO (joao, v4.42): remove default for `logits_warper`
def
_beam_search
(
def
_beam_search
(
self
,
self
,
input_ids
:
torch
.
LongTensor
,
input_ids
:
torch
.
LongTensor
,
...
@@ -2630,6 +2476,7 @@ class GenerationMixin:
...
@@ -2630,6 +2476,7 @@ class GenerationMixin:
stopping_criteria
:
StoppingCriteriaList
,
stopping_criteria
:
StoppingCriteriaList
,
generation_config
:
GenerationConfig
,
generation_config
:
GenerationConfig
,
synced_gpus
:
bool
,
synced_gpus
:
bool
,
logits_warper
:
Optional
[
LogitsProcessorList
]
=
None
,
**
model_kwargs
,
**
model_kwargs
,
)
->
Union
[
GenerateBeamOutput
,
torch
.
LongTensor
]:
)
->
Union
[
GenerateBeamOutput
,
torch
.
LongTensor
]:
r
"""
r
"""
...
@@ -2652,6 +2499,11 @@ class GenerationMixin:
...
@@ -2652,6 +2499,11 @@ class GenerationMixin:
The generation configuration to be used as parametrization of the decoding method.
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
logits_warper (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
to warp the prediction score distribution of the language modeling head applied before multinomial
sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in
`generation_config`)
model_kwargs:
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
an encoder-decoder model the kwargs should include `encoder_outputs`.
...
@@ -2672,6 +2524,12 @@ class GenerationMixin:
...
@@ -2672,6 +2524,12 @@ class GenerationMixin:
output_logits
=
generation_config
.
output_logits
output_logits
=
generation_config
.
output_logits
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
sequential
=
generation_config
.
low_memory
sequential
=
generation_config
.
low_memory
do_sample
=
generation_config
.
do_sample
if
do_sample
is
True
and
not
isinstance
(
logits_warper
,
LogitsProcessorList
):
raise
ValueError
(
"`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
f
"
{
logits_warper
}
)."
)
batch_size
=
len
(
beam_scorer
.
_beam_hyps
)
batch_size
=
len
(
beam_scorer
.
_beam_hyps
)
num_beams
=
beam_scorer
.
num_beams
num_beams
=
beam_scorer
.
num_beams
...
@@ -2768,6 +2626,8 @@ class GenerationMixin:
...
@@ -2768,6 +2626,8 @@ class GenerationMixin:
)
# (batch_size * num_beams, vocab_size)
)
# (batch_size * num_beams, vocab_size)
next_token_scores_processed
=
logits_processor
(
input_ids
,
next_token_scores
)
next_token_scores_processed
=
logits_processor
(
input_ids
,
next_token_scores
)
if
do_sample
:
next_token_scores_processed
=
logits_warper
(
input_ids
,
next_token_scores_processed
)
next_token_scores
=
next_token_scores_processed
+
beam_scores
[:,
None
].
expand_as
(
next_token_scores
=
next_token_scores_processed
+
beam_scores
[:,
None
].
expand_as
(
next_token_scores_processed
next_token_scores_processed
)
)
...
@@ -2795,11 +2655,20 @@ class GenerationMixin:
...
@@ -2795,11 +2655,20 @@ class GenerationMixin:
vocab_size
=
next_token_scores
.
shape
[
-
1
]
vocab_size
=
next_token_scores
.
shape
[
-
1
]
next_token_scores
=
next_token_scores
.
view
(
batch_size
,
num_beams
*
vocab_size
)
next_token_scores
=
next_token_scores
.
view
(
batch_size
,
num_beams
*
vocab_size
)
# Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam.
# Beam token selection: pick 1 + eos_token_id.shape[0] next tokens for each beam so we have at least 1
# non eos token per beam.
n_eos_tokens
=
eos_token_id
.
shape
[
0
]
if
eos_token_id
is
not
None
else
0
n_eos_tokens
=
eos_token_id
.
shape
[
0
]
if
eos_token_id
is
not
None
else
0
next_token_scores
,
next_tokens
=
torch
.
topk
(
n_tokens_to_keep
=
max
(
2
,
1
+
n_eos_tokens
)
*
num_beams
next_token_scores
,
max
(
2
,
1
+
n_eos_tokens
)
*
num_beams
,
dim
=
1
,
largest
=
True
,
sorted
=
True
if
do_sample
:
)
probs
=
nn
.
functional
.
softmax
(
next_token_scores
,
dim
=-
1
)
next_tokens
=
torch
.
multinomial
(
probs
,
num_samples
=
n_tokens_to_keep
)
next_token_scores
=
torch
.
gather
(
next_token_scores
,
-
1
,
next_tokens
)
next_token_scores
,
_indices
=
torch
.
sort
(
next_token_scores
,
descending
=
True
,
dim
=
1
)
next_tokens
=
torch
.
gather
(
next_tokens
,
-
1
,
_indices
)
else
:
next_token_scores
,
next_tokens
=
torch
.
topk
(
next_token_scores
,
n_tokens_to_keep
,
dim
=
1
,
largest
=
True
,
sorted
=
True
)
next_indices
=
torch
.
div
(
next_tokens
,
vocab_size
,
rounding_mode
=
"floor"
)
next_indices
=
torch
.
div
(
next_tokens
,
vocab_size
,
rounding_mode
=
"floor"
)
next_tokens
=
next_tokens
%
vocab_size
next_tokens
=
next_tokens
%
vocab_size
...
@@ -2897,219 +2766,24 @@ class GenerationMixin:
...
@@ -2897,219 +2766,24 @@ class GenerationMixin:
**
model_kwargs
,
**
model_kwargs
,
)
->
Union
[
GenerateBeamOutput
,
torch
.
LongTensor
]:
)
->
Union
[
GenerateBeamOutput
,
torch
.
LongTensor
]:
r
"""
r
"""
Generates sequences of token ids for models with a language modeling head using **beam search multinomial
Deprecated. Use `._beam_search()` instead, passing the same arguments.
sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
beam_scorer (`BeamScorer`):
A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
logits_processor (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
stopping_criteria (`StoppingCriteriaList`):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
logits_warper (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
to warp the prediction score distribution of the language modeling head applied before multinomial
sampling at each generation step.
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
Return:
[`~generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
[`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
`model.config.is_encoder_decoder=True`.
"""
"""
# init values
pad_token_id
=
generation_config
.
pad_token_id
eos_token_id
=
generation_config
.
eos_token_id
output_attentions
=
generation_config
.
output_attentions
output_hidden_states
=
generation_config
.
output_hidden_states
output_scores
=
generation_config
.
output_scores
output_logits
=
generation_config
.
output_logits
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
batch_size
=
len
(
beam_scorer
.
_beam_hyps
)
logger
.
warning_once
(
num_beams
=
beam_scorer
.
num_beams
"Calling `._beam_sample()` directly is deprecated and will be removed in v4.42. Use `._beam_search()` "
"instead, passing the same arguments."
batch_beam_size
,
cur_len
=
input_ids
.
shape
model_kwargs
=
self
.
_get_initial_cache_position
(
input_ids
,
model_kwargs
)
# init attention / hidden states / scores tuples
scores
=
()
if
(
return_dict_in_generate
and
output_scores
)
else
None
raw_logits
=
()
if
(
return_dict_in_generate
and
output_logits
)
else
None
beam_indices
=
(
tuple
(()
for
_
in
range
(
batch_beam_size
))
if
(
return_dict_in_generate
and
output_scores
)
else
None
)
)
decoder_attentions
=
()
if
(
return_dict_in_generate
and
output_attentions
)
else
None
return
self
.
_beam_search
(
cross_attentions
=
()
if
(
return_dict_in_generate
and
output_attentions
)
else
None
input_ids
=
input_ids
,
decoder_hidden_states
=
()
if
(
return_dict_in_generate
and
output_hidden_states
)
else
None
beam_scorer
=
beam_scorer
,
logits_processor
=
logits_processor
,
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
stopping_criteria
=
stopping_criteria
,
if
return_dict_in_generate
and
self
.
config
.
is_encoder_decoder
:
logits_warper
=
logits_warper
,
encoder_attentions
=
model_kwargs
[
"encoder_outputs"
].
get
(
"attentions"
)
if
output_attentions
else
None
generation_config
=
generation_config
,
encoder_hidden_states
=
(
synced_gpus
=
synced_gpus
,
model_kwargs
[
"encoder_outputs"
].
get
(
"hidden_states"
)
if
output_hidden_states
else
None
**
model_kwargs
,
)
beam_scores
=
torch
.
zeros
((
batch_size
,
num_beams
),
dtype
=
torch
.
float
,
device
=
input_ids
.
device
)
beam_scores
=
beam_scores
.
view
((
batch_size
*
num_beams
,))
this_peer_finished
=
False
decoder_prompt_len
=
input_ids
.
shape
[
-
1
]
# record the prompt length of decoder
while
self
.
_has_unfinished_sequences
(
this_peer_finished
,
synced_gpus
,
device
=
input_ids
.
device
):
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
**
model_kwargs
)
outputs
=
self
(
**
model_inputs
,
return_dict
=
True
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
)
if
synced_gpus
and
this_peer_finished
:
cur_len
=
cur_len
+
1
continue
# don't waste resources running the code we don't need
next_token_logits
=
outputs
.
logits
[:,
-
1
,
:]
next_token_scores
=
nn
.
functional
.
log_softmax
(
next_token_logits
,
dim
=-
1
)
# (batch_size * num_beams, vocab_size)
next_token_scores_processed
=
logits_processor
(
input_ids
,
next_token_scores
)
next_token_scores_processed
=
logits_warper
(
input_ids
,
next_token_scores_processed
)
next_token_scores
=
next_token_scores_processed
+
beam_scores
[:,
None
].
expand_as
(
next_token_scores_processed
)
# Store scores, attentions and hidden_states when required
if
return_dict_in_generate
:
if
output_scores
:
scores
+=
(
next_token_scores_processed
,)
if
output_logits
:
raw_logits
+=
(
next_token_logits
,)
if
output_attentions
:
decoder_attentions
+=
(
(
outputs
.
decoder_attentions
,)
if
self
.
config
.
is_encoder_decoder
else
(
outputs
.
attentions
,)
)
if
self
.
config
.
is_encoder_decoder
:
cross_attentions
+=
(
outputs
.
cross_attentions
,)
if
output_hidden_states
:
decoder_hidden_states
+=
(
(
outputs
.
decoder_hidden_states
,)
if
self
.
config
.
is_encoder_decoder
else
(
outputs
.
hidden_states
,)
)
# reshape for beam search
vocab_size
=
next_token_scores
.
shape
[
-
1
]
next_token_scores
=
next_token_scores
.
view
(
batch_size
,
num_beams
*
vocab_size
)
probs
=
nn
.
functional
.
softmax
(
next_token_scores
,
dim
=-
1
)
next_tokens
=
torch
.
multinomial
(
probs
,
num_samples
=
2
*
num_beams
)
next_token_scores
=
torch
.
gather
(
next_token_scores
,
-
1
,
next_tokens
)
next_token_scores
,
_indices
=
torch
.
sort
(
next_token_scores
,
descending
=
True
,
dim
=
1
)
next_tokens
=
torch
.
gather
(
next_tokens
,
-
1
,
_indices
)
next_indices
=
torch
.
div
(
next_tokens
,
vocab_size
,
rounding_mode
=
"floor"
)
next_tokens
=
next_tokens
%
vocab_size
# stateless
beam_outputs
=
beam_scorer
.
process
(
input_ids
,
next_token_scores
,
next_tokens
,
next_indices
,
pad_token_id
=
pad_token_id
,
eos_token_id
=
eos_token_id
,
beam_indices
=
beam_indices
,
decoder_prompt_len
=
decoder_prompt_len
,
)
beam_scores
=
beam_outputs
[
"next_beam_scores"
]
beam_next_tokens
=
beam_outputs
[
"next_beam_tokens"
]
beam_idx
=
beam_outputs
[
"next_beam_indices"
]
input_ids
=
torch
.
cat
([
input_ids
[
beam_idx
,
:],
beam_next_tokens
.
unsqueeze
(
-
1
)],
dim
=-
1
)
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
)
if
model_kwargs
.
get
(
"past_key_values"
,
None
)
is
not
None
:
model_kwargs
[
"past_key_values"
]
=
self
.
_temporary_reorder_cache
(
model_kwargs
[
"past_key_values"
],
beam_idx
)
if
return_dict_in_generate
and
output_scores
:
beam_indices
=
tuple
((
beam_indices
[
beam_idx
[
i
]]
+
(
beam_idx
[
i
],)
for
i
in
range
(
len
(
beam_indices
))))
# increase cur_len
cur_len
=
cur_len
+
1
if
beam_scorer
.
is_done
or
all
(
stopping_criteria
(
input_ids
,
scores
)):
this_peer_finished
=
True
sequence_outputs
=
beam_scorer
.
finalize
(
input_ids
,
beam_scores
,
next_tokens
,
next_indices
,
pad_token_id
=
pad_token_id
,
eos_token_id
=
eos_token_id
,
max_length
=
stopping_criteria
.
max_length
,
beam_indices
=
beam_indices
,
decoder_prompt_len
=
decoder_prompt_len
,
)
)
if
return_dict_in_generate
:
if
not
output_scores
:
sequence_outputs
[
"sequence_scores"
]
=
None
if
self
.
config
.
is_encoder_decoder
:
return
GenerateBeamEncoderDecoderOutput
(
sequences
=
sequence_outputs
[
"sequences"
],
sequences_scores
=
sequence_outputs
[
"sequence_scores"
],
scores
=
scores
,
logits
=
raw_logits
,
beam_indices
=
sequence_outputs
[
"beam_indices"
],
encoder_attentions
=
encoder_attentions
,
encoder_hidden_states
=
encoder_hidden_states
,
decoder_attentions
=
decoder_attentions
,
cross_attentions
=
cross_attentions
,
decoder_hidden_states
=
decoder_hidden_states
,
past_key_values
=
model_kwargs
.
get
(
"past_key_values"
),
)
else
:
return
GenerateBeamDecoderOnlyOutput
(
sequences
=
sequence_outputs
[
"sequences"
],
sequences_scores
=
sequence_outputs
[
"sequence_scores"
],
scores
=
scores
,
logits
=
raw_logits
,
beam_indices
=
sequence_outputs
[
"beam_indices"
],
attentions
=
decoder_attentions
,
hidden_states
=
decoder_hidden_states
,
past_key_values
=
model_kwargs
.
get
(
"past_key_values"
),
)
else
:
return
sequence_outputs
[
"sequences"
]
def
_group_beam_search
(
def
_group_beam_search
(
self
,
self
,
input_ids
:
torch
.
LongTensor
,
input_ids
:
torch
.
LongTensor
,
...
...
src/transformers/models/musicgen/modeling_musicgen.py
View file @
de2f7221
...
@@ -1739,7 +1739,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
...
@@ -1739,7 +1739,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
)
)
# 11. run greedy search
# 11. run greedy search
outputs
=
self
.
_
greedy_search
(
outputs
=
self
.
_
sample
(
input_ids
,
input_ids
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
...
@@ -2832,7 +2832,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
...
@@ -2832,7 +2832,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
)
)
# 11. run greedy search
# 11. run greedy search
outputs
=
self
.
_
greedy_search
(
outputs
=
self
.
_
sample
(
input_ids
,
input_ids
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
...
...
src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
View file @
de2f7221
...
@@ -1676,7 +1676,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
...
@@ -1676,7 +1676,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
)
)
# 11. run greedy search
# 11. run greedy search
outputs
=
self
.
_
greedy_search
(
outputs
=
self
.
_
sample
(
input_ids
,
input_ids
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
...
@@ -2691,7 +2691,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
...
@@ -2691,7 +2691,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
)
)
# 11. run greedy search
# 11. run greedy search
outputs
=
self
.
_
greedy_search
(
outputs
=
self
.
_
sample
(
input_ids
,
input_ids
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
...
...
src/transformers/models/rag/modeling_rag.py
View file @
de2f7221
...
@@ -1550,7 +1550,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
...
@@ -1550,7 +1550,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
f
"num_return_sequences has to be 1, but is
{
generation_config
.
num_return_sequences
}
when doing"
f
"num_return_sequences has to be 1, but is
{
generation_config
.
num_return_sequences
}
when doing"
" greedy search."
" greedy search."
)
)
return
self
.
_
greedy_search
(
return
self
.
_
sample
(
input_ids
,
input_ids
,
logits_processor
=
pre_processor
,
logits_processor
=
pre_processor
,
stopping_criteria
=
prepared_stopping_criteria
,
stopping_criteria
=
prepared_stopping_criteria
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment