Commit ef99be0b authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 315789209
parent dc9c75dd
...@@ -63,7 +63,13 @@ class Classification(tf.keras.Model): ...@@ -63,7 +63,13 @@ class Classification(tf.keras.Model):
kernel_initializer=initializer, kernel_initializer=initializer,
name='predictions/transform/logits')( name='predictions/transform/logits')(
cls_output) cls_output)
predictions = tf.keras.layers.Activation(tf.nn.log_softmax)(self.logits)
policy = tf.keras.mixed_precision.experimental.global_policy()
if policy.name == 'mixed_bfloat16':
# b/158514794: bf16 is not stable with post-softmax cross-entropy.
policy = tf.float32
predictions = tf.keras.layers.Activation(tf.nn.log_softmax,
dtype=policy)(self.logits)
if output == 'logits': if output == 'logits':
output_tensors = self.logits output_tensors = self.logits
......
...@@ -55,11 +55,16 @@ class MaskedLMTask(base_task.Task): ...@@ -55,11 +55,16 @@ class MaskedLMTask(base_task.Task):
weights=features['masked_lm_weights']) weights=features['masked_lm_weights'])
metrics['lm_example_loss'].update_state(mlm_loss) metrics['lm_example_loss'].update_state(mlm_loss)
if 'next_sentence_labels' in features: if 'next_sentence_labels' in features:
policy = tf.keras.mixed_precision.experimental.global_policy()
if policy.name == 'mixed_bfloat16': # b/158514794: bf16 is not stable.
policy = tf.float32
predictions = tf.keras.layers.Activation(
tf.nn.log_softmax, dtype=policy)(model_outputs['next_sentence'])
sentence_labels = features['next_sentence_labels'] sentence_labels = features['next_sentence_labels']
sentence_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss( sentence_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
labels=sentence_labels, labels=sentence_labels,
predictions=tf.nn.log_softmax( predictions=predictions)
model_outputs['next_sentence'], axis=-1))
metrics['next_sentence_loss'].update_state(sentence_loss) metrics['next_sentence_loss'].update_state(sentence_loss)
total_loss = mlm_loss + sentence_loss total_loss = mlm_loss + sentence_loss
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment