Commit f51452cb authored by Xianzhi Du's avatar Xianzhi Du Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 434795814
parent 068d3a39
...@@ -1663,6 +1663,7 @@ class AutoAugment(ImageAugment): ...@@ -1663,6 +1663,7 @@ class AutoAugment(ImageAugment):
tf_policies = self._make_tf_policies() tf_policies = self._make_tf_policies()
image, _ = select_and_apply_random_policy(tf_policies, image, bboxes=None) image, _ = select_and_apply_random_policy(tf_policies, image, bboxes=None)
image = tf.cast(image, dtype=input_image_type)
return image return image
def distort_with_boxes(self, image: tf.Tensor, def distort_with_boxes(self, image: tf.Tensor,
...@@ -2259,7 +2260,7 @@ class MixupAndCutmix: ...@@ -2259,7 +2260,7 @@ class MixupAndCutmix:
labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
"""Apply cutmix.""" """Apply cutmix."""
lam = MixupAndCutmix._sample_from_beta(self.cutmix_alpha, self.cutmix_alpha, lam = MixupAndCutmix._sample_from_beta(self.cutmix_alpha, self.cutmix_alpha,
labels.shape) tf.shape(labels))
ratio = tf.math.sqrt(1 - lam) ratio = tf.math.sqrt(1 - lam)
...@@ -2284,17 +2285,19 @@ class MixupAndCutmix: ...@@ -2284,17 +2285,19 @@ class MixupAndCutmix:
lambda x: _fill_rectangle(*x), lambda x: _fill_rectangle(*x),
(images, random_center_width, random_center_height, cut_width // 2, (images, random_center_width, random_center_height, cut_width // 2,
cut_height // 2, tf.reverse(images, [0])), cut_height // 2, tf.reverse(images, [0])),
dtype=(tf.float32, tf.int32, tf.int32, tf.int32, tf.int32, tf.float32), dtype=(
fn_output_signature=tf.TensorSpec(images.shape[1:], dtype=tf.float32)) images.dtype, tf.int32, tf.int32, tf.int32, tf.int32, images.dtype),
fn_output_signature=tf.TensorSpec(images.shape[1:], dtype=images.dtype))
return images, labels, lam return images, labels, lam
def _mixup(self, images: tf.Tensor, def _mixup(self, images: tf.Tensor,
labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
lam = MixupAndCutmix._sample_from_beta(self.mixup_alpha, self.mixup_alpha, lam = MixupAndCutmix._sample_from_beta(self.mixup_alpha, self.mixup_alpha,
labels.shape) tf.shape(labels))
lam = tf.reshape(lam, [-1, 1, 1, 1]) lam = tf.reshape(lam, [-1, 1, 1, 1])
images = lam * images + (1. - lam) * tf.reverse(images, [0]) lam_cast = tf.cast(lam, dtype=images.dtype)
images = lam_cast * images + (1. - lam_cast) * tf.reverse(images, [0])
return images, labels, tf.squeeze(lam) return images, labels, tf.squeeze(lam)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment