panoptic_maskrcnn.py 10.1 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Panoptic Mask R-CNN configuration definition."""

import dataclasses
Jaeyoun Kim's avatar
Jaeyoun Kim committed
18
19
20
21
22
import os
from typing import List, Optional

from official.core import config_definitions as cfg
from official.core import exp_factory
23
from official.modeling import hyperparams
Jaeyoun Kim's avatar
Jaeyoun Kim committed
24
from official.modeling import optimization
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
25
from official.projects.deepmac_maskrcnn.configs import deep_mask_head_rcnn as deepmac_maskrcnn
Abdullah Rashwan's avatar
Abdullah Rashwan committed
26
27
28
from official.vision.configs import common
from official.vision.configs import maskrcnn
from official.vision.configs import semantic_segmentation
Abdullah Rashwan's avatar
Abdullah Rashwan committed
29
30


Jaeyoun Kim's avatar
Jaeyoun Kim committed
31
32
33
SEGMENTATION_MODEL = semantic_segmentation.SemanticSegmentationModel
SEGMENTATION_HEAD = semantic_segmentation.SegmentationHead

34
_COCO_INPUT_PATH_BASE = 'coco/tfrecords'
Jaeyoun Kim's avatar
Jaeyoun Kim committed
35
36
37
38
39
40
_COCO_TRAIN_EXAMPLES = 118287
_COCO_VAL_EXAMPLES = 5000

# pytype: disable=wrong-keyword-args


Jaeyoun Kim's avatar
Jaeyoun Kim committed
41
42
@dataclasses.dataclass
class Parser(maskrcnn.Parser):
Jaeyoun Kim's avatar
Jaeyoun Kim committed
43
44
45
46
47
48
  """Panoptic Mask R-CNN parser config."""
  # If segmentation_resize_eval_groundtruth is set to False, original image
  # sizes are used for eval. In that case,
  # segmentation_groundtruth_padded_size has to be specified too to allow for
  # batching the variable input sizes of images.
  segmentation_resize_eval_groundtruth: bool = True
Jaeyoun Kim's avatar
Jaeyoun Kim committed
49
50
51
  segmentation_groundtruth_padded_size: List[int] = dataclasses.field(
      default_factory=list)
  segmentation_ignore_label: int = 255
52
  panoptic_ignore_label: int = 0
53
  # Setting this to true will enable parsing category_mask and instance_mask.
54
  include_panoptic_masks: bool = True
55

56

57
@dataclasses.dataclass
58
class TfExampleDecoder(common.TfExampleDecoder):
59
  """A simple TF Example decoder config."""
60
  # Setting this to true will enable decoding category_mask and instance_mask.
61
  include_panoptic_masks: bool = True
Abdullah Rashwan's avatar
Abdullah Rashwan committed
62
63
  panoptic_category_mask_key: str = 'image/panoptic/category_mask'
  panoptic_instance_mask_key: str = 'image/panoptic/instance_mask'
64
65
66


@dataclasses.dataclass
67
class DataDecoder(common.DataDecoder):
68
69
  """Data decoder config."""
  simple_decoder: TfExampleDecoder = TfExampleDecoder()
Jaeyoun Kim's avatar
Jaeyoun Kim committed
70
71
72
73
74


@dataclasses.dataclass
class DataConfig(maskrcnn.DataConfig):
  """Input config for training."""
75
  decoder: DataDecoder = DataDecoder()
Jaeyoun Kim's avatar
Jaeyoun Kim committed
76
77
78
  parser: Parser = Parser()


79
80
@dataclasses.dataclass
class PanopticSegmentationGenerator(hyperparams.Config):
81
  """Panoptic segmentation generator config."""
82
83
84
  output_size: List[int] = dataclasses.field(
      default_factory=list)
  mask_binarize_threshold: float = 0.5
srihari-humbarwadi's avatar
srihari-humbarwadi committed
85
  score_threshold: float = 0.5
86
87
  things_overlap_threshold: float = 0.5
  stuff_area_threshold: float = 4096.0
88
89
90
  things_class_label: int = 1
  void_class_label: int = 0
  void_instance_id: int = 0
91
  rescale_predictions: bool = False
92
93


Abdullah Rashwan's avatar
Abdullah Rashwan committed
94
@dataclasses.dataclass
Abdullah Rashwan's avatar
Abdullah Rashwan committed
95
class PanopticMaskRCNN(deepmac_maskrcnn.DeepMaskHeadRCNN):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
96
97
  """Panoptic Mask R-CNN model config."""
  segmentation_model: semantic_segmentation.SemanticSegmentationModel = (
Jaeyoun Kim's avatar
Jaeyoun Kim committed
98
99
      SEGMENTATION_MODEL(num_classes=2))
  include_mask = True
Abdullah Rashwan's avatar
Abdullah Rashwan committed
100
101
  shared_backbone: bool = True
  shared_decoder: bool = True
102
  stuff_classes_offset: int = 0
103
  generate_panoptic_masks: bool = True
104
  panoptic_segmentation_generator: PanopticSegmentationGenerator = PanopticSegmentationGenerator()  # pylint:disable=line-too-long
Jaeyoun Kim's avatar
Jaeyoun Kim committed
105
106
107
108
109
110
111
112
113
114
115


@dataclasses.dataclass
class Losses(maskrcnn.Losses):
  """Panoptic Mask R-CNN loss config."""
  semantic_segmentation_label_smoothing: float = 0.0
  semantic_segmentation_ignore_label: int = 255
  semantic_segmentation_class_weights: List[float] = dataclasses.field(
      default_factory=list)
  semantic_segmentation_use_groundtruth_dimension: bool = True
  semantic_segmentation_top_k_percent_pixels: float = 1.0
116
  instance_segmentation_weight: float = 1.0
117
  semantic_segmentation_weight: float = 0.5
Jaeyoun Kim's avatar
Jaeyoun Kim committed
118
119


120
121
122
123
124
@dataclasses.dataclass
class PanopticQualityEvaluator(hyperparams.Config):
  """Panoptic Quality Evaluator config."""
  num_categories: int = 2
  ignored_label: int = 0
125
  max_instances_per_category: int = 256
126
127
128
  offset: int = 256 * 256 * 256
  is_thing: List[float] = dataclasses.field(
      default_factory=list)
129
  rescale_predictions: bool = False
130
  report_per_class_metrics: bool = False
131

132

Jaeyoun Kim's avatar
Jaeyoun Kim committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
@dataclasses.dataclass
class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
  """Panoptic Mask R-CNN task config."""
  model: PanopticMaskRCNN = PanopticMaskRCNN()
  train_data: DataConfig = DataConfig(is_training=True)
  validation_data: DataConfig = DataConfig(is_training=False,
                                           drop_remainder=False)
  segmentation_evaluation: semantic_segmentation.Evaluation = semantic_segmentation.Evaluation()  # pylint: disable=line-too-long
  losses: Losses = Losses()
  init_checkpoint: Optional[str] = None
  segmentation_init_checkpoint: Optional[str] = None

  # 'init_checkpoint_modules' controls the modules that need to be initialized
  # from checkpoint paths given by 'init_checkpoint' and/or
  # 'segmentation_init_checkpoint. Supports modules:
  # 'backbone': Initialize MaskRCNN backbone
  # 'segmentation_backbone': Initialize segmentation backbone
  # 'segmentation_decoder': Initialize segmentation decoder
  # 'all': Initialize all modules
  init_checkpoint_modules: Optional[List[str]] = dataclasses.field(
      default_factory=list)
154
  panoptic_quality_evaluator: PanopticQualityEvaluator = PanopticQualityEvaluator()  # pylint: disable=line-too-long
Jaeyoun Kim's avatar
Jaeyoun Kim committed
155
156


157
158
@exp_factory.register_config_factory('panoptic_fpn_coco')
def panoptic_fpn_coco() -> cfg.ExperimentConfig:
Jaeyoun Kim's avatar
Jaeyoun Kim committed
159
160
161
162
163
164
  """COCO panoptic segmentation with Panoptic Mask R-CNN."""
  train_batch_size = 64
  eval_batch_size = 8
  steps_per_epoch = _COCO_TRAIN_EXAMPLES // train_batch_size
  validation_steps = _COCO_VAL_EXAMPLES // eval_batch_size

165
166
167
168
169
170
171
172
  # coco panoptic dataset has category ids ranging from [0-200] inclusive.
  # 0 is not used and represents the background class
  # ids 1-91 represent thing categories (91)
  # ids 92-200 represent stuff categories (109)
  # for the segmentation task, we continue using id=0 for the background
  # and map all thing categories to id=1, the remaining 109 stuff categories
  # are shifted by an offset=90 given by num_thing classes - 1. This shifting
  # will make all the stuff categories begin from id=2 and end at id=110
173
174
  num_panoptic_categories = 201
  num_thing_categories = 91
175
176
  num_semantic_segmentation_classes = 111

177
178
179
180
  is_thing = [False]
  for idx in range(1, num_panoptic_categories):
    is_thing.append(True if idx <= num_thing_categories else False)

Jaeyoun Kim's avatar
Jaeyoun Kim committed
181
  config = cfg.ExperimentConfig(
182
      runtime=cfg.RuntimeConfig(
183
          mixed_precision_dtype='float32', enable_xla=True),
Jaeyoun Kim's avatar
Jaeyoun Kim committed
184
185
186
187
188
      task=PanopticMaskRCNNTask(
          init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/resnet50_imagenet/ckpt-28080',  # pylint: disable=line-too-long
          init_checkpoint_modules=['backbone'],
          model=PanopticMaskRCNN(
              num_classes=91, input_size=[1024, 1024, 3],
189
              panoptic_segmentation_generator=PanopticSegmentationGenerator(
190
                  output_size=[640, 640], rescale_predictions=True),
191
              stuff_classes_offset=90,
Jaeyoun Kim's avatar
Jaeyoun Kim committed
192
              segmentation_model=SEGMENTATION_MODEL(
193
                  num_classes=num_semantic_segmentation_classes,
194
                  head=SEGMENTATION_HEAD(
srihari-humbarwadi's avatar
srihari-humbarwadi committed
195
                      level=2,
196
197
                      num_convs=0,
                      num_filters=128,
srihari-humbarwadi's avatar
srihari-humbarwadi committed
198
199
200
                      decoder_min_level=2,
                      decoder_max_level=6,
                      feature_fusion='panoptic_fpn_fusion'))),
Jaeyoun Kim's avatar
Jaeyoun Kim committed
201
202
203
204
205
206
207
208
209
210
211
          losses=Losses(l2_weight_decay=0.00004),
          train_data=DataConfig(
              input_path=os.path.join(_COCO_INPUT_PATH_BASE, 'train*'),
              is_training=True,
              global_batch_size=train_batch_size,
              parser=Parser(
                  aug_rand_hflip=True, aug_scale_min=0.8, aug_scale_max=1.25)),
          validation_data=DataConfig(
              input_path=os.path.join(_COCO_INPUT_PATH_BASE, 'val*'),
              is_training=False,
              global_batch_size=eval_batch_size,
212
213
214
              parser=Parser(
                  segmentation_resize_eval_groundtruth=False,
                  segmentation_groundtruth_padded_size=[640, 640]),
Jaeyoun Kim's avatar
Jaeyoun Kim committed
215
216
              drop_remainder=False),
          annotation_file=os.path.join(_COCO_INPUT_PATH_BASE,
217
                                       'instances_val2017.json'),
218
219
          segmentation_evaluation=semantic_segmentation.Evaluation(
              report_per_class_iou=False, report_train_mean_iou=False),
220
221
222
          panoptic_quality_evaluator=PanopticQualityEvaluator(
              num_categories=num_panoptic_categories,
              ignored_label=0,
223
224
              is_thing=is_thing,
              rescale_predictions=True)),
Jaeyoun Kim's avatar
Jaeyoun Kim committed
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
      trainer=cfg.TrainerConfig(
          train_steps=22500,
          validation_steps=validation_steps,
          validation_interval=steps_per_epoch,
          steps_per_loop=steps_per_epoch,
          summary_interval=steps_per_epoch,
          checkpoint_interval=steps_per_epoch,
          optimizer_config=optimization.OptimizationConfig({
              'optimizer': {
                  'type': 'sgd',
                  'sgd': {
                      'momentum': 0.9
                  }
              },
              'learning_rate': {
                  'type': 'stepwise',
                  'stepwise': {
                      'boundaries': [15000, 20000],
                      'values': [0.12, 0.012, 0.0012],
                  }
              },
              'warmup': {
                  'type': 'linear',
                  'linear': {
                      'warmup_steps': 500,
                      'warmup_learning_rate': 0.0067
                  }
              }
          })),
      restrictions=[
          'task.train_data.is_training != None',
          'task.validation_data.is_training != None'
      ])
  return config