Commit a04d9e0e authored by Vishnu Banna's avatar Vishnu Banna
Browse files

merged

parents 64f16d61 bcbce005
...@@ -1414,6 +1414,7 @@ def _strict_random_crop_image(image, ...@@ -1414,6 +1414,7 @@ def _strict_random_crop_image(image,
label_confidences=None, label_confidences=None,
multiclass_scores=None, multiclass_scores=None,
masks=None, masks=None,
mask_weights=None,
keypoints=None, keypoints=None,
keypoint_visibilities=None, keypoint_visibilities=None,
densepose_num_points=None, densepose_num_points=None,
...@@ -1451,6 +1452,8 @@ def _strict_random_crop_image(image, ...@@ -1451,6 +1452,8 @@ def _strict_random_crop_image(image,
masks: (optional) rank 3 float32 tensor with shape masks: (optional) rank 3 float32 tensor with shape
[num_instances, height, width] containing instance masks. The masks [num_instances, height, width] containing instance masks. The masks
are of the same height, width as the input `image`. are of the same height, width as the input `image`.
mask_weights: (optional) rank 1 float32 tensor with shape [num_instances]
with instance masks weights.
keypoints: (optional) rank 3 float32 tensor with shape keypoints: (optional) rank 3 float32 tensor with shape
[num_instances, num_keypoints, 2]. The keypoints are in y-x [num_instances, num_keypoints, 2]. The keypoints are in y-x
normalized coordinates. normalized coordinates.
...@@ -1488,7 +1491,7 @@ def _strict_random_crop_image(image, ...@@ -1488,7 +1491,7 @@ def _strict_random_crop_image(image,
Boxes are in normalized form. Boxes are in normalized form.
labels: new labels. labels: new labels.
If label_weights, multiclass_scores, masks, keypoints, If label_weights, multiclass_scores, masks, mask_weights, keypoints,
keypoint_visibilities, densepose_num_points, densepose_part_ids, or keypoint_visibilities, densepose_num_points, densepose_part_ids, or
densepose_surface_coords is not None, the function also returns: densepose_surface_coords is not None, the function also returns:
label_weights: rank 1 float32 tensor with shape [num_instances]. label_weights: rank 1 float32 tensor with shape [num_instances].
...@@ -1496,6 +1499,8 @@ def _strict_random_crop_image(image, ...@@ -1496,6 +1499,8 @@ def _strict_random_crop_image(image,
[num_instances, num_classes] [num_instances, num_classes]
masks: rank 3 float32 tensor with shape [num_instances, height, width] masks: rank 3 float32 tensor with shape [num_instances, height, width]
containing instance masks. containing instance masks.
mask_weights: rank 1 float32 tensor with shape [num_instances] with mask
weights.
keypoints: rank 3 float32 tensor with shape keypoints: rank 3 float32 tensor with shape
[num_instances, num_keypoints, 2] [num_instances, num_keypoints, 2]
keypoint_visibilities: rank 2 bool tensor with shape keypoint_visibilities: rank 2 bool tensor with shape
...@@ -1605,6 +1610,12 @@ def _strict_random_crop_image(image, ...@@ -1605,6 +1610,12 @@ def _strict_random_crop_image(image,
0]:im_box_end[0], im_box_begin[1]:im_box_end[1]] 0]:im_box_end[0], im_box_begin[1]:im_box_end[1]]
result.append(new_masks) result.append(new_masks)
if mask_weights is not None:
mask_weights_inside_window = tf.gather(mask_weights, inside_window_ids)
mask_weights_completely_inside_window = tf.gather(
mask_weights_inside_window, keep_ids)
result.append(mask_weights_completely_inside_window)
if keypoints is not None: if keypoints is not None:
keypoints_of_boxes_inside_window = tf.gather(keypoints, inside_window_ids) keypoints_of_boxes_inside_window = tf.gather(keypoints, inside_window_ids)
keypoints_of_boxes_completely_inside_window = tf.gather( keypoints_of_boxes_completely_inside_window = tf.gather(
...@@ -1654,6 +1665,7 @@ def random_crop_image(image, ...@@ -1654,6 +1665,7 @@ def random_crop_image(image,
label_confidences=None, label_confidences=None,
multiclass_scores=None, multiclass_scores=None,
masks=None, masks=None,
mask_weights=None,
keypoints=None, keypoints=None,
keypoint_visibilities=None, keypoint_visibilities=None,
densepose_num_points=None, densepose_num_points=None,
...@@ -1701,6 +1713,8 @@ def random_crop_image(image, ...@@ -1701,6 +1713,8 @@ def random_crop_image(image,
masks: (optional) rank 3 float32 tensor with shape masks: (optional) rank 3 float32 tensor with shape
[num_instances, height, width] containing instance masks. The masks [num_instances, height, width] containing instance masks. The masks
are of the same height, width as the input `image`. are of the same height, width as the input `image`.
mask_weights: (optional) rank 1 float32 tensor with shape [num_instances]
containing weights for each instance mask.
keypoints: (optional) rank 3 float32 tensor with shape keypoints: (optional) rank 3 float32 tensor with shape
[num_instances, num_keypoints, 2]. The keypoints are in y-x [num_instances, num_keypoints, 2]. The keypoints are in y-x
normalized coordinates. normalized coordinates.
...@@ -1751,6 +1765,7 @@ def random_crop_image(image, ...@@ -1751,6 +1765,7 @@ def random_crop_image(image,
[num_instances, num_classes] [num_instances, num_classes]
masks: rank 3 float32 tensor with shape [num_instances, height, width] masks: rank 3 float32 tensor with shape [num_instances, height, width]
containing instance masks. containing instance masks.
mask_weights: rank 1 float32 tensor with shape [num_instances].
keypoints: rank 3 float32 tensor with shape keypoints: rank 3 float32 tensor with shape
[num_instances, num_keypoints, 2] [num_instances, num_keypoints, 2]
keypoint_visibilities: rank 2 bool tensor with shape keypoint_visibilities: rank 2 bool tensor with shape
...@@ -1771,6 +1786,7 @@ def random_crop_image(image, ...@@ -1771,6 +1786,7 @@ def random_crop_image(image,
label_confidences=label_confidences, label_confidences=label_confidences,
multiclass_scores=multiclass_scores, multiclass_scores=multiclass_scores,
masks=masks, masks=masks,
mask_weights=mask_weights,
keypoints=keypoints, keypoints=keypoints,
keypoint_visibilities=keypoint_visibilities, keypoint_visibilities=keypoint_visibilities,
densepose_num_points=densepose_num_points, densepose_num_points=densepose_num_points,
...@@ -1803,6 +1819,8 @@ def random_crop_image(image, ...@@ -1803,6 +1819,8 @@ def random_crop_image(image,
outputs.append(multiclass_scores) outputs.append(multiclass_scores)
if masks is not None: if masks is not None:
outputs.append(masks) outputs.append(masks)
if mask_weights is not None:
outputs.append(mask_weights)
if keypoints is not None: if keypoints is not None:
outputs.append(keypoints) outputs.append(keypoints)
if keypoint_visibilities is not None: if keypoint_visibilities is not None:
...@@ -4388,6 +4406,7 @@ def get_default_func_arg_map(include_label_weights=True, ...@@ -4388,6 +4406,7 @@ def get_default_func_arg_map(include_label_weights=True,
include_label_confidences=False, include_label_confidences=False,
include_multiclass_scores=False, include_multiclass_scores=False,
include_instance_masks=False, include_instance_masks=False,
include_instance_mask_weights=False,
include_keypoints=False, include_keypoints=False,
include_keypoint_visibilities=False, include_keypoint_visibilities=False,
include_dense_pose=False, include_dense_pose=False,
...@@ -4403,6 +4422,8 @@ def get_default_func_arg_map(include_label_weights=True, ...@@ -4403,6 +4422,8 @@ def get_default_func_arg_map(include_label_weights=True,
multiclass scores, too. multiclass scores, too.
include_instance_masks: If True, preprocessing functions will modify the include_instance_masks: If True, preprocessing functions will modify the
instance masks, too. instance masks, too.
include_instance_mask_weights: If True, preprocessing functions will modify
the instance mask weights.
include_keypoints: If True, preprocessing functions will modify the include_keypoints: If True, preprocessing functions will modify the
keypoints, too. keypoints, too.
include_keypoint_visibilities: If True, preprocessing functions will modify include_keypoint_visibilities: If True, preprocessing functions will modify
...@@ -4434,6 +4455,11 @@ def get_default_func_arg_map(include_label_weights=True, ...@@ -4434,6 +4455,11 @@ def get_default_func_arg_map(include_label_weights=True,
groundtruth_instance_masks = ( groundtruth_instance_masks = (
fields.InputDataFields.groundtruth_instance_masks) fields.InputDataFields.groundtruth_instance_masks)
groundtruth_instance_mask_weights = None
if include_instance_mask_weights:
groundtruth_instance_mask_weights = (
fields.InputDataFields.groundtruth_instance_mask_weights)
groundtruth_keypoints = None groundtruth_keypoints = None
if include_keypoints: if include_keypoints:
groundtruth_keypoints = fields.InputDataFields.groundtruth_keypoints groundtruth_keypoints = fields.InputDataFields.groundtruth_keypoints
...@@ -4503,7 +4529,8 @@ def get_default_func_arg_map(include_label_weights=True, ...@@ -4503,7 +4529,8 @@ def get_default_func_arg_map(include_label_weights=True,
fields.InputDataFields.groundtruth_boxes, fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_classes, fields.InputDataFields.groundtruth_classes,
groundtruth_label_weights, groundtruth_label_confidences, groundtruth_label_weights, groundtruth_label_confidences,
multiclass_scores, groundtruth_instance_masks, groundtruth_keypoints, multiclass_scores, groundtruth_instance_masks,
groundtruth_instance_mask_weights, groundtruth_keypoints,
groundtruth_keypoint_visibilities, groundtruth_dp_num_points, groundtruth_keypoint_visibilities, groundtruth_dp_num_points,
groundtruth_dp_part_ids, groundtruth_dp_surface_coords), groundtruth_dp_part_ids, groundtruth_dp_surface_coords),
random_pad_image: random_pad_image:
......
...@@ -1894,6 +1894,37 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase): ...@@ -1894,6 +1894,37 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
self.assertAllClose( self.assertAllClose(
new_boxes.flatten(), expected_boxes.flatten()) new_boxes.flatten(), expected_boxes.flatten())
def testStrictRandomCropImageWithMaskWeights(self):
def graph_fn():
image = self.createColorfulTestImage()[0]
boxes = self.createTestBoxes()
labels = self.createTestLabels()
weights = self.createTestGroundtruthWeights()
masks = tf.random_uniform([2, 200, 400], dtype=tf.float32)
mask_weights = tf.constant([1.0, 0.0], dtype=tf.float32)
with mock.patch.object(
tf.image,
'sample_distorted_bounding_box'
) as mock_sample_distorted_bounding_box:
mock_sample_distorted_bounding_box.return_value = (
tf.constant([6, 143, 0], dtype=tf.int32),
tf.constant([190, 237, -1], dtype=tf.int32),
tf.constant([[[0.03, 0.3575, 0.98, 0.95]]], dtype=tf.float32))
results = preprocessor._strict_random_crop_image(
image, boxes, labels, weights, masks=masks,
mask_weights=mask_weights)
return results
(new_image, new_boxes, _, _,
new_masks, new_mask_weights) = self.execute_cpu(graph_fn, [])
expected_boxes = np.array(
[[0.0, 0.0, 0.75789469, 1.0],
[0.23157893, 0.24050637, 0.75789469, 1.0]], dtype=np.float32)
self.assertAllEqual(new_image.shape, [190, 237, 3])
self.assertAllEqual(new_masks.shape, [2, 190, 237])
self.assertAllClose(new_mask_weights, [1.0, 0.0])
self.assertAllClose(
new_boxes.flatten(), expected_boxes.flatten())
def testStrictRandomCropImageWithKeypoints(self): def testStrictRandomCropImageWithKeypoints(self):
def graph_fn(): def graph_fn():
image = self.createColorfulTestImage()[0] image = self.createColorfulTestImage()[0]
...@@ -1947,6 +1978,7 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase): ...@@ -1947,6 +1978,7 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
labels = self.createTestLabels() labels = self.createTestLabels()
weights = self.createTestGroundtruthWeights() weights = self.createTestGroundtruthWeights()
masks = tf.random_uniform([2, 200, 400], dtype=tf.float32) masks = tf.random_uniform([2, 200, 400], dtype=tf.float32)
mask_weights = tf.constant([1.0, 0.0], dtype=tf.float32)
tensor_dict = { tensor_dict = {
fields.InputDataFields.image: image, fields.InputDataFields.image: image,
...@@ -1954,10 +1986,12 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase): ...@@ -1954,10 +1986,12 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
fields.InputDataFields.groundtruth_classes: labels, fields.InputDataFields.groundtruth_classes: labels,
fields.InputDataFields.groundtruth_weights: weights, fields.InputDataFields.groundtruth_weights: weights,
fields.InputDataFields.groundtruth_instance_masks: masks, fields.InputDataFields.groundtruth_instance_masks: masks,
fields.InputDataFields.groundtruth_instance_mask_weights:
mask_weights
} }
preprocessor_arg_map = preprocessor.get_default_func_arg_map( preprocessor_arg_map = preprocessor.get_default_func_arg_map(
include_instance_masks=True) include_instance_masks=True, include_instance_mask_weights=True)
preprocessing_options = [(preprocessor.random_crop_image, {})] preprocessing_options = [(preprocessor.random_crop_image, {})]
...@@ -1980,16 +2014,19 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase): ...@@ -1980,16 +2014,19 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
fields.InputDataFields.groundtruth_classes] fields.InputDataFields.groundtruth_classes]
distorted_masks = distorted_tensor_dict[ distorted_masks = distorted_tensor_dict[
fields.InputDataFields.groundtruth_instance_masks] fields.InputDataFields.groundtruth_instance_masks]
distorted_mask_weights = distorted_tensor_dict[
fields.InputDataFields.groundtruth_instance_mask_weights]
return [distorted_image, distorted_boxes, distorted_labels, return [distorted_image, distorted_boxes, distorted_labels,
distorted_masks] distorted_masks, distorted_mask_weights]
(distorted_image_, distorted_boxes_, distorted_labels_, (distorted_image_, distorted_boxes_, distorted_labels_,
distorted_masks_) = self.execute_cpu(graph_fn, []) distorted_masks_, distorted_mask_weights_) = self.execute_cpu(graph_fn, [])
expected_boxes = np.array([ expected_boxes = np.array([
[0.0, 0.0, 0.75789469, 1.0], [0.0, 0.0, 0.75789469, 1.0],
[0.23157893, 0.24050637, 0.75789469, 1.0], [0.23157893, 0.24050637, 0.75789469, 1.0],
], dtype=np.float32) ], dtype=np.float32)
self.assertAllEqual(distorted_image_.shape, [1, 190, 237, 3]) self.assertAllEqual(distorted_image_.shape, [1, 190, 237, 3])
self.assertAllEqual(distorted_masks_.shape, [2, 190, 237]) self.assertAllEqual(distorted_masks_.shape, [2, 190, 237])
self.assertAllClose(distorted_mask_weights_, [1.0, 0.0])
self.assertAllEqual(distorted_labels_, [1, 2]) self.assertAllEqual(distorted_labels_, [1, 2])
self.assertAllClose( self.assertAllClose(
distorted_boxes_.flatten(), expected_boxes.flatten()) distorted_boxes_.flatten(), expected_boxes.flatten())
......
...@@ -64,6 +64,7 @@ class InputDataFields(object): ...@@ -64,6 +64,7 @@ class InputDataFields(object):
proposal_boxes: coordinates of object proposal boxes. proposal_boxes: coordinates of object proposal boxes.
proposal_objectness: objectness score of each proposal. proposal_objectness: objectness score of each proposal.
groundtruth_instance_masks: ground truth instance masks. groundtruth_instance_masks: ground truth instance masks.
groundtruth_instance_mask_weights: ground truth instance masks weights.
groundtruth_instance_boundaries: ground truth instance boundaries. groundtruth_instance_boundaries: ground truth instance boundaries.
groundtruth_instance_classes: instance mask-level class labels. groundtruth_instance_classes: instance mask-level class labels.
groundtruth_keypoints: ground truth keypoints. groundtruth_keypoints: ground truth keypoints.
...@@ -122,6 +123,7 @@ class InputDataFields(object): ...@@ -122,6 +123,7 @@ class InputDataFields(object):
proposal_boxes = 'proposal_boxes' proposal_boxes = 'proposal_boxes'
proposal_objectness = 'proposal_objectness' proposal_objectness = 'proposal_objectness'
groundtruth_instance_masks = 'groundtruth_instance_masks' groundtruth_instance_masks = 'groundtruth_instance_masks'
groundtruth_instance_mask_weights = 'groundtruth_instance_mask_weights'
groundtruth_instance_boundaries = 'groundtruth_instance_boundaries' groundtruth_instance_boundaries = 'groundtruth_instance_boundaries'
groundtruth_instance_classes = 'groundtruth_instance_classes' groundtruth_instance_classes = 'groundtruth_instance_classes'
groundtruth_keypoints = 'groundtruth_keypoints' groundtruth_keypoints = 'groundtruth_keypoints'
...@@ -208,6 +210,7 @@ class BoxListFields(object): ...@@ -208,6 +210,7 @@ class BoxListFields(object):
weights: sample weights per bounding box. weights: sample weights per bounding box.
objectness: objectness score per bounding box. objectness: objectness score per bounding box.
masks: masks per bounding box. masks: masks per bounding box.
mask_weights: mask weights for each bounding box.
boundaries: boundaries per bounding box. boundaries: boundaries per bounding box.
keypoints: keypoints per bounding box. keypoints: keypoints per bounding box.
keypoint_visibilities: keypoint visibilities per bounding box. keypoint_visibilities: keypoint visibilities per bounding box.
...@@ -228,6 +231,7 @@ class BoxListFields(object): ...@@ -228,6 +231,7 @@ class BoxListFields(object):
confidences = 'confidences' confidences = 'confidences'
objectness = 'objectness' objectness = 'objectness'
masks = 'masks' masks = 'masks'
mask_weights = 'mask_weights'
boundaries = 'boundaries' boundaries = 'boundaries'
keypoints = 'keypoints' keypoints = 'keypoints'
keypoint_visibilities = 'keypoint_visibilities' keypoint_visibilities = 'keypoint_visibilities'
......
...@@ -1409,8 +1409,10 @@ class CenterNetKeypointTargetAssigner(object): ...@@ -1409,8 +1409,10 @@ class CenterNetKeypointTargetAssigner(object):
[batch_size, num_keypoints] representing number of instances for each [batch_size, num_keypoints] representing number of instances for each
keypoint type. keypoint type.
valid_mask: A float tensor with shape [batch_size, output_height, valid_mask: A float tensor with shape [batch_size, output_height,
output_width] where all values within the regions of the blackout boxes output_width, num_keypoints] where all values within the regions of the
are 0.0 and 1.0 else where. blackout boxes are 0.0 and 1.0 else where. Note that the blackout boxes
are per keypoint type and are blacked out if the keypoint
visibility/weight (of the corresponding keypoint type) is zero.
""" """
out_width = tf.cast(tf.maximum(width // self._stride, 1), tf.float32) out_width = tf.cast(tf.maximum(width // self._stride, 1), tf.float32)
out_height = tf.cast(tf.maximum(height // self._stride, 1), tf.float32) out_height = tf.cast(tf.maximum(height // self._stride, 1), tf.float32)
...@@ -1480,13 +1482,17 @@ class CenterNetKeypointTargetAssigner(object): ...@@ -1480,13 +1482,17 @@ class CenterNetKeypointTargetAssigner(object):
keypoint_std_dev = keypoint_std_dev * tf.stack( keypoint_std_dev = keypoint_std_dev * tf.stack(
[sigma] * num_keypoints, axis=1) [sigma] * num_keypoints, axis=1)
# Generate the valid region mask to ignore regions with target class but # Generate the per-keypoint type valid region mask to ignore regions
# no corresponding keypoints. # with keypoint weights equal to zeros (e.g. visibility is 0).
# Shape: [num_instances]. # shape of valid_mask: [out_height, out_width, num_keypoints]
blackout = tf.logical_and(classes[:, self._class_id] > 0, kp_weight_list = tf.unstack(kp_weights, axis=1)
tf.reduce_max(kp_weights, axis=1) < 1e-3) valid_mask_channel_list = []
valid_mask = ta_utils.blackout_pixel_weights_by_box_regions( for kp_weight in kp_weight_list:
out_height, out_width, boxes.get(), blackout) blackout = kp_weight < 1e-3
valid_mask_channel_list.append(
ta_utils.blackout_pixel_weights_by_box_regions(
out_height, out_width, boxes.get(), blackout))
valid_mask = tf.stack(valid_mask_channel_list, axis=2)
valid_mask_list.append(valid_mask) valid_mask_list.append(valid_mask)
# Apply the Gaussian kernel to the keypoint coordinates. Returned heatmap # Apply the Gaussian kernel to the keypoint coordinates. Returned heatmap
...@@ -2001,8 +2007,8 @@ class CenterNetMaskTargetAssigner(object): ...@@ -2001,8 +2007,8 @@ class CenterNetMaskTargetAssigner(object):
self._stride = stride self._stride = stride
def assign_segmentation_targets( def assign_segmentation_targets(
self, gt_masks_list, gt_classes_list, self, gt_masks_list, gt_classes_list, gt_boxes_list=None,
mask_resize_method=ResizeMethod.BILINEAR): gt_mask_weights_list=None, mask_resize_method=ResizeMethod.BILINEAR):
"""Computes the segmentation targets. """Computes the segmentation targets.
This utility produces a semantic segmentation mask for each class, starting This utility produces a semantic segmentation mask for each class, starting
...@@ -2016,15 +2022,25 @@ class CenterNetMaskTargetAssigner(object): ...@@ -2016,15 +2022,25 @@ class CenterNetMaskTargetAssigner(object):
gt_classes_list: A list of float tensors with shape [num_boxes, gt_classes_list: A list of float tensors with shape [num_boxes,
num_classes] representing the one-hot encoded class labels for each box num_classes] representing the one-hot encoded class labels for each box
in the gt_boxes_list. in the gt_boxes_list.
gt_boxes_list: An optional list of float tensors with shape [num_boxes, 4]
with normalized boxes corresponding to each mask. The boxes are used to
spatially allocate mask weights.
gt_mask_weights_list: An optional list of float tensors with shape
[num_boxes] with weights for each mask. If a mask has a zero weight, it
indicates that the box region associated with the mask should not
contribute to the loss. If not provided, will use a per-pixel weight of
1.
mask_resize_method: A `tf.compat.v2.image.ResizeMethod`. The method to use mask_resize_method: A `tf.compat.v2.image.ResizeMethod`. The method to use
when resizing masks from input resolution to output resolution. when resizing masks from input resolution to output resolution.
Returns: Returns:
segmentation_targets: An int32 tensor of size [batch_size, output_height, segmentation_targets: An int32 tensor of size [batch_size, output_height,
output_width, num_classes] representing the class of each location in output_width, num_classes] representing the class of each location in
the output space. the output space.
segmentation_weight: A float32 tensor of size [batch_size, output_height,
output_width] indicating the loss weight to apply at each location.
""" """
# TODO(ronnyvotel): Handle groundtruth weights.
_, num_classes = shape_utils.combined_static_and_dynamic_shape( _, num_classes = shape_utils.combined_static_and_dynamic_shape(
gt_classes_list[0]) gt_classes_list[0])
...@@ -2033,8 +2049,35 @@ class CenterNetMaskTargetAssigner(object): ...@@ -2033,8 +2049,35 @@ class CenterNetMaskTargetAssigner(object):
output_height = tf.maximum(input_height // self._stride, 1) output_height = tf.maximum(input_height // self._stride, 1)
output_width = tf.maximum(input_width // self._stride, 1) output_width = tf.maximum(input_width // self._stride, 1)
if gt_boxes_list is None:
gt_boxes_list = [None] * len(gt_masks_list)
if gt_mask_weights_list is None:
gt_mask_weights_list = [None] * len(gt_masks_list)
segmentation_targets_list = [] segmentation_targets_list = []
for gt_masks, gt_classes in zip(gt_masks_list, gt_classes_list): segmentation_weights_list = []
for gt_boxes, gt_masks, gt_mask_weights, gt_classes in zip(
gt_boxes_list, gt_masks_list, gt_mask_weights_list, gt_classes_list):
if gt_boxes is not None and gt_mask_weights is not None:
boxes = box_list.BoxList(gt_boxes)
# Convert the box coordinates to absolute output image dimension space.
boxes_absolute = box_list_ops.to_absolute_coordinates(
boxes, output_height, output_width)
# Generate a segmentation weight that applies mask weights in object
# regions.
blackout = gt_mask_weights <= 0
segmentation_weight_for_image = (
ta_utils.blackout_pixel_weights_by_box_regions(
output_height, output_width, boxes_absolute.get(), blackout,
weights=gt_mask_weights))
segmentation_weights_list.append(segmentation_weight_for_image)
else:
segmentation_weights_list.append(tf.ones((output_height, output_width),
dtype=tf.float32))
gt_masks = _resize_masks(gt_masks, output_height, output_width, gt_masks = _resize_masks(gt_masks, output_height, output_width,
mask_resize_method) mask_resize_method)
gt_masks = gt_masks[:, :, :, tf.newaxis] gt_masks = gt_masks[:, :, :, tf.newaxis]
...@@ -2047,7 +2090,8 @@ class CenterNetMaskTargetAssigner(object): ...@@ -2047,7 +2090,8 @@ class CenterNetMaskTargetAssigner(object):
segmentation_targets_list.append(segmentations_for_image) segmentation_targets_list.append(segmentations_for_image)
segmentation_target = tf.stack(segmentation_targets_list, axis=0) segmentation_target = tf.stack(segmentation_targets_list, axis=0)
return segmentation_target segmentation_weight = tf.stack(segmentation_weights_list, axis=0)
return segmentation_target, segmentation_weight
class CenterNetDensePoseTargetAssigner(object): class CenterNetDensePoseTargetAssigner(object):
......
...@@ -1699,7 +1699,7 @@ class CenterNetKeypointTargetAssignerTest(test_case.TestCase): ...@@ -1699,7 +1699,7 @@ class CenterNetKeypointTargetAssignerTest(test_case.TestCase):
np.array([[0.0, 0.0, 0.3, 0.3], np.array([[0.0, 0.0, 0.3, 0.3],
[0.0, 0.0, 0.5, 0.5], [0.0, 0.0, 0.5, 0.5],
[0.0, 0.0, 0.5, 0.5], [0.0, 0.0, 0.5, 0.5],
[0.0, 0.0, 1.0, 1.0]]), [0.5, 0.5, 1.0, 1.0]]),
dtype=tf.float32) dtype=tf.float32)
] ]
...@@ -1728,15 +1728,20 @@ class CenterNetKeypointTargetAssignerTest(test_case.TestCase): ...@@ -1728,15 +1728,20 @@ class CenterNetKeypointTargetAssignerTest(test_case.TestCase):
# Verify the number of instances is correct. # Verify the number of instances is correct.
np.testing.assert_array_almost_equal([[0, 1]], np.testing.assert_array_almost_equal([[0, 1]],
num_instances_batch) num_instances_batch)
self.assertAllEqual([1, 30, 20, 2], valid_mask.shape)
# When calling the function, we specify the class id to be 1 (1th and 3rd) # When calling the function, we specify the class id to be 1 (1th and 3rd)
# instance and the keypoint indices to be [0, 2], meaning that the 1st # instance and the keypoint indices to be [0, 2], meaning that the 1st
# instance is the target class with no valid keypoints in it. As a result, # instance is the target class with no valid keypoints in it. As a result,
# the region of the 1st instance boxing box should be blacked out # the region of both keypoint types of the 1st instance boxing box should be
# (0.0, 0.0, 0.5, 0.5), transfering to (0, 0, 15, 10) in absolute output # blacked out (0.0, 0.0, 0.5, 0.5), transfering to (0, 0, 15, 10) in
# space. # absolute output space.
self.assertAlmostEqual(np.sum(valid_mask[:, 0:16, 0:11]), 0.0) self.assertAlmostEqual(np.sum(valid_mask[:, 0:15, 0:10, 0:2]), 0.0)
# All other values are 1.0 so the sum is: 30 * 20 - 16 * 11 = 424. # For the 2nd instance, only the 1st keypoint has visibility of 0 so only
self.assertAlmostEqual(np.sum(valid_mask), 424.0) # the corresponding valid mask contains zeros.
self.assertAlmostEqual(np.sum(valid_mask[:, 15:30, 10:20, 0]), 0.0)
# All other values are 1.0 so the sum is:
# 30 * 20 * 2 - 15 * 10 * 2 - 15 * 10 * 1 = 750.
self.assertAlmostEqual(np.sum(valid_mask), 750.0)
def test_assign_keypoints_offset_targets(self): def test_assign_keypoints_offset_targets(self):
def graph_fn(): def graph_fn():
...@@ -2090,13 +2095,31 @@ class CenterNetMaskTargetAssignerTest(test_case.TestCase): ...@@ -2090,13 +2095,31 @@ class CenterNetMaskTargetAssignerTest(test_case.TestCase):
tf.constant([[0., 1., 0.], tf.constant([[0., 1., 0.],
[0., 1., 0.]], dtype=tf.float32) [0., 1., 0.]], dtype=tf.float32)
] ]
gt_boxes_list = [
# Example 0.
tf.constant([[0.0, 0.0, 0.5, 0.5],
[0.0, 0.5, 0.5, 1.0],
[0.0, 0.0, 1.0, 1.0]], dtype=tf.float32),
# Example 1.
tf.constant([[0.0, 0.0, 1.0, 1.0],
[0.5, 0.0, 1.0, 0.5]], dtype=tf.float32)
]
gt_mask_weights_list = [
# Example 0.
tf.constant([0.0, 1.0, 1.0], dtype=tf.float32),
# Example 1.
tf.constant([1.0, 1.0], dtype=tf.float32)
]
cn_assigner = targetassigner.CenterNetMaskTargetAssigner(stride=2) cn_assigner = targetassigner.CenterNetMaskTargetAssigner(stride=2)
segmentation_target = cn_assigner.assign_segmentation_targets( segmentation_target, segmentation_weight = (
gt_masks_list=gt_masks_list, cn_assigner.assign_segmentation_targets(
gt_classes_list=gt_classes_list, gt_masks_list=gt_masks_list,
mask_resize_method=targetassigner.ResizeMethod.NEAREST_NEIGHBOR) gt_classes_list=gt_classes_list,
return segmentation_target gt_boxes_list=gt_boxes_list,
segmentation_target = self.execute(graph_fn, []) gt_mask_weights_list=gt_mask_weights_list,
mask_resize_method=targetassigner.ResizeMethod.NEAREST_NEIGHBOR))
return segmentation_target, segmentation_weight
segmentation_target, segmentation_weight = self.execute(graph_fn, [])
expected_seg_target = np.array([ expected_seg_target = np.array([
# Example 0 [[class 0, class 1], [background, class 0]] # Example 0 [[class 0, class 1], [background, class 0]]
...@@ -2108,13 +2131,18 @@ class CenterNetMaskTargetAssignerTest(test_case.TestCase): ...@@ -2108,13 +2131,18 @@ class CenterNetMaskTargetAssignerTest(test_case.TestCase):
], dtype=np.float32) ], dtype=np.float32)
np.testing.assert_array_almost_equal( np.testing.assert_array_almost_equal(
expected_seg_target, segmentation_target) expected_seg_target, segmentation_target)
expected_seg_weight = np.array([
[[0, 1], [1, 1]],
[[1, 1], [1, 1]]], dtype=np.float32)
np.testing.assert_array_almost_equal(
expected_seg_weight, segmentation_weight)
def test_assign_segmentation_targets_no_objects(self): def test_assign_segmentation_targets_no_objects(self):
def graph_fn(): def graph_fn():
gt_masks_list = [tf.zeros((0, 5, 5))] gt_masks_list = [tf.zeros((0, 5, 5))]
gt_classes_list = [tf.zeros((0, 10))] gt_classes_list = [tf.zeros((0, 10))]
cn_assigner = targetassigner.CenterNetMaskTargetAssigner(stride=1) cn_assigner = targetassigner.CenterNetMaskTargetAssigner(stride=1)
segmentation_target = cn_assigner.assign_segmentation_targets( segmentation_target, _ = cn_assigner.assign_segmentation_targets(
gt_masks_list=gt_masks_list, gt_masks_list=gt_masks_list,
gt_classes_list=gt_classes_list, gt_classes_list=gt_classes_list,
mask_resize_method=targetassigner.ResizeMethod.NEAREST_NEIGHBOR) mask_resize_method=targetassigner.ResizeMethod.NEAREST_NEIGHBOR)
......
...@@ -373,6 +373,11 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -373,6 +373,11 @@ class TfExampleDecoder(data_decoder.DataDecoder):
self._decode_png_instance_masks)) self._decode_png_instance_masks))
else: else:
raise ValueError('Did not recognize the `instance_mask_type` option.') raise ValueError('Did not recognize the `instance_mask_type` option.')
self.keys_to_features['image/object/mask/weight'] = (
tf.VarLenFeature(tf.float32))
self.items_to_handlers[
fields.InputDataFields.groundtruth_instance_mask_weights] = (
slim_example_decoder.Tensor('image/object/mask/weight'))
if load_dense_pose: if load_dense_pose:
self.keys_to_features['image/object/densepose/num'] = ( self.keys_to_features['image/object/densepose/num'] = (
tf.VarLenFeature(tf.int64)) tf.VarLenFeature(tf.int64))
...@@ -491,6 +496,10 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -491,6 +496,10 @@ class TfExampleDecoder(data_decoder.DataDecoder):
tensor of shape [None, num_keypoints] containing keypoint visibilites. tensor of shape [None, num_keypoints] containing keypoint visibilites.
fields.InputDataFields.groundtruth_instance_masks - 3D float32 tensor of fields.InputDataFields.groundtruth_instance_masks - 3D float32 tensor of
shape [None, None, None] containing instance masks. shape [None, None, None] containing instance masks.
fields.InputDataFields.groundtruth_instance_mask_weights - 1D float32
tensor of shape [None] containing weights. These are typically values
in {0.0, 1.0} which indicate whether to consider the mask related to an
object.
fields.InputDataFields.groundtruth_image_classes - 1D int64 of shape fields.InputDataFields.groundtruth_image_classes - 1D int64 of shape
[None] containing classes for the boxes. [None] containing classes for the boxes.
fields.InputDataFields.multiclass_scores - 1D float32 tensor of shape fields.InputDataFields.multiclass_scores - 1D float32 tensor of shape
...@@ -531,6 +540,21 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -531,6 +540,21 @@ class TfExampleDecoder(data_decoder.DataDecoder):
0), lambda: tensor_dict[fields.InputDataFields.groundtruth_weights], 0), lambda: tensor_dict[fields.InputDataFields.groundtruth_weights],
default_groundtruth_weights) default_groundtruth_weights)
if fields.InputDataFields.groundtruth_instance_masks in tensor_dict:
gt_instance_masks = tensor_dict[
fields.InputDataFields.groundtruth_instance_masks]
num_gt_instance_masks = tf.shape(gt_instance_masks)[0]
gt_instance_mask_weights = tensor_dict[
fields.InputDataFields.groundtruth_instance_mask_weights]
num_gt_instance_mask_weights = tf.shape(gt_instance_mask_weights)[0]
def default_groundtruth_instance_mask_weights():
return tf.ones([num_gt_instance_masks], dtype=tf.float32)
tensor_dict[fields.InputDataFields.groundtruth_instance_mask_weights] = (
tf.cond(tf.greater(num_gt_instance_mask_weights, 0),
lambda: gt_instance_mask_weights,
default_groundtruth_instance_mask_weights))
if fields.InputDataFields.groundtruth_keypoints in tensor_dict: if fields.InputDataFields.groundtruth_keypoints in tensor_dict:
# Set all keypoints that are not labeled to NaN. # Set all keypoints that are not labeled to NaN.
gt_kpt_fld = fields.InputDataFields.groundtruth_keypoints gt_kpt_fld = fields.InputDataFields.groundtruth_keypoints
......
...@@ -1225,6 +1225,9 @@ class TfExampleDecoderTest(test_case.TestCase): ...@@ -1225,6 +1225,9 @@ class TfExampleDecoderTest(test_case.TestCase):
self.assertAllEqual( self.assertAllEqual(
instance_masks.astype(np.float32), instance_masks.astype(np.float32),
tensor_dict[fields.InputDataFields.groundtruth_instance_masks]) tensor_dict[fields.InputDataFields.groundtruth_instance_masks])
self.assertAllEqual(
tensor_dict[fields.InputDataFields.groundtruth_instance_mask_weights],
[1, 1, 1, 1])
self.assertAllEqual(object_classes, self.assertAllEqual(object_classes,
tensor_dict[fields.InputDataFields.groundtruth_classes]) tensor_dict[fields.InputDataFields.groundtruth_classes])
...@@ -1272,6 +1275,71 @@ class TfExampleDecoderTest(test_case.TestCase): ...@@ -1272,6 +1275,71 @@ class TfExampleDecoderTest(test_case.TestCase):
self.assertNotIn(fields.InputDataFields.groundtruth_instance_masks, self.assertNotIn(fields.InputDataFields.groundtruth_instance_masks,
tensor_dict) tensor_dict)
def testDecodeInstanceSegmentationWithWeights(self):
num_instances = 4
image_height = 5
image_width = 3
# Randomly generate image.
image_tensor = np.random.randint(
256, size=(image_height, image_width, 3)).astype(np.uint8)
encoded_jpeg, _ = self._create_encoded_and_decoded_data(
image_tensor, 'jpeg')
# Randomly generate instance segmentation masks.
instance_masks = (
np.random.randint(2, size=(num_instances, image_height,
image_width)).astype(np.float32))
instance_masks_flattened = np.reshape(instance_masks, [-1])
instance_mask_weights = np.array([1, 1, 0, 1], dtype=np.float32)
# Randomly generate class labels for each instance.
object_classes = np.random.randint(
100, size=(num_instances)).astype(np.int64)
def graph_fn():
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
dataset_util.bytes_feature(encoded_jpeg),
'image/format':
dataset_util.bytes_feature(six.b('jpeg')),
'image/height':
dataset_util.int64_feature(image_height),
'image/width':
dataset_util.int64_feature(image_width),
'image/object/mask':
dataset_util.float_list_feature(instance_masks_flattened),
'image/object/mask/weight':
dataset_util.float_list_feature(instance_mask_weights),
'image/object/class/label':
dataset_util.int64_list_feature(object_classes)
})).SerializeToString()
example_decoder = tf_example_decoder.TfExampleDecoder(
load_instance_masks=True)
output = example_decoder.decode(tf.convert_to_tensor(example))
self.assertAllEqual(
(output[fields.InputDataFields.groundtruth_instance_masks].get_shape(
).as_list()), [4, 5, 3])
self.assertAllEqual(
output[fields.InputDataFields.groundtruth_instance_mask_weights],
[1, 1, 0, 1])
self.assertAllEqual((output[
fields.InputDataFields.groundtruth_classes].get_shape().as_list()),
[4])
return output
tensor_dict = self.execute_cpu(graph_fn, [])
self.assertAllEqual(
instance_masks.astype(np.float32),
tensor_dict[fields.InputDataFields.groundtruth_instance_masks])
self.assertAllEqual(object_classes,
tensor_dict[fields.InputDataFields.groundtruth_classes])
def testDecodeImageLabels(self): def testDecodeImageLabels(self):
image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8) image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg, _ = self._create_encoded_and_decoded_data( encoded_jpeg, _ = self._create_encoded_and_decoded_data(
......
...@@ -13,17 +13,22 @@ on-device machine learning inference with low latency and a small binary size. ...@@ -13,17 +13,22 @@ on-device machine learning inference with low latency and a small binary size.
TensorFlow Lite uses many techniques for this such as quantized kernels that TensorFlow Lite uses many techniques for this such as quantized kernels that
allow smaller and faster (fixed-point math) models. allow smaller and faster (fixed-point math) models.
This document shows how elgible models from the This document shows how eligible models from the
[TF2 Detection zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md) [TF2 Detection zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md)
can be converted for inference with TFLite. can be converted for inference with TFLite. See this Colab tutorial for a
runnable tutorial that walks you through the steps explained in this document:
<a target="_blank" href="https://colab.research.google.com/github/tensorflow/models/blob/master/research/object_detection/colab_tutorials/convert_odt_model_to_TFLite.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run
in Google Colab</a>
For an end-to-end Python guide on how to fine-tune an SSD model for mobile For an end-to-end Python guide on how to fine-tune an SSD model for mobile
inference, look at inference, look at
[this Colab](../colab_tutorials/eager_few_shot_od_training_tflite.ipynb). [this Colab](../colab_tutorials/eager_few_shot_od_training_tflite.ipynb).
**NOTE:** TFLite currently only supports **SSD Architectures** (excluding **NOTE:** TFLite currently only supports **SSD Architectures** (excluding
EfficientDet) for boxes-based detection. Support for EfficientDet is coming EfficientDet) for boxes-based detection. Support for EfficientDet is provided
soon. via the [TFLite Model Maker](https://www.tensorflow.org/lite/tutorials/model_maker_object_detection)
library.
The output model has the following inputs & outputs: The output model has the following inputs & outputs:
...@@ -87,9 +92,46 @@ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8, ...@@ -87,9 +92,46 @@ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8,
converter.representative_dataset = <...> converter.representative_dataset = <...>
``` ```
### Step 3: Add Metadata
The model needs to be packed with
[TFLite Metadata](https://www.tensorflow.org/lite/convert/metadata) to enable
easy integration into mobile apps using the
[TFLite Task Library](https://www.tensorflow.org/lite/inference_with_metadata/task_library/object_detector).
This metadata helps the inference code perform the correct pre & post processing
as required by the model. Use the following code to create the metadata.
```python
from tflite_support.metadata_writers import object_detector
from tflite_support.metadata_writers import writer_utils
writer = object_detector.MetadataWriter.create_for_inference(
writer_utils.load_file(_TFLITE_MODEL_PATH), input_norm_mean=[0],
input_norm_std=[255], label_file_paths=[_TFLITE_LABEL_PATH])
writer_utils.save_file(writer.populate(), _TFLITE_MODEL_WITH_METADATA_PATH)
```
See the TFLite Metadata Writer API [documentation](https://www.tensorflow.org/lite/convert/metadata_writer_tutorial#object_detectors)
for more details.
## Running our model on Android ## Running our model on Android
To run our TensorFlow Lite model on device, we will use Android Studio to build ### Integrate the model into your app
You can use the TFLite Task Library's [ObjectDetector API](https://www.tensorflow.org/lite/inference_with_metadata/task_library/object_detector)
to integrate the model into your Android app.
```java
// Initialization
ObjectDetectorOptions options = ObjectDetectorOptions.builder().setMaxResults(1).build();
ObjectDetector objectDetector = ObjectDetector.createFromFileAndOptions(context, modelFile, options);
// Run inference
List<Detection> results = objectDetector.detect(image);
```
### Test the model using the TFLite sample app
To test our TensorFlow Lite model on device, we will use Android Studio to build
and run the TensorFlow Lite detection example with the new model. The example is and run the TensorFlow Lite detection example with the new model. The example is
found in the found in the
[TensorFlow examples repository](https://github.com/tensorflow/examples) under [TensorFlow examples repository](https://github.com/tensorflow/examples) under
...@@ -102,7 +144,7 @@ that support API >= 21. Additional details are available on the ...@@ -102,7 +144,7 @@ that support API >= 21. Additional details are available on the
Next we need to point the app to our new detect.tflite file and give it the Next we need to point the app to our new detect.tflite file and give it the
names of our new labels. Specifically, we will copy our TensorFlow Lite names of our new labels. Specifically, we will copy our TensorFlow Lite
flatbuffer to the app assets directory with the following command: model with metadata to the app assets directory with the following command:
```shell ```shell
mkdir $TF_EXAMPLES/lite/examples/object_detection/android/app/src/main/assets mkdir $TF_EXAMPLES/lite/examples/object_detection/android/app/src/main/assets
...@@ -110,9 +152,6 @@ cp /tmp/tflite/detect.tflite \ ...@@ -110,9 +152,6 @@ cp /tmp/tflite/detect.tflite \
$TF_EXAMPLES/lite/examples/object_detection/android/app/src/main/assets $TF_EXAMPLES/lite/examples/object_detection/android/app/src/main/assets
``` ```
You will also need to copy your new labelmap labelmap.txt to the assets
directory.
We will now edit the gradle build file to use these assets. First, open the We will now edit the gradle build file to use these assets. First, open the
`build.gradle` file `build.gradle` file
`$TF_EXAMPLES/lite/examples/object_detection/android/app/build.gradle`. Comment `$TF_EXAMPLES/lite/examples/object_detection/android/app/build.gradle`. Comment
...@@ -122,23 +161,12 @@ out the model download script to avoid your assets being overwritten: ...@@ -122,23 +161,12 @@ out the model download script to avoid your assets being overwritten:
// apply from:'download_model.gradle' // apply from:'download_model.gradle'
``` ```
If your model is named `detect.tflite`, and your labels file `labelmap.txt`, the If your model is named `detect.tflite`, the example will use it automatically as
example will use them automatically as long as they've been properly copied into long as they've been properly copied into the base assets directory. If you need
the base assets directory. If you need to use a custom path or filename, open up to use a custom path or filename, open up the
the
$TF_EXAMPLES/lite/examples/object_detection/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java $TF_EXAMPLES/lite/examples/object_detection/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java
file in a text editor and find the definition of TF_OD_API_LABELS_FILE. Update file in a text editor and find the definition of TF_OD_API_MODEL_FILE. Update
this path to point to your new label map file: "labels_list.txt". Note that if this path to point to your new model file.
your model is quantized, the flag TF_OD_API_IS_QUANTIZED is set to true, and if
your model is floating point, the flag TF_OD_API_IS_QUANTIZED is set to false.
This new section of DetectorActivity.java should now look as follows for a
quantized model:
```java
private static final boolean TF_OD_API_IS_QUANTIZED = true;
private static final String TF_OD_API_MODEL_FILE = "detect.tflite";
private static final String TF_OD_API_LABELS_FILE = "labels_list.txt";
```
Once you’ve copied the TensorFlow Lite model and edited the gradle build script Once you’ve copied the TensorFlow Lite model and edited the gradle build script
to not use the downloaded assets, you can build and deploy the app using the to not use the downloaded assets, you can build and deploy the app using the
......
...@@ -479,6 +479,7 @@ def pad_input_data_to_static_shapes(tensor_dict, ...@@ -479,6 +479,7 @@ def pad_input_data_to_static_shapes(tensor_dict,
input_fields.groundtruth_instance_masks: [ input_fields.groundtruth_instance_masks: [
max_num_boxes, height, width max_num_boxes, height, width
], ],
input_fields.groundtruth_instance_mask_weights: [max_num_boxes],
input_fields.groundtruth_is_crowd: [max_num_boxes], input_fields.groundtruth_is_crowd: [max_num_boxes],
input_fields.groundtruth_group_of: [max_num_boxes], input_fields.groundtruth_group_of: [max_num_boxes],
input_fields.groundtruth_area: [max_num_boxes], input_fields.groundtruth_area: [max_num_boxes],
...@@ -601,6 +602,8 @@ def augment_input_data(tensor_dict, data_augmentation_options): ...@@ -601,6 +602,8 @@ def augment_input_data(tensor_dict, data_augmentation_options):
include_instance_masks = (fields.InputDataFields.groundtruth_instance_masks include_instance_masks = (fields.InputDataFields.groundtruth_instance_masks
in tensor_dict) in tensor_dict)
include_instance_mask_weights = (
fields.InputDataFields.groundtruth_instance_mask_weights in tensor_dict)
include_keypoints = (fields.InputDataFields.groundtruth_keypoints include_keypoints = (fields.InputDataFields.groundtruth_keypoints
in tensor_dict) in tensor_dict)
include_keypoint_visibilities = ( include_keypoint_visibilities = (
...@@ -624,6 +627,7 @@ def augment_input_data(tensor_dict, data_augmentation_options): ...@@ -624,6 +627,7 @@ def augment_input_data(tensor_dict, data_augmentation_options):
include_label_confidences=include_label_confidences, include_label_confidences=include_label_confidences,
include_multiclass_scores=include_multiclass_scores, include_multiclass_scores=include_multiclass_scores,
include_instance_masks=include_instance_masks, include_instance_masks=include_instance_masks,
include_instance_mask_weights=include_instance_mask_weights,
include_keypoints=include_keypoints, include_keypoints=include_keypoints,
include_keypoint_visibilities=include_keypoint_visibilities, include_keypoint_visibilities=include_keypoint_visibilities,
include_dense_pose=include_dense_pose, include_dense_pose=include_dense_pose,
...@@ -652,6 +656,7 @@ def _get_labels_dict(input_dict): ...@@ -652,6 +656,7 @@ def _get_labels_dict(input_dict):
fields.InputDataFields.groundtruth_keypoint_depths, fields.InputDataFields.groundtruth_keypoint_depths,
fields.InputDataFields.groundtruth_keypoint_depth_weights, fields.InputDataFields.groundtruth_keypoint_depth_weights,
fields.InputDataFields.groundtruth_instance_masks, fields.InputDataFields.groundtruth_instance_masks,
fields.InputDataFields.groundtruth_instance_mask_weights,
fields.InputDataFields.groundtruth_area, fields.InputDataFields.groundtruth_area,
fields.InputDataFields.groundtruth_is_crowd, fields.InputDataFields.groundtruth_is_crowd,
fields.InputDataFields.groundtruth_group_of, fields.InputDataFields.groundtruth_group_of,
...@@ -804,6 +809,9 @@ def train_input(train_config, train_input_config, ...@@ -804,6 +809,9 @@ def train_input(train_config, train_input_config,
labels[fields.InputDataFields.groundtruth_instance_masks] is a labels[fields.InputDataFields.groundtruth_instance_masks] is a
[batch_size, num_boxes, H, W] float32 tensor containing only binary [batch_size, num_boxes, H, W] float32 tensor containing only binary
values, which represent instance masks for objects. values, which represent instance masks for objects.
labels[fields.InputDataFields.groundtruth_instance_mask_weights] is a
[batch_size, num_boxes] float32 tensor containing groundtruth weights
for each instance mask.
labels[fields.InputDataFields.groundtruth_keypoints] is a labels[fields.InputDataFields.groundtruth_keypoints] is a
[batch_size, num_boxes, num_keypoints, 2] float32 tensor containing [batch_size, num_boxes, num_keypoints, 2] float32 tensor containing
keypoints for each box. keypoints for each box.
...@@ -961,6 +969,9 @@ def eval_input(eval_config, eval_input_config, model_config, ...@@ -961,6 +969,9 @@ def eval_input(eval_config, eval_input_config, model_config,
labels[fields.InputDataFields.groundtruth_instance_masks] is a labels[fields.InputDataFields.groundtruth_instance_masks] is a
[1, num_boxes, H, W] float32 tensor containing only binary values, [1, num_boxes, H, W] float32 tensor containing only binary values,
which represent instance masks for objects. which represent instance masks for objects.
labels[fields.InputDataFields.groundtruth_instance_mask_weights] is a
[1, num_boxes] float32 tensor containing groundtruth weights for each
instance mask.
labels[fields.InputDataFields.groundtruth_weights] is a labels[fields.InputDataFields.groundtruth_weights] is a
[batch_size, num_boxes, num_keypoints] float32 tensor containing [batch_size, num_boxes, num_keypoints] float32 tensor containing
groundtruth weights for the keypoints. groundtruth weights for the keypoints.
......
...@@ -795,15 +795,20 @@ class DataAugmentationFnTest(test_case.TestCase): ...@@ -795,15 +795,20 @@ class DataAugmentationFnTest(test_case.TestCase):
fields.InputDataFields.image: fields.InputDataFields.image:
tf.constant(np.random.rand(10, 10, 3).astype(np.float32)), tf.constant(np.random.rand(10, 10, 3).astype(np.float32)),
fields.InputDataFields.groundtruth_instance_masks: fields.InputDataFields.groundtruth_instance_masks:
tf.constant(np.zeros([2, 10, 10], np.uint8)) tf.constant(np.zeros([2, 10, 10], np.uint8)),
fields.InputDataFields.groundtruth_instance_mask_weights:
tf.constant([1.0, 0.0], np.float32)
} }
augmented_tensor_dict = data_augmentation_fn(tensor_dict=tensor_dict) augmented_tensor_dict = data_augmentation_fn(tensor_dict=tensor_dict)
return (augmented_tensor_dict[fields.InputDataFields.image], return (augmented_tensor_dict[fields.InputDataFields.image],
augmented_tensor_dict[fields.InputDataFields. augmented_tensor_dict[fields.InputDataFields.
groundtruth_instance_masks]) groundtruth_instance_masks],
image, masks = self.execute_cpu(graph_fn, []) augmented_tensor_dict[fields.InputDataFields.
groundtruth_instance_mask_weights])
image, masks, mask_weights = self.execute_cpu(graph_fn, [])
self.assertAllEqual(image.shape, [20, 20, 3]) self.assertAllEqual(image.shape, [20, 20, 3])
self.assertAllEqual(masks.shape, [2, 20, 20]) self.assertAllEqual(masks.shape, [2, 20, 20])
self.assertAllClose(mask_weights, [1.0, 0.0])
def test_include_keypoints_in_data_augmentation(self): def test_include_keypoints_in_data_augmentation(self):
data_augmentation_options = [ data_augmentation_options = [
......
...@@ -1668,7 +1668,9 @@ def predicted_embeddings_at_object_centers(embedding_predictions, ...@@ -1668,7 +1668,9 @@ def predicted_embeddings_at_object_centers(embedding_predictions,
class ObjectDetectionParams( class ObjectDetectionParams(
collections.namedtuple('ObjectDetectionParams', [ collections.namedtuple('ObjectDetectionParams', [
'localization_loss', 'scale_loss_weight', 'offset_loss_weight', 'localization_loss', 'scale_loss_weight', 'offset_loss_weight',
'task_loss_weight' 'task_loss_weight', 'scale_head_num_filters',
'scale_head_kernel_sizes', 'offset_head_num_filters',
'offset_head_kernel_sizes'
])): ])):
"""Namedtuple to host object detection related parameters. """Namedtuple to host object detection related parameters.
...@@ -1684,7 +1686,11 @@ class ObjectDetectionParams( ...@@ -1684,7 +1686,11 @@ class ObjectDetectionParams(
localization_loss, localization_loss,
scale_loss_weight, scale_loss_weight,
offset_loss_weight, offset_loss_weight,
task_loss_weight=1.0): task_loss_weight=1.0,
scale_head_num_filters=(256),
scale_head_kernel_sizes=(3),
offset_head_num_filters=(256),
offset_head_kernel_sizes=(3)):
"""Constructor with default values for ObjectDetectionParams. """Constructor with default values for ObjectDetectionParams.
Args: Args:
...@@ -1697,13 +1703,23 @@ class ObjectDetectionParams( ...@@ -1697,13 +1703,23 @@ class ObjectDetectionParams(
depending on the input size. depending on the input size.
offset_loss_weight: float, The weight for localizing center offsets. offset_loss_weight: float, The weight for localizing center offsets.
task_loss_weight: float, the weight of the object detection loss. task_loss_weight: float, the weight of the object detection loss.
scale_head_num_filters: filter numbers of the convolutional layers used
by the object detection box scale prediction head.
scale_head_kernel_sizes: kernel size of the convolutional layers used
by the object detection box scale prediction head.
offset_head_num_filters: filter numbers of the convolutional layers used
by the object detection box offset prediction head.
offset_head_kernel_sizes: kernel size of the convolutional layers used
by the object detection box offset prediction head.
Returns: Returns:
An initialized ObjectDetectionParams namedtuple. An initialized ObjectDetectionParams namedtuple.
""" """
return super(ObjectDetectionParams, return super(ObjectDetectionParams,
cls).__new__(cls, localization_loss, scale_loss_weight, cls).__new__(cls, localization_loss, scale_loss_weight,
offset_loss_weight, task_loss_weight) offset_loss_weight, task_loss_weight,
scale_head_num_filters, scale_head_kernel_sizes,
offset_head_num_filters, offset_head_kernel_sizes)
class KeypointEstimationParams( class KeypointEstimationParams(
...@@ -1937,7 +1953,8 @@ class ObjectCenterParams( ...@@ -1937,7 +1953,8 @@ class ObjectCenterParams(
class MaskParams( class MaskParams(
collections.namedtuple('MaskParams', [ collections.namedtuple('MaskParams', [
'classification_loss', 'task_loss_weight', 'mask_height', 'mask_width', 'classification_loss', 'task_loss_weight', 'mask_height', 'mask_width',
'score_threshold', 'heatmap_bias_init' 'score_threshold', 'heatmap_bias_init', 'mask_head_num_filters',
'mask_head_kernel_sizes'
])): ])):
"""Namedtuple to store mask prediction related parameters.""" """Namedtuple to store mask prediction related parameters."""
...@@ -1949,7 +1966,9 @@ class MaskParams( ...@@ -1949,7 +1966,9 @@ class MaskParams(
mask_height=256, mask_height=256,
mask_width=256, mask_width=256,
score_threshold=0.5, score_threshold=0.5,
heatmap_bias_init=-2.19): heatmap_bias_init=-2.19,
mask_head_num_filters=(256),
mask_head_kernel_sizes=(3)):
"""Constructor with default values for MaskParams. """Constructor with default values for MaskParams.
Args: Args:
...@@ -1963,6 +1982,10 @@ class MaskParams( ...@@ -1963,6 +1982,10 @@ class MaskParams(
heatmap_bias_init: float, the initial value of bias in the convolutional heatmap_bias_init: float, the initial value of bias in the convolutional
kernel of the semantic segmentation prediction head. If set to None, the kernel of the semantic segmentation prediction head. If set to None, the
bias is initialized with zeros. bias is initialized with zeros.
mask_head_num_filters: filter numbers of the convolutional layers used
by the mask prediction head.
mask_head_kernel_sizes: kernel size of the convolutional layers used
by the mask prediction head.
Returns: Returns:
An initialized MaskParams namedtuple. An initialized MaskParams namedtuple.
...@@ -1970,7 +1993,8 @@ class MaskParams( ...@@ -1970,7 +1993,8 @@ class MaskParams(
return super(MaskParams, return super(MaskParams,
cls).__new__(cls, classification_loss, cls).__new__(cls, classification_loss,
task_loss_weight, mask_height, mask_width, task_loss_weight, mask_height, mask_width,
score_threshold, heatmap_bias_init) score_threshold, heatmap_bias_init,
mask_head_num_filters, mask_head_kernel_sizes)
class DensePoseParams( class DensePoseParams(
...@@ -2312,10 +2336,18 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2312,10 +2336,18 @@ class CenterNetMetaArch(model.DetectionModel):
if self._od_params is not None: if self._od_params is not None:
prediction_heads[BOX_SCALE] = self._make_prediction_net_list( prediction_heads[BOX_SCALE] = self._make_prediction_net_list(
num_feature_outputs, NUM_SIZE_CHANNELS, name='box_scale', num_feature_outputs,
NUM_SIZE_CHANNELS,
kernel_sizes=self._od_params.scale_head_kernel_sizes,
num_filters=self._od_params.scale_head_num_filters,
name='box_scale',
unit_height_conv=unit_height_conv) unit_height_conv=unit_height_conv)
prediction_heads[BOX_OFFSET] = self._make_prediction_net_list( prediction_heads[BOX_OFFSET] = self._make_prediction_net_list(
num_feature_outputs, NUM_OFFSET_CHANNELS, name='box_offset', num_feature_outputs,
NUM_OFFSET_CHANNELS,
kernel_sizes=self._od_params.offset_head_kernel_sizes,
num_filters=self._od_params.offset_head_num_filters,
name='box_offset',
unit_height_conv=unit_height_conv) unit_height_conv=unit_height_conv)
if self._kp_params_dict is not None: if self._kp_params_dict is not None:
...@@ -2370,6 +2402,8 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2370,6 +2402,8 @@ class CenterNetMetaArch(model.DetectionModel):
prediction_heads[SEGMENTATION_HEATMAP] = self._make_prediction_net_list( prediction_heads[SEGMENTATION_HEATMAP] = self._make_prediction_net_list(
num_feature_outputs, num_feature_outputs,
num_classes, num_classes,
kernel_sizes=self._mask_params.mask_head_kernel_sizes,
num_filters=self._mask_params.mask_head_num_filters,
bias_fill=self._mask_params.heatmap_bias_init, bias_fill=self._mask_params.heatmap_bias_init,
name='seg_heatmap', name='seg_heatmap',
unit_height_conv=unit_height_conv) unit_height_conv=unit_height_conv)
...@@ -2721,8 +2755,7 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2721,8 +2755,7 @@ class CenterNetMetaArch(model.DetectionModel):
gt_weights_list=gt_weights_list, gt_weights_list=gt_weights_list,
gt_classes_list=gt_classes_list, gt_classes_list=gt_classes_list,
gt_boxes_list=gt_boxes_list) gt_boxes_list=gt_boxes_list)
flattened_valid_mask = _flatten_spatial_dimensions( flattened_valid_mask = _flatten_spatial_dimensions(valid_mask_batch)
tf.expand_dims(valid_mask_batch, axis=-1))
flattened_heapmap_targets = _flatten_spatial_dimensions(keypoint_heatmap) flattened_heapmap_targets = _flatten_spatial_dimensions(keypoint_heatmap)
# Sum over the number of instances per keypoint types to get the total # Sum over the number of instances per keypoint types to get the total
# number of keypoints. Note that this is used to normalized the loss and we # number of keypoints. Note that this is used to normalized the loss and we
...@@ -2945,20 +2978,32 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2945,20 +2978,32 @@ class CenterNetMetaArch(model.DetectionModel):
Returns: Returns:
A float scalar tensor representing the mask loss. A float scalar tensor representing the mask loss.
""" """
gt_boxes_list = self.groundtruth_lists(fields.BoxListFields.boxes)
gt_masks_list = self.groundtruth_lists(fields.BoxListFields.masks) gt_masks_list = self.groundtruth_lists(fields.BoxListFields.masks)
gt_mask_weights_list = None
if self.groundtruth_has_field(fields.BoxListFields.mask_weights):
gt_mask_weights_list = self.groundtruth_lists(
fields.BoxListFields.mask_weights)
gt_classes_list = self.groundtruth_lists(fields.BoxListFields.classes) gt_classes_list = self.groundtruth_lists(fields.BoxListFields.classes)
# Convert the groundtruth to targets. # Convert the groundtruth to targets.
assigner = self._target_assigner_dict[SEGMENTATION_TASK] assigner = self._target_assigner_dict[SEGMENTATION_TASK]
heatmap_targets = assigner.assign_segmentation_targets( heatmap_targets, heatmap_weight = assigner.assign_segmentation_targets(
gt_masks_list=gt_masks_list, gt_masks_list=gt_masks_list,
gt_classes_list=gt_classes_list) gt_classes_list=gt_classes_list,
gt_boxes_list=gt_boxes_list,
gt_mask_weights_list=gt_mask_weights_list)
flattened_heatmap_targets = _flatten_spatial_dimensions(heatmap_targets) flattened_heatmap_targets = _flatten_spatial_dimensions(heatmap_targets)
flattened_heatmap_mask = _flatten_spatial_dimensions(
heatmap_weight[:, :, :, tf.newaxis])
per_pixel_weights *= flattened_heatmap_mask
loss = 0.0 loss = 0.0
mask_loss_fn = self._mask_params.classification_loss mask_loss_fn = self._mask_params.classification_loss
total_pixels_in_loss = tf.reduce_sum(per_pixel_weights)
total_pixels_in_loss = tf.math.maximum(
tf.reduce_sum(per_pixel_weights), 1)
# Loop through each feature output head. # Loop through each feature output head.
for pred in segmentation_predictions: for pred in segmentation_predictions:
......
...@@ -1539,7 +1539,9 @@ def get_fake_mask_params(): ...@@ -1539,7 +1539,9 @@ def get_fake_mask_params():
classification_loss=losses.WeightedSoftmaxClassificationLoss(), classification_loss=losses.WeightedSoftmaxClassificationLoss(),
task_loss_weight=1.0, task_loss_weight=1.0,
mask_height=4, mask_height=4,
mask_width=4) mask_width=4,
mask_head_num_filters=[96],
mask_head_kernel_sizes=[3])
def get_fake_densepose_params(): def get_fake_densepose_params():
......
...@@ -266,6 +266,7 @@ def unstack_batch(tensor_dict, unpad_groundtruth_tensors=True): ...@@ -266,6 +266,7 @@ def unstack_batch(tensor_dict, unpad_groundtruth_tensors=True):
# dimension. This list has to be kept in sync with InputDataFields in # dimension. This list has to be kept in sync with InputDataFields in
# standard_fields.py. # standard_fields.py.
fields.InputDataFields.groundtruth_instance_masks, fields.InputDataFields.groundtruth_instance_masks,
fields.InputDataFields.groundtruth_instance_mask_weights,
fields.InputDataFields.groundtruth_classes, fields.InputDataFields.groundtruth_classes,
fields.InputDataFields.groundtruth_boxes, fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_keypoints, fields.InputDataFields.groundtruth_keypoints,
...@@ -319,6 +320,10 @@ def provide_groundtruth(model, labels): ...@@ -319,6 +320,10 @@ def provide_groundtruth(model, labels):
if fields.InputDataFields.groundtruth_instance_masks in labels: if fields.InputDataFields.groundtruth_instance_masks in labels:
gt_masks_list = labels[ gt_masks_list = labels[
fields.InputDataFields.groundtruth_instance_masks] fields.InputDataFields.groundtruth_instance_masks]
gt_mask_weights_list = None
if fields.InputDataFields.groundtruth_instance_mask_weights in labels:
gt_mask_weights_list = labels[
fields.InputDataFields.groundtruth_instance_mask_weights]
gt_keypoints_list = None gt_keypoints_list = None
if fields.InputDataFields.groundtruth_keypoints in labels: if fields.InputDataFields.groundtruth_keypoints in labels:
gt_keypoints_list = labels[fields.InputDataFields.groundtruth_keypoints] gt_keypoints_list = labels[fields.InputDataFields.groundtruth_keypoints]
...@@ -383,6 +388,7 @@ def provide_groundtruth(model, labels): ...@@ -383,6 +388,7 @@ def provide_groundtruth(model, labels):
groundtruth_confidences_list=gt_confidences_list, groundtruth_confidences_list=gt_confidences_list,
groundtruth_labeled_classes=gt_labeled_classes, groundtruth_labeled_classes=gt_labeled_classes,
groundtruth_masks_list=gt_masks_list, groundtruth_masks_list=gt_masks_list,
groundtruth_mask_weights_list=gt_mask_weights_list,
groundtruth_keypoints_list=gt_keypoints_list, groundtruth_keypoints_list=gt_keypoints_list,
groundtruth_keypoint_visibilities_list=gt_keypoint_visibilities_list, groundtruth_keypoint_visibilities_list=gt_keypoint_visibilities_list,
groundtruth_dp_num_points_list=gt_dp_num_points_list, groundtruth_dp_num_points_list=gt_dp_num_points_list,
......
...@@ -20,11 +20,11 @@ from __future__ import print_function ...@@ -20,11 +20,11 @@ from __future__ import print_function
import copy import copy
import os import os
import pprint
import time import time
import numpy as np
import numpy as np
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
import tensorflow.compat.v2 as tf2
from object_detection import eval_util from object_detection import eval_util
from object_detection import inputs from object_detection import inputs
...@@ -87,6 +87,8 @@ def _compute_losses_and_predictions_dicts( ...@@ -87,6 +87,8 @@ def _compute_losses_and_predictions_dicts(
labels[fields.InputDataFields.groundtruth_instance_masks] is a labels[fields.InputDataFields.groundtruth_instance_masks] is a
float32 tensor containing only binary values, which represent float32 tensor containing only binary values, which represent
instance masks for objects. instance masks for objects.
labels[fields.InputDataFields.groundtruth_instance_mask_weights] is a
float32 tensor containing weights for the instance masks.
labels[fields.InputDataFields.groundtruth_keypoints] is a labels[fields.InputDataFields.groundtruth_keypoints] is a
float32 tensor containing keypoints for each box. float32 tensor containing keypoints for each box.
labels[fields.InputDataFields.groundtruth_dp_num_points] is an int32 labels[fields.InputDataFields.groundtruth_dp_num_points] is an int32
...@@ -181,6 +183,22 @@ def _ensure_model_is_built(model, input_dataset, unpad_groundtruth_tensors): ...@@ -181,6 +183,22 @@ def _ensure_model_is_built(model, input_dataset, unpad_groundtruth_tensors):
)) ))
def normalize_dict(values_dict, num_replicas):
num_replicas = tf.constant(num_replicas, dtype=tf.float32)
return {key: tf.math.divide(loss, num_replicas) for key, loss
in values_dict.items()}
def reduce_dict(strategy, reduction_dict, reduction_op):
# TODO(anjalisridhar): explore if it is safe to remove the # num_replicas
# scaling of the loss and switch this to a ReduceOp.Mean
return {
name: strategy.reduce(reduction_op, loss, axis=None)
for name, loss in reduction_dict.items()
}
# TODO(kaftan): Explore removing learning_rate from this method & returning # TODO(kaftan): Explore removing learning_rate from this method & returning
## The full losses dict instead of just total_loss, then doing all summaries ## The full losses dict instead of just total_loss, then doing all summaries
## saving in a utility method called by the outer training loop. ## saving in a utility method called by the outer training loop.
...@@ -190,10 +208,8 @@ def eager_train_step(detection_model, ...@@ -190,10 +208,8 @@ def eager_train_step(detection_model,
labels, labels,
unpad_groundtruth_tensors, unpad_groundtruth_tensors,
optimizer, optimizer,
learning_rate,
add_regularization_loss=True, add_regularization_loss=True,
clip_gradients_value=None, clip_gradients_value=None,
global_step=None,
num_replicas=1.0): num_replicas=1.0):
"""Process a single training batch. """Process a single training batch.
...@@ -237,6 +253,9 @@ def eager_train_step(detection_model, ...@@ -237,6 +253,9 @@ def eager_train_step(detection_model,
labels[fields.InputDataFields.groundtruth_instance_masks] is a labels[fields.InputDataFields.groundtruth_instance_masks] is a
[batch_size, num_boxes, H, W] float32 tensor containing only binary [batch_size, num_boxes, H, W] float32 tensor containing only binary
values, which represent instance masks for objects. values, which represent instance masks for objects.
labels[fields.InputDataFields.groundtruth_instance_mask_weights] is a
[batch_size, num_boxes] float32 tensor containing weights for the
instance masks.
labels[fields.InputDataFields.groundtruth_keypoints] is a labels[fields.InputDataFields.groundtruth_keypoints] is a
[batch_size, num_boxes, num_keypoints, 2] float32 tensor containing [batch_size, num_boxes, num_keypoints, 2] float32 tensor containing
keypoints for each box. keypoints for each box.
...@@ -261,16 +280,10 @@ def eager_train_step(detection_model, ...@@ -261,16 +280,10 @@ def eager_train_step(detection_model,
float32 tensor containing the weights of the keypoint depth feature. float32 tensor containing the weights of the keypoint depth feature.
unpad_groundtruth_tensors: A parameter passed to unstack_batch. unpad_groundtruth_tensors: A parameter passed to unstack_batch.
optimizer: The training optimizer that will update the variables. optimizer: The training optimizer that will update the variables.
learning_rate: The learning rate tensor for the current training step.
This is used only for TensorBoard logging purposes, it does not affect
model training.
add_regularization_loss: Whether or not to include the model's add_regularization_loss: Whether or not to include the model's
regularization loss in the losses dictionary. regularization loss in the losses dictionary.
clip_gradients_value: If this is present, clip the gradients global norm clip_gradients_value: If this is present, clip the gradients global norm
at this value using `tf.clip_by_global_norm`. at this value using `tf.clip_by_global_norm`.
global_step: The current training step. Used for TensorBoard logging
purposes. This step is not updated by this function and must be
incremented separately.
num_replicas: The number of replicas in the current distribution strategy. num_replicas: The number of replicas in the current distribution strategy.
This is used to scale the total loss so that training in a distribution This is used to scale the total loss so that training in a distribution
strategy works correctly. strategy works correctly.
...@@ -291,31 +304,18 @@ def eager_train_step(detection_model, ...@@ -291,31 +304,18 @@ def eager_train_step(detection_model,
losses_dict, _ = _compute_losses_and_predictions_dicts( losses_dict, _ = _compute_losses_and_predictions_dicts(
detection_model, features, labels, add_regularization_loss) detection_model, features, labels, add_regularization_loss)
total_loss = losses_dict['Loss/total_loss'] losses_dict = normalize_dict(losses_dict, num_replicas)
# Normalize loss for num replicas
total_loss = tf.math.divide(total_loss,
tf.constant(num_replicas, dtype=tf.float32))
losses_dict['Loss/normalized_total_loss'] = total_loss
for loss_type in losses_dict:
tf.compat.v2.summary.scalar(
loss_type, losses_dict[loss_type], step=global_step)
trainable_variables = detection_model.trainable_variables trainable_variables = detection_model.trainable_variables
total_loss = losses_dict['Loss/total_loss']
gradients = tape.gradient(total_loss, trainable_variables) gradients = tape.gradient(total_loss, trainable_variables)
if clip_gradients_value: if clip_gradients_value:
gradients, _ = tf.clip_by_global_norm(gradients, clip_gradients_value) gradients, _ = tf.clip_by_global_norm(gradients, clip_gradients_value)
optimizer.apply_gradients(zip(gradients, trainable_variables)) optimizer.apply_gradients(zip(gradients, trainable_variables))
tf.compat.v2.summary.scalar('learning_rate', learning_rate, step=global_step)
tf.compat.v2.summary.image( return losses_dict
name='train_input_images',
step=global_step,
data=features[fields.InputDataFields.image],
max_outputs=3)
return total_loss
def validate_tf_v2_checkpoint_restore_map(checkpoint_restore_map): def validate_tf_v2_checkpoint_restore_map(checkpoint_restore_map):
...@@ -397,7 +397,8 @@ def load_fine_tune_checkpoint(model, checkpoint_path, checkpoint_type, ...@@ -397,7 +397,8 @@ def load_fine_tune_checkpoint(model, checkpoint_path, checkpoint_type,
fine_tune_checkpoint_type=checkpoint_type) fine_tune_checkpoint_type=checkpoint_type)
validate_tf_v2_checkpoint_restore_map(restore_from_objects_dict) validate_tf_v2_checkpoint_restore_map(restore_from_objects_dict)
ckpt = tf.train.Checkpoint(**restore_from_objects_dict) ckpt = tf.train.Checkpoint(**restore_from_objects_dict)
ckpt.restore(checkpoint_path).assert_existing_objects_matched() ckpt.restore(
checkpoint_path).expect_partial().assert_existing_objects_matched()
def get_filepath(strategy, filepath): def get_filepath(strategy, filepath):
...@@ -474,7 +475,12 @@ def train_loop( ...@@ -474,7 +475,12 @@ def train_loop(
Checkpoint every n training steps. Checkpoint every n training steps.
checkpoint_max_to_keep: checkpoint_max_to_keep:
int, the number of most recent checkpoints to keep in the model directory. int, the number of most recent checkpoints to keep in the model directory.
record_summaries: Boolean, whether or not to record summaries. record_summaries: Boolean, whether or not to record summaries defined by
the model or the training pipeline. This does not impact the summaries
of the loss values which are always recorded. Examples of summaries
that are controlled by this flag include:
- Image summaries of training images.
- Intermediate tensors which maybe logged by meta architectures.
performance_summary_exporter: function for exporting performance metrics. performance_summary_exporter: function for exporting performance metrics.
num_steps_per_iteration: int, The number of training steps to perform num_steps_per_iteration: int, The number of training steps to perform
in each iteration. in each iteration.
...@@ -533,7 +539,8 @@ def train_loop( ...@@ -533,7 +539,8 @@ def train_loop(
strategy = tf.compat.v2.distribute.get_strategy() strategy = tf.compat.v2.distribute.get_strategy()
with strategy.scope(): with strategy.scope():
detection_model = MODEL_BUILD_UTIL_MAP['detection_model_fn_base']( detection_model = MODEL_BUILD_UTIL_MAP['detection_model_fn_base'](
model_config=model_config, is_training=True) model_config=model_config, is_training=True,
add_summaries=record_summaries)
def train_dataset_fn(input_context): def train_dataset_fn(input_context):
"""Callable to create train input.""" """Callable to create train input."""
...@@ -576,11 +583,9 @@ def train_loop( ...@@ -576,11 +583,9 @@ def train_loop(
# is the chief. # is the chief.
summary_writer_filepath = get_filepath(strategy, summary_writer_filepath = get_filepath(strategy,
os.path.join(model_dir, 'train')) os.path.join(model_dir, 'train'))
if record_summaries:
summary_writer = tf.compat.v2.summary.create_file_writer( summary_writer = tf.compat.v2.summary.create_file_writer(
summary_writer_filepath) summary_writer_filepath)
else:
summary_writer = tf2.summary.create_noop_writer()
with summary_writer.as_default(): with summary_writer.as_default():
with strategy.scope(): with strategy.scope():
...@@ -614,32 +619,37 @@ def train_loop( ...@@ -614,32 +619,37 @@ def train_loop(
def train_step_fn(features, labels): def train_step_fn(features, labels):
"""Single train step.""" """Single train step."""
loss = eager_train_step(
if record_summaries:
tf.compat.v2.summary.image(
name='train_input_images',
step=global_step,
data=features[fields.InputDataFields.image],
max_outputs=3)
losses_dict = eager_train_step(
detection_model, detection_model,
features, features,
labels, labels,
unpad_groundtruth_tensors, unpad_groundtruth_tensors,
optimizer, optimizer,
learning_rate=learning_rate_fn(),
add_regularization_loss=add_regularization_loss, add_regularization_loss=add_regularization_loss,
clip_gradients_value=clip_gradients_value, clip_gradients_value=clip_gradients_value,
global_step=global_step,
num_replicas=strategy.num_replicas_in_sync) num_replicas=strategy.num_replicas_in_sync)
global_step.assign_add(1) global_step.assign_add(1)
return loss return losses_dict
def _sample_and_train(strategy, train_step_fn, data_iterator): def _sample_and_train(strategy, train_step_fn, data_iterator):
features, labels = data_iterator.next() features, labels = data_iterator.next()
if hasattr(tf.distribute.Strategy, 'run'): if hasattr(tf.distribute.Strategy, 'run'):
per_replica_losses = strategy.run( per_replica_losses_dict = strategy.run(
train_step_fn, args=(features, labels)) train_step_fn, args=(features, labels))
else: else:
per_replica_losses = strategy.experimental_run_v2( per_replica_losses_dict = (
train_step_fn, args=(features, labels)) strategy.experimental_run_v2(
# TODO(anjalisridhar): explore if it is safe to remove the train_step_fn, args=(features, labels)))
## num_replicas scaling of the loss and switch this to a ReduceOp.Mean
return strategy.reduce(tf.distribute.ReduceOp.SUM, return reduce_dict(
per_replica_losses, axis=None) strategy, per_replica_losses_dict, tf.distribute.ReduceOp.SUM)
@tf.function @tf.function
def _dist_train_step(data_iterator): def _dist_train_step(data_iterator):
...@@ -665,7 +675,7 @@ def train_loop( ...@@ -665,7 +675,7 @@ def train_loop(
for _ in range(global_step.value(), train_steps, for _ in range(global_step.value(), train_steps,
num_steps_per_iteration): num_steps_per_iteration):
loss = _dist_train_step(train_input_iter) losses_dict = _dist_train_step(train_input_iter)
time_taken = time.time() - last_step_time time_taken = time.time() - last_step_time
last_step_time = time.time() last_step_time = time.time()
...@@ -676,11 +686,19 @@ def train_loop( ...@@ -676,11 +686,19 @@ def train_loop(
steps_per_sec_list.append(steps_per_sec) steps_per_sec_list.append(steps_per_sec)
logged_dict = losses_dict.copy()
logged_dict['learning_rate'] = learning_rate_fn()
for key, val in logged_dict.items():
tf.compat.v2.summary.scalar(key, val, step=global_step)
if global_step.value() - logged_step >= 100: if global_step.value() - logged_step >= 100:
logged_dict_np = {name: value.numpy() for name, value in
logged_dict.items()}
tf.logging.info( tf.logging.info(
'Step {} per-step time {:.3f}s loss={:.3f}'.format( 'Step {} per-step time {:.3f}s'.format(
global_step.value(), time_taken / num_steps_per_iteration, global_step.value(), time_taken / num_steps_per_iteration))
loss)) tf.logging.info(pprint.pformat(logged_dict_np, width=40))
logged_step = global_step.value() logged_step = global_step.value()
if ((int(global_step.value()) - checkpointed_step) >= if ((int(global_step.value()) - checkpointed_step) >=
...@@ -699,7 +717,7 @@ def train_loop( ...@@ -699,7 +717,7 @@ def train_loop(
'steps_per_sec': np.mean(steps_per_sec_list), 'steps_per_sec': np.mean(steps_per_sec_list),
'steps_per_sec_p50': np.median(steps_per_sec_list), 'steps_per_sec_p50': np.median(steps_per_sec_list),
'steps_per_sec_max': max(steps_per_sec_list), 'steps_per_sec_max': max(steps_per_sec_list),
'last_batch_loss': float(loss) 'last_batch_loss': float(losses_dict['Loss/total_loss'])
} }
mixed_precision = 'bf16' if kwargs['use_bfloat16'] else 'fp32' mixed_precision = 'bf16' if kwargs['use_bfloat16'] else 'fp32'
performance_summary_exporter(metrics, mixed_precision) performance_summary_exporter(metrics, mixed_precision)
......
...@@ -65,8 +65,10 @@ flags.DEFINE_integer( ...@@ -65,8 +65,10 @@ flags.DEFINE_integer(
flags.DEFINE_integer( flags.DEFINE_integer(
'checkpoint_every_n', 1000, 'Integer defining how often we checkpoint.') 'checkpoint_every_n', 1000, 'Integer defining how often we checkpoint.')
flags.DEFINE_boolean('record_summaries', True, flags.DEFINE_boolean('record_summaries', True,
('Whether or not to record summaries during' ('Whether or not to record summaries defined by the model'
' training.')) ' or the training pipeline. This does not impact the'
' summaries of the loss values which are always'
' recorded.'))
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -19,9 +19,10 @@ from __future__ import absolute_import ...@@ -19,9 +19,10 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from keras.applications import resnet
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from tensorflow.python.keras.applications import resnet
from object_detection.core import freezable_batch_norm from object_detection.core import freezable_batch_norm
from object_detection.models.keras_models import model_utils from object_detection.models.keras_models import model_utils
......
...@@ -65,6 +65,14 @@ message CenterNet { ...@@ -65,6 +65,14 @@ message CenterNet {
// Localization loss configuration for object scale and offset losses. // Localization loss configuration for object scale and offset losses.
optional LocalizationLoss localization_loss = 8; optional LocalizationLoss localization_loss = 8;
// Parameters to determine the architecture of the object scale prediction
// head.
optional PredictionHeadParams scale_head_params = 9;
// Parameters to determine the architecture of the object offset prediction
// head.
optional PredictionHeadParams offset_head_params = 10;
} }
optional ObjectDetection object_detection_task = 4; optional ObjectDetection object_detection_task = 4;
...@@ -268,6 +276,10 @@ message CenterNet { ...@@ -268,6 +276,10 @@ message CenterNet {
// prediction head. -2.19 corresponds to predicting foreground with // prediction head. -2.19 corresponds to predicting foreground with
// a probability of 0.1. // a probability of 0.1.
optional float heatmap_bias_init = 3 [default = -2.19]; optional float heatmap_bias_init = 3 [default = -2.19];
// Parameters to determine the architecture of the segmentation mask
// prediction head.
optional PredictionHeadParams mask_head_params = 7;
} }
optional MaskEstimation mask_estimation_task = 8; optional MaskEstimation mask_estimation_task = 8;
......
...@@ -19,6 +19,7 @@ from __future__ import division ...@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from object_detection.utils import shape_utils
def _coordinate_vector_1d(start, end, size, align_endpoints): def _coordinate_vector_1d(start, end, size, align_endpoints):
...@@ -322,7 +323,7 @@ def multilevel_roi_align(features, boxes, box_levels, output_size, ...@@ -322,7 +323,7 @@ def multilevel_roi_align(features, boxes, box_levels, output_size,
""" """
with tf.name_scope(scope, 'MultiLevelRoIAlign'): with tf.name_scope(scope, 'MultiLevelRoIAlign'):
features, true_feature_shapes = pad_to_max_size(features) features, true_feature_shapes = pad_to_max_size(features)
batch_size = tf.shape(features)[0] batch_size = shape_utils.combined_static_and_dynamic_shape(features)[0]
num_levels = features.get_shape().as_list()[1] num_levels = features.get_shape().as_list()[1]
max_feature_height = tf.shape(features)[2] max_feature_height = tf.shape(features)[2]
max_feature_width = tf.shape(features)[3] max_feature_width = tf.shape(features)[3]
......
...@@ -289,12 +289,38 @@ def get_valid_keypoint_mask_for_class(keypoint_coordinates, ...@@ -289,12 +289,38 @@ def get_valid_keypoint_mask_for_class(keypoint_coordinates,
return mask, keypoints_nan_to_zeros return mask, keypoints_nan_to_zeros
def blackout_pixel_weights_by_box_regions(height, width, boxes, blackout): def blackout_pixel_weights_by_box_regions(height, width, boxes, blackout,
"""Blackout the pixel weights in the target box regions. weights=None):
"""Apply weights at pixel locations.
This function is used to generate the pixel weight mask (usually in the output This function is used to generate the pixel weight mask (usually in the output
image dimension). The mask is to ignore some regions when computing loss. image dimension). The mask is to ignore some regions when computing loss.
Weights are applied as follows:
- Any region outside of a box gets the default weight 1.0
- Any box for which an explicit weight is specifed gets that weight. If
multiple boxes overlap, the maximum of the weights is applied.
- Any box for which blackout=True is specified will get a weight of 0.0,
regardless of whether an equivalent non-zero weight is specified. Also, the
blackout region takes precedence over other boxes which may overlap with
non-zero weight.
Example:
height = 4
width = 4
boxes = [[0., 0., 2., 2.],
[0., 0., 4., 2.],
[3., 0., 4., 4.]]
blackout = [False, False, True]
weights = [4.0, 3.0, 2.0]
blackout_pixel_weights_by_box_regions(height, width, boxes, blackout,
weights)
>> [[4.0, 4.0, 1.0, 1.0],
[4.0, 4.0, 1.0, 1.0],
[3.0, 3.0, 1.0, 1.0],
[0.0, 0.0, 0.0, 0.0]]
Args: Args:
height: int, height of the (output) image. height: int, height of the (output) image.
width: int, width of the (output) image. width: int, width of the (output) image.
...@@ -302,10 +328,15 @@ def blackout_pixel_weights_by_box_regions(height, width, boxes, blackout): ...@@ -302,10 +328,15 @@ def blackout_pixel_weights_by_box_regions(height, width, boxes, blackout):
coordinates of the four corners of the boxes. coordinates of the four corners of the boxes.
blackout: A boolean tensor with shape [num_instances] indicating whether to blackout: A boolean tensor with shape [num_instances] indicating whether to
blackout (zero-out) the weights within the box regions. blackout (zero-out) the weights within the box regions.
weights: An optional float32 tensor with shape [num_instances] indicating
a value to apply in each box region. Note that if blackout=True for a
given box, the weight will be zero. If None, all weights are assumed to be
1.
Returns: Returns:
A float tensor with shape [height, width] where all values within the A float tensor with shape [height, width] where all values within the
regions of the blackout boxes are 0.0 and 1.0 else where. regions of the blackout boxes are 0.0 and 1.0 (or weights if supplied)
elsewhere.
""" """
num_instances, _ = shape_utils.combined_static_and_dynamic_shape(boxes) num_instances, _ = shape_utils.combined_static_and_dynamic_shape(boxes)
# If no annotation instance is provided, return all ones (instead of # If no annotation instance is provided, return all ones (instead of
...@@ -323,22 +354,36 @@ def blackout_pixel_weights_by_box_regions(height, width, boxes, blackout): ...@@ -323,22 +354,36 @@ def blackout_pixel_weights_by_box_regions(height, width, boxes, blackout):
# Make the mask with all 1.0 in the box regions. # Make the mask with all 1.0 in the box regions.
# Shape: [num_instances, height, width] # Shape: [num_instances, height, width]
in_boxes = tf.cast( in_boxes = tf.math.logical_and(
tf.logical_and( tf.math.logical_and(y_grid >= y_min, y_grid < y_max),
tf.logical_and(y_grid >= y_min, y_grid <= y_max), tf.math.logical_and(x_grid >= x_min, x_grid < x_max))
tf.logical_and(x_grid >= x_min, x_grid <= x_max)),
dtype=tf.float32) if weights is None:
weights = tf.ones_like(blackout, dtype=tf.float32)
# Shape: [num_instances, height, width]
blackout = tf.tile( # Compute a [height, width] tensor with the maximum weight in each box, and
tf.expand_dims(tf.expand_dims(blackout, axis=-1), axis=-1), # 0.0 elsewhere.
[1, height, width]) weights_tiled = tf.tile(
weights[:, tf.newaxis, tf.newaxis], [1, height, width])
# Select only the boxes specified by blackout. weights_3d = tf.where(in_boxes, weights_tiled,
selected_in_boxes = tf.where(blackout, in_boxes, tf.zeros_like(in_boxes)) tf.zeros_like(weights_tiled))
out_boxes = tf.reduce_max(selected_in_boxes, axis=0) weights_2d = tf.math.maximum(
out_boxes = tf.ones_like(out_boxes) - out_boxes tf.math.reduce_max(weights_3d, axis=0), 0.0)
return out_boxes
# Add 1.0 to all regions outside a box.
weights_2d = tf.where(
tf.math.reduce_any(in_boxes, axis=0),
weights_2d,
tf.ones_like(weights_2d))
# Now enforce that blackout regions all have zero weights.
keep_region = tf.cast(tf.math.logical_not(blackout), tf.float32)
keep_region_tiled = tf.tile(
keep_region[:, tf.newaxis, tf.newaxis], [1, height, width])
keep_region_3d = tf.where(in_boxes, keep_region_tiled,
tf.ones_like(keep_region_tiled))
keep_region_2d = tf.math.reduce_min(keep_region_3d, axis=0)
return weights_2d * keep_region_2d
def _get_yx_indices_offset_by_radius(radius): def _get_yx_indices_offset_by_radius(radius):
......
...@@ -196,13 +196,36 @@ class TargetUtilTest(parameterized.TestCase, test_case.TestCase): ...@@ -196,13 +196,36 @@ class TargetUtilTest(parameterized.TestCase, test_case.TestCase):
return output return output
output = self.execute(graph_fn, []) output = self.execute(graph_fn, [])
# All zeros in region [0:6, 0:6]. # All zeros in region [0:5, 0:5].
self.assertAlmostEqual(np.sum(output[0:6, 0:6]), 0.0) self.assertAlmostEqual(np.sum(output[0:5, 0:5]), 0.0)
# All zeros in region [12:19, 6:9]. # All zeros in region [12:18, 6:8].
self.assertAlmostEqual(np.sum(output[6:9, 12:19]), 0.0) self.assertAlmostEqual(np.sum(output[6:8, 12:18]), 0.0)
# All other pixel weights should be 1.0. # All other pixel weights should be 1.0.
# 20 * 10 - 6 * 6 - 3 * 7 = 143.0 # 20 * 10 - 5 * 5 - 2 * 6 = 163.0
self.assertAlmostEqual(np.sum(output), 143.0) self.assertAlmostEqual(np.sum(output), 163.0)
def test_blackout_pixel_weights_by_box_regions_with_weights(self):
def graph_fn():
boxes = tf.constant(
[[0.0, 0.0, 2.0, 2.0],
[0.0, 0.0, 4.0, 2.0],
[3.0, 0.0, 4.0, 4.0]],
dtype=tf.float32)
blackout = tf.constant([False, False, True], dtype=tf.bool)
weights = tf.constant([0.4, 0.3, 0.2], tf.float32)
blackout_pixel_weights_by_box_regions = tf.function(
ta_utils.blackout_pixel_weights_by_box_regions)
output = blackout_pixel_weights_by_box_regions(
4, 4, boxes, blackout, weights)
return output
output = self.execute(graph_fn, [])
expected_weights = [
[0.4, 0.4, 1.0, 1.0],
[0.4, 0.4, 1.0, 1.0],
[0.3, 0.3, 1.0, 1.0],
[0.0, 0.0, 0.0, 0.0]]
np.testing.assert_array_almost_equal(expected_weights, output)
def test_blackout_pixel_weights_by_box_regions_zero_instance(self): def test_blackout_pixel_weights_by_box_regions_zero_instance(self):
def graph_fn(): def graph_fn():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment