Commit 0c716ede authored by Gunnlaugur Thor Briem's avatar Gunnlaugur Thor Briem
Browse files

Use class decorator instead of superclass

When supplied by Keras deserialization, the config parameter to initializers
will be a dict. So intercept it and convert to PretrainedConfig object (and
store in instance attribute for get_config to get at it) before passing to the
actual initializer. To accomplish this, and repeat as little code as possible,
use a class decorator on TF*MainLayer classes.
parent b8da16f3
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import logging import logging
from collections import OrderedDict from collections import OrderedDict
from importlib import import_module
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig
...@@ -100,6 +101,20 @@ class AutoConfig: ...@@ -100,6 +101,20 @@ class AutoConfig:
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method." "using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
) )
@classmethod
def config_class_for_model_class(cls, model_class):
module = import_module(model_class.__module__)
return next(
(
module_attribute
for module_attribute_name in dir(module)
if module_attribute_name.endswith("Config")
for module_attribute in (getattr(module, module_attribute_name),)
if issubclass(module_attribute, PretrainedConfig)
),
None,
)
@classmethod @classmethod
def for_model(cls, model_type, *args, **kwargs): def for_model(cls, model_type, *args, **kwargs):
for pattern, config_class in CONFIG_MAPPING.items(): for pattern, config_class in CONFIG_MAPPING.items():
......
...@@ -23,7 +23,7 @@ import tensorflow as tf ...@@ -23,7 +23,7 @@ import tensorflow as tf
from .configuration_albert import AlbertConfig from .configuration_albert import AlbertConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_bert import ACT2FN, TFBertSelfAttention from .modeling_tf_bert import ACT2FN, TFBertSelfAttention
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, get_initializer, shape_list from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -478,9 +478,10 @@ class TFAlbertMLMHead(tf.keras.layers.Layer): ...@@ -478,9 +478,10 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
return hidden_states return hidden_states
class TFAlbertMainLayer(TFMainLayer): @keras_serializable
class TFAlbertMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(config, **kwargs) super().__init__(**kwargs)
self.num_hidden_layers = config.num_hidden_layers self.num_hidden_layers = config.num_hidden_layers
self.embeddings = TFAlbertEmbeddings(config, name="embeddings") self.embeddings = TFAlbertEmbeddings(config, name="embeddings")
......
...@@ -23,7 +23,7 @@ import tensorflow as tf ...@@ -23,7 +23,7 @@ import tensorflow as tf
from .configuration_bert import BertConfig from .configuration_bert import BertConfig
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, get_initializer, shape_list from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -471,9 +471,10 @@ class TFBertNSPHead(tf.keras.layers.Layer): ...@@ -471,9 +471,10 @@ class TFBertNSPHead(tf.keras.layers.Layer):
return seq_relationship_score return seq_relationship_score
class TFBertMainLayer(TFMainLayer): @keras_serializable
class TFBertMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(config, **kwargs) super().__init__(**kwargs)
self.num_hidden_layers = config.num_hidden_layers self.num_hidden_layers = config.num_hidden_layers
self.embeddings = TFBertEmbeddings(config, name="embeddings") self.embeddings = TFBertEmbeddings(config, name="embeddings")
......
...@@ -23,7 +23,7 @@ import tensorflow as tf ...@@ -23,7 +23,7 @@ import tensorflow as tf
from .configuration_ctrl import CTRLConfig from .configuration_ctrl import CTRLConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, TFSharedEmbeddings, shape_list from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -164,9 +164,10 @@ class TFEncoderLayer(tf.keras.layers.Layer): ...@@ -164,9 +164,10 @@ class TFEncoderLayer(tf.keras.layers.Layer):
return outputs return outputs
class TFCTRLMainLayer(TFMainLayer): @keras_serializable
class TFCTRLMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(config, **kwargs) super().__init__(**kwargs)
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_past = config.output_past self.output_past = config.output_past
......
...@@ -24,7 +24,7 @@ import tensorflow as tf ...@@ -24,7 +24,7 @@ import tensorflow as tf
from .configuration_distilbert import DistilBertConfig from .configuration_distilbert import DistilBertConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, TFSharedEmbeddings, get_initializer, shape_list from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, get_initializer, keras_serializable, shape_list
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -397,9 +397,10 @@ class TFTransformer(tf.keras.layers.Layer): ...@@ -397,9 +397,10 @@ class TFTransformer(tf.keras.layers.Layer):
return outputs # last-layer hidden state, (all hidden states), (all attentions) return outputs # last-layer hidden state, (all hidden states), (all attentions)
class TFDistilBertMainLayer(TFMainLayer): @keras_serializable
class TFDistilBertMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(config, **kwargs) super().__init__(**kwargs)
self.num_hidden_layers = config.num_hidden_layers self.num_hidden_layers = config.num_hidden_layers
self.embeddings = TFEmbeddings(config, name="embeddings") # Embeddings self.embeddings = TFEmbeddings(config, name="embeddings") # Embeddings
......
...@@ -25,11 +25,11 @@ from .configuration_gpt2 import GPT2Config ...@@ -25,11 +25,11 @@ from .configuration_gpt2 import GPT2Config
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import ( from .modeling_tf_utils import (
TFConv1D, TFConv1D,
TFMainLayer,
TFPreTrainedModel, TFPreTrainedModel,
TFSequenceSummary, TFSequenceSummary,
TFSharedEmbeddings, TFSharedEmbeddings,
get_initializer, get_initializer,
keras_serializable,
shape_list, shape_list,
) )
...@@ -197,9 +197,10 @@ class TFBlock(tf.keras.layers.Layer): ...@@ -197,9 +197,10 @@ class TFBlock(tf.keras.layers.Layer):
return outputs # x, present, (attentions) return outputs # x, present, (attentions)
class TFGPT2MainLayer(TFMainLayer): @keras_serializable
class TFGPT2MainLayer(tf.keras.layers.Layer):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(*inputs, **kwargs)
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.num_hidden_layers = config.n_layer self.num_hidden_layers = config.n_layer
......
...@@ -25,11 +25,11 @@ from .configuration_openai import OpenAIGPTConfig ...@@ -25,11 +25,11 @@ from .configuration_openai import OpenAIGPTConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import ( from .modeling_tf_utils import (
TFConv1D, TFConv1D,
TFMainLayer,
TFPreTrainedModel, TFPreTrainedModel,
TFSequenceSummary, TFSequenceSummary,
TFSharedEmbeddings, TFSharedEmbeddings,
get_initializer, get_initializer,
keras_serializable,
shape_list, shape_list,
) )
...@@ -198,9 +198,10 @@ class TFBlock(tf.keras.layers.Layer): ...@@ -198,9 +198,10 @@ class TFBlock(tf.keras.layers.Layer):
return outputs # x, (attentions) return outputs # x, (attentions)
class TFOpenAIGPTMainLayer(TFMainLayer): @keras_serializable
class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(*inputs, **kwargs)
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.num_hidden_layers = config.n_layer self.num_hidden_layers = config.n_layer
......
...@@ -20,10 +20,11 @@ import logging ...@@ -20,10 +20,11 @@ import logging
import tensorflow as tf import tensorflow as tf
from . import PretrainedConfig
from .configuration_roberta import RobertaConfig from .configuration_roberta import RobertaConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_bert import TFBertEmbeddings, TFBertMainLayer, gelu from .modeling_tf_bert import TFBertEmbeddings, TFBertMainLayer, gelu
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -25,7 +25,7 @@ import tensorflow as tf ...@@ -25,7 +25,7 @@ import tensorflow as tf
from .configuration_t5 import T5Config from .configuration_t5 import T5Config
from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, TFSharedEmbeddings, shape_list from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -359,9 +359,10 @@ class TFT5Block(tf.keras.layers.Layer): ...@@ -359,9 +359,10 @@ class TFT5Block(tf.keras.layers.Layer):
# The full model without a specific pretrained or finetuning head is # The full model without a specific pretrained or finetuning head is
# provided as a tf.keras.layers.Layer usually called "TFT5MainLayer" # provided as a tf.keras.layers.Layer usually called "TFT5MainLayer"
#################################################### ####################################################
class TFT5MainLayer(TFMainLayer): @keras_serializable
class TFT5MainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(config, **kwargs) super().__init__(**kwargs)
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
...@@ -383,14 +384,21 @@ class TFT5MainLayer(TFMainLayer): ...@@ -383,14 +384,21 @@ class TFT5MainLayer(TFMainLayer):
def call( def call(
self, self,
hidden_states, inputs,
attention_mask=None, attention_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
training=False, training=False,
): ):
if isinstance(inputs, (tuple, list)):
hidden_states = inputs[0]
assert len(inputs) <= 1, "Too many inputs."
elif isinstance(inputs, dict):
hidden_states = inputs["hidden_states"]
assert len(inputs) <= 1, "Too many inputs."
else:
hidden_states = inputs
batch_size, seq_length = shape_list(hidden_states)[:2] batch_size, seq_length = shape_list(hidden_states)[:2]
if attention_mask is None: if attention_mask is None:
attention_mask = tf.fill((batch_size, seq_length), 1) attention_mask = tf.fill((batch_size, seq_length), 1)
......
...@@ -24,7 +24,7 @@ import tensorflow as tf ...@@ -24,7 +24,7 @@ import tensorflow as tf
from .configuration_transfo_xl import TransfoXLConfig from .configuration_transfo_xl import TransfoXLConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, get_initializer, shape_list from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -378,9 +378,10 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer): ...@@ -378,9 +378,10 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
return embed return embed
class TFTransfoXLMainLayer(TFMainLayer): @keras_serializable
class TFTransfoXLMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(config, **kwargs) super().__init__(**kwargs)
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
......
...@@ -47,21 +47,31 @@ class TFModelUtilsMixin: ...@@ -47,21 +47,31 @@ class TFModelUtilsMixin:
return self.count_params() return self.count_params()
class TFMainLayer(tf.keras.layers.Layer): def keras_serializable(cls):
""" initializer = cls.__init__
A common superclass for main layers of models, to support `get_config` and thus Keras JSON serialization.
"""
def __init__(self, config, **kwargs): def wrapped_init(self, config, *args, **kwargs):
super().__init__(**kwargs)
if isinstance(config, dict): if isinstance(config, dict):
config = PretrainedConfig.from_dict(config) from transformers import AutoConfig
config = AutoConfig.config_class_for_model_class(cls).from_dict(config)
initializer(self, config, *args, **kwargs)
self._transformers_config = config self._transformers_config = config
def get_config(self): cls.__init__ = wrapped_init
cfg = super().get_config()
cfg["config"] = self._transformers_config.to_dict() if not hasattr(cls, "get_config"):
return cfg raise TypeError("Only use @keras_serializable on tf.keras.layers.Layer subclasses")
if hasattr(cls.get_config, "_is_default"):
def get_config(self):
cfg = super(cls, self).get_config()
cfg["config"] = self._transformers_config.to_dict()
return cfg
cls.get_config = get_config
return tf.keras.utils.register_keras_serializable()(cls)
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
......
...@@ -26,11 +26,11 @@ import tensorflow as tf ...@@ -26,11 +26,11 @@ import tensorflow as tf
from .configuration_xlm import XLMConfig from .configuration_xlm import XLMConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import ( from .modeling_tf_utils import (
TFMainLayer,
TFPreTrainedModel, TFPreTrainedModel,
TFSequenceSummary, TFSequenceSummary,
TFSharedEmbeddings, TFSharedEmbeddings,
get_initializer, get_initializer,
keras_serializable,
shape_list, shape_list,
) )
...@@ -203,9 +203,10 @@ class TFTransformerFFN(tf.keras.layers.Layer): ...@@ -203,9 +203,10 @@ class TFTransformerFFN(tf.keras.layers.Layer):
return x return x
class TFXLMMainLayer(TFMainLayer): @keras_serializable
class TFXLMMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(config, **kwargs) super().__init__(**kwargs)
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
......
...@@ -25,11 +25,11 @@ import tensorflow as tf ...@@ -25,11 +25,11 @@ import tensorflow as tf
from .configuration_xlnet import XLNetConfig from .configuration_xlnet import XLNetConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import ( from .modeling_tf_utils import (
TFMainLayer,
TFPreTrainedModel, TFPreTrainedModel,
TFSequenceSummary, TFSequenceSummary,
TFSharedEmbeddings, TFSharedEmbeddings,
get_initializer, get_initializer,
keras_serializable,
shape_list, shape_list,
) )
...@@ -349,9 +349,10 @@ class TFXLNetLMHead(tf.keras.layers.Layer): ...@@ -349,9 +349,10 @@ class TFXLNetLMHead(tf.keras.layers.Layer):
return hidden_states return hidden_states
class TFXLNetMainLayer(TFMainLayer): @keras_serializable
class TFXLNetMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(config, **kwargs) super().__init__(**kwargs)
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_past = config.output_past self.output_past = config.output_past
......
...@@ -22,7 +22,6 @@ import unittest ...@@ -22,7 +22,6 @@ import unittest
from importlib import import_module from importlib import import_module
from transformers import is_tf_available, is_torch_available from transformers import is_tf_available, is_torch_available
from transformers.modeling_tf_utils import TFMainLayer
from .utils import _tf_gpu_memory_limit, require_tf from .utils import _tf_gpu_memory_limit, require_tf
...@@ -90,6 +89,7 @@ class TFModelTesterMixin: ...@@ -90,6 +89,7 @@ class TFModelTesterMixin:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname) model = model_class.from_pretrained(tmpdirname)
after_outputs = model(inputs_dict) after_outputs = model(inputs_dict)
self.assert_outputs_same(after_outputs, outputs) self.assert_outputs_same(after_outputs, outputs)
def test_keras_save_load(self): def test_keras_save_load(self):
...@@ -100,10 +100,14 @@ class TFModelTesterMixin: ...@@ -100,10 +100,14 @@ class TFModelTesterMixin:
for model_class in self.all_model_classes for model_class in self.all_model_classes
for module in (import_module(model_class.__module__),) for module in (import_module(model_class.__module__),)
for module_member_name in dir(module) for module_member_name in dir(module)
if module_member_name.endswith("MainLayer")
for module_member in (getattr(module, module_member_name),) for module_member in (getattr(module, module_member_name),)
if isinstance(module_member, type) and TFMainLayer in module_member.__bases__ if isinstance(module_member, type) and tf.keras.layers.Layer in module_member.__bases__
) )
for main_layer_class in tf_main_layer_classes: for main_layer_class in tf_main_layer_classes:
if main_layer_class.__name__ == "TFT5MainLayer":
# Not really a “main layer” as in the other models, as this one doesn't receive the test inputs directly
continue
main_layer = main_layer_class(config) main_layer = main_layer_class(config)
symbolic_inputs = { symbolic_inputs = {
name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items() name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
...@@ -125,6 +129,7 @@ class TFModelTesterMixin: ...@@ -125,6 +129,7 @@ class TFModelTesterMixin:
# Make sure we don't have nans # Make sure we don't have nans
out_1 = after_outputs[0].numpy() out_1 = after_outputs[0].numpy()
out_2 = outputs[0].numpy() out_2 = outputs[0].numpy()
self.assertEqual(out_1.shape, out_2.shape)
out_1 = out_1[~np.isnan(out_1)] out_1 = out_1[~np.isnan(out_1)]
out_2 = out_2[~np.isnan(out_2)] out_2 = out_2[~np.isnan(out_2)]
max_diff = np.amax(np.abs(out_1 - out_2)) max_diff = np.amax(np.abs(out_1 - out_2))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment