retinanet.py 11.3 KB
Newer Older
Yeqing Li's avatar
Yeqing Li committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Yeqing Li's avatar
Yeqing Li committed
14
15

# Lint as: python3
Abdullah Rashwan's avatar
Abdullah Rashwan committed
16
17
18
19
"""RetinaNet task definition."""

from absl import logging
import tensorflow as tf
20
from official.common import dataset_fn
Abdullah Rashwan's avatar
Abdullah Rashwan committed
21
22
from official.core import base_task
from official.core import task_factory
Zhenyu Tan's avatar
Zhenyu Tan committed
23
from official.vision import keras_cv
Abdullah Rashwan's avatar
Abdullah Rashwan committed
24
from official.vision.beta.configs import retinanet as exp_cfg
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
25
from official.vision.beta.dataloaders import input_reader_factory
Abdullah Rashwan's avatar
Abdullah Rashwan committed
26
27
from official.vision.beta.dataloaders import retinanet_input
from official.vision.beta.dataloaders import tf_example_decoder
Abdullah Rashwan's avatar
Abdullah Rashwan committed
28
from official.vision.beta.dataloaders import tfds_detection_decoders
Abdullah Rashwan's avatar
Abdullah Rashwan committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from official.vision.beta.dataloaders import tf_example_label_map_decoder
from official.vision.beta.evaluation import coco_evaluator
from official.vision.beta.modeling import factory


@task_factory.register_task_cls(exp_cfg.RetinaNetTask)
class RetinaNetTask(base_task.Task):
  """A single-replica view of training procedure.

  RetinaNet task provides artifacts for training/evalution procedures, including
  loading/iterating over Datasets, initializing the model, calculating the loss,
  post-processing, and customized metrics with reduction.
  """

  def build_model(self):
    """Build RetinaNet model."""

    input_specs = tf.keras.layers.InputSpec(
        shape=[None] + self.task_config.model.input_size)

    l2_weight_decay = self.task_config.losses.l2_weight_decay
    # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
    # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
    # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
    l2_regularizer = (tf.keras.regularizers.l2(
        l2_weight_decay / 2.0) if l2_weight_decay else None)

    model = factory.build_retinanet(
        input_specs=input_specs,
        model_config=self.task_config.model,
        l2_regularizer=l2_regularizer)
    return model

  def initialize(self, model: tf.keras.Model):
    """Loading pretrained checkpoint."""
    if not self.task_config.init_checkpoint:
      return

    ckpt_dir_or_file = self.task_config.init_checkpoint
    if tf.io.gfile.isdir(ckpt_dir_or_file):
      ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)

    # Restoring checkpoint.
    if self.task_config.init_checkpoint_modules == 'all':
      ckpt = tf.train.Checkpoint(**model.checkpoint_items)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
74
      status = ckpt.restore(ckpt_dir_or_file)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
75
76
77
      status.assert_consumed()
    elif self.task_config.init_checkpoint_modules == 'backbone':
      ckpt = tf.train.Checkpoint(backbone=model.backbone)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
78
      status = ckpt.restore(ckpt_dir_or_file)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
79
80
      status.expect_partial().assert_existing_objects_matched()
    else:
Yeqing Li's avatar
Yeqing Li committed
81
82
      raise ValueError(
          "Only 'all' or 'backbone' can be used to initialize the model.")
Abdullah Rashwan's avatar
Abdullah Rashwan committed
83
84
85
86
87
88

    logging.info('Finished loading pretrained checkpoint from %s',
                 ckpt_dir_or_file)

  def build_inputs(self, params, input_context=None):
    """Build input dataset."""
Abdullah Rashwan's avatar
Abdullah Rashwan committed
89
90
91
92
93
94
95

    if params.tfds_name:
      if params.tfds_name in tfds_detection_decoders.TFDS_ID_TO_DECODER_MAP:
        decoder = tfds_detection_decoders.TFDS_ID_TO_DECODER_MAP[
            params.tfds_name]()
      else:
        raise ValueError('TFDS {} is not supported'.format(params.tfds_name))
Abdullah Rashwan's avatar
Abdullah Rashwan committed
96
    else:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
97
98
99
100
101
102
103
104
105
106
107
      decoder_cfg = params.decoder.get()
      if params.decoder.type == 'simple_decoder':
        decoder = tf_example_decoder.TfExampleDecoder(
            regenerate_source_id=decoder_cfg.regenerate_source_id)
      elif params.decoder.type == 'label_map_decoder':
        decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
            label_map=decoder_cfg.label_map,
            regenerate_source_id=decoder_cfg.regenerate_source_id)
      else:
        raise ValueError('Unknown decoder type: {}!'.format(
            params.decoder.type))
Abdullah Rashwan's avatar
Abdullah Rashwan committed
108

Abdullah Rashwan's avatar
Abdullah Rashwan committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    parser = retinanet_input.Parser(
        output_size=self.task_config.model.input_size[:2],
        min_level=self.task_config.model.min_level,
        max_level=self.task_config.model.max_level,
        num_scales=self.task_config.model.anchor.num_scales,
        aspect_ratios=self.task_config.model.anchor.aspect_ratios,
        anchor_size=self.task_config.model.anchor.anchor_size,
        dtype=params.dtype,
        match_threshold=params.parser.match_threshold,
        unmatched_threshold=params.parser.unmatched_threshold,
        aug_rand_hflip=params.parser.aug_rand_hflip,
        aug_scale_min=params.parser.aug_scale_min,
        aug_scale_max=params.parser.aug_scale_max,
        skip_crowd_during_training=params.parser.skip_crowd_during_training,
        max_num_instances=params.parser.max_num_instances)

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
125
    reader = input_reader_factory.input_reader_generator(
Abdullah Rashwan's avatar
Abdullah Rashwan committed
126
        params,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
127
        dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
Abdullah Rashwan's avatar
Abdullah Rashwan committed
128
129
130
131
132
133
134
135
136
        decoder_fn=decoder.decode,
        parser_fn=parser.parse_fn(params.is_training))
    dataset = reader.read(input_context=input_context)

    return dataset

  def build_losses(self, outputs, labels, aux_losses=None):
    """Build RetinaNet losses."""
    params = self.task_config
Zhenyu Tan's avatar
Zhenyu Tan committed
137
    cls_loss_fn = keras_cv.losses.FocalLoss(
Abdullah Rashwan's avatar
Abdullah Rashwan committed
138
139
140
        alpha=params.losses.focal_loss_alpha,
        gamma=params.losses.focal_loss_gamma,
        reduction=tf.keras.losses.Reduction.SUM)
Zhenyu Tan's avatar
Zhenyu Tan committed
141
    box_loss_fn = tf.keras.losses.Huber(
Abdullah Rashwan's avatar
Abdullah Rashwan committed
142
143
144
145
146
147
148
149
150
        params.losses.huber_loss_delta, reduction=tf.keras.losses.Reduction.SUM)

    # Sums all positives in a batch for normalization and avoids zero
    # num_positives_sum, which would lead to inf loss during training
    cls_sample_weight = labels['cls_weights']
    box_sample_weight = labels['box_weights']
    num_positives = tf.reduce_sum(box_sample_weight) + 1.0
    cls_sample_weight = cls_sample_weight / num_positives
    box_sample_weight = box_sample_weight / num_positives
Zhenyu Tan's avatar
Zhenyu Tan committed
151
    y_true_cls = keras_cv.losses.multi_level_flatten(
Zhenyu Tan's avatar
Zhenyu Tan committed
152
153
        labels['cls_targets'], last_dim=None)
    y_true_cls = tf.one_hot(y_true_cls, params.model.num_classes)
Zhenyu Tan's avatar
Zhenyu Tan committed
154
    y_pred_cls = keras_cv.losses.multi_level_flatten(
Zhenyu Tan's avatar
Zhenyu Tan committed
155
        outputs['cls_outputs'], last_dim=params.model.num_classes)
Zhenyu Tan's avatar
Zhenyu Tan committed
156
    y_true_box = keras_cv.losses.multi_level_flatten(
Zhenyu Tan's avatar
Zhenyu Tan committed
157
        labels['box_targets'], last_dim=4)
Zhenyu Tan's avatar
Zhenyu Tan committed
158
    y_pred_box = keras_cv.losses.multi_level_flatten(
Zhenyu Tan's avatar
Zhenyu Tan committed
159
160
        outputs['box_outputs'], last_dim=4)

Abdullah Rashwan's avatar
Abdullah Rashwan committed
161
    cls_loss = cls_loss_fn(
Zhenyu Tan's avatar
Zhenyu Tan committed
162
        y_true=y_true_cls, y_pred=y_pred_cls, sample_weight=cls_sample_weight)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
163
    box_loss = box_loss_fn(
Zhenyu Tan's avatar
Zhenyu Tan committed
164
        y_true=y_true_box, y_pred=y_pred_box, sample_weight=box_sample_weight)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182

    model_loss = cls_loss + params.losses.box_loss_weight * box_loss

    total_loss = model_loss
    if aux_losses:
      reg_loss = tf.reduce_sum(aux_losses)
      total_loss = model_loss + reg_loss

    return total_loss, cls_loss, box_loss, model_loss

  def build_metrics(self, training=True):
    """Build detection metrics."""
    metrics = []
    metric_names = ['total_loss', 'cls_loss', 'box_loss', 'model_loss']
    for name in metric_names:
      metrics.append(tf.keras.metrics.Mean(name, dtype=tf.float32))

    if not training:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
183
184
185
      if self.task_config.validation_data.tfds_name and self.task_config.annotation_file:
        raise ValueError(
            "Can't evaluate using annotation file when TFDS is used.")
Abdullah Rashwan's avatar
Abdullah Rashwan committed
186
      self.coco_metric = coco_evaluator.COCOEvaluator(
Abdullah Rashwan's avatar
Abdullah Rashwan committed
187
          annotation_file=self.task_config.annotation_file,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
188
          include_mask=False,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
189
          per_category_metrics=self.task_config.per_category_metrics)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218

    return metrics

  def train_step(self, inputs, model, optimizer, metrics=None):
    """Does forward and backward.

    Args:
      inputs: a dictionary of input tensors.
      model: the model, forward pass definition.
      optimizer: the optimizer for this training step.
      metrics: a nested structure of metrics objects.

    Returns:
      A dictionary of logs.
    """
    features, labels = inputs
    num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
    with tf.GradientTape() as tape:
      outputs = model(features, training=True)
      outputs = tf.nest.map_structure(
          lambda x: tf.cast(x, tf.float32), outputs)

      # Computes per-replica loss.
      loss, cls_loss, box_loss, model_loss = self.build_losses(
          outputs=outputs, labels=labels, aux_losses=model.losses)
      scaled_loss = loss / num_replicas

      # For mixed_precision policy, when LossScaleOptimizer is used, loss is
      # scaled for numerical stability.
Pankaj Kanwar's avatar
Pankaj Kanwar committed
219
      if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
220
221
222
223
224
        scaled_loss = optimizer.get_scaled_loss(scaled_loss)

    tvars = model.trainable_variables
    grads = tape.gradient(scaled_loss, tvars)
    # Scales back gradient when LossScaleOptimizer is used.
Pankaj Kanwar's avatar
Pankaj Kanwar committed
225
    if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
      grads = optimizer.get_unscaled_gradients(grads)
    optimizer.apply_gradients(list(zip(grads, tvars)))

    logs = {self.loss: loss}

    all_losses = {
        'total_loss': loss,
        'cls_loss': cls_loss,
        'box_loss': box_loss,
        'model_loss': model_loss,
    }
    if metrics:
      for m in metrics:
        m.update_state(all_losses[m.name])
        logs.update({m.name: m.result()})

    return logs

  def validation_step(self, inputs, model, metrics=None):
    """Validatation step.

    Args:
      inputs: a dictionary of input tensors.
      model: the keras.Model.
      metrics: a nested structure of metrics objects.

    Returns:
      A dictionary of logs.
    """
    features, labels = inputs

    outputs = model(features, anchor_boxes=labels['anchor_boxes'],
                    image_shape=labels['image_info'][:, 1, :],
                    training=False)
    loss, cls_loss, box_loss, model_loss = self.build_losses(
        outputs=outputs, labels=labels, aux_losses=model.losses)
    logs = {self.loss: loss}

    all_losses = {
        'total_loss': loss,
        'cls_loss': cls_loss,
        'box_loss': box_loss,
        'model_loss': model_loss,
    }

    coco_model_outputs = {
        'detection_boxes': outputs['detection_boxes'],
        'detection_scores': outputs['detection_scores'],
        'detection_classes': outputs['detection_classes'],
        'num_detections': outputs['num_detections'],
        'source_id': labels['groundtruths']['source_id'],
        'image_info': labels['image_info']
    }
    logs.update({self.coco_metric.name: (labels['groundtruths'],
                                         coco_model_outputs)})
    if metrics:
      for m in metrics:
        m.update_state(all_losses[m.name])
        logs.update({m.name: m.result()})
    return logs

  def aggregate_logs(self, state=None, step_outputs=None):
    if state is None:
      self.coco_metric.reset_states()
      state = self.coco_metric
    self.coco_metric.update_state(step_outputs[self.coco_metric.name][0],
                                  step_outputs[self.coco_metric.name][1])
    return state

295
  def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
296
    return self.coco_metric.result()