Unverified Commit 5e704a2c authored by oOraph's avatar oOraph Committed by GitHub
Browse files

keep _use_default_values as a list type (#4040)


Signed-off-by: default avatarRaphael <oOraph@users.noreply.github.com>
Co-authored-by: default avatarRaphael <oOraph@users.noreply.github.com>
parent 8bff7823
......@@ -607,7 +607,7 @@ def register_to_config(init):
# Take note of the parameters that were not present in the loaded config
if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
new_kwargs["_use_default_values"] = set(new_kwargs.keys()) - set(init_kwargs)
new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
new_kwargs = {**config_init_kwargs, **new_kwargs}
getattr(self, "register_to_config")(**new_kwargs)
......@@ -655,7 +655,7 @@ def flax_register_to_config(cls):
# Take note of the parameters that were not present in the loaded config
if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
new_kwargs["_use_default_values"] = set(new_kwargs.keys()) - set(init_kwargs)
new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
getattr(self, "register_to_config")(**new_kwargs)
original_init(self, *args, **kwargs)
......
......@@ -264,7 +264,7 @@ class ConfigTester(unittest.TestCase):
config_dict = {k: v for k, v in config.config.items() if not k.startswith("_")}
# make sure that default config has all keys in `_use_default_values`
assert set(config_dict.keys()) == config.config._use_default_values
assert set(config_dict.keys()) == set(config.config._use_default_values)
with tempfile.TemporaryDirectory() as tmpdirname:
config.save_config(tmpdirname)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment