"...resnet50_tensorflow.git" did not exist on "f047d65958f0b07f9b178eabbbcb70a3cc5374b8"
Unverified Commit 87a0783d authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: inner decoding methods are no longer public (#29437)

parent 4d892b72
...@@ -389,3 +389,6 @@ just like in multinomial sampling. However, in assisted decoding, reducing the t ...@@ -389,3 +389,6 @@ just like in multinomial sampling. However, in assisted decoding, reducing the t
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Alice and Bob are going to the same party. It is a small party, in a small'] ['Alice and Bob are going to the same party. It is a small party, in a small']
``` ```
Alternativelly, you can also set the `prompt_lookup_num_tokens` to trigger n-gram based assisted decoding, as opposed
to model based assisted decoding. You can read more about it [here](https://twitter.com/joao_gante/status/1747322413006643259).
...@@ -16,16 +16,7 @@ rendered properly in your Markdown viewer. ...@@ -16,16 +16,7 @@ rendered properly in your Markdown viewer.
# Utilities for Generation # Utilities for Generation
This page lists all the utility functions used by [`~generation.GenerationMixin.generate`], This page lists all the utility functions used by [`~generation.GenerationMixin.generate`].
[`~generation.GenerationMixin.greedy_search`],
[`~generation.GenerationMixin.contrastive_search`],
[`~generation.GenerationMixin.sample`],
[`~generation.GenerationMixin.beam_search`],
[`~generation.GenerationMixin.beam_sample`],
[`~generation.GenerationMixin.group_beam_search`], and
[`~generation.GenerationMixin.constrained_beam_search`].
Most of those are only useful if you are studying the code of the generate methods in the library.
## Generate Outputs ## Generate Outputs
...@@ -376,4 +367,4 @@ A [`Constraint`] can be used to force the generation to include specific tokens ...@@ -376,4 +367,4 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] StaticCache [[autodoc]] StaticCache
- update - update
- get_seq_length - get_seq_length
\ No newline at end of file
...@@ -43,13 +43,6 @@ like token streaming. ...@@ -43,13 +43,6 @@ like token streaming.
[[autodoc]] generation.GenerationMixin [[autodoc]] generation.GenerationMixin
- generate - generate
- compute_transition_scores - compute_transition_scores
- greedy_search
- sample
- beam_search
- beam_sample
- contrastive_search
- group_beam_search
- constrained_beam_search
## TFGenerationMixin ## TFGenerationMixin
......
...@@ -17,15 +17,6 @@ rendered properly in your Markdown viewer. ...@@ -17,15 +17,6 @@ rendered properly in your Markdown viewer.
# 発電用ユーティリティ # 発電用ユーティリティ
このページには、[`~generation.GenerationMixin.generate`] で使用されるすべてのユーティリティ関数がリストされています。 このページには、[`~generation.GenerationMixin.generate`] で使用されるすべてのユーティリティ関数がリストされています。
[`~generation.GenerationMixin.greedy_search`],
[`~generation.GenerationMixin.contrastive_search`],
[`~generation.GenerationMixin.sample`],
[`~generation.GenerationMixin.beam_search`],
[`~generation.GenerationMixin.beam_sample`],
[`~generation.GenerationMixin.group_beam_search`]、および
[`~generation.GenerationMixin.constrained_beam_search`]。
これらのほとんどは、ライブラリ内の生成メソッドのコードを学習する場合にのみ役に立ちます。
## 出力を生成する ## 出力を生成する
......
...@@ -43,13 +43,6 @@ rendered properly in your Markdown viewer. ...@@ -43,13 +43,6 @@ rendered properly in your Markdown viewer.
[[autodoc]] generation.GenerationMixin [[autodoc]] generation.GenerationMixin
- generate - generate
- compute_transition_scores - compute_transition_scores
- greedy_search
- sample
- beam_search
- beam_sample
- contrastive_search
- group_beam_search
- constrained_beam_search
## TFGenerationMixin ## TFGenerationMixin
......
...@@ -16,16 +16,7 @@ rendered properly in your Markdown viewer. ...@@ -16,16 +16,7 @@ rendered properly in your Markdown viewer.
# 用于生成的工具 # 用于生成的工具
此页面列出了所有由 [`~generation.GenerationMixin.generate`], 此页面列出了所有由 [`~generation.GenerationMixin.generate`]。
[`~generation.GenerationMixin.greedy_search`],
[`~generation.GenerationMixin.contrastive_search`],
[`~generation.GenerationMixin.sample`],
[`~generation.GenerationMixin.beam_search`],
[`~generation.GenerationMixin.beam_sample`],
[`~generation.GenerationMixin.group_beam_search`], 和
[`~generation.GenerationMixin.constrained_beam_search`]使用的实用函数。
其中大多数仅在您研究库中生成方法的代码时才有用。
## 生成输出 ## 生成输出
......
...@@ -38,13 +38,6 @@ rendered properly in your Markdown viewer. ...@@ -38,13 +38,6 @@ rendered properly in your Markdown viewer.
[[autodoc]] generation.GenerationMixin [[autodoc]] generation.GenerationMixin
- generate - generate
- compute_transition_scores - compute_transition_scores
- greedy_search
- sample
- beam_search
- beam_sample
- contrastive_search
- group_beam_search
- constrained_beam_search
## TFGenerationMixin ## TFGenerationMixin
......
...@@ -43,22 +43,22 @@ class GenerationConfig(PushToHubMixin): ...@@ -43,22 +43,22 @@ class GenerationConfig(PushToHubMixin):
Class that holds a configuration for a generation task. A `generate` call supports the following generation methods Class that holds a configuration for a generation task. A `generate` call supports the following generation methods
for text-decoder, text-to-text, speech-to-text, and vision-to-text models: for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
- *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and - *greedy decoding* by calling [`~generation.GenerationMixin._greedy_search`] if `num_beams=1` and
`do_sample=False` `do_sample=False`
- *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0.` - *contrastive search* by calling [`~generation.GenerationMixin._contrastive_search`] if `penalty_alpha>0.`
and `top_k>1` and `top_k>1`
- *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and - *multinomial sampling* by calling [`~generation.GenerationMixin._sample`] if `num_beams=1` and
`do_sample=True` `do_sample=True`
- *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and - *beam-search decoding* by calling [`~generation.GenerationMixin._beam_search`] if `num_beams>1` and
`do_sample=False` `do_sample=False`
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if - *beam-search multinomial sampling* by calling [`~generation.GenerationMixin._beam_sample`] if
`num_beams>1` and `do_sample=True` `num_beams>1` and `do_sample=True`
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin.group_beam_search`], if - *diverse beam-search decoding* by calling [`~generation.GenerationMixin._group_beam_search`], if
`num_beams>1` and `num_beam_groups>1` `num_beams>1` and `num_beam_groups>1`
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if - *constrained beam-search decoding* by calling [`~generation.GenerationMixin._constrained_beam_search`], if
`constraints!=None` or `force_words_ids!=None` `constraints!=None` or `force_words_ids!=None`
- *assisted decoding* by calling [`~generation.GenerationMixin.assisted_decoding`], if - *assisted decoding* by calling [`~generation.GenerationMixin._assisted_decoding`], if
`assistant_model` is passed to `.generate()` `assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
You do not need to call any of the above methods directly. Pass custom parameter values to '.generate()'. To learn You do not need to call any of the above methods directly. Pass custom parameter values to '.generate()'. To learn
more about decoding strategies refer to the [text generation strategies guide](../generation_strategies). more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
......
...@@ -347,20 +347,22 @@ class GenerationMixin: ...@@ -347,20 +347,22 @@ class GenerationMixin:
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`]. A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
The class exposes [`~generation.GenerationMixin.generate`], which can be used for: The class exposes [`~generation.GenerationMixin.generate`], which can be used for:
- *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and - *greedy decoding* by calling [`~generation.GenerationMixin._greedy_search`] if `num_beams=1` and
`do_sample=False` `do_sample=False`
- *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0` and - *contrastive search* by calling [`~generation.GenerationMixin._contrastive_search`] if `penalty_alpha>0` and
`top_k>1` `top_k>1`
- *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and - *multinomial sampling* by calling [`~generation.GenerationMixin._sample`] if `num_beams=1` and
`do_sample=True` `do_sample=True`
- *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and - *beam-search decoding* by calling [`~generation.GenerationMixin._beam_search`] if `num_beams>1` and
`do_sample=False` `do_sample=False`
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if `num_beams>1` - *beam-search multinomial sampling* by calling [`~generation.GenerationMixin._beam_sample`] if `num_beams>1`
and `do_sample=True` and `do_sample=True`
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin.group_beam_search`], if `num_beams>1` - *diverse beam-search decoding* by calling [`~generation.GenerationMixin._group_beam_search`], if `num_beams>1`
and `num_beam_groups>1` and `num_beam_groups>1`
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if - *constrained beam-search decoding* by calling [`~generation.GenerationMixin._constrained_beam_search`], if
`constraints!=None` or `force_words_ids!=None` `constraints!=None` or `force_words_ids!=None`
- *assisted decoding* by calling [`~generation.GenerationMixin._assisted_decoding`], if
`assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
You do not need to call any of the above methods directly. Pass custom parameter values to 'generate' instead. To You do not need to call any of the above methods directly. Pass custom parameter values to 'generate' instead. To
learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies). learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
...@@ -1547,7 +1549,7 @@ class GenerationMixin: ...@@ -1547,7 +1549,7 @@ class GenerationMixin:
) )
if generation_mode == GenerationMode.GREEDY_SEARCH: if generation_mode == GenerationMode.GREEDY_SEARCH:
# 11. run greedy search # 11. run greedy search
result = self.greedy_search( result = self._greedy_search(
input_ids, input_ids,
logits_processor=prepared_logits_processor, logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria, stopping_criteria=prepared_stopping_criteria,
...@@ -1565,7 +1567,7 @@ class GenerationMixin: ...@@ -1565,7 +1567,7 @@ class GenerationMixin:
if not model_kwargs["use_cache"]: if not model_kwargs["use_cache"]:
raise ValueError("Contrastive search requires `use_cache=True`") raise ValueError("Contrastive search requires `use_cache=True`")
result = self.contrastive_search( result = self._contrastive_search(
input_ids, input_ids,
top_k=generation_config.top_k, top_k=generation_config.top_k,
penalty_alpha=generation_config.penalty_alpha, penalty_alpha=generation_config.penalty_alpha,
...@@ -1595,7 +1597,7 @@ class GenerationMixin: ...@@ -1595,7 +1597,7 @@ class GenerationMixin:
) )
# 13. run sample # 13. run sample
result = self.sample( result = self._sample(
input_ids, input_ids,
logits_processor=prepared_logits_processor, logits_processor=prepared_logits_processor,
logits_warper=logits_warper, logits_warper=logits_warper,
...@@ -1629,7 +1631,7 @@ class GenerationMixin: ...@@ -1629,7 +1631,7 @@ class GenerationMixin:
**model_kwargs, **model_kwargs,
) )
# 13. run beam search # 13. run beam search
result = self.beam_search( result = self._beam_search(
input_ids, input_ids,
beam_scorer, beam_scorer,
logits_processor=prepared_logits_processor, logits_processor=prepared_logits_processor,
...@@ -1668,7 +1670,7 @@ class GenerationMixin: ...@@ -1668,7 +1670,7 @@ class GenerationMixin:
) )
# 14. run beam sample # 14. run beam sample
result = self.beam_sample( result = self._beam_sample(
input_ids, input_ids,
beam_scorer, beam_scorer,
logits_processor=prepared_logits_processor, logits_processor=prepared_logits_processor,
...@@ -1703,7 +1705,7 @@ class GenerationMixin: ...@@ -1703,7 +1705,7 @@ class GenerationMixin:
**model_kwargs, **model_kwargs,
) )
# 13. run beam search # 13. run beam search
result = self.group_beam_search( result = self._group_beam_search(
input_ids, input_ids,
beam_scorer, beam_scorer,
logits_processor=prepared_logits_processor, logits_processor=prepared_logits_processor,
...@@ -1777,7 +1779,7 @@ class GenerationMixin: ...@@ -1777,7 +1779,7 @@ class GenerationMixin:
**model_kwargs, **model_kwargs,
) )
# 13. run beam search # 13. run beam search
result = self.constrained_beam_search( result = self._constrained_beam_search(
input_ids, input_ids,
constrained_beam_scorer=constrained_beam_scorer, constrained_beam_scorer=constrained_beam_scorer,
logits_processor=prepared_logits_processor, logits_processor=prepared_logits_processor,
...@@ -1801,8 +1803,15 @@ class GenerationMixin: ...@@ -1801,8 +1803,15 @@ class GenerationMixin:
return result return result
def contrastive_search(self, *args, **kwargs):
logger.warning_once(
"Calling `contrastive_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
return self._contrastive_search(*args, **kwargs)
@torch.no_grad() @torch.no_grad()
def contrastive_search( def _contrastive_search(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
top_k: Optional[int] = 1, top_k: Optional[int] = 1,
...@@ -1828,7 +1837,7 @@ class GenerationMixin: ...@@ -1828,7 +1837,7 @@ class GenerationMixin:
<Tip warning={true}> <Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.contrastive_search`] directly. Use In most cases, you do not need to call [`~generation.GenerationMixin._contrastive_search`] directly. Use
generate() instead. For an overview of generation strategies and code examples, check the [following generate() instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies). guide](../generation_strategies).
...@@ -1902,7 +1911,7 @@ class GenerationMixin: ...@@ -1902,7 +1911,7 @@ class GenerationMixin:
>>> input_prompt = "DeepMind Company is" >>> input_prompt = "DeepMind Company is"
>>> input_ids = tokenizer(input_prompt, return_tensors="pt") >>> input_ids = tokenizer(input_prompt, return_tensors="pt")
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=64)]) >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=64)])
>>> outputs = model.contrastive_search( >>> outputs = model._contrastive_search(
... **input_ids, penalty_alpha=0.6, top_k=4, stopping_criteria=stopping_criteria ... **input_ids, penalty_alpha=0.6, top_k=4, stopping_criteria=stopping_criteria
... ) ... )
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
...@@ -2243,7 +2252,14 @@ class GenerationMixin: ...@@ -2243,7 +2252,14 @@ class GenerationMixin:
else: else:
return input_ids return input_ids
def greedy_search( def greedy_search(self, *args, **kwargs):
logger.warning_once(
"Calling `greedy_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
return self._greedy_search(*args, **kwargs)
def _greedy_search(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
...@@ -2266,7 +2282,7 @@ class GenerationMixin: ...@@ -2266,7 +2282,7 @@ class GenerationMixin:
<Tip warning={true}> <Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.greedy_search`] directly. Use generate() In most cases, you do not need to call [`~generation.GenerationMixin._greedy_search`] directly. Use generate()
instead. For an overview of generation strategies and code examples, check the [following instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies). guide](../generation_strategies).
...@@ -2348,7 +2364,7 @@ class GenerationMixin: ...@@ -2348,7 +2364,7 @@ class GenerationMixin:
... ) ... )
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
>>> outputs = model.greedy_search( >>> outputs = model._greedy_search(
... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria ... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria
... ) ... )
...@@ -2514,7 +2530,14 @@ class GenerationMixin: ...@@ -2514,7 +2530,14 @@ class GenerationMixin:
else: else:
return input_ids return input_ids
def sample( def sample(self, *args, **kwargs):
logger.warning_once(
"Calling `sample` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
return self._sample(*args, **kwargs)
def _sample(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
...@@ -2538,7 +2561,7 @@ class GenerationMixin: ...@@ -2538,7 +2561,7 @@ class GenerationMixin:
<Tip warning={true}> <Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead. In most cases, you do not need to call [`~generation.GenerationMixin._sample`] directly. Use generate() instead.
For an overview of generation strategies and code examples, check the [following For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies). guide](../generation_strategies).
...@@ -2635,7 +2658,7 @@ class GenerationMixin: ...@@ -2635,7 +2658,7 @@ class GenerationMixin:
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
>>> torch.manual_seed(0) # doctest: +IGNORE_RESULT >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
>>> outputs = model.sample( >>> outputs = model._sample(
... input_ids, ... input_ids,
... logits_processor=logits_processor, ... logits_processor=logits_processor,
... logits_warper=logits_warper, ... logits_warper=logits_warper,
...@@ -2832,7 +2855,14 @@ class GenerationMixin: ...@@ -2832,7 +2855,14 @@ class GenerationMixin:
past_key_values.reorder_cache(beam_idx) past_key_values.reorder_cache(beam_idx)
return past_key_values return past_key_values
def beam_search( def beam_search(self, *args, **kwargs):
logger.warning_once(
"Calling `beam_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
return self._beam_search(*args, **kwargs)
def _beam_search(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
beam_scorer: BeamScorer, beam_scorer: BeamScorer,
...@@ -2856,7 +2886,7 @@ class GenerationMixin: ...@@ -2856,7 +2886,7 @@ class GenerationMixin:
<Tip warning={true}> <Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.beam_search`] directly. Use generate() In most cases, you do not need to call [`~generation.GenerationMixin._beam_search`] directly. Use generate()
instead. For an overview of generation strategies and code examples, check the [following instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies). guide](../generation_strategies).
...@@ -2958,7 +2988,7 @@ class GenerationMixin: ...@@ -2958,7 +2988,7 @@ class GenerationMixin:
... ] ... ]
... ) ... )
>>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) >>> outputs = model._beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Wie alt bist du?'] ['Wie alt bist du?']
...@@ -3214,7 +3244,14 @@ class GenerationMixin: ...@@ -3214,7 +3244,14 @@ class GenerationMixin:
else: else:
return sequence_outputs["sequences"] return sequence_outputs["sequences"]
def beam_sample( def beam_sample(self, *args, **kwargs):
logger.warning_once(
"Calling `beam_sample` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
return self._beam_sample(*args, **kwargs)
def _beam_sample(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
beam_scorer: BeamScorer, beam_scorer: BeamScorer,
...@@ -3238,7 +3275,7 @@ class GenerationMixin: ...@@ -3238,7 +3275,7 @@ class GenerationMixin:
<Tip warning={true}> <Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.beam_sample`] directly. Use generate() In most cases, you do not need to call [`~generation.GenerationMixin._beam_sample`] directly. Use generate()
instead. For an overview of generation strategies and code examples, check the [following instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies). guide](../generation_strategies).
...@@ -3346,7 +3383,7 @@ class GenerationMixin: ...@@ -3346,7 +3383,7 @@ class GenerationMixin:
... ] ... ]
... ) ... )
>>> outputs = model.beam_sample( >>> outputs = model._beam_sample(
... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs ... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs
... ) ... )
...@@ -3561,7 +3598,14 @@ class GenerationMixin: ...@@ -3561,7 +3598,14 @@ class GenerationMixin:
else: else:
return sequence_outputs["sequences"] return sequence_outputs["sequences"]
def group_beam_search( def group_beam_search(self, *args, **kwargs):
logger.warning_once(
"Calling `group_beam_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
return self._group_beam_search(*args, **kwargs)
def _group_beam_search(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
beam_scorer: BeamScorer, beam_scorer: BeamScorer,
...@@ -3584,7 +3628,7 @@ class GenerationMixin: ...@@ -3584,7 +3628,7 @@ class GenerationMixin:
<Tip warning={true}> <Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.group_beam_search`] directly. Use In most cases, you do not need to call [`~generation.GenerationMixin._group_beam_search`] directly. Use
generate() instead. For an overview of generation strategies and code examples, check the [following generate() instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies). guide](../generation_strategies).
...@@ -3686,7 +3730,7 @@ class GenerationMixin: ...@@ -3686,7 +3730,7 @@ class GenerationMixin:
... ] ... ]
... ) ... )
>>> outputs = model.group_beam_search( >>> outputs = model._group_beam_search(
... input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs ... input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs
... ) ... )
...@@ -3958,7 +4002,14 @@ class GenerationMixin: ...@@ -3958,7 +4002,14 @@ class GenerationMixin:
else: else:
return sequence_outputs["sequences"] return sequence_outputs["sequences"]
def constrained_beam_search( def constrained_beam_search(self, *args, **kwargs):
logger.warning_once(
"Calling `constrained_beam_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
return self._constrained_beam_search(*args, **kwargs)
def _constrained_beam_search(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
constrained_beam_scorer: ConstrainedBeamSearchScorer, constrained_beam_scorer: ConstrainedBeamSearchScorer,
...@@ -3981,7 +4032,7 @@ class GenerationMixin: ...@@ -3981,7 +4032,7 @@ class GenerationMixin:
<Tip warning={true}> <Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.constrained_beam_search`] directly. Use In most cases, you do not need to call [`~generation.GenerationMixin._constrained_beam_search`] directly. Use
generate() instead. For an overview of generation strategies and code examples, check the [following generate() instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies). guide](../generation_strategies).
...@@ -4088,7 +4139,7 @@ class GenerationMixin: ...@@ -4088,7 +4139,7 @@ class GenerationMixin:
... ] ... ]
... ) ... )
>>> outputs = model.constrained_beam_search( >>> outputs = model._constrained_beam_search(
... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs ... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs
... ) ... )
...@@ -4311,7 +4362,14 @@ class GenerationMixin: ...@@ -4311,7 +4362,14 @@ class GenerationMixin:
else: else:
return sequence_outputs["sequences"] return sequence_outputs["sequences"]
def assisted_decoding( def assisted_decoding(self, *args, **kwargs):
logger.warning_once(
"Calling `_assisted_decoding` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
return self._assisted_decoding(*args, **kwargs)
def _assisted_decoding(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
candidate_generator: Optional["CandidateGenerator"] = None, candidate_generator: Optional["CandidateGenerator"] = None,
...@@ -4338,7 +4396,7 @@ class GenerationMixin: ...@@ -4338,7 +4396,7 @@ class GenerationMixin:
<Tip warning={true}> <Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.candidate_decoding`] directly. Use In most cases, you do not need to call [`~generation.GenerationMixin._assisted_decoding`] directly. Use
generate() instead. For an overview of generation strategies and code examples, check the [following generate() instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies). guide](../generation_strategies).
...@@ -4429,7 +4487,7 @@ class GenerationMixin: ...@@ -4429,7 +4487,7 @@ class GenerationMixin:
... logits_processor=logits_processor, ... logits_processor=logits_processor,
... model_kwargs={}, ... model_kwargs={},
... ) ... )
>>> outputs = model.assisted_decoding( >>> outputs = model._assisted_decoding(
... input_ids, ... input_ids,
... candidate_generator=candidate_generator, ... candidate_generator=candidate_generator,
... logits_processor=logits_processor, ... logits_processor=logits_processor,
......
...@@ -1336,7 +1336,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel): ...@@ -1336,7 +1336,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
) )
# 11. run greedy search # 11. run greedy search
outputs = self.greedy_search( outputs = self._greedy_search(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
...@@ -1361,7 +1361,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel): ...@@ -1361,7 +1361,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
) )
# 12. run sample # 12. run sample
outputs = self.sample( outputs = self._sample(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
logits_warper=logits_warper, logits_warper=logits_warper,
...@@ -2402,7 +2402,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel): ...@@ -2402,7 +2402,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
) )
# 11. run greedy search # 11. run greedy search
outputs = self.greedy_search( outputs = self._greedy_search(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
...@@ -2428,7 +2428,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel): ...@@ -2428,7 +2428,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
) )
# 12. run sample # 12. run sample
outputs = self.sample( outputs = self._sample(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
logits_warper=logits_warper, logits_warper=logits_warper,
......
...@@ -1539,7 +1539,7 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1539,7 +1539,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
" greedy search." " greedy search."
) )
return self.greedy_search( return self._greedy_search(
input_ids, input_ids,
logits_processor=pre_processor, logits_processor=pre_processor,
max_length=generation_config.max_length, max_length=generation_config.max_length,
...@@ -1559,7 +1559,7 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1559,7 +1559,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
num_beam_hyps_to_keep=generation_config.num_return_sequences, num_beam_hyps_to_keep=generation_config.num_return_sequences,
max_length=generation_config.max_length, max_length=generation_config.max_length,
) )
return self.beam_search( return self._beam_search(
input_ids, input_ids,
beam_scorer, beam_scorer,
logits_processor=pre_processor, logits_processor=pre_processor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment