Commit 18f4b927 authored by Gunnlaugur Thor Briem's avatar Gunnlaugur Thor Briem
Browse files

fix: work with Tensorflow < 2.1.0

tf.keras.utils.register_keras_serializable was added in TF 2.1.0, so
don't rely on it being there; just decorate the class with it if it
exists.
parent 96c49901
...@@ -72,7 +72,9 @@ def keras_serializable(cls): ...@@ -72,7 +72,9 @@ def keras_serializable(cls):
cls.get_config = get_config cls.get_config = get_config
cls._keras_serializable = True cls._keras_serializable = True
return tf.keras.utils.register_keras_serializable()(cls) if hasattr(tf.keras.utils, "register_keras_serializable"):
cls = tf.keras.utils.register_keras_serializable()(cls)
return cls
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment