"...git@developer.sourcefind.cn:OpenDAS/torch-scatter.git" did not exist on "c8348c052165446d40df90d340e71704534bedc5"
Commit 371f9419 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 426075338
parent 7785dec0
...@@ -2284,8 +2284,9 @@ class MixupAndCutmix: ...@@ -2284,8 +2284,9 @@ class MixupAndCutmix:
lambda x: _fill_rectangle(*x), lambda x: _fill_rectangle(*x),
(images, random_center_width, random_center_height, cut_width // 2, (images, random_center_width, random_center_height, cut_width // 2,
cut_height // 2, tf.reverse(images, [0])), cut_height // 2, tf.reverse(images, [0])),
dtype=(tf.float32, tf.int32, tf.int32, tf.int32, tf.int32, tf.float32), dtype=(
fn_output_signature=tf.TensorSpec(images.shape[1:], dtype=tf.float32)) images.dtype, tf.int32, tf.int32, tf.int32, tf.int32, images.dtype),
fn_output_signature=tf.TensorSpec(images.shape[1:], dtype=images.dtype))
return images, labels, lam return images, labels, lam
...@@ -2294,7 +2295,8 @@ class MixupAndCutmix: ...@@ -2294,7 +2295,8 @@ class MixupAndCutmix:
lam = MixupAndCutmix._sample_from_beta(self.mixup_alpha, self.mixup_alpha, lam = MixupAndCutmix._sample_from_beta(self.mixup_alpha, self.mixup_alpha,
labels.shape) labels.shape)
lam = tf.reshape(lam, [-1, 1, 1, 1]) lam = tf.reshape(lam, [-1, 1, 1, 1])
images = lam * images + (1. - lam) * tf.reverse(images, [0]) lam_cast = tf.cast(lam, dtype=images.dtype)
images = lam_cast * images + (1. - lam_cast) * tf.reverse(images, [0])
return images, labels, tf.squeeze(lam) return images, labels, tf.squeeze(lam)
......
...@@ -366,14 +366,19 @@ class RandomErasingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -366,14 +366,19 @@ class RandomErasingTest(tf.test.TestCase, parameterized.TestCase):
self.assertNotEqual(0, tf.reduce_max(aug_image)) self.assertNotEqual(0, tf.reduce_max(aug_image))
class MixupAndCutmixTest(tf.test.TestCase, parameterized.TestCase): @parameterized.named_parameters([
('float16_images', tf.float16),
def test_mixup_and_cutmix_smoothes_labels(self): ('bfloat16_images', tf.bfloat16),
('float32_images', tf.float32),
])
class MixupAndCutmixTest(parameterized.TestCase, tf.test.TestCase):
def test_mixup_and_cutmix_smoothes_labels(self, image_dtype):
batch_size = 12 batch_size = 12
num_classes = 1000 num_classes = 1000
label_smoothing = 0.1 label_smoothing = 0.1
images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32) images = tf.random.normal((batch_size, 224, 224, 3), dtype=image_dtype)
labels = tf.range(batch_size) labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix( augmenter = augment.MixupAndCutmix(
num_classes=num_classes, label_smoothing=label_smoothing) num_classes=num_classes, label_smoothing=label_smoothing)
...@@ -388,12 +393,12 @@ class MixupAndCutmixTest(tf.test.TestCase, parameterized.TestCase): ...@@ -388,12 +393,12 @@ class MixupAndCutmixTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes - self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes -
1e4) # With tolerance 1e4) # With tolerance
def test_mixup_changes_image(self): def test_mixup_changes_image(self, image_dtype):
batch_size = 12 batch_size = 12
num_classes = 1000 num_classes = 1000
label_smoothing = 0.1 label_smoothing = 0.1
images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32) images = tf.random.normal((batch_size, 224, 224, 3), dtype=image_dtype)
labels = tf.range(batch_size) labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix( augmenter = augment.MixupAndCutmix(
mixup_alpha=1., cutmix_alpha=0., num_classes=num_classes) mixup_alpha=1., cutmix_alpha=0., num_classes=num_classes)
...@@ -409,12 +414,12 @@ class MixupAndCutmixTest(tf.test.TestCase, parameterized.TestCase): ...@@ -409,12 +414,12 @@ class MixupAndCutmixTest(tf.test.TestCase, parameterized.TestCase):
1e4) # With tolerance 1e4) # With tolerance
self.assertFalse(tf.math.reduce_all(images == aug_images)) self.assertFalse(tf.math.reduce_all(images == aug_images))
def test_cutmix_changes_image(self): def test_cutmix_changes_image(self, image_dtype):
batch_size = 12 batch_size = 12
num_classes = 1000 num_classes = 1000
label_smoothing = 0.1 label_smoothing = 0.1
images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32) images = tf.random.normal((batch_size, 224, 224, 3), dtype=image_dtype)
labels = tf.range(batch_size) labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix( augmenter = augment.MixupAndCutmix(
mixup_alpha=0., cutmix_alpha=1., num_classes=num_classes) mixup_alpha=0., cutmix_alpha=1., num_classes=num_classes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment