"docs/source/ja/perf_train_special.md" did not exist on "8abe4930d39d9fb379fe89ee19327d371daee29f"
Unverified Commit f270b960 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: move generation_*.py src files into generation/*.py (#20096)

* move generation_*.py src files into generation/*.py

* populate generation.__init__ with lazy loading

* move imports and references from generation.xxx.object to generation.object
parent bac2d29a
......@@ -56,7 +56,7 @@ Wenn Sie mehr als eine Eingabe haben, übergeben Sie die Eingabe als Liste:
... ) # doctest: +SKIP
```
Alle zusätzlichen Parameter für Ihre Aufgabe können auch in die [`pipeline`] aufgenommen werden. Die Aufgabe `Text-Generierung` hat eine [`~generation_utils.GenerationMixin.generate`]-Methode mit mehreren Parametern zur Steuerung der Ausgabe. Wenn Sie zum Beispiel mehr als eine Ausgabe erzeugen wollen, setzen Sie den Parameter `num_return_sequences`:
Alle zusätzlichen Parameter für Ihre Aufgabe können auch in die [`pipeline`] aufgenommen werden. Die Aufgabe `Text-Generierung` hat eine [`~generation.GenerationMixin.generate`]-Methode mit mehreren Parametern zur Steuerung der Ausgabe. Wenn Sie zum Beispiel mehr als eine Ausgabe erzeugen wollen, setzen Sie den Parameter `num_return_sequences`:
```py
>>> generator(
......
......@@ -12,22 +12,22 @@ specific language governing permissions and limitations under the License.
# Utilities for Generation
This page lists all the utility functions used by [`~generation_utils.GenerationMixin.generate`],
[`~generation_utils.GenerationMixin.greedy_search`],
[`~generation_utils.GenerationMixin.contrastive_search`],
[`~generation_utils.GenerationMixin.sample`],
[`~generation_utils.GenerationMixin.beam_search`],
[`~generation_utils.GenerationMixin.beam_sample`],
[`~generation_utils.GenerationMixin.group_beam_search`], and
[`~generation_utils.GenerationMixin.constrained_beam_search`].
This page lists all the utility functions used by [`~generation.GenerationMixin.generate`],
[`~generation.GenerationMixin.greedy_search`],
[`~generation.GenerationMixin.contrastive_search`],
[`~generation.GenerationMixin.sample`],
[`~generation.GenerationMixin.beam_search`],
[`~generation.GenerationMixin.beam_sample`],
[`~generation.GenerationMixin.group_beam_search`], and
[`~generation.GenerationMixin.constrained_beam_search`].
Most of those are only useful if you are studying the code of the generate methods in the library.
## Generate Outputs
The output of [`~generation_utils.GenerationMixin.generate`] is an instance of a subclass of
The output of [`~generation.GenerationMixin.generate`] is an instance of a subclass of
[`~utils.ModelOutput`]. This output is a data structure containing all the information returned
by [`~generation_utils.GenerationMixin.generate`], but that can also be used as tuple or dictionary.
by [`~generation.GenerationMixin.generate`], but that can also be used as tuple or dictionary.
Here's an example:
......@@ -41,7 +41,7 @@ inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt")
generation_output = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
```
The `generation_output` object is a [`~generation_utils.GreedySearchDecoderOnlyOutput`], as we can
The `generation_output` object is a [`~generation.GreedySearchDecoderOnlyOutput`], as we can
see in the documentation of that class below, it means it has the following attributes:
- `sequences`: the generated sequences of tokens
......@@ -73,31 +73,31 @@ We document here all output types.
### GreedySearchOutput
[[autodoc]] generation_utils.GreedySearchDecoderOnlyOutput
[[autodoc]] generation.GreedySearchDecoderOnlyOutput
[[autodoc]] generation_utils.GreedySearchEncoderDecoderOutput
[[autodoc]] generation.GreedySearchEncoderDecoderOutput
[[autodoc]] generation_flax_utils.FlaxGreedySearchOutput
[[autodoc]] generation.FlaxGreedySearchOutput
### SampleOutput
[[autodoc]] generation_utils.SampleDecoderOnlyOutput
[[autodoc]] generation.SampleDecoderOnlyOutput
[[autodoc]] generation_utils.SampleEncoderDecoderOutput
[[autodoc]] generation.SampleEncoderDecoderOutput
[[autodoc]] generation_flax_utils.FlaxSampleOutput
[[autodoc]] generation.FlaxSampleOutput
### BeamSearchOutput
[[autodoc]] generation_utils.BeamSearchDecoderOnlyOutput
[[autodoc]] generation.BeamSearchDecoderOnlyOutput
[[autodoc]] generation_utils.BeamSearchEncoderDecoderOutput
[[autodoc]] generation.BeamSearchEncoderDecoderOutput
### BeamSampleOutput
[[autodoc]] generation_utils.BeamSampleDecoderOnlyOutput
[[autodoc]] generation.BeamSampleDecoderOnlyOutput
[[autodoc]] generation_utils.BeamSampleEncoderDecoderOutput
[[autodoc]] generation.BeamSampleEncoderDecoderOutput
## LogitsProcessor
......
......@@ -25,9 +25,9 @@ are common among all the models to:
The other methods that are common to each model are defined in [`~modeling_utils.ModuleUtilsMixin`]
(for the PyTorch models) and [`~modeling_tf_utils.TFModuleUtilsMixin`] (for the TensorFlow models) or
for text generation, [`~generation_utils.GenerationMixin`] (for the PyTorch models),
[`~generation_tf_utils.TFGenerationMixin`] (for the TensorFlow models) and
[`~generation_flax_utils.FlaxGenerationMixin`] (for the Flax/JAX models).
for text generation, [`~generation.GenerationMixin`] (for the PyTorch models),
[`~generation.TFGenerationMixin`] (for the TensorFlow models) and
[`~generation.FlaxGenerationMixin`] (for the Flax/JAX models).
## PreTrainedModel
......
......@@ -14,13 +14,13 @@ specific language governing permissions and limitations under the License.
Each framework has a generate method for auto-regressive text generation implemented in their respective `GenerationMixin` class:
- PyTorch [`~generation_utils.GenerationMixin.generate`] is implemented in [`~generation_utils.GenerationMixin`].
- TensorFlow [`~generation_tf_utils.TFGenerationMixin.generate`] is implemented in [`~generation_tf_utils.TFGenerationMixin`].
- Flax/JAX [`~generation_flax_utils.FlaxGenerationMixin.generate`] is implemented in [`~generation_flax_utils.FlaxGenerationMixin`].
- PyTorch [`~generation.GenerationMixin.generate`] is implemented in [`~generation.GenerationMixin`].
- TensorFlow [`~generation.TFGenerationMixin.generate`] is implemented in [`~generation.TFGenerationMixin`].
- Flax/JAX [`~generation.FlaxGenerationMixin.generate`] is implemented in [`~generation.FlaxGenerationMixin`].
## GenerationMixin
[[autodoc]] generation_utils.GenerationMixin
[[autodoc]] generation.GenerationMixin
- generate
- greedy_search
- sample
......@@ -32,10 +32,10 @@ Each framework has a generate method for auto-regressive text generation impleme
## TFGenerationMixin
[[autodoc]] generation_tf_utils.TFGenerationMixin
[[autodoc]] generation.TFGenerationMixin
- generate
## FlaxGenerationMixin
[[autodoc]] generation_flax_utils.FlaxGenerationMixin
[[autodoc]] generation.FlaxGenerationMixin
- generate
......@@ -58,7 +58,7 @@ This model was contributed by [sshleifer](https://huggingface.co/sshleifer). The
- Model predictions are intended to be identical to the original implementation when
`forced_bos_token_id=0`. This only works, however, if the string you pass to
[`fairseq.encode`] starts with a space.
- [`~generation_utils.GenerationMixin.generate`] should be used for conditional generation tasks like
- [`~generation.GenerationMixin.generate`] should be used for conditional generation tasks like
summarization, see the example in that docstrings.
- Models that load the *facebook/bart-large-cnn* weights will not have a `mask_token_id`, or be able to perform
mask-filling tasks.
......
......@@ -40,7 +40,7 @@ Tips:
## Inference
Donut's [`VisionEncoderDecoder`] model accepts images as input and makes use of
[`~generation_utils.GenerationMixin.generate`] to autoregressively generate text given the input image.
[`~generation.GenerationMixin.generate`] to autoregressively generate text given the input image.
The [`DonutFeatureExtractor`] class is responsible for preprocessing the input image and
[`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`] decodes the generated target tokens to the target string. The
......
......@@ -53,7 +53,7 @@ Tips:
### Generation
The [`~generation_utils.GenerationMixin.generate`] method can be used to generate text using GPT-J
The [`~generation.GenerationMixin.generate`] method can be used to generate text using GPT-J
model.
```python
......
......@@ -38,7 +38,7 @@ Tips:
## Inference
Speech2Text2's [`SpeechEncoderDecoderModel`] model accepts raw waveform input values from speech and
makes use of [`~generation_utils.GenerationMixin.generate`] to translate the input speech
makes use of [`~generation.GenerationMixin.generate`] to translate the input speech
autoregressively to the target language.
The [`Wav2Vec2FeatureExtractor`] class is responsible for preprocessing the input speech and
......
......@@ -225,7 +225,7 @@ batch) leads to very slow training on TPU.
## Inference
At inference time, it is recommended to use [`~generation_utils.GenerationMixin.generate`]. This
At inference time, it is recommended to use [`~generation.GenerationMixin.generate`]. This
method takes care of encoding the input and feeding the encoded hidden states via cross-attention layers to the decoder
and auto-regressively generates the decoder output. Check out [this blog post](https://huggingface.co/blog/how-to-generate) to know all the details about generating text with Transformers.
There's also [this blog post](https://huggingface.co/blog/encoder-decoder#encoder-decoder) which explains how
......@@ -244,7 +244,7 @@ Das Haus ist wunderbar.
```
Note that T5 uses the `pad_token_id` as the `decoder_start_token_id`, so when doing generation without using
[`~generation_utils.GenerationMixin.generate`], make sure you start it with the `pad_token_id`.
[`~generation.GenerationMixin.generate`], make sure you start it with the `pad_token_id`.
The example above only shows a single example. You can also do batched inference, like so:
......
......@@ -53,7 +53,7 @@ Tips:
## Inference
TrOCR's [`VisionEncoderDecoder`] model accepts images as input and makes use of
[`~generation_utils.GenerationMixin.generate`] to autoregressively generate text given the input image.
[`~generation.GenerationMixin.generate`] to autoregressively generate text given the input image.
The [`ViTFeatureExtractor`/`DeiTFeatureExtractor`] class is responsible for preprocessing the input image and
[`RobertaTokenizer`/`XLMRobertaTokenizer`] decodes the generated target tokens to the target string. The
......
......@@ -24,7 +24,7 @@ The abstract from the paper is the following:
Tips:
- The model usually performs well without requiring any finetuning.
- The architecture follows a classic encoder-decoder architecture, which means that it relies on the [`~generation_utils.GenerationMixin.generate`] function for inference.
- The architecture follows a classic encoder-decoder architecture, which means that it relies on the [`~generation.GenerationMixin.generate`] function for inference.
- Inference is currently only implemented for short-form i.e. audio is pre-segmented into <=30s segments. Long-form (including timestamps) will be implemented in a future release.
- One can use [`WhisperProcessor`] to prepare audio for the model, and decode the predicted ID's back into text.
......
......@@ -56,7 +56,7 @@ If you have more than one input, pass your input as a list:
... ) # doctest: +SKIP
```
Any additional parameters for your task can also be included in the [`pipeline`]. The `text-generation` task has a [`~generation_utils.GenerationMixin.generate`] method with several parameters for controlling the output. For example, if you want to generate more than one output, set the `num_return_sequences` parameter:
Any additional parameters for your task can also be included in the [`pipeline`]. The `text-generation` task has a [`~generation.GenerationMixin.generate`] method with several parameters for controlling the output. For example, if you want to generate more than one output, set the `num_return_sequences` parameter:
```py
>>> generator(
......
......@@ -544,7 +544,7 @@ Hugging Face is based in DUMBO, New York City, and ...
This outputs a (hopefully) coherent next token following the original sequence, which in our case is the word *is* or
*features*.
In the next section, we show how [`generation_utils.GenerationMixin.generate`] can be used to
In the next section, we show how [`generation.GenerationMixin.generate`] can be used to
generate multiple tokens up to a specified length instead of one token at a time.
### Text Generation
......
......@@ -54,7 +54,7 @@ Si tienes más de una entrada, pásala como una lista:
... )
```
Cualquier parámetro adicional para tu tarea también se puede incluir en el [`pipeline`]. La tarea `text-generation` tiene un método [`~generation_utils.GenerationMixin.generate`] con varios parámetros para controlar la salida. Por ejemplo, si deseas generar más de una salida, defínelo en el parámetro `num_return_sequences`:
Cualquier parámetro adicional para tu tarea también se puede incluir en el [`pipeline`]. La tarea `text-generation` tiene un método [`~generation.GenerationMixin.generate`] con varios parámetros para controlar la salida. Por ejemplo, si deseas generar más de una salida, defínelo en el parámetro `num_return_sequences`:
```py
>>> generator(
......
......@@ -56,7 +56,7 @@ Se hai più di un input, inseriscilo in una lista:
... ) # doctest: +SKIP
```
Qualsiasi parametro addizionale per il tuo compito può essere incluso nella [`pipeline`]. La mansione `text-generation` ha un metodo [`~generation_utils.GenerationMixin.generate`] con diversi parametri per controllare l'output. Ad esempio, se desideri generare più di un output, utilizza il parametro `num_return_sequences`:
Qualsiasi parametro addizionale per il tuo compito può essere incluso nella [`pipeline`]. La mansione `text-generation` ha un metodo [`~generation.GenerationMixin.generate`] con diversi parametri per controllare l'output. Ad esempio, se desideri generare più di un output, utilizza il parametro `num_return_sequences`:
```py
>>> generator(
......
......@@ -61,7 +61,7 @@ Se tiver mais de uma entrada, passe-a como uma lista:
```
Qualquer parâmetro adicional para a sua tarefa também pode ser incluído no [`pipeline`]. A tarefa `text-generation` tem um método
[`~generation_utils.GenerationMixin.generate`] com vários parâmetros para controlar a saída.
[`~generation.GenerationMixin.generate`] com vários parâmetros para controlar a saída.
Por exemplo, se quiser gerar mais de uma saída, defina-a no parâmetro `num_return_sequences`:
```py
......
......@@ -6,7 +6,7 @@ import torch
import torch.nn.functional as F
from transformers import BartConfig
from transformers.generation_utils import GenerationMixin
from transformers.generation import GenerationMixin
def _convert_past_list_to_tuple(past_key_values):
......
......@@ -97,6 +97,7 @@ _import_structure = {
"feature_extraction_sequence_utils": ["SequenceFeatureExtractor"],
"feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
"file_utils": [],
"generation": [],
"hf_argparser": ["HfArgumentParser"],
"integrations": [
"is_comet_available",
......@@ -821,14 +822,16 @@ else:
"TextDatasetForNextSentencePrediction",
]
_import_structure["deepspeed"] = []
_import_structure["generation_beam_constraints"] = [
_import_structure["generation_utils"] = []
_import_structure["generation"].extend(
[
"Constraint",
"ConstraintListState",
"DisjunctiveConstraint",
"PhrasalConstraint",
]
_import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer", "ConstrainedBeamSearchScorer"]
_import_structure["generation_logits_process"] = [
"BeamScorer",
"BeamSearchScorer",
"ConstrainedBeamSearchScorer",
"ForcedBOSTokenLogitsProcessor",
"ForcedEOSTokenLogitsProcessor",
"HammingDiversityLogitsProcessor",
......@@ -845,14 +848,14 @@ else:
"TopKLogitsWarper",
"TopPLogitsWarper",
"TypicalLogitsWarper",
]
_import_structure["generation_stopping_criteria"] = [
"MaxLengthCriteria",
"MaxTimeCriteria",
"StoppingCriteria",
"StoppingCriteriaList",
"GenerationMixin",
"top_k_top_p_filtering",
]
_import_structure["generation_utils"] = ["top_k_top_p_filtering"]
)
_import_structure["modeling_outputs"] = []
_import_structure["modeling_utils"] = ["PreTrainedModel"]
......@@ -2278,7 +2281,9 @@ else:
_import_structure["activations_tf"] = []
_import_structure["benchmark.benchmark_args_tf"] = ["TensorFlowBenchmarkArguments"]
_import_structure["benchmark.benchmark_tf"] = ["TensorFlowBenchmark"]
_import_structure["generation_tf_logits_process"] = [
_import_structure["generation_tf_utils"] = []
_import_structure["generation"].extend(
[
"TFForcedBOSTokenLogitsProcessor",
"TFForcedEOSTokenLogitsProcessor",
"TFLogitsProcessor",
......@@ -2291,8 +2296,10 @@ else:
"TFTemperatureLogitsWarper",
"TFTopKLogitsWarper",
"TFTopPLogitsWarper",
"TFGenerationMixin",
"tf_top_k_top_p_filtering",
]
_import_structure["generation_tf_utils"] = ["tf_top_k_top_p_filtering"]
)
_import_structure["keras_callbacks"] = ["KerasMetricCallback", "PushToHubCallback"]
_import_structure["modeling_tf_outputs"] = []
_import_structure["modeling_tf_utils"] = [
......@@ -2915,7 +2922,9 @@ except OptionalDependencyNotAvailable:
name for name in dir(dummy_flax_objects) if not name.startswith("_")
]
else:
_import_structure["generation_flax_logits_process"] = [
_import_structure["generation_flax_utils"] = []
_import_structure["generation"].extend(
[
"FlaxForcedBOSTokenLogitsProcessor",
"FlaxForcedEOSTokenLogitsProcessor",
"FlaxLogitsProcessor",
......@@ -2925,8 +2934,9 @@ else:
"FlaxTemperatureLogitsWarper",
"FlaxTopKLogitsWarper",
"FlaxTopPLogitsWarper",
"FlaxGenerationMixin",
]
_import_structure["generation_flax_utils"] = []
)
_import_structure["modeling_flax_outputs"] = []
_import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"]
_import_structure["models.albert"].extend(
......@@ -3834,38 +3844,37 @@ if TYPE_CHECKING:
TextDataset,
TextDatasetForNextSentencePrediction,
)
from .generation_beam_constraints import (
from .generation import (
BeamScorer,
BeamSearchScorer,
ConstrainedBeamSearchScorer,
Constraint,
ConstraintListState,
DisjunctiveConstraint,
PhrasalConstraint,
)
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .generation_logits_process import (
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
GenerationMixin,
HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitsProcessor,
LogitsProcessorList,
LogitsWarper,
MaxLengthCriteria,
MaxTimeCriteria,
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
PhrasalConstraint,
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
StoppingCriteria,
StoppingCriteriaList,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
top_k_top_p_filtering,
)
from .generation_stopping_criteria import (
MaxLengthCriteria,
MaxTimeCriteria,
StoppingCriteria,
StoppingCriteriaList,
)
from .generation_utils import top_k_top_p_filtering
from .modeling_utils import PreTrainedModel
# PyTorch model imports
......@@ -5037,9 +5046,10 @@ if TYPE_CHECKING:
# Benchmarks
from .benchmark.benchmark_tf import TensorFlowBenchmark
from .generation_tf_logits_process import (
from .generation import (
TFForcedBOSTokenLogitsProcessor,
TFForcedEOSTokenLogitsProcessor,
TFGenerationMixin,
TFLogitsProcessor,
TFLogitsProcessorList,
TFLogitsWarper,
......@@ -5050,8 +5060,8 @@ if TYPE_CHECKING:
TFTemperatureLogitsWarper,
TFTopKLogitsWarper,
TFTopPLogitsWarper,
tf_top_k_top_p_filtering,
)
from .generation_tf_utils import tf_top_k_top_p_filtering
from .keras_callbacks import KerasMetricCallback, PushToHubCallback
from .modeling_tf_layoutlm import (
TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
......@@ -5541,9 +5551,10 @@ if TYPE_CHECKING:
# They will raise an import error if the user tries to instantiate / use them.
from .utils.dummy_flax_objects import *
else:
from .generation_flax_logits_process import (
from .generation import (
FlaxForcedBOSTokenLogitsProcessor,
FlaxForcedEOSTokenLogitsProcessor,
FlaxGenerationMixin,
FlaxLogitsProcessor,
FlaxLogitsProcessorList,
FlaxLogitsWarper,
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_tf_available, is_torch_available
_import_structure = {}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["beam_constraints"] = [
"Constraint",
"ConstraintListState",
"DisjunctiveConstraint",
"PhrasalConstraint",
]
_import_structure["beam_search"] = [
"BeamHypotheses",
"BeamScorer",
"BeamSearchScorer",
"ConstrainedBeamSearchScorer",
]
_import_structure["logits_process"] = [
"ForcedBOSTokenLogitsProcessor",
"ForcedEOSTokenLogitsProcessor",
"HammingDiversityLogitsProcessor",
"InfNanRemoveLogitsProcessor",
"LogitsProcessor",
"LogitsProcessorList",
"LogitsWarper",
"MinLengthLogitsProcessor",
"NoBadWordsLogitsProcessor",
"NoRepeatNGramLogitsProcessor",
"PrefixConstrainedLogitsProcessor",
"RepetitionPenaltyLogitsProcessor",
"TemperatureLogitsWarper",
"TopKLogitsWarper",
"TopPLogitsWarper",
"TypicalLogitsWarper",
"EncoderNoRepeatNGramLogitsProcessor",
"ExponentialDecayLengthPenalty",
"LogitNormalization",
]
_import_structure["stopping_criteria"] = [
"MaxNewTokensCriteria",
"MaxLengthCriteria",
"MaxTimeCriteria",
"StoppingCriteria",
"StoppingCriteriaList",
"validate_stopping_criteria",
]
_import_structure["utils"] = [
"GenerationMixin",
"top_k_top_p_filtering",
"GreedySearchEncoderDecoderOutput",
"GreedySearchDecoderOnlyOutput",
"SampleEncoderDecoderOutput",
"SampleDecoderOnlyOutput",
"BeamSearchEncoderDecoderOutput",
"BeamSearchDecoderOnlyOutput",
"BeamSampleEncoderDecoderOutput",
"BeamSampleDecoderOnlyOutput",
"ContrastiveSearchEncoderDecoderOutput",
"ContrastiveSearchDecoderOnlyOutput",
]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tf_logits_process"] = [
"TFForcedBOSTokenLogitsProcessor",
"TFForcedEOSTokenLogitsProcessor",
"TFLogitsProcessor",
"TFLogitsProcessorList",
"TFLogitsWarper",
"TFMinLengthLogitsProcessor",
"TFNoBadWordsLogitsProcessor",
"TFNoRepeatNGramLogitsProcessor",
"TFRepetitionPenaltyLogitsProcessor",
"TFTemperatureLogitsWarper",
"TFTopKLogitsWarper",
"TFTopPLogitsWarper",
"TFForceTokensLogitsProcessor",
"TFSuppressTokensAtBeginLogitsProcessor",
"TFSuppressTokensLogitsProcessor",
]
_import_structure["tf_utils"] = [
"TFGenerationMixin",
"tf_top_k_top_p_filtering",
"TFGreedySearchDecoderOnlyOutput",
"TFGreedySearchEncoderDecoderOutput",
"TFSampleEncoderDecoderOutput",
"TFSampleDecoderOnlyOutput",
"TFBeamSearchEncoderDecoderOutput",
"TFBeamSearchDecoderOnlyOutput",
"TFBeamSampleEncoderDecoderOutput",
"TFBeamSampleDecoderOnlyOutput",
"TFContrastiveSearchEncoderDecoderOutput",
"TFContrastiveSearchDecoderOnlyOutput",
]
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["flax_logits_process"] = [
"FlaxForcedBOSTokenLogitsProcessor",
"FlaxForcedEOSTokenLogitsProcessor",
"FlaxLogitsProcessor",
"FlaxLogitsProcessorList",
"FlaxLogitsWarper",
"FlaxMinLengthLogitsProcessor",
"FlaxTemperatureLogitsWarper",
"FlaxTopKLogitsWarper",
"FlaxTopPLogitsWarper",
]
_import_structure["flax_utils"] = [
"FlaxGenerationMixin",
"FlaxGreedySearchOutput",
"FlaxSampleOutput",
"FlaxBeamSearchOutput",
]
if TYPE_CHECKING:
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint
from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .logits_process import (
EncoderNoRepeatNGramLogitsProcessor,
ExponentialDecayLengthPenalty,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitNormalization,
LogitsProcessor,
LogitsProcessorList,
LogitsWarper,
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
)
from .stopping_criteria import (
MaxLengthCriteria,
MaxNewTokensCriteria,
MaxTimeCriteria,
StoppingCriteria,
StoppingCriteriaList,
validate_stopping_criteria,
)
from .utils import (
BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput,
BeamSearchDecoderOnlyOutput,
BeamSearchEncoderDecoderOutput,
ContrastiveSearchDecoderOnlyOutput,
ContrastiveSearchEncoderDecoderOutput,
GenerationMixin,
GreedySearchDecoderOnlyOutput,
GreedySearchEncoderDecoderOutput,
SampleDecoderOnlyOutput,
SampleEncoderDecoderOutput,
top_k_top_p_filtering,
)
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tf_logits_process import (
TFForcedBOSTokenLogitsProcessor,
TFForcedEOSTokenLogitsProcessor,
TFForceTokensLogitsProcessor,
TFLogitsProcessor,
TFLogitsProcessorList,
TFLogitsWarper,
TFMinLengthLogitsProcessor,
TFNoBadWordsLogitsProcessor,
TFNoRepeatNGramLogitsProcessor,
TFRepetitionPenaltyLogitsProcessor,
TFSuppressTokensAtBeginLogitsProcessor,
TFSuppressTokensLogitsProcessor,
TFTemperatureLogitsWarper,
TFTopKLogitsWarper,
TFTopPLogitsWarper,
)
from .tf_utils import (
TFBeamSampleDecoderOnlyOutput,
TFBeamSampleEncoderDecoderOutput,
TFBeamSearchDecoderOnlyOutput,
TFBeamSearchEncoderDecoderOutput,
TFContrastiveSearchDecoderOnlyOutput,
TFContrastiveSearchEncoderDecoderOutput,
TFGenerationMixin,
TFGreedySearchDecoderOnlyOutput,
TFGreedySearchEncoderDecoderOutput,
TFSampleDecoderOnlyOutput,
TFSampleEncoderDecoderOutput,
tf_top_k_top_p_filtering,
)
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .flax_logits_process import (
FlaxForcedBOSTokenLogitsProcessor,
FlaxForcedEOSTokenLogitsProcessor,
FlaxLogitsProcessor,
FlaxLogitsProcessorList,
FlaxLogitsWarper,
FlaxMinLengthLogitsProcessor,
FlaxTemperatureLogitsWarper,
FlaxTopKLogitsWarper,
FlaxTopPLogitsWarper,
)
from .flax_utils import FlaxBeamSearchOutput, FlaxGenerationMixin, FlaxGreedySearchOutput, FlaxSampleOutput
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment