Unverified Commit 05541cbe authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

fixed box rescaling

parent d8eed145
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
"""Contains definition for postprocessing layer to genrate panoptic segmentations.""" """Contains definition for postprocessing layer to genrate panoptic segmentations."""
from typing import List from typing import List, Mapping
import tensorflow as tf import tensorflow as tf
...@@ -224,7 +224,7 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -224,7 +224,7 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
} }
return results return results
def call(self, inputs: tf.Tensor, image_shape: tf.Tensor): def call(self, inputs: tf.Tensor, image_info: Mapping[str, tf.Tensor]):
detections = inputs detections = inputs
batched_scores = detections['detection_scores'] batched_scores = detections['detection_scores']
...@@ -241,11 +241,10 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -241,11 +241,10 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
dtype=tf.float32), axis=-1) dtype=tf.float32), axis=-1)
batched_boxes = detections['detection_boxes'] batched_boxes = detections['detection_boxes']
image_shape = tf.cast(image_shape, dtype=batched_boxes.dtype) scale = tf.tile(
scale = tf.convert_to_tensor( tf.cast(image_info[:, 2:3, :], dtype=batched_boxes.dtype),
[self._output_size], dtype=batched_boxes.dtype) / image_shape multiples=[1, 1, 2])
scale = tf.tile(tf.expand_dims(scale, axis=1), multiples=[1, 1, 2]) batched_boxes /= scale
batched_boxes = batched_boxes * scale
panoptic_masks = tf.map_fn( panoptic_masks = tf.map_fn(
fn=lambda x: self._generate_panoptic_masks( # pylint:disable=g-long-lambda fn=lambda x: self._generate_panoptic_masks( # pylint:disable=g-long-lambda
......
...@@ -143,12 +143,13 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -143,12 +143,13 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
def call(self, def call(self,
images: tf.Tensor, images: tf.Tensor,
image_shape: tf.Tensor, image_info: Mapping[str, tf.Tensor],
anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None, anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None,
gt_boxes: Optional[tf.Tensor] = None, gt_boxes: Optional[tf.Tensor] = None,
gt_classes: Optional[tf.Tensor] = None, gt_classes: Optional[tf.Tensor] = None,
gt_masks: Optional[tf.Tensor] = None, gt_masks: Optional[tf.Tensor] = None,
training: Optional[bool] = None) -> Mapping[str, tf.Tensor]: training: Optional[bool] = None) -> Mapping[str, tf.Tensor]:
image_shape = image_info[:, 1, :]
model_outputs = super(PanopticMaskRCNNModel, self).call( model_outputs = super(PanopticMaskRCNNModel, self).call(
images=images, images=images,
image_shape=image_shape, image_shape=image_shape,
...@@ -178,7 +179,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -178,7 +179,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
if not training and self.panoptic_segmentation_generator is not None: if not training and self.panoptic_segmentation_generator is not None:
panoptic_outputs = self.panoptic_segmentation_generator( panoptic_outputs = self.panoptic_segmentation_generator(
model_outputs, image_shape=image_shape) model_outputs, image_info=image_info)
model_outputs.update({'panoptic_outputs': panoptic_outputs}) model_outputs.update({'panoptic_outputs': panoptic_outputs})
return model_outputs return model_outputs
......
...@@ -286,7 +286,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -286,7 +286,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
outputs = model( outputs = model(
images, images,
image_shape=labels['image_info'][:, 1, :], image_info=labels['image_info'],
anchor_boxes=labels['anchor_boxes'], anchor_boxes=labels['anchor_boxes'],
gt_boxes=labels['gt_boxes'], gt_boxes=labels['gt_boxes'],
gt_classes=labels['gt_classes'], gt_classes=labels['gt_classes'],
...@@ -355,7 +355,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -355,7 +355,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
outputs = model( outputs = model(
images, images,
anchor_boxes=labels['anchor_boxes'], anchor_boxes=labels['anchor_boxes'],
image_shape=labels['image_info'][:, 1, :], image_info=labels['image_info'],
training=False) training=False)
logs = {self.loss: 0} logs = {self.loss: 0}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment