Unverified Commit 123ad536 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generation: strict generation config validation at save time (#25411)

* strict gen config save; Add tests

* add note that the warning will be an exception in v4.34
parent 16edf4d9
...@@ -354,8 +354,8 @@ class GenerationConfig(PushToHubMixin): ...@@ -354,8 +354,8 @@ class GenerationConfig(PushToHubMixin):
# 1. detect sampling-only parameterization when not in sampling mode # 1. detect sampling-only parameterization when not in sampling mode
if self.do_sample is False: if self.do_sample is False:
greedy_wrong_parameter_msg = ( greedy_wrong_parameter_msg = (
"`do_sample` is set to `False`. However, {flag_name} is set to {flag_value} -- this flag is only used " "`do_sample` is set to `False`. However, `{flag_name}` is set to `{flag_value}` -- this flag is only "
"in sample-based generation modes. You should set `do_sample=True` or unset {flag_name}." "used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`."
+ fix_location + fix_location
) )
if self.temperature != 1.0: if self.temperature != 1.0:
...@@ -392,8 +392,8 @@ class GenerationConfig(PushToHubMixin): ...@@ -392,8 +392,8 @@ class GenerationConfig(PushToHubMixin):
# 2. detect beam-only parameterization when not in beam mode # 2. detect beam-only parameterization when not in beam mode
if self.num_beams == 1: if self.num_beams == 1:
single_beam_wrong_parameter_msg = ( single_beam_wrong_parameter_msg = (
"`num_beams` is set to 1. However, {flag_name} is set to {flag_value} -- this flag is only used in " "`num_beams` is set to 1. However, `{flag_name}` is set to `{flag_value}` -- this flag is only used "
"beam-based generation modes. You should set `num_beams>1` or unset {flag_name}." + fix_location "in beam-based generation modes. You should set `num_beams>1` or unset `{flag_name}`." + fix_location
) )
if self.early_stopping is not False: if self.early_stopping is not False:
warnings.warn( warnings.warn(
...@@ -430,9 +430,9 @@ class GenerationConfig(PushToHubMixin): ...@@ -430,9 +430,9 @@ class GenerationConfig(PushToHubMixin):
# constrained beam search # constrained beam search
if self.constraints is not None: if self.constraints is not None:
constrained_wrong_parameter_msg = ( constrained_wrong_parameter_msg = (
"`constraints` is not `None`, triggering constrained beam search. However, {flag_name} is set to " "`constraints` is not `None`, triggering constrained beam search. However, `{flag_name}` is set "
"{flag_value}, which is incompatible with this generation mode. Set `constraints=None` or unset " "to `{flag_value}`, which is incompatible with this generation mode. Set `constraints=None` or "
"{flag_name} to continue." + fix_location "unset `{flag_name}` to continue." + fix_location
) )
if self.do_sample is True: if self.do_sample is True:
raise ValueError( raise ValueError(
...@@ -497,6 +497,22 @@ class GenerationConfig(PushToHubMixin): ...@@ -497,6 +497,22 @@ class GenerationConfig(PushToHubMixin):
kwargs (`Dict[str, Any]`, *optional*): kwargs (`Dict[str, Any]`, *optional*):
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
""" """
# At save time, validate the instance -- if any warning/exception is thrown, we refuse to save the instance
try:
with warnings.catch_warnings(record=True) as caught_warnings:
self.validate()
for w in caught_warnings:
raise ValueError(w.message)
except ValueError as exc:
warnings.warn(
"The generation config instance is invalid -- `.validate()` throws warnings and/or exceptions. "
"Fix these issues to save the configuration. This warning will be raised to an exception in v4.34."
"\n\nThrown during validation:\n" + str(exc),
UserWarning,
)
return
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)
if use_auth_token is not None: if use_auth_token is not None:
......
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
# limitations under the License. # limitations under the License.
import copy import copy
import os
import tempfile import tempfile
import unittest import unittest
import warnings
from huggingface_hub import HfFolder, delete_repo from huggingface_hub import HfFolder, delete_repo
from parameterized import parameterized from parameterized import parameterized
...@@ -118,6 +120,39 @@ class GenerationConfigTest(unittest.TestCase): ...@@ -118,6 +120,39 @@ class GenerationConfigTest(unittest.TestCase):
self.assertEqual(loaded_config.do_sample, True) self.assertEqual(loaded_config.do_sample, True)
self.assertEqual(loaded_config.num_beams, 1) # default value self.assertEqual(loaded_config.num_beams, 1) # default value
def test_refuse_to_save(self):
"""Tests that we refuse to save a generation config that fails validation."""
# setting the temperature alone is invalid, as we also need to set do_sample to True -> throws a warning that
# is caught, doesn't save, and raises a warning
config = GenerationConfig()
config.temperature = 0.5
with tempfile.TemporaryDirectory() as tmp_dir:
with warnings.catch_warnings(record=True) as captured_warnings:
config.save_pretrained(tmp_dir)
self.assertEqual(len(captured_warnings), 1)
self.assertTrue("Fix these issues to save the configuration." in str(captured_warnings[0].message))
self.assertTrue(len(os.listdir(tmp_dir)) == 0)
# greedy decoding throws an exception if we try to return multiple sequences -> throws an exception that is
# caught, doesn't save, and raises a warning
config = GenerationConfig()
config.num_return_sequences = 2
with tempfile.TemporaryDirectory() as tmp_dir:
with warnings.catch_warnings(record=True) as captured_warnings:
config.save_pretrained(tmp_dir)
self.assertEqual(len(captured_warnings), 1)
self.assertTrue("Fix these issues to save the configuration." in str(captured_warnings[0].message))
self.assertTrue(len(os.listdir(tmp_dir)) == 0)
# final check: no warnings thrown if it is correct, and file is saved
config = GenerationConfig()
with tempfile.TemporaryDirectory() as tmp_dir:
with warnings.catch_warnings(record=True) as captured_warnings:
config.save_pretrained(tmp_dir)
self.assertEqual(len(captured_warnings), 0)
self.assertTrue(len(os.listdir(tmp_dir)) == 1)
@is_staging_test @is_staging_test
class ConfigPushToHubTester(unittest.TestCase): class ConfigPushToHubTester(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment