Unverified Commit 4bc723f8 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: use `GenerationConfig` as the basis for `.generate()` parametrization (#20388)



* generate from config mvp

* fix failing tests

* max_time test

* Load default gen config at model load time; Update docs

* further documentation; add tests

* adapt rag to the new structure

* handle models not instantiated with from_pretained (like in tests)

* better default generation config

* add can_generate fn

* handle legacy use case of ad hoc model config changes

* initialize gen config from config in individual methods, if gen config is none

* fix _get_decoder_start_token_id when called outside GenerationMixin

* correct model config load order (set attr > model config > decoder config)

* update rag to match latest changes

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* load gen config from model config in model.from_pretrained

* fix can_generate fn

* handle generate calls without a previous from_pretrained (e.g. tests)

* add legacy behavior (and a warning)

* lower logger severity
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent b1706f69
......@@ -18,12 +18,79 @@ Each framework has a generate method for auto-regressive text generation impleme
- TensorFlow [`~generation.TFGenerationMixin.generate`] is implemented in [`~generation.TFGenerationMixin`].
- Flax/JAX [`~generation.FlaxGenerationMixin.generate`] is implemented in [`~generation.FlaxGenerationMixin`].
<!--- TODO: add a brief description of GenerationConfig (with examples) when it becomes usable with generate --->
Regardless of your framework of choice, you can parameterize the generate method with a [`~generation.GenerationConfig`]
class instance. Please refer to this class for the complete list of generation parameters, which control the behavior
of the generation method.
All models have a default generation configuration that will be used if you don't provide one. If you have a loaded
model instance `model`, you can inspect the default generation configuration with `model.generation_config`. If you'd
like to set a new default generation configuration, you can create a new [`~generation.GenerationConfig`] instance and
store it with `save_pretrained`, making sure to leave its `config_file_name` argument empty.
```python
from transformers import AutoModelForCausalLM, GenerationConfig
model = AutoModelForCausalLM.from_pretrained("my_account/my_model")
# Inspect the default generation configuration
print(model.generation_config)
# Set a new default generation configuration
generation_config = GenerationConfig(
max_new_tokens=50, do_sample=True, top_k=50, eos_token_id=model.config.eos_token_id
)
generation_config.save_pretrained("my_account/my_model", push_to_hub=True)
```
<Tip>
If you inspect a serialized [`~generation.GenerationConfig`] file or print a class instance, you will notice that
default values are omitted. Some attributes, like `max_length`, have a conservative default value, to avoid running
into resource limitations. Make sure you double-check the defaults in the documentation.
</Tip>
You can also store several generation parametrizations in a single directory, making use of the `config_file_name`
argument in `save_pretrained`. You can latter instantiate them with `from_pretrained`. This is useful if you want to
store several generation configurations for a single model (e.g. one for creative text generation with sampling, and
other for summarization with beam search).
```python
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
tokenizer = AutoTokenizer.from_pretrained("t5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
translation_generation_config = GenerationConfig(
num_beams=4,
early_stopping=True,
decoder_start_token_id=0,
eos_token_id=model.config.eos_token_id,
pad_token=model.config.pad_token_id,
)
# If you were working on a model for which your had the right Hub permissions, you could store a named generation
# config as follows
translation_generation_config.save_pretrained("t5-small", "translation_generation_config.json", push_to_hub=True)
# You could then use the named generation config file to parameterize generation
generation_config = GenerationConfig.from_pretrained("t5-small", "translation_generation_config.json")
inputs = tokenizer("translate English to French: Configuration files are easy to use!", return_tensors="pt")
outputs = model.generate(**inputs, generation_config=generation_config)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
# ['Les fichiers de configuration sont faciles à utiliser !']
```
Finally, you can specify ad hoc modifications to the used generation configuration by passing the attribute you
wish to override directly to the generate method (e.g. `model.generate(inputs, max_new_tokens=512)`). Each
framework's `generate` method docstring (available below) has a few illustrative examples on the different strategies
to parameterize it.
## GenerationConfig
[[autodoc]] generation.GenerationConfig
- from_pretrained
- from_model_config
- save_pretrained
## GenerationMixin
......
......@@ -20,6 +20,7 @@ import os
from typing import Any, Dict, Optional, Union
from .. import __version__
from ..configuration_utils import PretrainedConfig
from ..utils import (
GENERATION_CONFIG_NAME,
PushToHubMixin,
......@@ -36,7 +37,23 @@ logger = logging.get_logger(__name__)
class GenerationConfig(PushToHubMixin):
r"""
Class that holds a configuration for a generation task.
Class that holds a configuration for a generation task. A `generate` call supports the following generation methods
for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
- *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and
`do_sample=False`.
- *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0.`
and `top_k>1`
- *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and
`do_sample=True`.
- *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and
`do_sample=False`.
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if
`num_beams>1` and `do_sample=True`.
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin.group_beam_search`], if
`num_beams>1` and `num_beam_groups>1`.
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if
`constraints!=None` or `force_words_ids!=None`.
<Tip>
......@@ -45,6 +62,9 @@ class GenerationConfig(PushToHubMixin):
</Tip>
Most of these parameters are explained in more detail in [this blog
post](https://huggingface.co/blog/how-to-generate).
Arg:
> Parameters that control the length of the output
......@@ -73,6 +93,9 @@ class GenerationConfig(PushToHubMixin):
[this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
penalty_alpha (`float`, *optional*):
The values balance the model confidence and the degeneration penalty in contrastive search decoding.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
> Parameters for manipulation of the model output logits
......@@ -108,13 +131,13 @@ class GenerationConfig(PushToHubMixin):
words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, this
triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one
can allow different forms of each word.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
renormalize_logits (`bool`, *optional*, defaults to `False`):
Whether to renormalize the logits after applying all the logits processors or warpers (including the custom
ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits
are normalized but some logit processors or warpers break the normalization.
constraints (`List[Constraint]`, *optional*):
Custom constraints that can be added to the generation to ensure that the output will contain the use of
certain tokens as defined by `Constraint` objects, in the most sensible way possible.
forced_bos_token_id (`int`, *optional*, defaults to `model.config.forced_bos_token_id`):
The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful for
multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be the target
......@@ -191,6 +214,7 @@ class GenerationConfig(PushToHubMixin):
self.num_beams = kwargs.pop("num_beams", 1)
self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
self.penalty_alpha = kwargs.pop("penalty_alpha", None)
self.use_cache = kwargs.pop("use_cache", True)
# Parameters for manipulation of the model output logits
self.temperature = kwargs.pop("temperature", 1.0)
......@@ -202,7 +226,9 @@ class GenerationConfig(PushToHubMixin):
self.length_penalty = kwargs.pop("length_penalty", 1.0)
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
self.force_word_ids = kwargs.pop("force_word_ids", None)
self.force_words_ids = kwargs.pop("force_words_ids", None)
self.renormalize_logits = kwargs.pop("renormalize_logits", False)
self.constraints = kwargs.pop("constraints", None)
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
self.remove_invalid_values = kwargs.pop("remove_invalid_values", False)
......@@ -230,12 +256,20 @@ class GenerationConfig(PushToHubMixin):
# Wild card
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
# The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the the hub interface.
# The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the the hub
# interface.
self._from_model_config = kwargs.pop("_from_model_config", False)
self._commit_hash = kwargs.pop("_commit_hash", None)
self.transformers_version = kwargs.pop("transformers_version", __version__)
def __eq__(self, other):
return self.__dict__ == other.__dict__
self_dict = self.__dict__.copy()
other_dict = other.__dict__.copy()
# ignore metadata
for metadata_field in ("_from_model_config", "_commit_hash", "transformers_version"):
self_dict.pop(metadata_field, None)
other_dict.pop(metadata_field, None)
return self_dict == other_dict
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"
......@@ -484,18 +518,11 @@ class GenerationConfig(PushToHubMixin):
kwargs["_commit_hash"] = config_dict["_commit_hash"]
config = cls(**config_dict)
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
unused_kwargs = config.update(**kwargs)
logger.info(f"Generate config {config}")
if return_unused_kwargs:
return config, kwargs
return config, unused_kwargs
else:
return config
......@@ -568,3 +595,54 @@ class GenerationConfig(PushToHubMixin):
"""
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string(use_diff=use_diff))
@classmethod
def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig":
"""
Instantiates a [`GenerationConfig`] from a [`PretrainedConfig`]. This function is useful to convert legacy
[`PretrainedConfig`] objects, which may contain generation parameters, into a stand-alone [`GenerationConfig`].
Args:
model_config (`PretrainedConfig`):
The model config that will be used to instantiate the generation config.
Returns:
[`GenerationConfig`]: The configuration object instantiated from those parameters.
"""
config_dict = model_config.to_dict()
config = cls.from_dict(config_dict, return_unused_kwargs=False)
# Special case: some models have generation attributes set in the decoder. Use them if still unset in the
# generation config.
for decoder_name in ("decoder", "generator"):
if decoder_name in config_dict:
default_generation_config = GenerationConfig()
decoder_config = config_dict[decoder_name]
for attr in config.to_dict().keys():
if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr):
setattr(config, attr, decoder_config[attr])
config._from_model_config = True
return config
def update(self, **kwargs):
"""
Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes,
returning all the unused kwargs.
Args:
kwargs (`Dict[str, Any]`):
Dictionary of attributes to tentatively update this class.
Returns:
`Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
"""
to_remove = []
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
to_remove.append(key)
# remove all the attributes that were updated, without modifying the input dict
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
return unused_kwargs
......@@ -14,10 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
......@@ -33,8 +34,9 @@ from ..models.auto import (
)
from ..pytorch_utils import torch_int_div
from ..utils import ModelOutput, logging
from .beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .configuration_utils import GenerationConfig
from .logits_process import (
EncoderNoRepeatNGramLogitsProcessor,
ExponentialDecayLengthPenalty,
......@@ -478,6 +480,11 @@ class GenerationMixin:
`constraints!=None` or `force_words_ids!=None`.
"""
def prepare_inputs_for_generation(self, *args, **kwargs):
raise NotImplementedError(
"A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`."
)
def _prepare_model_inputs(
self,
inputs: Optional[torch.Tensor] = None,
......@@ -620,26 +627,16 @@ class GenerationMixin:
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
decoder_start_token_id
if decoder_start_token_id is not None
else self.generation_config.decoder_start_token_id
)
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
if decoder_start_token_id is not None:
return decoder_start_token_id
elif (
hasattr(self.config, "decoder")
and hasattr(self.config.decoder, "decoder_start_token_id")
and self.config.decoder.decoder_start_token_id is not None
):
return self.config.decoder.decoder_start_token_id
elif bos_token_id is not None:
return bos_token_id
elif (
hasattr(self.config, "decoder")
and hasattr(self.config.decoder, "bos_token_id")
and self.config.decoder.bos_token_id is not None
):
return self.config.decoder.bos_token_id
raise ValueError(
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
)
......@@ -722,166 +719,149 @@ class GenerationMixin:
def _get_logits_warper(
self,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
typical_p: Optional[float] = None,
temperature: Optional[float] = None,
num_beams: Optional[int] = None,
renormalize_logits: Optional[bool] = None,
generation_config: GenerationConfig,
) -> LogitsProcessorList:
"""
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
used for multinomial sampling.
"""
# init warp parameters
top_k = top_k if top_k is not None else self.config.top_k
top_p = top_p if top_p is not None else self.config.top_p
typical_p = typical_p if typical_p is not None else self.config.typical_p
temperature = temperature if temperature is not None else self.config.temperature
# instantiate warpers list
warpers = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
if temperature is not None and temperature != 1.0:
warpers.append(TemperatureLogitsWarper(temperature))
if top_k is not None and top_k != 0:
warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
if top_p is not None and top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
if typical_p is not None and typical_p < 1.0:
warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
if generation_config.temperature is not None and generation_config.temperature != 1.0:
warpers.append(TemperatureLogitsWarper(generation_config.temperature))
if generation_config.top_k is not None and generation_config.top_k != 0:
warpers.append(
TopKLogitsWarper(
top_k=generation_config.top_k, min_tokens_to_keep=(2 if generation_config.num_beams > 1 else 1)
)
)
if generation_config.top_p is not None and generation_config.top_p < 1.0:
warpers.append(
TopPLogitsWarper(
top_p=generation_config.top_p, min_tokens_to_keep=(2 if generation_config.num_beams > 1 else 1)
)
)
if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
warpers.append(
TypicalLogitsWarper(
mass=generation_config.typical_p, min_tokens_to_keep=(2 if generation_config.num_beams > 1 else 1)
)
)
# `LogitNormalization` should always be the last logit processor, when present
if renormalize_logits is True:
if generation_config.renormalize_logits is True:
warpers.append(LogitNormalization())
return warpers
def _get_logits_processor(
self,
repetition_penalty: float,
no_repeat_ngram_size: int,
encoder_no_repeat_ngram_size: int,
generation_config: GenerationConfig,
input_ids_seq_length: int,
encoder_input_ids: torch.LongTensor,
bad_words_ids: List[List[int]],
min_length: int,
max_length: int,
eos_token_id: int,
forced_bos_token_id: int,
forced_eos_token_id: int,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
num_beams: int,
num_beam_groups: int,
diversity_penalty: float,
remove_invalid_values: bool,
exponential_decay_length_penalty: Tuple,
logits_processor: Optional[LogitsProcessorList],
renormalize_logits: Optional[bool],
suppress_tokens: Optional[List[int]] = None,
begin_suppress_tokens: Optional[List[int]] = None,
forced_decoder_ids: Optional[List[List[int]]] = None,
) -> LogitsProcessorList:
"""
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
instances used to modify the scores of the language model head.
"""
processors = LogitsProcessorList()
# init warp parameters
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
)
encoder_no_repeat_ngram_size = (
encoder_no_repeat_ngram_size
if encoder_no_repeat_ngram_size is not None
else self.config.encoder_no_repeat_ngram_size
)
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty
forced_bos_token_id = (
forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id
)
forced_eos_token_id = (
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
)
remove_invalid_values = (
remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
)
exponential_decay_length_penalty = (
exponential_decay_length_penalty
if exponential_decay_length_penalty is not None
else self.config.exponential_decay_length_penalty
)
suppress_tokens = suppress_tokens if suppress_tokens is not None else self.config.suppress_tokens
begin_suppress_tokens = (
begin_suppress_tokens if begin_suppress_tokens is not None else self.config.begin_suppress_tokens
)
if forced_decoder_ids is None and hasattr(self.config, "forced_decoder_ids"):
forced_decoder_ids = self.config.forced_decoder_ids
# instantiate processors list
processors = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
if diversity_penalty is not None and diversity_penalty > 0.0:
if generation_config.diversity_penalty is not None and generation_config.diversity_penalty > 0.0:
processors.append(
HammingDiversityLogitsProcessor(
diversity_penalty=diversity_penalty, num_beams=num_beams, num_beam_groups=num_beam_groups
diversity_penalty=generation_config.diversity_penalty,
num_beams=generation_config.num_beams,
num_beam_groups=generation_config.num_beam_groups,
)
)
if repetition_penalty is not None and repetition_penalty != 1.0:
processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
if encoder_no_repeat_ngram_size is not None and encoder_no_repeat_ngram_size > 0:
if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:
processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty))
if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0:
processors.append(NoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size))
if (
generation_config.encoder_no_repeat_ngram_size is not None
and generation_config.encoder_no_repeat_ngram_size > 0
):
if self.config.is_encoder_decoder:
processors.append(EncoderNoRepeatNGramLogitsProcessor(encoder_no_repeat_ngram_size, encoder_input_ids))
processors.append(
EncoderNoRepeatNGramLogitsProcessor(
generation_config.encoder_no_repeat_ngram_size, encoder_input_ids
)
)
else:
raise ValueError(
"It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture"
)
if bad_words_ids is not None:
processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
if min_length is not None and eos_token_id is not None and min_length > 0:
processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
if generation_config.bad_words_ids is not None:
processors.append(
NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id)
)
if (
generation_config.min_length is not None
and generation_config.eos_token_id is not None
and generation_config.min_length > 0
):
processors.append(MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id))
if prefix_allowed_tokens_fn is not None:
processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams // num_beam_groups))
if forced_bos_token_id is not None:
processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id))
if forced_eos_token_id is not None:
processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
if remove_invalid_values is True:
processors.append(
PrefixConstrainedLogitsProcessor(
prefix_allowed_tokens_fn, generation_config.num_beams // generation_config.num_beam_groups
)
)
if generation_config.forced_bos_token_id is not None:
processors.append(ForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id))
if generation_config.forced_eos_token_id is not None:
processors.append(
ForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id)
)
if generation_config.remove_invalid_values is True:
processors.append(InfNanRemoveLogitsProcessor())
if exponential_decay_length_penalty is not None:
if generation_config.exponential_decay_length_penalty is not None:
processors.append(
ExponentialDecayLengthPenalty(exponential_decay_length_penalty, eos_token_id, input_ids_seq_length)
ExponentialDecayLengthPenalty(
generation_config.exponential_decay_length_penalty,
generation_config.eos_token_id,
generation_config.input_ids_seq_length,
)
)
if suppress_tokens is not None:
processors.append(SuppressTokensLogitsProcessor(suppress_tokens))
if begin_suppress_tokens is not None:
if generation_config.suppress_tokens is not None:
processors.append(SuppressTokensLogitsProcessor(generation_config.suppress_tokens))
if generation_config.begin_suppress_tokens is not None:
begin_index = input_ids_seq_length
begin_index = begin_index if (input_ids_seq_length > 1 or forced_bos_token_id is None) else begin_index + 1
if forced_decoder_ids is not None:
begin_index += forced_decoder_ids[-1][0] # generation starts after the last token that is forced
processors.append(SuppressTokensAtBeginLogitsProcessor(begin_suppress_tokens, begin_index))
if forced_decoder_ids is not None:
processors.append(ForceTokensLogitsProcessor(forced_decoder_ids))
begin_index = (
begin_index
if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)
else begin_index + 1
)
if generation_config.forced_decoder_ids is not None:
# generation starts after the last token that is forced
begin_index += generation_config.forced_decoder_ids[-1][0]
processors.append(
SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)
)
if generation_config.forced_decoder_ids is not None:
processors.append(ForceTokensLogitsProcessor(generation_config.forced_decoder_ids))
processors = self._merge_criteria_processor_list(processors, logits_processor)
# `LogitNormalization` should always be the last logit processor, when present
if renormalize_logits is True:
if generation_config.renormalize_logits is True:
processors.append(LogitNormalization())
return processors
def _get_stopping_criteria(
self, max_length: Optional[int], max_time: Optional[float], stopping_criteria: Optional[StoppingCriteriaList]
self, generation_config: GenerationConfig, stopping_criteria: Optional[StoppingCriteriaList]
) -> StoppingCriteriaList:
criteria = StoppingCriteriaList()
if max_length is not None:
criteria.append(MaxLengthCriteria(max_length=max_length))
if max_time is not None:
criteria.append(MaxTimeCriteria(max_time=max_time))
if generation_config.max_length is not None:
criteria.append(MaxLengthCriteria(max_length=generation_config.max_length))
if generation_config.max_time is not None:
criteria.append(MaxTimeCriteria(max_time=generation_config.max_time))
criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
return criteria
......@@ -951,7 +931,7 @@ class GenerationMixin:
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use.
"""
if not hasattr(self, "prepare_inputs_for_generation"):
if not self.can_generate():
generate_compatible_mappings = [
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
......@@ -999,81 +979,27 @@ class GenerationMixin:
def generate(
self,
inputs: Optional[torch.Tensor] = None,
max_length: Optional[int] = None,
min_length: Optional[int] = None,
do_sample: Optional[bool] = None,
early_stopping: Optional[bool] = None,
num_beams: Optional[int] = None,
temperature: Optional[float] = None,
penalty_alpha: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
typical_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
bad_words_ids: Optional[Iterable[int]] = None,
force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None,
bos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
encoder_no_repeat_ngram_size: Optional[int] = None,
num_return_sequences: Optional[int] = None,
max_time: Optional[float] = None,
max_new_tokens: Optional[int] = None,
decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None,
num_beam_groups: Optional[int] = None,
diversity_penalty: Optional[float] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
renormalize_logits: Optional[bool] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
constraints: Optional[List[Constraint]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[int] = None,
remove_invalid_values: Optional[bool] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = False,
exponential_decay_length_penalty: Optional[Tuple[int, float]] = None,
suppress_tokens: Optional[List[int]] = None,
begin_suppress_tokens: Optional[List[int]] = None,
forced_decoder_ids: Optional[List[List[int]]] = None,
**model_kwargs,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
r"""
Generates sequences of token ids for models with a language modeling head. The method supports the following
generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
- *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and
`do_sample=False`.
- *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0.`
and `top_k>1`
- *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and
`do_sample=True`.
- *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and
`do_sample=False`.
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if
`num_beams>1` and `do_sample=True`.
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin.group_beam_search`], if
`num_beams>1` and `num_beam_groups>1`.
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if
`constraints!=None` or `force_words_ids!=None`.
Generates sequences of token ids for models with a language modeling head.
<Tip warning={true}>
Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as
defined in the model's config (`config.json`) which in turn defaults to the
[`~modeling_utils.PretrainedConfig`] of the model.
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
model's default generation configuration. You can override any `generation_config` by passing the corresponding
parameters to generate, e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
</Tip>
For a complete overview of generate, check the [following
guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation).
Most of these parameters are explained in more detail in [this blog
post](https://huggingface.co/blog/how-to-generate).
</Tip>
Parameters:
inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
......@@ -1081,81 +1007,21 @@ class GenerationMixin:
method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
`input_ids`, `input_values`, `input_features`, or `pixel_values`.
max_length (`int`, *optional*, defaults to `model.config.max_length`):
The maximum length the generated tokens can have. Corresponds to the length of the input prompt +
`max_new_tokens`. In general, prefer the use of `max_new_tokens`, which ignores the number of tokens in
the prompt.
max_new_tokens (`int`, *optional*):
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
min_length (`int`, *optional*, defaults to `model.config.min_length` or 10 if the config does not set any value):
The minimum length of the sequence to be generated.
do_sample (`bool`, *optional*, defaults to `model.config.do_sample` or `False` if the config does not set any value):
Whether or not to use sampling ; use greedy decoding otherwise.
early_stopping (`bool`, *optional*, defaults to `False`):
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
num_beams (`int`, *optional*, defaults to `model.config.num_beams` or 1 if the config does not set any value):
Number of beams for beam search. 1 means no beam search.
temperature (`float`, *optional*, defaults to `model.config.temperature` or 1.0 if the config does not set any value):
The value used to module the next token probabilities.
penalty_alpha (`float`, *optional*, defaults to `model.config.penalty_alpha` or None if the config does not set any value):
The values balance the model confidence and the degeneration penalty in contrastive search decoding.
top_k (`int`, *optional*, defaults to `model.config.top_k` or 50 if the config does not set any value):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`, *optional*, defaults to `model.config.top_p` or 1.0 if the config does not set any value):
If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to
`top_p` or higher are kept for generation.
typical_p (`float`, *optional*, defaults to `model.config.typical_p` or 1.0 if the config does not set any value):
The amount of probability mass from the original distribution to be considered in typical decoding. If
set to 1.0 it takes no effect. See [this paper](https://arxiv.org/pdf/2202.00666.pdf) for more details.
repetition_penalty (`float`, *optional*, defaults to `model.config.repetition_penalty` or 1.0 if the config does not set any value):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
pad_token_id (`int`, *optional*, defaults to `model.config.pad_token_id`):
The id of the *padding* token.
bos_token_id (`int`, *optional*, defaults to `model.config.bos_token_id`):
The id of the *beginning-of-sequence* token.
eos_token_id (`int`, *optional*, defaults to `model.config.eos_token_id`):
The id of the *end-of-sequence* token.
length_penalty (`float`, *optional*, defaults to `model.config.length_penalty` or 1.0 if the config does not set any value):
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent
to the sequence length, which in turn is used to divide the score of the sequence. Since the score is
the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences,
while `length_penalty` < 0.0 encourages shorter sequences.
no_repeat_ngram_size (`int`, *optional*, defaults to `model.config.no_repeat_ngram_size` or 0 if the config does not set any value):
If set to int > 0, all ngrams of that size can only occur once.
encoder_no_repeat_ngram_size (`int`, *optional*, defaults to `model.config.encoder_no_repeat_ngram_size` or 0 if the config does not set any value):
If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the
`decoder_input_ids`.
bad_words_ids(`List[List[int]]`, *optional*, defaults to `model.config.bad_words_ids`):
List of token ids that are not allowed to be generated. In order to get the token ids of the words that
should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True,
add_special_tokens=False).input_ids`.
force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*):
List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple
list of words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`,
this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081),
where one can allow different forms of each word.
num_return_sequences(`int`, *optional*, defaults to `model.config.num_return_sequences` or 1 if the config does not set any value):
The number of independently computed returned sequences for each element in the batch.
max_time(`float`, *optional*):
The maximum amount of time you allow the computation to run for in seconds. generation will still
finish the current pass after allocated time has been passed.
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values are in `[0, 1]`, 1 for tokens
that are not masked, and 0 for masked tokens. If not provided, will default to a tensor the same shape
as `input_ids` that masks the pad token. [What are attention masks?](../glossary#attention-mask)
decoder_start_token_id (`int`, *optional*):
If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
num_beam_groups (`int`, *optional*, defaults to `model.config.num_beam_groups` or 1 if the config does not set any value):
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of
beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
diversity_penalty (`float`, *optional*, defaults to `model.config.diversity_penalty` or 0.0 if the config does not set any value):
This value is subtracted from a beam's score if it generates a token same as any beam from other group
at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is
enabled.
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, the default will be used, which had the following loading
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
logits_processor (`LogitsProcessorList`, *optional*):
Custom logits processors that complement the default logits processors built from arguments and
generation config. If a logit processor is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
stopping_criteria (`StoppingCriteriaList`, *optional*):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
......@@ -1163,60 +1029,12 @@ class GenerationMixin:
on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
Retrieval](https://arxiv.org/abs/2010.00904).
logits_processor (`LogitsProcessorList`, *optional*):
Custom logits processors that complement the default logits processors built from arguments and a
model's config. If a logit processor is passed that is already created with the arguments or a model's
config an error is thrown. This feature is intended for advanced users.
renormalize_logits (`bool`, *optional*, defaults to `False`):
Whether to renormalize the logits after applying all the logits processors or warpers (including the
custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the
score logits are normalized but some logit processors or warpers break the normalization.
stopping_criteria (`StoppingCriteriaList`, *optional*):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
model's config. If a stopping criteria is passed that is already created with the arguments or a
model's config an error is thrown. This feature is intended for advanced users.
constraints (`List[Constraint]`, *optional*):
Custom constraints that can be added to the generation to ensure that the output will contain the use
of certain tokens as defined by `Constraint` objects, in the most sensible way possible.
output_attentions (`bool`, *optional*, defaults to `model.config.output_attentions` or `False` if the config does not set any value):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
output_hidden_states (`bool`, *optional*, defaults to `model.config.output_hidden_states` or `False` if the config does not set any value):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more details.
output_scores (`bool`, *optional*, defaults to `model.config.output_scores` or `False` if the config does not set any value):
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
return_dict_in_generate (`bool`, *optional*, defaults to `model.config.return_dict_in_generate` or `False` if the config does not set any value):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
forced_bos_token_id (`int`, *optional*, defaults to `model.config.forced_bos_token_id`):
The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful
for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be
the target language token.
forced_eos_token_id (`int`, *optional*, defaults to `model.config.forced_eos_token_id`):
The id of the token to force as the last generated token when `max_length` is reached.
remove_invalid_values (`bool`, *optional*, defaults to `model.config.remove_invalid_values`):
Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to
crash. Note that using `remove_invalid_values` can slow down generation.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
exponential_decay_length_penalty (`tuple(int, float)`, *optional*, defaults to `model.config.exponential_decay_length_penalty`):
This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been
generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates
where penalty starts and `decay_factor` represents the factor of exponential decay
suppress_tokens (`List[int]`, *optional*, defaults to `model.config.suppress_tokens`):
A list of tokens that will be supressed at generation. The `SupressTokens` logit processor will set
their log probs to `-inf` so that they are not sampled.
begin_suppress_tokens (`List[int]`, *optional*, defaults to `model.config.begin_suppress_tokens`):
A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens`
logit processor will set their log probs to `-inf` so that they are not sampled.
forced_decoder_ids (`List[List[int]]`, *optional*, defaults to `model.config.forced_decoder_ids`):
A list of pairs of integers which indicates a mapping from generation indices to token indices that
will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always
be a token of index 123.
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model
is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs
should be prefixed with *decoder_*.
kwargs:
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
Return:
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
......@@ -1240,7 +1058,7 @@ class GenerationMixin:
Examples:
Greedy Decoding:
Greedy decoding, using the default generation configuration and ad hoc modifications:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
......@@ -1251,16 +1069,16 @@ class GenerationMixin:
>>> prompt = "Today I believe we can finally"
>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids
>>> # generate up to 30 tokens
>>> # Generate up to 30 tokens
>>> outputs = model.generate(input_ids, do_sample=False, max_length=30)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n']
```
Multinomial Sampling:
Multinomial sampling, modifying an existing generation configuration:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
......@@ -1269,17 +1087,20 @@ class GenerationMixin:
>>> prompt = "Today I believe we can finally"
>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids
>>> # sample up to 30 tokens
>>> # Sample up to 30 tokens
>>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
>>> outputs = model.generate(input_ids, do_sample=True, max_length=30)
>>> generation_config = GenerationConfig.from_pretrained("gpt2")
>>> generation_config.max_length = 30
>>> generation_config.do_sample = True
>>> outputs = model.generate(input_ids, generation_config=generation_config)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Today I believe we can finally get rid of discrimination," said Rep. Mark Pocan (D-Wis.).\n\n"Just look at the']
```
Beam-search decoding:
Beam-search decoding, using a freshly initialized generation configuration:
```python
>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, GenerationConfig
>>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
>>> model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de")
......@@ -1287,75 +1108,86 @@ class GenerationMixin:
>>> sentence = "Paris is one of the densest populated areas in Europe."
>>> input_ids = tokenizer(sentence, return_tensors="pt").input_ids
>>> outputs = model.generate(input_ids, num_beams=5)
>>> generation_config = GenerationConfig(
... max_length=64,
... num_beams=5,
... bos_token_id=0,
... eos_token_id=0,
... decoder_start_token_id=58100,
... pad_token_id=58100,
... bad_words_ids=[[58100]],
... )
>>> outputs = model.generate(input_ids, generation_config=generation_config)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
```"""
# 0. Validate the `.generate()` call
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()
self._validate_model_kwargs(model_kwargs.copy())
# 1. Set generation parameters if not already defined
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
num_beams = num_beams if num_beams is not None else self.config.num_beams
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
do_sample = do_sample if do_sample is not None else self.config.do_sample
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if generation_config is None:
# legacy: users may modify the model configuration to control generation -- update the generation config
# model attribute accordingly, if it was created from the model config
if self.generation_config._from_model_config:
new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config:
warnings.warn(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use a generation configuration file (see"
" https://huggingface.co/docs/transformers/main_classes/text_generation)"
)
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
self.generation_config = new_generation_config
generation_config = self.generation_config
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
self._validate_model_kwargs(model_kwargs.copy())
if eos_token_id is None and hasattr(self.config, "decoder"):
eos_token_id = self.config.decoder.eos_token_id
# 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if pad_token_id is None and eos_token_id is not None:
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask", None) is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
pad_token_id = eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
logger.warning(
f"Setting `pad_token_id` to `eos_token_id`:{generation_config.eos_token_id} for open-end generation."
)
generation_config.pad_token_id = generation_config.eos_token_id
# 2. Define model inputs
# 3. Define model inputs
# inputs_tensor has to be defined
# model_input_name is defined if model-specific keyword input is passed
# otherwise model_input_name is None
# all model-specific keyword inputs are removed from `model_kwargs`
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, bos_token_id, model_kwargs)
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
# 3. Define other model kwargs
model_kwargs["output_attentions"] = output_attentions
model_kwargs["output_hidden_states"] = output_hidden_states
model_kwargs["use_cache"] = use_cache
# 4. Define other model kwargs
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, pad_token_id, eos_token_id
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
)
# decoder-only models should use left-padding for generation
if not self.config.is_encoder_decoder:
if pad_token_id is not None and torch.sum(inputs_tensor[:, -1] == pad_token_id) > 0:
if (
generation_config.pad_token_id is not None
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
):
logger.warning(
"A decoder-only architecture is being used, but right-padding was detected! For correct "
"generation results, please set `padding_side='left'` when initializing the tokenizer."
......@@ -1368,12 +1200,12 @@ class GenerationMixin:
inputs_tensor, model_kwargs, model_input_name
)
# 4. Prepare `input_ids` which will be used for auto-regressive generation
# 5. Prepare `input_ids` which will be used for auto-regressive generation
if self.config.is_encoder_decoder:
input_ids = self._prepare_decoder_input_ids_for_generation(
batch_size,
decoder_start_token_id=decoder_start_token_id,
bos_token_id=bos_token_id,
decoder_start_token_id=generation_config.decoder_start_token_id,
bos_token_id=generation_config.bos_token_id,
model_kwargs=model_kwargs,
device=inputs_tensor.device,
)
......@@ -1381,87 +1213,91 @@ class GenerationMixin:
# if decoder-only then inputs_tensor has to be `input_ids`
input_ids = inputs_tensor
# 5. Prepare `max_length` depending on other stopping criteria.
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
if max_length is None and max_new_tokens is None:
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length == 20
if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn(
"Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to "
f"{self.config.max_length} (`self.config.max_length`). Controlling `max_length` via the config is "
"deprecated and `max_length` will be removed from the config in v5 of Transformers -- we recommend "
"using `max_new_tokens` to control the maximum length of the generation.",
"Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to"
f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the"
" config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we"
" recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning,
)
elif max_length is None and max_new_tokens is not None:
max_length = max_new_tokens + input_ids_seq_length
elif max_length is not None and max_new_tokens is not None:
elif has_default_max_length and generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
elif not has_default_max_length and generation_config.max_new_tokens is not None:
raise ValueError(
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
" limit to the generated output length. Remove one of those arguments. Please refer to the"
" documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
# default to config if still None
max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
if min_length is not None and min_length > max_length:
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
raise ValueError(
f"Unfeasible length constraints: the minimum length ({min_length}) is larger than the maximum "
f"length ({max_length})"
f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
f" the maximum length ({generation_config.max_length})"
)
if input_ids_seq_length >= max_length:
if input_ids_seq_length >= generation_config.max_length:
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
logger.warning(
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
f" {max_length}. This can lead to unexpected behavior. You should consider increasing "
"`max_new_tokens`."
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`."
)
# 6. determine generation mode
is_constraint_gen_mode = constraints is not None or force_words_ids is not None
# 7. determine generation mode
is_constraint_gen_mode = (
generation_config.constraints is not None or generation_config.force_words_ids is not None
)
is_contrastive_search_gen_mode = (
top_k is not None and top_k > 1 and do_sample is False and penalty_alpha is not None and penalty_alpha > 0
generation_config.top_k is not None
and generation_config.top_k > 1
and generation_config.do_sample is False
and generation_config.penalty_alpha is not None
and generation_config.penalty_alpha > 0
)
is_greedy_gen_mode = (
(num_beams == 1)
and (num_beam_groups == 1)
and do_sample is False
(generation_config.num_beams == 1)
and (generation_config.num_beam_groups == 1)
and generation_config.do_sample is False
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
is_sample_gen_mode = (
(num_beams == 1)
and (num_beam_groups == 1)
and do_sample is True
(generation_config.num_beams == 1)
and (generation_config.num_beam_groups == 1)
and generation_config.do_sample is True
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
is_beam_gen_mode = (
(num_beams > 1)
and (num_beam_groups == 1)
and do_sample is False
(generation_config.num_beams > 1)
and (generation_config.num_beam_groups == 1)
and generation_config.do_sample is False
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
is_beam_sample_gen_mode = (
(num_beams > 1)
and (num_beam_groups == 1)
and do_sample is True
(generation_config.num_beams > 1)
and (generation_config.num_beam_groups == 1)
and generation_config.do_sample is True
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
is_group_beam_gen_mode = (
(num_beams > 1)
and (num_beam_groups > 1)
(generation_config.num_beams > 1)
and (generation_config.num_beam_groups > 1)
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
if num_beam_groups > num_beams:
if generation_config.num_beam_groups > generation_config.num_beams:
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
if is_group_beam_gen_mode and do_sample is True:
if is_group_beam_gen_mode and generation_config.do_sample is True:
raise ValueError(
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
)
......@@ -1477,269 +1313,244 @@ class GenerationMixin:
UserWarning,
)
# 7. prepare distribution pre_processing samplers
# 8. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=inputs_tensor,
bad_words_ids=bad_words_ids,
min_length=min_length,
max_length=max_length,
eos_token_id=eos_token_id,
forced_bos_token_id=forced_bos_token_id,
forced_eos_token_id=forced_eos_token_id,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
num_beams=num_beams,
num_beam_groups=num_beam_groups,
diversity_penalty=diversity_penalty,
remove_invalid_values=remove_invalid_values,
exponential_decay_length_penalty=exponential_decay_length_penalty,
logits_processor=logits_processor,
renormalize_logits=renormalize_logits,
suppress_tokens=suppress_tokens,
begin_suppress_tokens=begin_suppress_tokens,
forced_decoder_ids=forced_decoder_ids,
)
# 8. prepare stopping criteria
# 9. prepare stopping criteria
stopping_criteria = self._get_stopping_criteria(
max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria
generation_config=generation_config, stopping_criteria=stopping_criteria
)
# 9. go into different generation modes
# 10. go into different generation modes
if is_greedy_gen_mode:
if num_return_sequences > 1:
if generation_config.num_return_sequences > 1:
raise ValueError(
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
" greedy search."
)
# 10. run greedy search
# 11. run greedy search
return self.greedy_search(
input_ids,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
**model_kwargs,
)
elif is_contrastive_search_gen_mode:
if num_return_sequences > 1:
if generation_config.num_return_sequences > 1:
raise ValueError(
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing contrastive search."
f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
" contrastive search."
)
return self.contrastive_search(
input_ids,
top_k=top_k,
penalty_alpha=penalty_alpha,
top_k=generation_config.top_k,
penalty_alpha=generation_config.penalty_alpha,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
**model_kwargs,
)
elif is_sample_gen_mode:
# 10. prepare logits warper
logits_warper = self._get_logits_warper(
top_k=top_k,
top_p=top_p,
typical_p=typical_p,
temperature=temperature,
num_beams=num_beams,
renormalize_logits=renormalize_logits,
)
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=num_return_sequences,
expand_size=generation_config.num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# 12. run sample
# 13. run sample
return self.sample(
input_ids,
logits_processor=logits_processor,
logits_warper=logits_warper,
stopping_criteria=stopping_criteria,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
**model_kwargs,
)
elif is_beam_gen_mode:
if num_return_sequences > num_beams:
if generation_config.num_return_sequences > generation_config.num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
# 10. prepare beam search scorer
# 11. prepare beam search scorer
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
num_beams=generation_config.num_beams,
device=inputs_tensor.device,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
num_beam_hyps_to_keep=num_return_sequences,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
)
# 11. interleave input_ids with `num_beams` additional sequences per batch
# 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=num_beams,
expand_size=generation_config.num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# 12. run beam search
# 13. run beam search
return self.beam_search(
input_ids,
beam_scorer,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
**model_kwargs,
)
elif is_beam_sample_gen_mode:
# 10. prepare logits warper
logits_warper = self._get_logits_warper(
top_k=top_k,
top_p=top_p,
typical_p=typical_p,
temperature=temperature,
num_beams=num_beams,
renormalize_logits=renormalize_logits,
)
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
# 11. prepare beam search scorer
# 12. prepare beam search scorer
beam_scorer = BeamSearchScorer(
batch_size=batch_size * num_return_sequences,
num_beams=num_beams,
batch_size=batch_size * generation_config.num_return_sequences,
num_beams=generation_config.num_beams,
device=inputs_tensor.device,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
)
# 12. interleave input_ids with `num_beams` additional sequences per batch
# 13. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=num_beams * num_return_sequences,
expand_size=generation_config.num_beams * generation_config.num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# 13. run beam sample
# 14. run beam sample
return self.beam_sample(
input_ids,
beam_scorer,
logits_processor=logits_processor,
logits_warper=logits_warper,
stopping_criteria=stopping_criteria,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
**model_kwargs,
)
elif is_group_beam_gen_mode:
if num_return_sequences > num_beams:
if generation_config.num_return_sequences > generation_config.num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
if num_beams % num_beam_groups != 0:
if generation_config.num_beams % generation_config.num_beam_groups != 0:
raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
if typical_p is not None:
has_default_typical_p = kwargs.get("typical_p") is None and generation_config.typical_p == 1.0
if not has_default_typical_p:
raise ValueError("Decoder argument `typical_p` is not supported with beam groups.")
# 10. prepare beam search scorer
# 11. prepare beam search scorer
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
num_beams=generation_config.num_beams,
max_length=stopping_criteria.max_length,
device=inputs_tensor.device,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
num_beam_hyps_to_keep=num_return_sequences,
num_beam_groups=num_beam_groups,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
num_beam_groups=generation_config.num_beam_groups,
)
# 11. interleave input_ids with `num_beams` additional sequences per batch
# 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=num_beams,
expand_size=generation_config.num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# 12. run beam search
# 13. run beam search
return self.group_beam_search(
input_ids,
beam_scorer,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
**model_kwargs,
)
elif is_constraint_gen_mode:
if num_return_sequences > num_beams:
if generation_config.num_return_sequences > generation_config.num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
if num_beams <= 1:
if generation_config.num_beams <= 1:
raise ValueError("`num_beams` needs to be greater than 1 for constrained generation.")
if do_sample:
if generation_config.do_sample:
raise ValueError("`do_sample` needs to be false for constrained generation.")
if num_beam_groups is not None and num_beam_groups > 1:
if generation_config.num_beam_groups is not None and generation_config.num_beam_groups > 1:
raise ValueError("`num_beam_groups` not supported yet for constrained generation.")
final_constraints = []
if constraints is not None:
final_constraints = constraints
if generation_config.constraints is not None:
final_constraints = generation_config.constraints
if force_words_ids is not None:
if generation_config.force_words_ids is not None:
def typeerror():
raise ValueError(
"`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`"
f"of positive integers, but is {force_words_ids}."
f"of positive integers, but is {generation_config.force_words_ids}."
)
if not isinstance(force_words_ids, list) or len(force_words_ids) == 0:
if (
not isinstance(generation_config.force_words_ids, list)
or len(generation_config.force_words_ids) == 0
):
typeerror()
for word_ids in force_words_ids:
for word_ids in generation_config.force_words_ids:
if isinstance(word_ids[0], list):
if not isinstance(word_ids, list) or len(word_ids) == 0:
typeerror()
......@@ -1761,33 +1572,33 @@ class GenerationMixin:
constraint = PhrasalConstraint(word_ids)
final_constraints.append(constraint)
# 10. prepare beam search scorer
# 11. prepare beam search scorer
constrained_beam_scorer = ConstrainedBeamSearchScorer(
constraints=final_constraints,
batch_size=batch_size,
num_beams=num_beams,
num_beams=generation_config.num_beams,
device=inputs_tensor.device,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
num_beam_hyps_to_keep=num_return_sequences,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
)
# 11. interleave input_ids with `num_beams` additional sequences per batch
# 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=num_beams,
expand_size=generation_config.num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# 12. run beam search
# 13. run beam search
return self.constrained_beam_search(
input_ids,
constrained_beam_scorer=constrained_beam_scorer,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
**model_kwargs,
)
......@@ -1884,15 +1695,19 @@ class GenerationMixin:
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
return_dict_in_generate
if return_dict_in_generate is not None
else self.generation_config.return_dict_in_generate
)
# init attention / hidden states / scores tuples
......@@ -2239,15 +2054,19 @@ class GenerationMixin:
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
return_dict_in_generate
if return_dict_in_generate is not None
else self.generation_config.return_dict_in_generate
)
# init attention / hidden states / scores tuples
......@@ -2487,15 +2306,19 @@ class GenerationMixin:
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
return_dict_in_generate
if return_dict_in_generate is not None
else self.generation_config.return_dict_in_generate
)
# init attention / hidden states / scores tuples
......@@ -2739,15 +2562,19 @@ class GenerationMixin:
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
if len(stopping_criteria) == 0:
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
return_dict_in_generate
if return_dict_in_generate is not None
else self.generation_config.return_dict_in_generate
)
batch_size = len(beam_scorer._beam_hyps)
......@@ -3059,15 +2886,19 @@ class GenerationMixin:
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
return_dict_in_generate
if return_dict_in_generate is not None
else self.generation_config.return_dict_in_generate
)
batch_size = len(beam_scorer._beam_hyps)
......@@ -3368,15 +3199,19 @@ class GenerationMixin:
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
return_dict_in_generate
if return_dict_in_generate is not None
else self.generation_config.return_dict_in_generate
)
batch_size = len(beam_scorer._beam_hyps)
......@@ -3737,15 +3572,19 @@ class GenerationMixin:
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
if len(stopping_criteria) == 0:
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
return_dict_in_generate
if return_dict_in_generate is not None
else self.generation_config.return_dict_in_generate
)
# init attention / hidden states / scores tuples
......
......@@ -39,7 +39,7 @@ from .activations import get_activation
from .configuration_utils import PretrainedConfig
from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
from .dynamic_module_utils import custom_object_save
from .generation import GenerationMixin
from .generation import GenerationConfig, GenerationMixin
from .pytorch_utils import ( # noqa: F401
Conv1D,
apply_chunking_to_forward,
......@@ -1024,6 +1024,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
self.config = config
self.name_or_path = config.name_or_path
self.warnings_issued = {}
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
def post_init(self):
"""
......@@ -1106,6 +1107,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"""
return getattr(self, self.base_model_prefix, self)
def can_generate(self) -> bool:
"""
Returns whether this model can generate sequences with `.generate()`.
Returns:
`bool`: Whether this model can generate sequences with `.generate()`.
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
if "GenerationMixin" in str(self.prepare_inputs_for_generation):
return False
return True
def get_input_embeddings(self) -> nn.Module:
"""
Returns the model's input embeddings.
......@@ -2477,6 +2490,29 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
# If it is a model with generation capabilities, attempt to load the generation config
if model.can_generate():
try:
model.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
**kwargs,
)
except OSError:
logger.info(
"Generation config file not found, using a generation config created from the model config."
)
pass
# Dispatch model with hooks on all devices if necessary
if device_map is not None:
dispatch_model(model, device_map=device_map, offload_dir=offload_folder, offload_index=offload_index)
......
......@@ -14,6 +14,7 @@
# limitations under the License.
"""RAG model implementation."""
import copy
from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple, Union
......@@ -21,7 +22,7 @@ import torch
from torch import nn
from ...configuration_utils import PretrainedConfig
from ...generation import BeamSearchScorer, LogitsProcessorList, StoppingCriteriaList
from ...generation import BeamSearchScorer, GenerationConfig, LogitsProcessorList, StoppingCriteriaList
from ...modeling_outputs import ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings
......@@ -1384,33 +1385,12 @@ class RagTokenForGeneration(RagPreTrainedModel):
context_input_ids: Optional[torch.LongTensor] = None,
context_attention_mask: Optional[torch.LongTensor] = None,
doc_scores: Optional[torch.FloatTensor] = None,
max_length: Optional[int] = None,
min_length: Optional[int] = None,
early_stopping: Optional[bool] = None,
use_cache: Optional[bool] = None,
num_beams: Optional[int] = None,
num_beam_groups: Optional[int] = None,
diversity_penalty: Optional[float] = None,
bos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
encoder_no_repeat_ngram_size: Optional[int] = None,
repetition_penalty: Optional[float] = None,
bad_words_ids: Optional[List[List[int]]] = None,
num_return_sequences: Optional[int] = None,
decoder_start_token_id: Optional[int] = None,
n_docs: Optional[int] = None,
generation_config: Optional[GenerationConfig] = None,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
renormalize_logits: Optional[bool] = None,
stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[int] = None,
remove_invalid_values: Optional[bool] = None,
exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None,
**model_kwargs
**kwargs
) -> torch.LongTensor:
"""
Implements RAG token decoding.
......@@ -1444,51 +1424,15 @@ class RagTokenForGeneration(RagPreTrainedModel):
If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
max_length (`int`, *optional*, defaults to 20):
The maximum length of the sequence to be generated.
min_length (`int`, *optional*, defaults to 10):
The minimum length of the sequence to be generated.
early_stopping (`bool`, *optional*, defaults to `False`):
Whether or not to stop the beam search when at least `num_beams` sentences are finished per batch or
not.
use_cache: (`bool`, *optional*, defaults to `True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
bos_token_id (`int`, *optional*):
The id of the *beginning-of-sequence* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
length_penalty (`float`, *optional*, defaults to 1.0):
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent
to the sequence length, which in turn is used to divide the score of the sequence. Since the score is
the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences,
while `length_penalty` < 0.0 encourages shorter sequences.
no_repeat_ngram_size (`int`, *optional*, defaults to 0):
If set to int > 0, all ngrams of that size can only occur once.
encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0):
If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the
`decoder_input_ids`.
bad_words_ids(`List[int]`, *optional*):
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.
num_beams (`int`, *optional*, defaults to 1):
Number of beams for beam search. 1 means no beam search.
num_beam_groups (`int`, *optional*, defaults to 1):
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of
beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
diversity_penalty (`float`, *optional*, defaults to 0.0):
This value is subtracted from a beam's score if it generates a token same as any beam from other group
at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is
enabled.
num_return_sequences(`int`, *optional*, defaults to 1):
The number of independently computed returned sequences for each element in the batch. Note that this
is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`] function, where
we set `num_return_sequences` to `num_beams`. decoder_start_token_id (`int`, *optional*): If an
encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
n_docs (`int`, *optional*, defaults to `config.n_docs`)
Number of documents to retrieve and/or number of documents for which to generate an answer.
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, the default will be used, which has the following loading
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
prefix_allowed_tokens_fn: (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments `inputs_ids` and the batch ID
......@@ -1504,46 +1448,23 @@ class RagTokenForGeneration(RagPreTrainedModel):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
model's config. If a stopping criteria is passed that is already created with the arguments or a
model's config an error is thrown.
forced_bos_token_id (`int`, *optional*):
The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful
for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be
the target language token.
forced_eos_token_id (`int`, *optional*):
The id of the token to force as the last generated token when `max_length` is reached.
remove_invalid_values (`bool`, *optional*):
Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to
crash. Note that using `remove_invalid_values` can slow down generation.
kwargs:
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model.
Return:
`torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated
sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches
finished early due to the `eos_token_id`.
"""
# Handle `generation_config` and kwargs that might update it
if generation_config is None:
generation_config = self.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
# set default parameters
n_docs = n_docs if n_docs is not None else self.config.n_docs
num_beams = num_beams if num_beams is not None else self.config.num_beams
num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
max_length = max_length if max_length is not None else self.config.max_length
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
)
bos_token_id = bos_token_id if bos_token_id is not None else self.config.generator.bos_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.generator.eos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.generator.pad_token_id
use_cache = use_cache if use_cache is not None else self.config.use_cache
decoder_start_token_id = (
decoder_start_token_id
if decoder_start_token_id is not None
else self.config.generator.decoder_start_token_id
)
remove_invalid_values = (
remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
)
exponential_decay_length_penalty = (
exponential_decay_length_penalty
if exponential_decay_length_penalty is not None
else self.config.exponential_decay_length_penalty
)
# retrieve docs
if self.retriever is not None and context_input_ids is None:
......@@ -1583,8 +1504,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask, return_dict=True)
input_ids = torch.full(
(batch_size * num_beams, 1),
decoder_start_token_id,
(batch_size * generation_config.num_beams, 1),
generation_config.decoder_start_token_id,
dtype=torch.long,
device=next(self.parameters()).device,
)
......@@ -1600,10 +1521,12 @@ class RagTokenForGeneration(RagPreTrainedModel):
return tensor.reshape((batch_size * num_beams * n_docs,) + tensor.shape[3:])
# correctly extend last_hidden_state and attention mask
context_attention_mask = extend_enc_output(context_attention_mask, num_beams=num_beams)
encoder_outputs["last_hidden_state"] = extend_enc_output(last_hidden_state, num_beams=num_beams)
context_attention_mask = extend_enc_output(context_attention_mask, num_beams=generation_config.num_beams)
encoder_outputs["last_hidden_state"] = extend_enc_output(
last_hidden_state, num_beams=generation_config.num_beams
)
doc_scores = doc_scores.repeat_interleave(num_beams, dim=0)
doc_scores = doc_scores.repeat_interleave(generation_config.num_beams, dim=0)
# define start_len & additional parameters
model_kwargs["doc_scores"] = doc_scores
......@@ -1612,64 +1535,51 @@ class RagTokenForGeneration(RagPreTrainedModel):
model_kwargs["n_docs"] = n_docs
pre_processor = self._get_logits_processor(
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=context_input_ids,
bad_words_ids=bad_words_ids,
min_length=min_length,
max_length=max_length,
eos_token_id=eos_token_id,
forced_bos_token_id=forced_bos_token_id,
forced_eos_token_id=forced_eos_token_id,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
num_beams=num_beams,
num_beam_groups=num_beam_groups,
diversity_penalty=diversity_penalty,
remove_invalid_values=remove_invalid_values,
exponential_decay_length_penalty=exponential_decay_length_penalty,
logits_processor=logits_processor,
renormalize_logits=renormalize_logits,
)
if num_beams == 1:
if num_return_sequences > 1:
if generation_config.num_beams == 1:
if generation_config.num_return_sequences > 1:
raise ValueError(
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
" greedy search."
)
return self.greedy_search(
input_ids,
logits_processor=pre_processor,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=generation_config.max_length,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
**model_kwargs,
)
elif num_beams > 1:
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
if num_return_sequences > num_beams:
elif generation_config.num_beams > 1:
if generation_config.num_return_sequences > generation_config.num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
num_beams=generation_config.num_beams,
device=self.device,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
num_beam_hyps_to_keep=num_return_sequences,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
)
return self.beam_search(
input_ids,
beam_scorer,
logits_processor=pre_processor,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=generation_config.max_length,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
**model_kwargs,
)
else:
raise ValueError(f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {num_beams}")
raise ValueError(
f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {generation_config.num_beams}"
)
def get_input_embeddings(self):
return self.rag.generator.get_input_embeddings()
......
......@@ -13,11 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import tempfile
import unittest
from parameterized import parameterized
from transformers.generation import GenerationConfig
from transformers import AutoConfig, GenerationConfig
class LogitsProcessorTest(unittest.TestCase):
......@@ -43,3 +44,33 @@ class LogitsProcessorTest(unittest.TestCase):
self.assertEqual(loaded_config.top_k, 50)
self.assertEqual(loaded_config.max_length, 20)
self.assertEqual(loaded_config.max_time, None)
def test_from_model_config(self):
model_config = AutoConfig.from_pretrained("gpt2")
generation_config_from_model = GenerationConfig.from_model_config(model_config)
default_generation_config = GenerationConfig()
# The generation config has loaded a few non-default parameters from the model config
self.assertNotEqual(generation_config_from_model, default_generation_config)
# One of those parameters is eos_token_id -- check if it matches
self.assertNotEqual(generation_config_from_model.eos_token_id, default_generation_config.eos_token_id)
self.assertEqual(generation_config_from_model.eos_token_id, model_config.eos_token_id)
def test_update(self):
generation_config = GenerationConfig()
update_kwargs = {
"max_new_tokens": 1024,
"foo": "bar",
}
update_kwargs_copy = copy.deepcopy(update_kwargs)
unused_kwargs = generation_config.update(**update_kwargs)
# update_kwargs was not modified (no side effects)
self.assertEqual(update_kwargs, update_kwargs_copy)
# update_kwargs was used to update the config on valid attributes
self.assertEqual(generation_config.max_new_tokens, 1024)
# `.update()` returns a dictionary of unused kwargs
self.assertEqual(unused_kwargs, {"foo": "bar"})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment