retinanet_model.py 8.44 KB
Newer Older
Yeqing Li's avatar
Yeqing Li committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Yeqing Li's avatar
Yeqing Li committed
14

Abdullah Rashwan's avatar
Abdullah Rashwan committed
15
"""RetinaNet."""
Fan Yang's avatar
Fan Yang committed
16
from typing import Any, Mapping, List, Optional, Union
Abdullah Rashwan's avatar
Abdullah Rashwan committed
17
18
19
20

# Import libraries
import tensorflow as tf

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
21
22
from official.vision.beta.ops import anchor

Abdullah Rashwan's avatar
Abdullah Rashwan committed
23
24
25
26
27
28

@tf.keras.utils.register_keras_serializable(package='Vision')
class RetinaNetModel(tf.keras.Model):
  """The RetinaNet model class."""

  def __init__(self,
Fan Yang's avatar
Fan Yang committed
29
30
31
32
               backbone: tf.keras.Model,
               decoder: tf.keras.Model,
               head: tf.keras.layers.Layer,
               detection_generator: tf.keras.layers.Layer,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
33
34
35
36
37
               min_level: Optional[int] = None,
               max_level: Optional[int] = None,
               num_scales: Optional[int] = None,
               aspect_ratios: Optional[List[float]] = None,
               anchor_size: Optional[float] = None,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
38
39
40
41
42
43
44
45
               **kwargs):
    """Classification initialization function.

    Args:
      backbone: `tf.keras.Model` a backbone network.
      decoder: `tf.keras.Model` a decoder network.
      head: `RetinaNetHead`, the RetinaNet head.
      detection_generator: the detection generator.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
46
47
48
49
50
51
52
53
54
55
56
      min_level: Minimum level in output feature maps.
      max_level: Maximum level in output feature maps.
      num_scales: A number representing intermediate scales added
        on each level. For instances, num_scales=2 adds one additional
        intermediate anchor scales [2^0, 2^0.5] on each level.
      aspect_ratios: A list representing the aspect raito
        anchors added on each level. The number indicates the ratio of width to
        height. For instances, aspect_ratios=[1.0, 2.0, 0.5] adds three anchors
        on each scale level.
      anchor_size: A number representing the scale of size of the base
        anchor to the feature stride 2^level.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
57
58
59
60
61
62
63
64
      **kwargs: keyword arguments to be passed.
    """
    super(RetinaNetModel, self).__init__(**kwargs)
    self._config_dict = {
        'backbone': backbone,
        'decoder': decoder,
        'head': head,
        'detection_generator': detection_generator,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
65
66
67
68
69
        'min_level': min_level,
        'max_level': max_level,
        'num_scales': num_scales,
        'aspect_ratios': aspect_ratios,
        'anchor_size': anchor_size,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
70
71
72
73
74
75
76
    }
    self._backbone = backbone
    self._decoder = decoder
    self._head = head
    self._detection_generator = detection_generator

  def call(self,
Fan Yang's avatar
Fan Yang committed
77
78
79
           images: tf.Tensor,
           image_shape: Optional[tf.Tensor] = None,
           anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
80
           output_intermediate_features: bool = False,
Fan Yang's avatar
Fan Yang committed
81
           training: bool = None) -> Mapping[str, tf.Tensor]:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
82
83
84
85
86
87
88
89
90
91
92
    """Forward pass of the RetinaNet model.

    Args:
      images: `Tensor`, the input batched images, whose shape is
        [batch, height, width, 3].
      image_shape: `Tensor`, the actual shape of the input images, whose shape
        is [batch, 2] where the last dimension is [height, width]. Note that
        this is the actual image shape excluding paddings. For example, images
        in the batch may be resized into different shapes before padding to the
        fixed size.
      anchor_boxes: a dict of tensors which includes multilevel anchors.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
93
        - key: `str`, the level of the multilevel predictions.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
94
95
        - values: `Tensor`, the anchor coordinates of a particular feature
            level, whose shape is [height_l, width_l, num_anchors_per_location].
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
96
97
      output_intermediate_features: `bool` indicating whether to return the
        intermediate feature maps generated by backbone and decoder.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
98
99
100
101
      training: `bool`, indicating whether it is in training mode.

    Returns:
      scores: a dict of tensors which includes scores of the predictions.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
102
        - key: `str`, the level of the multilevel predictions.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
103
104
105
106
        - values: `Tensor`, the box scores predicted from a particular feature
            level, whose shape is
            [batch, height_l, width_l, num_classes * num_anchors_per_location].
      boxes: a dict of tensors which includes coordinates of the predictions.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
107
        - key: `str`, the level of the multilevel predictions.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
108
109
110
        - values: `Tensor`, the box coordinates predicted from a particular
            feature level, whose shape is
            [batch, height_l, width_l, 4 * num_anchors_per_location].
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
111
112
113
114
115
116
      attributes: a dict of (attribute_name, attribute_predictions). Each
        attribute prediction is a dict that includes:
        - key: `str`, the level of the multilevel predictions.
        - values: `Tensor`, the attribute predictions from a particular
            feature level, whose shape is
            [batch, height_l, width_l, att_size * num_anchors_per_location].
Abdullah Rashwan's avatar
Abdullah Rashwan committed
117
    """
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
118
    outputs = {}
Abdullah Rashwan's avatar
Abdullah Rashwan committed
119
120
    # Feature extraction.
    features = self.backbone(images)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
121
122
123
    if output_intermediate_features:
      outputs.update(
          {'backbone_{}'.format(k): v for k, v in features.items()})
Abdullah Rashwan's avatar
Abdullah Rashwan committed
124
125
    if self.decoder:
      features = self.decoder(features)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
126
127
128
      if output_intermediate_features:
        outputs.update(
            {'decoder_{}'.format(k): v for k, v in features.items()})
Abdullah Rashwan's avatar
Abdullah Rashwan committed
129

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
130
131
    # Dense prediction. `raw_attributes` can be empty.
    raw_scores, raw_boxes, raw_attributes = self.head(features)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
132
133

    if training:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
134
      outputs.update({
Abdullah Rashwan's avatar
Abdullah Rashwan committed
135
136
          'cls_outputs': raw_scores,
          'box_outputs': raw_boxes,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
137
      })
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
138
      if raw_attributes:
Xianzhi Du's avatar
Xianzhi Du committed
139
        outputs.update({'attribute_outputs': raw_attributes})
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
140
      return outputs
Abdullah Rashwan's avatar
Abdullah Rashwan committed
141
    else:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
      # Generate anchor boxes for this batch if not provided.
      if anchor_boxes is None:
        _, image_height, image_width, _ = images.get_shape().as_list()
        anchor_boxes = anchor.Anchor(
            min_level=self._config_dict['min_level'],
            max_level=self._config_dict['max_level'],
            num_scales=self._config_dict['num_scales'],
            aspect_ratios=self._config_dict['aspect_ratios'],
            anchor_size=self._config_dict['anchor_size'],
            image_size=(image_height, image_width)).multilevel_boxes
        for l in anchor_boxes:
          anchor_boxes[l] = tf.tile(
              tf.expand_dims(anchor_boxes[l], axis=0),
              [tf.shape(images)[0], 1, 1, 1])

Abdullah Rashwan's avatar
Abdullah Rashwan committed
157
      # Post-processing.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
158
159
160
161
      final_results = self.detection_generator(raw_boxes, raw_scores,
                                               anchor_boxes, image_shape,
                                               raw_attributes)
      outputs.update({
Abdullah Rashwan's avatar
Abdullah Rashwan committed
162
          'cls_outputs': raw_scores,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
163
          'box_outputs': raw_boxes,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
164
      })
Fan Yang's avatar
Fan Yang committed
165
166
167
168
169
170
171
172
173
174
175
176
177
      if self.detection_generator.get_config()['apply_nms']:
        outputs.update({
            'detection_boxes': final_results['detection_boxes'],
            'detection_scores': final_results['detection_scores'],
            'detection_classes': final_results['detection_classes'],
            'num_detections': final_results['num_detections']
        })
      else:
        outputs.update({
            'decoded_boxes': final_results['decoded_boxes'],
            'decoded_box_scores': final_results['decoded_box_scores']
        })

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
178
179
      if raw_attributes:
        outputs.update({
Xianzhi Du's avatar
Xianzhi Du committed
180
            'attribute_outputs': raw_attributes,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
181
182
183
            'detection_attributes': final_results['detection_attributes'],
        })
      return outputs
Abdullah Rashwan's avatar
Abdullah Rashwan committed
184
185

  @property
Fan Yang's avatar
Fan Yang committed
186
187
  def checkpoint_items(
      self) -> Mapping[str, Union[tf.keras.Model, tf.keras.layers.Layer]]:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
188
189
190
191
192
193
194
195
    """Returns a dictionary of items to be additionally checkpointed."""
    items = dict(backbone=self.backbone, head=self.head)
    if self.decoder is not None:
      items.update(decoder=self.decoder)

    return items

  @property
Fan Yang's avatar
Fan Yang committed
196
  def backbone(self) -> tf.keras.Model:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
197
198
199
    return self._backbone

  @property
Fan Yang's avatar
Fan Yang committed
200
  def decoder(self) -> tf.keras.Model:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
201
202
203
    return self._decoder

  @property
Fan Yang's avatar
Fan Yang committed
204
  def head(self) -> tf.keras.layers.Layer:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
205
206
207
    return self._head

  @property
Fan Yang's avatar
Fan Yang committed
208
  def detection_generator(self) -> tf.keras.layers.Layer:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
209
210
    return self._detection_generator

Fan Yang's avatar
Fan Yang committed
211
  def get_config(self) -> Mapping[str, Any]:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
212
213
214
215
216
    return self._config_dict

  @classmethod
  def from_config(cls, config):
    return cls(**config)