Unverified Commit f270b960 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: move generation_*.py src files into generation/*.py (#20096)

* move generation_*.py src files into generation/*.py

* populate generation.__init__ with lazy loading

* move imports and references from generation.xxx.object to generation.object
parent bac2d29a
...@@ -56,7 +56,7 @@ Wenn Sie mehr als eine Eingabe haben, übergeben Sie die Eingabe als Liste: ...@@ -56,7 +56,7 @@ Wenn Sie mehr als eine Eingabe haben, übergeben Sie die Eingabe als Liste:
... ) # doctest: +SKIP ... ) # doctest: +SKIP
``` ```
Alle zusätzlichen Parameter für Ihre Aufgabe können auch in die [`pipeline`] aufgenommen werden. Die Aufgabe `Text-Generierung` hat eine [`~generation_utils.GenerationMixin.generate`]-Methode mit mehreren Parametern zur Steuerung der Ausgabe. Wenn Sie zum Beispiel mehr als eine Ausgabe erzeugen wollen, setzen Sie den Parameter `num_return_sequences`: Alle zusätzlichen Parameter für Ihre Aufgabe können auch in die [`pipeline`] aufgenommen werden. Die Aufgabe `Text-Generierung` hat eine [`~generation.GenerationMixin.generate`]-Methode mit mehreren Parametern zur Steuerung der Ausgabe. Wenn Sie zum Beispiel mehr als eine Ausgabe erzeugen wollen, setzen Sie den Parameter `num_return_sequences`:
```py ```py
>>> generator( >>> generator(
......
...@@ -12,22 +12,22 @@ specific language governing permissions and limitations under the License. ...@@ -12,22 +12,22 @@ specific language governing permissions and limitations under the License.
# Utilities for Generation # Utilities for Generation
This page lists all the utility functions used by [`~generation_utils.GenerationMixin.generate`], This page lists all the utility functions used by [`~generation.GenerationMixin.generate`],
[`~generation_utils.GenerationMixin.greedy_search`], [`~generation.GenerationMixin.greedy_search`],
[`~generation_utils.GenerationMixin.contrastive_search`], [`~generation.GenerationMixin.contrastive_search`],
[`~generation_utils.GenerationMixin.sample`], [`~generation.GenerationMixin.sample`],
[`~generation_utils.GenerationMixin.beam_search`], [`~generation.GenerationMixin.beam_search`],
[`~generation_utils.GenerationMixin.beam_sample`], [`~generation.GenerationMixin.beam_sample`],
[`~generation_utils.GenerationMixin.group_beam_search`], and [`~generation.GenerationMixin.group_beam_search`], and
[`~generation_utils.GenerationMixin.constrained_beam_search`]. [`~generation.GenerationMixin.constrained_beam_search`].
Most of those are only useful if you are studying the code of the generate methods in the library. Most of those are only useful if you are studying the code of the generate methods in the library.
## Generate Outputs ## Generate Outputs
The output of [`~generation_utils.GenerationMixin.generate`] is an instance of a subclass of The output of [`~generation.GenerationMixin.generate`] is an instance of a subclass of
[`~utils.ModelOutput`]. This output is a data structure containing all the information returned [`~utils.ModelOutput`]. This output is a data structure containing all the information returned
by [`~generation_utils.GenerationMixin.generate`], but that can also be used as tuple or dictionary. by [`~generation.GenerationMixin.generate`], but that can also be used as tuple or dictionary.
Here's an example: Here's an example:
...@@ -41,7 +41,7 @@ inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt") ...@@ -41,7 +41,7 @@ inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt")
generation_output = model.generate(**inputs, return_dict_in_generate=True, output_scores=True) generation_output = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
``` ```
The `generation_output` object is a [`~generation_utils.GreedySearchDecoderOnlyOutput`], as we can The `generation_output` object is a [`~generation.GreedySearchDecoderOnlyOutput`], as we can
see in the documentation of that class below, it means it has the following attributes: see in the documentation of that class below, it means it has the following attributes:
- `sequences`: the generated sequences of tokens - `sequences`: the generated sequences of tokens
...@@ -73,31 +73,31 @@ We document here all output types. ...@@ -73,31 +73,31 @@ We document here all output types.
### GreedySearchOutput ### GreedySearchOutput
[[autodoc]] generation_utils.GreedySearchDecoderOnlyOutput [[autodoc]] generation.GreedySearchDecoderOnlyOutput
[[autodoc]] generation_utils.GreedySearchEncoderDecoderOutput [[autodoc]] generation.GreedySearchEncoderDecoderOutput
[[autodoc]] generation_flax_utils.FlaxGreedySearchOutput [[autodoc]] generation.FlaxGreedySearchOutput
### SampleOutput ### SampleOutput
[[autodoc]] generation_utils.SampleDecoderOnlyOutput [[autodoc]] generation.SampleDecoderOnlyOutput
[[autodoc]] generation_utils.SampleEncoderDecoderOutput [[autodoc]] generation.SampleEncoderDecoderOutput
[[autodoc]] generation_flax_utils.FlaxSampleOutput [[autodoc]] generation.FlaxSampleOutput
### BeamSearchOutput ### BeamSearchOutput
[[autodoc]] generation_utils.BeamSearchDecoderOnlyOutput [[autodoc]] generation.BeamSearchDecoderOnlyOutput
[[autodoc]] generation_utils.BeamSearchEncoderDecoderOutput [[autodoc]] generation.BeamSearchEncoderDecoderOutput
### BeamSampleOutput ### BeamSampleOutput
[[autodoc]] generation_utils.BeamSampleDecoderOnlyOutput [[autodoc]] generation.BeamSampleDecoderOnlyOutput
[[autodoc]] generation_utils.BeamSampleEncoderDecoderOutput [[autodoc]] generation.BeamSampleEncoderDecoderOutput
## LogitsProcessor ## LogitsProcessor
......
...@@ -25,9 +25,9 @@ are common among all the models to: ...@@ -25,9 +25,9 @@ are common among all the models to:
The other methods that are common to each model are defined in [`~modeling_utils.ModuleUtilsMixin`] The other methods that are common to each model are defined in [`~modeling_utils.ModuleUtilsMixin`]
(for the PyTorch models) and [`~modeling_tf_utils.TFModuleUtilsMixin`] (for the TensorFlow models) or (for the PyTorch models) and [`~modeling_tf_utils.TFModuleUtilsMixin`] (for the TensorFlow models) or
for text generation, [`~generation_utils.GenerationMixin`] (for the PyTorch models), for text generation, [`~generation.GenerationMixin`] (for the PyTorch models),
[`~generation_tf_utils.TFGenerationMixin`] (for the TensorFlow models) and [`~generation.TFGenerationMixin`] (for the TensorFlow models) and
[`~generation_flax_utils.FlaxGenerationMixin`] (for the Flax/JAX models). [`~generation.FlaxGenerationMixin`] (for the Flax/JAX models).
## PreTrainedModel ## PreTrainedModel
......
...@@ -14,13 +14,13 @@ specific language governing permissions and limitations under the License. ...@@ -14,13 +14,13 @@ specific language governing permissions and limitations under the License.
Each framework has a generate method for auto-regressive text generation implemented in their respective `GenerationMixin` class: Each framework has a generate method for auto-regressive text generation implemented in their respective `GenerationMixin` class:
- PyTorch [`~generation_utils.GenerationMixin.generate`] is implemented in [`~generation_utils.GenerationMixin`]. - PyTorch [`~generation.GenerationMixin.generate`] is implemented in [`~generation.GenerationMixin`].
- TensorFlow [`~generation_tf_utils.TFGenerationMixin.generate`] is implemented in [`~generation_tf_utils.TFGenerationMixin`]. - TensorFlow [`~generation.TFGenerationMixin.generate`] is implemented in [`~generation.TFGenerationMixin`].
- Flax/JAX [`~generation_flax_utils.FlaxGenerationMixin.generate`] is implemented in [`~generation_flax_utils.FlaxGenerationMixin`]. - Flax/JAX [`~generation.FlaxGenerationMixin.generate`] is implemented in [`~generation.FlaxGenerationMixin`].
## GenerationMixin ## GenerationMixin
[[autodoc]] generation_utils.GenerationMixin [[autodoc]] generation.GenerationMixin
- generate - generate
- greedy_search - greedy_search
- sample - sample
...@@ -32,10 +32,10 @@ Each framework has a generate method for auto-regressive text generation impleme ...@@ -32,10 +32,10 @@ Each framework has a generate method for auto-regressive text generation impleme
## TFGenerationMixin ## TFGenerationMixin
[[autodoc]] generation_tf_utils.TFGenerationMixin [[autodoc]] generation.TFGenerationMixin
- generate - generate
## FlaxGenerationMixin ## FlaxGenerationMixin
[[autodoc]] generation_flax_utils.FlaxGenerationMixin [[autodoc]] generation.FlaxGenerationMixin
- generate - generate
...@@ -58,7 +58,7 @@ This model was contributed by [sshleifer](https://huggingface.co/sshleifer). The ...@@ -58,7 +58,7 @@ This model was contributed by [sshleifer](https://huggingface.co/sshleifer). The
- Model predictions are intended to be identical to the original implementation when - Model predictions are intended to be identical to the original implementation when
`forced_bos_token_id=0`. This only works, however, if the string you pass to `forced_bos_token_id=0`. This only works, however, if the string you pass to
[`fairseq.encode`] starts with a space. [`fairseq.encode`] starts with a space.
- [`~generation_utils.GenerationMixin.generate`] should be used for conditional generation tasks like - [`~generation.GenerationMixin.generate`] should be used for conditional generation tasks like
summarization, see the example in that docstrings. summarization, see the example in that docstrings.
- Models that load the *facebook/bart-large-cnn* weights will not have a `mask_token_id`, or be able to perform - Models that load the *facebook/bart-large-cnn* weights will not have a `mask_token_id`, or be able to perform
mask-filling tasks. mask-filling tasks.
...@@ -188,4 +188,4 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h ...@@ -188,4 +188,4 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
## FlaxBartForCausalLM ## FlaxBartForCausalLM
[[autodoc]] FlaxBartForCausalLM [[autodoc]] FlaxBartForCausalLM
- __call__ - __call__
\ No newline at end of file
...@@ -23,7 +23,7 @@ The abstract from the paper is the following: ...@@ -23,7 +23,7 @@ The abstract from the paper is the following:
*Understanding document images (e.g., invoices) is a core but challenging task since it requires complex functions such as reading text and a holistic understanding of the document. Current Visual Document Understanding (VDU) methods outsource the task of reading text to off-the-shelf Optical Character Recognition (OCR) engines and focus on the understanding task with the OCR outputs. Although such OCR-based approaches have shown promising performance, they suffer from 1) high computational costs for using OCR; 2) inflexibility of OCR models on languages or types of document; 3) OCR error propagation to the subsequent process. To address these issues, in this paper, we introduce a novel OCR-free VDU model named Donut, which stands for Document understanding transformer. As the first step in OCR-free VDU research, we propose a simple architecture (i.e., Transformer) with a pre-training objective (i.e., cross-entropy loss). Donut is conceptually simple yet effective. Through extensive experiments and analyses, we show a simple OCR-free VDU model, Donut, achieves state-of-the-art performances on various VDU tasks in terms of both speed and accuracy. In addition, we offer a synthetic data generator that helps the model pre-training to be flexible in various languages and domains.* *Understanding document images (e.g., invoices) is a core but challenging task since it requires complex functions such as reading text and a holistic understanding of the document. Current Visual Document Understanding (VDU) methods outsource the task of reading text to off-the-shelf Optical Character Recognition (OCR) engines and focus on the understanding task with the OCR outputs. Although such OCR-based approaches have shown promising performance, they suffer from 1) high computational costs for using OCR; 2) inflexibility of OCR models on languages or types of document; 3) OCR error propagation to the subsequent process. To address these issues, in this paper, we introduce a novel OCR-free VDU model named Donut, which stands for Document understanding transformer. As the first step in OCR-free VDU research, we propose a simple architecture (i.e., Transformer) with a pre-training objective (i.e., cross-entropy loss). Donut is conceptually simple yet effective. Through extensive experiments and analyses, we show a simple OCR-free VDU model, Donut, achieves state-of-the-art performances on various VDU tasks in terms of both speed and accuracy. In addition, we offer a synthetic data generator that helps the model pre-training to be flexible in various languages and domains.*
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/donut_architecture.jpg" <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/donut_architecture.jpg"
alt="drawing" width="600"/> alt="drawing" width="600"/>
<small> Donut high-level overview. Taken from the <a href="https://arxiv.org/abs/2111.15664">original paper</a>. </small> <small> Donut high-level overview. Taken from the <a href="https://arxiv.org/abs/2111.15664">original paper</a>. </small>
...@@ -40,7 +40,7 @@ Tips: ...@@ -40,7 +40,7 @@ Tips:
## Inference ## Inference
Donut's [`VisionEncoderDecoder`] model accepts images as input and makes use of Donut's [`VisionEncoderDecoder`] model accepts images as input and makes use of
[`~generation_utils.GenerationMixin.generate`] to autoregressively generate text given the input image. [`~generation.GenerationMixin.generate`] to autoregressively generate text given the input image.
The [`DonutFeatureExtractor`] class is responsible for preprocessing the input image and The [`DonutFeatureExtractor`] class is responsible for preprocessing the input image and
[`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`] decodes the generated target tokens to the target string. The [`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`] decodes the generated target tokens to the target string. The
...@@ -211,4 +211,4 @@ We refer to the [tutorial notebooks](https://github.com/NielsRogge/Transformers- ...@@ -211,4 +211,4 @@ We refer to the [tutorial notebooks](https://github.com/NielsRogge/Transformers-
## DonutSwinModel ## DonutSwinModel
[[autodoc]] DonutSwinModel [[autodoc]] DonutSwinModel
- forward - forward
\ No newline at end of file
...@@ -53,7 +53,7 @@ Tips: ...@@ -53,7 +53,7 @@ Tips:
### Generation ### Generation
The [`~generation_utils.GenerationMixin.generate`] method can be used to generate text using GPT-J The [`~generation.GenerationMixin.generate`] method can be used to generate text using GPT-J
model. model.
```python ```python
......
...@@ -38,7 +38,7 @@ Tips: ...@@ -38,7 +38,7 @@ Tips:
## Inference ## Inference
Speech2Text2's [`SpeechEncoderDecoderModel`] model accepts raw waveform input values from speech and Speech2Text2's [`SpeechEncoderDecoderModel`] model accepts raw waveform input values from speech and
makes use of [`~generation_utils.GenerationMixin.generate`] to translate the input speech makes use of [`~generation.GenerationMixin.generate`] to translate the input speech
autoregressively to the target language. autoregressively to the target language.
The [`Wav2Vec2FeatureExtractor`] class is responsible for preprocessing the input speech and The [`Wav2Vec2FeatureExtractor`] class is responsible for preprocessing the input speech and
......
...@@ -225,7 +225,7 @@ batch) leads to very slow training on TPU. ...@@ -225,7 +225,7 @@ batch) leads to very slow training on TPU.
## Inference ## Inference
At inference time, it is recommended to use [`~generation_utils.GenerationMixin.generate`]. This At inference time, it is recommended to use [`~generation.GenerationMixin.generate`]. This
method takes care of encoding the input and feeding the encoded hidden states via cross-attention layers to the decoder method takes care of encoding the input and feeding the encoded hidden states via cross-attention layers to the decoder
and auto-regressively generates the decoder output. Check out [this blog post](https://huggingface.co/blog/how-to-generate) to know all the details about generating text with Transformers. and auto-regressively generates the decoder output. Check out [this blog post](https://huggingface.co/blog/how-to-generate) to know all the details about generating text with Transformers.
There's also [this blog post](https://huggingface.co/blog/encoder-decoder#encoder-decoder) which explains how There's also [this blog post](https://huggingface.co/blog/encoder-decoder#encoder-decoder) which explains how
...@@ -244,7 +244,7 @@ Das Haus ist wunderbar. ...@@ -244,7 +244,7 @@ Das Haus ist wunderbar.
``` ```
Note that T5 uses the `pad_token_id` as the `decoder_start_token_id`, so when doing generation without using Note that T5 uses the `pad_token_id` as the `decoder_start_token_id`, so when doing generation without using
[`~generation_utils.GenerationMixin.generate`], make sure you start it with the `pad_token_id`. [`~generation.GenerationMixin.generate`], make sure you start it with the `pad_token_id`.
The example above only shows a single example. You can also do batched inference, like so: The example above only shows a single example. You can also do batched inference, like so:
......
...@@ -30,7 +30,7 @@ show that the TrOCR model outperforms the current state-of-the-art models on bot ...@@ -30,7 +30,7 @@ show that the TrOCR model outperforms the current state-of-the-art models on bot
tasks.* tasks.*
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/trocr_architecture.jpg" <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/trocr_architecture.jpg"
alt="drawing" width="600"/> alt="drawing" width="600"/>
<small> TrOCR architecture. Taken from the <a href="https://arxiv.org/abs/2109.10282">original paper</a>. </small> <small> TrOCR architecture. Taken from the <a href="https://arxiv.org/abs/2109.10282">original paper</a>. </small>
...@@ -53,7 +53,7 @@ Tips: ...@@ -53,7 +53,7 @@ Tips:
## Inference ## Inference
TrOCR's [`VisionEncoderDecoder`] model accepts images as input and makes use of TrOCR's [`VisionEncoderDecoder`] model accepts images as input and makes use of
[`~generation_utils.GenerationMixin.generate`] to autoregressively generate text given the input image. [`~generation.GenerationMixin.generate`] to autoregressively generate text given the input image.
The [`ViTFeatureExtractor`/`DeiTFeatureExtractor`] class is responsible for preprocessing the input image and The [`ViTFeatureExtractor`/`DeiTFeatureExtractor`] class is responsible for preprocessing the input image and
[`RobertaTokenizer`/`XLMRobertaTokenizer`] decodes the generated target tokens to the target string. The [`RobertaTokenizer`/`XLMRobertaTokenizer`] decodes the generated target tokens to the target string. The
...@@ -64,20 +64,20 @@ into a single instance to both extract the input features and decode the predict ...@@ -64,20 +64,20 @@ into a single instance to both extract the input features and decode the predict
``` py ``` py
>>> from transformers import TrOCRProcessor, VisionEncoderDecoderModel >>> from transformers import TrOCRProcessor, VisionEncoderDecoderModel
>>> import requests >>> import requests
>>> from PIL import Image >>> from PIL import Image
>>> processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") >>> processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
>>> model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") >>> model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
>>> # load image from the IAM dataset >>> # load image from the IAM dataset
>>> url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg" >>> url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB") >>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
>>> pixel_values = processor(image, return_tensors="pt").pixel_values >>> pixel_values = processor(image, return_tensors="pt").pixel_values
>>> generated_ids = model.generate(pixel_values) >>> generated_ids = model.generate(pixel_values)
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
``` ```
See the [model hub](https://huggingface.co/models?filter=trocr) to look for TrOCR checkpoints. See the [model hub](https://huggingface.co/models?filter=trocr) to look for TrOCR checkpoints.
......
...@@ -24,7 +24,7 @@ The abstract from the paper is the following: ...@@ -24,7 +24,7 @@ The abstract from the paper is the following:
Tips: Tips:
- The model usually performs well without requiring any finetuning. - The model usually performs well without requiring any finetuning.
- The architecture follows a classic encoder-decoder architecture, which means that it relies on the [`~generation_utils.GenerationMixin.generate`] function for inference. - The architecture follows a classic encoder-decoder architecture, which means that it relies on the [`~generation.GenerationMixin.generate`] function for inference.
- Inference is currently only implemented for short-form i.e. audio is pre-segmented into <=30s segments. Long-form (including timestamps) will be implemented in a future release. - Inference is currently only implemented for short-form i.e. audio is pre-segmented into <=30s segments. Long-form (including timestamps) will be implemented in a future release.
- One can use [`WhisperProcessor`] to prepare audio for the model, and decode the predicted ID's back into text. - One can use [`WhisperProcessor`] to prepare audio for the model, and decode the predicted ID's back into text.
......
...@@ -56,7 +56,7 @@ If you have more than one input, pass your input as a list: ...@@ -56,7 +56,7 @@ If you have more than one input, pass your input as a list:
... ) # doctest: +SKIP ... ) # doctest: +SKIP
``` ```
Any additional parameters for your task can also be included in the [`pipeline`]. The `text-generation` task has a [`~generation_utils.GenerationMixin.generate`] method with several parameters for controlling the output. For example, if you want to generate more than one output, set the `num_return_sequences` parameter: Any additional parameters for your task can also be included in the [`pipeline`]. The `text-generation` task has a [`~generation.GenerationMixin.generate`] method with several parameters for controlling the output. For example, if you want to generate more than one output, set the `num_return_sequences` parameter:
```py ```py
>>> generator( >>> generator(
......
...@@ -544,7 +544,7 @@ Hugging Face is based in DUMBO, New York City, and ... ...@@ -544,7 +544,7 @@ Hugging Face is based in DUMBO, New York City, and ...
This outputs a (hopefully) coherent next token following the original sequence, which in our case is the word *is* or This outputs a (hopefully) coherent next token following the original sequence, which in our case is the word *is* or
*features*. *features*.
In the next section, we show how [`generation_utils.GenerationMixin.generate`] can be used to In the next section, we show how [`generation.GenerationMixin.generate`] can be used to
generate multiple tokens up to a specified length instead of one token at a time. generate multiple tokens up to a specified length instead of one token at a time.
### Text Generation ### Text Generation
...@@ -1094,10 +1094,10 @@ The following examples demonstrate how to use a [`pipeline`] and a model and tok ...@@ -1094,10 +1094,10 @@ The following examples demonstrate how to use a [`pipeline`] and a model and tok
... images="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" ... images="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
... ) ... )
>>> print("\n".join([f"Class {d['label']} with score {round(d['score'], 4)}" for d in result])) >>> print("\n".join([f"Class {d['label']} with score {round(d['score'], 4)}" for d in result]))
Class lynx, catamount with score 0.4335 Class lynx, catamount with score 0.4335
Class cougar, puma, catamount, mountain lion, painter, panther, Felis concolor with score 0.0348 Class cougar, puma, catamount, mountain lion, painter, panther, Felis concolor with score 0.0348
Class snow leopard, ounce, Panthera uncia with score 0.0324 Class snow leopard, ounce, Panthera uncia with score 0.0324
Class Egyptian cat with score 0.0239 Class Egyptian cat with score 0.0239
Class tiger cat with score 0.0229 Class tiger cat with score 0.0229
``` ```
......
...@@ -54,7 +54,7 @@ Si tienes más de una entrada, pásala como una lista: ...@@ -54,7 +54,7 @@ Si tienes más de una entrada, pásala como una lista:
... ) ... )
``` ```
Cualquier parámetro adicional para tu tarea también se puede incluir en el [`pipeline`]. La tarea `text-generation` tiene un método [`~generation_utils.GenerationMixin.generate`] con varios parámetros para controlar la salida. Por ejemplo, si deseas generar más de una salida, defínelo en el parámetro `num_return_sequences`: Cualquier parámetro adicional para tu tarea también se puede incluir en el [`pipeline`]. La tarea `text-generation` tiene un método [`~generation.GenerationMixin.generate`] con varios parámetros para controlar la salida. Por ejemplo, si deseas generar más de una salida, defínelo en el parámetro `num_return_sequences`:
```py ```py
>>> generator( >>> generator(
......
...@@ -26,7 +26,7 @@ Dai un'occhiata alla documentazione di [`pipeline`] per una lista completa dei c ...@@ -26,7 +26,7 @@ Dai un'occhiata alla documentazione di [`pipeline`] per una lista completa dei c
## Utilizzo della Pipeline ## Utilizzo della Pipeline
Nonostante ogni compito abbia una [`pipeline`] associata, è più semplice utilizzare l'astrazione generica della [`pipeline`] che contiene tutte quelle specifiche per ogni mansione. La [`pipeline`] carica automaticamente un modello predefinito e un tokenizer in grado di fare inferenza per il tuo compito. Nonostante ogni compito abbia una [`pipeline`] associata, è più semplice utilizzare l'astrazione generica della [`pipeline`] che contiene tutte quelle specifiche per ogni mansione. La [`pipeline`] carica automaticamente un modello predefinito e un tokenizer in grado di fare inferenza per il tuo compito.
1. Inizia creando una [`pipeline`] e specificando il compito su cui fare inferenza: 1. Inizia creando una [`pipeline`] e specificando il compito su cui fare inferenza:
...@@ -56,7 +56,7 @@ Se hai più di un input, inseriscilo in una lista: ...@@ -56,7 +56,7 @@ Se hai più di un input, inseriscilo in una lista:
... ) # doctest: +SKIP ... ) # doctest: +SKIP
``` ```
Qualsiasi parametro addizionale per il tuo compito può essere incluso nella [`pipeline`]. La mansione `text-generation` ha un metodo [`~generation_utils.GenerationMixin.generate`] con diversi parametri per controllare l'output. Ad esempio, se desideri generare più di un output, utilizza il parametro `num_return_sequences`: Qualsiasi parametro addizionale per il tuo compito può essere incluso nella [`pipeline`]. La mansione `text-generation` ha un metodo [`~generation.GenerationMixin.generate`] con diversi parametri per controllare l'output. Ad esempio, se desideri generare più di un output, utilizza il parametro `num_return_sequences`:
```py ```py
>>> generator( >>> generator(
......
...@@ -61,7 +61,7 @@ Se tiver mais de uma entrada, passe-a como uma lista: ...@@ -61,7 +61,7 @@ Se tiver mais de uma entrada, passe-a como uma lista:
``` ```
Qualquer parâmetro adicional para a sua tarefa também pode ser incluído no [`pipeline`]. A tarefa `text-generation` tem um método Qualquer parâmetro adicional para a sua tarefa também pode ser incluído no [`pipeline`]. A tarefa `text-generation` tem um método
[`~generation_utils.GenerationMixin.generate`] com vários parâmetros para controlar a saída. [`~generation.GenerationMixin.generate`] com vários parâmetros para controlar a saída.
Por exemplo, se quiser gerar mais de uma saída, defina-a no parâmetro `num_return_sequences`: Por exemplo, se quiser gerar mais de uma saída, defina-a no parâmetro `num_return_sequences`:
```py ```py
......
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import BartConfig from transformers import BartConfig
from transformers.generation_utils import GenerationMixin from transformers.generation import GenerationMixin
def _convert_past_list_to_tuple(past_key_values): def _convert_past_list_to_tuple(past_key_values):
......
...@@ -97,6 +97,7 @@ _import_structure = { ...@@ -97,6 +97,7 @@ _import_structure = {
"feature_extraction_sequence_utils": ["SequenceFeatureExtractor"], "feature_extraction_sequence_utils": ["SequenceFeatureExtractor"],
"feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"], "feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
"file_utils": [], "file_utils": [],
"generation": [],
"hf_argparser": ["HfArgumentParser"], "hf_argparser": ["HfArgumentParser"],
"integrations": [ "integrations": [
"is_comet_available", "is_comet_available",
...@@ -821,38 +822,40 @@ else: ...@@ -821,38 +822,40 @@ else:
"TextDatasetForNextSentencePrediction", "TextDatasetForNextSentencePrediction",
] ]
_import_structure["deepspeed"] = [] _import_structure["deepspeed"] = []
_import_structure["generation_beam_constraints"] = [ _import_structure["generation_utils"] = []
"Constraint", _import_structure["generation"].extend(
"ConstraintListState", [
"DisjunctiveConstraint", "Constraint",
"PhrasalConstraint", "ConstraintListState",
] "DisjunctiveConstraint",
_import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer", "ConstrainedBeamSearchScorer"] "PhrasalConstraint",
_import_structure["generation_logits_process"] = [ "BeamScorer",
"ForcedBOSTokenLogitsProcessor", "BeamSearchScorer",
"ForcedEOSTokenLogitsProcessor", "ConstrainedBeamSearchScorer",
"HammingDiversityLogitsProcessor", "ForcedBOSTokenLogitsProcessor",
"InfNanRemoveLogitsProcessor", "ForcedEOSTokenLogitsProcessor",
"LogitsProcessor", "HammingDiversityLogitsProcessor",
"LogitsProcessorList", "InfNanRemoveLogitsProcessor",
"LogitsWarper", "LogitsProcessor",
"MinLengthLogitsProcessor", "LogitsProcessorList",
"NoBadWordsLogitsProcessor", "LogitsWarper",
"NoRepeatNGramLogitsProcessor", "MinLengthLogitsProcessor",
"PrefixConstrainedLogitsProcessor", "NoBadWordsLogitsProcessor",
"RepetitionPenaltyLogitsProcessor", "NoRepeatNGramLogitsProcessor",
"TemperatureLogitsWarper", "PrefixConstrainedLogitsProcessor",
"TopKLogitsWarper", "RepetitionPenaltyLogitsProcessor",
"TopPLogitsWarper", "TemperatureLogitsWarper",
"TypicalLogitsWarper", "TopKLogitsWarper",
] "TopPLogitsWarper",
_import_structure["generation_stopping_criteria"] = [ "TypicalLogitsWarper",
"MaxLengthCriteria", "MaxLengthCriteria",
"MaxTimeCriteria", "MaxTimeCriteria",
"StoppingCriteria", "StoppingCriteria",
"StoppingCriteriaList", "StoppingCriteriaList",
] "GenerationMixin",
_import_structure["generation_utils"] = ["top_k_top_p_filtering"] "top_k_top_p_filtering",
]
)
_import_structure["modeling_outputs"] = [] _import_structure["modeling_outputs"] = []
_import_structure["modeling_utils"] = ["PreTrainedModel"] _import_structure["modeling_utils"] = ["PreTrainedModel"]
...@@ -2278,21 +2281,25 @@ else: ...@@ -2278,21 +2281,25 @@ else:
_import_structure["activations_tf"] = [] _import_structure["activations_tf"] = []
_import_structure["benchmark.benchmark_args_tf"] = ["TensorFlowBenchmarkArguments"] _import_structure["benchmark.benchmark_args_tf"] = ["TensorFlowBenchmarkArguments"]
_import_structure["benchmark.benchmark_tf"] = ["TensorFlowBenchmark"] _import_structure["benchmark.benchmark_tf"] = ["TensorFlowBenchmark"]
_import_structure["generation_tf_logits_process"] = [ _import_structure["generation_tf_utils"] = []
"TFForcedBOSTokenLogitsProcessor", _import_structure["generation"].extend(
"TFForcedEOSTokenLogitsProcessor", [
"TFLogitsProcessor", "TFForcedBOSTokenLogitsProcessor",
"TFLogitsProcessorList", "TFForcedEOSTokenLogitsProcessor",
"TFLogitsWarper", "TFLogitsProcessor",
"TFMinLengthLogitsProcessor", "TFLogitsProcessorList",
"TFNoBadWordsLogitsProcessor", "TFLogitsWarper",
"TFNoRepeatNGramLogitsProcessor", "TFMinLengthLogitsProcessor",
"TFRepetitionPenaltyLogitsProcessor", "TFNoBadWordsLogitsProcessor",
"TFTemperatureLogitsWarper", "TFNoRepeatNGramLogitsProcessor",
"TFTopKLogitsWarper", "TFRepetitionPenaltyLogitsProcessor",
"TFTopPLogitsWarper", "TFTemperatureLogitsWarper",
] "TFTopKLogitsWarper",
_import_structure["generation_tf_utils"] = ["tf_top_k_top_p_filtering"] "TFTopPLogitsWarper",
"TFGenerationMixin",
"tf_top_k_top_p_filtering",
]
)
_import_structure["keras_callbacks"] = ["KerasMetricCallback", "PushToHubCallback"] _import_structure["keras_callbacks"] = ["KerasMetricCallback", "PushToHubCallback"]
_import_structure["modeling_tf_outputs"] = [] _import_structure["modeling_tf_outputs"] = []
_import_structure["modeling_tf_utils"] = [ _import_structure["modeling_tf_utils"] = [
...@@ -2915,18 +2922,21 @@ except OptionalDependencyNotAvailable: ...@@ -2915,18 +2922,21 @@ except OptionalDependencyNotAvailable:
name for name in dir(dummy_flax_objects) if not name.startswith("_") name for name in dir(dummy_flax_objects) if not name.startswith("_")
] ]
else: else:
_import_structure["generation_flax_logits_process"] = [
"FlaxForcedBOSTokenLogitsProcessor",
"FlaxForcedEOSTokenLogitsProcessor",
"FlaxLogitsProcessor",
"FlaxLogitsProcessorList",
"FlaxLogitsWarper",
"FlaxMinLengthLogitsProcessor",
"FlaxTemperatureLogitsWarper",
"FlaxTopKLogitsWarper",
"FlaxTopPLogitsWarper",
]
_import_structure["generation_flax_utils"] = [] _import_structure["generation_flax_utils"] = []
_import_structure["generation"].extend(
[
"FlaxForcedBOSTokenLogitsProcessor",
"FlaxForcedEOSTokenLogitsProcessor",
"FlaxLogitsProcessor",
"FlaxLogitsProcessorList",
"FlaxLogitsWarper",
"FlaxMinLengthLogitsProcessor",
"FlaxTemperatureLogitsWarper",
"FlaxTopKLogitsWarper",
"FlaxTopPLogitsWarper",
"FlaxGenerationMixin",
]
)
_import_structure["modeling_flax_outputs"] = [] _import_structure["modeling_flax_outputs"] = []
_import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"] _import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"]
_import_structure["models.albert"].extend( _import_structure["models.albert"].extend(
...@@ -3834,38 +3844,37 @@ if TYPE_CHECKING: ...@@ -3834,38 +3844,37 @@ if TYPE_CHECKING:
TextDataset, TextDataset,
TextDatasetForNextSentencePrediction, TextDatasetForNextSentencePrediction,
) )
from .generation_beam_constraints import ( from .generation import (
BeamScorer,
BeamSearchScorer,
ConstrainedBeamSearchScorer,
Constraint, Constraint,
ConstraintListState, ConstraintListState,
DisjunctiveConstraint, DisjunctiveConstraint,
PhrasalConstraint,
)
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .generation_logits_process import (
ForcedBOSTokenLogitsProcessor, ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor,
GenerationMixin,
HammingDiversityLogitsProcessor, HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor, InfNanRemoveLogitsProcessor,
LogitsProcessor, LogitsProcessor,
LogitsProcessorList, LogitsProcessorList,
LogitsWarper, LogitsWarper,
MaxLengthCriteria,
MaxTimeCriteria,
MinLengthLogitsProcessor, MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor, NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor, NoRepeatNGramLogitsProcessor,
PhrasalConstraint,
PrefixConstrainedLogitsProcessor, PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor, RepetitionPenaltyLogitsProcessor,
StoppingCriteria,
StoppingCriteriaList,
TemperatureLogitsWarper, TemperatureLogitsWarper,
TopKLogitsWarper, TopKLogitsWarper,
TopPLogitsWarper, TopPLogitsWarper,
TypicalLogitsWarper, TypicalLogitsWarper,
top_k_top_p_filtering,
) )
from .generation_stopping_criteria import (
MaxLengthCriteria,
MaxTimeCriteria,
StoppingCriteria,
StoppingCriteriaList,
)
from .generation_utils import top_k_top_p_filtering
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
# PyTorch model imports # PyTorch model imports
...@@ -5037,9 +5046,10 @@ if TYPE_CHECKING: ...@@ -5037,9 +5046,10 @@ if TYPE_CHECKING:
# Benchmarks # Benchmarks
from .benchmark.benchmark_tf import TensorFlowBenchmark from .benchmark.benchmark_tf import TensorFlowBenchmark
from .generation_tf_logits_process import ( from .generation import (
TFForcedBOSTokenLogitsProcessor, TFForcedBOSTokenLogitsProcessor,
TFForcedEOSTokenLogitsProcessor, TFForcedEOSTokenLogitsProcessor,
TFGenerationMixin,
TFLogitsProcessor, TFLogitsProcessor,
TFLogitsProcessorList, TFLogitsProcessorList,
TFLogitsWarper, TFLogitsWarper,
...@@ -5050,8 +5060,8 @@ if TYPE_CHECKING: ...@@ -5050,8 +5060,8 @@ if TYPE_CHECKING:
TFTemperatureLogitsWarper, TFTemperatureLogitsWarper,
TFTopKLogitsWarper, TFTopKLogitsWarper,
TFTopPLogitsWarper, TFTopPLogitsWarper,
tf_top_k_top_p_filtering,
) )
from .generation_tf_utils import tf_top_k_top_p_filtering
from .keras_callbacks import KerasMetricCallback, PushToHubCallback from .keras_callbacks import KerasMetricCallback, PushToHubCallback
from .modeling_tf_layoutlm import ( from .modeling_tf_layoutlm import (
TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST, TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
...@@ -5541,9 +5551,10 @@ if TYPE_CHECKING: ...@@ -5541,9 +5551,10 @@ if TYPE_CHECKING:
# They will raise an import error if the user tries to instantiate / use them. # They will raise an import error if the user tries to instantiate / use them.
from .utils.dummy_flax_objects import * from .utils.dummy_flax_objects import *
else: else:
from .generation_flax_logits_process import ( from .generation import (
FlaxForcedBOSTokenLogitsProcessor, FlaxForcedBOSTokenLogitsProcessor,
FlaxForcedEOSTokenLogitsProcessor, FlaxForcedEOSTokenLogitsProcessor,
FlaxGenerationMixin,
FlaxLogitsProcessor, FlaxLogitsProcessor,
FlaxLogitsProcessorList, FlaxLogitsProcessorList,
FlaxLogitsWarper, FlaxLogitsWarper,
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_tf_available, is_torch_available
_import_structure = {}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["beam_constraints"] = [
"Constraint",
"ConstraintListState",
"DisjunctiveConstraint",
"PhrasalConstraint",
]
_import_structure["beam_search"] = [
"BeamHypotheses",
"BeamScorer",
"BeamSearchScorer",
"ConstrainedBeamSearchScorer",
]
_import_structure["logits_process"] = [
"ForcedBOSTokenLogitsProcessor",
"ForcedEOSTokenLogitsProcessor",
"HammingDiversityLogitsProcessor",
"InfNanRemoveLogitsProcessor",
"LogitsProcessor",
"LogitsProcessorList",
"LogitsWarper",
"MinLengthLogitsProcessor",
"NoBadWordsLogitsProcessor",
"NoRepeatNGramLogitsProcessor",
"PrefixConstrainedLogitsProcessor",
"RepetitionPenaltyLogitsProcessor",
"TemperatureLogitsWarper",
"TopKLogitsWarper",
"TopPLogitsWarper",
"TypicalLogitsWarper",
"EncoderNoRepeatNGramLogitsProcessor",
"ExponentialDecayLengthPenalty",
"LogitNormalization",
]
_import_structure["stopping_criteria"] = [
"MaxNewTokensCriteria",
"MaxLengthCriteria",
"MaxTimeCriteria",
"StoppingCriteria",
"StoppingCriteriaList",
"validate_stopping_criteria",
]
_import_structure["utils"] = [
"GenerationMixin",
"top_k_top_p_filtering",
"GreedySearchEncoderDecoderOutput",
"GreedySearchDecoderOnlyOutput",
"SampleEncoderDecoderOutput",
"SampleDecoderOnlyOutput",
"BeamSearchEncoderDecoderOutput",
"BeamSearchDecoderOnlyOutput",
"BeamSampleEncoderDecoderOutput",
"BeamSampleDecoderOnlyOutput",
"ContrastiveSearchEncoderDecoderOutput",
"ContrastiveSearchDecoderOnlyOutput",
]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tf_logits_process"] = [
"TFForcedBOSTokenLogitsProcessor",
"TFForcedEOSTokenLogitsProcessor",
"TFLogitsProcessor",
"TFLogitsProcessorList",
"TFLogitsWarper",
"TFMinLengthLogitsProcessor",
"TFNoBadWordsLogitsProcessor",
"TFNoRepeatNGramLogitsProcessor",
"TFRepetitionPenaltyLogitsProcessor",
"TFTemperatureLogitsWarper",
"TFTopKLogitsWarper",
"TFTopPLogitsWarper",
"TFForceTokensLogitsProcessor",
"TFSuppressTokensAtBeginLogitsProcessor",
"TFSuppressTokensLogitsProcessor",
]
_import_structure["tf_utils"] = [
"TFGenerationMixin",
"tf_top_k_top_p_filtering",
"TFGreedySearchDecoderOnlyOutput",
"TFGreedySearchEncoderDecoderOutput",
"TFSampleEncoderDecoderOutput",
"TFSampleDecoderOnlyOutput",
"TFBeamSearchEncoderDecoderOutput",
"TFBeamSearchDecoderOnlyOutput",
"TFBeamSampleEncoderDecoderOutput",
"TFBeamSampleDecoderOnlyOutput",
"TFContrastiveSearchEncoderDecoderOutput",
"TFContrastiveSearchDecoderOnlyOutput",
]
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["flax_logits_process"] = [
"FlaxForcedBOSTokenLogitsProcessor",
"FlaxForcedEOSTokenLogitsProcessor",
"FlaxLogitsProcessor",
"FlaxLogitsProcessorList",
"FlaxLogitsWarper",
"FlaxMinLengthLogitsProcessor",
"FlaxTemperatureLogitsWarper",
"FlaxTopKLogitsWarper",
"FlaxTopPLogitsWarper",
]
_import_structure["flax_utils"] = [
"FlaxGenerationMixin",
"FlaxGreedySearchOutput",
"FlaxSampleOutput",
"FlaxBeamSearchOutput",
]
if TYPE_CHECKING:
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint
from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .logits_process import (
EncoderNoRepeatNGramLogitsProcessor,
ExponentialDecayLengthPenalty,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitNormalization,
LogitsProcessor,
LogitsProcessorList,
LogitsWarper,
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
)
from .stopping_criteria import (
MaxLengthCriteria,
MaxNewTokensCriteria,
MaxTimeCriteria,
StoppingCriteria,
StoppingCriteriaList,
validate_stopping_criteria,
)
from .utils import (
BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput,
BeamSearchDecoderOnlyOutput,
BeamSearchEncoderDecoderOutput,
ContrastiveSearchDecoderOnlyOutput,
ContrastiveSearchEncoderDecoderOutput,
GenerationMixin,
GreedySearchDecoderOnlyOutput,
GreedySearchEncoderDecoderOutput,
SampleDecoderOnlyOutput,
SampleEncoderDecoderOutput,
top_k_top_p_filtering,
)
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tf_logits_process import (
TFForcedBOSTokenLogitsProcessor,
TFForcedEOSTokenLogitsProcessor,
TFForceTokensLogitsProcessor,
TFLogitsProcessor,
TFLogitsProcessorList,
TFLogitsWarper,
TFMinLengthLogitsProcessor,
TFNoBadWordsLogitsProcessor,
TFNoRepeatNGramLogitsProcessor,
TFRepetitionPenaltyLogitsProcessor,
TFSuppressTokensAtBeginLogitsProcessor,
TFSuppressTokensLogitsProcessor,
TFTemperatureLogitsWarper,
TFTopKLogitsWarper,
TFTopPLogitsWarper,
)
from .tf_utils import (
TFBeamSampleDecoderOnlyOutput,
TFBeamSampleEncoderDecoderOutput,
TFBeamSearchDecoderOnlyOutput,
TFBeamSearchEncoderDecoderOutput,
TFContrastiveSearchDecoderOnlyOutput,
TFContrastiveSearchEncoderDecoderOutput,
TFGenerationMixin,
TFGreedySearchDecoderOnlyOutput,
TFGreedySearchEncoderDecoderOutput,
TFSampleDecoderOnlyOutput,
TFSampleEncoderDecoderOutput,
tf_top_k_top_p_filtering,
)
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .flax_logits_process import (
FlaxForcedBOSTokenLogitsProcessor,
FlaxForcedEOSTokenLogitsProcessor,
FlaxLogitsProcessor,
FlaxLogitsProcessorList,
FlaxLogitsWarper,
FlaxMinLengthLogitsProcessor,
FlaxTemperatureLogitsWarper,
FlaxTopKLogitsWarper,
FlaxTopPLogitsWarper,
)
from .flax_utils import FlaxBeamSearchOutput, FlaxGenerationMixin, FlaxGreedySearchOutput, FlaxSampleOutput
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment