"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "4bafc43b0ebf65dc1e9df70c4fe1a81dfa2475cf"
Unverified Commit 6ce11c2c authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Docs] Improve PyTorch, Flax generate API (#15988)

* Move generate docs

* up

* Update docs/source/_toctree.yml

* correct

* correct some stuff

* correct tests

* more fixes

* finish generate

* add to doc stest

* finish

* finalize

* add warning to generate method
parent 0951d317
...@@ -114,6 +114,8 @@ ...@@ -114,6 +114,8 @@
title: Logging title: Logging
- local: main_classes/model - local: main_classes/model
title: Models title: Models
- local: main_classes/text_generation
title: Text Generation
- local: main_classes/onnx - local: main_classes/onnx
title: ONNX title: ONNX
- local: main_classes/optimizer_schedules - local: main_classes/optimizer_schedules
......
...@@ -86,14 +86,6 @@ Due to Pytorch design, this functionality is only available for floating dtypes. ...@@ -86,14 +86,6 @@ Due to Pytorch design, this functionality is only available for floating dtypes.
- push_to_hub - push_to_hub
- all - all
## Generation
[[autodoc]] generation_utils.GenerationMixin
[[autodoc]] generation_tf_utils.TFGenerationMixin
[[autodoc]] generation_flax_utils.FlaxGenerationMixin
## Pushing to the Hub ## Pushing to the Hub
[[autodoc]] file_utils.PushToHubMixin [[autodoc]] file_utils.PushToHubMixin
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Generation
The methods for auto-regressive text generation, namely [`~generation_utils.GenerationMixin.generate`] (for the PyTorch models), [`~generation_tf_utils.TFGenerationMixin.generate`] (for the TensorFlow models) and [`~generation_flax_utils.FlaxGenerationMixin.generate`] (for the Flax/JAX models), are implemented in [`~generation_utils.GenerationMixin`], [`~generation_tf_utils.TFGenerationMixin`] and [`~generation_flax_utils.FlaxGenerationMixin`] respectively.
The `GenerationMixin` classes are inherited by the corresponding base model classes, *e.g.* [`PreTrainedModel`], [`TFPreTrainedModel`], and [`FlaxPreTrainedModel`] respectively, therefore exposing all
methods for auto-regressive text generation to every model class.
## GenerationMixn
[[autodoc]] generation_utils.GenerationMixin
- generate
- greedy_search
- sample
- beam_search
- beam_sample
- group_beam_search
- constrained_beam_search
## TFGenerationMixn
[[autodoc]] generation_tf_utils.TFGenerationMixin
- generate
## FlaxGenerationMixn
[[autodoc]] generation_flax_utils.FlaxGenerationMixin
- generate
...@@ -118,7 +118,16 @@ class BeamSearchState: ...@@ -118,7 +118,16 @@ class BeamSearchState:
class FlaxGenerationMixin: class FlaxGenerationMixin:
""" """
A class containing all of the functions supporting generation, to be used as a mixin in [`FlaxPreTrainedModel`]. A class containing all functions for auto-regressive text generation, to be used as a mixin in
[`FlaxPreTrainedModel`].
The class exposes [`~generation_flax_utils.FlaxGenerationMixin.generate`], which can be used for:
- *greedy decoding* by calling [`~generation_flax_utils.FlaxGenerationMixin._greedy_search`] if
`num_beams=1` and `do_sample=False`.
- *multinomial sampling* by calling [`~generation_flax_utils.FlaxGenerationMixin._sample`] if `num_beams=1`
and `do_sample=True`.
- *beam-search decoding* by calling [`~generation_utils.FlaxGenerationMixin._beam_search`] if `num_beams>1`
and `do_sample=False`.
""" """
@staticmethod @staticmethod
...@@ -176,12 +185,23 @@ class FlaxGenerationMixin: ...@@ -176,12 +185,23 @@ class FlaxGenerationMixin:
**model_kwargs, **model_kwargs,
): ):
r""" r"""
Generates sequences for models with a language modeling head. The method currently supports greedy decoding, Generates sequences of token ids for models with a language modeling head. The method supports the following
and, multinomial sampling. generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
Apart from `input_ids`, all the arguments below will default to the value of the attribute of the same name - *greedy decoding* by calling [`~generation_flax_utils.FlaxGenerationMixin._greedy_search`] if
inside the [`PretrainedConfig`] of the model. The default values indicated are the default values of those `num_beams=1` and `do_sample=False`.
config. - *multinomial sampling* by calling [`~generation_flax_utils.FlaxGenerationMixin._sample`] if `num_beams=1`
and `do_sample=True`.
- *beam-search decoding* by calling [`~generation_utils.FlaxGenerationMixin._beam_search`] if `num_beams>1`
and `do_sample=False`.
<Tip warning={true}>
Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as
defined in the model's config (`config.json`) which in turn defaults to the
[`~modeling_utils.PretrainedConfig`] of the model.
</Tip>
Most of these parameters are explained in more detail in [this blog Most of these parameters are explained in more detail in [this blog
post](https://huggingface.co/blog/how-to-generate). post](https://huggingface.co/blog/how-to-generate).
...@@ -236,7 +256,7 @@ class FlaxGenerationMixin: ...@@ -236,7 +256,7 @@ class FlaxGenerationMixin:
>>> input_ids = tokenizer(input_context, return_tensors="np").input_ids >>> input_ids = tokenizer(input_context, return_tensors="np").input_ids
>>> # generate candidates using sampling >>> # generate candidates using sampling
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True) >>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
```""" ```"""
# set init values # set init values
max_length = max_length if max_length is not None else self.config.max_length max_length = max_length if max_length is not None else self.config.max_length
......
...@@ -377,7 +377,21 @@ BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOu ...@@ -377,7 +377,21 @@ BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOu
class GenerationMixin: class GenerationMixin:
""" """
A class containing all of the functions supporting generation, to be used as a mixin in [`PreTrainedModel`]. A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
The class exposes [`~generation_utils.GenerationMixin.generate`], which can be used for:
- *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and
`do_sample=False`.
- *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and
`do_sample=True`.
- *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and
`do_sample=False`.
- *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if
`num_beams>1` and `do_sample=True`.
- *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if
`num_beams>1` and `num_beam_groups>1`.
- *constrained beam-search decoding* by calling [`~generation_utils.GenerationMixin.constrained_beam_search`],
if `constraints!=None` or `force_words_ids!=None`.
""" """
def _prepare_model_inputs( def _prepare_model_inputs(
...@@ -847,18 +861,37 @@ class GenerationMixin: ...@@ -847,18 +861,37 @@ class GenerationMixin:
**model_kwargs, **model_kwargs,
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]: ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
r""" r"""
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
multinomial sampling, beam-search decoding, and beam-search multinomial sampling.
Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name inside Generates sequences of token ids for models with a language modeling head. The method supports the following
the [`PretrainedConfig`] of the model. The default values indicated are the default values of those config. generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
- *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and
`do_sample=False`.
- *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and
`do_sample=True`.
- *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and
`do_sample=False`.
- *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if
`num_beams>1` and `do_sample=True`.
- *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if
`num_beams>1` and `num_beam_groups>1`.
- *constrained beam-search decoding* by calling
[`~generation_utils.GenerationMixin.constrained_beam_search`], if `constraints!=None` or
`force_words_ids!=None`.
<Tip warning={true}>
Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as
defined in the model's config (`config.json`) which in turn defaults to the
[`~modeling_utils.PretrainedConfig`] of the model.
</Tip>
Most of these parameters are explained in more detail in [this blog Most of these parameters are explained in more detail in [this blog
post](https://huggingface.co/blog/how-to-generate). post](https://huggingface.co/blog/how-to-generate).
Parameters: Parameters:
inputs (`torch.Tensor` of shape `(batch_size, sequence_length)`, `(batch_size, sequence_length, inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
feature_dim)` or `(batch_size, num_channels, height, width)`, *optional*):
The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
...@@ -997,66 +1030,56 @@ class GenerationMixin: ...@@ -997,66 +1030,56 @@ class GenerationMixin:
Examples: Examples:
Greedy Decoding:
```python ```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM >>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2") >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2") >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> # do greedy decoding without providing a prompt
>>> outputs = model.generate(max_length=40)
>>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base") >>> prompt = "Today I believe we can finally"
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids
>>> document = (
... "at least two people were killed in a suspected bomb attack on a passenger bus " >>> # generate up to 30 tokens
... "in the strife-torn southern philippines on monday , the military said." >>> outputs = model.generate(input_ids, do_sample=False, max_length=30)
... ) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
>>> # encode input context ['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n']
>>> input_ids = tokenizer(document, return_tensors="pt").input_ids ```
>>> # generate 3 independent sequences using beam search decoding (5 beams)
>>> # with T5 encoder-decoder model conditioned on short news article. Multinomial Sampling:
>>> outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) ```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2") >>> import torch
>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
>>> input_context = "The dog" >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> # encode input context
>>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
>>> # generate 3 candidates using sampling
>>> outputs = model.generate(input_ids=input_ids, max_length=20, num_return_sequences=3, do_sample=True)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
>>> tokenizer = AutoTokenizer.from_pretrained("ctrl")
>>> model = AutoModelForCausalLM.from_pretrained("ctrl")
>>> # "Legal" is one of the control codes for ctrl
>>> input_context = "Legal My neighbor is"
>>> # encode input context
>>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
>>> outputs = model.generate(input_ids=input_ids, max_length=20, repetition_penalty=1.2)
>>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=False)
>>> model = AutoModelForCausalLM.from_pretrained("gpt2") >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> input_context = "My cute dog"
>>> # get tokens of words that should not be generated >>> prompt = "Today I believe we can finally"
>>> bad_words_ids = tokenizer( >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids
... ["idiot", "stupid", "shut up"], add_prefix_space=True, add_special_tokens=False
>>> ).input_ids >>> # sample up to 30 tokens
>>> # get tokens of words that we want generated >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
>>> force_words_ids = tokenizer(["runs", "loves"], add_prefix_space=True, add_special_tokens=False).input_ids >>> outputs = model.generate(input_ids, do_sample=True, max_length=30)
>>> # encode input context >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
>>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids ['Today I believe we can finally get rid of discrimination," said Rep. Mark Pocan (D-Wis.).\n\n"Just look at the']
>>> # generate sequences without allowing bad_words to be generated ```
>>> outputs = model.generate(
... input_ids=input_ids, Beam-search decoding:
... max_length=20,
... do_sample=True, ```python
... bad_words_ids=bad_words_ids, >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
... force_words_ids=force_words_ids,
... ) >>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
>>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True)) >>> model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de")
>>> sentence = "Paris is one of the densest populated areas in Europe."
>>> input_ids = tokenizer(sentence, return_tensors="pt").input_ids
>>> outputs = model.generate(input_ids)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
```""" ```"""
# 1. Set generation parameters if not already defined # 1. Set generation parameters if not already defined
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
...@@ -1457,7 +1480,8 @@ class GenerationMixin: ...@@ -1457,7 +1480,8 @@ class GenerationMixin:
**model_kwargs, **model_kwargs,
) -> Union[GreedySearchOutput, torch.LongTensor]: ) -> Union[GreedySearchOutput, torch.LongTensor]:
r""" r"""
Generates sequences for models with a language modeling head using greedy decoding. Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be
used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters: Parameters:
...@@ -1508,6 +1532,8 @@ class GenerationMixin: ...@@ -1508,6 +1532,8 @@ class GenerationMixin:
... AutoModelForCausalLM, ... AutoModelForCausalLM,
... LogitsProcessorList, ... LogitsProcessorList,
... MinLengthLogitsProcessor, ... MinLengthLogitsProcessor,
... StoppingCriteriaList,
... MaxLengthCriteria,
... ) ... )
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
...@@ -1516,26 +1542,30 @@ class GenerationMixin: ...@@ -1516,26 +1542,30 @@ class GenerationMixin:
>>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
>>> model.config.pad_token_id = model.config.eos_token_id >>> model.config.pad_token_id = model.config.eos_token_id
>>> input_prompt = "Today is a beautiful day, and" >>> input_prompt = "It might be possible to"
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
>>> # instantiate logits processors >>> # instantiate logits processors
>>> logits_processor = LogitsProcessorList( >>> logits_processor = LogitsProcessorList(
... [ ... [
... MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id), ... MinLengthLogitsProcessor(10, eos_token_id=model.config.eos_token_id),
... ] ... ]
... ) ... )
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
>>> outputs = model.greedy_search(input_ids, logits_processor=logits_processor) >>> outputs = model.greedy_search(
... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria
... )
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
["It might be possible to get a better understanding of the nature of the problem, but it's not"]
```""" ```"""
# init values # init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None: if max_length is not None:
warnings.warn( warnings.warn(
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
UserWarning, UserWarning,
) )
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
...@@ -1683,7 +1713,8 @@ class GenerationMixin: ...@@ -1683,7 +1713,8 @@ class GenerationMixin:
**model_kwargs, **model_kwargs,
) -> Union[SampleOutput, torch.LongTensor]: ) -> Union[SampleOutput, torch.LongTensor]:
r""" r"""
Generates sequences for models with a language modeling head using multinomial sampling. Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters: Parameters:
...@@ -1739,7 +1770,10 @@ class GenerationMixin: ...@@ -1739,7 +1770,10 @@ class GenerationMixin:
... MinLengthLogitsProcessor, ... MinLengthLogitsProcessor,
... TopKLogitsWarper, ... TopKLogitsWarper,
... TemperatureLogitsWarper, ... TemperatureLogitsWarper,
... StoppingCriteriaList,
... MaxLengthCriteria,
... ) ... )
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2") >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
...@@ -1764,9 +1798,18 @@ class GenerationMixin: ...@@ -1764,9 +1798,18 @@ class GenerationMixin:
... ] ... ]
... ) ... )
>>> outputs = model.sample(input_ids, logits_processor=logits_processor, logits_warper=logits_warper) >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
>>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
>>> outputs = model.sample(
... input_ids,
... logits_processor=logits_processor,
... logits_warper=logits_warper,
... stopping_criteria=stopping_criteria,
... )
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the']
```""" ```"""
# init values # init values
...@@ -1926,7 +1969,8 @@ class GenerationMixin: ...@@ -1926,7 +1969,8 @@ class GenerationMixin:
**model_kwargs, **model_kwargs,
) -> Union[BeamSearchOutput, torch.LongTensor]: ) -> Union[BeamSearchOutput, torch.LongTensor]:
r""" r"""
Generates sequences for models with a language modeling head using beam search decoding. Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters: Parameters:
...@@ -2020,7 +2064,8 @@ class GenerationMixin: ...@@ -2020,7 +2064,8 @@ class GenerationMixin:
>>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) >>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Wie alt bist du?']
```""" ```"""
# init values # init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
...@@ -2237,7 +2282,8 @@ class GenerationMixin: ...@@ -2237,7 +2282,8 @@ class GenerationMixin:
**model_kwargs, **model_kwargs,
) -> Union[BeamSampleOutput, torch.LongTensor]: ) -> Union[BeamSampleOutput, torch.LongTensor]:
r""" r"""
Generates sequences for models with a language modeling head using beam search with multinomial sampling. Generates sequences of token ids for models with a language modeling head using **beam search multinomial
sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters: Parameters:
...@@ -2343,7 +2389,8 @@ class GenerationMixin: ...@@ -2343,7 +2389,8 @@ class GenerationMixin:
... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs ... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs
... ) ... )
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Wie alt bist du?']
```""" ```"""
# init values # init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
...@@ -2556,7 +2603,8 @@ class GenerationMixin: ...@@ -2556,7 +2603,8 @@ class GenerationMixin:
**model_kwargs, **model_kwargs,
): ):
r""" r"""
Generates sequences for models with a language modeling head using beam search decoding. Generates sequences of token ids for models with a language modeling head using **diverse beam search
decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters: Parameters:
...@@ -2656,7 +2704,8 @@ class GenerationMixin: ...@@ -2656,7 +2704,8 @@ class GenerationMixin:
... input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs ... input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs
... ) ... )
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Wie alt bist du?']
```""" ```"""
# init values # init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
...@@ -2920,7 +2969,8 @@ class GenerationMixin: ...@@ -2920,7 +2969,8 @@ class GenerationMixin:
) -> Union[BeamSearchOutput, torch.LongTensor]: ) -> Union[BeamSearchOutput, torch.LongTensor]:
r""" r"""
Generates sequences for models with a language modeling head using beam search decoding. Generates sequences of token ids for models with a language modeling head using **constrained beam search
decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters: Parameters:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
...@@ -3024,8 +3074,8 @@ class GenerationMixin: ...@@ -3024,8 +3074,8 @@ class GenerationMixin:
... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs ... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs
... ) ... )
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
# => ['Wie alter sind Sie?'] ['Wie alt sind Sie?']
```""" ```"""
# init values # init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
......
...@@ -28,5 +28,6 @@ src/transformers/models/pegasus/modeling_pegasus.py ...@@ -28,5 +28,6 @@ src/transformers/models/pegasus/modeling_pegasus.py
src/transformers/models/blenderbot/modeling_blenderbot.py src/transformers/models/blenderbot/modeling_blenderbot.py
src/transformers/models/blenderbot_small/modeling_blenderbot_small.py src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
src/transformers/models/plbart/modeling_plbart.py src/transformers/models/plbart/modeling_plbart.py
src/transformers/generation_utils.py
docs/source/quicktour.mdx docs/source/quicktour.mdx
docs/source/task_summary.mdx docs/source/task_summary.mdx
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment