Unverified Commit f270b960 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: move generation_*.py src files into generation/*.py (#20096)

* move generation_*.py src files into generation/*.py

* populate generation.__init__ with lazy loading

* move imports and references from generation.xxx.object to generation.object
parent bac2d29a
...@@ -21,8 +21,8 @@ from typing import List, Optional, Tuple ...@@ -21,8 +21,8 @@ from typing import List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from .generation_beam_constraints import Constraint, ConstraintListState from ..utils import add_start_docstrings
from .utils import add_start_docstrings from .beam_constraints import Constraint, ConstraintListState
PROCESS_INPUTS_DOCSTRING = r""" PROCESS_INPUTS_DOCSTRING = r"""
......
...@@ -19,8 +19,8 @@ import jax ...@@ -19,8 +19,8 @@ import jax
import jax.lax as lax import jax.lax as lax
import jax.numpy as jnp import jax.numpy as jnp
from .utils import add_start_docstrings from ..utils import add_start_docstrings
from .utils.logging import get_logger from ..utils.logging import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
......
This diff is collapsed.
...@@ -20,8 +20,8 @@ from typing import Callable, Iterable, List, Optional, Tuple ...@@ -20,8 +20,8 @@ from typing import Callable, Iterable, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from .utils import add_start_docstrings from ..utils import add_start_docstrings
from .utils.logging import get_logger from ..utils.logging import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
......
...@@ -6,7 +6,7 @@ from typing import Optional ...@@ -6,7 +6,7 @@ from typing import Optional
import torch import torch
from .utils import add_start_docstrings from ..utils import add_start_docstrings
STOPPING_CRITERIA_INPUTS_DOCSTRING = r""" STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
......
...@@ -19,9 +19,9 @@ from typing import List, Tuple ...@@ -19,9 +19,9 @@ from typing import List, Tuple
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from .tf_utils import stable_softmax from ..tf_utils import stable_softmax
from .utils import add_start_docstrings from ..utils import add_start_docstrings
from .utils.logging import get_logger from ..utils.logging import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -33,7 +33,7 @@ from jax.random import PRNGKey ...@@ -33,7 +33,7 @@ from jax.random import PRNGKey
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save from .dynamic_module_utils import custom_object_save
from .generation_flax_utils import FlaxGenerationMixin from .generation import FlaxGenerationMixin
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
from .utils import ( from .utils import (
FLAX_WEIGHTS_INDEX_NAME, FLAX_WEIGHTS_INDEX_NAME,
......
...@@ -43,7 +43,7 @@ from . import DataCollatorWithPadding, DefaultDataCollator ...@@ -43,7 +43,7 @@ from . import DataCollatorWithPadding, DefaultDataCollator
from .activations_tf import get_tf_activation from .activations_tf import get_tf_activation
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save from .dynamic_module_utils import custom_object_save
from .generation_tf_utils import TFGenerationMixin from .generation import TFGenerationMixin
from .tf_utils import shape_list from .tf_utils import shape_list
from .utils import ( from .utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
......
...@@ -38,7 +38,7 @@ from .activations import get_activation ...@@ -38,7 +38,7 @@ from .activations import get_activation
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
from .dynamic_module_utils import custom_object_save from .dynamic_module_utils import custom_object_save
from .generation_utils import GenerationMixin from .generation import GenerationMixin
from .pytorch_utils import ( # noqa: F401 from .pytorch_utils import ( # noqa: F401
Conv1D, Conv1D,
apply_chunking_to_forward, apply_chunking_to_forward,
......
...@@ -21,9 +21,7 @@ import torch ...@@ -21,9 +21,7 @@ import torch
from torch import nn from torch import nn
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...generation_beam_search import BeamSearchScorer from ...generation import BeamSearchScorer, LogitsProcessorList, StoppingCriteriaList
from ...generation_logits_process import LogitsProcessorList
from ...generation_stopping_criteria import StoppingCriteriaList
from ...modeling_outputs import ModelOutput from ...modeling_outputs import ModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings
...@@ -925,8 +923,8 @@ class RagSequenceForGeneration(RagPreTrainedModel): ...@@ -925,8 +923,8 @@ class RagSequenceForGeneration(RagPreTrainedModel):
**model_kwargs **model_kwargs
) -> torch.LongTensor: ) -> torch.LongTensor:
""" """
Implements RAG sequence "thorough" decoding. Read the [`~generation_utils.GenerationMixin.generate`]` Implements RAG sequence "thorough" decoding. Read the [`~generation.GenerationMixin.generate`]` documentation
documentation for more information on how to set other generate input parameters. for more information on how to set other generate input parameters.
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
...@@ -960,14 +958,14 @@ class RagSequenceForGeneration(RagPreTrainedModel): ...@@ -960,14 +958,14 @@ class RagSequenceForGeneration(RagPreTrainedModel):
to be set to `False` if used while training with distributed backend. to be set to `False` if used while training with distributed backend.
num_return_sequences(`int`, *optional*, defaults to 1): num_return_sequences(`int`, *optional*, defaults to 1):
The number of independently computed returned sequences for each element in the batch. Note that this The number of independently computed returned sequences for each element in the batch. Note that this
is not the value we pass to the `generator`'s `[`~generation_utils.GenerationMixin.generate`]` is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`]` function,
function, where we set `num_return_sequences` to `num_beams`. where we set `num_return_sequences` to `num_beams`.
num_beams (`int`, *optional*, defaults to 1): num_beams (`int`, *optional*, defaults to 1):
Number of beams for beam search. 1 means no beam search. Number of beams for beam search. 1 means no beam search.
n_docs (`int`, *optional*, defaults to `config.n_docs`) n_docs (`int`, *optional*, defaults to `config.n_docs`)
Number of documents to retrieve and/or number of documents for which to generate an answer. Number of documents to retrieve and/or number of documents for which to generate an answer.
kwargs: kwargs:
Additional kwargs will be passed to [`~generation_utils.GenerationMixin.generate`]. Additional kwargs will be passed to [`~generation.GenerationMixin.generate`].
Return: Return:
`torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated
...@@ -1486,8 +1484,8 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1486,8 +1484,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
enabled. enabled.
num_return_sequences(`int`, *optional*, defaults to 1): num_return_sequences(`int`, *optional*, defaults to 1):
The number of independently computed returned sequences for each element in the batch. Note that this The number of independently computed returned sequences for each element in the batch. Note that this
is not the value we pass to the `generator`'s `[`~generation_utils.GenerationMixin.generate`] function, is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`] function, where
where we set `num_return_sequences` to `num_beams`. decoder_start_token_id (`int`, *optional*): If an we set `num_return_sequences` to `num_beams`. decoder_start_token_id (`int`, *optional*): If an
encoder-decoder model starts decoding with a different token than *bos*, the id of that token. encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
n_docs (`int`, *optional*, defaults to `config.n_docs`) n_docs (`int`, *optional*, defaults to `config.n_docs`)
Number of documents to retrieve and/or number of documents for which to generate an answer. Number of documents to retrieve and/or number of documents for which to generate an answer.
......
...@@ -1073,8 +1073,8 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss ...@@ -1073,8 +1073,8 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
Number of beams for beam search. 1 means no beam search. Number of beams for beam search. 1 means no beam search.
num_return_sequences(`int`, *optional*, defaults to 1): num_return_sequences(`int`, *optional*, defaults to 1):
The number of independently computed returned sequences for each element in the batch. Note that this The number of independently computed returned sequences for each element in the batch. Note that this
is not the value we pass to the `generator`'s `[`~generation_utils.GenerationMixin.generate`] function, is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`] function, where
where we set `num_return_sequences` to `num_beams`. decoder_start_token_id (`int`, *optional*): If an we set `num_return_sequences` to `num_beams`. decoder_start_token_id (`int`, *optional*): If an
encoder-decoder model starts decoding with a different token than *bos*, the id of that token. encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
n_docs (`int`, *optional*, defaults to `config.n_docs`) n_docs (`int`, *optional*, defaults to `config.n_docs`)
Number of documents to retrieve and/or number of documents for which to generate an answer. Number of documents to retrieve and/or number of documents for which to generate an answer.
...@@ -1676,8 +1676,8 @@ class TFRagSequenceForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingL ...@@ -1676,8 +1676,8 @@ class TFRagSequenceForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingL
**model_kwargs **model_kwargs
): ):
""" """
Implements RAG sequence "thorough" decoding. Read the [`~generation_utils.GenerationMixin.generate`]` Implements RAG sequence "thorough" decoding. Read the [`~generation.GenerationMixin.generate`]` documentation
documentation for more information on how to set other generate input parameters for more information on how to set other generate input parameters
Args: Args:
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
...@@ -1705,14 +1705,14 @@ class TFRagSequenceForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingL ...@@ -1705,14 +1705,14 @@ class TFRagSequenceForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingL
to be set to `False` if used while training with distributed backend. to be set to `False` if used while training with distributed backend.
num_return_sequences(`int`, *optional*, defaults to 1): num_return_sequences(`int`, *optional*, defaults to 1):
The number of independently computed returned sequences for each element in the batch. Note that this The number of independently computed returned sequences for each element in the batch. Note that this
is not the value we pass to the `generator`'s `[`~generation_utils.GenerationMixin.generate`]` is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`]` function,
function, where we set `num_return_sequences` to `num_beams`. where we set `num_return_sequences` to `num_beams`.
num_beams (`int`, *optional*, defaults to 1): num_beams (`int`, *optional*, defaults to 1):
Number of beams for beam search. 1 means no beam search. Number of beams for beam search. 1 means no beam search.
n_docs (`int`, *optional*, defaults to `config.n_docs`) n_docs (`int`, *optional*, defaults to `config.n_docs`)
Number of documents to retrieve and/or number of documents for which to generate an answer. Number of documents to retrieve and/or number of documents for which to generate an answer.
kwargs: kwargs:
Additional kwargs will be passed to [`~generation_utils.GenerationMixin.generate`] Additional kwargs will be passed to [`~generation.GenerationMixin.generate`]
Return: Return:
`tf.Tensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences. The `tf.Tensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences. The
......
...@@ -94,7 +94,7 @@ class ImageToTextPipeline(Pipeline): ...@@ -94,7 +94,7 @@ class ImageToTextPipeline(Pipeline):
def _forward(self, model_inputs, generate_kwargs=None): def _forward(self, model_inputs, generate_kwargs=None):
if generate_kwargs is None: if generate_kwargs is None:
generate_kwargs = {} generate_kwargs = {}
# FIXME: We need to pop here due to a difference in how `generation_utils.py` and `generation_tf_utils.py` # FIXME: We need to pop here due to a difference in how `generation.py` and `generation.tf_utils.py`
# parse inputs. In the Tensorflow version, `generate` raises an error if we don't use `input_ids` whereas # parse inputs. In the Tensorflow version, `generate` raises an error if we don't use `input_ids` whereas
# the PyTorch version matches it with `self.model.main_input_name` or `self.model.encoder.main_input_name` # the PyTorch version matches it with `self.model.main_input_name` or `self.model.encoder.main_input_name`
# in the `_prepare_model_inputs` method. # in the `_prepare_model_inputs` method.
......
...@@ -34,7 +34,7 @@ class Text2TextGenerationPipeline(Pipeline): ...@@ -34,7 +34,7 @@ class Text2TextGenerationPipeline(Pipeline):
up-to-date list of available models on up-to-date list of available models on
[huggingface.co/models](https://huggingface.co/models?filter=text2text-generation). For a list of available [huggingface.co/models](https://huggingface.co/models?filter=text2text-generation). For a list of available
parameters, see the [following parameters, see the [following
documentation](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate) documentation](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.generation.GenerationMixin.generate)
Usage: Usage:
...@@ -206,7 +206,7 @@ class SummarizationPipeline(Text2TextGenerationPipeline): ...@@ -206,7 +206,7 @@ class SummarizationPipeline(Text2TextGenerationPipeline):
currently, '*bart-large-cnn*', '*t5-small*', '*t5-base*', '*t5-large*', '*t5-3b*', '*t5-11b*'. See the up-to-date currently, '*bart-large-cnn*', '*t5-small*', '*t5-base*', '*t5-large*', '*t5-3b*', '*t5-11b*'. See the up-to-date
list of available models on [huggingface.co/models](https://huggingface.co/models?filter=summarization). For a list list of available models on [huggingface.co/models](https://huggingface.co/models?filter=summarization). For a list
of available parameters, see the [following of available parameters, see the [following
documentation](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate) documentation](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.generation.GenerationMixin.generate)
Usage: Usage:
...@@ -274,7 +274,7 @@ class TranslationPipeline(Text2TextGenerationPipeline): ...@@ -274,7 +274,7 @@ class TranslationPipeline(Text2TextGenerationPipeline):
The models that this pipeline can use are models that have been fine-tuned on a translation task. See the The models that this pipeline can use are models that have been fine-tuned on a translation task. See the
up-to-date list of available models on [huggingface.co/models](https://huggingface.co/models?filter=translation). up-to-date list of available models on [huggingface.co/models](https://huggingface.co/models?filter=translation).
For a list of available parameters, see the [following For a list of available parameters, see the [following
documentation](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate) documentation](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.generation.GenerationMixin.generate)
Usage: Usage:
......
...@@ -17,6 +17,13 @@ class FlaxForcedEOSTokenLogitsProcessor(metaclass=DummyObject): ...@@ -17,6 +17,13 @@ class FlaxForcedEOSTokenLogitsProcessor(metaclass=DummyObject):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxGenerationMixin(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxLogitsProcessor(metaclass=DummyObject): class FlaxLogitsProcessor(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
......
...@@ -80,63 +80,63 @@ class TextDatasetForNextSentencePrediction(metaclass=DummyObject): ...@@ -80,63 +80,63 @@ class TextDatasetForNextSentencePrediction(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class Constraint(metaclass=DummyObject): class BeamScorer(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class ConstraintListState(metaclass=DummyObject): class BeamSearchScorer(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class DisjunctiveConstraint(metaclass=DummyObject): class ConstrainedBeamSearchScorer(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class PhrasalConstraint(metaclass=DummyObject): class Constraint(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class BeamScorer(metaclass=DummyObject): class ConstraintListState(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class BeamSearchScorer(metaclass=DummyObject): class DisjunctiveConstraint(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class ConstrainedBeamSearchScorer(metaclass=DummyObject): class ForcedBOSTokenLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class ForcedBOSTokenLogitsProcessor(metaclass=DummyObject): class ForcedEOSTokenLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class ForcedEOSTokenLogitsProcessor(metaclass=DummyObject): class GenerationMixin(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -178,91 +178,98 @@ class LogitsWarper(metaclass=DummyObject): ...@@ -178,91 +178,98 @@ class LogitsWarper(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class MinLengthLogitsProcessor(metaclass=DummyObject): class MaxLengthCriteria(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class NoBadWordsLogitsProcessor(metaclass=DummyObject): class MaxTimeCriteria(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class NoRepeatNGramLogitsProcessor(metaclass=DummyObject): class MinLengthLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class PrefixConstrainedLogitsProcessor(metaclass=DummyObject): class NoBadWordsLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class RepetitionPenaltyLogitsProcessor(metaclass=DummyObject): class NoRepeatNGramLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class TemperatureLogitsWarper(metaclass=DummyObject): class PhrasalConstraint(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class TopKLogitsWarper(metaclass=DummyObject): class PrefixConstrainedLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class TopPLogitsWarper(metaclass=DummyObject): class RepetitionPenaltyLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class TypicalLogitsWarper(metaclass=DummyObject): class StoppingCriteria(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class MaxLengthCriteria(metaclass=DummyObject): class StoppingCriteriaList(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class MaxTimeCriteria(metaclass=DummyObject): class TemperatureLogitsWarper(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class StoppingCriteria(metaclass=DummyObject): class TopKLogitsWarper(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class StoppingCriteriaList(metaclass=DummyObject): class TopPLogitsWarper(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class TypicalLogitsWarper(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment