"vscode:/vscode.git/clone" did not exist on "2ae06c8a6447453558e972313aa9d8a29bb86491"
Commit 78a367e1 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 343132085
parent e8b6955e
...@@ -18,6 +18,7 @@ import collections ...@@ -18,6 +18,7 @@ import collections
import copy import copy
from typing import List, Optional from typing import List, Optional
from absl import logging
import gin import gin
import tensorflow as tf import tensorflow as tf
...@@ -164,7 +165,6 @@ class BertPretrainer(tf.keras.Model): ...@@ -164,7 +165,6 @@ class BertPretrainer(tf.keras.Model):
class BertPretrainerV2(tf.keras.Model): class BertPretrainerV2(tf.keras.Model):
"""BERT pretraining model V2. """BERT pretraining model V2.
(Experimental).
Adds the masked language model head and optional classification heads upon the Adds the masked language model head and optional classification heads upon the
transformer encoder. transformer encoder.
...@@ -198,7 +198,7 @@ class BertPretrainerV2(tf.keras.Model): ...@@ -198,7 +198,7 @@ class BertPretrainerV2(tf.keras.Model):
customized_masked_lm: Optional[tf.keras.layers.Layer] = None, customized_masked_lm: Optional[tf.keras.layers.Layer] = None,
name: str = 'bert', name: str = 'bert',
**kwargs): **kwargs):
self._self_setattr_tracking = False super().__init__(self, name=name, **kwargs)
self._config = { self._config = {
'encoder_network': encoder_network, 'encoder_network': encoder_network,
'mlm_initializer': mlm_initializer, 'mlm_initializer': mlm_initializer,
...@@ -207,6 +207,28 @@ class BertPretrainerV2(tf.keras.Model): ...@@ -207,6 +207,28 @@ class BertPretrainerV2(tf.keras.Model):
} }
self.encoder_network = encoder_network self.encoder_network = encoder_network
inputs = copy.copy(self.encoder_network.inputs) inputs = copy.copy(self.encoder_network.inputs)
self.classification_heads = classification_heads or []
if len(set([cls.name for cls in self.classification_heads])) != len(
self.classification_heads):
raise ValueError('Classification heads should have unique names.')
self.masked_lm = customized_masked_lm or layers.MaskedLM(
embedding_table=self.encoder_network.get_embedding_table(),
activation=mlm_activation,
initializer=mlm_initializer,
name='cls/predictions')
masked_lm_positions = tf.keras.layers.Input(
shape=(None,), name='masked_lm_positions', dtype=tf.int32)
inputs.append(masked_lm_positions)
self.inputs = inputs
def call(self, inputs):
if isinstance(inputs, list):
logging.warning('List inputs to BertPretrainer are discouraged.')
inputs = dict([
(ref.name, tensor) for ref, tensor in zip(self.inputs, inputs)
])
outputs = dict() outputs = dict()
encoder_network_outputs = self.encoder_network(inputs) encoder_network_outputs = self.encoder_network(inputs)
if isinstance(encoder_network_outputs, list): if isinstance(encoder_network_outputs, list):
...@@ -224,31 +246,13 @@ class BertPretrainerV2(tf.keras.Model): ...@@ -224,31 +246,13 @@ class BertPretrainerV2(tf.keras.Model):
else: else:
raise ValueError('encoder_network\'s output should be either a list ' raise ValueError('encoder_network\'s output should be either a list '
'or a dict, but got %s' % encoder_network_outputs) 'or a dict, but got %s' % encoder_network_outputs)
sequence_output = outputs['sequence_output'] sequence_output = outputs['sequence_output']
self.classification_heads = classification_heads or [] masked_lm_positions = inputs['masked_lm_positions']
if len(set([cls.name for cls in self.classification_heads])) != len(
self.classification_heads):
raise ValueError('Classification heads should have unique names.')
if customized_masked_lm is not None:
self.masked_lm = customized_masked_lm
else:
self.masked_lm = layers.MaskedLM(
embedding_table=self.encoder_network.get_embedding_table(),
activation=mlm_activation,
initializer=mlm_initializer,
name='cls/predictions')
masked_lm_positions = tf.keras.layers.Input(
shape=(None,), name='masked_lm_positions', dtype=tf.int32)
inputs.append(masked_lm_positions)
outputs['mlm_logits'] = self.masked_lm( outputs['mlm_logits'] = self.masked_lm(
sequence_output, masked_positions=masked_lm_positions) sequence_output, masked_positions=masked_lm_positions)
for cls_head in self.classification_heads: for cls_head in self.classification_heads:
outputs[cls_head.name] = cls_head(sequence_output) outputs[cls_head.name] = cls_head(sequence_output)
return outputs
super(BertPretrainerV2, self).__init__(
inputs=inputs, outputs=outputs, name=name, **kwargs)
@property @property
def checkpoint_items(self): def checkpoint_items(self):
......
...@@ -142,13 +142,15 @@ class BertPretrainerTest(keras_parameterized.TestCase): ...@@ -142,13 +142,15 @@ class BertPretrainerTest(keras_parameterized.TestCase):
encoder_network=test_network, customized_masked_lm=customized_masked_lm) encoder_network=test_network, customized_masked_lm=customized_masked_lm)
num_token_predictions = 20 num_token_predictions = 20
# Create a set of 2-dimensional inputs (the first dimension is implicit). # Create a set of 2-dimensional inputs (the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) inputs = dict(
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) input_word_ids=tf.keras.Input(shape=(sequence_length,), dtype=tf.int32),
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) input_mask=tf.keras.Input(shape=(sequence_length,), dtype=tf.int32),
lm_mask = tf.keras.Input(shape=(num_token_predictions,), dtype=tf.int32) input_type_ids=tf.keras.Input(shape=(sequence_length,), dtype=tf.int32),
masked_lm_positions=tf.keras.Input(
shape=(num_token_predictions,), dtype=tf.int32))
# Invoke the trainer model on the inputs. This causes the layer to be built. # Invoke the trainer model on the inputs. This causes the layer to be built.
outputs = bert_trainer_model([word_ids, mask, type_ids, lm_mask]) outputs = bert_trainer_model(inputs)
has_encoder_outputs = dict_outputs or return_all_encoder_outputs has_encoder_outputs = dict_outputs or return_all_encoder_outputs
if has_encoder_outputs: if has_encoder_outputs:
......
...@@ -103,6 +103,8 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -103,6 +103,8 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
inner_dim=768, num_classes=2, name="next_sentence") inner_dim=768, num_classes=2, name="next_sentence")
]) ])
pretrain_model = masked_lm.MaskedLMTask(None).build_model(pretrain_cfg) pretrain_model = masked_lm.MaskedLMTask(None).build_model(pretrain_cfg)
# The model variables will be created after the forward call.
_ = pretrain_model(pretrain_model.inputs)
ckpt = tf.train.Checkpoint( ckpt = tf.train.Checkpoint(
model=pretrain_model, **pretrain_model.checkpoint_items) model=pretrain_model, **pretrain_model.checkpoint_items)
init_path = ckpt.save(self.get_temp_dir()) init_path = ckpt.save(self.get_temp_dir())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment