Commit c7d4dd39 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by TF Object Detection Team
Browse files

Mask boxes with centers outside of the true_image_shape.

PiperOrigin-RevId: 383699499
parent 07589c47
......@@ -1802,6 +1802,31 @@ def predicted_embeddings_at_object_centers(embedding_predictions,
return embeddings
def mask_from_true_image_shape(data_shape, true_image_shapes):
"""Get a binary mask based on the true_image_shape.
Args:
data_shape: a possibly static (4,) tensor for the shape of the feature
map.
true_image_shapes: int32 tensor of shape [batch, 3] where each row is of
the form [height, width, channels] indicating the shapes of true
images in the resized images, as resized images can be padded with
zeros.
Returns:
a [batch, data_height, data_width, 1] tensor of 1.0 wherever data_height
is less than height, etc.
"""
mask_h = tf.cast(
tf.range(data_shape[1]) < true_image_shapes[:, tf.newaxis, 0],
tf.float32)
mask_w = tf.cast(
tf.range(data_shape[2]) < true_image_shapes[:, tf.newaxis, 1],
tf.float32)
mask = tf.expand_dims(
mask_h[:, :, tf.newaxis] * mask_w[:, tf.newaxis, :], 3)
return mask
class ObjectDetectionParams(
collections.namedtuple('ObjectDetectionParams', [
'localization_loss', 'scale_loss_weight', 'offset_loss_weight',
......@@ -3698,6 +3723,12 @@ class CenterNetMetaArch(model.DetectionModel):
max_detections, reid_embed_size] containing object embeddings.
"""
object_center_prob = tf.nn.sigmoid(prediction_dict[OBJECT_CENTER][-1])
# Mask object centers by true_image_shape. [batch, h, w, 1]
object_center_mask = mask_from_true_image_shape(
_get_shape(object_center_prob, 4), true_image_shapes)
object_center_prob *= object_center_mask
# Get x, y and channel indices corresponding to the top indices in the class
# center predictions.
detection_scores, y_indices, x_indices, channel_indices = (
......
......@@ -2518,6 +2518,75 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
self.assertAllClose(detections['detection_keypoint_scores'][0, 0],
np.array([0.9, 0.9, 0.9, 0.1]))
def test_mask_object_center_in_postprocess_by_true_image_shape(self):
"""Test the postprocess function is masked by true_image_shape."""
model = build_center_net_meta_arch(num_classes=1)
max_detection = model._center_params.max_box_predictions
num_keypoints = len(model._kp_params_dict[_TASK_NAME].keypoint_indices)
class_center = np.zeros((1, 32, 32, 1), dtype=np.float32)
height_width = np.zeros((1, 32, 32, 2), dtype=np.float32)
offset = np.zeros((1, 32, 32, 2), dtype=np.float32)
keypoint_heatmaps = np.zeros((1, 32, 32, num_keypoints), dtype=np.float32)
keypoint_offsets = np.zeros((1, 32, 32, 2), dtype=np.float32)
keypoint_regression = np.random.randn(1, 32, 32, num_keypoints * 2)
class_probs = np.zeros(1)
class_probs[0] = _logit(0.75)
class_center[0, 16, 16] = class_probs
height_width[0, 16, 16] = [5, 10]
offset[0, 16, 16] = [.25, .5]
keypoint_regression[0, 16, 16] = [
-1., -1.,
-1., 1.,
1., -1.,
1., 1.]
keypoint_heatmaps[0, 14, 14, 0] = _logit(0.9)
keypoint_heatmaps[0, 14, 18, 1] = _logit(0.9)
keypoint_heatmaps[0, 18, 14, 2] = _logit(0.9)
keypoint_heatmaps[0, 18, 18, 3] = _logit(0.05) # Note the low score.
class_center = tf.constant(class_center)
height_width = tf.constant(height_width)
offset = tf.constant(offset)
keypoint_heatmaps = tf.constant(keypoint_heatmaps, dtype=tf.float32)
keypoint_offsets = tf.constant(keypoint_offsets, dtype=tf.float32)
keypoint_regression = tf.constant(keypoint_regression, dtype=tf.float32)
print(class_center)
prediction_dict = {
cnma.OBJECT_CENTER: [class_center],
cnma.BOX_SCALE: [height_width],
cnma.BOX_OFFSET: [offset],
cnma.get_keypoint_name(_TASK_NAME, cnma.KEYPOINT_HEATMAP):
[keypoint_heatmaps],
cnma.get_keypoint_name(_TASK_NAME, cnma.KEYPOINT_OFFSET):
[keypoint_offsets],
cnma.get_keypoint_name(_TASK_NAME, cnma.KEYPOINT_REGRESSION):
[keypoint_regression],
}
def graph_fn():
detections = model.postprocess(prediction_dict,
tf.constant([[1, 1, 3]]))
return detections
detections = self.execute_cpu(graph_fn, [])
self.assertAllClose(detections['detection_boxes'][0, 0],
np.array([0, 0, 0, 0]))
# The class_center logits are initialized as 0's so it's filled with 0.5s.
# Despite that, we should only find one box.
self.assertAllClose(detections['detection_scores'][0],
[0.5, 0., 0., 0., 0.])
self.assertEqual(np.sum(detections['detection_classes']), 0)
self.assertEqual(detections['num_detections'], [1])
self.assertAllEqual([1, max_detection, num_keypoints, 2],
detections['detection_keypoints'].shape)
self.assertAllEqual([1, max_detection, num_keypoints],
detections['detection_keypoint_scores'].shape)
def test_get_instance_indices(self):
classes = tf.constant([[0, 1, 2, 0], [2, 1, 2, 2]], dtype=tf.int32)
num_detections = tf.constant([1, 3], dtype=tf.int32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment