Commit 9400d73c authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Implement bounding box tightness prior and CC loss warmup.

PiperOrigin-RevId: 415265931
parent 3b7bc268
...@@ -89,6 +89,7 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)): ...@@ -89,6 +89,7 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
""" """
self._num_classes = num_classes self._num_classes = num_classes
self._groundtruth_lists = {} self._groundtruth_lists = {}
self._training_step = None
super(DetectionModel, self).__init__() super(DetectionModel, self).__init__()
...@@ -132,6 +133,13 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)): ...@@ -132,6 +133,13 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
""" """
return field in self._groundtruth_lists return field in self._groundtruth_lists
@property
def training_step(self):
if self._training_step is None:
raise ValueError('Training step was not provided to the model.')
return self._training_step
@staticmethod @staticmethod
def get_side_inputs(features): def get_side_inputs(features):
"""Get side inputs from input features. """Get side inputs from input features.
...@@ -318,7 +326,8 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)): ...@@ -318,7 +326,8 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
groundtruth_verified_neg_classes=None, groundtruth_verified_neg_classes=None,
groundtruth_not_exhaustive_classes=None, groundtruth_not_exhaustive_classes=None,
groundtruth_keypoint_depths_list=None, groundtruth_keypoint_depths_list=None,
groundtruth_keypoint_depth_weights_list=None): groundtruth_keypoint_depth_weights_list=None,
training_step=None):
"""Provide groundtruth tensors. """Provide groundtruth tensors.
Args: Args:
...@@ -389,6 +398,8 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)): ...@@ -389,6 +398,8 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
groundtruth_keypoint_depth_weights_list: a list of 2-D tf.float32 tensors groundtruth_keypoint_depth_weights_list: a list of 2-D tf.float32 tensors
of shape [num_boxes, num_keypoints] containing the weights of the of shape [num_boxes, num_keypoints] containing the weights of the
relative depths. relative depths.
training_step: An integer denoting the current training step. This is
useful when models want to anneal loss terms.
""" """
self._groundtruth_lists[fields.BoxListFields.boxes] = groundtruth_boxes_list self._groundtruth_lists[fields.BoxListFields.boxes] = groundtruth_boxes_list
self._groundtruth_lists[ self._groundtruth_lists[
...@@ -468,6 +479,8 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)): ...@@ -468,6 +479,8 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
fields.InputDataFields fields.InputDataFields
.groundtruth_not_exhaustive_classes] = ( .groundtruth_not_exhaustive_classes] = (
groundtruth_not_exhaustive_classes) groundtruth_not_exhaustive_classes)
if training_step is not None:
self._training_step = training_step
@abc.abstractmethod @abc.abstractmethod
def regularization_losses(self): def regularization_losses(self):
......
...@@ -12,12 +12,12 @@ import tensorflow as tf ...@@ -12,12 +12,12 @@ import tensorflow as tf
from object_detection.builders import losses_builder from object_detection.builders import losses_builder
from object_detection.core import box_list from object_detection.core import box_list
from object_detection.core import box_list_ops from object_detection.core import box_list_ops
from object_detection.core import losses
from object_detection.core import preprocessor from object_detection.core import preprocessor
from object_detection.core import standard_fields as fields from object_detection.core import standard_fields as fields
from object_detection.meta_architectures import center_net_meta_arch from object_detection.meta_architectures import center_net_meta_arch
from object_detection.models.keras_models import hourglass_network from object_detection.models.keras_models import hourglass_network
from object_detection.models.keras_models import resnet_v1 from object_detection.models.keras_models import resnet_v1
from object_detection.protos import center_net_pb2
from object_detection.protos import losses_pb2 from object_detection.protos import losses_pb2
from object_detection.protos import preprocessor_pb2 from object_detection.protos import preprocessor_pb2
from object_detection.utils import shape_utils from object_detection.utils import shape_utils
...@@ -38,46 +38,26 @@ NEIGHBORS_2D = [[-1, -1], [-1, 0], [-1, 1], ...@@ -38,46 +38,26 @@ NEIGHBORS_2D = [[-1, -1], [-1, 0], [-1, 1],
[0, -1], [0, 1], [0, -1], [0, 1],
[1, -1], [1, 0], [1, 1]] [1, -1], [1, 0], [1, 1]]
WEAK_LOSSES = [DEEP_MASK_BOX_CONSISTENCY, DEEP_MASK_COLOR_CONSISTENCY] WEAK_LOSSES = [DEEP_MASK_BOX_CONSISTENCY, DEEP_MASK_COLOR_CONSISTENCY]
MASK_LOSSES = WEAK_LOSSES + [DEEP_MASK_ESTIMATION]
class DeepMACParams( DeepMACParams = collections.namedtuple('DeepMACParams', [
collections.namedtuple('DeepMACParams', [
'classification_loss', 'dim', 'task_loss_weight', 'pixel_embedding_dim', 'classification_loss', 'dim', 'task_loss_weight', 'pixel_embedding_dim',
'allowed_masked_classes_ids', 'mask_size', 'mask_num_subsamples', 'allowed_masked_classes_ids', 'mask_size', 'mask_num_subsamples',
'use_xy', 'network_type', 'use_instance_embedding', 'num_init_channels', 'use_xy', 'network_type', 'use_instance_embedding', 'num_init_channels',
'predict_full_resolution_masks', 'postprocess_crop_size', 'predict_full_resolution_masks', 'postprocess_crop_size',
'max_roi_jitter_ratio', 'roi_jitter_mode', 'max_roi_jitter_ratio', 'roi_jitter_mode',
'box_consistency_loss_weight', 'color_consistency_threshold', 'box_consistency_loss_weight', 'color_consistency_threshold',
'color_consistency_dilation', 'color_consistency_loss_weight' 'color_consistency_dilation', 'color_consistency_loss_weight',
])): 'box_consistency_loss_normalize', 'box_consistency_tightness',
"""Class holding the DeepMAC network configutration.""" 'color_consistency_warmup_steps', 'color_consistency_warmup_start'
])
__slots__ = ()
def __new__(cls, classification_loss, dim, task_loss_weight, def _get_loss_weight(loss_name, config):
pixel_embedding_dim, allowed_masked_classes_ids, mask_size, if loss_name == DEEP_MASK_ESTIMATION:
mask_num_subsamples, use_xy, network_type, use_instance_embedding, return config.task_loss_weight
num_init_channels, predict_full_resolution_masks, elif loss_name == DEEP_MASK_COLOR_CONSISTENCY:
postprocess_crop_size, max_roi_jitter_ratio,
roi_jitter_mode, box_consistency_loss_weight,
color_consistency_threshold, color_consistency_dilation,
color_consistency_loss_weight):
return super(DeepMACParams,
cls).__new__(cls, classification_loss, dim,
task_loss_weight, pixel_embedding_dim,
allowed_masked_classes_ids, mask_size,
mask_num_subsamples, use_xy, network_type,
use_instance_embedding, num_init_channels,
predict_full_resolution_masks,
postprocess_crop_size, max_roi_jitter_ratio,
roi_jitter_mode, box_consistency_loss_weight,
color_consistency_threshold,
color_consistency_dilation,
color_consistency_loss_weight)
def _get_weak_loss_weight(loss_name, config):
if loss_name == DEEP_MASK_COLOR_CONSISTENCY:
return config.color_consistency_loss_weight return config.color_consistency_loss_weight
elif loss_name == DEEP_MASK_BOX_CONSISTENCY: elif loss_name == DEEP_MASK_BOX_CONSISTENCY:
return config.box_consistency_loss_weight return config.box_consistency_loss_weight
...@@ -755,6 +735,9 @@ def deepmac_proto_to_params(deepmac_config): ...@@ -755,6 +735,9 @@ def deepmac_proto_to_params(deepmac_config):
jitter_mode = preprocessor_pb2.RandomJitterBoxes.JitterMode.Name( jitter_mode = preprocessor_pb2.RandomJitterBoxes.JitterMode.Name(
deepmac_config.jitter_mode).lower() deepmac_config.jitter_mode).lower()
box_consistency_loss_normalize = center_net_pb2.LossNormalize.Name(
deepmac_config.box_consistency_loss_normalize).lower()
return DeepMACParams( return DeepMACParams(
dim=deepmac_config.dim, dim=deepmac_config.dim,
classification_loss=classification_loss, classification_loss=classification_loss,
...@@ -775,7 +758,14 @@ def deepmac_proto_to_params(deepmac_config): ...@@ -775,7 +758,14 @@ def deepmac_proto_to_params(deepmac_config):
box_consistency_loss_weight=deepmac_config.box_consistency_loss_weight, box_consistency_loss_weight=deepmac_config.box_consistency_loss_weight,
color_consistency_threshold=deepmac_config.color_consistency_threshold, color_consistency_threshold=deepmac_config.color_consistency_threshold,
color_consistency_dilation=deepmac_config.color_consistency_dilation, color_consistency_dilation=deepmac_config.color_consistency_dilation,
color_consistency_loss_weight=deepmac_config.color_consistency_loss_weight color_consistency_loss_weight=
deepmac_config.color_consistency_loss_weight,
box_consistency_loss_normalize=box_consistency_loss_normalize,
box_consistency_tightness=deepmac_config.box_consistency_tightness,
color_consistency_warmup_steps=
deepmac_config.color_consistency_warmup_steps,
color_consistency_warmup_start=
deepmac_config.color_consistency_warmup_start
) )
...@@ -972,6 +962,60 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -972,6 +962,60 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
return resize_instance_masks(logits, (height, width)) return resize_instance_masks(logits, (height, width))
def _aggregate_classification_loss(self, loss, gt, pred, method):
"""Aggregates loss at a per-instance level.
When this function is used with mask-heads, num_classes is usually 1.
Args:
loss: A [num_instances, num_pixels, num_classes] or
[num_instances, num_classes] tensor. If the tensor is of rank 2, i.e.,
of the form [num_instances, num_classes], we will assume that the
number of pixels have already been nornalized.
gt: A [num_instances, num_pixels, num_classes] float tensor of
groundtruths.
pred: A [num_instances, num_pixels, num_classes] float tensor of
preditions.
method: A string in ['auto', 'groundtruth'].
'auto': When `loss` is rank 2, aggregates by sum. Otherwise, aggregates
by mean.
'groundtruth_count': Aggreagates the loss by computing sum and dividing
by the number of positive (1) groundtruth pixels.
'balanced': Normalizes each pixel by the number of positive or negative
pixels depending on the groundtruth.
Returns:
per_instance_loss: A [num_instances] float tensor.
"""
rank = len(loss.get_shape().as_list())
if rank == 2:
axes = [1]
else:
axes = [1, 2]
if method == 'normalize_auto':
normalization = 1.0
if rank == 2:
return tf.reduce_sum(loss, axis=axes)
else:
return tf.reduce_mean(loss, axis=axes)
elif method == 'normalize_groundtruth_count':
normalization = tf.reduce_sum(gt, axis=axes)
return tf.reduce_sum(loss, axis=axes) / normalization
elif method == 'normalize_balanced':
if rank != 3:
raise ValueError('Cannot apply normalized_balanced aggregation '
f'to loss of rank {rank}')
normalization = (
(gt * tf.reduce_sum(gt, keepdims=True, axis=axes)) +
(1 - gt) * tf.reduce_sum(1 - gt, keepdims=True, axis=axes))
return tf.reduce_sum(loss / normalization, axis=axes)
else:
raise ValueError('Unknown loss aggregation - {}'.format(method))
def _compute_per_instance_mask_prediction_loss( def _compute_per_instance_mask_prediction_loss(
self, boxes, mask_logits, mask_gt): self, boxes, mask_logits, mask_gt):
"""Compute the per-instance mask loss. """Compute the per-instance mask loss.
...@@ -995,14 +1039,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -995,14 +1039,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
target_tensor=mask_gt, target_tensor=mask_gt,
weights=tf.ones_like(mask_logits)) weights=tf.ones_like(mask_logits))
# TODO(vighneshb) Make this configurable via config. return self._aggregate_classification_loss(
# Skip normalization for dice loss because the denominator term already loss, mask_gt, mask_logits, 'normalize_auto')
# does normalization.
if isinstance(self._deepmac_params.classification_loss,
losses.WeightedDiceClassificationLoss):
return tf.reduce_sum(loss, axis=1)
else:
return tf.reduce_mean(loss, axis=[1, 2])
def _compute_per_instance_box_consistency_loss( def _compute_per_instance_box_consistency_loss(
self, boxes_gt, boxes_for_crop, mask_logits): self, boxes_gt, boxes_for_crop, mask_logits):
...@@ -1034,23 +1072,30 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1034,23 +1072,30 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
loss = 0.0 loss = 0.0
for axis in [1, 2]: for axis in [1, 2]:
pred_max = tf.reduce_max(pred_crop, axis=axis)[:, :, tf.newaxis]
if self._deepmac_params.box_consistency_tightness:
pred_max_raw = tf.reduce_max(pred_crop, axis=axis)
pred_max_within_box = tf.reduce_max(pred_crop * gt_crop, axis=axis)
box_1d = tf.reduce_max(gt_crop, axis=axis)
pred_max = ((box_1d * pred_max_within_box) +
((1 - box_1d) * pred_max_raw))
else:
pred_max = tf.reduce_max(pred_crop, axis=axis)
pred_max = pred_max[:, :, tf.newaxis]
gt_max = tf.reduce_max(gt_crop, axis=axis)[:, :, tf.newaxis] gt_max = tf.reduce_max(gt_crop, axis=axis)[:, :, tf.newaxis]
axis_loss = self._deepmac_params.classification_loss( raw_loss = self._deepmac_params.classification_loss(
prediction_tensor=pred_max, prediction_tensor=pred_max,
target_tensor=gt_max, target_tensor=gt_max,
weights=tf.ones_like(pred_max)) weights=tf.ones_like(pred_max))
loss += axis_loss
loss += self._aggregate_classification_loss(
# Skip normalization for dice loss because the denominator term already raw_loss, gt_max, pred_max,
# does normalization. self._deepmac_params.box_consistency_loss_normalize)
# TODO(vighneshb) Make this configurable via config.
if isinstance(self._deepmac_params.classification_loss, return loss
losses.WeightedDiceClassificationLoss):
return tf.reduce_sum(loss, axis=1)
else:
return tf.reduce_mean(loss, axis=[1, 2])
def _compute_per_instance_color_consistency_loss( def _compute_per_instance_color_consistency_loss(
self, boxes, preprocessed_image, mask_logits): self, boxes, preprocessed_image, mask_logits):
...@@ -1099,6 +1144,17 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1099,6 +1144,17 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
num_box_pixels = tf.maximum(1.0, tf.reduce_sum(box_mask, axis=[1, 2])) num_box_pixels = tf.maximum(1.0, tf.reduce_sum(box_mask, axis=[1, 2]))
loss = loss / num_box_pixels loss = loss / num_box_pixels
if ((self._deepmac_params.color_consistency_warmup_steps > 0) and
self._is_training):
training_step = tf.cast(self.training_step, tf.float32)
warmup_steps = tf.cast(
self._deepmac_params.color_consistency_warmup_steps, tf.float32)
start_step = tf.cast(
self._deepmac_params.color_consistency_warmup_start, tf.float32)
warmup_weight = (training_step - start_step) / warmup_steps
warmup_weight = tf.clip_by_value(warmup_weight, 0.0, 1.0)
loss *= warmup_weight
return loss return loss
def _compute_per_instance_deepmac_losses( def _compute_per_instance_deepmac_losses(
...@@ -1188,11 +1244,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1188,11 +1244,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
allowed_masked_classes_ids = ( allowed_masked_classes_ids = (
self._deepmac_params.allowed_masked_classes_ids) self._deepmac_params.allowed_masked_classes_ids)
loss_dict = { loss_dict = {}
DEEP_MASK_ESTIMATION: 0.0, for loss_name in MASK_LOSSES:
}
for loss_name in WEAK_LOSSES:
loss_dict[loss_name] = 0.0 loss_dict[loss_name] = 0.0
prediction_shape = tf.shape(prediction_dict[INSTANCE_EMBEDDING][0]) prediction_shape = tf.shape(prediction_dict[INSTANCE_EMBEDDING][0])
...@@ -1252,13 +1305,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1252,13 +1305,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
mask_loss_dict = self._compute_instance_masks_loss( mask_loss_dict = self._compute_instance_masks_loss(
prediction_dict=prediction_dict) prediction_dict=prediction_dict)
losses_dict[LOSS_KEY_PREFIX + '/' + DEEP_MASK_ESTIMATION] = ( for loss_name in MASK_LOSSES:
self._deepmac_params.task_loss_weight * mask_loss_dict[ loss_weight = _get_loss_weight(loss_name, self._deepmac_params)
DEEP_MASK_ESTIMATION]
)
for loss_name in WEAK_LOSSES:
loss_weight = _get_weak_loss_weight(loss_name, self._deepmac_params)
if loss_weight > 0.0: if loss_weight > 0.0:
losses_dict[LOSS_KEY_PREFIX + '/' + loss_name] = ( losses_dict[LOSS_KEY_PREFIX + '/' + loss_name] = (
loss_weight * mask_loss_dict[loss_name]) loss_weight * mask_loss_dict[loss_name])
......
"""Tests for google3.third_party.tensorflow_models.object_detection.meta_architectures.deepmac_meta_arch.""" """Tests for google3.third_party.tensorflow_models.object_detection.meta_architectures.deepmac_meta_arch."""
import functools import functools
import random
import unittest import unittest
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from google.protobuf import text_format
from object_detection.core import losses from object_detection.core import losses
from object_detection.core import preprocessor from object_detection.core import preprocessor
from object_detection.meta_architectures import center_net_meta_arch from object_detection.meta_architectures import center_net_meta_arch
from object_detection.meta_architectures import deepmac_meta_arch from object_detection.meta_architectures import deepmac_meta_arch
from object_detection.protos import center_net_pb2
from object_detection.utils import tf_version from object_detection.utils import tf_version
DEEPMAC_PROTO_TEXT = """
dim: 153
task_loss_weight: 5.0
pixel_embedding_dim: 8
use_xy: false
use_instance_embedding: false
network_type: "cond_inst3"
num_init_channels: 8
classification_loss {
weighted_dice_classification_loss {
squared_normalization: false
is_prediction_probability: false
}
}
jitter_mode: EXPAND_SYMMETRIC_XY
max_roi_jitter_ratio: 0.0
predict_full_resolution_masks: true
allowed_masked_classes_ids: [99]
box_consistency_loss_weight: 1.0
color_consistency_loss_weight: 1.0
color_consistency_threshold: 0.1
box_consistency_tightness: false
box_consistency_loss_normalize: NORMALIZE_AUTO
color_consistency_warmup_steps: 20
color_consistency_warmup_start: 10
"""
class DummyFeatureExtractor(center_net_meta_arch.CenterNetFeatureExtractor): class DummyFeatureExtractor(center_net_meta_arch.CenterNetFeatureExtractor):
def __init__(self, def __init__(self,
...@@ -60,13 +93,36 @@ class MockMaskNet(tf.keras.layers.Layer): ...@@ -60,13 +93,36 @@ class MockMaskNet(tf.keras.layers.Layer):
return tf.zeros_like(pixel_embedding[:, :, :, 0]) + 0.9 return tf.zeros_like(pixel_embedding[:, :, :, 0]) + 0.9
def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False, def build_meta_arch(**override_params):
use_instance_embedding=True, mask_num_subsamples=-1, """Builds the DeepMAC meta architecture."""
network_type='hourglass10', use_xy=True,
params = dict(
predict_full_resolution_masks=False,
use_instance_embedding=True,
mask_num_subsamples=-1,
network_type='hourglass10',
use_xy=True,
pixel_embedding_dim=2, pixel_embedding_dim=2,
dice_loss_prediction_probability=False, dice_loss_prediction_probability=False,
color_consistency_threshold=0.5): color_consistency_threshold=0.5,
"""Builds the DeepMAC meta architecture.""" use_dice_loss=False,
box_consistency_loss_normalize='normalize_auto',
box_consistency_tightness=False,
task_loss_weight=1.0,
color_consistency_loss_weight=1.0,
box_consistency_loss_weight=1.0,
num_init_channels=8,
dim=8,
allowed_masked_classes_ids=[],
mask_size=16,
postprocess_crop_size=128,
max_roi_jitter_ratio=0.0,
roi_jitter_mode='random',
color_consistency_dilation=2,
color_consistency_warmup_steps=0,
color_consistency_warmup_start=0)
params.update(override_params)
feature_extractor = DummyFeatureExtractor( feature_extractor = DummyFeatureExtractor(
channel_means=(1.0, 2.0, 3.0), channel_means=(1.0, 2.0, 3.0),
...@@ -87,33 +143,18 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False, ...@@ -87,33 +143,18 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False,
max_box_predictions=5, max_box_predictions=5,
use_labeled_classes=False) use_labeled_classes=False)
use_dice_loss = params.pop('use_dice_loss')
dice_loss_prediction_prob = params.pop('dice_loss_prediction_probability')
if use_dice_loss: if use_dice_loss:
classification_loss = losses.WeightedDiceClassificationLoss( classification_loss = losses.WeightedDiceClassificationLoss(
squared_normalization=False, squared_normalization=False,
is_prediction_probability=dice_loss_prediction_probability) is_prediction_probability=dice_loss_prediction_prob)
else: else:
classification_loss = losses.WeightedSigmoidClassificationLoss() classification_loss = losses.WeightedSigmoidClassificationLoss()
deepmac_params = deepmac_meta_arch.DeepMACParams( deepmac_params = deepmac_meta_arch.DeepMACParams(
classification_loss=classification_loss, classification_loss=classification_loss,
dim=8, **params
task_loss_weight=1.0,
pixel_embedding_dim=pixel_embedding_dim,
allowed_masked_classes_ids=[],
mask_size=16,
mask_num_subsamples=mask_num_subsamples,
use_xy=use_xy,
network_type=network_type,
use_instance_embedding=use_instance_embedding,
num_init_channels=8,
predict_full_resolution_masks=predict_full_resolution_masks,
postprocess_crop_size=128,
max_roi_jitter_ratio=0.0,
roi_jitter_mode='random',
box_consistency_loss_weight=1.0,
color_consistency_threshold=color_consistency_threshold,
color_consistency_dilation=2,
color_consistency_loss_weight=1.0
) )
object_detection_params = center_net_meta_arch.ObjectDetectionParams( object_detection_params = center_net_meta_arch.ObjectDetectionParams(
...@@ -136,6 +177,15 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False, ...@@ -136,6 +177,15 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False,
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.') @unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class DeepMACUtilsTest(tf.test.TestCase, parameterized.TestCase): class DeepMACUtilsTest(tf.test.TestCase, parameterized.TestCase):
def test_proto_parse(self):
proto = center_net_pb2.CenterNet().DeepMACMaskEstimation()
text_format.Parse(DEEPMAC_PROTO_TEXT, proto)
params = deepmac_meta_arch.deepmac_proto_to_params(proto)
self.assertIsInstance(params, deepmac_meta_arch.DeepMACParams)
self.assertEqual(params.dim, 153)
self.assertEqual(params.box_consistency_loss_normalize, 'normalize_auto')
def test_subsample_trivial(self): def test_subsample_trivial(self):
"""Test subsampling masks.""" """Test subsampling masks."""
...@@ -781,8 +831,85 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -781,8 +831,85 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
xloss = tf.nn.sigmoid_cross_entropy_with_logits( xloss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.constant([1.0] * 16), labels=tf.constant([1.0] * 16),
logits=[1.0] * 12 + [0.0] * 4) logits=[1.0] * 12 + [0.0] * 4)
yloss_mean = tf.reduce_mean(yloss)
xloss_mean = tf.reduce_mean(xloss)
self.assertAllClose(loss, [yloss_mean + xloss_mean])
def test_box_consistency_loss_with_tightness(self):
boxes_gt = tf.constant([[0., 0., 0.49, 0.49]])
boxes_jittered = None
mask_prediction = np.zeros((1, 8, 8)).astype(np.float32) - 1e10
mask_prediction[0, :4, :4] = 1e10
model = build_meta_arch(box_consistency_tightness=True,
predict_full_resolution_masks=True)
loss = model._compute_per_instance_box_consistency_loss(
boxes_gt, boxes_jittered, tf.constant(mask_prediction))
self.assertAllClose(loss, [0.0])
def test_box_consistency_loss_gt_count(self):
boxes_gt = tf.constant([
[0., 0., 1.0, 1.0],
[0., 0., 0.49, 0.49]])
boxes_jittered = None
mask_prediction = np.zeros((2, 32, 32)).astype(np.float32)
mask_prediction[0, :16, :16] = 1.0
mask_prediction[1, :8, :8] = 1.0
model = build_meta_arch(
box_consistency_loss_normalize='normalize_groundtruth_count',
predict_full_resolution_masks=True)
loss_func = tf.function(
model._compute_per_instance_box_consistency_loss)
loss = loss_func(
boxes_gt, boxes_jittered, tf.constant(mask_prediction))
yloss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.constant([1.0] * 32),
logits=[1.0] * 16 + [0.0] * 16) / 32.0
yloss_mean = tf.reduce_sum(yloss)
xloss = yloss
xloss_mean = tf.reduce_sum(xloss)
self.assertAllClose(loss, [tf.reduce_mean(yloss + xloss).numpy()]) self.assertAllClose(loss[0], yloss_mean + xloss_mean)
yloss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.constant([1.0] * 16 + [0.0] * 16),
logits=[1.0] * 8 + [0.0] * 24) / 16.0
yloss_mean = tf.reduce_sum(yloss)
xloss = yloss
xloss_mean = tf.reduce_sum(xloss)
self.assertAllClose(loss[1], yloss_mean + xloss_mean)
def test_box_consistency_loss_balanced(self):
boxes_gt = tf.constant([
[0., 0., 0.49, 0.49]])
boxes_jittered = None
mask_prediction = np.zeros((1, 32, 32)).astype(np.float32)
mask_prediction[0] = 1.0
model = build_meta_arch(box_consistency_loss_normalize='normalize_balanced',
predict_full_resolution_masks=True)
loss_func = tf.function(
model._compute_per_instance_box_consistency_loss)
loss = loss_func(
boxes_gt, boxes_jittered, tf.constant(mask_prediction))
yloss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=[0.] * 16 + [1.0] * 16,
logits=[1.0] * 32)
yloss_mean = tf.reduce_sum(yloss) / 16.0
xloss_mean = yloss_mean
self.assertAllClose(loss[0], yloss_mean + xloss_mean)
def test_box_consistency_dice_loss(self): def test_box_consistency_dice_loss(self):
...@@ -863,34 +990,145 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -863,34 +990,145 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
loss = model.loss(prediction, tf.constant([[32, 32, 3.0]])) loss = model.loss(prediction, tf.constant([[32, 32, 3.0]]))
self.assertGreater(loss['Loss/deep_mask_estimation'], 0.0) self.assertGreater(loss['Loss/deep_mask_estimation'], 0.0)
for weak_loss in deepmac_meta_arch.WEAK_LOSSES: for weak_loss in deepmac_meta_arch.MASK_LOSSES:
if weak_loss == deepmac_meta_arch.DEEP_MASK_COLOR_CONSISTENCY: if weak_loss == deepmac_meta_arch.DEEP_MASK_COLOR_CONSISTENCY:
continue continue
self.assertGreater(loss['Loss/' + weak_loss], 0.0, self.assertGreater(loss['Loss/' + weak_loss], 0.0,
'{} was <= 0'.format(weak_loss)) '{} was <= 0'.format(weak_loss))
def test_loss_keys_full_res(self): def test_loss_weight_response(self):
model = build_meta_arch(use_dice_loss=True, model = build_meta_arch(
predict_full_resolution_masks=True) use_dice_loss=True,
predict_full_resolution_masks=True,
network_type='cond_inst1',
dim=9,
pixel_embedding_dim=8,
use_instance_embedding=False,
use_xy=False)
num_stages = 1
prediction = { prediction = {
'preprocessed_inputs': tf.random.normal((1, 32, 32, 3)), 'preprocessed_inputs': tf.random.normal((1, 32, 32, 3)),
'INSTANCE_EMBEDDING': [tf.random.normal((1, 8, 8, 17))] * 2, 'INSTANCE_EMBEDDING': [tf.random.normal((1, 8, 8, 9))] * num_stages,
'PIXEL_EMBEDDING': [tf.random.normal((1, 8, 8, 19))] * 2, 'PIXEL_EMBEDDING': [tf.random.normal((1, 8, 8, 8))] * num_stages,
'object_center': [tf.random.normal((1, 8, 8, 6))] * 2, 'object_center': [tf.random.normal((1, 8, 8, 6))] * num_stages,
'box/offset': [tf.random.normal((1, 8, 8, 2))] * 2, 'box/offset': [tf.random.normal((1, 8, 8, 2))] * num_stages,
'box/scale': [tf.random.normal((1, 8, 8, 2))] * 2 'box/scale': [tf.random.normal((1, 8, 8, 2))] * num_stages
} }
boxes = [tf.convert_to_tensor([[0., 0., 1., 1.]] * 5)]
classes = [tf.one_hot([1, 0, 1, 1, 1], depth=6)]
weights = [tf.ones(5)]
masks = [tf.ones((5, 32, 32))]
model.provide_groundtruth( model.provide_groundtruth(
groundtruth_boxes_list=[tf.convert_to_tensor([[0., 0., 1., 1.]] * 5)], groundtruth_boxes_list=boxes,
groundtruth_classes_list=[tf.one_hot([1, 0, 1, 1, 1], depth=6)], groundtruth_classes_list=classes,
groundtruth_weights_list=[tf.ones(5)], groundtruth_weights_list=weights,
groundtruth_masks_list=[tf.ones((5, 32, 32))]) groundtruth_masks_list=masks)
loss = model.loss(prediction, tf.constant([[32, 32, 3.0]])) loss = model.loss(prediction, tf.constant([[32, 32, 3.0]]))
self.assertGreater(loss['Loss/deep_mask_estimation'], 0.0) self.assertGreater(loss['Loss/deep_mask_estimation'], 0.0)
for weak_loss in deepmac_meta_arch.WEAK_LOSSES: for mask_loss in deepmac_meta_arch.MASK_LOSSES:
self.assertGreater(loss['Loss/' + weak_loss], 0.0, self.assertGreater(loss['Loss/' + mask_loss], 0.0,
'{} was <= 0'.format(weak_loss)) '{} was <= 0'.format(mask_loss))
rng = random.Random(0)
loss_weights = {
deepmac_meta_arch.DEEP_MASK_ESTIMATION: rng.uniform(1, 5),
deepmac_meta_arch.DEEP_MASK_BOX_CONSISTENCY: rng.uniform(1, 5),
deepmac_meta_arch.DEEP_MASK_COLOR_CONSISTENCY: rng.uniform(1, 5)
}
weighted_model = build_meta_arch(
use_dice_loss=True,
predict_full_resolution_masks=True,
network_type='cond_inst1',
dim=9,
pixel_embedding_dim=8,
use_instance_embedding=False,
use_xy=False,
task_loss_weight=loss_weights[deepmac_meta_arch.DEEP_MASK_ESTIMATION],
box_consistency_loss_weight=(
loss_weights[deepmac_meta_arch.DEEP_MASK_BOX_CONSISTENCY]),
color_consistency_loss_weight=(
loss_weights[deepmac_meta_arch.DEEP_MASK_COLOR_CONSISTENCY]))
weighted_model.provide_groundtruth(
groundtruth_boxes_list=boxes,
groundtruth_classes_list=classes,
groundtruth_weights_list=weights,
groundtruth_masks_list=masks)
weighted_loss = weighted_model.loss(prediction, tf.constant([[32, 32, 3]]))
for mask_loss in deepmac_meta_arch.MASK_LOSSES:
loss_key = 'Loss/' + mask_loss
self.assertAllEqual(
weighted_loss[loss_key], loss[loss_key] * loss_weights[mask_loss],
f'{mask_loss} did not respond to change in weight.')
def test_color_consistency_warmup(self):
model = build_meta_arch(
use_dice_loss=True,
predict_full_resolution_masks=True,
network_type='cond_inst1',
dim=9,
pixel_embedding_dim=8,
use_instance_embedding=False,
use_xy=False,
color_consistency_warmup_steps=10,
color_consistency_warmup_start=10)
num_stages = 1
prediction = {
'preprocessed_inputs': tf.random.normal((1, 32, 32, 3)),
'INSTANCE_EMBEDDING': [tf.random.normal((1, 8, 8, 9))] * num_stages,
'PIXEL_EMBEDDING': [tf.random.normal((1, 8, 8, 8))] * num_stages,
'object_center': [tf.random.normal((1, 8, 8, 6))] * num_stages,
'box/offset': [tf.random.normal((1, 8, 8, 2))] * num_stages,
'box/scale': [tf.random.normal((1, 8, 8, 2))] * num_stages
}
boxes = [tf.convert_to_tensor([[0., 0., 1., 1.]] * 5)]
classes = [tf.one_hot([1, 0, 1, 1, 1], depth=6)]
weights = [tf.ones(5)]
masks = [tf.ones((5, 32, 32))]
model.provide_groundtruth(
groundtruth_boxes_list=boxes,
groundtruth_classes_list=classes,
groundtruth_weights_list=weights,
groundtruth_masks_list=masks,
training_step=5)
loss_at_5 = model.loss(prediction, tf.constant([[32, 32, 3.0]]))
model.provide_groundtruth(
groundtruth_boxes_list=boxes,
groundtruth_classes_list=classes,
groundtruth_weights_list=weights,
groundtruth_masks_list=masks,
training_step=15)
loss_at_15 = model.loss(prediction, tf.constant([[32, 32, 3.0]]))
model.provide_groundtruth(
groundtruth_boxes_list=boxes,
groundtruth_classes_list=classes,
groundtruth_weights_list=weights,
groundtruth_masks_list=masks,
training_step=20)
loss_at_20 = model.loss(prediction, tf.constant([[32, 32, 3.0]]))
model.provide_groundtruth(
groundtruth_boxes_list=boxes,
groundtruth_classes_list=classes,
groundtruth_weights_list=weights,
groundtruth_masks_list=masks,
training_step=100)
loss_at_100 = model.loss(prediction, tf.constant([[32, 32, 3.0]]))
loss_key = 'Loss/' + deepmac_meta_arch.DEEP_MASK_COLOR_CONSISTENCY
self.assertAlmostEqual(loss_at_5[loss_key].numpy(), 0.0)
self.assertAlmostEqual(loss_at_15[loss_key].numpy(),
loss_at_20[loss_key].numpy() / 2.0)
self.assertAlmostEqual(loss_at_20[loss_key].numpy(),
loss_at_100[loss_key].numpy())
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.') @unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
......
...@@ -303,7 +303,7 @@ def unstack_batch(tensor_dict, unpad_groundtruth_tensors=True): ...@@ -303,7 +303,7 @@ def unstack_batch(tensor_dict, unpad_groundtruth_tensors=True):
return unbatched_tensor_dict return unbatched_tensor_dict
def provide_groundtruth(model, labels): def provide_groundtruth(model, labels, training_step=None):
"""Provides the labels to a model as groundtruth. """Provides the labels to a model as groundtruth.
This helper function extracts the corresponding boxes, classes, This helper function extracts the corresponding boxes, classes,
...@@ -313,6 +313,8 @@ def provide_groundtruth(model, labels): ...@@ -313,6 +313,8 @@ def provide_groundtruth(model, labels):
Args: Args:
model: The detection model to provide groundtruth to. model: The detection model to provide groundtruth to.
labels: The labels for the training or evaluation inputs. labels: The labels for the training or evaluation inputs.
training_step: int, optional. The training step for the model. Useful
for models which want to anneal loss weights.
""" """
gt_boxes_list = labels[fields.InputDataFields.groundtruth_boxes] gt_boxes_list = labels[fields.InputDataFields.groundtruth_boxes]
gt_classes_list = labels[fields.InputDataFields.groundtruth_classes] gt_classes_list = labels[fields.InputDataFields.groundtruth_classes]
...@@ -402,7 +404,8 @@ def provide_groundtruth(model, labels): ...@@ -402,7 +404,8 @@ def provide_groundtruth(model, labels):
groundtruth_verified_neg_classes=gt_verified_neg_classes, groundtruth_verified_neg_classes=gt_verified_neg_classes,
groundtruth_not_exhaustive_classes=gt_not_exhaustive_classes, groundtruth_not_exhaustive_classes=gt_not_exhaustive_classes,
groundtruth_keypoint_depths_list=gt_keypoint_depths_list, groundtruth_keypoint_depths_list=gt_keypoint_depths_list,
groundtruth_keypoint_depth_weights_list=gt_keypoint_depth_weights_list) groundtruth_keypoint_depth_weights_list=gt_keypoint_depth_weights_list,
training_step=training_step)
def create_model_fn(detection_model_fn, configs, hparams=None, use_tpu=False, def create_model_fn(detection_model_fn, configs, hparams=None, use_tpu=False,
......
...@@ -51,7 +51,7 @@ RESTORE_MAP_ERROR_TEMPLATE = ( ...@@ -51,7 +51,7 @@ RESTORE_MAP_ERROR_TEMPLATE = (
def _compute_losses_and_predictions_dicts( def _compute_losses_and_predictions_dicts(
model, features, labels, model, features, labels, training_step=None,
add_regularization_loss=True): add_regularization_loss=True):
"""Computes the losses dict and predictions dict for a model on inputs. """Computes the losses dict and predictions dict for a model on inputs.
...@@ -107,6 +107,7 @@ def _compute_losses_and_predictions_dicts( ...@@ -107,6 +107,7 @@ def _compute_losses_and_predictions_dicts(
float32 tensor containing keypoint depths information. float32 tensor containing keypoint depths information.
labels[fields.InputDataFields.groundtruth_keypoint_depth_weights] is a labels[fields.InputDataFields.groundtruth_keypoint_depth_weights] is a
float32 tensor containing the weights of the keypoint depth feature. float32 tensor containing the weights of the keypoint depth feature.
training_step: int, the current training step.
add_regularization_loss: Whether or not to include the model's add_regularization_loss: Whether or not to include the model's
regularization loss in the losses dictionary. regularization loss in the losses dictionary.
...@@ -116,7 +117,7 @@ def _compute_losses_and_predictions_dicts( ...@@ -116,7 +117,7 @@ def _compute_losses_and_predictions_dicts(
`model.predict`. `model.predict`.
""" """
model_lib.provide_groundtruth(model, labels) model_lib.provide_groundtruth(model, labels, training_step=training_step)
preprocessed_images = features[fields.InputDataFields.image] preprocessed_images = features[fields.InputDataFields.image]
prediction_dict = model.predict( prediction_dict = model.predict(
...@@ -166,7 +167,8 @@ def _ensure_model_is_built(model, input_dataset, unpad_groundtruth_tensors): ...@@ -166,7 +167,8 @@ def _ensure_model_is_built(model, input_dataset, unpad_groundtruth_tensors):
labels = model_lib.unstack_batch( labels = model_lib.unstack_batch(
labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors) labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)
return _compute_losses_and_predictions_dicts(model, features, labels) return _compute_losses_and_predictions_dicts(model, features, labels,
training_step=0)
strategy = tf.compat.v2.distribute.get_strategy() strategy = tf.compat.v2.distribute.get_strategy()
if hasattr(tf.distribute.Strategy, 'run'): if hasattr(tf.distribute.Strategy, 'run'):
...@@ -208,6 +210,7 @@ def eager_train_step(detection_model, ...@@ -208,6 +210,7 @@ def eager_train_step(detection_model,
labels, labels,
unpad_groundtruth_tensors, unpad_groundtruth_tensors,
optimizer, optimizer,
training_step,
add_regularization_loss=True, add_regularization_loss=True,
clip_gradients_value=None, clip_gradients_value=None,
num_replicas=1.0): num_replicas=1.0):
...@@ -280,6 +283,7 @@ def eager_train_step(detection_model, ...@@ -280,6 +283,7 @@ def eager_train_step(detection_model,
float32 tensor containing the weights of the keypoint depth feature. float32 tensor containing the weights of the keypoint depth feature.
unpad_groundtruth_tensors: A parameter passed to unstack_batch. unpad_groundtruth_tensors: A parameter passed to unstack_batch.
optimizer: The training optimizer that will update the variables. optimizer: The training optimizer that will update the variables.
training_step: int, the training step number.
add_regularization_loss: Whether or not to include the model's add_regularization_loss: Whether or not to include the model's
regularization loss in the losses dictionary. regularization loss in the losses dictionary.
clip_gradients_value: If this is present, clip the gradients global norm clip_gradients_value: If this is present, clip the gradients global norm
...@@ -302,7 +306,9 @@ def eager_train_step(detection_model, ...@@ -302,7 +306,9 @@ def eager_train_step(detection_model,
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
losses_dict, _ = _compute_losses_and_predictions_dicts( losses_dict, _ = _compute_losses_and_predictions_dicts(
detection_model, features, labels, add_regularization_loss) detection_model, features, labels,
training_step=training_step,
add_regularization_loss=add_regularization_loss)
losses_dict = normalize_dict(losses_dict, num_replicas) losses_dict = normalize_dict(losses_dict, num_replicas)
...@@ -632,6 +638,7 @@ def train_loop( ...@@ -632,6 +638,7 @@ def train_loop(
labels, labels,
unpad_groundtruth_tensors, unpad_groundtruth_tensors,
optimizer, optimizer,
training_step=global_step,
add_regularization_loss=add_regularization_loss, add_regularization_loss=add_regularization_loss,
clip_gradients_value=clip_gradients_value, clip_gradients_value=clip_gradients_value,
num_replicas=strategy.num_replicas_in_sync) num_replicas=strategy.num_replicas_in_sync)
...@@ -901,7 +908,8 @@ def eager_eval_loop( ...@@ -901,7 +908,8 @@ def eager_eval_loop(
labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors) labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)
losses_dict, prediction_dict = _compute_losses_and_predictions_dicts( losses_dict, prediction_dict = _compute_losses_and_predictions_dicts(
detection_model, features, labels, add_regularization_loss) detection_model, features, labels, training_step=None,
add_regularization_loss=add_regularization_loss)
prediction_dict = detection_model.postprocess( prediction_dict = detection_model.postprocess(
prediction_dict, features[fields.InputDataFields.true_image_shape]) prediction_dict, features[fields.InputDataFields.true_image_shape])
eval_features = { eval_features = {
......
...@@ -403,6 +403,7 @@ message CenterNet { ...@@ -403,6 +403,7 @@ message CenterNet {
// Mask prediction support using DeepMAC. See https://arxiv.org/abs/2104.00613 // Mask prediction support using DeepMAC. See https://arxiv.org/abs/2104.00613
// Next ID 24
message DeepMACMaskEstimation { message DeepMACMaskEstimation {
// The loss used for penalizing mask predictions. // The loss used for penalizing mask predictions.
optional ClassificationLoss classification_loss = 1; optional ClassificationLoss classification_loss = 1;
...@@ -471,6 +472,19 @@ message CenterNet { ...@@ -471,6 +472,19 @@ message CenterNet {
optional float color_consistency_loss_weight = 19 [default=0.0]; optional float color_consistency_loss_weight = 19 [default=0.0];
optional LossNormalize box_consistency_loss_normalize = 20 [
default=NORMALIZE_AUTO];
// If set, will use the bounding box tightness prior approach. This means
// that the max will be restricted to only be inside the box for both
// dimensions. See details here:
// https://papers.nips.cc/paper/2019/hash/e6e713296627dff6475085cc6a224464-Abstract.html
optional bool box_consistency_tightness = 21 [default=false];
optional int32 color_consistency_warmup_steps = 22 [default=0];
optional int32 color_consistency_warmup_start = 23 [default=0];
} }
optional DeepMACMaskEstimation deepmac_mask_estimation = 14; optional DeepMACMaskEstimation deepmac_mask_estimation = 14;
...@@ -483,6 +497,12 @@ message CenterNet { ...@@ -483,6 +497,12 @@ message CenterNet {
optional PostProcessing post_processing = 24; optional PostProcessing post_processing = 24;
} }
enum LossNormalize {
NORMALIZE_AUTO = 0; // SUM for 2D inputs (dice loss) and MEAN for others.
NORMALIZE_GROUNDTRUTH_COUNT = 1;
NORMALIZE_BALANCED = 3;
}
message CenterNetFeatureExtractor { message CenterNetFeatureExtractor {
optional string type = 1; optional string type = 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment