Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
de2f7221
Unverified
Commit
de2f7221
authored
May 13, 2024
by
Joao Gante
Committed by
GitHub
May 13, 2024
Browse files
Generate: remove near-duplicate sample/greedy copy (#30773)
parent
ce87dca1
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
92 additions
and
418 deletions
+92
-418
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+87
-413
src/transformers/models/musicgen/modeling_musicgen.py
src/transformers/models/musicgen/modeling_musicgen.py
+2
-2
src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
...ormers/models/musicgen_melody/modeling_musicgen_melody.py
+2
-2
src/transformers/models/rag/modeling_rag.py
src/transformers/models/rag/modeling_rag.py
+1
-1
No files found.
src/transformers/generation/utils.py
View file @
de2f7221
...
...
@@ -1683,17 +1683,6 @@ class GenerationMixin:
streamer
=
streamer
,
**
model_kwargs
,
)
if
generation_mode
==
GenerationMode
.
GREEDY_SEARCH
:
# 11. run greedy search
result
=
self
.
_greedy_search
(
input_ids
,
logits_processor
=
prepared_logits_processor
,
stopping_criteria
=
prepared_stopping_criteria
,
generation_config
=
generation_config
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
**
model_kwargs
,
)
elif
generation_mode
==
GenerationMode
.
CONTRASTIVE_SEARCH
:
if
not
model_kwargs
[
"use_cache"
]:
...
...
@@ -1709,9 +1698,11 @@ class GenerationMixin:
**
model_kwargs
,
)
elif
generation_mode
==
GenerationMode
.
SAMPLE
:
elif
generation_mode
in
(
GenerationMode
.
SAMPLE
,
GenerationMode
.
GREEDY_SEARCH
)
:
# 11. prepare logits warper
logits_warper
=
self
.
_get_logits_warper
(
generation_config
)
prepared_logits_warper
=
(
self
.
_get_logits_warper
(
generation_config
)
if
generation_config
.
do_sample
else
None
)
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids
,
model_kwargs
=
self
.
_expand_inputs_for_generation
(
...
...
@@ -1721,11 +1712,11 @@ class GenerationMixin:
**
model_kwargs
,
)
# 13. run sample
# 13. run sample
(it degenerates to greedy search when `generation_config.do_sample=False`)
result
=
self
.
_sample
(
input_ids
,
logits_processor
=
prepared_logits_processor
,
logits_warper
=
logits_warper
,
logits_warper
=
prepared_
logits_warper
,
stopping_criteria
=
prepared_stopping_criteria
,
generation_config
=
generation_config
,
synced_gpus
=
synced_gpus
,
...
...
@@ -1733,38 +1724,11 @@ class GenerationMixin:
**
model_kwargs
,
)
elif
generation_mode
==
GenerationMode
.
BEAM_SEARCH
:
# 11. prepare beam search scorer
beam_scorer
=
BeamSearchScorer
(
batch_size
=
batch_size
,
num_beams
=
generation_config
.
num_beams
,
device
=
inputs_tensor
.
device
,
length_penalty
=
generation_config
.
length_penalty
,
do_early_stopping
=
generation_config
.
early_stopping
,
num_beam_hyps_to_keep
=
generation_config
.
num_return_sequences
,
max_length
=
generation_config
.
max_length
,
)
# 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids
,
model_kwargs
=
self
.
_expand_inputs_for_generation
(
input_ids
=
input_ids
,
expand_size
=
generation_config
.
num_beams
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
**
model_kwargs
,
)
# 13. run beam search
result
=
self
.
_beam_search
(
input_ids
,
beam_scorer
,
logits_processor
=
prepared_logits_processor
,
stopping_criteria
=
prepared_stopping_criteria
,
generation_config
=
generation_config
,
synced_gpus
=
synced_gpus
,
**
model_kwargs
,
)
elif
generation_mode
==
GenerationMode
.
BEAM_SAMPLE
:
elif
generation_mode
in
(
GenerationMode
.
BEAM_SAMPLE
,
GenerationMode
.
BEAM_SEARCH
):
# 11. prepare logits warper
logits_warper
=
self
.
_get_logits_warper
(
generation_config
)
prepared_logits_warper
=
(
self
.
_get_logits_warper
(
generation_config
)
if
generation_config
.
do_sample
else
None
)
# 12. prepare beam search scorer
beam_scorer
=
BeamSearchScorer
(
...
...
@@ -1786,11 +1750,11 @@ class GenerationMixin:
)
# 14. run beam sample
result
=
self
.
_beam_s
ample
(
result
=
self
.
_beam_s
earch
(
input_ids
,
beam_scorer
,
logits_processor
=
prepared_logits_processor
,
logits_warper
=
logits_warper
,
logits_warper
=
prepared_
logits_warper
,
stopping_criteria
=
prepared_stopping_criteria
,
generation_config
=
generation_config
,
synced_gpus
=
synced_gpus
,
...
...
@@ -2284,162 +2248,32 @@ class GenerationMixin:
**
model_kwargs
,
)
->
Union
[
GenerateNonBeamOutput
,
torch
.
LongTensor
]:
r
"""
Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be
used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
logits_processor (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
stopping_criteria (`StoppingCriteriaList`):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
model_kwargs:
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
Return:
[`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
[`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
`model.config.is_encoder_decoder=True`.
Deprecated. Use `._sample()` instead, passing the same arguments.
"""
# init values
pad_token_id
=
generation_config
.
pad_token_id
output_attentions
=
generation_config
.
output_attentions
output_hidden_states
=
generation_config
.
output_hidden_states
output_scores
=
generation_config
.
output_scores
output_logits
=
generation_config
.
output_logits
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
has_eos_stopping_criteria
=
any
(
hasattr
(
criteria
,
"eos_token_id"
)
for
criteria
in
stopping_criteria
)
# init attention / hidden states / scores tuples
raw_logits
=
()
if
(
return_dict_in_generate
and
output_logits
)
else
None
scores
=
()
if
(
return_dict_in_generate
and
output_scores
)
else
None
decoder_attentions
=
()
if
(
return_dict_in_generate
and
output_attentions
)
else
None
cross_attentions
=
()
if
(
return_dict_in_generate
and
output_attentions
)
else
None
decoder_hidden_states
=
()
if
(
return_dict_in_generate
and
output_hidden_states
)
else
None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if
return_dict_in_generate
and
self
.
config
.
is_encoder_decoder
:
encoder_attentions
=
model_kwargs
[
"encoder_outputs"
].
get
(
"attentions"
)
if
output_attentions
else
None
encoder_hidden_states
=
(
model_kwargs
[
"encoder_outputs"
].
get
(
"hidden_states"
)
if
output_hidden_states
else
None
)
# keep track of which sequences are already finished
batch_size
=
input_ids
.
shape
[
0
]
this_peer_finished
=
False
unfinished_sequences
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
model_kwargs
=
self
.
_get_initial_cache_position
(
input_ids
,
model_kwargs
)
while
self
.
_has_unfinished_sequences
(
this_peer_finished
,
synced_gpus
,
device
=
input_ids
.
device
):
# prepare model inputs
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
**
model_kwargs
)
# forward pass to get next token
outputs
=
self
(
**
model_inputs
,
return_dict
=
True
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
)
if
synced_gpus
and
this_peer_finished
:
continue
# don't waste resources running the code we don't need
next_token_logits
=
outputs
.
logits
[:,
-
1
,
:]
# pre-process distribution
next_tokens_scores
=
logits_processor
(
input_ids
,
next_token_logits
)
# Store scores, attentions and hidden_states when required
if
return_dict_in_generate
:
if
output_scores
:
scores
+=
(
next_tokens_scores
,)
if
output_logits
:
raw_logits
+=
(
next_token_logits
,)
if
output_attentions
:
decoder_attentions
+=
(
(
outputs
.
decoder_attentions
,)
if
self
.
config
.
is_encoder_decoder
else
(
outputs
.
attentions
,)
)
if
self
.
config
.
is_encoder_decoder
:
cross_attentions
+=
(
outputs
.
cross_attentions
,)
if
output_hidden_states
:
decoder_hidden_states
+=
(
(
outputs
.
decoder_hidden_states
,)
if
self
.
config
.
is_encoder_decoder
else
(
outputs
.
hidden_states
,)
)
# argmax
next_tokens
=
torch
.
argmax
(
next_tokens_scores
,
dim
=-
1
)
# finished sentences should have their next token be a padding token
if
has_eos_stopping_criteria
:
next_tokens
=
next_tokens
*
unfinished_sequences
+
pad_token_id
*
(
1
-
unfinished_sequences
)
# update generated ids, model inputs, and length for next step
input_ids
=
torch
.
cat
([
input_ids
,
next_tokens
[:,
None
]],
dim
=-
1
)
if
streamer
is
not
None
:
streamer
.
put
(
next_tokens
.
cpu
())
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
)
unfinished_sequences
=
unfinished_sequences
&
~
stopping_criteria
(
input_ids
,
scores
)
this_peer_finished
=
unfinished_sequences
.
max
()
==
0
if
streamer
is
not
None
:
streamer
.
end
()
if
return_dict_in_generate
:
if
self
.
config
.
is_encoder_decoder
:
return
GenerateEncoderDecoderOutput
(
sequences
=
input_ids
,
scores
=
scores
,
logits
=
raw_logits
,
encoder_attentions
=
encoder_attentions
,
encoder_hidden_states
=
encoder_hidden_states
,
decoder_attentions
=
decoder_attentions
,
cross_attentions
=
cross_attentions
,
decoder_hidden_states
=
decoder_hidden_states
,
past_key_values
=
model_kwargs
.
get
(
"past_key_values"
),
)
else
:
return
GenerateDecoderOnlyOutput
(
sequences
=
input_ids
,
scores
=
scores
,
logits
=
raw_logits
,
attentions
=
decoder_attentions
,
hidden_states
=
decoder_hidden_states
,
past_key_values
=
model_kwargs
.
get
(
"past_key_values"
),
)
else
:
return
input_ids
logger
.
warning_once
(
"Calling `._greedy_search()` directly is deprecated and will be removed in v4.42. Use `._sample()` "
"instead, passing the same arguments."
)
return
self
.
_sample
(
input_ids
=
input_ids
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
generation_config
=
generation_config
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
**
model_kwargs
,
)
def
_sample
(
self
,
input_ids
:
torch
.
LongTensor
,
logits_processor
:
LogitsProcessorList
,
stopping_criteria
:
StoppingCriteriaList
,
logits_warper
:
LogitsProcessorList
,
generation_config
:
GenerationConfig
,
synced_gpus
:
bool
,
streamer
:
Optional
[
"BaseStreamer"
],
logits_warper
:
Optional
[
LogitsProcessorList
]
=
None
,
**
model_kwargs
,
)
->
Union
[
GenerateNonBeamOutput
,
torch
.
LongTensor
]:
r
"""
...
...
@@ -2455,10 +2289,6 @@ class GenerationMixin:
stopping_criteria (`StoppingCriteriaList`):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
logits_warper (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
to warp the prediction score distribution of the language modeling head applied before multinomial
sampling at each generation step.
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
...
...
@@ -2466,6 +2296,11 @@ class GenerationMixin:
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
logits_warper (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
to warp the prediction score distribution of the language modeling head applied before multinomial
sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in
`generation_config`)
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
...
...
@@ -2485,6 +2320,12 @@ class GenerationMixin:
output_logits
=
generation_config
.
output_logits
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
has_eos_stopping_criteria
=
any
(
hasattr
(
criteria
,
"eos_token_id"
)
for
criteria
in
stopping_criteria
)
do_sample
=
generation_config
.
do_sample
if
do_sample
is
True
and
not
isinstance
(
logits_warper
,
LogitsProcessorList
):
raise
ValueError
(
"`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
f
"
{
logits_warper
}
)."
)
# init attention / hidden states / scores tuples
scores
=
()
if
(
return_dict_in_generate
and
output_scores
)
else
None
...
...
@@ -2525,7 +2366,8 @@ class GenerationMixin:
# pre-process distribution
next_token_scores
=
logits_processor
(
input_ids
,
next_token_logits
)
next_token_scores
=
logits_warper
(
input_ids
,
next_token_scores
)
if
do_sample
:
next_token_scores
=
logits_warper
(
input_ids
,
next_token_scores
)
# Store scores, attentions and hidden_states when required
if
return_dict_in_generate
:
...
...
@@ -2547,9 +2389,12 @@ class GenerationMixin:
else
(
outputs
.
hidden_states
,)
)
# sample
probs
=
nn
.
functional
.
softmax
(
next_token_scores
,
dim
=-
1
)
next_tokens
=
torch
.
multinomial
(
probs
,
num_samples
=
1
).
squeeze
(
1
)
# token selection
if
do_sample
:
probs
=
nn
.
functional
.
softmax
(
next_token_scores
,
dim
=-
1
)
next_tokens
=
torch
.
multinomial
(
probs
,
num_samples
=
1
).
squeeze
(
1
)
else
:
next_tokens
=
torch
.
argmax
(
next_token_scores
,
dim
=-
1
)
# finished sentences should have their next token be a padding token
if
has_eos_stopping_criteria
:
...
...
@@ -2622,6 +2467,7 @@ class GenerationMixin:
past_key_values
.
reorder_cache
(
beam_idx
)
return
past_key_values
# TODO (joao, v4.42): remove default for `logits_warper`
def
_beam_search
(
self
,
input_ids
:
torch
.
LongTensor
,
...
...
@@ -2630,6 +2476,7 @@ class GenerationMixin:
stopping_criteria
:
StoppingCriteriaList
,
generation_config
:
GenerationConfig
,
synced_gpus
:
bool
,
logits_warper
:
Optional
[
LogitsProcessorList
]
=
None
,
**
model_kwargs
,
)
->
Union
[
GenerateBeamOutput
,
torch
.
LongTensor
]:
r
"""
...
...
@@ -2652,6 +2499,11 @@ class GenerationMixin:
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
logits_warper (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
to warp the prediction score distribution of the language modeling head applied before multinomial
sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in
`generation_config`)
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
...
...
@@ -2672,6 +2524,12 @@ class GenerationMixin:
output_logits
=
generation_config
.
output_logits
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
sequential
=
generation_config
.
low_memory
do_sample
=
generation_config
.
do_sample
if
do_sample
is
True
and
not
isinstance
(
logits_warper
,
LogitsProcessorList
):
raise
ValueError
(
"`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
f
"
{
logits_warper
}
)."
)
batch_size
=
len
(
beam_scorer
.
_beam_hyps
)
num_beams
=
beam_scorer
.
num_beams
...
...
@@ -2768,6 +2626,8 @@ class GenerationMixin:
)
# (batch_size * num_beams, vocab_size)
next_token_scores_processed
=
logits_processor
(
input_ids
,
next_token_scores
)
if
do_sample
:
next_token_scores_processed
=
logits_warper
(
input_ids
,
next_token_scores_processed
)
next_token_scores
=
next_token_scores_processed
+
beam_scores
[:,
None
].
expand_as
(
next_token_scores_processed
)
...
...
@@ -2795,11 +2655,20 @@ class GenerationMixin:
vocab_size
=
next_token_scores
.
shape
[
-
1
]
next_token_scores
=
next_token_scores
.
view
(
batch_size
,
num_beams
*
vocab_size
)
# Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam.
# Beam token selection: pick 1 + eos_token_id.shape[0] next tokens for each beam so we have at least 1
# non eos token per beam.
n_eos_tokens
=
eos_token_id
.
shape
[
0
]
if
eos_token_id
is
not
None
else
0
next_token_scores
,
next_tokens
=
torch
.
topk
(
next_token_scores
,
max
(
2
,
1
+
n_eos_tokens
)
*
num_beams
,
dim
=
1
,
largest
=
True
,
sorted
=
True
)
n_tokens_to_keep
=
max
(
2
,
1
+
n_eos_tokens
)
*
num_beams
if
do_sample
:
probs
=
nn
.
functional
.
softmax
(
next_token_scores
,
dim
=-
1
)
next_tokens
=
torch
.
multinomial
(
probs
,
num_samples
=
n_tokens_to_keep
)
next_token_scores
=
torch
.
gather
(
next_token_scores
,
-
1
,
next_tokens
)
next_token_scores
,
_indices
=
torch
.
sort
(
next_token_scores
,
descending
=
True
,
dim
=
1
)
next_tokens
=
torch
.
gather
(
next_tokens
,
-
1
,
_indices
)
else
:
next_token_scores
,
next_tokens
=
torch
.
topk
(
next_token_scores
,
n_tokens_to_keep
,
dim
=
1
,
largest
=
True
,
sorted
=
True
)
next_indices
=
torch
.
div
(
next_tokens
,
vocab_size
,
rounding_mode
=
"floor"
)
next_tokens
=
next_tokens
%
vocab_size
...
...
@@ -2897,219 +2766,24 @@ class GenerationMixin:
**
model_kwargs
,
)
->
Union
[
GenerateBeamOutput
,
torch
.
LongTensor
]:
r
"""
Generates sequences of token ids for models with a language modeling head using **beam search multinomial
sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
beam_scorer (`BeamScorer`):
A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
logits_processor (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
stopping_criteria (`StoppingCriteriaList`):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
logits_warper (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
to warp the prediction score distribution of the language modeling head applied before multinomial
sampling at each generation step.
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
Return:
[`~generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
[`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
`model.config.is_encoder_decoder=True`.
Deprecated. Use `._beam_search()` instead, passing the same arguments.
"""
# init values
pad_token_id
=
generation_config
.
pad_token_id
eos_token_id
=
generation_config
.
eos_token_id
output_attentions
=
generation_config
.
output_attentions
output_hidden_states
=
generation_config
.
output_hidden_states
output_scores
=
generation_config
.
output_scores
output_logits
=
generation_config
.
output_logits
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
batch_size
=
len
(
beam_scorer
.
_beam_hyps
)
num_beams
=
beam_scorer
.
num_beams
batch_beam_size
,
cur_len
=
input_ids
.
shape
model_kwargs
=
self
.
_get_initial_cache_position
(
input_ids
,
model_kwargs
)
# init attention / hidden states / scores tuples
scores
=
()
if
(
return_dict_in_generate
and
output_scores
)
else
None
raw_logits
=
()
if
(
return_dict_in_generate
and
output_logits
)
else
None
beam_indices
=
(
tuple
(()
for
_
in
range
(
batch_beam_size
))
if
(
return_dict_in_generate
and
output_scores
)
else
None
logger
.
warning_once
(
"Calling `._beam_sample()` directly is deprecated and will be removed in v4.42. Use `._beam_search()` "
"instead, passing the same arguments."
)
decoder_attentions
=
()
if
(
return_dict_in_generate
and
output_attentions
)
else
None
cross_attentions
=
()
if
(
return_dict_in_generate
and
output_attentions
)
else
None
decoder_hidden_states
=
()
if
(
return_dict_in_generate
and
output_hidden_states
)
else
None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if
return_dict_in_generate
and
self
.
config
.
is_encoder_decoder
:
encoder_attentions
=
model_kwargs
[
"encoder_outputs"
].
get
(
"attentions"
)
if
output_attentions
else
None
encoder_hidden_states
=
(
model_kwargs
[
"encoder_outputs"
].
get
(
"hidden_states"
)
if
output_hidden_states
else
None
)
beam_scores
=
torch
.
zeros
((
batch_size
,
num_beams
),
dtype
=
torch
.
float
,
device
=
input_ids
.
device
)
beam_scores
=
beam_scores
.
view
((
batch_size
*
num_beams
,))
this_peer_finished
=
False
decoder_prompt_len
=
input_ids
.
shape
[
-
1
]
# record the prompt length of decoder
while
self
.
_has_unfinished_sequences
(
this_peer_finished
,
synced_gpus
,
device
=
input_ids
.
device
):
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
**
model_kwargs
)
outputs
=
self
(
**
model_inputs
,
return_dict
=
True
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
)
if
synced_gpus
and
this_peer_finished
:
cur_len
=
cur_len
+
1
continue
# don't waste resources running the code we don't need
next_token_logits
=
outputs
.
logits
[:,
-
1
,
:]
next_token_scores
=
nn
.
functional
.
log_softmax
(
next_token_logits
,
dim
=-
1
)
# (batch_size * num_beams, vocab_size)
next_token_scores_processed
=
logits_processor
(
input_ids
,
next_token_scores
)
next_token_scores_processed
=
logits_warper
(
input_ids
,
next_token_scores_processed
)
next_token_scores
=
next_token_scores_processed
+
beam_scores
[:,
None
].
expand_as
(
next_token_scores_processed
)
# Store scores, attentions and hidden_states when required
if
return_dict_in_generate
:
if
output_scores
:
scores
+=
(
next_token_scores_processed
,)
if
output_logits
:
raw_logits
+=
(
next_token_logits
,)
if
output_attentions
:
decoder_attentions
+=
(
(
outputs
.
decoder_attentions
,)
if
self
.
config
.
is_encoder_decoder
else
(
outputs
.
attentions
,)
)
if
self
.
config
.
is_encoder_decoder
:
cross_attentions
+=
(
outputs
.
cross_attentions
,)
if
output_hidden_states
:
decoder_hidden_states
+=
(
(
outputs
.
decoder_hidden_states
,)
if
self
.
config
.
is_encoder_decoder
else
(
outputs
.
hidden_states
,)
)
# reshape for beam search
vocab_size
=
next_token_scores
.
shape
[
-
1
]
next_token_scores
=
next_token_scores
.
view
(
batch_size
,
num_beams
*
vocab_size
)
probs
=
nn
.
functional
.
softmax
(
next_token_scores
,
dim
=-
1
)
next_tokens
=
torch
.
multinomial
(
probs
,
num_samples
=
2
*
num_beams
)
next_token_scores
=
torch
.
gather
(
next_token_scores
,
-
1
,
next_tokens
)
next_token_scores
,
_indices
=
torch
.
sort
(
next_token_scores
,
descending
=
True
,
dim
=
1
)
next_tokens
=
torch
.
gather
(
next_tokens
,
-
1
,
_indices
)
next_indices
=
torch
.
div
(
next_tokens
,
vocab_size
,
rounding_mode
=
"floor"
)
next_tokens
=
next_tokens
%
vocab_size
# stateless
beam_outputs
=
beam_scorer
.
process
(
input_ids
,
next_token_scores
,
next_tokens
,
next_indices
,
pad_token_id
=
pad_token_id
,
eos_token_id
=
eos_token_id
,
beam_indices
=
beam_indices
,
decoder_prompt_len
=
decoder_prompt_len
,
)
beam_scores
=
beam_outputs
[
"next_beam_scores"
]
beam_next_tokens
=
beam_outputs
[
"next_beam_tokens"
]
beam_idx
=
beam_outputs
[
"next_beam_indices"
]
input_ids
=
torch
.
cat
([
input_ids
[
beam_idx
,
:],
beam_next_tokens
.
unsqueeze
(
-
1
)],
dim
=-
1
)
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
)
if
model_kwargs
.
get
(
"past_key_values"
,
None
)
is
not
None
:
model_kwargs
[
"past_key_values"
]
=
self
.
_temporary_reorder_cache
(
model_kwargs
[
"past_key_values"
],
beam_idx
)
if
return_dict_in_generate
and
output_scores
:
beam_indices
=
tuple
((
beam_indices
[
beam_idx
[
i
]]
+
(
beam_idx
[
i
],)
for
i
in
range
(
len
(
beam_indices
))))
# increase cur_len
cur_len
=
cur_len
+
1
if
beam_scorer
.
is_done
or
all
(
stopping_criteria
(
input_ids
,
scores
)):
this_peer_finished
=
True
sequence_outputs
=
beam_scorer
.
finalize
(
input_ids
,
beam_scores
,
next_tokens
,
next_indices
,
pad_token_id
=
pad_token_id
,
eos_token_id
=
eos_token_id
,
max_length
=
stopping_criteria
.
max_length
,
beam_indices
=
beam_indices
,
decoder_prompt_len
=
decoder_prompt_len
,
return
self
.
_beam_search
(
input_ids
=
input_ids
,
beam_scorer
=
beam_scorer
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
logits_warper
=
logits_warper
,
generation_config
=
generation_config
,
synced_gpus
=
synced_gpus
,
**
model_kwargs
,
)
if
return_dict_in_generate
:
if
not
output_scores
:
sequence_outputs
[
"sequence_scores"
]
=
None
if
self
.
config
.
is_encoder_decoder
:
return
GenerateBeamEncoderDecoderOutput
(
sequences
=
sequence_outputs
[
"sequences"
],
sequences_scores
=
sequence_outputs
[
"sequence_scores"
],
scores
=
scores
,
logits
=
raw_logits
,
beam_indices
=
sequence_outputs
[
"beam_indices"
],
encoder_attentions
=
encoder_attentions
,
encoder_hidden_states
=
encoder_hidden_states
,
decoder_attentions
=
decoder_attentions
,
cross_attentions
=
cross_attentions
,
decoder_hidden_states
=
decoder_hidden_states
,
past_key_values
=
model_kwargs
.
get
(
"past_key_values"
),
)
else
:
return
GenerateBeamDecoderOnlyOutput
(
sequences
=
sequence_outputs
[
"sequences"
],
sequences_scores
=
sequence_outputs
[
"sequence_scores"
],
scores
=
scores
,
logits
=
raw_logits
,
beam_indices
=
sequence_outputs
[
"beam_indices"
],
attentions
=
decoder_attentions
,
hidden_states
=
decoder_hidden_states
,
past_key_values
=
model_kwargs
.
get
(
"past_key_values"
),
)
else
:
return
sequence_outputs
[
"sequences"
]
def
_group_beam_search
(
self
,
input_ids
:
torch
.
LongTensor
,
...
...
src/transformers/models/musicgen/modeling_musicgen.py
View file @
de2f7221
...
...
@@ -1739,7 +1739,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
)
# 11. run greedy search
outputs
=
self
.
_
greedy_search
(
outputs
=
self
.
_
sample
(
input_ids
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
...
...
@@ -2832,7 +2832,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
)
# 11. run greedy search
outputs
=
self
.
_
greedy_search
(
outputs
=
self
.
_
sample
(
input_ids
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
...
...
src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
View file @
de2f7221
...
...
@@ -1676,7 +1676,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
)
# 11. run greedy search
outputs
=
self
.
_
greedy_search
(
outputs
=
self
.
_
sample
(
input_ids
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
...
...
@@ -2691,7 +2691,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
)
# 11. run greedy search
outputs
=
self
.
_
greedy_search
(
outputs
=
self
.
_
sample
(
input_ids
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
...
...
src/transformers/models/rag/modeling_rag.py
View file @
de2f7221
...
...
@@ -1550,7 +1550,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
f
"num_return_sequences has to be 1, but is
{
generation_config
.
num_return_sequences
}
when doing"
" greedy search."
)
return
self
.
_
greedy_search
(
return
self
.
_
sample
(
input_ids
,
logits_processor
=
pre_processor
,
stopping_criteria
=
prepared_stopping_criteria
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment