Commit 07a07f6a authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 333772508
parent 071b3b94
...@@ -215,7 +215,7 @@ class BertPretrainerV2(tf.keras.Model): ...@@ -215,7 +215,7 @@ class BertPretrainerV2(tf.keras.Model):
masked_lm_positions = tf.keras.layers.Input( masked_lm_positions = tf.keras.layers.Input(
shape=(None,), name='masked_lm_positions', dtype=tf.int32) shape=(None,), name='masked_lm_positions', dtype=tf.int32)
inputs.append(masked_lm_positions) inputs.append(masked_lm_positions)
outputs['lm_output'] = self.masked_lm( outputs['mlm_logits'] = self.masked_lm(
sequence_output, masked_positions=masked_lm_positions) sequence_output, masked_positions=masked_lm_positions)
for cls_head in self.classification_heads: for cls_head in self.classification_heads:
outputs[cls_head.name] = cls_head(sequence_output) outputs[cls_head.name] = cls_head(sequence_output)
......
...@@ -145,15 +145,16 @@ class BertPretrainerTest(keras_parameterized.TestCase): ...@@ -145,15 +145,16 @@ class BertPretrainerTest(keras_parameterized.TestCase):
if has_encoder_outputs: if has_encoder_outputs:
self.assertSameElements( self.assertSameElements(
outputs.keys(), outputs.keys(),
['sequence_output', 'pooled_output', 'lm_output', 'encoder_outputs']) ['sequence_output', 'pooled_output', 'mlm_logits', 'encoder_outputs'])
self.assertLen(outputs['encoder_outputs'], num_layers) self.assertLen(outputs['encoder_outputs'], num_layers)
else: else:
self.assertSameElements(outputs.keys(), self.assertSameElements(
['sequence_output', 'pooled_output', 'lm_output']) outputs.keys(), ['sequence_output', 'pooled_output', 'mlm_logits'])
# Validate that the outputs are of the expected shape. # Validate that the outputs are of the expected shape.
expected_lm_shape = [None, num_token_predictions, vocab_size] expected_lm_shape = [None, num_token_predictions, vocab_size]
self.assertAllEqual(expected_lm_shape, outputs['lm_output'].shape.as_list()) self.assertAllEqual(expected_lm_shape,
outputs['mlm_logits'].shape.as_list())
expected_sequence_output_shape = [None, sequence_length, hidden_size] expected_sequence_output_shape = [None, sequence_length, hidden_size]
self.assertAllEqual(expected_sequence_output_shape, self.assertAllEqual(expected_sequence_output_shape,
......
...@@ -67,7 +67,7 @@ class MaskedLMTask(base_task.Task): ...@@ -67,7 +67,7 @@ class MaskedLMTask(base_task.Task):
metrics = dict([(metric.name, metric) for metric in metrics]) metrics = dict([(metric.name, metric) for metric in metrics])
lm_prediction_losses = tf.keras.losses.sparse_categorical_crossentropy( lm_prediction_losses = tf.keras.losses.sparse_categorical_crossentropy(
labels['masked_lm_ids'], labels['masked_lm_ids'],
tf.cast(model_outputs['lm_output'], tf.float32), tf.cast(model_outputs['mlm_logits'], tf.float32),
from_logits=True) from_logits=True)
lm_label_weights = labels['masked_lm_weights'] lm_label_weights = labels['masked_lm_weights']
lm_numerator_loss = tf.reduce_sum(lm_prediction_losses * lm_numerator_loss = tf.reduce_sum(lm_prediction_losses *
...@@ -134,7 +134,7 @@ class MaskedLMTask(base_task.Task): ...@@ -134,7 +134,7 @@ class MaskedLMTask(base_task.Task):
metrics = dict([(metric.name, metric) for metric in metrics]) metrics = dict([(metric.name, metric) for metric in metrics])
if 'masked_lm_accuracy' in metrics: if 'masked_lm_accuracy' in metrics:
metrics['masked_lm_accuracy'].update_state( metrics['masked_lm_accuracy'].update_state(
labels['masked_lm_ids'], model_outputs['lm_output'], labels['masked_lm_ids'], model_outputs['mlm_logits'],
labels['masked_lm_weights']) labels['masked_lm_weights'])
if 'next_sentence_accuracy' in metrics: if 'next_sentence_accuracy' in metrics:
metrics['next_sentence_accuracy'].update_state( metrics['next_sentence_accuracy'].update_state(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment