Commit 4bf492a8 authored by Ronny Votel's avatar Ronny Votel Committed by TF Object Detection Team
Browse files

Updating the centernet mask target assigner.

PiperOrigin-RevId: 377511299
parent 33d1ce83
...@@ -2001,8 +2001,8 @@ class CenterNetMaskTargetAssigner(object): ...@@ -2001,8 +2001,8 @@ class CenterNetMaskTargetAssigner(object):
self._stride = stride self._stride = stride
def assign_segmentation_targets( def assign_segmentation_targets(
self, gt_masks_list, gt_classes_list, self, gt_masks_list, gt_classes_list, gt_boxes_list=None,
mask_resize_method=ResizeMethod.BILINEAR): gt_mask_weights_list=None, mask_resize_method=ResizeMethod.BILINEAR):
"""Computes the segmentation targets. """Computes the segmentation targets.
This utility produces a semantic segmentation mask for each class, starting This utility produces a semantic segmentation mask for each class, starting
...@@ -2016,15 +2016,25 @@ class CenterNetMaskTargetAssigner(object): ...@@ -2016,15 +2016,25 @@ class CenterNetMaskTargetAssigner(object):
gt_classes_list: A list of float tensors with shape [num_boxes, gt_classes_list: A list of float tensors with shape [num_boxes,
num_classes] representing the one-hot encoded class labels for each box num_classes] representing the one-hot encoded class labels for each box
in the gt_boxes_list. in the gt_boxes_list.
gt_boxes_list: An optional list of float tensors with shape [num_boxes, 4]
with normalized boxes corresponding to each mask. The boxes are used to
spatially allocate mask weights.
gt_mask_weights_list: An optional list of float tensors with shape
[num_boxes] with weights for each mask. If a mask has a zero weight, it
indicates that the box region associated with the mask should not
contribute to the loss. If not provided, will use a per-pixel weight of
1.
mask_resize_method: A `tf.compat.v2.image.ResizeMethod`. The method to use mask_resize_method: A `tf.compat.v2.image.ResizeMethod`. The method to use
when resizing masks from input resolution to output resolution. when resizing masks from input resolution to output resolution.
Returns: Returns:
segmentation_targets: An int32 tensor of size [batch_size, output_height, segmentation_targets: An int32 tensor of size [batch_size, output_height,
output_width, num_classes] representing the class of each location in output_width, num_classes] representing the class of each location in
the output space. the output space.
segmentation_weight: A float32 tensor of size [batch_size, output_height,
output_width] indicating the loss weight to apply at each location.
""" """
# TODO(ronnyvotel): Handle groundtruth weights.
_, num_classes = shape_utils.combined_static_and_dynamic_shape( _, num_classes = shape_utils.combined_static_and_dynamic_shape(
gt_classes_list[0]) gt_classes_list[0])
...@@ -2033,8 +2043,35 @@ class CenterNetMaskTargetAssigner(object): ...@@ -2033,8 +2043,35 @@ class CenterNetMaskTargetAssigner(object):
output_height = tf.maximum(input_height // self._stride, 1) output_height = tf.maximum(input_height // self._stride, 1)
output_width = tf.maximum(input_width // self._stride, 1) output_width = tf.maximum(input_width // self._stride, 1)
if gt_boxes_list is None:
gt_boxes_list = [None] * len(gt_masks_list)
if gt_mask_weights_list is None:
gt_mask_weights_list = [None] * len(gt_masks_list)
segmentation_targets_list = [] segmentation_targets_list = []
for gt_masks, gt_classes in zip(gt_masks_list, gt_classes_list): segmentation_weights_list = []
for gt_boxes, gt_masks, gt_mask_weights, gt_classes in zip(
gt_boxes_list, gt_masks_list, gt_mask_weights_list, gt_classes_list):
if gt_boxes is not None and gt_mask_weights is not None:
boxes = box_list.BoxList(gt_boxes)
# Convert the box coordinates to absolute output image dimension space.
boxes_absolute = box_list_ops.to_absolute_coordinates(
boxes, output_height, output_width)
# Generate a segmentation weight that applies mask weights in object
# regions.
blackout = gt_mask_weights <= 0
segmentation_weight_for_image = (
ta_utils.blackout_pixel_weights_by_box_regions(
output_height, output_width, boxes_absolute.get(), blackout,
weights=gt_mask_weights))
segmentation_weights_list.append(segmentation_weight_for_image)
else:
segmentation_weights_list.append(tf.ones((output_height, output_width),
dtype=tf.float32))
gt_masks = _resize_masks(gt_masks, output_height, output_width, gt_masks = _resize_masks(gt_masks, output_height, output_width,
mask_resize_method) mask_resize_method)
gt_masks = gt_masks[:, :, :, tf.newaxis] gt_masks = gt_masks[:, :, :, tf.newaxis]
...@@ -2047,7 +2084,8 @@ class CenterNetMaskTargetAssigner(object): ...@@ -2047,7 +2084,8 @@ class CenterNetMaskTargetAssigner(object):
segmentation_targets_list.append(segmentations_for_image) segmentation_targets_list.append(segmentations_for_image)
segmentation_target = tf.stack(segmentation_targets_list, axis=0) segmentation_target = tf.stack(segmentation_targets_list, axis=0)
return segmentation_target segmentation_weight = tf.stack(segmentation_weights_list, axis=0)
return segmentation_target, segmentation_weight
class CenterNetDensePoseTargetAssigner(object): class CenterNetDensePoseTargetAssigner(object):
......
...@@ -2090,13 +2090,31 @@ class CenterNetMaskTargetAssignerTest(test_case.TestCase): ...@@ -2090,13 +2090,31 @@ class CenterNetMaskTargetAssignerTest(test_case.TestCase):
tf.constant([[0., 1., 0.], tf.constant([[0., 1., 0.],
[0., 1., 0.]], dtype=tf.float32) [0., 1., 0.]], dtype=tf.float32)
] ]
gt_boxes_list = [
# Example 0.
tf.constant([[0.0, 0.0, 0.5, 0.5],
[0.0, 0.5, 0.5, 1.0],
[0.0, 0.0, 1.0, 1.0]], dtype=tf.float32),
# Example 1.
tf.constant([[0.0, 0.0, 1.0, 1.0],
[0.5, 0.0, 1.0, 0.5]], dtype=tf.float32)
]
gt_mask_weights_list = [
# Example 0.
tf.constant([0.0, 1.0, 1.0], dtype=tf.float32),
# Example 1.
tf.constant([1.0, 1.0], dtype=tf.float32)
]
cn_assigner = targetassigner.CenterNetMaskTargetAssigner(stride=2) cn_assigner = targetassigner.CenterNetMaskTargetAssigner(stride=2)
segmentation_target = cn_assigner.assign_segmentation_targets( segmentation_target, segmentation_weight = (
cn_assigner.assign_segmentation_targets(
gt_masks_list=gt_masks_list, gt_masks_list=gt_masks_list,
gt_classes_list=gt_classes_list, gt_classes_list=gt_classes_list,
mask_resize_method=targetassigner.ResizeMethod.NEAREST_NEIGHBOR) gt_boxes_list=gt_boxes_list,
return segmentation_target gt_mask_weights_list=gt_mask_weights_list,
segmentation_target = self.execute(graph_fn, []) mask_resize_method=targetassigner.ResizeMethod.NEAREST_NEIGHBOR))
return segmentation_target, segmentation_weight
segmentation_target, segmentation_weight = self.execute(graph_fn, [])
expected_seg_target = np.array([ expected_seg_target = np.array([
# Example 0 [[class 0, class 1], [background, class 0]] # Example 0 [[class 0, class 1], [background, class 0]]
...@@ -2108,13 +2126,18 @@ class CenterNetMaskTargetAssignerTest(test_case.TestCase): ...@@ -2108,13 +2126,18 @@ class CenterNetMaskTargetAssignerTest(test_case.TestCase):
], dtype=np.float32) ], dtype=np.float32)
np.testing.assert_array_almost_equal( np.testing.assert_array_almost_equal(
expected_seg_target, segmentation_target) expected_seg_target, segmentation_target)
expected_seg_weight = np.array([
[[0, 1], [1, 1]],
[[1, 1], [1, 1]]], dtype=np.float32)
np.testing.assert_array_almost_equal(
expected_seg_weight, segmentation_weight)
def test_assign_segmentation_targets_no_objects(self): def test_assign_segmentation_targets_no_objects(self):
def graph_fn(): def graph_fn():
gt_masks_list = [tf.zeros((0, 5, 5))] gt_masks_list = [tf.zeros((0, 5, 5))]
gt_classes_list = [tf.zeros((0, 10))] gt_classes_list = [tf.zeros((0, 10))]
cn_assigner = targetassigner.CenterNetMaskTargetAssigner(stride=1) cn_assigner = targetassigner.CenterNetMaskTargetAssigner(stride=1)
segmentation_target = cn_assigner.assign_segmentation_targets( segmentation_target, _ = cn_assigner.assign_segmentation_targets(
gt_masks_list=gt_masks_list, gt_masks_list=gt_masks_list,
gt_classes_list=gt_classes_list, gt_classes_list=gt_classes_list,
mask_resize_method=targetassigner.ResizeMethod.NEAREST_NEIGHBOR) mask_resize_method=targetassigner.ResizeMethod.NEAREST_NEIGHBOR)
......
...@@ -2979,20 +2979,32 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2979,20 +2979,32 @@ class CenterNetMetaArch(model.DetectionModel):
Returns: Returns:
A float scalar tensor representing the mask loss. A float scalar tensor representing the mask loss.
""" """
gt_boxes_list = self.groundtruth_lists(fields.BoxListFields.boxes)
gt_masks_list = self.groundtruth_lists(fields.BoxListFields.masks) gt_masks_list = self.groundtruth_lists(fields.BoxListFields.masks)
gt_mask_weights_list = None
if self.groundtruth_has_field(fields.BoxListFields.mask_weights):
gt_mask_weights_list = self.groundtruth_lists(
fields.BoxListFields.mask_weights)
gt_classes_list = self.groundtruth_lists(fields.BoxListFields.classes) gt_classes_list = self.groundtruth_lists(fields.BoxListFields.classes)
# Convert the groundtruth to targets. # Convert the groundtruth to targets.
assigner = self._target_assigner_dict[SEGMENTATION_TASK] assigner = self._target_assigner_dict[SEGMENTATION_TASK]
heatmap_targets = assigner.assign_segmentation_targets( heatmap_targets, heatmap_weight = assigner.assign_segmentation_targets(
gt_masks_list=gt_masks_list, gt_masks_list=gt_masks_list,
gt_classes_list=gt_classes_list) gt_classes_list=gt_classes_list,
gt_boxes_list=gt_boxes_list,
gt_mask_weights_list=gt_mask_weights_list)
flattened_heatmap_targets = _flatten_spatial_dimensions(heatmap_targets) flattened_heatmap_targets = _flatten_spatial_dimensions(heatmap_targets)
flattened_heatmap_mask = _flatten_spatial_dimensions(
heatmap_weight[:, :, :, tf.newaxis])
per_pixel_weights *= flattened_heatmap_mask
loss = 0.0 loss = 0.0
mask_loss_fn = self._mask_params.classification_loss mask_loss_fn = self._mask_params.classification_loss
total_pixels_in_loss = tf.reduce_sum(per_pixel_weights)
total_pixels_in_loss = tf.math.maximum(
tf.reduce_sum(per_pixel_weights), 1)
# Loop through each feature output head. # Loop through each feature output head.
for pred in segmentation_predictions: for pred in segmentation_predictions:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment