"test/git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "56ab0368f295defbbb04de61f2677c7e86052630"
Unverified Commit f270b960 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: move generation_*.py src files into generation/*.py (#20096)

* move generation_*.py src files into generation/*.py

* populate generation.__init__ with lazy loading

* move imports and references from generation.xxx.object to generation.object
parent bac2d29a
...@@ -56,7 +56,7 @@ Wenn Sie mehr als eine Eingabe haben, übergeben Sie die Eingabe als Liste: ...@@ -56,7 +56,7 @@ Wenn Sie mehr als eine Eingabe haben, übergeben Sie die Eingabe als Liste:
... ) # doctest: +SKIP ... ) # doctest: +SKIP
``` ```
Alle zusätzlichen Parameter für Ihre Aufgabe können auch in die [`pipeline`] aufgenommen werden. Die Aufgabe `Text-Generierung` hat eine [`~generation_utils.GenerationMixin.generate`]-Methode mit mehreren Parametern zur Steuerung der Ausgabe. Wenn Sie zum Beispiel mehr als eine Ausgabe erzeugen wollen, setzen Sie den Parameter `num_return_sequences`: Alle zusätzlichen Parameter für Ihre Aufgabe können auch in die [`pipeline`] aufgenommen werden. Die Aufgabe `Text-Generierung` hat eine [`~generation.GenerationMixin.generate`]-Methode mit mehreren Parametern zur Steuerung der Ausgabe. Wenn Sie zum Beispiel mehr als eine Ausgabe erzeugen wollen, setzen Sie den Parameter `num_return_sequences`:
```py ```py
>>> generator( >>> generator(
......
...@@ -12,22 +12,22 @@ specific language governing permissions and limitations under the License. ...@@ -12,22 +12,22 @@ specific language governing permissions and limitations under the License.
# Utilities for Generation # Utilities for Generation
This page lists all the utility functions used by [`~generation_utils.GenerationMixin.generate`], This page lists all the utility functions used by [`~generation.GenerationMixin.generate`],
[`~generation_utils.GenerationMixin.greedy_search`], [`~generation.GenerationMixin.greedy_search`],
[`~generation_utils.GenerationMixin.contrastive_search`], [`~generation.GenerationMixin.contrastive_search`],
[`~generation_utils.GenerationMixin.sample`], [`~generation.GenerationMixin.sample`],
[`~generation_utils.GenerationMixin.beam_search`], [`~generation.GenerationMixin.beam_search`],
[`~generation_utils.GenerationMixin.beam_sample`], [`~generation.GenerationMixin.beam_sample`],
[`~generation_utils.GenerationMixin.group_beam_search`], and [`~generation.GenerationMixin.group_beam_search`], and
[`~generation_utils.GenerationMixin.constrained_beam_search`]. [`~generation.GenerationMixin.constrained_beam_search`].
Most of those are only useful if you are studying the code of the generate methods in the library. Most of those are only useful if you are studying the code of the generate methods in the library.
## Generate Outputs ## Generate Outputs
The output of [`~generation_utils.GenerationMixin.generate`] is an instance of a subclass of The output of [`~generation.GenerationMixin.generate`] is an instance of a subclass of
[`~utils.ModelOutput`]. This output is a data structure containing all the information returned [`~utils.ModelOutput`]. This output is a data structure containing all the information returned
by [`~generation_utils.GenerationMixin.generate`], but that can also be used as tuple or dictionary. by [`~generation.GenerationMixin.generate`], but that can also be used as tuple or dictionary.
Here's an example: Here's an example:
...@@ -41,7 +41,7 @@ inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt") ...@@ -41,7 +41,7 @@ inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt")
generation_output = model.generate(**inputs, return_dict_in_generate=True, output_scores=True) generation_output = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
``` ```
The `generation_output` object is a [`~generation_utils.GreedySearchDecoderOnlyOutput`], as we can The `generation_output` object is a [`~generation.GreedySearchDecoderOnlyOutput`], as we can
see in the documentation of that class below, it means it has the following attributes: see in the documentation of that class below, it means it has the following attributes:
- `sequences`: the generated sequences of tokens - `sequences`: the generated sequences of tokens
...@@ -73,31 +73,31 @@ We document here all output types. ...@@ -73,31 +73,31 @@ We document here all output types.
### GreedySearchOutput ### GreedySearchOutput
[[autodoc]] generation_utils.GreedySearchDecoderOnlyOutput [[autodoc]] generation.GreedySearchDecoderOnlyOutput
[[autodoc]] generation_utils.GreedySearchEncoderDecoderOutput [[autodoc]] generation.GreedySearchEncoderDecoderOutput
[[autodoc]] generation_flax_utils.FlaxGreedySearchOutput [[autodoc]] generation.FlaxGreedySearchOutput
### SampleOutput ### SampleOutput
[[autodoc]] generation_utils.SampleDecoderOnlyOutput [[autodoc]] generation.SampleDecoderOnlyOutput
[[autodoc]] generation_utils.SampleEncoderDecoderOutput [[autodoc]] generation.SampleEncoderDecoderOutput
[[autodoc]] generation_flax_utils.FlaxSampleOutput [[autodoc]] generation.FlaxSampleOutput
### BeamSearchOutput ### BeamSearchOutput
[[autodoc]] generation_utils.BeamSearchDecoderOnlyOutput [[autodoc]] generation.BeamSearchDecoderOnlyOutput
[[autodoc]] generation_utils.BeamSearchEncoderDecoderOutput [[autodoc]] generation.BeamSearchEncoderDecoderOutput
### BeamSampleOutput ### BeamSampleOutput
[[autodoc]] generation_utils.BeamSampleDecoderOnlyOutput [[autodoc]] generation.BeamSampleDecoderOnlyOutput
[[autodoc]] generation_utils.BeamSampleEncoderDecoderOutput [[autodoc]] generation.BeamSampleEncoderDecoderOutput
## LogitsProcessor ## LogitsProcessor
......
...@@ -25,9 +25,9 @@ are common among all the models to: ...@@ -25,9 +25,9 @@ are common among all the models to:
The other methods that are common to each model are defined in [`~modeling_utils.ModuleUtilsMixin`] The other methods that are common to each model are defined in [`~modeling_utils.ModuleUtilsMixin`]
(for the PyTorch models) and [`~modeling_tf_utils.TFModuleUtilsMixin`] (for the TensorFlow models) or (for the PyTorch models) and [`~modeling_tf_utils.TFModuleUtilsMixin`] (for the TensorFlow models) or
for text generation, [`~generation_utils.GenerationMixin`] (for the PyTorch models), for text generation, [`~generation.GenerationMixin`] (for the PyTorch models),
[`~generation_tf_utils.TFGenerationMixin`] (for the TensorFlow models) and [`~generation.TFGenerationMixin`] (for the TensorFlow models) and
[`~generation_flax_utils.FlaxGenerationMixin`] (for the Flax/JAX models). [`~generation.FlaxGenerationMixin`] (for the Flax/JAX models).
## PreTrainedModel ## PreTrainedModel
......
...@@ -14,13 +14,13 @@ specific language governing permissions and limitations under the License. ...@@ -14,13 +14,13 @@ specific language governing permissions and limitations under the License.
Each framework has a generate method for auto-regressive text generation implemented in their respective `GenerationMixin` class: Each framework has a generate method for auto-regressive text generation implemented in their respective `GenerationMixin` class:
- PyTorch [`~generation_utils.GenerationMixin.generate`] is implemented in [`~generation_utils.GenerationMixin`]. - PyTorch [`~generation.GenerationMixin.generate`] is implemented in [`~generation.GenerationMixin`].
- TensorFlow [`~generation_tf_utils.TFGenerationMixin.generate`] is implemented in [`~generation_tf_utils.TFGenerationMixin`]. - TensorFlow [`~generation.TFGenerationMixin.generate`] is implemented in [`~generation.TFGenerationMixin`].
- Flax/JAX [`~generation_flax_utils.FlaxGenerationMixin.generate`] is implemented in [`~generation_flax_utils.FlaxGenerationMixin`]. - Flax/JAX [`~generation.FlaxGenerationMixin.generate`] is implemented in [`~generation.FlaxGenerationMixin`].
## GenerationMixin ## GenerationMixin
[[autodoc]] generation_utils.GenerationMixin [[autodoc]] generation.GenerationMixin
- generate - generate
- greedy_search - greedy_search
- sample - sample
...@@ -32,10 +32,10 @@ Each framework has a generate method for auto-regressive text generation impleme ...@@ -32,10 +32,10 @@ Each framework has a generate method for auto-regressive text generation impleme
## TFGenerationMixin ## TFGenerationMixin
[[autodoc]] generation_tf_utils.TFGenerationMixin [[autodoc]] generation.TFGenerationMixin
- generate - generate
## FlaxGenerationMixin ## FlaxGenerationMixin
[[autodoc]] generation_flax_utils.FlaxGenerationMixin [[autodoc]] generation.FlaxGenerationMixin
- generate - generate
...@@ -58,7 +58,7 @@ This model was contributed by [sshleifer](https://huggingface.co/sshleifer). The ...@@ -58,7 +58,7 @@ This model was contributed by [sshleifer](https://huggingface.co/sshleifer). The
- Model predictions are intended to be identical to the original implementation when - Model predictions are intended to be identical to the original implementation when
`forced_bos_token_id=0`. This only works, however, if the string you pass to `forced_bos_token_id=0`. This only works, however, if the string you pass to
[`fairseq.encode`] starts with a space. [`fairseq.encode`] starts with a space.
- [`~generation_utils.GenerationMixin.generate`] should be used for conditional generation tasks like - [`~generation.GenerationMixin.generate`] should be used for conditional generation tasks like
summarization, see the example in that docstrings. summarization, see the example in that docstrings.
- Models that load the *facebook/bart-large-cnn* weights will not have a `mask_token_id`, or be able to perform - Models that load the *facebook/bart-large-cnn* weights will not have a `mask_token_id`, or be able to perform
mask-filling tasks. mask-filling tasks.
......
...@@ -40,7 +40,7 @@ Tips: ...@@ -40,7 +40,7 @@ Tips:
## Inference ## Inference
Donut's [`VisionEncoderDecoder`] model accepts images as input and makes use of Donut's [`VisionEncoderDecoder`] model accepts images as input and makes use of
[`~generation_utils.GenerationMixin.generate`] to autoregressively generate text given the input image. [`~generation.GenerationMixin.generate`] to autoregressively generate text given the input image.
The [`DonutFeatureExtractor`] class is responsible for preprocessing the input image and The [`DonutFeatureExtractor`] class is responsible for preprocessing the input image and
[`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`] decodes the generated target tokens to the target string. The [`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`] decodes the generated target tokens to the target string. The
......
...@@ -53,7 +53,7 @@ Tips: ...@@ -53,7 +53,7 @@ Tips:
### Generation ### Generation
The [`~generation_utils.GenerationMixin.generate`] method can be used to generate text using GPT-J The [`~generation.GenerationMixin.generate`] method can be used to generate text using GPT-J
model. model.
```python ```python
......
...@@ -38,7 +38,7 @@ Tips: ...@@ -38,7 +38,7 @@ Tips:
## Inference ## Inference
Speech2Text2's [`SpeechEncoderDecoderModel`] model accepts raw waveform input values from speech and Speech2Text2's [`SpeechEncoderDecoderModel`] model accepts raw waveform input values from speech and
makes use of [`~generation_utils.GenerationMixin.generate`] to translate the input speech makes use of [`~generation.GenerationMixin.generate`] to translate the input speech
autoregressively to the target language. autoregressively to the target language.
The [`Wav2Vec2FeatureExtractor`] class is responsible for preprocessing the input speech and The [`Wav2Vec2FeatureExtractor`] class is responsible for preprocessing the input speech and
......
...@@ -225,7 +225,7 @@ batch) leads to very slow training on TPU. ...@@ -225,7 +225,7 @@ batch) leads to very slow training on TPU.
## Inference ## Inference
At inference time, it is recommended to use [`~generation_utils.GenerationMixin.generate`]. This At inference time, it is recommended to use [`~generation.GenerationMixin.generate`]. This
method takes care of encoding the input and feeding the encoded hidden states via cross-attention layers to the decoder method takes care of encoding the input and feeding the encoded hidden states via cross-attention layers to the decoder
and auto-regressively generates the decoder output. Check out [this blog post](https://huggingface.co/blog/how-to-generate) to know all the details about generating text with Transformers. and auto-regressively generates the decoder output. Check out [this blog post](https://huggingface.co/blog/how-to-generate) to know all the details about generating text with Transformers.
There's also [this blog post](https://huggingface.co/blog/encoder-decoder#encoder-decoder) which explains how There's also [this blog post](https://huggingface.co/blog/encoder-decoder#encoder-decoder) which explains how
...@@ -244,7 +244,7 @@ Das Haus ist wunderbar. ...@@ -244,7 +244,7 @@ Das Haus ist wunderbar.
``` ```
Note that T5 uses the `pad_token_id` as the `decoder_start_token_id`, so when doing generation without using Note that T5 uses the `pad_token_id` as the `decoder_start_token_id`, so when doing generation without using
[`~generation_utils.GenerationMixin.generate`], make sure you start it with the `pad_token_id`. [`~generation.GenerationMixin.generate`], make sure you start it with the `pad_token_id`.
The example above only shows a single example. You can also do batched inference, like so: The example above only shows a single example. You can also do batched inference, like so:
......
...@@ -53,7 +53,7 @@ Tips: ...@@ -53,7 +53,7 @@ Tips:
## Inference ## Inference
TrOCR's [`VisionEncoderDecoder`] model accepts images as input and makes use of TrOCR's [`VisionEncoderDecoder`] model accepts images as input and makes use of
[`~generation_utils.GenerationMixin.generate`] to autoregressively generate text given the input image. [`~generation.GenerationMixin.generate`] to autoregressively generate text given the input image.
The [`ViTFeatureExtractor`/`DeiTFeatureExtractor`] class is responsible for preprocessing the input image and The [`ViTFeatureExtractor`/`DeiTFeatureExtractor`] class is responsible for preprocessing the input image and
[`RobertaTokenizer`/`XLMRobertaTokenizer`] decodes the generated target tokens to the target string. The [`RobertaTokenizer`/`XLMRobertaTokenizer`] decodes the generated target tokens to the target string. The
......
...@@ -24,7 +24,7 @@ The abstract from the paper is the following: ...@@ -24,7 +24,7 @@ The abstract from the paper is the following:
Tips: Tips:
- The model usually performs well without requiring any finetuning. - The model usually performs well without requiring any finetuning.
- The architecture follows a classic encoder-decoder architecture, which means that it relies on the [`~generation_utils.GenerationMixin.generate`] function for inference. - The architecture follows a classic encoder-decoder architecture, which means that it relies on the [`~generation.GenerationMixin.generate`] function for inference.
- Inference is currently only implemented for short-form i.e. audio is pre-segmented into <=30s segments. Long-form (including timestamps) will be implemented in a future release. - Inference is currently only implemented for short-form i.e. audio is pre-segmented into <=30s segments. Long-form (including timestamps) will be implemented in a future release.
- One can use [`WhisperProcessor`] to prepare audio for the model, and decode the predicted ID's back into text. - One can use [`WhisperProcessor`] to prepare audio for the model, and decode the predicted ID's back into text.
......
...@@ -56,7 +56,7 @@ If you have more than one input, pass your input as a list: ...@@ -56,7 +56,7 @@ If you have more than one input, pass your input as a list:
... ) # doctest: +SKIP ... ) # doctest: +SKIP
``` ```
Any additional parameters for your task can also be included in the [`pipeline`]. The `text-generation` task has a [`~generation_utils.GenerationMixin.generate`] method with several parameters for controlling the output. For example, if you want to generate more than one output, set the `num_return_sequences` parameter: Any additional parameters for your task can also be included in the [`pipeline`]. The `text-generation` task has a [`~generation.GenerationMixin.generate`] method with several parameters for controlling the output. For example, if you want to generate more than one output, set the `num_return_sequences` parameter:
```py ```py
>>> generator( >>> generator(
......
...@@ -544,7 +544,7 @@ Hugging Face is based in DUMBO, New York City, and ... ...@@ -544,7 +544,7 @@ Hugging Face is based in DUMBO, New York City, and ...
This outputs a (hopefully) coherent next token following the original sequence, which in our case is the word *is* or This outputs a (hopefully) coherent next token following the original sequence, which in our case is the word *is* or
*features*. *features*.
In the next section, we show how [`generation_utils.GenerationMixin.generate`] can be used to In the next section, we show how [`generation.GenerationMixin.generate`] can be used to
generate multiple tokens up to a specified length instead of one token at a time. generate multiple tokens up to a specified length instead of one token at a time.
### Text Generation ### Text Generation
......
...@@ -54,7 +54,7 @@ Si tienes más de una entrada, pásala como una lista: ...@@ -54,7 +54,7 @@ Si tienes más de una entrada, pásala como una lista:
... ) ... )
``` ```
Cualquier parámetro adicional para tu tarea también se puede incluir en el [`pipeline`]. La tarea `text-generation` tiene un método [`~generation_utils.GenerationMixin.generate`] con varios parámetros para controlar la salida. Por ejemplo, si deseas generar más de una salida, defínelo en el parámetro `num_return_sequences`: Cualquier parámetro adicional para tu tarea también se puede incluir en el [`pipeline`]. La tarea `text-generation` tiene un método [`~generation.GenerationMixin.generate`] con varios parámetros para controlar la salida. Por ejemplo, si deseas generar más de una salida, defínelo en el parámetro `num_return_sequences`:
```py ```py
>>> generator( >>> generator(
......
...@@ -56,7 +56,7 @@ Se hai più di un input, inseriscilo in una lista: ...@@ -56,7 +56,7 @@ Se hai più di un input, inseriscilo in una lista:
... ) # doctest: +SKIP ... ) # doctest: +SKIP
``` ```
Qualsiasi parametro addizionale per il tuo compito può essere incluso nella [`pipeline`]. La mansione `text-generation` ha un metodo [`~generation_utils.GenerationMixin.generate`] con diversi parametri per controllare l'output. Ad esempio, se desideri generare più di un output, utilizza il parametro `num_return_sequences`: Qualsiasi parametro addizionale per il tuo compito può essere incluso nella [`pipeline`]. La mansione `text-generation` ha un metodo [`~generation.GenerationMixin.generate`] con diversi parametri per controllare l'output. Ad esempio, se desideri generare più di un output, utilizza il parametro `num_return_sequences`:
```py ```py
>>> generator( >>> generator(
......
...@@ -61,7 +61,7 @@ Se tiver mais de uma entrada, passe-a como uma lista: ...@@ -61,7 +61,7 @@ Se tiver mais de uma entrada, passe-a como uma lista:
``` ```
Qualquer parâmetro adicional para a sua tarefa também pode ser incluído no [`pipeline`]. A tarefa `text-generation` tem um método Qualquer parâmetro adicional para a sua tarefa também pode ser incluído no [`pipeline`]. A tarefa `text-generation` tem um método
[`~generation_utils.GenerationMixin.generate`] com vários parâmetros para controlar a saída. [`~generation.GenerationMixin.generate`] com vários parâmetros para controlar a saída.
Por exemplo, se quiser gerar mais de uma saída, defina-a no parâmetro `num_return_sequences`: Por exemplo, se quiser gerar mais de uma saída, defina-a no parâmetro `num_return_sequences`:
```py ```py
......
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import BartConfig from transformers import BartConfig
from transformers.generation_utils import GenerationMixin from transformers.generation import GenerationMixin
def _convert_past_list_to_tuple(past_key_values): def _convert_past_list_to_tuple(past_key_values):
......
...@@ -97,6 +97,7 @@ _import_structure = { ...@@ -97,6 +97,7 @@ _import_structure = {
"feature_extraction_sequence_utils": ["SequenceFeatureExtractor"], "feature_extraction_sequence_utils": ["SequenceFeatureExtractor"],
"feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"], "feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
"file_utils": [], "file_utils": [],
"generation": [],
"hf_argparser": ["HfArgumentParser"], "hf_argparser": ["HfArgumentParser"],
"integrations": [ "integrations": [
"is_comet_available", "is_comet_available",
...@@ -821,14 +822,16 @@ else: ...@@ -821,14 +822,16 @@ else:
"TextDatasetForNextSentencePrediction", "TextDatasetForNextSentencePrediction",
] ]
_import_structure["deepspeed"] = [] _import_structure["deepspeed"] = []
_import_structure["generation_beam_constraints"] = [ _import_structure["generation_utils"] = []
_import_structure["generation"].extend(
[
"Constraint", "Constraint",
"ConstraintListState", "ConstraintListState",
"DisjunctiveConstraint", "DisjunctiveConstraint",
"PhrasalConstraint", "PhrasalConstraint",
] "BeamScorer",
_import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer", "ConstrainedBeamSearchScorer"] "BeamSearchScorer",
_import_structure["generation_logits_process"] = [ "ConstrainedBeamSearchScorer",
"ForcedBOSTokenLogitsProcessor", "ForcedBOSTokenLogitsProcessor",
"ForcedEOSTokenLogitsProcessor", "ForcedEOSTokenLogitsProcessor",
"HammingDiversityLogitsProcessor", "HammingDiversityLogitsProcessor",
...@@ -845,14 +848,14 @@ else: ...@@ -845,14 +848,14 @@ else:
"TopKLogitsWarper", "TopKLogitsWarper",
"TopPLogitsWarper", "TopPLogitsWarper",
"TypicalLogitsWarper", "TypicalLogitsWarper",
]
_import_structure["generation_stopping_criteria"] = [
"MaxLengthCriteria", "MaxLengthCriteria",
"MaxTimeCriteria", "MaxTimeCriteria",
"StoppingCriteria", "StoppingCriteria",
"StoppingCriteriaList", "StoppingCriteriaList",
"GenerationMixin",
"top_k_top_p_filtering",
] ]
_import_structure["generation_utils"] = ["top_k_top_p_filtering"] )
_import_structure["modeling_outputs"] = [] _import_structure["modeling_outputs"] = []
_import_structure["modeling_utils"] = ["PreTrainedModel"] _import_structure["modeling_utils"] = ["PreTrainedModel"]
...@@ -2278,7 +2281,9 @@ else: ...@@ -2278,7 +2281,9 @@ else:
_import_structure["activations_tf"] = [] _import_structure["activations_tf"] = []
_import_structure["benchmark.benchmark_args_tf"] = ["TensorFlowBenchmarkArguments"] _import_structure["benchmark.benchmark_args_tf"] = ["TensorFlowBenchmarkArguments"]
_import_structure["benchmark.benchmark_tf"] = ["TensorFlowBenchmark"] _import_structure["benchmark.benchmark_tf"] = ["TensorFlowBenchmark"]
_import_structure["generation_tf_logits_process"] = [ _import_structure["generation_tf_utils"] = []
_import_structure["generation"].extend(
[
"TFForcedBOSTokenLogitsProcessor", "TFForcedBOSTokenLogitsProcessor",
"TFForcedEOSTokenLogitsProcessor", "TFForcedEOSTokenLogitsProcessor",
"TFLogitsProcessor", "TFLogitsProcessor",
...@@ -2291,8 +2296,10 @@ else: ...@@ -2291,8 +2296,10 @@ else:
"TFTemperatureLogitsWarper", "TFTemperatureLogitsWarper",
"TFTopKLogitsWarper", "TFTopKLogitsWarper",
"TFTopPLogitsWarper", "TFTopPLogitsWarper",
"TFGenerationMixin",
"tf_top_k_top_p_filtering",
] ]
_import_structure["generation_tf_utils"] = ["tf_top_k_top_p_filtering"] )
_import_structure["keras_callbacks"] = ["KerasMetricCallback", "PushToHubCallback"] _import_structure["keras_callbacks"] = ["KerasMetricCallback", "PushToHubCallback"]
_import_structure["modeling_tf_outputs"] = [] _import_structure["modeling_tf_outputs"] = []
_import_structure["modeling_tf_utils"] = [ _import_structure["modeling_tf_utils"] = [
...@@ -2915,7 +2922,9 @@ except OptionalDependencyNotAvailable: ...@@ -2915,7 +2922,9 @@ except OptionalDependencyNotAvailable:
name for name in dir(dummy_flax_objects) if not name.startswith("_") name for name in dir(dummy_flax_objects) if not name.startswith("_")
] ]
else: else:
_import_structure["generation_flax_logits_process"] = [ _import_structure["generation_flax_utils"] = []
_import_structure["generation"].extend(
[
"FlaxForcedBOSTokenLogitsProcessor", "FlaxForcedBOSTokenLogitsProcessor",
"FlaxForcedEOSTokenLogitsProcessor", "FlaxForcedEOSTokenLogitsProcessor",
"FlaxLogitsProcessor", "FlaxLogitsProcessor",
...@@ -2925,8 +2934,9 @@ else: ...@@ -2925,8 +2934,9 @@ else:
"FlaxTemperatureLogitsWarper", "FlaxTemperatureLogitsWarper",
"FlaxTopKLogitsWarper", "FlaxTopKLogitsWarper",
"FlaxTopPLogitsWarper", "FlaxTopPLogitsWarper",
"FlaxGenerationMixin",
] ]
_import_structure["generation_flax_utils"] = [] )
_import_structure["modeling_flax_outputs"] = [] _import_structure["modeling_flax_outputs"] = []
_import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"] _import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"]
_import_structure["models.albert"].extend( _import_structure["models.albert"].extend(
...@@ -3834,38 +3844,37 @@ if TYPE_CHECKING: ...@@ -3834,38 +3844,37 @@ if TYPE_CHECKING:
TextDataset, TextDataset,
TextDatasetForNextSentencePrediction, TextDatasetForNextSentencePrediction,
) )
from .generation_beam_constraints import ( from .generation import (
BeamScorer,
BeamSearchScorer,
ConstrainedBeamSearchScorer,
Constraint, Constraint,
ConstraintListState, ConstraintListState,
DisjunctiveConstraint, DisjunctiveConstraint,
PhrasalConstraint,
)
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .generation_logits_process import (
ForcedBOSTokenLogitsProcessor, ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor,
GenerationMixin,
HammingDiversityLogitsProcessor, HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor, InfNanRemoveLogitsProcessor,
LogitsProcessor, LogitsProcessor,
LogitsProcessorList, LogitsProcessorList,
LogitsWarper, LogitsWarper,
MaxLengthCriteria,
MaxTimeCriteria,
MinLengthLogitsProcessor, MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor, NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor, NoRepeatNGramLogitsProcessor,
PhrasalConstraint,
PrefixConstrainedLogitsProcessor, PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor, RepetitionPenaltyLogitsProcessor,
StoppingCriteria,
StoppingCriteriaList,
TemperatureLogitsWarper, TemperatureLogitsWarper,
TopKLogitsWarper, TopKLogitsWarper,
TopPLogitsWarper, TopPLogitsWarper,
TypicalLogitsWarper, TypicalLogitsWarper,
top_k_top_p_filtering,
) )
from .generation_stopping_criteria import (
MaxLengthCriteria,
MaxTimeCriteria,
StoppingCriteria,
StoppingCriteriaList,
)
from .generation_utils import top_k_top_p_filtering
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
# PyTorch model imports # PyTorch model imports
...@@ -5037,9 +5046,10 @@ if TYPE_CHECKING: ...@@ -5037,9 +5046,10 @@ if TYPE_CHECKING:
# Benchmarks # Benchmarks
from .benchmark.benchmark_tf import TensorFlowBenchmark from .benchmark.benchmark_tf import TensorFlowBenchmark
from .generation_tf_logits_process import ( from .generation import (
TFForcedBOSTokenLogitsProcessor, TFForcedBOSTokenLogitsProcessor,
TFForcedEOSTokenLogitsProcessor, TFForcedEOSTokenLogitsProcessor,
TFGenerationMixin,
TFLogitsProcessor, TFLogitsProcessor,
TFLogitsProcessorList, TFLogitsProcessorList,
TFLogitsWarper, TFLogitsWarper,
...@@ -5050,8 +5060,8 @@ if TYPE_CHECKING: ...@@ -5050,8 +5060,8 @@ if TYPE_CHECKING:
TFTemperatureLogitsWarper, TFTemperatureLogitsWarper,
TFTopKLogitsWarper, TFTopKLogitsWarper,
TFTopPLogitsWarper, TFTopPLogitsWarper,
tf_top_k_top_p_filtering,
) )
from .generation_tf_utils import tf_top_k_top_p_filtering
from .keras_callbacks import KerasMetricCallback, PushToHubCallback from .keras_callbacks import KerasMetricCallback, PushToHubCallback
from .modeling_tf_layoutlm import ( from .modeling_tf_layoutlm import (
TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST, TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
...@@ -5541,9 +5551,10 @@ if TYPE_CHECKING: ...@@ -5541,9 +5551,10 @@ if TYPE_CHECKING:
# They will raise an import error if the user tries to instantiate / use them. # They will raise an import error if the user tries to instantiate / use them.
from .utils.dummy_flax_objects import * from .utils.dummy_flax_objects import *
else: else:
from .generation_flax_logits_process import ( from .generation import (
FlaxForcedBOSTokenLogitsProcessor, FlaxForcedBOSTokenLogitsProcessor,
FlaxForcedEOSTokenLogitsProcessor, FlaxForcedEOSTokenLogitsProcessor,
FlaxGenerationMixin,
FlaxLogitsProcessor, FlaxLogitsProcessor,
FlaxLogitsProcessorList, FlaxLogitsProcessorList,
FlaxLogitsWarper, FlaxLogitsWarper,
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_tf_available, is_torch_available
_import_structure = {}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["beam_constraints"] = [
"Constraint",
"ConstraintListState",
"DisjunctiveConstraint",
"PhrasalConstraint",
]
_import_structure["beam_search"] = [
"BeamHypotheses",
"BeamScorer",
"BeamSearchScorer",
"ConstrainedBeamSearchScorer",
]
_import_structure["logits_process"] = [
"ForcedBOSTokenLogitsProcessor",
"ForcedEOSTokenLogitsProcessor",
"HammingDiversityLogitsProcessor",
"InfNanRemoveLogitsProcessor",
"LogitsProcessor",
"LogitsProcessorList",
"LogitsWarper",
"MinLengthLogitsProcessor",
"NoBadWordsLogitsProcessor",
"NoRepeatNGramLogitsProcessor",
"PrefixConstrainedLogitsProcessor",
"RepetitionPenaltyLogitsProcessor",
"TemperatureLogitsWarper",
"TopKLogitsWarper",
"TopPLogitsWarper",
"TypicalLogitsWarper",
"EncoderNoRepeatNGramLogitsProcessor",
"ExponentialDecayLengthPenalty",
"LogitNormalization",
]
_import_structure["stopping_criteria"] = [
"MaxNewTokensCriteria",
"MaxLengthCriteria",
"MaxTimeCriteria",
"StoppingCriteria",
"StoppingCriteriaList",
"validate_stopping_criteria",
]
_import_structure["utils"] = [
"GenerationMixin",
"top_k_top_p_filtering",
"GreedySearchEncoderDecoderOutput",
"GreedySearchDecoderOnlyOutput",
"SampleEncoderDecoderOutput",
"SampleDecoderOnlyOutput",
"BeamSearchEncoderDecoderOutput",
"BeamSearchDecoderOnlyOutput",
"BeamSampleEncoderDecoderOutput",
"BeamSampleDecoderOnlyOutput",
"ContrastiveSearchEncoderDecoderOutput",
"ContrastiveSearchDecoderOnlyOutput",
]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tf_logits_process"] = [
"TFForcedBOSTokenLogitsProcessor",
"TFForcedEOSTokenLogitsProcessor",
"TFLogitsProcessor",
"TFLogitsProcessorList",
"TFLogitsWarper",
"TFMinLengthLogitsProcessor",
"TFNoBadWordsLogitsProcessor",
"TFNoRepeatNGramLogitsProcessor",
"TFRepetitionPenaltyLogitsProcessor",
"TFTemperatureLogitsWarper",
"TFTopKLogitsWarper",
"TFTopPLogitsWarper",
"TFForceTokensLogitsProcessor",
"TFSuppressTokensAtBeginLogitsProcessor",
"TFSuppressTokensLogitsProcessor",
]
_import_structure["tf_utils"] = [
"TFGenerationMixin",
"tf_top_k_top_p_filtering",
"TFGreedySearchDecoderOnlyOutput",
"TFGreedySearchEncoderDecoderOutput",
"TFSampleEncoderDecoderOutput",
"TFSampleDecoderOnlyOutput",
"TFBeamSearchEncoderDecoderOutput",
"TFBeamSearchDecoderOnlyOutput",
"TFBeamSampleEncoderDecoderOutput",
"TFBeamSampleDecoderOnlyOutput",
"TFContrastiveSearchEncoderDecoderOutput",
"TFContrastiveSearchDecoderOnlyOutput",
]
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["flax_logits_process"] = [
"FlaxForcedBOSTokenLogitsProcessor",
"FlaxForcedEOSTokenLogitsProcessor",
"FlaxLogitsProcessor",
"FlaxLogitsProcessorList",
"FlaxLogitsWarper",
"FlaxMinLengthLogitsProcessor",
"FlaxTemperatureLogitsWarper",
"FlaxTopKLogitsWarper",
"FlaxTopPLogitsWarper",
]
_import_structure["flax_utils"] = [
"FlaxGenerationMixin",
"FlaxGreedySearchOutput",
"FlaxSampleOutput",
"FlaxBeamSearchOutput",
]
if TYPE_CHECKING:
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint
from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .logits_process import (
EncoderNoRepeatNGramLogitsProcessor,
ExponentialDecayLengthPenalty,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitNormalization,
LogitsProcessor,
LogitsProcessorList,
LogitsWarper,
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
)
from .stopping_criteria import (
MaxLengthCriteria,
MaxNewTokensCriteria,
MaxTimeCriteria,
StoppingCriteria,
StoppingCriteriaList,
validate_stopping_criteria,
)
from .utils import (
BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput,
BeamSearchDecoderOnlyOutput,
BeamSearchEncoderDecoderOutput,
ContrastiveSearchDecoderOnlyOutput,
ContrastiveSearchEncoderDecoderOutput,
GenerationMixin,
GreedySearchDecoderOnlyOutput,
GreedySearchEncoderDecoderOutput,
SampleDecoderOnlyOutput,
SampleEncoderDecoderOutput,
top_k_top_p_filtering,
)
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tf_logits_process import (
TFForcedBOSTokenLogitsProcessor,
TFForcedEOSTokenLogitsProcessor,
TFForceTokensLogitsProcessor,
TFLogitsProcessor,
TFLogitsProcessorList,
TFLogitsWarper,
TFMinLengthLogitsProcessor,
TFNoBadWordsLogitsProcessor,
TFNoRepeatNGramLogitsProcessor,
TFRepetitionPenaltyLogitsProcessor,
TFSuppressTokensAtBeginLogitsProcessor,
TFSuppressTokensLogitsProcessor,
TFTemperatureLogitsWarper,
TFTopKLogitsWarper,
TFTopPLogitsWarper,
)
from .tf_utils import (
TFBeamSampleDecoderOnlyOutput,
TFBeamSampleEncoderDecoderOutput,
TFBeamSearchDecoderOnlyOutput,
TFBeamSearchEncoderDecoderOutput,
TFContrastiveSearchDecoderOnlyOutput,
TFContrastiveSearchEncoderDecoderOutput,
TFGenerationMixin,
TFGreedySearchDecoderOnlyOutput,
TFGreedySearchEncoderDecoderOutput,
TFSampleDecoderOnlyOutput,
TFSampleEncoderDecoderOutput,
tf_top_k_top_p_filtering,
)
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .flax_logits_process import (
FlaxForcedBOSTokenLogitsProcessor,
FlaxForcedEOSTokenLogitsProcessor,
FlaxLogitsProcessor,
FlaxLogitsProcessorList,
FlaxLogitsWarper,
FlaxMinLengthLogitsProcessor,
FlaxTemperatureLogitsWarper,
FlaxTopKLogitsWarper,
FlaxTopPLogitsWarper,
)
from .flax_utils import FlaxBeamSearchOutput, FlaxGenerationMixin, FlaxGreedySearchOutput, FlaxSampleOutput
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment