Commit 45ecc0f9 authored by Zhichao Lu's avatar Zhichao Lu Committed by pkulzc
Browse files

Adding an attribute to SSD model to indicate which fields in prediction...

Adding an attribute to SSD model to indicate which fields in prediction dictionary have a batch dimension. This will be useful for future video models.

PiperOrigin-RevId: 191743097
parent a0c3c440
......@@ -235,6 +235,7 @@ class SSDMetaArch(model.DetectionModel):
self._anchors = None
self._add_summaries = add_summaries
self._batched_prediction_tensor_names = []
@property
def anchors(self):
......@@ -244,6 +245,13 @@ class SSDMetaArch(model.DetectionModel):
raise RuntimeError('anchors should be a BoxList object, but is not.')
return self._anchors
@property
def batched_prediction_tensor_names(self):
if not self._batched_prediction_tensor_names:
raise RuntimeError('Must call predict() method to get batched prediction '
'tensor names.')
return self._batched_prediction_tensor_names
def preprocess(self, inputs):
"""Feature-extractor specific preprocessing.
......@@ -385,6 +393,8 @@ class SSDMetaArch(model.DetectionModel):
'feature_maps': feature_maps,
'anchors': self._anchors.get()
}
self._batched_prediction_tensor_names = [x for x in predictions_dict
if x != 'anchors']
return predictions_dict
def _get_feature_map_spatial_dims(self, feature_maps):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment