"vscode:/vscode.git/clone" did not exist on "acf4156ee3b2fdcb9b3604d40563d610a9786105"
Commit b3be14bc authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by Jaeyoun Kim
Browse files

Added `PanopticMaskRCNNModel` model (#10045)

* Added `PanopticMaskRCNNModel` model

* test checkpoint loading for segmentation objects

* fixed docstring

* subclassed `MaskRCNNModel`

* always enable mask and segmentation heads

* added __init__.py

* added README.md
parent fcd681d2
# Panoptic Segmentation
## Description
Panoptic Segmentation combines the two distinct vision tasks - semantic segmentation and instance segmentation.
These tasks are unified such that, each pixel in the image is assigned the label of the class it belongs to, and also the instance identifier of the object it a part of.
## Environment setup
The code can be run on multiple GPUs or TPUs with different distribution
strategies. See the TensorFlow distributed training
[guide](https://www.tensorflow.org/guide/distributed_training) for an overview
of `tf.distribute`.
The code is compatible with TensorFlow 2.4+. See requirements.txt for all
prerequisites, and you can also install them using the following command. `pip
install -r ./official/requirements.txt`
**DISCLAIMER**: Panoptic MaskRCNN is still under active development, stay tuned!
\ No newline at end of file
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Panoptic Segmentation model."""
from typing import List, Mapping, Optional, Union
# Import libraries
import tensorflow as tf
from official.vision.beta.modeling import maskrcnn_model
from official.vision.beta.ops import anchor
from official.vision.beta.ops import box_ops
@tf.keras.utils.register_keras_serializable(package='Vision')
class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
"""The Panoptic Segmentation model."""
def __init__(self,
backbone: tf.keras.Model,
decoder: tf.keras.Model,
rpn_head: tf.keras.layers.Layer,
detection_head: Union[tf.keras.layers.Layer,
List[tf.keras.layers.Layer]],
roi_generator: tf.keras.layers.Layer,
roi_sampler: Union[tf.keras.layers.Layer,
List[tf.keras.layers.Layer]],
roi_aligner: tf.keras.layers.Layer,
detection_generator: tf.keras.layers.Layer,
mask_head: Optional[tf.keras.layers.Layer] = None,
mask_sampler: Optional[tf.keras.layers.Layer] = None,
mask_roi_aligner: Optional[tf.keras.layers.Layer] = None,
segmentation_backbone: Optional[tf.keras.Model] = None,
segmentation_decoder: Optional[tf.keras.Model] = None,
segmentation_head: tf.keras.layers.Layer = None,
class_agnostic_bbox_pred: bool = False,
cascade_class_ensemble: bool = False,
min_level: Optional[int] = None,
max_level: Optional[int] = None,
num_scales: Optional[int] = None,
aspect_ratios: Optional[List[float]] = None,
anchor_size: Optional[float] = None,
**kwargs):
"""Initializes the Panoptic Mask R-CNN model.
Args:
backbone: `tf.keras.Model`, the backbone network.
decoder: `tf.keras.Model`, the decoder network.
rpn_head: the RPN head.
detection_head: the detection head or a list of heads.
roi_generator: the ROI generator.
roi_sampler: a single ROI sampler or a list of ROI samplers for cascade
detection heads.
roi_aligner: the ROI aligner.
detection_generator: the detection generator.
mask_head: the mask head.
mask_sampler: the mask sampler.
mask_roi_aligner: the ROI alginer for mask prediction.
segmentation_backbone: `tf.keras.Model`, the backbone network for the
segmentation head for panoptic task. Providing `segmentation_backbone`
will allow the segmentation head to use a standlone backbone. Setting
`segmentation_backbone=None` would enable backbone sharing between
the MaskRCNN model and segmentation head.
segmentation_decoder: `tf.keras.Model`, the decoder network for the
segmentation head for panoptic task. Providing `segmentation_decoder`
will allow the segmentation head to use a standlone decoder. Setting
`segmentation_decoder=None` would enable decoder sharing between
the MaskRCNN model and segmentation head. Decoders can only be shared
when `segmentation_backbone` is shared as well.
segmentation_head: segmentatation head for panoptic task.
class_agnostic_bbox_pred: if True, perform class agnostic bounding box
prediction. Needs to be `True` for Cascade RCNN models.
cascade_class_ensemble: if True, ensemble classification scores over
all detection heads.
min_level: Minimum level in output feature maps.
max_level: Maximum level in output feature maps.
num_scales: A number representing intermediate scales added
on each level. For instances, num_scales=2 adds one additional
intermediate anchor scales [2^0, 2^0.5] on each level.
aspect_ratios: A list representing the aspect raito
anchors added on each level. The number indicates the ratio of width to
height. For instances, aspect_ratios=[1.0, 2.0, 0.5] adds three anchors
on each scale level.
anchor_size: A number representing the scale of size of the base
anchor to the feature stride 2^level.
**kwargs: keyword arguments to be passed.
"""
super(PanopticMaskRCNNModel, self).__init__(
backbone=backbone,
decoder=decoder,
rpn_head=rpn_head,
detection_head=detection_head,
roi_generator=roi_generator,
roi_sampler=roi_sampler,
roi_aligner=roi_aligner,
detection_generator=detection_generator,
mask_head=mask_head,
mask_sampler=mask_sampler,
mask_roi_aligner=mask_roi_aligner,
class_agnostic_bbox_pred=class_agnostic_bbox_pred,
cascade_class_ensemble=cascade_class_ensemble,
min_level=min_level,
max_level=max_level,
num_scales=num_scales,
aspect_ratios=aspect_ratios,
anchor_size=anchor_size,
**kwargs)
self._config_dict.update({
'segmentation_backbone':segmentation_backbone,
'segmentation_decoder': segmentation_decoder,
'segmentation_head': segmentation_head
})
if not self._include_mask:
raise ValueError(
'`mask_head` needs to be provided for Panoptic Mask R-CNN.')
if segmentation_backbone is not None and segmentation_decoder is None:
raise ValueError(
'`segmentation_decoder` needs to be provided for Panoptic Mask R-CNN if `backbone` is not shared.'
)
self.segmentation_backbone = segmentation_backbone
self.segmentation_decoder = segmentation_decoder
self.segmentation_head = segmentation_head
def call(self,
images: tf.Tensor,
image_shape: tf.Tensor,
anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None,
gt_boxes: Optional[tf.Tensor] = None,
gt_classes: Optional[tf.Tensor] = None,
gt_masks: Optional[tf.Tensor] = None,
training: Optional[bool] = None) -> Mapping[str, tf.Tensor]:
model_outputs = {}
# Feature extraction.
backbone_features = self.backbone(images)
if self.decoder:
decoder_features = self.decoder(backbone_features)
# Region proposal network.
rpn_scores, rpn_boxes = self.rpn_head(decoder_features)
model_outputs.update({
'rpn_boxes': rpn_boxes,
'rpn_scores': rpn_scores
})
# Generate anchor boxes for this batch if not provided.
if anchor_boxes is None:
_, image_height, image_width, _ = images.get_shape().as_list()
anchor_boxes = anchor.Anchor(
min_level=self._config_dict['min_level'],
max_level=self._config_dict['max_level'],
num_scales=self._config_dict['num_scales'],
aspect_ratios=self._config_dict['aspect_ratios'],
anchor_size=self._config_dict['anchor_size'],
image_size=(image_height, image_width)).multilevel_boxes
for l in anchor_boxes:
anchor_boxes[l] = tf.tile(
tf.expand_dims(anchor_boxes[l], axis=0),
[tf.shape(images)[0], 1, 1, 1])
# Generate RoIs.
current_rois, _ = self.roi_generator(rpn_boxes, rpn_scores, anchor_boxes,
image_shape, training)
next_rois = current_rois
all_class_outputs = []
for cascade_num in range(len(self.roi_sampler)):
# In cascade RCNN we want the higher layers to have different regression
# weights as the predicted deltas become smaller and smaller.
regression_weights = self._cascade_layer_to_weights[cascade_num]
current_rois = next_rois
(class_outputs, box_outputs, model_outputs, matched_gt_boxes,
matched_gt_classes, matched_gt_indices,
current_rois) = self._run_frcnn_head(
features=decoder_features,
rois=current_rois,
gt_boxes=gt_boxes,
gt_classes=gt_classes,
training=training,
model_outputs=model_outputs,
cascade_num=cascade_num,
regression_weights=regression_weights)
all_class_outputs.append(class_outputs)
# Generate ROIs for the next cascade head if there is any.
if cascade_num < len(self.roi_sampler) - 1:
next_rois = box_ops.decode_boxes(
tf.cast(box_outputs, tf.float32),
current_rois,
weights=regression_weights)
next_rois = box_ops.clip_boxes(next_rois,
tf.expand_dims(image_shape, axis=1))
if not training:
if self._config_dict['cascade_class_ensemble']:
class_outputs = tf.add_n(all_class_outputs) / len(all_class_outputs)
detections = self.detection_generator(
box_outputs,
class_outputs,
current_rois,
image_shape,
regression_weights,
bbox_per_class=(not self._config_dict['class_agnostic_bbox_pred']))
model_outputs.update({
'cls_outputs': class_outputs,
'box_outputs': box_outputs,
})
if self.detection_generator.get_config()['apply_nms']:
model_outputs.update({
'detection_boxes': detections['detection_boxes'],
'detection_scores': detections['detection_scores'],
'detection_classes': detections['detection_classes'],
'num_detections': detections['num_detections']
})
else:
model_outputs.update({
'decoded_boxes': detections['decoded_boxes'],
'decoded_box_scores': detections['decoded_box_scores']
})
if not self._include_mask:
return model_outputs
if training:
current_rois, roi_classes, roi_masks = self.mask_sampler(
current_rois, matched_gt_boxes, matched_gt_classes,
matched_gt_indices, gt_masks)
roi_masks = tf.stop_gradient(roi_masks)
model_outputs.update({
'mask_class_targets': roi_classes,
'mask_targets': roi_masks,
})
else:
current_rois = model_outputs['detection_boxes']
roi_classes = model_outputs['detection_classes']
# Mask RoI align.
mask_roi_features = self.mask_roi_aligner(decoder_features, current_rois)
# Mask head.
raw_masks = self.mask_head([mask_roi_features, roi_classes])
if training:
model_outputs.update({
'mask_outputs': raw_masks,
})
else:
model_outputs.update({
'detection_masks': tf.math.sigmoid(raw_masks),
})
if self.segmentation_backbone is not None:
backbone_features = self.segmentation_backbone(
images,
training=training)
if self.segmentation_decoder is not None:
decoder_features = self.segmentation_decoder(
backbone_features,
training=training)
segmentation_outputs = self.segmentation_head(
backbone_features,
decoder_features,
training=training)
model_outputs.update({
'segmentation_outputs': segmentation_outputs,
})
return model_outputs
@property
def checkpoint_items(
self) -> Mapping[str, Union[tf.keras.Model, tf.keras.layers.Layer]]:
"""Returns a dictionary of items to be additionally checkpointed."""
items = super(PanopticMaskRCNNModel, self).checkpoint_items
if self.segmentation_backbone is not None:
items.update(segmentation_backbone=self.segmentation_backbone)
if self.segmentation_decoder is not None:
items.update(segmentation_decoder=self.segmentation_decoder)
items.update(segmentation_head=self.segmentation_head)
return items
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for panoptic_maskrcnn_model.py."""
import os
# Import libraries
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_maskrcnn_model
from official.vision.beta.modeling.backbones import resnet
from official.vision.beta.modeling.decoders import fpn
from official.vision.beta.modeling.decoders import aspp
from official.vision.beta.modeling.heads import dense_prediction_heads
from official.vision.beta.modeling.heads import instance_heads
from official.vision.beta.modeling.heads import segmentation_heads
from official.vision.beta.modeling.layers import detection_generator
from official.vision.beta.modeling.layers import mask_sampler
from official.vision.beta.modeling.layers import roi_aligner
from official.vision.beta.modeling.layers import roi_generator
from official.vision.beta.modeling.layers import roi_sampler
from official.vision.beta.ops import anchor
class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
@combinations.generate(
combinations.combine(
use_separable_conv=[True, False],
build_anchor_boxes=[True, False],
shared_backbone=[True, False],
shared_decoder=[True, False],
is_training=[True, False]))
def test_build_model(self, use_separable_conv, build_anchor_boxes,
shared_backbone, shared_decoder, is_training):
num_classes = 3
min_level = 3
max_level = 7
num_scales = 3
aspect_ratios = [1.0]
anchor_size = 3
resnet_model_id = 50
segmentation_resnet_model_id = 101
segmentation_output_stride = 16
aspp_dilation_rates = [6, 12, 18]
aspp_decoder_level = int(np.math.log2(segmentation_output_stride))
fpn_decoder_level = 3
num_anchors_per_location = num_scales * len(aspect_ratios)
image_size = 384
images = np.random.rand(2, image_size, image_size, 3)
image_shape = np.array(
[[image_size, image_size], [image_size, image_size]])
shared_decoder = shared_decoder and shared_backbone
if build_anchor_boxes:
anchor_boxes = anchor.Anchor(
min_level=min_level,
max_level=max_level,
num_scales=num_scales,
aspect_ratios=aspect_ratios,
anchor_size=3,
image_size=(image_size, image_size)).multilevel_boxes
for l in anchor_boxes:
anchor_boxes[l] = tf.tile(
tf.expand_dims(anchor_boxes[l], axis=0), [2, 1, 1, 1])
else:
anchor_boxes = None
backbone = resnet.ResNet(model_id=resnet_model_id)
decoder = fpn.FPN(
input_specs=backbone.output_specs,
min_level=min_level,
max_level=max_level,
use_separable_conv=use_separable_conv)
rpn_head = dense_prediction_heads.RPNHead(
min_level=min_level,
max_level=max_level,
num_anchors_per_location=num_anchors_per_location,
num_convs=1)
detection_head = instance_heads.DetectionHead(num_classes=num_classes)
roi_generator_obj = roi_generator.MultilevelROIGenerator()
roi_sampler_obj = roi_sampler.ROISampler()
roi_aligner_obj = roi_aligner.MultilevelROIAligner()
detection_generator_obj = detection_generator.DetectionGenerator()
mask_head = instance_heads.MaskHead(
num_classes=num_classes, upsample_factor=2)
mask_sampler_obj = mask_sampler.MaskSampler(
mask_target_size=28, num_sampled_masks=1)
mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(crop_size=14)
if shared_backbone:
segmentation_backbone = None
else:
segmentation_backbone = resnet.ResNet(
model_id=segmentation_resnet_model_id)
if not shared_decoder:
level = aspp_decoder_level
segmentation_decoder = aspp.ASPP(
level=level, dilation_rates=aspp_dilation_rates)
else:
level = fpn_decoder_level
segmentation_decoder = None
segmentation_head = segmentation_heads.SegmentationHead(
num_classes=2, # stuff and common class for things,
level=level,
num_convs=2)
model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
backbone,
decoder,
rpn_head,
detection_head,
roi_generator_obj,
roi_sampler_obj,
roi_aligner_obj,
detection_generator_obj,
mask_head,
mask_sampler_obj,
mask_roi_aligner_obj,
segmentation_backbone=segmentation_backbone,
segmentation_decoder=segmentation_decoder,
segmentation_head=segmentation_head,
min_level=min_level,
max_level=max_level,
num_scales=num_scales,
aspect_ratios=aspect_ratios,
anchor_size=anchor_size)
gt_boxes = np.array(
[[[10, 10, 15, 15], [2.5, 2.5, 7.5, 7.5], [-1, -1, -1, -1]],
[[100, 100, 150, 150], [-1, -1, -1, -1], [-1, -1, -1, -1]]],
dtype=np.float32)
gt_classes = np.array([[2, 1, -1], [1, -1, -1]], dtype=np.int32)
gt_masks = np.ones((2, 3, 100, 100))
# Results will be checked in test_forward.
_ = model(
images,
image_shape,
anchor_boxes,
gt_boxes,
gt_classes,
gt_masks,
training=is_training)
@combinations.generate(
combinations.combine(
strategy=[
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
build_anchor_boxes=[True, False],
use_cascade_heads=[True, False],
shared_backbone=[True, False],
shared_decoder=[True, False],
training=[True, False],
))
def test_forward(self, strategy, build_anchor_boxes, training,
shared_backbone, shared_decoder, use_cascade_heads):
num_classes = 3
min_level = 3
max_level = 4
num_scales = 3
aspect_ratios = [1.0]
anchor_size = 3
segmentation_resnet_model_id = 101
segmentation_output_stride = 16
aspp_dilation_rates = [6, 12, 18]
aspp_decoder_level = int(np.math.log2(segmentation_output_stride))
fpn_decoder_level = 3
if use_cascade_heads:
cascade_iou_thresholds = [0.6]
class_agnostic_bbox_pred = True
cascade_class_ensemble = True
else:
cascade_iou_thresholds = None
class_agnostic_bbox_pred = False
cascade_class_ensemble = False
image_size = (256, 256)
images = np.random.rand(2, image_size[0], image_size[1], 3)
image_shape = np.array([[224, 100], [100, 224]])
shared_decoder = shared_decoder and shared_backbone
with strategy.scope():
if build_anchor_boxes:
anchor_boxes = anchor.Anchor(
min_level=min_level,
max_level=max_level,
num_scales=num_scales,
aspect_ratios=aspect_ratios,
anchor_size=anchor_size,
image_size=image_size).multilevel_boxes
else:
anchor_boxes = None
num_anchors_per_location = len(aspect_ratios) * num_scales
input_specs = tf.keras.layers.InputSpec(shape=[None, None, None, 3])
backbone = resnet.ResNet(model_id=50, input_specs=input_specs)
decoder = fpn.FPN(
min_level=min_level,
max_level=max_level,
input_specs=backbone.output_specs)
rpn_head = dense_prediction_heads.RPNHead(
min_level=min_level,
max_level=max_level,
num_anchors_per_location=num_anchors_per_location)
detection_head = instance_heads.DetectionHead(
num_classes=num_classes,
class_agnostic_bbox_pred=class_agnostic_bbox_pred)
roi_generator_obj = roi_generator.MultilevelROIGenerator()
roi_sampler_cascade = []
roi_sampler_obj = roi_sampler.ROISampler()
roi_sampler_cascade.append(roi_sampler_obj)
if cascade_iou_thresholds:
for iou in cascade_iou_thresholds:
roi_sampler_obj = roi_sampler.ROISampler(
mix_gt_boxes=False,
foreground_iou_threshold=iou,
background_iou_high_threshold=iou,
background_iou_low_threshold=0.0,
skip_subsampling=True)
roi_sampler_cascade.append(roi_sampler_obj)
roi_aligner_obj = roi_aligner.MultilevelROIAligner()
detection_generator_obj = detection_generator.DetectionGenerator()
mask_head = instance_heads.MaskHead(
num_classes=num_classes, upsample_factor=2)
mask_sampler_obj = mask_sampler.MaskSampler(
mask_target_size=28, num_sampled_masks=1)
mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(crop_size=14)
if shared_backbone:
segmentation_backbone = None
else:
segmentation_backbone = resnet.ResNet(
model_id=segmentation_resnet_model_id)
if not shared_decoder:
level = aspp_decoder_level
segmentation_decoder = aspp.ASPP(
level=level, dilation_rates=aspp_dilation_rates)
else:
level = fpn_decoder_level
segmentation_decoder = None
segmentation_head = segmentation_heads.SegmentationHead(
num_classes=2, # stuff and common class for things,
level=level,
num_convs=2)
model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
backbone,
decoder,
rpn_head,
detection_head,
roi_generator_obj,
roi_sampler_obj,
roi_aligner_obj,
detection_generator_obj,
mask_head,
mask_sampler_obj,
mask_roi_aligner_obj,
segmentation_backbone=segmentation_backbone,
segmentation_decoder=segmentation_decoder,
segmentation_head=segmentation_head,
class_agnostic_bbox_pred=class_agnostic_bbox_pred,
cascade_class_ensemble=cascade_class_ensemble,
min_level=min_level,
max_level=max_level,
num_scales=num_scales,
aspect_ratios=aspect_ratios,
anchor_size=anchor_size)
gt_boxes = np.array(
[[[10, 10, 15, 15], [2.5, 2.5, 7.5, 7.5], [-1, -1, -1, -1]],
[[100, 100, 150, 150], [-1, -1, -1, -1], [-1, -1, -1, -1]]],
dtype=np.float32)
gt_classes = np.array([[2, 1, -1], [1, -1, -1]], dtype=np.int32)
gt_masks = np.ones((2, 3, 100, 100))
results = model(
images,
image_shape,
anchor_boxes,
gt_boxes,
gt_classes,
gt_masks,
training=training)
self.assertIn('rpn_boxes', results)
self.assertIn('rpn_scores', results)
if training:
self.assertIn('class_targets', results)
self.assertIn('box_targets', results)
self.assertIn('class_outputs', results)
self.assertIn('box_outputs', results)
self.assertIn('mask_outputs', results)
else:
self.assertIn('detection_boxes', results)
self.assertIn('detection_scores', results)
self.assertIn('detection_classes', results)
self.assertIn('num_detections', results)
self.assertIn('detection_masks', results)
self.assertIn('segmentation_outputs', results)
self.assertAllEqual(
[2,
image_size[0] // (2**level),
image_size[1] // (2**level), 2],
results['segmentation_outputs'].numpy().shape)
@combinations.generate(
combinations.combine(
shared_backbone=[True, False],
shared_decoder=[True, False]))
def test_serialize_deserialize(self, shared_backbone, shared_decoder):
input_specs = tf.keras.layers.InputSpec(shape=[None, None, None, 3])
backbone = resnet.ResNet(model_id=50, input_specs=input_specs)
decoder = fpn.FPN(
min_level=3, max_level=7, input_specs=backbone.output_specs)
rpn_head = dense_prediction_heads.RPNHead(
min_level=3, max_level=7, num_anchors_per_location=3)
detection_head = instance_heads.DetectionHead(num_classes=2)
roi_generator_obj = roi_generator.MultilevelROIGenerator()
roi_sampler_obj = roi_sampler.ROISampler()
roi_aligner_obj = roi_aligner.MultilevelROIAligner()
detection_generator_obj = detection_generator.DetectionGenerator()
segmentation_resnet_model_id = 101
segmentation_output_stride = 16
aspp_dilation_rates = [6, 12, 18]
aspp_decoder_level = int(np.math.log2(segmentation_output_stride))
fpn_decoder_level = 3
shared_decoder = shared_decoder and shared_backbone
mask_head = instance_heads.MaskHead(num_classes=2, upsample_factor=2)
mask_sampler_obj = mask_sampler.MaskSampler(
mask_target_size=28, num_sampled_masks=1)
mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(crop_size=14)
if shared_backbone:
segmentation_backbone = None
else:
segmentation_backbone = resnet.ResNet(
model_id=segmentation_resnet_model_id)
if not shared_decoder:
level = aspp_decoder_level
segmentation_decoder = aspp.ASPP(
level=level, dilation_rates=aspp_dilation_rates)
else:
level = fpn_decoder_level
segmentation_decoder = None
segmentation_head = segmentation_heads.SegmentationHead(
num_classes=2, # stuff and common class for things,
level=level,
num_convs=2)
model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
backbone,
decoder,
rpn_head,
detection_head,
roi_generator_obj,
roi_sampler_obj,
roi_aligner_obj,
detection_generator_obj,
mask_head,
mask_sampler_obj,
mask_roi_aligner_obj,
segmentation_backbone=segmentation_backbone,
segmentation_decoder=segmentation_decoder,
segmentation_head=segmentation_head,
min_level=3,
max_level=7,
num_scales=3,
aspect_ratios=[1.0],
anchor_size=3)
config = model.get_config()
new_model = \
panoptic_maskrcnn_model.PanopticMaskRCNNModel.from_config(config)
# Validate that the config can be forced to JSON.
_ = new_model.to_json()
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(model.get_config(), new_model.get_config())
@combinations.generate(
combinations.combine(
shared_backbone=[True, False],
shared_decoder=[True, False]))
def test_checkpoint(self, shared_backbone, shared_decoder):
input_specs = tf.keras.layers.InputSpec(shape=[None, None, None, 3])
backbone = resnet.ResNet(model_id=50, input_specs=input_specs)
decoder = fpn.FPN(
min_level=3, max_level=7, input_specs=backbone.output_specs)
rpn_head = dense_prediction_heads.RPNHead(
min_level=3, max_level=7, num_anchors_per_location=3)
detection_head = instance_heads.DetectionHead(num_classes=2)
roi_generator_obj = roi_generator.MultilevelROIGenerator()
roi_sampler_obj = roi_sampler.ROISampler()
roi_aligner_obj = roi_aligner.MultilevelROIAligner()
detection_generator_obj = detection_generator.DetectionGenerator()
segmentation_resnet_model_id = 101
segmentation_output_stride = 16
aspp_dilation_rates = [6, 12, 18]
aspp_decoder_level = int(np.math.log2(segmentation_output_stride))
fpn_decoder_level = 3
shared_decoder = shared_decoder and shared_backbone
mask_head = instance_heads.MaskHead(num_classes=2, upsample_factor=2)
mask_sampler_obj = mask_sampler.MaskSampler(
mask_target_size=28, num_sampled_masks=1)
mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(crop_size=14)
if shared_backbone:
segmentation_backbone = None
else:
segmentation_backbone = resnet.ResNet(
model_id=segmentation_resnet_model_id)
if not shared_decoder:
level = aspp_decoder_level
segmentation_decoder = aspp.ASPP(
level=level, dilation_rates=aspp_dilation_rates)
else:
level = fpn_decoder_level
segmentation_decoder = None
segmentation_head = segmentation_heads.SegmentationHead(
num_classes=2, # stuff and common class for things,
level=level,
num_convs=2)
model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
backbone,
decoder,
rpn_head,
detection_head,
roi_generator_obj,
roi_sampler_obj,
roi_aligner_obj,
detection_generator_obj,
mask_head,
mask_sampler_obj,
mask_roi_aligner_obj,
segmentation_backbone=segmentation_backbone,
segmentation_decoder=segmentation_decoder,
segmentation_head=segmentation_head,
min_level=3,
max_level=7,
num_scales=3,
aspect_ratios=[1.0],
anchor_size=3)
expect_checkpoint_items = dict(
backbone=backbone,
decoder=decoder,
rpn_head=rpn_head,
detection_head=[detection_head])
expect_checkpoint_items['mask_head'] = mask_head
if not shared_backbone:
expect_checkpoint_items['segmentation_backbone'] = segmentation_backbone
if not shared_decoder:
expect_checkpoint_items['segmentation_decoder'] = segmentation_decoder
expect_checkpoint_items['segmentation_head'] = segmentation_head
self.assertAllEqual(expect_checkpoint_items, model.checkpoint_items)
# Test save and load checkpoints.
ckpt = tf.train.Checkpoint(model=model, **model.checkpoint_items)
save_dir = self.create_tempdir().full_path
ckpt.save(os.path.join(save_dir, 'ckpt'))
partial_ckpt = tf.train.Checkpoint(backbone=backbone)
partial_ckpt.restore(tf.train.latest_checkpoint(
save_dir)).expect_partial().assert_existing_objects_matched()
partial_ckpt_mask = tf.train.Checkpoint(
backbone=backbone, mask_head=mask_head)
partial_ckpt_mask.restore(tf.train.latest_checkpoint(
save_dir)).expect_partial().assert_existing_objects_matched()
if not shared_backbone:
partial_ckpt_segmentation = tf.train.Checkpoint(
segmentation_backbone=segmentation_backbone,
segmentation_decoder=segmentation_decoder,
segmentation_head=segmentation_head)
elif not shared_decoder:
partial_ckpt_segmentation = tf.train.Checkpoint(
segmentation_decoder=segmentation_decoder,
segmentation_head=segmentation_head)
else:
partial_ckpt_segmentation = tf.train.Checkpoint(
segmentation_head=segmentation_head)
partial_ckpt_segmentation.restore(tf.train.latest_checkpoint(
save_dir)).expect_partial().assert_existing_objects_matched()
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment