Commit 87d7e974 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Move MobileBertMaskedLM to official/nlp/modeling/layers/mobile_bert_layers.py

PiperOrigin-RevId: 346826537
parent a4b767b7
......@@ -22,6 +22,7 @@ from official.nlp.modeling.layers.masked_lm import MaskedLM
from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax
from official.nlp.modeling.layers.mat_mul_with_margin import MatMulWithMargin
from official.nlp.modeling.layers.mobile_bert_layers import MobileBertEmbedding
from official.nlp.modeling.layers.mobile_bert_layers import MobileBertMaskedLM
from official.nlp.modeling.layers.mobile_bert_layers import MobileBertTransformer
from official.nlp.modeling.layers.multi_channel_attention import *
from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding
......
......@@ -18,6 +18,7 @@ import tensorflow as tf
from official.nlp import keras_nlp
@tf.keras.utils.register_keras_serializable(package='Text')
class NoNorm(tf.keras.layers.Layer):
"""Apply element-wise linear transformation to the last dimension."""
......@@ -62,6 +63,7 @@ def _get_norm_layer(normalization_type='no_norm', name=None):
return layer
@tf.keras.utils.register_keras_serializable(package='Text')
class MobileBertEmbedding(tf.keras.layers.Layer):
"""Performs an embedding lookup for MobileBERT.
......@@ -163,6 +165,7 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
return embedding_out
@tf.keras.utils.register_keras_serializable(package='Text')
class MobileBertTransformer(tf.keras.layers.Layer):
"""Transformer block for MobileBERT.
......@@ -422,3 +425,129 @@ class MobileBertTransformer(tf.keras.layers.Layer):
return layer_output, attention_scores
else:
return layer_output
@tf.keras.utils.register_keras_serializable(package='Text')
class MobileBertMaskedLM(tf.keras.layers.Layer):
"""Masked language model network head for BERT modeling.
This layer implements a masked language model based on the provided
transformer based encoder. It assumes that the encoder network being passed
has a "get_embedding_table()" method. Different from canonical BERT's masked
LM layer, when the embedding width is smaller than hidden_size, it adds an
extra output weights in shape [vocab_size, (hidden_size - embedding_width)].
"""
def __init__(self,
embedding_table,
activation=None,
initializer='glorot_uniform',
output='logits',
**kwargs):
"""Class initialization.
Arguments:
embedding_table: The embedding table from encoder network.
activation: The activation, if any, for the dense layer.
initializer: The initializer for the dense layer. Defaults to a Glorot
uniform initializer.
output: The output style for this layer. Can be either 'logits' or
'predictions'.
**kwargs: keyword arguments.
"""
super(MobileBertMaskedLM, self).__init__(**kwargs)
self.embedding_table = embedding_table
self.activation = activation
self.initializer = tf.keras.initializers.get(initializer)
if output not in ('predictions', 'logits'):
raise ValueError(
('Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"') % output)
self._output_type = output
def build(self, input_shape):
self._vocab_size, embedding_width = self.embedding_table.shape
hidden_size = input_shape[-1]
self.dense = tf.keras.layers.Dense(
hidden_size,
activation=self.activation,
kernel_initializer=self.initializer,
name='transform/dense')
if hidden_size > embedding_width:
self.extra_output_weights = self.add_weight(
'extra_output_weights',
shape=(self._vocab_size, hidden_size - embedding_width),
initializer=self.initializer,
trainable=True)
elif hidden_size == embedding_width:
self.extra_output_weights = None
else:
raise ValueError(
'hidden size %d cannot be smaller than embedding width %d.' %
(hidden_size, embedding_width))
self.layer_norm = tf.keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12, name='transform/LayerNorm')
self.bias = self.add_weight(
'output_bias/bias',
shape=(self._vocab_size,),
initializer='zeros',
trainable=True)
super(MobileBertMaskedLM, self).build(input_shape)
def call(self, sequence_data, masked_positions):
masked_lm_input = self._gather_indexes(sequence_data, masked_positions)
lm_data = self.dense(masked_lm_input)
lm_data = self.layer_norm(lm_data)
if self.extra_output_weights is None:
lm_data = tf.matmul(lm_data, self.embedding_table, transpose_b=True)
else:
lm_data = tf.matmul(
lm_data,
tf.concat([self.embedding_table, self.extra_output_weights], axis=1),
transpose_b=True)
logits = tf.nn.bias_add(lm_data, self.bias)
masked_positions_length = masked_positions.shape.as_list()[1] or tf.shape(
masked_positions)[1]
logits = tf.reshape(logits,
[-1, masked_positions_length, self._vocab_size])
if self._output_type == 'logits':
return logits
return tf.nn.log_softmax(logits)
def get_config(self):
raise NotImplementedError('MaskedLM cannot be directly serialized because '
'it has variable sharing logic.')
def _gather_indexes(self, sequence_tensor, positions):
"""Gathers the vectors at the specific positions.
Args:
sequence_tensor: Sequence output of `BertModel` layer of shape
(`batch_size`, `seq_length`, num_hidden) where num_hidden is number of
hidden units of `BertModel` layer.
positions: Positions ids of tokens in sequence to mask for pretraining
of with dimension (batch_size, num_predictions) where
`num_predictions` is maximum number of tokens to mask out and predict
per each sequence.
Returns:
Masked out sequence tensor of shape (batch_size * num_predictions,
num_hidden).
"""
sequence_shape = tf.shape(sequence_tensor)
batch_size, seq_length = sequence_shape[0], sequence_shape[1]
width = sequence_tensor.shape.as_list()[2] or sequence_shape[2]
flat_offsets = tf.reshape(
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
flat_positions = tf.reshape(positions + flat_offsets, [-1])
flat_sequence_tensor = tf.reshape(sequence_tensor,
[batch_size * seq_length, width])
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
return output_tensor
......@@ -18,6 +18,7 @@ import numpy as np
import tensorflow as tf
from official.nlp.modeling.layers import mobile_bert_layers
from official.nlp.modeling.networks import mobile_bert_encoder
def generate_fake_input(batch_size=1, seq_len=5, vocab_size=10000, seed=0):
......@@ -127,5 +128,146 @@ class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase):
self.assertEqual(layer_config, new_layer.get_config())
class MobileBertMaskedLMTest(tf.test.TestCase):
def create_layer(self,
vocab_size,
hidden_size,
embedding_width,
output='predictions',
xformer_stack=None):
# First, create a transformer stack that we can use to get the LM's
# vocabulary weight.
if xformer_stack is None:
xformer_stack = mobile_bert_encoder.MobileBERTEncoder(
word_vocab_size=vocab_size,
num_blocks=1,
hidden_size=hidden_size,
num_attention_heads=4,
word_embed_size=embedding_width)
# Create a maskedLM from the transformer stack.
test_layer = mobile_bert_layers.MobileBertMaskedLM(
embedding_table=xformer_stack.get_embedding_table(), output=output)
return test_layer
def test_layer_creation(self):
vocab_size = 100
sequence_length = 32
hidden_size = 64
embedding_width = 32
num_predictions = 21
test_layer = self.create_layer(
vocab_size=vocab_size,
hidden_size=hidden_size,
embedding_width=embedding_width)
# Make sure that the output tensor of the masked LM is the right shape.
lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
masked_positions = tf.keras.Input(shape=(num_predictions,), dtype=tf.int32)
output = test_layer(lm_input_tensor, masked_positions=masked_positions)
expected_output_shape = [None, num_predictions, vocab_size]
self.assertEqual(expected_output_shape, output.shape.as_list())
def test_layer_invocation_with_external_logits(self):
vocab_size = 100
sequence_length = 32
hidden_size = 64
embedding_width = 32
num_predictions = 21
xformer_stack = mobile_bert_encoder.MobileBERTEncoder(
word_vocab_size=vocab_size,
num_blocks=1,
hidden_size=hidden_size,
num_attention_heads=4,
word_embed_size=embedding_width)
test_layer = self.create_layer(
vocab_size=vocab_size,
hidden_size=hidden_size,
embedding_width=embedding_width,
xformer_stack=xformer_stack,
output='predictions')
logit_layer = self.create_layer(
vocab_size=vocab_size,
hidden_size=hidden_size,
embedding_width=embedding_width,
xformer_stack=xformer_stack,
output='logits')
# Create a model from the masked LM layer.
lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
masked_positions = tf.keras.Input(shape=(num_predictions,), dtype=tf.int32)
output = test_layer(lm_input_tensor, masked_positions)
logit_output = logit_layer(lm_input_tensor, masked_positions)
logit_output = tf.keras.layers.Activation(tf.nn.log_softmax)(logit_output)
logit_layer.set_weights(test_layer.get_weights())
model = tf.keras.Model([lm_input_tensor, masked_positions], output)
logits_model = tf.keras.Model(([lm_input_tensor, masked_positions]),
logit_output)
# Invoke the masked LM on some fake data to make sure there are no runtime
# errors in the code.
batch_size = 3
lm_input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, hidden_size))
masked_position_data = np.random.randint(
sequence_length, size=(batch_size, num_predictions))
# ref_outputs = model.predict([lm_input_data, masked_position_data])
# outputs = logits_model.predict([lm_input_data, masked_position_data])
ref_outputs = model([lm_input_data, masked_position_data])
outputs = logits_model([lm_input_data, masked_position_data])
# Ensure that the tensor shapes are correct.
expected_output_shape = (batch_size, num_predictions, vocab_size)
self.assertEqual(expected_output_shape, ref_outputs.shape)
self.assertEqual(expected_output_shape, outputs.shape)
self.assertAllClose(ref_outputs, outputs)
def test_layer_invocation(self):
vocab_size = 100
sequence_length = 32
hidden_size = 64
embedding_width = 32
num_predictions = 21
test_layer = self.create_layer(
vocab_size=vocab_size,
hidden_size=hidden_size,
embedding_width=embedding_width)
# Create a model from the masked LM layer.
lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
masked_positions = tf.keras.Input(shape=(num_predictions,), dtype=tf.int32)
output = test_layer(lm_input_tensor, masked_positions)
model = tf.keras.Model([lm_input_tensor, masked_positions], output)
# Invoke the masked LM on some fake data to make sure there are no runtime
# errors in the code.
batch_size = 3
lm_input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, hidden_size))
masked_position_data = np.random.randint(
2, size=(batch_size, num_predictions))
_ = model.predict([lm_input_data, masked_position_data])
def test_unknown_output_type_fails(self):
with self.assertRaisesRegex(ValueError, 'Unknown `output` value "bad".*'):
_ = self.create_layer(
vocab_size=8, hidden_size=8, embedding_width=4, output='bad')
def test_hidden_size_smaller_than_embedding_width(self):
hidden_size = 8
sequence_length = 32
num_predictions = 20
with self.assertRaisesRegex(
ValueError, 'hidden size 8 cannot be smaller than embedding width 16.'):
test_layer = self.create_layer(
vocab_size=8, hidden_size=8, embedding_width=16)
lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
masked_positions = tf.keras.Input(
shape=(num_predictions,), dtype=tf.int32)
_ = test_layer(lm_input_tensor, masked_positions)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment