Commit fcd09603 authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 406022769
parent 4b06b4b3
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
"""Metrics for segmentation.""" """Metrics for segmentation."""
import tensorflow as tf import tensorflow as tf
from official.vision import keras_cv from official.vision import keras_cv
...@@ -77,11 +76,11 @@ class MeanIoU(tf.keras.metrics.MeanIoU): ...@@ -77,11 +76,11 @@ class MeanIoU(tf.keras.metrics.MeanIoU):
if self._rescale_predictions: if self._rescale_predictions:
# This part can only run on cpu/gpu due to dynamic image resizing. # This part can only run on cpu/gpu due to dynamic image resizing.
flatten_predictions = [] for i in range(tf.shape(predictions)[0]):
flatten_masks = [] mask = masks[i]
flatten_valid_masks = [] valid_mask = valid_masks[i]
for mask, valid_mask, predicted_mask, image_info in zip( predicted_mask = predictions[i]
masks, valid_masks, predictions, images_info): image_info = images_info[i]
rescale_size = tf.cast( rescale_size = tf.cast(
tf.math.ceil(image_info[1, :] / image_info[2, :]), tf.int32) tf.math.ceil(image_info[1, :] / image_info[2, :]), tf.int32)
...@@ -104,12 +103,12 @@ class MeanIoU(tf.keras.metrics.MeanIoU): ...@@ -104,12 +103,12 @@ class MeanIoU(tf.keras.metrics.MeanIoU):
image_shape[1]) image_shape[1])
predicted_mask = tf.argmax(predicted_mask, axis=2) predicted_mask = tf.argmax(predicted_mask, axis=2)
flatten_predictions.append(tf.reshape(predicted_mask, shape=[1, -1])) flatten_predictions = tf.reshape(predicted_mask, shape=[1, -1])
flatten_masks.append(tf.reshape(mask, shape=[1, -1])) flatten_masks = tf.reshape(mask, shape=[1, -1])
flatten_valid_masks.append(tf.reshape(valid_mask, shape=[1, -1])) flatten_valid_masks = tf.reshape(valid_mask, shape=[1, -1])
flatten_predictions = tf.concat(flatten_predictions, axis=1) super(MeanIoU, self).update_state(
flatten_masks = tf.concat(flatten_masks, axis=1) flatten_masks, flatten_predictions,
flatten_valid_masks = tf.concat(flatten_valid_masks, axis=1) tf.cast(flatten_valid_masks, tf.float32))
else: else:
predictions = tf.image.resize( predictions = tf.image.resize(
...@@ -184,11 +183,11 @@ class PerClassIoU(keras_cv.metrics.PerClassIoU): ...@@ -184,11 +183,11 @@ class PerClassIoU(keras_cv.metrics.PerClassIoU):
if self._rescale_predictions: if self._rescale_predictions:
# This part can only run on cpu/gpu due to dynamic image resizing. # This part can only run on cpu/gpu due to dynamic image resizing.
flatten_predictions = [] for i in range(tf.shape(predictions)[0]):
flatten_masks = [] mask = masks[i]
flatten_valid_masks = [] valid_mask = valid_masks[i]
for mask, valid_mask, predicted_mask, image_info in zip( predicted_mask = predictions[i]
masks, valid_masks, predictions, images_info): image_info = images_info[i]
rescale_size = tf.cast( rescale_size = tf.cast(
tf.math.ceil(image_info[1, :] / image_info[2, :]), tf.int32) tf.math.ceil(image_info[1, :] / image_info[2, :]), tf.int32)
...@@ -211,12 +210,12 @@ class PerClassIoU(keras_cv.metrics.PerClassIoU): ...@@ -211,12 +210,12 @@ class PerClassIoU(keras_cv.metrics.PerClassIoU):
image_shape[1]) image_shape[1])
predicted_mask = tf.argmax(predicted_mask, axis=2) predicted_mask = tf.argmax(predicted_mask, axis=2)
flatten_predictions.append(tf.reshape(predicted_mask, shape=[1, -1])) flatten_predictions = tf.reshape(predicted_mask, shape=[1, -1])
flatten_masks.append(tf.reshape(mask, shape=[1, -1])) flatten_masks = tf.reshape(mask, shape=[1, -1])
flatten_valid_masks.append(tf.reshape(valid_mask, shape=[1, -1])) flatten_valid_masks = tf.reshape(valid_mask, shape=[1, -1])
flatten_predictions = tf.concat(flatten_predictions, axis=1) super(PerClassIoU, self).update_state(
flatten_masks = tf.concat(flatten_masks, axis=1) flatten_masks, flatten_predictions,
flatten_valid_masks = tf.concat(flatten_valid_masks, axis=1) tf.cast(flatten_valid_masks, tf.float32))
else: else:
predictions = tf.image.resize( predictions = tf.image.resize(
......
...@@ -158,6 +158,10 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -158,6 +158,10 @@ class SemanticSegmentationTask(base_task.Task):
.resize_eval_groundtruth, .resize_eval_groundtruth,
dtype=tf.float32) dtype=tf.float32)
# Update state on CPU if TPUStrategy due to dynamic resizing.
self._process_iou_metric_on_cpu = isinstance(
tf.distribute.get_strategy(), tf.distribute.TPUStrategy)
return metrics return metrics
def train_step(self, def train_step(self,
...@@ -251,7 +255,11 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -251,7 +255,11 @@ class SemanticSegmentationTask(base_task.Task):
loss = 0 loss = 0
logs = {self.loss: loss} logs = {self.loss: loss}
if self._process_iou_metric_on_cpu:
logs.update({self.iou_metric.name: (labels, outputs)}) logs.update({self.iou_metric.name: (labels, outputs)})
else:
self.iou_metric.update_state(labels, outputs)
if metrics: if metrics:
self.process_metrics(metrics, labels, outputs) self.process_metrics(metrics, labels, outputs)
...@@ -267,6 +275,7 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -267,6 +275,7 @@ class SemanticSegmentationTask(base_task.Task):
if state is None: if state is None:
self.iou_metric.reset_states() self.iou_metric.reset_states()
state = self.iou_metric state = self.iou_metric
if self._process_iou_metric_on_cpu:
self.iou_metric.update_state(step_outputs[self.iou_metric.name][0], self.iou_metric.update_state(step_outputs[self.iou_metric.name][0],
step_outputs[self.iou_metric.name][1]) step_outputs[self.iou_metric.name][1])
return state return state
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment