Commit bfc65896 authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 416953858
parent b5248395
...@@ -183,7 +183,7 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -183,7 +183,7 @@ class SemanticSegmentationTask(base_task.Task):
num_classes=self.task_config.model.num_classes, num_classes=self.task_config.model.num_classes,
rescale_predictions=False, rescale_predictions=False,
dtype=tf.float32)) dtype=tf.float32))
if self.task_config.model.mask_scoring_head: if self.task_config.model.get('mask_scoring_head'):
metrics.append( metrics.append(
tf.keras.metrics.MeanSquaredError(name='mask_scores_mse')) tf.keras.metrics.MeanSquaredError(name='mask_scores_mse'))
else: else:
...@@ -193,7 +193,7 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -193,7 +193,7 @@ class SemanticSegmentationTask(base_task.Task):
rescale_predictions=not self.task_config.validation_data rescale_predictions=not self.task_config.validation_data
.resize_eval_groundtruth, .resize_eval_groundtruth,
dtype=tf.float32) dtype=tf.float32)
if self.task_config.validation_data.resize_eval_groundtruth and self.task_config.model.mask_scoring_head: if self.task_config.validation_data.resize_eval_groundtruth and self.task_config.model.get('mask_scoring_head'): # pylint: disable=line-too-long
# Masks scores metric can only be computed if labels are scaled to match # Masks scores metric can only be computed if labels are scaled to match
# preticted mask scores. # preticted mask scores.
metrics.append( metrics.append(
...@@ -232,6 +232,8 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -232,6 +232,8 @@ class SemanticSegmentationTask(base_task.Task):
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
outputs = model(features, training=True) outputs = model(features, training=True)
if isinstance(outputs, tf.Tensor):
outputs = {'logits': outputs}
# Casting output layer as float32 is necessary when mixed_precision is # Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32. # mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure( outputs = tf.nest.map_structure(
...@@ -287,6 +289,8 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -287,6 +289,8 @@ class SemanticSegmentationTask(base_task.Task):
features, input_partition_dims) features, input_partition_dims)
outputs = self.inference_step(features, model) outputs = self.inference_step(features, model)
if isinstance(outputs, tf.Tensor):
outputs = {'logits': outputs}
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs) outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
if self.task_config.validation_data.resize_eval_groundtruth: if self.task_config.validation_data.resize_eval_groundtruth:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment