Commit 1357ce19 authored by Jeremiah Harmsen's avatar Jeremiah Harmsen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 317638173
parent 8aa44501
......@@ -25,7 +25,6 @@ import tensorflow_hub as hub
from official.modeling import tf_utils
from official.nlp.albert import configs as albert_configs
from official.nlp.bert import configs
from official.nlp.modeling import losses
from official.nlp.modeling import models
from official.nlp.modeling import networks
......@@ -67,22 +66,27 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
next_sentence_loss, name='next_sentence_loss', aggregation='mean')
def call(self,
lm_output,
sentence_output,
lm_output_logits,
sentence_output_logits,
lm_label_ids,
lm_label_weights,
sentence_labels=None):
"""Implements call() for the layer."""
lm_label_weights = tf.cast(lm_label_weights, tf.float32)
lm_output = tf.cast(lm_output, tf.float32)
lm_output_logits = tf.cast(lm_output_logits, tf.float32)
mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss(
labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights)
lm_prediction_losses = tf.keras.losses.sparse_categorical_crossentropy(
lm_label_ids, lm_output_logits, from_logits=True)
lm_numerator_loss = tf.reduce_sum(lm_prediction_losses * lm_label_weights)
lm_denominator_loss = tf.reduce_sum(lm_label_weights)
mask_label_loss = tf.math.divide_no_nan(lm_numerator_loss,
lm_denominator_loss)
if sentence_labels is not None:
sentence_output = tf.cast(sentence_output, tf.float32)
sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss(
labels=sentence_labels, predictions=sentence_output)
sentence_output_logits = tf.cast(sentence_output_logits, tf.float32)
sentence_loss = tf.keras.losses.sparse_categorical_crossentropy(
sentence_labels, sentence_output_logits, from_logits=True)
sentence_loss = tf.reduce_mean(sentence_loss)
loss = mask_label_loss + sentence_loss
else:
sentence_loss = None
......@@ -92,8 +96,8 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
# TODO(hongkuny): Avoids the hack and switches add_loss.
final_loss = tf.fill(batch_shape, loss)
self._add_metrics(lm_output, lm_label_ids, lm_label_weights,
mask_label_loss, sentence_output, sentence_labels,
self._add_metrics(lm_output_logits, lm_label_ids, lm_label_weights,
mask_label_loss, sentence_output_logits, sentence_labels,
sentence_loss)
return final_loss
......@@ -228,7 +232,7 @@ def pretrain_model(bert_config,
activation=tf_utils.get_activation(bert_config.hidden_act),
num_token_predictions=max_predictions_per_seq,
initializer=initializer,
output='predictions')
output='logits')
outputs = pretrainer_model(
[input_word_ids, input_mask, input_type_ids, masked_lm_positions])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment