"...python/git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "8e3797be1ca9e3f0c68ff53c86e363bbfeffa268"
Commit 7d45e7b9 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Allow setting hyperparameters of augmentation for video classification in the...

Allow setting hyperparameters of augmentation for video classification in the config, and extend MixupAndCutmix for videos.

PiperOrigin-RevId: 459615324
parent c9a6456d
......@@ -58,7 +58,9 @@ class DataConfig(cfg.DataConfig):
aug_max_aspect_ratio: float = 2.0
aug_min_area_ratio: float = 0.49
aug_max_area_ratio: float = 1.0
aug_type: Optional[str] = None # 'autoaug', 'randaug', or None
aug_type: Optional[
common.Augmentation] = None # AutoAugment and RandAugment.
mixup_and_cutmix: Optional[common.MixupAndCutmix] = None
image_field_key: str = 'image/encoded'
label_field_key: str = 'clip/label/index'
......
......@@ -286,18 +286,28 @@ class Parser(parser.Parser):
self._audio_feature = input_params.audio_feature
self._audio_shape = input_params.audio_feature_shape
self._augmenter = None
if input_params.aug_type is not None:
aug_type = input_params.aug_type
if aug_type == 'autoaug':
aug_type = input_params.aug_type
if aug_type is not None:
if aug_type.type == 'autoaug':
logging.info('Using AutoAugment.')
self._augmenter = augment.AutoAugment()
elif aug_type == 'randaug':
self._augmenter = augment.AutoAugment(
augmentation_name=aug_type.autoaug.augmentation_name,
cutout_const=aug_type.autoaug.cutout_const,
translate_const=aug_type.autoaug.translate_const)
elif aug_type.type == 'randaug':
logging.info('Using RandAugment.')
self._augmenter = augment.RandAugment()
self._augmenter = augment.RandAugment(
num_layers=aug_type.randaug.num_layers,
magnitude=aug_type.randaug.magnitude,
cutout_const=aug_type.randaug.cutout_const,
translate_const=aug_type.randaug.translate_const,
prob_to_apply=aug_type.randaug.prob_to_apply,
exclude_ops=aug_type.randaug.exclude_ops)
else:
raise ValueError('Augmentation policy {} is not supported.'.format(
aug_type))
raise ValueError(
'Augmentation policy {} not supported.'.format(aug_type.type))
else:
self._augmenter = None
def _parse_train_data(
self, decoded_tensors: Dict[str, tf.Tensor]
......
......@@ -21,6 +21,7 @@ from PIL import Image
import tensorflow as tf
import tensorflow_datasets as tfds
from official.vision.configs import common
from official.vision.configs import video_classification as exp_cfg
from official.vision.dataloaders import video_input
......@@ -173,7 +174,8 @@ class VideoAndLabelParserTest(tf.test.TestCase):
params.min_image_size = 224
params.temporal_stride = 2
params.aug_type = 'autoaug'
params.aug_type = common.Augmentation(
type='autoaug', autoaug=common.AutoAugment())
decoder = video_input.Decoder()
parser = video_input.Parser(params).parse_fn(params.is_training)
......
......@@ -317,7 +317,7 @@ def _fill_rectangle(image,
half_width,
half_height,
replace=None):
"""Fill blank area."""
"""Fills blank area."""
image_height = tf.shape(image)[0]
image_width = tf.shape(image)[1]
......@@ -349,6 +349,45 @@ def _fill_rectangle(image,
return image
def _fill_rectangle_video(image,
center_width,
center_height,
half_width,
half_height,
replace=None):
"""Fills blank area for video."""
image_time = tf.shape(image)[0]
image_height = tf.shape(image)[1]
image_width = tf.shape(image)[2]
lower_pad = tf.maximum(0, center_height - half_height)
upper_pad = tf.maximum(0, image_height - center_height - half_height)
left_pad = tf.maximum(0, center_width - half_width)
right_pad = tf.maximum(0, image_width - center_width - half_width)
cutout_shape = [
image_time, image_height - (lower_pad + upper_pad),
image_width - (left_pad + right_pad)
]
padding_dims = [[0, 0], [lower_pad, upper_pad], [left_pad, right_pad]]
mask = tf.pad(
tf.zeros(cutout_shape, dtype=image.dtype),
padding_dims,
constant_values=1)
mask = tf.expand_dims(mask, -1)
mask = tf.tile(mask, [1, 1, 1, 3])
if replace is None:
fill = tf.random.normal(tf.shape(image), dtype=image.dtype)
elif isinstance(replace, tf.Tensor):
fill = replace
else:
fill = tf.ones_like(image, dtype=image.dtype) * replace
image = tf.where(tf.equal(mask, 0), fill, image)
return image
def cutout_video(image: tf.Tensor, replace: int = 0) -> tf.Tensor:
"""Apply cutout (https://arxiv.org/abs/1708.04552) to a video.
......@@ -2187,8 +2226,7 @@ class MixupAndCutmix:
- Mixup: https://arxiv.org/abs/1710.09412
- Cutmix: https://arxiv.org/abs/1905.04899
Implementation is inspired by
https://github.com/rwightman/pytorch-image-models.
Implementaion is inspired by https://github.com/rwightman/pytorch-image-models
"""
def __init__(self,
......@@ -2201,15 +2239,18 @@ class MixupAndCutmix:
"""Applies Mixup and/or Cutmix to a batch of images.
Args:
mixup_alpha: For drawing a random lambda (`lam`) from a beta distribution
(for each image). If zero Mixup is deactivated. Defaults to `.8`.
cutmix_alpha: For drawing a random lambda (`lam`) from a beta distribution
(for each image). If zero Cutmix is deactivated. Defaults to `1.`.
prob: Of augmenting the batch. Defaults to `1.0`.
switch_prob: Probability of applying Cutmix for the batch. Defaults to
`0.5`.
label_smoothing: Constant for label smoothing. Defaults to `0.1`.
num_classes: Number of classes. Defaults to `1001`.
mixup_alpha (float, optional): For drawing a random lambda (`lam`) from a
beta distribution (for each image). If zero Mixup is deactivated.
Defaults to .8.
cutmix_alpha (float, optional): For drawing a random lambda (`lam`) from a
beta distribution (for each image). If zero Cutmix is deactivated.
Defaults to 1..
prob (float, optional): Of augmenting the batch. Defaults to 1.0.
switch_prob (float, optional): Probability of applying Cutmix for the
batch. Defaults to 0.5.
label_smoothing (float, optional): Constant for label smoothing. Defaults
to 0.1.
num_classes (int, optional): Number of classes. Defaults to 1001.
"""
self.mixup_alpha = mixup_alpha
self.cutmix_alpha = cutmix_alpha
......@@ -2234,8 +2275,9 @@ class MixupAndCutmix:
"""Applies Mixup and/or Cutmix to batch of images and transforms labels.
Args:
images (tf.Tensor): Of shape [batch_size,height, width, 3] representing a
batch of image.
images (tf.Tensor): Of shape [batch_size, height, width, 3] representing a
batch of image, or [batch_size, time, height, width, 3] representing a
batch of video.
labels (tf.Tensor): Of shape [batch_size, ] representing the class id for
each image of the batch.
......@@ -2243,6 +2285,7 @@ class MixupAndCutmix:
Tuple[tf.Tensor, tf.Tensor]: The augmented version of `image` and
`labels`.
"""
labels = tf.reshape(labels, [-1])
augment_cond = tf.less(
tf.random.uniform(shape=[], minval=0., maxval=1.0), self.mix_prob)
# pylint: disable=g-long-lambda
......@@ -2264,14 +2307,22 @@ class MixupAndCutmix:
def _cutmix(self, images: tf.Tensor,
labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
"""Apply cutmix."""
"""Applies cutmix."""
lam = MixupAndCutmix._sample_from_beta(self.cutmix_alpha, self.cutmix_alpha,
tf.shape(labels))
ratio = tf.math.sqrt(1 - lam)
batch_size = tf.shape(images)[0]
image_height, image_width = tf.shape(images)[1], tf.shape(images)[2]
if images.shape.rank == 4:
image_height, image_width = tf.shape(images)[1], tf.shape(images)[2]
fill_fn = _fill_rectangle
elif images.shape.rank == 5:
image_height, image_width = tf.shape(images)[2], tf.shape(images)[3]
fill_fn = _fill_rectangle_video
else:
raise ValueError('Bad image rank: {}'.format(images.shape.rank))
cut_height = tf.cast(
ratio * tf.cast(image_height, dtype=tf.float32), dtype=tf.int32)
......@@ -2288,7 +2339,7 @@ class MixupAndCutmix:
lam = tf.cast(lam, dtype=tf.float32)
images = tf.map_fn(
lambda x: _fill_rectangle(*x),
lambda x: fill_fn(*x),
(images, random_center_width, random_center_height, cut_width // 2,
cut_height // 2, tf.reverse(images, [0])),
dtype=(
......@@ -2299,9 +2350,16 @@ class MixupAndCutmix:
def _mixup(self, images: tf.Tensor,
labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
"""Applies mixup."""
lam = MixupAndCutmix._sample_from_beta(self.mixup_alpha, self.mixup_alpha,
tf.shape(labels))
lam = tf.reshape(lam, [-1, 1, 1, 1])
if images.shape.rank == 4:
lam = tf.reshape(lam, [-1, 1, 1, 1])
elif images.shape.rank == 5:
lam = tf.reshape(lam, [-1, 1, 1, 1, 1])
else:
raise ValueError('Bad image rank: {}'.format(images.shape.rank))
lam_cast = tf.cast(lam, dtype=images.dtype)
images = lam_cast * images + (1. - lam_cast) * tf.reverse(images, [0])
......
......@@ -430,6 +430,68 @@ class MixupAndCutmixTest(tf.test.TestCase, parameterized.TestCase):
1e4) # With tolerance
self.assertFalse(tf.math.reduce_all(images == aug_images))
def test_mixup_and_cutmix_smoothes_labels_with_videos(self):
batch_size = 12
num_classes = 1000
label_smoothing = 0.1
images = tf.random.normal((batch_size, 8, 224, 224, 3), dtype=tf.float32)
labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix(
num_classes=num_classes, label_smoothing=label_smoothing)
aug_images, aug_labels = augmenter.distort(images, labels)
self.assertEqual(images.shape, aug_images.shape)
self.assertEqual(images.dtype, aug_images.dtype)
self.assertEqual([batch_size, num_classes], aug_labels.shape)
self.assertAllLessEqual(aug_labels, 1. - label_smoothing +
2. / num_classes) # With tolerance
self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes -
1e4) # With tolerance
def test_mixup_changes_video(self):
batch_size = 12
num_classes = 1000
label_smoothing = 0.1
images = tf.random.normal((batch_size, 8, 224, 224, 3), dtype=tf.float32)
labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix(
mixup_alpha=1., cutmix_alpha=0., num_classes=num_classes)
aug_images, aug_labels = augmenter.distort(images, labels)
self.assertEqual(images.shape, aug_images.shape)
self.assertEqual(images.dtype, aug_images.dtype)
self.assertEqual([batch_size, num_classes], aug_labels.shape)
self.assertAllLessEqual(aug_labels, 1. - label_smoothing +
2. / num_classes) # With tolerance
self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes -
1e4) # With tolerance
self.assertFalse(tf.math.reduce_all(images == aug_images))
def test_cutmix_changes_video(self):
batch_size = 12
num_classes = 1000
label_smoothing = 0.1
images = tf.random.normal((batch_size, 8, 224, 224, 3), dtype=tf.float32)
labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix(
mixup_alpha=0., cutmix_alpha=1., num_classes=num_classes)
aug_images, aug_labels = augmenter.distort(images, labels)
self.assertEqual(images.shape, aug_images.shape)
self.assertEqual(images.dtype, aug_images.dtype)
self.assertEqual([batch_size, num_classes], aug_labels.shape)
self.assertAllLessEqual(aug_labels, 1. - label_smoothing +
2. / num_classes) # With tolerance
self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes -
1e4) # With tolerance
self.assertFalse(tf.math.reduce_all(images == aug_images))
if __name__ == '__main__':
tf.test.main()
......@@ -24,6 +24,7 @@ from official.vision.configs import video_classification as exp_cfg
from official.vision.dataloaders import input_reader_factory
from official.vision.dataloaders import video_input
from official.vision.modeling import factory_3d
from official.vision.ops import augment
@task_factory.register_task_cls(exp_cfg.VideoClassificationTask)
......@@ -128,6 +129,17 @@ class VideoClassificationTask(base_task.Task):
image_key=params.image_field_key,
label_key=params.label_field_key)
postprocess_fn = video_input.PostBatchProcessor(params)
if params.mixup_and_cutmix is not None:
def mixup_and_cutmix(features, labels):
augmenter = augment.MixupAndCutmix(
mixup_alpha=params.mixup_and_cutmix.mixup_alpha,
cutmix_alpha=params.mixup_and_cutmix.cutmix_alpha,
prob=params.mixup_and_cutmix.prob,
label_smoothing=params.mixup_and_cutmix.label_smoothing,
num_classes=self._get_num_classes())
features['image'], labels = augmenter(features['image'], labels)
return features, labels
postprocess_fn = mixup_and_cutmix
reader = input_reader_factory.input_reader_generator(
params,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment