Unverified Commit d57ffb48 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: remove deprecated public decoding functions and streamline logic 🧼 (#29956)

parent dc401d3a
...@@ -123,7 +123,7 @@ class AssistedCandidateGenerator(CandidateGenerator): ...@@ -123,7 +123,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_kwargs inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_kwargs
) )
assistant_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation( assistant_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation(
inputs_tensor, assistant_kwargs, model_input_name inputs_tensor, assistant_kwargs, model_input_name, assistant_model.generation_config
) )
elif "encoder_outputs" in model_kwargs: elif "encoder_outputs" in model_kwargs:
assistant_kwargs["encoder_outputs"] = model_kwargs["encoder_outputs"] assistant_kwargs["encoder_outputs"] = model_kwargs["encoder_outputs"]
......
...@@ -65,25 +65,16 @@ class GenerationConfig(PushToHubMixin): ...@@ -65,25 +65,16 @@ class GenerationConfig(PushToHubMixin):
Class that holds a configuration for a generation task. A `generate` call supports the following generation methods Class that holds a configuration for a generation task. A `generate` call supports the following generation methods
for text-decoder, text-to-text, speech-to-text, and vision-to-text models: for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
- *greedy decoding* by calling [`~generation.GenerationMixin._greedy_search`] if `num_beams=1` and - *greedy decoding* if `num_beams=1` and `do_sample=False`
`do_sample=False` - *contrastive search* if `penalty_alpha>0.` and `top_k>1`
- *contrastive search* by calling [`~generation.GenerationMixin._contrastive_search`] if `penalty_alpha>0.` - *multinomial sampling* if `num_beams=1` and `do_sample=True`
and `top_k>1` - *beam-search decoding* if `num_beams>1` and `do_sample=False`
- *multinomial sampling* by calling [`~generation.GenerationMixin._sample`] if `num_beams=1` and - *beam-search multinomial sampling* if `num_beams>1` and `do_sample=True`
`do_sample=True` - *diverse beam-search decoding* if `num_beams>1` and `num_beam_groups>1`
- *beam-search decoding* by calling [`~generation.GenerationMixin._beam_search`] if `num_beams>1` and - *constrained beam-search decoding* if `constraints!=None` or `force_words_ids!=None`
`do_sample=False` - *assisted decoding* if `assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin._beam_sample`] if
`num_beams>1` and `do_sample=True` To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin._group_beam_search`], if
`num_beams>1` and `num_beam_groups>1`
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin._constrained_beam_search`], if
`constraints!=None` or `force_words_ids!=None`
- *assisted decoding* by calling [`~generation.GenerationMixin._assisted_decoding`], if
`assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
You do not need to call any of the above methods directly. Pass custom parameter values to '.generate()'. To learn
more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
<Tip> <Tip>
......
This diff is collapsed.
...@@ -1650,8 +1650,6 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel): ...@@ -1650,8 +1650,6 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
batch_size = input_ids.shape[0] // self.num_codebooks batch_size = input_ids.shape[0] // self.num_codebooks
# 4. Define other model kwargs # 4. Define other model kwargs
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache model_kwargs["use_cache"] = generation_config.use_cache
model_kwargs["guidance_scale"] = generation_config.guidance_scale model_kwargs["guidance_scale"] = generation_config.guidance_scale
...@@ -1748,10 +1746,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel): ...@@ -1748,10 +1746,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
pad_token_id=generation_config.pad_token_id, generation_config=generation_config,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus, synced_gpus=synced_gpus,
streamer=streamer, streamer=streamer,
**model_kwargs, **model_kwargs,
...@@ -1774,10 +1769,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel): ...@@ -1774,10 +1769,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
logits_processor=logits_processor, logits_processor=logits_processor,
logits_warper=logits_warper, logits_warper=logits_warper,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
pad_token_id=generation_config.pad_token_id, generation_config=generation_config,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus, synced_gpus=synced_gpus,
streamer=streamer, streamer=streamer,
**model_kwargs, **model_kwargs,
...@@ -2423,8 +2415,8 @@ class MusicgenForConditionalGeneration(PreTrainedModel): ...@@ -2423,8 +2415,8 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
self, self,
inputs_tensor: torch.Tensor, inputs_tensor: torch.Tensor,
model_kwargs, model_kwargs,
model_input_name: Optional[str] = None, model_input_name: Optional[str],
guidance_scale: Optional[float] = None, generation_config: GenerationConfig,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
# 1. get text encoder # 1. get text encoder
encoder = self.get_text_encoder() encoder = self.get_text_encoder()
...@@ -2446,6 +2438,9 @@ class MusicgenForConditionalGeneration(PreTrainedModel): ...@@ -2446,6 +2438,9 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
encoder_kwargs = { encoder_kwargs = {
argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
} }
encoder_kwargs["output_attentions"] = generation_config.output_attentions
encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states
guidance_scale = generation_config.guidance_scale
# 3. make sure that encoder returns `ModelOutput` # 3. make sure that encoder returns `ModelOutput`
model_input_name = model_input_name if model_input_name is not None else self.text_encoder.main_input_name model_input_name = model_input_name if model_input_name is not None else self.text_encoder.main_input_name
...@@ -2708,8 +2703,6 @@ class MusicgenForConditionalGeneration(PreTrainedModel): ...@@ -2708,8 +2703,6 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
batch_size = inputs_tensor.shape[0] batch_size = inputs_tensor.shape[0]
# 4. Define other model kwargs # 4. Define other model kwargs
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache model_kwargs["use_cache"] = generation_config.use_cache
model_kwargs["guidance_scale"] = generation_config.guidance_scale model_kwargs["guidance_scale"] = generation_config.guidance_scale
...@@ -2723,10 +2716,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel): ...@@ -2723,10 +2716,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
if "encoder_outputs" not in model_kwargs: if "encoder_outputs" not in model_kwargs:
# encoder_outputs are created and added to `model_kwargs` # encoder_outputs are created and added to `model_kwargs`
model_kwargs = self._prepare_text_encoder_kwargs_for_generation( model_kwargs = self._prepare_text_encoder_kwargs_for_generation(
inputs_tensor, inputs_tensor, model_kwargs, model_input_name, generation_config
model_kwargs,
model_input_name,
guidance_scale=generation_config.guidance_scale,
) )
if "decoder_input_ids" not in model_kwargs and "input_values" in model_kwargs: if "decoder_input_ids" not in model_kwargs and "input_values" in model_kwargs:
...@@ -2831,10 +2821,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel): ...@@ -2831,10 +2821,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
pad_token_id=generation_config.pad_token_id, generation_config=generation_config,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus, synced_gpus=synced_gpus,
streamer=streamer, streamer=streamer,
**model_kwargs, **model_kwargs,
...@@ -2858,10 +2845,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel): ...@@ -2858,10 +2845,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
logits_processor=logits_processor, logits_processor=logits_processor,
logits_warper=logits_warper, logits_warper=logits_warper,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
pad_token_id=generation_config.pad_token_id, generation_config=generation_config,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus, synced_gpus=synced_gpus,
streamer=streamer, streamer=streamer,
**model_kwargs, **model_kwargs,
......
...@@ -1586,8 +1586,6 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel): ...@@ -1586,8 +1586,6 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
batch_size = input_ids.shape[0] // self.num_codebooks batch_size = input_ids.shape[0] // self.num_codebooks
# 4. Define other model kwargs # 4. Define other model kwargs
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache model_kwargs["use_cache"] = generation_config.use_cache
model_kwargs["guidance_scale"] = generation_config.guidance_scale model_kwargs["guidance_scale"] = generation_config.guidance_scale
...@@ -1684,10 +1682,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel): ...@@ -1684,10 +1682,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
pad_token_id=generation_config.pad_token_id, generation_config=generation_config,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus, synced_gpus=synced_gpus,
streamer=streamer, streamer=streamer,
**model_kwargs, **model_kwargs,
...@@ -1710,10 +1705,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel): ...@@ -1710,10 +1705,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
logits_processor=logits_processor, logits_processor=logits_processor,
logits_warper=logits_warper, logits_warper=logits_warper,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
pad_token_id=generation_config.pad_token_id, generation_config=generation_config,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus, synced_gpus=synced_gpus,
streamer=streamer, streamer=streamer,
**model_kwargs, **model_kwargs,
...@@ -2318,12 +2310,13 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel): ...@@ -2318,12 +2310,13 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
self, self,
inputs_tensor: torch.Tensor, inputs_tensor: torch.Tensor,
model_kwargs, model_kwargs,
model_input_name: Optional[str] = None, model_input_name: Optional[str],
guidance_scale: Optional[float] = None, generation_config: GenerationConfig,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
encoder_hidden_states = None encoder_hidden_states = None
# attention mask is consumed once to produce text conditional hidden states through the text encoder # attention mask is consumed once to produce text conditional hidden states through the text encoder
encoder_attention_mask = model_kwargs.pop("attention_mask") encoder_attention_mask = model_kwargs.pop("attention_mask")
guidance_scale = generation_config.guidance_scale
# 1. condition on text # 1. condition on text
if inputs_tensor is not None: if inputs_tensor is not None:
...@@ -2346,6 +2339,8 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel): ...@@ -2346,6 +2339,8 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
encoder_kwargs = { encoder_kwargs = {
argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
} }
encoder_kwargs["output_attentions"] = generation_config.output_attentions
encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states
# make sure that encoder returns `ModelOutput` # make sure that encoder returns `ModelOutput`
model_input_name = model_input_name if model_input_name is not None else self.text_encoder.main_input_name model_input_name = model_input_name if model_input_name is not None else self.text_encoder.main_input_name
...@@ -2572,8 +2567,6 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel): ...@@ -2572,8 +2567,6 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
batch_size = inputs_tensor.shape[0] batch_size = inputs_tensor.shape[0]
# 4. Define other model kwargs # 4. Define other model kwargs
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache model_kwargs["use_cache"] = generation_config.use_cache
model_kwargs["guidance_scale"] = generation_config.guidance_scale model_kwargs["guidance_scale"] = generation_config.guidance_scale
...@@ -2585,10 +2578,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel): ...@@ -2585,10 +2578,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
if "encoder_hidden_states" not in model_kwargs: if "encoder_hidden_states" not in model_kwargs:
# encoder_hidden_states are created and added to `model_kwargs` # encoder_hidden_states are created and added to `model_kwargs`
model_kwargs = self._prepare_encoder_hidden_states_kwargs_for_generation( model_kwargs = self._prepare_encoder_hidden_states_kwargs_for_generation(
inputs_tensor, inputs_tensor, model_kwargs, model_input_name, generation_config
model_kwargs,
model_input_name,
guidance_scale=generation_config.guidance_scale,
) )
# 5. Prepare `input_ids` which will be used for auto-regressive generation # 5. Prepare `input_ids` which will be used for auto-regressive generation
...@@ -2684,14 +2674,11 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel): ...@@ -2684,14 +2674,11 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
) )
# 11. run greedy search # 11. run greedy search
outputs = self.greedy_search( outputs = self._greedy_search(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
pad_token_id=generation_config.pad_token_id, generation_config=generation_config,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus, synced_gpus=synced_gpus,
streamer=streamer, streamer=streamer,
**model_kwargs, **model_kwargs,
...@@ -2710,15 +2697,12 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel): ...@@ -2710,15 +2697,12 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
) )
# 12. run sample # 12. run sample
outputs = self.sample( outputs = self._sample(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
logits_warper=logits_warper, logits_warper=logits_warper,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
pad_token_id=generation_config.pad_token_id, generation_config=generation_config,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus, synced_gpus=synced_gpus,
streamer=streamer, streamer=streamer,
**model_kwargs, **model_kwargs,
......
...@@ -1537,6 +1537,10 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1537,6 +1537,10 @@ class RagTokenForGeneration(RagPreTrainedModel):
logits_processor=logits_processor, logits_processor=logits_processor,
) )
prepared_stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria
)
if generation_config.num_beams == 1: if generation_config.num_beams == 1:
if generation_config.num_return_sequences > 1: if generation_config.num_return_sequences > 1:
raise ValueError( raise ValueError(
...@@ -1546,9 +1550,10 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1546,9 +1550,10 @@ class RagTokenForGeneration(RagPreTrainedModel):
return self._greedy_search( return self._greedy_search(
input_ids, input_ids,
logits_processor=pre_processor, logits_processor=pre_processor,
max_length=generation_config.max_length, stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id, generation_config=generation_config,
eos_token_id=generation_config.eos_token_id, synced_gpus=False,
streamer=None,
**model_kwargs, **model_kwargs,
) )
elif generation_config.num_beams > 1: elif generation_config.num_beams > 1:
...@@ -1567,9 +1572,9 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1567,9 +1572,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
input_ids, input_ids,
beam_scorer, beam_scorer,
logits_processor=pre_processor, logits_processor=pre_processor,
max_length=generation_config.max_length, stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id, generation_config=generation_config,
eos_token_id=generation_config.eos_token_id, synced_gpus=False,
**model_kwargs, **model_kwargs,
) )
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment