"tools/vscode:/vscode.git/clone" did not exist on "76edd498035720c1f639985928652d29893ed797"
Commit 8f23bc72 authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Updated the postprocess function in CenterNetMetaArch such that the

true_image_shape is not required and can be inferred by the prediction tensors.

PiperOrigin-RevId: 390421432
parent 3512e9f8
...@@ -3742,6 +3742,16 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3742,6 +3742,16 @@ class CenterNetMetaArch(model.DetectionModel):
""" """
object_center_prob = tf.nn.sigmoid(prediction_dict[OBJECT_CENTER][-1]) object_center_prob = tf.nn.sigmoid(prediction_dict[OBJECT_CENTER][-1])
if true_image_shapes is None:
# If true_image_shapes is not provided, we assume the whole image is valid
# and infer the true_image_shapes from the object_center_prob shape.
batch_size, strided_height, strided_width, _ = _get_shape(
object_center_prob, 4)
true_image_shapes = tf.stack(
[strided_height * self._stride, strided_width * self._stride,
tf.constant(len(self._feature_extractor._channel_means))]) # pylint: disable=protected-access
true_image_shapes = tf.stack([true_image_shapes] * batch_size, axis=0)
else:
# Mask object centers by true_image_shape. [batch, h, w, 1] # Mask object centers by true_image_shape. [batch, h, w, 1]
object_center_mask = mask_from_true_image_shape( object_center_mask = mask_from_true_image_shape(
_get_shape(object_center_prob, 4), true_image_shapes) _get_shape(object_center_prob, 4), true_image_shapes)
......
...@@ -2056,10 +2056,11 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -2056,10 +2056,11 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
cnma.TEMPORAL_OFFSET)]) cnma.TEMPORAL_OFFSET)])
@parameterized.parameters( @parameterized.parameters(
{'target_class_id': 1}, {'target_class_id': 1, 'with_true_image_shape': True},
{'target_class_id': 2}, {'target_class_id': 2, 'with_true_image_shape': True},
{'target_class_id': 1, 'with_true_image_shape': False},
) )
def test_postprocess(self, target_class_id): def test_postprocess(self, target_class_id, with_true_image_shape):
"""Test the postprocess function.""" """Test the postprocess function."""
model = build_center_net_meta_arch() model = build_center_net_meta_arch()
max_detection = model._center_params.max_box_predictions max_detection = model._center_params.max_box_predictions
...@@ -2140,8 +2141,11 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -2140,8 +2141,11 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
} }
def graph_fn(): def graph_fn():
if with_true_image_shape:
detections = model.postprocess(prediction_dict, detections = model.postprocess(prediction_dict,
tf.constant([[128, 128, 3]])) tf.constant([[128, 128, 3]]))
else:
detections = model.postprocess(prediction_dict, None)
return detections return detections
detections = self.execute_cpu(graph_fn, []) detections = self.execute_cpu(graph_fn, [])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment