Commit 470753bc authored by Gunnlaugur Thor Briem's avatar Gunnlaugur Thor Briem
Browse files

Put @keras_serializable only on layers it works on

And only run the test on TF*MainLayer classes so marked.
parent 0c716ede
......@@ -397,7 +397,6 @@ class TFTransformer(tf.keras.layers.Layer):
return outputs # last-layer hidden state, (all hidden states), (all attentions)
@keras_serializable
class TFDistilBertMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
......
......@@ -198,7 +198,6 @@ class TFBlock(tf.keras.layers.Layer):
return outputs # x, (attentions)
@keras_serializable
class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
def __init__(self, config, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
......
......@@ -359,7 +359,6 @@ class TFT5Block(tf.keras.layers.Layer):
# The full model without a specific pretrained or finetuning head is
# provided as a tf.keras.layers.Layer usually called "TFT5MainLayer"
####################################################
@keras_serializable
class TFT5MainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
......
......@@ -71,6 +71,7 @@ def keras_serializable(cls):
cls.get_config = get_config
cls._keras_serializable = True
return tf.keras.utils.register_keras_serializable()(cls)
......
......@@ -203,7 +203,6 @@ class TFTransformerFFN(tf.keras.layers.Layer):
return x
@keras_serializable
class TFXLMMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
......
......@@ -103,11 +103,9 @@ class TFModelTesterMixin:
if module_member_name.endswith("MainLayer")
for module_member in (getattr(module, module_member_name),)
if isinstance(module_member, type) and tf.keras.layers.Layer in module_member.__bases__
and getattr(module_member, '_keras_serializable', False)
)
for main_layer_class in tf_main_layer_classes:
if main_layer_class.__name__ == "TFT5MainLayer":
# Not really a “main layer” as in the other models, as this one doesn't receive the test inputs directly
continue
main_layer = main_layer_class(config)
symbolic_inputs = {
name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment