Unverified Commit 94a7edd9 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[GenerationConfig] add additional kwargs handling (#21269)



* add additional kwargs handling

* fix issue when serializing

* correct order of kwargs removal for serialization in from dict

* add `dict_torch_dtype_to_str` in case a dtype is needed for generation

* add condition when adding the kwargs : not from config

* Add comment based on review
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* add test function

* default None when poping arg
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
parent 9286039c
...@@ -282,6 +282,16 @@ class GenerationConfig(PushToHubMixin): ...@@ -282,6 +282,16 @@ class GenerationConfig(PushToHubMixin):
self._commit_hash = kwargs.pop("_commit_hash", None) self._commit_hash = kwargs.pop("_commit_hash", None)
self.transformers_version = kwargs.pop("transformers_version", __version__) self.transformers_version = kwargs.pop("transformers_version", __version__)
# Additional attributes without default values
if not self._from_model_config:
# we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a model's default configuration file
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
logger.error(f"Can't set {key} with value {value} for {self}")
raise err
def __eq__(self, other): def __eq__(self, other):
self_dict = self.__dict__.copy() self_dict = self.__dict__.copy()
other_dict = other.__dict__.copy() other_dict = other.__dict__.copy()
...@@ -537,7 +547,9 @@ class GenerationConfig(PushToHubMixin): ...@@ -537,7 +547,9 @@ class GenerationConfig(PushToHubMixin):
if "_commit_hash" in kwargs and "_commit_hash" in config_dict: if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
kwargs["_commit_hash"] = config_dict["_commit_hash"] kwargs["_commit_hash"] = config_dict["_commit_hash"]
config = cls(**config_dict) # remove all the arguments that are in the config_dict
config = cls(**config_dict, **kwargs)
unused_kwargs = config.update(**kwargs) unused_kwargs = config.update(**kwargs)
logger.info(f"Generate config {config}") logger.info(f"Generate config {config}")
...@@ -546,6 +558,18 @@ class GenerationConfig(PushToHubMixin): ...@@ -546,6 +558,18 @@ class GenerationConfig(PushToHubMixin):
else: else:
return config return config
def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
"""
Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None,
converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
string, which can then be stored in the json format.
"""
if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
for value in d.values():
if isinstance(value, dict):
self.dict_torch_dtype_to_str(value)
def to_diff_dict(self) -> Dict[str, Any]: def to_diff_dict(self) -> Dict[str, Any]:
""" """
Removes all attributes from config which correspond to the default config attributes for better readability and Removes all attributes from config which correspond to the default config attributes for better readability and
...@@ -566,6 +590,7 @@ class GenerationConfig(PushToHubMixin): ...@@ -566,6 +590,7 @@ class GenerationConfig(PushToHubMixin):
if key not in default_config_dict or key == "transformers_version" or value != default_config_dict[key]: if key not in default_config_dict or key == "transformers_version" or value != default_config_dict[key]:
serializable_config_dict[key] = value serializable_config_dict[key] = value
self.dict_torch_dtype_to_str(serializable_config_dict)
return serializable_config_dict return serializable_config_dict
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
...@@ -582,6 +607,7 @@ class GenerationConfig(PushToHubMixin): ...@@ -582,6 +607,7 @@ class GenerationConfig(PushToHubMixin):
# Transformers version when serializing this file # Transformers version when serializing this file
output["transformers_version"] = __version__ output["transformers_version"] = __version__
self.dict_torch_dtype_to_str(output)
return output return output
def to_json_string(self, use_diff: bool = True) -> str: def to_json_string(self, use_diff: bool = True) -> str:
...@@ -630,7 +656,8 @@ class GenerationConfig(PushToHubMixin): ...@@ -630,7 +656,8 @@ class GenerationConfig(PushToHubMixin):
[`GenerationConfig`]: The configuration object instantiated from those parameters. [`GenerationConfig`]: The configuration object instantiated from those parameters.
""" """
config_dict = model_config.to_dict() config_dict = model_config.to_dict()
config = cls.from_dict(config_dict, return_unused_kwargs=False) config_dict.pop("_from_model_config", None)
config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)
# Special case: some models have generation attributes set in the decoder. Use them if still unset in the # Special case: some models have generation attributes set in the decoder. Use them if still unset in the
# generation config. # generation config.
...@@ -642,7 +669,6 @@ class GenerationConfig(PushToHubMixin): ...@@ -642,7 +669,6 @@ class GenerationConfig(PushToHubMixin):
if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr): if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr):
setattr(config, attr, decoder_config[attr]) setattr(config, attr, decoder_config[attr])
config._from_model_config = True
return config return config
def update(self, **kwargs): def update(self, **kwargs):
......
...@@ -78,6 +78,20 @@ class GenerationConfigTest(unittest.TestCase): ...@@ -78,6 +78,20 @@ class GenerationConfigTest(unittest.TestCase):
# `.update()` returns a dictionary of unused kwargs # `.update()` returns a dictionary of unused kwargs
self.assertEqual(unused_kwargs, {"foo": "bar"}) self.assertEqual(unused_kwargs, {"foo": "bar"})
def test_initialize_new_kwargs(self):
generation_config = GenerationConfig()
generation_config.foo = "bar"
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
# update_kwargs was used to update the config on valid attributes
self.assertEqual(new_config.foo, "bar")
generation_config = GenerationConfig.from_model_config(new_config)
assert not hasattr(generation_config, "foo") # no new kwargs should be initialized if from config
@is_staging_test @is_staging_test
class ConfigPushToHubTester(unittest.TestCase): class ConfigPushToHubTester(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment