Commit 2d3235f8 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 477493248
parent 242f4098
...@@ -133,11 +133,12 @@ class PanopticDeeplabTask(base_task.Task): ...@@ -133,11 +133,12 @@ class PanopticDeeplabTask(base_task.Task):
The total loss tensor. The total loss tensor.
""" """
loss_config = self._task_config.losses loss_config = self._task_config.losses
segmentation_loss_fn = panoptic_deeplab_losses.WeightedBootstrappedCrossEntropyLoss( segmentation_loss_fn = (
panoptic_deeplab_losses.WeightedBootstrappedCrossEntropyLoss(
loss_config.label_smoothing, loss_config.label_smoothing,
loss_config.class_weights, loss_config.class_weights,
loss_config.ignore_label, loss_config.ignore_label,
top_k_percent_pixels=loss_config.top_k_percent_pixels) top_k_percent_pixels=loss_config.top_k_percent_pixels))
instance_center_heatmap_loss_fn = panoptic_deeplab_losses.CenterHeatmapLoss( instance_center_heatmap_loss_fn = panoptic_deeplab_losses.CenterHeatmapLoss(
) )
instance_center_offset_loss_fn = panoptic_deeplab_losses.CenterOffsetLoss() instance_center_offset_loss_fn = panoptic_deeplab_losses.CenterOffsetLoss()
...@@ -214,24 +215,16 @@ class PanopticDeeplabTask(base_task.Task): ...@@ -214,24 +215,16 @@ class PanopticDeeplabTask(base_task.Task):
rescale_predictions=rescale_predictions, rescale_predictions=rescale_predictions,
dtype=tf.float32) dtype=tf.float32)
if isinstance(tf.distribute.get_strategy(), tf.distribute.TPUStrategy):
self._process_iou_metric_on_cpu = True
else:
self._process_iou_metric_on_cpu = False
if self.task_config.model.generate_panoptic_masks: if self.task_config.model.generate_panoptic_masks:
self.panoptic_quality_metric = panoptic_quality_evaluator.PanopticQualityEvaluator( self.panoptic_quality_metric = (
panoptic_quality_evaluator.PanopticQualityEvaluator(
num_categories=self.task_config.model.num_classes, num_categories=self.task_config.model.num_classes,
ignored_label=eval_config.ignored_label, ignored_label=eval_config.ignored_label,
max_instances_per_category=eval_config.max_instances_per_category, max_instances_per_category=eval_config
.max_instances_per_category,
offset=eval_config.offset, offset=eval_config.offset,
is_thing=eval_config.is_thing, is_thing=eval_config.is_thing,
rescale_predictions=eval_config.rescale_predictions) rescale_predictions=eval_config.rescale_predictions))
# Update state on CPU if TPUStrategy due to dynamic resizing.
self._process_iou_metric_on_cpu = isinstance(
tf.distribute.get_strategy(),
tf.distribute.TPUStrategy)
return metrics return metrics
...@@ -334,22 +327,13 @@ class PanopticDeeplabTask(base_task.Task): ...@@ -334,22 +327,13 @@ class PanopticDeeplabTask(base_task.Task):
'image_info': labels['image_info'] 'image_info': labels['image_info']
} }
if self._process_iou_metric_on_cpu: self.perclass_iou_metric.update_state(segmentation_labels,
logs.update({
self.perclass_iou_metric.name:
(segmentation_labels, outputs['segmentation_outputs'])
})
else:
self.perclass_iou_metric.update_state(
segmentation_labels,
outputs['segmentation_outputs']) outputs['segmentation_outputs'])
if self.task_config.model.generate_panoptic_masks: if self.task_config.model.generate_panoptic_masks:
pq_metric_labels = { pq_metric_labels = {
'category_mask': 'category_mask': tf.squeeze(labels['category_mask'], axis=3),
tf.squeeze(labels['category_mask'], axis=3), 'instance_mask': tf.squeeze(labels['instance_mask'], axis=3),
'instance_mask':
tf.squeeze(labels['instance_mask'], axis=3),
'image_info': labels['image_info'] 'image_info': labels['image_info']
} }
panoptic_outputs = { panoptic_outputs = {
...@@ -370,11 +354,6 @@ class PanopticDeeplabTask(base_task.Task): ...@@ -370,11 +354,6 @@ class PanopticDeeplabTask(base_task.Task):
if self.task_config.model.generate_panoptic_masks: if self.task_config.model.generate_panoptic_masks:
state += [self.panoptic_quality_metric] state += [self.panoptic_quality_metric]
if self._process_iou_metric_on_cpu:
self.perclass_iou_metric.update_state(
step_outputs[self.perclass_iou_metric.name][0],
step_outputs[self.perclass_iou_metric.name][1])
if self.task_config.model.generate_panoptic_masks: if self.task_config.model.generate_panoptic_masks:
self.panoptic_quality_metric.update_state( self.panoptic_quality_metric.update_state(
step_outputs[self.panoptic_quality_metric.name][0], step_outputs[self.panoptic_quality_metric.name][0],
......
...@@ -175,7 +175,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -175,7 +175,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
"""Build Panoptic Mask R-CNN losses.""" """Build Panoptic Mask R-CNN losses."""
params = self.task_config.losses params = self.task_config.losses
use_groundtruth_dimension = params.semantic_segmentation_use_groundtruth_dimension use_groundtruth_dimension = (
params.semantic_segmentation_use_groundtruth_dimension)
segmentation_loss_fn = segmentation_losses.SegmentationLoss( segmentation_loss_fn = segmentation_losses.SegmentationLoss(
label_smoothing=params.semantic_segmentation_label_smoothing, label_smoothing=params.semantic_segmentation_label_smoothing,
...@@ -218,7 +219,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -218,7 +219,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
tf.keras.metrics.Metric]: tf.keras.metrics.Metric]:
"""Build detection metrics.""" """Build detection metrics."""
metrics = [] metrics = []
num_segmentation_classes = self.task_config.model.segmentation_model.num_classes num_segmentation_classes = (
self.task_config.model.segmentation_model.num_classes)
if training: if training:
metric_names = [ metric_names = [
'total_loss', 'total_loss',
...@@ -253,23 +255,19 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -253,23 +255,19 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
rescale_predictions=rescale_predictions, rescale_predictions=rescale_predictions,
dtype=tf.float32) dtype=tf.float32)
if isinstance(tf.distribute.get_strategy(), tf.distribute.TPUStrategy):
self._process_iou_metric_on_cpu = True
else:
self._process_iou_metric_on_cpu = False
if self.task_config.model.generate_panoptic_masks: if self.task_config.model.generate_panoptic_masks:
if not self.task_config.validation_data.parser.include_panoptic_masks: if not self.task_config.validation_data.parser.include_panoptic_masks:
raise ValueError('`include_panoptic_masks` should be set to True when' raise ValueError('`include_panoptic_masks` should be set to True when'
' computing panoptic quality.') ' computing panoptic quality.')
pq_config = self.task_config.panoptic_quality_evaluator pq_config = self.task_config.panoptic_quality_evaluator
self.panoptic_quality_metric = panoptic_quality_evaluator.PanopticQualityEvaluator( self.panoptic_quality_metric = (
panoptic_quality_evaluator.PanopticQualityEvaluator(
num_categories=pq_config.num_categories, num_categories=pq_config.num_categories,
ignored_label=pq_config.ignored_label, ignored_label=pq_config.ignored_label,
max_instances_per_category=pq_config.max_instances_per_category, max_instances_per_category=pq_config.max_instances_per_category,
offset=pq_config.offset, offset=pq_config.offset,
is_thing=pq_config.is_thing, is_thing=pq_config.is_thing,
rescale_predictions=pq_config.rescale_predictions) rescale_predictions=pq_config.rescale_predictions))
return metrics return metrics
...@@ -385,22 +383,14 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -385,22 +383,14 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
logs.update( logs.update(
{self.coco_metric.name: (labels['groundtruths'], coco_model_outputs)}) {self.coco_metric.name: (labels['groundtruths'], coco_model_outputs)})
if self._process_iou_metric_on_cpu:
logs.update({
self.segmentation_perclass_iou_metric.name:
(segmentation_labels, outputs['segmentation_outputs'])
})
else:
self.segmentation_perclass_iou_metric.update_state( self.segmentation_perclass_iou_metric.update_state(
segmentation_labels, segmentation_labels, outputs['segmentation_outputs'])
outputs['segmentation_outputs'])
if self.task_config.model.generate_panoptic_masks: if self.task_config.model.generate_panoptic_masks:
pq_metric_labels = { pq_metric_labels = {
'category_mask': 'category_mask': labels['groundtruths']['gt_panoptic_category_mask'],
labels['groundtruths']['gt_panoptic_category_mask'], 'instance_mask': labels['groundtruths']['gt_panoptic_instance_mask'],
'instance_mask':
labels['groundtruths']['gt_panoptic_instance_mask'],
'image_info': labels['image_info'] 'image_info': labels['image_info']
} }
logs.update({ logs.update({
...@@ -420,11 +410,6 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -420,11 +410,6 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
step_outputs[self.coco_metric.name][0], step_outputs[self.coco_metric.name][0],
step_outputs[self.coco_metric.name][1]) step_outputs[self.coco_metric.name][1])
if self._process_iou_metric_on_cpu:
self.segmentation_perclass_iou_metric.update_state(
step_outputs[self.segmentation_perclass_iou_metric.name][0],
step_outputs[self.segmentation_perclass_iou_metric.name][1])
if self.task_config.model.generate_panoptic_masks: if self.task_config.model.generate_panoptic_masks:
self.panoptic_quality_metric.update_state( self.panoptic_quality_metric.update_state(
step_outputs[self.panoptic_quality_metric.name][0], step_outputs[self.panoptic_quality_metric.name][0],
...@@ -433,11 +418,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -433,11 +418,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
return state return state
def reduce_aggregated_logs(self, aggregated_logs, global_step=None): def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
result = {} result = super().reduce_aggregated_logs(
result = super( aggregated_logs=aggregated_logs, global_step=global_step)
PanopticMaskRCNNTask, self).reduce_aggregated_logs(
aggregated_logs=aggregated_logs,
global_step=global_step)
ious = self.segmentation_perclass_iou_metric.result() ious = self.segmentation_perclass_iou_metric.result()
if self.task_config.segmentation_evaluation.report_per_class_iou: if self.task_config.segmentation_evaluation.report_per_class_iou:
...@@ -447,7 +429,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -447,7 +429,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
result.update({'segmentation_mean_iou': tf.reduce_mean(ious).numpy()}) result.update({'segmentation_mean_iou': tf.reduce_mean(ious).numpy()})
if self.task_config.model.generate_panoptic_masks: if self.task_config.model.generate_panoptic_masks:
report_per_class_metrics = self.task_config.panoptic_quality_evaluator.report_per_class_metrics report_per_class_metrics = (
self.task_config.panoptic_quality_evaluator.report_per_class_metrics)
panoptic_quality_results = self.panoptic_quality_metric.result() panoptic_quality_results = self.panoptic_quality_metric.result()
for k, value in panoptic_quality_results.items(): for k, value in panoptic_quality_results.items():
if k.endswith('per_class'): if k.endswith('per_class'):
......
...@@ -13,9 +13,12 @@ ...@@ -13,9 +13,12 @@
# limitations under the License. # limitations under the License.
"""Metrics for segmentation.""" """Metrics for segmentation."""
import tensorflow as tf import tensorflow as tf
from official.vision.evaluation import iou from official.vision.evaluation import iou
from official.vision.ops import box_ops
from official.vision.ops import spatial_transform_ops
class MeanIoU(tf.keras.metrics.MeanIoU): class MeanIoU(tf.keras.metrics.MeanIoU):
...@@ -48,8 +51,8 @@ class MeanIoU(tf.keras.metrics.MeanIoU): ...@@ -48,8 +51,8 @@ class MeanIoU(tf.keras.metrics.MeanIoU):
Args: Args:
y_true: `dict`, dictionary with the following name, and key values. y_true: `dict`, dictionary with the following name, and key values.
- masks: [batch, width, height, 1], groundtruth masks. - masks: [batch, height, width, 1], groundtruth masks.
- valid_masks: [batch, width, height, 1], valid elements in the mask. - valid_masks: [batch, height, width, 1], valid elements in the mask.
- image_info: [batch, 4, 2], a tensor that holds information about - image_info: [batch, 4, 2], a tensor that holds information about
original and preprocessed images. Each entry is in the format of original and preprocessed images. Each entry is in the format of
[[original_height, original_width], [input_height, input_width], [[original_height, original_width], [input_height, input_width],
...@@ -57,7 +60,7 @@ class MeanIoU(tf.keras.metrics.MeanIoU): ...@@ -57,7 +60,7 @@ class MeanIoU(tf.keras.metrics.MeanIoU):
desired_width] is the actual scaled image size, and [y_scale, x_scale] desired_width] is the actual scaled image size, and [y_scale, x_scale]
is the scaling factor, which is the ratio of scaled dimension / is the scaling factor, which is the ratio of scaled dimension /
original dimension. original dimension.
y_pred: Tensor [batch, width_p, height_p, num_classes], predicated masks. y_pred: Tensor [batch, height_p, width_p, num_classes], predicated masks.
""" """
predictions = y_pred predictions = y_pred
masks = y_true['masks'] masks = y_true['masks']
...@@ -72,55 +75,32 @@ class MeanIoU(tf.keras.metrics.MeanIoU): ...@@ -72,55 +75,32 @@ class MeanIoU(tf.keras.metrics.MeanIoU):
# Ignore mask elements is set to zero for argmax op. # Ignore mask elements is set to zero for argmax op.
masks = tf.where(valid_masks, masks, tf.zeros_like(masks)) masks = tf.where(valid_masks, masks, tf.zeros_like(masks))
masks_size = tf.shape(masks)[1:3]
if self._rescale_predictions: if self._rescale_predictions:
# This part can only run on cpu/gpu due to dynamic image resizing. # Scale back predictions to original image shapes and pad to mask size.
for i in range(tf.shape(predictions)[0]): # Note: instead of cropping the masks to image shape (dynamic), here we
mask = masks[i] # pad the rescaled predictions to mask size (fixed). And update the
valid_mask = valid_masks[i] # valid_masks to mask out the pixels outside the original image shape.
predicted_mask = predictions[i] predictions, image_shape_masks = _rescale_and_pad_predictions(
image_info = images_info[i] predictions, images_info, output_size=masks_size)
# Only the area within the original image shape is valid.
rescale_size = tf.cast( # (batch_size, height, width, 1)
tf.math.ceil(image_info[1, :] / image_info[2, :]), tf.int32) valid_masks = tf.cast(valid_masks, tf.bool) & tf.expand_dims(
image_shape = tf.cast(image_info[0, :], tf.int32) image_shape_masks, axis=-1)
offsets = tf.cast(image_info[3, :], tf.int32)
predicted_mask = tf.image.resize(
predicted_mask,
rescale_size,
method=tf.image.ResizeMethod.BILINEAR)
predicted_mask = tf.image.crop_to_bounding_box(predicted_mask,
offsets[0], offsets[1],
image_shape[0],
image_shape[1])
mask = tf.image.crop_to_bounding_box(mask, 0, 0, image_shape[0],
image_shape[1])
valid_mask = tf.image.crop_to_bounding_box(valid_mask, 0, 0,
image_shape[0],
image_shape[1])
predicted_mask = tf.argmax(predicted_mask, axis=2)
flatten_predictions = tf.reshape(predicted_mask, shape=[1, -1])
flatten_masks = tf.reshape(mask, shape=[1, -1])
flatten_valid_masks = tf.reshape(valid_mask, shape=[1, -1])
super(MeanIoU, self).update_state(
flatten_masks, flatten_predictions,
tf.cast(flatten_valid_masks, tf.float32))
else: else:
predictions = tf.image.resize( predictions = tf.image.resize(
predictions, predictions, masks_size, method=tf.image.ResizeMethod.BILINEAR)
tf.shape(masks)[1:3],
method=tf.image.ResizeMethod.BILINEAR)
predictions = tf.argmax(predictions, axis=3) predictions = tf.argmax(predictions, axis=3)
flatten_predictions = tf.reshape(predictions, shape=[-1]) flatten_predictions = tf.reshape(predictions, shape=[-1])
flatten_masks = tf.reshape(masks, shape=[-1]) flatten_masks = tf.reshape(masks, shape=[-1])
flatten_valid_masks = tf.reshape(valid_masks, shape=[-1]) flatten_valid_masks = tf.reshape(valid_masks, shape=[-1])
super().update_state(flatten_masks, flatten_predictions, super().update_state(
tf.cast(flatten_valid_masks, tf.float32)) y_true=flatten_masks,
y_pred=flatten_predictions,
sample_weight=tf.cast(flatten_valid_masks, tf.float32))
class PerClassIoU(iou.PerClassIoU): class PerClassIoU(iou.PerClassIoU):
...@@ -153,8 +133,8 @@ class PerClassIoU(iou.PerClassIoU): ...@@ -153,8 +133,8 @@ class PerClassIoU(iou.PerClassIoU):
Args: Args:
y_true: `dict`, dictionary with the following name, and key values. y_true: `dict`, dictionary with the following name, and key values.
- masks: [batch, width, height, 1], groundtruth masks. - masks: [batch, height, width, 1], groundtruth masks.
- valid_masks: [batch, width, height, 1], valid elements in the mask. - valid_masks: [batch, height, width, 1], valid elements in the mask.
- image_info: [batch, 4, 2], a tensor that holds information about - image_info: [batch, 4, 2], a tensor that holds information about
original and preprocessed images. Each entry is in the format of original and preprocessed images. Each entry is in the format of
[[original_height, original_width], [input_height, input_width], [[original_height, original_width], [input_height, input_width],
...@@ -162,7 +142,7 @@ class PerClassIoU(iou.PerClassIoU): ...@@ -162,7 +142,7 @@ class PerClassIoU(iou.PerClassIoU):
desired_width] is the actual scaled image size, and [y_scale, x_scale] desired_width] is the actual scaled image size, and [y_scale, x_scale]
is the scaling factor, which is the ratio of scaled dimension / is the scaling factor, which is the ratio of scaled dimension /
original dimension. original dimension.
y_pred: Tensor [batch, width_p, height_p, num_classes], predicated masks. y_pred: Tensor [batch, height_p, width_p, num_classes], predicated masks.
""" """
predictions = y_pred predictions = y_pred
masks = y_true['masks'] masks = y_true['masks']
...@@ -177,51 +157,83 @@ class PerClassIoU(iou.PerClassIoU): ...@@ -177,51 +157,83 @@ class PerClassIoU(iou.PerClassIoU):
# Ignore mask elements is set to zero for argmax op. # Ignore mask elements is set to zero for argmax op.
masks = tf.where(valid_masks, masks, tf.zeros_like(masks)) masks = tf.where(valid_masks, masks, tf.zeros_like(masks))
masks_size = tf.shape(masks)[1:3]
if self._rescale_predictions: if self._rescale_predictions:
# This part can only run on cpu/gpu due to dynamic image resizing. # Scale back predictions to original image shapes and pad to mask size.
for i in range(tf.shape(predictions)[0]): # Note: instead of cropping the masks to image shape (dynamic), here we
mask = masks[i] # pad the rescaled predictions to mask size (fixed). And update the
valid_mask = valid_masks[i] # valid_masks to mask out the pixels outside the original image shape.
predicted_mask = predictions[i] predictions, image_shape_masks = _rescale_and_pad_predictions(
image_info = images_info[i] predictions, images_info, output_size=masks_size)
# Only the area within the original image shape is valid.
rescale_size = tf.cast( # (batch_size, height, width, 1)
tf.math.ceil(image_info[1, :] / image_info[2, :]), tf.int32) valid_masks = tf.cast(valid_masks, tf.bool) & tf.expand_dims(
image_shape = tf.cast(image_info[0, :], tf.int32) image_shape_masks, axis=-1)
offsets = tf.cast(image_info[3, :], tf.int32)
predicted_mask = tf.image.resize(
predicted_mask,
rescale_size,
method=tf.image.ResizeMethod.BILINEAR)
predicted_mask = tf.image.crop_to_bounding_box(predicted_mask,
offsets[0], offsets[1],
image_shape[0],
image_shape[1])
mask = tf.image.crop_to_bounding_box(mask, 0, 0, image_shape[0],
image_shape[1])
valid_mask = tf.image.crop_to_bounding_box(valid_mask, 0, 0,
image_shape[0],
image_shape[1])
predicted_mask = tf.argmax(predicted_mask, axis=2)
flatten_predictions = tf.reshape(predicted_mask, shape=[1, -1])
flatten_masks = tf.reshape(mask, shape=[1, -1])
flatten_valid_masks = tf.reshape(valid_mask, shape=[1, -1])
super().update_state(flatten_masks, flatten_predictions,
tf.cast(flatten_valid_masks, tf.float32))
else: else:
predictions = tf.image.resize( predictions = tf.image.resize(
predictions, predictions, masks_size, method=tf.image.ResizeMethod.BILINEAR)
tf.shape(masks)[1:3],
method=tf.image.ResizeMethod.BILINEAR)
predictions = tf.argmax(predictions, axis=3) predictions = tf.argmax(predictions, axis=3)
flatten_predictions = tf.reshape(predictions, shape=[-1]) flatten_predictions = tf.reshape(predictions, shape=[-1])
flatten_masks = tf.reshape(masks, shape=[-1]) flatten_masks = tf.reshape(masks, shape=[-1])
flatten_valid_masks = tf.reshape(valid_masks, shape=[-1]) flatten_valid_masks = tf.reshape(valid_masks, shape=[-1])
super().update_state(flatten_masks, flatten_predictions, super().update_state(
tf.cast(flatten_valid_masks, tf.float32)) y_true=flatten_masks,
y_pred=flatten_predictions,
sample_weight=tf.cast(flatten_valid_masks, tf.float32))
def _rescale_and_pad_predictions(predictions, images_info, output_size):
"""Scales back predictions to original image shapes and pads to output size.
Args:
predictions: A tensor in shape [batch, height, width, num_classes] which
stores the model predictions.
images_info: A tensor in shape [batch, 4, 2] that holds information about
original and preprocessed images. Each entry is in the format of
[[original_height, original_width], [input_height, input_width], [y_scale,
x_scale], [y_offset, x_offset]], where [desired_height, desired_width] is
the actual scaled image size, and [y_scale, x_scale] is the scaling
factor, which is the ratio of scaled dimension / original dimension.
output_size: A list/tuple/tensor stores the size of the padded output in
[output_height, output_width].
Returns:
predictions: A tensor in shape [batch, output_height, output_width,
num_classes] which stores the rescaled and padded predictions.
image_shape_masks: A bool tensor in shape [batch, output_height,
output_width] where the pixels inside the original image shape are true,
otherwise false.
"""
# (batch_size, 2)
image_shape = tf.cast(images_info[:, 0, :], tf.int32)
desired_size = tf.cast(images_info[:, 1, :], tf.float32)
image_scale = tf.cast(images_info[:, 2, :], tf.float32)
offset = tf.cast(images_info[:, 3, :], tf.int32)
rescale_size = tf.cast(tf.math.ceil(desired_size / image_scale), tf.int32)
# Rescale the predictions, then crop to the original image shape and
# finally pad zeros to match the mask size.
predictions = (
spatial_transform_ops.bilinear_resize_with_crop_and_pad(
predictions,
rescale_size,
crop_offset=offset,
crop_size=image_shape,
output_size=output_size))
# (batch_size, 2)
y0_x0 = tf.broadcast_to(
tf.constant([[0, 0]], dtype=image_shape.dtype), tf.shape(image_shape))
# (batch_size, 4)
image_shape_bbox = tf.concat([y0_x0, image_shape], axis=1)
# (batch_size, height, width)
image_shape_masks = box_ops.bbox2mask(
bbox=image_shape_bbox,
image_height=output_size[0],
image_width=output_size[1],
dtype=tf.bool)
return predictions, image_shape_masks
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
"""Tests for segmentation_metrics.""" """Tests for segmentation_metrics."""
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np
import tensorflow as tf import tensorflow as tf
from official.vision.evaluation import segmentation_metrics from official.vision.evaluation import segmentation_metrics
...@@ -24,26 +23,23 @@ from official.vision.evaluation import segmentation_metrics ...@@ -24,26 +23,23 @@ from official.vision.evaluation import segmentation_metrics
class SegmentationMetricsTest(parameterized.TestCase, tf.test.TestCase): class SegmentationMetricsTest(parameterized.TestCase, tf.test.TestCase):
def _create_test_data(self): def _create_test_data(self):
y_pred_cls0 = np.expand_dims( y_pred_cls0 = tf.constant([[1, 1, 0], [1, 1, 0], [0, 0, 0]],
np.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]], dtype=np.uint16), dtype=tf.uint16)[tf.newaxis, :, :, tf.newaxis]
axis=(0, -1)) y_pred_cls1 = tf.constant([[0, 0, 0], [0, 0, 1], [0, 0, 1]],
y_pred_cls1 = np.expand_dims( dtype=tf.uint16)[tf.newaxis, :, :, tf.newaxis]
np.array([[0, 0, 0], [0, 0, 1], [0, 0, 1]], dtype=np.uint16), y_pred = tf.concat((y_pred_cls0, y_pred_cls1), axis=-1)
axis=(0, -1))
y_pred = np.concatenate((y_pred_cls0, y_pred_cls1), axis=-1)
y_true = { y_true = {
'masks': 'masks':
np.expand_dims( tf.constant(
np.array([[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 1]],
[0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 1]], dtype=tf.uint16)[tf.newaxis, :, :, tf.newaxis],
dtype=np.uint16),
axis=(0, -1)),
'valid_masks': 'valid_masks':
np.ones([1, 6, 6, 1], dtype=np.uint16), tf.ones([1, 6, 6, 1], dtype=tf.bool),
'image_info': 'image_info':
np.array([[[6, 6], [3, 3], [0.5, 0.5], [0, 0]]], dtype=np.float32) tf.constant([[[6, 6], [3, 3], [0.5, 0.5], [0, 0]]],
dtype=tf.float32)
} }
return y_pred, y_true return y_pred, y_true
......
...@@ -137,8 +137,7 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -137,8 +137,7 @@ class SemanticSegmentationTask(base_task.Task):
loss_params.ignore_label, loss_params.ignore_label,
use_groundtruth_dimension=loss_params.use_groundtruth_dimension, use_groundtruth_dimension=loss_params.use_groundtruth_dimension,
top_k_percent_pixels=loss_params.top_k_percent_pixels, top_k_percent_pixels=loss_params.top_k_percent_pixels,
gt_is_matting_map=loss_params.gt_is_matting_map gt_is_matting_map=loss_params.gt_is_matting_map)
)
total_loss = segmentation_loss_fn(model_outputs['logits'], labels['masks']) total_loss = segmentation_loss_fn(model_outputs['logits'], labels['masks'])
...@@ -181,6 +180,8 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -181,6 +180,8 @@ class SemanticSegmentationTask(base_task.Task):
def build_metrics(self, training: bool = True): def build_metrics(self, training: bool = True):
"""Gets streaming metrics for training/validation.""" """Gets streaming metrics for training/validation."""
metrics = [] metrics = []
self.iou_metric = None
if training and self.task_config.evaluation.report_train_mean_iou: if training and self.task_config.evaluation.report_train_mean_iou:
metrics.append( metrics.append(
segmentation_metrics.MeanIoU( segmentation_metrics.MeanIoU(
...@@ -196,8 +197,8 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -196,8 +197,8 @@ class SemanticSegmentationTask(base_task.Task):
self.iou_metric = segmentation_metrics.PerClassIoU( self.iou_metric = segmentation_metrics.PerClassIoU(
name='per_class_iou', name='per_class_iou',
num_classes=self.task_config.model.num_classes, num_classes=self.task_config.model.num_classes,
rescale_predictions=not self.task_config.validation_data rescale_predictions=(
.resize_eval_groundtruth, not self.task_config.validation_data.resize_eval_groundtruth),
dtype=tf.float32) dtype=tf.float32)
if (self.task_config.validation_data.resize_eval_groundtruth and if (self.task_config.validation_data.resize_eval_groundtruth and
self.task_config.model.get('mask_scoring_head')): self.task_config.model.get('mask_scoring_head')):
...@@ -206,10 +207,6 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -206,10 +207,6 @@ class SemanticSegmentationTask(base_task.Task):
metrics.append( metrics.append(
tf.keras.metrics.MeanSquaredError(name='mask_scores_mse')) tf.keras.metrics.MeanSquaredError(name='mask_scores_mse'))
# Update state on CPU if TPUStrategy due to dynamic resizing.
self._process_iou_metric_on_cpu = isinstance(tf.distribute.get_strategy(),
tf.distribute.TPUStrategy)
return metrics return metrics
def train_step(self, def train_step(self,
...@@ -307,11 +304,8 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -307,11 +304,8 @@ class SemanticSegmentationTask(base_task.Task):
logs = {self.loss: loss} logs = {self.loss: loss}
if self._process_iou_metric_on_cpu: if self.iou_metric is not None:
logs.update({self.iou_metric.name: (labels, outputs['logits'])})
else:
self.iou_metric.update_state(labels, outputs['logits']) self.iou_metric.update_state(labels, outputs['logits'])
if metrics: if metrics:
self.process_metrics(metrics, labels, outputs) self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics}) logs.update({m.name: m.result() for m in metrics})
...@@ -323,21 +317,19 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -323,21 +317,19 @@ class SemanticSegmentationTask(base_task.Task):
return model(inputs, training=False) return model(inputs, training=False)
def aggregate_logs(self, state=None, step_outputs=None): def aggregate_logs(self, state=None, step_outputs=None):
if state is None: if state is None and self.iou_metric is not None:
self.iou_metric.reset_states() self.iou_metric.reset_states()
state = self.iou_metric state = self.iou_metric
if self._process_iou_metric_on_cpu:
self.iou_metric.update_state(step_outputs[self.iou_metric.name][0],
step_outputs[self.iou_metric.name][1])
return state return state
def reduce_aggregated_logs(self, aggregated_logs, global_step=None): def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
result = {} result = {}
if self.iou_metric is not None:
ious = self.iou_metric.result() ious = self.iou_metric.result()
# TODO(arashwan): support loading class name from a label map file. # TODO(arashwan): support loading class name from a label map file.
if self.task_config.evaluation.report_per_class_iou: if self.task_config.evaluation.report_per_class_iou:
for i, value in enumerate(ious.numpy()): for i, value in enumerate(ious.numpy()):
result.update({'iou/{}'.format(i): value}) result.update({'iou/{}'.format(i): value})
# Computes mean IoU # Computes mean IoU
result.update({'mean_iou': tf.reduce_mean(ious).numpy()}) result.update({'mean_iou': tf.reduce_mean(ious)})
return result return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment