factory.py 11.3 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Factory method to build panoptic segmentation model."""
16
from typing import Optional
Abdullah Rashwan's avatar
Abdullah Rashwan committed
17
18
19

import tensorflow as tf

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
20
from official.projects.deepmac_maskrcnn.tasks import deep_mask_head_rcnn
Abdullah Rashwan's avatar
Abdullah Rashwan committed
21
22
23
24
25
26
27
from official.projects.panoptic.configs import panoptic_deeplab as panoptic_deeplab_cfg
from official.projects.panoptic.configs import panoptic_maskrcnn as panoptic_maskrcnn_cfg
from official.projects.panoptic.modeling import panoptic_deeplab_model
from official.projects.panoptic.modeling import panoptic_maskrcnn_model
from official.projects.panoptic.modeling.heads import panoptic_deeplab_heads
from official.projects.panoptic.modeling.layers import panoptic_deeplab_merge
from official.projects.panoptic.modeling.layers import panoptic_segmentation_generator
Abdullah Rashwan's avatar
Abdullah Rashwan committed
28
29
30
from official.vision.modeling import backbones
from official.vision.modeling.decoders import factory as decoder_factory
from official.vision.modeling.heads import segmentation_heads
Abdullah Rashwan's avatar
Abdullah Rashwan committed
31
32
33
34
35


def build_panoptic_maskrcnn(
    input_specs: tf.keras.layers.InputSpec,
    model_config: panoptic_maskrcnn_cfg.PanopticMaskRCNN,
Rebecca Chen's avatar
Rebecca Chen committed
36
    l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:  # pytype: disable=annotation-type-mismatch  # typed-keras
Abdullah Rashwan's avatar
Abdullah Rashwan committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
  """Builds Panoptic Mask R-CNN model.

  This factory function builds the mask rcnn first, builds the non-shared
  semantic segmentation layers, and finally combines the two models to form
  the panoptic segmentation model.

  Args:
    input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
    model_config: Config instance for the panoptic maskrcnn model.
    l2_regularizer: Optional `tf.keras.regularizers.Regularizer`, if specified,
      the model is built with the provided regularization layer.
  Returns:
    tf.keras.Model for the panoptic segmentation model.
  """
  norm_activation_config = model_config.norm_activation
  segmentation_config = model_config.segmentation_model

  # Builds the maskrcnn model.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
55
  maskrcnn_model = deep_mask_head_rcnn.build_maskrcnn(
Abdullah Rashwan's avatar
Abdullah Rashwan committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
      input_specs=input_specs,
      model_config=model_config,
      l2_regularizer=l2_regularizer)

  # Builds the semantic segmentation branch.
  if not model_config.shared_backbone:
    segmentation_backbone = backbones.factory.build_backbone(
        input_specs=input_specs,
        backbone_config=segmentation_config.backbone,
        norm_activation_config=norm_activation_config,
        l2_regularizer=l2_regularizer)
    segmentation_decoder_input_specs = segmentation_backbone.output_specs
  else:
    segmentation_backbone = None
    segmentation_decoder_input_specs = maskrcnn_model.backbone.output_specs

  if not model_config.shared_decoder:
    segmentation_decoder = decoder_factory.build_decoder(
        input_specs=segmentation_decoder_input_specs,
        model_config=segmentation_config,
        l2_regularizer=l2_regularizer)
77
    decoder_config = segmentation_decoder.get_config()
Abdullah Rashwan's avatar
Abdullah Rashwan committed
78
79
  else:
    segmentation_decoder = None
80
    decoder_config = maskrcnn_model.decoder.get_config()
Abdullah Rashwan's avatar
Abdullah Rashwan committed
81
82
83

  segmentation_head_config = segmentation_config.head
  detection_head_config = model_config.detection_head
84
  postprocessing_config = model_config.panoptic_segmentation_generator
Abdullah Rashwan's avatar
Abdullah Rashwan committed
85
86
87
88
89
90
91
92
93

  segmentation_head = segmentation_heads.SegmentationHead(
      num_classes=segmentation_config.num_classes,
      level=segmentation_head_config.level,
      num_convs=segmentation_head_config.num_convs,
      prediction_kernel_size=segmentation_head_config.prediction_kernel_size,
      num_filters=segmentation_head_config.num_filters,
      upsample_factor=segmentation_head_config.upsample_factor,
      feature_fusion=segmentation_head_config.feature_fusion,
94
95
      decoder_min_level=segmentation_head_config.decoder_min_level,
      decoder_max_level=segmentation_head_config.decoder_max_level,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
96
97
98
99
100
101
      low_level=segmentation_head_config.low_level,
      low_level_num_filters=segmentation_head_config.low_level_num_filters,
      activation=norm_activation_config.activation,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
102
      num_decoder_filters=decoder_config['num_filters'],
Abdullah Rashwan's avatar
Abdullah Rashwan committed
103
104
      kernel_regularizer=l2_regularizer)

105
106
107
  if model_config.generate_panoptic_masks:
    max_num_detections = model_config.detection_generator.max_num_detections
    mask_binarize_threshold = postprocessing_config.mask_binarize_threshold
108
109
110
111
112
113
    panoptic_segmentation_generator_obj = panoptic_segmentation_generator.PanopticSegmentationGenerator(
        output_size=postprocessing_config.output_size,
        max_num_detections=max_num_detections,
        stuff_classes_offset=model_config.stuff_classes_offset,
        mask_binarize_threshold=mask_binarize_threshold,
        score_threshold=postprocessing_config.score_threshold,
114
        things_overlap_threshold=postprocessing_config.things_overlap_threshold,
115
        things_class_label=postprocessing_config.things_class_label,
116
        stuff_area_threshold=postprocessing_config.stuff_area_threshold,
117
        void_class_label=postprocessing_config.void_class_label,
118
119
        void_instance_id=postprocessing_config.void_instance_id,
        rescale_predictions=postprocessing_config.rescale_predictions)
120
121
  else:
    panoptic_segmentation_generator_obj = None
122

Abdullah Rashwan's avatar
Abdullah Rashwan committed
123
124
  # Combines maskrcnn, and segmentation models to build panoptic segmentation
  # model.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
125

Abdullah Rashwan's avatar
Abdullah Rashwan committed
126
127
128
129
130
131
132
133
134
  model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
      backbone=maskrcnn_model.backbone,
      decoder=maskrcnn_model.decoder,
      rpn_head=maskrcnn_model.rpn_head,
      detection_head=maskrcnn_model.detection_head,
      roi_generator=maskrcnn_model.roi_generator,
      roi_sampler=maskrcnn_model.roi_sampler,
      roi_aligner=maskrcnn_model.roi_aligner,
      detection_generator=maskrcnn_model.detection_generator,
135
      panoptic_segmentation_generator=panoptic_segmentation_generator_obj,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
      mask_head=maskrcnn_model.mask_head,
      mask_sampler=maskrcnn_model.mask_sampler,
      mask_roi_aligner=maskrcnn_model.mask_roi_aligner,
      segmentation_backbone=segmentation_backbone,
      segmentation_decoder=segmentation_decoder,
      segmentation_head=segmentation_head,
      class_agnostic_bbox_pred=detection_head_config.class_agnostic_bbox_pred,
      cascade_class_ensemble=detection_head_config.cascade_class_ensemble,
      min_level=model_config.min_level,
      max_level=model_config.max_level,
      num_scales=model_config.anchor.num_scales,
      aspect_ratios=model_config.anchor.aspect_ratios,
      anchor_size=model_config.anchor.anchor_size)
  return model
150
151
152


def build_panoptic_deeplab(
153
154
155
156
    input_specs: tf.keras.layers.InputSpec,
    model_config: panoptic_deeplab_cfg.PanopticDeeplab,
    l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None
) -> tf.keras.Model:
157
158
159
160
161
  """Builds Panoptic Deeplab model.


  Args:
    input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
162
    model_config: Config instance for the panoptic deeplab model.
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    l2_regularizer: Optional `tf.keras.regularizers.Regularizer`, if specified,
      the model is built with the provided regularization layer.
  Returns:
    tf.keras.Model for the panoptic segmentation model.
  """
  norm_activation_config = model_config.norm_activation
  backbone = backbones.factory.build_backbone(
      input_specs=input_specs,
      backbone_config=model_config.backbone,
      norm_activation_config=norm_activation_config,
      l2_regularizer=l2_regularizer)

  semantic_decoder = decoder_factory.build_decoder(
      input_specs=backbone.output_specs,
      model_config=model_config,
      l2_regularizer=l2_regularizer)

  if model_config.shared_decoder:
    instance_decoder = None
  else:
srihari-humbarwadi's avatar
srihari-humbarwadi committed
183
    # semantic and instance share the same decoder type
184
185
186
187
188
189
    instance_decoder = decoder_factory.build_decoder(
        input_specs=backbone.output_specs,
        model_config=model_config,
        l2_regularizer=l2_regularizer)

  semantic_head_config = model_config.semantic_head
190
  instance_head_config = model_config.instance_head
191

192
  semantic_head = panoptic_deeplab_heads.SemanticHead(
193
194
195
196
197
198
199
200
201
202
      num_classes=model_config.num_classes,
      level=semantic_head_config.level,
      num_convs=semantic_head_config.num_convs,
      kernel_size=semantic_head_config.kernel_size,
      prediction_kernel_size=semantic_head_config.prediction_kernel_size,
      num_filters=semantic_head_config.num_filters,
      use_depthwise_convolution=semantic_head_config.use_depthwise_convolution,
      upsample_factor=semantic_head_config.upsample_factor,
      low_level=semantic_head_config.low_level,
      low_level_num_filters=semantic_head_config.low_level_num_filters,
203
      fusion_num_output_filters=semantic_head_config.fusion_num_output_filters,
204
205
206
207
208
209
      activation=norm_activation_config.activation,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      kernel_regularizer=l2_regularizer)

210
211
212
213
214
215
216
217
218
219
  instance_head = panoptic_deeplab_heads.InstanceHead(
      level=instance_head_config.level,
      num_convs=instance_head_config.num_convs,
      kernel_size=instance_head_config.kernel_size,
      prediction_kernel_size=instance_head_config.prediction_kernel_size,
      num_filters=instance_head_config.num_filters,
      use_depthwise_convolution=instance_head_config.use_depthwise_convolution,
      upsample_factor=instance_head_config.upsample_factor,
      low_level=instance_head_config.low_level,
      low_level_num_filters=instance_head_config.low_level_num_filters,
220
      fusion_num_output_filters=instance_head_config.fusion_num_output_filters,
221
222
223
224
225
226
      activation=norm_activation_config.activation,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      kernel_regularizer=l2_regularizer)

227
228
229
230
231
232
233
234
235
236
237
238
239
240
  if model_config.generate_panoptic_masks:
    post_processing_config = model_config.post_processor
    post_processor = panoptic_deeplab_merge.PostProcessor(
        output_size=post_processing_config.output_size,
        center_score_threshold=post_processing_config.center_score_threshold,
        thing_class_ids=post_processing_config.thing_class_ids,
        label_divisor=post_processing_config.label_divisor,
        stuff_area_limit=post_processing_config.stuff_area_limit,
        ignore_label=post_processing_config.ignore_label,
        nms_kernel=post_processing_config.nms_kernel,
        keep_k_centers=post_processing_config.keep_k_centers,
        rescale_predictions=post_processing_config.rescale_predictions)
  else:
    post_processor = None
241

242
  model = panoptic_deeplab_model.PanopticDeeplabModel(
243
      backbone=backbone,
244
245
246
      semantic_decoder=semantic_decoder,
      instance_decoder=instance_decoder,
      semantic_head=semantic_head,
247
      instance_head=instance_head,
248
      post_processor=post_processor)
249
250

  return model