Unverified Commit bff4313b authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: get generation mode as an enum (#25292)

parent fab1a0aa
...@@ -33,7 +33,7 @@ from ..models.auto import ( ...@@ -33,7 +33,7 @@ from ..models.auto import (
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING, MODEL_FOR_VISION_2_SEQ_MAPPING,
) )
from ..utils import ModelOutput, logging from ..utils import ExplicitEnum, ModelOutput, logging
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .configuration_utils import GenerationConfig from .configuration_utils import GenerationConfig
...@@ -468,6 +468,23 @@ ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, Contrasti ...@@ -468,6 +468,23 @@ ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, Contrasti
GenerateOutput = Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, ContrastiveSearchOutput] GenerateOutput = Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, ContrastiveSearchOutput]
class GenerationMode(ExplicitEnum):
"""
Possible generation modes, downstream of the [`~generation.GenerationMixin.generate`] method.
"""
# Non-beam methods
CONTRASTIVE_SEARCH = "contrastive_search"
GREEDY_SEARCH = "greedy_search"
SAMPLE = "sample"
ASSISTED_GENERATION = "assisted_generation"
# Beam methods
BEAM_SEARCH = "beam_search"
BEAM_SAMPLE = "beam_sample"
CONSTRAINED_BEAM_SEARCH = "constrained_beam_search"
GROUP_BEAM_SEARCH = "group_beam_search"
class GenerationMixin: class GenerationMixin:
""" """
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`]. A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
...@@ -829,6 +846,46 @@ class GenerationMixin: ...@@ -829,6 +846,46 @@ class GenerationMixin:
warpers.append(LogitNormalization()) warpers.append(LogitNormalization())
return warpers return warpers
def _get_generation_mode(
self, generation_config: GenerationConfig, assistant_model: Optional["PreTrainedModel"]
) -> GenerationMode:
"""
Returns the generation mode triggered by a [`GenerationConfig`] instance.
"""
if generation_config.constraints is not None or generation_config.force_words_ids is not None:
generation_mode = GenerationMode.CONSTRAINED_BEAM_SEARCH
elif generation_config.num_beams == 1:
if generation_config.do_sample is False:
if (
generation_config.top_k is not None
and generation_config.top_k > 1
and generation_config.penalty_alpha is not None
and generation_config.penalty_alpha > 0
):
generation_mode = GenerationMode.CONTRASTIVE_SEARCH
else:
generation_mode = GenerationMode.GREEDY_SEARCH
else:
generation_mode = GenerationMode.SAMPLE
else:
if generation_config.num_beam_groups > 1:
generation_mode = GenerationMode.GROUP_BEAM_SEARCH
elif generation_config.do_sample is True:
generation_mode = GenerationMode.BEAM_SAMPLE
else:
generation_mode = GenerationMode.BEAM_SEARCH
# Assisted generation may extend some generation modes
if assistant_model is not None:
if generation_mode in ("greedy_search", "sample"):
generation_mode = GenerationMode.ASSISTED_GENERATION
else:
raise ValueError(
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
"is only supported with Greedy Search and Sample."
)
return generation_mode
def _get_logits_processor( def _get_logits_processor(
self, self,
generation_config: GenerationConfig, generation_config: GenerationConfig,
...@@ -1422,65 +1479,11 @@ class GenerationMixin: ...@@ -1422,65 +1479,11 @@ class GenerationMixin:
) )
# 7. determine generation mode # 7. determine generation mode
is_constraint_gen_mode = ( generation_mode = self._get_generation_mode(generation_config, assistant_model)
generation_config.constraints is not None or generation_config.force_words_ids is not None
)
is_contrastive_search_gen_mode = (
(generation_config.num_beams == 1)
and generation_config.top_k is not None
and generation_config.top_k > 1
and generation_config.do_sample is False
and generation_config.penalty_alpha is not None
and generation_config.penalty_alpha > 0
)
is_greedy_gen_mode = (
(generation_config.num_beams == 1)
and (generation_config.num_beam_groups == 1)
and generation_config.do_sample is False
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
is_sample_gen_mode = (
(generation_config.num_beams == 1)
and (generation_config.num_beam_groups == 1)
and generation_config.do_sample is True
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
is_beam_gen_mode = (
(generation_config.num_beams > 1)
and (generation_config.num_beam_groups == 1)
and generation_config.do_sample is False
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
is_beam_sample_gen_mode = (
(generation_config.num_beams > 1)
and (generation_config.num_beam_groups == 1)
and generation_config.do_sample is True
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
is_group_beam_gen_mode = (
(generation_config.num_beams > 1)
and (generation_config.num_beam_groups > 1)
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
is_assisted_gen_mode = False
if assistant_model is not None:
if not (is_greedy_gen_mode or is_sample_gen_mode):
raise ValueError(
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
"is only supported with Greedy Search and Sample."
)
is_assisted_gen_mode = True
if generation_config.num_beam_groups > generation_config.num_beams: if generation_config.num_beam_groups > generation_config.num_beams:
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
if is_group_beam_gen_mode and generation_config.do_sample is True: if generation_mode == GenerationMode.GROUP_BEAM_SEARCH and generation_config.do_sample is True:
raise ValueError( raise ValueError(
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
) )
...@@ -1515,7 +1518,7 @@ class GenerationMixin: ...@@ -1515,7 +1518,7 @@ class GenerationMixin:
generation_config=generation_config, stopping_criteria=stopping_criteria generation_config=generation_config, stopping_criteria=stopping_criteria
) )
# 10. go into different generation modes # 10. go into different generation modes
if is_assisted_gen_mode: if generation_mode == GenerationMode.ASSISTED_GENERATION:
if generation_config.num_return_sequences > 1: if generation_config.num_return_sequences > 1:
raise ValueError( raise ValueError(
"num_return_sequences has to be 1 when doing assisted generate, " "num_return_sequences has to be 1 when doing assisted generate, "
...@@ -1553,7 +1556,7 @@ class GenerationMixin: ...@@ -1553,7 +1556,7 @@ class GenerationMixin:
streamer=streamer, streamer=streamer,
**model_kwargs, **model_kwargs,
) )
if is_greedy_gen_mode: if generation_mode == GenerationMode.GREEDY_SEARCH:
if generation_config.num_return_sequences > 1: if generation_config.num_return_sequences > 1:
raise ValueError( raise ValueError(
"num_return_sequences has to be 1 when doing greedy search, " "num_return_sequences has to be 1 when doing greedy search, "
...@@ -1574,7 +1577,7 @@ class GenerationMixin: ...@@ -1574,7 +1577,7 @@ class GenerationMixin:
**model_kwargs, **model_kwargs,
) )
elif is_contrastive_search_gen_mode: elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
if generation_config.num_return_sequences > 1: if generation_config.num_return_sequences > 1:
raise ValueError( raise ValueError(
"num_return_sequences has to be 1 when doing contrastive search, " "num_return_sequences has to be 1 when doing contrastive search, "
...@@ -1599,7 +1602,7 @@ class GenerationMixin: ...@@ -1599,7 +1602,7 @@ class GenerationMixin:
**model_kwargs, **model_kwargs,
) )
elif is_sample_gen_mode: elif generation_mode == GenerationMode.SAMPLE:
# 11. prepare logits warper # 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config) logits_warper = self._get_logits_warper(generation_config)
...@@ -1626,7 +1629,7 @@ class GenerationMixin: ...@@ -1626,7 +1629,7 @@ class GenerationMixin:
**model_kwargs, **model_kwargs,
) )
elif is_beam_gen_mode: elif generation_mode == GenerationMode.BEAM_SEARCH:
if generation_config.num_return_sequences > generation_config.num_beams: if generation_config.num_return_sequences > generation_config.num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
...@@ -1664,7 +1667,7 @@ class GenerationMixin: ...@@ -1664,7 +1667,7 @@ class GenerationMixin:
**model_kwargs, **model_kwargs,
) )
elif is_beam_sample_gen_mode: elif generation_mode == GenerationMode.BEAM_SAMPLE:
# 11. prepare logits warper # 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config) logits_warper = self._get_logits_warper(generation_config)
...@@ -1703,7 +1706,7 @@ class GenerationMixin: ...@@ -1703,7 +1706,7 @@ class GenerationMixin:
**model_kwargs, **model_kwargs,
) )
elif is_group_beam_gen_mode: elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH:
if generation_config.num_return_sequences > generation_config.num_beams: if generation_config.num_return_sequences > generation_config.num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
...@@ -1754,7 +1757,7 @@ class GenerationMixin: ...@@ -1754,7 +1757,7 @@ class GenerationMixin:
**model_kwargs, **model_kwargs,
) )
elif is_constraint_gen_mode: elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH:
if generation_config.num_return_sequences > generation_config.num_beams: if generation_config.num_return_sequences > generation_config.num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment