"web/git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "ece5acb8e8025d8ca26aa880f604d971d245475d"
Commit 470753bc authored by Gunnlaugur Thor Briem's avatar Gunnlaugur Thor Briem
Browse files

Put @keras_serializable only on layers it works on

And only run the test on TF*MainLayer classes so marked.
parent 0c716ede
...@@ -397,7 +397,6 @@ class TFTransformer(tf.keras.layers.Layer): ...@@ -397,7 +397,6 @@ class TFTransformer(tf.keras.layers.Layer):
return outputs # last-layer hidden state, (all hidden states), (all attentions) return outputs # last-layer hidden state, (all hidden states), (all attentions)
@keras_serializable
class TFDistilBertMainLayer(tf.keras.layers.Layer): class TFDistilBertMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
......
...@@ -198,7 +198,6 @@ class TFBlock(tf.keras.layers.Layer): ...@@ -198,7 +198,6 @@ class TFBlock(tf.keras.layers.Layer):
return outputs # x, (attentions) return outputs # x, (attentions)
@keras_serializable
class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)
......
...@@ -359,7 +359,6 @@ class TFT5Block(tf.keras.layers.Layer): ...@@ -359,7 +359,6 @@ class TFT5Block(tf.keras.layers.Layer):
# The full model without a specific pretrained or finetuning head is # The full model without a specific pretrained or finetuning head is
# provided as a tf.keras.layers.Layer usually called "TFT5MainLayer" # provided as a tf.keras.layers.Layer usually called "TFT5MainLayer"
#################################################### ####################################################
@keras_serializable
class TFT5MainLayer(tf.keras.layers.Layer): class TFT5MainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
......
...@@ -71,6 +71,7 @@ def keras_serializable(cls): ...@@ -71,6 +71,7 @@ def keras_serializable(cls):
cls.get_config = get_config cls.get_config = get_config
cls._keras_serializable = True
return tf.keras.utils.register_keras_serializable()(cls) return tf.keras.utils.register_keras_serializable()(cls)
......
...@@ -203,7 +203,6 @@ class TFTransformerFFN(tf.keras.layers.Layer): ...@@ -203,7 +203,6 @@ class TFTransformerFFN(tf.keras.layers.Layer):
return x return x
@keras_serializable
class TFXLMMainLayer(tf.keras.layers.Layer): class TFXLMMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
......
...@@ -103,11 +103,9 @@ class TFModelTesterMixin: ...@@ -103,11 +103,9 @@ class TFModelTesterMixin:
if module_member_name.endswith("MainLayer") if module_member_name.endswith("MainLayer")
for module_member in (getattr(module, module_member_name),) for module_member in (getattr(module, module_member_name),)
if isinstance(module_member, type) and tf.keras.layers.Layer in module_member.__bases__ if isinstance(module_member, type) and tf.keras.layers.Layer in module_member.__bases__
and getattr(module_member, '_keras_serializable', False)
) )
for main_layer_class in tf_main_layer_classes: for main_layer_class in tf_main_layer_classes:
if main_layer_class.__name__ == "TFT5MainLayer":
# Not really a “main layer” as in the other models, as this one doesn't receive the test inputs directly
continue
main_layer = main_layer_class(config) main_layer = main_layer_class(config)
symbolic_inputs = { symbolic_inputs = {
name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items() name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment