Commit fcd09603 authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 406022769
parent 4b06b4b3
......@@ -13,7 +13,6 @@
# limitations under the License.
"""Metrics for segmentation."""
import tensorflow as tf
from official.vision import keras_cv
......@@ -77,11 +76,11 @@ class MeanIoU(tf.keras.metrics.MeanIoU):
if self._rescale_predictions:
# This part can only run on cpu/gpu due to dynamic image resizing.
flatten_predictions = []
flatten_masks = []
flatten_valid_masks = []
for mask, valid_mask, predicted_mask, image_info in zip(
masks, valid_masks, predictions, images_info):
for i in range(tf.shape(predictions)[0]):
mask = masks[i]
valid_mask = valid_masks[i]
predicted_mask = predictions[i]
image_info = images_info[i]
rescale_size = tf.cast(
tf.math.ceil(image_info[1, :] / image_info[2, :]), tf.int32)
......@@ -104,12 +103,12 @@ class MeanIoU(tf.keras.metrics.MeanIoU):
image_shape[1])
predicted_mask = tf.argmax(predicted_mask, axis=2)
flatten_predictions.append(tf.reshape(predicted_mask, shape=[1, -1]))
flatten_masks.append(tf.reshape(mask, shape=[1, -1]))
flatten_valid_masks.append(tf.reshape(valid_mask, shape=[1, -1]))
flatten_predictions = tf.concat(flatten_predictions, axis=1)
flatten_masks = tf.concat(flatten_masks, axis=1)
flatten_valid_masks = tf.concat(flatten_valid_masks, axis=1)
flatten_predictions = tf.reshape(predicted_mask, shape=[1, -1])
flatten_masks = tf.reshape(mask, shape=[1, -1])
flatten_valid_masks = tf.reshape(valid_mask, shape=[1, -1])
super(MeanIoU, self).update_state(
flatten_masks, flatten_predictions,
tf.cast(flatten_valid_masks, tf.float32))
else:
predictions = tf.image.resize(
......@@ -184,11 +183,11 @@ class PerClassIoU(keras_cv.metrics.PerClassIoU):
if self._rescale_predictions:
# This part can only run on cpu/gpu due to dynamic image resizing.
flatten_predictions = []
flatten_masks = []
flatten_valid_masks = []
for mask, valid_mask, predicted_mask, image_info in zip(
masks, valid_masks, predictions, images_info):
for i in range(tf.shape(predictions)[0]):
mask = masks[i]
valid_mask = valid_masks[i]
predicted_mask = predictions[i]
image_info = images_info[i]
rescale_size = tf.cast(
tf.math.ceil(image_info[1, :] / image_info[2, :]), tf.int32)
......@@ -211,12 +210,12 @@ class PerClassIoU(keras_cv.metrics.PerClassIoU):
image_shape[1])
predicted_mask = tf.argmax(predicted_mask, axis=2)
flatten_predictions.append(tf.reshape(predicted_mask, shape=[1, -1]))
flatten_masks.append(tf.reshape(mask, shape=[1, -1]))
flatten_valid_masks.append(tf.reshape(valid_mask, shape=[1, -1]))
flatten_predictions = tf.concat(flatten_predictions, axis=1)
flatten_masks = tf.concat(flatten_masks, axis=1)
flatten_valid_masks = tf.concat(flatten_valid_masks, axis=1)
flatten_predictions = tf.reshape(predicted_mask, shape=[1, -1])
flatten_masks = tf.reshape(mask, shape=[1, -1])
flatten_valid_masks = tf.reshape(valid_mask, shape=[1, -1])
super(PerClassIoU, self).update_state(
flatten_masks, flatten_predictions,
tf.cast(flatten_valid_masks, tf.float32))
else:
predictions = tf.image.resize(
......
......@@ -158,6 +158,10 @@ class SemanticSegmentationTask(base_task.Task):
.resize_eval_groundtruth,
dtype=tf.float32)
# Update state on CPU if TPUStrategy due to dynamic resizing.
self._process_iou_metric_on_cpu = isinstance(
tf.distribute.get_strategy(), tf.distribute.TPUStrategy)
return metrics
def train_step(self,
......@@ -251,7 +255,11 @@ class SemanticSegmentationTask(base_task.Task):
loss = 0
logs = {self.loss: loss}
if self._process_iou_metric_on_cpu:
logs.update({self.iou_metric.name: (labels, outputs)})
else:
self.iou_metric.update_state(labels, outputs)
if metrics:
self.process_metrics(metrics, labels, outputs)
......@@ -267,6 +275,7 @@ class SemanticSegmentationTask(base_task.Task):
if state is None:
self.iou_metric.reset_states()
state = self.iou_metric
if self._process_iou_metric_on_cpu:
self.iou_metric.update_state(step_outputs[self.iou_metric.name][0],
step_outputs[self.iou_metric.name][1])
return state
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment