Commit 802488f1 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 316593329
parent cabb22cd
...@@ -230,9 +230,10 @@ def pretrain_model(bert_config, ...@@ -230,9 +230,10 @@ def pretrain_model(bert_config,
initializer=initializer, initializer=initializer,
output='predictions') output='predictions')
lm_output, sentence_output = pretrainer_model( outputs = pretrainer_model(
[input_word_ids, input_mask, input_type_ids, masked_lm_positions]) [input_word_ids, input_mask, input_type_ids, masked_lm_positions])
lm_output = outputs['masked_lm']
sentence_output = outputs['classification']
pretrain_loss_layer = BertPretrainLossAndMetricLayer( pretrain_loss_layer = BertPretrainLossAndMetricLayer(
vocab_size=bert_config.vocab_size) vocab_size=bert_config.vocab_size)
output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids, output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids,
......
...@@ -45,6 +45,9 @@ assemble new layers, networks, or models. ...@@ -45,6 +45,9 @@ assemble new layers, networks, or models.
should be masked), the output will have masked positions set to should be masked), the output will have masked positions set to
approximately zero. approximately zero.
* [`MaskedLM`](masked_lm.py) implements a masked language model. It assumes the
embedding table variable is passed to it.
* [ClassificationHead](cls_head.py) A pooling head over a sequence of * [ClassificationHead](cls_head.py) A pooling head over a sequence of
embeddings, commonly used by classification tasks. embeddings, commonly used by classification tasks.
......
...@@ -18,6 +18,7 @@ from official.nlp.modeling.layers.attention import * ...@@ -18,6 +18,7 @@ from official.nlp.modeling.layers.attention import *
from official.nlp.modeling.layers.cls_head import * from official.nlp.modeling.layers.cls_head import *
from official.nlp.modeling.layers.dense_einsum import DenseEinsum from official.nlp.modeling.layers.dense_einsum import DenseEinsum
from official.nlp.modeling.layers.gated_feedforward import GatedFeedforward from official.nlp.modeling.layers.gated_feedforward import GatedFeedforward
from official.nlp.modeling.layers.masked_lm import MaskedLM
from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax
from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding
from official.nlp.modeling.layers.position_embedding import PositionEmbedding from official.nlp.modeling.layers.position_embedding import PositionEmbedding
......
...@@ -25,91 +25,74 @@ from official.modeling import tf_utils ...@@ -25,91 +25,74 @@ from official.modeling import tf_utils
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
class MaskedLM(tf.keras.Model): class MaskedLM(tf.keras.layers.Layer):
"""Masked language model network head for BERT modeling. """Masked language model network head for BERT modeling.
This network implements a masked language model based on the provided network. This network implements a masked language model based on the provided network.
It assumes that the network being passed has a "get_embedding_table()" method. It assumes that the network being passed has a "get_embedding_table()" method.
Arguments: Arguments:
input_width: The innermost dimension of the input tensor to this network. embedding_table: The embedding table of the targets.
num_predictions: The number of predictions to make per sequence. activation: The activation, if any, for the dense layer.
source_network: The network with the embedding layer to use for the initializer: The intializer for the dense layer. Defaults to a Glorot
embedding layer. uniform initializer.
embedding_table: The embedding table of a source network, If None, the
`source_network.get_embedding_table()` method is used.
activation: The activation, if any, for the dense layer in this network.
initializer: The intializer for the dense layer in this network. Defaults to
a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or output: The output style for this network. Can be either 'logits' or
'predictions'. 'predictions'.
""" """
def __init__(self, def __init__(self,
input_width, embedding_table,
num_predictions,
source_network,
embedding_table=None,
activation=None, activation=None,
initializer='glorot_uniform', initializer='glorot_uniform',
output='logits', output='logits',
name='cls/predictions',
**kwargs): **kwargs):
super(MaskedLM, self).__init__(name=name, **kwargs)
self.embedding_table = embedding_table
self.activation = activation
self.initializer = tf.keras.initializers.get(initializer)
if embedding_table is None: if output not in ('predictions', 'logits'):
embedding_table = source_network.get_embedding_table()
vocab_size, hidden_size = embedding_table.shape
sequence_data = tf.keras.layers.Input(
shape=(None, input_width), name='sequence_data', dtype=tf.float32)
masked_lm_positions = tf.keras.layers.Input(
shape=(num_predictions,), name='masked_lm_positions', dtype=tf.int32)
masked_lm_input = tf.keras.layers.Lambda(
lambda x: self._gather_indexes(x[0], x[1]))(
[sequence_data, masked_lm_positions])
lm_data = (
tf.keras.layers.Dense(
hidden_size,
activation=activation,
kernel_initializer=initializer,
name='cls/predictions/transform/dense')(masked_lm_input))
lm_data = tf.keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12, name='cls/predictions/transform/LayerNorm')(
lm_data)
lm_data = tf.keras.layers.Lambda(
lambda x: tf.matmul(x, embedding_table, transpose_b=True))(
lm_data)
logits = Bias(
initializer=tf.keras.initializers.Zeros(),
name='cls/predictions/output_bias')(
lm_data)
# We can't use the standard Keras reshape layer here, since it expects
# the input and output batch size to be the same.
reshape_layer = tf.keras.layers.Lambda(
lambda x: tf.reshape(x, [-1, num_predictions, vocab_size]))
self.logits = reshape_layer(logits)
predictions = tf.keras.layers.Activation(tf.nn.log_softmax)(self.logits)
if output == 'logits':
output_tensors = self.logits
elif output == 'predictions':
output_tensors = predictions
else:
raise ValueError( raise ValueError(
('Unknown `output` value "%s". `output` can be either "logits" or ' ('Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"') % output) '"predictions"') % output)
self._output_type = output
super(MaskedLM, self).__init__( def build(self, input_shape):
inputs=[sequence_data, masked_lm_positions], self._vocab_size, hidden_size = self.embedding_table.shape
outputs=output_tensors, self.dense = tf.keras.layers.Dense(
**kwargs) hidden_size,
activation=self.activation,
kernel_initializer=self.initializer,
name='transform/dense')
self.layer_norm = tf.keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12, name='transform/LayerNorm')
self.bias = self.add_weight(
'output_bias/bias',
shape=(self._vocab_size,),
initializer='zeros',
trainable=True)
super(MaskedLM, self).build(input_shape)
def call(self, sequence_data, masked_positions):
masked_lm_input = self._gather_indexes(sequence_data, masked_positions)
lm_data = self.dense(masked_lm_input)
lm_data = self.layer_norm(lm_data)
lm_data = tf.matmul(lm_data, self.embedding_table, transpose_b=True)
logits = tf.nn.bias_add(lm_data, self.bias)
masked_positions_shape = tf_utils.get_shape_list(
masked_positions, name='masked_positions_tensor')
logits = tf.reshape(logits,
[-1, masked_positions_shape[1], self._vocab_size])
if self._output_type == 'logits':
return logits
return tf.nn.log_softmax(logits)
def get_config(self): def get_config(self):
raise NotImplementedError('MaskedLM cannot be directly serialized at this ' raise NotImplementedError('MaskedLM cannot be directly serialized because '
'time. Please use it only in Layers or ' 'it has variable sharing logic.')
'functionally subclassed Models/Networks.')
def _gather_indexes(self, sequence_tensor, positions): def _gather_indexes(self, sequence_tensor, positions):
"""Gathers the vectors at the specific positions. """Gathers the vectors at the specific positions.
...@@ -139,51 +122,3 @@ class MaskedLM(tf.keras.Model): ...@@ -139,51 +122,3 @@ class MaskedLM(tf.keras.Model):
output_tensor = tf.gather(flat_sequence_tensor, flat_positions) output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
return output_tensor return output_tensor
@tf.keras.utils.register_keras_serializable(package='Text')
# Temporary until we can create a Dense layer that ties the embedding.
class Bias(tf.keras.layers.Layer):
"""Adds a bias term to an input."""
def __init__(self,
initializer='zeros',
regularizer=None,
constraint=None,
activation=None,
**kwargs):
super(Bias, self).__init__(**kwargs)
self._initializer = tf.keras.initializers.get(initializer)
self._regularizer = tf.keras.regularizers.get(regularizer)
self._constraint = tf.keras.constraints.get(constraint)
self._activation = tf.keras.activations.get(activation)
def build(self, input_shape):
input_shape = tf.TensorShape(input_shape)
self._bias = self.add_weight(
'bias',
shape=input_shape[1:],
initializer=self._initializer,
regularizer=self._regularizer,
constraint=self._constraint,
dtype=self._dtype,
trainable=True)
super(Bias, self).build(input_shape)
def get_config(self):
config = {
'activation': tf.keras.activations.serialize(self._activation),
'initializer': tf.keras.initializers.serialize(self._initializer),
'regularizer': tf.keras.regularizers.serialize(self._regularizer),
'constraint': tf.keras.constraints.serialize(self._constraint)
}
base_config = super(Bias, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
outputs = tf.nn.bias_add(inputs, self._bias)
if self._activation is not None:
return self._activation(outputs) # pylint: disable=not-callable
else:
return outputs
...@@ -23,7 +23,7 @@ import tensorflow as tf ...@@ -23,7 +23,7 @@ import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling.networks import masked_lm from official.nlp.modeling.layers import masked_lm
from official.nlp.modeling.networks import transformer_encoder from official.nlp.modeling.networks import transformer_encoder
...@@ -32,13 +32,12 @@ from official.nlp.modeling.networks import transformer_encoder ...@@ -32,13 +32,12 @@ from official.nlp.modeling.networks import transformer_encoder
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class MaskedLMTest(keras_parameterized.TestCase): class MaskedLMTest(keras_parameterized.TestCase):
def create_network(self, def create_layer(self,
vocab_size, vocab_size,
sequence_length, sequence_length,
hidden_size, hidden_size,
num_predictions, output='predictions',
output='predictions', xformer_stack=None):
xformer_stack=None):
# First, create a transformer stack that we can use to get the LM's # First, create a transformer stack that we can use to get the LM's
# vocabulary weight. # vocabulary weight.
if xformer_stack is None: if xformer_stack is None:
...@@ -49,82 +48,32 @@ class MaskedLMTest(keras_parameterized.TestCase): ...@@ -49,82 +48,32 @@ class MaskedLMTest(keras_parameterized.TestCase):
hidden_size=hidden_size, hidden_size=hidden_size,
num_attention_heads=4, num_attention_heads=4,
) )
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
lm_outputs, _ = xformer_stack([word_ids, mask, type_ids])
# Create a maskedLM from the transformer stack. # Create a maskedLM from the transformer stack.
test_network = masked_lm.MaskedLM( test_layer = masked_lm.MaskedLM(
num_predictions=num_predictions, embedding_table=xformer_stack.get_embedding_table(),
input_width=lm_outputs.shape[-1],
source_network=xformer_stack,
output=output) output=output)
return test_network return test_layer
def test_network_creation(self): def test_layer_creation(self):
vocab_size = 100 vocab_size = 100
sequence_length = 32 sequence_length = 32
hidden_size = 64 hidden_size = 64
num_predictions = 21 num_predictions = 21
test_network = self.create_network( test_layer = self.create_layer(
vocab_size=vocab_size, vocab_size=vocab_size,
sequence_length=sequence_length, sequence_length=sequence_length,
hidden_size=hidden_size, hidden_size=hidden_size)
num_predictions=num_predictions)
# Make sure that the output tensor of the masked LM is the right shape. # Make sure that the output tensor of the masked LM is the right shape.
lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size)) lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
masked_lm_positions = tf.keras.Input( masked_positions = tf.keras.Input(shape=(num_predictions,), dtype=tf.int32)
shape=(num_predictions,), dtype=tf.int32) output = test_layer(lm_input_tensor, masked_positions=masked_positions)
output = test_network([lm_input_tensor, masked_lm_positions])
expected_output_shape = [None, num_predictions, vocab_size] expected_output_shape = [None, num_predictions, vocab_size]
self.assertEqual(expected_output_shape, output.shape.as_list()) self.assertEqual(expected_output_shape, output.shape.as_list())
def test_network_invocation_with_internal_logits(self): def test_layer_invocation_with_external_logits(self):
vocab_size = 100
sequence_length = 32
hidden_size = 64
num_predictions = 21
test_network = self.create_network(
vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size,
num_predictions=num_predictions)
# Create a model from the masked LM layer.
lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
masked_lm_positions = tf.keras.Input(
shape=(num_predictions,), dtype=tf.int32)
output = test_network([lm_input_tensor, masked_lm_positions])
model = tf.keras.Model([lm_input_tensor, masked_lm_positions], output)
logits_model = tf.keras.Model(test_network.inputs, test_network.logits)
# Invoke the masked LM on some fake data to make sure there are no runtime
# errors in the code.
batch_size = 3
lm_input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, hidden_size))
masked_position_data = np.random.randint(
2, size=(batch_size, num_predictions))
outputs = model.predict([lm_input_data, masked_position_data])
logits = logits_model.predict([lm_input_data, masked_position_data])
# Ensure that the tensor shapes are correct.
expected_output_shape = (batch_size, num_predictions, vocab_size)
self.assertEqual(expected_output_shape, outputs.shape)
self.assertEqual(expected_output_shape, logits.shape)
# Ensure that the logits, when softmaxed, create the outputs.
input_tensor = tf.keras.Input(expected_output_shape[1:])
output_tensor = tf.keras.layers.Activation(tf.nn.log_softmax)(input_tensor)
softmax_model = tf.keras.Model(input_tensor, output_tensor)
calculated_softmax = softmax_model.predict(logits)
self.assertAllClose(outputs, calculated_softmax)
def test_network_invocation_with_external_logits(self):
vocab_size = 100 vocab_size = 100
sequence_length = 32 sequence_length = 32
hidden_size = 64 hidden_size = 64
...@@ -136,31 +85,28 @@ class MaskedLMTest(keras_parameterized.TestCase): ...@@ -136,31 +85,28 @@ class MaskedLMTest(keras_parameterized.TestCase):
hidden_size=hidden_size, hidden_size=hidden_size,
num_attention_heads=4, num_attention_heads=4,
) )
test_network = self.create_network( test_layer = self.create_layer(
vocab_size=vocab_size, vocab_size=vocab_size,
sequence_length=sequence_length, sequence_length=sequence_length,
hidden_size=hidden_size, hidden_size=hidden_size,
num_predictions=num_predictions,
xformer_stack=xformer_stack, xformer_stack=xformer_stack,
output='predictions') output='predictions')
logit_network = self.create_network( logit_layer = self.create_layer(
vocab_size=vocab_size, vocab_size=vocab_size,
sequence_length=sequence_length, sequence_length=sequence_length,
hidden_size=hidden_size, hidden_size=hidden_size,
num_predictions=num_predictions,
xformer_stack=xformer_stack, xformer_stack=xformer_stack,
output='logits') output='logits')
logit_network.set_weights(test_network.get_weights())
# Create a model from the masked LM layer. # Create a model from the masked LM layer.
lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size)) lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
masked_lm_positions = tf.keras.Input( masked_positions = tf.keras.Input(shape=(num_predictions,), dtype=tf.int32)
shape=(num_predictions,), dtype=tf.int32) output = test_layer(lm_input_tensor, masked_positions)
output = test_network([lm_input_tensor, masked_lm_positions]) logit_output = logit_layer(lm_input_tensor, masked_positions)
logit_output = logit_network([lm_input_tensor, masked_lm_positions]) logit_output = tf.keras.layers.Activation(tf.nn.log_softmax)(logit_output)
logit_layer.set_weights(test_layer.get_weights())
model = tf.keras.Model([lm_input_tensor, masked_lm_positions], output) model = tf.keras.Model([lm_input_tensor, masked_positions], output)
logits_model = tf.keras.Model(([lm_input_tensor, masked_lm_positions]), logits_model = tf.keras.Model(([lm_input_tensor, masked_positions]),
logit_output) logit_output)
# Invoke the masked LM on some fake data to make sure there are no runtime # Invoke the masked LM on some fake data to make sure there are no runtime
...@@ -169,40 +115,33 @@ class MaskedLMTest(keras_parameterized.TestCase): ...@@ -169,40 +115,33 @@ class MaskedLMTest(keras_parameterized.TestCase):
lm_input_data = 10 * np.random.random_sample( lm_input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, hidden_size)) (batch_size, sequence_length, hidden_size))
masked_position_data = np.random.randint( masked_position_data = np.random.randint(
2, size=(batch_size, num_predictions)) sequence_length, size=(batch_size, num_predictions))
outputs = model.predict([lm_input_data, masked_position_data]) # ref_outputs = model.predict([lm_input_data, masked_position_data])
logits = logits_model.predict([lm_input_data, masked_position_data]) # outputs = logits_model.predict([lm_input_data, masked_position_data])
ref_outputs = model([lm_input_data, masked_position_data])
outputs = logits_model([lm_input_data, masked_position_data])
# Ensure that the tensor shapes are correct. # Ensure that the tensor shapes are correct.
expected_output_shape = (batch_size, num_predictions, vocab_size) expected_output_shape = (batch_size, num_predictions, vocab_size)
self.assertEqual(expected_output_shape, ref_outputs.shape)
self.assertEqual(expected_output_shape, outputs.shape) self.assertEqual(expected_output_shape, outputs.shape)
self.assertEqual(expected_output_shape, logits.shape) self.assertAllClose(ref_outputs, outputs)
# Ensure that the logits, when softmaxed, create the outputs. def test_layer_invocation(self):
input_tensor = tf.keras.Input(expected_output_shape[1:])
output_tensor = tf.keras.layers.Activation(tf.nn.log_softmax)(input_tensor)
softmax_model = tf.keras.Model(input_tensor, output_tensor)
calculated_softmax = softmax_model.predict(logits)
self.assertAllClose(outputs, calculated_softmax)
def test_network_invocation(self):
vocab_size = 100 vocab_size = 100
sequence_length = 32 sequence_length = 32
hidden_size = 64 hidden_size = 64
num_predictions = 21 num_predictions = 21
test_network = self.create_network( test_layer = self.create_layer(
vocab_size=vocab_size, vocab_size=vocab_size,
sequence_length=sequence_length, sequence_length=sequence_length,
hidden_size=hidden_size, hidden_size=hidden_size)
num_predictions=num_predictions)
# Create a model from the masked LM layer. # Create a model from the masked LM layer.
lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size)) lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
masked_lm_positions = tf.keras.Input( masked_positions = tf.keras.Input(shape=(num_predictions,), dtype=tf.int32)
shape=(num_predictions,), dtype=tf.int32) output = test_layer(lm_input_tensor, masked_positions)
output = test_network([lm_input_tensor, masked_lm_positions]) model = tf.keras.Model([lm_input_tensor, masked_positions], output)
model = tf.keras.Model([lm_input_tensor, masked_lm_positions], output)
# Invoke the masked LM on some fake data to make sure there are no runtime # Invoke the masked LM on some fake data to make sure there are no runtime
# errors in the code. # errors in the code.
...@@ -215,12 +154,8 @@ class MaskedLMTest(keras_parameterized.TestCase): ...@@ -215,12 +154,8 @@ class MaskedLMTest(keras_parameterized.TestCase):
def test_unknown_output_type_fails(self): def test_unknown_output_type_fails(self):
with self.assertRaisesRegex(ValueError, 'Unknown `output` value "bad".*'): with self.assertRaisesRegex(ValueError, 'Unknown `output` value "bad".*'):
_ = self.create_network( _ = self.create_layer(
vocab_size=8, vocab_size=8, sequence_length=8, hidden_size=8, output='bad')
sequence_length=8,
hidden_size=8,
num_predictions=8,
output='bad')
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -23,6 +23,7 @@ import numpy as np ...@@ -23,6 +23,7 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling import layers
from official.nlp.modeling import networks from official.nlp.modeling import networks
from official.nlp.modeling.losses import weighted_sparse_categorical_crossentropy from official.nlp.modeling.losses import weighted_sparse_categorical_crossentropy
...@@ -48,20 +49,18 @@ class ClassificationLossTest(keras_parameterized.TestCase): ...@@ -48,20 +49,18 @@ class ClassificationLossTest(keras_parameterized.TestCase):
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
lm_outputs, _ = xformer_stack([word_ids, mask, type_ids]) _ = xformer_stack([word_ids, mask, type_ids])
# Create a maskedLM from the transformer stack. # Create a maskedLM from the transformer stack.
test_network = networks.MaskedLM( test_layer = layers.MaskedLM(
num_predictions=num_predictions, embedding_table=xformer_stack.get_embedding_table(),
input_width=lm_outputs.shape[-1],
source_network=xformer_stack,
output=output) output=output)
# Create a model from the masked LM layer. # Create a model from the masked LM layer.
lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size)) lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
masked_lm_positions = tf.keras.Input( masked_lm_positions = tf.keras.Input(
shape=(num_predictions,), dtype=tf.int32) shape=(num_predictions,), dtype=tf.int32)
output = test_network([lm_input_tensor, masked_lm_positions]) output = test_layer(lm_input_tensor, masked_positions=masked_lm_positions)
return tf.keras.Model([lm_input_tensor, masked_lm_positions], output) return tf.keras.Model([lm_input_tensor, masked_lm_positions], output)
def create_classification_model(self, input_width, num_classes): def create_classification_model(self, input_width, num_classes):
......
...@@ -25,6 +25,7 @@ from typing import List, Optional ...@@ -25,6 +25,7 @@ from typing import List, Optional
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling import layers
from official.nlp.modeling import networks from official.nlp.modeling import networks
...@@ -47,8 +48,8 @@ class BertPretrainer(tf.keras.Model): ...@@ -47,8 +48,8 @@ class BertPretrainer(tf.keras.Model):
num_token_predictions: Number of tokens to predict from the masked LM. num_token_predictions: Number of tokens to predict from the masked LM.
embedding_table: Embedding table of a network. If None, the embedding_table: Embedding table of a network. If None, the
"network.get_embedding_table()" is used. "network.get_embedding_table()" is used.
activation: The activation (if any) to use in the masked LM network. activation: The activation (if any) to use in the masked LM network. If
If None, no activation will be used. None, no activation will be used.
initializer: The initializer (if any) to use in the masked LM and initializer: The initializer (if any) to use in the masked LM and
classification networks. Defaults to a Glorot uniform initializer. classification networks. Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or output: The output style for this network. Can be either 'logits' or
...@@ -106,16 +107,16 @@ class BertPretrainer(tf.keras.Model): ...@@ -106,16 +107,16 @@ class BertPretrainer(tf.keras.Model):
dtype=tf.int32) dtype=tf.int32)
inputs.append(masked_lm_positions) inputs.append(masked_lm_positions)
self.masked_lm = networks.MaskedLM( if embedding_table is None:
num_predictions=num_token_predictions, embedding_table = self.encoder.get_embedding_table()
input_width=sequence_output.shape[-1], self.masked_lm = layers.MaskedLM(
source_network=network,
embedding_table=embedding_table, embedding_table=embedding_table,
activation=activation, activation=activation,
initializer=initializer, initializer=initializer,
output=output, output=output,
name='masked_lm') name='cls/predictions')
lm_outputs = self.masked_lm([sequence_output, masked_lm_positions]) lm_outputs = self.masked_lm(
sequence_output, masked_positions=masked_lm_positions)
self.classification = networks.Classification( self.classification = networks.Classification(
input_width=cls_output.shape[-1], input_width=cls_output.shape[-1],
...@@ -126,7 +127,9 @@ class BertPretrainer(tf.keras.Model): ...@@ -126,7 +127,9 @@ class BertPretrainer(tf.keras.Model):
sentence_outputs = self.classification(cls_output) sentence_outputs = self.classification(cls_output)
super(BertPretrainer, self).__init__( super(BertPretrainer, self).__init__(
inputs=inputs, outputs=[lm_outputs, sentence_outputs], **kwargs) inputs=inputs,
outputs=dict(masked_lm=lm_outputs, classification=sentence_outputs),
**kwargs)
def get_config(self): def get_config(self):
return self._config return self._config
...@@ -151,8 +154,8 @@ class BertPretrainerV2(tf.keras.Model): ...@@ -151,8 +154,8 @@ class BertPretrainerV2(tf.keras.Model):
num_masked_tokens: Number of tokens to predict from the masked LM. num_masked_tokens: Number of tokens to predict from the masked LM.
encoder_network: A transformer network. This network should output a encoder_network: A transformer network. This network should output a
sequence output and a classification output. sequence output and a classification output.
mlm_activation: The activation (if any) to use in the masked LM network. mlm_activation: The activation (if any) to use in the masked LM network. If
If None, no activation will be used. None, no activation will be used.
mlm_initializer: The initializer (if any) to use in the masked LM. Default mlm_initializer: The initializer (if any) to use in the masked LM. Default
to a Glorot uniform initializer. to a Glorot uniform initializer.
classification_heads: A list of optional head layers to transform on encoder classification_heads: A list of optional head layers to transform on encoder
...@@ -193,17 +196,18 @@ class BertPretrainerV2(tf.keras.Model): ...@@ -193,17 +196,18 @@ class BertPretrainerV2(tf.keras.Model):
outputs = dict() outputs = dict()
if num_masked_tokens > 0: if num_masked_tokens > 0:
self.masked_lm = networks.MaskedLM( self.masked_lm = layers.MaskedLM(
num_predictions=num_masked_tokens, embedding_table=self.encoder_network.get_embedding_table(),
input_width=sequence_output.shape[-1],
source_network=self.encoder_network,
activation=mlm_activation, activation=mlm_activation,
initializer=mlm_initializer, initializer=mlm_initializer,
name='masked_lm') name='cls/predictions')
masked_lm_positions = copy.copy(self.masked_lm.inputs[-1]) masked_lm_positions = tf.keras.layers.Input(
shape=(num_masked_tokens,),
name='masked_lm_positions',
dtype=tf.int32)
inputs.append(masked_lm_positions) inputs.append(masked_lm_positions)
outputs['lm_output'] = self.masked_lm( outputs['lm_output'] = self.masked_lm(
[sequence_output, masked_lm_positions]) sequence_output, masked_positions=masked_lm_positions)
for cls_head in self.classification_heads: for cls_head in self.classification_heads:
outputs[cls_head.name] = cls_head(sequence_output) outputs[cls_head.name] = cls_head(sequence_output)
......
...@@ -50,16 +50,19 @@ class BertPretrainerTest(keras_parameterized.TestCase): ...@@ -50,16 +50,19 @@ class BertPretrainerTest(keras_parameterized.TestCase):
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
lm_mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) masked_lm_positions = tf.keras.Input(
shape=(num_token_predictions,), dtype=tf.int32)
# Invoke the trainer model on the inputs. This causes the layer to be built. # Invoke the trainer model on the inputs. This causes the layer to be built.
lm_outs, cls_outs = bert_trainer_model([word_ids, mask, type_ids, lm_mask]) outputs = bert_trainer_model(
[word_ids, mask, type_ids, masked_lm_positions])
# Validate that the outputs are of the expected shape. # Validate that the outputs are of the expected shape.
expected_lm_shape = [None, num_token_predictions, vocab_size] expected_lm_shape = [None, num_token_predictions, vocab_size]
expected_classification_shape = [None, num_classes] expected_classification_shape = [None, num_classes]
self.assertAllEqual(expected_lm_shape, lm_outs.shape.as_list()) self.assertAllEqual(expected_lm_shape, outputs['masked_lm'].shape.as_list())
self.assertAllEqual(expected_classification_shape, cls_outs.shape.as_list()) self.assertAllEqual(expected_classification_shape,
outputs['classification'].shape.as_list())
def test_bert_trainer_tensor_call(self): def test_bert_trainer_tensor_call(self):
"""Validate that the Keras object can be invoked.""" """Validate that the Keras object can be invoked."""
...@@ -81,7 +84,7 @@ class BertPretrainerTest(keras_parameterized.TestCase): ...@@ -81,7 +84,7 @@ class BertPretrainerTest(keras_parameterized.TestCase):
# Invoke the trainer model on the tensors. In Eager mode, this does the # Invoke the trainer model on the tensors. In Eager mode, this does the
# actual calculation. (We can't validate the outputs, since the network is # actual calculation. (We can't validate the outputs, since the network is
# too complex: this simply ensures we're not hitting runtime errors.) # too complex: this simply ensures we're not hitting runtime errors.)
_, _ = bert_trainer_model([word_ids, mask, type_ids, lm_mask]) _ = bert_trainer_model([word_ids, mask, type_ids, lm_mask])
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
"""Validate that the BERT trainer can be serialized and deserialized.""" """Validate that the BERT trainer can be serialized and deserialized."""
...@@ -123,7 +126,7 @@ class BertPretrainerTest(keras_parameterized.TestCase): ...@@ -123,7 +126,7 @@ class BertPretrainerTest(keras_parameterized.TestCase):
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
lm_mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) lm_mask = tf.keras.Input(shape=(num_token_predictions,), dtype=tf.int32)
# Invoke the trainer model on the inputs. This causes the layer to be built. # Invoke the trainer model on the inputs. This causes the layer to be built.
outputs = bert_trainer_model([word_ids, mask, type_ids, lm_mask]) outputs = bert_trainer_model([word_ids, mask, type_ids, lm_mask])
......
...@@ -16,8 +16,6 @@ Self-supervised Learning of Language Representations] ...@@ -16,8 +16,6 @@ Self-supervised Learning of Language Representations]
(https://arxiv.org/abs/1909.11942). Compared with [BERT](https://arxiv.org/abs/1810.04805), ALBERT refactorizes embedding parameters (https://arxiv.org/abs/1909.11942). Compared with [BERT](https://arxiv.org/abs/1810.04805), ALBERT refactorizes embedding parameters
into two smaller matrices and shares parameters across layers. into two smaller matrices and shares parameters across layers.
* [`MaskedLM`](masked_lm.py) implements a masked language model for BERT pretraining. It assumes that the network being passed has a `get_embedding_table()` method.
* [`Classification`](classification.py) contains a single hidden layer, and is * [`Classification`](classification.py) contains a single hidden layer, and is
intended for use as a classification or regression (if number of classes is set intended for use as a classification or regression (if number of classes is set
to 1) head. to 1) head.
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
from official.nlp.modeling.networks.albert_transformer_encoder import AlbertTransformerEncoder from official.nlp.modeling.networks.albert_transformer_encoder import AlbertTransformerEncoder
from official.nlp.modeling.networks.classification import Classification from official.nlp.modeling.networks.classification import Classification
from official.nlp.modeling.networks.encoder_scaffold import EncoderScaffold from official.nlp.modeling.networks.encoder_scaffold import EncoderScaffold
from official.nlp.modeling.networks.masked_lm import MaskedLM
from official.nlp.modeling.networks.span_labeling import SpanLabeling from official.nlp.modeling.networks.span_labeling import SpanLabeling
from official.nlp.modeling.networks.token_classification import TokenClassification from official.nlp.modeling.networks.token_classification import TokenClassification
from official.nlp.modeling.networks.transformer_encoder import TransformerEncoder from official.nlp.modeling.networks.transformer_encoder import TransformerEncoder
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment