"tests/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "552cd32058660573c14ac481d88417e828c68756"
Commit 1357ce19 authored by Jeremiah Harmsen's avatar Jeremiah Harmsen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 317638173
parent 8aa44501
...@@ -25,7 +25,6 @@ import tensorflow_hub as hub ...@@ -25,7 +25,6 @@ import tensorflow_hub as hub
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp.albert import configs as albert_configs from official.nlp.albert import configs as albert_configs
from official.nlp.bert import configs from official.nlp.bert import configs
from official.nlp.modeling import losses
from official.nlp.modeling import models from official.nlp.modeling import models
from official.nlp.modeling import networks from official.nlp.modeling import networks
...@@ -67,22 +66,27 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): ...@@ -67,22 +66,27 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
next_sentence_loss, name='next_sentence_loss', aggregation='mean') next_sentence_loss, name='next_sentence_loss', aggregation='mean')
def call(self, def call(self,
lm_output, lm_output_logits,
sentence_output, sentence_output_logits,
lm_label_ids, lm_label_ids,
lm_label_weights, lm_label_weights,
sentence_labels=None): sentence_labels=None):
"""Implements call() for the layer.""" """Implements call() for the layer."""
lm_label_weights = tf.cast(lm_label_weights, tf.float32) lm_label_weights = tf.cast(lm_label_weights, tf.float32)
lm_output = tf.cast(lm_output, tf.float32) lm_output_logits = tf.cast(lm_output_logits, tf.float32)
mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss( lm_prediction_losses = tf.keras.losses.sparse_categorical_crossentropy(
labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights) lm_label_ids, lm_output_logits, from_logits=True)
lm_numerator_loss = tf.reduce_sum(lm_prediction_losses * lm_label_weights)
lm_denominator_loss = tf.reduce_sum(lm_label_weights)
mask_label_loss = tf.math.divide_no_nan(lm_numerator_loss,
lm_denominator_loss)
if sentence_labels is not None: if sentence_labels is not None:
sentence_output = tf.cast(sentence_output, tf.float32) sentence_output_logits = tf.cast(sentence_output_logits, tf.float32)
sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss( sentence_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels=sentence_labels, predictions=sentence_output) sentence_labels, sentence_output_logits, from_logits=True)
sentence_loss = tf.reduce_mean(sentence_loss)
loss = mask_label_loss + sentence_loss loss = mask_label_loss + sentence_loss
else: else:
sentence_loss = None sentence_loss = None
...@@ -92,8 +96,8 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): ...@@ -92,8 +96,8 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
# TODO(hongkuny): Avoids the hack and switches add_loss. # TODO(hongkuny): Avoids the hack and switches add_loss.
final_loss = tf.fill(batch_shape, loss) final_loss = tf.fill(batch_shape, loss)
self._add_metrics(lm_output, lm_label_ids, lm_label_weights, self._add_metrics(lm_output_logits, lm_label_ids, lm_label_weights,
mask_label_loss, sentence_output, sentence_labels, mask_label_loss, sentence_output_logits, sentence_labels,
sentence_loss) sentence_loss)
return final_loss return final_loss
...@@ -228,7 +232,7 @@ def pretrain_model(bert_config, ...@@ -228,7 +232,7 @@ def pretrain_model(bert_config,
activation=tf_utils.get_activation(bert_config.hidden_act), activation=tf_utils.get_activation(bert_config.hidden_act),
num_token_predictions=max_predictions_per_seq, num_token_predictions=max_predictions_per_seq,
initializer=initializer, initializer=initializer,
output='predictions') output='logits')
outputs = pretrainer_model( outputs = pretrainer_model(
[input_word_ids, input_mask, input_type_ids, masked_lm_positions]) [input_word_ids, input_mask, input_type_ids, masked_lm_positions])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment