Commit 68d9973c authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 469755043
parent c04133d9
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
"""Losses used for segmentation models.""" """Losses used for segmentation models."""
# Import libraries
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
...@@ -25,16 +24,45 @@ EPSILON = 1e-5 ...@@ -25,16 +24,45 @@ EPSILON = 1e-5
class SegmentationLoss: class SegmentationLoss:
"""Semantic segmentation loss.""" """Semantic segmentation loss."""
def __init__(self, label_smoothing, class_weights, ignore_label, def __init__(self,
use_groundtruth_dimension, top_k_percent_pixels=1.0): label_smoothing,
self._top_k_percent_pixels = top_k_percent_pixels class_weights,
ignore_label,
use_groundtruth_dimension,
top_k_percent_pixels=1.0):
"""Initializes `SegmentationLoss`.
Args:
label_smoothing: A float, if > 0., smooth out one-hot probability by
spreading the amount of probability to all other label classes.
class_weights: A float list containing the weight of each class.
ignore_label: An integer specifying the ignore label.
use_groundtruth_dimension: A boolean, whether to resize the output to
match the dimension of the ground truth.
top_k_percent_pixels: A float, the value lies in [0.0, 1.0]. When its
value < 1., only compute the loss for the top k percent pixels. This is
useful for hard pixel mining.
"""
self._label_smoothing = label_smoothing
self._class_weights = class_weights self._class_weights = class_weights
self._ignore_label = ignore_label self._ignore_label = ignore_label
self._use_groundtruth_dimension = use_groundtruth_dimension self._use_groundtruth_dimension = use_groundtruth_dimension
self._label_smoothing = label_smoothing self._top_k_percent_pixels = top_k_percent_pixels
def __call__(self, logits, labels, **kwargs): def __call__(self, logits, labels, **kwargs):
_, height, width, num_classes = logits.get_shape().as_list() """Computes `SegmentationLoss`.
Args:
logits: A float tensor in shape (batch_size, height, width, num_classes)
which is the output of the network.
labels: A tensor in shape (batch_size, height, width, 1), which is the
label mask of the ground truth.
**kwargs: additional keyword arguments.
Returns:
A 0-D float which stores the overall loss of the batch.
"""
_, height, width, _ = logits.get_shape().as_list()
if self._use_groundtruth_dimension: if self._use_groundtruth_dimension:
# TODO(arashwan): Test using align corners to match deeplab alignment. # TODO(arashwan): Test using align corners to match deeplab alignment.
...@@ -45,14 +73,38 @@ class SegmentationLoss: ...@@ -45,14 +73,38 @@ class SegmentationLoss:
labels, (height, width), labels, (height, width),
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
labels = tf.cast(labels, tf.int32)
valid_mask = tf.not_equal(labels, self._ignore_label) valid_mask = tf.not_equal(labels, self._ignore_label)
normalizer = tf.reduce_sum(tf.cast(valid_mask, tf.float32)) + EPSILON cross_entropy_loss = self.compute_pixelwise_loss(labels, logits, valid_mask,
**kwargs)
if self._top_k_percent_pixels < 1.0:
return self.aggregate_loss_top_k(cross_entropy_loss)
else:
return self.aggregate_loss(cross_entropy_loss, valid_mask)
def compute_pixelwise_loss(self, labels, logits, valid_mask, **kwargs):
"""Computes the loss for each pixel.
Args:
labels: An int32 tensor in shape (batch_size, height, width, 1), which is
the label mask of the ground truth.
logits: A float tensor in shape (batch_size, height, width, num_classes)
which is the output of the network.
valid_mask: A bool tensor in shape (batch_size, height, width, 1) which
masks out ignored pixels.
**kwargs: additional keyword arguments.
Returns:
A float tensor in shape (batch_size, height, width) which stores the loss
value for each pixel.
"""
num_classes = logits.get_shape().as_list()[-1]
# Assign pixel with ignore label to class 0 (background). The loss on the # Assign pixel with ignore label to class 0 (background). The loss on the
# pixel will later be masked out. # pixel will later be masked out.
labels = tf.where(valid_mask, labels, tf.zeros_like(labels)) labels = tf.where(valid_mask, labels, tf.zeros_like(labels))
labels = tf.squeeze(tf.cast(labels, tf.int32), axis=3)
valid_mask = tf.squeeze(tf.cast(valid_mask, tf.float32), axis=3)
cross_entropy_loss = tf.nn.softmax_cross_entropy_with_logits( cross_entropy_loss = tf.nn.softmax_cross_entropy_with_logits(
labels=self.get_labels_with_prob(labels, logits, **kwargs), labels=self.get_labels_with_prob(labels, logits, **kwargs),
logits=logits) logits=logits)
...@@ -66,26 +118,12 @@ class SegmentationLoss: ...@@ -66,26 +118,12 @@ class SegmentationLoss:
raise ValueError( raise ValueError(
'Length of class_weights should be {}'.format(num_classes)) 'Length of class_weights should be {}'.format(num_classes))
weight_mask = tf.einsum('...y,y->...', valid_mask = tf.squeeze(tf.cast(valid_mask, tf.float32), axis=-1)
tf.one_hot(labels, num_classes, dtype=tf.float32), weight_mask = tf.einsum(
tf.constant(class_weights, tf.float32)) '...y,y->...',
valid_mask *= weight_mask tf.one_hot(tf.squeeze(labels, axis=-1), num_classes, dtype=tf.float32),
cross_entropy_loss *= tf.cast(valid_mask, tf.float32) tf.constant(class_weights, tf.float32))
return cross_entropy_loss * valid_mask * weight_mask
if self._top_k_percent_pixels >= 1.0:
loss = tf.reduce_sum(cross_entropy_loss) / normalizer
else:
cross_entropy_loss = tf.reshape(cross_entropy_loss, shape=[-1])
top_k_pixels = tf.cast(
self._top_k_percent_pixels *
tf.cast(tf.size(cross_entropy_loss), tf.float32), tf.int32)
top_k_losses, _ = tf.math.top_k(
cross_entropy_loss, k=top_k_pixels, sorted=True)
normalizer = tf.reduce_sum(
tf.cast(tf.not_equal(top_k_losses, 0.0), tf.float32)) + EPSILON
loss = tf.reduce_sum(top_k_losses) / normalizer
return loss
def get_labels_with_prob(self, labels, logits, **unused_kwargs): def get_labels_with_prob(self, labels, logits, **unused_kwargs):
"""Get a tensor representing the probability of each class for each pixel. """Get a tensor representing the probability of each class for each pixel.
...@@ -93,8 +131,8 @@ class SegmentationLoss: ...@@ -93,8 +131,8 @@ class SegmentationLoss:
This method can be overridden in subclasses for customizing loss function. This method can be overridden in subclasses for customizing loss function.
Args: Args:
labels: A float tensor in shape (batch_size, height, width), which is the labels: An int32 tensor in shape (batch_size, height, width, 1), which is
label map of the ground truth. the label map of the ground truth.
logits: A float tensor in shape (batch_size, height, width, num_classes) logits: A float tensor in shape (batch_size, height, width, num_classes)
which is the output of the network. which is the output of the network.
**unused_kwargs: Unused keyword arguments. **unused_kwargs: Unused keyword arguments.
...@@ -102,11 +140,46 @@ class SegmentationLoss: ...@@ -102,11 +140,46 @@ class SegmentationLoss:
Returns: Returns:
A float tensor in shape (batch_size, height, width, num_classes). A float tensor in shape (batch_size, height, width, num_classes).
""" """
labels = tf.squeeze(labels, axis=-1)
num_classes = logits.get_shape().as_list()[-1] num_classes = logits.get_shape().as_list()[-1]
onehot_labels = tf.one_hot(labels, num_classes) onehot_labels = tf.one_hot(labels, num_classes)
return onehot_labels * ( return onehot_labels * (
1 - self._label_smoothing) + self._label_smoothing / num_classes 1 - self._label_smoothing) + self._label_smoothing / num_classes
def aggregate_loss(self, pixelwise_loss, valid_mask):
"""Aggregate the pixelwise loss.
Args:
pixelwise_loss: A float tensor in shape (batch_size, height, width) which
stores the loss of each pixel.
valid_mask: A bool tensor in shape (batch_size, height, width, 1) which
masks out ignored pixels.
Returns:
A 0-D float which stores the overall loss of the batch.
"""
normalizer = tf.reduce_sum(tf.cast(valid_mask, tf.float32)) + EPSILON
return tf.reduce_sum(pixelwise_loss) / normalizer
def aggregate_loss_top_k(self, pixelwise_loss):
"""Aggregate the top-k greatest pixelwise loss.
Args:
pixelwise_loss: A float tensor in shape (batch_size, height, width) which
stores the loss of each pixel.
Returns:
A 0-D float which stores the overall loss of the batch.
"""
pixelwise_loss = tf.reshape(pixelwise_loss, shape=[-1])
top_k_pixels = tf.cast(
self._top_k_percent_pixels *
tf.cast(tf.size(pixelwise_loss), tf.float32), tf.int32)
top_k_losses, _ = tf.math.top_k(pixelwise_loss, k=top_k_pixels, sorted=True)
normalizer = tf.reduce_sum(
tf.cast(tf.not_equal(top_k_losses, 0.0), tf.float32)) + EPSILON
return tf.reduce_sum(top_k_losses) / normalizer
def get_actual_mask_scores(logits, labels, ignore_label): def get_actual_mask_scores(logits, labels, ignore_label):
"""Gets actual mask scores.""" """Gets actual mask scores."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment