Unverified Commit 0344c550 authored by Katherine Wu's avatar Katherine Wu Committed by GitHub
Browse files

Fix transformer loss (#4270)

parent 461fc094
...@@ -81,9 +81,12 @@ def model_fn(features, labels, mode, params): ...@@ -81,9 +81,12 @@ def model_fn(features, labels, mode, params):
logits = output logits = output
# Calculate model loss. # Calculate model loss.
# xentropy contains the cross entropy loss of every nonpadding token in the
# targets.
xentropy, weights = metrics.padded_cross_entropy_loss( xentropy, weights = metrics.padded_cross_entropy_loss(
logits, targets, params.label_smoothing, params.vocab_size) logits, targets, params.label_smoothing, params.vocab_size)
loss = tf.reduce_sum(xentropy * weights) / tf.reduce_sum(weights) # Compute the weighted mean of the cross entropy losses
loss = tf.reduce_sum(xentropy) / tf.reduce_sum(weights)
# Save loss as named tensor that will be logged with the logging hook. # Save loss as named tensor that will be logged with the logging hook.
tf.identity(loss, "cross_entropy") tf.identity(loss, "cross_entropy")
......
...@@ -58,8 +58,8 @@ def padded_cross_entropy_loss(logits, labels, smoothing, vocab_size): ...@@ -58,8 +58,8 @@ def padded_cross_entropy_loss(logits, labels, smoothing, vocab_size):
smoothing: Label smoothing constant, used to determine the on and off values smoothing: Label smoothing constant, used to determine the on and off values
vocab_size: int size of the vocabulary vocab_size: int size of the vocabulary
Returns: Returns:
Returns a float32 tensor with shape Returns the cross entropy loss and weight tensors: float32 tensors with
[batch_size, max(length_logits, length_labels)] shape [batch_size, max(length_logits, length_labels)]
""" """
with tf.name_scope("loss", [logits, labels]): with tf.name_scope("loss", [logits, labels]):
logits, labels = _pad_tensors_to_same_length(logits, labels) logits, labels = _pad_tensors_to_same_length(logits, labels)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment