maskrcnn_model.py 12 KB
Newer Older
Yeqing Li's avatar
Yeqing Li committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Yeqing Li's avatar
Yeqing Li committed
14

Abdullah Rashwan's avatar
Abdullah Rashwan committed
15
16
"""Mask R-CNN model."""

Xianzhi Du's avatar
Xianzhi Du committed
17
from typing import Any, List, Mapping, Optional, Union
Fan Yang's avatar
Fan Yang committed
18

Abdullah Rashwan's avatar
Abdullah Rashwan committed
19
20
21
22
23
24
25
26
27
28
29
# Import libraries
import tensorflow as tf

from official.vision.beta.ops import box_ops


@tf.keras.utils.register_keras_serializable(package='Vision')
class MaskRCNNModel(tf.keras.Model):
  """The Mask R-CNN model."""

  def __init__(self,
Fan Yang's avatar
Fan Yang committed
30
31
32
33
34
               backbone: tf.keras.Model,
               decoder: tf.keras.Model,
               rpn_head: tf.keras.layers.Layer,
               detection_head: tf.keras.layers.Layer,
               roi_generator: tf.keras.layers.Layer,
Xianzhi Du's avatar
Xianzhi Du committed
35
36
               roi_sampler: Union[tf.keras.layers.Layer,
                                  List[tf.keras.layers.Layer]],
Fan Yang's avatar
Fan Yang committed
37
38
39
40
41
               roi_aligner: tf.keras.layers.Layer,
               detection_generator: tf.keras.layers.Layer,
               mask_head: Optional[tf.keras.layers.Layer] = None,
               mask_sampler: Optional[tf.keras.layers.Layer] = None,
               mask_roi_aligner: Optional[tf.keras.layers.Layer] = None,
Xianzhi Du's avatar
Xianzhi Du committed
42
43
               class_agnostic_bbox_pred: bool = False,
               cascade_class_ensemble: bool = False,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
44
45
46
47
48
49
50
51
52
               **kwargs):
    """Initializes the Mask R-CNN model.

    Args:
      backbone: `tf.keras.Model`, the backbone network.
      decoder: `tf.keras.Model`, the decoder network.
      rpn_head: the RPN head.
      detection_head: the detection head.
      roi_generator: the ROI generator.
Xianzhi Du's avatar
Xianzhi Du committed
53
54
      roi_sampler: a single ROI sampler or a list of ROI samplers for cascade
        detection heads.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
55
56
57
58
59
      roi_aligner: the ROI aligner.
      detection_generator: the detection generator.
      mask_head: the mask head.
      mask_sampler: the mask sampler.
      mask_roi_aligner: the ROI alginer for mask prediction.
Xianzhi Du's avatar
Xianzhi Du committed
60
61
62
63
      class_agnostic_bbox_pred: if True, perform class agnostic bounding box
        prediction. Needs to be `True` for Cascade RCNN models.
      cascade_class_ensemble: if True, ensemble classification scores over
        all detection heads.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
      **kwargs: keyword arguments to be passed.
    """
    super(MaskRCNNModel, self).__init__(**kwargs)
    self._config_dict = {
        'backbone': backbone,
        'decoder': decoder,
        'rpn_head': rpn_head,
        'detection_head': detection_head,
        'roi_generator': roi_generator,
        'roi_sampler': roi_sampler,
        'roi_aligner': roi_aligner,
        'detection_generator': detection_generator,
        'mask_head': mask_head,
        'mask_sampler': mask_sampler,
        'mask_roi_aligner': mask_roi_aligner,
Xianzhi Du's avatar
Xianzhi Du committed
79
80
        'class_agnostic_bbox_pred': class_agnostic_bbox_pred,
        'cascade_class_ensemble': cascade_class_ensemble,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
81
82
83
84
85
86
    }
    self.backbone = backbone
    self.decoder = decoder
    self.rpn_head = rpn_head
    self.detection_head = detection_head
    self.roi_generator = roi_generator
Xianzhi Du's avatar
Xianzhi Du committed
87
88
89
90
91
92
93
94
    if not isinstance(roi_sampler, (list, tuple)):
      self.roi_sampler = [roi_sampler]
    else:
      self.roi_sampler = roi_sampler
    if len(self.roi_sampler) > 1 and not class_agnostic_bbox_pred:
      raise ValueError(
          '`class_agnostic_bbox_pred` needs to be True if multiple detection heads are specified.'
      )
Abdullah Rashwan's avatar
Abdullah Rashwan committed
95
96
97
98
99
100
101
102
103
104
    self.roi_aligner = roi_aligner
    self.detection_generator = detection_generator
    self._include_mask = mask_head is not None
    self.mask_head = mask_head
    if self._include_mask and mask_sampler is None:
      raise ValueError('`mask_sampler` is not provided in Mask R-CNN.')
    self.mask_sampler = mask_sampler
    if self._include_mask and mask_roi_aligner is None:
      raise ValueError('`mask_roi_aligner` is not provided in Mask R-CNN.')
    self.mask_roi_aligner = mask_roi_aligner
Xianzhi Du's avatar
Xianzhi Du committed
105
106
107
108
109
110
111
    # Weights for the regression losses for each FRCNN layer.
    # TODO(xianzhi): Make the weights configurable.
    self._cascade_layer_to_weights = [
        [10.0, 10.0, 5.0, 5.0],
        [20.0, 20.0, 10.0, 10.0],
        [30.0, 30.0, 15.0, 15.0],
    ]
Abdullah Rashwan's avatar
Abdullah Rashwan committed
112
113

  def call(self,
Fan Yang's avatar
Fan Yang committed
114
115
116
117
118
119
120
           images: tf.Tensor,
           image_shape: tf.Tensor,
           anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None,
           gt_boxes: tf.Tensor = None,
           gt_classes: tf.Tensor = None,
           gt_masks: tf.Tensor = None,
           training: bool = None) -> Mapping[str, tf.Tensor]:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    model_outputs = {}

    # Feature extraction.
    features = self.backbone(images)
    if self.decoder:
      features = self.decoder(features)

    # Region proposal network.
    rpn_scores, rpn_boxes = self.rpn_head(features)

    model_outputs.update({
        'rpn_boxes': rpn_boxes,
        'rpn_scores': rpn_scores
    })

    # Generate RoIs.
Xianzhi Du's avatar
Xianzhi Du committed
137
138
    current_rois, _ = self.roi_generator(rpn_boxes, rpn_scores, anchor_boxes,
                                         image_shape, training)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
139

Xianzhi Du's avatar
Xianzhi Du committed
140
141
142
143
144
145
146
    next_rois = current_rois
    all_class_outputs = []
    for cascade_num in range(len(self.roi_sampler)):
      # In cascade RCNN we want the higher layers to have different regression
      # weights as the predicted deltas become smaller and smaller.
      regression_weights = self._cascade_layer_to_weights[cascade_num]
      current_rois = next_rois
Abdullah Rashwan's avatar
Abdullah Rashwan committed
147

Xianzhi Du's avatar
Xianzhi Du committed
148
149
150
151
152
153
154
155
156
157
158
159
      (class_outputs, box_outputs, model_outputs, matched_gt_boxes,
       matched_gt_classes, matched_gt_indices,
       current_rois) = self._run_frcnn_head(
           features=features,
           rois=current_rois,
           gt_boxes=gt_boxes,
           gt_classes=gt_classes,
           training=training,
           model_outputs=model_outputs,
           layer_num=cascade_num,
           regression_weights=regression_weights)
      all_class_outputs.append(class_outputs)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
160

Xianzhi Du's avatar
Xianzhi Du committed
161
162
163
164
165
166
167
168
      # Generate ROIs for the next cascade head if there is any.
      if cascade_num < len(self.roi_sampler) - 1:
        next_rois = box_ops.decode_boxes(
            tf.cast(box_outputs, tf.float32),
            current_rois,
            weights=regression_weights)
        next_rois = box_ops.clip_boxes(next_rois,
                                       tf.expand_dims(image_shape, axis=1))
Abdullah Rashwan's avatar
Abdullah Rashwan committed
169

Xianzhi Du's avatar
Xianzhi Du committed
170
171
172
    if not training:
      if self._config_dict['cascade_class_ensemble']:
        class_outputs = tf.add_n(all_class_outputs) / len(all_class_outputs)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
173
174

      detections = self.detection_generator(
Xianzhi Du's avatar
Xianzhi Du committed
175
176
177
178
179
180
          box_outputs,
          class_outputs,
          current_rois,
          image_shape,
          regression_weights,
          bbox_per_class=(not self._config_dict['class_agnostic_bbox_pred']))
Abdullah Rashwan's avatar
Abdullah Rashwan committed
181
182
183
184
185
186
187
188
189
190
191
      model_outputs.update({
          'detection_boxes': detections['detection_boxes'],
          'detection_scores': detections['detection_scores'],
          'detection_classes': detections['detection_classes'],
          'num_detections': detections['num_detections'],
      })

    if not self._include_mask:
      return model_outputs

    if training:
Xianzhi Du's avatar
Xianzhi Du committed
192
193
194
      current_rois, roi_classes, roi_masks = self.mask_sampler(
          current_rois, matched_gt_boxes, matched_gt_classes,
          matched_gt_indices, gt_masks)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
195
196
197
198
199
200
201
      roi_masks = tf.stop_gradient(roi_masks)

      model_outputs.update({
          'mask_class_targets': roi_classes,
          'mask_targets': roi_masks,
      })
    else:
Xianzhi Du's avatar
Xianzhi Du committed
202
      current_rois = model_outputs['detection_boxes']
Abdullah Rashwan's avatar
Abdullah Rashwan committed
203
204
205
      roi_classes = model_outputs['detection_classes']

    # Mask RoI align.
Xianzhi Du's avatar
Xianzhi Du committed
206
    mask_roi_features = self.mask_roi_aligner(features, current_rois)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
207
208
209

    # Mask head.
    raw_masks = self.mask_head([mask_roi_features, roi_classes])
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
210

Abdullah Rashwan's avatar
Abdullah Rashwan committed
211
212
213
214
215
216
217
218
219
220
    if training:
      model_outputs.update({
          'mask_outputs': raw_masks,
      })
    else:
      model_outputs.update({
          'detection_masks': tf.math.sigmoid(raw_masks),
      })
    return model_outputs

Xianzhi Du's avatar
Xianzhi Du committed
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
  def _run_frcnn_head(self, features, rois, gt_boxes, gt_classes, training,
                      model_outputs, layer_num, regression_weights):
    """Runs the frcnn head that does both class and box prediction.

    Args:
      features: `list` of features from the feature extractor.
      rois: `list` of current rois that will be used to predict bbox refinement
        and classes from.
      gt_boxes: a tensor with a shape of [batch_size, MAX_NUM_INSTANCES, 4].
        This tensor might have paddings with a negative value.
      gt_classes: [batch_size, MAX_INSTANCES] representing the groundtruth box
        classes. It is padded with -1s to indicate the invalid classes.
      training: `bool`, if model is training or being evaluated.
      model_outputs: `dict`, used for storing outputs used for eval and losses.
      layer_num: `int`, the current frcnn layer in the cascade.
      regression_weights: `list`, weights used for l1 loss in bounding box
        regression.

    Returns:
      class_outputs: Class predictions for rois.
      box_outputs: Box predictions for rois. These are formatted for the
        regression loss and need to be converted before being used as rois
        in the next stage.
      model_outputs: Updated dict with predictions used for losses and eval.
      matched_gt_boxes: If `is_training` is true, then these give the gt box
        location of its positive match.
      matched_gt_classes: If `is_training` is true, then these give the gt class
         of the predicted box.
      matched_gt_boxes: If `is_training` is true, then these give the box
        location of its positive match.
      matched_gt_indices: If `is_training` is true, then gives the index of
        the positive box match. Used for mask prediction.
      rois: The sampled rois used for this layer.
    """
    # Only used during training.
    matched_gt_boxes, matched_gt_classes, matched_gt_indices = (None, None,
                                                                None)
    if training:
      rois = tf.stop_gradient(rois)

      current_roi_sampler = self.roi_sampler[layer_num]
      rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices = (
          current_roi_sampler(rois, gt_boxes, gt_classes))
      # Create bounding box training targets.
      box_targets = box_ops.encode_boxes(
          matched_gt_boxes, rois, weights=regression_weights)
      # If the target is background, the box target is set to all 0s.
      box_targets = tf.where(
          tf.tile(
              tf.expand_dims(tf.equal(matched_gt_classes, 0), axis=-1),
              [1, 1, 4]), tf.zeros_like(box_targets), box_targets)
      model_outputs.update({
          'class_targets_{}'.format(layer_num)
          if layer_num else 'class_targets':
              matched_gt_classes,
          'box_targets_{}'.format(layer_num) if layer_num else 'box_targets':
              box_targets,
      })

    # Get roi features.
    roi_features = self.roi_aligner(features, rois)

    # Run frcnn head to get class and bbox predictions.
    class_outputs, box_outputs = self.detection_head(roi_features)

    model_outputs.update({
        'class_outputs_{}'.format(layer_num) if layer_num else 'class_outputs':
            class_outputs,
        'box_outputs_{}'.format(layer_num) if layer_num else 'box_outputs':
            box_outputs,
    })
    return (class_outputs, box_outputs, model_outputs, matched_gt_boxes,
            matched_gt_classes, matched_gt_indices, rois)

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
295
  @property
Fan Yang's avatar
Fan Yang committed
296
297
  def checkpoint_items(
      self) -> Mapping[str, Union[tf.keras.Model, tf.keras.layers.Layer]]:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
298
299
300
301
302
303
304
305
306
307
308
309
    """Returns a dictionary of items to be additionally checkpointed."""
    items = dict(
        backbone=self.backbone,
        rpn_head=self.rpn_head,
        detection_head=self.detection_head)
    if self.decoder is not None:
      items.update(decoder=self.decoder)
    if self._include_mask:
      items.update(mask_head=self.mask_head)

    return items

Fan Yang's avatar
Fan Yang committed
310
  def get_config(self) -> Mapping[str, Any]:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
311
312
313
314
315
    return self._config_dict

  @classmethod
  def from_config(cls, config):
    return cls(**config)