Commit 1b3b2839 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

no need to use `tf.keras.backend` ops. Just use equivalent tf.* ops

PiperOrigin-RevId: 294588219
parent 6d1d918f
......@@ -68,14 +68,14 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
def call(self, lm_output, sentence_output, lm_label_ids, lm_label_weights,
sentence_labels):
"""Implements call() for the layer."""
lm_label_weights = tf.keras.backend.cast(lm_label_weights, tf.float32)
lm_label_weights = tf.cast(lm_label_weights, tf.float32)
mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss(
labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights)
sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss(
labels=sentence_labels, predictions=sentence_output)
loss = mask_label_loss + sentence_loss
batch_shape = tf.slice(tf.keras.backend.shape(sentence_labels), [0], [1])
batch_shape = tf.slice(tf.shape(sentence_labels), [0], [1])
# TODO(hongkuny): Avoids the hack and switches add_loss.
final_loss = tf.fill(batch_shape, loss)
......@@ -208,10 +208,9 @@ class BertSquadLogitsLayer(tf.keras.layers.Layer):
sequence_length = input_shape[1]
num_hidden_units = input_shape[2]
final_hidden_input = tf.keras.backend.reshape(sequence_output,
[-1, num_hidden_units])
final_hidden_input = tf.reshape(sequence_output, [-1, num_hidden_units])
logits = self.final_dense(final_hidden_input)
logits = tf.keras.backend.reshape(logits, [-1, sequence_length, 2])
logits = tf.reshape(logits, [-1, sequence_length, 2])
logits = tf.transpose(logits, [2, 0, 1])
unstacked_logits = tf.unstack(logits, axis=0)
return unstacked_logits[0], unstacked_logits[1]
......
......@@ -77,7 +77,7 @@ def get_loss_fn(loss_factor=1.0):
"""Returns loss function for BERT pretraining."""
def _bert_pretrain_loss_fn(unused_labels, losses, **unused_args):
return tf.keras.backend.mean(losses) * loss_factor
return tf.reduce_mean(losses) * loss_factor
return _bert_pretrain_loss_fn
......
......@@ -63,12 +63,12 @@ def per_example_loss(labels, predictions, weights=None):
labels, predictions = _adjust_labels(labels, predictions)
_validate_rank(labels, predictions, weights)
labels_one_hot = tf.keras.backend.one_hot(labels, predictions.shape[-1])
labels_one_hot = tf.keras.backend.cast(labels_one_hot, predictions.dtype)
per_example_loss_data = -tf.keras.backend.sum(
labels_one_hot = tf.one_hot(labels, predictions.shape[-1])
labels_one_hot = tf.cast(labels_one_hot, predictions.dtype)
per_example_loss_data = -tf.reduce_sum(
predictions * labels_one_hot, axis=[-1])
if weights is not None:
weights = tf.keras.backend.cast(weights, per_example_loss_data.dtype)
weights = tf.cast(weights, per_example_loss_data.dtype)
per_example_loss_data = weights * per_example_loss_data
return per_example_loss_data
......@@ -98,9 +98,9 @@ def loss(labels, predictions, weights=None):
per_example_loss_data = per_example_loss(labels, predictions, weights)
if weights is None:
return tf.keras.backend.mean(per_example_loss_data)
return tf.reduce_mean(per_example_loss_data)
else:
numerator = tf.keras.backend.sum(per_example_loss_data)
weights = tf.keras.backend.cast(weights, predictions.dtype)
denominator = tf.keras.backend.sum(weights) + 1e-5
numerator = tf.reduce_sum(per_example_loss_data)
weights = tf.cast(weights, predictions.dtype)
denominator = tf.reduce_sum(weights) + 1e-5
return numerator / denominator
......@@ -128,11 +128,11 @@ class MaskedLM(network.Network):
sequence_tensor, name='sequence_output_tensor')
batch_size, seq_length, width = sequence_shape
flat_offsets = tf.keras.backend.reshape(
flat_offsets = tf.reshape(
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
flat_positions = tf.keras.backend.reshape(positions + flat_offsets, [-1])
flat_sequence_tensor = tf.keras.backend.reshape(
sequence_tensor, [batch_size * seq_length, width])
flat_positions = tf.reshape(positions + flat_offsets, [-1])
flat_sequence_tensor = tf.reshape(sequence_tensor,
[batch_size * seq_length, width])
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
return output_tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment