"...model/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "8e8103a8ce976ca883ce0fd78c0aa14b67075dd4"
Commit f5a1343f authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Shift all attribute assignments in functional subclass models below the...

Shift all attribute assignments in functional subclass models below the super().__init__ call, and enable setattr tracking on all functional subclasses.

PiperOrigin-RevId: 338759531
parent f16e42b4
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
"""Bert encoder network.""" """Bert encoder network."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
import collections
from absl import logging
import tensorflow as tf import tensorflow as tf
from official.nlp.keras_nlp import layers from official.nlp.keras_nlp import layers
...@@ -65,6 +67,8 @@ class BertEncoder(tf.keras.Model): ...@@ -65,6 +67,8 @@ class BertEncoder(tf.keras.Model):
matrices in the shape of ['vocab_size', 'embedding_width'] and matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much ['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size'). smaller than 'hidden_size').
embedding_layer: An optional Layer instance which will be called to
generate embeddings for the input word IDs.
""" """
def __init__( def __init__(
...@@ -82,27 +86,11 @@ class BertEncoder(tf.keras.Model): ...@@ -82,27 +86,11 @@ class BertEncoder(tf.keras.Model):
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02), initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
output_range=None, output_range=None,
embedding_width=None, embedding_width=None,
embedding_layer=None,
**kwargs): **kwargs):
activation = tf.keras.activations.get(inner_activation) activation = tf.keras.activations.get(inner_activation)
initializer = tf.keras.initializers.get(initializer) initializer = tf.keras.initializers.get(initializer)
self._self_setattr_tracking = False
self._config_dict = {
'vocab_size': vocab_size,
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_attention_heads': num_attention_heads,
'max_sequence_length': max_sequence_length,
'type_vocab_size': type_vocab_size,
'inner_dim': inner_dim,
'inner_activation': tf.keras.activations.serialize(activation),
'output_dropout': output_dropout,
'attention_dropout': attention_dropout,
'initializer': tf.keras.initializers.serialize(initializer),
'output_range': output_range,
'embedding_width': embedding_width,
}
word_ids = tf.keras.layers.Input( word_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_word_ids') shape=(None,), dtype=tf.int32, name='input_word_ids')
mask = tf.keras.layers.Input( mask = tf.keras.layers.Input(
...@@ -112,44 +100,54 @@ class BertEncoder(tf.keras.Model): ...@@ -112,44 +100,54 @@ class BertEncoder(tf.keras.Model):
if embedding_width is None: if embedding_width is None:
embedding_width = hidden_size embedding_width = hidden_size
self._embedding_layer = self._build_embedding_layer()
word_embeddings = self._embedding_layer(word_ids) if embedding_layer is None:
embedding_layer_inst = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=initializer,
name='word_embeddings')
else:
embedding_layer_inst = embedding_layer
word_embeddings = embedding_layer_inst(word_ids)
# Always uses dynamic slicing for simplicity. # Always uses dynamic slicing for simplicity.
self._position_embedding_layer = layers.PositionEmbedding( position_embedding_layer = layers.PositionEmbedding(
initializer=initializer, initializer=initializer,
max_length=max_sequence_length, max_length=max_sequence_length,
name='position_embedding') name='position_embedding')
position_embeddings = self._position_embedding_layer(word_embeddings) position_embeddings = position_embedding_layer(word_embeddings)
self._type_embedding_layer = layers.OnDeviceEmbedding( type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=type_vocab_size, vocab_size=type_vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
initializer=initializer, initializer=initializer,
use_one_hot=True, use_one_hot=True,
name='type_embeddings') name='type_embeddings')
type_embeddings = self._type_embedding_layer(type_ids) type_embeddings = type_embedding_layer(type_ids)
embeddings = tf.keras.layers.Add()( embeddings = tf.keras.layers.Add()(
[word_embeddings, position_embeddings, type_embeddings]) [word_embeddings, position_embeddings, type_embeddings])
self._embedding_norm_layer = tf.keras.layers.LayerNormalization( embedding_norm_layer = tf.keras.layers.LayerNormalization(
name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32) name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32)
embeddings = self._embedding_norm_layer(embeddings) embeddings = embedding_norm_layer(embeddings)
embeddings = (tf.keras.layers.Dropout(rate=output_dropout)(embeddings)) embeddings = (tf.keras.layers.Dropout(rate=output_dropout)(embeddings))
# We project the 'embedding' output to 'hidden_size' if it is not already # We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'. # 'hidden_size'.
if embedding_width != hidden_size: if embedding_width != hidden_size:
self._embedding_projection = tf.keras.layers.experimental.EinsumDense( embedding_projection = tf.keras.layers.experimental.EinsumDense(
'...x,xy->...y', '...x,xy->...y',
output_shape=hidden_size, output_shape=hidden_size,
bias_axes='y', bias_axes='y',
kernel_initializer=initializer, kernel_initializer=initializer,
name='embedding_projection') name='embedding_projection')
embeddings = self._embedding_projection(embeddings) embeddings = embedding_projection(embeddings)
else:
embedding_projection = None
self._transformer_layers = [] transformer_layers = []
data = embeddings data = embeddings
attention_mask = layers.SelfAttentionMask()(data, mask) attention_mask = layers.SelfAttentionMask()(data, mask)
encoder_outputs = [] encoder_outputs = []
...@@ -167,7 +165,7 @@ class BertEncoder(tf.keras.Model): ...@@ -167,7 +165,7 @@ class BertEncoder(tf.keras.Model):
output_range=transformer_output_range, output_range=transformer_output_range,
kernel_initializer=initializer, kernel_initializer=initializer,
name='transformer/layer_%d' % i) name='transformer/layer_%d' % i)
self._transformer_layers.append(layer) transformer_layers.append(layer)
data = layer([data, attention_mask]) data = layer([data, attention_mask])
encoder_outputs.append(data) encoder_outputs.append(data)
...@@ -176,38 +174,68 @@ class BertEncoder(tf.keras.Model): ...@@ -176,38 +174,68 @@ class BertEncoder(tf.keras.Model):
# like this will create a SliceOpLambda layer. This is better than a Lambda # like this will create a SliceOpLambda layer. This is better than a Lambda
# layer with Python code, because that is fundamentally less portable. # layer with Python code, because that is fundamentally less portable.
first_token_tensor = last_enocder_output[:, 0, :] first_token_tensor = last_enocder_output[:, 0, :]
self._pooler_layer = tf.keras.layers.Dense( pooler_layer = tf.keras.layers.Dense(
units=hidden_size, units=hidden_size,
activation='tanh', activation='tanh',
kernel_initializer=initializer, kernel_initializer=initializer,
name='pooler_transform') name='pooler_transform')
cls_output = self._pooler_layer(first_token_tensor) cls_output = pooler_layer(first_token_tensor)
outputs = dict( outputs = dict(
sequence_output=encoder_outputs[-1], sequence_output=encoder_outputs[-1],
pooled_output=cls_output, pooled_output=cls_output,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
) )
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super(BertEncoder, self).__init__( super(BertEncoder, self).__init__(
inputs=[word_ids, mask, type_ids], outputs=outputs, **kwargs) inputs=[word_ids, mask, type_ids], outputs=outputs, **kwargs)
config_dict = {
'vocab_size': vocab_size,
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_attention_heads': num_attention_heads,
'max_sequence_length': max_sequence_length,
'type_vocab_size': type_vocab_size,
'inner_dim': inner_dim,
'inner_activation': tf.keras.activations.serialize(activation),
'output_dropout': output_dropout,
'attention_dropout': attention_dropout,
'initializer': tf.keras.initializers.serialize(initializer),
'output_range': output_range,
'embedding_width': embedding_width,
'embedding_layer': embedding_layer,
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self._pooler_layer = pooler_layer
self._transformer_layers = transformer_layers
self._embedding_norm_layer = embedding_norm_layer
self._embedding_layer = embedding_layer_inst
self._position_embedding_layer = position_embedding_layer
self._type_embedding_layer = type_embedding_layer
self._embedding_projection = embedding_projection
def get_embedding_table(self): def get_embedding_table(self):
return self._embedding_layer.embeddings return self._embedding_layer.embeddings
def _build_embedding_layer(self):
embedding_width = self._config_dict[
'embedding_width'] or self._config_dict['hidden_size']
return layers.OnDeviceEmbedding(
vocab_size=self._config_dict['vocab_size'],
embedding_width=embedding_width,
initializer=self._config_dict['initializer'],
name='word_embeddings')
def get_embedding_layer(self): def get_embedding_layer(self):
return self._embedding_layer return self._embedding_layer
def get_config(self): def get_config(self):
return self._config_dict return dict(self._config._asdict())
@property @property
def transformer_layers(self): def transformer_layers(self):
...@@ -221,4 +249,13 @@ class BertEncoder(tf.keras.Model): ...@@ -221,4 +249,13 @@ class BertEncoder(tf.keras.Model):
@classmethod @classmethod
def from_config(cls, config, custom_objects=None): def from_config(cls, config, custom_objects=None):
if config['embedding_layer'] is not None:
warn_string = (
'You are reloading a model that was saved with a '
'potentially-shared embedding layer object. If you contine to '
'train this model, the embedding layer will no longer be shared. '
'To work around this, load the model outside of the Keras API.')
print('WARNING: ' + warn_string)
logging.warn(warn_string)
return cls(**config) return cls(**config)
...@@ -204,7 +204,8 @@ class BertEncoderTest(keras_parameterized.TestCase): ...@@ -204,7 +204,8 @@ class BertEncoderTest(keras_parameterized.TestCase):
attention_dropout=0.22, attention_dropout=0.22,
initializer="glorot_uniform", initializer="glorot_uniform",
output_range=-1, output_range=-1,
embedding_width=16) embedding_width=16,
embedding_layer=None)
network = bert_encoder.BertEncoder(**kwargs) network = bert_encoder.BertEncoder(**kwargs)
expected_config = dict(kwargs) expected_config = dict(kwargs)
expected_config["inner_activation"] = tf.keras.activations.serialize( expected_config["inner_activation"] = tf.keras.activations.serialize(
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# ============================================================================== # ==============================================================================
"""BERT cls-token classifier.""" """BERT cls-token classifier."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
import collections
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling import layers from official.nlp.modeling import layers
...@@ -55,14 +55,6 @@ class BertClassifier(tf.keras.Model): ...@@ -55,14 +55,6 @@ class BertClassifier(tf.keras.Model):
dropout_rate=0.1, dropout_rate=0.1,
use_encoder_pooler=True, use_encoder_pooler=True,
**kwargs): **kwargs):
self._self_setattr_tracking = False
self._network = network
self._config = {
'network': network,
'num_classes': num_classes,
'initializer': initializer,
'use_encoder_pooler': use_encoder_pooler,
}
# We want to use the inputs of the passed network as the inputs to this # We want to use the inputs of the passed network as the inputs to this
# Model. To do this, we need to keep a handle to the network inputs for use # Model. To do this, we need to keep a handle to the network inputs for use
...@@ -79,29 +71,52 @@ class BertClassifier(tf.keras.Model): ...@@ -79,29 +71,52 @@ class BertClassifier(tf.keras.Model):
cls_output = outputs['pooled_output'] cls_output = outputs['pooled_output']
cls_output = tf.keras.layers.Dropout(rate=dropout_rate)(cls_output) cls_output = tf.keras.layers.Dropout(rate=dropout_rate)(cls_output)
self.classifier = networks.Classification( classifier = networks.Classification(
input_width=cls_output.shape[-1], input_width=cls_output.shape[-1],
num_classes=num_classes, num_classes=num_classes,
initializer=initializer, initializer=initializer,
output='logits', output='logits',
name='sentence_prediction') name='sentence_prediction')
predictions = self.classifier(cls_output) predictions = classifier(cls_output)
else: else:
outputs = network(inputs) outputs = network(inputs)
if isinstance(outputs, list): if isinstance(outputs, list):
sequence_output = outputs[0] sequence_output = outputs[0]
else: else:
sequence_output = outputs['sequence_output'] sequence_output = outputs['sequence_output']
self.classifier = layers.ClassificationHead( classifier = layers.ClassificationHead(
inner_dim=sequence_output.shape[-1], inner_dim=sequence_output.shape[-1],
num_classes=num_classes, num_classes=num_classes,
initializer=initializer, initializer=initializer,
dropout_rate=dropout_rate, dropout_rate=dropout_rate,
name='sentence_prediction') name='sentence_prediction')
predictions = self.classifier(sequence_output) predictions = classifier(sequence_output)
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super(BertClassifier, self).__init__( super(BertClassifier, self).__init__(
inputs=inputs, outputs=predictions, **kwargs) inputs=inputs, outputs=predictions, **kwargs)
self._network = network
config_dict = {
'network': network,
'num_classes': num_classes,
'initializer': initializer,
'use_encoder_pooler': use_encoder_pooler,
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self.classifier = classifier
@property @property
def checkpoint_items(self): def checkpoint_items(self):
...@@ -112,7 +127,7 @@ class BertClassifier(tf.keras.Model): ...@@ -112,7 +127,7 @@ class BertClassifier(tf.keras.Model):
return items return items
def get_config(self): def get_config(self):
return self._config return dict(self._config._asdict())
@classmethod @classmethod
def from_config(cls, config, custom_objects=None): def from_config(cls, config, custom_objects=None):
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# ============================================================================== # ==============================================================================
"""BERT Pre-training model.""" """BERT Pre-training model."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
import collections
import copy import copy
from typing import List, Optional from typing import List, Optional
...@@ -64,21 +64,12 @@ class BertPretrainer(tf.keras.Model): ...@@ -64,21 +64,12 @@ class BertPretrainer(tf.keras.Model):
initializer='glorot_uniform', initializer='glorot_uniform',
output='logits', output='logits',
**kwargs): **kwargs):
self._self_setattr_tracking = False
self._config = {
'network': network,
'num_classes': num_classes,
'num_token_predictions': num_token_predictions,
'activation': activation,
'initializer': initializer,
'output': output,
}
self.encoder = network
# We want to use the inputs of the passed network as the inputs to this # We want to use the inputs of the passed network as the inputs to this
# Model. To do this, we need to keep a copy of the network inputs for use # Model. To do this, we need to keep a copy of the network inputs for use
# when we construct the Model object at the end of init. (We keep a copy # when we construct the Model object at the end of init. (We keep a copy
# because we'll be adding another tensor to the copy later.) # because we'll be adding another tensor to the copy later.)
network_inputs = self.encoder.inputs network_inputs = network.inputs
inputs = copy.copy(network_inputs) inputs = copy.copy(network_inputs)
# Because we have a copy of inputs to create this Model object, we can # Because we have a copy of inputs to create this Model object, we can
...@@ -86,7 +77,7 @@ class BertPretrainer(tf.keras.Model): ...@@ -86,7 +77,7 @@ class BertPretrainer(tf.keras.Model):
# Note that, because of how deferred construction happens, we can't use # Note that, because of how deferred construction happens, we can't use
# the copy of the list here - by the time the network is invoked, the list # the copy of the list here - by the time the network is invoked, the list
# object contains the additional input added below. # object contains the additional input added below.
sequence_output, cls_output = self.encoder(network_inputs) sequence_output, cls_output = network(network_inputs)
# The encoder network may get outputs from all layers. # The encoder network may get outputs from all layers.
if isinstance(sequence_output, list): if isinstance(sequence_output, list):
...@@ -108,31 +99,59 @@ class BertPretrainer(tf.keras.Model): ...@@ -108,31 +99,59 @@ class BertPretrainer(tf.keras.Model):
inputs.append(masked_lm_positions) inputs.append(masked_lm_positions)
if embedding_table is None: if embedding_table is None:
embedding_table = self.encoder.get_embedding_table() embedding_table = network.get_embedding_table()
self.masked_lm = layers.MaskedLM( masked_lm = layers.MaskedLM(
embedding_table=embedding_table, embedding_table=embedding_table,
activation=activation, activation=activation,
initializer=initializer, initializer=initializer,
output=output, output=output,
name='cls/predictions') name='cls/predictions')
lm_outputs = self.masked_lm( lm_outputs = masked_lm(
sequence_output, masked_positions=masked_lm_positions) sequence_output, masked_positions=masked_lm_positions)
self.classification = networks.Classification( classification = networks.Classification(
input_width=cls_output.shape[-1], input_width=cls_output.shape[-1],
num_classes=num_classes, num_classes=num_classes,
initializer=initializer, initializer=initializer,
output=output, output=output,
name='classification') name='classification')
sentence_outputs = self.classification(cls_output) sentence_outputs = classification(cls_output)
super(BertPretrainer, self).__init__( super(BertPretrainer, self).__init__(
inputs=inputs, inputs=inputs,
outputs=dict(masked_lm=lm_outputs, classification=sentence_outputs), outputs=dict(masked_lm=lm_outputs, classification=sentence_outputs),
**kwargs) **kwargs)
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
config_dict = {
'network': network,
'num_classes': num_classes,
'num_token_predictions': num_token_predictions,
'activation': activation,
'initializer': initializer,
'output': output,
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self.encoder = network
self.classification = classification
self.masked_lm = masked_lm
def get_config(self): def get_config(self):
return self._config return dict(self._config._asdict())
@classmethod @classmethod
def from_config(cls, config, custom_objects=None): def from_config(cls, config, custom_objects=None):
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# ============================================================================== # ==============================================================================
"""BERT Question Answering model.""" """BERT Question Answering model."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
import collections
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling import networks from official.nlp.modeling import networks
...@@ -49,13 +49,6 @@ class BertSpanLabeler(tf.keras.Model): ...@@ -49,13 +49,6 @@ class BertSpanLabeler(tf.keras.Model):
initializer='glorot_uniform', initializer='glorot_uniform',
output='logits', output='logits',
**kwargs): **kwargs):
self._self_setattr_tracking = False
self._network = network
self._config = {
'network': network,
'initializer': initializer,
'output': output,
}
# We want to use the inputs of the passed network as the inputs to this # We want to use the inputs of the passed network as the inputs to this
# Model. To do this, we need to keep a handle to the network inputs for use # Model. To do this, we need to keep a handle to the network inputs for use
...@@ -72,12 +65,12 @@ class BertSpanLabeler(tf.keras.Model): ...@@ -72,12 +65,12 @@ class BertSpanLabeler(tf.keras.Model):
# This is an instance variable for ease of access to the underlying task # This is an instance variable for ease of access to the underlying task
# network. # network.
self.span_labeling = networks.SpanLabeling( span_labeling = networks.SpanLabeling(
input_width=sequence_output.shape[-1], input_width=sequence_output.shape[-1],
initializer=initializer, initializer=initializer,
output=output, output=output,
name='span_labeling') name='span_labeling')
start_logits, end_logits = self.span_labeling(sequence_output) start_logits, end_logits = span_labeling(sequence_output)
# Use identity layers wrapped in lambdas to explicitly name the output # Use identity layers wrapped in lambdas to explicitly name the output
# tensors. This allows us to use string-keyed dicts in Keras fit/predict/ # tensors. This allows us to use string-keyed dicts in Keras fit/predict/
...@@ -91,15 +84,36 @@ class BertSpanLabeler(tf.keras.Model): ...@@ -91,15 +84,36 @@ class BertSpanLabeler(tf.keras.Model):
logits = [start_logits, end_logits] logits = [start_logits, end_logits]
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super(BertSpanLabeler, self).__init__( super(BertSpanLabeler, self).__init__(
inputs=inputs, outputs=logits, **kwargs) inputs=inputs, outputs=logits, **kwargs)
self._network = network
config_dict = {
'network': network,
'initializer': initializer,
'output': output,
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self.span_labeling = span_labeling
@property @property
def checkpoint_items(self): def checkpoint_items(self):
return dict(encoder=self._network) return dict(encoder=self._network)
def get_config(self): def get_config(self):
return self._config return dict(self._config._asdict())
@classmethod @classmethod
def from_config(cls, config, custom_objects=None): def from_config(cls, config, custom_objects=None):
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# ============================================================================== # ==============================================================================
"""BERT token classifier.""" """BERT token classifier."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
import collections
import tensorflow as tf import tensorflow as tf
...@@ -51,14 +51,6 @@ class BertTokenClassifier(tf.keras.Model): ...@@ -51,14 +51,6 @@ class BertTokenClassifier(tf.keras.Model):
output='logits', output='logits',
dropout_rate=0.1, dropout_rate=0.1,
**kwargs): **kwargs):
self._self_setattr_tracking = False
self._network = network
self._config = {
'network': network,
'num_classes': num_classes,
'initializer': initializer,
'output': output,
}
# We want to use the inputs of the passed network as the inputs to this # We want to use the inputs of the passed network as the inputs to this
# Model. To do this, we need to keep a handle to the network inputs for use # Model. To do this, we need to keep a handle to the network inputs for use
...@@ -75,30 +67,56 @@ class BertTokenClassifier(tf.keras.Model): ...@@ -75,30 +67,56 @@ class BertTokenClassifier(tf.keras.Model):
sequence_output = tf.keras.layers.Dropout(rate=dropout_rate)( sequence_output = tf.keras.layers.Dropout(rate=dropout_rate)(
sequence_output) sequence_output)
self.classifier = tf.keras.layers.Dense( classifier = tf.keras.layers.Dense(
num_classes, num_classes,
activation=None, activation=None,
kernel_initializer=initializer, kernel_initializer=initializer,
name='predictions/transform/logits') name='predictions/transform/logits')
self.logits = self.classifier(sequence_output) logits = classifier(sequence_output)
if output == 'logits': if output == 'logits':
output_tensors = self.logits output_tensors = logits
elif output == 'predictions': elif output == 'predictions':
output_tensors = tf.keras.layers.Activation(tf.nn.log_softmax)( output_tensors = tf.keras.layers.Activation(tf.nn.log_softmax)(logits)
self.logits)
else: else:
raise ValueError( raise ValueError(
('Unknown `output` value "%s". `output` can be either "logits" or ' ('Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"') % output) '"predictions"') % output)
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super(BertTokenClassifier, self).__init__( super(BertTokenClassifier, self).__init__(
inputs=inputs, outputs=output_tensors, **kwargs) inputs=inputs, outputs=output_tensors, **kwargs)
self._network = network
config_dict = {
'network': network,
'num_classes': num_classes,
'initializer': initializer,
'output': output,
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self.classifier = classifier
self.logits = logits
@property @property
def checkpoint_items(self): def checkpoint_items(self):
return dict(encoder=self._network) return dict(encoder=self._network)
def get_config(self): def get_config(self):
return self._config return dict(self._config._asdict())
@classmethod @classmethod
def from_config(cls, config, custom_objects=None): def from_config(cls, config, custom_objects=None):
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# ============================================================================== # ==============================================================================
"""Trainer network for dual encoder style models.""" """Trainer network for dual encoder style models."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
import collections
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling import layers from official.nlp.modeling import layers
...@@ -50,17 +50,6 @@ class DualEncoder(tf.keras.Model): ...@@ -50,17 +50,6 @@ class DualEncoder(tf.keras.Model):
logit_margin: float = 0.0, logit_margin: float = 0.0,
output: str = 'logits', output: str = 'logits',
**kwargs) -> None: **kwargs) -> None:
self._self_setattr_tracking = False
self._config = {
'network': network,
'max_seq_length': max_seq_length,
'normalize': normalize,
'logit_scale': logit_scale,
'logit_margin': logit_margin,
'output': output,
}
self.network = network
if output == 'logits': if output == 'logits':
left_word_ids = tf.keras.layers.Input( left_word_ids = tf.keras.layers.Input(
...@@ -132,13 +121,35 @@ class DualEncoder(tf.keras.Model): ...@@ -132,13 +121,35 @@ class DualEncoder(tf.keras.Model):
else: else:
raise ValueError('output type %s is not supported' % output) raise ValueError('output type %s is not supported' % output)
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super(DualEncoder, self).__init__(inputs=inputs, outputs=outputs, **kwargs) super(DualEncoder, self).__init__(inputs=inputs, outputs=outputs, **kwargs)
# Set _self_setattr_tracking to True so it can be exported with assets. config_dict = {
self._self_setattr_tracking = True 'network': network,
'max_seq_length': max_seq_length,
'normalize': normalize,
'logit_scale': logit_scale,
'logit_margin': logit_margin,
'output': output,
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self.network = network
def get_config(self): def get_config(self):
return self._config return dict(self._config._asdict())
@classmethod @classmethod
def from_config(cls, config, custom_objects=None): def from_config(cls, config, custom_objects=None):
......
...@@ -100,12 +100,12 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -100,12 +100,12 @@ class ElectraPretrainer(tf.keras.Model):
output=output_type, output=output_type,
name='generator_masked_lm') name='generator_masked_lm')
self.classification = layers.ClassificationHead( self.classification = layers.ClassificationHead(
inner_dim=generator_network._config_dict['hidden_size'], inner_dim=generator_network.get_config()['hidden_size'],
num_classes=num_classes, num_classes=num_classes,
initializer=mlm_initializer, initializer=mlm_initializer,
name='generator_classification_head') name='generator_classification_head')
self.discriminator_projection = tf.keras.layers.Dense( self.discriminator_projection = tf.keras.layers.Dense(
units=discriminator_network._config_dict['hidden_size'], units=discriminator_network.get_config()['hidden_size'],
activation=mlm_activation, activation=mlm_activation,
kernel_initializer=mlm_initializer, kernel_initializer=mlm_initializer,
name='discriminator_projection_head') name='discriminator_projection_head')
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# ============================================================================== # ==============================================================================
"""ALBERT (https://arxiv.org/abs/1810.04805) text encoder network.""" """ALBERT (https://arxiv.org/abs/1810.04805) text encoder network."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
import collections
import tensorflow as tf import tensorflow as tf
from official.modeling import activations from official.modeling import activations
...@@ -81,22 +81,6 @@ class AlbertEncoder(tf.keras.Model): ...@@ -81,22 +81,6 @@ class AlbertEncoder(tf.keras.Model):
activation = tf.keras.activations.get(activation) activation = tf.keras.activations.get(activation)
initializer = tf.keras.initializers.get(initializer) initializer = tf.keras.initializers.get(initializer)
self._self_setattr_tracking = False
self._config_dict = {
'vocab_size': vocab_size,
'embedding_width': embedding_width,
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_attention_heads': num_attention_heads,
'max_sequence_length': max_sequence_length,
'type_vocab_size': type_vocab_size,
'intermediate_size': intermediate_size,
'activation': tf.keras.activations.serialize(activation),
'dropout_rate': dropout_rate,
'attention_dropout_rate': attention_dropout_rate,
'initializer': tf.keras.initializers.serialize(initializer),
}
word_ids = tf.keras.layers.Input( word_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_word_ids') shape=(None,), dtype=tf.int32, name='input_word_ids')
mask = tf.keras.layers.Input( mask = tf.keras.layers.Input(
...@@ -106,19 +90,19 @@ class AlbertEncoder(tf.keras.Model): ...@@ -106,19 +90,19 @@ class AlbertEncoder(tf.keras.Model):
if embedding_width is None: if embedding_width is None:
embedding_width = hidden_size embedding_width = hidden_size
self._embedding_layer = layers.OnDeviceEmbedding( embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
initializer=initializer, initializer=initializer,
name='word_embeddings') name='word_embeddings')
word_embeddings = self._embedding_layer(word_ids) word_embeddings = embedding_layer(word_ids)
# Always uses dynamic slicing for simplicity. # Always uses dynamic slicing for simplicity.
self._position_embedding_layer = keras_nlp.layers.PositionEmbedding( position_embedding_layer = keras_nlp.layers.PositionEmbedding(
initializer=initializer, initializer=initializer,
max_length=max_sequence_length, max_length=max_sequence_length,
name='position_embedding') name='position_embedding')
position_embeddings = self._position_embedding_layer(word_embeddings) position_embeddings = position_embedding_layer(word_embeddings)
type_embeddings = ( type_embeddings = (
layers.OnDeviceEmbedding( layers.OnDeviceEmbedding(
...@@ -182,14 +166,45 @@ class AlbertEncoder(tf.keras.Model): ...@@ -182,14 +166,45 @@ class AlbertEncoder(tf.keras.Model):
else: else:
outputs = [data, cls_output] outputs = [data, cls_output]
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super(AlbertEncoder, self).__init__( super(AlbertEncoder, self).__init__(
inputs=[word_ids, mask, type_ids], outputs=outputs, **kwargs) inputs=[word_ids, mask, type_ids], outputs=outputs, **kwargs)
config_dict = {
'vocab_size': vocab_size,
'embedding_width': embedding_width,
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_attention_heads': num_attention_heads,
'max_sequence_length': max_sequence_length,
'type_vocab_size': type_vocab_size,
'intermediate_size': intermediate_size,
'activation': tf.keras.activations.serialize(activation),
'dropout_rate': dropout_rate,
'attention_dropout_rate': attention_dropout_rate,
'initializer': tf.keras.initializers.serialize(initializer),
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self._embedding_layer = embedding_layer
self._position_embedding_layer = position_embedding_layer
def get_embedding_table(self): def get_embedding_table(self):
return self._embedding_layer.embeddings return self._embedding_layer.embeddings
def get_config(self): def get_config(self):
return self._config_dict return dict(self._config._asdict())
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# ============================================================================== # ==============================================================================
"""Transformer-based text encoder network.""" """Transformer-based text encoder network."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
import collections
import tensorflow as tf import tensorflow as tf
from official.modeling import activations from official.modeling import activations
...@@ -99,9 +99,13 @@ class BertEncoder(keras_nlp.encoders.BertEncoder): ...@@ -99,9 +99,13 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
dict_outputs=False, dict_outputs=False,
**kwargs): **kwargs):
self._self_setattr_tracking = False # b/164516224
self._embedding_layer_instance = embedding_layer # Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super(BertEncoder, self).__init__( super(BertEncoder, self).__init__(
vocab_size=vocab_size, vocab_size=vocab_size,
hidden_size=hidden_size, hidden_size=hidden_size,
...@@ -115,16 +119,21 @@ class BertEncoder(keras_nlp.encoders.BertEncoder): ...@@ -115,16 +119,21 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
attention_dropout=attention_dropout_rate, attention_dropout=attention_dropout_rate,
initializer=initializer, initializer=initializer,
output_range=output_range, output_range=output_range,
embedding_width=embedding_width) embedding_width=embedding_width,
embedding_layer=embedding_layer)
self._embedding_layer_instance = embedding_layer
# Replace arguments from keras_nlp.encoders.BertEncoder. # Replace arguments from keras_nlp.encoders.BertEncoder.
self._config_dict['activation'] = self._config_dict.pop('inner_activation') config_dict = self._config._asdict()
self._config_dict['intermediate_size'] = self._config_dict.pop('inner_dim') config_dict['activation'] = config_dict.pop('inner_activation')
self._config_dict['dropout_rate'] = self._config_dict.pop('output_dropout') config_dict['intermediate_size'] = config_dict.pop('inner_dim')
self._config_dict['attention_dropout_rate'] = self._config_dict.pop( config_dict['dropout_rate'] = config_dict.pop('output_dropout')
'attention_dropout') config_dict['attention_dropout_rate'] = config_dict.pop('attention_dropout')
self._config_dict['dict_outputs'] = dict_outputs config_dict['dict_outputs'] = dict_outputs
self._config_dict['return_all_encoder_outputs'] = return_all_encoder_outputs config_dict['return_all_encoder_outputs'] = return_all_encoder_outputs
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
if dict_outputs: if dict_outputs:
return return
...@@ -139,10 +148,3 @@ class BertEncoder(keras_nlp.encoders.BertEncoder): ...@@ -139,10 +148,3 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
outputs = [sequence_output, cls_output] outputs = [sequence_output, cls_output]
super(keras_nlp.encoders.BertEncoder, self).__init__( super(keras_nlp.encoders.BertEncoder, self).__init__(
inputs=self.inputs, outputs=outputs, **kwargs) inputs=self.inputs, outputs=outputs, **kwargs)
# Override method for shared embedding use case.
def _build_embedding_layer(self):
if self._embedding_layer_instance is None:
return super(BertEncoder, self)._build_embedding_layer()
else:
return self._embedding_layer_instance
...@@ -225,13 +225,15 @@ class BertEncoderTest(keras_parameterized.TestCase): ...@@ -225,13 +225,15 @@ class BertEncoderTest(keras_parameterized.TestCase):
return_all_encoder_outputs=False, return_all_encoder_outputs=False,
output_range=-1, output_range=-1,
embedding_width=16, embedding_width=16,
dict_outputs=True) dict_outputs=True,
embedding_layer=None)
network = bert_encoder.BertEncoder(**kwargs) network = bert_encoder.BertEncoder(**kwargs)
expected_config = dict(kwargs) expected_config = dict(kwargs)
expected_config["activation"] = tf.keras.activations.serialize( expected_config["activation"] = tf.keras.activations.serialize(
tf.keras.activations.get(expected_config["activation"])) tf.keras.activations.get(expected_config["activation"]))
expected_config["initializer"] = tf.keras.initializers.serialize( expected_config["initializer"] = tf.keras.initializers.serialize(
tf.keras.initializers.get(expected_config["initializer"])) tf.keras.initializers.get(expected_config["initializer"]))
self.assertEqual(network.get_config(), expected_config) self.assertEqual(network.get_config(), expected_config)
# Create another network object from the first object's config. # Create another network object from the first object's config.
new_network = bert_encoder.BertEncoder.from_config(network.get_config()) new_network = bert_encoder.BertEncoder.from_config(network.get_config())
......
...@@ -18,7 +18,7 @@ from __future__ import absolute_import ...@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
from __future__ import print_function from __future__ import print_function
import collections
import tensorflow as tf import tensorflow as tf
...@@ -49,36 +49,27 @@ class Classification(tf.keras.Model): ...@@ -49,36 +49,27 @@ class Classification(tf.keras.Model):
initializer='glorot_uniform', initializer='glorot_uniform',
output='logits', output='logits',
**kwargs): **kwargs):
self._self_setattr_tracking = False
self._config_dict = {
'input_width': input_width,
'num_classes': num_classes,
'initializer': initializer,
'output': output,
}
cls_output = tf.keras.layers.Input( cls_output = tf.keras.layers.Input(
shape=(input_width,), name='cls_output', dtype=tf.float32) shape=(input_width,), name='cls_output', dtype=tf.float32)
self.logits = tf.keras.layers.Dense( logits = tf.keras.layers.Dense(
num_classes, num_classes,
activation=None, activation=None,
kernel_initializer=initializer, kernel_initializer=initializer,
name='predictions/transform/logits')( name='predictions/transform/logits')(
cls_output) cls_output)
if output == 'logits':
output_tensors = logits
elif output == 'predictions':
policy = tf.keras.mixed_precision.experimental.global_policy() policy = tf.keras.mixed_precision.experimental.global_policy()
if policy.name == 'mixed_bfloat16': if policy.name == 'mixed_bfloat16':
# b/158514794: bf16 is not stable with post-softmax cross-entropy. # b/158514794: bf16 is not stable with post-softmax cross-entropy.
policy = tf.float32 policy = tf.float32
predictions = tf.keras.layers.Activation( output_tensors = tf.keras.layers.Activation(
tf.nn.log_softmax, dtype=policy)( tf.nn.log_softmax, dtype=policy)(
self.logits) logits)
if output == 'logits':
output_tensors = self.logits
elif output == 'predictions':
output_tensors = predictions
else: else:
raise ValueError( raise ValueError(
('Unknown `output` value "%s". `output` can be either "logits" or ' ('Unknown `output` value "%s". `output` can be either "logits" or '
...@@ -87,8 +78,30 @@ class Classification(tf.keras.Model): ...@@ -87,8 +78,30 @@ class Classification(tf.keras.Model):
super(Classification, self).__init__( super(Classification, self).__init__(
inputs=[cls_output], outputs=output_tensors, **kwargs) inputs=[cls_output], outputs=output_tensors, **kwargs)
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
config_dict = {
'input_width': input_width,
'num_classes': num_classes,
'initializer': initializer,
'output': output,
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self.logits = logits
def get_config(self): def get_config(self):
return self._config_dict return dict(self._config._asdict())
@classmethod @classmethod
def from_config(cls, config, custom_objects=None): def from_config(cls, config, custom_objects=None):
......
...@@ -109,29 +109,18 @@ class EncoderScaffold(tf.keras.Model): ...@@ -109,29 +109,18 @@ class EncoderScaffold(tf.keras.Model):
return_all_layer_outputs=False, return_all_layer_outputs=False,
dict_outputs=False, dict_outputs=False,
**kwargs): **kwargs):
self._self_setattr_tracking = False
self._hidden_cls = hidden_cls
self._hidden_cfg = hidden_cfg
self._num_hidden_instances = num_hidden_instances
self._pooled_output_dim = pooled_output_dim
self._pooler_layer_initializer = pooler_layer_initializer
self._embedding_cls = embedding_cls
self._embedding_cfg = embedding_cfg
self._embedding_data = embedding_data
self._return_all_layer_outputs = return_all_layer_outputs
self._dict_outputs = dict_outputs
self._kwargs = kwargs
if embedding_cls: if embedding_cls:
if inspect.isclass(embedding_cls): if inspect.isclass(embedding_cls):
self._embedding_network = embedding_cls( embedding_network = embedding_cls(
**embedding_cfg) if embedding_cfg else embedding_cls() **embedding_cfg) if embedding_cfg else embedding_cls()
else: else:
self._embedding_network = embedding_cls embedding_network = embedding_cls
inputs = self._embedding_network.inputs inputs = embedding_network.inputs
embeddings, attention_mask = self._embedding_network(inputs) embeddings, attention_mask = embedding_network(inputs)
embedding_layer = None
else: else:
self._embedding_network = None embedding_network = None
seq_length = embedding_cfg.get('seq_length', None) seq_length = embedding_cfg.get('seq_length', None)
word_ids = tf.keras.layers.Input( word_ids = tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='input_word_ids') shape=(seq_length,), dtype=tf.int32, name='input_word_ids')
...@@ -141,38 +130,38 @@ class EncoderScaffold(tf.keras.Model): ...@@ -141,38 +130,38 @@ class EncoderScaffold(tf.keras.Model):
shape=(seq_length,), dtype=tf.int32, name='input_type_ids') shape=(seq_length,), dtype=tf.int32, name='input_type_ids')
inputs = [word_ids, mask, type_ids] inputs = [word_ids, mask, type_ids]
self._embedding_layer = keras_nlp.layers.OnDeviceEmbedding( embedding_layer = keras_nlp.layers.OnDeviceEmbedding(
vocab_size=embedding_cfg['vocab_size'], vocab_size=embedding_cfg['vocab_size'],
embedding_width=embedding_cfg['hidden_size'], embedding_width=embedding_cfg['hidden_size'],
initializer=embedding_cfg['initializer'], initializer=embedding_cfg['initializer'],
name='word_embeddings') name='word_embeddings')
word_embeddings = self._embedding_layer(word_ids) word_embeddings = embedding_layer(word_ids)
# Always uses dynamic slicing for simplicity. # Always uses dynamic slicing for simplicity.
self._position_embedding_layer = keras_nlp.layers.PositionEmbedding( position_embedding_layer = keras_nlp.layers.PositionEmbedding(
initializer=embedding_cfg['initializer'], initializer=embedding_cfg['initializer'],
max_length=embedding_cfg['max_seq_length'], max_length=embedding_cfg['max_seq_length'],
name='position_embedding') name='position_embedding')
position_embeddings = self._position_embedding_layer(word_embeddings) position_embeddings = position_embedding_layer(word_embeddings)
self._type_embedding_layer = keras_nlp.layers.OnDeviceEmbedding( type_embedding_layer = keras_nlp.layers.OnDeviceEmbedding(
vocab_size=embedding_cfg['type_vocab_size'], vocab_size=embedding_cfg['type_vocab_size'],
embedding_width=embedding_cfg['hidden_size'], embedding_width=embedding_cfg['hidden_size'],
initializer=embedding_cfg['initializer'], initializer=embedding_cfg['initializer'],
use_one_hot=True, use_one_hot=True,
name='type_embeddings') name='type_embeddings')
type_embeddings = self._type_embedding_layer(type_ids) type_embeddings = type_embedding_layer(type_ids)
embeddings = tf.keras.layers.Add()( embeddings = tf.keras.layers.Add()(
[word_embeddings, position_embeddings, type_embeddings]) [word_embeddings, position_embeddings, type_embeddings])
self._embedding_norm_layer = tf.keras.layers.LayerNormalization( embedding_norm_layer = tf.keras.layers.LayerNormalization(
name='embeddings/layer_norm', name='embeddings/layer_norm',
axis=-1, axis=-1,
epsilon=1e-12, epsilon=1e-12,
dtype=tf.float32) dtype=tf.float32)
embeddings = self._embedding_norm_layer(embeddings) embeddings = embedding_norm_layer(embeddings)
embeddings = ( embeddings = (
tf.keras.layers.Dropout( tf.keras.layers.Dropout(
...@@ -183,7 +172,7 @@ class EncoderScaffold(tf.keras.Model): ...@@ -183,7 +172,7 @@ class EncoderScaffold(tf.keras.Model):
data = embeddings data = embeddings
layer_output_data = [] layer_output_data = []
self._hidden_layers = [] hidden_layers = []
for _ in range(num_hidden_instances): for _ in range(num_hidden_instances):
if inspect.isclass(hidden_cls): if inspect.isclass(hidden_cls):
layer = hidden_cls(**hidden_cfg) if hidden_cfg else hidden_cls() layer = hidden_cls(**hidden_cfg) if hidden_cfg else hidden_cls()
...@@ -191,19 +180,19 @@ class EncoderScaffold(tf.keras.Model): ...@@ -191,19 +180,19 @@ class EncoderScaffold(tf.keras.Model):
layer = hidden_cls layer = hidden_cls
data = layer([data, attention_mask]) data = layer([data, attention_mask])
layer_output_data.append(data) layer_output_data.append(data)
self._hidden_layers.append(layer) hidden_layers.append(layer)
last_layer_output = layer_output_data[-1] last_layer_output = layer_output_data[-1]
# Applying a tf.slice op (through subscript notation) to a Keras tensor # Applying a tf.slice op (through subscript notation) to a Keras tensor
# like this will create a SliceOpLambda layer. This is better than a Lambda # like this will create a SliceOpLambda layer. This is better than a Lambda
# layer with Python code, because that is fundamentally less portable. # layer with Python code, because that is fundamentally less portable.
first_token_tensor = last_layer_output[:, 0, :] first_token_tensor = last_layer_output[:, 0, :]
self._pooler_layer = tf.keras.layers.Dense( pooler_layer = tf.keras.layers.Dense(
units=pooled_output_dim, units=pooled_output_dim,
activation='tanh', activation='tanh',
kernel_initializer=pooler_layer_initializer, kernel_initializer=pooler_layer_initializer,
name='cls_transform') name='cls_transform')
cls_output = self._pooler_layer(first_token_tensor) cls_output = pooler_layer(first_token_tensor)
if dict_outputs: if dict_outputs:
outputs = dict( outputs = dict(
...@@ -216,9 +205,33 @@ class EncoderScaffold(tf.keras.Model): ...@@ -216,9 +205,33 @@ class EncoderScaffold(tf.keras.Model):
else: else:
outputs = [layer_output_data[-1], cls_output] outputs = [layer_output_data[-1], cls_output]
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super(EncoderScaffold, self).__init__( super(EncoderScaffold, self).__init__(
inputs=inputs, outputs=outputs, **kwargs) inputs=inputs, outputs=outputs, **kwargs)
self._hidden_cls = hidden_cls
self._hidden_cfg = hidden_cfg
self._num_hidden_instances = num_hidden_instances
self._pooled_output_dim = pooled_output_dim
self._pooler_layer_initializer = pooler_layer_initializer
self._embedding_cls = embedding_cls
self._embedding_cfg = embedding_cfg
self._embedding_data = embedding_data
self._return_all_layer_outputs = return_all_layer_outputs
self._dict_outputs = dict_outputs
self._kwargs = kwargs
self._embedding_layer = embedding_layer
self._embedding_network = embedding_network
self._hidden_layers = hidden_layers
self._pooler_layer = pooler_layer
logging.info('EncoderScaffold configs: %s', self.get_config()) logging.info('EncoderScaffold configs: %s', self.get_config())
def get_config(self): def get_config(self):
......
...@@ -18,7 +18,7 @@ from __future__ import absolute_import ...@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
from __future__ import print_function from __future__ import print_function
import collections
import tensorflow as tf import tensorflow as tf
...@@ -53,13 +53,6 @@ class SpanLabeling(tf.keras.Model): ...@@ -53,13 +53,6 @@ class SpanLabeling(tf.keras.Model):
initializer='glorot_uniform', initializer='glorot_uniform',
output='logits', output='logits',
**kwargs): **kwargs):
self._self_setattr_tracking = False
self._config = {
'input_width': input_width,
'activation': activation,
'initializer': initializer,
'output': output,
}
sequence_data = tf.keras.layers.Input( sequence_data = tf.keras.layers.Input(
shape=(None, input_width), name='sequence_data', dtype=tf.float32) shape=(None, input_width), name='sequence_data', dtype=tf.float32)
...@@ -70,16 +63,14 @@ class SpanLabeling(tf.keras.Model): ...@@ -70,16 +63,14 @@ class SpanLabeling(tf.keras.Model):
kernel_initializer=initializer, kernel_initializer=initializer,
name='predictions/transform/logits')( name='predictions/transform/logits')(
sequence_data) sequence_data)
self.start_logits, self.end_logits = ( start_logits, end_logits = self._split_output_tensor(intermediate_logits)
tf.keras.layers.Lambda(self._split_output_tensor)(intermediate_logits))
start_predictions = tf.keras.layers.Activation(tf.nn.log_softmax)( start_predictions = tf.keras.layers.Activation(tf.nn.log_softmax)(
self.start_logits) start_logits)
end_predictions = tf.keras.layers.Activation(tf.nn.log_softmax)( end_predictions = tf.keras.layers.Activation(tf.nn.log_softmax)(end_logits)
self.end_logits)
if output == 'logits': if output == 'logits':
output_tensors = [self.start_logits, self.end_logits] output_tensors = [start_logits, end_logits]
elif output == 'predictions': elif output == 'predictions':
output_tensors = [start_predictions, end_predictions] output_tensors = [start_predictions, end_predictions]
else: else:
...@@ -87,15 +78,37 @@ class SpanLabeling(tf.keras.Model): ...@@ -87,15 +78,37 @@ class SpanLabeling(tf.keras.Model):
('Unknown `output` value "%s". `output` can be either "logits" or ' ('Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"') % output) '"predictions"') % output)
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super(SpanLabeling, self).__init__( super(SpanLabeling, self).__init__(
inputs=[sequence_data], outputs=output_tensors, **kwargs) inputs=[sequence_data], outputs=output_tensors, **kwargs)
config_dict = {
'input_width': input_width,
'activation': activation,
'initializer': initializer,
'output': output,
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self.start_logits = start_logits
self.end_logits = end_logits
def _split_output_tensor(self, tensor): def _split_output_tensor(self, tensor):
transposed_tensor = tf.transpose(tensor, [2, 0, 1]) transposed_tensor = tf.transpose(tensor, [2, 0, 1])
return tf.unstack(transposed_tensor) return tf.unstack(transposed_tensor)
def get_config(self): def get_config(self):
return self._config return dict(self._config._asdict())
@classmethod @classmethod
def from_config(cls, config, custom_objects=None): def from_config(cls, config, custom_objects=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment