Commit 7ebcbe20 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Clean up: use sparse_categorical_crossentropy directly for MLM loss.

PiperOrigin-RevId: 318322629
parent 4140da21
...@@ -48,12 +48,14 @@ class MaskedLMTask(base_task.Task): ...@@ -48,12 +48,14 @@ class MaskedLMTask(base_task.Task):
metrics, metrics,
aux_losses=None) -> tf.Tensor: aux_losses=None) -> tf.Tensor:
metrics = dict([(metric.name, metric) for metric in metrics]) metrics = dict([(metric.name, metric) for metric in metrics])
lm_output = tf.nn.log_softmax( lm_prediction_losses = tf.keras.losses.sparse_categorical_crossentropy(
tf.cast(model_outputs['lm_output'], tf.float32), axis=-1) labels['masked_lm_ids'],
mlm_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss( tf.cast(model_outputs['lm_output'], tf.float32),
labels=labels['masked_lm_ids'], from_logits=True)
predictions=lm_output, lm_label_weights = labels['masked_lm_weights']
weights=labels['masked_lm_weights']) lm_numerator_loss = tf.reduce_sum(lm_prediction_losses * lm_label_weights)
lm_denominator_loss = tf.reduce_sum(lm_label_weights)
mlm_loss = tf.math.divide_no_nan(lm_numerator_loss, lm_denominator_loss)
metrics['lm_example_loss'].update_state(mlm_loss) metrics['lm_example_loss'].update_state(mlm_loss)
if 'next_sentence_labels' in labels: if 'next_sentence_labels' in labels:
sentence_labels = labels['next_sentence_labels'] sentence_labels = labels['next_sentence_labels']
...@@ -74,6 +76,7 @@ class MaskedLMTask(base_task.Task): ...@@ -74,6 +76,7 @@ class MaskedLMTask(base_task.Task):
def build_inputs(self, params, input_context=None): def build_inputs(self, params, input_context=None):
"""Returns tf.data.Dataset for pretraining.""" """Returns tf.data.Dataset for pretraining."""
if params.input_path == 'dummy': if params.input_path == 'dummy':
def dummy_data(_): def dummy_data(_):
dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32) dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
dummy_lm = tf.zeros((1, params.max_predictions_per_seq), dtype=tf.int32) dummy_lm = tf.zeros((1, params.max_predictions_per_seq), dtype=tf.int32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment