"docs/vscode:/vscode.git/clone" did not exist on "6f3cf1297e7600f4b2ba8dd3af3a5cc2e33de6ef"
Commit 784cbc7d authored by Simon Geisler's avatar Simon Geisler
Browse files

fix random erase probability

parent fee3ca79
...@@ -1358,8 +1358,8 @@ class RandomErasing(ImageAugment): ...@@ -1358,8 +1358,8 @@ class RandomErasing(ImageAugment):
tf.Tensor: The augmented version of `image`. tf.Tensor: The augmented version of `image`.
""" """
uniform_random = tf.random.uniform(shape=[], minval=0., maxval=1.0) uniform_random = tf.random.uniform(shape=[], minval=0., maxval=1.0)
mirror_cond = tf.less(uniform_random, .5) mirror_cond = tf.less(uniform_random, self._probability)
tf.cond(mirror_cond, lambda: self._erase(image), lambda: image) image = tf.cond(mirror_cond, lambda: self._erase(image), lambda: image)
return image return image
@tf.function @tf.function
......
...@@ -263,7 +263,7 @@ class RandomErasingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -263,7 +263,7 @@ class RandomErasingTest(tf.test.TestCase, parameterized.TestCase):
aug_image = augmenter.distort(image) aug_image = augmenter.distort(image)
self.assertEqual((224, 224, 3), aug_image.shape) self.assertEqual((224, 224, 3), aug_image.shape)
self.assertLess(0, tf.reduce_max(aug_image.shape)) self.assertNotEqual(0, tf.reduce_max(aug_image))
class MixupAndCutmixTest(tf.test.TestCase, parameterized.TestCase): class MixupAndCutmixTest(tf.test.TestCase, parameterized.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment