Commit 6c0250ad authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 343132085
parent 6f514024
......@@ -18,6 +18,7 @@ import collections
import copy
from typing import List, Optional
from absl import logging
import gin
import tensorflow as tf
......@@ -164,7 +165,6 @@ class BertPretrainer(tf.keras.Model):
class BertPretrainerV2(tf.keras.Model):
"""BERT pretraining model V2.
(Experimental).
Adds the masked language model head and optional classification heads upon the
transformer encoder.
......@@ -198,7 +198,7 @@ class BertPretrainerV2(tf.keras.Model):
customized_masked_lm: Optional[tf.keras.layers.Layer] = None,
name: str = 'bert',
**kwargs):
self._self_setattr_tracking = False
super().__init__(self, name=name, **kwargs)
self._config = {
'encoder_network': encoder_network,
'mlm_initializer': mlm_initializer,
......@@ -207,6 +207,28 @@ class BertPretrainerV2(tf.keras.Model):
}
self.encoder_network = encoder_network
inputs = copy.copy(self.encoder_network.inputs)
self.classification_heads = classification_heads or []
if len(set([cls.name for cls in self.classification_heads])) != len(
self.classification_heads):
raise ValueError('Classification heads should have unique names.')
self.masked_lm = customized_masked_lm or layers.MaskedLM(
embedding_table=self.encoder_network.get_embedding_table(),
activation=mlm_activation,
initializer=mlm_initializer,
name='cls/predictions')
masked_lm_positions = tf.keras.layers.Input(
shape=(None,), name='masked_lm_positions', dtype=tf.int32)
inputs.append(masked_lm_positions)
self.inputs = inputs
def call(self, inputs):
if isinstance(inputs, list):
logging.warning('List inputs to BertPretrainer are discouraged.')
inputs = dict([
(ref.name, tensor) for ref, tensor in zip(self.inputs, inputs)
])
outputs = dict()
encoder_network_outputs = self.encoder_network(inputs)
if isinstance(encoder_network_outputs, list):
......@@ -224,31 +246,13 @@ class BertPretrainerV2(tf.keras.Model):
else:
raise ValueError('encoder_network\'s output should be either a list '
'or a dict, but got %s' % encoder_network_outputs)
sequence_output = outputs['sequence_output']
self.classification_heads = classification_heads or []
if len(set([cls.name for cls in self.classification_heads])) != len(
self.classification_heads):
raise ValueError('Classification heads should have unique names.')
if customized_masked_lm is not None:
self.masked_lm = customized_masked_lm
else:
self.masked_lm = layers.MaskedLM(
embedding_table=self.encoder_network.get_embedding_table(),
activation=mlm_activation,
initializer=mlm_initializer,
name='cls/predictions')
masked_lm_positions = tf.keras.layers.Input(
shape=(None,), name='masked_lm_positions', dtype=tf.int32)
inputs.append(masked_lm_positions)
masked_lm_positions = inputs['masked_lm_positions']
outputs['mlm_logits'] = self.masked_lm(
sequence_output, masked_positions=masked_lm_positions)
for cls_head in self.classification_heads:
outputs[cls_head.name] = cls_head(sequence_output)
super(BertPretrainerV2, self).__init__(
inputs=inputs, outputs=outputs, name=name, **kwargs)
return outputs
@property
def checkpoint_items(self):
......
......@@ -142,13 +142,15 @@ class BertPretrainerTest(keras_parameterized.TestCase):
encoder_network=test_network, customized_masked_lm=customized_masked_lm)
num_token_predictions = 20
# Create a set of 2-dimensional inputs (the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
lm_mask = tf.keras.Input(shape=(num_token_predictions,), dtype=tf.int32)
inputs = dict(
input_word_ids=tf.keras.Input(shape=(sequence_length,), dtype=tf.int32),
input_mask=tf.keras.Input(shape=(sequence_length,), dtype=tf.int32),
input_type_ids=tf.keras.Input(shape=(sequence_length,), dtype=tf.int32),
masked_lm_positions=tf.keras.Input(
shape=(num_token_predictions,), dtype=tf.int32))
# Invoke the trainer model on the inputs. This causes the layer to be built.
outputs = bert_trainer_model([word_ids, mask, type_ids, lm_mask])
outputs = bert_trainer_model(inputs)
has_encoder_outputs = dict_outputs or return_all_encoder_outputs
if has_encoder_outputs:
......
......@@ -103,6 +103,8 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
inner_dim=768, num_classes=2, name="next_sentence")
])
pretrain_model = masked_lm.MaskedLMTask(None).build_model(pretrain_cfg)
# The model variables will be created after the forward call.
_ = pretrain_model(pretrain_model.inputs)
ckpt = tf.train.Checkpoint(
model=pretrain_model, **pretrain_model.checkpoint_items)
init_path = ckpt.save(self.get_temp_dir())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment