Commit f8b0f1dd authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Add support for dice loss.

PiperOrigin-RevId: 355908695
parent 8a064338
......@@ -227,7 +227,7 @@ def _build_classification_loss(loss_config):
if loss_type == 'weighted_sigmoid':
return losses.WeightedSigmoidClassificationLoss()
if loss_type == 'weighted_sigmoid_focal':
elif loss_type == 'weighted_sigmoid_focal':
config = loss_config.weighted_sigmoid_focal
alpha = None
if config.HasField('alpha'):
......@@ -236,25 +236,31 @@ def _build_classification_loss(loss_config):
gamma=config.gamma,
alpha=alpha)
if loss_type == 'weighted_softmax':
elif loss_type == 'weighted_softmax':
config = loss_config.weighted_softmax
return losses.WeightedSoftmaxClassificationLoss(
logit_scale=config.logit_scale)
if loss_type == 'weighted_logits_softmax':
elif loss_type == 'weighted_logits_softmax':
config = loss_config.weighted_logits_softmax
return losses.WeightedSoftmaxClassificationAgainstLogitsLoss(
logit_scale=config.logit_scale)
if loss_type == 'bootstrapped_sigmoid':
elif loss_type == 'bootstrapped_sigmoid':
config = loss_config.bootstrapped_sigmoid
return losses.BootstrappedSigmoidClassificationLoss(
alpha=config.alpha,
bootstrap_type=('hard' if config.hard_bootstrap else 'soft'))
if loss_type == 'penalty_reduced_logistic_focal_loss':
elif loss_type == 'penalty_reduced_logistic_focal_loss':
config = loss_config.penalty_reduced_logistic_focal_loss
return losses.PenaltyReducedLogisticFocalLoss(
alpha=config.alpha, beta=config.beta)
raise ValueError('Empty loss config.')
elif loss_type == 'weighted_dice_classification_loss':
config = loss_config.weighted_dice_classification_loss
return losses.WeightedDiceClassificationLoss(
squared_normalization=config.squared_normalization)
else:
raise ValueError('Empty loss config.')
......@@ -298,6 +298,45 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
with self.assertRaises(ValueError):
losses_builder.build(losses_proto)
def test_build_penalty_reduced_logistic_focal_loss(self):
losses_text_proto = """
classification_loss {
penalty_reduced_logistic_focal_loss {
alpha: 2.0
beta: 4.0
}
}
localization_loss {
l1_localization_loss {
}
}
"""
losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertIsInstance(classification_loss,
losses.PenaltyReducedLogisticFocalLoss)
self.assertAlmostEqual(classification_loss._alpha, 2.0)
self.assertAlmostEqual(classification_loss._beta, 4.0)
def test_build_dice_loss(self):
losses_text_proto = """
classification_loss {
weighted_dice_classification_loss {
squared_normalization: true
}
}
localization_loss {
l1_localization_loss {
}
}
"""
losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertIsInstance(classification_loss,
losses.WeightedDiceClassificationLoss)
assert classification_loss._squared_normalization
class HardExampleMinerBuilderTest(tf.test.TestCase):
......
......@@ -278,6 +278,79 @@ class WeightedSigmoidClassificationLoss(Loss):
return per_entry_cross_ent * weights
class WeightedDiceClassificationLoss(Loss):
"""Dice loss for classification [1][2].
[1]: https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
[2]: https://arxiv.org/abs/1606.04797
"""
def __init__(self, squared_normalization):
"""Initializes the loss object.
Args:
squared_normalization: boolean, if set, we square the probabilities in the
denominator term used for normalization.
"""
self._squared_normalization = squared_normalization
super(WeightedDiceClassificationLoss, self).__init__()
def _compute_loss(self,
prediction_tensor,
target_tensor,
weights,
class_indices=None):
"""Computes the loss value.
Dice loss uses the area of the ground truth and prediction tensors for
normalization. We compute area by summing along the anchors (2nd) dimension.
Args:
prediction_tensor: A float tensor of shape [batch_size, num_pixels,
num_classes] representing the predicted logits for each class.
num_pixels denotes the total number of pixels in the spatial dimensions
of the mask after flattening.
target_tensor: A float tensor of shape [batch_size, num_pixels,
num_classes] representing one-hot encoded classification targets.
num_pixels denotes the total number of pixels in the spatial dimensions
of the mask after flattening.
weights: a float tensor of shape, either [batch_size, num_anchors,
num_classes] or [batch_size, num_anchors, 1]. If the shape is
[batch_size, num_anchors, 1], all the classses are equally weighted.
class_indices: (Optional) A 1-D integer tensor of class indices.
If provided, computes loss only for the specified class indices.
Returns:
loss: a float tensor of shape [batch_size, num_classes]
representing the value of the loss function.
"""
if class_indices is not None:
weights *= tf.reshape(
ops.indices_to_dense_vector(class_indices,
tf.shape(prediction_tensor)[2]),
[1, 1, -1])
prob_tensor = tf.nn.sigmoid(prediction_tensor)
if self._squared_normalization:
prob_tensor = tf.pow(prob_tensor, 2)
target_tensor = tf.pow(target_tensor, 2)
prob_tensor *= weights
target_tensor *= weights
prediction_area = tf.reduce_sum(prob_tensor, axis=1)
gt_area = tf.reduce_sum(target_tensor, axis=1)
intersection = tf.reduce_sum(prob_tensor * target_tensor, axis=1)
dice_coeff = 2 * intersection / tf.maximum(gt_area + prediction_area, 1.0)
dice_loss = 1 - dice_coeff
return dice_loss
class SigmoidFocalClassificationLoss(Loss):
"""Sigmoid focal cross entropy loss.
......
......@@ -1447,5 +1447,111 @@ class L1LocalizationLossTest(test_case.TestCase):
self.assertAllClose(computed_value, [[0.8, 0.0], [0.6, 0.1]], rtol=1e-6)
class WeightedDiceClassificationLoss(test_case.TestCase):
def test_compute_weights_1(self):
def graph_fn():
loss = losses.WeightedDiceClassificationLoss(squared_normalization=False)
pred = np.zeros((2, 3, 4), dtype=np.float32)
target = np.zeros((2, 3, 4), dtype=np.float32)
pred[0, 1, 0] = _logit(0.9)
pred[0, 2, 0] = _logit(0.1)
pred[0, 2, 2] = _logit(0.5)
pred[0, 1, 3] = _logit(0.1)
pred[1, 2, 3] = _logit(0.2)
pred[1, 1, 1] = _logit(0.3)
pred[1, 0, 2] = _logit(0.1)
target[0, 1, 0] = 1.0
target[0, 2, 2] = 1.0
target[0, 1, 3] = 1.0
target[1, 2, 3] = 1.0
target[1, 1, 1] = 0.0
target[1, 0, 2] = 0.0
weights = np.ones_like(target)
return loss._compute_loss(pred, target, weights)
dice_coeff = np.zeros((2, 4))
dice_coeff[0, 0] = 2 * 0.9 / 2.5
dice_coeff[0, 2] = 2 * 0.5 / 2.5
dice_coeff[0, 3] = 2 * 0.1 / 2.1
dice_coeff[1, 3] = 2 * 0.2 / 2.2
computed_value = self.execute(graph_fn, [])
self.assertAllClose(computed_value, 1 - dice_coeff, rtol=1e-6)
def test_compute_weights_set(self):
def graph_fn():
loss = losses.WeightedDiceClassificationLoss(squared_normalization=False)
pred = np.zeros((2, 3, 4), dtype=np.float32)
target = np.zeros((2, 3, 4), dtype=np.float32)
pred[0, 1, 0] = _logit(0.9)
pred[0, 2, 0] = _logit(0.1)
pred[0, 2, 2] = _logit(0.5)
pred[0, 1, 3] = _logit(0.1)
pred[1, 2, 3] = _logit(0.2)
pred[1, 1, 1] = _logit(0.3)
pred[1, 0, 2] = _logit(0.1)
target[0, 1, 0] = 1.0
target[0, 2, 2] = 1.0
target[0, 1, 3] = 1.0
target[1, 2, 3] = 1.0
target[1, 1, 1] = 0.0
target[1, 0, 2] = 0.0
weights = np.ones_like(target)
weights[:, :, 0] = 0.0
return loss._compute_loss(pred, target, weights)
dice_coeff = np.zeros((2, 4))
dice_coeff[0, 2] = 2 * 0.5 / 2.5
dice_coeff[0, 3] = 2 * 0.1 / 2.1
dice_coeff[1, 3] = 2 * 0.2 / 2.2
computed_value = self.execute(graph_fn, [])
self.assertAllClose(computed_value, 1 - dice_coeff, rtol=1e-6)
def test_class_indices(self):
def graph_fn():
loss = losses.WeightedDiceClassificationLoss(squared_normalization=False)
pred = np.zeros((2, 3, 4), dtype=np.float32)
target = np.zeros((2, 3, 4), dtype=np.float32)
pred[0, 1, 0] = _logit(0.9)
pred[0, 2, 0] = _logit(0.1)
pred[0, 2, 2] = _logit(0.5)
pred[0, 1, 3] = _logit(0.1)
pred[1, 2, 3] = _logit(0.2)
pred[1, 1, 1] = _logit(0.3)
pred[1, 0, 2] = _logit(0.1)
target[0, 1, 0] = 1.0
target[0, 2, 2] = 1.0
target[0, 1, 3] = 1.0
target[1, 2, 3] = 1.0
target[1, 1, 1] = 0.0
target[1, 0, 2] = 0.0
weights = np.ones_like(target)
return loss._compute_loss(pred, target, weights, class_indices=[0])
dice_coeff = np.zeros((2, 4))
dice_coeff[0, 0] = 2 * 0.9 / 2.5
computed_value = self.execute(graph_fn, [])
self.assertAllClose(computed_value, 1 - dice_coeff, rtol=1e-6)
if __name__ == '__main__':
tf.test.main()
......@@ -110,6 +110,7 @@ message ClassificationLoss {
BootstrappedSigmoidClassificationLoss bootstrapped_sigmoid = 3;
SigmoidFocalClassificationLoss weighted_sigmoid_focal = 4;
PenaltyReducedLogisticFocalLoss penalty_reduced_logistic_focal_loss = 6;
WeightedDiceClassificationLoss weighted_dice_classification_loss = 7;
}
}
......@@ -217,3 +218,14 @@ message RandomExampleSampler {
// example sampling.
optional float positive_sample_fraction = 1 [default = 0.01];
}
// Dice loss for training instance masks[1][2].
// [1]: https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
// [2]: https://arxiv.org/abs/1606.04797
message WeightedDiceClassificationLoss {
// If set, we square the probabilities in the denominator term used for
// normalization.
optional bool squared_normalization = 1 [default=false];
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment