"template/phi-3.gotmpl" did not exist on "9b6c2e6eb62c234f8a44556984bbb680d7065e01"
Commit d5ae12e1 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by TF Object Detection Team
Browse files

Modify CenterNet to output extracted features.

PiperOrigin-RevId: 400090738
parent e06d2c3a
......@@ -3805,6 +3805,7 @@ class CenterNetMetaArch(model.DetectionModel):
prediction_dict: a dictionary holding predicted tensors with
'preprocessed_inputs' - The input image after being resized and
preprocessed by the feature extractor.
'extracted_features' - The output of the feature extractor.
'object_center' - A list of size num_feature_outputs containing
float tensors of size [batch_size, output_height, output_width,
num_classes] representing the predicted object center heatmap logits.
......@@ -3848,6 +3849,7 @@ class CenterNetMetaArch(model.DetectionModel):
predictions[head_name] = [
head(feature) for (feature, head) in zip(features_list, heads)
]
predictions['extracted_features'] = features_list
predictions['preprocessed_inputs'] = preprocessed_inputs
self._batched_prediction_tensor_names = predictions.keys()
......
......@@ -3460,6 +3460,7 @@ class CenterNetMetaArch1dTest(test_case.TestCase, parameterized.TestCase):
postprocess_output = arch.postprocess(predictions, true_shapes)
losses_output = arch.loss(predictions, true_shapes)
self.assertIn('extracted_features', predictions)
self.assertIn('%s/%s' % (cnma.LOSS_KEY_PREFIX, cnma.OBJECT_CENTER),
losses_output)
self.assertEqual((), losses_output['%s/%s' % (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment