Commit 38c61e26 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 479685685
parent 96481557
...@@ -14,15 +14,13 @@ ...@@ -14,15 +14,13 @@
"""IOU Metrics used for semantic segmentation models.""" """IOU Metrics used for semantic segmentation models."""
import numpy as np
import tensorflow as tf import tensorflow as tf
class PerClassIoU(tf.keras.metrics.Metric): class PerClassIoU(tf.keras.metrics.MeanIoU):
"""Computes the per-class Intersection-Over-Union metric. """Computes the per-class Intersection-Over-Union metric.
Mean Intersection-Over-Union is a common evaluation metric for semantic image This metric computes the IOU for each semantic class.
segmentation, which first computes the IOU for each semantic class.
IOU is defined as follows: IOU is defined as follows:
IOU = true_positive / (true_positive + false_positive + false_negative). IOU = true_positive / (true_positive + false_positive + false_negative).
The predictions are accumulated in a confusion matrix, weighted by The predictions are accumulated in a confusion matrix, weighted by
...@@ -42,70 +40,10 @@ class PerClassIoU(tf.keras.metrics.Metric): ...@@ -42,70 +40,10 @@ class PerClassIoU(tf.keras.metrics.Metric):
>>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1]) >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1])
>>> m.result().numpy() >>> m.result().numpy()
[0.33333334, 0.33333334] [0.33333334, 0.33333334]
"""
def __init__(self, num_classes, name=None, dtype=None):
"""Initializes `PerClassIoU`.
Args:
num_classes: The possible number of labels the prediction task can have.
This value must be provided, since a confusion matrix of dimension =
[num_classes, num_classes] will be allocated.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
"""
super(PerClassIoU, self).__init__(name=name, dtype=dtype)
self.num_classes = num_classes
# Variable to accumulate the predictions in the confusion matrix.
self.total_cm = self.add_weight(
'total_confusion_matrix',
shape=(num_classes, num_classes),
initializer=tf.compat.v1.zeros_initializer)
def update_state(self, y_true, y_pred, sample_weight=None):
"""Accumulates the confusion matrix statistics.
Args:
y_true: The ground truth values.
y_pred: The predicted values.
sample_weight: Optional weighting of each example. Defaults to 1. Can be a
`Tensor` whose rank is either 0, or the same rank as `y_true`, and must
be broadcastable to `y_true`.
Returns:
IOU per class.
""" """
y_true = tf.cast(y_true, self._dtype)
y_pred = tf.cast(y_pred, self._dtype)
# Flatten the input if its rank > 1.
if y_pred.shape.ndims > 1:
y_pred = tf.reshape(y_pred, [-1])
if y_true.shape.ndims > 1:
y_true = tf.reshape(y_true, [-1])
if sample_weight is not None:
sample_weight = tf.cast(sample_weight, self._dtype)
if sample_weight.shape.ndims > 1:
sample_weight = tf.reshape(sample_weight, [-1])
# Accumulate the prediction to current confusion matrix.
current_cm = tf.math.confusion_matrix(
y_true,
y_pred,
self.num_classes,
weights=sample_weight,
dtype=self._dtype)
return self.total_cm.assign_add(current_cm)
def result(self): def result(self):
"""Compute the mean intersection-over-union via the confusion matrix.""" """Compute IoU for each class via the confusion matrix."""
sum_over_row = tf.cast( sum_over_row = tf.cast(
tf.reduce_sum(self.total_cm, axis=0), dtype=self._dtype) tf.reduce_sum(self.total_cm, axis=0), dtype=self._dtype)
sum_over_col = tf.cast( sum_over_col = tf.cast(
...@@ -118,12 +56,3 @@ class PerClassIoU(tf.keras.metrics.Metric): ...@@ -118,12 +56,3 @@ class PerClassIoU(tf.keras.metrics.Metric):
denominator = sum_over_row + sum_over_col - true_positives denominator = sum_over_row + sum_over_col - true_positives
return tf.math.divide_no_nan(true_positives, denominator) return tf.math.divide_no_nan(true_positives, denominator)
def reset_states(self):
tf.keras.backend.set_value(
self.total_cm, np.zeros((self.num_classes, self.num_classes)))
def get_config(self):
config = {'num_classes': self.num_classes}
base_config = super(PerClassIoU, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
import tensorflow as tf import tensorflow as tf
from official.vision.evaluation import iou
from official.vision.ops import box_ops from official.vision.ops import box_ops
from official.vision.ops import spatial_transform_ops from official.vision.ops import spatial_transform_ops
...@@ -31,8 +30,11 @@ class MeanIoU(tf.keras.metrics.MeanIoU): ...@@ -31,8 +30,11 @@ class MeanIoU(tf.keras.metrics.MeanIoU):
size. size.
""" """
def __init__( def __init__(self,
self, num_classes, rescale_predictions=False, name=None, dtype=None): num_classes,
rescale_predictions=False,
name=None,
dtype=None):
"""Constructs Segmentation evaluator class. """Constructs Segmentation evaluator class.
Args: Args:
...@@ -103,86 +105,23 @@ class MeanIoU(tf.keras.metrics.MeanIoU): ...@@ -103,86 +105,23 @@ class MeanIoU(tf.keras.metrics.MeanIoU):
sample_weight=tf.cast(flatten_valid_masks, tf.float32)) sample_weight=tf.cast(flatten_valid_masks, tf.float32))
class PerClassIoU(iou.PerClassIoU): class PerClassIoU(MeanIoU):
"""Per Class IoU metric for semantic segmentation. """Per class IoU metric for semantic segmentation."""
This class utilizes iou.PerClassIoU to perform batched per class def result(self):
iou when both input images and groundtruth masks are resized to the same size """Compute IoU for each class via the confusion matrix."""
(rescale_predictions=False). It also computes per class iou on groundtruth sum_over_row = tf.cast(
original sizes, in which case, each prediction is rescaled back to the tf.reduce_sum(self.total_cm, axis=0), dtype=self._dtype)
original image size. sum_over_col = tf.cast(
""" tf.reduce_sum(self.total_cm, axis=1), dtype=self._dtype)
true_positives = tf.cast(
def __init__( tf.linalg.tensor_diag_part(self.total_cm), dtype=self._dtype)
self, num_classes, rescale_predictions=False, name=None, dtype=None):
"""Constructs Segmentation evaluator class.
Args:
num_classes: `int`, number of classes.
rescale_predictions: `bool`, whether to scale back prediction to original
image sizes. If True, y_true['image_info'] is used to rescale
predictions.
name: `str`, name of the metric instance..
dtype: data type of the metric result.
"""
self._rescale_predictions = rescale_predictions
super().__init__(num_classes=num_classes, name=name, dtype=dtype)
def update_state(self, y_true, y_pred):
"""Updates metric state.
Args:
y_true: `dict`, dictionary with the following name, and key values.
- masks: [batch, height, width, 1], groundtruth masks.
- valid_masks: [batch, height, width, 1], valid elements in the mask.
- image_info: [batch, 4, 2], a tensor that holds information about
original and preprocessed images. Each entry is in the format of
[[original_height, original_width], [input_height, input_width],
[y_scale, x_scale], [y_offset, x_offset]], where [desired_height,
desired_width] is the actual scaled image size, and [y_scale, x_scale]
is the scaling factor, which is the ratio of scaled dimension /
original dimension.
y_pred: Tensor [batch, height_p, width_p, num_classes], predicated masks.
"""
predictions = y_pred
masks = y_true['masks']
valid_masks = y_true['valid_masks']
images_info = y_true['image_info']
if isinstance(predictions, tuple) or isinstance(predictions, list):
predictions = tf.concat(predictions, axis=0)
masks = tf.concat(masks, axis=0)
valid_masks = tf.concat(valid_masks, axis=0)
images_info = tf.concat(images_info, axis=0)
# Ignore mask elements is set to zero for argmax op.
masks = tf.where(valid_masks, masks, tf.zeros_like(masks))
masks_size = tf.shape(masks)[1:3]
if self._rescale_predictions:
# Scale back predictions to original image shapes and pad to mask size.
# Note: instead of cropping the masks to image shape (dynamic), here we
# pad the rescaled predictions to mask size (fixed). And update the
# valid_masks to mask out the pixels outside the original image shape.
predictions, image_shape_masks = _rescale_and_pad_predictions(
predictions, images_info, output_size=masks_size)
# Only the area within the original image shape is valid.
# (batch_size, height, width, 1)
valid_masks = tf.cast(valid_masks, tf.bool) & tf.expand_dims(
image_shape_masks, axis=-1)
else:
predictions = tf.image.resize(
predictions, masks_size, method=tf.image.ResizeMethod.BILINEAR)
predictions = tf.argmax(predictions, axis=3) # sum_over_row + sum_over_col =
flatten_predictions = tf.reshape(predictions, shape=[-1]) # 2 * true_positives + false_positives + false_negatives.
flatten_masks = tf.reshape(masks, shape=[-1]) denominator = sum_over_row + sum_over_col - true_positives
flatten_valid_masks = tf.reshape(valid_masks, shape=[-1])
super().update_state( return tf.math.divide_no_nan(true_positives, denominator)
y_true=flatten_masks,
y_pred=flatten_predictions,
sample_weight=tf.cast(flatten_valid_masks, tf.float32))
def _rescale_and_pad_predictions(predictions, images_info, output_size): def _rescale_and_pad_predictions(predictions, images_info, output_size):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment