random_flip.py 6.83 KB
Newer Older
zhanggzh's avatar
zhanggzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf

from keras_cv import bounding_box
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
    BaseImageAugmentationLayer,
)

# In order to support both unbatched and batched inputs, the horizontal
# and vertical axis is reverse indexed
H_AXIS = -3
W_AXIS = -2

# Defining modes for random flipping
HORIZONTAL = "horizontal"
VERTICAL = "vertical"
HORIZONTAL_AND_VERTICAL = "horizontal_and_vertical"


@tf.keras.utils.register_keras_serializable(package="keras_cv")
class RandomFlip(BaseImageAugmentationLayer):
    """A preprocessing layer which randomly flips images during training.

    This layer will flip the images horizontally and or vertically based on the
    `mode` attribute. During inference time, the output will be identical to
    input. Call the layer with `training=True` to flip the input.

    Input shape:
      3D (unbatched) or 4D (batched) tensor with shape:
      `(..., height, width, channels)`, in `"channels_last"` format.

    Output shape:
      3D (unbatched) or 4D (batched) tensor with shape:
      `(..., height, width, channels)`, in `"channels_last"` format.

    Arguments:
      mode: String indicating which flip mode to use. Can be `"horizontal"`,
        `"vertical"`, or `"horizontal_and_vertical"`. Defaults to
        `"horizontal"`. `"horizontal"` is a left-right flip and `"vertical"` is
        a top-bottom flip.
      seed: Integer. Used to create a random seed.
    """

    def __init__(self, mode=HORIZONTAL, seed=None, bounding_box_format=None, **kwargs):
        super().__init__(seed=seed, force_generator=True, **kwargs)
        self.mode = mode
        self.seed = seed
        if mode == HORIZONTAL:
            self.horizontal = True
            self.vertical = False
        elif mode == VERTICAL:
            self.horizontal = False
            self.vertical = True
        elif mode == HORIZONTAL_AND_VERTICAL:
            self.horizontal = True
            self.vertical = True
        else:
            raise ValueError(
                "RandomFlip layer {name} received an unknown mode="
                "{arg}".format(name=self.name, arg=mode)
            )
        self.auto_vectorize = True
        self.bounding_box_format = bounding_box_format

    def augment_label(self, label, transformation, **kwargs):
        return label

    def augment_image(self, image, transformation, **kwargs):
        return RandomFlip._flip_image(image, transformation)

    def get_random_transformation(self, **kwargs):
        flip_horizontal = False
        flip_vertical = False
        if self.horizontal:
            flip_horizontal = self._random_generator.random_uniform(shape=[]) > 0.5
        if self.vertical:
            flip_vertical = self._random_generator.random_uniform(shape=[]) > 0.5
        return {
            "flip_horizontal": tf.cast(flip_horizontal, dtype=tf.bool),
            "flip_vertical": tf.cast(flip_vertical, dtype=tf.bool),
        }

    def _flip_image(image, transformation):
        flipped_output = tf.cond(
            transformation["flip_horizontal"],
            lambda: tf.image.flip_left_right(image),
            lambda: image,
        )
        flipped_output = tf.cond(
            transformation["flip_vertical"],
            lambda: tf.image.flip_up_down(flipped_output),
            lambda: flipped_output,
        )
        flipped_output.set_shape(image.shape)
        return flipped_output

    def _flip_bounding_boxes_horizontal(bounding_boxes):
        x1, x2, x3, x4, rest = tf.split(
            bounding_boxes, [1, 1, 1, 1, bounding_boxes.shape[-1] - 4], axis=-1
        )
        output = tf.stack(
            [
                1 - x3,
                x2,
                1 - x1,
                x4,
                rest,
            ],
            axis=-1,
        )
        output = tf.squeeze(output, axis=1)
        return output

    def _flip_bounding_boxes_vertical(bounding_boxes):
        x1, x2, x3, x4, rest = tf.split(
            bounding_boxes, [1, 1, 1, 1, bounding_boxes.shape[-1] - 4], axis=-1
        )
        output = tf.stack(
            [
                x1,
                1 - x4,
                x3,
                1 - x2,
                rest,
            ],
            axis=-1,
        )
        output = tf.squeeze(output, axis=1)
        return output

    def augment_bounding_boxes(
        self, bounding_boxes, transformation=None, image=None, **kwargs
    ):
        if self.bounding_box_format is None:
            raise ValueError(
                "`RandomFlip()` was called with bounding boxes,"
                "but no `bounding_box_format` was specified in the constructor."
                "Please specify a bounding box format in the constructor. i.e."
                "`RandomFlip(bounding_box_format='xyxy')`"
            )

        bounding_boxes = bounding_box.convert_format(
            bounding_boxes,
            source=self.bounding_box_format,
            target="rel_xyxy",
            images=image,
        )
        bounding_boxes = tf.cond(
            transformation["flip_horizontal"],
            lambda: RandomFlip._flip_bounding_boxes_horizontal(bounding_boxes),
            lambda: bounding_boxes,
        )
        bounding_boxes = tf.cond(
            transformation["flip_vertical"],
            lambda: RandomFlip._flip_bounding_boxes_vertical(bounding_boxes),
            lambda: bounding_boxes,
        )
        bounding_boxes = bounding_box.clip_to_image(
            bounding_boxes,
            bounding_box_format="rel_xyxy",
            images=image,
        )
        bounding_boxes = bounding_box.convert_format(
            bounding_boxes,
            source="rel_xyxy",
            target=self.bounding_box_format,
            dtype=self.compute_dtype,
            images=image,
        )
        return bounding_boxes

    def augment_segmentation_mask(
        self, segmentation_mask, transformation=None, **kwargs
    ):
        return RandomFlip._flip_image(segmentation_mask, transformation)

    def compute_output_shape(self, input_shape):
        return input_shape

    def get_config(self):
        config = {
            "mode": self.mode,
            "seed": self.seed,
            "bounding_box_format": self.bounding_box_format,
        }
        base_config = super().get_config()
        return dict(list(base_config.items()) + list(config.items()))