Commit 0c716ede authored by Gunnlaugur Thor Briem's avatar Gunnlaugur Thor Briem
Browse files

Use class decorator instead of superclass

When supplied by Keras deserialization, the config parameter to initializers
will be a dict. So intercept it and convert to PretrainedConfig object (and
store in instance attribute for get_config to get at it) before passing to the
actual initializer. To accomplish this, and repeat as little code as possible,
use a class decorator on TF*MainLayer classes.
parent b8da16f3
......@@ -17,6 +17,7 @@
import logging
from collections import OrderedDict
from importlib import import_module
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig
......@@ -100,6 +101,20 @@ class AutoConfig:
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
)
@classmethod
def config_class_for_model_class(cls, model_class):
module = import_module(model_class.__module__)
return next(
(
module_attribute
for module_attribute_name in dir(module)
if module_attribute_name.endswith("Config")
for module_attribute in (getattr(module, module_attribute_name),)
if issubclass(module_attribute, PretrainedConfig)
),
None,
)
@classmethod
def for_model(cls, model_type, *args, **kwargs):
for pattern, config_class in CONFIG_MAPPING.items():
......
......@@ -23,7 +23,7 @@ import tensorflow as tf
from .configuration_albert import AlbertConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_bert import ACT2FN, TFBertSelfAttention
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, get_initializer, shape_list
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
logger = logging.getLogger(__name__)
......@@ -478,9 +478,10 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
return hidden_states
class TFAlbertMainLayer(TFMainLayer):
@keras_serializable
class TFAlbertMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
super().__init__(**kwargs)
self.num_hidden_layers = config.num_hidden_layers
self.embeddings = TFAlbertEmbeddings(config, name="embeddings")
......
......@@ -23,7 +23,7 @@ import tensorflow as tf
from .configuration_bert import BertConfig
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, get_initializer, shape_list
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
logger = logging.getLogger(__name__)
......@@ -471,9 +471,10 @@ class TFBertNSPHead(tf.keras.layers.Layer):
return seq_relationship_score
class TFBertMainLayer(TFMainLayer):
@keras_serializable
class TFBertMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
super().__init__(**kwargs)
self.num_hidden_layers = config.num_hidden_layers
self.embeddings = TFBertEmbeddings(config, name="embeddings")
......
......@@ -23,7 +23,7 @@ import tensorflow as tf
from .configuration_ctrl import CTRLConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, TFSharedEmbeddings, shape_list
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list
logger = logging.getLogger(__name__)
......@@ -164,9 +164,10 @@ class TFEncoderLayer(tf.keras.layers.Layer):
return outputs
class TFCTRLMainLayer(TFMainLayer):
@keras_serializable
class TFCTRLMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
super().__init__(**kwargs)
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.output_past = config.output_past
......
......@@ -24,7 +24,7 @@ import tensorflow as tf
from .configuration_distilbert import DistilBertConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, TFSharedEmbeddings, get_initializer, shape_list
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, get_initializer, keras_serializable, shape_list
logger = logging.getLogger(__name__)
......@@ -397,9 +397,10 @@ class TFTransformer(tf.keras.layers.Layer):
return outputs # last-layer hidden state, (all hidden states), (all attentions)
class TFDistilBertMainLayer(TFMainLayer):
@keras_serializable
class TFDistilBertMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
super().__init__(**kwargs)
self.num_hidden_layers = config.num_hidden_layers
self.embeddings = TFEmbeddings(config, name="embeddings") # Embeddings
......
......@@ -25,11 +25,11 @@ from .configuration_gpt2 import GPT2Config
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import (
TFConv1D,
TFMainLayer,
TFPreTrainedModel,
TFSequenceSummary,
TFSharedEmbeddings,
get_initializer,
keras_serializable,
shape_list,
)
......@@ -197,9 +197,10 @@ class TFBlock(tf.keras.layers.Layer):
return outputs # x, present, (attentions)
class TFGPT2MainLayer(TFMainLayer):
@keras_serializable
class TFGPT2MainLayer(tf.keras.layers.Layer):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
super().__init__(*inputs, **kwargs)
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.num_hidden_layers = config.n_layer
......
......@@ -25,11 +25,11 @@ from .configuration_openai import OpenAIGPTConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import (
TFConv1D,
TFMainLayer,
TFPreTrainedModel,
TFSequenceSummary,
TFSharedEmbeddings,
get_initializer,
keras_serializable,
shape_list,
)
......@@ -198,9 +198,10 @@ class TFBlock(tf.keras.layers.Layer):
return outputs # x, (attentions)
class TFOpenAIGPTMainLayer(TFMainLayer):
@keras_serializable
class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
super().__init__(*inputs, **kwargs)
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.num_hidden_layers = config.n_layer
......
......@@ -20,10 +20,11 @@ import logging
import tensorflow as tf
from . import PretrainedConfig
from .configuration_roberta import RobertaConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_bert import TFBertEmbeddings, TFBertMainLayer, gelu
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
logger = logging.getLogger(__name__)
......
......@@ -25,7 +25,7 @@ import tensorflow as tf
from .configuration_t5 import T5Config
from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, TFSharedEmbeddings, shape_list
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list
logger = logging.getLogger(__name__)
......@@ -359,9 +359,10 @@ class TFT5Block(tf.keras.layers.Layer):
# The full model without a specific pretrained or finetuning head is
# provided as a tf.keras.layers.Layer usually called "TFT5MainLayer"
####################################################
class TFT5MainLayer(TFMainLayer):
@keras_serializable
class TFT5MainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
super().__init__(**kwargs)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.is_decoder = config.is_decoder
......@@ -383,14 +384,21 @@ class TFT5MainLayer(TFMainLayer):
def call(
self,
hidden_states,
inputs,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
training=False,
):
if isinstance(inputs, (tuple, list)):
hidden_states = inputs[0]
assert len(inputs) <= 1, "Too many inputs."
elif isinstance(inputs, dict):
hidden_states = inputs["hidden_states"]
assert len(inputs) <= 1, "Too many inputs."
else:
hidden_states = inputs
batch_size, seq_length = shape_list(hidden_states)[:2]
if attention_mask is None:
attention_mask = tf.fill((batch_size, seq_length), 1)
......
......@@ -24,7 +24,7 @@ import tensorflow as tf
from .configuration_transfo_xl import TransfoXLConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, get_initializer, shape_list
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
logger = logging.getLogger(__name__)
......@@ -378,9 +378,10 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
return embed
class TFTransfoXLMainLayer(TFMainLayer):
@keras_serializable
class TFTransfoXLMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
super().__init__(**kwargs)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
......
......@@ -47,21 +47,31 @@ class TFModelUtilsMixin:
return self.count_params()
class TFMainLayer(tf.keras.layers.Layer):
"""
A common superclass for main layers of models, to support `get_config` and thus Keras JSON serialization.
"""
def keras_serializable(cls):
initializer = cls.__init__
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
def wrapped_init(self, config, *args, **kwargs):
if isinstance(config, dict):
config = PretrainedConfig.from_dict(config)
from transformers import AutoConfig
config = AutoConfig.config_class_for_model_class(cls).from_dict(config)
initializer(self, config, *args, **kwargs)
self._transformers_config = config
def get_config(self):
cfg = super().get_config()
cfg["config"] = self._transformers_config.to_dict()
return cfg
cls.__init__ = wrapped_init
if not hasattr(cls, "get_config"):
raise TypeError("Only use @keras_serializable on tf.keras.layers.Layer subclasses")
if hasattr(cls.get_config, "_is_default"):
def get_config(self):
cfg = super(cls, self).get_config()
cfg["config"] = self._transformers_config.to_dict()
return cfg
cls.get_config = get_config
return tf.keras.utils.register_keras_serializable()(cls)
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
......
......@@ -26,11 +26,11 @@ import tensorflow as tf
from .configuration_xlm import XLMConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import (
TFMainLayer,
TFPreTrainedModel,
TFSequenceSummary,
TFSharedEmbeddings,
get_initializer,
keras_serializable,
shape_list,
)
......@@ -203,9 +203,10 @@ class TFTransformerFFN(tf.keras.layers.Layer):
return x
class TFXLMMainLayer(TFMainLayer):
@keras_serializable
class TFXLMMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
super().__init__(**kwargs)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
......
......@@ -25,11 +25,11 @@ import tensorflow as tf
from .configuration_xlnet import XLNetConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import (
TFMainLayer,
TFPreTrainedModel,
TFSequenceSummary,
TFSharedEmbeddings,
get_initializer,
keras_serializable,
shape_list,
)
......@@ -349,9 +349,10 @@ class TFXLNetLMHead(tf.keras.layers.Layer):
return hidden_states
class TFXLNetMainLayer(TFMainLayer):
@keras_serializable
class TFXLNetMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
super().__init__(**kwargs)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.output_past = config.output_past
......
......@@ -22,7 +22,6 @@ import unittest
from importlib import import_module
from transformers import is_tf_available, is_torch_available
from transformers.modeling_tf_utils import TFMainLayer
from .utils import _tf_gpu_memory_limit, require_tf
......@@ -90,6 +89,7 @@ class TFModelTesterMixin:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)
after_outputs = model(inputs_dict)
self.assert_outputs_same(after_outputs, outputs)
def test_keras_save_load(self):
......@@ -100,10 +100,14 @@ class TFModelTesterMixin:
for model_class in self.all_model_classes
for module in (import_module(model_class.__module__),)
for module_member_name in dir(module)
if module_member_name.endswith("MainLayer")
for module_member in (getattr(module, module_member_name),)
if isinstance(module_member, type) and TFMainLayer in module_member.__bases__
if isinstance(module_member, type) and tf.keras.layers.Layer in module_member.__bases__
)
for main_layer_class in tf_main_layer_classes:
if main_layer_class.__name__ == "TFT5MainLayer":
# Not really a “main layer” as in the other models, as this one doesn't receive the test inputs directly
continue
main_layer = main_layer_class(config)
symbolic_inputs = {
name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
......@@ -125,6 +129,7 @@ class TFModelTesterMixin:
# Make sure we don't have nans
out_1 = after_outputs[0].numpy()
out_2 = outputs[0].numpy()
self.assertEqual(out_1.shape, out_2.shape)
out_1 = out_1[~np.isnan(out_1)]
out_2 = out_2[~np.isnan(out_2)]
max_diff = np.amax(np.abs(out_1 - out_2))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment