"git@developer.sourcefind.cn:ox696c/ktransformers.git" did not exist on "c519747f3ca6841870f8686ad34cca236d552dbf"
Unverified Commit 12ef2884 authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

added postprocessing layer to `PanopticMaskRCNNModel`

parent 9eed3c5c
...@@ -36,6 +36,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -36,6 +36,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
List[tf.keras.layers.Layer]], List[tf.keras.layers.Layer]],
roi_aligner: tf.keras.layers.Layer, roi_aligner: tf.keras.layers.Layer,
detection_generator: tf.keras.layers.Layer, detection_generator: tf.keras.layers.Layer,
panoptic_segmentation_generator: tf.keras.layers.Layer,
mask_head: Optional[tf.keras.layers.Layer] = None, mask_head: Optional[tf.keras.layers.Layer] = None,
mask_sampler: Optional[tf.keras.layers.Layer] = None, mask_sampler: Optional[tf.keras.layers.Layer] = None,
mask_roi_aligner: Optional[tf.keras.layers.Layer] = None, mask_roi_aligner: Optional[tf.keras.layers.Layer] = None,
...@@ -62,6 +63,8 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -62,6 +63,8 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
detection heads. detection heads.
roi_aligner: the ROI aligner. roi_aligner: the ROI aligner.
detection_generator: the detection generator. detection_generator: the detection generator.
panoptic_segmentation_generator: the panoptic segmentation generator that
is used to merge instance and semantic segmentation masks.
mask_head: the mask head. mask_head: the mask head.
mask_sampler: the mask sampler. mask_sampler: the mask sampler.
mask_roi_aligner: the ROI alginer for mask prediction. mask_roi_aligner: the ROI alginer for mask prediction.
...@@ -117,7 +120,8 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -117,7 +120,8 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
self._config_dict.update({ self._config_dict.update({
'segmentation_backbone': segmentation_backbone, 'segmentation_backbone': segmentation_backbone,
'segmentation_decoder': segmentation_decoder, 'segmentation_decoder': segmentation_decoder,
'segmentation_head': segmentation_head 'segmentation_head': segmentation_head,
'panoptic_segmentation_generator': panoptic_segmentation_generator
}) })
if not self._include_mask: if not self._include_mask:
...@@ -131,6 +135,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -131,6 +135,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
self.segmentation_backbone = segmentation_backbone self.segmentation_backbone = segmentation_backbone
self.segmentation_decoder = segmentation_decoder self.segmentation_decoder = segmentation_decoder
self.segmentation_head = segmentation_head self.segmentation_head = segmentation_head
self.panoptic_segmentation_generator = panoptic_segmentation_generator
def call(self, def call(self,
images: tf.Tensor, images: tf.Tensor,
...@@ -167,6 +172,12 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -167,6 +172,12 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
'segmentation_outputs': segmentation_outputs, 'segmentation_outputs': segmentation_outputs,
}) })
if not training:
panoptic_outputs = self.panoptic_segmentation_generator(model_outputs)
model_outputs.update({
'panoptic_outputs': panoptic_outputs
})
return model_outputs return model_outputs
@property @property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment