Commit 28c2acd9 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 440164472
parent d21b3b3f
...@@ -166,7 +166,7 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -166,7 +166,7 @@ class SemanticSegmentationTask(base_task.Task):
**kwargs: other args. **kwargs: other args.
""" """
for metric in metrics: for metric in metrics:
if 'mask_scores_mse' is metric.name: if 'mask_scores_mse' == metric.name:
actual_mask_scores = segmentation_losses.get_actual_mask_scores( actual_mask_scores = segmentation_losses.get_actual_mask_scores(
model_outputs['logits'], labels['masks'], model_outputs['logits'], labels['masks'],
self.task_config.losses.ignore_label) self.task_config.losses.ignore_label)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment