maskrcnn_model.py 16.6 KB
Newer Older
Yeqing Li's avatar
Yeqing Li committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Yeqing Li's avatar
Yeqing Li committed
14

Xianzhi Du's avatar
Xianzhi Du committed
15
"""R-CNN(-RS) models."""
Abdullah Rashwan's avatar
Abdullah Rashwan committed
16

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
17
from typing import Any, List, Mapping, Optional, Tuple, Union
Fan Yang's avatar
Fan Yang committed
18

Abdullah Rashwan's avatar
Abdullah Rashwan committed
19
20
import tensorflow as tf

21
from official.vision.beta.ops import anchor
Abdullah Rashwan's avatar
Abdullah Rashwan committed
22
23
24
25
26
from official.vision.beta.ops import box_ops


@tf.keras.utils.register_keras_serializable(package='Vision')
class MaskRCNNModel(tf.keras.Model):
Xianzhi Du's avatar
Xianzhi Du committed
27
  """The Mask R-CNN(-RS) and Cascade RCNN-RS models."""
Abdullah Rashwan's avatar
Abdullah Rashwan committed
28
29

  def __init__(self,
Fan Yang's avatar
Fan Yang committed
30
31
32
               backbone: tf.keras.Model,
               decoder: tf.keras.Model,
               rpn_head: tf.keras.layers.Layer,
Xianzhi Du's avatar
Xianzhi Du committed
33
34
               detection_head: Union[tf.keras.layers.Layer,
                                     List[tf.keras.layers.Layer]],
Fan Yang's avatar
Fan Yang committed
35
               roi_generator: tf.keras.layers.Layer,
Xianzhi Du's avatar
Xianzhi Du committed
36
37
               roi_sampler: Union[tf.keras.layers.Layer,
                                  List[tf.keras.layers.Layer]],
Fan Yang's avatar
Fan Yang committed
38
39
40
41
42
               roi_aligner: tf.keras.layers.Layer,
               detection_generator: tf.keras.layers.Layer,
               mask_head: Optional[tf.keras.layers.Layer] = None,
               mask_sampler: Optional[tf.keras.layers.Layer] = None,
               mask_roi_aligner: Optional[tf.keras.layers.Layer] = None,
Xianzhi Du's avatar
Xianzhi Du committed
43
44
               class_agnostic_bbox_pred: bool = False,
               cascade_class_ensemble: bool = False,
45
46
47
48
49
               min_level: Optional[int] = None,
               max_level: Optional[int] = None,
               num_scales: Optional[int] = None,
               aspect_ratios: Optional[List[float]] = None,
               anchor_size: Optional[float] = None,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
50
               **kwargs):
Xianzhi Du's avatar
Xianzhi Du committed
51
    """Initializes the R-CNN(-RS) model.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
52
53
54
55
56

    Args:
      backbone: `tf.keras.Model`, the backbone network.
      decoder: `tf.keras.Model`, the decoder network.
      rpn_head: the RPN head.
Xianzhi Du's avatar
Xianzhi Du committed
57
      detection_head: the detection head or a list of heads.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
58
      roi_generator: the ROI generator.
Xianzhi Du's avatar
Xianzhi Du committed
59
60
      roi_sampler: a single ROI sampler or a list of ROI samplers for cascade
        detection heads.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
61
62
63
64
65
      roi_aligner: the ROI aligner.
      detection_generator: the detection generator.
      mask_head: the mask head.
      mask_sampler: the mask sampler.
      mask_roi_aligner: the ROI alginer for mask prediction.
Xianzhi Du's avatar
Xianzhi Du committed
66
67
      class_agnostic_bbox_pred: if True, perform class agnostic bounding box
        prediction. Needs to be `True` for Cascade RCNN models.
Xianzhi Du's avatar
Xianzhi Du committed
68
69
      cascade_class_ensemble: if True, ensemble classification scores over all
        detection heads.
70
71
      min_level: Minimum level in output feature maps.
      max_level: Maximum level in output feature maps.
Xianzhi Du's avatar
Xianzhi Du committed
72
73
74
75
76
77
78
79
      num_scales: A number representing intermediate scales added on each level.
        For instances, num_scales=2 adds one additional intermediate anchor
        scales [2^0, 2^0.5] on each level.
      aspect_ratios: A list representing the aspect raito anchors added on each
        level. The number indicates the ratio of width to height. For instances,
        aspect_ratios=[1.0, 2.0, 0.5] adds three anchors on each scale level.
      anchor_size: A number representing the scale of size of the base anchor to
        the feature stride 2^level.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
      **kwargs: keyword arguments to be passed.
    """
    super(MaskRCNNModel, self).__init__(**kwargs)
    self._config_dict = {
        'backbone': backbone,
        'decoder': decoder,
        'rpn_head': rpn_head,
        'detection_head': detection_head,
        'roi_generator': roi_generator,
        'roi_sampler': roi_sampler,
        'roi_aligner': roi_aligner,
        'detection_generator': detection_generator,
        'mask_head': mask_head,
        'mask_sampler': mask_sampler,
        'mask_roi_aligner': mask_roi_aligner,
Xianzhi Du's avatar
Xianzhi Du committed
95
96
        'class_agnostic_bbox_pred': class_agnostic_bbox_pred,
        'cascade_class_ensemble': cascade_class_ensemble,
97
98
99
100
101
        'min_level': min_level,
        'max_level': max_level,
        'num_scales': num_scales,
        'aspect_ratios': aspect_ratios,
        'anchor_size': anchor_size,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
102
103
104
105
    }
    self.backbone = backbone
    self.decoder = decoder
    self.rpn_head = rpn_head
Xianzhi Du's avatar
Xianzhi Du committed
106
107
108
109
    if not isinstance(detection_head, (list, tuple)):
      self.detection_head = [detection_head]
    else:
      self.detection_head = detection_head
Abdullah Rashwan's avatar
Abdullah Rashwan committed
110
    self.roi_generator = roi_generator
Xianzhi Du's avatar
Xianzhi Du committed
111
112
113
114
115
116
117
118
    if not isinstance(roi_sampler, (list, tuple)):
      self.roi_sampler = [roi_sampler]
    else:
      self.roi_sampler = roi_sampler
    if len(self.roi_sampler) > 1 and not class_agnostic_bbox_pred:
      raise ValueError(
          '`class_agnostic_bbox_pred` needs to be True if multiple detection heads are specified.'
      )
Abdullah Rashwan's avatar
Abdullah Rashwan committed
119
120
121
122
123
124
125
126
127
128
    self.roi_aligner = roi_aligner
    self.detection_generator = detection_generator
    self._include_mask = mask_head is not None
    self.mask_head = mask_head
    if self._include_mask and mask_sampler is None:
      raise ValueError('`mask_sampler` is not provided in Mask R-CNN.')
    self.mask_sampler = mask_sampler
    if self._include_mask and mask_roi_aligner is None:
      raise ValueError('`mask_roi_aligner` is not provided in Mask R-CNN.')
    self.mask_roi_aligner = mask_roi_aligner
Xianzhi Du's avatar
Xianzhi Du committed
129
130
131
132
133
134
135
    # Weights for the regression losses for each FRCNN layer.
    # TODO(xianzhi): Make the weights configurable.
    self._cascade_layer_to_weights = [
        [10.0, 10.0, 5.0, 5.0],
        [20.0, 20.0, 10.0, 10.0],
        [30.0, 30.0, 15.0, 15.0],
    ]
Abdullah Rashwan's avatar
Abdullah Rashwan committed
136
137

  def call(self,
Fan Yang's avatar
Fan Yang committed
138
139
140
           images: tf.Tensor,
           image_shape: tf.Tensor,
           anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None,
Rebecca Chen's avatar
Rebecca Chen committed
141
142
143
144
           gt_boxes: Optional[tf.Tensor] = None,
           gt_classes: Optional[tf.Tensor] = None,
           gt_masks: Optional[tf.Tensor] = None,
           training: Optional[bool] = None) -> Mapping[str, tf.Tensor]:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172

    model_outputs, intermediate_outputs = self._call_box_outputs(
        images=images, image_shape=image_shape, anchor_boxes=anchor_boxes,
        gt_boxes=gt_boxes, gt_classes=gt_classes, training=training)
    if not self._include_mask:
      return model_outputs

    model_mask_outputs = self._call_mask_outputs(
        model_box_outputs=model_outputs,
        features=intermediate_outputs['features'],
        current_rois=intermediate_outputs['current_rois'],
        matched_gt_indices=intermediate_outputs['matched_gt_indices'],
        matched_gt_boxes=intermediate_outputs['matched_gt_boxes'],
        matched_gt_classes=intermediate_outputs['matched_gt_classes'],
        gt_masks=gt_masks,
        training=training)
    model_outputs.update(model_mask_outputs)
    return model_outputs

  def _call_box_outputs(
      self, images: tf.Tensor,
      image_shape: tf.Tensor,
      anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None,
      gt_boxes: Optional[tf.Tensor] = None,
      gt_classes: Optional[tf.Tensor] = None,
      training: Optional[bool] = None) -> Tuple[
          Mapping[str, tf.Tensor], Mapping[str, tf.Tensor]]:
    """Implementation of the Faster-RCNN logic for boxes."""
Abdullah Rashwan's avatar
Abdullah Rashwan committed
173
174
175
    model_outputs = {}

    # Feature extraction.
176
    backbone_features = self.backbone(images)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
177
    if self.decoder:
178
179
180
      features = self.decoder(backbone_features)
    else:
      features = backbone_features
Abdullah Rashwan's avatar
Abdullah Rashwan committed
181
182
183
184
185

    # Region proposal network.
    rpn_scores, rpn_boxes = self.rpn_head(features)

    model_outputs.update({
186
187
        'backbone_features': backbone_features,
        'decoder_features': features,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
188
189
190
191
        'rpn_boxes': rpn_boxes,
        'rpn_scores': rpn_scores
    })

192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    # Generate anchor boxes for this batch if not provided.
    if anchor_boxes is None:
      _, image_height, image_width, _ = images.get_shape().as_list()
      anchor_boxes = anchor.Anchor(
          min_level=self._config_dict['min_level'],
          max_level=self._config_dict['max_level'],
          num_scales=self._config_dict['num_scales'],
          aspect_ratios=self._config_dict['aspect_ratios'],
          anchor_size=self._config_dict['anchor_size'],
          image_size=(image_height, image_width)).multilevel_boxes
      for l in anchor_boxes:
        anchor_boxes[l] = tf.tile(
            tf.expand_dims(anchor_boxes[l], axis=0),
            [tf.shape(images)[0], 1, 1, 1])

Abdullah Rashwan's avatar
Abdullah Rashwan committed
207
    # Generate RoIs.
Xianzhi Du's avatar
Xianzhi Du committed
208
209
    current_rois, _ = self.roi_generator(rpn_boxes, rpn_scores, anchor_boxes,
                                         image_shape, training)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
210

Xianzhi Du's avatar
Xianzhi Du committed
211
212
213
214
215
216
217
    next_rois = current_rois
    all_class_outputs = []
    for cascade_num in range(len(self.roi_sampler)):
      # In cascade RCNN we want the higher layers to have different regression
      # weights as the predicted deltas become smaller and smaller.
      regression_weights = self._cascade_layer_to_weights[cascade_num]
      current_rois = next_rois
Abdullah Rashwan's avatar
Abdullah Rashwan committed
218

Xianzhi Du's avatar
Xianzhi Du committed
219
220
221
222
223
224
225
226
227
      (class_outputs, box_outputs, model_outputs, matched_gt_boxes,
       matched_gt_classes, matched_gt_indices,
       current_rois) = self._run_frcnn_head(
           features=features,
           rois=current_rois,
           gt_boxes=gt_boxes,
           gt_classes=gt_classes,
           training=training,
           model_outputs=model_outputs,
Xianzhi Du's avatar
Xianzhi Du committed
228
           cascade_num=cascade_num,
Xianzhi Du's avatar
Xianzhi Du committed
229
230
           regression_weights=regression_weights)
      all_class_outputs.append(class_outputs)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
231

Xianzhi Du's avatar
Xianzhi Du committed
232
233
234
235
236
237
238
239
      # Generate ROIs for the next cascade head if there is any.
      if cascade_num < len(self.roi_sampler) - 1:
        next_rois = box_ops.decode_boxes(
            tf.cast(box_outputs, tf.float32),
            current_rois,
            weights=regression_weights)
        next_rois = box_ops.clip_boxes(next_rois,
                                       tf.expand_dims(image_shape, axis=1))
Abdullah Rashwan's avatar
Abdullah Rashwan committed
240

Xianzhi Du's avatar
Xianzhi Du committed
241
242
243
    if not training:
      if self._config_dict['cascade_class_ensemble']:
        class_outputs = tf.add_n(all_class_outputs) / len(all_class_outputs)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
244
245

      detections = self.detection_generator(
Xianzhi Du's avatar
Xianzhi Du committed
246
247
248
249
250
251
          box_outputs,
          class_outputs,
          current_rois,
          image_shape,
          regression_weights,
          bbox_per_class=(not self._config_dict['class_agnostic_bbox_pred']))
Abdullah Rashwan's avatar
Abdullah Rashwan committed
252
      model_outputs.update({
Fan Yang's avatar
Fan Yang committed
253
254
          'cls_outputs': class_outputs,
          'box_outputs': box_outputs,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
255
      })
Fan Yang's avatar
Fan Yang committed
256
257
258
259
260
261
262
263
264
265
266
267
      if self.detection_generator.get_config()['apply_nms']:
        model_outputs.update({
            'detection_boxes': detections['detection_boxes'],
            'detection_scores': detections['detection_scores'],
            'detection_classes': detections['detection_classes'],
            'num_detections': detections['num_detections']
        })
      else:
        model_outputs.update({
            'decoded_boxes': detections['decoded_boxes'],
            'decoded_box_scores': detections['decoded_box_scores']
        })
Abdullah Rashwan's avatar
Abdullah Rashwan committed
268

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
    intermediate_outputs = {
        'matched_gt_boxes': matched_gt_boxes,
        'matched_gt_indices': matched_gt_indices,
        'matched_gt_classes': matched_gt_classes,
        'features': features,
        'current_rois': current_rois,
    }
    return (model_outputs, intermediate_outputs)

  def _call_mask_outputs(
      self,
      model_box_outputs: Mapping[str, tf.Tensor],
      features: tf.Tensor,
      current_rois: tf.Tensor,
      matched_gt_indices: tf.Tensor,
      matched_gt_boxes: tf.Tensor,
      matched_gt_classes: tf.Tensor,
      gt_masks: tf.Tensor,
      training: Optional[bool] = None) -> Mapping[str, tf.Tensor]:
    """Implementation of Mask-RCNN mask prediction logic."""
Abdullah Rashwan's avatar
Abdullah Rashwan committed
289

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
290
    model_outputs = dict(model_box_outputs)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
291
    if training:
Xianzhi Du's avatar
Xianzhi Du committed
292
293
294
      current_rois, roi_classes, roi_masks = self.mask_sampler(
          current_rois, matched_gt_boxes, matched_gt_classes,
          matched_gt_indices, gt_masks)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
295
296
297
298
299
300
301
      roi_masks = tf.stop_gradient(roi_masks)

      model_outputs.update({
          'mask_class_targets': roi_classes,
          'mask_targets': roi_masks,
      })
    else:
Xianzhi Du's avatar
Xianzhi Du committed
302
      current_rois = model_outputs['detection_boxes']
Abdullah Rashwan's avatar
Abdullah Rashwan committed
303
304
305
      roi_classes = model_outputs['detection_classes']

    # Mask RoI align.
Xianzhi Du's avatar
Xianzhi Du committed
306
    mask_roi_features = self.mask_roi_aligner(features, current_rois)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
307
308
309

    # Mask head.
    raw_masks = self.mask_head([mask_roi_features, roi_classes])
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
310

Abdullah Rashwan's avatar
Abdullah Rashwan committed
311
312
313
314
315
316
317
318
319
320
    if training:
      model_outputs.update({
          'mask_outputs': raw_masks,
      })
    else:
      model_outputs.update({
          'detection_masks': tf.math.sigmoid(raw_masks),
      })
    return model_outputs

Xianzhi Du's avatar
Xianzhi Du committed
321
  def _run_frcnn_head(self, features, rois, gt_boxes, gt_classes, training,
Xianzhi Du's avatar
Xianzhi Du committed
322
                      model_outputs, cascade_num, regression_weights):
Xianzhi Du's avatar
Xianzhi Du committed
323
324
325
326
327
328
329
330
331
332
333
334
    """Runs the frcnn head that does both class and box prediction.

    Args:
      features: `list` of features from the feature extractor.
      rois: `list` of current rois that will be used to predict bbox refinement
        and classes from.
      gt_boxes: a tensor with a shape of [batch_size, MAX_NUM_INSTANCES, 4].
        This tensor might have paddings with a negative value.
      gt_classes: [batch_size, MAX_INSTANCES] representing the groundtruth box
        classes. It is padded with -1s to indicate the invalid classes.
      training: `bool`, if model is training or being evaluated.
      model_outputs: `dict`, used for storing outputs used for eval and losses.
Xianzhi Du's avatar
Xianzhi Du committed
335
      cascade_num: `int`, the current frcnn layer in the cascade.
Xianzhi Du's avatar
Xianzhi Du committed
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
      regression_weights: `list`, weights used for l1 loss in bounding box
        regression.

    Returns:
      class_outputs: Class predictions for rois.
      box_outputs: Box predictions for rois. These are formatted for the
        regression loss and need to be converted before being used as rois
        in the next stage.
      model_outputs: Updated dict with predictions used for losses and eval.
      matched_gt_boxes: If `is_training` is true, then these give the gt box
        location of its positive match.
      matched_gt_classes: If `is_training` is true, then these give the gt class
         of the predicted box.
      matched_gt_boxes: If `is_training` is true, then these give the box
        location of its positive match.
      matched_gt_indices: If `is_training` is true, then gives the index of
        the positive box match. Used for mask prediction.
      rois: The sampled rois used for this layer.
    """
    # Only used during training.
    matched_gt_boxes, matched_gt_classes, matched_gt_indices = (None, None,
                                                                None)
358
    if training and gt_boxes is not None:
Xianzhi Du's avatar
Xianzhi Du committed
359
360
      rois = tf.stop_gradient(rois)

Xianzhi Du's avatar
Xianzhi Du committed
361
      current_roi_sampler = self.roi_sampler[cascade_num]
Xianzhi Du's avatar
Xianzhi Du committed
362
363
364
365
366
367
368
369
370
371
372
      rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices = (
          current_roi_sampler(rois, gt_boxes, gt_classes))
      # Create bounding box training targets.
      box_targets = box_ops.encode_boxes(
          matched_gt_boxes, rois, weights=regression_weights)
      # If the target is background, the box target is set to all 0s.
      box_targets = tf.where(
          tf.tile(
              tf.expand_dims(tf.equal(matched_gt_classes, 0), axis=-1),
              [1, 1, 4]), tf.zeros_like(box_targets), box_targets)
      model_outputs.update({
Xianzhi Du's avatar
Xianzhi Du committed
373
374
          'class_targets_{}'.format(cascade_num)
          if cascade_num else 'class_targets':
Xianzhi Du's avatar
Xianzhi Du committed
375
              matched_gt_classes,
Xianzhi Du's avatar
Xianzhi Du committed
376
377
          'box_targets_{}'.format(cascade_num)
          if cascade_num else 'box_targets':
Xianzhi Du's avatar
Xianzhi Du committed
378
379
380
381
382
383
384
              box_targets,
      })

    # Get roi features.
    roi_features = self.roi_aligner(features, rois)

    # Run frcnn head to get class and bbox predictions.
Xianzhi Du's avatar
Xianzhi Du committed
385
386
    current_detection_head = self.detection_head[cascade_num]
    class_outputs, box_outputs = current_detection_head(roi_features)
Xianzhi Du's avatar
Xianzhi Du committed
387
388

    model_outputs.update({
Xianzhi Du's avatar
Xianzhi Du committed
389
390
        'class_outputs_{}'.format(cascade_num)
        if cascade_num else 'class_outputs':
Xianzhi Du's avatar
Xianzhi Du committed
391
            class_outputs,
Xianzhi Du's avatar
Xianzhi Du committed
392
        'box_outputs_{}'.format(cascade_num) if cascade_num else 'box_outputs':
Xianzhi Du's avatar
Xianzhi Du committed
393
394
395
396
397
            box_outputs,
    })
    return (class_outputs, box_outputs, model_outputs, matched_gt_boxes,
            matched_gt_classes, matched_gt_indices, rois)

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
398
  @property
Fan Yang's avatar
Fan Yang committed
399
400
  def checkpoint_items(
      self) -> Mapping[str, Union[tf.keras.Model, tf.keras.layers.Layer]]:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
401
402
403
404
405
406
407
408
409
410
411
412
    """Returns a dictionary of items to be additionally checkpointed."""
    items = dict(
        backbone=self.backbone,
        rpn_head=self.rpn_head,
        detection_head=self.detection_head)
    if self.decoder is not None:
      items.update(decoder=self.decoder)
    if self._include_mask:
      items.update(mask_head=self.mask_head)

    return items

Fan Yang's avatar
Fan Yang committed
413
  def get_config(self) -> Mapping[str, Any]:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
414
415
416
417
418
    return self._config_dict

  @classmethod
  def from_config(cls, config):
    return cls(**config)