Unverified Commit a7755d24 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: unset GenerationConfig parameters do not raise warning (#29119)

parent 7d312ad2
...@@ -271,7 +271,6 @@ class GenerationConfig(PushToHubMixin): ...@@ -271,7 +271,6 @@ class GenerationConfig(PushToHubMixin):
def __init__(self, **kwargs): def __init__(self, **kwargs):
# Parameters that control the length of the output # Parameters that control the length of the output
# if the default `max_length` is updated here, make sure to update the `generate` tests following https://github.com/huggingface/transformers/pull/25030
self.max_length = kwargs.pop("max_length", 20) self.max_length = kwargs.pop("max_length", 20)
self.max_new_tokens = kwargs.pop("max_new_tokens", None) self.max_new_tokens = kwargs.pop("max_new_tokens", None)
self.min_length = kwargs.pop("min_length", 0) self.min_length = kwargs.pop("min_length", 0)
...@@ -407,32 +406,34 @@ class GenerationConfig(PushToHubMixin): ...@@ -407,32 +406,34 @@ class GenerationConfig(PushToHubMixin):
"used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`." "used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`."
+ fix_location + fix_location
) )
if self.temperature != 1.0: if self.temperature is not None and self.temperature != 1.0:
warnings.warn( warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="temperature", flag_value=self.temperature), greedy_wrong_parameter_msg.format(flag_name="temperature", flag_value=self.temperature),
UserWarning, UserWarning,
) )
if self.top_p != 1.0: if self.top_p is not None and self.top_p != 1.0:
warnings.warn( warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p), greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p),
UserWarning, UserWarning,
) )
if self.typical_p != 1.0: if self.typical_p is not None and self.typical_p != 1.0:
warnings.warn( warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="typical_p", flag_value=self.typical_p), greedy_wrong_parameter_msg.format(flag_name="typical_p", flag_value=self.typical_p),
UserWarning, UserWarning,
) )
if self.top_k != 50 and self.penalty_alpha is None: # contrastive search uses top_k if (
self.top_k is not None and self.top_k != 50 and self.penalty_alpha is None
): # contrastive search uses top_k
warnings.warn( warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k), greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k),
UserWarning, UserWarning,
) )
if self.epsilon_cutoff != 0.0: if self.epsilon_cutoff is not None and self.epsilon_cutoff != 0.0:
warnings.warn( warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff), greedy_wrong_parameter_msg.format(flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff),
UserWarning, UserWarning,
) )
if self.eta_cutoff != 0.0: if self.eta_cutoff is not None and self.eta_cutoff != 0.0:
warnings.warn( warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="eta_cutoff", flag_value=self.eta_cutoff), greedy_wrong_parameter_msg.format(flag_name="eta_cutoff", flag_value=self.eta_cutoff),
UserWarning, UserWarning,
...@@ -453,21 +454,21 @@ class GenerationConfig(PushToHubMixin): ...@@ -453,21 +454,21 @@ class GenerationConfig(PushToHubMixin):
single_beam_wrong_parameter_msg.format(flag_name="early_stopping", flag_value=self.early_stopping), single_beam_wrong_parameter_msg.format(flag_name="early_stopping", flag_value=self.early_stopping),
UserWarning, UserWarning,
) )
if self.num_beam_groups != 1: if self.num_beam_groups is not None and self.num_beam_groups != 1:
warnings.warn( warnings.warn(
single_beam_wrong_parameter_msg.format( single_beam_wrong_parameter_msg.format(
flag_name="num_beam_groups", flag_value=self.num_beam_groups flag_name="num_beam_groups", flag_value=self.num_beam_groups
), ),
UserWarning, UserWarning,
) )
if self.diversity_penalty != 0.0: if self.diversity_penalty is not None and self.diversity_penalty != 0.0:
warnings.warn( warnings.warn(
single_beam_wrong_parameter_msg.format( single_beam_wrong_parameter_msg.format(
flag_name="diversity_penalty", flag_value=self.diversity_penalty flag_name="diversity_penalty", flag_value=self.diversity_penalty
), ),
UserWarning, UserWarning,
) )
if self.length_penalty != 1.0: if self.length_penalty is not None and self.length_penalty != 1.0:
warnings.warn( warnings.warn(
single_beam_wrong_parameter_msg.format(flag_name="length_penalty", flag_value=self.length_penalty), single_beam_wrong_parameter_msg.format(flag_name="length_penalty", flag_value=self.length_penalty),
UserWarning, UserWarning,
...@@ -491,7 +492,7 @@ class GenerationConfig(PushToHubMixin): ...@@ -491,7 +492,7 @@ class GenerationConfig(PushToHubMixin):
raise ValueError( raise ValueError(
constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample) constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample)
) )
if self.num_beam_groups != 1: if self.num_beam_groups is not None and self.num_beam_groups != 1:
raise ValueError( raise ValueError(
constrained_wrong_parameter_msg.format( constrained_wrong_parameter_msg.format(
flag_name="num_beam_groups", flag_value=self.num_beam_groups flag_name="num_beam_groups", flag_value=self.num_beam_groups
...@@ -1000,6 +1001,9 @@ class GenerationConfig(PushToHubMixin): ...@@ -1000,6 +1001,9 @@ class GenerationConfig(PushToHubMixin):
setattr(self, key, value) setattr(self, key, value)
to_remove.append(key) to_remove.append(key)
# remove all the attributes that were updated, without modifying the input dict # Confirm that the updated instance is still valid
self.validate()
# Remove all the attributes that were updated, without modifying the input dict
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
return unused_kwargs return unused_kwargs
...@@ -330,7 +330,6 @@ class FlaxGenerationMixin: ...@@ -330,7 +330,6 @@ class FlaxGenerationMixin:
generation_config = copy.deepcopy(generation_config) generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy()) self._validate_model_kwargs(model_kwargs.copy())
logits_processor = logits_processor if logits_processor is not None else FlaxLogitsProcessorList() logits_processor = logits_processor if logits_processor is not None else FlaxLogitsProcessorList()
......
...@@ -736,7 +736,6 @@ class TFGenerationMixin: ...@@ -736,7 +736,6 @@ class TFGenerationMixin:
generation_config = copy.deepcopy(generation_config) generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy()) self._validate_model_kwargs(model_kwargs.copy())
# 2. Cast input dtypes to tf.int32 unless they're floats (which happens for some image models) # 2. Cast input dtypes to tf.int32 unless they're floats (which happens for some image models)
......
...@@ -1347,7 +1347,6 @@ class GenerationMixin: ...@@ -1347,7 +1347,6 @@ class GenerationMixin:
generation_config = copy.deepcopy(generation_config) generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy()) self._validate_model_kwargs(model_kwargs.copy())
# 2. Set generation parameters if not already defined # 2. Set generation parameters if not already defined
......
...@@ -152,7 +152,6 @@ class QuantizationConfigMixin: ...@@ -152,7 +152,6 @@ class QuantizationConfigMixin:
config_dict = self.to_dict() config_dict = self.to_dict()
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
# Copied from transformers.generation.configuration_utils.GenerationConfig.update
def update(self, **kwargs): def update(self, **kwargs):
""" """
Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes, Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes,
...@@ -171,7 +170,7 @@ class QuantizationConfigMixin: ...@@ -171,7 +170,7 @@ class QuantizationConfigMixin:
setattr(self, key, value) setattr(self, key, value)
to_remove.append(key) to_remove.append(key)
# remove all the attributes that were updated, without modifying the input dict # Remove all the attributes that were updated, without modifying the input dict
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
return unused_kwargs return unused_kwargs
......
...@@ -124,26 +124,44 @@ class GenerationConfigTest(unittest.TestCase): ...@@ -124,26 +124,44 @@ class GenerationConfigTest(unittest.TestCase):
""" """
Tests that the `validate` method is working as expected. Note that `validate` is called at initialization time Tests that the `validate` method is working as expected. Note that `validate` is called at initialization time
""" """
# Case 1: A correct configuration will not throw any warning # A correct configuration will not throw any warning
with warnings.catch_warnings(record=True) as captured_warnings: with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig() GenerationConfig()
self.assertEqual(len(captured_warnings), 0) self.assertEqual(len(captured_warnings), 0)
# Case 2: Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling # Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling
# parameters with `do_sample=False`). May be escalated to an error in the future. # parameters with `do_sample=False`). May be escalated to an error in the future.
with warnings.catch_warnings(record=True) as captured_warnings: with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig(temperature=0.5) GenerationConfig(do_sample=False, temperature=0.5)
self.assertEqual(len(captured_warnings), 1) self.assertEqual(len(captured_warnings), 1)
# Case 3: Impossible sets of contraints/parameters will raise an exception # Expanding on the case above, we can update a bad configuration to get rid of the warning. Ideally,
# that is done by unsetting the parameter (i.e. setting it to None)
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
with warnings.catch_warnings(record=True) as captured_warnings:
# BAD - 0.9 means it is still set, we should warn
generation_config_bad_temperature.update(temperature=0.9)
self.assertEqual(len(captured_warnings), 1)
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
with warnings.catch_warnings(record=True) as captured_warnings:
# CORNER CASE - 1.0 is the default, we can't detect whether it is set by the user or not, we shouldn't warn
generation_config_bad_temperature.update(temperature=1.0)
self.assertEqual(len(captured_warnings), 0)
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
with warnings.catch_warnings(record=True) as captured_warnings:
# OK - None means it is unset, nothing to warn about
generation_config_bad_temperature.update(temperature=None)
self.assertEqual(len(captured_warnings), 0)
# Impossible sets of contraints/parameters will raise an exception
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
GenerationConfig(num_return_sequences=2) GenerationConfig(do_sample=False, num_beams=1, num_return_sequences=2)
# Case 4: Passing `generate()`-only flags to `validate` will raise an exception # Passing `generate()`-only flags to `validate` will raise an exception
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
GenerationConfig(logits_processor="foo") GenerationConfig(logits_processor="foo")
# Case 5: Model-specific parameters will NOT raise an exception or a warning # Model-specific parameters will NOT raise an exception or a warning
with warnings.catch_warnings(record=True) as captured_warnings: with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig(foo="bar") GenerationConfig(foo="bar")
self.assertEqual(len(captured_warnings), 0) self.assertEqual(len(captured_warnings), 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment