Unverified Commit de2f7221 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: remove near-duplicate sample/greedy copy (#30773)

parent ce87dca1
...@@ -1683,17 +1683,6 @@ class GenerationMixin: ...@@ -1683,17 +1683,6 @@ class GenerationMixin:
streamer=streamer, streamer=streamer,
**model_kwargs, **model_kwargs,
) )
if generation_mode == GenerationMode.GREEDY_SEARCH:
# 11. run greedy search
result = self._greedy_search(
input_ids,
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
)
elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH: elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
if not model_kwargs["use_cache"]: if not model_kwargs["use_cache"]:
...@@ -1709,9 +1698,11 @@ class GenerationMixin: ...@@ -1709,9 +1698,11 @@ class GenerationMixin:
**model_kwargs, **model_kwargs,
) )
elif generation_mode == GenerationMode.SAMPLE: elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
# 11. prepare logits warper # 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config) prepared_logits_warper = (
self._get_logits_warper(generation_config) if generation_config.do_sample else None
)
# 12. expand input_ids with `num_return_sequences` additional sequences per batch # 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, model_kwargs = self._expand_inputs_for_generation(
...@@ -1721,11 +1712,11 @@ class GenerationMixin: ...@@ -1721,11 +1712,11 @@ class GenerationMixin:
**model_kwargs, **model_kwargs,
) )
# 13. run sample # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
result = self._sample( result = self._sample(
input_ids, input_ids,
logits_processor=prepared_logits_processor, logits_processor=prepared_logits_processor,
logits_warper=logits_warper, logits_warper=prepared_logits_warper,
stopping_criteria=prepared_stopping_criteria, stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config, generation_config=generation_config,
synced_gpus=synced_gpus, synced_gpus=synced_gpus,
...@@ -1733,38 +1724,11 @@ class GenerationMixin: ...@@ -1733,38 +1724,11 @@ class GenerationMixin:
**model_kwargs, **model_kwargs,
) )
elif generation_mode == GenerationMode.BEAM_SEARCH: elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
# 11. prepare beam search scorer
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=generation_config.num_beams,
device=inputs_tensor.device,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
max_length=generation_config.max_length,
)
# 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# 13. run beam search
result = self._beam_search(
input_ids,
beam_scorer,
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config,
synced_gpus=synced_gpus,
**model_kwargs,
)
elif generation_mode == GenerationMode.BEAM_SAMPLE:
# 11. prepare logits warper # 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config) prepared_logits_warper = (
self._get_logits_warper(generation_config) if generation_config.do_sample else None
)
# 12. prepare beam search scorer # 12. prepare beam search scorer
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
...@@ -1786,11 +1750,11 @@ class GenerationMixin: ...@@ -1786,11 +1750,11 @@ class GenerationMixin:
) )
# 14. run beam sample # 14. run beam sample
result = self._beam_sample( result = self._beam_search(
input_ids, input_ids,
beam_scorer, beam_scorer,
logits_processor=prepared_logits_processor, logits_processor=prepared_logits_processor,
logits_warper=logits_warper, logits_warper=prepared_logits_warper,
stopping_criteria=prepared_stopping_criteria, stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config, generation_config=generation_config,
synced_gpus=synced_gpus, synced_gpus=synced_gpus,
...@@ -2284,162 +2248,32 @@ class GenerationMixin: ...@@ -2284,162 +2248,32 @@ class GenerationMixin:
**model_kwargs, **model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]: ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
r""" r"""
Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be Deprecated. Use `._sample()` instead, passing the same arguments.
used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
logits_processor (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
stopping_criteria (`StoppingCriteriaList`):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
model_kwargs:
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
Return:
[`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
[`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
`model.config.is_encoder_decoder=True`.
""" """
# init values
pad_token_id = generation_config.pad_token_id
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
# init attention / hidden states / scores tuples logger.warning_once(
raw_logits = () if (return_dict_in_generate and output_logits) else None "Calling `._greedy_search()` directly is deprecated and will be removed in v4.42. Use `._sample()` "
scores = () if (return_dict_in_generate and output_scores) else None "instead, passing the same arguments."
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None )
cross_attentions = () if (return_dict_in_generate and output_attentions) else None return self._sample(
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None input_ids=input_ids,
logits_processor=logits_processor,
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states stopping_criteria=stopping_criteria,
if return_dict_in_generate and self.config.is_encoder_decoder: generation_config=generation_config,
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None synced_gpus=synced_gpus,
encoder_hidden_states = ( streamer=streamer,
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None **model_kwargs,
) )
# keep track of which sequences are already finished
batch_size = input_ids.shape[0]
this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
next_tokens_scores = logits_processor(input_ids, next_token_logits)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_tokens_scores,)
if output_logits:
raw_logits += (next_token_logits,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
# argmax
next_tokens = torch.argmax(next_tokens_scores, dim=-1)
# finished sentences should have their next token be a padding token
if has_eos_stopping_criteria:
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
if streamer is not None:
streamer.end()
if return_dict_in_generate:
if self.config.is_encoder_decoder:
return GenerateEncoderDecoderOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return GenerateDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return input_ids
def _sample( def _sample(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
logits_processor: LogitsProcessorList, logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList, stopping_criteria: StoppingCriteriaList,
logits_warper: LogitsProcessorList,
generation_config: GenerationConfig, generation_config: GenerationConfig,
synced_gpus: bool, synced_gpus: bool,
streamer: Optional["BaseStreamer"], streamer: Optional["BaseStreamer"],
logits_warper: Optional[LogitsProcessorList] = None,
**model_kwargs, **model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]: ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
r""" r"""
...@@ -2455,10 +2289,6 @@ class GenerationMixin: ...@@ -2455,10 +2289,6 @@ class GenerationMixin:
stopping_criteria (`StoppingCriteriaList`): stopping_criteria (`StoppingCriteriaList`):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop. used to tell if the generation loop should stop.
logits_warper (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
to warp the prediction score distribution of the language modeling head applied before multinomial
sampling at each generation step.
generation_config ([`~generation.GenerationConfig`]): generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method. The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`): synced_gpus (`bool`):
...@@ -2466,6 +2296,11 @@ class GenerationMixin: ...@@ -2466,6 +2296,11 @@ class GenerationMixin:
streamer (`BaseStreamer`, *optional*): streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing. through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
logits_warper (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
to warp the prediction score distribution of the language modeling head applied before multinomial
sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in
`generation_config`)
model_kwargs: model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`. an encoder-decoder model the kwargs should include `encoder_outputs`.
...@@ -2485,6 +2320,12 @@ class GenerationMixin: ...@@ -2485,6 +2320,12 @@ class GenerationMixin:
output_logits = generation_config.output_logits output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate return_dict_in_generate = generation_config.return_dict_in_generate
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
do_sample = generation_config.do_sample
if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
raise ValueError(
"`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
f"{logits_warper})."
)
# init attention / hidden states / scores tuples # init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None scores = () if (return_dict_in_generate and output_scores) else None
...@@ -2525,7 +2366,8 @@ class GenerationMixin: ...@@ -2525,7 +2366,8 @@ class GenerationMixin:
# pre-process distribution # pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits) next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores) if do_sample:
next_token_scores = logits_warper(input_ids, next_token_scores)
# Store scores, attentions and hidden_states when required # Store scores, attentions and hidden_states when required
if return_dict_in_generate: if return_dict_in_generate:
...@@ -2547,9 +2389,12 @@ class GenerationMixin: ...@@ -2547,9 +2389,12 @@ class GenerationMixin:
else (outputs.hidden_states,) else (outputs.hidden_states,)
) )
# sample # token selection
probs = nn.functional.softmax(next_token_scores, dim=-1) if do_sample:
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(next_token_scores, dim=-1)
# finished sentences should have their next token be a padding token # finished sentences should have their next token be a padding token
if has_eos_stopping_criteria: if has_eos_stopping_criteria:
...@@ -2622,6 +2467,7 @@ class GenerationMixin: ...@@ -2622,6 +2467,7 @@ class GenerationMixin:
past_key_values.reorder_cache(beam_idx) past_key_values.reorder_cache(beam_idx)
return past_key_values return past_key_values
# TODO (joao, v4.42): remove default for `logits_warper`
def _beam_search( def _beam_search(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
...@@ -2630,6 +2476,7 @@ class GenerationMixin: ...@@ -2630,6 +2476,7 @@ class GenerationMixin:
stopping_criteria: StoppingCriteriaList, stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig, generation_config: GenerationConfig,
synced_gpus: bool, synced_gpus: bool,
logits_warper: Optional[LogitsProcessorList] = None,
**model_kwargs, **model_kwargs,
) -> Union[GenerateBeamOutput, torch.LongTensor]: ) -> Union[GenerateBeamOutput, torch.LongTensor]:
r""" r"""
...@@ -2652,6 +2499,11 @@ class GenerationMixin: ...@@ -2652,6 +2499,11 @@ class GenerationMixin:
The generation configuration to be used as parametrization of the decoding method. The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`): synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3) Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
logits_warper (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
to warp the prediction score distribution of the language modeling head applied before multinomial
sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in
`generation_config`)
model_kwargs: model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`. an encoder-decoder model the kwargs should include `encoder_outputs`.
...@@ -2672,6 +2524,12 @@ class GenerationMixin: ...@@ -2672,6 +2524,12 @@ class GenerationMixin:
output_logits = generation_config.output_logits output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate return_dict_in_generate = generation_config.return_dict_in_generate
sequential = generation_config.low_memory sequential = generation_config.low_memory
do_sample = generation_config.do_sample
if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
raise ValueError(
"`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
f"{logits_warper})."
)
batch_size = len(beam_scorer._beam_hyps) batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams num_beams = beam_scorer.num_beams
...@@ -2768,6 +2626,8 @@ class GenerationMixin: ...@@ -2768,6 +2626,8 @@ class GenerationMixin:
) # (batch_size * num_beams, vocab_size) ) # (batch_size * num_beams, vocab_size)
next_token_scores_processed = logits_processor(input_ids, next_token_scores) next_token_scores_processed = logits_processor(input_ids, next_token_scores)
if do_sample:
next_token_scores_processed = logits_warper(input_ids, next_token_scores_processed)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
next_token_scores_processed next_token_scores_processed
) )
...@@ -2795,11 +2655,20 @@ class GenerationMixin: ...@@ -2795,11 +2655,20 @@ class GenerationMixin:
vocab_size = next_token_scores.shape[-1] vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
# Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam. # Beam token selection: pick 1 + eos_token_id.shape[0] next tokens for each beam so we have at least 1
# non eos token per beam.
n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0 n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
next_token_scores, next_tokens = torch.topk( n_tokens_to_keep = max(2, 1 + n_eos_tokens) * num_beams
next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True if do_sample:
) probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=n_tokens_to_keep)
next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
next_tokens = torch.gather(next_tokens, -1, _indices)
else:
next_token_scores, next_tokens = torch.topk(
next_token_scores, n_tokens_to_keep, dim=1, largest=True, sorted=True
)
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
next_tokens = next_tokens % vocab_size next_tokens = next_tokens % vocab_size
...@@ -2897,219 +2766,24 @@ class GenerationMixin: ...@@ -2897,219 +2766,24 @@ class GenerationMixin:
**model_kwargs, **model_kwargs,
) -> Union[GenerateBeamOutput, torch.LongTensor]: ) -> Union[GenerateBeamOutput, torch.LongTensor]:
r""" r"""
Generates sequences of token ids for models with a language modeling head using **beam search multinomial Deprecated. Use `._beam_search()` instead, passing the same arguments.
sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
beam_scorer (`BeamScorer`):
A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
logits_processor (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
stopping_criteria (`StoppingCriteriaList`):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
logits_warper (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
to warp the prediction score distribution of the language modeling head applied before multinomial
sampling at each generation step.
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
Return:
[`~generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
[`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
`model.config.is_encoder_decoder=True`.
""" """
# init values
pad_token_id = generation_config.pad_token_id
eos_token_id = generation_config.eos_token_id
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
batch_size = len(beam_scorer._beam_hyps) logger.warning_once(
num_beams = beam_scorer.num_beams "Calling `._beam_sample()` directly is deprecated and will be removed in v4.42. Use `._beam_search()` "
"instead, passing the same arguments."
batch_beam_size, cur_len = input_ids.shape
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
beam_indices = (
tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
) )
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None return self._beam_search(
cross_attentions = () if (return_dict_in_generate and output_attentions) else None input_ids=input_ids,
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None beam_scorer=beam_scorer,
logits_processor=logits_processor,
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states stopping_criteria=stopping_criteria,
if return_dict_in_generate and self.config.is_encoder_decoder: logits_warper=logits_warper,
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None generation_config=generation_config,
encoder_hidden_states = ( synced_gpus=synced_gpus,
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None **model_kwargs,
)
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
next_token_scores_processed = logits_warper(input_ids, next_token_scores_processed)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
next_token_scores_processed
)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores_processed,)
if output_logits:
raw_logits += (next_token_logits,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
# reshape for beam search
vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
next_tokens = torch.gather(next_tokens, -1, _indices)
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
next_tokens = next_tokens % vocab_size
# stateless
beam_outputs = beam_scorer.process(
input_ids,
next_token_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=beam_indices,
decoder_prompt_len=decoder_prompt_len,
)
beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
model_kwargs["past_key_values"], beam_idx
)
if return_dict_in_generate and output_scores:
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
# increase cur_len
cur_len = cur_len + 1
if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
this_peer_finished = True
sequence_outputs = beam_scorer.finalize(
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=beam_indices,
decoder_prompt_len=decoder_prompt_len,
) )
if return_dict_in_generate:
if not output_scores:
sequence_outputs["sequence_scores"] = None
if self.config.is_encoder_decoder:
return GenerateBeamEncoderDecoderOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
logits=raw_logits,
beam_indices=sequence_outputs["beam_indices"],
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return GenerateBeamDecoderOnlyOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
logits=raw_logits,
beam_indices=sequence_outputs["beam_indices"],
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return sequence_outputs["sequences"]
def _group_beam_search( def _group_beam_search(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
......
...@@ -1739,7 +1739,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel): ...@@ -1739,7 +1739,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
) )
# 11. run greedy search # 11. run greedy search
outputs = self._greedy_search( outputs = self._sample(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
...@@ -2832,7 +2832,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel): ...@@ -2832,7 +2832,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
) )
# 11. run greedy search # 11. run greedy search
outputs = self._greedy_search( outputs = self._sample(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
......
...@@ -1676,7 +1676,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel): ...@@ -1676,7 +1676,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
) )
# 11. run greedy search # 11. run greedy search
outputs = self._greedy_search( outputs = self._sample(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
...@@ -2691,7 +2691,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel): ...@@ -2691,7 +2691,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
) )
# 11. run greedy search # 11. run greedy search
outputs = self._greedy_search( outputs = self._sample(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
......
...@@ -1550,7 +1550,7 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1550,7 +1550,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
" greedy search." " greedy search."
) )
return self._greedy_search( return self._sample(
input_ids, input_ids,
logits_processor=pre_processor, logits_processor=pre_processor,
stopping_criteria=prepared_stopping_criteria, stopping_criteria=prepared_stopping_criteria,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment