Unverified Commit 700d48fb authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: get generation mode from the generation config instance 🧼 (#29441)

parent 41f7b7ae
......@@ -37,6 +37,9 @@ like token streaming.
- from_pretrained
- from_model_config
- save_pretrained
- update
- validate
- get_generation_mode
## GenerationMixin
......
......@@ -18,7 +18,7 @@ from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_availab
_import_structure = {
"configuration_utils": ["GenerationConfig"],
"configuration_utils": ["GenerationConfig", "GenerationMode"],
"streamers": ["TextIteratorStreamer", "TextStreamer"],
}
......@@ -172,7 +172,7 @@ else:
]
if TYPE_CHECKING:
from .configuration_utils import GenerationConfig
from .configuration_utils import GenerationConfig, GenerationMode
from .streamers import TextIteratorStreamer, TextStreamer
try:
......
......@@ -18,12 +18,13 @@ import copy
import json
import os
import warnings
from typing import Any, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from .. import __version__
from ..configuration_utils import PretrainedConfig
from ..utils import (
GENERATION_CONFIG_NAME,
ExplicitEnum,
PushToHubMixin,
cached_file,
download_url,
......@@ -33,10 +34,31 @@ from ..utils import (
)
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
logger = logging.get_logger(__name__)
METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version")
class GenerationMode(ExplicitEnum):
"""
Possible generation modes, downstream of the [`~generation.GenerationMixin.generate`] method.
"""
# Non-beam methods
CONTRASTIVE_SEARCH = "contrastive_search"
GREEDY_SEARCH = "greedy_search"
SAMPLE = "sample"
ASSISTED_GENERATION = "assisted_generation"
# Beam methods
BEAM_SEARCH = "beam_search"
BEAM_SAMPLE = "beam_sample"
CONSTRAINED_BEAM_SEARCH = "constrained_beam_search"
GROUP_BEAM_SEARCH = "group_beam_search"
class GenerationConfig(PushToHubMixin):
# no-format
r"""
......@@ -376,13 +398,65 @@ class GenerationConfig(PushToHubMixin):
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string(ignore_metadata=True)}"
def get_generation_mode(self, assistant_model: Optional["PreTrainedModel"] = None) -> GenerationMode:
"""
Returns the generation mode triggered by the [`GenerationConfig`] instance.
Arg:
assistant_model (`PreTrainedModel`, *optional*):
The assistant model to be used for assisted generation. If set, the generation mode will be
assisted generation.
Returns:
`GenerationMode`: The generation mode triggered by the instance.
"""
# TODO joao: find out a way of not depending on external fields (e.g. `assistant_model`), then make this a
# property and part of the `__repr__`
if self.constraints is not None or self.force_words_ids is not None:
generation_mode = GenerationMode.CONSTRAINED_BEAM_SEARCH
elif self.num_beams == 1:
if self.do_sample is False:
if (
self.top_k is not None
and self.top_k > 1
and self.penalty_alpha is not None
and self.penalty_alpha > 0
):
generation_mode = GenerationMode.CONTRASTIVE_SEARCH
else:
generation_mode = GenerationMode.GREEDY_SEARCH
else:
generation_mode = GenerationMode.SAMPLE
else:
if self.num_beam_groups > 1:
generation_mode = GenerationMode.GROUP_BEAM_SEARCH
elif self.do_sample is True:
generation_mode = GenerationMode.BEAM_SAMPLE
else:
generation_mode = GenerationMode.BEAM_SEARCH
# Assisted generation may extend some generation modes
if assistant_model is not None or self.prompt_lookup_num_tokens is not None:
if generation_mode in ("greedy_search", "sample"):
generation_mode = GenerationMode.ASSISTED_GENERATION
else:
raise ValueError(
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
"is only supported with Greedy Search and Sample."
)
return generation_mode
def validate(self, is_init=False):
"""
Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
of parameterization that can be detected as incorrect from the configuration instance alone.
Note that some parameters are best validated at generate runtime, as they may depend on other inputs and/or the
model, such as parameters related to the generation length.
Note that some parameters not validated here are best validated at generate runtime, as they may depend on
other inputs and/or the model, such as parameters related to the generation length.
Arg:
is_init (`bool`, *optional*, defaults to `False`):
Whether the validation is performed during the initialization of the instance.
"""
# Validation of individual attributes
......
......@@ -34,7 +34,7 @@ from ..models.auto import (
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
)
from ..utils import ExplicitEnum, ModelOutput, is_accelerate_available, logging
from ..utils import ModelOutput, is_accelerate_available, logging
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .candidate_generator import (
......@@ -45,7 +45,7 @@ from .candidate_generator import (
_prepare_attention_mask,
_prepare_token_type_ids,
)
from .configuration_utils import GenerationConfig
from .configuration_utils import GenerationConfig, GenerationMode
from .logits_process import (
EncoderNoRepeatNGramLogitsProcessor,
EncoderRepetitionPenaltyLogitsProcessor,
......@@ -325,23 +325,6 @@ GenerateBeamOutput = Union[GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDec
GenerateOutput = Union[GenerateNonBeamOutput, GenerateBeamOutput]
class GenerationMode(ExplicitEnum):
"""
Possible generation modes, downstream of the [`~generation.GenerationMixin.generate`] method.
"""
# Non-beam methods
CONTRASTIVE_SEARCH = "contrastive_search"
GREEDY_SEARCH = "greedy_search"
SAMPLE = "sample"
ASSISTED_GENERATION = "assisted_generation"
# Beam methods
BEAM_SEARCH = "beam_search"
BEAM_SAMPLE = "beam_sample"
CONSTRAINED_BEAM_SEARCH = "constrained_beam_search"
GROUP_BEAM_SEARCH = "group_beam_search"
class GenerationMixin:
"""
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
......@@ -764,46 +747,6 @@ class GenerationMixin:
warpers.append(LogitNormalization())
return warpers
def _get_generation_mode(
self, generation_config: GenerationConfig, assistant_model: Optional["PreTrainedModel"]
) -> GenerationMode:
"""
Returns the generation mode triggered by a [`GenerationConfig`] instance.
"""
if generation_config.constraints is not None or generation_config.force_words_ids is not None:
generation_mode = GenerationMode.CONSTRAINED_BEAM_SEARCH
elif generation_config.num_beams == 1:
if generation_config.do_sample is False:
if (
generation_config.top_k is not None
and generation_config.top_k > 1
and generation_config.penalty_alpha is not None
and generation_config.penalty_alpha > 0
):
generation_mode = GenerationMode.CONTRASTIVE_SEARCH
else:
generation_mode = GenerationMode.GREEDY_SEARCH
else:
generation_mode = GenerationMode.SAMPLE
else:
if generation_config.num_beam_groups > 1:
generation_mode = GenerationMode.GROUP_BEAM_SEARCH
elif generation_config.do_sample is True:
generation_mode = GenerationMode.BEAM_SAMPLE
else:
generation_mode = GenerationMode.BEAM_SEARCH
# Assisted generation may extend some generation modes
if assistant_model is not None or generation_config.prompt_lookup_num_tokens is not None:
if generation_mode in ("greedy_search", "sample"):
generation_mode = GenerationMode.ASSISTED_GENERATION
else:
raise ValueError(
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
"is only supported with Greedy Search and Sample."
)
return generation_mode
def _get_logits_processor(
self,
generation_config: GenerationConfig,
......@@ -1474,7 +1417,7 @@ class GenerationMixin:
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
# 7. determine generation mode
generation_mode = self._get_generation_mode(generation_config, assistant_model)
generation_mode = generation_config.get_generation_mode(assistant_model)
if streamer is not None and (generation_config.num_beams > 1):
raise ValueError(
......
......@@ -24,6 +24,7 @@ from parameterized import parameterized
from requests.exceptions import HTTPError
from transformers import AutoConfig, GenerationConfig
from transformers.generation import GenerationMode
from transformers.testing_utils import TOKEN, USER, is_staging_test
......@@ -202,6 +203,23 @@ class GenerationConfigTest(unittest.TestCase):
self.assertEqual(len(captured_warnings), 0)
self.assertTrue(len(os.listdir(tmp_dir)) == 1)
def test_generation_mode(self):
"""Tests that the `get_generation_mode` method is working as expected."""
config = GenerationConfig()
self.assertEqual(config.get_generation_mode(), GenerationMode.GREEDY_SEARCH)
config = GenerationConfig(do_sample=True)
self.assertEqual(config.get_generation_mode(), GenerationMode.SAMPLE)
config = GenerationConfig(num_beams=2)
self.assertEqual(config.get_generation_mode(), GenerationMode.BEAM_SEARCH)
config = GenerationConfig(top_k=10, do_sample=False, penalty_alpha=0.6)
self.assertEqual(config.get_generation_mode(), GenerationMode.CONTRASTIVE_SEARCH)
config = GenerationConfig()
self.assertEqual(config.get_generation_mode(assistant_model="foo"), GenerationMode.ASSISTED_GENERATION)
@is_staging_test
class ConfigPushToHubTester(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment