Commit 152a0ce4 authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Implement feature based consistency in DeepMAC.

PiperOrigin-RevId: 474123579
parent 40d0862d
...@@ -33,17 +33,18 @@ PIXEL_EMBEDDING = 'PIXEL_EMBEDDING' ...@@ -33,17 +33,18 @@ PIXEL_EMBEDDING = 'PIXEL_EMBEDDING'
MASK_LOGITS_GT_BOXES = 'MASK_LOGITS_GT_BOXES' MASK_LOGITS_GT_BOXES = 'MASK_LOGITS_GT_BOXES'
DEEP_MASK_ESTIMATION = 'deep_mask_estimation' DEEP_MASK_ESTIMATION = 'deep_mask_estimation'
DEEP_MASK_BOX_CONSISTENCY = 'deep_mask_box_consistency' DEEP_MASK_BOX_CONSISTENCY = 'deep_mask_box_consistency'
DEEP_MASK_COLOR_CONSISTENCY = 'deep_mask_color_consistency' DEEP_MASK_FEATURE_CONSISTENCY = 'deep_mask_feature_consistency'
DEEP_MASK_POINTLY_SUPERVISED = 'deep_mask_pointly_supervised' DEEP_MASK_POINTLY_SUPERVISED = 'deep_mask_pointly_supervised'
SELF_SUPERVISED_DEAUGMENTED_MASK_LOGITS = ( SELF_SUPERVISED_DEAUGMENTED_MASK_LOGITS = (
'SELF_SUPERVISED_DEAUGMENTED_MASK_LOGITS') 'SELF_SUPERVISED_DEAUGMENTED_MASK_LOGITS')
DEEP_MASK_AUGMENTED_SELF_SUPERVISION = 'deep_mask_augmented_self_supervision' DEEP_MASK_AUGMENTED_SELF_SUPERVISION = 'deep_mask_augmented_self_supervision'
CONSISTENCY_FEATURE_MAP = 'CONSISTENCY_FEATURE_MAP'
LOSS_KEY_PREFIX = center_net_meta_arch.LOSS_KEY_PREFIX LOSS_KEY_PREFIX = center_net_meta_arch.LOSS_KEY_PREFIX
NEIGHBORS_2D = [[-1, -1], [-1, 0], [-1, 1], NEIGHBORS_2D = [[-1, -1], [-1, 0], [-1, 1],
[0, -1], [0, 1], [0, -1], [0, 1],
[1, -1], [1, 0], [1, 1]] [1, -1], [1, 0], [1, 1]]
WEAK_LOSSES = [DEEP_MASK_BOX_CONSISTENCY, DEEP_MASK_COLOR_CONSISTENCY, WEAK_LOSSES = [DEEP_MASK_BOX_CONSISTENCY, DEEP_MASK_FEATURE_CONSISTENCY,
DEEP_MASK_AUGMENTED_SELF_SUPERVISION, DEEP_MASK_AUGMENTED_SELF_SUPERVISION,
DEEP_MASK_POINTLY_SUPERVISED] DEEP_MASK_POINTLY_SUPERVISED]
...@@ -56,10 +57,10 @@ DeepMACParams = collections.namedtuple('DeepMACParams', [ ...@@ -56,10 +57,10 @@ DeepMACParams = collections.namedtuple('DeepMACParams', [
'use_xy', 'network_type', 'use_instance_embedding', 'num_init_channels', 'use_xy', 'network_type', 'use_instance_embedding', 'num_init_channels',
'predict_full_resolution_masks', 'postprocess_crop_size', 'predict_full_resolution_masks', 'postprocess_crop_size',
'max_roi_jitter_ratio', 'roi_jitter_mode', 'max_roi_jitter_ratio', 'roi_jitter_mode',
'box_consistency_loss_weight', 'color_consistency_threshold', 'box_consistency_loss_weight', 'feature_consistency_threshold',
'color_consistency_dilation', 'color_consistency_loss_weight', 'feature_consistency_dilation', 'feature_consistency_loss_weight',
'box_consistency_loss_normalize', 'box_consistency_tightness', 'box_consistency_loss_normalize', 'box_consistency_tightness',
'color_consistency_warmup_steps', 'color_consistency_warmup_start', 'feature_consistency_warmup_steps', 'feature_consistency_warmup_start',
'use_only_last_stage', 'augmented_self_supervision_max_translation', 'use_only_last_stage', 'augmented_self_supervision_max_translation',
'augmented_self_supervision_loss_weight', 'augmented_self_supervision_loss_weight',
'augmented_self_supervision_flip_probability', 'augmented_self_supervision_flip_probability',
...@@ -69,7 +70,9 @@ DeepMACParams = collections.namedtuple('DeepMACParams', [ ...@@ -69,7 +70,9 @@ DeepMACParams = collections.namedtuple('DeepMACParams', [
'augmented_self_supervision_scale_min', 'augmented_self_supervision_scale_min',
'augmented_self_supervision_scale_max', 'augmented_self_supervision_scale_max',
'pointly_supervised_keypoint_loss_weight', 'pointly_supervised_keypoint_loss_weight',
'ignore_per_class_box_overlap' 'ignore_per_class_box_overlap',
'feature_consistency_type',
'feature_consistency_comparison'
]) ])
...@@ -77,8 +80,8 @@ def _get_loss_weight(loss_name, config): ...@@ -77,8 +80,8 @@ def _get_loss_weight(loss_name, config):
"""Utility function to get loss weights by name.""" """Utility function to get loss weights by name."""
if loss_name == DEEP_MASK_ESTIMATION: if loss_name == DEEP_MASK_ESTIMATION:
return config.task_loss_weight return config.task_loss_weight
elif loss_name == DEEP_MASK_COLOR_CONSISTENCY: elif loss_name == DEEP_MASK_FEATURE_CONSISTENCY:
return config.color_consistency_loss_weight return config.feature_consistency_loss_weight
elif loss_name == DEEP_MASK_BOX_CONSISTENCY: elif loss_name == DEEP_MASK_BOX_CONSISTENCY:
return config.box_consistency_loss_weight return config.box_consistency_loss_weight
elif loss_name == DEEP_MASK_AUGMENTED_SELF_SUPERVISION: elif loss_name == DEEP_MASK_AUGMENTED_SELF_SUPERVISION:
...@@ -443,6 +446,10 @@ def generate_2d_neighbors(input_tensor, dilation=2): ...@@ -443,6 +446,10 @@ def generate_2d_neighbors(input_tensor, dilation=2):
return tf.transpose(output, [4, 0, 2, 3, 1]) return tf.transpose(output, [4, 0, 2, 3, 1])
def normalize_feature_map(feature_map):
return tf.math.l2_normalize(feature_map, axis=3, epsilon=1e-4)
def gaussian_pixel_similarity(a, b, theta): def gaussian_pixel_similarity(a, b, theta):
norm_difference = tf.linalg.norm(a - b, axis=-1) norm_difference = tf.linalg.norm(a - b, axis=-1)
similarity = tf.exp(-norm_difference / theta) similarity = tf.exp(-norm_difference / theta)
...@@ -1031,7 +1038,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1031,7 +1038,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
image_resizer_fn, image_resizer_fn,
object_center_params, object_center_params,
object_detection_params, object_detection_params,
deepmac_params, deepmac_params: DeepMACParams,
compute_heatmap_sparse=False): compute_heatmap_sparse=False):
"""Constructs the super class with object center & detection params only.""" """Constructs the super class with object center & detection params only."""
...@@ -1507,14 +1514,14 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1507,14 +1514,14 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
return loss return loss
def _compute_color_consistency_loss( def _compute_feature_consistency_loss(
self, boxes, preprocessed_image, mask_logits): self, boxes, consistency_feature_map, mask_logits):
"""Compute the per-instance color consistency loss. """Compute the per-instance feature consistency loss.
Args: Args:
boxes: A [batch_size, num_instances, 4] float tensor of GT boxes. boxes: A [batch_size, num_instances, 4] float tensor of GT boxes.
preprocessed_image: A [batch_size, height, width, 3] consistency_feature_map: A [batch_size, height, width, 3]
float tensor containing the preprocessed image. float tensor containing the feature map to use for consistency.
mask_logits: A [batch_size, num_instances, height, width] float tensor of mask_logits: A [batch_size, num_instances, height, width] float tensor of
predicted masks. predicted masks.
...@@ -1524,27 +1531,40 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1524,27 +1531,40 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
""" """
if not self._deepmac_params.predict_full_resolution_masks: if not self._deepmac_params.predict_full_resolution_masks:
logging.info('Color consistency is not implemented with RoIAlign ' logging.info('Feature consistency is not implemented with RoIAlign '
', i.e, fixed sized masks. Returning 0 loss.') ', i.e, fixed sized masks. Returning 0 loss.')
return tf.zeros(tf.shape(boxes)[:2]) return tf.zeros(tf.shape(boxes)[:2])
dilation = self._deepmac_params.color_consistency_dilation dilation = self._deepmac_params.feature_consistency_dilation
height, width = (tf.shape(consistency_feature_map)[1],
tf.shape(consistency_feature_map)[2])
comparison = self._deepmac_params.feature_consistency_comparison
if comparison == 'comparison_default_gaussian':
similarity = dilated_cross_pixel_similarity(
consistency_feature_map, dilation=dilation, theta=2.0,
method='gaussian')
elif comparison == 'comparison_normalized_dotprod':
consistency_feature_map = normalize_feature_map(consistency_feature_map)
similarity = dilated_cross_pixel_similarity(
consistency_feature_map, dilation=dilation, theta=2.0,
method='dotprod')
else:
raise ValueError('Unknown comparison type - %s' % comparison)
height, width = (tf.shape(preprocessed_image)[1],
tf.shape(preprocessed_image)[2])
color_similarity = dilated_cross_pixel_similarity(
preprocessed_image, dilation=dilation, theta=2.0)
mask_probs = tf.nn.sigmoid(mask_logits) mask_probs = tf.nn.sigmoid(mask_logits)
same_mask_label_probability = dilated_cross_same_mask_label( same_mask_label_probability = dilated_cross_same_mask_label(
mask_probs, dilation=dilation) mask_probs, dilation=dilation)
same_mask_label_probability = tf.clip_by_value( same_mask_label_probability = tf.clip_by_value(
same_mask_label_probability, 1e-3, 1.0) same_mask_label_probability, 1e-3, 1.0)
color_similarity_mask = ( similarity_mask = (
color_similarity > self._deepmac_params.color_consistency_threshold) similarity > self._deepmac_params.feature_consistency_threshold)
color_similarity_mask = tf.cast( similarity_mask = tf.cast(
color_similarity_mask[:, :, tf.newaxis, :, :], tf.float32) similarity_mask[:, :, tf.newaxis, :, :], tf.float32)
per_pixel_loss = -(color_similarity_mask * per_pixel_loss = -(similarity_mask *
tf.math.log(same_mask_label_probability)) tf.math.log(same_mask_label_probability))
# TODO(vighneshb) explore if shrinking the box by 1px helps. # TODO(vighneshb) explore if shrinking the box by 1px helps.
box_mask = fill_boxes(boxes, height, width) box_mask = fill_boxes(boxes, height, width)
...@@ -1558,8 +1578,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1558,8 +1578,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
if tf.keras.backend.learning_phase(): if tf.keras.backend.learning_phase():
loss *= _warmup_weight( loss *= _warmup_weight(
current_training_step=self._training_step, current_training_step=self._training_step,
warmup_start=self._deepmac_params.color_consistency_warmup_start, warmup_start=self._deepmac_params.feature_consistency_warmup_start,
warmup_steps=self._deepmac_params.color_consistency_warmup_steps) warmup_steps=self._deepmac_params.feature_consistency_warmup_steps)
return loss return loss
...@@ -1720,7 +1740,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1720,7 +1740,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
return tf.reshape(loss, [batch_size, num_instances]) return tf.reshape(loss, [batch_size, num_instances])
def _compute_deepmac_losses( def _compute_deepmac_losses(
self, boxes, masks_logits, masks_gt, classes, image, self, boxes, masks_logits, masks_gt, classes, consistency_feature_map,
self_supervised_masks_logits=None, keypoints_gt=None, self_supervised_masks_logits=None, keypoints_gt=None,
keypoints_depth_gt=None): keypoints_depth_gt=None):
"""Returns the mask loss per instance. """Returns the mask loss per instance.
...@@ -1736,8 +1756,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1736,8 +1756,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
DEEP_MASK_ESTIMATION is filled with 0s. DEEP_MASK_ESTIMATION is filled with 0s.
classes: A [batch_size, num_instances, num_classes] tensor of one-hot classes: A [batch_size, num_instances, num_classes] tensor of one-hot
encoded classes. encoded classes.
image: [batch_size, output_height, output_width, channels] float tensor consistency_feature_map: [batch_size, output_height, output_width,
denoting the input image. channels] float tensor denoting the image to use for consistency.
self_supervised_masks_logits: Optional self-supervised mask logits to self_supervised_masks_logits: Optional self-supervised mask logits to
compare against of same shape as mask_logits. compare against of same shape as mask_logits.
keypoints_gt: A float tensor of shape keypoints_gt: A float tensor of shape
...@@ -1753,7 +1773,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1753,7 +1773,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
[batch_size, num_instances]. The 4 keys are: [batch_size, num_instances]. The 4 keys are:
- DEEP_MASK_ESTIMATION - DEEP_MASK_ESTIMATION
- DEEP_MASK_BOX_CONSISTENCY - DEEP_MASK_BOX_CONSISTENCY
- DEEP_MASK_COLOR_CONSISTENCY - DEEP_MASK_FEATURE_CONSISTENCY
- DEEP_MASK_AUGMENTED_SELF_SUPERVISION - DEEP_MASK_AUGMENTED_SELF_SUPERVISION
- DEEP_MASK_POINTLY_SUPERVISED - DEEP_MASK_POINTLY_SUPERVISED
""" """
...@@ -1779,8 +1799,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1779,8 +1799,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
box_consistency_loss = self._compute_box_consistency_loss( box_consistency_loss = self._compute_box_consistency_loss(
boxes, boxes_for_crop, masks_logits) boxes, boxes_for_crop, masks_logits)
color_consistency_loss = self._compute_color_consistency_loss( feature_consistency_loss = self._compute_feature_consistency_loss(
boxes, image, masks_logits) boxes, consistency_feature_map, masks_logits)
self_supervised_loss = self._compute_self_supervised_augmented_loss( self_supervised_loss = self._compute_self_supervised_augmented_loss(
masks_logits, self_supervised_masks_logits, boxes, masks_logits, self_supervised_masks_logits, boxes,
...@@ -1793,7 +1813,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1793,7 +1813,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
return { return {
DEEP_MASK_ESTIMATION: mask_prediction_loss, DEEP_MASK_ESTIMATION: mask_prediction_loss,
DEEP_MASK_BOX_CONSISTENCY: box_consistency_loss, DEEP_MASK_BOX_CONSISTENCY: box_consistency_loss,
DEEP_MASK_COLOR_CONSISTENCY: color_consistency_loss, DEEP_MASK_FEATURE_CONSISTENCY: feature_consistency_loss,
DEEP_MASK_AUGMENTED_SELF_SUPERVISION: self_supervised_loss, DEEP_MASK_AUGMENTED_SELF_SUPERVISION: self_supervised_loss,
DEEP_MASK_POINTLY_SUPERVISED: pointly_supervised_loss, DEEP_MASK_POINTLY_SUPERVISED: pointly_supervised_loss,
} }
...@@ -1815,6 +1835,26 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1815,6 +1835,26 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
else: else:
return None return None
def _get_consistency_feature_map(self, prediction_dict):
prediction_shape = tf.shape(prediction_dict[MASK_LOGITS_GT_BOXES][0])
height, width = prediction_shape[2], prediction_shape[3]
consistency_type = self._deepmac_params.feature_consistency_type
if consistency_type == 'consistency_default_lab':
preprocessed_image = tf.image.resize(
prediction_dict['preprocessed_inputs'], (height, width))
consistency_feature_map = self._get_lab_image(preprocessed_image)
elif consistency_type == 'consistency_feature_map':
consistency_feature_map = prediction_dict['extracted_features'][-1]
consistency_feature_map = tf.image.resize(
consistency_feature_map, (height, width))
else:
raise ValueError('Unknown feature consistency type - {}.'.format(
self._deepmac_params.feature_consistency_type))
return tf.stop_gradient(consistency_feature_map)
def _compute_masks_loss(self, prediction_dict): def _compute_masks_loss(self, prediction_dict):
"""Computes the mask loss. """Computes the mask loss.
...@@ -1835,13 +1875,6 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1835,13 +1875,6 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
for loss_name in MASK_LOSSES: for loss_name in MASK_LOSSES:
loss_dict[loss_name] = 0.0 loss_dict[loss_name] = 0.0
prediction_shape = tf.shape(prediction_dict[MASK_LOGITS_GT_BOXES][0])
height, width = prediction_shape[2], prediction_shape[3]
preprocessed_image = tf.image.resize(
prediction_dict['preprocessed_inputs'], (height, width))
image = self._get_lab_image(preprocessed_image)
gt_boxes = self._maybe_get_gt_batch(fields.BoxListFields.boxes) gt_boxes = self._maybe_get_gt_batch(fields.BoxListFields.boxes)
gt_weights = self._maybe_get_gt_batch(fields.BoxListFields.weights) gt_weights = self._maybe_get_gt_batch(fields.BoxListFields.weights)
gt_classes = self._maybe_get_gt_batch(fields.BoxListFields.classes) gt_classes = self._maybe_get_gt_batch(fields.BoxListFields.classes)
...@@ -1855,6 +1888,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1855,6 +1888,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
[None] * len(mask_logits_list)) [None] * len(mask_logits_list))
assert len(mask_logits_list) == len(self_supervised_mask_logits_list) assert len(mask_logits_list) == len(self_supervised_mask_logits_list)
consistency_feature_map = self._get_consistency_feature_map(prediction_dict)
# Iterate over multiple preidctions by backbone (for hourglass length=2) # Iterate over multiple preidctions by backbone (for hourglass length=2)
for (mask_logits, self_supervised_mask_logits) in zip( for (mask_logits, self_supervised_mask_logits) in zip(
mask_logits_list, self_supervised_mask_logits_list): mask_logits_list, self_supervised_mask_logits_list):
...@@ -1866,7 +1901,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1866,7 +1901,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
sample_loss_dict = self._compute_deepmac_losses( sample_loss_dict = self._compute_deepmac_losses(
boxes=gt_boxes, masks_logits=mask_logits, masks_gt=gt_masks, boxes=gt_boxes, masks_logits=mask_logits, masks_gt=gt_masks,
classes=gt_classes, image=image, classes=gt_classes, consistency_feature_map=consistency_feature_map,
self_supervised_masks_logits=self_supervised_mask_logits, self_supervised_masks_logits=self_supervised_mask_logits,
keypoints_gt=gt_keypoints, keypoints_depth_gt=gt_depths) keypoints_gt=gt_keypoints, keypoints_depth_gt=gt_depths)
......
...@@ -83,12 +83,12 @@ def build_meta_arch(**override_params): ...@@ -83,12 +83,12 @@ def build_meta_arch(**override_params):
use_xy=True, use_xy=True,
pixel_embedding_dim=2, pixel_embedding_dim=2,
dice_loss_prediction_probability=False, dice_loss_prediction_probability=False,
color_consistency_threshold=0.5, feature_consistency_threshold=0.5,
use_dice_loss=False, use_dice_loss=False,
box_consistency_loss_normalize='normalize_auto', box_consistency_loss_normalize='normalize_auto',
box_consistency_tightness=False, box_consistency_tightness=False,
task_loss_weight=1.0, task_loss_weight=1.0,
color_consistency_loss_weight=1.0, feature_consistency_loss_weight=1.0,
box_consistency_loss_weight=1.0, box_consistency_loss_weight=1.0,
num_init_channels=8, num_init_channels=8,
dim=8, dim=8,
...@@ -97,9 +97,9 @@ def build_meta_arch(**override_params): ...@@ -97,9 +97,9 @@ def build_meta_arch(**override_params):
postprocess_crop_size=128, postprocess_crop_size=128,
max_roi_jitter_ratio=0.0, max_roi_jitter_ratio=0.0,
roi_jitter_mode='default', roi_jitter_mode='default',
color_consistency_dilation=2, feature_consistency_dilation=2,
color_consistency_warmup_steps=0, feature_consistency_warmup_steps=0,
color_consistency_warmup_start=0, feature_consistency_warmup_start=0,
use_only_last_stage=True, use_only_last_stage=True,
augmented_self_supervision_max_translation=0.0, augmented_self_supervision_max_translation=0.0,
augmented_self_supervision_loss_weight=0.0, augmented_self_supervision_loss_weight=0.0,
...@@ -110,7 +110,9 @@ def build_meta_arch(**override_params): ...@@ -110,7 +110,9 @@ def build_meta_arch(**override_params):
augmented_self_supervision_scale_min=1.0, augmented_self_supervision_scale_min=1.0,
augmented_self_supervision_scale_max=1.0, augmented_self_supervision_scale_max=1.0,
pointly_supervised_keypoint_loss_weight=1.0, pointly_supervised_keypoint_loss_weight=1.0,
ignore_per_class_box_overlap=False) ignore_per_class_box_overlap=False,
feature_consistency_type='consistency_default_lab',
feature_consistency_comparison='comparison_default_gaussian')
params.update(override_params) params.update(override_params)
...@@ -183,13 +185,13 @@ DEEPMAC_PROTO_TEXT = """ ...@@ -183,13 +185,13 @@ DEEPMAC_PROTO_TEXT = """
predict_full_resolution_masks: true predict_full_resolution_masks: true
allowed_masked_classes_ids: [99] allowed_masked_classes_ids: [99]
box_consistency_loss_weight: 1.0 box_consistency_loss_weight: 1.0
color_consistency_loss_weight: 1.0 feature_consistency_loss_weight: 1.0
color_consistency_threshold: 0.1 feature_consistency_threshold: 0.1
box_consistency_tightness: false box_consistency_tightness: false
box_consistency_loss_normalize: NORMALIZE_AUTO box_consistency_loss_normalize: NORMALIZE_AUTO
color_consistency_warmup_steps: 20 feature_consistency_warmup_steps: 20
color_consistency_warmup_start: 10 feature_consistency_warmup_start: 10
use_only_last_stage: false use_only_last_stage: false
augmented_self_supervision_warmup_start: 13 augmented_self_supervision_warmup_start: 13
augmented_self_supervision_warmup_steps: 14 augmented_self_supervision_warmup_steps: 14
...@@ -201,6 +203,8 @@ DEEPMAC_PROTO_TEXT = """ ...@@ -201,6 +203,8 @@ DEEPMAC_PROTO_TEXT = """
augmented_self_supervision_scale_max: 1.42 augmented_self_supervision_scale_max: 1.42
pointly_supervised_keypoint_loss_weight: 0.13 pointly_supervised_keypoint_loss_weight: 0.13
ignore_per_class_box_overlap: true ignore_per_class_box_overlap: true
feature_consistency_type: CONSISTENCY_FEATURE_MAP
feature_consistency_comparison: COMPARISON_NORMALIZED_DOTPROD
""" """
...@@ -232,6 +236,9 @@ class DeepMACUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -232,6 +236,9 @@ class DeepMACUtilsTest(tf.test.TestCase, parameterized.TestCase):
self.assertAlmostEqual( self.assertAlmostEqual(
params.pointly_supervised_keypoint_loss_weight, 0.13) params.pointly_supervised_keypoint_loss_weight, 0.13)
self.assertTrue(params.ignore_per_class_box_overlap) self.assertTrue(params.ignore_per_class_box_overlap)
self.assertEqual(params.feature_consistency_type, 'consistency_feature_map')
self.assertEqual(
params.feature_consistency_comparison, 'comparison_normalized_dotprod')
def test_subsample_trivial(self): def test_subsample_trivial(self):
"""Test subsampling masks.""" """Test subsampling masks."""
...@@ -1255,7 +1262,7 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -1255,7 +1262,7 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(loss, [[yloss + xloss]]) self.assertAllClose(loss, [[yloss + xloss]])
def test_color_consistency_loss_full_res_shape(self): def test_feature_consistency_loss_full_res_shape(self):
model = build_meta_arch(use_dice_loss=True, model = build_meta_arch(use_dice_loss=True,
predict_full_resolution_masks=True) predict_full_resolution_masks=True)
...@@ -1263,18 +1270,18 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -1263,18 +1270,18 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
img = tf.zeros((5, 32, 32, 3)) img = tf.zeros((5, 32, 32, 3))
mask_logits = tf.zeros((5, 3, 32, 32)) mask_logits = tf.zeros((5, 3, 32, 32))
loss = model._compute_color_consistency_loss( loss = model._compute_feature_consistency_loss(
boxes, img, mask_logits) boxes, img, mask_logits)
self.assertEqual([5, 3], loss.shape) self.assertEqual([5, 3], loss.shape)
def test_color_consistency_1_threshold(self): def test_feature_consistency_1_threshold(self):
model = build_meta_arch(predict_full_resolution_masks=True, model = build_meta_arch(predict_full_resolution_masks=True,
color_consistency_threshold=0.99) feature_consistency_threshold=0.99)
boxes = tf.zeros((5, 3, 4)) boxes = tf.zeros((5, 3, 4))
img = tf.zeros((5, 32, 32, 3)) img = tf.zeros((5, 32, 32, 3))
mask_logits = tf.zeros((5, 3, 32, 32)) - 1e4 mask_logits = tf.zeros((5, 3, 32, 32)) - 1e4
loss = model._compute_color_consistency_loss( loss = model._compute_feature_consistency_loss(
boxes, img, mask_logits) boxes, img, mask_logits)
self.assertAllClose(loss, np.zeros((5, 3))) self.assertAllClose(loss, np.zeros((5, 3)))
...@@ -1414,7 +1421,8 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -1414,7 +1421,8 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
[tf.random.normal((1, 5, 8, 8))] * num_stages, [tf.random.normal((1, 5, 8, 8))] * num_stages,
'object_center': [tf.random.normal((1, 8, 8, 6))] * num_stages, 'object_center': [tf.random.normal((1, 8, 8, 6))] * num_stages,
'box/offset': [tf.random.normal((1, 8, 8, 2))] * num_stages, 'box/offset': [tf.random.normal((1, 8, 8, 2))] * num_stages,
'box/scale': [tf.random.normal((1, 8, 8, 2))] * num_stages 'box/scale': [tf.random.normal((1, 8, 8, 2))] * num_stages,
'extracted_features': [tf.random.normal((3, 32, 32, 7))] * num_stages
} }
boxes = [tf.convert_to_tensor([[0., 0., 1., 1.]] * 5)] boxes = [tf.convert_to_tensor([[0., 0., 1., 1.]] * 5)]
...@@ -1477,7 +1485,8 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -1477,7 +1485,8 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
'box/offset': [tf.random.normal((3, 8, 8, 2))] * 2, 'box/offset': [tf.random.normal((3, 8, 8, 2))] * 2,
'box/scale': [tf.random.normal((3, 8, 8, 2))] * 2, 'box/scale': [tf.random.normal((3, 8, 8, 2))] * 2,
'SELF_SUPERVISED_DEAUGMENTED_MASK_LOGITS': ( 'SELF_SUPERVISED_DEAUGMENTED_MASK_LOGITS': (
[tf.random.normal((3, 5, 8, 8))] * 2) [tf.random.normal((3, 5, 8, 8))] * 2),
'extracted_features': [tf.random.normal((3, 32, 32, 7))] * 2
} }
model.provide_groundtruth( model.provide_groundtruth(
groundtruth_boxes_list=[ groundtruth_boxes_list=[
...@@ -1491,7 +1500,7 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -1491,7 +1500,7 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
self.assertGreater(loss['Loss/deep_mask_estimation'], 0.0) self.assertGreater(loss['Loss/deep_mask_estimation'], 0.0)
for weak_loss in deepmac_meta_arch.MASK_LOSSES: for weak_loss in deepmac_meta_arch.MASK_LOSSES:
if weak_loss == deepmac_meta_arch.DEEP_MASK_COLOR_CONSISTENCY: if weak_loss == deepmac_meta_arch.DEEP_MASK_FEATURE_CONSISTENCY:
continue continue
self.assertGreater(loss['Loss/' + weak_loss], 0.0, self.assertGreater(loss['Loss/' + weak_loss], 0.0,
'{} was <= 0'.format(weak_loss)) '{} was <= 0'.format(weak_loss))
...@@ -1544,7 +1553,8 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -1544,7 +1553,8 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
'box/offset': [tf.random.normal((1, 8, 8, 2))] * num_stages, 'box/offset': [tf.random.normal((1, 8, 8, 2))] * num_stages,
'box/scale': [tf.random.normal((1, 8, 8, 2))] * num_stages, 'box/scale': [tf.random.normal((1, 8, 8, 2))] * num_stages,
'SELF_SUPERVISED_DEAUGMENTED_MASK_LOGITS': ( 'SELF_SUPERVISED_DEAUGMENTED_MASK_LOGITS': (
[tf.random.normal((1, 5, 8, 8))] * num_stages) [tf.random.normal((1, 5, 8, 8))] * num_stages),
'extracted_features': [tf.random.normal((3, 32, 32, 7))] * num_stages
} }
boxes = [tf.convert_to_tensor([[0., 0., 1., 1.]] * 5)] boxes = [tf.convert_to_tensor([[0., 0., 1., 1.]] * 5)]
...@@ -1571,7 +1581,7 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -1571,7 +1581,7 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
loss_weights = { loss_weights = {
deepmac_meta_arch.DEEP_MASK_ESTIMATION: rng.uniform(1, 5), deepmac_meta_arch.DEEP_MASK_ESTIMATION: rng.uniform(1, 5),
deepmac_meta_arch.DEEP_MASK_BOX_CONSISTENCY: rng.uniform(1, 5), deepmac_meta_arch.DEEP_MASK_BOX_CONSISTENCY: rng.uniform(1, 5),
deepmac_meta_arch.DEEP_MASK_COLOR_CONSISTENCY: rng.uniform(1, 5), deepmac_meta_arch.DEEP_MASK_FEATURE_CONSISTENCY: rng.uniform(1, 5),
deepmac_meta_arch.DEEP_MASK_AUGMENTED_SELF_SUPERVISION: ( deepmac_meta_arch.DEEP_MASK_AUGMENTED_SELF_SUPERVISION: (
rng.uniform(1, 5)), rng.uniform(1, 5)),
deepmac_meta_arch.DEEP_MASK_POINTLY_SUPERVISED: rng.uniform(1, 5) deepmac_meta_arch.DEEP_MASK_POINTLY_SUPERVISED: rng.uniform(1, 5)
...@@ -1588,8 +1598,8 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -1588,8 +1598,8 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
task_loss_weight=loss_weights[deepmac_meta_arch.DEEP_MASK_ESTIMATION], task_loss_weight=loss_weights[deepmac_meta_arch.DEEP_MASK_ESTIMATION],
box_consistency_loss_weight=( box_consistency_loss_weight=(
loss_weights[deepmac_meta_arch.DEEP_MASK_BOX_CONSISTENCY]), loss_weights[deepmac_meta_arch.DEEP_MASK_BOX_CONSISTENCY]),
color_consistency_loss_weight=( feature_consistency_loss_weight=(
loss_weights[deepmac_meta_arch.DEEP_MASK_COLOR_CONSISTENCY]), loss_weights[deepmac_meta_arch.DEEP_MASK_FEATURE_CONSISTENCY]),
augmented_self_supervision_loss_weight=( augmented_self_supervision_loss_weight=(
loss_weights[deepmac_meta_arch.DEEP_MASK_AUGMENTED_SELF_SUPERVISION] loss_weights[deepmac_meta_arch.DEEP_MASK_AUGMENTED_SELF_SUPERVISION]
), ),
...@@ -1612,7 +1622,14 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -1612,7 +1622,14 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
weighted_loss[loss_key], loss[loss_key] * loss_weights[mask_loss], weighted_loss[loss_key], loss[loss_key] * loss_weights[mask_loss],
f'{mask_loss} did not respond to change in weight.') f'{mask_loss} did not respond to change in weight.')
def test_color_consistency_warmup(self): @parameterized.parameters(
[dict(feature_consistency_type='consistency_default_lab',
feature_consistency_comparison='comparison_default_gaussian'),
dict(feature_consistency_type='consistency_feature_map',
feature_consistency_comparison='comparison_normalized_dotprod')],
)
def test_feature_consistency_warmup(
self, feature_consistency_type, feature_consistency_comparison):
tf.keras.backend.set_learning_phase(True) tf.keras.backend.set_learning_phase(True)
model = build_meta_arch( model = build_meta_arch(
use_dice_loss=True, use_dice_loss=True,
...@@ -1622,15 +1639,19 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -1622,15 +1639,19 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
pixel_embedding_dim=8, pixel_embedding_dim=8,
use_instance_embedding=False, use_instance_embedding=False,
use_xy=False, use_xy=False,
color_consistency_warmup_steps=10, feature_consistency_warmup_steps=10,
color_consistency_warmup_start=10) feature_consistency_warmup_start=10,
feature_consistency_type=feature_consistency_type,
feature_consistency_comparison=feature_consistency_comparison)
num_stages = 1 num_stages = 1
prediction = { prediction = {
'preprocessed_inputs': tf.random.normal((1, 32, 32, 3)), 'preprocessed_inputs': tf.random.normal((1, 32, 32, 3)),
'MASK_LOGITS_GT_BOXES': [tf.random.normal((1, 5, 8, 8))] * num_stages, 'MASK_LOGITS_GT_BOXES': [tf.random.normal((1, 5, 8, 8))] * num_stages,
'object_center': [tf.random.normal((1, 8, 8, 6))] * num_stages, 'object_center': [tf.random.normal((1, 8, 8, 6))] * num_stages,
'box/offset': [tf.random.normal((1, 8, 8, 2))] * num_stages, 'box/offset': [tf.random.normal((1, 8, 8, 2))] * num_stages,
'box/scale': [tf.random.normal((1, 8, 8, 2))] * num_stages 'box/scale': [tf.random.normal((1, 8, 8, 2))] * num_stages,
'extracted_features': [tf.random.normal((3, 32, 32, 7))] * num_stages
} }
boxes = [tf.convert_to_tensor([[0., 0., 1., 1.]] * 5)] boxes = [tf.convert_to_tensor([[0., 0., 1., 1.]] * 5)]
...@@ -1670,7 +1691,7 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -1670,7 +1691,7 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
training_step=100) training_step=100)
loss_at_100 = model.loss(prediction, tf.constant([[32, 32, 3.0]])) loss_at_100 = model.loss(prediction, tf.constant([[32, 32, 3.0]]))
loss_key = 'Loss/' + deepmac_meta_arch.DEEP_MASK_COLOR_CONSISTENCY loss_key = 'Loss/' + deepmac_meta_arch.DEEP_MASK_FEATURE_CONSISTENCY
self.assertAlmostEqual(loss_at_5[loss_key].numpy(), 0.0) self.assertAlmostEqual(loss_at_5[loss_key].numpy(), 0.0)
self.assertGreater(loss_at_15[loss_key], 0.0) self.assertGreater(loss_at_15[loss_key], 0.0)
self.assertAlmostEqual(loss_at_15[loss_key].numpy(), self.assertAlmostEqual(loss_at_15[loss_key].numpy(),
......
...@@ -21,7 +21,8 @@ REQUIRED_PACKAGES = [ ...@@ -21,7 +21,8 @@ REQUIRED_PACKAGES = [
'tf-models-official>=2.5.1', 'tf-models-official>=2.5.1',
'tensorflow_io', 'tensorflow_io',
'keras', 'keras',
'pyparsing==2.4.7' # TODO(b/204103388) 'pyparsing==2.4.7', # TODO(b/204103388)
'sacrebleu<=2.2.0' # https://github.com/mjpost/sacrebleu/issues/209
] ]
setup( setup(
......
...@@ -403,7 +403,7 @@ message CenterNet { ...@@ -403,7 +403,7 @@ message CenterNet {
// Mask prediction support using DeepMAC. See https://arxiv.org/abs/2104.00613 // Mask prediction support using DeepMAC. See https://arxiv.org/abs/2104.00613
// Next ID 35 // Next ID 37
message DeepMACMaskEstimation { message DeepMACMaskEstimation {
// The loss used for penalizing mask predictions. // The loss used for penalizing mask predictions.
optional ClassificationLoss classification_loss = 1; optional ClassificationLoss classification_loss = 1;
...@@ -466,11 +466,11 @@ message CenterNet { ...@@ -466,11 +466,11 @@ message CenterNet {
// https://arxiv.org/abs/2012.02310 // https://arxiv.org/abs/2012.02310
optional float box_consistency_loss_weight = 16 [default = 0.0]; optional float box_consistency_loss_weight = 16 [default = 0.0];
optional float color_consistency_threshold = 17 [default = 0.4]; optional float feature_consistency_threshold = 17 [default = 0.4];
optional int32 color_consistency_dilation = 18 [default = 2]; optional int32 feature_consistency_dilation = 18 [default = 2];
optional float color_consistency_loss_weight = 19 [default = 0.0]; optional float feature_consistency_loss_weight = 19 [default = 0.0];
optional LossNormalize box_consistency_loss_normalize = 20 optional LossNormalize box_consistency_loss_normalize = 20
[default = NORMALIZE_AUTO]; [default = NORMALIZE_AUTO];
...@@ -481,9 +481,16 @@ message CenterNet { ...@@ -481,9 +481,16 @@ message CenterNet {
// https://papers.nips.cc/paper/2019/hash/e6e713296627dff6475085cc6a224464-Abstract.html // https://papers.nips.cc/paper/2019/hash/e6e713296627dff6475085cc6a224464-Abstract.html
optional bool box_consistency_tightness = 21 [default = false]; optional bool box_consistency_tightness = 21 [default = false];
optional int32 color_consistency_warmup_steps = 22 [default = 0]; optional int32 feature_consistency_warmup_steps = 22 [default = 0];
optional int32 color_consistency_warmup_start = 23 [default = 0]; optional int32 feature_consistency_warmup_start = 23 [default = 0];
// TODO(vighneshb)
optional FeatureConsistencyType feature_consistency_type = 35
[default = CONSISTENCY_DEFAULT_LAB];
optional FeatureConsistencyComparison feature_consistency_comparison = 36
[default = COMPARISON_DEFAULT_GAUSSIAN];
// This flag controls whether or not we use the outputs from only the // This flag controls whether or not we use the outputs from only the
// last stage of the hourglass for training the mask-heads. // last stage of the hourglass for training the mask-heads.
...@@ -505,21 +512,24 @@ message CenterNet { ...@@ -505,21 +512,24 @@ message CenterNet {
optional bool use_only_last_stage = 24 [default = false]; optional bool use_only_last_stage = 24 [default = false];
optional float augmented_self_supervision_max_translation = 25 [default=0.0]; optional float augmented_self_supervision_max_translation = 25
[default = 0.0];
optional float augmented_self_supervision_flip_probability = 26 [default=0.0]; optional float augmented_self_supervision_flip_probability = 26
[default = 0.0];
optional float augmented_self_supervision_loss_weight = 27 [default=0.0]; optional float augmented_self_supervision_loss_weight = 27 [default = 0.0];
optional int32 augmented_self_supervision_warmup_start = 28 [default=0]; optional int32 augmented_self_supervision_warmup_start = 28 [default = 0];
optional int32 augmented_self_supervision_warmup_steps = 29 [default=0]; optional int32 augmented_self_supervision_warmup_steps = 29 [default = 0];
optional AugmentedSelfSupervisionLoss augmented_self_supervision_loss = 30 [default=LOSS_DICE]; optional AugmentedSelfSupervisionLoss augmented_self_supervision_loss = 30
[default = LOSS_DICE];
optional float augmented_self_supervision_scale_min = 31 [default=1.0]; optional float augmented_self_supervision_scale_min = 31 [default = 1.0];
optional float augmented_self_supervision_scale_max = 32 [default=1.0]; optional float augmented_self_supervision_scale_max = 32 [default = 1.0];
// The loss weight for the pointly supervised loss as defined in the paper // The loss weight for the pointly supervised loss as defined in the paper
// https://arxiv.org/abs/2104.06404 // https://arxiv.org/abs/2104.06404
...@@ -534,7 +544,6 @@ message CenterNet { ...@@ -534,7 +544,6 @@ message CenterNet {
// When set, loss computation is ignored at pixels that fall within // When set, loss computation is ignored at pixels that fall within
// 2 boxes of the same class. // 2 boxes of the same class.
optional bool ignore_per_class_box_overlap = 34 [default = false]; optional bool ignore_per_class_box_overlap = 34 [default = false];
} }
optional DeepMACMaskEstimation deepmac_mask_estimation = 14; optional DeepMACMaskEstimation deepmac_mask_estimation = 14;
...@@ -564,6 +573,16 @@ enum AugmentedSelfSupervisionLoss { ...@@ -564,6 +573,16 @@ enum AugmentedSelfSupervisionLoss {
LOSS_KL_DIV = 3; LOSS_KL_DIV = 3;
} }
enum FeatureConsistencyType {
CONSISTENCY_DEFAULT_LAB = 0;
CONSISTENCY_FEATURE_MAP = 1;
}
enum FeatureConsistencyComparison {
COMPARISON_DEFAULT_GAUSSIAN = 0;
COMPARISON_NORMALIZED_DOTPROD = 1;
}
message CenterNetFeatureExtractor { message CenterNetFeatureExtractor {
optional string type = 1; optional string type = 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment