Commit 9400d73c authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Implement bounding box tightness prior and CC loss warmup.

PiperOrigin-RevId: 415265931
parent 3b7bc268
......@@ -89,6 +89,7 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
"""
self._num_classes = num_classes
self._groundtruth_lists = {}
self._training_step = None
super(DetectionModel, self).__init__()
......@@ -132,6 +133,13 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
"""
return field in self._groundtruth_lists
@property
def training_step(self):
if self._training_step is None:
raise ValueError('Training step was not provided to the model.')
return self._training_step
@staticmethod
def get_side_inputs(features):
"""Get side inputs from input features.
......@@ -318,7 +326,8 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
groundtruth_verified_neg_classes=None,
groundtruth_not_exhaustive_classes=None,
groundtruth_keypoint_depths_list=None,
groundtruth_keypoint_depth_weights_list=None):
groundtruth_keypoint_depth_weights_list=None,
training_step=None):
"""Provide groundtruth tensors.
Args:
......@@ -389,6 +398,8 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
groundtruth_keypoint_depth_weights_list: a list of 2-D tf.float32 tensors
of shape [num_boxes, num_keypoints] containing the weights of the
relative depths.
training_step: An integer denoting the current training step. This is
useful when models want to anneal loss terms.
"""
self._groundtruth_lists[fields.BoxListFields.boxes] = groundtruth_boxes_list
self._groundtruth_lists[
......@@ -468,6 +479,8 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
fields.InputDataFields
.groundtruth_not_exhaustive_classes] = (
groundtruth_not_exhaustive_classes)
if training_step is not None:
self._training_step = training_step
@abc.abstractmethod
def regularization_losses(self):
......
......@@ -12,12 +12,12 @@ import tensorflow as tf
from object_detection.builders import losses_builder
from object_detection.core import box_list
from object_detection.core import box_list_ops
from object_detection.core import losses
from object_detection.core import preprocessor
from object_detection.core import standard_fields as fields
from object_detection.meta_architectures import center_net_meta_arch
from object_detection.models.keras_models import hourglass_network
from object_detection.models.keras_models import resnet_v1
from object_detection.protos import center_net_pb2
from object_detection.protos import losses_pb2
from object_detection.protos import preprocessor_pb2
from object_detection.utils import shape_utils
......@@ -38,46 +38,26 @@ NEIGHBORS_2D = [[-1, -1], [-1, 0], [-1, 1],
[0, -1], [0, 1],
[1, -1], [1, 0], [1, 1]]
WEAK_LOSSES = [DEEP_MASK_BOX_CONSISTENCY, DEEP_MASK_COLOR_CONSISTENCY]
MASK_LOSSES = WEAK_LOSSES + [DEEP_MASK_ESTIMATION]
class DeepMACParams(
collections.namedtuple('DeepMACParams', [
DeepMACParams = collections.namedtuple('DeepMACParams', [
'classification_loss', 'dim', 'task_loss_weight', 'pixel_embedding_dim',
'allowed_masked_classes_ids', 'mask_size', 'mask_num_subsamples',
'use_xy', 'network_type', 'use_instance_embedding', 'num_init_channels',
'predict_full_resolution_masks', 'postprocess_crop_size',
'max_roi_jitter_ratio', 'roi_jitter_mode',
'box_consistency_loss_weight', 'color_consistency_threshold',
'color_consistency_dilation', 'color_consistency_loss_weight'
])):
"""Class holding the DeepMAC network configutration."""
__slots__ = ()
def __new__(cls, classification_loss, dim, task_loss_weight,
pixel_embedding_dim, allowed_masked_classes_ids, mask_size,
mask_num_subsamples, use_xy, network_type, use_instance_embedding,
num_init_channels, predict_full_resolution_masks,
postprocess_crop_size, max_roi_jitter_ratio,
roi_jitter_mode, box_consistency_loss_weight,
color_consistency_threshold, color_consistency_dilation,
color_consistency_loss_weight):
return super(DeepMACParams,
cls).__new__(cls, classification_loss, dim,
task_loss_weight, pixel_embedding_dim,
allowed_masked_classes_ids, mask_size,
mask_num_subsamples, use_xy, network_type,
use_instance_embedding, num_init_channels,
predict_full_resolution_masks,
postprocess_crop_size, max_roi_jitter_ratio,
roi_jitter_mode, box_consistency_loss_weight,
color_consistency_threshold,
color_consistency_dilation,
color_consistency_loss_weight)
def _get_weak_loss_weight(loss_name, config):
if loss_name == DEEP_MASK_COLOR_CONSISTENCY:
'color_consistency_dilation', 'color_consistency_loss_weight',
'box_consistency_loss_normalize', 'box_consistency_tightness',
'color_consistency_warmup_steps', 'color_consistency_warmup_start'
])
def _get_loss_weight(loss_name, config):
if loss_name == DEEP_MASK_ESTIMATION:
return config.task_loss_weight
elif loss_name == DEEP_MASK_COLOR_CONSISTENCY:
return config.color_consistency_loss_weight
elif loss_name == DEEP_MASK_BOX_CONSISTENCY:
return config.box_consistency_loss_weight
......@@ -755,6 +735,9 @@ def deepmac_proto_to_params(deepmac_config):
jitter_mode = preprocessor_pb2.RandomJitterBoxes.JitterMode.Name(
deepmac_config.jitter_mode).lower()
box_consistency_loss_normalize = center_net_pb2.LossNormalize.Name(
deepmac_config.box_consistency_loss_normalize).lower()
return DeepMACParams(
dim=deepmac_config.dim,
classification_loss=classification_loss,
......@@ -775,7 +758,14 @@ def deepmac_proto_to_params(deepmac_config):
box_consistency_loss_weight=deepmac_config.box_consistency_loss_weight,
color_consistency_threshold=deepmac_config.color_consistency_threshold,
color_consistency_dilation=deepmac_config.color_consistency_dilation,
color_consistency_loss_weight=deepmac_config.color_consistency_loss_weight
color_consistency_loss_weight=
deepmac_config.color_consistency_loss_weight,
box_consistency_loss_normalize=box_consistency_loss_normalize,
box_consistency_tightness=deepmac_config.box_consistency_tightness,
color_consistency_warmup_steps=
deepmac_config.color_consistency_warmup_steps,
color_consistency_warmup_start=
deepmac_config.color_consistency_warmup_start
)
......@@ -972,6 +962,60 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
return resize_instance_masks(logits, (height, width))
def _aggregate_classification_loss(self, loss, gt, pred, method):
"""Aggregates loss at a per-instance level.
When this function is used with mask-heads, num_classes is usually 1.
Args:
loss: A [num_instances, num_pixels, num_classes] or
[num_instances, num_classes] tensor. If the tensor is of rank 2, i.e.,
of the form [num_instances, num_classes], we will assume that the
number of pixels have already been nornalized.
gt: A [num_instances, num_pixels, num_classes] float tensor of
groundtruths.
pred: A [num_instances, num_pixels, num_classes] float tensor of
preditions.
method: A string in ['auto', 'groundtruth'].
'auto': When `loss` is rank 2, aggregates by sum. Otherwise, aggregates
by mean.
'groundtruth_count': Aggreagates the loss by computing sum and dividing
by the number of positive (1) groundtruth pixels.
'balanced': Normalizes each pixel by the number of positive or negative
pixels depending on the groundtruth.
Returns:
per_instance_loss: A [num_instances] float tensor.
"""
rank = len(loss.get_shape().as_list())
if rank == 2:
axes = [1]
else:
axes = [1, 2]
if method == 'normalize_auto':
normalization = 1.0
if rank == 2:
return tf.reduce_sum(loss, axis=axes)
else:
return tf.reduce_mean(loss, axis=axes)
elif method == 'normalize_groundtruth_count':
normalization = tf.reduce_sum(gt, axis=axes)
return tf.reduce_sum(loss, axis=axes) / normalization
elif method == 'normalize_balanced':
if rank != 3:
raise ValueError('Cannot apply normalized_balanced aggregation '
f'to loss of rank {rank}')
normalization = (
(gt * tf.reduce_sum(gt, keepdims=True, axis=axes)) +
(1 - gt) * tf.reduce_sum(1 - gt, keepdims=True, axis=axes))
return tf.reduce_sum(loss / normalization, axis=axes)
else:
raise ValueError('Unknown loss aggregation - {}'.format(method))
def _compute_per_instance_mask_prediction_loss(
self, boxes, mask_logits, mask_gt):
"""Compute the per-instance mask loss.
......@@ -995,14 +1039,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
target_tensor=mask_gt,
weights=tf.ones_like(mask_logits))
# TODO(vighneshb) Make this configurable via config.
# Skip normalization for dice loss because the denominator term already
# does normalization.
if isinstance(self._deepmac_params.classification_loss,
losses.WeightedDiceClassificationLoss):
return tf.reduce_sum(loss, axis=1)
else:
return tf.reduce_mean(loss, axis=[1, 2])
return self._aggregate_classification_loss(
loss, mask_gt, mask_logits, 'normalize_auto')
def _compute_per_instance_box_consistency_loss(
self, boxes_gt, boxes_for_crop, mask_logits):
......@@ -1034,23 +1072,30 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
loss = 0.0
for axis in [1, 2]:
pred_max = tf.reduce_max(pred_crop, axis=axis)[:, :, tf.newaxis]
if self._deepmac_params.box_consistency_tightness:
pred_max_raw = tf.reduce_max(pred_crop, axis=axis)
pred_max_within_box = tf.reduce_max(pred_crop * gt_crop, axis=axis)
box_1d = tf.reduce_max(gt_crop, axis=axis)
pred_max = ((box_1d * pred_max_within_box) +
((1 - box_1d) * pred_max_raw))
else:
pred_max = tf.reduce_max(pred_crop, axis=axis)
pred_max = pred_max[:, :, tf.newaxis]
gt_max = tf.reduce_max(gt_crop, axis=axis)[:, :, tf.newaxis]
axis_loss = self._deepmac_params.classification_loss(
raw_loss = self._deepmac_params.classification_loss(
prediction_tensor=pred_max,
target_tensor=gt_max,
weights=tf.ones_like(pred_max))
loss += axis_loss
# Skip normalization for dice loss because the denominator term already
# does normalization.
# TODO(vighneshb) Make this configurable via config.
if isinstance(self._deepmac_params.classification_loss,
losses.WeightedDiceClassificationLoss):
return tf.reduce_sum(loss, axis=1)
else:
return tf.reduce_mean(loss, axis=[1, 2])
loss += self._aggregate_classification_loss(
raw_loss, gt_max, pred_max,
self._deepmac_params.box_consistency_loss_normalize)
return loss
def _compute_per_instance_color_consistency_loss(
self, boxes, preprocessed_image, mask_logits):
......@@ -1099,6 +1144,17 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
num_box_pixels = tf.maximum(1.0, tf.reduce_sum(box_mask, axis=[1, 2]))
loss = loss / num_box_pixels
if ((self._deepmac_params.color_consistency_warmup_steps > 0) and
self._is_training):
training_step = tf.cast(self.training_step, tf.float32)
warmup_steps = tf.cast(
self._deepmac_params.color_consistency_warmup_steps, tf.float32)
start_step = tf.cast(
self._deepmac_params.color_consistency_warmup_start, tf.float32)
warmup_weight = (training_step - start_step) / warmup_steps
warmup_weight = tf.clip_by_value(warmup_weight, 0.0, 1.0)
loss *= warmup_weight
return loss
def _compute_per_instance_deepmac_losses(
......@@ -1188,11 +1244,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
allowed_masked_classes_ids = (
self._deepmac_params.allowed_masked_classes_ids)
loss_dict = {
DEEP_MASK_ESTIMATION: 0.0,
}
for loss_name in WEAK_LOSSES:
loss_dict = {}
for loss_name in MASK_LOSSES:
loss_dict[loss_name] = 0.0
prediction_shape = tf.shape(prediction_dict[INSTANCE_EMBEDDING][0])
......@@ -1252,13 +1305,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
mask_loss_dict = self._compute_instance_masks_loss(
prediction_dict=prediction_dict)
losses_dict[LOSS_KEY_PREFIX + '/' + DEEP_MASK_ESTIMATION] = (
self._deepmac_params.task_loss_weight * mask_loss_dict[
DEEP_MASK_ESTIMATION]
)
for loss_name in WEAK_LOSSES:
loss_weight = _get_weak_loss_weight(loss_name, self._deepmac_params)
for loss_name in MASK_LOSSES:
loss_weight = _get_loss_weight(loss_name, self._deepmac_params)
if loss_weight > 0.0:
losses_dict[LOSS_KEY_PREFIX + '/' + loss_name] = (
loss_weight * mask_loss_dict[loss_name])
......
"""Tests for google3.third_party.tensorflow_models.object_detection.meta_architectures.deepmac_meta_arch."""
import functools
import random
import unittest
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from google.protobuf import text_format
from object_detection.core import losses
from object_detection.core import preprocessor
from object_detection.meta_architectures import center_net_meta_arch
from object_detection.meta_architectures import deepmac_meta_arch
from object_detection.protos import center_net_pb2
from object_detection.utils import tf_version
DEEPMAC_PROTO_TEXT = """
dim: 153
task_loss_weight: 5.0
pixel_embedding_dim: 8
use_xy: false
use_instance_embedding: false
network_type: "cond_inst3"
num_init_channels: 8
classification_loss {
weighted_dice_classification_loss {
squared_normalization: false
is_prediction_probability: false
}
}
jitter_mode: EXPAND_SYMMETRIC_XY
max_roi_jitter_ratio: 0.0
predict_full_resolution_masks: true
allowed_masked_classes_ids: [99]
box_consistency_loss_weight: 1.0
color_consistency_loss_weight: 1.0
color_consistency_threshold: 0.1
box_consistency_tightness: false
box_consistency_loss_normalize: NORMALIZE_AUTO
color_consistency_warmup_steps: 20
color_consistency_warmup_start: 10
"""
class DummyFeatureExtractor(center_net_meta_arch.CenterNetFeatureExtractor):
def __init__(self,
......@@ -60,14 +93,37 @@ class MockMaskNet(tf.keras.layers.Layer):
return tf.zeros_like(pixel_embedding[:, :, :, 0]) + 0.9
def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False,
use_instance_embedding=True, mask_num_subsamples=-1,
network_type='hourglass10', use_xy=True,
pixel_embedding_dim=2,
dice_loss_prediction_probability=False,
color_consistency_threshold=0.5):
def build_meta_arch(**override_params):
"""Builds the DeepMAC meta architecture."""
params = dict(
predict_full_resolution_masks=False,
use_instance_embedding=True,
mask_num_subsamples=-1,
network_type='hourglass10',
use_xy=True,
pixel_embedding_dim=2,
dice_loss_prediction_probability=False,
color_consistency_threshold=0.5,
use_dice_loss=False,
box_consistency_loss_normalize='normalize_auto',
box_consistency_tightness=False,
task_loss_weight=1.0,
color_consistency_loss_weight=1.0,
box_consistency_loss_weight=1.0,
num_init_channels=8,
dim=8,
allowed_masked_classes_ids=[],
mask_size=16,
postprocess_crop_size=128,
max_roi_jitter_ratio=0.0,
roi_jitter_mode='random',
color_consistency_dilation=2,
color_consistency_warmup_steps=0,
color_consistency_warmup_start=0)
params.update(override_params)
feature_extractor = DummyFeatureExtractor(
channel_means=(1.0, 2.0, 3.0),
channel_stds=(10., 20., 30.),
......@@ -87,33 +143,18 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False,
max_box_predictions=5,
use_labeled_classes=False)
use_dice_loss = params.pop('use_dice_loss')
dice_loss_prediction_prob = params.pop('dice_loss_prediction_probability')
if use_dice_loss:
classification_loss = losses.WeightedDiceClassificationLoss(
squared_normalization=False,
is_prediction_probability=dice_loss_prediction_probability)
is_prediction_probability=dice_loss_prediction_prob)
else:
classification_loss = losses.WeightedSigmoidClassificationLoss()
deepmac_params = deepmac_meta_arch.DeepMACParams(
classification_loss=classification_loss,
dim=8,
task_loss_weight=1.0,
pixel_embedding_dim=pixel_embedding_dim,
allowed_masked_classes_ids=[],
mask_size=16,
mask_num_subsamples=mask_num_subsamples,
use_xy=use_xy,
network_type=network_type,
use_instance_embedding=use_instance_embedding,
num_init_channels=8,
predict_full_resolution_masks=predict_full_resolution_masks,
postprocess_crop_size=128,
max_roi_jitter_ratio=0.0,
roi_jitter_mode='random',
box_consistency_loss_weight=1.0,
color_consistency_threshold=color_consistency_threshold,
color_consistency_dilation=2,
color_consistency_loss_weight=1.0
**params
)
object_detection_params = center_net_meta_arch.ObjectDetectionParams(
......@@ -136,6 +177,15 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False,
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class DeepMACUtilsTest(tf.test.TestCase, parameterized.TestCase):
def test_proto_parse(self):
proto = center_net_pb2.CenterNet().DeepMACMaskEstimation()
text_format.Parse(DEEPMAC_PROTO_TEXT, proto)
params = deepmac_meta_arch.deepmac_proto_to_params(proto)
self.assertIsInstance(params, deepmac_meta_arch.DeepMACParams)
self.assertEqual(params.dim, 153)
self.assertEqual(params.box_consistency_loss_normalize, 'normalize_auto')
def test_subsample_trivial(self):
"""Test subsampling masks."""
......@@ -781,8 +831,85 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
xloss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.constant([1.0] * 16),
logits=[1.0] * 12 + [0.0] * 4)
yloss_mean = tf.reduce_mean(yloss)
xloss_mean = tf.reduce_mean(xloss)
self.assertAllClose(loss, [yloss_mean + xloss_mean])
def test_box_consistency_loss_with_tightness(self):
boxes_gt = tf.constant([[0., 0., 0.49, 0.49]])
boxes_jittered = None
mask_prediction = np.zeros((1, 8, 8)).astype(np.float32) - 1e10
mask_prediction[0, :4, :4] = 1e10
model = build_meta_arch(box_consistency_tightness=True,
predict_full_resolution_masks=True)
loss = model._compute_per_instance_box_consistency_loss(
boxes_gt, boxes_jittered, tf.constant(mask_prediction))
self.assertAllClose(loss, [0.0])
def test_box_consistency_loss_gt_count(self):
boxes_gt = tf.constant([
[0., 0., 1.0, 1.0],
[0., 0., 0.49, 0.49]])
boxes_jittered = None
mask_prediction = np.zeros((2, 32, 32)).astype(np.float32)
mask_prediction[0, :16, :16] = 1.0
mask_prediction[1, :8, :8] = 1.0
model = build_meta_arch(
box_consistency_loss_normalize='normalize_groundtruth_count',
predict_full_resolution_masks=True)
loss_func = tf.function(
model._compute_per_instance_box_consistency_loss)
loss = loss_func(
boxes_gt, boxes_jittered, tf.constant(mask_prediction))
self.assertAllClose(loss, [tf.reduce_mean(yloss + xloss).numpy()])
yloss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.constant([1.0] * 32),
logits=[1.0] * 16 + [0.0] * 16) / 32.0
yloss_mean = tf.reduce_sum(yloss)
xloss = yloss
xloss_mean = tf.reduce_sum(xloss)
self.assertAllClose(loss[0], yloss_mean + xloss_mean)
yloss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.constant([1.0] * 16 + [0.0] * 16),
logits=[1.0] * 8 + [0.0] * 24) / 16.0
yloss_mean = tf.reduce_sum(yloss)
xloss = yloss
xloss_mean = tf.reduce_sum(xloss)
self.assertAllClose(loss[1], yloss_mean + xloss_mean)
def test_box_consistency_loss_balanced(self):
boxes_gt = tf.constant([
[0., 0., 0.49, 0.49]])
boxes_jittered = None
mask_prediction = np.zeros((1, 32, 32)).astype(np.float32)
mask_prediction[0] = 1.0
model = build_meta_arch(box_consistency_loss_normalize='normalize_balanced',
predict_full_resolution_masks=True)
loss_func = tf.function(
model._compute_per_instance_box_consistency_loss)
loss = loss_func(
boxes_gt, boxes_jittered, tf.constant(mask_prediction))
yloss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=[0.] * 16 + [1.0] * 16,
logits=[1.0] * 32)
yloss_mean = tf.reduce_sum(yloss) / 16.0
xloss_mean = yloss_mean
self.assertAllClose(loss[0], yloss_mean + xloss_mean)
def test_box_consistency_dice_loss(self):
......@@ -863,34 +990,145 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
loss = model.loss(prediction, tf.constant([[32, 32, 3.0]]))
self.assertGreater(loss['Loss/deep_mask_estimation'], 0.0)
for weak_loss in deepmac_meta_arch.WEAK_LOSSES:
for weak_loss in deepmac_meta_arch.MASK_LOSSES:
if weak_loss == deepmac_meta_arch.DEEP_MASK_COLOR_CONSISTENCY:
continue
self.assertGreater(loss['Loss/' + weak_loss], 0.0,
'{} was <= 0'.format(weak_loss))
def test_loss_keys_full_res(self):
model = build_meta_arch(use_dice_loss=True,
predict_full_resolution_masks=True)
def test_loss_weight_response(self):
model = build_meta_arch(
use_dice_loss=True,
predict_full_resolution_masks=True,
network_type='cond_inst1',
dim=9,
pixel_embedding_dim=8,
use_instance_embedding=False,
use_xy=False)
num_stages = 1
prediction = {
'preprocessed_inputs': tf.random.normal((1, 32, 32, 3)),
'INSTANCE_EMBEDDING': [tf.random.normal((1, 8, 8, 17))] * 2,
'PIXEL_EMBEDDING': [tf.random.normal((1, 8, 8, 19))] * 2,
'object_center': [tf.random.normal((1, 8, 8, 6))] * 2,
'box/offset': [tf.random.normal((1, 8, 8, 2))] * 2,
'box/scale': [tf.random.normal((1, 8, 8, 2))] * 2
'INSTANCE_EMBEDDING': [tf.random.normal((1, 8, 8, 9))] * num_stages,
'PIXEL_EMBEDDING': [tf.random.normal((1, 8, 8, 8))] * num_stages,
'object_center': [tf.random.normal((1, 8, 8, 6))] * num_stages,
'box/offset': [tf.random.normal((1, 8, 8, 2))] * num_stages,
'box/scale': [tf.random.normal((1, 8, 8, 2))] * num_stages
}
boxes = [tf.convert_to_tensor([[0., 0., 1., 1.]] * 5)]
classes = [tf.one_hot([1, 0, 1, 1, 1], depth=6)]
weights = [tf.ones(5)]
masks = [tf.ones((5, 32, 32))]
model.provide_groundtruth(
groundtruth_boxes_list=[tf.convert_to_tensor([[0., 0., 1., 1.]] * 5)],
groundtruth_classes_list=[tf.one_hot([1, 0, 1, 1, 1], depth=6)],
groundtruth_weights_list=[tf.ones(5)],
groundtruth_masks_list=[tf.ones((5, 32, 32))])
groundtruth_boxes_list=boxes,
groundtruth_classes_list=classes,
groundtruth_weights_list=weights,
groundtruth_masks_list=masks)
loss = model.loss(prediction, tf.constant([[32, 32, 3.0]]))
self.assertGreater(loss['Loss/deep_mask_estimation'], 0.0)
for weak_loss in deepmac_meta_arch.WEAK_LOSSES:
self.assertGreater(loss['Loss/' + weak_loss], 0.0,
'{} was <= 0'.format(weak_loss))
for mask_loss in deepmac_meta_arch.MASK_LOSSES:
self.assertGreater(loss['Loss/' + mask_loss], 0.0,
'{} was <= 0'.format(mask_loss))
rng = random.Random(0)
loss_weights = {
deepmac_meta_arch.DEEP_MASK_ESTIMATION: rng.uniform(1, 5),
deepmac_meta_arch.DEEP_MASK_BOX_CONSISTENCY: rng.uniform(1, 5),
deepmac_meta_arch.DEEP_MASK_COLOR_CONSISTENCY: rng.uniform(1, 5)
}
weighted_model = build_meta_arch(
use_dice_loss=True,
predict_full_resolution_masks=True,
network_type='cond_inst1',
dim=9,
pixel_embedding_dim=8,
use_instance_embedding=False,
use_xy=False,
task_loss_weight=loss_weights[deepmac_meta_arch.DEEP_MASK_ESTIMATION],
box_consistency_loss_weight=(
loss_weights[deepmac_meta_arch.DEEP_MASK_BOX_CONSISTENCY]),
color_consistency_loss_weight=(
loss_weights[deepmac_meta_arch.DEEP_MASK_COLOR_CONSISTENCY]))
weighted_model.provide_groundtruth(
groundtruth_boxes_list=boxes,
groundtruth_classes_list=classes,
groundtruth_weights_list=weights,
groundtruth_masks_list=masks)
weighted_loss = weighted_model.loss(prediction, tf.constant([[32, 32, 3]]))
for mask_loss in deepmac_meta_arch.MASK_LOSSES:
loss_key = 'Loss/' + mask_loss
self.assertAllEqual(
weighted_loss[loss_key], loss[loss_key] * loss_weights[mask_loss],
f'{mask_loss} did not respond to change in weight.')
def test_color_consistency_warmup(self):
model = build_meta_arch(
use_dice_loss=True,
predict_full_resolution_masks=True,
network_type='cond_inst1',
dim=9,
pixel_embedding_dim=8,
use_instance_embedding=False,
use_xy=False,
color_consistency_warmup_steps=10,
color_consistency_warmup_start=10)
num_stages = 1
prediction = {
'preprocessed_inputs': tf.random.normal((1, 32, 32, 3)),
'INSTANCE_EMBEDDING': [tf.random.normal((1, 8, 8, 9))] * num_stages,
'PIXEL_EMBEDDING': [tf.random.normal((1, 8, 8, 8))] * num_stages,
'object_center': [tf.random.normal((1, 8, 8, 6))] * num_stages,
'box/offset': [tf.random.normal((1, 8, 8, 2))] * num_stages,
'box/scale': [tf.random.normal((1, 8, 8, 2))] * num_stages
}
boxes = [tf.convert_to_tensor([[0., 0., 1., 1.]] * 5)]
classes = [tf.one_hot([1, 0, 1, 1, 1], depth=6)]
weights = [tf.ones(5)]
masks = [tf.ones((5, 32, 32))]
model.provide_groundtruth(
groundtruth_boxes_list=boxes,
groundtruth_classes_list=classes,
groundtruth_weights_list=weights,
groundtruth_masks_list=masks,
training_step=5)
loss_at_5 = model.loss(prediction, tf.constant([[32, 32, 3.0]]))
model.provide_groundtruth(
groundtruth_boxes_list=boxes,
groundtruth_classes_list=classes,
groundtruth_weights_list=weights,
groundtruth_masks_list=masks,
training_step=15)
loss_at_15 = model.loss(prediction, tf.constant([[32, 32, 3.0]]))
model.provide_groundtruth(
groundtruth_boxes_list=boxes,
groundtruth_classes_list=classes,
groundtruth_weights_list=weights,
groundtruth_masks_list=masks,
training_step=20)
loss_at_20 = model.loss(prediction, tf.constant([[32, 32, 3.0]]))
model.provide_groundtruth(
groundtruth_boxes_list=boxes,
groundtruth_classes_list=classes,
groundtruth_weights_list=weights,
groundtruth_masks_list=masks,
training_step=100)
loss_at_100 = model.loss(prediction, tf.constant([[32, 32, 3.0]]))
loss_key = 'Loss/' + deepmac_meta_arch.DEEP_MASK_COLOR_CONSISTENCY
self.assertAlmostEqual(loss_at_5[loss_key].numpy(), 0.0)
self.assertAlmostEqual(loss_at_15[loss_key].numpy(),
loss_at_20[loss_key].numpy() / 2.0)
self.assertAlmostEqual(loss_at_20[loss_key].numpy(),
loss_at_100[loss_key].numpy())
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
......
......@@ -303,7 +303,7 @@ def unstack_batch(tensor_dict, unpad_groundtruth_tensors=True):
return unbatched_tensor_dict
def provide_groundtruth(model, labels):
def provide_groundtruth(model, labels, training_step=None):
"""Provides the labels to a model as groundtruth.
This helper function extracts the corresponding boxes, classes,
......@@ -313,6 +313,8 @@ def provide_groundtruth(model, labels):
Args:
model: The detection model to provide groundtruth to.
labels: The labels for the training or evaluation inputs.
training_step: int, optional. The training step for the model. Useful
for models which want to anneal loss weights.
"""
gt_boxes_list = labels[fields.InputDataFields.groundtruth_boxes]
gt_classes_list = labels[fields.InputDataFields.groundtruth_classes]
......@@ -402,7 +404,8 @@ def provide_groundtruth(model, labels):
groundtruth_verified_neg_classes=gt_verified_neg_classes,
groundtruth_not_exhaustive_classes=gt_not_exhaustive_classes,
groundtruth_keypoint_depths_list=gt_keypoint_depths_list,
groundtruth_keypoint_depth_weights_list=gt_keypoint_depth_weights_list)
groundtruth_keypoint_depth_weights_list=gt_keypoint_depth_weights_list,
training_step=training_step)
def create_model_fn(detection_model_fn, configs, hparams=None, use_tpu=False,
......
......@@ -51,7 +51,7 @@ RESTORE_MAP_ERROR_TEMPLATE = (
def _compute_losses_and_predictions_dicts(
model, features, labels,
model, features, labels, training_step=None,
add_regularization_loss=True):
"""Computes the losses dict and predictions dict for a model on inputs.
......@@ -107,6 +107,7 @@ def _compute_losses_and_predictions_dicts(
float32 tensor containing keypoint depths information.
labels[fields.InputDataFields.groundtruth_keypoint_depth_weights] is a
float32 tensor containing the weights of the keypoint depth feature.
training_step: int, the current training step.
add_regularization_loss: Whether or not to include the model's
regularization loss in the losses dictionary.
......@@ -116,7 +117,7 @@ def _compute_losses_and_predictions_dicts(
`model.predict`.
"""
model_lib.provide_groundtruth(model, labels)
model_lib.provide_groundtruth(model, labels, training_step=training_step)
preprocessed_images = features[fields.InputDataFields.image]
prediction_dict = model.predict(
......@@ -166,7 +167,8 @@ def _ensure_model_is_built(model, input_dataset, unpad_groundtruth_tensors):
labels = model_lib.unstack_batch(
labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)
return _compute_losses_and_predictions_dicts(model, features, labels)
return _compute_losses_and_predictions_dicts(model, features, labels,
training_step=0)
strategy = tf.compat.v2.distribute.get_strategy()
if hasattr(tf.distribute.Strategy, 'run'):
......@@ -208,6 +210,7 @@ def eager_train_step(detection_model,
labels,
unpad_groundtruth_tensors,
optimizer,
training_step,
add_regularization_loss=True,
clip_gradients_value=None,
num_replicas=1.0):
......@@ -280,6 +283,7 @@ def eager_train_step(detection_model,
float32 tensor containing the weights of the keypoint depth feature.
unpad_groundtruth_tensors: A parameter passed to unstack_batch.
optimizer: The training optimizer that will update the variables.
training_step: int, the training step number.
add_regularization_loss: Whether or not to include the model's
regularization loss in the losses dictionary.
clip_gradients_value: If this is present, clip the gradients global norm
......@@ -302,7 +306,9 @@ def eager_train_step(detection_model,
with tf.GradientTape() as tape:
losses_dict, _ = _compute_losses_and_predictions_dicts(
detection_model, features, labels, add_regularization_loss)
detection_model, features, labels,
training_step=training_step,
add_regularization_loss=add_regularization_loss)
losses_dict = normalize_dict(losses_dict, num_replicas)
......@@ -632,6 +638,7 @@ def train_loop(
labels,
unpad_groundtruth_tensors,
optimizer,
training_step=global_step,
add_regularization_loss=add_regularization_loss,
clip_gradients_value=clip_gradients_value,
num_replicas=strategy.num_replicas_in_sync)
......@@ -901,7 +908,8 @@ def eager_eval_loop(
labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)
losses_dict, prediction_dict = _compute_losses_and_predictions_dicts(
detection_model, features, labels, add_regularization_loss)
detection_model, features, labels, training_step=None,
add_regularization_loss=add_regularization_loss)
prediction_dict = detection_model.postprocess(
prediction_dict, features[fields.InputDataFields.true_image_shape])
eval_features = {
......
......@@ -403,6 +403,7 @@ message CenterNet {
// Mask prediction support using DeepMAC. See https://arxiv.org/abs/2104.00613
// Next ID 24
message DeepMACMaskEstimation {
// The loss used for penalizing mask predictions.
optional ClassificationLoss classification_loss = 1;
......@@ -471,6 +472,19 @@ message CenterNet {
optional float color_consistency_loss_weight = 19 [default=0.0];
optional LossNormalize box_consistency_loss_normalize = 20 [
default=NORMALIZE_AUTO];
// If set, will use the bounding box tightness prior approach. This means
// that the max will be restricted to only be inside the box for both
// dimensions. See details here:
// https://papers.nips.cc/paper/2019/hash/e6e713296627dff6475085cc6a224464-Abstract.html
optional bool box_consistency_tightness = 21 [default=false];
optional int32 color_consistency_warmup_steps = 22 [default=0];
optional int32 color_consistency_warmup_start = 23 [default=0];
}
optional DeepMACMaskEstimation deepmac_mask_estimation = 14;
......@@ -483,6 +497,12 @@ message CenterNet {
optional PostProcessing post_processing = 24;
}
enum LossNormalize {
NORMALIZE_AUTO = 0; // SUM for 2D inputs (dice loss) and MEAN for others.
NORMALIZE_GROUNDTRUTH_COUNT = 1;
NORMALIZE_BALANCED = 3;
}
message CenterNetFeatureExtractor {
optional string type = 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment