"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "db718021717b76def2725584303bdb5b221e2677"
Commit 3e44a9d6 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 317020610
parent bdf6f121
...@@ -48,23 +48,20 @@ class MaskedLMTask(base_task.Task): ...@@ -48,23 +48,20 @@ class MaskedLMTask(base_task.Task):
metrics, metrics,
aux_losses=None) -> tf.Tensor: aux_losses=None) -> tf.Tensor:
metrics = dict([(metric.name, metric) for metric in metrics]) metrics = dict([(metric.name, metric) for metric in metrics])
lm_output = tf.nn.log_softmax(model_outputs['lm_output'], axis=-1) lm_output = tf.nn.log_softmax(
tf.cast(model_outputs['lm_output'], tf.float32), axis=-1)
mlm_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss( mlm_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
labels=labels['masked_lm_ids'], labels=labels['masked_lm_ids'],
predictions=lm_output, predictions=lm_output,
weights=labels['masked_lm_weights']) weights=labels['masked_lm_weights'])
metrics['lm_example_loss'].update_state(mlm_loss) metrics['lm_example_loss'].update_state(mlm_loss)
if 'next_sentence_labels' in labels: if 'next_sentence_labels' in labels:
policy = tf.keras.mixed_precision.experimental.global_policy()
if policy.name == 'mixed_bfloat16': # b/158514794: bf16 is not stable.
policy = tf.float32
predictions = tf.keras.layers.Activation(
tf.nn.log_softmax, dtype=policy)(model_outputs['next_sentence'])
sentence_labels = labels['next_sentence_labels'] sentence_labels = labels['next_sentence_labels']
sentence_outputs = tf.cast(
model_outputs['next_sentence'], dtype=tf.float32)
sentence_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss( sentence_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
labels=sentence_labels, labels=sentence_labels,
predictions=predictions) predictions=tf.nn.log_softmax(sentence_outputs, axis=-1))
metrics['next_sentence_loss'].update_state(sentence_loss) metrics['next_sentence_loss'].update_state(sentence_loss)
total_loss = mlm_loss + sentence_loss total_loss = mlm_loss + sentence_loss
else: else:
......
...@@ -83,7 +83,7 @@ class SentencePredictionTask(base_task.Task): ...@@ -83,7 +83,7 @@ class SentencePredictionTask(base_task.Task):
loss = loss_lib.weighted_sparse_categorical_crossentropy_loss( loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
labels=labels, labels=labels,
predictions=tf.nn.log_softmax( predictions=tf.nn.log_softmax(
model_outputs['sentence_prediction'], axis=-1)) tf.cast(model_outputs['sentence_prediction'], tf.float32), axis=-1))
if aux_losses: if aux_losses:
loss += tf.add_n(aux_losses) loss += tf.add_n(aux_losses)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment