Commit 0eeeaf98 authored by Simon Geisler's avatar Simon Geisler
Browse files

test: saturation

parent 784cbc7d
......@@ -638,5 +638,11 @@ def random_saturation(image: tf.Tensor, saturation: Optional[float] = 0.,
assert saturation >= 0, '`saturation` must be positive'
saturation = tf.random.uniform(
[], max(0, 1 - saturation), 1 + saturation, seed=seed, dtype=tf.float32)
return augment.blend(
tf.repeat(tf.image.rgb_to_grayscale(image), 3, axis=-1), image, saturation)
return _saturation(image, saturation)
def _saturation(image: tf.Tensor,
saturation: Optional[float] = 0.) -> tf.Tensor:
return augment.blend(tf.repeat(tf.image.rgb_to_grayscale(image), 3, axis=-1),
image,
saturation)
......@@ -235,6 +235,14 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase):
image, color_jitter, color_jitter, color_jitter)
assert jittered_image.shape == image.shape
@parameterized.parameters(
(400, 600, 0), (400, 600, 0.4), (600, 400, 1)
)
def testSaturation(self, input_height, input_width, saturation):
image = tf.convert_to_tensor(
np.random.rand(input_height, input_width, 3))
jittered_image = preprocess_ops._saturation(image, saturation)
assert jittered_image.shape == image.shape
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment