Commit 84630072 authored by Jaeyoun Kim's avatar Jaeyoun Kim Committed by A. Unique TensorFlower
Browse files

Copybara import of the project:

--
b3be14bc by Srihari Humbarwadi <sriharihumbarwadi97@gmail.com>:

Added `PanopticMaskRCNNModel` model (#10045)

* Added `PanopticMaskRCNNModel` model

* test checkpoint loading for segmentation objects

* fixed docstring

* subclassed `MaskRCNNModel`

* always enable mask and segmentation heads

* added __init__.py

* added README.md
--
35c3a79f by Srihari Humbarwadi <sriharihumbarwadi97@gmail.com>:

fixed linting errors (#10053)

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/models/pull/10052 from tensorflow:panoptic-segmentation 35c3a79f
PiperOrigin-RevId: 379834961
parent c8348887
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
from typing import Any, List, Mapping, Optional, Union from typing import Any, List, Mapping, Optional, Union
# Import libraries
import tensorflow as tf import tensorflow as tf
from official.vision.beta.ops import anchor from official.vision.beta.ops import anchor
...@@ -147,14 +146,18 @@ class MaskRCNNModel(tf.keras.Model): ...@@ -147,14 +146,18 @@ class MaskRCNNModel(tf.keras.Model):
model_outputs = {} model_outputs = {}
# Feature extraction. # Feature extraction.
features = self.backbone(images) backbone_features = self.backbone(images)
if self.decoder: if self.decoder:
features = self.decoder(features) features = self.decoder(backbone_features)
else:
features = backbone_features
# Region proposal network. # Region proposal network.
rpn_scores, rpn_boxes = self.rpn_head(features) rpn_scores, rpn_boxes = self.rpn_head(features)
model_outputs.update({ model_outputs.update({
'backbone_features': backbone_features,
'decoder_features': features,
'rpn_boxes': rpn_boxes, 'rpn_boxes': rpn_boxes,
'rpn_scores': rpn_scores 'rpn_scores': rpn_scores
}) })
......
# Panoptic Segmentation
## Description
Panoptic Segmentation combines the two distinct vision tasks - semantic
segmentation and instance segmentation. These tasks are unified such that, each
pixel in the image is assigned the label of the class it belongs to, and also
the instance identifier of the object it a part of.
## Environment setup
The code can be run on multiple GPUs or TPUs with different distribution
strategies. See the TensorFlow distributed training
[guide](https://www.tensorflow.org/guide/distributed_training) for an overview
of `tf.distribute`.
The code is compatible with TensorFlow 2.4+. See requirements.txt for all
prerequisites, and you can also install them using the following command. `pip
install -r ./official/requirements.txt`
**DISCLAIMER**: Panoptic MaskRCNN is still under active development, stay tuned!
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Panoptic Segmentation model."""
from typing import List, Mapping, Optional, Union
import tensorflow as tf
from official.vision.beta.modeling import maskrcnn_model
@tf.keras.utils.register_keras_serializable(package='Vision')
class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
"""The Panoptic Segmentation model."""
def __init__(self,
backbone: tf.keras.Model,
decoder: tf.keras.Model,
rpn_head: tf.keras.layers.Layer,
detection_head: Union[tf.keras.layers.Layer,
List[tf.keras.layers.Layer]],
roi_generator: tf.keras.layers.Layer,
roi_sampler: Union[tf.keras.layers.Layer,
List[tf.keras.layers.Layer]],
roi_aligner: tf.keras.layers.Layer,
detection_generator: tf.keras.layers.Layer,
mask_head: Optional[tf.keras.layers.Layer] = None,
mask_sampler: Optional[tf.keras.layers.Layer] = None,
mask_roi_aligner: Optional[tf.keras.layers.Layer] = None,
segmentation_backbone: Optional[tf.keras.Model] = None,
segmentation_decoder: Optional[tf.keras.Model] = None,
segmentation_head: tf.keras.layers.Layer = None,
class_agnostic_bbox_pred: bool = False,
cascade_class_ensemble: bool = False,
min_level: Optional[int] = None,
max_level: Optional[int] = None,
num_scales: Optional[int] = None,
aspect_ratios: Optional[List[float]] = None,
anchor_size: Optional[float] = None,
**kwargs):
"""Initializes the Panoptic Mask R-CNN model.
Args:
backbone: `tf.keras.Model`, the backbone network.
decoder: `tf.keras.Model`, the decoder network.
rpn_head: the RPN head.
detection_head: the detection head or a list of heads.
roi_generator: the ROI generator.
roi_sampler: a single ROI sampler or a list of ROI samplers for cascade
detection heads.
roi_aligner: the ROI aligner.
detection_generator: the detection generator.
mask_head: the mask head.
mask_sampler: the mask sampler.
mask_roi_aligner: the ROI alginer for mask prediction.
segmentation_backbone: `tf.keras.Model`, the backbone network for the
segmentation head for panoptic task. Providing `segmentation_backbone`
will allow the segmentation head to use a standlone backbone. Setting
`segmentation_backbone=None` would enable backbone sharing between the
MaskRCNN model and segmentation head.
segmentation_decoder: `tf.keras.Model`, the decoder network for the
segmentation head for panoptic task. Providing `segmentation_decoder`
will allow the segmentation head to use a standlone decoder. Setting
`segmentation_decoder=None` would enable decoder sharing between the
MaskRCNN model and segmentation head. Decoders can only be shared when
`segmentation_backbone` is shared as well.
segmentation_head: segmentatation head for panoptic task.
class_agnostic_bbox_pred: if True, perform class agnostic bounding box
prediction. Needs to be `True` for Cascade RCNN models.
cascade_class_ensemble: if True, ensemble classification scores over all
detection heads.
min_level: Minimum level in output feature maps.
max_level: Maximum level in output feature maps.
num_scales: A number representing intermediate scales added on each level.
For instances, num_scales=2 adds one additional intermediate anchor
scales [2^0, 2^0.5] on each level.
aspect_ratios: A list representing the aspect raito anchors added on each
level. The number indicates the ratio of width to height. For instances,
aspect_ratios=[1.0, 2.0, 0.5] adds three anchors on each scale level.
anchor_size: A number representing the scale of size of the base anchor to
the feature stride 2^level.
**kwargs: keyword arguments to be passed.
"""
super(PanopticMaskRCNNModel, self).__init__(
backbone=backbone,
decoder=decoder,
rpn_head=rpn_head,
detection_head=detection_head,
roi_generator=roi_generator,
roi_sampler=roi_sampler,
roi_aligner=roi_aligner,
detection_generator=detection_generator,
mask_head=mask_head,
mask_sampler=mask_sampler,
mask_roi_aligner=mask_roi_aligner,
class_agnostic_bbox_pred=class_agnostic_bbox_pred,
cascade_class_ensemble=cascade_class_ensemble,
min_level=min_level,
max_level=max_level,
num_scales=num_scales,
aspect_ratios=aspect_ratios,
anchor_size=anchor_size,
**kwargs)
self._config_dict.update({
'segmentation_backbone': segmentation_backbone,
'segmentation_decoder': segmentation_decoder,
'segmentation_head': segmentation_head
})
if not self._include_mask:
raise ValueError(
'`mask_head` needs to be provided for Panoptic Mask R-CNN.')
if segmentation_backbone is not None and segmentation_decoder is None:
raise ValueError(
'`segmentation_decoder` needs to be provided for Panoptic Mask R-CNN'
'if `backbone` is not shared.')
self.segmentation_backbone = segmentation_backbone
self.segmentation_decoder = segmentation_decoder
self.segmentation_head = segmentation_head
def call(self,
images: tf.Tensor,
image_shape: tf.Tensor,
anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None,
gt_boxes: Optional[tf.Tensor] = None,
gt_classes: Optional[tf.Tensor] = None,
gt_masks: Optional[tf.Tensor] = None,
training: Optional[bool] = None) -> Mapping[str, tf.Tensor]:
model_outputs = super(PanopticMaskRCNNModel, self).call(
images=images,
image_shape=image_shape,
anchor_boxes=anchor_boxes,
gt_boxes=gt_boxes,
gt_classes=gt_classes,
gt_masks=gt_masks,
training=training)
if self.segmentation_backbone is not None:
backbone_features = self.segmentation_backbone(images, training=training)
else:
backbone_features = model_outputs['backbone_features']
if self.segmentation_decoder is not None:
decoder_features = self.segmentation_decoder(
backbone_features, training=training)
else:
decoder_features = model_outputs['decoder_features']
segmentation_outputs = self.segmentation_head(
backbone_features, decoder_features, training=training)
model_outputs.update({
'segmentation_outputs': segmentation_outputs,
})
return model_outputs
@property
def checkpoint_items(
self) -> Mapping[str, Union[tf.keras.Model, tf.keras.layers.Layer]]:
"""Returns a dictionary of items to be additionally checkpointed."""
items = super(PanopticMaskRCNNModel, self).checkpoint_items
if self.segmentation_backbone is not None:
items.update(segmentation_backbone=self.segmentation_backbone)
if self.segmentation_decoder is not None:
items.update(segmentation_decoder=self.segmentation_decoder)
items.update(segmentation_head=self.segmentation_head)
return items
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for panoptic_maskrcnn_model.py."""
import os
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.vision.beta.modeling.backbones import resnet
from official.vision.beta.modeling.decoders import aspp
from official.vision.beta.modeling.decoders import fpn
from official.vision.beta.modeling.heads import dense_prediction_heads
from official.vision.beta.modeling.heads import instance_heads
from official.vision.beta.modeling.heads import segmentation_heads
from official.vision.beta.modeling.layers import detection_generator
from official.vision.beta.modeling.layers import mask_sampler
from official.vision.beta.modeling.layers import roi_aligner
from official.vision.beta.modeling.layers import roi_generator
from official.vision.beta.modeling.layers import roi_sampler
from official.vision.beta.ops import anchor
from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_maskrcnn_model
class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
@combinations.generate(
combinations.combine(
use_separable_conv=[True, False],
build_anchor_boxes=[True, False],
shared_backbone=[True, False],
shared_decoder=[True, False],
is_training=[True, False]))
def test_build_model(self,
use_separable_conv,
build_anchor_boxes,
shared_backbone,
shared_decoder,
is_training=True):
num_classes = 3
min_level = 3
max_level = 7
num_scales = 3
aspect_ratios = [1.0]
anchor_size = 3
resnet_model_id = 50
segmentation_resnet_model_id = 50
segmentation_output_stride = 16
aspp_dilation_rates = [6, 12, 18]
aspp_decoder_level = int(np.math.log2(segmentation_output_stride))
fpn_decoder_level = 3
num_anchors_per_location = num_scales * len(aspect_ratios)
image_size = 128
images = np.random.rand(2, image_size, image_size, 3)
image_shape = np.array([[image_size, image_size], [image_size, image_size]])
shared_decoder = shared_decoder and shared_backbone
if build_anchor_boxes:
anchor_boxes = anchor.Anchor(
min_level=min_level,
max_level=max_level,
num_scales=num_scales,
aspect_ratios=aspect_ratios,
anchor_size=3,
image_size=(image_size, image_size)).multilevel_boxes
for l in anchor_boxes:
anchor_boxes[l] = tf.tile(
tf.expand_dims(anchor_boxes[l], axis=0), [2, 1, 1, 1])
else:
anchor_boxes = None
backbone = resnet.ResNet(model_id=resnet_model_id)
decoder = fpn.FPN(
input_specs=backbone.output_specs,
min_level=min_level,
max_level=max_level,
use_separable_conv=use_separable_conv)
rpn_head = dense_prediction_heads.RPNHead(
min_level=min_level,
max_level=max_level,
num_anchors_per_location=num_anchors_per_location,
num_convs=1)
detection_head = instance_heads.DetectionHead(num_classes=num_classes)
roi_generator_obj = roi_generator.MultilevelROIGenerator()
roi_sampler_obj = roi_sampler.ROISampler()
roi_aligner_obj = roi_aligner.MultilevelROIAligner()
detection_generator_obj = detection_generator.DetectionGenerator()
mask_head = instance_heads.MaskHead(
num_classes=num_classes, upsample_factor=2)
mask_sampler_obj = mask_sampler.MaskSampler(
mask_target_size=28, num_sampled_masks=1)
mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(crop_size=14)
if shared_backbone:
segmentation_backbone = None
else:
segmentation_backbone = resnet.ResNet(
model_id=segmentation_resnet_model_id)
if not shared_decoder:
level = aspp_decoder_level
segmentation_decoder = aspp.ASPP(
level=level, dilation_rates=aspp_dilation_rates)
else:
level = fpn_decoder_level
segmentation_decoder = None
segmentation_head = segmentation_heads.SegmentationHead(
num_classes=2, # stuff and common class for things,
level=level,
num_convs=2)
model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
backbone,
decoder,
rpn_head,
detection_head,
roi_generator_obj,
roi_sampler_obj,
roi_aligner_obj,
detection_generator_obj,
mask_head,
mask_sampler_obj,
mask_roi_aligner_obj,
segmentation_backbone=segmentation_backbone,
segmentation_decoder=segmentation_decoder,
segmentation_head=segmentation_head,
min_level=min_level,
max_level=max_level,
num_scales=num_scales,
aspect_ratios=aspect_ratios,
anchor_size=anchor_size)
gt_boxes = np.array(
[[[10, 10, 15, 15], [2.5, 2.5, 7.5, 7.5], [-1, -1, -1, -1]],
[[100, 100, 150, 150], [-1, -1, -1, -1], [-1, -1, -1, -1]]],
dtype=np.float32)
gt_classes = np.array([[2, 1, -1], [1, -1, -1]], dtype=np.int32)
gt_masks = np.ones((2, 3, 100, 100))
# Results will be checked in test_forward.
_ = model(
images,
image_shape,
anchor_boxes,
gt_boxes,
gt_classes,
gt_masks,
training=is_training)
@combinations.generate(
combinations.combine(
strategy=[
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
shared_backbone=[True, False],
shared_decoder=[True, False],
training=[True, False],
))
def test_forward(self, strategy, training,
shared_backbone, shared_decoder):
num_classes = 3
min_level = 3
max_level = 4
num_scales = 3
aspect_ratios = [1.0]
anchor_size = 3
segmentation_resnet_model_id = 101
segmentation_output_stride = 16
aspp_dilation_rates = [6, 12, 18]
aspp_decoder_level = int(np.math.log2(segmentation_output_stride))
fpn_decoder_level = 3
class_agnostic_bbox_pred = False
cascade_class_ensemble = False
image_size = (256, 256)
images = np.random.rand(2, image_size[0], image_size[1], 3)
image_shape = np.array([[224, 100], [100, 224]])
shared_decoder = shared_decoder and shared_backbone
with strategy.scope():
anchor_boxes = anchor.Anchor(
min_level=min_level,
max_level=max_level,
num_scales=num_scales,
aspect_ratios=aspect_ratios,
anchor_size=anchor_size,
image_size=image_size).multilevel_boxes
num_anchors_per_location = len(aspect_ratios) * num_scales
input_specs = tf.keras.layers.InputSpec(shape=[None, None, None, 3])
backbone = resnet.ResNet(model_id=50, input_specs=input_specs)
decoder = fpn.FPN(
min_level=min_level,
max_level=max_level,
input_specs=backbone.output_specs)
rpn_head = dense_prediction_heads.RPNHead(
min_level=min_level,
max_level=max_level,
num_anchors_per_location=num_anchors_per_location)
detection_head = instance_heads.DetectionHead(
num_classes=num_classes,
class_agnostic_bbox_pred=class_agnostic_bbox_pred)
roi_generator_obj = roi_generator.MultilevelROIGenerator()
roi_sampler_cascade = []
roi_sampler_obj = roi_sampler.ROISampler()
roi_sampler_cascade.append(roi_sampler_obj)
roi_aligner_obj = roi_aligner.MultilevelROIAligner()
detection_generator_obj = detection_generator.DetectionGenerator()
mask_head = instance_heads.MaskHead(
num_classes=num_classes, upsample_factor=2)
mask_sampler_obj = mask_sampler.MaskSampler(
mask_target_size=28, num_sampled_masks=1)
mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(crop_size=14)
if shared_backbone:
segmentation_backbone = None
else:
segmentation_backbone = resnet.ResNet(
model_id=segmentation_resnet_model_id)
if not shared_decoder:
level = aspp_decoder_level
segmentation_decoder = aspp.ASPP(
level=level, dilation_rates=aspp_dilation_rates)
else:
level = fpn_decoder_level
segmentation_decoder = None
segmentation_head = segmentation_heads.SegmentationHead(
num_classes=2, # stuff and common class for things,
level=level,
num_convs=2)
model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
backbone,
decoder,
rpn_head,
detection_head,
roi_generator_obj,
roi_sampler_obj,
roi_aligner_obj,
detection_generator_obj,
mask_head,
mask_sampler_obj,
mask_roi_aligner_obj,
segmentation_backbone=segmentation_backbone,
segmentation_decoder=segmentation_decoder,
segmentation_head=segmentation_head,
class_agnostic_bbox_pred=class_agnostic_bbox_pred,
cascade_class_ensemble=cascade_class_ensemble,
min_level=min_level,
max_level=max_level,
num_scales=num_scales,
aspect_ratios=aspect_ratios,
anchor_size=anchor_size)
gt_boxes = np.array(
[[[10, 10, 15, 15], [2.5, 2.5, 7.5, 7.5], [-1, -1, -1, -1]],
[[100, 100, 150, 150], [-1, -1, -1, -1], [-1, -1, -1, -1]]],
dtype=np.float32)
gt_classes = np.array([[2, 1, -1], [1, -1, -1]], dtype=np.int32)
gt_masks = np.ones((2, 3, 100, 100))
results = model(
images,
image_shape,
anchor_boxes,
gt_boxes,
gt_classes,
gt_masks,
training=training)
self.assertIn('rpn_boxes', results)
self.assertIn('rpn_scores', results)
if training:
self.assertIn('class_targets', results)
self.assertIn('box_targets', results)
self.assertIn('class_outputs', results)
self.assertIn('box_outputs', results)
self.assertIn('mask_outputs', results)
else:
self.assertIn('detection_boxes', results)
self.assertIn('detection_scores', results)
self.assertIn('detection_classes', results)
self.assertIn('num_detections', results)
self.assertIn('detection_masks', results)
self.assertIn('segmentation_outputs', results)
self.assertAllEqual(
[2, image_size[0] // (2**level), image_size[1] // (2**level), 2],
results['segmentation_outputs'].numpy().shape)
@combinations.generate(
combinations.combine(
shared_backbone=[True, False], shared_decoder=[True, False]))
def test_serialize_deserialize(self, shared_backbone, shared_decoder):
input_specs = tf.keras.layers.InputSpec(shape=[None, None, None, 3])
backbone = resnet.ResNet(model_id=50, input_specs=input_specs)
decoder = fpn.FPN(
min_level=3, max_level=7, input_specs=backbone.output_specs)
rpn_head = dense_prediction_heads.RPNHead(
min_level=3, max_level=7, num_anchors_per_location=3)
detection_head = instance_heads.DetectionHead(num_classes=2)
roi_generator_obj = roi_generator.MultilevelROIGenerator()
roi_sampler_obj = roi_sampler.ROISampler()
roi_aligner_obj = roi_aligner.MultilevelROIAligner()
detection_generator_obj = detection_generator.DetectionGenerator()
segmentation_resnet_model_id = 101
segmentation_output_stride = 16
aspp_dilation_rates = [6, 12, 18]
aspp_decoder_level = int(np.math.log2(segmentation_output_stride))
fpn_decoder_level = 3
shared_decoder = shared_decoder and shared_backbone
mask_head = instance_heads.MaskHead(num_classes=2, upsample_factor=2)
mask_sampler_obj = mask_sampler.MaskSampler(
mask_target_size=28, num_sampled_masks=1)
mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(crop_size=14)
if shared_backbone:
segmentation_backbone = None
else:
segmentation_backbone = resnet.ResNet(
model_id=segmentation_resnet_model_id)
if not shared_decoder:
level = aspp_decoder_level
segmentation_decoder = aspp.ASPP(
level=level, dilation_rates=aspp_dilation_rates)
else:
level = fpn_decoder_level
segmentation_decoder = None
segmentation_head = segmentation_heads.SegmentationHead(
num_classes=2, # stuff and common class for things,
level=level,
num_convs=2)
model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
backbone,
decoder,
rpn_head,
detection_head,
roi_generator_obj,
roi_sampler_obj,
roi_aligner_obj,
detection_generator_obj,
mask_head,
mask_sampler_obj,
mask_roi_aligner_obj,
segmentation_backbone=segmentation_backbone,
segmentation_decoder=segmentation_decoder,
segmentation_head=segmentation_head,
min_level=3,
max_level=7,
num_scales=3,
aspect_ratios=[1.0],
anchor_size=3)
config = model.get_config()
new_model = panoptic_maskrcnn_model.PanopticMaskRCNNModel.from_config(
config)
# Validate that the config can be forced to JSON.
_ = new_model.to_json()
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(model.get_config(), new_model.get_config())
@combinations.generate(
combinations.combine(
shared_backbone=[True, False], shared_decoder=[True, False]))
def test_checkpoint(self, shared_backbone, shared_decoder):
input_specs = tf.keras.layers.InputSpec(shape=[None, None, None, 3])
backbone = resnet.ResNet(model_id=50, input_specs=input_specs)
decoder = fpn.FPN(
min_level=3, max_level=7, input_specs=backbone.output_specs)
rpn_head = dense_prediction_heads.RPNHead(
min_level=3, max_level=7, num_anchors_per_location=3)
detection_head = instance_heads.DetectionHead(num_classes=2)
roi_generator_obj = roi_generator.MultilevelROIGenerator()
roi_sampler_obj = roi_sampler.ROISampler()
roi_aligner_obj = roi_aligner.MultilevelROIAligner()
detection_generator_obj = detection_generator.DetectionGenerator()
segmentation_resnet_model_id = 101
segmentation_output_stride = 16
aspp_dilation_rates = [6, 12, 18]
aspp_decoder_level = int(np.math.log2(segmentation_output_stride))
fpn_decoder_level = 3
shared_decoder = shared_decoder and shared_backbone
mask_head = instance_heads.MaskHead(num_classes=2, upsample_factor=2)
mask_sampler_obj = mask_sampler.MaskSampler(
mask_target_size=28, num_sampled_masks=1)
mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(crop_size=14)
if shared_backbone:
segmentation_backbone = None
else:
segmentation_backbone = resnet.ResNet(
model_id=segmentation_resnet_model_id)
if not shared_decoder:
level = aspp_decoder_level
segmentation_decoder = aspp.ASPP(
level=level, dilation_rates=aspp_dilation_rates)
else:
level = fpn_decoder_level
segmentation_decoder = None
segmentation_head = segmentation_heads.SegmentationHead(
num_classes=2, # stuff and common class for things,
level=level,
num_convs=2)
model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
backbone,
decoder,
rpn_head,
detection_head,
roi_generator_obj,
roi_sampler_obj,
roi_aligner_obj,
detection_generator_obj,
mask_head,
mask_sampler_obj,
mask_roi_aligner_obj,
segmentation_backbone=segmentation_backbone,
segmentation_decoder=segmentation_decoder,
segmentation_head=segmentation_head,
min_level=3,
max_level=7,
num_scales=3,
aspect_ratios=[1.0],
anchor_size=3)
expect_checkpoint_items = dict(
backbone=backbone,
decoder=decoder,
rpn_head=rpn_head,
detection_head=[detection_head])
expect_checkpoint_items['mask_head'] = mask_head
if not shared_backbone:
expect_checkpoint_items['segmentation_backbone'] = segmentation_backbone
if not shared_decoder:
expect_checkpoint_items['segmentation_decoder'] = segmentation_decoder
expect_checkpoint_items['segmentation_head'] = segmentation_head
self.assertAllEqual(expect_checkpoint_items, model.checkpoint_items)
# Test save and load checkpoints.
ckpt = tf.train.Checkpoint(model=model, **model.checkpoint_items)
save_dir = self.create_tempdir().full_path
ckpt.save(os.path.join(save_dir, 'ckpt'))
partial_ckpt = tf.train.Checkpoint(backbone=backbone)
partial_ckpt.restore(tf.train.latest_checkpoint(
save_dir)).expect_partial().assert_existing_objects_matched()
partial_ckpt_mask = tf.train.Checkpoint(
backbone=backbone, mask_head=mask_head)
partial_ckpt_mask.restore(tf.train.latest_checkpoint(
save_dir)).expect_partial().assert_existing_objects_matched()
if not shared_backbone:
partial_ckpt_segmentation = tf.train.Checkpoint(
segmentation_backbone=segmentation_backbone,
segmentation_decoder=segmentation_decoder,
segmentation_head=segmentation_head)
elif not shared_decoder:
partial_ckpt_segmentation = tf.train.Checkpoint(
segmentation_decoder=segmentation_decoder,
segmentation_head=segmentation_head)
else:
partial_ckpt_segmentation = tf.train.Checkpoint(
segmentation_head=segmentation_head)
partial_ckpt_segmentation.restore(tf.train.latest_checkpoint(
save_dir)).expect_partial().assert_existing_objects_matched()
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment