Commit b9dd8e7c authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Implement the color consistency loss in DeepMAC.

PiperOrigin-RevId: 406939681
parent b1f77678
......@@ -104,6 +104,32 @@ class CenterNetFeatureExtractor(tf.keras.Model):
return (inputs - channel_means)/channel_stds
def preprocess_reverse(self, preprocessed_inputs):
"""Undo the preprocessing and return the raw image.
This is a convenience function for some algorithms that require access
to the raw inputs.
Args:
preprocessed_inputs: A [batch_size, height, width, channels] float
tensor preprocessed_inputs from the preprocess function.
Returns:
images: A [batch_size, height, width, channels] float tensor with
the preprocessing removed.
"""
channel_means = tf.reshape(tf.constant(self._channel_means),
[1, 1, 1, -1])
channel_stds = tf.reshape(tf.constant(self._channel_stds),
[1, 1, 1, -1])
inputs = (preprocessed_inputs * channel_stds) + channel_means
if self._bgr_ordering:
blue, green, red = tf.unstack(inputs, axis=3)
inputs = tf.stack([red, green, blue], axis=3)
return inputs
@property
@abc.abstractmethod
def out_stride(self):
......
......@@ -3311,6 +3311,23 @@ class CenterNetFeatureExtractorTest(test_case.TestCase):
output = self.execute(graph_fn, [])
self.assertAlmostEqual(output.sum(), 2 * 32 * 32 * 3)
def test_preprocess_reverse(self):
feature_extractor = DummyFeatureExtractor(
channel_means=(1.0, 2.0, 3.0),
channel_stds=(10., 20., 30.), bgr_ordering=False,
num_feature_outputs=2, stride=4)
img = np.zeros((2, 32, 32, 3))
img[:, :, :] = 11, 22, 33
def graph_fn():
output = feature_extractor.preprocess_reverse(
feature_extractor.preprocess(img))
return output
output = self.execute(graph_fn, [])
self.assertAllClose(img, output)
def test_bgr_ordering(self):
feature_extractor = DummyFeatureExtractor(
channel_means=(0.0, 0.0, 0.0),
......
......@@ -21,13 +21,21 @@ from object_detection.protos import losses_pb2
from object_detection.protos import preprocessor_pb2
from object_detection.utils import shape_utils
from object_detection.utils import spatial_transform_ops
from object_detection.utils import tf_version
if tf_version.is_tf2():
import tensorflow_io as tfio # pylint:disable=g-import-not-at-top
INSTANCE_EMBEDDING = 'INSTANCE_EMBEDDING'
PIXEL_EMBEDDING = 'PIXEL_EMBEDDING'
DEEP_MASK_ESTIMATION = 'deep_mask_estimation'
DEEP_MASK_BOX_CONSISTENCY = 'deep_mask_box_consistency'
DEEP_MASK_COLOR_CONSISTENCY = 'deep_mask_color_consistency'
LOSS_KEY_PREFIX = center_net_meta_arch.LOSS_KEY_PREFIX
NEIGHBORS_2D = [[-1, -1], [-1, 0], [-1, 1],
[0, -1], [0, 1],
[1, -1], [1, 0], [1, 1]]
class DeepMACParams(
......@@ -37,7 +45,8 @@ class DeepMACParams(
'use_xy', 'network_type', 'use_instance_embedding', 'num_init_channels',
'predict_full_resolution_masks', 'postprocess_crop_size',
'max_roi_jitter_ratio', 'roi_jitter_mode',
'box_consistency_loss_weight',
'box_consistency_loss_weight', 'color_consistency_threshold',
'color_consistency_dilation', 'color_consistency_loss_weight'
])):
"""Class holding the DeepMAC network configutration."""
......@@ -48,7 +57,9 @@ class DeepMACParams(
mask_num_subsamples, use_xy, network_type, use_instance_embedding,
num_init_channels, predict_full_resolution_masks,
postprocess_crop_size, max_roi_jitter_ratio,
roi_jitter_mode, box_consistency_loss_weight):
roi_jitter_mode, box_consistency_loss_weight,
color_consistency_threshold, color_consistency_dilation,
color_consistency_loss_weight):
return super(DeepMACParams,
cls).__new__(cls, classification_loss, dim,
task_loss_weight, pixel_embedding_dim,
......@@ -57,7 +68,10 @@ class DeepMACParams(
use_instance_embedding, num_init_channels,
predict_full_resolution_masks,
postprocess_crop_size, max_roi_jitter_ratio,
roi_jitter_mode, box_consistency_loss_weight)
roi_jitter_mode, box_consistency_loss_weight,
color_consistency_threshold,
color_consistency_dilation,
color_consistency_loss_weight)
def subsample_instances(classes, weights, boxes, masks, num_subsamples):
......@@ -284,6 +298,92 @@ def embedding_projection(x, y):
return dot
def _get_2d_neighbors_kenel():
"""Returns a conv. kernel that when applies generates 2D neighbors.
Returns:
kernel: A float tensor of shape [3, 3, 1, 8]
"""
kernel = np.zeros((3, 3, 1, 8))
for i, (y, x) in enumerate(NEIGHBORS_2D):
kernel[1 + y, 1 + x, 0, i] = 1.0
return tf.constant(kernel, dtype=tf.float32)
def generate_2d_neighbors(input_tensor, dilation=2):
"""Generate a feature map of 2D neighbors.
Note: This op makes 8 (# of neighbors) as the leading dimension so that
following ops on TPU won't have to pad the last dimension to 128.
Args:
input_tensor: A float tensor of shape [height, width, channels].
dilation: int, the dilation factor for considering neighbors.
Returns:
output: A float tensor of all 8 2-D neighbors. of shape
[8, height, width, channels].
"""
input_tensor = tf.transpose(input_tensor, (2, 0, 1))
input_tensor = input_tensor[:, :, :, tf.newaxis]
kernel = _get_2d_neighbors_kenel()
output = tf.nn.atrous_conv2d(input_tensor, kernel, rate=dilation,
padding='SAME')
return tf.transpose(output, [3, 1, 2, 0])
def gaussian_pixel_similarity(a, b, theta):
norm_difference = tf.linalg.norm(a - b, axis=-1)
similarity = tf.exp(-norm_difference / theta)
return similarity
def dilated_cross_pixel_similarity(feature_map, dilation=2, theta=2.0):
"""Dilated cross pixel similarity as defined in [1].
[1]: https://arxiv.org/abs/2012.02310
Args:
feature_map: A float tensor of shape [height, width, channels]
dilation: int, the dilation factor.
theta: The denominator while taking difference inside the gaussian.
Returns:
dilated_similarity: A tensor of shape [8, height, width]
"""
neighbors = generate_2d_neighbors(feature_map, dilation)
feature_map = feature_map[tf.newaxis]
return gaussian_pixel_similarity(feature_map, neighbors, theta=theta)
def dilated_cross_same_mask_label(instance_masks, dilation=2):
"""Dilated cross pixel similarity as defined in [1].
[1]: https://arxiv.org/abs/2012.02310
Args:
instance_masks: A float tensor of shape [num_instances, height, width]
dilation: int, the dilation factor.
Returns:
dilated_same_label: A tensor of shape [8, num_instances, height, width]
"""
instance_masks = tf.transpose(instance_masks, (1, 2, 0))
neighbors = generate_2d_neighbors(instance_masks, dilation)
instance_masks = instance_masks[tf.newaxis]
same_mask_prob = ((instance_masks * neighbors) +
((1 - instance_masks) * (1 - neighbors)))
return tf.transpose(same_mask_prob, (0, 3, 1, 2))
class ResNetMaskNetwork(tf.keras.layers.Layer):
"""A small wrapper around ResNet blocks to predict masks."""
......@@ -557,7 +657,10 @@ def deepmac_proto_to_params(deepmac_config):
postprocess_crop_size=deepmac_config.postprocess_crop_size,
max_roi_jitter_ratio=deepmac_config.max_roi_jitter_ratio,
roi_jitter_mode=jitter_mode,
box_consistency_loss_weight=deepmac_config.box_consistency_loss_weight
box_consistency_loss_weight=deepmac_config.box_consistency_loss_weight,
color_consistency_threshold=deepmac_config.color_consistency_threshold,
color_consistency_dilation=deepmac_config.color_consistency_dilation,
color_consistency_loss_weight=deepmac_config.color_consistency_loss_weight
)
......@@ -756,6 +859,17 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
def _compute_per_instance_mask_prediction_loss(
self, boxes, mask_logits, mask_gt):
"""Compute the per-instance mask loss.
Args:
boxes: A [num_instances, 4] float tensor of GT boxes.
mask_logits: A [num_instances, height, width] float tensor of predicted
masks
mask_gt: The groundtruth mask.
Returns:
loss: A [num_instances] shaped tensor with the loss for each instance.
"""
num_instances = tf.shape(boxes)[0]
mask_logits = self._resize_logits_like_gt(mask_logits, mask_gt)
......@@ -777,6 +891,18 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
def _compute_per_instance_box_consistency_loss(
self, boxes_gt, boxes_for_crop, mask_logits):
"""Compute the per-instance box consistency loss.
Args:
boxes_gt: A [num_instances, 4] float tensor of GT boxes.
boxes_for_crop: A [num_instances, 4] float tensor of augmented boxes,
to be used when using crop-and-resize based mask head.
mask_logits: A [num_instances, height, width] float tensor of predicted
masks.
Returns:
loss: A [num_instances] shaped tensor with the loss for each instance.
"""
height, width = tf.shape(mask_logits)[1], tf.shape(mask_logits)[2]
filled_boxes = fill_boxes(boxes_gt, height, width)[:, :, :, tf.newaxis]
......@@ -811,8 +937,53 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
else:
return tf.reduce_mean(loss, axis=[1, 2])
def _compute_per_instance_color_consistency_loss(
self, boxes, preprocessed_image, mask_logits):
"""Compute the per-instance color consistency loss.
Args:
boxes: A [num_instances, 4] float tensor of GT boxes.
preprocessed_image: A [height, width, 3] float tensor containing the
preprocessed image.
mask_logits: A [num_instances, height, width] float tensor of predicted
masks.
Returns:
loss: A [num_instances] shaped tensor with the loss for each instance.
"""
dilation = self._deepmac_params.color_consistency_dilation
height, width = (tf.shape(preprocessed_image)[0],
tf.shape(preprocessed_image)[1])
color_similarity = dilated_cross_pixel_similarity(
preprocessed_image, dilation=dilation, theta=2.0)
mask_probs = tf.nn.sigmoid(mask_logits)
same_mask_label_probability = dilated_cross_same_mask_label(
mask_probs, dilation=dilation)
same_mask_label_probability = tf.clip_by_value(
same_mask_label_probability, 1e-3, 1.0)
color_similarity_mask = (
color_similarity > self._deepmac_params.color_consistency_threshold)
color_similarity_mask = tf.cast(
color_similarity_mask[:, tf.newaxis, :, :], tf.float32)
per_pixel_loss = -(color_similarity_mask *
tf.math.log(same_mask_label_probability))
# TODO(vighneshb) explore if shrinking the box by 1px helps.
box_mask = fill_boxes(boxes, height, width)
box_mask_expanded = box_mask[tf.newaxis, :, :, :]
per_pixel_loss = per_pixel_loss * box_mask_expanded
loss = tf.reduce_sum(per_pixel_loss, axis=[0, 2, 3])
num_box_pixels = tf.maximum(1.0, tf.reduce_sum(box_mask, axis=[1, 2]))
loss = loss / num_box_pixels
return loss
def _compute_per_instance_deepmac_losses(
self, boxes, masks, instance_embedding, pixel_embedding):
self, boxes, masks, instance_embedding, pixel_embedding,
image):
"""Returns the mask loss per instance.
Args:
......@@ -824,13 +995,16 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
float tensor containing the instance embeddings.
pixel_embedding: optional [output_height, output_width,
pixel_embedding_size] float tensor containing the per-pixel embeddings.
image: [output_height, output_width, channels] float tensor
denoting the input image.
Returns:
mask_prediction_loss: A [num_instances] shaped float tensor containing the
mask loss for each instance.
box_consistency_loss: A [num_instances] shaped float tensor containing
the box consistency loss for each instance.
box_consistency_loss: A [num_instances] shaped float tensor containing
the color consistency loss.
"""
if tf.keras.backend.learning_phase():
......@@ -855,7 +1029,20 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
box_consistency_loss = self._compute_per_instance_box_consistency_loss(
boxes, boxes_for_crop, mask_logits)
return mask_prediction_loss, box_consistency_loss
color_consistency_loss = self._compute_per_instance_color_consistency_loss(
boxes, image, mask_logits)
return mask_prediction_loss, box_consistency_loss, color_consistency_loss
def _get_lab_image(self, preprocessed_image):
raw_image = self._feature_extractor.preprocess_reverse(
preprocessed_image)
raw_image = raw_image / 255.0
if tf_version.is_tf1():
raise NotImplementedError(('RGB-to-LAB conversion required for the color'
' consistency loss is not supported in TF1.'))
return tfio.experimental.color.rgb_to_lab(raw_image)
def _compute_instance_masks_loss(self, prediction_dict):
"""Computes the mask loss.
......@@ -879,9 +1066,19 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
loss_dict = {
DEEP_MASK_ESTIMATION: 0.0,
DEEP_MASK_BOX_CONSISTENCY: 0.0
DEEP_MASK_BOX_CONSISTENCY: 0.0,
DEEP_MASK_COLOR_CONSISTENCY: 0.0
}
prediction_shape = tf.shape(prediction_dict[INSTANCE_EMBEDDING][0])
height, width = prediction_shape[1], prediction_shape[2]
preprocessed_image = tf.image.resize(
prediction_dict['preprocessed_inputs'], (height, width))
image = self._get_lab_image(preprocessed_image)
# TODO(vighneshb) See if we can save memory by only using the final
# prediction
# Iterate over multiple preidctions by backbone (for hourglass length=2)
for instance_pred, pixel_pred in zip(
prediction_dict[INSTANCE_EMBEDDING],
......@@ -896,9 +1093,11 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
classes, valid_mask_weights, masks = filter_masked_classes(
allowed_masked_classes_ids, classes, weights, masks)
per_instance_mask_loss, per_instance_consistency_loss = (
self._compute_per_instance_deepmac_losses(
boxes, masks, instance_pred[i], pixel_pred[i]))
(per_instance_mask_loss, per_instance_consistency_loss,
per_instance_color_consistency_loss) = (
self._compute_per_instance_deepmac_losses(
boxes, masks, instance_pred[i], pixel_pred[i],
image[i]))
per_instance_mask_loss *= valid_mask_weights
per_instance_consistency_loss *= weights
......@@ -912,6 +1111,9 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
loss_dict[DEEP_MASK_BOX_CONSISTENCY] += (
tf.reduce_sum(per_instance_consistency_loss) / num_instances)
loss_dict[DEEP_MASK_COLOR_CONSISTENCY] += (
tf.reduce_sum(per_instance_color_consistency_loss) / num_instances)
batch_size = len(gt_boxes_list)
num_predictions = len(prediction_dict[INSTANCE_EMBEDDING])
......@@ -937,6 +1139,12 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
self._deepmac_params.box_consistency_loss_weight * mask_loss_dict[
DEEP_MASK_BOX_CONSISTENCY]
)
if self._deepmac_params.color_consistency_loss_weight > 0.0:
losses_dict[LOSS_KEY_PREFIX + '/' + DEEP_MASK_COLOR_CONSISTENCY] = (
self._deepmac_params.box_consistency_loss_weight * mask_loss_dict[
DEEP_MASK_COLOR_CONSISTENCY]
)
return losses_dict
def postprocess(self, prediction_dict, true_image_shapes, **params):
......
......@@ -64,7 +64,8 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False,
use_instance_embedding=True, mask_num_subsamples=-1,
network_type='hourglass10', use_xy=True,
pixel_embedding_dim=2,
dice_loss_prediction_probability=False):
dice_loss_prediction_probability=False,
color_consistency_threshold=0.5):
"""Builds the DeepMAC meta architecture."""
feature_extractor = DummyFeatureExtractor(
......@@ -110,6 +111,9 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False,
max_roi_jitter_ratio=0.0,
roi_jitter_mode='random',
box_consistency_loss_weight=1.0,
color_consistency_threshold=color_consistency_threshold,
color_consistency_dilation=2,
color_consistency_loss_weight=1.0
)
object_detection_params = center_net_meta_arch.ObjectDetectionParams(
......@@ -206,29 +210,97 @@ class DeepMACUtilsTest(tf.test.TestCase, parameterized.TestCase):
out = net(tf.zeros((2, 24)))
self.assertEqual(out.shape, (2, 8))
def test_generate_2d_neighbors_shape(self):
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class DeepMACMaskHeadTest(tf.test.TestCase, parameterized.TestCase):
inp = tf.zeros((13, 14, 3))
out = deepmac_meta_arch.generate_2d_neighbors(inp)
self.assertEqual((8, 13, 14, 3), out.shape)
def test_mask_network(self):
net = deepmac_meta_arch.MaskHeadNetwork('hourglass10', 8)
def test_generate_2d_neighbors(self):
out = net(tf.zeros((2, 4)), tf.zeros((2, 32, 32, 16)), training=True)
self.assertEqual(out.shape, (2, 32, 32))
inp = np.arange(16).reshape(4, 4).astype(np.float32)
inp = tf.stack([inp, inp * 2], axis=2)
out = deepmac_meta_arch.generate_2d_neighbors(inp, dilation=1)
self.assertEqual((8, 4, 4, 2), out.shape)
for i in range(2):
expected = np.array([0, 1, 2, 4, 6, 8, 9, 10]) * (i + 1)
self.assertAllEqual(out[:, 1, 1, i], expected)
def test_mask_network_hourglass20(self):
net = deepmac_meta_arch.MaskHeadNetwork('hourglass20', 8)
expected = np.array([1, 2, 3, 5, 7, 9, 10, 11]) * (i + 1)
self.assertAllEqual(out[:, 1, 2, i], expected)
out = net(tf.zeros((2, 4)), tf.zeros((2, 32, 32, 16)), training=True)
self.assertEqual(out.shape, (2, 32, 32))
expected = np.array([4, 5, 6, 8, 10, 12, 13, 14]) * (i + 1)
self.assertAllEqual(out[:, 2, 1, i], expected)
expected = np.array([5, 6, 7, 9, 11, 13, 14, 15]) * (i + 1)
self.assertAllEqual(out[:, 2, 2, i], expected)
def test_generate_2d_neighbors_dilation2(self):
inp = np.arange(16).reshape(4, 4, 1).astype(np.float32)
out = deepmac_meta_arch.generate_2d_neighbors(inp, dilation=2)
self.assertEqual((8, 4, 4, 1), out.shape)
expected = np.array([0, 0, 0, 0, 2, 0, 8, 10])
self.assertAllEqual(out[:, 0, 0, 0], expected)
def test_dilated_similarity_shape(self):
fmap = tf.zeros((32, 32, 9))
similarity = deepmac_meta_arch.dilated_cross_pixel_similarity(
fmap)
self.assertEqual((8, 32, 32), similarity.shape)
def test_dilated_similarity(self):
fmap = np.zeros((5, 5, 2), dtype=np.float32)
fmap[0, 0, :] = 1.0
fmap[4, 4, :] = 1.0
similarity = deepmac_meta_arch.dilated_cross_pixel_similarity(
fmap, theta=1.0, dilation=2)
self.assertAlmostEqual(similarity.numpy()[0, 2, 2],
np.exp(-np.sqrt(2)))
def test_dilated_same_instance_mask_shape(self):
instances = tf.zeros((5, 32, 32))
output = deepmac_meta_arch.dilated_cross_same_mask_label(instances)
self.assertEqual((8, 5, 32, 32), output.shape)
def test_mask_network_resnet(self):
def test_dilated_same_instance_mask(self):
net = deepmac_meta_arch.MaskHeadNetwork('resnet4')
instances = np.zeros((2, 5, 5), dtype=np.float32)
instances[0, 0, 0] = 1.0
instances[0, 2, 2] = 1.0
instances[0, 4, 4] = 1.0
output = deepmac_meta_arch.dilated_cross_same_mask_label(instances).numpy()
self.assertAllClose(np.ones((8, 5, 5)), output[:, 1, :, :])
self.assertAllClose([1, 0, 0, 0, 0, 0, 0, 1], output[:, 0, 2, 2])
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class DeepMACMaskHeadTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(
['hourglass10', 'hourglass20', 'resnet4'])
def test_mask_network(self, head_type):
net = deepmac_meta_arch.MaskHeadNetwork(head_type, 8)
out = net(tf.zeros((2, 4)), tf.zeros((2, 32, 32, 16)), training=True)
self.assertEqual(out.shape, (2, 32, 32))
def test_mask_network_params_resnet4(self):
net = deepmac_meta_arch.MaskHeadNetwork('resnet4', num_init_channels=8)
_ = net(tf.zeros((2, 16)), tf.zeros((2, 32, 32, 16)), training=True)
trainable_params = tf.reduce_sum([tf.reduce_prod(tf.shape(w)) for w in
net.trainable_weights])
self.assertEqual(trainable_params.numpy(), 8665)
def test_mask_network_resnet_tf_function(self):
net = deepmac_meta_arch.MaskHeadNetwork('resnet8')
......@@ -360,8 +432,9 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
masks[1, 16:, 16:] = 1.0
masks = tf.constant(masks)
loss, _ = model._compute_per_instance_deepmac_losses(
boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2)))
loss, _, _ = model._compute_per_instance_deepmac_losses(
boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2)),
tf.zeros((16, 16, 3)))
self.assertAllClose(
loss, np.zeros(2) - tf.math.log(tf.nn.sigmoid(0.9)))
......@@ -373,8 +446,9 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
masks = np.ones((2, 128, 128), dtype=np.float32)
masks = tf.constant(masks)
loss, _ = model._compute_per_instance_deepmac_losses(
boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2)))
loss, _, _ = model._compute_per_instance_deepmac_losses(
boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2)),
tf.zeros((32, 32, 3)))
self.assertAllClose(
loss, np.zeros(2) - tf.math.log(tf.nn.sigmoid(0.9)))
......@@ -387,8 +461,9 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
masks = np.ones((2, 128, 128), dtype=np.float32)
masks = tf.constant(masks)
loss, _ = model._compute_per_instance_deepmac_losses(
boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2)))
loss, _, _ = model._compute_per_instance_deepmac_losses(
boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2)),
tf.zeros((32, 32, 3)))
pred = tf.nn.sigmoid(0.9)
expected = (1.0 - ((2.0 * pred) / (1.0 + pred)))
self.assertAllClose(loss, [expected, expected], rtol=1e-3)
......@@ -397,8 +472,9 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
boxes = tf.zeros([0, 4])
masks = tf.zeros([0, 128, 128])
loss, _ = self.model._compute_per_instance_deepmac_losses(
boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2)))
loss, _, _ = self.model._compute_per_instance_deepmac_losses(
boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2)),
tf.zeros((16, 16, 3)))
self.assertEqual(loss.shape, (0,))
def test_postprocess(self):
......@@ -560,6 +636,29 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(loss, [yloss + xloss])
def test_color_consistency_loss_full_res_shape(self):
model = build_meta_arch(use_dice_loss=True,
predict_full_resolution_masks=True)
boxes = tf.zeros((3, 4))
img = tf.zeros((32, 32, 3))
mask_logits = tf.zeros((3, 32, 32))
loss = model._compute_per_instance_color_consistency_loss(
boxes, img, mask_logits)
self.assertEqual([3], loss.shape)
def test_color_consistency_1_threshold(self):
model = build_meta_arch(predict_full_resolution_masks=True,
color_consistency_threshold=0.99)
boxes = tf.zeros((3, 4))
img = tf.zeros((32, 32, 3))
mask_logits = tf.zeros((3, 32, 32)) - 1e4
loss = model._compute_per_instance_color_consistency_loss(
boxes, img, mask_logits)
self.assertAllClose(loss, np.zeros(3))
def test_box_consistency_dice_loss_full_res(self):
model = build_meta_arch(use_dice_loss=True,
......@@ -575,6 +674,11 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
boxes_gt, boxes_jittered, tf.constant(mask_prediction))
self.assertAlmostEqual(loss[0].numpy(), 1 / 3)
def test_get_lab_image_shape(self):
output = self.model._get_lab_image(tf.zeros((2, 4, 4, 3)))
self.assertEqual(output.shape, (2, 4, 4, 3))
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class FullyConnectedMaskHeadTest(tf.test.TestCase):
......
......@@ -22,6 +22,7 @@ REQUIRED_PACKAGES = [
'scipy',
'pandas',
'tf-models-official>=2.5.1',
'tensorflow_io'
]
setup(
......
......@@ -464,6 +464,13 @@ message CenterNet {
// Weight for the box consistency loss as described in the BoxInst paper
// https://arxiv.org/abs/2012.02310
optional float box_consistency_loss_weight = 16 [default=0.0];
optional float color_consistency_threshold = 17 [default=0.4];
optional int32 color_consistency_dilation = 18 [default=2];
optional float color_consistency_loss_weight = 19 [default=0.0];
}
optional DeepMACMaskEstimation deepmac_mask_estimation = 14;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment