Unverified Commit b1ea6b4b authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: GenerationConfig can overwrite attributes at from_pretrained time (#24238)


Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 7bb6933b
......@@ -288,7 +288,8 @@ class GenerationConfig(PushToHubMixin):
# Additional attributes without default values
if not self._from_model_config:
# we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a model's default configuration file
# we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a
# model's default configuration file
for key, value in kwargs.items():
try:
setattr(self, key, value)
......@@ -569,9 +570,9 @@ class GenerationConfig(PushToHubMixin):
if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
kwargs["_commit_hash"] = config_dict["_commit_hash"]
# remove all the arguments that are in the config_dict
config = cls(**config_dict, **kwargs)
# The line below allows model-specific config to be loaded as well through kwargs, with safety checks.
# See https://github.com/huggingface/transformers/pull/21269
config = cls(**{**config_dict, **kwargs})
unused_kwargs = config.update(**kwargs)
logger.info(f"Generate config {config}")
......
......@@ -93,6 +93,31 @@ class GenerationConfigTest(unittest.TestCase):
generation_config = GenerationConfig.from_model_config(new_config)
assert not hasattr(generation_config, "foo") # no new kwargs should be initialized if from config
def test_kwarg_init(self):
"""Tests that we can overwrite attributes at `from_pretrained` time."""
default_config = GenerationConfig()
self.assertEqual(default_config.temperature, 1.0)
self.assertEqual(default_config.do_sample, False)
self.assertEqual(default_config.num_beams, 1)
config = GenerationConfig(
do_sample=True,
temperature=0.7,
length_penalty=1.0,
bad_words_ids=[[1, 2, 3], [4, 5]],
)
self.assertEqual(config.temperature, 0.7)
self.assertEqual(config.do_sample, True)
self.assertEqual(config.num_beams, 1)
with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(tmp_dir)
loaded_config = GenerationConfig.from_pretrained(tmp_dir, temperature=1.0)
self.assertEqual(loaded_config.temperature, 1.0)
self.assertEqual(loaded_config.do_sample, True)
self.assertEqual(loaded_config.num_beams, 1) # default value
@is_staging_test
class ConfigPushToHubTester(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment