Commit 4be01e5c authored by Gunnlaugur Thor Briem's avatar Gunnlaugur Thor Briem
Browse files

Use name transformers_config in Keras serialization

Be explicit that this is config for the transformers package (as these
layers may coexist with other custom stuff in a Keras model, plus the
Keras container itself is called config, and config["config"] is not
great)

Add explicit error handling for initializer calls that have neither
the `config` nor the `transformers_config` argument, or have both.
parent a355f4f0
......@@ -54,10 +54,20 @@ def keras_serializable(cls):
raise AttributeError("Must set `config_class` to use @keras_serializable")
@functools.wraps(initializer)
def wrapped_init(self, config, *args, **kwargs):
if isinstance(config, dict):
config = config_class.from_dict(config)
initializer(self, config, *args, **kwargs)
def wrapped_init(self, *args, **kwargs):
transformers_config = kwargs.pop("transformers_config", None)
config = args[0] if args and isinstance(args[0], PretrainedConfig) else kwargs.get("config", None)
if config is not None and transformers_config is not None:
raise ValueError("Must pass either `config` or `transformers_config`, not both")
elif config is not None:
# normal layer construction, call with unchanged args (config is already in there)
initializer(self, *args, **kwargs)
elif transformers_config is not None:
# Keras deserialization, convert dict to config
config = config_class.from_dict(transformers_config)
initializer(self, config, *args, **kwargs)
else:
raise ValueError("Must pass either `config` (PretrainedConfig) or `transformers_config` (dict)")
self._transformers_config = config
cls.__init__ = wrapped_init
......@@ -68,7 +78,7 @@ def keras_serializable(cls):
def get_config(self):
cfg = super(cls, self).get_config()
cfg["config"] = self._transformers_config.to_dict()
cfg["transformers_config"] = self._transformers_config.to_dict()
return cfg
cls.get_config = get_config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment