Unverified Commit 94a7edd9 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[GenerationConfig] add additional kwargs handling (#21269)



* add additional kwargs handling

* fix issue when serializing

* correct order of kwargs removal for serialization in from dict

* add `dict_torch_dtype_to_str` in case a dtype is needed for generation

* add condition when adding the kwargs : not from config

* Add comment based on review
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* add test function

* default None when poping arg
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
parent 9286039c
......@@ -282,6 +282,16 @@ class GenerationConfig(PushToHubMixin):
self._commit_hash = kwargs.pop("_commit_hash", None)
self.transformers_version = kwargs.pop("transformers_version", __version__)
# Additional attributes without default values
if not self._from_model_config:
# we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a model's default configuration file
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
logger.error(f"Can't set {key} with value {value} for {self}")
raise err
def __eq__(self, other):
self_dict = self.__dict__.copy()
other_dict = other.__dict__.copy()
......@@ -537,7 +547,9 @@ class GenerationConfig(PushToHubMixin):
if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
kwargs["_commit_hash"] = config_dict["_commit_hash"]
config = cls(**config_dict)
# remove all the arguments that are in the config_dict
config = cls(**config_dict, **kwargs)
unused_kwargs = config.update(**kwargs)
logger.info(f"Generate config {config}")
......@@ -546,6 +558,18 @@ class GenerationConfig(PushToHubMixin):
else:
return config
def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
"""
Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None,
converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
string, which can then be stored in the json format.
"""
if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
for value in d.values():
if isinstance(value, dict):
self.dict_torch_dtype_to_str(value)
def to_diff_dict(self) -> Dict[str, Any]:
"""
Removes all attributes from config which correspond to the default config attributes for better readability and
......@@ -566,6 +590,7 @@ class GenerationConfig(PushToHubMixin):
if key not in default_config_dict or key == "transformers_version" or value != default_config_dict[key]:
serializable_config_dict[key] = value
self.dict_torch_dtype_to_str(serializable_config_dict)
return serializable_config_dict
def to_dict(self) -> Dict[str, Any]:
......@@ -582,6 +607,7 @@ class GenerationConfig(PushToHubMixin):
# Transformers version when serializing this file
output["transformers_version"] = __version__
self.dict_torch_dtype_to_str(output)
return output
def to_json_string(self, use_diff: bool = True) -> str:
......@@ -630,7 +656,8 @@ class GenerationConfig(PushToHubMixin):
[`GenerationConfig`]: The configuration object instantiated from those parameters.
"""
config_dict = model_config.to_dict()
config = cls.from_dict(config_dict, return_unused_kwargs=False)
config_dict.pop("_from_model_config", None)
config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)
# Special case: some models have generation attributes set in the decoder. Use them if still unset in the
# generation config.
......@@ -642,7 +669,6 @@ class GenerationConfig(PushToHubMixin):
if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr):
setattr(config, attr, decoder_config[attr])
config._from_model_config = True
return config
def update(self, **kwargs):
......
......@@ -78,6 +78,20 @@ class GenerationConfigTest(unittest.TestCase):
# `.update()` returns a dictionary of unused kwargs
self.assertEqual(unused_kwargs, {"foo": "bar"})
def test_initialize_new_kwargs(self):
generation_config = GenerationConfig()
generation_config.foo = "bar"
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
# update_kwargs was used to update the config on valid attributes
self.assertEqual(new_config.foo, "bar")
generation_config = GenerationConfig.from_model_config(new_config)
assert not hasattr(generation_config, "foo") # no new kwargs should be initialized if from config
@is_staging_test
class ConfigPushToHubTester(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment