Commit 8d71f896 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 476492677
parent d38e8da1
......@@ -29,9 +29,10 @@ class SegmentationLoss:
label_smoothing,
class_weights,
ignore_label,
gt_is_matting_map,
use_groundtruth_dimension,
top_k_percent_pixels=1.0):
top_k_percent_pixels=1.0,
gt_is_matting_map=False
):
"""Initializes `SegmentationLoss`.
Args:
......@@ -39,20 +40,21 @@ class SegmentationLoss:
spreading the amount of probability to all other label classes.
class_weights: A float list containing the weight of each class.
ignore_label: An integer specifying the ignore label.
gt_is_matting_map: If or not the groundtruth mask is a matting map. Note
that the matting map is only supported for 2 class segmentation.
use_groundtruth_dimension: A boolean, whether to resize the output to
match the dimension of the ground truth.
top_k_percent_pixels: A float, the value lies in [0.0, 1.0]. When its
value < 1., only compute the loss for the top k percent pixels. This is
useful for hard pixel mining.
gt_is_matting_map: If or not the groundtruth mask is a matting map. Note
that the matting map is only supported for 2 class segmentation.
"""
self._label_smoothing = label_smoothing
self._class_weights = class_weights
self._ignore_label = ignore_label
self._gt_is_matting_map = gt_is_matting_map
self._use_groundtruth_dimension = use_groundtruth_dimension
self._top_k_percent_pixels = top_k_percent_pixels
self._gt_is_matting_map = gt_is_matting_map
def __call__(self, logits, labels, **kwargs):
"""Computes `SegmentationLoss`.
......
......@@ -135,9 +135,10 @@ class SemanticSegmentationTask(base_task.Task):
loss_params.label_smoothing,
loss_params.class_weights,
loss_params.ignore_label,
loss_params.gt_is_matting_map,
use_groundtruth_dimension=loss_params.use_groundtruth_dimension,
top_k_percent_pixels=loss_params.top_k_percent_pixels)
top_k_percent_pixels=loss_params.top_k_percent_pixels,
gt_is_matting_map=loss_params.gt_is_matting_map
)
total_loss = segmentation_loss_fn(model_outputs['logits'], labels['masks'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment