"rust/git@developer.sourcefind.cn:change/sglang.git" did not exist on "d4de9a62359d1299cb639a67f39cfb40fda5d957"
Unverified Commit fa06f822 authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

added docstrings

parent 3982b139
...@@ -33,11 +33,16 @@ from official.vision.beta.evaluation import panoptic_quality ...@@ -33,11 +33,16 @@ from official.vision.beta.evaluation import panoptic_quality
def _crop_padding(mask, image_info): def _crop_padding(mask, image_info):
"""Crops padded masks to match original image shape.
Args:
mask: a padded mask tensor.
image_info: a tensor that holds information about original and preprocessed
images.
"""
image_shape = tf.cast(image_info[0, :], tf.int32) image_shape = tf.cast(image_info[0, :], tf.int32)
mask = tf.image.crop_to_bounding_box( mask = tf.image.crop_to_bounding_box(
tf.expand_dims(mask, axis=-1), 0, 0, tf.expand_dims(mask, axis=-1), 0, 0,
image_shape[0], image_shape[1]) image_shape[0], image_shape[1])
return tf.expand_dims(mask[:, :, 0], axis=0) return tf.expand_dims(mask[:, :, 0], axis=0)
class PanopticQualityEvaluator: class PanopticQualityEvaluator:
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
"""Contains definition for postprocessing layer to genrate panoptic segmentations.""" """Contains definition for postprocessing layer to genrate panoptic segmentations."""
from typing import List, Mapping from typing import List
import tensorflow as tf import tensorflow as tf
...@@ -57,7 +57,7 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -57,7 +57,7 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
assigned to any thing class. That is, void_instance_id are assigned to assigned to any thing class. That is, void_instance_id are assigned to
both stuff regions and empty regions. both stuff regions and empty regions.
rescale_predictions: `bool`, whether to scale back prediction to original rescale_predictions: `bool`, whether to scale back prediction to original
image sizes. If True, image_info `dict` is used to rescale predictions. image sizes. If True, image_info is used to rescale predictions.
**kwargs: additional kewargs arguments. **kwargs: additional kewargs arguments.
""" """
self._output_size = output_size self._output_size = output_size
...@@ -230,6 +230,13 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -230,6 +230,13 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
return results return results
def _resize_and_pad_masks(self, mask, image_info): def _resize_and_pad_masks(self, mask, image_info):
"""Resizes masks to match the original image shape and pads them to
`output_size`.
Args:
mask: a padded mask tensor.
image_info: a tensor that holds information about original and
preprocessed images.
"""
rescale_size = tf.cast( rescale_size = tf.cast(
tf.math.ceil(image_info[1, :] / image_info[2, :]), tf.int32) tf.math.ceil(image_info[1, :] / image_info[2, :]), tf.int32)
image_shape = tf.cast(image_info[0, :], tf.int32) image_shape = tf.cast(image_info[0, :], tf.int32)
...@@ -238,7 +245,7 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -238,7 +245,7 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
mask = tf.image.resize( mask = tf.image.resize(
mask, mask,
rescale_size, rescale_size,
method=tf.image.ResizeMethod.BILINEAR) method='bilinear')
mask = tf.image.crop_to_bounding_box( mask = tf.image.crop_to_bounding_box(
mask, mask,
offsets[0], offsets[1], offsets[0], offsets[1],
...@@ -248,7 +255,7 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -248,7 +255,7 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
mask, 0, 0, self._output_size[0], self._output_size[1]) mask, 0, 0, self._output_size[0], self._output_size[1])
return mask return mask
def call(self, inputs: tf.Tensor, image_info: Mapping[str, tf.Tensor]): def call(self, inputs: tf.Tensor, image_info: tf.Tensor):
detections = inputs detections = inputs
batched_scores = detections['detection_scores'] batched_scores = detections['detection_scores']
......
...@@ -143,7 +143,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -143,7 +143,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
def call(self, def call(self,
images: tf.Tensor, images: tf.Tensor,
image_info: Mapping[str, tf.Tensor], image_info: tf.Tensor,
anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None, anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None,
gt_boxes: Optional[tf.Tensor] = None, gt_boxes: Optional[tf.Tensor] = None,
gt_classes: Optional[tf.Tensor] = None, gt_classes: Optional[tf.Tensor] = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment