Unverified Commit afc45b13 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: refuse to save bad generation config files (#28477)

parent dc01cf9c
...@@ -551,16 +551,13 @@ class GenerationConfig(PushToHubMixin): ...@@ -551,16 +551,13 @@ class GenerationConfig(PushToHubMixin):
try: try:
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
self.validate() self.validate()
for w in caught_warnings: if len(caught_warnings) > 0:
raise ValueError(w.message) raise ValueError(str([w.message for w in caught_warnings]))
except ValueError as exc: except ValueError as exc:
warnings.warn( raise ValueError(
"The generation config instance is invalid -- `.validate()` throws warnings and/or exceptions. " "The generation config instance is invalid -- `.validate()` throws warnings and/or exceptions. "
"Fix these issues to save the configuration. This warning will be raised to an exception in v4.34." "Fix these issues to save the configuration.\n\nThrown during validation:\n" + str(exc)
"\n\nThrown during validation:\n" + str(exc),
UserWarning,
) )
return
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)
......
...@@ -152,14 +152,13 @@ class GenerationConfigTest(unittest.TestCase): ...@@ -152,14 +152,13 @@ class GenerationConfigTest(unittest.TestCase):
"""Tests that we refuse to save a generation config that fails validation.""" """Tests that we refuse to save a generation config that fails validation."""
# setting the temperature alone is invalid, as we also need to set do_sample to True -> throws a warning that # setting the temperature alone is invalid, as we also need to set do_sample to True -> throws a warning that
# is caught, doesn't save, and raises a warning # is caught, doesn't save, and raises an exception
config = GenerationConfig() config = GenerationConfig()
config.temperature = 0.5 config.temperature = 0.5
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
with warnings.catch_warnings(record=True) as captured_warnings: with self.assertRaises(ValueError) as exc:
config.save_pretrained(tmp_dir) config.save_pretrained(tmp_dir)
self.assertEqual(len(captured_warnings), 1) self.assertTrue("Fix these issues to save the configuration." in str(exc.exception))
self.assertTrue("Fix these issues to save the configuration." in str(captured_warnings[0].message))
self.assertTrue(len(os.listdir(tmp_dir)) == 0) self.assertTrue(len(os.listdir(tmp_dir)) == 0)
# greedy decoding throws an exception if we try to return multiple sequences -> throws an exception that is # greedy decoding throws an exception if we try to return multiple sequences -> throws an exception that is
...@@ -167,13 +166,12 @@ class GenerationConfigTest(unittest.TestCase): ...@@ -167,13 +166,12 @@ class GenerationConfigTest(unittest.TestCase):
config = GenerationConfig() config = GenerationConfig()
config.num_return_sequences = 2 config.num_return_sequences = 2
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
with warnings.catch_warnings(record=True) as captured_warnings: with self.assertRaises(ValueError) as exc:
config.save_pretrained(tmp_dir) config.save_pretrained(tmp_dir)
self.assertEqual(len(captured_warnings), 1) self.assertTrue("Fix these issues to save the configuration." in str(exc.exception))
self.assertTrue("Fix these issues to save the configuration." in str(captured_warnings[0].message))
self.assertTrue(len(os.listdir(tmp_dir)) == 0) self.assertTrue(len(os.listdir(tmp_dir)) == 0)
# final check: no warnings thrown if it is correct, and file is saved # final check: no warnings/exceptions thrown if it is correct, and file is saved
config = GenerationConfig() config = GenerationConfig()
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
with warnings.catch_warnings(record=True) as captured_warnings: with warnings.catch_warnings(record=True) as captured_warnings:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment