maskrcnn.py 15.6 KB
Newer Older
Yeqing Li's avatar
Yeqing Li committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Yeqing Li's avatar
Yeqing Li committed
14

Abdullah Rashwan's avatar
Abdullah Rashwan committed
15
"""RetinaNet task definition."""
Abdullah Rashwan's avatar
Abdullah Rashwan committed
16
import os
Fan Yang's avatar
Fan Yang committed
17
from typing import Any, Optional, List, Tuple, Mapping
Abdullah Rashwan's avatar
Abdullah Rashwan committed
18
19
20

from absl import logging
import tensorflow as tf
21
from official.common import dataset_fn
Abdullah Rashwan's avatar
Abdullah Rashwan committed
22
23
24
from official.core import base_task
from official.core import task_factory
from official.vision.beta.configs import maskrcnn as exp_cfg
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
25
from official.vision.beta.dataloaders import input_reader_factory
Abdullah Rashwan's avatar
Abdullah Rashwan committed
26
27
28
29
from official.vision.beta.dataloaders import maskrcnn_input
from official.vision.beta.dataloaders import tf_example_decoder
from official.vision.beta.dataloaders import tf_example_label_map_decoder
from official.vision.beta.evaluation import coco_evaluator
Abdullah Rashwan's avatar
Abdullah Rashwan committed
30
from official.vision.beta.evaluation import coco_utils
Abdullah Rashwan's avatar
Abdullah Rashwan committed
31
32
33
34
from official.vision.beta.losses import maskrcnn_losses
from official.vision.beta.modeling import factory


Fan Yang's avatar
Fan Yang committed
35
36
def zero_out_disallowed_class_ids(batch_class_ids: tf.Tensor,
                                  allowed_class_ids: List[int]):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
  """Zero out IDs of classes not in allowed_class_ids.

  Args:
    batch_class_ids: A [batch_size, num_instances] int tensor of input
      class IDs.
    allowed_class_ids: A python list of class IDs which we want to allow.

  Returns:
      filtered_class_ids: A [batch_size, num_instances] int tensor with any
        class ID not in allowed_class_ids set to 0.
  """

  allowed_class_ids = tf.constant(allowed_class_ids,
                                  dtype=batch_class_ids.dtype)

  match_ids = (batch_class_ids[:, :, tf.newaxis] ==
               allowed_class_ids[tf.newaxis, tf.newaxis, :])

  match_ids = tf.reduce_any(match_ids, axis=2)
  return tf.where(match_ids, batch_class_ids, tf.zeros_like(batch_class_ids))


Abdullah Rashwan's avatar
Abdullah Rashwan committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
@task_factory.register_task_cls(exp_cfg.MaskRCNNTask)
class MaskRCNNTask(base_task.Task):
  """A single-replica view of training procedure.

  Mask R-CNN task provides artifacts for training/evalution procedures,
  including loading/iterating over Datasets, initializing the model, calculating
  the loss, post-processing, and customized metrics with reduction.
  """

  def build_model(self):
    """Build Mask R-CNN model."""

    input_specs = tf.keras.layers.InputSpec(
        shape=[None] + self.task_config.model.input_size)

    l2_weight_decay = self.task_config.losses.l2_weight_decay
    # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
    # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
    # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
    l2_regularizer = (tf.keras.regularizers.l2(
        l2_weight_decay / 2.0) if l2_weight_decay else None)

    model = factory.build_maskrcnn(
        input_specs=input_specs,
        model_config=self.task_config.model,
        l2_regularizer=l2_regularizer)
    return model

  def initialize(self, model: tf.keras.Model):
    """Loading pretrained checkpoint."""
    if not self.task_config.init_checkpoint:
      return

    ckpt_dir_or_file = self.task_config.init_checkpoint
    if tf.io.gfile.isdir(ckpt_dir_or_file):
      ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)

    # Restoring checkpoint.
    if self.task_config.init_checkpoint_modules == 'all':
      ckpt = tf.train.Checkpoint(**model.checkpoint_items)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
99
      status = ckpt.restore(ckpt_dir_or_file)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
100
101
102
      status.assert_consumed()
    elif self.task_config.init_checkpoint_modules == 'backbone':
      ckpt = tf.train.Checkpoint(backbone=model.backbone)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
103
      status = ckpt.restore(ckpt_dir_or_file)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
104
105
      status.expect_partial().assert_existing_objects_matched()
    else:
Yeqing Li's avatar
Yeqing Li committed
106
107
      raise ValueError(
          "Only 'all' or 'backbone' can be used to initialize the model.")
Abdullah Rashwan's avatar
Abdullah Rashwan committed
108
109
110
111

    logging.info('Finished loading pretrained checkpoint from %s',
                 ckpt_dir_or_file)

Fan Yang's avatar
Fan Yang committed
112
113
114
  def build_inputs(self,
                   params: exp_cfg.DataConfig,
                   input_context: Optional[tf.distribute.InputContext] = None):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
115
116
117
118
119
    """Build input dataset."""
    decoder_cfg = params.decoder.get()
    if params.decoder.type == 'simple_decoder':
      decoder = tf_example_decoder.TfExampleDecoder(
          include_mask=self._task_config.model.include_mask,
120
121
          regenerate_source_id=decoder_cfg.regenerate_source_id,
          mask_binarize_threshold=decoder_cfg.mask_binarize_threshold)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
122
123
124
125
    elif params.decoder.type == 'label_map_decoder':
      decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
          label_map=decoder_cfg.label_map,
          include_mask=self._task_config.model.include_mask,
126
127
          regenerate_source_id=decoder_cfg.regenerate_source_id,
          mask_binarize_threshold=decoder_cfg.mask_binarize_threshold)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    else:
      raise ValueError('Unknown decoder type: {}!'.format(params.decoder.type))

    parser = maskrcnn_input.Parser(
        output_size=self.task_config.model.input_size[:2],
        min_level=self.task_config.model.min_level,
        max_level=self.task_config.model.max_level,
        num_scales=self.task_config.model.anchor.num_scales,
        aspect_ratios=self.task_config.model.anchor.aspect_ratios,
        anchor_size=self.task_config.model.anchor.anchor_size,
        dtype=params.dtype,
        rpn_match_threshold=params.parser.rpn_match_threshold,
        rpn_unmatched_threshold=params.parser.rpn_unmatched_threshold,
        rpn_batch_size_per_im=params.parser.rpn_batch_size_per_im,
        rpn_fg_fraction=params.parser.rpn_fg_fraction,
        aug_rand_hflip=params.parser.aug_rand_hflip,
        aug_scale_min=params.parser.aug_scale_min,
        aug_scale_max=params.parser.aug_scale_max,
        skip_crowd_during_training=params.parser.skip_crowd_during_training,
        max_num_instances=params.parser.max_num_instances,
        include_mask=self._task_config.model.include_mask,
        mask_crop_size=params.parser.mask_crop_size)

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
151
    reader = input_reader_factory.input_reader_generator(
Abdullah Rashwan's avatar
Abdullah Rashwan committed
152
        params,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
153
        dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
Abdullah Rashwan's avatar
Abdullah Rashwan committed
154
155
156
157
158
159
        decoder_fn=decoder.decode,
        parser_fn=parser.parse_fn(params.is_training))
    dataset = reader.read(input_context=input_context)

    return dataset

Fan Yang's avatar
Fan Yang committed
160
161
162
163
  def build_losses(self,
                   outputs: Mapping[str, Any],
                   labels: Mapping[str, Any],
                   aux_losses: Optional[Any] = None):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
164
165
    """Build Mask R-CNN losses."""
    params = self.task_config
Xianzhi Du's avatar
Xianzhi Du committed
166
    cascade_ious = params.model.roi_sampler.cascade_iou_thresholds
Abdullah Rashwan's avatar
Abdullah Rashwan committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180

    rpn_score_loss_fn = maskrcnn_losses.RpnScoreLoss(
        tf.shape(outputs['box_outputs'])[1])
    rpn_box_loss_fn = maskrcnn_losses.RpnBoxLoss(
        params.losses.rpn_huber_loss_delta)
    rpn_score_loss = tf.reduce_mean(
        rpn_score_loss_fn(
            outputs['rpn_scores'], labels['rpn_score_targets']))
    rpn_box_loss = tf.reduce_mean(
        rpn_box_loss_fn(
            outputs['rpn_boxes'], labels['rpn_box_targets']))

    frcnn_cls_loss_fn = maskrcnn_losses.FastrcnnClassLoss()
    frcnn_box_loss_fn = maskrcnn_losses.FastrcnnBoxLoss(
Xianzhi Du's avatar
Xianzhi Du committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
        params.losses.frcnn_huber_loss_delta,
        params.model.detection_head.class_agnostic_bbox_pred)

    # Final cls/box losses are computed as an average of all detection heads.
    frcnn_cls_loss = 0.0
    frcnn_box_loss = 0.0
    num_det_heads = 1 if cascade_ious is None else 1 + len(cascade_ious)
    for cas_num in range(num_det_heads):
      frcnn_cls_loss_i = tf.reduce_mean(
          frcnn_cls_loss_fn(
              outputs['class_outputs_{}'
                      .format(cas_num) if cas_num else 'class_outputs'],
              outputs['class_targets_{}'
                      .format(cas_num) if cas_num else 'class_targets']))
      frcnn_box_loss_i = tf.reduce_mean(
          frcnn_box_loss_fn(
              outputs['box_outputs_{}'.format(cas_num
                                             ) if cas_num else 'box_outputs'],
              outputs['class_targets_{}'
                      .format(cas_num) if cas_num else 'class_targets'],
              outputs['box_targets_{}'.format(cas_num
                                             ) if cas_num else 'box_targets']))
      frcnn_cls_loss += frcnn_cls_loss_i
      frcnn_box_loss += frcnn_box_loss_i
    frcnn_cls_loss /= num_det_heads
    frcnn_box_loss /= num_det_heads
Abdullah Rashwan's avatar
Abdullah Rashwan committed
207
208
209

    if params.model.include_mask:
      mask_loss_fn = maskrcnn_losses.MaskrcnnLoss()
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
210
211
212
213
214
215
      mask_class_targets = outputs['mask_class_targets']
      if self._task_config.allowed_mask_class_ids is not None:
        # Classes with ID=0 are ignored by mask_loss_fn in loss computation.
        mask_class_targets = zero_out_disallowed_class_ids(
            mask_class_targets, self._task_config.allowed_mask_class_ids)

Abdullah Rashwan's avatar
Abdullah Rashwan committed
216
217
218
219
      mask_loss = tf.reduce_mean(
          mask_loss_fn(
              outputs['mask_outputs'],
              outputs['mask_targets'],
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
220
              mask_class_targets))
Abdullah Rashwan's avatar
Abdullah Rashwan committed
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
    else:
      mask_loss = 0.0

    model_loss = (
        params.losses.rpn_score_weight * rpn_score_loss +
        params.losses.rpn_box_weight * rpn_box_loss +
        params.losses.frcnn_class_weight * frcnn_cls_loss +
        params.losses.frcnn_box_weight * frcnn_box_loss +
        params.losses.mask_weight * mask_loss)

    total_loss = model_loss
    if aux_losses:
      reg_loss = tf.reduce_sum(aux_losses)
      total_loss = model_loss + reg_loss

    losses = {
        'total_loss': total_loss,
        'rpn_score_loss': rpn_score_loss,
        'rpn_box_loss': rpn_box_loss,
        'frcnn_cls_loss': frcnn_cls_loss,
        'frcnn_box_loss': frcnn_box_loss,
        'mask_loss': mask_loss,
        'model_loss': model_loss,
    }
    return losses

Fan Yang's avatar
Fan Yang committed
247
  def build_metrics(self, training: bool = True):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
    """Build detection metrics."""
    metrics = []
    if training:
      metric_names = [
          'total_loss',
          'rpn_score_loss',
          'rpn_box_loss',
          'frcnn_cls_loss',
          'frcnn_box_loss',
          'mask_loss',
          'model_loss'
      ]
      for name in metric_names:
        metrics.append(tf.keras.metrics.Mean(name, dtype=tf.float32))

    else:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
264
265
      if (not self._task_config.model.include_mask
         ) or self._task_config.annotation_file:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
266
267
268
269
270
        self.coco_metric = coco_evaluator.COCOEvaluator(
            annotation_file=self._task_config.annotation_file,
            include_mask=self._task_config.model.include_mask,
            per_category_metrics=self._task_config.per_category_metrics)
      else:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
271
272
        # Builds COCO-style annotation file if include_mask is True, and
        # annotation_file isn't provided.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
        annotation_path = os.path.join(self._logging_dir, 'annotation.json')
        if tf.io.gfile.exists(annotation_path):
          logging.info(
              'annotation.json file exists, skipping creating the annotation'
              ' file.')
        else:
          if self._task_config.validation_data.num_examples <= 0:
            logging.info('validation_data.num_examples needs to be > 0')
          if not self._task_config.validation_data.input_path:
            logging.info('Can not create annotation file for tfds.')
          logging.info(
              'Creating coco-style annotation file: %s', annotation_path)
          coco_utils.scan_and_generator_annotation_file(
              self._task_config.validation_data.input_path,
              self._task_config.validation_data.file_type,
              self._task_config.validation_data.num_examples,
              self.task_config.model.include_mask, annotation_path)
        self.coco_metric = coco_evaluator.COCOEvaluator(
            annotation_file=annotation_path,
            include_mask=self._task_config.model.include_mask,
            per_category_metrics=self._task_config.per_category_metrics)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
294
295
296

    return metrics

Fan Yang's avatar
Fan Yang committed
297
298
299
300
301
  def train_step(self,
                 inputs: Tuple[Any, Any],
                 model: tf.keras.Model,
                 optimizer: tf.keras.optimizers.Optimizer,
                 metrics: Optional[List[Any]] = None):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
    """Does forward and backward.

    Args:
      inputs: a dictionary of input tensors.
      model: the model, forward pass definition.
      optimizer: the optimizer for this training step.
      metrics: a nested structure of metrics objects.

    Returns:
      A dictionary of logs.
    """
    images, labels = inputs
    num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
    with tf.GradientTape() as tape:
      outputs = model(
          images,
          image_shape=labels['image_info'][:, 1, :],
          anchor_boxes=labels['anchor_boxes'],
          gt_boxes=labels['gt_boxes'],
          gt_classes=labels['gt_classes'],
          gt_masks=(labels['gt_masks'] if self.task_config.model.include_mask
                    else None),
          training=True)
      outputs = tf.nest.map_structure(
          lambda x: tf.cast(x, tf.float32), outputs)

      # Computes per-replica loss.
      losses = self.build_losses(
          outputs=outputs, labels=labels, aux_losses=model.losses)
      scaled_loss = losses['total_loss'] / num_replicas

      # For mixed_precision policy, when LossScaleOptimizer is used, loss is
      # scaled for numerical stability.
Pankaj Kanwar's avatar
Pankaj Kanwar committed
335
      if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
336
337
338
339
340
        scaled_loss = optimizer.get_scaled_loss(scaled_loss)

    tvars = model.trainable_variables
    grads = tape.gradient(scaled_loss, tvars)
    # Scales back gradient when LossScaleOptimizer is used.
Pankaj Kanwar's avatar
Pankaj Kanwar committed
341
    if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
342
343
344
345
346
347
348
349
350
351
352
      grads = optimizer.get_unscaled_gradients(grads)
    optimizer.apply_gradients(list(zip(grads, tvars)))

    logs = {self.loss: losses['total_loss']}

    if metrics:
      for m in metrics:
        m.update_state(losses[m.name])

    return logs

Fan Yang's avatar
Fan Yang committed
353
354
355
356
  def validation_step(self,
                      inputs: Tuple[Any, Any],
                      model: tf.keras.Model,
                      metrics: Optional[List[Any]] = None):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
    """Validatation step.

    Args:
      inputs: a dictionary of input tensors.
      model: the keras.Model.
      metrics: a nested structure of metrics objects.

    Returns:
      A dictionary of logs.
    """
    images, labels = inputs

    outputs = model(
        images,
        anchor_boxes=labels['anchor_boxes'],
        image_shape=labels['image_info'][:, 1, :],
        training=False)

    logs = {self.loss: 0}
    coco_model_outputs = {
        'detection_boxes': outputs['detection_boxes'],
        'detection_scores': outputs['detection_scores'],
        'detection_classes': outputs['detection_classes'],
        'num_detections': outputs['num_detections'],
        'source_id': labels['groundtruths']['source_id'],
        'image_info': labels['image_info']
    }
    if self.task_config.model.include_mask:
      coco_model_outputs.update({
          'detection_masks': outputs['detection_masks'],
      })
    logs.update({
        self.coco_metric.name: (labels['groundtruths'], coco_model_outputs)
    })
    return logs

  def aggregate_logs(self, state=None, step_outputs=None):
    if state is None:
      self.coco_metric.reset_states()
      state = self.coco_metric
    self.coco_metric.update_state(
        step_outputs[self.coco_metric.name][0],
        step_outputs[self.coco_metric.name][1])
    return state

402
  def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
403
    return self.coco_metric.result()