Unverified Commit 5bd8c011 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: add config-level validation (#25381)

parent 9e57e0c0
...@@ -332,12 +332,121 @@ class GenerationConfig(PushToHubMixin): ...@@ -332,12 +332,121 @@ class GenerationConfig(PushToHubMixin):
def validate(self): def validate(self):
""" """
Validates the values of the attributes of the GenerationConfig instance, and raises a `ValueError` if any of Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
the values are invalid. of parameterization that can be detected as incorrect from the configuration instance alone.
Note that some parameters are best validated at generate runtime, as they may depend on other inputs and/or the
model, such as parameters related to the generation length.
""" """
# Validation of individual attributes
if self.early_stopping not in {True, False, "never"}: if self.early_stopping not in {True, False, "never"}:
raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.") raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.")
# Validation of attribute relations:
# 1. detect sampling-only parameterization when not in sampling mode
if self.do_sample is False:
greedy_wrong_parameter_msg = (
"`do_sample` is set to `False`. However, {flag_name} is set to {flag_value} -- this flag is only used "
"in sample-based generation modes. Set `do_sample=True` or unset {flag_name} to continue."
)
if self.temperature != 1.0:
raise ValueError(
greedy_wrong_parameter_msg.format(flag_name="temperature", flag_value=self.temperature)
)
if self.top_p != 1.0:
raise ValueError(greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p))
if self.typical_p != 1.0:
raise ValueError(greedy_wrong_parameter_msg.format(flag_name="typical_p", flag_value=self.typical_p))
if self.top_k != 50 and self.penalty_alpha is None: # contrastive search uses top_k
raise ValueError(greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k))
if self.epsilon_cutoff != 0.0:
raise ValueError(
greedy_wrong_parameter_msg.format(flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff)
)
if self.eta_cutoff != 0.0:
raise ValueError(greedy_wrong_parameter_msg.format(flag_name="eta_cutoff", flag_value=self.eta_cutoff))
# 2. detect beam-only parameterization when not in beam mode
if self.num_beams == 1:
single_beam_wrong_parameter_msg = (
"`num_beams` is set to 1. However, {flag_name} is set to {flag_value} -- this flag is only used in "
"beam-based generation modes. Set `num_beams>1` or unset {flag_name} to continue."
)
if self.early_stopping is not False:
raise ValueError(
single_beam_wrong_parameter_msg.format(flag_name="early_stopping", flag_value=self.early_stopping)
)
if self.num_beam_groups != 1:
raise ValueError(
single_beam_wrong_parameter_msg.format(
flag_name="num_beam_groups", flag_value=self.num_beam_groups
)
)
if self.diversity_penalty != 0.0:
raise ValueError(
single_beam_wrong_parameter_msg.format(
flag_name="diversity_penalty", flag_value=self.diversity_penalty
)
)
if self.length_penalty != 1.0:
raise ValueError(
single_beam_wrong_parameter_msg.format(flag_name="length_penalty", flag_value=self.length_penalty)
)
if self.constraints is not None:
raise ValueError(
single_beam_wrong_parameter_msg.format(flag_name="constraints", flag_value=self.constraints)
)
# 3. detect incorrect paramaterization specific to advanced beam modes
else:
# constrained beam search
if self.constraints is not None:
constrained_wrong_parameter_msg = (
"`constraints` is not `None`, triggering constrained beam search. However, {flag_name} is set to "
"{flag_value}, which is incompatible with this generation mode. Set `constraints=None` or unset "
"{flag_name} to continue."
)
if self.do_sample is True:
raise ValueError(
constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample)
)
if self.num_beam_groups != 1:
raise ValueError(
constrained_wrong_parameter_msg.format(
flag_name="num_beam_groups", flag_value=self.num_beam_groups
)
)
# group beam search
if self.diversity_penalty != 0.0 or self.num_beam_groups != 1:
group_error_prefix = (
"`diversity_penalty` is not 0.0 or `num_beam_groups` is not 1, triggering group beam search. In "
"this generation mode, "
)
if self.do_sample is True:
raise ValueError(group_error_prefix + "`do_sample` must be set to `False`")
if self.num_beams % self.num_beam_groups != 0:
raise ValueError(group_error_prefix + "`num_beams` should be divisible by `num_beam_groups`")
if self.diversity_penalty == 0.0:
raise ValueError(
group_error_prefix
+ "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical."
)
# 4. check `num_return_sequences`
if self.num_return_sequences != 1:
if self.num_beams == 1:
if self.do_sample is False:
raise ValueError(
"Greedy methods without beam search do not support `num_return_sequences` different than 1 "
f"(got {self.num_return_sequences})."
)
elif self.num_return_sequences > self.num_beams:
raise ValueError(
f"`num_return_sequences` ({self.num_return_sequences}) has to be smaller or equal to `num_beams` "
f"({self.num_beams})."
)
def save_pretrained( def save_pretrained(
self, self,
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
......
...@@ -1493,13 +1493,6 @@ class GenerationMixin: ...@@ -1493,13 +1493,6 @@ class GenerationMixin:
# 7. determine generation mode # 7. determine generation mode
generation_mode = self._get_generation_mode(generation_config, assistant_model) generation_mode = self._get_generation_mode(generation_config, assistant_model)
if generation_config.num_beam_groups > generation_config.num_beams:
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
if generation_mode == GenerationMode.GROUP_BEAM_SEARCH and generation_config.do_sample is True:
raise ValueError(
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
)
if streamer is not None and (generation_config.num_beams > 1): if streamer is not None and (generation_config.num_beams > 1):
raise ValueError( raise ValueError(
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
...@@ -1572,12 +1565,6 @@ class GenerationMixin: ...@@ -1572,12 +1565,6 @@ class GenerationMixin:
**model_kwargs, **model_kwargs,
) )
if generation_mode == GenerationMode.GREEDY_SEARCH: if generation_mode == GenerationMode.GREEDY_SEARCH:
if generation_config.num_return_sequences > 1:
raise ValueError(
"num_return_sequences has to be 1 when doing greedy search, "
f"but is {generation_config.num_return_sequences}."
)
# 11. run greedy search # 11. run greedy search
return self.greedy_search( return self.greedy_search(
input_ids, input_ids,
...@@ -1593,11 +1580,6 @@ class GenerationMixin: ...@@ -1593,11 +1580,6 @@ class GenerationMixin:
) )
elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH: elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
if generation_config.num_return_sequences > 1:
raise ValueError(
"num_return_sequences has to be 1 when doing contrastive search, "
f"but is {generation_config.num_return_sequences}."
)
if not model_kwargs["use_cache"]: if not model_kwargs["use_cache"]:
raise ValueError("Contrastive search requires `use_cache=True`") raise ValueError("Contrastive search requires `use_cache=True`")
...@@ -1645,12 +1627,6 @@ class GenerationMixin: ...@@ -1645,12 +1627,6 @@ class GenerationMixin:
) )
elif generation_mode == GenerationMode.BEAM_SEARCH: elif generation_mode == GenerationMode.BEAM_SEARCH:
if generation_config.num_return_sequences > generation_config.num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
# 11. prepare beam search scorer # 11. prepare beam search scorer
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
...@@ -1686,8 +1662,6 @@ class GenerationMixin: ...@@ -1686,8 +1662,6 @@ class GenerationMixin:
# 11. prepare logits warper # 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config) logits_warper = self._get_logits_warper(generation_config)
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
# 12. prepare beam search scorer # 12. prepare beam search scorer
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
batch_size=batch_size * generation_config.num_return_sequences, batch_size=batch_size * generation_config.num_return_sequences,
...@@ -1722,24 +1696,6 @@ class GenerationMixin: ...@@ -1722,24 +1696,6 @@ class GenerationMixin:
) )
elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH: elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH:
if generation_config.num_return_sequences > generation_config.num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
if generation_config.num_beams % generation_config.num_beam_groups != 0:
raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")
if generation_config.diversity_penalty == 0.0:
raise ValueError(
"`diversity_penalty` should be greater than `0.0`, otherwise your beam groups will be identical."
)
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
has_default_typical_p = kwargs.get("typical_p") is None and generation_config.typical_p == 1.0
if not has_default_typical_p:
raise ValueError("Decoder argument `typical_p` is not supported with beam groups.")
# 11. prepare beam search scorer # 11. prepare beam search scorer
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
...@@ -1773,21 +1729,6 @@ class GenerationMixin: ...@@ -1773,21 +1729,6 @@ class GenerationMixin:
) )
elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH: elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH:
if generation_config.num_return_sequences > generation_config.num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
if generation_config.num_beams <= 1:
raise ValueError("`num_beams` needs to be greater than 1 for constrained generation.")
if generation_config.do_sample:
raise ValueError("`do_sample` needs to be false for constrained generation.")
if generation_config.num_beam_groups is not None and generation_config.num_beam_groups > 1:
raise ValueError("`num_beam_groups` not supported yet for constrained generation.")
final_constraints = [] final_constraints = []
if generation_config.constraints is not None: if generation_config.constraints is not None:
final_constraints = generation_config.constraints final_constraints = generation_config.constraints
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment