retinanet_model.py 6.78 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model defination for the RetinaNet Model."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

21
import tensorflow as tf
22
23

from official.vision.detection.dataloader import mode_keys
Pengchong Jin's avatar
Pengchong Jin committed
24
from official.vision.detection.evaluation import factory as eval_factory
25
26
27
from official.vision.detection.modeling import base_model
from official.vision.detection.modeling import losses
from official.vision.detection.modeling.architecture import factory
28
from official.vision.detection.modeling.architecture import keras_utils
Pengchong Jin's avatar
Pengchong Jin committed
29
from official.vision.detection.ops import postprocess_ops
30
31
32
33
34
35
36
37
38
39
40
41
42
43


class RetinanetModel(base_model.Model):
  """RetinaNet model function."""

  def __init__(self, params):
    super(RetinanetModel, self).__init__(params)

    # For eval metrics.
    self._params = params

    # Architecture generators.
    self._backbone_fn = factory.backbone_generator(params)
    self._fpn_fn = factory.multilevel_features_generator(params)
Pengchong Jin's avatar
Pengchong Jin committed
44
    self._head_fn = factory.retinanet_head_generator(params)
45
46

    # Loss function.
Pengchong Jin's avatar
Pengchong Jin committed
47
48
    self._cls_loss_fn = losses.RetinanetClassLoss(
        params.retinanet_loss, params.architecture.num_classes)
49
50
51
52
53
    self._box_loss_fn = losses.RetinanetBoxLoss(params.retinanet_loss)
    self._box_loss_weight = params.retinanet_loss.box_loss_weight
    self._keras_model = None

    # Predict function.
Pengchong Jin's avatar
Pengchong Jin committed
54
    self._generate_detections_fn = postprocess_ops.MultilevelDetectionGenerator(
Pengchong Jin's avatar
Pengchong Jin committed
55
56
        params.architecture.min_level,
        params.architecture.max_level,
57
58
59
        params.postprocess)

    self._transpose_input = params.train.transpose_input
60
    assert not self._transpose_input, 'Transpose input is not supported.'
61
62
63
64
    # Input layer.
    input_shape = (
        params.retinanet_parser.output_size +
        [params.retinanet_parser.num_channels])
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
65
66
67
    self._input_layer = tf.keras.layers.Input(
        shape=input_shape, name='',
        dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32)
68
69

  def build_outputs(self, inputs, mode):
Yeqing Li's avatar
Yeqing Li committed
70
71
72
73
74
    # If the input image is transposed (from NHWC to HWCN), we need to revert it
    # back to the original shape before it's used in the computation.
    if self._transpose_input:
      inputs = tf.transpose(inputs, [3, 0, 1, 2])

75
76
77
78
79
80
    backbone_features = self._backbone_fn(
        inputs, is_training=(mode == mode_keys.TRAIN))
    fpn_features = self._fpn_fn(
        backbone_features, is_training=(mode == mode_keys.TRAIN))
    cls_outputs, box_outputs = self._head_fn(
        fpn_features, is_training=(mode == mode_keys.TRAIN))
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
81
82
83
84
85
86
87

    if self._use_bfloat16:
      levels = cls_outputs.keys()
      for level in levels:
        cls_outputs[level] = tf.cast(cls_outputs[level], tf.float32)
        box_outputs[level] = tf.cast(box_outputs[level], tf.float32)

88
89
90
91
92
93
94
95
96
97
    model_outputs = {
        'cls_outputs': cls_outputs,
        'box_outputs': box_outputs,
    }
    return model_outputs

  def build_loss_fn(self):
    if self._keras_model is None:
      raise ValueError('build_loss_fn() must be called after build_model().')

Yeqing Li's avatar
Yeqing Li committed
98
99
100
    filter_fn = self.make_filter_trainable_variables_fn()
    trainable_variables = filter_fn(self._keras_model.trainable_variables)

101
102
103
104
105
106
107
108
    def _total_loss_fn(labels, outputs):
      cls_loss = self._cls_loss_fn(outputs['cls_outputs'],
                                   labels['cls_targets'],
                                   labels['num_positives'])
      box_loss = self._box_loss_fn(outputs['box_outputs'],
                                   labels['box_targets'],
                                   labels['num_positives'])
      model_loss = cls_loss + self._box_loss_weight * box_loss
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
109
      l2_regularization_loss = self.weight_decay_loss(trainable_variables)
110
111
112
113
114
115
116
117
118
119
120
121
122
      total_loss = model_loss + l2_regularization_loss
      return {
          'total_loss': total_loss,
          'cls_loss': cls_loss,
          'box_loss': box_loss,
          'model_loss': model_loss,
          'l2_regularization_loss': l2_regularization_loss,
      }

    return _total_loss_fn

  def build_model(self, params, mode=None):
    if self._keras_model is None:
123
      with keras_utils.maybe_enter_backend_graph():
124
125
126
127
128
129
130
131
132
133
134
        outputs = self.model_outputs(self._input_layer, mode)

        model = tf.keras.models.Model(
            inputs=self._input_layer, outputs=outputs, name='retinanet')
        assert model is not None, 'Fail to build tf.keras.Model.'
        model.optimizer = self.build_optimizer()
        self._keras_model = model

    return self._keras_model

  def post_processing(self, labels, outputs):
Yeqing Li's avatar
Yeqing Li committed
135
    # TODO(yeqing): Moves the output related part into build_outputs.
136
137
138
139
140
141
142
143
144
145
146
    required_output_fields = ['cls_outputs', 'box_outputs']
    for field in required_output_fields:
      if field not in outputs:
        raise ValueError('"%s" is missing in outputs, requried %s found %s',
                         field, required_output_fields, outputs.keys())
    required_label_fields = ['image_info', 'groundtruths']
    for field in required_label_fields:
      if field not in labels:
        raise ValueError('"%s" is missing in outputs, requried %s found %s',
                         field, required_label_fields, labels.keys())
    boxes, scores, classes, valid_detections = self._generate_detections_fn(
Pengchong Jin's avatar
Pengchong Jin committed
147
148
        outputs['box_outputs'], outputs['cls_outputs'],
        labels['anchor_boxes'], labels['image_info'][:, 1:2, :])
149
150
151
    # Discards the old output tensors to save memory. The `cls_outputs` and
    # `box_outputs` are pretty big and could potentiall lead to memory issue.
    outputs = {
152
153
154
155
156
157
        'source_id': labels['groundtruths']['source_id'],
        'image_info': labels['image_info'],
        'num_detections': valid_detections,
        'detection_boxes': boxes,
        'detection_classes': classes,
        'detection_scores': scores,
158
    }
159
160
161
162
163
164
165
166
167
168
169

    if 'groundtruths' in labels:
      labels['source_id'] = labels['groundtruths']['source_id']
      labels['boxes'] = labels['groundtruths']['boxes']
      labels['classes'] = labels['groundtruths']['classes']
      labels['areas'] = labels['groundtruths']['areas']
      labels['is_crowds'] = labels['groundtruths']['is_crowds']

    return labels, outputs

  def eval_metrics(self):
Yeqing Li's avatar
Yeqing Li committed
170
    return eval_factory.evaluator_generator(self._params.eval)