solarization.py 4.86 KB
Newer Older
zhanggzh's avatar
zhanggzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf

from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
    BaseImageAugmentationLayer,
)
from keras_cv.utils import preprocessing


@tf.keras.utils.register_keras_serializable(package="keras_cv")
class Solarization(BaseImageAugmentationLayer):
    """Applies (max_value - pixel + min_value) for each pixel in the image.

    When created without `threshold` parameter, the layer performs solarization to
    all values. When created with specified `threshold` the layer only augments
    pixels that are above the `threshold` value

    Reference:
    - [AutoAugment: Learning Augmentation Policies from Data](
        https://arxiv.org/abs/1805.09501
    )
    - [RandAugment](https://arxiv.org/pdf/1909.13719.pdf)

    Args:
        value_range: a tuple or a list of two elements. The first value represents
            the lower bound for values in passed images, the second represents the
            upper bound. Images passed to the layer should have values within
            `value_range`.
        addition_factor: (Optional)  A tuple of two floats, a single float or a
            `keras_cv.FactorSampler`. For each augmented image a value is sampled
            from the provided range. If a float is passed, the range is interpreted as
            `(0, addition_factor)`. If specified, this value is added to each pixel
            before solarization and thresholding.  The addition value should be scaled
            according to the value range (0, 255). Defaults to 0.0.
        threshold_factor: (Optional)  A tuple of two floats, a single float or a
            `keras_cv.FactorSampler`. For each augmented image a value is sampled
            from the provided range. If a float is passed, the range is interpreted as
            `(0, threshold_factor)`. If specified, only pixel values above this
            threshold will be solarized.
        seed: Integer. Used to create a random seed.

    Usage:
    ```python
    (images, labels), _ = tf.keras.datasets.cifar10.load_data()
    print(images[0, 0, 0])
    # [59 62 63]
    # Note that images are Tensor with values in the range [0, 255]
    solarization = Solarization()
    images = solarization(images)
    print(images[0, 0, 0])
    # [196, 193, 192]
    ```

    Call arguments:
        images: Tensor of type int or float, with pixels in
            range [0, 255] and shape [batch, height, width, channels]
            or [height, width, channels].
    """

    def __init__(
        self,
        value_range,
        addition_factor=0.0,
        threshold_factor=0.0,
        seed=None,
        **kwargs
    ):
        super().__init__(seed=seed, **kwargs)
        self.seed = seed
        self.addition_factor = preprocessing.parse_factor(
            addition_factor, max_value=255, seed=seed, param_name="addition_factor"
        )
        self.threshold_factor = preprocessing.parse_factor(
            threshold_factor, max_value=255, seed=seed, param_name="threshold_factor"
        )
        self.value_range = value_range

    def get_random_transformation(self, **kwargs):
        return (self.addition_factor(), self.threshold_factor())

    def augment_image(self, image, transformation=None, **kwargs):
        (addition, threshold) = transformation
        image = preprocessing.transform_value_range(
            image, original_range=self.value_range, target_range=(0, 255)
        )
        result = image + addition
        result = tf.clip_by_value(result, 0, 255)
        result = tf.where(result < threshold, result, 255 - result)
        result = preprocessing.transform_value_range(
            result, original_range=(0, 255), target_range=self.value_range
        )
        return result

    def augment_bounding_boxes(self, bounding_boxes, **kwargs):
        return bounding_boxes

    def augment_label(self, label, transformation=None, **kwargs):
        return label

    def augment_segmentation_mask(self, segmentation_mask, transformation, **kwargs):
        return segmentation_mask

    def get_config(self):
        config = {
            "threshold_factor": self.threshold_factor,
            "addition_factor": self.addition_factor,
            "value_range": self.value_range,
            "seed": self.seed,
        }
        base_config = super().get_config()
        return dict(list(base_config.items()) + list(config.items()))