Unverified Commit 3c5330d8 authored by Hongkun Yu's avatar Hongkun Yu Committed by GitHub
Browse files

Merged commit includes the following changes: (#7298)

259790197  by hongkuny<hongkuny@google.com>:

    Update pretraining model to match tf1 var names.

--

PiperOrigin-RevId: 259790197
parent 2533c697
...@@ -86,18 +86,29 @@ class BertPretrainLayer(tf.keras.layers.Layer): ...@@ -86,18 +86,29 @@ class BertPretrainLayer(tf.keras.layers.Layer):
def build(self, unused_input_shapes): def build(self, unused_input_shapes):
"""Implements build() for the layer.""" """Implements build() for the layer."""
self.output_bias = self.add_weight(
shape=[self.config.vocab_size],
name='predictions/output_bias',
initializer=tf.keras.initializers.Zeros())
self.lm_dense = tf.keras.layers.Dense( self.lm_dense = tf.keras.layers.Dense(
self.config.hidden_size, self.config.hidden_size,
activation=modeling.get_activation(self.config.hidden_act), activation=modeling.get_activation(self.config.hidden_act),
kernel_initializer=self.initializer) kernel_initializer=self.initializer,
self.lm_bias = self.add_weight( name='predictions/transform/dense')
shape=[self.config.vocab_size],
name='lm_bias',
initializer=tf.keras.initializers.Zeros())
self.lm_layer_norm = tf.keras.layers.LayerNormalization( self.lm_layer_norm = tf.keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12) axis=-1, epsilon=1e-12, name='predictions/transform/LayerNorm')
self.next_sentence_dense = tf.keras.layers.Dense(
self.num_next_sentence_label, kernel_initializer=self.initializer) # Next sentence binary classification dense layer including bias to match
# TF1.x BERT variable shapes.
with tf.name_scope('seq_relationship'):
self.next_seq_weights = self.add_weight(
shape=[self.num_next_sentence_label, self.config.hidden_size],
name='output_weights',
initializer=self.initializer)
self.next_seq_bias = self.add_weight(
shape=[self.num_next_sentence_label],
name='output_bias',
initializer=tf.keras.initializers.Zeros())
super(BertPretrainLayer, self).build(unused_input_shapes) super(BertPretrainLayer, self).build(unused_input_shapes)
def __call__(self, def __call__(self,
...@@ -119,15 +130,13 @@ class BertPretrainLayer(tf.keras.layers.Layer): ...@@ -119,15 +130,13 @@ class BertPretrainLayer(tf.keras.layers.Layer):
sequence_output, masked_lm_positions) sequence_output, masked_lm_positions)
lm_output = self.lm_dense(mask_lm_input_tensor) lm_output = self.lm_dense(mask_lm_input_tensor)
lm_output = self.lm_layer_norm(lm_output) lm_output = self.lm_layer_norm(lm_output)
lm_output = tf.keras.backend.dot( lm_output = tf.matmul(lm_output, self.embedding_table, transpose_b=True)
lm_output, tf.keras.backend.transpose(self.embedding_table)) lm_output = tf.nn.bias_add(lm_output, self.output_bias)
lm_output = tf.keras.backend.bias_add(lm_output, self.lm_bias) lm_output = tf.nn.log_softmax(lm_output, axis=-1)
lm_output = tf.keras.backend.softmax(lm_output)
lm_output = tf.keras.backend.log(lm_output) logits = tf.matmul(pooled_output, self.next_seq_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, self.next_seq_bias)
sentence_output = self.next_sentence_dense(pooled_output) sentence_output = tf.nn.log_softmax(logits, axis=-1)
sentence_output = tf.keras.backend.softmax(sentence_output)
sentence_output = tf.keras.backend.log(sentence_output)
return (lm_output, sentence_output) return (lm_output, sentence_output)
...@@ -180,7 +189,7 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): ...@@ -180,7 +189,7 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
unpacked_inputs = modeling.unpack_inputs(inputs) unpacked_inputs = modeling.unpack_inputs(inputs)
lm_output = unpacked_inputs[0] lm_output = unpacked_inputs[0]
sentence_output = unpacked_inputs[1] sentence_output = unpacked_inputs[1]
lm_label_ids = tf.keras.backend.cast(unpacked_inputs[2], tf.int32) lm_label_ids = unpacked_inputs[2]
lm_label_ids = tf.keras.backend.reshape(lm_label_ids, [-1]) lm_label_ids = tf.keras.backend.reshape(lm_label_ids, [-1])
lm_label_ids_one_hot = tf.keras.backend.one_hot(lm_label_ids, lm_label_ids_one_hot = tf.keras.backend.one_hot(lm_label_ids,
self.config.vocab_size) self.config.vocab_size)
...@@ -192,13 +201,14 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): ...@@ -192,13 +201,14 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
denominator = tf.keras.backend.sum(lm_label_weights) + 1e-5 denominator = tf.keras.backend.sum(lm_label_weights) + 1e-5
mask_label_loss = numerator / denominator mask_label_loss = numerator / denominator
sentence_labels = tf.keras.backend.cast(unpacked_inputs[4], dtype=tf.int32) sentence_labels = unpacked_inputs[4]
sentence_labels = tf.keras.backend.reshape(sentence_labels, [-1]) sentence_labels = tf.keras.backend.reshape(sentence_labels, [-1])
sentence_label_one_hot = tf.keras.backend.one_hot(sentence_labels, 2) sentence_label_one_hot = tf.keras.backend.one_hot(sentence_labels, 2)
per_example_loss_sentence = -tf.keras.backend.sum( per_example_loss_sentence = -tf.keras.backend.sum(
sentence_label_one_hot * sentence_output, axis=-1) sentence_label_one_hot * sentence_output, axis=-1)
sentence_loss = tf.keras.backend.mean(per_example_loss_sentence) sentence_loss = tf.keras.backend.mean(per_example_loss_sentence)
loss = mask_label_loss + sentence_loss loss = mask_label_loss + sentence_loss
# TODO(hongkuny): Avoids the hack and switches add_loss.
final_loss = tf.fill( final_loss = tf.fill(
tf.keras.backend.shape(per_example_loss_sentence), loss) tf.keras.backend.shape(per_example_loss_sentence), loss)
...@@ -245,7 +255,7 @@ def pretrain_model(bert_config, ...@@ -245,7 +255,7 @@ def pretrain_model(bert_config,
masked_lm_ids = tf.keras.layers.Input( masked_lm_ids = tf.keras.layers.Input(
shape=(max_predictions_per_seq,), name='masked_lm_ids', dtype=tf.int32) shape=(max_predictions_per_seq,), name='masked_lm_ids', dtype=tf.int32)
bert_submodel_name = 'bert_core_layer' bert_submodel_name = 'bert_model'
bert_submodel = modeling.get_bert_model( bert_submodel = modeling.get_bert_model(
input_word_ids, input_word_ids,
input_mask, input_mask,
...@@ -258,7 +268,8 @@ def pretrain_model(bert_config, ...@@ -258,7 +268,8 @@ def pretrain_model(bert_config,
pretrain_layer = BertPretrainLayer( pretrain_layer = BertPretrainLayer(
bert_config, bert_config,
bert_submodel.get_layer(bert_submodel_name), bert_submodel.get_layer(bert_submodel_name),
initializer=initializer) initializer=initializer,
name='cls')
lm_output, sentence_output = pretrain_layer(pooled_output, sequence_output, lm_output, sentence_output = pretrain_layer(pooled_output, sequence_output,
masked_lm_positions) masked_lm_positions)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment