Commit 20949dd8 authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Add self-supervised augmentation to DeepMAC HR model.

PiperOrigin-RevId: 440225825
parent a98cb7d2
......@@ -12,6 +12,7 @@ import tensorflow as tf
from object_detection.builders import losses_builder
from object_detection.core import box_list
from object_detection.core import box_list_ops
from object_detection.core import losses
from object_detection.core import preprocessor
from object_detection.core import standard_fields as fields
from object_detection.meta_architectures import center_net_meta_arch
......@@ -19,7 +20,6 @@ from object_detection.models.keras_models import hourglass_network
from object_detection.models.keras_models import resnet_v1
from object_detection.protos import center_net_pb2
from object_detection.protos import losses_pb2
from object_detection.protos import preprocessor_pb2
from object_detection.utils import shape_utils
from object_detection.utils import spatial_transform_ops
from object_detection.utils import tf_version
......@@ -34,11 +34,16 @@ MASK_LOGITS_GT_BOXES = 'MASK_LOGITS_GT_BOXES'
DEEP_MASK_ESTIMATION = 'deep_mask_estimation'
DEEP_MASK_BOX_CONSISTENCY = 'deep_mask_box_consistency'
DEEP_MASK_COLOR_CONSISTENCY = 'deep_mask_color_consistency'
SELF_SUPERVISED_DEAUGMENTED_MASK_LOGITS = (
'SELF_SUPERVISED_DEAUGMENTED_MASK_LOGITS')
DEEP_MASK_AUGMENTED_SELF_SUPERVISION = 'deep_mask_augmented_self_supervision'
LOSS_KEY_PREFIX = center_net_meta_arch.LOSS_KEY_PREFIX
NEIGHBORS_2D = [[-1, -1], [-1, 0], [-1, 1],
[0, -1], [0, 1],
[1, -1], [1, 0], [1, 1]]
WEAK_LOSSES = [DEEP_MASK_BOX_CONSISTENCY, DEEP_MASK_COLOR_CONSISTENCY]
WEAK_LOSSES = [DEEP_MASK_BOX_CONSISTENCY, DEEP_MASK_COLOR_CONSISTENCY,
DEEP_MASK_AUGMENTED_SELF_SUPERVISION]
MASK_LOSSES = WEAK_LOSSES + [DEEP_MASK_ESTIMATION]
......@@ -52,17 +57,27 @@ DeepMACParams = collections.namedtuple('DeepMACParams', [
'color_consistency_dilation', 'color_consistency_loss_weight',
'box_consistency_loss_normalize', 'box_consistency_tightness',
'color_consistency_warmup_steps', 'color_consistency_warmup_start',
'use_only_last_stage'
'use_only_last_stage', 'augmented_self_supervision_max_translation',
'augmented_self_supervision_loss_weight',
'augmented_self_supervision_flip_probability',
'augmented_self_supervision_warmup_start',
'augmented_self_supervision_warmup_steps',
'augmented_self_supervision_loss',
'augmented_self_supervision_scale_min',
'augmented_self_supervision_scale_max'
])
def _get_loss_weight(loss_name, config):
"""Utility function to get loss weights by name."""
if loss_name == DEEP_MASK_ESTIMATION:
return config.task_loss_weight
elif loss_name == DEEP_MASK_COLOR_CONSISTENCY:
return config.color_consistency_loss_weight
elif loss_name == DEEP_MASK_BOX_CONSISTENCY:
return config.box_consistency_loss_weight
elif loss_name == DEEP_MASK_AUGMENTED_SELF_SUPERVISION:
return config.augmented_self_supervision_loss_weight
else:
raise ValueError('Unknown loss - {}'.format(loss_name))
......@@ -142,6 +157,28 @@ def _get_deepmac_network_by_type(name, num_init_channels, mask_size=None):
raise ValueError('Unknown network type {}'.format(name))
def boxes_batch_normalized_to_absolute_coordinates(boxes, height, width):
ymin, xmin, ymax, xmax = tf.unstack(boxes, axis=2)
height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32)
ymin *= height
ymax *= height
xmin *= width
xmax *= width
return tf.stack([ymin, xmin, ymax, xmax], axis=2)
def boxes_batch_absolute_to_normalized_coordinates(boxes, height, width):
ymin, xmin, ymax, xmax = tf.unstack(boxes, axis=2)
height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32)
ymin /= height
ymax /= height
xmin /= width
xmax /= width
return tf.stack([ymin, xmin, ymax, xmax], axis=2)
def _resize_instance_masks_non_empty(masks, shape):
"""Resize a non-empty tensor of masks to the given shape."""
height, width = shape
......@@ -272,7 +309,7 @@ def fill_boxes(boxes, height, width):
"""Fills the area included in the boxes with 1s.
Args:
boxes: A [batch_size, num_instances, 4] shapes float tensor of boxes given
boxes: A [batch_size, num_instances, 4] shaped float tensor of boxes given
in the normalized coordinate space.
height: int, height of the output image.
width: int, width of the output image.
......@@ -282,13 +319,10 @@ def fill_boxes(boxes, height, width):
tensor with 1s in the area that falls inside each box.
"""
boxes_abs = boxes_batch_normalized_to_absolute_coordinates(
boxes, height, width)
ymin, xmin, ymax, xmax = tf.unstack(
boxes[:, :, tf.newaxis, tf.newaxis, :], 4, axis=4)
height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32)
ymin *= height
ymax *= height
xmin *= width
xmax *= width
boxes_abs[:, :, tf.newaxis, tf.newaxis, :], 4, axis=4)
ygrid, xgrid = tf.meshgrid(tf.range(height), tf.range(width), indexing='ij')
ygrid, xgrid = tf.cast(ygrid, tf.float32), tf.cast(xgrid, tf.float32)
......@@ -513,6 +547,120 @@ def per_pixel_conditional_conv(input_tensor, parameters, channels, depth):
return output
def flip_boxes_left_right(boxes):
ymin, xmin, ymax, xmax = tf.unstack(boxes, axis=2)
return tf.stack(
[ymin, 1.0 - xmax, ymax, 1.0 - xmin], axis=2
)
def transform_images_and_boxes(images, boxes, tx, ty, sx, sy, flip):
"""Translate and scale a batch of images and boxes by the given amount.
The function first translates and then scales the image and assumes the
origin to be at the center of the image.
Args:
images: A [batch_size, height, width, 3] float tensor of images.
boxes: optional, A [batch_size, num_instances, 4] shaped float tensor of
normalized bounding boxes. If None, the second return value is always
None.
tx: A [batch_size] shaped float tensor of x translations.
ty: A [batch_size] shaped float tensor of y translations.
sx: A [batch_size] shaped float tensor of x scale factor.
sy: A [batch_size] shaped float tensor of y scale factor.
flip: A [batch_size] shaped bool tensor indicating whether or not we
flip the image.
Returns:
transformed_images: Transfomed images of same shape as `images`.
transformed_boxes: If `boxes` was not None, transformed boxes of same
shape as boxes.
"""
_, height, width, _ = shape_utils.combined_static_and_dynamic_shape(
images)
flip_selector = tf.cast(flip, tf.float32)
flip_selector_4d = flip_selector[:, tf.newaxis, tf.newaxis, tf.newaxis]
flip_selector_3d = flip_selector[:, tf.newaxis, tf.newaxis]
flipped_images = tf.image.flip_left_right(images)
images = flipped_images * flip_selector_4d + (1.0 - flip_selector_4d) * images
cy = cx = tf.zeros_like(tx) + 0.5
ymin = -ty*sy + cy - sy * 0.5
xmin = -tx*sx + cx - sx * 0.5
ymax = -ty*sy + cy + sy * 0.5
xmax = -tx*sx + cx + sx * 0.5
crop_box = tf.stack([ymin, xmin, ymax, xmax], axis=1)
crop_box_expanded = crop_box[:, tf.newaxis, :]
images_transformed = spatial_transform_ops.matmul_crop_and_resize(
images, crop_box_expanded, (height, width)
)
images_transformed = images_transformed[:, 0, :, :, :]
if boxes is not None:
flipped_boxes = flip_boxes_left_right(boxes)
boxes = flipped_boxes * flip_selector_3d + (1.0 - flip_selector_3d) * boxes
win_height = ymax - ymin
win_width = xmax - xmin
win_height = win_height[:, tf.newaxis]
win_width = win_width[:, tf.newaxis]
boxes_transformed = (
boxes - tf.stack([ymin, xmin, ymin, xmin], axis=1)[:, tf.newaxis, :])
boxes_ymin, boxes_xmin, boxes_ymax, boxes_xmax = tf.unstack(
boxes_transformed, axis=2)
boxes_ymin *= 1.0 / win_height
boxes_xmin *= 1.0 / win_width
boxes_ymax *= 1.0 / win_height
boxes_xmax *= 1.0 / win_width
boxes = tf.stack([boxes_ymin, boxes_xmin, boxes_ymax, boxes_xmax], axis=2)
return images_transformed, boxes
def transform_instance_masks(instance_masks, tx, ty, sx, sy, flip):
"""Transforms a batch of instances by the given amount.
Args:
instance_masks: A [batch_size, num_instances, height, width, 3] float
tensor of instance masks.
tx: A [batch_size] shaped float tensor of x translations.
ty: A [batch_size] shaped float tensor of y translations.
sx: A [batch_size] shaped float tensor of x scale factor.
sy: A [batch_size] shaped float tensor of y scale factor.
flip: A [batch_size] shaped bool tensor indicating whether or not we
flip the image.
Returns:
transformed_images: Transfomed images of same shape as `images`.
transformed_boxes: If `boxes` was not None, transformed boxes of same
shape as boxes.
"""
instance_masks, batch_size, num_instances = flatten_first2_dims(
instance_masks)
repeat = tf.zeros_like(tx, dtype=tf.int32) + num_instances
tx = tf.repeat(tx, repeat)
ty = tf.repeat(ty, repeat)
sx = tf.repeat(sx, repeat)
sy = tf.repeat(sy, repeat)
flip = tf.repeat(flip, repeat)
instance_masks = instance_masks[:, :, :, tf.newaxis]
instance_masks, _ = transform_images_and_boxes(
instance_masks, boxes=None, tx=tx, ty=ty, sx=sx, sy=sy, flip=flip)
return unpack_first2_dims(
instance_masks[:, :, :, 0], batch_size, num_instances)
class ResNetMaskNetwork(tf.keras.layers.Layer):
"""A small wrapper around ResNet blocks to predict masks."""
......@@ -783,45 +931,39 @@ def deepmac_proto_to_params(deepmac_config):
# Add dummy localization loss to avoid the loss_builder throwing error.
loss.localization_loss.weighted_l2.CopyFrom(
losses_pb2.WeightedL2LocalizationLoss())
loss.classification_loss.CopyFrom(deepmac_config.classification_loss)
classification_loss, _, _, _, _, _, _ = (losses_builder.build(loss))
jitter_mode = preprocessor_pb2.RandomJitterBoxes.JitterMode.Name(
deepmac_config.jitter_mode).lower()
box_consistency_loss_normalize = center_net_pb2.LossNormalize.Name(
deepmac_config.box_consistency_loss_normalize).lower()
return DeepMACParams(
dim=deepmac_config.dim,
classification_loss=classification_loss,
task_loss_weight=deepmac_config.task_loss_weight,
pixel_embedding_dim=deepmac_config.pixel_embedding_dim,
allowed_masked_classes_ids=deepmac_config.allowed_masked_classes_ids,
mask_size=deepmac_config.mask_size,
mask_num_subsamples=deepmac_config.mask_num_subsamples,
use_xy=deepmac_config.use_xy,
network_type=deepmac_config.network_type,
use_instance_embedding=deepmac_config.use_instance_embedding,
num_init_channels=deepmac_config.num_init_channels,
predict_full_resolution_masks=
deepmac_config.predict_full_resolution_masks,
postprocess_crop_size=deepmac_config.postprocess_crop_size,
max_roi_jitter_ratio=deepmac_config.max_roi_jitter_ratio,
roi_jitter_mode=jitter_mode,
box_consistency_loss_weight=deepmac_config.box_consistency_loss_weight,
color_consistency_threshold=deepmac_config.color_consistency_threshold,
color_consistency_dilation=deepmac_config.color_consistency_dilation,
color_consistency_loss_weight=
deepmac_config.color_consistency_loss_weight,
box_consistency_loss_normalize=box_consistency_loss_normalize,
box_consistency_tightness=deepmac_config.box_consistency_tightness,
color_consistency_warmup_steps=
deepmac_config.color_consistency_warmup_steps,
color_consistency_warmup_start=
deepmac_config.color_consistency_warmup_start,
use_only_last_stage=deepmac_config.use_only_last_stage
)
deepmac_field_class = (
center_net_pb2.CenterNet.DESCRIPTOR.nested_types_by_name[
'DeepMACMaskEstimation'])
params = {}
for field in deepmac_field_class.fields:
value = getattr(deepmac_config, field.name)
if field.enum_type:
params[field.name] = field.enum_type.values_by_number[value].name.lower()
else:
params[field.name] = value
params['roi_jitter_mode'] = params.pop('jitter_mode')
params['classification_loss'] = classification_loss
return DeepMACParams(**params)
def _warmup_weight(current_training_step, warmup_start, warmup_steps):
"""Utility function for warming up loss weights."""
if warmup_steps == 0:
return 1.0
training_step = tf.cast(current_training_step, tf.float32)
warmup_steps = tf.cast(warmup_steps, tf.float32)
start_step = tf.cast(warmup_start, tf.float32)
warmup_weight = (training_step - start_step) / warmup_steps
warmup_weight = tf.clip_by_value(warmup_weight, 0.0, 1.0)
return warmup_weight
class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
......@@ -863,6 +1005,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
f'pixel_embedding_dim({pixel_embedding_dim}) '
f'must be same as dim({dim}).')
generator_class = tf.random.Generator
self._self_supervised_rng = generator_class.from_non_deterministic_state()
super(DeepMACMetaArch, self).__init__(
is_training=is_training, add_summaries=add_summaries,
num_classes=num_classes, feature_extractor=feature_extractor,
......@@ -991,13 +1135,74 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
return center_latents
def predict(self, preprocessed_inputs, other_inputs):
def predict(self, preprocessed_inputs, true_image_shapes):
prediction_dict = super(DeepMACMetaArch, self).predict(
preprocessed_inputs, other_inputs)
preprocessed_inputs, true_image_shapes)
mask_logits = self._predict_mask_logits_from_gt_boxes(prediction_dict)
prediction_dict[MASK_LOGITS_GT_BOXES] = mask_logits
if self._deepmac_params.augmented_self_supervision_loss_weight > 0.0:
prediction_dict[SELF_SUPERVISED_DEAUGMENTED_MASK_LOGITS] = (
self._predict_deaugmented_mask_logits_on_augmented_inputs(
preprocessed_inputs, true_image_shapes))
return prediction_dict
def _predict_deaugmented_mask_logits_on_augmented_inputs(
self, preprocessed_inputs, true_image_shapes):
"""Predicts masks on augmented images and reverses that augmentation.
The masks are de-augmented so that they are aligned with the original image.
Args:
preprocessed_inputs: A batch of images of shape
[batch_size, height, width, 3].
true_image_shapes: True shape of the image in case there is any padding.
Returns:
mask_logits:
A float tensor of shape [batch_size, num_instances,
output_height, output_width, ]
"""
batch_size = tf.shape(preprocessed_inputs)[0]
gt_boxes = _batch_gt_list(
self.groundtruth_lists(fields.BoxListFields.boxes))
max_t = self._deepmac_params.augmented_self_supervision_max_translation
tx = self._self_supervised_rng.uniform(
[batch_size], minval=-max_t, maxval=max_t)
ty = self._self_supervised_rng.uniform(
[batch_size], minval=-max_t, maxval=max_t)
scale_min = self._deepmac_params.augmented_self_supervision_scale_min
scale_max = self._deepmac_params.augmented_self_supervision_scale_max
sx = self._self_supervised_rng.uniform([batch_size], minval=scale_min,
maxval=scale_max)
sy = self._self_supervised_rng.uniform([batch_size], minval=scale_min,
maxval=scale_max)
flip = (self._self_supervised_rng.uniform(
[batch_size], minval=0.0, maxval=1.0) <
self._deepmac_params.augmented_self_supervision_flip_probability)
augmented_inputs, augmented_boxes = transform_images_and_boxes(
preprocessed_inputs, gt_boxes, tx=tx, ty=ty, sx=sx, sy=sy, flip=flip
)
augmented_prediction_dict = super(DeepMACMetaArch, self).predict(
augmented_inputs, true_image_shapes)
augmented_masks_lists = self._predict_mask_logits_from_boxes(
augmented_prediction_dict, augmented_boxes)
deaugmented_masks_list = []
for mask_logits in augmented_masks_lists:
deaugmented_masks = transform_instance_masks(
mask_logits, tx=-tx, ty=-ty, sx=1.0/sx, sy=1.0/sy, flip=flip)
deaugmented_masks = tf.stop_gradient(deaugmented_masks)
deaugmented_masks_list.append(deaugmented_masks)
return deaugmented_masks_list
def _predict_mask_logits_from_embeddings(
self, pixel_embedding, instance_embedding, boxes):
mask_input = self._get_mask_head_input(boxes, pixel_embedding)
......@@ -1014,10 +1219,22 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
mask_logits, batch_size, num_instances)
return mask_logits
def _predict_mask_logits_from_gt_boxes(self, prediction_dict):
def _predict_mask_logits_from_boxes(self, prediction_dict, boxes):
"""Predict mask logits using the predict dict and the given set of boxes.
Args:
prediction_dict: a dict containing the keys INSTANCE_EMBEDDING and
PIXEL_EMBEDDING, both expected to be list of tensors.
boxes: A [batch_size, num_instances, 4] float tensor of boxes in the
normalized coordinate system.
Returns:
mask_logits_list: A list of mask logits with the same spatial extents
as prediction_dict[PIXEL_EMBEDDING].
Returns:
"""
mask_logits_list = []
boxes = _batch_gt_list(self.groundtruth_lists(fields.BoxListFields.boxes))
instance_embedding_list = prediction_dict[INSTANCE_EMBEDDING]
pixel_embedding_list = prediction_dict[PIXEL_EMBEDDING]
......@@ -1035,6 +1252,11 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
return mask_logits_list
def _predict_mask_logits_from_gt_boxes(self, prediction_dict):
return self._predict_mask_logits_from_boxes(
prediction_dict,
_batch_gt_list(self.groundtruth_lists(fields.BoxListFields.boxes)))
def _get_groundtruth_mask_output(self, boxes, masks):
"""Get the expected mask output for each box.
......@@ -1262,33 +1484,111 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
num_box_pixels = tf.maximum(1.0, tf.reduce_sum(box_mask, axis=[2, 3]))
loss = loss / num_box_pixels
if ((self._deepmac_params.color_consistency_warmup_steps > 0) and
tf.keras.backend.learning_phase()):
training_step = tf.cast(self.training_step, tf.float32)
warmup_steps = tf.cast(
self._deepmac_params.color_consistency_warmup_steps, tf.float32)
start_step = tf.cast(
self._deepmac_params.color_consistency_warmup_start, tf.float32)
warmup_weight = (training_step - start_step) / warmup_steps
warmup_weight = tf.clip_by_value(warmup_weight, 0.0, 1.0)
loss *= warmup_weight
if tf.keras.backend.learning_phase():
loss *= _warmup_weight(
current_training_step=self._training_step,
warmup_start=self._deepmac_params.color_consistency_warmup_start,
warmup_steps=self._deepmac_params.color_consistency_warmup_steps)
return loss
def _self_supervision_loss(
self, predicted_logits, self_supervised_logits, boxes, loss_name):
original_shape = tf.shape(predicted_logits)
batch_size, num_instances = original_shape[0], original_shape[1]
box_mask = fill_boxes(boxes, original_shape[2], original_shape[3])
loss_tensor_shape = [batch_size * num_instances, -1, 1]
weights = tf.reshape(box_mask, loss_tensor_shape)
predicted_logits = tf.reshape(predicted_logits, loss_tensor_shape)
self_supervised_logits = tf.reshape(self_supervised_logits,
loss_tensor_shape)
self_supervised_probs = tf.nn.sigmoid(self_supervised_logits)
predicted_probs = tf.nn.sigmoid(predicted_logits)
num_box_pixels = tf.reduce_sum(weights, axis=[1, 2])
num_box_pixels = tf.maximum(num_box_pixels, 1.0)
if loss_name == 'loss_dice':
self_supervised_binary_probs = tf.cast(
self_supervised_logits > 0.0, tf.float32)
loss_class = losses.WeightedDiceClassificationLoss(
squared_normalization=False)
loss = loss_class(prediction_tensor=predicted_logits,
target_tensor=self_supervised_binary_probs,
weights=weights)
agg_loss = self._aggregate_classification_loss(
loss, gt=self_supervised_probs, pred=predicted_logits,
method='normalize_auto')
elif loss_name == 'loss_mse':
diff = self_supervised_probs - predicted_probs
diff_sq = (diff * diff)
diff_sq_sum = tf.reduce_sum(diff_sq * weights, axis=[1, 2])
agg_loss = diff_sq_sum / num_box_pixels
elif loss_name == 'loss_kl_div':
loss_class = tf.keras.losses.KLDivergence(
reduction=tf.keras.losses.Reduction.NONE)
predicted_2class_probability = tf.stack(
[predicted_probs, 1 - predicted_probs], axis=2
)
target_2class_probability = tf.stack(
[self_supervised_probs, 1 - self_supervised_probs], axis=2
)
loss = loss_class(
y_pred=predicted_2class_probability,
y_true=target_2class_probability)
agg_loss = tf.reduce_sum(loss * weights, axis=[1, 2]) / num_box_pixels
else:
raise RuntimeError('Unknown self-supervision loss %s' % loss_name)
return tf.reshape(agg_loss, [batch_size, num_instances])
def _compute_self_supervised_augmented_loss(
self, original_logits, deaugmented_logits, boxes):
if deaugmented_logits is None:
logging.info('No self supervised masks provided. '
'Returning 0 self-supervised loss,')
return tf.zeros(tf.shape(original_logits)[:2])
loss = self._self_supervision_loss(
predicted_logits=original_logits,
self_supervised_logits=deaugmented_logits,
boxes=boxes,
loss_name=self._deepmac_params.augmented_self_supervision_loss)
if tf.keras.backend.learning_phase():
loss *= _warmup_weight(
current_training_step=self._training_step,
warmup_start=
self._deepmac_params.augmented_self_supervision_warmup_start,
warmup_steps=
self._deepmac_params.augmented_self_supervision_warmup_steps)
return loss
def _compute_deepmac_losses(
self, boxes, masks_logits, masks_gt, image):
self, boxes, masks_logits, masks_gt, image,
self_supervised_masks_logits=None):
"""Returns the mask loss per instance.
Args:
boxes: A [batch_size, num_instances, 4] float tensor holding bounding
boxes. The coordinates are in normalized input space.
masks_logits: A [batch_size, num_instances, input_height, input_width]
masks_logits: A [batch_size, num_instances, output_height, output_height].
float tensor containing the instance mask predictions in their logit
form.
masks_gt: A [batch_size, num_instances, input_height, input_width] float
masks_gt: A [batch_size, num_instances, output_height, output_width] float
tensor containing the groundtruth masks.
image: [batch_size, output_height, output_width, channels] float tensor
denoting the input image.
self_supervised_masks_logits: Optional self-supervised mask logits to
compare against of same shape as mask_logits.
Returns:
mask_prediction_loss: A [batch_size, num_instances] shaped float tensor
......@@ -1323,10 +1623,15 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
color_consistency_loss = self._compute_color_consistency_loss(
boxes, image, masks_logits)
self_supervised_loss = self._compute_self_supervised_augmented_loss(
masks_logits, self_supervised_masks_logits, boxes,
)
return {
DEEP_MASK_ESTIMATION: mask_prediction_loss,
DEEP_MASK_BOX_CONSISTENCY: box_consistency_loss,
DEEP_MASK_COLOR_CONSISTENCY: color_consistency_loss
DEEP_MASK_COLOR_CONSISTENCY: color_consistency_loss,
DEEP_MASK_AUGMENTED_SELF_SUPERVISION: self_supervised_loss
}
def _get_lab_image(self, preprocessed_image):
......@@ -1366,8 +1671,6 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
prediction_dict['preprocessed_inputs'], (height, width))
image = self._get_lab_image(preprocessed_image)
# TODO(vighneshb) See if we can save memory by only using the final
# prediction
# Iterate over multiple preidctions by backbone (for hourglass length=2)
gt_boxes = _batch_gt_list(
......@@ -1380,7 +1683,13 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
self.groundtruth_lists(fields.BoxListFields.classes))
mask_logits_list = prediction_dict[MASK_LOGITS_GT_BOXES]
for mask_logits in mask_logits_list:
self_supervised_mask_logits_list = prediction_dict.get(
SELF_SUPERVISED_DEAUGMENTED_MASK_LOGITS,
[None] * len(mask_logits_list))
assert len(mask_logits_list) == len(self_supervised_mask_logits_list)
for (mask_logits, self_supervised_mask_logits) in zip(
mask_logits_list, self_supervised_mask_logits_list):
# TODO(vighneshb) Add sub-sampling back if required.
_, valid_mask_weights, gt_masks = filter_masked_classes(
......@@ -1388,7 +1697,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
gt_weights, gt_masks)
sample_loss_dict = self._compute_deepmac_losses(
gt_boxes, mask_logits, gt_masks, image)
gt_boxes, mask_logits, gt_masks, image,
self_supervised_masks_logits=self_supervised_mask_logits)
sample_loss_dict[DEEP_MASK_ESTIMATION] *= valid_mask_weights
for loss_name in WEAK_LOSSES:
......@@ -1435,7 +1745,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
prediction_dict: a dictionary holding predicted tensors from "predict"
function.
true_image_shapes: int32 tensor of shape [batch, 3] where each row is of
the form [height, width, channels] indicating the shapes of true images
the form [height, width, channels] indicating the shapes of true images
in the resized images, as resized images can be padded with zeros.
**params: Currently ignored.
......@@ -1480,13 +1790,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
height, width = (tf.shape(instance_embedding)[1],
tf.shape(instance_embedding)[2])
height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32)
ymin, xmin, ymax, xmax = tf.unstack(boxes_output_stride, axis=2)
ymin /= height
ymax /= height
xmin /= width
xmax /= width
boxes = tf.stack([ymin, xmin, ymax, xmax], axis=2)
boxes = boxes_batch_absolute_to_normalized_coordinates(
boxes_output_stride, height, width)
mask_logits = self._predict_mask_logits_from_embeddings(
pixel_embedding, instance_embedding, boxes)
......
"""Tests for google3.third_party.tensorflow_models.object_detection.meta_architectures.deepmac_meta_arch."""
import functools
import math
import random
import unittest
......@@ -17,35 +18,12 @@ from object_detection.protos import center_net_pb2
from object_detection.utils import tf_version
DEEPMAC_PROTO_TEXT = """
dim: 153
task_loss_weight: 5.0
pixel_embedding_dim: 8
use_xy: false
use_instance_embedding: false
network_type: "cond_inst3"
num_init_channels: 8
def _logit(probability):
return math.log(probability / (1. - probability))
classification_loss {
weighted_dice_classification_loss {
squared_normalization: false
is_prediction_probability: false
}
}
jitter_mode: EXPAND_SYMMETRIC_XY
max_roi_jitter_ratio: 0.0
predict_full_resolution_masks: true
allowed_masked_classes_ids: [99]
box_consistency_loss_weight: 1.0
color_consistency_loss_weight: 1.0
color_consistency_threshold: 0.1
box_consistency_tightness: false
box_consistency_loss_normalize: NORMALIZE_AUTO
color_consistency_warmup_steps: 20
color_consistency_warmup_start: 10
use_only_last_stage: false
"""
LOGIT_HALF = _logit(0.5)
LOGIT_QUARTER = _logit(0.25)
class DummyFeatureExtractor(center_net_meta_arch.CenterNetFeatureExtractor):
......@@ -122,7 +100,15 @@ def build_meta_arch(**override_params):
color_consistency_dilation=2,
color_consistency_warmup_steps=0,
color_consistency_warmup_start=0,
use_only_last_stage=True)
use_only_last_stage=True,
augmented_self_supervision_max_translation=0.0,
augmented_self_supervision_loss_weight=0.0,
augmented_self_supervision_flip_probability=0.0,
augmented_self_supervision_warmup_start=0,
augmented_self_supervision_warmup_steps=0,
augmented_self_supervision_loss='loss_dice',
augmented_self_supervision_scale_min=1.0,
augmented_self_supervision_scale_max=1.0)
params.update(override_params)
......@@ -176,6 +162,45 @@ def build_meta_arch(**override_params):
image_resizer_fn=image_resizer_fn)
DEEPMAC_PROTO_TEXT = """
dim: 153
task_loss_weight: 5.0
pixel_embedding_dim: 8
use_xy: false
use_instance_embedding: false
network_type: "cond_inst3"
classification_loss {
weighted_dice_classification_loss {
squared_normalization: false
is_prediction_probability: false
}
}
jitter_mode: EXPAND_SYMMETRIC_XY
max_roi_jitter_ratio: 0.0
predict_full_resolution_masks: true
allowed_masked_classes_ids: [99]
box_consistency_loss_weight: 1.0
color_consistency_loss_weight: 1.0
color_consistency_threshold: 0.1
box_consistency_tightness: false
box_consistency_loss_normalize: NORMALIZE_AUTO
color_consistency_warmup_steps: 20
color_consistency_warmup_start: 10
use_only_last_stage: false
augmented_self_supervision_warmup_start: 13
augmented_self_supervision_warmup_steps: 14
augmented_self_supervision_loss: LOSS_MSE
augmented_self_supervision_loss_weight: 11.0
augmented_self_supervision_max_translation: 2.5
augmented_self_supervision_flip_probability: 0.9
augmented_self_supervision_scale_min: 0.42
augmented_self_supervision_scale_max: 1.42
"""
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class DeepMACUtilsTest(tf.test.TestCase, parameterized.TestCase):
......@@ -185,9 +210,21 @@ class DeepMACUtilsTest(tf.test.TestCase, parameterized.TestCase):
text_format.Parse(DEEPMAC_PROTO_TEXT, proto)
params = deepmac_meta_arch.deepmac_proto_to_params(proto)
self.assertIsInstance(params, deepmac_meta_arch.DeepMACParams)
self.assertEqual(params.num_init_channels, 64)
self.assertEqual(params.dim, 153)
self.assertEqual(params.box_consistency_loss_normalize, 'normalize_auto')
self.assertFalse(params.use_only_last_stage)
self.assertEqual(params.augmented_self_supervision_warmup_start, 13)
self.assertEqual(params.augmented_self_supervision_warmup_steps, 14)
self.assertEqual(params.augmented_self_supervision_loss, 'loss_mse')
self.assertEqual(params.augmented_self_supervision_loss_weight, 11.0)
self.assertEqual(params.augmented_self_supervision_max_translation, 2.5)
self.assertAlmostEqual(
params.augmented_self_supervision_flip_probability, 0.9)
self.assertAlmostEqual(
params.augmented_self_supervision_scale_min, 0.42)
self.assertAlmostEqual(
params.augmented_self_supervision_scale_max, 1.42)
def test_subsample_trivial(self):
"""Test subsampling masks."""
......@@ -590,6 +627,190 @@ class DeepMACMaskHeadTest(tf.test.TestCase, parameterized.TestCase):
tf.zeros((0, 32, 32, input_channels)), training=True)
self.assertEqual(out.shape, (0, 32, 32))
@parameterized.parameters(
[
dict(x=4, y=4, height=4, width=4),
dict(x=1, y=2, height=3, width=4),
dict(x=14, y=14, height=5, width=5),
]
)
def test_transform_images_and_boxes_identity(self, x, y, height, width):
images = np.zeros((1, 32, 32, 3), dtype=np.float32)
images[:, y:y + height, x:x + width, :] = 1.0
boxes = tf.constant([[[y / 32., x / 32.,
y / 32. + height/32, x/32. + width / 32]]])
zeros = tf.zeros(1)
ones = tf.ones(1)
falses = tf.zeros(1, dtype=tf.bool)
images = tf.constant(images)
images_out, boxes_out = deepmac_meta_arch.transform_images_and_boxes(
images, boxes, zeros, zeros, ones, ones, falses)
self.assertAllClose(images, images_out)
self.assertAllClose(boxes, boxes_out)
coords = np.argwhere(images_out.numpy()[0, :, :, 0] > 0.5)
self.assertEqual(np.min(coords[:, 0]), y)
self.assertEqual(np.min(coords[:, 1]), x)
self.assertEqual(np.max(coords[:, 0]), y + height - 1)
self.assertEqual(np.max(coords[:, 1]), x + width - 1)
def test_transform_images_and_boxes(self):
images = np.zeros((2, 32, 32, 3), dtype=np.float32)
images[:, 14:19, 14:19, :] = 1.0
boxes = tf.constant(
[[[14.0 / 32, 14.0 / 32, 18.0 / 32, 18.0 / 32]] * 2] * 2)
flip = tf.constant([False, False])
scale_y0 = 2.0
translate_y0 = 1.0
scale_x0 = 4.0
translate_x0 = 4.0
scale_y1 = 3.0
translate_y1 = 3.0
scale_x1 = 0.5
translate_x1 = 2.0
ty = tf.constant([translate_y0/32, translate_y1/32])
sy = tf.constant([1./scale_y0, 1.0 / scale_y1])
tx = tf.constant([translate_x0/32, translate_x1/32])
sx = tf.constant([1 / scale_x0, 1.0 / scale_x1])
images = tf.constant(images)
images_out, boxes_out = deepmac_meta_arch.transform_images_and_boxes(
images, boxes, tx=tx, ty=ty, sx=sx, sy=sy, flip=flip)
boxes_out = boxes_out.numpy() * 32
coords = np.argwhere(images_out[0, :, :, 0] >= 0.9)
ymin = np.min(coords[:, 0])
ymax = np.max(coords[:, 0])
xmin = np.min(coords[:, 1])
xmax = np.max(coords[:, 1])
self.assertAlmostEqual(
ymin, 16 - 2*scale_y0 + translate_y0, delta=1)
self.assertAlmostEqual(
ymax, 16 + 2*scale_y0 + translate_y0, delta=1)
self.assertAlmostEqual(
xmin, 16 - 2*scale_x0 + translate_x0, delta=1)
self.assertAlmostEqual(
xmax, 16 + 2*scale_x0 + translate_x0, delta=1)
self.assertAlmostEqual(ymin, boxes_out[0, 0, 0], delta=1)
self.assertAlmostEqual(xmin, boxes_out[0, 0, 1], delta=1)
self.assertAlmostEqual(ymax, boxes_out[0, 0, 2], delta=1)
self.assertAlmostEqual(xmax, boxes_out[0, 0, 3], delta=1)
coords = np.argwhere(images_out[1, :, :, 0] >= 0.9)
ymin = np.min(coords[:, 0])
ymax = np.max(coords[:, 0])
xmin = np.min(coords[:, 1])
xmax = np.max(coords[:, 1])
self.assertAlmostEqual(
ymin, 16 - 2*scale_y1 + translate_y1, delta=1)
self.assertAlmostEqual(
ymax, 16 + 2*scale_y1 + translate_y1, delta=1)
self.assertAlmostEqual(
xmin, 16 - 2*scale_x1 + translate_x1, delta=1)
self.assertAlmostEqual(
xmax, 16 + 2*scale_x1 + translate_x1, delta=1)
self.assertAlmostEqual(ymin, boxes_out[1, 0, 0], delta=1)
self.assertAlmostEqual(xmin, boxes_out[1, 0, 1], delta=1)
self.assertAlmostEqual(ymax, boxes_out[1, 0, 2], delta=1)
self.assertAlmostEqual(xmax, boxes_out[1, 0, 3], delta=1)
def test_transform_images_and_boxes_flip(self):
images = np.zeros((2, 2, 2, 1), dtype=np.float32)
images[0, :, :, 0] = [[1, 2], [3, 4]]
images[1, :, :, 0] = [[1, 2], [3, 4]]
images = tf.constant(images)
boxes = tf.constant(
[[[0.1, 0.2, 0.3, 0.4]], [[0.1, 0.2, 0.3, 0.4]]], dtype=tf.float32)
tx = ty = tf.zeros([2], dtype=tf.float32)
sx = sy = tf.ones([2], dtype=tf.float32)
flip = tf.constant([True, False])
output_images, output_boxes = deepmac_meta_arch.transform_images_and_boxes(
images, boxes, tx, ty, sx, sy, flip)
expected_images = np.zeros((2, 2, 2, 1), dtype=np.float32)
expected_images[0, :, :, 0] = [[2, 1], [4, 3]]
expected_images[1, :, :, 0] = [[1, 2], [3, 4]]
self.assertAllClose(output_boxes,
[[[0.1, 0.6, 0.3, 0.8]], [[0.1, 0.2, 0.3, 0.4]]])
self.assertAllClose(expected_images, output_images)
def test_transform_images_and_boxes_tf_function(self):
func = tf.function(deepmac_meta_arch.transform_images_and_boxes)
output, _ = func(images=tf.zeros((2, 32, 32, 3)), boxes=tf.zeros((2, 5, 4)),
tx=tf.zeros(2), ty=tf.zeros(2),
sx=tf.ones(2), sy=tf.ones(2),
flip=tf.zeros(2, dtype=tf.bool))
self.assertEqual(output.shape, (2, 32, 32, 3))
def test_transform_instance_masks(self):
instance_masks = np.zeros((2, 10, 32, 32), dtype=np.float32)
instance_masks[0, 0, 1, 1] = 1
instance_masks[0, 1, 1, 1] = 1
instance_masks[1, 0, 2, 2] = 1
instance_masks[1, 1, 2, 2] = 1
tx = ty = tf.constant([1., 2.]) / 32.0
sx = sy = tf.ones(2, dtype=tf.float32)
flip = tf.zeros(2, dtype=tf.bool)
instance_masks = deepmac_meta_arch.transform_instance_masks(
instance_masks, tx, ty, sx, sy, flip=flip)
self.assertEqual(instance_masks.shape, (2, 10, 32, 32))
self.assertAlmostEqual(
instance_masks[0].numpy().sum(), 2.0)
self.assertGreater(
instance_masks[0, 0, 2, 2].numpy(), 0.5)
self.assertGreater(
instance_masks[0, 1, 2, 2].numpy(), 0.5)
self.assertAlmostEqual(
instance_masks[1].numpy().sum(), 2.0)
self.assertGreater(
instance_masks[1, 0, 4, 4].numpy(), 0.5)
self.assertGreater(
instance_masks[1, 1, 4, 4].numpy(), 0.5)
def test_augment_image_and_deaugment_mask(self):
img = np.zeros((1, 32, 32, 3), dtype=np.float32)
img[0, 10:12, 10:12, :] = 1.0
tx = ty = tf.constant([1.]) / 32.0
sx = sy = tf.constant([1.0 / 2.0])
flip = tf.constant([False])
img = tf.constant(img)
img_t, _ = deepmac_meta_arch.transform_images_and_boxes(
images=img, boxes=None, tx=tx, ty=ty, sx=sx, sy=sy, flip=flip)
self.assertAlmostEqual(img_t.numpy().sum(), 16 * 3)
# Converting channels of the image to instances.
masks = tf.transpose(img_t, (0, 3, 1, 2))
masks_t = deepmac_meta_arch.transform_instance_masks(
masks, tx=-tx, ty=-ty, sx=1.0/sx, sy=1.0/sy, flip=flip)
self.assertAlmostEqual(masks_t.numpy().sum(), 4 * 3)
coords = np.argwhere(masks_t[0, 0, :, :] >= 0.5)
self.assertAlmostEqual(np.min(coords[:, 0]), 10, delta=1)
self.assertAlmostEqual(np.max(coords[:, 0]), 12, delta=1)
self.assertAlmostEqual(np.min(coords[:, 1]), 10, delta=1)
self.assertAlmostEqual(np.max(coords[:, 1]), 12, delta=1)
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
......@@ -716,6 +937,24 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(prediction['MASK_LOGITS_GT_BOXES'][0].shape,
(1, 5, 16, 16))
def test_predict_self_supervised_deaugmented_mask_logits(self):
model = build_meta_arch(
augmented_self_supervision_loss_weight=1.0,
predict_full_resolution_masks=True)
model.provide_groundtruth(
groundtruth_boxes_list=[tf.convert_to_tensor([[0., 0., 1., 1.]] * 5)],
groundtruth_classes_list=[tf.one_hot([1, 0, 1, 1, 1], depth=6)],
groundtruth_weights_list=[tf.ones(5)],
groundtruth_masks_list=[tf.ones((5, 32, 32))])
prediction = model.predict(tf.zeros((1, 32, 32, 3)), None)
self.assertEqual(prediction['MASK_LOGITS_GT_BOXES'][0].shape,
(1, 5, 8, 8))
self.assertEqual(
prediction['SELF_SUPERVISED_DEAUGMENTED_MASK_LOGITS'][0].shape,
(1, 5, 8, 8))
def test_loss(self):
model = build_meta_arch()
......@@ -1036,14 +1275,182 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
output = self.model._get_lab_image(tf.zeros((2, 4, 4, 3)))
self.assertEqual(output.shape, (2, 4, 4, 3))
def test_self_supervised_augmented_loss_identity(self):
model = build_meta_arch(predict_full_resolution_masks=True,
augmented_self_supervision_max_translation=0.0)
x = tf.random.uniform((2, 3, 32, 32), 0, 1)
boxes = tf.constant([[0., 0., 1., 1.]] * 6)
boxes = tf.reshape(boxes, [2, 3, 4])
x = tf.cast(x > 0, tf.float32)
x = (x - 0.5) * 2e40 # x is a tensor or large +ve or -ve values.
loss = model._compute_self_supervised_augmented_loss(x, x, boxes)
self.assertAlmostEqual(loss.numpy().sum(), 0.0)
def test_self_supervised_mse_augmented_loss_0(self):
model = build_meta_arch(predict_full_resolution_masks=True,
augmented_self_supervision_max_translation=0.0,
augmented_self_supervision_loss='loss_mse')
x = tf.random.uniform((2, 3, 32, 32), 0, 1)
boxes = tf.constant([[0., 0., 1., 1.]] * 6)
boxes = tf.reshape(boxes, [2, 3, 4])
loss = model._compute_self_supervised_augmented_loss(x, x, boxes)
self.assertAlmostEqual(loss.numpy().min(), 0.0)
self.assertAlmostEqual(loss.numpy().max(), 0.0)
def test_self_supervised_mse_loss_scale_equivalent(self):
model = build_meta_arch(predict_full_resolution_masks=True,
augmented_self_supervision_max_translation=0.0,
augmented_self_supervision_loss='loss_mse')
x = np.zeros((1, 3, 32, 32), dtype=np.float32) + 100.0
y = 0.0 * x.copy()
x[0, 0, :8, :8] = 0.0
y[0, 0, :8, :8] = 1.0
x[0, 1, :16, :16] = 0.0
y[0, 1, :16, :16] = 1.0
x[0, 2, :16, :16] = 0.0
x[0, 2, :8, :8] = 1.0
y[0, 2, :16, :16] = 0.0
boxes = np.array([[0., 0., 0.22, 0.22], [0., 0., 0.47, 0.47],
[0., 0., 0.47, 0.47]],
dtype=np.float32)
boxes = tf.reshape(tf.constant(boxes), [1, 3, 4])
loss = model._compute_self_supervised_augmented_loss(x, y, boxes)
self.assertEqual(loss.shape, (1, 3))
mse_1_minus_0 = (tf.nn.sigmoid(1.0) - tf.nn.sigmoid(0.0)).numpy()**2
self.assertAlmostEqual(loss.numpy()[0, 0], mse_1_minus_0)
self.assertAlmostEqual(loss.numpy()[0, 1], mse_1_minus_0)
self.assertAlmostEqual(loss.numpy()[0, 2], mse_1_minus_0 / 4.0)
def test_self_supervised_kldiv_augmented_loss_0(self):
model = build_meta_arch(predict_full_resolution_masks=True,
augmented_self_supervision_max_translation=0.0,
augmented_self_supervision_loss='loss_kl_div')
x = tf.random.uniform((2, 3, 32, 32), 0, 1)
boxes = tf.constant([[0., 0., 1., 1.]] * 6)
boxes = tf.reshape(boxes, [2, 3, 4])
loss = model._compute_self_supervised_augmented_loss(x, x, boxes)
self.assertAlmostEqual(loss.numpy().min(), 0.0)
self.assertAlmostEqual(loss.numpy().max(), 0.0)
def test_self_supervised_kldiv_scale_equivalent(self):
model = build_meta_arch(predict_full_resolution_masks=True,
augmented_self_supervision_max_translation=0.0,
augmented_self_supervision_loss='loss_kl_div')
pred = np.zeros((1, 2, 32, 32), dtype=np.float32) + 100.0
true = 0.0 * pred.copy()
pred[0, 0, :8, :8] = LOGIT_HALF
true[0, 0, :8, :8] = LOGIT_QUARTER
pred[0, 1, :16, :16] = LOGIT_HALF
true[0, 1, :16, :16] = LOGIT_QUARTER
boxes = np.array([[0., 0., 0.22, 0.22], [0., 0., 0.47, 0.47]],
dtype=np.float32)
boxes = tf.reshape(tf.constant(boxes), [1, 2, 4])
loss = model._compute_self_supervised_augmented_loss(
original_logits=pred, deaugmented_logits=true, boxes=boxes)
self.assertEqual(loss.shape, (1, 2))
expected = (3 * math.log(3) - 4 * math.log(2)) / 4.0
self.assertAlmostEqual(loss.numpy()[0, 0], expected, places=4)
self.assertAlmostEqual(loss.numpy()[0, 1], expected, places=4)
def test_self_supervision_warmup(self):
tf.keras.backend.set_learning_phase(True)
model = build_meta_arch(
use_dice_loss=True,
predict_full_resolution_masks=True,
network_type='cond_inst1',
dim=9,
pixel_embedding_dim=8,
use_instance_embedding=False,
use_xy=False,
augmented_self_supervision_loss_weight=1.0,
augmented_self_supervision_max_translation=0.5,
augmented_self_supervision_warmup_start=10,
augmented_self_supervision_warmup_steps=40)
num_stages = 1
prediction = {
'preprocessed_inputs': tf.random.normal((1, 32, 32, 3)),
'MASK_LOGITS_GT_BOXES': [tf.random.normal((1, 5, 8, 8))] * num_stages,
'SELF_SUPERVISED_DEAUGMENTED_MASK_LOGITS':
[tf.random.normal((1, 5, 8, 8))] * num_stages,
'object_center': [tf.random.normal((1, 8, 8, 6))] * num_stages,
'box/offset': [tf.random.normal((1, 8, 8, 2))] * num_stages,
'box/scale': [tf.random.normal((1, 8, 8, 2))] * num_stages
}
boxes = [tf.convert_to_tensor([[0., 0., 1., 1.]] * 5)]
classes = [tf.one_hot([1, 0, 1, 1, 1], depth=6)]
weights = [tf.ones(5)]
masks = [tf.ones((5, 32, 32))]
model.provide_groundtruth(
groundtruth_boxes_list=boxes,
groundtruth_classes_list=classes,
groundtruth_weights_list=weights,
groundtruth_masks_list=masks,
training_step=5)
loss_at_5 = model.loss(prediction, tf.constant([[32, 32, 3.0]]))
model.provide_groundtruth(
groundtruth_boxes_list=boxes,
groundtruth_classes_list=classes,
groundtruth_weights_list=weights,
groundtruth_masks_list=masks,
training_step=20)
loss_at_20 = model.loss(prediction, tf.constant([[32, 32, 3.0]]))
model.provide_groundtruth(
groundtruth_boxes_list=boxes,
groundtruth_classes_list=classes,
groundtruth_weights_list=weights,
groundtruth_masks_list=masks,
training_step=50)
loss_at_50 = model.loss(prediction, tf.constant([[32, 32, 3.0]]))
model.provide_groundtruth(
groundtruth_boxes_list=boxes,
groundtruth_classes_list=classes,
groundtruth_weights_list=weights,
groundtruth_masks_list=masks,
training_step=100)
loss_at_100 = model.loss(prediction, tf.constant([[32, 32, 3.0]]))
loss_key = 'Loss/' + deepmac_meta_arch.DEEP_MASK_AUGMENTED_SELF_SUPERVISION
self.assertAlmostEqual(loss_at_5[loss_key].numpy(), 0.0)
self.assertGreater(loss_at_20[loss_key], 0.0)
self.assertAlmostEqual(loss_at_20[loss_key].numpy(),
loss_at_50[loss_key].numpy() / 4.0)
self.assertAlmostEqual(loss_at_50[loss_key].numpy(),
loss_at_100[loss_key].numpy())
def test_loss_keys(self):
model = build_meta_arch(use_dice_loss=True)
model = build_meta_arch(use_dice_loss=True,
augmented_self_supervision_loss_weight=1.0,
augmented_self_supervision_max_translation=0.5)
prediction = {
'preprocessed_inputs': tf.random.normal((3, 32, 32, 3)),
'MASK_LOGITS_GT_BOXES': [tf.random.normal((3, 5, 8, 8))] * 2,
'object_center': [tf.random.normal((3, 8, 8, 6))] * 2,
'box/offset': [tf.random.normal((3, 8, 8, 2))] * 2,
'box/scale': [tf.random.normal((3, 8, 8, 2))] * 2
'box/scale': [tf.random.normal((3, 8, 8, 2))] * 2,
'SELF_SUPERVISED_DEAUGMENTED_MASK_LOGITS': (
[tf.random.normal((3, 5, 8, 8))] * 2)
}
model.provide_groundtruth(
groundtruth_boxes_list=[
......@@ -1061,6 +1468,7 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
'{} was <= 0'.format(weak_loss))
def test_loss_weight_response(self):
tf.random.set_seed(12)
model = build_meta_arch(
use_dice_loss=True,
predict_full_resolution_masks=True,
......@@ -1068,14 +1476,19 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
dim=9,
pixel_embedding_dim=8,
use_instance_embedding=False,
use_xy=False)
use_xy=False,
augmented_self_supervision_loss_weight=1.0,
augmented_self_supervision_max_translation=0.5,
)
num_stages = 1
prediction = {
'preprocessed_inputs': tf.random.normal((1, 32, 32, 3)),
'MASK_LOGITS_GT_BOXES': [tf.random.normal((1, 5, 8, 8))] * num_stages,
'object_center': [tf.random.normal((1, 8, 8, 6))] * num_stages,
'box/offset': [tf.random.normal((1, 8, 8, 2))] * num_stages,
'box/scale': [tf.random.normal((1, 8, 8, 2))] * num_stages
'box/scale': [tf.random.normal((1, 8, 8, 2))] * num_stages,
'SELF_SUPERVISED_DEAUGMENTED_MASK_LOGITS': (
[tf.random.normal((1, 5, 8, 8))] * num_stages)
}
boxes = [tf.convert_to_tensor([[0., 0., 1., 1.]] * 5)]
......@@ -1098,7 +1511,9 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
loss_weights = {
deepmac_meta_arch.DEEP_MASK_ESTIMATION: rng.uniform(1, 5),
deepmac_meta_arch.DEEP_MASK_BOX_CONSISTENCY: rng.uniform(1, 5),
deepmac_meta_arch.DEEP_MASK_COLOR_CONSISTENCY: rng.uniform(1, 5)
deepmac_meta_arch.DEEP_MASK_COLOR_CONSISTENCY: rng.uniform(1, 5),
deepmac_meta_arch.DEEP_MASK_AUGMENTED_SELF_SUPERVISION: (
rng.uniform(1, 5))
}
weighted_model = build_meta_arch(
......@@ -1113,7 +1528,11 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
box_consistency_loss_weight=(
loss_weights[deepmac_meta_arch.DEEP_MASK_BOX_CONSISTENCY]),
color_consistency_loss_weight=(
loss_weights[deepmac_meta_arch.DEEP_MASK_COLOR_CONSISTENCY]))
loss_weights[deepmac_meta_arch.DEEP_MASK_COLOR_CONSISTENCY]),
augmented_self_supervision_loss_weight=(
loss_weights[deepmac_meta_arch.DEEP_MASK_AUGMENTED_SELF_SUPERVISION]
)
)
weighted_model.provide_groundtruth(
groundtruth_boxes_list=boxes,
......@@ -1188,6 +1607,7 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
loss_key = 'Loss/' + deepmac_meta_arch.DEEP_MASK_COLOR_CONSISTENCY
self.assertAlmostEqual(loss_at_5[loss_key].numpy(), 0.0)
self.assertGreater(loss_at_15[loss_key], 0.0)
self.assertAlmostEqual(loss_at_15[loss_key].numpy(),
loss_at_20[loss_key].numpy() / 2.0)
self.assertAlmostEqual(loss_at_20[loss_key].numpy(),
......
......@@ -403,7 +403,7 @@ message CenterNet {
// Mask prediction support using DeepMAC. See https://arxiv.org/abs/2104.00613
// Next ID 25
// Next ID 33
message DeepMACMaskEstimation {
// The loss used for penalizing mask predictions.
optional ClassificationLoss classification_loss = 1;
......@@ -505,6 +505,21 @@ message CenterNet {
optional bool use_only_last_stage = 24 [default = false];
optional float augmented_self_supervision_max_translation = 25 [default=0.0];
optional float augmented_self_supervision_flip_probability = 26 [default=0.0];
optional float augmented_self_supervision_loss_weight = 27 [default=0.0];
optional int32 augmented_self_supervision_warmup_start = 28 [default=0];
optional int32 augmented_self_supervision_warmup_steps = 29 [default=0];
optional AugmentedSelfSupervisionLoss augmented_self_supervision_loss = 30 [default=LOSS_DICE];
optional float augmented_self_supervision_scale_min = 31 [default=1.0];
optional float augmented_self_supervision_scale_max = 32 [default=1.0];
}
optional DeepMACMaskEstimation deepmac_mask_estimation = 14;
......@@ -527,6 +542,13 @@ enum LossNormalize {
NORMALIZE_BALANCED = 3;
}
enum AugmentedSelfSupervisionLoss {
LOSS_UNSET = 0;
LOSS_DICE = 1;
LOSS_MSE = 2;
LOSS_KL_DIV = 3;
}
message CenterNetFeatureExtractor {
optional string type = 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment