Commit f5a1343f authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Shift all attribute assignments in functional subclass models below the...

Shift all attribute assignments in functional subclass models below the super().__init__ call, and enable setattr tracking on all functional subclasses.

PiperOrigin-RevId: 338759531
parent f16e42b4
......@@ -15,6 +15,8 @@
"""Bert encoder network."""
# pylint: disable=g-classes-have-attributes
import collections
from absl import logging
import tensorflow as tf
from official.nlp.keras_nlp import layers
......@@ -65,6 +67,8 @@ class BertEncoder(tf.keras.Model):
matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size').
embedding_layer: An optional Layer instance which will be called to
generate embeddings for the input word IDs.
"""
def __init__(
......@@ -82,27 +86,11 @@ class BertEncoder(tf.keras.Model):
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
output_range=None,
embedding_width=None,
embedding_layer=None,
**kwargs):
activation = tf.keras.activations.get(inner_activation)
initializer = tf.keras.initializers.get(initializer)
self._self_setattr_tracking = False
self._config_dict = {
'vocab_size': vocab_size,
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_attention_heads': num_attention_heads,
'max_sequence_length': max_sequence_length,
'type_vocab_size': type_vocab_size,
'inner_dim': inner_dim,
'inner_activation': tf.keras.activations.serialize(activation),
'output_dropout': output_dropout,
'attention_dropout': attention_dropout,
'initializer': tf.keras.initializers.serialize(initializer),
'output_range': output_range,
'embedding_width': embedding_width,
}
word_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_word_ids')
mask = tf.keras.layers.Input(
......@@ -112,44 +100,54 @@ class BertEncoder(tf.keras.Model):
if embedding_width is None:
embedding_width = hidden_size
self._embedding_layer = self._build_embedding_layer()
word_embeddings = self._embedding_layer(word_ids)
if embedding_layer is None:
embedding_layer_inst = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=initializer,
name='word_embeddings')
else:
embedding_layer_inst = embedding_layer
word_embeddings = embedding_layer_inst(word_ids)
# Always uses dynamic slicing for simplicity.
self._position_embedding_layer = layers.PositionEmbedding(
position_embedding_layer = layers.PositionEmbedding(
initializer=initializer,
max_length=max_sequence_length,
name='position_embedding')
position_embeddings = self._position_embedding_layer(word_embeddings)
self._type_embedding_layer = layers.OnDeviceEmbedding(
position_embeddings = position_embedding_layer(word_embeddings)
type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=initializer,
use_one_hot=True,
name='type_embeddings')
type_embeddings = self._type_embedding_layer(type_ids)
type_embeddings = type_embedding_layer(type_ids)
embeddings = tf.keras.layers.Add()(
[word_embeddings, position_embeddings, type_embeddings])
self._embedding_norm_layer = tf.keras.layers.LayerNormalization(
embedding_norm_layer = tf.keras.layers.LayerNormalization(
name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32)
embeddings = self._embedding_norm_layer(embeddings)
embeddings = embedding_norm_layer(embeddings)
embeddings = (tf.keras.layers.Dropout(rate=output_dropout)(embeddings))
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
if embedding_width != hidden_size:
self._embedding_projection = tf.keras.layers.experimental.EinsumDense(
embedding_projection = tf.keras.layers.experimental.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
kernel_initializer=initializer,
name='embedding_projection')
embeddings = self._embedding_projection(embeddings)
embeddings = embedding_projection(embeddings)
else:
embedding_projection = None
self._transformer_layers = []
transformer_layers = []
data = embeddings
attention_mask = layers.SelfAttentionMask()(data, mask)
encoder_outputs = []
......@@ -167,7 +165,7 @@ class BertEncoder(tf.keras.Model):
output_range=transformer_output_range,
kernel_initializer=initializer,
name='transformer/layer_%d' % i)
self._transformer_layers.append(layer)
transformer_layers.append(layer)
data = layer([data, attention_mask])
encoder_outputs.append(data)
......@@ -176,38 +174,68 @@ class BertEncoder(tf.keras.Model):
# like this will create a SliceOpLambda layer. This is better than a Lambda
# layer with Python code, because that is fundamentally less portable.
first_token_tensor = last_enocder_output[:, 0, :]
self._pooler_layer = tf.keras.layers.Dense(
pooler_layer = tf.keras.layers.Dense(
units=hidden_size,
activation='tanh',
kernel_initializer=initializer,
name='pooler_transform')
cls_output = self._pooler_layer(first_token_tensor)
cls_output = pooler_layer(first_token_tensor)
outputs = dict(
sequence_output=encoder_outputs[-1],
pooled_output=cls_output,
encoder_outputs=encoder_outputs,
)
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super(BertEncoder, self).__init__(
inputs=[word_ids, mask, type_ids], outputs=outputs, **kwargs)
config_dict = {
'vocab_size': vocab_size,
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_attention_heads': num_attention_heads,
'max_sequence_length': max_sequence_length,
'type_vocab_size': type_vocab_size,
'inner_dim': inner_dim,
'inner_activation': tf.keras.activations.serialize(activation),
'output_dropout': output_dropout,
'attention_dropout': attention_dropout,
'initializer': tf.keras.initializers.serialize(initializer),
'output_range': output_range,
'embedding_width': embedding_width,
'embedding_layer': embedding_layer,
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self._pooler_layer = pooler_layer
self._transformer_layers = transformer_layers
self._embedding_norm_layer = embedding_norm_layer
self._embedding_layer = embedding_layer_inst
self._position_embedding_layer = position_embedding_layer
self._type_embedding_layer = type_embedding_layer
self._embedding_projection = embedding_projection
def get_embedding_table(self):
return self._embedding_layer.embeddings
def _build_embedding_layer(self):
embedding_width = self._config_dict[
'embedding_width'] or self._config_dict['hidden_size']
return layers.OnDeviceEmbedding(
vocab_size=self._config_dict['vocab_size'],
embedding_width=embedding_width,
initializer=self._config_dict['initializer'],
name='word_embeddings')
def get_embedding_layer(self):
return self._embedding_layer
def get_config(self):
return self._config_dict
return dict(self._config._asdict())
@property
def transformer_layers(self):
......@@ -221,4 +249,13 @@ class BertEncoder(tf.keras.Model):
@classmethod
def from_config(cls, config, custom_objects=None):
if config['embedding_layer'] is not None:
warn_string = (
'You are reloading a model that was saved with a '
'potentially-shared embedding layer object. If you contine to '
'train this model, the embedding layer will no longer be shared. '
'To work around this, load the model outside of the Keras API.')
print('WARNING: ' + warn_string)
logging.warn(warn_string)
return cls(**config)
......@@ -204,7 +204,8 @@ class BertEncoderTest(keras_parameterized.TestCase):
attention_dropout=0.22,
initializer="glorot_uniform",
output_range=-1,
embedding_width=16)
embedding_width=16,
embedding_layer=None)
network = bert_encoder.BertEncoder(**kwargs)
expected_config = dict(kwargs)
expected_config["inner_activation"] = tf.keras.activations.serialize(
......
......@@ -14,7 +14,7 @@
# ==============================================================================
"""BERT cls-token classifier."""
# pylint: disable=g-classes-have-attributes
import collections
import tensorflow as tf
from official.nlp.modeling import layers
......@@ -55,14 +55,6 @@ class BertClassifier(tf.keras.Model):
dropout_rate=0.1,
use_encoder_pooler=True,
**kwargs):
self._self_setattr_tracking = False
self._network = network
self._config = {
'network': network,
'num_classes': num_classes,
'initializer': initializer,
'use_encoder_pooler': use_encoder_pooler,
}
# We want to use the inputs of the passed network as the inputs to this
# Model. To do this, we need to keep a handle to the network inputs for use
......@@ -79,29 +71,52 @@ class BertClassifier(tf.keras.Model):
cls_output = outputs['pooled_output']
cls_output = tf.keras.layers.Dropout(rate=dropout_rate)(cls_output)
self.classifier = networks.Classification(
classifier = networks.Classification(
input_width=cls_output.shape[-1],
num_classes=num_classes,
initializer=initializer,
output='logits',
name='sentence_prediction')
predictions = self.classifier(cls_output)
predictions = classifier(cls_output)
else:
outputs = network(inputs)
if isinstance(outputs, list):
sequence_output = outputs[0]
else:
sequence_output = outputs['sequence_output']
self.classifier = layers.ClassificationHead(
classifier = layers.ClassificationHead(
inner_dim=sequence_output.shape[-1],
num_classes=num_classes,
initializer=initializer,
dropout_rate=dropout_rate,
name='sentence_prediction')
predictions = self.classifier(sequence_output)
predictions = classifier(sequence_output)
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super(BertClassifier, self).__init__(
inputs=inputs, outputs=predictions, **kwargs)
self._network = network
config_dict = {
'network': network,
'num_classes': num_classes,
'initializer': initializer,
'use_encoder_pooler': use_encoder_pooler,
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self.classifier = classifier
@property
def checkpoint_items(self):
......@@ -112,7 +127,7 @@ class BertClassifier(tf.keras.Model):
return items
def get_config(self):
return self._config
return dict(self._config._asdict())
@classmethod
def from_config(cls, config, custom_objects=None):
......
......@@ -14,7 +14,7 @@
# ==============================================================================
"""BERT Pre-training model."""
# pylint: disable=g-classes-have-attributes
import collections
import copy
from typing import List, Optional
......@@ -64,21 +64,12 @@ class BertPretrainer(tf.keras.Model):
initializer='glorot_uniform',
output='logits',
**kwargs):
self._self_setattr_tracking = False
self._config = {
'network': network,
'num_classes': num_classes,
'num_token_predictions': num_token_predictions,
'activation': activation,
'initializer': initializer,
'output': output,
}
self.encoder = network
# We want to use the inputs of the passed network as the inputs to this
# Model. To do this, we need to keep a copy of the network inputs for use
# when we construct the Model object at the end of init. (We keep a copy
# because we'll be adding another tensor to the copy later.)
network_inputs = self.encoder.inputs
network_inputs = network.inputs
inputs = copy.copy(network_inputs)
# Because we have a copy of inputs to create this Model object, we can
......@@ -86,7 +77,7 @@ class BertPretrainer(tf.keras.Model):
# Note that, because of how deferred construction happens, we can't use
# the copy of the list here - by the time the network is invoked, the list
# object contains the additional input added below.
sequence_output, cls_output = self.encoder(network_inputs)
sequence_output, cls_output = network(network_inputs)
# The encoder network may get outputs from all layers.
if isinstance(sequence_output, list):
......@@ -108,31 +99,59 @@ class BertPretrainer(tf.keras.Model):
inputs.append(masked_lm_positions)
if embedding_table is None:
embedding_table = self.encoder.get_embedding_table()
self.masked_lm = layers.MaskedLM(
embedding_table = network.get_embedding_table()
masked_lm = layers.MaskedLM(
embedding_table=embedding_table,
activation=activation,
initializer=initializer,
output=output,
name='cls/predictions')
lm_outputs = self.masked_lm(
lm_outputs = masked_lm(
sequence_output, masked_positions=masked_lm_positions)
self.classification = networks.Classification(
classification = networks.Classification(
input_width=cls_output.shape[-1],
num_classes=num_classes,
initializer=initializer,
output=output,
name='classification')
sentence_outputs = self.classification(cls_output)
sentence_outputs = classification(cls_output)
super(BertPretrainer, self).__init__(
inputs=inputs,
outputs=dict(masked_lm=lm_outputs, classification=sentence_outputs),
**kwargs)
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
config_dict = {
'network': network,
'num_classes': num_classes,
'num_token_predictions': num_token_predictions,
'activation': activation,
'initializer': initializer,
'output': output,
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self.encoder = network
self.classification = classification
self.masked_lm = masked_lm
def get_config(self):
return self._config
return dict(self._config._asdict())
@classmethod
def from_config(cls, config, custom_objects=None):
......
......@@ -14,7 +14,7 @@
# ==============================================================================
"""BERT Question Answering model."""
# pylint: disable=g-classes-have-attributes
import collections
import tensorflow as tf
from official.nlp.modeling import networks
......@@ -49,13 +49,6 @@ class BertSpanLabeler(tf.keras.Model):
initializer='glorot_uniform',
output='logits',
**kwargs):
self._self_setattr_tracking = False
self._network = network
self._config = {
'network': network,
'initializer': initializer,
'output': output,
}
# We want to use the inputs of the passed network as the inputs to this
# Model. To do this, we need to keep a handle to the network inputs for use
......@@ -72,12 +65,12 @@ class BertSpanLabeler(tf.keras.Model):
# This is an instance variable for ease of access to the underlying task
# network.
self.span_labeling = networks.SpanLabeling(
span_labeling = networks.SpanLabeling(
input_width=sequence_output.shape[-1],
initializer=initializer,
output=output,
name='span_labeling')
start_logits, end_logits = self.span_labeling(sequence_output)
start_logits, end_logits = span_labeling(sequence_output)
# Use identity layers wrapped in lambdas to explicitly name the output
# tensors. This allows us to use string-keyed dicts in Keras fit/predict/
......@@ -91,15 +84,36 @@ class BertSpanLabeler(tf.keras.Model):
logits = [start_logits, end_logits]
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super(BertSpanLabeler, self).__init__(
inputs=inputs, outputs=logits, **kwargs)
self._network = network
config_dict = {
'network': network,
'initializer': initializer,
'output': output,
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self.span_labeling = span_labeling
@property
def checkpoint_items(self):
return dict(encoder=self._network)
def get_config(self):
return self._config
return dict(self._config._asdict())
@classmethod
def from_config(cls, config, custom_objects=None):
......
......@@ -14,7 +14,7 @@
# ==============================================================================
"""BERT token classifier."""
# pylint: disable=g-classes-have-attributes
import collections
import tensorflow as tf
......@@ -51,14 +51,6 @@ class BertTokenClassifier(tf.keras.Model):
output='logits',
dropout_rate=0.1,
**kwargs):
self._self_setattr_tracking = False
self._network = network
self._config = {
'network': network,
'num_classes': num_classes,
'initializer': initializer,
'output': output,
}
# We want to use the inputs of the passed network as the inputs to this
# Model. To do this, we need to keep a handle to the network inputs for use
......@@ -75,30 +67,56 @@ class BertTokenClassifier(tf.keras.Model):
sequence_output = tf.keras.layers.Dropout(rate=dropout_rate)(
sequence_output)
self.classifier = tf.keras.layers.Dense(
classifier = tf.keras.layers.Dense(
num_classes,
activation=None,
kernel_initializer=initializer,
name='predictions/transform/logits')
self.logits = self.classifier(sequence_output)
logits = classifier(sequence_output)
if output == 'logits':
output_tensors = self.logits
output_tensors = logits
elif output == 'predictions':
output_tensors = tf.keras.layers.Activation(tf.nn.log_softmax)(
self.logits)
output_tensors = tf.keras.layers.Activation(tf.nn.log_softmax)(logits)
else:
raise ValueError(
('Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"') % output)
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super(BertTokenClassifier, self).__init__(
inputs=inputs, outputs=output_tensors, **kwargs)
self._network = network
config_dict = {
'network': network,
'num_classes': num_classes,
'initializer': initializer,
'output': output,
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self.classifier = classifier
self.logits = logits
@property
def checkpoint_items(self):
return dict(encoder=self._network)
def get_config(self):
return self._config
return dict(self._config._asdict())
@classmethod
def from_config(cls, config, custom_objects=None):
......
......@@ -14,7 +14,7 @@
# ==============================================================================
"""Trainer network for dual encoder style models."""
# pylint: disable=g-classes-have-attributes
import collections
import tensorflow as tf
from official.nlp.modeling import layers
......@@ -50,17 +50,6 @@ class DualEncoder(tf.keras.Model):
logit_margin: float = 0.0,
output: str = 'logits',
**kwargs) -> None:
self._self_setattr_tracking = False
self._config = {
'network': network,
'max_seq_length': max_seq_length,
'normalize': normalize,
'logit_scale': logit_scale,
'logit_margin': logit_margin,
'output': output,
}
self.network = network
if output == 'logits':
left_word_ids = tf.keras.layers.Input(
......@@ -132,13 +121,35 @@ class DualEncoder(tf.keras.Model):
else:
raise ValueError('output type %s is not supported' % output)
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super(DualEncoder, self).__init__(inputs=inputs, outputs=outputs, **kwargs)
# Set _self_setattr_tracking to True so it can be exported with assets.
self._self_setattr_tracking = True
config_dict = {
'network': network,
'max_seq_length': max_seq_length,
'normalize': normalize,
'logit_scale': logit_scale,
'logit_margin': logit_margin,
'output': output,
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self.network = network
def get_config(self):
return self._config
return dict(self._config._asdict())
@classmethod
def from_config(cls, config, custom_objects=None):
......
......@@ -100,12 +100,12 @@ class ElectraPretrainer(tf.keras.Model):
output=output_type,
name='generator_masked_lm')
self.classification = layers.ClassificationHead(
inner_dim=generator_network._config_dict['hidden_size'],
inner_dim=generator_network.get_config()['hidden_size'],
num_classes=num_classes,
initializer=mlm_initializer,
name='generator_classification_head')
self.discriminator_projection = tf.keras.layers.Dense(
units=discriminator_network._config_dict['hidden_size'],
units=discriminator_network.get_config()['hidden_size'],
activation=mlm_activation,
kernel_initializer=mlm_initializer,
name='discriminator_projection_head')
......
......@@ -14,7 +14,7 @@
# ==============================================================================
"""ALBERT (https://arxiv.org/abs/1810.04805) text encoder network."""
# pylint: disable=g-classes-have-attributes
import collections
import tensorflow as tf
from official.modeling import activations
......@@ -81,22 +81,6 @@ class AlbertEncoder(tf.keras.Model):
activation = tf.keras.activations.get(activation)
initializer = tf.keras.initializers.get(initializer)
self._self_setattr_tracking = False
self._config_dict = {
'vocab_size': vocab_size,
'embedding_width': embedding_width,
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_attention_heads': num_attention_heads,
'max_sequence_length': max_sequence_length,
'type_vocab_size': type_vocab_size,
'intermediate_size': intermediate_size,
'activation': tf.keras.activations.serialize(activation),
'dropout_rate': dropout_rate,
'attention_dropout_rate': attention_dropout_rate,
'initializer': tf.keras.initializers.serialize(initializer),
}
word_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_word_ids')
mask = tf.keras.layers.Input(
......@@ -106,19 +90,19 @@ class AlbertEncoder(tf.keras.Model):
if embedding_width is None:
embedding_width = hidden_size
self._embedding_layer = layers.OnDeviceEmbedding(
embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=initializer,
name='word_embeddings')
word_embeddings = self._embedding_layer(word_ids)
word_embeddings = embedding_layer(word_ids)
# Always uses dynamic slicing for simplicity.
self._position_embedding_layer = keras_nlp.layers.PositionEmbedding(
position_embedding_layer = keras_nlp.layers.PositionEmbedding(
initializer=initializer,
max_length=max_sequence_length,
name='position_embedding')
position_embeddings = self._position_embedding_layer(word_embeddings)
position_embeddings = position_embedding_layer(word_embeddings)
type_embeddings = (
layers.OnDeviceEmbedding(
......@@ -182,14 +166,45 @@ class AlbertEncoder(tf.keras.Model):
else:
outputs = [data, cls_output]
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super(AlbertEncoder, self).__init__(
inputs=[word_ids, mask, type_ids], outputs=outputs, **kwargs)
config_dict = {
'vocab_size': vocab_size,
'embedding_width': embedding_width,
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_attention_heads': num_attention_heads,
'max_sequence_length': max_sequence_length,
'type_vocab_size': type_vocab_size,
'intermediate_size': intermediate_size,
'activation': tf.keras.activations.serialize(activation),
'dropout_rate': dropout_rate,
'attention_dropout_rate': attention_dropout_rate,
'initializer': tf.keras.initializers.serialize(initializer),
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self._embedding_layer = embedding_layer
self._position_embedding_layer = position_embedding_layer
def get_embedding_table(self):
return self._embedding_layer.embeddings
def get_config(self):
return self._config_dict
return dict(self._config._asdict())
@classmethod
def from_config(cls, config):
......
......@@ -14,7 +14,7 @@
# ==============================================================================
"""Transformer-based text encoder network."""
# pylint: disable=g-classes-have-attributes
import collections
import tensorflow as tf
from official.modeling import activations
......@@ -99,9 +99,13 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
dict_outputs=False,
**kwargs):
self._self_setattr_tracking = False
self._embedding_layer_instance = embedding_layer
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super(BertEncoder, self).__init__(
vocab_size=vocab_size,
hidden_size=hidden_size,
......@@ -115,16 +119,21 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
attention_dropout=attention_dropout_rate,
initializer=initializer,
output_range=output_range,
embedding_width=embedding_width)
embedding_width=embedding_width,
embedding_layer=embedding_layer)
self._embedding_layer_instance = embedding_layer
# Replace arguments from keras_nlp.encoders.BertEncoder.
self._config_dict['activation'] = self._config_dict.pop('inner_activation')
self._config_dict['intermediate_size'] = self._config_dict.pop('inner_dim')
self._config_dict['dropout_rate'] = self._config_dict.pop('output_dropout')
self._config_dict['attention_dropout_rate'] = self._config_dict.pop(
'attention_dropout')
self._config_dict['dict_outputs'] = dict_outputs
self._config_dict['return_all_encoder_outputs'] = return_all_encoder_outputs
config_dict = self._config._asdict()
config_dict['activation'] = config_dict.pop('inner_activation')
config_dict['intermediate_size'] = config_dict.pop('inner_dim')
config_dict['dropout_rate'] = config_dict.pop('output_dropout')
config_dict['attention_dropout_rate'] = config_dict.pop('attention_dropout')
config_dict['dict_outputs'] = dict_outputs
config_dict['return_all_encoder_outputs'] = return_all_encoder_outputs
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
if dict_outputs:
return
......@@ -139,10 +148,3 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
outputs = [sequence_output, cls_output]
super(keras_nlp.encoders.BertEncoder, self).__init__(
inputs=self.inputs, outputs=outputs, **kwargs)
# Override method for shared embedding use case.
def _build_embedding_layer(self):
if self._embedding_layer_instance is None:
return super(BertEncoder, self)._build_embedding_layer()
else:
return self._embedding_layer_instance
......@@ -225,13 +225,15 @@ class BertEncoderTest(keras_parameterized.TestCase):
return_all_encoder_outputs=False,
output_range=-1,
embedding_width=16,
dict_outputs=True)
dict_outputs=True,
embedding_layer=None)
network = bert_encoder.BertEncoder(**kwargs)
expected_config = dict(kwargs)
expected_config["activation"] = tf.keras.activations.serialize(
tf.keras.activations.get(expected_config["activation"]))
expected_config["initializer"] = tf.keras.initializers.serialize(
tf.keras.initializers.get(expected_config["initializer"]))
self.assertEqual(network.get_config(), expected_config)
# Create another network object from the first object's config.
new_network = bert_encoder.BertEncoder.from_config(network.get_config())
......
......@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import collections
import tensorflow as tf
......@@ -49,36 +49,27 @@ class Classification(tf.keras.Model):
initializer='glorot_uniform',
output='logits',
**kwargs):
self._self_setattr_tracking = False
self._config_dict = {
'input_width': input_width,
'num_classes': num_classes,
'initializer': initializer,
'output': output,
}
cls_output = tf.keras.layers.Input(
shape=(input_width,), name='cls_output', dtype=tf.float32)
self.logits = tf.keras.layers.Dense(
logits = tf.keras.layers.Dense(
num_classes,
activation=None,
kernel_initializer=initializer,
name='predictions/transform/logits')(
cls_output)
policy = tf.keras.mixed_precision.experimental.global_policy()
if policy.name == 'mixed_bfloat16':
# b/158514794: bf16 is not stable with post-softmax cross-entropy.
policy = tf.float32
predictions = tf.keras.layers.Activation(
tf.nn.log_softmax, dtype=policy)(
self.logits)
if output == 'logits':
output_tensors = self.logits
output_tensors = logits
elif output == 'predictions':
output_tensors = predictions
policy = tf.keras.mixed_precision.experimental.global_policy()
if policy.name == 'mixed_bfloat16':
# b/158514794: bf16 is not stable with post-softmax cross-entropy.
policy = tf.float32
output_tensors = tf.keras.layers.Activation(
tf.nn.log_softmax, dtype=policy)(
logits)
else:
raise ValueError(
('Unknown `output` value "%s". `output` can be either "logits" or '
......@@ -87,8 +78,30 @@ class Classification(tf.keras.Model):
super(Classification, self).__init__(
inputs=[cls_output], outputs=output_tensors, **kwargs)
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
config_dict = {
'input_width': input_width,
'num_classes': num_classes,
'initializer': initializer,
'output': output,
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self.logits = logits
def get_config(self):
return self._config_dict
return dict(self._config._asdict())
@classmethod
def from_config(cls, config, custom_objects=None):
......
......@@ -109,29 +109,18 @@ class EncoderScaffold(tf.keras.Model):
return_all_layer_outputs=False,
dict_outputs=False,
**kwargs):
self._self_setattr_tracking = False
self._hidden_cls = hidden_cls
self._hidden_cfg = hidden_cfg
self._num_hidden_instances = num_hidden_instances
self._pooled_output_dim = pooled_output_dim
self._pooler_layer_initializer = pooler_layer_initializer
self._embedding_cls = embedding_cls
self._embedding_cfg = embedding_cfg
self._embedding_data = embedding_data
self._return_all_layer_outputs = return_all_layer_outputs
self._dict_outputs = dict_outputs
self._kwargs = kwargs
if embedding_cls:
if inspect.isclass(embedding_cls):
self._embedding_network = embedding_cls(
embedding_network = embedding_cls(
**embedding_cfg) if embedding_cfg else embedding_cls()
else:
self._embedding_network = embedding_cls
inputs = self._embedding_network.inputs
embeddings, attention_mask = self._embedding_network(inputs)
embedding_network = embedding_cls
inputs = embedding_network.inputs
embeddings, attention_mask = embedding_network(inputs)
embedding_layer = None
else:
self._embedding_network = None
embedding_network = None
seq_length = embedding_cfg.get('seq_length', None)
word_ids = tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='input_word_ids')
......@@ -141,38 +130,38 @@ class EncoderScaffold(tf.keras.Model):
shape=(seq_length,), dtype=tf.int32, name='input_type_ids')
inputs = [word_ids, mask, type_ids]
self._embedding_layer = keras_nlp.layers.OnDeviceEmbedding(
embedding_layer = keras_nlp.layers.OnDeviceEmbedding(
vocab_size=embedding_cfg['vocab_size'],
embedding_width=embedding_cfg['hidden_size'],
initializer=embedding_cfg['initializer'],
name='word_embeddings')
word_embeddings = self._embedding_layer(word_ids)
word_embeddings = embedding_layer(word_ids)
# Always uses dynamic slicing for simplicity.
self._position_embedding_layer = keras_nlp.layers.PositionEmbedding(
position_embedding_layer = keras_nlp.layers.PositionEmbedding(
initializer=embedding_cfg['initializer'],
max_length=embedding_cfg['max_seq_length'],
name='position_embedding')
position_embeddings = self._position_embedding_layer(word_embeddings)
position_embeddings = position_embedding_layer(word_embeddings)
self._type_embedding_layer = keras_nlp.layers.OnDeviceEmbedding(
type_embedding_layer = keras_nlp.layers.OnDeviceEmbedding(
vocab_size=embedding_cfg['type_vocab_size'],
embedding_width=embedding_cfg['hidden_size'],
initializer=embedding_cfg['initializer'],
use_one_hot=True,
name='type_embeddings')
type_embeddings = self._type_embedding_layer(type_ids)
type_embeddings = type_embedding_layer(type_ids)
embeddings = tf.keras.layers.Add()(
[word_embeddings, position_embeddings, type_embeddings])
self._embedding_norm_layer = tf.keras.layers.LayerNormalization(
embedding_norm_layer = tf.keras.layers.LayerNormalization(
name='embeddings/layer_norm',
axis=-1,
epsilon=1e-12,
dtype=tf.float32)
embeddings = self._embedding_norm_layer(embeddings)
embeddings = embedding_norm_layer(embeddings)
embeddings = (
tf.keras.layers.Dropout(
......@@ -183,7 +172,7 @@ class EncoderScaffold(tf.keras.Model):
data = embeddings
layer_output_data = []
self._hidden_layers = []
hidden_layers = []
for _ in range(num_hidden_instances):
if inspect.isclass(hidden_cls):
layer = hidden_cls(**hidden_cfg) if hidden_cfg else hidden_cls()
......@@ -191,19 +180,19 @@ class EncoderScaffold(tf.keras.Model):
layer = hidden_cls
data = layer([data, attention_mask])
layer_output_data.append(data)
self._hidden_layers.append(layer)
hidden_layers.append(layer)
last_layer_output = layer_output_data[-1]
# Applying a tf.slice op (through subscript notation) to a Keras tensor
# like this will create a SliceOpLambda layer. This is better than a Lambda
# layer with Python code, because that is fundamentally less portable.
first_token_tensor = last_layer_output[:, 0, :]
self._pooler_layer = tf.keras.layers.Dense(
pooler_layer = tf.keras.layers.Dense(
units=pooled_output_dim,
activation='tanh',
kernel_initializer=pooler_layer_initializer,
name='cls_transform')
cls_output = self._pooler_layer(first_token_tensor)
cls_output = pooler_layer(first_token_tensor)
if dict_outputs:
outputs = dict(
......@@ -216,9 +205,33 @@ class EncoderScaffold(tf.keras.Model):
else:
outputs = [layer_output_data[-1], cls_output]
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super(EncoderScaffold, self).__init__(
inputs=inputs, outputs=outputs, **kwargs)
self._hidden_cls = hidden_cls
self._hidden_cfg = hidden_cfg
self._num_hidden_instances = num_hidden_instances
self._pooled_output_dim = pooled_output_dim
self._pooler_layer_initializer = pooler_layer_initializer
self._embedding_cls = embedding_cls
self._embedding_cfg = embedding_cfg
self._embedding_data = embedding_data
self._return_all_layer_outputs = return_all_layer_outputs
self._dict_outputs = dict_outputs
self._kwargs = kwargs
self._embedding_layer = embedding_layer
self._embedding_network = embedding_network
self._hidden_layers = hidden_layers
self._pooler_layer = pooler_layer
logging.info('EncoderScaffold configs: %s', self.get_config())
def get_config(self):
......
......@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import collections
import tensorflow as tf
......@@ -53,13 +53,6 @@ class SpanLabeling(tf.keras.Model):
initializer='glorot_uniform',
output='logits',
**kwargs):
self._self_setattr_tracking = False
self._config = {
'input_width': input_width,
'activation': activation,
'initializer': initializer,
'output': output,
}
sequence_data = tf.keras.layers.Input(
shape=(None, input_width), name='sequence_data', dtype=tf.float32)
......@@ -70,16 +63,14 @@ class SpanLabeling(tf.keras.Model):
kernel_initializer=initializer,
name='predictions/transform/logits')(
sequence_data)
self.start_logits, self.end_logits = (
tf.keras.layers.Lambda(self._split_output_tensor)(intermediate_logits))
start_logits, end_logits = self._split_output_tensor(intermediate_logits)
start_predictions = tf.keras.layers.Activation(tf.nn.log_softmax)(
self.start_logits)
end_predictions = tf.keras.layers.Activation(tf.nn.log_softmax)(
self.end_logits)
start_logits)
end_predictions = tf.keras.layers.Activation(tf.nn.log_softmax)(end_logits)
if output == 'logits':
output_tensors = [self.start_logits, self.end_logits]
output_tensors = [start_logits, end_logits]
elif output == 'predictions':
output_tensors = [start_predictions, end_predictions]
else:
......@@ -87,15 +78,37 @@ class SpanLabeling(tf.keras.Model):
('Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"') % output)
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super(SpanLabeling, self).__init__(
inputs=[sequence_data], outputs=output_tensors, **kwargs)
config_dict = {
'input_width': input_width,
'activation': activation,
'initializer': initializer,
'output': output,
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self.start_logits = start_logits
self.end_logits = end_logits
def _split_output_tensor(self, tensor):
transposed_tensor = tf.transpose(tensor, [2, 0, 1])
return tf.unstack(transposed_tensor)
def get_config(self):
return self._config
return dict(self._config._asdict())
@classmethod
def from_config(cls, config, custom_objects=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment