Commit a7c38397 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Allow customized masked lm in BertPretrainerV2.

PiperOrigin-RevId: 341708068
parent 6e897a35
......@@ -177,6 +177,10 @@ class BertPretrainerV2(tf.keras.Model):
to a Glorot uniform initializer.
classification_heads: A list of optional head layers to transform on encoder
sequence outputs.
customized_masked_lm: A customized masked_lm layer. If None, will create
a standard layer from `layers.MaskedLM`; if not None, will use the
specified masked_lm layer. Above arguments `mlm_activation` and
`mlm_initializer` will be ignored.
name: The name of the model.
Inputs: Inputs defined by the encoder network, plus `masked_lm_positions` as a
dictionary.
......@@ -191,6 +195,7 @@ class BertPretrainerV2(tf.keras.Model):
mlm_activation=None,
mlm_initializer='glorot_uniform',
classification_heads: Optional[List[tf.keras.layers.Layer]] = None,
customized_masked_lm: Optional[tf.keras.layers.Layer] = None,
name: str = 'bert',
**kwargs):
self._self_setattr_tracking = False
......@@ -226,6 +231,9 @@ class BertPretrainerV2(tf.keras.Model):
self.classification_heads):
raise ValueError('Classification heads should have unique names.')
if customized_masked_lm is not None:
self.masked_lm = customized_masked_lm
else:
self.masked_lm = layers.MaskedLM(
embedding_table=self.encoder_network.get_embedding_table(),
activation=mlm_activation,
......
......@@ -19,6 +19,7 @@ from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling import layers
from official.nlp.modeling import networks
from official.nlp.modeling.models import bert_pretrainer
......@@ -112,8 +113,10 @@ class BertPretrainerTest(keras_parameterized.TestCase):
@parameterized.parameters(itertools.product(
(False, True),
(False, True),
(False, True),
))
def test_bert_pretrainerv2(self, dict_outputs, return_all_encoder_outputs):
def test_bert_pretrainerv2(self, dict_outputs, return_all_encoder_outputs,
use_customized_masked_lm):
"""Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer.
vocab_size = 100
......@@ -129,8 +132,14 @@ class BertPretrainerTest(keras_parameterized.TestCase):
dict_outputs=dict_outputs)
# Create a BERT trainer with the created network.
if use_customized_masked_lm:
customized_masked_lm = layers.MaskedLM(
embedding_table=test_network.get_embedding_table())
else:
customized_masked_lm = None
bert_trainer_model = bert_pretrainer.BertPretrainerV2(
encoder_network=test_network)
encoder_network=test_network, customized_masked_lm=customized_masked_lm)
num_token_predictions = 20
# Create a set of 2-dimensional inputs (the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment