detection.py 7.23 KB
Newer Older
Yeqing Li's avatar
Yeqing Li committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Yeqing Li's avatar
Yeqing Li committed
14
15

# Lint as: python3
Abdullah Rashwan's avatar
Abdullah Rashwan committed
16
17
"""Detection input and model functions for serving/inference."""

18
from typing import Mapping, Text
Abdullah Rashwan's avatar
Abdullah Rashwan committed
19
20
21
22
23
import tensorflow as tf

from official.vision.beta import configs
from official.vision.beta.modeling import factory
from official.vision.beta.ops import anchor
Abdullah Rashwan's avatar
Abdullah Rashwan committed
24
from official.vision.beta.ops import box_ops
Abdullah Rashwan's avatar
Abdullah Rashwan committed
25
26
27
28
29
30
31
32
33
34
35
from official.vision.beta.ops import preprocess_ops
from official.vision.beta.serving import export_base


MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255)
STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)


class DetectionModule(export_base.ExportModule):
  """Detection Module."""

Hongkun Yu's avatar
Hongkun Yu committed
36
  def _build_model(self):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
37
38

    if self._batch_size is None:
Fan Yang's avatar
Fan Yang committed
39
      raise ValueError('batch_size cannot be None for detection models.')
Abdullah Rashwan's avatar
Abdullah Rashwan committed
40
41
42
    input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] +
                                            self._input_image_size + [3])

Hongkun Yu's avatar
Hongkun Yu committed
43
44
45
46
47
48
    if isinstance(self.params.task.model, configs.maskrcnn.MaskRCNN):
      model = factory.build_maskrcnn(
          input_specs=input_specs, model_config=self.params.task.model)
    elif isinstance(self.params.task.model, configs.retinanet.RetinaNet):
      model = factory.build_retinanet(
          input_specs=input_specs, model_config=self.params.task.model)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
49
50
    else:
      raise ValueError('Detection module not implemented for {} model.'.format(
Hongkun Yu's avatar
Hongkun Yu committed
51
          type(self.params.task.model)))
Abdullah Rashwan's avatar
Abdullah Rashwan committed
52

Hongkun Yu's avatar
Hongkun Yu committed
53
    return model
Abdullah Rashwan's avatar
Abdullah Rashwan committed
54
55

  def _build_inputs(self, image):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
56
    """Builds detection model inputs for serving."""
Hongkun Yu's avatar
Hongkun Yu committed
57
    model_params = self.params.task.model
Abdullah Rashwan's avatar
Abdullah Rashwan committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    # Normalizes image with mean and std pixel values.
    image = preprocess_ops.normalize_image(image,
                                           offset=MEAN_RGB,
                                           scale=STDDEV_RGB)

    image, image_info = preprocess_ops.resize_and_crop_image(
        image,
        self._input_image_size,
        padded_size=preprocess_ops.compute_padded_size(
            self._input_image_size, 2**model_params.max_level),
        aug_scale_min=1.0,
        aug_scale_max=1.0)

    input_anchor = anchor.build_anchor_generator(
        min_level=model_params.min_level,
        max_level=model_params.max_level,
        num_scales=model_params.anchor.num_scales,
        aspect_ratios=model_params.anchor.aspect_ratios,
        anchor_size=model_params.anchor.anchor_size)
    anchor_boxes = input_anchor(image_size=(self._input_image_size[0],
                                            self._input_image_size[1]))

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
80
    return image, anchor_boxes, image_info
Abdullah Rashwan's avatar
Abdullah Rashwan committed
81

82
83
84
  def preprocess(self, images: tf.Tensor) -> (
      tf.Tensor, Mapping[Text, tf.Tensor], tf.Tensor):
    """Preprocess inputs to be suitable for the model.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
85
86

    Args:
87
      images: The images tensor.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
88
    Returns:
89
90
91
92
      images: The images tensor cast to float.
      anchor_boxes: Dict mapping anchor levels to anchor boxes.
      image_info: Tensor containing the details of the image resizing.

Abdullah Rashwan's avatar
Abdullah Rashwan committed
93
    """
Hongkun Yu's avatar
Hongkun Yu committed
94
    model_params = self.params.task.model
Abdullah Rashwan's avatar
Abdullah Rashwan committed
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    with tf.device('cpu:0'):
      images = tf.cast(images, dtype=tf.float32)

      # Tensor Specs for map_fn outputs (images, anchor_boxes, and image_info).
      images_spec = tf.TensorSpec(shape=self._input_image_size + [3],
                                  dtype=tf.float32)

      num_anchors = model_params.anchor.num_scales * len(
          model_params.anchor.aspect_ratios) * 4
      anchor_shapes = []
      for level in range(model_params.min_level, model_params.max_level + 1):
        anchor_level_spec = tf.TensorSpec(
            shape=[
                self._input_image_size[0] // 2**level,
                self._input_image_size[1] // 2**level, num_anchors
            ],
            dtype=tf.float32)
        anchor_shapes.append((str(level), anchor_level_spec))

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
114
      image_info_spec = tf.TensorSpec(shape=[4, 2], dtype=tf.float32)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
115

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
116
      images, anchor_boxes, image_info = tf.nest.map_structure(
Abdullah Rashwan's avatar
Abdullah Rashwan committed
117
118
119
120
121
          tf.identity,
          tf.map_fn(
              self._build_inputs,
              elems=images,
              fn_output_signature=(images_spec, dict(anchor_shapes),
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
122
                                   image_info_spec),
Abdullah Rashwan's avatar
Abdullah Rashwan committed
123
124
              parallel_iterations=32))

125
126
127
128
129
130
131
132
133
134
135
136
      return images, anchor_boxes, image_info

  def serve(self, images: tf.Tensor):
    """Cast image to float and run inference.

    Args:
      images: uint8 Tensor of shape [batch_size, None, None, 3]
    Returns:
      Tensor holding detection output logits.
    """

    images, anchor_boxes, image_info = self.preprocess(images)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
137
138
    input_image_shape = image_info[:, 1, :]

139
140
141
    # To overcome keras.Model extra limitation to save a model with layers that
    # have multiple inputs, we use `model.call` here to trigger the forward
    # path. Note that, this disables some keras magics happens in `__call__`.
Hongkun Yu's avatar
Hongkun Yu committed
142
    detections = self.model.call(
Abdullah Rashwan's avatar
Abdullah Rashwan committed
143
        images=images,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
144
        image_shape=input_image_shape,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
145
146
147
        anchor_boxes=anchor_boxes,
        training=False)

Fan Yang's avatar
Fan Yang committed
148
    if self.params.task.model.detection_generator.apply_nms:
149
150
151
152
153
154
155
156
157
      # For RetinaNet model, apply export_config.
      # TODO(huizhongc): Add export_config to fasterrcnn and maskrcnn as needed.
      if isinstance(self.params.task.model, configs.retinanet.RetinaNet):
        export_config = self.params.task.export_config
        # Normalize detection box coordinates to [0, 1].
        if export_config.output_normalized_coordinates:
          detection_boxes = (
              detections['detection_boxes'] /
              tf.tile(image_info[:, 2:3, :], [1, 1, 2]))
Abdullah Rashwan's avatar
Abdullah Rashwan committed
158
          detections['detection_boxes'] = box_ops.normalize_boxes(
159
160
161
162
163
164
165
166
167
168
169
170
              detection_boxes, image_info[:, 0:1, :])

        # Cast num_detections and detection_classes to float. This allows the
        # model inference to work on chain (go/chain) as chain requires floating
        # point outputs.
        if export_config.cast_num_detections_to_float:
          detections['num_detections'] = tf.cast(
              detections['num_detections'], dtype=tf.float32)
        if export_config.cast_detection_classes_to_float:
          detections['detection_classes'] = tf.cast(
              detections['detection_classes'], dtype=tf.float32)

Fan Yang's avatar
Fan Yang committed
171
172
173
174
175
176
177
178
179
      final_outputs = {
          'detection_boxes': detections['detection_boxes'],
          'detection_scores': detections['detection_scores'],
          'detection_classes': detections['detection_classes'],
          'num_detections': detections['num_detections']
      }
    else:
      final_outputs = {
          'decoded_boxes': detections['decoded_boxes'],
Xianzhi Du's avatar
Xianzhi Du committed
180
          'decoded_box_scores': detections['decoded_box_scores']
Fan Yang's avatar
Fan Yang committed
181
182
      }

Abdullah Rashwan's avatar
Abdullah Rashwan committed
183
    if 'detection_masks' in detections.keys():
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
184
      final_outputs['detection_masks'] = detections['detection_masks']
Abdullah Rashwan's avatar
Abdullah Rashwan committed
185

Fan Yang's avatar
Fan Yang committed
186
    final_outputs.update({'image_info': image_info})
Abdullah Rashwan's avatar
Abdullah Rashwan committed
187
    return final_outputs