Commit b2bf29cf authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Update losses.

PiperOrigin-RevId: 279111017
parent 833e6939
......@@ -36,12 +36,13 @@ def focal_loss(logits, targets, alpha, gamma, normalizer):
and (1-alpha) to the loss from negative examples.
gamma: A float32 scalar modulating loss from hard and easy examples.
normalizer: A float32 scalar normalizes the total loss from all examples.
Returns:
loss: A float32 Tensor of size [batch, height_in, width_in, num_predictions]
representing normalized loss on the prediction map.
"""
with tf.name_scope('focal_loss'):
positive_label_mask = tf.equal(targets, 1.0)
positive_label_mask = tf.math.equal(targets, 1.0)
cross_entropy = (
tf.nn.sigmoid_cross_entropy_with_logits(labels=targets, logits=logits))
# Below are comments/derivations for computing modulator.
......@@ -74,8 +75,8 @@ def focal_loss(logits, targets, alpha, gamma, normalizer):
# samples is:
# (1 - p_t)^r = exp(-r * z * x - r * log(1 + exp(-x))).
neg_logits = -1.0 * logits
modulator = tf.exp(gamma * targets * neg_logits -
gamma * tf.math.log1p(tf.exp(neg_logits)))
modulator = tf.math.exp(gamma * targets * neg_logits -
gamma * tf.math.log1p(tf.math.exp(neg_logits)))
loss = modulator * cross_entropy
weighted_loss = tf.where(positive_label_mask, alpha * loss,
(1.0 - alpha) * loss)
......@@ -87,19 +88,19 @@ class RpnScoreLoss(object):
"""Region Proposal Network score loss function."""
def __init__(self, params):
raise ValueError('Not TF 2.0 ready.')
self._batch_size = params.batch_size
self._rpn_batch_size_per_im = params.rpn_batch_size_per_im
def __call__(self, score_outputs, labels):
"""Computes total RPN detection loss.
Computes total RPN detection loss including box and score from all levels.
Args:
score_outputs: an OrderDict with keys representing levels and values
representing scores in [batch_size, height, width, num_anchors].
labels: the dictionary that returned from dataloader that includes
groundturth targets.
Returns:
rpn_score_loss: a scalar tensor representing total score loss.
"""
......@@ -108,17 +109,16 @@ class RpnScoreLoss(object):
score_losses = []
for level in levels:
score_targets_l = labels['score_targets_%d' % level]
score_losses.append(
self._rpn_score_loss(
score_outputs[level],
score_targets_l,
labels[level],
normalizer=tf.cast(
self._batch_size * self._rpn_batch_size_per_im,
dtype=tf.float32)))
tf.shape(score_outputs[level])[0] *
self._rpn_batch_size_per_im, dtype=tf.float32)))
# Sums per level losses to total loss.
return tf.add_n(score_losses)
return tf.math.add_n(score_losses)
def _rpn_score_loss(self, score_outputs, score_targets, normalizer=1.0):
"""Computes score loss."""
......@@ -127,10 +127,13 @@ class RpnScoreLoss(object):
# (2) score_targets[i]=0, negative.
# (3) score_targets[i]=-1, the anchor is don't care (ignore).
with tf.name_scope('rpn_score_loss'):
mask = tf.logical_or(tf.equal(score_targets, 1),
tf.equal(score_targets, 0))
score_targets = tf.maximum(score_targets, tf.zeros_like(score_targets))
mask = tf.math.logical_or(tf.math.equal(score_targets, 1),
tf.math.equal(score_targets, 0))
score_targets = tf.math.maximum(score_targets, tf.zeros_like(score_targets))
# RPN score loss is sum over all except ignored samples.
# Keep the compat.v1 loss because Keras does not have a
# sigmoid_cross_entropy substitution yet.
# TODO(b/143720144): replace this loss.
score_loss = tf.compat.v1.losses.sigmoid_cross_entropy(
score_targets,
score_outputs,
......@@ -144,31 +147,32 @@ class RpnBoxLoss(object):
"""Region Proposal Network box regression loss function."""
def __init__(self, params):
raise ValueError('Not TF 2.0 ready.')
self._delta = params.huber_loss_delta
self._huber_loss = tf.keras.losses.Huber(
delta=params.huber_loss_delta, reduction=tf.keras.losses.Reduction.SUM)
def __call__(self, box_outputs, labels):
"""Computes total RPN detection loss.
Computes total RPN detection loss including box and score from all levels.
Args:
box_outputs: an OrderDict with keys representing levels and values
representing box regression targets in
[batch_size, height, width, num_anchors * 4].
labels: the dictionary that returned from dataloader that includes
groundturth targets.
Returns:
rpn_box_loss: a scalar tensor representing total box regression loss.
"""
with tf.compat.v1.name_scope('rpn_loss'):
with tf.name_scope('rpn_loss'):
levels = sorted(box_outputs.keys())
box_losses = []
for level in levels:
box_targets_l = labels['box_targets_%d' % level]
box_losses.append(
self._rpn_box_loss(
box_outputs[level], box_targets_l, delta=self._delta))
box_outputs[level], labels[level], delta=self._delta))
# Sum per level losses to total loss.
return tf.add_n(box_losses)
......@@ -178,16 +182,11 @@ class RpnBoxLoss(object):
# The delta is typically around the mean value of regression target.
# for instances, the regression targets of 512x512 input with 6 anchors on
# P2-P6 pyramid is about [0.1, 0.1, 0.2, 0.2].
with tf.compat.v1.name_scope('rpn_box_loss'):
mask = tf.not_equal(box_targets, 0.0)
with tf.name_scope('rpn_box_loss'):
mask = tf.math.not_equal(box_targets, 0.0)
# The loss is normalized by the sum of non-zero weights before additional
# normalizer provided by the function caller.
box_loss = tf.compat.v1.losses.huber_loss(
box_targets,
box_outputs,
weights=mask,
delta=delta,
reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS)
box_loss = self._huber_loss(box_targets, box_outputs, sample_weight=mask)
box_loss /= normalizer
return box_loss
......@@ -195,9 +194,6 @@ class RpnBoxLoss(object):
class FastrcnnClassLoss(object):
"""Fast R-CNN classification loss function."""
def __init__(self):
raise ValueError('Not TF 2.0 ready.')
def __call__(self, class_outputs, class_targets):
"""Computes the class loss (Fast-RCNN branch) of Mask-RCNN.
......@@ -211,10 +207,11 @@ class FastrcnnClassLoss(object):
with a shape of [batch_size, num_boxes, num_classes].
class_targets: a float tensor representing the class label for each box
with a shape of [batch_size, num_boxes].
Returns:
a scalar tensor representing total class loss.
"""
with tf.compat.v1.name_scope('fast_rcnn_loss'):
with tf.name_scope('fast_rcnn_loss'):
_, _, _, num_classes = class_outputs.get_shape().as_list()
class_targets = tf.cast(class_targets, dtype=tf.int32)
class_targets_one_hot = tf.one_hot(class_targets, num_classes)
......@@ -223,9 +220,12 @@ class FastrcnnClassLoss(object):
def _fast_rcnn_class_loss(self, class_outputs, class_targets_one_hot,
normalizer=1.0):
"""Computes classification loss."""
with tf.compat.v1.name_scope('fast_rcnn_class_loss'):
with tf.name_scope('fast_rcnn_class_loss'):
# The loss is normalized by the sum of non-zero weights before additional
# normalizer provided by the function caller.
# Keep the compat.v1 loss because Keras does not have a
# softmax_cross_entropy substitution yet.
# TODO(b/143720144): replace this loss.
class_loss = tf.compat.v1.losses.softmax_cross_entropy(
class_targets_one_hot,
class_outputs,
......@@ -238,7 +238,6 @@ class FastrcnnBoxLoss(object):
"""Fast R-CNN box regression loss function."""
def __init__(self, params):
raise ValueError('Not TF 2.0 ready.')
self._delta = params.huber_loss_delta
def __call__(self, box_outputs, class_targets, box_targets):
......@@ -261,10 +260,11 @@ class FastrcnnBoxLoss(object):
with a shape of [batch_size, num_boxes].
box_targets: a float tensor representing the box label for each box
with a shape of [batch_size, num_boxes, 4].
Returns:
box_loss: a scalar tensor representing total box regression loss.
"""
with tf.compat.v1.name_scope('fast_rcnn_loss'):
with tf.name_scope('fast_rcnn_loss'):
class_targets = tf.cast(class_targets, dtype=tf.int32)
# Selects the box from `box_outputs` based on `class_targets`, with which
......@@ -299,11 +299,14 @@ class FastrcnnBoxLoss(object):
# The delta is typically around the mean value of regression target.
# for instances, the regression targets of 512x512 input with 6 anchors on
# P2-P6 pyramid is about [0.1, 0.1, 0.2, 0.2].
with tf.compat.v1.name_scope('fast_rcnn_box_loss'):
with tf.name_scope('fast_rcnn_box_loss'):
mask = tf.tile(tf.expand_dims(tf.greater(class_targets, 0), axis=2),
[1, 1, 4])
# The loss is normalized by the sum of non-zero weights before additional
# normalizer provided by the function caller.
# Keep the compat.v1 loss because Keras does not have a
# Reduction.SUM_BY_NONZERO_WEIGHTS substitution yet.
# TODO(b/143720144): replace this loss.
box_loss = tf.compat.v1.losses.huber_loss(
box_targets,
box_outputs,
......@@ -340,10 +343,11 @@ class MaskrcnnLoss(object):
[batch_size, num_masks, mask_height, mask_width].
select_class_targets: a tensor with a shape of [batch_size, num_masks],
representing the foreground mask targets.
Returns:
mask_loss: a float tensor representing total mask loss.
"""
with tf.compat.v1.name_scope('mask_loss'):
with tf.name_scope('mask_rcnn_loss'):
(batch_size, num_masks, mask_height,
mask_width) = mask_outputs.get_shape().as_list()
......@@ -370,6 +374,7 @@ class RetinanetClassLoss(object):
"""Computes total detection loss.
Computes total detection loss including box and class loss from all levels.
Args:
cls_outputs: an OrderDict with keys representing levels and values
representing logits in [batch_size, height, width,
......@@ -426,6 +431,7 @@ class RetinanetBoxLoss(object):
"""Computes box detection loss.
Computes total detection loss including box and class loss from all levels.
Args:
box_outputs: an OrderDict with keys representing levels and values
representing box regression targets in [batch_size, height, width,
......@@ -462,62 +468,15 @@ class RetinanetBoxLoss(object):
return box_loss
class ShapeMaskLoss(object):
"""ShapeMask mask loss function wrapper."""
class ShapemaskMseLoss(object):
"""ShapeMask mask Mean Squared Error loss function wrapper."""
def __init__(self):
raise ValueError('Not TF 2.0 ready.')
raise NotImplementedError('Not Implemented.')
def __call__(self, logits, scaled_labels, classes,
category_loss=True, mse_loss=False):
"""Compute instance segmentation loss.
Args:
logits: A Tensor of shape [batch_size * num_points, height, width,
num_classes]. The logits are not necessarily between 0 and 1.
scaled_labels: A float16 Tensor of shape [batch_size, num_instances,
mask_size, mask_size], where mask_size =
mask_crop_size * gt_upsample_scale for fine mask, or mask_crop_size
for coarse masks and shape priors.
classes: A int tensor of shape [batch_size, num_instances].
category_loss: use class specific mask prediction or not.
mse_loss: use mean square error for mask loss or not
class ShapemaskLoss(object):
"""ShapeMask mask loss function wrapper."""
Returns:
mask_loss: an float tensor representing total mask classification loss.
iou: a float tensor representing the IoU between target and prediction.
"""
classes = tf.reshape(classes, [-1])
_, _, height, width = scaled_labels.get_shape().as_list()
scaled_labels = tf.reshape(scaled_labels, [-1, height, width])
if not category_loss:
logits = logits[:, :, :, 0]
else:
logits = tf.transpose(a=logits, perm=(0, 3, 1, 2))
gather_idx = tf.stack([tf.range(tf.size(input=classes)), classes - 1],
axis=1)
logits = tf.gather_nd(logits, gather_idx)
# Ignore loss on empty mask targets.
valid_labels = tf.reduce_any(
input_tensor=tf.greater(scaled_labels, 0), axis=[1, 2])
if mse_loss:
# Logits are probabilities in the case of shape prior prediction.
logits *= tf.reshape(
tf.cast(valid_labels, logits.dtype), [-1, 1, 1])
weighted_loss = tf.nn.l2_loss(scaled_labels - logits)
probs = logits
else:
weighted_loss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=scaled_labels, logits=logits)
probs = tf.sigmoid(logits)
weighted_loss *= tf.reshape(
tf.cast(valid_labels, weighted_loss.dtype), [-1, 1, 1])
iou = tf.reduce_sum(
input_tensor=tf.minimum(scaled_labels, probs)) / tf.reduce_sum(
input_tensor=tf.maximum(scaled_labels, probs))
mask_loss = tf.reduce_sum(input_tensor=weighted_loss) / tf.reduce_sum(
input_tensor=scaled_labels)
return tf.cast(mask_loss, tf.float32), tf.cast(iou, tf.float32)
def __init__(self):
raise NotImplementedError('Not Implemented.')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment