Commit b9dd8e7c authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Implement the color consistency loss in DeepMAC.

PiperOrigin-RevId: 406939681
parent b1f77678
...@@ -104,6 +104,32 @@ class CenterNetFeatureExtractor(tf.keras.Model): ...@@ -104,6 +104,32 @@ class CenterNetFeatureExtractor(tf.keras.Model):
return (inputs - channel_means)/channel_stds return (inputs - channel_means)/channel_stds
def preprocess_reverse(self, preprocessed_inputs):
"""Undo the preprocessing and return the raw image.
This is a convenience function for some algorithms that require access
to the raw inputs.
Args:
preprocessed_inputs: A [batch_size, height, width, channels] float
tensor preprocessed_inputs from the preprocess function.
Returns:
images: A [batch_size, height, width, channels] float tensor with
the preprocessing removed.
"""
channel_means = tf.reshape(tf.constant(self._channel_means),
[1, 1, 1, -1])
channel_stds = tf.reshape(tf.constant(self._channel_stds),
[1, 1, 1, -1])
inputs = (preprocessed_inputs * channel_stds) + channel_means
if self._bgr_ordering:
blue, green, red = tf.unstack(inputs, axis=3)
inputs = tf.stack([red, green, blue], axis=3)
return inputs
@property @property
@abc.abstractmethod @abc.abstractmethod
def out_stride(self): def out_stride(self):
......
...@@ -3311,6 +3311,23 @@ class CenterNetFeatureExtractorTest(test_case.TestCase): ...@@ -3311,6 +3311,23 @@ class CenterNetFeatureExtractorTest(test_case.TestCase):
output = self.execute(graph_fn, []) output = self.execute(graph_fn, [])
self.assertAlmostEqual(output.sum(), 2 * 32 * 32 * 3) self.assertAlmostEqual(output.sum(), 2 * 32 * 32 * 3)
def test_preprocess_reverse(self):
feature_extractor = DummyFeatureExtractor(
channel_means=(1.0, 2.0, 3.0),
channel_stds=(10., 20., 30.), bgr_ordering=False,
num_feature_outputs=2, stride=4)
img = np.zeros((2, 32, 32, 3))
img[:, :, :] = 11, 22, 33
def graph_fn():
output = feature_extractor.preprocess_reverse(
feature_extractor.preprocess(img))
return output
output = self.execute(graph_fn, [])
self.assertAllClose(img, output)
def test_bgr_ordering(self): def test_bgr_ordering(self):
feature_extractor = DummyFeatureExtractor( feature_extractor = DummyFeatureExtractor(
channel_means=(0.0, 0.0, 0.0), channel_means=(0.0, 0.0, 0.0),
......
...@@ -21,13 +21,21 @@ from object_detection.protos import losses_pb2 ...@@ -21,13 +21,21 @@ from object_detection.protos import losses_pb2
from object_detection.protos import preprocessor_pb2 from object_detection.protos import preprocessor_pb2
from object_detection.utils import shape_utils from object_detection.utils import shape_utils
from object_detection.utils import spatial_transform_ops from object_detection.utils import spatial_transform_ops
from object_detection.utils import tf_version
if tf_version.is_tf2():
import tensorflow_io as tfio # pylint:disable=g-import-not-at-top
INSTANCE_EMBEDDING = 'INSTANCE_EMBEDDING' INSTANCE_EMBEDDING = 'INSTANCE_EMBEDDING'
PIXEL_EMBEDDING = 'PIXEL_EMBEDDING' PIXEL_EMBEDDING = 'PIXEL_EMBEDDING'
DEEP_MASK_ESTIMATION = 'deep_mask_estimation' DEEP_MASK_ESTIMATION = 'deep_mask_estimation'
DEEP_MASK_BOX_CONSISTENCY = 'deep_mask_box_consistency' DEEP_MASK_BOX_CONSISTENCY = 'deep_mask_box_consistency'
DEEP_MASK_COLOR_CONSISTENCY = 'deep_mask_color_consistency'
LOSS_KEY_PREFIX = center_net_meta_arch.LOSS_KEY_PREFIX LOSS_KEY_PREFIX = center_net_meta_arch.LOSS_KEY_PREFIX
NEIGHBORS_2D = [[-1, -1], [-1, 0], [-1, 1],
[0, -1], [0, 1],
[1, -1], [1, 0], [1, 1]]
class DeepMACParams( class DeepMACParams(
...@@ -37,7 +45,8 @@ class DeepMACParams( ...@@ -37,7 +45,8 @@ class DeepMACParams(
'use_xy', 'network_type', 'use_instance_embedding', 'num_init_channels', 'use_xy', 'network_type', 'use_instance_embedding', 'num_init_channels',
'predict_full_resolution_masks', 'postprocess_crop_size', 'predict_full_resolution_masks', 'postprocess_crop_size',
'max_roi_jitter_ratio', 'roi_jitter_mode', 'max_roi_jitter_ratio', 'roi_jitter_mode',
'box_consistency_loss_weight', 'box_consistency_loss_weight', 'color_consistency_threshold',
'color_consistency_dilation', 'color_consistency_loss_weight'
])): ])):
"""Class holding the DeepMAC network configutration.""" """Class holding the DeepMAC network configutration."""
...@@ -48,7 +57,9 @@ class DeepMACParams( ...@@ -48,7 +57,9 @@ class DeepMACParams(
mask_num_subsamples, use_xy, network_type, use_instance_embedding, mask_num_subsamples, use_xy, network_type, use_instance_embedding,
num_init_channels, predict_full_resolution_masks, num_init_channels, predict_full_resolution_masks,
postprocess_crop_size, max_roi_jitter_ratio, postprocess_crop_size, max_roi_jitter_ratio,
roi_jitter_mode, box_consistency_loss_weight): roi_jitter_mode, box_consistency_loss_weight,
color_consistency_threshold, color_consistency_dilation,
color_consistency_loss_weight):
return super(DeepMACParams, return super(DeepMACParams,
cls).__new__(cls, classification_loss, dim, cls).__new__(cls, classification_loss, dim,
task_loss_weight, pixel_embedding_dim, task_loss_weight, pixel_embedding_dim,
...@@ -57,7 +68,10 @@ class DeepMACParams( ...@@ -57,7 +68,10 @@ class DeepMACParams(
use_instance_embedding, num_init_channels, use_instance_embedding, num_init_channels,
predict_full_resolution_masks, predict_full_resolution_masks,
postprocess_crop_size, max_roi_jitter_ratio, postprocess_crop_size, max_roi_jitter_ratio,
roi_jitter_mode, box_consistency_loss_weight) roi_jitter_mode, box_consistency_loss_weight,
color_consistency_threshold,
color_consistency_dilation,
color_consistency_loss_weight)
def subsample_instances(classes, weights, boxes, masks, num_subsamples): def subsample_instances(classes, weights, boxes, masks, num_subsamples):
...@@ -284,6 +298,92 @@ def embedding_projection(x, y): ...@@ -284,6 +298,92 @@ def embedding_projection(x, y):
return dot return dot
def _get_2d_neighbors_kenel():
"""Returns a conv. kernel that when applies generates 2D neighbors.
Returns:
kernel: A float tensor of shape [3, 3, 1, 8]
"""
kernel = np.zeros((3, 3, 1, 8))
for i, (y, x) in enumerate(NEIGHBORS_2D):
kernel[1 + y, 1 + x, 0, i] = 1.0
return tf.constant(kernel, dtype=tf.float32)
def generate_2d_neighbors(input_tensor, dilation=2):
"""Generate a feature map of 2D neighbors.
Note: This op makes 8 (# of neighbors) as the leading dimension so that
following ops on TPU won't have to pad the last dimension to 128.
Args:
input_tensor: A float tensor of shape [height, width, channels].
dilation: int, the dilation factor for considering neighbors.
Returns:
output: A float tensor of all 8 2-D neighbors. of shape
[8, height, width, channels].
"""
input_tensor = tf.transpose(input_tensor, (2, 0, 1))
input_tensor = input_tensor[:, :, :, tf.newaxis]
kernel = _get_2d_neighbors_kenel()
output = tf.nn.atrous_conv2d(input_tensor, kernel, rate=dilation,
padding='SAME')
return tf.transpose(output, [3, 1, 2, 0])
def gaussian_pixel_similarity(a, b, theta):
norm_difference = tf.linalg.norm(a - b, axis=-1)
similarity = tf.exp(-norm_difference / theta)
return similarity
def dilated_cross_pixel_similarity(feature_map, dilation=2, theta=2.0):
"""Dilated cross pixel similarity as defined in [1].
[1]: https://arxiv.org/abs/2012.02310
Args:
feature_map: A float tensor of shape [height, width, channels]
dilation: int, the dilation factor.
theta: The denominator while taking difference inside the gaussian.
Returns:
dilated_similarity: A tensor of shape [8, height, width]
"""
neighbors = generate_2d_neighbors(feature_map, dilation)
feature_map = feature_map[tf.newaxis]
return gaussian_pixel_similarity(feature_map, neighbors, theta=theta)
def dilated_cross_same_mask_label(instance_masks, dilation=2):
"""Dilated cross pixel similarity as defined in [1].
[1]: https://arxiv.org/abs/2012.02310
Args:
instance_masks: A float tensor of shape [num_instances, height, width]
dilation: int, the dilation factor.
Returns:
dilated_same_label: A tensor of shape [8, num_instances, height, width]
"""
instance_masks = tf.transpose(instance_masks, (1, 2, 0))
neighbors = generate_2d_neighbors(instance_masks, dilation)
instance_masks = instance_masks[tf.newaxis]
same_mask_prob = ((instance_masks * neighbors) +
((1 - instance_masks) * (1 - neighbors)))
return tf.transpose(same_mask_prob, (0, 3, 1, 2))
class ResNetMaskNetwork(tf.keras.layers.Layer): class ResNetMaskNetwork(tf.keras.layers.Layer):
"""A small wrapper around ResNet blocks to predict masks.""" """A small wrapper around ResNet blocks to predict masks."""
...@@ -557,7 +657,10 @@ def deepmac_proto_to_params(deepmac_config): ...@@ -557,7 +657,10 @@ def deepmac_proto_to_params(deepmac_config):
postprocess_crop_size=deepmac_config.postprocess_crop_size, postprocess_crop_size=deepmac_config.postprocess_crop_size,
max_roi_jitter_ratio=deepmac_config.max_roi_jitter_ratio, max_roi_jitter_ratio=deepmac_config.max_roi_jitter_ratio,
roi_jitter_mode=jitter_mode, roi_jitter_mode=jitter_mode,
box_consistency_loss_weight=deepmac_config.box_consistency_loss_weight box_consistency_loss_weight=deepmac_config.box_consistency_loss_weight,
color_consistency_threshold=deepmac_config.color_consistency_threshold,
color_consistency_dilation=deepmac_config.color_consistency_dilation,
color_consistency_loss_weight=deepmac_config.color_consistency_loss_weight
) )
...@@ -756,6 +859,17 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -756,6 +859,17 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
def _compute_per_instance_mask_prediction_loss( def _compute_per_instance_mask_prediction_loss(
self, boxes, mask_logits, mask_gt): self, boxes, mask_logits, mask_gt):
"""Compute the per-instance mask loss.
Args:
boxes: A [num_instances, 4] float tensor of GT boxes.
mask_logits: A [num_instances, height, width] float tensor of predicted
masks
mask_gt: The groundtruth mask.
Returns:
loss: A [num_instances] shaped tensor with the loss for each instance.
"""
num_instances = tf.shape(boxes)[0] num_instances = tf.shape(boxes)[0]
mask_logits = self._resize_logits_like_gt(mask_logits, mask_gt) mask_logits = self._resize_logits_like_gt(mask_logits, mask_gt)
...@@ -777,6 +891,18 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -777,6 +891,18 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
def _compute_per_instance_box_consistency_loss( def _compute_per_instance_box_consistency_loss(
self, boxes_gt, boxes_for_crop, mask_logits): self, boxes_gt, boxes_for_crop, mask_logits):
"""Compute the per-instance box consistency loss.
Args:
boxes_gt: A [num_instances, 4] float tensor of GT boxes.
boxes_for_crop: A [num_instances, 4] float tensor of augmented boxes,
to be used when using crop-and-resize based mask head.
mask_logits: A [num_instances, height, width] float tensor of predicted
masks.
Returns:
loss: A [num_instances] shaped tensor with the loss for each instance.
"""
height, width = tf.shape(mask_logits)[1], tf.shape(mask_logits)[2] height, width = tf.shape(mask_logits)[1], tf.shape(mask_logits)[2]
filled_boxes = fill_boxes(boxes_gt, height, width)[:, :, :, tf.newaxis] filled_boxes = fill_boxes(boxes_gt, height, width)[:, :, :, tf.newaxis]
...@@ -811,8 +937,53 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -811,8 +937,53 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
else: else:
return tf.reduce_mean(loss, axis=[1, 2]) return tf.reduce_mean(loss, axis=[1, 2])
def _compute_per_instance_color_consistency_loss(
self, boxes, preprocessed_image, mask_logits):
"""Compute the per-instance color consistency loss.
Args:
boxes: A [num_instances, 4] float tensor of GT boxes.
preprocessed_image: A [height, width, 3] float tensor containing the
preprocessed image.
mask_logits: A [num_instances, height, width] float tensor of predicted
masks.
Returns:
loss: A [num_instances] shaped tensor with the loss for each instance.
"""
dilation = self._deepmac_params.color_consistency_dilation
height, width = (tf.shape(preprocessed_image)[0],
tf.shape(preprocessed_image)[1])
color_similarity = dilated_cross_pixel_similarity(
preprocessed_image, dilation=dilation, theta=2.0)
mask_probs = tf.nn.sigmoid(mask_logits)
same_mask_label_probability = dilated_cross_same_mask_label(
mask_probs, dilation=dilation)
same_mask_label_probability = tf.clip_by_value(
same_mask_label_probability, 1e-3, 1.0)
color_similarity_mask = (
color_similarity > self._deepmac_params.color_consistency_threshold)
color_similarity_mask = tf.cast(
color_similarity_mask[:, tf.newaxis, :, :], tf.float32)
per_pixel_loss = -(color_similarity_mask *
tf.math.log(same_mask_label_probability))
# TODO(vighneshb) explore if shrinking the box by 1px helps.
box_mask = fill_boxes(boxes, height, width)
box_mask_expanded = box_mask[tf.newaxis, :, :, :]
per_pixel_loss = per_pixel_loss * box_mask_expanded
loss = tf.reduce_sum(per_pixel_loss, axis=[0, 2, 3])
num_box_pixels = tf.maximum(1.0, tf.reduce_sum(box_mask, axis=[1, 2]))
loss = loss / num_box_pixels
return loss
def _compute_per_instance_deepmac_losses( def _compute_per_instance_deepmac_losses(
self, boxes, masks, instance_embedding, pixel_embedding): self, boxes, masks, instance_embedding, pixel_embedding,
image):
"""Returns the mask loss per instance. """Returns the mask loss per instance.
Args: Args:
...@@ -824,13 +995,16 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -824,13 +995,16 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
float tensor containing the instance embeddings. float tensor containing the instance embeddings.
pixel_embedding: optional [output_height, output_width, pixel_embedding: optional [output_height, output_width,
pixel_embedding_size] float tensor containing the per-pixel embeddings. pixel_embedding_size] float tensor containing the per-pixel embeddings.
image: [output_height, output_width, channels] float tensor
denoting the input image.
Returns: Returns:
mask_prediction_loss: A [num_instances] shaped float tensor containing the mask_prediction_loss: A [num_instances] shaped float tensor containing the
mask loss for each instance. mask loss for each instance.
box_consistency_loss: A [num_instances] shaped float tensor containing box_consistency_loss: A [num_instances] shaped float tensor containing
the box consistency loss for each instance. the box consistency loss for each instance.
box_consistency_loss: A [num_instances] shaped float tensor containing
the color consistency loss.
""" """
if tf.keras.backend.learning_phase(): if tf.keras.backend.learning_phase():
...@@ -855,7 +1029,20 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -855,7 +1029,20 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
box_consistency_loss = self._compute_per_instance_box_consistency_loss( box_consistency_loss = self._compute_per_instance_box_consistency_loss(
boxes, boxes_for_crop, mask_logits) boxes, boxes_for_crop, mask_logits)
return mask_prediction_loss, box_consistency_loss color_consistency_loss = self._compute_per_instance_color_consistency_loss(
boxes, image, mask_logits)
return mask_prediction_loss, box_consistency_loss, color_consistency_loss
def _get_lab_image(self, preprocessed_image):
raw_image = self._feature_extractor.preprocess_reverse(
preprocessed_image)
raw_image = raw_image / 255.0
if tf_version.is_tf1():
raise NotImplementedError(('RGB-to-LAB conversion required for the color'
' consistency loss is not supported in TF1.'))
return tfio.experimental.color.rgb_to_lab(raw_image)
def _compute_instance_masks_loss(self, prediction_dict): def _compute_instance_masks_loss(self, prediction_dict):
"""Computes the mask loss. """Computes the mask loss.
...@@ -879,9 +1066,19 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -879,9 +1066,19 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
loss_dict = { loss_dict = {
DEEP_MASK_ESTIMATION: 0.0, DEEP_MASK_ESTIMATION: 0.0,
DEEP_MASK_BOX_CONSISTENCY: 0.0 DEEP_MASK_BOX_CONSISTENCY: 0.0,
DEEP_MASK_COLOR_CONSISTENCY: 0.0
} }
prediction_shape = tf.shape(prediction_dict[INSTANCE_EMBEDDING][0])
height, width = prediction_shape[1], prediction_shape[2]
preprocessed_image = tf.image.resize(
prediction_dict['preprocessed_inputs'], (height, width))
image = self._get_lab_image(preprocessed_image)
# TODO(vighneshb) See if we can save memory by only using the final
# prediction
# Iterate over multiple preidctions by backbone (for hourglass length=2) # Iterate over multiple preidctions by backbone (for hourglass length=2)
for instance_pred, pixel_pred in zip( for instance_pred, pixel_pred in zip(
prediction_dict[INSTANCE_EMBEDDING], prediction_dict[INSTANCE_EMBEDDING],
...@@ -896,9 +1093,11 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -896,9 +1093,11 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
classes, valid_mask_weights, masks = filter_masked_classes( classes, valid_mask_weights, masks = filter_masked_classes(
allowed_masked_classes_ids, classes, weights, masks) allowed_masked_classes_ids, classes, weights, masks)
per_instance_mask_loss, per_instance_consistency_loss = ( (per_instance_mask_loss, per_instance_consistency_loss,
per_instance_color_consistency_loss) = (
self._compute_per_instance_deepmac_losses( self._compute_per_instance_deepmac_losses(
boxes, masks, instance_pred[i], pixel_pred[i])) boxes, masks, instance_pred[i], pixel_pred[i],
image[i]))
per_instance_mask_loss *= valid_mask_weights per_instance_mask_loss *= valid_mask_weights
per_instance_consistency_loss *= weights per_instance_consistency_loss *= weights
...@@ -912,6 +1111,9 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -912,6 +1111,9 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
loss_dict[DEEP_MASK_BOX_CONSISTENCY] += ( loss_dict[DEEP_MASK_BOX_CONSISTENCY] += (
tf.reduce_sum(per_instance_consistency_loss) / num_instances) tf.reduce_sum(per_instance_consistency_loss) / num_instances)
loss_dict[DEEP_MASK_COLOR_CONSISTENCY] += (
tf.reduce_sum(per_instance_color_consistency_loss) / num_instances)
batch_size = len(gt_boxes_list) batch_size = len(gt_boxes_list)
num_predictions = len(prediction_dict[INSTANCE_EMBEDDING]) num_predictions = len(prediction_dict[INSTANCE_EMBEDDING])
...@@ -937,6 +1139,12 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -937,6 +1139,12 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
self._deepmac_params.box_consistency_loss_weight * mask_loss_dict[ self._deepmac_params.box_consistency_loss_weight * mask_loss_dict[
DEEP_MASK_BOX_CONSISTENCY] DEEP_MASK_BOX_CONSISTENCY]
) )
if self._deepmac_params.color_consistency_loss_weight > 0.0:
losses_dict[LOSS_KEY_PREFIX + '/' + DEEP_MASK_COLOR_CONSISTENCY] = (
self._deepmac_params.box_consistency_loss_weight * mask_loss_dict[
DEEP_MASK_COLOR_CONSISTENCY]
)
return losses_dict return losses_dict
def postprocess(self, prediction_dict, true_image_shapes, **params): def postprocess(self, prediction_dict, true_image_shapes, **params):
......
...@@ -64,7 +64,8 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False, ...@@ -64,7 +64,8 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False,
use_instance_embedding=True, mask_num_subsamples=-1, use_instance_embedding=True, mask_num_subsamples=-1,
network_type='hourglass10', use_xy=True, network_type='hourglass10', use_xy=True,
pixel_embedding_dim=2, pixel_embedding_dim=2,
dice_loss_prediction_probability=False): dice_loss_prediction_probability=False,
color_consistency_threshold=0.5):
"""Builds the DeepMAC meta architecture.""" """Builds the DeepMAC meta architecture."""
feature_extractor = DummyFeatureExtractor( feature_extractor = DummyFeatureExtractor(
...@@ -110,6 +111,9 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False, ...@@ -110,6 +111,9 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False,
max_roi_jitter_ratio=0.0, max_roi_jitter_ratio=0.0,
roi_jitter_mode='random', roi_jitter_mode='random',
box_consistency_loss_weight=1.0, box_consistency_loss_weight=1.0,
color_consistency_threshold=color_consistency_threshold,
color_consistency_dilation=2,
color_consistency_loss_weight=1.0
) )
object_detection_params = center_net_meta_arch.ObjectDetectionParams( object_detection_params = center_net_meta_arch.ObjectDetectionParams(
...@@ -206,29 +210,97 @@ class DeepMACUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -206,29 +210,97 @@ class DeepMACUtilsTest(tf.test.TestCase, parameterized.TestCase):
out = net(tf.zeros((2, 24))) out = net(tf.zeros((2, 24)))
self.assertEqual(out.shape, (2, 8)) self.assertEqual(out.shape, (2, 8))
def test_generate_2d_neighbors_shape(self):
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.') inp = tf.zeros((13, 14, 3))
class DeepMACMaskHeadTest(tf.test.TestCase, parameterized.TestCase): out = deepmac_meta_arch.generate_2d_neighbors(inp)
self.assertEqual((8, 13, 14, 3), out.shape)
def test_mask_network(self): def test_generate_2d_neighbors(self):
net = deepmac_meta_arch.MaskHeadNetwork('hourglass10', 8)
out = net(tf.zeros((2, 4)), tf.zeros((2, 32, 32, 16)), training=True) inp = np.arange(16).reshape(4, 4).astype(np.float32)
self.assertEqual(out.shape, (2, 32, 32)) inp = tf.stack([inp, inp * 2], axis=2)
out = deepmac_meta_arch.generate_2d_neighbors(inp, dilation=1)
self.assertEqual((8, 4, 4, 2), out.shape)
for i in range(2):
expected = np.array([0, 1, 2, 4, 6, 8, 9, 10]) * (i + 1)
self.assertAllEqual(out[:, 1, 1, i], expected)
def test_mask_network_hourglass20(self): expected = np.array([1, 2, 3, 5, 7, 9, 10, 11]) * (i + 1)
net = deepmac_meta_arch.MaskHeadNetwork('hourglass20', 8) self.assertAllEqual(out[:, 1, 2, i], expected)
out = net(tf.zeros((2, 4)), tf.zeros((2, 32, 32, 16)), training=True) expected = np.array([4, 5, 6, 8, 10, 12, 13, 14]) * (i + 1)
self.assertEqual(out.shape, (2, 32, 32)) self.assertAllEqual(out[:, 2, 1, i], expected)
expected = np.array([5, 6, 7, 9, 11, 13, 14, 15]) * (i + 1)
self.assertAllEqual(out[:, 2, 2, i], expected)
def test_generate_2d_neighbors_dilation2(self):
inp = np.arange(16).reshape(4, 4, 1).astype(np.float32)
out = deepmac_meta_arch.generate_2d_neighbors(inp, dilation=2)
self.assertEqual((8, 4, 4, 1), out.shape)
expected = np.array([0, 0, 0, 0, 2, 0, 8, 10])
self.assertAllEqual(out[:, 0, 0, 0], expected)
def test_dilated_similarity_shape(self):
fmap = tf.zeros((32, 32, 9))
similarity = deepmac_meta_arch.dilated_cross_pixel_similarity(
fmap)
self.assertEqual((8, 32, 32), similarity.shape)
def test_dilated_similarity(self):
fmap = np.zeros((5, 5, 2), dtype=np.float32)
fmap[0, 0, :] = 1.0
fmap[4, 4, :] = 1.0
similarity = deepmac_meta_arch.dilated_cross_pixel_similarity(
fmap, theta=1.0, dilation=2)
self.assertAlmostEqual(similarity.numpy()[0, 2, 2],
np.exp(-np.sqrt(2)))
def test_dilated_same_instance_mask_shape(self):
instances = tf.zeros((5, 32, 32))
output = deepmac_meta_arch.dilated_cross_same_mask_label(instances)
self.assertEqual((8, 5, 32, 32), output.shape)
def test_mask_network_resnet(self): def test_dilated_same_instance_mask(self):
net = deepmac_meta_arch.MaskHeadNetwork('resnet4') instances = np.zeros((2, 5, 5), dtype=np.float32)
instances[0, 0, 0] = 1.0
instances[0, 2, 2] = 1.0
instances[0, 4, 4] = 1.0
output = deepmac_meta_arch.dilated_cross_same_mask_label(instances).numpy()
self.assertAllClose(np.ones((8, 5, 5)), output[:, 1, :, :])
self.assertAllClose([1, 0, 0, 0, 0, 0, 0, 1], output[:, 0, 2, 2])
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class DeepMACMaskHeadTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(
['hourglass10', 'hourglass20', 'resnet4'])
def test_mask_network(self, head_type):
net = deepmac_meta_arch.MaskHeadNetwork(head_type, 8)
out = net(tf.zeros((2, 4)), tf.zeros((2, 32, 32, 16)), training=True) out = net(tf.zeros((2, 4)), tf.zeros((2, 32, 32, 16)), training=True)
self.assertEqual(out.shape, (2, 32, 32)) self.assertEqual(out.shape, (2, 32, 32))
def test_mask_network_params_resnet4(self):
net = deepmac_meta_arch.MaskHeadNetwork('resnet4', num_init_channels=8)
_ = net(tf.zeros((2, 16)), tf.zeros((2, 32, 32, 16)), training=True)
trainable_params = tf.reduce_sum([tf.reduce_prod(tf.shape(w)) for w in
net.trainable_weights])
self.assertEqual(trainable_params.numpy(), 8665)
def test_mask_network_resnet_tf_function(self): def test_mask_network_resnet_tf_function(self):
net = deepmac_meta_arch.MaskHeadNetwork('resnet8') net = deepmac_meta_arch.MaskHeadNetwork('resnet8')
...@@ -360,8 +432,9 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -360,8 +432,9 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
masks[1, 16:, 16:] = 1.0 masks[1, 16:, 16:] = 1.0
masks = tf.constant(masks) masks = tf.constant(masks)
loss, _ = model._compute_per_instance_deepmac_losses( loss, _, _ = model._compute_per_instance_deepmac_losses(
boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2))) boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2)),
tf.zeros((16, 16, 3)))
self.assertAllClose( self.assertAllClose(
loss, np.zeros(2) - tf.math.log(tf.nn.sigmoid(0.9))) loss, np.zeros(2) - tf.math.log(tf.nn.sigmoid(0.9)))
...@@ -373,8 +446,9 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -373,8 +446,9 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
masks = np.ones((2, 128, 128), dtype=np.float32) masks = np.ones((2, 128, 128), dtype=np.float32)
masks = tf.constant(masks) masks = tf.constant(masks)
loss, _ = model._compute_per_instance_deepmac_losses( loss, _, _ = model._compute_per_instance_deepmac_losses(
boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2))) boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2)),
tf.zeros((32, 32, 3)))
self.assertAllClose( self.assertAllClose(
loss, np.zeros(2) - tf.math.log(tf.nn.sigmoid(0.9))) loss, np.zeros(2) - tf.math.log(tf.nn.sigmoid(0.9)))
...@@ -387,8 +461,9 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -387,8 +461,9 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
masks = np.ones((2, 128, 128), dtype=np.float32) masks = np.ones((2, 128, 128), dtype=np.float32)
masks = tf.constant(masks) masks = tf.constant(masks)
loss, _ = model._compute_per_instance_deepmac_losses( loss, _, _ = model._compute_per_instance_deepmac_losses(
boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2))) boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2)),
tf.zeros((32, 32, 3)))
pred = tf.nn.sigmoid(0.9) pred = tf.nn.sigmoid(0.9)
expected = (1.0 - ((2.0 * pred) / (1.0 + pred))) expected = (1.0 - ((2.0 * pred) / (1.0 + pred)))
self.assertAllClose(loss, [expected, expected], rtol=1e-3) self.assertAllClose(loss, [expected, expected], rtol=1e-3)
...@@ -397,8 +472,9 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -397,8 +472,9 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
boxes = tf.zeros([0, 4]) boxes = tf.zeros([0, 4])
masks = tf.zeros([0, 128, 128]) masks = tf.zeros([0, 128, 128])
loss, _ = self.model._compute_per_instance_deepmac_losses( loss, _, _ = self.model._compute_per_instance_deepmac_losses(
boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2))) boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2)),
tf.zeros((16, 16, 3)))
self.assertEqual(loss.shape, (0,)) self.assertEqual(loss.shape, (0,))
def test_postprocess(self): def test_postprocess(self):
...@@ -560,6 +636,29 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -560,6 +636,29 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(loss, [yloss + xloss]) self.assertAllClose(loss, [yloss + xloss])
def test_color_consistency_loss_full_res_shape(self):
model = build_meta_arch(use_dice_loss=True,
predict_full_resolution_masks=True)
boxes = tf.zeros((3, 4))
img = tf.zeros((32, 32, 3))
mask_logits = tf.zeros((3, 32, 32))
loss = model._compute_per_instance_color_consistency_loss(
boxes, img, mask_logits)
self.assertEqual([3], loss.shape)
def test_color_consistency_1_threshold(self):
model = build_meta_arch(predict_full_resolution_masks=True,
color_consistency_threshold=0.99)
boxes = tf.zeros((3, 4))
img = tf.zeros((32, 32, 3))
mask_logits = tf.zeros((3, 32, 32)) - 1e4
loss = model._compute_per_instance_color_consistency_loss(
boxes, img, mask_logits)
self.assertAllClose(loss, np.zeros(3))
def test_box_consistency_dice_loss_full_res(self): def test_box_consistency_dice_loss_full_res(self):
model = build_meta_arch(use_dice_loss=True, model = build_meta_arch(use_dice_loss=True,
...@@ -575,6 +674,11 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -575,6 +674,11 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
boxes_gt, boxes_jittered, tf.constant(mask_prediction)) boxes_gt, boxes_jittered, tf.constant(mask_prediction))
self.assertAlmostEqual(loss[0].numpy(), 1 / 3) self.assertAlmostEqual(loss[0].numpy(), 1 / 3)
def test_get_lab_image_shape(self):
output = self.model._get_lab_image(tf.zeros((2, 4, 4, 3)))
self.assertEqual(output.shape, (2, 4, 4, 3))
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.') @unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class FullyConnectedMaskHeadTest(tf.test.TestCase): class FullyConnectedMaskHeadTest(tf.test.TestCase):
......
...@@ -22,6 +22,7 @@ REQUIRED_PACKAGES = [ ...@@ -22,6 +22,7 @@ REQUIRED_PACKAGES = [
'scipy', 'scipy',
'pandas', 'pandas',
'tf-models-official>=2.5.1', 'tf-models-official>=2.5.1',
'tensorflow_io'
] ]
setup( setup(
......
...@@ -464,6 +464,13 @@ message CenterNet { ...@@ -464,6 +464,13 @@ message CenterNet {
// Weight for the box consistency loss as described in the BoxInst paper // Weight for the box consistency loss as described in the BoxInst paper
// https://arxiv.org/abs/2012.02310 // https://arxiv.org/abs/2012.02310
optional float box_consistency_loss_weight = 16 [default=0.0]; optional float box_consistency_loss_weight = 16 [default=0.0];
optional float color_consistency_threshold = 17 [default=0.4];
optional int32 color_consistency_dilation = 18 [default=2];
optional float color_consistency_loss_weight = 19 [default=0.0];
} }
optional DeepMACMaskEstimation deepmac_mask_estimation = 14; optional DeepMACMaskEstimation deepmac_mask_estimation = 14;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment