Commit 7d45e7b9 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Allow setting hyperparameters of augmentation for video classification in the...

Allow setting hyperparameters of augmentation for video classification in the config, and extend MixupAndCutmix for videos.

PiperOrigin-RevId: 459615324
parent c9a6456d
...@@ -58,7 +58,9 @@ class DataConfig(cfg.DataConfig): ...@@ -58,7 +58,9 @@ class DataConfig(cfg.DataConfig):
aug_max_aspect_ratio: float = 2.0 aug_max_aspect_ratio: float = 2.0
aug_min_area_ratio: float = 0.49 aug_min_area_ratio: float = 0.49
aug_max_area_ratio: float = 1.0 aug_max_area_ratio: float = 1.0
aug_type: Optional[str] = None # 'autoaug', 'randaug', or None aug_type: Optional[
common.Augmentation] = None # AutoAugment and RandAugment.
mixup_and_cutmix: Optional[common.MixupAndCutmix] = None
image_field_key: str = 'image/encoded' image_field_key: str = 'image/encoded'
label_field_key: str = 'clip/label/index' label_field_key: str = 'clip/label/index'
......
...@@ -286,18 +286,28 @@ class Parser(parser.Parser): ...@@ -286,18 +286,28 @@ class Parser(parser.Parser):
self._audio_feature = input_params.audio_feature self._audio_feature = input_params.audio_feature
self._audio_shape = input_params.audio_feature_shape self._audio_shape = input_params.audio_feature_shape
self._augmenter = None
if input_params.aug_type is not None:
aug_type = input_params.aug_type aug_type = input_params.aug_type
if aug_type == 'autoaug': if aug_type is not None:
if aug_type.type == 'autoaug':
logging.info('Using AutoAugment.') logging.info('Using AutoAugment.')
self._augmenter = augment.AutoAugment() self._augmenter = augment.AutoAugment(
elif aug_type == 'randaug': augmentation_name=aug_type.autoaug.augmentation_name,
cutout_const=aug_type.autoaug.cutout_const,
translate_const=aug_type.autoaug.translate_const)
elif aug_type.type == 'randaug':
logging.info('Using RandAugment.') logging.info('Using RandAugment.')
self._augmenter = augment.RandAugment() self._augmenter = augment.RandAugment(
num_layers=aug_type.randaug.num_layers,
magnitude=aug_type.randaug.magnitude,
cutout_const=aug_type.randaug.cutout_const,
translate_const=aug_type.randaug.translate_const,
prob_to_apply=aug_type.randaug.prob_to_apply,
exclude_ops=aug_type.randaug.exclude_ops)
else:
raise ValueError(
'Augmentation policy {} not supported.'.format(aug_type.type))
else: else:
raise ValueError('Augmentation policy {} is not supported.'.format( self._augmenter = None
aug_type))
def _parse_train_data( def _parse_train_data(
self, decoded_tensors: Dict[str, tf.Tensor] self, decoded_tensors: Dict[str, tf.Tensor]
......
...@@ -21,6 +21,7 @@ from PIL import Image ...@@ -21,6 +21,7 @@ from PIL import Image
import tensorflow as tf import tensorflow as tf
import tensorflow_datasets as tfds import tensorflow_datasets as tfds
from official.vision.configs import common
from official.vision.configs import video_classification as exp_cfg from official.vision.configs import video_classification as exp_cfg
from official.vision.dataloaders import video_input from official.vision.dataloaders import video_input
...@@ -173,7 +174,8 @@ class VideoAndLabelParserTest(tf.test.TestCase): ...@@ -173,7 +174,8 @@ class VideoAndLabelParserTest(tf.test.TestCase):
params.min_image_size = 224 params.min_image_size = 224
params.temporal_stride = 2 params.temporal_stride = 2
params.aug_type = 'autoaug' params.aug_type = common.Augmentation(
type='autoaug', autoaug=common.AutoAugment())
decoder = video_input.Decoder() decoder = video_input.Decoder()
parser = video_input.Parser(params).parse_fn(params.is_training) parser = video_input.Parser(params).parse_fn(params.is_training)
......
...@@ -317,7 +317,7 @@ def _fill_rectangle(image, ...@@ -317,7 +317,7 @@ def _fill_rectangle(image,
half_width, half_width,
half_height, half_height,
replace=None): replace=None):
"""Fill blank area.""" """Fills blank area."""
image_height = tf.shape(image)[0] image_height = tf.shape(image)[0]
image_width = tf.shape(image)[1] image_width = tf.shape(image)[1]
...@@ -349,6 +349,45 @@ def _fill_rectangle(image, ...@@ -349,6 +349,45 @@ def _fill_rectangle(image,
return image return image
def _fill_rectangle_video(image,
center_width,
center_height,
half_width,
half_height,
replace=None):
"""Fills blank area for video."""
image_time = tf.shape(image)[0]
image_height = tf.shape(image)[1]
image_width = tf.shape(image)[2]
lower_pad = tf.maximum(0, center_height - half_height)
upper_pad = tf.maximum(0, image_height - center_height - half_height)
left_pad = tf.maximum(0, center_width - half_width)
right_pad = tf.maximum(0, image_width - center_width - half_width)
cutout_shape = [
image_time, image_height - (lower_pad + upper_pad),
image_width - (left_pad + right_pad)
]
padding_dims = [[0, 0], [lower_pad, upper_pad], [left_pad, right_pad]]
mask = tf.pad(
tf.zeros(cutout_shape, dtype=image.dtype),
padding_dims,
constant_values=1)
mask = tf.expand_dims(mask, -1)
mask = tf.tile(mask, [1, 1, 1, 3])
if replace is None:
fill = tf.random.normal(tf.shape(image), dtype=image.dtype)
elif isinstance(replace, tf.Tensor):
fill = replace
else:
fill = tf.ones_like(image, dtype=image.dtype) * replace
image = tf.where(tf.equal(mask, 0), fill, image)
return image
def cutout_video(image: tf.Tensor, replace: int = 0) -> tf.Tensor: def cutout_video(image: tf.Tensor, replace: int = 0) -> tf.Tensor:
"""Apply cutout (https://arxiv.org/abs/1708.04552) to a video. """Apply cutout (https://arxiv.org/abs/1708.04552) to a video.
...@@ -2187,8 +2226,7 @@ class MixupAndCutmix: ...@@ -2187,8 +2226,7 @@ class MixupAndCutmix:
- Mixup: https://arxiv.org/abs/1710.09412 - Mixup: https://arxiv.org/abs/1710.09412
- Cutmix: https://arxiv.org/abs/1905.04899 - Cutmix: https://arxiv.org/abs/1905.04899
Implementation is inspired by Implementaion is inspired by https://github.com/rwightman/pytorch-image-models
https://github.com/rwightman/pytorch-image-models.
""" """
def __init__(self, def __init__(self,
...@@ -2201,15 +2239,18 @@ class MixupAndCutmix: ...@@ -2201,15 +2239,18 @@ class MixupAndCutmix:
"""Applies Mixup and/or Cutmix to a batch of images. """Applies Mixup and/or Cutmix to a batch of images.
Args: Args:
mixup_alpha: For drawing a random lambda (`lam`) from a beta distribution mixup_alpha (float, optional): For drawing a random lambda (`lam`) from a
(for each image). If zero Mixup is deactivated. Defaults to `.8`. beta distribution (for each image). If zero Mixup is deactivated.
cutmix_alpha: For drawing a random lambda (`lam`) from a beta distribution Defaults to .8.
(for each image). If zero Cutmix is deactivated. Defaults to `1.`. cutmix_alpha (float, optional): For drawing a random lambda (`lam`) from a
prob: Of augmenting the batch. Defaults to `1.0`. beta distribution (for each image). If zero Cutmix is deactivated.
switch_prob: Probability of applying Cutmix for the batch. Defaults to Defaults to 1..
`0.5`. prob (float, optional): Of augmenting the batch. Defaults to 1.0.
label_smoothing: Constant for label smoothing. Defaults to `0.1`. switch_prob (float, optional): Probability of applying Cutmix for the
num_classes: Number of classes. Defaults to `1001`. batch. Defaults to 0.5.
label_smoothing (float, optional): Constant for label smoothing. Defaults
to 0.1.
num_classes (int, optional): Number of classes. Defaults to 1001.
""" """
self.mixup_alpha = mixup_alpha self.mixup_alpha = mixup_alpha
self.cutmix_alpha = cutmix_alpha self.cutmix_alpha = cutmix_alpha
...@@ -2234,8 +2275,9 @@ class MixupAndCutmix: ...@@ -2234,8 +2275,9 @@ class MixupAndCutmix:
"""Applies Mixup and/or Cutmix to batch of images and transforms labels. """Applies Mixup and/or Cutmix to batch of images and transforms labels.
Args: Args:
images (tf.Tensor): Of shape [batch_size,height, width, 3] representing a images (tf.Tensor): Of shape [batch_size, height, width, 3] representing a
batch of image. batch of image, or [batch_size, time, height, width, 3] representing a
batch of video.
labels (tf.Tensor): Of shape [batch_size, ] representing the class id for labels (tf.Tensor): Of shape [batch_size, ] representing the class id for
each image of the batch. each image of the batch.
...@@ -2243,6 +2285,7 @@ class MixupAndCutmix: ...@@ -2243,6 +2285,7 @@ class MixupAndCutmix:
Tuple[tf.Tensor, tf.Tensor]: The augmented version of `image` and Tuple[tf.Tensor, tf.Tensor]: The augmented version of `image` and
`labels`. `labels`.
""" """
labels = tf.reshape(labels, [-1])
augment_cond = tf.less( augment_cond = tf.less(
tf.random.uniform(shape=[], minval=0., maxval=1.0), self.mix_prob) tf.random.uniform(shape=[], minval=0., maxval=1.0), self.mix_prob)
# pylint: disable=g-long-lambda # pylint: disable=g-long-lambda
...@@ -2264,14 +2307,22 @@ class MixupAndCutmix: ...@@ -2264,14 +2307,22 @@ class MixupAndCutmix:
def _cutmix(self, images: tf.Tensor, def _cutmix(self, images: tf.Tensor,
labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
"""Apply cutmix.""" """Applies cutmix."""
lam = MixupAndCutmix._sample_from_beta(self.cutmix_alpha, self.cutmix_alpha, lam = MixupAndCutmix._sample_from_beta(self.cutmix_alpha, self.cutmix_alpha,
tf.shape(labels)) tf.shape(labels))
ratio = tf.math.sqrt(1 - lam) ratio = tf.math.sqrt(1 - lam)
batch_size = tf.shape(images)[0] batch_size = tf.shape(images)[0]
if images.shape.rank == 4:
image_height, image_width = tf.shape(images)[1], tf.shape(images)[2] image_height, image_width = tf.shape(images)[1], tf.shape(images)[2]
fill_fn = _fill_rectangle
elif images.shape.rank == 5:
image_height, image_width = tf.shape(images)[2], tf.shape(images)[3]
fill_fn = _fill_rectangle_video
else:
raise ValueError('Bad image rank: {}'.format(images.shape.rank))
cut_height = tf.cast( cut_height = tf.cast(
ratio * tf.cast(image_height, dtype=tf.float32), dtype=tf.int32) ratio * tf.cast(image_height, dtype=tf.float32), dtype=tf.int32)
...@@ -2288,7 +2339,7 @@ class MixupAndCutmix: ...@@ -2288,7 +2339,7 @@ class MixupAndCutmix:
lam = tf.cast(lam, dtype=tf.float32) lam = tf.cast(lam, dtype=tf.float32)
images = tf.map_fn( images = tf.map_fn(
lambda x: _fill_rectangle(*x), lambda x: fill_fn(*x),
(images, random_center_width, random_center_height, cut_width // 2, (images, random_center_width, random_center_height, cut_width // 2,
cut_height // 2, tf.reverse(images, [0])), cut_height // 2, tf.reverse(images, [0])),
dtype=( dtype=(
...@@ -2299,9 +2350,16 @@ class MixupAndCutmix: ...@@ -2299,9 +2350,16 @@ class MixupAndCutmix:
def _mixup(self, images: tf.Tensor, def _mixup(self, images: tf.Tensor,
labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
"""Applies mixup."""
lam = MixupAndCutmix._sample_from_beta(self.mixup_alpha, self.mixup_alpha, lam = MixupAndCutmix._sample_from_beta(self.mixup_alpha, self.mixup_alpha,
tf.shape(labels)) tf.shape(labels))
if images.shape.rank == 4:
lam = tf.reshape(lam, [-1, 1, 1, 1]) lam = tf.reshape(lam, [-1, 1, 1, 1])
elif images.shape.rank == 5:
lam = tf.reshape(lam, [-1, 1, 1, 1, 1])
else:
raise ValueError('Bad image rank: {}'.format(images.shape.rank))
lam_cast = tf.cast(lam, dtype=images.dtype) lam_cast = tf.cast(lam, dtype=images.dtype)
images = lam_cast * images + (1. - lam_cast) * tf.reverse(images, [0]) images = lam_cast * images + (1. - lam_cast) * tf.reverse(images, [0])
......
...@@ -430,6 +430,68 @@ class MixupAndCutmixTest(tf.test.TestCase, parameterized.TestCase): ...@@ -430,6 +430,68 @@ class MixupAndCutmixTest(tf.test.TestCase, parameterized.TestCase):
1e4) # With tolerance 1e4) # With tolerance
self.assertFalse(tf.math.reduce_all(images == aug_images)) self.assertFalse(tf.math.reduce_all(images == aug_images))
def test_mixup_and_cutmix_smoothes_labels_with_videos(self):
batch_size = 12
num_classes = 1000
label_smoothing = 0.1
images = tf.random.normal((batch_size, 8, 224, 224, 3), dtype=tf.float32)
labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix(
num_classes=num_classes, label_smoothing=label_smoothing)
aug_images, aug_labels = augmenter.distort(images, labels)
self.assertEqual(images.shape, aug_images.shape)
self.assertEqual(images.dtype, aug_images.dtype)
self.assertEqual([batch_size, num_classes], aug_labels.shape)
self.assertAllLessEqual(aug_labels, 1. - label_smoothing +
2. / num_classes) # With tolerance
self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes -
1e4) # With tolerance
def test_mixup_changes_video(self):
batch_size = 12
num_classes = 1000
label_smoothing = 0.1
images = tf.random.normal((batch_size, 8, 224, 224, 3), dtype=tf.float32)
labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix(
mixup_alpha=1., cutmix_alpha=0., num_classes=num_classes)
aug_images, aug_labels = augmenter.distort(images, labels)
self.assertEqual(images.shape, aug_images.shape)
self.assertEqual(images.dtype, aug_images.dtype)
self.assertEqual([batch_size, num_classes], aug_labels.shape)
self.assertAllLessEqual(aug_labels, 1. - label_smoothing +
2. / num_classes) # With tolerance
self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes -
1e4) # With tolerance
self.assertFalse(tf.math.reduce_all(images == aug_images))
def test_cutmix_changes_video(self):
batch_size = 12
num_classes = 1000
label_smoothing = 0.1
images = tf.random.normal((batch_size, 8, 224, 224, 3), dtype=tf.float32)
labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix(
mixup_alpha=0., cutmix_alpha=1., num_classes=num_classes)
aug_images, aug_labels = augmenter.distort(images, labels)
self.assertEqual(images.shape, aug_images.shape)
self.assertEqual(images.dtype, aug_images.dtype)
self.assertEqual([batch_size, num_classes], aug_labels.shape)
self.assertAllLessEqual(aug_labels, 1. - label_smoothing +
2. / num_classes) # With tolerance
self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes -
1e4) # With tolerance
self.assertFalse(tf.math.reduce_all(images == aug_images))
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -24,6 +24,7 @@ from official.vision.configs import video_classification as exp_cfg ...@@ -24,6 +24,7 @@ from official.vision.configs import video_classification as exp_cfg
from official.vision.dataloaders import input_reader_factory from official.vision.dataloaders import input_reader_factory
from official.vision.dataloaders import video_input from official.vision.dataloaders import video_input
from official.vision.modeling import factory_3d from official.vision.modeling import factory_3d
from official.vision.ops import augment
@task_factory.register_task_cls(exp_cfg.VideoClassificationTask) @task_factory.register_task_cls(exp_cfg.VideoClassificationTask)
...@@ -128,6 +129,17 @@ class VideoClassificationTask(base_task.Task): ...@@ -128,6 +129,17 @@ class VideoClassificationTask(base_task.Task):
image_key=params.image_field_key, image_key=params.image_field_key,
label_key=params.label_field_key) label_key=params.label_field_key)
postprocess_fn = video_input.PostBatchProcessor(params) postprocess_fn = video_input.PostBatchProcessor(params)
if params.mixup_and_cutmix is not None:
def mixup_and_cutmix(features, labels):
augmenter = augment.MixupAndCutmix(
mixup_alpha=params.mixup_and_cutmix.mixup_alpha,
cutmix_alpha=params.mixup_and_cutmix.cutmix_alpha,
prob=params.mixup_and_cutmix.prob,
label_smoothing=params.mixup_and_cutmix.label_smoothing,
num_classes=self._get_num_classes())
features['image'], labels = augmenter(features['image'], labels)
return features, labels
postprocess_fn = mixup_and_cutmix
reader = input_reader_factory.input_reader_generator( reader = input_reader_factory.input_reader_generator(
params, params,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment