"vscode:/vscode.git/clone" did not exist on "91953e51b42a2112c3aa171ec1c04959a090a78e"
Unverified Commit f270b960 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: move generation_*.py src files into generation/*.py (#20096)

* move generation_*.py src files into generation/*.py

* populate generation.__init__ with lazy loading

* move imports and references from generation.xxx.object to generation.object
parent bac2d29a
......@@ -21,8 +21,8 @@ from typing import List, Optional, Tuple
import numpy as np
import torch
from .generation_beam_constraints import Constraint, ConstraintListState
from .utils import add_start_docstrings
from ..utils import add_start_docstrings
from .beam_constraints import Constraint, ConstraintListState
PROCESS_INPUTS_DOCSTRING = r"""
......
......@@ -19,8 +19,8 @@ import jax
import jax.lax as lax
import jax.numpy as jnp
from .utils import add_start_docstrings
from .utils.logging import get_logger
from ..utils import add_start_docstrings
from ..utils.logging import get_logger
logger = get_logger(__name__)
......
This diff is collapsed.
......@@ -20,8 +20,8 @@ from typing import Callable, Iterable, List, Optional, Tuple
import numpy as np
import torch
from .utils import add_start_docstrings
from .utils.logging import get_logger
from ..utils import add_start_docstrings
from ..utils.logging import get_logger
logger = get_logger(__name__)
......
......@@ -6,7 +6,7 @@ from typing import Optional
import torch
from .utils import add_start_docstrings
from ..utils import add_start_docstrings
STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
......
......@@ -19,9 +19,9 @@ from typing import List, Tuple
import numpy as np
import tensorflow as tf
from .tf_utils import stable_softmax
from .utils import add_start_docstrings
from .utils.logging import get_logger
from ..tf_utils import stable_softmax
from ..utils import add_start_docstrings
from ..utils.logging import get_logger
logger = get_logger(__name__)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -33,7 +33,7 @@ from jax.random import PRNGKey
from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
from .generation_flax_utils import FlaxGenerationMixin
from .generation import FlaxGenerationMixin
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
from .utils import (
FLAX_WEIGHTS_INDEX_NAME,
......
......@@ -43,7 +43,7 @@ from . import DataCollatorWithPadding, DefaultDataCollator
from .activations_tf import get_tf_activation
from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
from .generation_tf_utils import TFGenerationMixin
from .generation import TFGenerationMixin
from .tf_utils import shape_list
from .utils import (
DUMMY_INPUTS,
......
......@@ -38,7 +38,7 @@ from .activations import get_activation
from .configuration_utils import PretrainedConfig
from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
from .dynamic_module_utils import custom_object_save
from .generation_utils import GenerationMixin
from .generation import GenerationMixin
from .pytorch_utils import ( # noqa: F401
Conv1D,
apply_chunking_to_forward,
......
......@@ -21,9 +21,7 @@ import torch
from torch import nn
from ...configuration_utils import PretrainedConfig
from ...generation_beam_search import BeamSearchScorer
from ...generation_logits_process import LogitsProcessorList
from ...generation_stopping_criteria import StoppingCriteriaList
from ...generation import BeamSearchScorer, LogitsProcessorList, StoppingCriteriaList
from ...modeling_outputs import ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings
......@@ -925,8 +923,8 @@ class RagSequenceForGeneration(RagPreTrainedModel):
**model_kwargs
) -> torch.LongTensor:
"""
Implements RAG sequence "thorough" decoding. Read the [`~generation_utils.GenerationMixin.generate`]`
documentation for more information on how to set other generate input parameters.
Implements RAG sequence "thorough" decoding. Read the [`~generation.GenerationMixin.generate`]` documentation
for more information on how to set other generate input parameters.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
......@@ -960,14 +958,14 @@ class RagSequenceForGeneration(RagPreTrainedModel):
to be set to `False` if used while training with distributed backend.
num_return_sequences(`int`, *optional*, defaults to 1):
The number of independently computed returned sequences for each element in the batch. Note that this
is not the value we pass to the `generator`'s `[`~generation_utils.GenerationMixin.generate`]`
function, where we set `num_return_sequences` to `num_beams`.
is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`]` function,
where we set `num_return_sequences` to `num_beams`.
num_beams (`int`, *optional*, defaults to 1):
Number of beams for beam search. 1 means no beam search.
n_docs (`int`, *optional*, defaults to `config.n_docs`)
Number of documents to retrieve and/or number of documents for which to generate an answer.
kwargs:
Additional kwargs will be passed to [`~generation_utils.GenerationMixin.generate`].
Additional kwargs will be passed to [`~generation.GenerationMixin.generate`].
Return:
`torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated
......@@ -1486,8 +1484,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
enabled.
num_return_sequences(`int`, *optional*, defaults to 1):
The number of independently computed returned sequences for each element in the batch. Note that this
is not the value we pass to the `generator`'s `[`~generation_utils.GenerationMixin.generate`] function,
where we set `num_return_sequences` to `num_beams`. decoder_start_token_id (`int`, *optional*): If an
is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`] function, where
we set `num_return_sequences` to `num_beams`. decoder_start_token_id (`int`, *optional*): If an
encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
n_docs (`int`, *optional*, defaults to `config.n_docs`)
Number of documents to retrieve and/or number of documents for which to generate an answer.
......
......@@ -1073,8 +1073,8 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
Number of beams for beam search. 1 means no beam search.
num_return_sequences(`int`, *optional*, defaults to 1):
The number of independently computed returned sequences for each element in the batch. Note that this
is not the value we pass to the `generator`'s `[`~generation_utils.GenerationMixin.generate`] function,
where we set `num_return_sequences` to `num_beams`. decoder_start_token_id (`int`, *optional*): If an
is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`] function, where
we set `num_return_sequences` to `num_beams`. decoder_start_token_id (`int`, *optional*): If an
encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
n_docs (`int`, *optional*, defaults to `config.n_docs`)
Number of documents to retrieve and/or number of documents for which to generate an answer.
......@@ -1676,8 +1676,8 @@ class TFRagSequenceForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingL
**model_kwargs
):
"""
Implements RAG sequence "thorough" decoding. Read the [`~generation_utils.GenerationMixin.generate`]`
documentation for more information on how to set other generate input parameters
Implements RAG sequence "thorough" decoding. Read the [`~generation.GenerationMixin.generate`]` documentation
for more information on how to set other generate input parameters
Args:
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
......@@ -1705,14 +1705,14 @@ class TFRagSequenceForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingL
to be set to `False` if used while training with distributed backend.
num_return_sequences(`int`, *optional*, defaults to 1):
The number of independently computed returned sequences for each element in the batch. Note that this
is not the value we pass to the `generator`'s `[`~generation_utils.GenerationMixin.generate`]`
function, where we set `num_return_sequences` to `num_beams`.
is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`]` function,
where we set `num_return_sequences` to `num_beams`.
num_beams (`int`, *optional*, defaults to 1):
Number of beams for beam search. 1 means no beam search.
n_docs (`int`, *optional*, defaults to `config.n_docs`)
Number of documents to retrieve and/or number of documents for which to generate an answer.
kwargs:
Additional kwargs will be passed to [`~generation_utils.GenerationMixin.generate`]
Additional kwargs will be passed to [`~generation.GenerationMixin.generate`]
Return:
`tf.Tensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences. The
......
......@@ -94,7 +94,7 @@ class ImageToTextPipeline(Pipeline):
def _forward(self, model_inputs, generate_kwargs=None):
if generate_kwargs is None:
generate_kwargs = {}
# FIXME: We need to pop here due to a difference in how `generation_utils.py` and `generation_tf_utils.py`
# FIXME: We need to pop here due to a difference in how `generation.py` and `generation.tf_utils.py`
# parse inputs. In the Tensorflow version, `generate` raises an error if we don't use `input_ids` whereas
# the PyTorch version matches it with `self.model.main_input_name` or `self.model.encoder.main_input_name`
# in the `_prepare_model_inputs` method.
......
......@@ -34,7 +34,7 @@ class Text2TextGenerationPipeline(Pipeline):
up-to-date list of available models on
[huggingface.co/models](https://huggingface.co/models?filter=text2text-generation). For a list of available
parameters, see the [following
documentation](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate)
documentation](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.generation.GenerationMixin.generate)
Usage:
......@@ -206,7 +206,7 @@ class SummarizationPipeline(Text2TextGenerationPipeline):
currently, '*bart-large-cnn*', '*t5-small*', '*t5-base*', '*t5-large*', '*t5-3b*', '*t5-11b*'. See the up-to-date
list of available models on [huggingface.co/models](https://huggingface.co/models?filter=summarization). For a list
of available parameters, see the [following
documentation](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate)
documentation](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.generation.GenerationMixin.generate)
Usage:
......@@ -274,7 +274,7 @@ class TranslationPipeline(Text2TextGenerationPipeline):
The models that this pipeline can use are models that have been fine-tuned on a translation task. See the
up-to-date list of available models on [huggingface.co/models](https://huggingface.co/models?filter=translation).
For a list of available parameters, see the [following
documentation](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate)
documentation](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.generation.GenerationMixin.generate)
Usage:
......
......@@ -17,6 +17,13 @@ class FlaxForcedEOSTokenLogitsProcessor(metaclass=DummyObject):
requires_backends(self, ["flax"])
class FlaxGenerationMixin(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxLogitsProcessor(metaclass=DummyObject):
_backends = ["flax"]
......
......@@ -80,63 +80,63 @@ class TextDatasetForNextSentencePrediction(metaclass=DummyObject):
requires_backends(self, ["torch"])
class Constraint(metaclass=DummyObject):
class BeamScorer(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ConstraintListState(metaclass=DummyObject):
class BeamSearchScorer(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DisjunctiveConstraint(metaclass=DummyObject):
class ConstrainedBeamSearchScorer(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PhrasalConstraint(metaclass=DummyObject):
class Constraint(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class BeamScorer(metaclass=DummyObject):
class ConstraintListState(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class BeamSearchScorer(metaclass=DummyObject):
class DisjunctiveConstraint(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ConstrainedBeamSearchScorer(metaclass=DummyObject):
class ForcedBOSTokenLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ForcedBOSTokenLogitsProcessor(metaclass=DummyObject):
class ForcedEOSTokenLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ForcedEOSTokenLogitsProcessor(metaclass=DummyObject):
class GenerationMixin(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
......@@ -178,91 +178,98 @@ class LogitsWarper(metaclass=DummyObject):
requires_backends(self, ["torch"])
class MinLengthLogitsProcessor(metaclass=DummyObject):
class MaxLengthCriteria(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class NoBadWordsLogitsProcessor(metaclass=DummyObject):
class MaxTimeCriteria(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class NoRepeatNGramLogitsProcessor(metaclass=DummyObject):
class MinLengthLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PrefixConstrainedLogitsProcessor(metaclass=DummyObject):
class NoBadWordsLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class RepetitionPenaltyLogitsProcessor(metaclass=DummyObject):
class NoRepeatNGramLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class TemperatureLogitsWarper(metaclass=DummyObject):
class PhrasalConstraint(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class TopKLogitsWarper(metaclass=DummyObject):
class PrefixConstrainedLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class TopPLogitsWarper(metaclass=DummyObject):
class RepetitionPenaltyLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class TypicalLogitsWarper(metaclass=DummyObject):
class StoppingCriteria(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MaxLengthCriteria(metaclass=DummyObject):
class StoppingCriteriaList(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MaxTimeCriteria(metaclass=DummyObject):
class TemperatureLogitsWarper(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class StoppingCriteria(metaclass=DummyObject):
class TopKLogitsWarper(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class StoppingCriteriaList(metaclass=DummyObject):
class TopPLogitsWarper(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class TypicalLogitsWarper(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment