Unverified Commit eb3ded16 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: lower severity of parameterization checks (#25407)

parent ef74da65
...@@ -313,7 +313,7 @@ class GenerationConfig(PushToHubMixin): ...@@ -313,7 +313,7 @@ class GenerationConfig(PushToHubMixin):
raise err raise err
# Validate the values of the attributes # Validate the values of the attributes
self.validate() self.validate(is_init=True)
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, GenerationConfig): if not isinstance(other, GenerationConfig):
...@@ -330,7 +330,7 @@ class GenerationConfig(PushToHubMixin): ...@@ -330,7 +330,7 @@ class GenerationConfig(PushToHubMixin):
def __repr__(self): def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}" return f"{self.__class__.__name__} {self.to_json_string()}"
def validate(self): def validate(self, is_init=False):
""" """
Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
of parameterization that can be detected as incorrect from the configuration instance alone. of parameterization that can be detected as incorrect from the configuration instance alone.
...@@ -344,58 +344,85 @@ class GenerationConfig(PushToHubMixin): ...@@ -344,58 +344,85 @@ class GenerationConfig(PushToHubMixin):
raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.") raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.")
# Validation of attribute relations: # Validation of attribute relations:
fix_location = ""
if is_init:
fix_location = (
" This was detected when initializing the generation config instance, which means the corresponding "
"file may hold incorrect parameterization and should be fixed."
)
# 1. detect sampling-only parameterization when not in sampling mode # 1. detect sampling-only parameterization when not in sampling mode
if self.do_sample is False: if self.do_sample is False:
greedy_wrong_parameter_msg = ( greedy_wrong_parameter_msg = (
"`do_sample` is set to `False`. However, {flag_name} is set to {flag_value} -- this flag is only used " "`do_sample` is set to `False`. However, {flag_name} is set to {flag_value} -- this flag is only used "
"in sample-based generation modes. Set `do_sample=True` or unset {flag_name} to continue." "in sample-based generation modes. You should set `do_sample=True` or unset {flag_name}."
+ fix_location
) )
if self.temperature != 1.0: if self.temperature != 1.0:
raise ValueError( warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="temperature", flag_value=self.temperature) greedy_wrong_parameter_msg.format(flag_name="temperature", flag_value=self.temperature),
UserWarning,
) )
if self.top_p != 1.0: if self.top_p != 1.0:
raise ValueError(greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p)) warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p),
UserWarning,
)
if self.typical_p != 1.0: if self.typical_p != 1.0:
raise ValueError(greedy_wrong_parameter_msg.format(flag_name="typical_p", flag_value=self.typical_p)) warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="typical_p", flag_value=self.typical_p),
UserWarning,
)
if self.top_k != 50 and self.penalty_alpha is None: # contrastive search uses top_k if self.top_k != 50 and self.penalty_alpha is None: # contrastive search uses top_k
raise ValueError(greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k)) warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k),
UserWarning,
)
if self.epsilon_cutoff != 0.0: if self.epsilon_cutoff != 0.0:
raise ValueError( warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff) greedy_wrong_parameter_msg.format(flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff),
UserWarning,
) )
if self.eta_cutoff != 0.0: if self.eta_cutoff != 0.0:
raise ValueError(greedy_wrong_parameter_msg.format(flag_name="eta_cutoff", flag_value=self.eta_cutoff)) warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="eta_cutoff", flag_value=self.eta_cutoff),
UserWarning,
)
# 2. detect beam-only parameterization when not in beam mode # 2. detect beam-only parameterization when not in beam mode
if self.num_beams == 1: if self.num_beams == 1:
single_beam_wrong_parameter_msg = ( single_beam_wrong_parameter_msg = (
"`num_beams` is set to 1. However, {flag_name} is set to {flag_value} -- this flag is only used in " "`num_beams` is set to 1. However, {flag_name} is set to {flag_value} -- this flag is only used in "
"beam-based generation modes. Set `num_beams>1` or unset {flag_name} to continue." "beam-based generation modes. You should set `num_beams>1` or unset {flag_name}." + fix_location
) )
if self.early_stopping is not False: if self.early_stopping is not False:
raise ValueError( warnings.warn(
single_beam_wrong_parameter_msg.format(flag_name="early_stopping", flag_value=self.early_stopping) single_beam_wrong_parameter_msg.format(flag_name="early_stopping", flag_value=self.early_stopping),
UserWarning,
) )
if self.num_beam_groups != 1: if self.num_beam_groups != 1:
raise ValueError( warnings.warn(
single_beam_wrong_parameter_msg.format( single_beam_wrong_parameter_msg.format(
flag_name="num_beam_groups", flag_value=self.num_beam_groups flag_name="num_beam_groups", flag_value=self.num_beam_groups
) ),
UserWarning,
) )
if self.diversity_penalty != 0.0: if self.diversity_penalty != 0.0:
raise ValueError( warnings.warn(
single_beam_wrong_parameter_msg.format( single_beam_wrong_parameter_msg.format(
flag_name="diversity_penalty", flag_value=self.diversity_penalty flag_name="diversity_penalty", flag_value=self.diversity_penalty
) ),
UserWarning,
) )
if self.length_penalty != 1.0: if self.length_penalty != 1.0:
raise ValueError( warnings.warn(
single_beam_wrong_parameter_msg.format(flag_name="length_penalty", flag_value=self.length_penalty) single_beam_wrong_parameter_msg.format(flag_name="length_penalty", flag_value=self.length_penalty),
UserWarning,
) )
if self.constraints is not None: if self.constraints is not None:
raise ValueError( warnings.warn(
single_beam_wrong_parameter_msg.format(flag_name="constraints", flag_value=self.constraints) single_beam_wrong_parameter_msg.format(flag_name="constraints", flag_value=self.constraints),
UserWarning,
) )
# 3. detect incorrect paramaterization specific to advanced beam modes # 3. detect incorrect paramaterization specific to advanced beam modes
...@@ -405,7 +432,7 @@ class GenerationConfig(PushToHubMixin): ...@@ -405,7 +432,7 @@ class GenerationConfig(PushToHubMixin):
constrained_wrong_parameter_msg = ( constrained_wrong_parameter_msg = (
"`constraints` is not `None`, triggering constrained beam search. However, {flag_name} is set to " "`constraints` is not `None`, triggering constrained beam search. However, {flag_name} is set to "
"{flag_value}, which is incompatible with this generation mode. Set `constraints=None` or unset " "{flag_value}, which is incompatible with this generation mode. Set `constraints=None` or unset "
"{flag_name} to continue." "{flag_name} to continue." + fix_location
) )
if self.do_sample is True: if self.do_sample is True:
raise ValueError( raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment