Commit bfc65896 authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 416953858
parent b5248395
......@@ -183,7 +183,7 @@ class SemanticSegmentationTask(base_task.Task):
num_classes=self.task_config.model.num_classes,
rescale_predictions=False,
dtype=tf.float32))
if self.task_config.model.mask_scoring_head:
if self.task_config.model.get('mask_scoring_head'):
metrics.append(
tf.keras.metrics.MeanSquaredError(name='mask_scores_mse'))
else:
......@@ -193,7 +193,7 @@ class SemanticSegmentationTask(base_task.Task):
rescale_predictions=not self.task_config.validation_data
.resize_eval_groundtruth,
dtype=tf.float32)
if self.task_config.validation_data.resize_eval_groundtruth and self.task_config.model.mask_scoring_head:
if self.task_config.validation_data.resize_eval_groundtruth and self.task_config.model.get('mask_scoring_head'): # pylint: disable=line-too-long
# Masks scores metric can only be computed if labels are scaled to match
# preticted mask scores.
metrics.append(
......@@ -232,6 +232,8 @@ class SemanticSegmentationTask(base_task.Task):
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape:
outputs = model(features, training=True)
if isinstance(outputs, tf.Tensor):
outputs = {'logits': outputs}
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure(
......@@ -287,6 +289,8 @@ class SemanticSegmentationTask(base_task.Task):
features, input_partition_dims)
outputs = self.inference_step(features, model)
if isinstance(outputs, tf.Tensor):
outputs = {'logits': outputs}
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
if self.task_config.validation_data.resize_eval_groundtruth:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment