"...python/git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "c163bf4ff16235536832f54a7512b69bc03825c5"
Unverified Commit 52c7f869 authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

generate panoptic masks at original resolution

parent 0bee3f54
......@@ -32,18 +32,12 @@ import tensorflow as tf
from official.vision.beta.evaluation import panoptic_quality
def rescale_masks(groundtruth, prediction, image_info):
rescale_size = tf.cast(
tf.math.ceil(image_info[1, :] / image_info[2, :]), tf.int32)
def _crop_padding(groundtruth, prediction, image_info):
image_shape = tf.cast(image_info[0, :], tf.int32)
offsets = tf.cast(image_info[3, :], tf.int32)
prediction = tf.image.resize(
tf.expand_dims(prediction, axis=-1),
rescale_size,
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
prediction = tf.image.crop_to_bounding_box(
prediction,
tf.expand_dims(prediction, axis=-1),
offsets[0], offsets[1],
image_shape[0],
image_shape[1])
......@@ -52,8 +46,8 @@ def rescale_masks(groundtruth, prediction, image_info):
image_shape[0], image_shape[1])
return (
tf.expand_dims(groundtruth[:, :, 0], axis=0),
tf.expand_dims(prediction[:, :, 0], axis=0))
tf.expand_dims(groundtruth[:, :, 0], axis=0),
tf.expand_dims(prediction[:, :, 0], axis=0))
class PanopticQualityEvaluator:
"""Panoptic Quality metric class."""
......@@ -169,11 +163,11 @@ class PanopticQualityEvaluator:
if self._rescale_predictions:
for idx in range(len(groundtruths['category_mask'])):
image_info = groundtruths['image_info'][idx]
groundtruth_category_mask, prediction_category_mask = rescale_masks(
groundtruth_category_mask, prediction_category_mask = _crop_padding(
groundtruths['category_mask'][idx],
predictions['category_mask'][idx],
image_info)
groundtruth_instance_mask, prediction_instance_mask = rescale_masks(
groundtruth_instance_mask, prediction_instance_mask = _crop_padding(
groundtruths['instance_mask'][idx],
predictions['instance_mask'][idx],
image_info)
......
......@@ -182,7 +182,7 @@ def panoptic_fpn_coco() -> cfg.ExperimentConfig:
model=PanopticMaskRCNN(
num_classes=91, input_size=[1024, 1024, 3],
panoptic_segmentation_generator=PanopticSegmentationGenerator(
output_size=[1024, 1024]),
output_size=[640, 640]),
stuff_classes_offset=90,
segmentation_model=SEGMENTATION_MODEL(
num_classes=num_semantic_segmentation_classes,
......
......@@ -224,12 +224,11 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
}
return results
def call(self, inputs):
def call(self, inputs: tf.Tensor, image_shape: tf.Tensor):
detections = inputs
batched_scores = detections['detection_scores']
batched_classes = detections['detection_classes']
batched_boxes = detections['detection_boxes']
batched_detections_masks = tf.expand_dims(
detections['detection_masks'], axis=-1)
......@@ -241,6 +240,13 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
tf.argmax(batched_segmentation_masks, axis=-1),
dtype=tf.float32), axis=-1)
batched_boxes = detections['detection_boxes']
image_shape = tf.cast(image_shape, dtype=batched_boxes.dtype)
scale = tf.convert_to_tensor(
[self._output_size], dtype=batched_boxes.dtype) / image_shape
scale = tf.tile(tf.expand_dims(scale, axis=0), multiples=[1, 1, 2])
batched_boxes = batched_boxes * scale
panoptic_masks = tf.map_fn(
fn=lambda x: self._generate_panoptic_masks( # pylint:disable=g-long-lambda
x[0], x[1], x[2], x[3], x[4]),
......@@ -253,7 +259,7 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
fn_output_signature={
'category_mask': tf.float32,
'instance_mask': tf.float32
})
}, parallel_iterations=32)
for k, v in panoptic_masks.items():
panoptic_masks[k] = tf.cast(v, dtype=tf.int32)
......
......@@ -177,7 +177,8 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
})
if not training and self.panoptic_segmentation_generator is not None:
panoptic_outputs = self.panoptic_segmentation_generator(model_outputs)
panoptic_outputs = self.panoptic_segmentation_generator(
model_outputs, image_shape=image_shape)
model_outputs.update({'panoptic_outputs': panoptic_outputs})
return model_outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment