Commit 1e4fd825 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Update weighted sparse categorical loss. Remove per-example loss.

PiperOrigin-RevId: 318421049
parent 02242bc8
......@@ -4,6 +4,3 @@ Losses contains common loss computation used in NLP tasks.
* `weighted_sparse_categorical_crossentropy_loss` computes per-batch sparse
categorical crossentropy loss.
* `weighted_sparse_categorical_crossentropy_per_example_loss` computes
per-example sparse categorical crossentropy loss.
......@@ -14,4 +14,3 @@
# ==============================================================================
"""Activations package definition. Subject to change."""
from official.nlp.modeling.losses.weighted_sparse_categorical_crossentropy import loss as weighted_sparse_categorical_crossentropy_loss
from official.nlp.modeling.losses.weighted_sparse_categorical_crossentropy import per_example_loss as weighted_sparse_categorical_crossentropy_per_example_loss
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Sparse categorical cross-entropy losses."""
"""Weighted sparse categorical cross-entropy losses."""
from __future__ import absolute_import
from __future__ import division
......@@ -43,37 +43,7 @@ def _validate_rank(labels, predictions, weights):
"predictions.shape was %s.") % (labels.shape, predictions.shape))
def per_example_loss(labels, predictions, weights=None):
"""Calculate a per-example sparse categorical crossentropy loss.
This loss function assumes that the predictions are post-softmax.
Args:
labels: The labels to evaluate against. Should be a set of integer indices
ranging from 0 to (vocab_size-1).
predictions: The network predictions. Should have softmax already applied.
weights: An optional weight array of the same shape as the 'labels' array.
If None, all examples will be used.
Returns:
A tensor of shape predictions.shape[:-1] containing the per-example
loss.
"""
# When using these functions with the Keras core API, we will need to squeeze
# the labels tensor - Keras adds a spurious inner dimension.
labels, predictions = _adjust_labels(labels, predictions)
_validate_rank(labels, predictions, weights)
labels_one_hot = tf.one_hot(labels, predictions.shape[-1])
labels_one_hot = tf.cast(labels_one_hot, predictions.dtype)
per_example_loss_data = -tf.reduce_sum(
predictions * labels_one_hot, axis=[-1])
if weights is not None:
weights = tf.cast(weights, per_example_loss_data.dtype)
per_example_loss_data = weights * per_example_loss_data
return per_example_loss_data
def loss(labels, predictions, weights=None):
def loss(labels, predictions, weights=None, from_logits=False):
"""Calculate a per-batch sparse categorical crossentropy loss.
This loss function assumes that the predictions are post-softmax.
......@@ -83,6 +53,7 @@ def loss(labels, predictions, weights=None):
predictions: The network predictions. Should have softmax already applied.
weights: An optional weight array of the same shape as the 'labels' array.
If None, all examples will be used.
from_logits: Whether the input predictions are logits.
Returns:
A loss scalar.
......@@ -95,12 +66,11 @@ def loss(labels, predictions, weights=None):
labels, predictions = _adjust_labels(labels, predictions)
_validate_rank(labels, predictions, weights)
per_example_loss_data = per_example_loss(labels, predictions, weights)
example_losses = tf.keras.losses.sparse_categorical_crossentropy(
labels, predictions, from_logits=from_logits)
if weights is None:
return tf.reduce_mean(per_example_loss_data)
else:
numerator = tf.reduce_sum(per_example_loss_data)
weights = tf.cast(weights, predictions.dtype)
denominator = tf.reduce_sum(weights) + 1e-5
return numerator / denominator
return tf.reduce_mean(example_losses)
weights = tf.cast(weights, predictions.dtype)
return tf.math.divide_no_nan(
tf.reduce_sum(example_losses * weights), tf.reduce_sum(weights))
......@@ -53,8 +53,7 @@ class ClassificationLossTest(keras_parameterized.TestCase):
# Create a maskedLM from the transformer stack.
test_layer = layers.MaskedLM(
embedding_table=xformer_stack.get_embedding_table(),
output=output)
embedding_table=xformer_stack.get_embedding_table(), output=output)
# Create a model from the masked LM layer.
lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
......@@ -63,123 +62,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
output = test_layer(lm_input_tensor, masked_positions=masked_lm_positions)
return tf.keras.Model([lm_input_tensor, masked_lm_positions], output)
def create_classification_model(self, input_width, num_classes):
test_object = networks.Classification(
input_width=input_width, num_classes=num_classes)
# Create a 2-dimensional input (the first dimension is implicit).
pooled_data = tf.keras.Input(shape=(input_width,), dtype=tf.float32)
output = test_object(pooled_data)
return tf.keras.Model(pooled_data, output)
def test_per_example_loss_3d_input(self):
"""Test per-example loss with a 3-dimensional input, from a masked LM."""
vocab_size = 100
sequence_length = 32
hidden_size = 64
num_predictions = 21
model = self.create_lm_model(
vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size,
num_predictions=num_predictions)
# Get the output of the masked LM.
batch_size = 3
lm_input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, hidden_size))
masked_position_data = np.random.randint(
2, size=(batch_size, num_predictions))
output_data = model.predict([lm_input_data, masked_position_data])
# Calculate per-example loss.
labels = np.random.randint(vocab_size, size=(batch_size, num_predictions))
per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels)
# Per-example loss data should have one value per prediction, and those
# values shouldn't be zero in this case (as we're using random data).
expected_shape = [batch_size, num_predictions]
self.assertEqual(expected_shape, per_example_loss_data.shape.as_list())
self.assertNotAllClose(
tf.zeros_like(per_example_loss_data), per_example_loss_data)
def test_per_example_loss_2d_input(self):
"""Test per-example loss with a 2-d input, from a classifier."""
input_width = 512
num_classes = 10
model = self.create_classification_model(input_width, num_classes)
# Invoke the network as part of a Model.
batch_size = 3
input_data = 10 * np.random.random_sample((batch_size, input_width))
output_data = model.predict(input_data)
# Calculate per example loss.
labels = np.random.randint(num_classes, size=(batch_size))
per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels)
# Per-example loss data should have one value per batch item, and those
# values shouldn't be zero in this case (as we're using random data).
self.assertEqual([batch_size], per_example_loss_data.shape.as_list())
self.assertNotAllClose(
tf.zeros_like(per_example_loss_data), per_example_loss_data)
def test_per_example_loss_weights_3d_input(self):
"""Test weighted per-example loss with a 3-d input, from a masked LM."""
vocab_size = 100
sequence_length = 32
hidden_size = 64
num_predictions = 21
model = self.create_lm_model(
vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size,
num_predictions=num_predictions)
# Get the output of the masked LM.
batch_size = 3
lm_input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, hidden_size))
masked_position_data = np.random.randint(
2, size=(batch_size, num_predictions))
output_data = model.predict([lm_input_data, masked_position_data])
# Calculate per-example loss with weights.
labels = np.random.randint(vocab_size, size=(batch_size, num_predictions))
weights = np.random.randint(2, size=(batch_size, num_predictions))
per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels, weights=weights)
# Weighted per-example loss data should be equivalent to multiplying the
# loss tensor by the weights tensor.
expected_weighted_loss = per_example_loss_data * weights
self.assertAllClose(expected_weighted_loss, per_example_loss_data)
def test_per_example_loss_weights_2d_input(self):
"""Test weighted per-example loss with a 2-d input, from a classifier."""
input_width = 512
num_classes = 10
model = self.create_classification_model(input_width, num_classes)
# Invoke the network as part of a Model.
batch_size = 3
input_data = 10 * np.random.random_sample((batch_size, input_width))
output_data = model.predict(input_data)
# Calculate per-example loss with weights.
labels = np.random.randint(num_classes, size=(batch_size))
weights = np.random.randint(2, size=(batch_size))
per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels, weights=weights)
# Weighted per-example loss data should be equivalent to multiplying the
# loss tensor by the weights tensor.
expected_weighted_loss = per_example_loss_data * weights
self.assertAllClose(expected_weighted_loss, per_example_loss_data)
def test_loss_3d_input(self):
"""Test overall loss with a 3-dimensional input, from a masked LM."""
vocab_size = 100
......@@ -213,26 +95,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
self.assertNotAllClose(
tf.zeros_like(per_example_loss_data), per_example_loss_data)
def test_loss_2d_input(self):
"""Test overall loss with a 2-d input, from a classifier."""
input_width = 512
num_classes = 10
model = self.create_classification_model(input_width, num_classes)
# Invoke the network as part of a Model.
batch_size = 3
input_data = 10 * np.random.random_sample((batch_size, input_width))
output_data = model.predict(input_data)
# Calculate per example loss.
labels = np.random.randint(num_classes, size=(batch_size))
loss_data = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels)
# Loss data should have one value only, and that value shouldn't be zero in
# this case (as we're using random data).
self.assertNotAllClose(0, loss_data)
def test_loss_weights_3d_input(self):
"""Test masked loss with a 3-dimensional input, from a masked LM."""
vocab_size = 100
......@@ -262,26 +124,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
# Because the tensor is fully masked, the loss should be 0.
self.assertAllClose(0, weighted_loss_data)
def test_loss_weights_2d_input(self):
"""Test masked loss with a 2-d input, from a classifier."""
input_width = 512
num_classes = 10
model = self.create_classification_model(input_width, num_classes)
# Invoke the network as part of a Model.
batch_size = 3
input_data = 10 * np.random.random_sample((batch_size, input_width))
output_data = model.predict(input_data)
# Calculate a fully masked weight tensor. This should give a loss of zero.
labels = np.random.randint(num_classes, size=(batch_size))
null_weights = np.zeros((batch_size))
weighted_loss_data = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels, weights=null_weights)
# Because the tensor is fully masked, the loss should be 0.
self.assertAllClose(0, weighted_loss_data)
def test_mismatched_predictions_and_labels_ranks_squeezes(self):
"""Test that the loss asserts when rank(predictions)-1 != rank(labels)."""
batch_size = 3
......@@ -289,7 +131,7 @@ class ClassificationLossTest(keras_parameterized.TestCase):
labels = np.random.randint(10, size=(batch_size, 1))
# All that this test tests is that the squeeze is successful.
_ = weighted_sparse_categorical_crossentropy.per_example_loss(
_ = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels)
def test_mismatched_weights_and_labels_ranks_fail(self):
......@@ -299,9 +141,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
labels = np.random.randint(10, size=(batch_size, 10))
weights = np.random.randint(2, size=(batch_size))
with self.assertRaisesRegex(RuntimeError, ".*of the same rank.*"):
_ = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels, weights=weights)
with self.assertRaisesRegex(RuntimeError, ".*of the same rank.*"):
_ = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels, weights=weights)
......@@ -317,8 +156,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
# We're not trying to validate numerical correctness, just ensure that
# we can in fact pass tensors to these functions without causing runtime
# errors from the shape checking code.
_ = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels, weights=weights)
_ = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels, weights=weights)
......@@ -338,20 +175,15 @@ class ClassificationLossTest(keras_parameterized.TestCase):
[-2.7760355, -1.8219438, -3.0924666, -1.0779881, -0.9407509]]])
labels = np.array([[4, 0], [2, 2], [2, 1]])
# Validate that per_example loss calculations are the same.
per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels)
expected_per_example_loss_data = [[1.2923571, 2.7117882],
[2.287932, 2.287932],
[3.0924666, 1.8219438]]
self.assertAllClose(expected_per_example_loss_data, per_example_loss_data)
# Validate that overall loss calculations are the same.
weights = np.array([[1, 0], [0, 0], [0, 0]])
loss_data = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels, weights=weights)
predictions=output_data,
labels=labels,
weights=weights,
from_logits=True)
expected_loss_data = 1.2923441
self.assertAllClose(expected_loss_data, loss_data)
self.assertAllClose(expected_loss_data, loss_data, rtol=1e-3)
def test_legacy_classification_loss_compatibility(self):
"""Test to validate computational correctness during refactors."""
......@@ -362,19 +194,15 @@ class ClassificationLossTest(keras_parameterized.TestCase):
[-1.6975292e-03, -6.4009643e+00, -1.0226612e+01]])
labels = np.array([2, 1])
# Validate that per_example loss calculations are the same.
per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels)
expected_per_example_loss_data = [6.4434357, 6.4009643]
self.assertAllClose(expected_per_example_loss_data, per_example_loss_data)
# Validate that overall loss calculations are the same.
weights = None
loss_data = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels, weights=weights)
predictions=output_data,
labels=labels,
weights=weights,
from_logits=True)
expected_loss_data = 6.4222
self.assertAllClose(expected_loss_data, loss_data)
self.assertAllClose(expected_loss_data, loss_data, rtol=1e-3)
if __name__ == "__main__":
tf.test.main()
......@@ -21,7 +21,6 @@ from official.core import base_task
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import bert
from official.nlp.data import pretrain_dataloader
from official.nlp.modeling import losses as loss_lib
@dataclasses.dataclass
......@@ -61,9 +60,10 @@ class MaskedLMTask(base_task.Task):
sentence_labels = labels['next_sentence_labels']
sentence_outputs = tf.cast(
model_outputs['next_sentence'], dtype=tf.float32)
sentence_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
labels=sentence_labels,
predictions=tf.nn.log_softmax(sentence_outputs, axis=-1))
sentence_loss = tf.keras.losses.sparse_categorical_crossentropy(
sentence_labels,
sentence_outputs,
from_logits=True)
metrics['next_sentence_loss'].update_state(sentence_loss)
total_loss = mlm_loss + sentence_loss
else:
......
......@@ -26,7 +26,6 @@ from official.core import base_task
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import bert
from official.nlp.data import sentence_prediction_dataloader
from official.nlp.modeling import losses as loss_lib
from official.nlp.tasks import utils
......@@ -75,10 +74,10 @@ class SentencePredictionTask(base_task.Task):
return bert.instantiate_bertpretrainer_from_cfg(self.task_config.model)
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
labels=labels,
predictions=tf.nn.log_softmax(
tf.cast(model_outputs['sentence_prediction'], tf.float32), axis=-1))
loss = tf.keras.losses.sparse_categorical_crossentropy(
labels,
tf.cast(model_outputs['sentence_prediction'], tf.float32),
from_logits=True)
if aux_losses:
loss += tf.add_n(aux_losses)
......@@ -94,7 +93,7 @@ class SentencePredictionTask(base_task.Task):
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids)
y = tf.ones((1, 1), dtype=tf.int32)
y = tf.zeros((1, 1), dtype=tf.int32)
return (x, y)
dataset = tf.data.Dataset.range(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment