"tests/vscode:/vscode.git/clone" did not exist on "a69ebe5527b24c5f688545460a7d83eb4cc66648"
Commit bc324fda authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Add per-class IoU to list of semantic segmentation metrics.

PiperOrigin-RevId: 356892700
parent 9815ea67
...@@ -90,6 +90,12 @@ class Losses(hyperparams.Config): ...@@ -90,6 +90,12 @@ class Losses(hyperparams.Config):
top_k_percent_pixels: float = 1.0 top_k_percent_pixels: float = 1.0
@dataclasses.dataclass
class Evaluation(hyperparams.Config):
report_per_class_iou: bool = True
report_train_mean_iou: bool = True # Turning this off can speed up training.
@dataclasses.dataclass @dataclasses.dataclass
class SemanticSegmentationTask(cfg.TaskConfig): class SemanticSegmentationTask(cfg.TaskConfig):
"""The model config.""" """The model config."""
...@@ -97,6 +103,7 @@ class SemanticSegmentationTask(cfg.TaskConfig): ...@@ -97,6 +103,7 @@ class SemanticSegmentationTask(cfg.TaskConfig):
train_data: DataConfig = DataConfig(is_training=True) train_data: DataConfig = DataConfig(is_training=True)
validation_data: DataConfig = DataConfig(is_training=False) validation_data: DataConfig = DataConfig(is_training=False)
losses: Losses = Losses() losses: Losses = Losses()
evaluation: Evaluation = Evaluation()
train_input_partition_dims: List[int] = dataclasses.field( train_input_partition_dims: List[int] = dataclasses.field(
default_factory=list) default_factory=list)
eval_input_partition_dims: List[int] = dataclasses.field( eval_input_partition_dims: List[int] = dataclasses.field(
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
import tensorflow as tf import tensorflow as tf
from official.vision import keras_cv
class MeanIoU(tf.keras.metrics.MeanIoU): class MeanIoU(tf.keras.metrics.MeanIoU):
"""Mean IoU metric for semantic segmentation. """Mean IoU metric for semantic segmentation.
...@@ -122,3 +124,110 @@ class MeanIoU(tf.keras.metrics.MeanIoU): ...@@ -122,3 +124,110 @@ class MeanIoU(tf.keras.metrics.MeanIoU):
super(MeanIoU, self).update_state( super(MeanIoU, self).update_state(
flatten_masks, flatten_predictions, flatten_masks, flatten_predictions,
tf.cast(flatten_valid_masks, tf.float32)) tf.cast(flatten_valid_masks, tf.float32))
class PerClassIoU(keras_cv.metrics.PerClassIoU):
"""Per Class IoU metric for semantic segmentation.
This class utilizes keras_cv.metrics.PerClassIoU to perform batched per class
iou when both input images and groundtruth masks are resized to the same size
(rescale_predictions=False). It also computes per class iou on groundtruth
original sizes, in which case, each prediction is rescaled back to the
original image size.
"""
def __init__(
self, num_classes, rescale_predictions=False, name=None, dtype=None):
"""Constructs Segmentation evaluator class.
Args:
num_classes: `int`, number of classes.
rescale_predictions: `bool`, whether to scale back prediction to original
image sizes. If True, y_true['image_info'] is used to rescale
predictions.
name: `str`, name of the metric instance..
dtype: data type of the metric result.
"""
self._rescale_predictions = rescale_predictions
super(PerClassIoU, self).__init__(
num_classes=num_classes, name=name, dtype=dtype)
def update_state(self, y_true, y_pred):
"""Updates metric state.
Args:
y_true: `dict`, dictionary with the following name, and key values.
- masks: [batch, width, height, 1], groundtruth masks.
- valid_masks: [batch, width, height, 1], valid elements in the mask.
- image_info: [batch, 4, 2], a tensor that holds information about
original and preprocessed images. Each entry is in the format of
[[original_height, original_width], [input_height, input_width],
[y_scale, x_scale], [y_offset, x_offset]], where [desired_height,
desired_width] is the actual scaled image size, and [y_scale, x_scale]
is the scaling factor, which is the ratio of scaled dimension /
original dimension.
y_pred: Tensor [batch, width_p, height_p, num_classes], predicated masks.
"""
predictions = y_pred
masks = y_true['masks']
valid_masks = y_true['valid_masks']
images_info = y_true['image_info']
if isinstance(predictions, tuple) or isinstance(predictions, list):
predictions = tf.concat(predictions, axis=0)
masks = tf.concat(masks, axis=0)
valid_masks = tf.concat(valid_masks, axis=0)
images_info = tf.concat(images_info, axis=0)
# Ignore mask elements is set to zero for argmax op.
masks = tf.where(valid_masks, masks, tf.zeros_like(masks))
if self._rescale_predictions:
# This part can only run on cpu/gpu due to dynamic image resizing.
flatten_predictions = []
flatten_masks = []
flatten_valid_masks = []
for mask, valid_mask, predicted_mask, image_info in zip(
masks, valid_masks, predictions, images_info):
rescale_size = tf.cast(
tf.math.ceil(image_info[1, :] / image_info[2, :]), tf.int32)
image_shape = tf.cast(image_info[0, :], tf.int32)
offsets = tf.cast(image_info[3, :], tf.int32)
predicted_mask = tf.image.resize(
predicted_mask,
rescale_size,
method=tf.image.ResizeMethod.BILINEAR)
predicted_mask = tf.image.crop_to_bounding_box(predicted_mask,
offsets[0], offsets[1],
image_shape[0],
image_shape[1])
mask = tf.image.crop_to_bounding_box(mask, 0, 0, image_shape[0],
image_shape[1])
valid_mask = tf.image.crop_to_bounding_box(valid_mask, 0, 0,
image_shape[0],
image_shape[1])
predicted_mask = tf.argmax(predicted_mask, axis=2)
flatten_predictions.append(tf.reshape(predicted_mask, shape=[1, -1]))
flatten_masks.append(tf.reshape(mask, shape=[1, -1]))
flatten_valid_masks.append(tf.reshape(valid_mask, shape=[1, -1]))
flatten_predictions = tf.concat(flatten_predictions, axis=1)
flatten_masks = tf.concat(flatten_masks, axis=1)
flatten_valid_masks = tf.concat(flatten_valid_masks, axis=1)
else:
predictions = tf.image.resize(
predictions,
tf.shape(masks)[1:3],
method=tf.image.ResizeMethod.BILINEAR)
predictions = tf.argmax(predictions, axis=3)
flatten_predictions = tf.reshape(predictions, shape=[-1])
flatten_masks = tf.reshape(masks, shape=[-1])
flatten_valid_masks = tf.reshape(valid_masks, shape=[-1])
super(PerClassIoU, self).update_state(
flatten_masks, flatten_predictions,
tf.cast(flatten_valid_masks, tf.float32))
...@@ -143,15 +143,15 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -143,15 +143,15 @@ class SemanticSegmentationTask(base_task.Task):
def build_metrics(self, training=True): def build_metrics(self, training=True):
"""Gets streaming metrics for training/validation.""" """Gets streaming metrics for training/validation."""
metrics = [] metrics = []
if training: if training and self.task_config.evaluation.report_train_mean_iou:
metrics.append(segmentation_metrics.MeanIoU( metrics.append(segmentation_metrics.MeanIoU(
name='mean_iou', name='mean_iou',
num_classes=self.task_config.model.num_classes, num_classes=self.task_config.model.num_classes,
rescale_predictions=False, rescale_predictions=False,
dtype=tf.float32)) dtype=tf.float32))
else: else:
self.miou_metric = segmentation_metrics.MeanIoU( self.iou_metric = segmentation_metrics.PerClassIoU(
name='val_mean_iou', name='per_class_iou',
num_classes=self.task_config.model.num_classes, num_classes=self.task_config.model.num_classes,
rescale_predictions=not self.task_config.validation_data rescale_predictions=not self.task_config.validation_data
.resize_eval_groundtruth, .resize_eval_groundtruth,
...@@ -243,7 +243,7 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -243,7 +243,7 @@ class SemanticSegmentationTask(base_task.Task):
loss = 0 loss = 0
logs = {self.loss: loss} logs = {self.loss: loss}
logs.update({self.miou_metric.name: (labels, outputs)}) logs.update({self.iou_metric.name: (labels, outputs)})
if metrics: if metrics:
self.process_metrics(metrics, labels, outputs) self.process_metrics(metrics, labels, outputs)
...@@ -257,11 +257,19 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -257,11 +257,19 @@ class SemanticSegmentationTask(base_task.Task):
def aggregate_logs(self, state=None, step_outputs=None): def aggregate_logs(self, state=None, step_outputs=None):
if state is None: if state is None:
self.miou_metric.reset_states() self.iou_metric.reset_states()
state = self.miou_metric state = self.iou_metric
self.miou_metric.update_state(step_outputs[self.miou_metric.name][0], self.iou_metric.update_state(step_outputs[self.iou_metric.name][0],
step_outputs[self.miou_metric.name][1]) step_outputs[self.iou_metric.name][1])
return state return state
def reduce_aggregated_logs(self, aggregated_logs): def reduce_aggregated_logs(self, aggregated_logs):
return {self.miou_metric.name: self.miou_metric.result().numpy()} result = {}
ious = self.iou_metric.result()
# TODO(arashwan): support loading class name from a label map file.
if self.task_config.evaluation.report_per_class_iou:
for i, value in enumerate(ious.numpy()):
result.update({'iou/{}'.format(i): value})
# Computes mean IoU
result.update({'mean_iou': tf.reduce_mean(ious).numpy()})
return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment