Commit 0375a63b authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Fix the issue when both numerator and denominator are 0, the dice score is...

Fix the issue when both numerator and denominator are 0, the dice score is incorrectly computed as 1, which should be 0.

PiperOrigin-RevId: 384257731
parent 9c377959
...@@ -63,6 +63,7 @@ class DiceScore: ...@@ -63,6 +63,7 @@ class DiceScore:
self._dice_scores_per_class = [ self._dice_scores_per_class = [
tf.Variable(0.0) for _ in range(num_classes) tf.Variable(0.0) for _ in range(num_classes)
] ]
self._count_per_class = [tf.Variable(0.0) for _ in range(num_classes)]
self.name = name self.name = name
self.dtype = dtype self.dtype = dtype
...@@ -87,15 +88,21 @@ class DiceScore: ...@@ -87,15 +88,21 @@ class DiceScore:
self._num_classes, self._num_classes,
y_true.get_shape()[-1])) y_true.get_shape()[-1]))
self._count.assign_add(1.) # If both y_pred and y_true are all 0s, we skip computing the metrics;
self._dice_scores_overall.assign_add(1 - # otherwise the averaged metrics will be erroneously lower.
self._dice_op_overall(y_pred, y_true)) if tf.reduce_sum(y_true) != 0 or tf.reduce_sum(y_pred) != 0:
if self._per_class_metric: self._count.assign_add(1.)
for class_id in range(self._num_classes): self._dice_scores_overall.assign_add(
self._dice_scores_per_class[class_id].assign_add( 1 - self._dice_op_overall(y_pred, y_true))
1 - if self._per_class_metric:
self._dice_op_per_class(y_pred[..., class_id], y_true[..., for class_id in range(self._num_classes):
class_id])) if tf.reduce_sum(y_true[..., class_id]) != 0 or tf.reduce_sum(
y_pred[..., class_id]) != 0:
self._count_per_class[class_id].assign_add(1.)
self._dice_scores_per_class[class_id].assign_add(
1 - self._dice_op_per_class(y_pred[...,
class_id], y_true[...,
class_id]))
def result(self) -> tf.Tensor: def result(self) -> tf.Tensor:
"""Computes and returns the metric. """Computes and returns the metric.
...@@ -114,7 +121,7 @@ class DiceScore: ...@@ -114,7 +121,7 @@ class DiceScore:
for class_id in range(self._num_classes): for class_id in range(self._num_classes):
dice_scores.append( dice_scores.append(
tf.math.divide_no_nan(self._dice_scores_per_class[class_id], tf.math.divide_no_nan(self._dice_scores_per_class[class_id],
self._count)) self._count_per_class[class_id]))
return tf.stack(dice_scores) return tf.stack(dice_scores)
else: else:
return tf.math.divide_no_nan(self._dice_scores_overall, self._count) return tf.math.divide_no_nan(self._dice_scores_overall, self._count)
......
...@@ -22,10 +22,10 @@ from official.vision.beta.projects.volumetric_models.evaluation import segmentat ...@@ -22,10 +22,10 @@ from official.vision.beta.projects.volumetric_models.evaluation import segmentat
class SegmentationMetricsTest(parameterized.TestCase, tf.test.TestCase): class SegmentationMetricsTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters((1, 'generalized', 0.5, [0.74, 0.67]), @parameterized.parameters((1, 'generalized', 0.5, [0.67, 0.67]),
(1, 'adaptive', 0.5, [0.93, 0.67]), (1, 'adaptive', 0.5, [0.93, 0.67]),
(2, None, 0.5, [0.67, 0.67, 0.67]), (2, None, 0.5, [0.67, 0.67, 0.67]),
(3, 'generalized', 0.5, [0.7, 0.67, 0.67, 0.67])) (3, 'generalized', 0.5, [0.67, 0.67, 0.67, 0.67]))
def test_forward_dice_score(self, num_classes, metric_type, output, def test_forward_dice_score(self, num_classes, metric_type, output,
expected_score): expected_score):
metric = segmentation_metrics.DiceScore( metric = segmentation_metrics.DiceScore(
......
...@@ -66,26 +66,43 @@ class SegmentationLossDiceScore(object): ...@@ -66,26 +66,43 @@ class SegmentationLossDiceScore(object):
raise ValueError('The labels and logits must be at least rank 2.') raise ValueError('The labels and logits must be at least rank 2.')
epsilon = tf.keras.backend.epsilon() epsilon = tf.keras.backend.epsilon()
axis = list(range(len(logits.shape) - 1)) keep_label_axis = list(range(len(logits.shape) - 1))
keep_batch_axis = list(range(1, len(logits.shape)))
# Compute sample mask to filter out samples with both all-0's labels and
# predictions because such samples should not contribute to mean dice score
# in this batch.
sample_mask = tf.logical_or(
tf.cast(tf.reduce_sum(labels, axis=keep_batch_axis), dtype=tf.bool),
tf.cast(tf.reduce_sum(logits, axis=keep_batch_axis), dtype=tf.bool))
labels = tf.boolean_mask(labels, sample_mask)
logits = tf.boolean_mask(logits, sample_mask)
# If all samples are filtered out, return 0 as the loss so this batch does
# not contribute.
if labels.shape[0] == 0:
return tf.convert_to_tensor(0.0)
# Calculate intersections and unions per class. # Calculate intersections and unions per class.
intersection = tf.reduce_sum(labels * logits, axis=axis) intersection = tf.reduce_sum(labels * logits, axis=keep_label_axis)
union = tf.reduce_sum(labels + logits, axis=axis) union = tf.reduce_sum(labels + logits, axis=keep_label_axis)
if self._metric_type == 'generalized': if self._metric_type == 'generalized':
# Calculate the volume of groundtruth labels. # Calculate the volume of groundtruth labels.
w = tf.math.reciprocal( w = tf.math.reciprocal(
tf.square(tf.reduce_sum(labels, axis=axis)) + epsilon) tf.square(tf.reduce_sum(labels, axis=keep_label_axis)) + epsilon)
# Calculate the weighted dice score and normalizer. # Calculate the weighted dice score and normalizer.
dice = 2 * tf.reduce_sum(w * intersection) + epsilon dice = 2 * tf.reduce_sum(w * intersection)
normalizer = tf.reduce_sum(w * union) + epsilon normalizer = tf.reduce_sum(w * union)
if normalizer == 0:
return tf.convert_to_tensor(1.0)
dice = tf.cast(dice, dtype=tf.float32) dice = tf.cast(dice, dtype=tf.float32)
normalizer = tf.cast(normalizer, dtype=tf.float32) normalizer = tf.cast(normalizer, dtype=tf.float32)
return 1 - tf.reduce_mean(dice / normalizer) return 1 - tf.reduce_mean(dice / normalizer)
elif self._metric_type == 'adaptive': elif self._metric_type == 'adaptive':
dice = 2.0 * (intersection + epsilon) / (union + epsilon) dice = 2.0 * intersection / (union + epsilon)
# Calculate weights based on Dice scores. # Calculate weights based on Dice scores.
weights = tf.exp(-1.0 * dice) weights = tf.exp(-1.0 * dice)
...@@ -94,12 +111,14 @@ class SegmentationLossDiceScore(object): ...@@ -94,12 +111,14 @@ class SegmentationLossDiceScore(object):
# Calculate normalization factor. # Calculate normalization factor.
normalizer = tf.cast(tf.size(input=dice), dtype=tf.float32) * tf.exp(-1.0) normalizer = tf.cast(tf.size(input=dice), dtype=tf.float32) * tf.exp(-1.0)
if normalizer == 0:
return tf.convert_to_tensor(1.0)
weighted_dice = tf.cast(weighted_dice, dtype=tf.float32) weighted_dice = tf.cast(weighted_dice, dtype=tf.float32)
return 1 - tf.reduce_mean(weighted_dice / normalizer) return 1 - tf.reduce_mean(weighted_dice / normalizer)
else: else:
summation = tf.reduce_sum( summation = tf.reduce_sum(
labels, axis=self._axis) + tf.reduce_sum( labels, axis=self._axis) + tf.reduce_sum(
logits, axis=self._axis) logits, axis=self._axis)
dice = (2 * tf.reduce_sum(labels * logits, axis=self._axis) + epsilon) / ( dice = (2 * tf.reduce_sum(labels * logits, axis=self._axis)) / (
summation + epsilon) summation + epsilon)
return 1 - tf.reduce_mean(dice) return 1 - tf.reduce_mean(dice)
...@@ -27,8 +27,19 @@ class SegmentationLossDiceScoreTest(parameterized.TestCase, tf.test.TestCase): ...@@ -27,8 +27,19 @@ class SegmentationLossDiceScoreTest(parameterized.TestCase, tf.test.TestCase):
def test_supported_loss(self, metric_type, output, expected_score): def test_supported_loss(self, metric_type, output, expected_score):
loss = segmentation_losses.SegmentationLossDiceScore( loss = segmentation_losses.SegmentationLossDiceScore(
metric_type=metric_type) metric_type=metric_type)
logits = tf.constant(output, shape=[1, 128, 128, 128, 1], dtype=tf.float32) logits = tf.constant(output, shape=[2, 128, 128, 128, 1], dtype=tf.float32)
labels = tf.ones(shape=[1, 128, 128, 128, 1], dtype=tf.float32) labels = tf.ones(shape=[2, 128, 128, 128, 1], dtype=tf.float32)
actual_score = loss(logits=logits, labels=labels)
self.assertAlmostEqual(actual_score.numpy(), expected_score, places=1)
@parameterized.parameters((None, 0, 0), ('generalized', 0, 0),
('adaptive', 0, 0))
def test_supported_loss_zero_labels_logits(self, metric_type, output,
expected_score):
loss = segmentation_losses.SegmentationLossDiceScore(
metric_type=metric_type)
logits = tf.constant(output, shape=[2, 128, 128, 128, 1], dtype=tf.float32)
labels = tf.zeros(shape=[2, 128, 128, 128, 1], dtype=tf.float32)
actual_score = loss(logits=logits, labels=labels) actual_score = loss(logits=logits, labels=labels)
self.assertAlmostEqual(actual_score.numpy(), expected_score, places=1) self.assertAlmostEqual(actual_score.numpy(), expected_score, places=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment