yolo_input.py 14.3 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Vishnu Banna's avatar
Vishnu Banna committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Detection Data parser and processing for YOLO."""
Vishnu Banna's avatar
config  
Vishnu Banna committed
16
import tensorflow as tf
17

Abdullah Rashwan's avatar
Abdullah Rashwan committed
18
19
from official.projects.yolo.ops import anchor
from official.projects.yolo.ops import preprocessing_ops
Abdullah Rashwan's avatar
Abdullah Rashwan committed
20
21
22
23
from official.vision.dataloaders import parser
from official.vision.dataloaders import utils
from official.vision.ops import box_ops as bbox_ops
from official.vision.ops import preprocess_ops
Vishnu Banna's avatar
config  
Vishnu Banna committed
24
25
26


class Parser(parser.Parser):
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
  """Parse the dataset in to the YOLO model format."""

  def __init__(self,
               output_size,
               anchors,
               expanded_strides,
               level_limits=None,
               max_num_instances=200,
               area_thresh=0.1,
               aug_rand_hue=1.0,
               aug_rand_saturation=1.0,
               aug_rand_brightness=1.0,
               letter_box=False,
               random_pad=True,
               random_flip=True,
               jitter=0.0,
               aug_scale_min=1.0,
               aug_scale_max=1.0,
               aug_rand_translate=0.0,
               aug_rand_perspective=0.0,
               aug_rand_angle=0.0,
               anchor_t=4.0,
               scale_xy=None,
               best_match_only=False,
               darknet=False,
               use_tie_breaker=True,
               dtype='float32',
               seed=None):
Vishnu Banna's avatar
config  
Vishnu Banna committed
55
56
57
58
59
    """Initializes parameters for parsing annotations in the dataset.

    Args:
      output_size: `Tensor` or `List` for [height, width] of output image. The
        output_size should be divided by the largest feature stride 2^max_level.
60
61
62
63
64
65
66
67
68
      anchors: `Dict[List[Union[int, float]]]` of anchor boxes to be bes used in
        each level.
      expanded_strides: `Dict[int]` for how much the model scales down the
        images at the largest level. For example, level 3 down samples the image
        by a factor of 16, in the expanded strides dictionary, we will pass
        along {3: 16} indicating that relative to the original image, the shapes
          must be reduced by a factor of 16 to compute the loss.
      level_limits: `List` the box sizes that will be allowed at each FPN level
        as is done in the FCOS and YOLOX paper for anchor free box assignment.
Vishnu Banna's avatar
config  
Vishnu Banna committed
69
      max_num_instances: `int` for the number of boxes to compute loss on.
70
      area_thresh: `float` for the minimum area of a box to allow to pass
Vishnu Banna's avatar
config  
Vishnu Banna committed
71
        through for optimization.
72
73
74
      aug_rand_hue: `float` indicating the maximum scaling value for hue.
        saturation will be scaled between 1 - value and 1 + value.
      aug_rand_saturation: `float` indicating the maximum scaling value for
Vishnu Banna's avatar
config  
Vishnu Banna committed
75
        saturation. saturation will be scaled between 1/value and value.
76
      aug_rand_brightness: `float` indicating the maximum scaling value for
Vishnu Banna's avatar
config  
Vishnu Banna committed
77
        brightness. brightness will be scaled between 1/value and value.
78
      letter_box: `boolean` indicating whether upon start of the data pipeline
79
80
81
        regardless of the preprocessing ops that are used, the aspect ratio of
        the images should be preserved.
      random_pad: `bool` indiccating wether to use padding to apply random
82
        translation, true for darknet yolo false for scaled yolo.
83
84
85
86
87
88
89
      random_flip: `boolean` indicating whether or not to randomly flip the
        image horizontally.
      jitter: `float` for the maximum change in aspect ratio expected in each
        preprocessing step.
      aug_scale_min: `float` indicating the minimum scaling value for image
        scale jitter.
      aug_scale_max: `float` indicating the maximum scaling value for image
Vishnu Banna's avatar
config  
Vishnu Banna committed
90
        scale jitter.
91
      aug_rand_translate: `float` ranging from 0 to 1 indicating the maximum
Vishnu Banna's avatar
config  
Vishnu Banna committed
92
        amount to randomly translate an image.
93
94
95
96
97
      aug_rand_perspective: `float` ranging from 0.000 to 0.001 indicating how
        much to prespective warp the image.
      aug_rand_angle: `float` indicating the maximum angle value for angle.
        angle will be changes between 0 and value.
      anchor_t: `float` indicating the threshold over which an anchor will be
Vishnu Banna's avatar
config  
Vishnu Banna committed
98
        considered for prediction, at zero, all the anchors will be used and at
99
100
101
102
103
104
105
106
107
        1.0 only the best will be used. for anchor thresholds larger than 1.0 we
        stop using the IOU for anchor comparison and resort directly to
        comparing the width and height, this is used for the scaled models.
      scale_xy: dictionary `float` values inidcating how far each pixel can see
        outside of its containment of 1.0. a value of 1.2 indicates there is a
        20% extended radius around each pixel that this specific pixel can
        predict values for a center at. the center can range from 0 - value/2 to
        1 + value/2, this value is set in the yolo filter, and resused here.
        there should be one value for scale_xy for each level from min_level to
Vishnu Banna's avatar
config  
Vishnu Banna committed
108
        max_level.
109
110
111
112
113
      best_match_only: `boolean` indicating how boxes are selected for
        optimization.
      darknet: `boolean` indicating which data pipeline to use. Setting to True
        swaps the pipeline to output images realtive to Yolov4 and older.
      use_tie_breaker: `boolean` indicating whether to use the anchor threshold
Vishnu Banna's avatar
config  
Vishnu Banna committed
114
        value.
115
      dtype: `str` indicating the output datatype of the datapipeline selecting
Vishnu Banna's avatar
config  
Vishnu Banna committed
116
        from {"float32", "float16", "bfloat16"}.
117
      seed: `int` the seed for random number generation.
Vishnu Banna's avatar
config  
Vishnu Banna committed
118
    """
Vishnu Banna's avatar
Vishnu Banna committed
119
    for key in anchors:
Vishnu Banna's avatar
config  
Vishnu Banna committed
120
      # Assert that the width and height is viable
Vishnu Banna's avatar
Vishnu Banna committed
121
122
      assert output_size[1] % expanded_strides[str(key)] == 0
      assert output_size[0] % expanded_strides[str(key)] == 0
Vishnu Banna's avatar
config  
Vishnu Banna committed
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

    # Set the width and height properly and base init:
    self._image_w = output_size[1]
    self._image_h = output_size[0]
    self._max_num_instances = max_num_instances

    # Image scaling params
    self._jitter = 0.0 if jitter is None else jitter
    self._aug_scale_min = aug_scale_min
    self._aug_scale_max = aug_scale_max
    self._aug_rand_translate = aug_rand_translate
    self._aug_rand_perspective = aug_rand_perspective

    # Image spatial distortion
    self._random_flip = random_flip
    self._letter_box = letter_box
    self._random_pad = random_pad
    self._aug_rand_angle = aug_rand_angle

    # Color space distortion of the image
    self._aug_rand_saturation = aug_rand_saturation
    self._aug_rand_brightness = aug_rand_brightness
    self._aug_rand_hue = aug_rand_hue

    # Set the per level values needed for operation
    self._darknet = darknet
    self._area_thresh = area_thresh
Vishnu Banna's avatar
Vishnu Banna committed
150
    self._level_limits = level_limits
Vishnu Banna's avatar
config  
Vishnu Banna committed
151
152
153
154

    self._seed = seed
    self._dtype = dtype

Vishnu Banna's avatar
Vishnu Banna committed
155
    self._label_builder = anchor.YoloAnchorLabeler(
156
157
158
159
160
161
162
163
164
165
        anchors=anchors,
        anchor_free_level_limits=level_limits,
        level_strides=expanded_strides,
        center_radius=scale_xy,
        max_num_instances=max_num_instances,
        match_threshold=anchor_t,
        best_matches_only=best_match_only,
        use_tie_breaker=use_tie_breaker,
        darknet=darknet,
        dtype=dtype)
Vishnu Banna's avatar
Vishnu Banna committed
166
167
168

  def _pad_infos_object(self, image):
    """Get a Tensor to pad the info object list."""
Vishnu Banna's avatar
config  
Vishnu Banna committed
169
170
171
172
173
174
175
176
177
178
179
180
    shape_ = tf.shape(image)
    val = tf.stack([
        tf.cast(shape_[:2], tf.float32),
        tf.cast(shape_[:2], tf.float32),
        tf.ones_like(tf.cast(shape_[:2], tf.float32)),
        tf.zeros_like(tf.cast(shape_[:2], tf.float32)),
    ])
    return val

  def _jitter_scale(self, image, shape, letter_box, jitter, random_pad,
                    aug_scale_min, aug_scale_max, translate, angle,
                    perspective):
181
    """Distort and scale each input image."""
Vishnu Banna's avatar
Vishnu Banna committed
182
    infos = []
Vishnu Banna's avatar
config  
Vishnu Banna committed
183
184
185
186
    if (aug_scale_min != 1.0 or aug_scale_max != 1.0):
      crop_only = True
      # jitter gives you only one info object, resize and crop gives you one,
      # if crop only then there can be 1 form jitter and 1 from crop
Vishnu Banna's avatar
Vishnu Banna committed
187
      infos.append(self._pad_infos_object(image))
Vishnu Banna's avatar
config  
Vishnu Banna committed
188
189
    else:
      crop_only = False
Vishnu Banna's avatar
Vishnu Banna committed
190
    image, crop_info, _ = preprocessing_ops.resize_and_jitter_image(
Vishnu Banna's avatar
config  
Vishnu Banna committed
191
192
193
194
195
196
197
198
        image,
        shape,
        letter_box=letter_box,
        jitter=jitter,
        crop_only=crop_only,
        random_pad=random_pad,
        seed=self._seed,
    )
Vishnu Banna's avatar
Vishnu Banna committed
199
    infos.extend(crop_info)
Vishnu Banna's avatar
config  
Vishnu Banna committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
    image, _, affine = preprocessing_ops.affine_warp_image(
        image,
        shape,
        scale_min=aug_scale_min,
        scale_max=aug_scale_max,
        translate=translate,
        degrees=angle,
        perspective=perspective,
        random_pad=random_pad,
        seed=self._seed,
    )
    return image, infos, affine

  def _parse_train_data(self, data):
Vishnu Banna's avatar
Vishnu Banna committed
214
    """Parses data for training."""
Vishnu Banna's avatar
config  
Vishnu Banna committed
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233

    # Initialize the shape constants.
    image = data['image']
    boxes = data['groundtruth_boxes']
    classes = data['groundtruth_classes']

    if self._random_flip:
      # Randomly flip the image horizontally.
      image, boxes, _ = preprocess_ops.random_horizontal_flip(
          image, boxes, seed=self._seed)

    if not data['is_mosaic']:
      image, infos, affine = self._jitter_scale(
          image, [self._image_h, self._image_w], self._letter_box, self._jitter,
          self._random_pad, self._aug_scale_min, self._aug_scale_max,
          self._aug_rand_translate, self._aug_rand_angle,
          self._aug_rand_perspective)

      # Clip and clean boxes.
Vishnu Banna's avatar
Vishnu Banna committed
234
      boxes, inds = preprocessing_ops.transform_and_clip_boxes(
Vishnu Banna's avatar
config  
Vishnu Banna committed
235
236
237
238
239
          boxes,
          infos,
          affine=affine,
          shuffle_boxes=False,
          area_thresh=self._area_thresh,
240
          filter_and_clip_boxes=True,
Vishnu Banna's avatar
config  
Vishnu Banna committed
241
242
243
244
245
246
          seed=self._seed)
      classes = tf.gather(classes, inds)
      info = infos[-1]
    else:
      image = tf.image.resize(
          image, (self._image_h, self._image_w), method='nearest')
247
      output_size = tf.cast([self._image_h, self._image_w], tf.float32)
Vishnu Banna's avatar
Vishnu Banna committed
248
249
250
251
252
      boxes_ = bbox_ops.denormalize_boxes(boxes, output_size)
      inds = bbox_ops.get_non_empty_box_indices(boxes_)
      boxes = tf.gather(boxes, inds)
      classes = tf.gather(classes, inds)
      info = self._pad_infos_object(image)
Vishnu Banna's avatar
config  
Vishnu Banna committed
253
254
255

    # Apply scaling to the hue saturation and brightness of an image.
    image = tf.cast(image, dtype=self._dtype)
Vishnu Banna's avatar
Vishnu Banna committed
256
    image = image / 255.0
Vishnu Banna's avatar
config  
Vishnu Banna committed
257
258
259
260
261
262
    image = preprocessing_ops.image_rand_hsv(
        image,
        self._aug_rand_hue,
        self._aug_rand_saturation,
        self._aug_rand_brightness,
        seed=self._seed,
Vishnu Banna's avatar
Vishnu Banna committed
263
        darknet=self._darknet or self._level_limits is not None)
Vishnu Banna's avatar
config  
Vishnu Banna committed
264
265

    # Cast the image to the selcted datatype.
266
267
    image, labels = self._build_label(
        image, boxes, classes, info, inds, data, is_training=True)
Vishnu Banna's avatar
config  
Vishnu Banna committed
268
269
270
    return image, labels

  def _parse_eval_data(self, data):
Vishnu Banna's avatar
Vishnu Banna committed
271
    """Parses data for evaluation."""
Vishnu Banna's avatar
config  
Vishnu Banna committed
272
273
274
275
276
277
278

    # Get the image shape constants and cast the image to the selcted datatype.
    image = tf.cast(data['image'], dtype=self._dtype)
    boxes = data['groundtruth_boxes']
    classes = data['groundtruth_classes']

    image, infos, _ = preprocessing_ops.resize_and_jitter_image(
Vishnu Banna's avatar
Vishnu Banna committed
279
        image, [self._image_h, self._image_w],
Vishnu Banna's avatar
config  
Vishnu Banna committed
280
281
282
283
284
285
286
        letter_box=self._letter_box,
        random_pad=False,
        shiftx=0.5,
        shifty=0.5,
        jitter=0.0)

    # Clip and clean boxes.
Vishnu Banna's avatar
Vishnu Banna committed
287
    image = image / 255.0
Vishnu Banna's avatar
Vishnu Banna committed
288
    boxes, inds = preprocessing_ops.transform_and_clip_boxes(
289
290
        boxes, infos, shuffle_boxes=False, area_thresh=0.0,
        filter_and_clip_boxes=False)
Vishnu Banna's avatar
config  
Vishnu Banna committed
291
292
293
294
    classes = tf.gather(classes, inds)
    info = infos[-1]

    image, labels = self._build_label(
295
        image, boxes, classes, info, inds, data, is_training=False)
Vishnu Banna's avatar
config  
Vishnu Banna committed
296
297
    return image, labels

Vishnu Banna's avatar
Vishnu Banna committed
298
  def set_shape(self, values, pad_axis=0, pad_value=0, inds=None):
Vishnu Banna's avatar
Vishnu Banna committed
299
    """Calls set shape for all input objects."""
Vishnu Banna's avatar
config  
Vishnu Banna committed
300
301
302
303
    if inds is not None:
      values = tf.gather(values, inds)
    vshape = values.get_shape().as_list()

Vishnu Banna's avatar
Vishnu Banna committed
304
    values = preprocessing_ops.pad_max_instances(
305
        values, self._max_num_instances, pad_axis=pad_axis, pad_value=pad_value)
Vishnu Banna's avatar
config  
Vishnu Banna committed
306

307
    vshape[pad_axis] = self._max_num_instances
Vishnu Banna's avatar
config  
Vishnu Banna committed
308
309
310
311
312
313
314
315
316
317
318
    values.set_shape(vshape)
    return values

  def _build_label(self,
                   image,
                   gt_boxes,
                   gt_classes,
                   info,
                   inds,
                   data,
                   is_training=True):
319
    """Label construction for both the train and eval data."""
Vishnu Banna's avatar
Vishnu Banna committed
320
321
322
    width = self._image_w
    height = self._image_h

Vishnu Banna's avatar
config  
Vishnu Banna committed
323
324
325
326
    # Set the image shape.
    imshape = image.get_shape().as_list()
    imshape[-1] = 3
    image.set_shape(imshape)
327

Vishnu Banna's avatar
Vishnu Banna committed
328
    labels = dict()
329
330
331
    (labels['inds'], labels['upds'],
     labels['true_conf']) = self._label_builder(gt_boxes, gt_classes, width,
                                                height)
Vishnu Banna's avatar
config  
Vishnu Banna committed
332
333

    # Set/fix the boxes shape.
Vishnu Banna's avatar
Vishnu Banna committed
334
    boxes = self.set_shape(gt_boxes, pad_axis=0, pad_value=0)
Vishnu Banna's avatar
config  
Vishnu Banna committed
335
336
337
    classes = self.set_shape(gt_classes, pad_axis=0, pad_value=-1)

    # Build the dictionary set.
Vishnu Banna's avatar
Vishnu Banna committed
338
    labels.update({
Vishnu Banna's avatar
config  
Vishnu Banna committed
339
340
341
        'source_id': utils.process_source_id(data['source_id']),
        'bbox': tf.cast(boxes, dtype=self._dtype),
        'classes': tf.cast(classes, dtype=self._dtype),
Vishnu Banna's avatar
Vishnu Banna committed
342
    })
Vishnu Banna's avatar
config  
Vishnu Banna committed
343
344
345
346
347

    # Update the labels dictionary.
    if not is_training:
      # Sets up groundtruth data for evaluation.
      groundtruths = {
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
          'source_id':
              labels['source_id'],
          'height':
              data['height'],
          'width':
              data['width'],
          'num_detections':
              tf.shape(data['groundtruth_boxes'])[0],
          'image_info':
              info,
          'boxes':
              bbox_ops.denormalize_boxes(
                  data['groundtruth_boxes'],
                  tf.cast([data['height'], data['width']], gt_boxes.dtype)),
          'classes':
              data['groundtruth_classes'],
          'areas':
              data['groundtruth_area'],
366
367
          'is_crowds':
              tf.cast(tf.gather(data['groundtruth_is_crowd'], inds), tf.int32),
Vishnu Banna's avatar
config  
Vishnu Banna committed
368
369
370
371
372
373
374
      }
      groundtruths['source_id'] = utils.process_source_id(
          groundtruths['source_id'])
      groundtruths = utils.pad_groundtruths_to_fixed_size(
          groundtruths, self._max_num_instances)
      labels['groundtruths'] = groundtruths
    return image, labels