Commit 8f23bc72 authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Updated the postprocess function in CenterNetMetaArch such that the

true_image_shape is not required and can be inferred by the prediction tensors.

PiperOrigin-RevId: 390421432
parent 3512e9f8
......@@ -3742,6 +3742,16 @@ class CenterNetMetaArch(model.DetectionModel):
"""
object_center_prob = tf.nn.sigmoid(prediction_dict[OBJECT_CENTER][-1])
if true_image_shapes is None:
# If true_image_shapes is not provided, we assume the whole image is valid
# and infer the true_image_shapes from the object_center_prob shape.
batch_size, strided_height, strided_width, _ = _get_shape(
object_center_prob, 4)
true_image_shapes = tf.stack(
[strided_height * self._stride, strided_width * self._stride,
tf.constant(len(self._feature_extractor._channel_means))]) # pylint: disable=protected-access
true_image_shapes = tf.stack([true_image_shapes] * batch_size, axis=0)
else:
# Mask object centers by true_image_shape. [batch, h, w, 1]
object_center_mask = mask_from_true_image_shape(
_get_shape(object_center_prob, 4), true_image_shapes)
......
......@@ -2056,10 +2056,11 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
cnma.TEMPORAL_OFFSET)])
@parameterized.parameters(
{'target_class_id': 1},
{'target_class_id': 2},
{'target_class_id': 1, 'with_true_image_shape': True},
{'target_class_id': 2, 'with_true_image_shape': True},
{'target_class_id': 1, 'with_true_image_shape': False},
)
def test_postprocess(self, target_class_id):
def test_postprocess(self, target_class_id, with_true_image_shape):
"""Test the postprocess function."""
model = build_center_net_meta_arch()
max_detection = model._center_params.max_box_predictions
......@@ -2140,8 +2141,11 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
}
def graph_fn():
if with_true_image_shape:
detections = model.postprocess(prediction_dict,
tf.constant([[128, 128, 3]]))
else:
detections = model.postprocess(prediction_dict, None)
return detections
detections = self.execute_cpu(graph_fn, [])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment