coco.py 5.97 KB
Newer Older
zhanggzh's avatar
zhanggzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""COCO data loader for DETR."""

import dataclasses
from typing import Optional, Tuple
import tensorflow as tf

from official.core import config_definitions as cfg
from official.core import input_reader
from official.vision.ops import box_ops
from official.vision.ops import preprocess_ops


@dataclasses.dataclass
class COCODataConfig(cfg.DataConfig):
  """Data config for COCO."""
  output_size: Tuple[int, int] = (1333, 1333)
  max_num_boxes: int = 100
  resize_scales: Tuple[int, ...] = (
      480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)


class COCODataLoader():
  """A class to load dataset for COCO detection task."""

  def __init__(self, params: COCODataConfig):
    self._params = params

  def preprocess(self, inputs):
    """Preprocess COCO for DETR."""
    image = inputs['image']
    boxes = inputs['objects']['bbox']
    classes = inputs['objects']['label'] + 1
    is_crowd = inputs['objects']['is_crowd']

    image = preprocess_ops.normalize_image(image)
    if self._params.is_training:
      image, boxes, _ = preprocess_ops.random_horizontal_flip(image, boxes)

      do_crop = tf.greater(tf.random.uniform([]), 0.5)
      if do_crop:
        # Rescale
        boxes = box_ops.denormalize_boxes(boxes, tf.shape(image)[:2])
        index = tf.random.categorical(tf.zeros([1, 3]), 1)[0]
        scales = tf.gather([400.0, 500.0, 600.0], index, axis=0)
        short_side = scales[0]
        image, image_info = preprocess_ops.resize_image(image, short_side)
        boxes = preprocess_ops.resize_and_crop_boxes(boxes,
                                                     image_info[2, :],
                                                     image_info[1, :],
                                                     image_info[3, :])
        boxes = box_ops.normalize_boxes(boxes, image_info[1, :])

        # Do croping
        shape = tf.cast(image_info[1], dtype=tf.int32)
        h = tf.random.uniform(
            [], 384, tf.math.minimum(shape[0], 600), dtype=tf.int32)
        w = tf.random.uniform(
            [], 384, tf.math.minimum(shape[1], 600), dtype=tf.int32)
        i = tf.random.uniform([], 0, shape[0] - h + 1, dtype=tf.int32)
        j = tf.random.uniform([], 0, shape[1] - w + 1, dtype=tf.int32)
        image = tf.image.crop_to_bounding_box(image, i, j, h, w)
        boxes = tf.clip_by_value(
            (boxes[..., :] * tf.cast(
                tf.stack([shape[0], shape[1], shape[0], shape[1]]),
                dtype=tf.float32) -
             tf.cast(tf.stack([i, j, i, j]), dtype=tf.float32)) /
            tf.cast(tf.stack([h, w, h, w]), dtype=tf.float32), 0.0, 1.0)
      scales = tf.constant(
          self._params.resize_scales,
          dtype=tf.float32)
      index = tf.random.categorical(tf.zeros([1, 11]), 1)[0]
      scales = tf.gather(scales, index, axis=0)
    else:
      scales = tf.constant([self._params.resize_scales[-1]], tf.float32)

    image_shape = tf.shape(image)[:2]
    boxes = box_ops.denormalize_boxes(boxes, image_shape)
    gt_boxes = boxes
    short_side = scales[0]
    image, image_info = preprocess_ops.resize_image(
        image,
        short_side,
        max(self._params.output_size))
    boxes = preprocess_ops.resize_and_crop_boxes(boxes,
                                                 image_info[2, :],
                                                 image_info[1, :],
                                                 image_info[3, :])
    boxes = box_ops.normalize_boxes(boxes, image_info[1, :])

    # Filters out ground truth boxes that are all zeros.
    indices = box_ops.get_non_empty_box_indices(boxes)
    boxes = tf.gather(boxes, indices)
    classes = tf.gather(classes, indices)
    is_crowd = tf.gather(is_crowd, indices)
    boxes = box_ops.yxyx_to_cycxhw(boxes)

    image = tf.image.pad_to_bounding_box(
        image, 0, 0, self._params.output_size[0], self._params.output_size[1])
    labels = {
        'classes':
            preprocess_ops.clip_or_pad_to_fixed_size(
                classes, self._params.max_num_boxes),
        'boxes':
            preprocess_ops.clip_or_pad_to_fixed_size(
                boxes, self._params.max_num_boxes)
    }
    if not self._params.is_training:
      labels.update({
          'id':
              inputs['image/id'],
          'image_info':
              image_info,
          'is_crowd':
              preprocess_ops.clip_or_pad_to_fixed_size(
                  is_crowd, self._params.max_num_boxes),
          'gt_boxes':
              preprocess_ops.clip_or_pad_to_fixed_size(
                  gt_boxes, self._params.max_num_boxes),
      })

    return image, labels

  def _transform_and_batch_fn(
      self,
      dataset,
      input_context: Optional[tf.distribute.InputContext] = None):
    """Preprocess and batch."""
    dataset = dataset.map(
        self.preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    per_replica_batch_size = input_context.get_per_replica_batch_size(
        self._params.global_batch_size
    ) if input_context else self._params.global_batch_size
    dataset = dataset.batch(
        per_replica_batch_size, drop_remainder=self._params.drop_remainder)
    return dataset

  def load(self, input_context: Optional[tf.distribute.InputContext] = None):
    """Returns a tf.dataset.Dataset."""
    reader = input_reader.InputReader(
        params=self._params,
        decoder_fn=None,
        transform_and_batch_fn=self._transform_and_batch_fn)
    return reader.read(input_context)