Unverified Commit 3c5330d8 authored by Hongkun Yu's avatar Hongkun Yu Committed by GitHub
Browse files

Merged commit includes the following changes: (#7298)

259790197  by hongkuny<hongkuny@google.com>:

    Update pretraining model to match tf1 var names.

--

PiperOrigin-RevId: 259790197
parent 2533c697
......@@ -86,18 +86,29 @@ class BertPretrainLayer(tf.keras.layers.Layer):
def build(self, unused_input_shapes):
"""Implements build() for the layer."""
self.output_bias = self.add_weight(
shape=[self.config.vocab_size],
name='predictions/output_bias',
initializer=tf.keras.initializers.Zeros())
self.lm_dense = tf.keras.layers.Dense(
self.config.hidden_size,
activation=modeling.get_activation(self.config.hidden_act),
kernel_initializer=self.initializer)
self.lm_bias = self.add_weight(
shape=[self.config.vocab_size],
name='lm_bias',
initializer=tf.keras.initializers.Zeros())
kernel_initializer=self.initializer,
name='predictions/transform/dense')
self.lm_layer_norm = tf.keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12)
self.next_sentence_dense = tf.keras.layers.Dense(
self.num_next_sentence_label, kernel_initializer=self.initializer)
axis=-1, epsilon=1e-12, name='predictions/transform/LayerNorm')
# Next sentence binary classification dense layer including bias to match
# TF1.x BERT variable shapes.
with tf.name_scope('seq_relationship'):
self.next_seq_weights = self.add_weight(
shape=[self.num_next_sentence_label, self.config.hidden_size],
name='output_weights',
initializer=self.initializer)
self.next_seq_bias = self.add_weight(
shape=[self.num_next_sentence_label],
name='output_bias',
initializer=tf.keras.initializers.Zeros())
super(BertPretrainLayer, self).build(unused_input_shapes)
def __call__(self,
......@@ -119,15 +130,13 @@ class BertPretrainLayer(tf.keras.layers.Layer):
sequence_output, masked_lm_positions)
lm_output = self.lm_dense(mask_lm_input_tensor)
lm_output = self.lm_layer_norm(lm_output)
lm_output = tf.keras.backend.dot(
lm_output, tf.keras.backend.transpose(self.embedding_table))
lm_output = tf.keras.backend.bias_add(lm_output, self.lm_bias)
lm_output = tf.keras.backend.softmax(lm_output)
lm_output = tf.keras.backend.log(lm_output)
sentence_output = self.next_sentence_dense(pooled_output)
sentence_output = tf.keras.backend.softmax(sentence_output)
sentence_output = tf.keras.backend.log(sentence_output)
lm_output = tf.matmul(lm_output, self.embedding_table, transpose_b=True)
lm_output = tf.nn.bias_add(lm_output, self.output_bias)
lm_output = tf.nn.log_softmax(lm_output, axis=-1)
logits = tf.matmul(pooled_output, self.next_seq_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, self.next_seq_bias)
sentence_output = tf.nn.log_softmax(logits, axis=-1)
return (lm_output, sentence_output)
......@@ -180,7 +189,7 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
unpacked_inputs = modeling.unpack_inputs(inputs)
lm_output = unpacked_inputs[0]
sentence_output = unpacked_inputs[1]
lm_label_ids = tf.keras.backend.cast(unpacked_inputs[2], tf.int32)
lm_label_ids = unpacked_inputs[2]
lm_label_ids = tf.keras.backend.reshape(lm_label_ids, [-1])
lm_label_ids_one_hot = tf.keras.backend.one_hot(lm_label_ids,
self.config.vocab_size)
......@@ -192,13 +201,14 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
denominator = tf.keras.backend.sum(lm_label_weights) + 1e-5
mask_label_loss = numerator / denominator
sentence_labels = tf.keras.backend.cast(unpacked_inputs[4], dtype=tf.int32)
sentence_labels = unpacked_inputs[4]
sentence_labels = tf.keras.backend.reshape(sentence_labels, [-1])
sentence_label_one_hot = tf.keras.backend.one_hot(sentence_labels, 2)
per_example_loss_sentence = -tf.keras.backend.sum(
sentence_label_one_hot * sentence_output, axis=-1)
sentence_loss = tf.keras.backend.mean(per_example_loss_sentence)
loss = mask_label_loss + sentence_loss
# TODO(hongkuny): Avoids the hack and switches add_loss.
final_loss = tf.fill(
tf.keras.backend.shape(per_example_loss_sentence), loss)
......@@ -245,7 +255,7 @@ def pretrain_model(bert_config,
masked_lm_ids = tf.keras.layers.Input(
shape=(max_predictions_per_seq,), name='masked_lm_ids', dtype=tf.int32)
bert_submodel_name = 'bert_core_layer'
bert_submodel_name = 'bert_model'
bert_submodel = modeling.get_bert_model(
input_word_ids,
input_mask,
......@@ -258,7 +268,8 @@ def pretrain_model(bert_config,
pretrain_layer = BertPretrainLayer(
bert_config,
bert_submodel.get_layer(bert_submodel_name),
initializer=initializer)
initializer=initializer,
name='cls')
lm_output, sentence_output = pretrain_layer(pooled_output, sequence_output,
masked_lm_positions)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment