Unverified Commit a6c6089b authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

rescale predictions before generating outputs

parent e46a2497
......@@ -386,17 +386,21 @@ class PostProcessor(tf.keras.layers.Layer):
def __init__(
self,
output_size: List[int],
center_score_threshold: float,
thing_class_ids: List[int],
label_divisor: int,
stuff_area_limit: int,
ignore_label: int,
nms_kernel: int,
keep_k_centers: int,
keep_k_centers: int,
rescale_predictions: bool,
**kwargs):
"""Initializes a Panoptic-Deeplab post-processor.
Args:
output_size: A `List` of integers that represent the height and width of
the output mask.
center_threshold: A float setting the threshold for the center heatmap.
thing_class_ids: An integer list shape [N] containing N thing indices.
label_divisor: An integer specifying the label divisor of the dataset.
......@@ -407,18 +411,22 @@ class PostProcessor(tf.keras.layers.Layer):
void_label: An integer specifying the void label.
nms_kernel_size: An integer specifying the nms kernel size.
keep_k_centers: An integer specifying the number of centers to keep.
Negative values will keep all centers.
Negative values will keep all centers.
rescale_predictions: `bool`, whether to scale back prediction to original
image sizes. If True, image_info is used to rescale predictions.
"""
super(PostProcessor, self).__init__(**kwargs)
self._config_dict = {
'output_size': output_size,
'center_score_threshold': center_score_threshold,
'thing_class_ids': thing_class_ids,
'label_divisor': label_divisor,
'stuff_area_limit': stuff_area_limit,
'ignore_label': ignore_label,
'nms_kernel': nms_kernel,
'keep_k_centers': keep_k_centers
'keep_k_centers': keep_k_centers,
'rescale_predictions': rescale_predictions
}
self._post_processor = functools.partial(
_get_panoptic_predictions,
......@@ -430,15 +438,48 @@ class PostProcessor(tf.keras.layers.Layer):
nms_kernel_size=nms_kernel,
keep_k_centers=keep_k_centers)
def call(self, result_dict: Dict[Text, tf.Tensor]) -> Dict[Text, tf.Tensor]:
def _resize_and_pad_masks(self, mask, image_info):
"""Resizes masks to match the original image shape and pads to`output_size`.
Args:
mask: a padded mask tensor.
image_info: a tensor that holds information about original and
preprocessed images.
Returns:
resized and padded masks: tf.Tensor.
"""
rescale_size = tf.cast(
tf.math.ceil(image_info[1, :] / image_info[2, :]), tf.int32)
image_shape = tf.cast(image_info[0, :], tf.int32)
offsets = tf.cast(image_info[3, :], tf.int32)
mask = tf.image.resize(
mask,
rescale_size,
method='bilinear')
mask = tf.image.crop_to_bounding_box(
mask,
offsets[0], offsets[1],
image_shape[0],
image_shape[1])
mask = tf.image.pad_to_bounding_box(
mask, 0, 0,
self._config_dict['output_size'][0],
self._config_dict['output_size'][1])
return mask
def call(
self,
result_dict: Dict[Text, tf.Tensor],
image_info: tf.Tensor) -> Dict[Text, tf.Tensor]:
"""Performs the post-processing given model predicted results.
Args:
result_dict: A dictionary of tf.Tensor containing model results. The dict
has to contain
- segmentation_outputs
- instance_center_prediction
- instance_center_regression
- instance_centers_heatmap
- instance_centers_offset
Returns:
The post-processed dict of tf.Tensor, containing the following keys:
......@@ -448,16 +489,51 @@ class PostProcessor(tf.keras.layers.Layer):
- instance_centers
- instance_score
"""
if self._config_dict['rescale_predictions']:
def _batch_resize_and_pad_masks(mask):
mask = tf.map_fn(
fn=lambda x: self._resize_and_pad_masks(x[0], x[1]),
elems=(mask, image_info),
fn_output_signature=tf.float32,
parallel_iterations=32)
return mask
segmentation_outputs = _batch_resize_and_pad_masks(
result_dict['segmentation_outputs'])
instance_centers_heatmap = _batch_resize_and_pad_masks(
result_dict['instance_centers_heatmap'])
instance_centers_offset = _batch_resize_and_pad_masks(
result_dict['instance_centers_offset'])
else:
segmentation_outputs = tf.image.resize(
result_dict['segmentation_outputs'],
size=self._config_dict['output_size'],
method='bilinear')
instance_centers_heatmap = tf.image.resize(
result_dict['instance_centers_heatmap'],
size=self._config_dict['output_size'],
method='bilinear')
instance_centers_offset = tf.image.resize(
result_dict['instance_centers_offset'],
size=self._config_dict['output_size'],
method='bilinear')
processed_dict = {}
(processed_dict['panoptic_outputs'],
processed_dict['category_mask'],
processed_dict['instance_mask'],
processed_dict['instance_centers'],
processed_dict['instance_scores']
) = self._post_processor(
tf.nn.softmax(result_dict['segmentation_outputs'], axis=-1),
result_dict['instance_center_prediction'],
result_dict['instance_center_regression'])
tf.nn.softmax(segmentation_outputs, axis=-1),
instance_centers_heatmap,
instance_centers_offset)
processed_dict.update({
'segmentation_outputs': result_dict['segmentation_outputs']})
return processed_dict
def get_config(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment