Commit 802488f1 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 316593329
parent cabb22cd
......@@ -230,9 +230,10 @@ def pretrain_model(bert_config,
initializer=initializer,
output='predictions')
lm_output, sentence_output = pretrainer_model(
outputs = pretrainer_model(
[input_word_ids, input_mask, input_type_ids, masked_lm_positions])
lm_output = outputs['masked_lm']
sentence_output = outputs['classification']
pretrain_loss_layer = BertPretrainLossAndMetricLayer(
vocab_size=bert_config.vocab_size)
output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids,
......
......@@ -45,6 +45,9 @@ assemble new layers, networks, or models.
should be masked), the output will have masked positions set to
approximately zero.
* [`MaskedLM`](masked_lm.py) implements a masked language model. It assumes the
embedding table variable is passed to it.
* [ClassificationHead](cls_head.py) A pooling head over a sequence of
embeddings, commonly used by classification tasks.
......
......@@ -18,6 +18,7 @@ from official.nlp.modeling.layers.attention import *
from official.nlp.modeling.layers.cls_head import *
from official.nlp.modeling.layers.dense_einsum import DenseEinsum
from official.nlp.modeling.layers.gated_feedforward import GatedFeedforward
from official.nlp.modeling.layers.masked_lm import MaskedLM
from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax
from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding
from official.nlp.modeling.layers.position_embedding import PositionEmbedding
......
......@@ -25,91 +25,74 @@ from official.modeling import tf_utils
@tf.keras.utils.register_keras_serializable(package='Text')
class MaskedLM(tf.keras.Model):
class MaskedLM(tf.keras.layers.Layer):
"""Masked language model network head for BERT modeling.
This network implements a masked language model based on the provided network.
It assumes that the network being passed has a "get_embedding_table()" method.
Arguments:
input_width: The innermost dimension of the input tensor to this network.
num_predictions: The number of predictions to make per sequence.
source_network: The network with the embedding layer to use for the
embedding layer.
embedding_table: The embedding table of a source network, If None, the
`source_network.get_embedding_table()` method is used.
activation: The activation, if any, for the dense layer in this network.
initializer: The intializer for the dense layer in this network. Defaults to
a Glorot uniform initializer.
embedding_table: The embedding table of the targets.
activation: The activation, if any, for the dense layer.
initializer: The intializer for the dense layer. Defaults to a Glorot
uniform initializer.
output: The output style for this network. Can be either 'logits' or
'predictions'.
"""
def __init__(self,
input_width,
num_predictions,
source_network,
embedding_table=None,
embedding_table,
activation=None,
initializer='glorot_uniform',
output='logits',
name='cls/predictions',
**kwargs):
super(MaskedLM, self).__init__(name=name, **kwargs)
self.embedding_table = embedding_table
self.activation = activation
self.initializer = tf.keras.initializers.get(initializer)
if embedding_table is None:
embedding_table = source_network.get_embedding_table()
vocab_size, hidden_size = embedding_table.shape
sequence_data = tf.keras.layers.Input(
shape=(None, input_width), name='sequence_data', dtype=tf.float32)
masked_lm_positions = tf.keras.layers.Input(
shape=(num_predictions,), name='masked_lm_positions', dtype=tf.int32)
masked_lm_input = tf.keras.layers.Lambda(
lambda x: self._gather_indexes(x[0], x[1]))(
[sequence_data, masked_lm_positions])
lm_data = (
tf.keras.layers.Dense(
hidden_size,
activation=activation,
kernel_initializer=initializer,
name='cls/predictions/transform/dense')(masked_lm_input))
lm_data = tf.keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12, name='cls/predictions/transform/LayerNorm')(
lm_data)
lm_data = tf.keras.layers.Lambda(
lambda x: tf.matmul(x, embedding_table, transpose_b=True))(
lm_data)
logits = Bias(
initializer=tf.keras.initializers.Zeros(),
name='cls/predictions/output_bias')(
lm_data)
# We can't use the standard Keras reshape layer here, since it expects
# the input and output batch size to be the same.
reshape_layer = tf.keras.layers.Lambda(
lambda x: tf.reshape(x, [-1, num_predictions, vocab_size]))
self.logits = reshape_layer(logits)
predictions = tf.keras.layers.Activation(tf.nn.log_softmax)(self.logits)
if output == 'logits':
output_tensors = self.logits
elif output == 'predictions':
output_tensors = predictions
else:
if output not in ('predictions', 'logits'):
raise ValueError(
('Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"') % output)
self._output_type = output
super(MaskedLM, self).__init__(
inputs=[sequence_data, masked_lm_positions],
outputs=output_tensors,
**kwargs)
def build(self, input_shape):
self._vocab_size, hidden_size = self.embedding_table.shape
self.dense = tf.keras.layers.Dense(
hidden_size,
activation=self.activation,
kernel_initializer=self.initializer,
name='transform/dense')
self.layer_norm = tf.keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12, name='transform/LayerNorm')
self.bias = self.add_weight(
'output_bias/bias',
shape=(self._vocab_size,),
initializer='zeros',
trainable=True)
super(MaskedLM, self).build(input_shape)
def call(self, sequence_data, masked_positions):
masked_lm_input = self._gather_indexes(sequence_data, masked_positions)
lm_data = self.dense(masked_lm_input)
lm_data = self.layer_norm(lm_data)
lm_data = tf.matmul(lm_data, self.embedding_table, transpose_b=True)
logits = tf.nn.bias_add(lm_data, self.bias)
masked_positions_shape = tf_utils.get_shape_list(
masked_positions, name='masked_positions_tensor')
logits = tf.reshape(logits,
[-1, masked_positions_shape[1], self._vocab_size])
if self._output_type == 'logits':
return logits
return tf.nn.log_softmax(logits)
def get_config(self):
raise NotImplementedError('MaskedLM cannot be directly serialized at this '
'time. Please use it only in Layers or '
'functionally subclassed Models/Networks.')
raise NotImplementedError('MaskedLM cannot be directly serialized because '
'it has variable sharing logic.')
def _gather_indexes(self, sequence_tensor, positions):
"""Gathers the vectors at the specific positions.
......@@ -139,51 +122,3 @@ class MaskedLM(tf.keras.Model):
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
return output_tensor
@tf.keras.utils.register_keras_serializable(package='Text')
# Temporary until we can create a Dense layer that ties the embedding.
class Bias(tf.keras.layers.Layer):
"""Adds a bias term to an input."""
def __init__(self,
initializer='zeros',
regularizer=None,
constraint=None,
activation=None,
**kwargs):
super(Bias, self).__init__(**kwargs)
self._initializer = tf.keras.initializers.get(initializer)
self._regularizer = tf.keras.regularizers.get(regularizer)
self._constraint = tf.keras.constraints.get(constraint)
self._activation = tf.keras.activations.get(activation)
def build(self, input_shape):
input_shape = tf.TensorShape(input_shape)
self._bias = self.add_weight(
'bias',
shape=input_shape[1:],
initializer=self._initializer,
regularizer=self._regularizer,
constraint=self._constraint,
dtype=self._dtype,
trainable=True)
super(Bias, self).build(input_shape)
def get_config(self):
config = {
'activation': tf.keras.activations.serialize(self._activation),
'initializer': tf.keras.initializers.serialize(self._initializer),
'regularizer': tf.keras.regularizers.serialize(self._regularizer),
'constraint': tf.keras.constraints.serialize(self._constraint)
}
base_config = super(Bias, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
outputs = tf.nn.bias_add(inputs, self._bias)
if self._activation is not None:
return self._activation(outputs) # pylint: disable=not-callable
else:
return outputs
......@@ -23,7 +23,7 @@ import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling.networks import masked_lm
from official.nlp.modeling.layers import masked_lm
from official.nlp.modeling.networks import transformer_encoder
......@@ -32,13 +32,12 @@ from official.nlp.modeling.networks import transformer_encoder
@keras_parameterized.run_all_keras_modes
class MaskedLMTest(keras_parameterized.TestCase):
def create_network(self,
vocab_size,
sequence_length,
hidden_size,
num_predictions,
output='predictions',
xformer_stack=None):
def create_layer(self,
vocab_size,
sequence_length,
hidden_size,
output='predictions',
xformer_stack=None):
# First, create a transformer stack that we can use to get the LM's
# vocabulary weight.
if xformer_stack is None:
......@@ -49,82 +48,32 @@ class MaskedLMTest(keras_parameterized.TestCase):
hidden_size=hidden_size,
num_attention_heads=4,
)
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
lm_outputs, _ = xformer_stack([word_ids, mask, type_ids])
# Create a maskedLM from the transformer stack.
test_network = masked_lm.MaskedLM(
num_predictions=num_predictions,
input_width=lm_outputs.shape[-1],
source_network=xformer_stack,
test_layer = masked_lm.MaskedLM(
embedding_table=xformer_stack.get_embedding_table(),
output=output)
return test_network
return test_layer
def test_network_creation(self):
def test_layer_creation(self):
vocab_size = 100
sequence_length = 32
hidden_size = 64
num_predictions = 21
test_network = self.create_network(
test_layer = self.create_layer(
vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size,
num_predictions=num_predictions)
hidden_size=hidden_size)
# Make sure that the output tensor of the masked LM is the right shape.
lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
masked_lm_positions = tf.keras.Input(
shape=(num_predictions,), dtype=tf.int32)
output = test_network([lm_input_tensor, masked_lm_positions])
masked_positions = tf.keras.Input(shape=(num_predictions,), dtype=tf.int32)
output = test_layer(lm_input_tensor, masked_positions=masked_positions)
expected_output_shape = [None, num_predictions, vocab_size]
self.assertEqual(expected_output_shape, output.shape.as_list())
def test_network_invocation_with_internal_logits(self):
vocab_size = 100
sequence_length = 32
hidden_size = 64
num_predictions = 21
test_network = self.create_network(
vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size,
num_predictions=num_predictions)
# Create a model from the masked LM layer.
lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
masked_lm_positions = tf.keras.Input(
shape=(num_predictions,), dtype=tf.int32)
output = test_network([lm_input_tensor, masked_lm_positions])
model = tf.keras.Model([lm_input_tensor, masked_lm_positions], output)
logits_model = tf.keras.Model(test_network.inputs, test_network.logits)
# Invoke the masked LM on some fake data to make sure there are no runtime
# errors in the code.
batch_size = 3
lm_input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, hidden_size))
masked_position_data = np.random.randint(
2, size=(batch_size, num_predictions))
outputs = model.predict([lm_input_data, masked_position_data])
logits = logits_model.predict([lm_input_data, masked_position_data])
# Ensure that the tensor shapes are correct.
expected_output_shape = (batch_size, num_predictions, vocab_size)
self.assertEqual(expected_output_shape, outputs.shape)
self.assertEqual(expected_output_shape, logits.shape)
# Ensure that the logits, when softmaxed, create the outputs.
input_tensor = tf.keras.Input(expected_output_shape[1:])
output_tensor = tf.keras.layers.Activation(tf.nn.log_softmax)(input_tensor)
softmax_model = tf.keras.Model(input_tensor, output_tensor)
calculated_softmax = softmax_model.predict(logits)
self.assertAllClose(outputs, calculated_softmax)
def test_network_invocation_with_external_logits(self):
def test_layer_invocation_with_external_logits(self):
vocab_size = 100
sequence_length = 32
hidden_size = 64
......@@ -136,31 +85,28 @@ class MaskedLMTest(keras_parameterized.TestCase):
hidden_size=hidden_size,
num_attention_heads=4,
)
test_network = self.create_network(
test_layer = self.create_layer(
vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size,
num_predictions=num_predictions,
xformer_stack=xformer_stack,
output='predictions')
logit_network = self.create_network(
logit_layer = self.create_layer(
vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size,
num_predictions=num_predictions,
xformer_stack=xformer_stack,
output='logits')
logit_network.set_weights(test_network.get_weights())
# Create a model from the masked LM layer.
lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
masked_lm_positions = tf.keras.Input(
shape=(num_predictions,), dtype=tf.int32)
output = test_network([lm_input_tensor, masked_lm_positions])
logit_output = logit_network([lm_input_tensor, masked_lm_positions])
model = tf.keras.Model([lm_input_tensor, masked_lm_positions], output)
logits_model = tf.keras.Model(([lm_input_tensor, masked_lm_positions]),
masked_positions = tf.keras.Input(shape=(num_predictions,), dtype=tf.int32)
output = test_layer(lm_input_tensor, masked_positions)
logit_output = logit_layer(lm_input_tensor, masked_positions)
logit_output = tf.keras.layers.Activation(tf.nn.log_softmax)(logit_output)
logit_layer.set_weights(test_layer.get_weights())
model = tf.keras.Model([lm_input_tensor, masked_positions], output)
logits_model = tf.keras.Model(([lm_input_tensor, masked_positions]),
logit_output)
# Invoke the masked LM on some fake data to make sure there are no runtime
......@@ -169,40 +115,33 @@ class MaskedLMTest(keras_parameterized.TestCase):
lm_input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, hidden_size))
masked_position_data = np.random.randint(
2, size=(batch_size, num_predictions))
outputs = model.predict([lm_input_data, masked_position_data])
logits = logits_model.predict([lm_input_data, masked_position_data])
sequence_length, size=(batch_size, num_predictions))
# ref_outputs = model.predict([lm_input_data, masked_position_data])
# outputs = logits_model.predict([lm_input_data, masked_position_data])
ref_outputs = model([lm_input_data, masked_position_data])
outputs = logits_model([lm_input_data, masked_position_data])
# Ensure that the tensor shapes are correct.
expected_output_shape = (batch_size, num_predictions, vocab_size)
self.assertEqual(expected_output_shape, ref_outputs.shape)
self.assertEqual(expected_output_shape, outputs.shape)
self.assertEqual(expected_output_shape, logits.shape)
self.assertAllClose(ref_outputs, outputs)
# Ensure that the logits, when softmaxed, create the outputs.
input_tensor = tf.keras.Input(expected_output_shape[1:])
output_tensor = tf.keras.layers.Activation(tf.nn.log_softmax)(input_tensor)
softmax_model = tf.keras.Model(input_tensor, output_tensor)
calculated_softmax = softmax_model.predict(logits)
self.assertAllClose(outputs, calculated_softmax)
def test_network_invocation(self):
def test_layer_invocation(self):
vocab_size = 100
sequence_length = 32
hidden_size = 64
num_predictions = 21
test_network = self.create_network(
test_layer = self.create_layer(
vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size,
num_predictions=num_predictions)
hidden_size=hidden_size)
# Create a model from the masked LM layer.
lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
masked_lm_positions = tf.keras.Input(
shape=(num_predictions,), dtype=tf.int32)
output = test_network([lm_input_tensor, masked_lm_positions])
model = tf.keras.Model([lm_input_tensor, masked_lm_positions], output)
masked_positions = tf.keras.Input(shape=(num_predictions,), dtype=tf.int32)
output = test_layer(lm_input_tensor, masked_positions)
model = tf.keras.Model([lm_input_tensor, masked_positions], output)
# Invoke the masked LM on some fake data to make sure there are no runtime
# errors in the code.
......@@ -215,12 +154,8 @@ class MaskedLMTest(keras_parameterized.TestCase):
def test_unknown_output_type_fails(self):
with self.assertRaisesRegex(ValueError, 'Unknown `output` value "bad".*'):
_ = self.create_network(
vocab_size=8,
sequence_length=8,
hidden_size=8,
num_predictions=8,
output='bad')
_ = self.create_layer(
vocab_size=8, sequence_length=8, hidden_size=8, output='bad')
if __name__ == '__main__':
......
......@@ -23,6 +23,7 @@ import numpy as np
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling import layers
from official.nlp.modeling import networks
from official.nlp.modeling.losses import weighted_sparse_categorical_crossentropy
......@@ -48,20 +49,18 @@ class ClassificationLossTest(keras_parameterized.TestCase):
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
lm_outputs, _ = xformer_stack([word_ids, mask, type_ids])
_ = xformer_stack([word_ids, mask, type_ids])
# Create a maskedLM from the transformer stack.
test_network = networks.MaskedLM(
num_predictions=num_predictions,
input_width=lm_outputs.shape[-1],
source_network=xformer_stack,
test_layer = layers.MaskedLM(
embedding_table=xformer_stack.get_embedding_table(),
output=output)
# Create a model from the masked LM layer.
lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
masked_lm_positions = tf.keras.Input(
shape=(num_predictions,), dtype=tf.int32)
output = test_network([lm_input_tensor, masked_lm_positions])
output = test_layer(lm_input_tensor, masked_positions=masked_lm_positions)
return tf.keras.Model([lm_input_tensor, masked_lm_positions], output)
def create_classification_model(self, input_width, num_classes):
......
......@@ -25,6 +25,7 @@ from typing import List, Optional
import gin
import tensorflow as tf
from official.nlp.modeling import layers
from official.nlp.modeling import networks
......@@ -47,8 +48,8 @@ class BertPretrainer(tf.keras.Model):
num_token_predictions: Number of tokens to predict from the masked LM.
embedding_table: Embedding table of a network. If None, the
"network.get_embedding_table()" is used.
activation: The activation (if any) to use in the masked LM network.
If None, no activation will be used.
activation: The activation (if any) to use in the masked LM network. If
None, no activation will be used.
initializer: The initializer (if any) to use in the masked LM and
classification networks. Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or
......@@ -106,16 +107,16 @@ class BertPretrainer(tf.keras.Model):
dtype=tf.int32)
inputs.append(masked_lm_positions)
self.masked_lm = networks.MaskedLM(
num_predictions=num_token_predictions,
input_width=sequence_output.shape[-1],
source_network=network,
if embedding_table is None:
embedding_table = self.encoder.get_embedding_table()
self.masked_lm = layers.MaskedLM(
embedding_table=embedding_table,
activation=activation,
initializer=initializer,
output=output,
name='masked_lm')
lm_outputs = self.masked_lm([sequence_output, masked_lm_positions])
name='cls/predictions')
lm_outputs = self.masked_lm(
sequence_output, masked_positions=masked_lm_positions)
self.classification = networks.Classification(
input_width=cls_output.shape[-1],
......@@ -126,7 +127,9 @@ class BertPretrainer(tf.keras.Model):
sentence_outputs = self.classification(cls_output)
super(BertPretrainer, self).__init__(
inputs=inputs, outputs=[lm_outputs, sentence_outputs], **kwargs)
inputs=inputs,
outputs=dict(masked_lm=lm_outputs, classification=sentence_outputs),
**kwargs)
def get_config(self):
return self._config
......@@ -151,8 +154,8 @@ class BertPretrainerV2(tf.keras.Model):
num_masked_tokens: Number of tokens to predict from the masked LM.
encoder_network: A transformer network. This network should output a
sequence output and a classification output.
mlm_activation: The activation (if any) to use in the masked LM network.
If None, no activation will be used.
mlm_activation: The activation (if any) to use in the masked LM network. If
None, no activation will be used.
mlm_initializer: The initializer (if any) to use in the masked LM. Default
to a Glorot uniform initializer.
classification_heads: A list of optional head layers to transform on encoder
......@@ -193,17 +196,18 @@ class BertPretrainerV2(tf.keras.Model):
outputs = dict()
if num_masked_tokens > 0:
self.masked_lm = networks.MaskedLM(
num_predictions=num_masked_tokens,
input_width=sequence_output.shape[-1],
source_network=self.encoder_network,
self.masked_lm = layers.MaskedLM(
embedding_table=self.encoder_network.get_embedding_table(),
activation=mlm_activation,
initializer=mlm_initializer,
name='masked_lm')
masked_lm_positions = copy.copy(self.masked_lm.inputs[-1])
name='cls/predictions')
masked_lm_positions = tf.keras.layers.Input(
shape=(num_masked_tokens,),
name='masked_lm_positions',
dtype=tf.int32)
inputs.append(masked_lm_positions)
outputs['lm_output'] = self.masked_lm(
[sequence_output, masked_lm_positions])
sequence_output, masked_positions=masked_lm_positions)
for cls_head in self.classification_heads:
outputs[cls_head.name] = cls_head(sequence_output)
......
......@@ -50,16 +50,19 @@ class BertPretrainerTest(keras_parameterized.TestCase):
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
lm_mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
masked_lm_positions = tf.keras.Input(
shape=(num_token_predictions,), dtype=tf.int32)
# Invoke the trainer model on the inputs. This causes the layer to be built.
lm_outs, cls_outs = bert_trainer_model([word_ids, mask, type_ids, lm_mask])
outputs = bert_trainer_model(
[word_ids, mask, type_ids, masked_lm_positions])
# Validate that the outputs are of the expected shape.
expected_lm_shape = [None, num_token_predictions, vocab_size]
expected_classification_shape = [None, num_classes]
self.assertAllEqual(expected_lm_shape, lm_outs.shape.as_list())
self.assertAllEqual(expected_classification_shape, cls_outs.shape.as_list())
self.assertAllEqual(expected_lm_shape, outputs['masked_lm'].shape.as_list())
self.assertAllEqual(expected_classification_shape,
outputs['classification'].shape.as_list())
def test_bert_trainer_tensor_call(self):
"""Validate that the Keras object can be invoked."""
......@@ -81,7 +84,7 @@ class BertPretrainerTest(keras_parameterized.TestCase):
# Invoke the trainer model on the tensors. In Eager mode, this does the
# actual calculation. (We can't validate the outputs, since the network is
# too complex: this simply ensures we're not hitting runtime errors.)
_, _ = bert_trainer_model([word_ids, mask, type_ids, lm_mask])
_ = bert_trainer_model([word_ids, mask, type_ids, lm_mask])
def test_serialize_deserialize(self):
"""Validate that the BERT trainer can be serialized and deserialized."""
......@@ -123,7 +126,7 @@ class BertPretrainerTest(keras_parameterized.TestCase):
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
lm_mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
lm_mask = tf.keras.Input(shape=(num_token_predictions,), dtype=tf.int32)
# Invoke the trainer model on the inputs. This causes the layer to be built.
outputs = bert_trainer_model([word_ids, mask, type_ids, lm_mask])
......
......@@ -16,8 +16,6 @@ Self-supervised Learning of Language Representations]
(https://arxiv.org/abs/1909.11942). Compared with [BERT](https://arxiv.org/abs/1810.04805), ALBERT refactorizes embedding parameters
into two smaller matrices and shares parameters across layers.
* [`MaskedLM`](masked_lm.py) implements a masked language model for BERT pretraining. It assumes that the network being passed has a `get_embedding_table()` method.
* [`Classification`](classification.py) contains a single hidden layer, and is
intended for use as a classification or regression (if number of classes is set
to 1) head.
......
......@@ -16,7 +16,6 @@
from official.nlp.modeling.networks.albert_transformer_encoder import AlbertTransformerEncoder
from official.nlp.modeling.networks.classification import Classification
from official.nlp.modeling.networks.encoder_scaffold import EncoderScaffold
from official.nlp.modeling.networks.masked_lm import MaskedLM
from official.nlp.modeling.networks.span_labeling import SpanLabeling
from official.nlp.modeling.networks.token_classification import TokenClassification
from official.nlp.modeling.networks.transformer_encoder import TransformerEncoder
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment