Commit 17b3db9f authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 424463592
parent d55ee951
...@@ -25,6 +25,7 @@ from official.modeling import optimization ...@@ -25,6 +25,7 @@ from official.modeling import optimization
from official.vision.beta.configs import common from official.vision.beta.configs import common
from official.vision.beta.configs import maskrcnn from official.vision.beta.configs import maskrcnn
from official.vision.beta.configs import semantic_segmentation from official.vision.beta.configs import semantic_segmentation
from official.vision.beta.projects.deepmac_maskrcnn.configs import deep_mask_head_rcnn as deepmac_maskrcnn
SEGMENTATION_MODEL = semantic_segmentation.SemanticSegmentationModel SEGMENTATION_MODEL = semantic_segmentation.SemanticSegmentationModel
...@@ -89,7 +90,7 @@ class PanopticSegmentationGenerator(hyperparams.Config): ...@@ -89,7 +90,7 @@ class PanopticSegmentationGenerator(hyperparams.Config):
@dataclasses.dataclass @dataclasses.dataclass
class PanopticMaskRCNN(maskrcnn.MaskRCNN): class PanopticMaskRCNN(deepmac_maskrcnn.DeepMaskHeadRCNN):
"""Panoptic Mask R-CNN model config.""" """Panoptic Mask R-CNN model config."""
segmentation_model: semantic_segmentation.SemanticSegmentationModel = ( segmentation_model: semantic_segmentation.SemanticSegmentationModel = (
SEGMENTATION_MODEL(num_classes=2)) SEGMENTATION_MODEL(num_classes=2))
......
...@@ -17,9 +17,9 @@ ...@@ -17,9 +17,9 @@
import tensorflow as tf import tensorflow as tf
from official.vision.beta.modeling import backbones from official.vision.beta.modeling import backbones
from official.vision.beta.modeling import factory as models_factory
from official.vision.beta.modeling.decoders import factory as decoder_factory from official.vision.beta.modeling.decoders import factory as decoder_factory
from official.vision.beta.modeling.heads import segmentation_heads from official.vision.beta.modeling.heads import segmentation_heads
from official.vision.beta.projects.deepmac_maskrcnn.tasks import deep_mask_head_rcnn
from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as panoptic_maskrcnn_cfg from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as panoptic_maskrcnn_cfg
from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_maskrcnn_model from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_maskrcnn_model
from official.vision.beta.projects.panoptic_maskrcnn.modeling.layers import panoptic_segmentation_generator from official.vision.beta.projects.panoptic_maskrcnn.modeling.layers import panoptic_segmentation_generator
...@@ -47,7 +47,7 @@ def build_panoptic_maskrcnn( ...@@ -47,7 +47,7 @@ def build_panoptic_maskrcnn(
segmentation_config = model_config.segmentation_model segmentation_config = model_config.segmentation_model
# Builds the maskrcnn model. # Builds the maskrcnn model.
maskrcnn_model = models_factory.build_maskrcnn( maskrcnn_model = deep_mask_head_rcnn.build_maskrcnn(
input_specs=input_specs, input_specs=input_specs,
model_config=model_config, model_config=model_config,
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
...@@ -117,6 +117,7 @@ def build_panoptic_maskrcnn( ...@@ -117,6 +117,7 @@ def build_panoptic_maskrcnn(
# Combines maskrcnn, and segmentation models to build panoptic segmentation # Combines maskrcnn, and segmentation models to build panoptic segmentation
# model. # model.
model = panoptic_maskrcnn_model.PanopticMaskRCNNModel( model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
backbone=maskrcnn_model.backbone, backbone=maskrcnn_model.backbone,
decoder=maskrcnn_model.decoder, decoder=maskrcnn_model.decoder,
......
...@@ -18,10 +18,10 @@ from typing import List, Mapping, Optional, Union ...@@ -18,10 +18,10 @@ from typing import List, Mapping, Optional, Union
import tensorflow as tf import tensorflow as tf
from official.vision.beta.modeling import maskrcnn_model from official.vision.beta.projects.deepmac_maskrcnn.modeling import maskrcnn_model
class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): class PanopticMaskRCNNModel(maskrcnn_model.DeepMaskRCNNModel):
"""The Panoptic Segmentation model.""" """The Panoptic Segmentation model."""
def __init__( def __init__(
...@@ -49,7 +49,8 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -49,7 +49,8 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
max_level: Optional[int] = None, max_level: Optional[int] = None,
num_scales: Optional[int] = None, num_scales: Optional[int] = None,
aspect_ratios: Optional[List[float]] = None, aspect_ratios: Optional[List[float]] = None,
anchor_size: Optional[float] = None, # pytype: disable=annotation-type-mismatch # typed-keras anchor_size: Optional[float] = None,
use_gt_boxes_for_masks: bool = False, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs): **kwargs):
"""Initializes the Panoptic Mask R-CNN model. """Initializes the Panoptic Mask R-CNN model.
...@@ -94,6 +95,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -94,6 +95,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
aspect_ratios=[1.0, 2.0, 0.5] adds three anchors on each scale level. aspect_ratios=[1.0, 2.0, 0.5] adds three anchors on each scale level.
anchor_size: A number representing the scale of size of the base anchor to anchor_size: A number representing the scale of size of the base anchor to
the feature stride 2^level. the feature stride 2^level.
use_gt_boxes_for_masks: `bool`, whether to use only gt boxes for masks.
**kwargs: keyword arguments to be passed. **kwargs: keyword arguments to be passed.
""" """
super(PanopticMaskRCNNModel, self).__init__( super(PanopticMaskRCNNModel, self).__init__(
...@@ -115,6 +117,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -115,6 +117,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
num_scales=num_scales, num_scales=num_scales,
aspect_ratios=aspect_ratios, aspect_ratios=aspect_ratios,
anchor_size=anchor_size, anchor_size=anchor_size,
use_gt_boxes_for_masks=use_gt_boxes_for_masks,
**kwargs) **kwargs)
self._config_dict.update({ self._config_dict.update({
......
...@@ -97,6 +97,20 @@ class PanopticSegmentationModule(detection.DetectionModule): ...@@ -97,6 +97,20 @@ class PanopticSegmentationModule(detection.DetectionModule):
anchor_boxes=anchor_boxes, anchor_boxes=anchor_boxes,
training=False) training=False)
detections.pop('rpn_boxes')
detections.pop('rpn_scores')
detections.pop('cls_outputs')
detections.pop('box_outputs')
detections.pop('backbone_features')
detections.pop('decoder_features')
# Normalize detection boxes to [0, 1]. Here we first map them to the
# original image size, then normalize them to [0, 1].
detections['detection_boxes'] = (
detections['detection_boxes'] /
tf.tile(image_info[:, 2:3, :], [1, 1, 2]) /
tf.tile(image_info[:, 0:1, :], [1, 1, 2]))
if model_params.detection_generator.apply_nms: if model_params.detection_generator.apply_nms:
final_outputs = { final_outputs = {
'detection_boxes': detections['detection_boxes'], 'detection_boxes': detections['detection_boxes'],
...@@ -109,10 +123,15 @@ class PanopticSegmentationModule(detection.DetectionModule): ...@@ -109,10 +123,15 @@ class PanopticSegmentationModule(detection.DetectionModule):
'decoded_boxes': detections['decoded_boxes'], 'decoded_boxes': detections['decoded_boxes'],
'decoded_box_scores': detections['decoded_box_scores'] 'decoded_box_scores': detections['decoded_box_scores']
} }
masks = detections['segmentation_outputs']
masks = tf.image.resize(masks, self._input_image_size, method='bilinear')
classes = tf.math.argmax(masks, axis=-1)
scores = tf.nn.softmax(masks, axis=-1)
final_outputs.update({ final_outputs.update({
'detection_masks': detections['detection_masks'], 'detection_masks': detections['detection_masks'],
'segmentation_outputs': detections['segmentation_outputs'], 'masks': masks,
'scores': scores,
'classes': classes,
'image_info': image_info 'image_info': image_info
}) })
if model_params.generate_panoptic_masks: if model_params.generate_panoptic_masks:
......
...@@ -61,7 +61,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -61,7 +61,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
def initialize(self, model: tf.keras.Model) -> None: def initialize(self, model: tf.keras.Model) -> None:
"""Loading pretrained checkpoint.""" """Loading pretrained checkpoint."""
if not self.task_config.init_checkpoint_modules: if not self.task_config.init_checkpoint:
return return
def _get_checkpoint_path(checkpoint_dir_or_file): def _get_checkpoint_path(checkpoint_dir_or_file):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment