Commit ba281707 authored by Gunnlaugur Thor Briem's avatar Gunnlaugur Thor Briem
Browse files

Support keras JSON/HDF5 serialization of main layers

Fixes #3101
parent a088d75e
......@@ -23,7 +23,7 @@ import tensorflow as tf
from .configuration_albert import AlbertConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_bert import ACT2FN, TFBertSelfAttention
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, get_initializer, shape_list
logger = logging.getLogger(__name__)
......@@ -478,7 +478,7 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
return hidden_states
class TFAlbertMainLayer(tf.keras.layers.Layer):
class TFAlbertMainLayer(TFMainLayer):
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.num_hidden_layers = config.num_hidden_layers
......
......@@ -23,7 +23,7 @@ import tensorflow as tf
from .configuration_bert import BertConfig
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, get_initializer, shape_list
logger = logging.getLogger(__name__)
......@@ -471,9 +471,9 @@ class TFBertNSPHead(tf.keras.layers.Layer):
return seq_relationship_score
class TFBertMainLayer(tf.keras.layers.Layer):
class TFBertMainLayer(TFMainLayer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
super().__init__(config, **kwargs)
self.num_hidden_layers = config.num_hidden_layers
self.embeddings = TFBertEmbeddings(config, name="embeddings")
......
......@@ -23,7 +23,7 @@ import tensorflow as tf
from .configuration_ctrl import CTRLConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, TFSharedEmbeddings, shape_list
logger = logging.getLogger(__name__)
......@@ -164,9 +164,9 @@ class TFEncoderLayer(tf.keras.layers.Layer):
return outputs
class TFCTRLMainLayer(tf.keras.layers.Layer):
class TFCTRLMainLayer(TFMainLayer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
super().__init__(config, **kwargs)
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.output_past = config.output_past
......
......@@ -24,7 +24,7 @@ import tensorflow as tf
from .configuration_distilbert import DistilBertConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, get_initializer, shape_list
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, TFSharedEmbeddings, get_initializer, shape_list
logger = logging.getLogger(__name__)
......@@ -397,9 +397,9 @@ class TFTransformer(tf.keras.layers.Layer):
return outputs # last-layer hidden state, (all hidden states), (all attentions)
class TFDistilBertMainLayer(tf.keras.layers.Layer):
class TFDistilBertMainLayer(TFMainLayer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
super().__init__(config, **kwargs)
self.num_hidden_layers = config.num_hidden_layers
self.embeddings = TFEmbeddings(config, name="embeddings") # Embeddings
......
......@@ -25,6 +25,7 @@ from .configuration_gpt2 import GPT2Config
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import (
TFConv1D,
TFMainLayer,
TFPreTrainedModel,
TFSequenceSummary,
TFSharedEmbeddings,
......@@ -196,9 +197,9 @@ class TFBlock(tf.keras.layers.Layer):
return outputs # x, present, (attentions)
class TFGPT2MainLayer(tf.keras.layers.Layer):
class TFGPT2MainLayer(TFMainLayer):
def __init__(self, config, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
super().__init__(config, *inputs, **kwargs)
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.num_hidden_layers = config.n_layer
......
......@@ -25,6 +25,7 @@ from .configuration_openai import OpenAIGPTConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import (
TFConv1D,
TFMainLayer,
TFPreTrainedModel,
TFSequenceSummary,
TFSharedEmbeddings,
......@@ -197,7 +198,7 @@ class TFBlock(tf.keras.layers.Layer):
return outputs # x, (attentions)
class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
class TFOpenAIGPTMainLayer(TFMainLayer):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.output_hidden_states = config.output_hidden_states
......
......@@ -25,7 +25,7 @@ import tensorflow as tf
from .configuration_t5 import T5Config
from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, TFSharedEmbeddings, shape_list
logger = logging.getLogger(__name__)
......@@ -359,9 +359,9 @@ class TFT5Block(tf.keras.layers.Layer):
# The full model without a specific pretrained or finetuning head is
# provided as a tf.keras.layers.Layer usually called "TFT5MainLayer"
####################################################
class TFT5MainLayer(tf.keras.layers.Layer):
class TFT5MainLayer(TFMainLayer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
super().__init__(config, **kwargs)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.is_decoder = config.is_decoder
......
......@@ -24,7 +24,7 @@ import tensorflow as tf
from .configuration_transfo_xl import TransfoXLConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, get_initializer, shape_list
logger = logging.getLogger(__name__)
......@@ -378,9 +378,9 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
return embed
class TFTransfoXLMainLayer(tf.keras.layers.Layer):
class TFTransfoXLMainLayer(TFMainLayer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
super().__init__(config, **kwargs)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
......
......@@ -47,6 +47,23 @@ class TFModelUtilsMixin:
return self.count_params()
class TFMainLayer(tf.keras.layers.Layer):
"""
A common superclass for main layers of models, to support `get_config` and thus Keras JSON serialization.
"""
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
if isinstance(config, dict):
config = PretrainedConfig.from_dict(config)
self._transformers_config = config
def get_config(self):
cfg = super().get_config()
cfg["config"] = self._transformers_config.to_dict()
return cfg
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
r""" Base class for all TF models.
......
......@@ -25,7 +25,14 @@ import tensorflow as tf
from .configuration_xlm import XLMConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, get_initializer, shape_list
from .modeling_tf_utils import (
TFMainLayer,
TFPreTrainedModel,
TFSequenceSummary,
TFSharedEmbeddings,
get_initializer,
shape_list,
)
logger = logging.getLogger(__name__)
......@@ -196,9 +203,9 @@ class TFTransformerFFN(tf.keras.layers.Layer):
return x
class TFXLMMainLayer(tf.keras.layers.Layer):
class TFXLMMainLayer(TFMainLayer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
super().__init__(config, **kwargs)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
......
......@@ -24,7 +24,14 @@ import tensorflow as tf
from .configuration_xlnet import XLNetConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, get_initializer, shape_list
from .modeling_tf_utils import (
TFMainLayer,
TFPreTrainedModel,
TFSequenceSummary,
TFSharedEmbeddings,
get_initializer,
shape_list,
)
logger = logging.getLogger(__name__)
......@@ -342,9 +349,9 @@ class TFXLNetLMHead(tf.keras.layers.Layer):
return hidden_states
class TFXLNetMainLayer(tf.keras.layers.Layer):
class TFXLNetMainLayer(TFMainLayer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
super().__init__(config, **kwargs)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.output_past = config.output_past
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment