Commit ef6a4159 authored by Yin Cui's avatar Yin Cui Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 352949436
parent e0f818da
# Video Self-supervised Learning
TF2 implementation of [CVRL](https://arxiv.org/abs/2008.03800):
[1] Qian, Rui, Tianjian Meng, Boqing Gong, Ming-Hsuan Yang, Huisheng Wang,
Serge Belongie, and Yin Cui. "Spatiotemporal contrastive video
representation learning." arXiv preprint arXiv:2008.03800 (2020).
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Configs package definition."""
from official.vision.beta.projects.video_ssl.configs import video_ssl
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Video classification configuration definition."""
import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.vision.beta.configs import video_classification
Losses = video_classification.Losses
VideoClassificationModel = video_classification.VideoClassificationModel
VideoClassificationTask = video_classification.VideoClassificationTask
@dataclasses.dataclass
class DataConfig(video_classification.DataConfig):
"""The base configuration for building datasets."""
is_ssl: bool = False
@exp_factory.register_config_factory('video_ssl_pretrain_kinetics400')
def video_ssl_pretrain_kinetics400() -> cfg.ExperimentConfig:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp = video_classification.video_classification_kinetics400()
exp.task.train_data = DataConfig(is_ssl=True, **exp.task.train_data.as_dict())
exp.task.train_data.feature_shape = (16, 224, 224, 3)
exp.task.train_data.temporal_stride = 2
return exp
@exp_factory.register_config_factory('video_ssl_linear_eval_kinetics400')
def video_ssl_linear_eval_kinetics400() -> cfg.ExperimentConfig:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp = video_classification.video_classification_kinetics400()
exp.task.train_data = DataConfig(is_ssl=False,
**exp.task.train_data.as_dict())
exp.task.train_data.feature_shape = (32, 224, 224, 3)
exp.task.train_data.temporal_stride = 2
exp.task.validation_data.feature_shape = (32, 256, 256, 3)
exp.task.validation_data.temporal_stride = 2
exp.task.validation_data = DataConfig(is_ssl=False,
**exp.task.validation_data.as_dict())
exp.task.validation_data.min_image_size = 256
exp.task.validation_data.num_test_clips = 10
exp.task.validation_data.num_test_crops = 3
return exp
@exp_factory.register_config_factory('video_ssl_pretrain_kinetics600')
def video_ssl_pretrain_kinetics600() -> cfg.ExperimentConfig:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp = video_classification.video_classification_kinetics600()
exp.task.train_data = DataConfig(is_ssl=True, **exp.task.train_data.as_dict())
exp.task.train_data.feature_shape = (16, 224, 224, 3)
exp.task.train_data.temporal_stride = 2
return exp
@exp_factory.register_config_factory('video_ssl_linear_eval_kinetics600')
def video_ssl_linear_eval_kinetics600() -> cfg.ExperimentConfig:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp = video_classification.video_classification_kinetics600()
exp.task.train_data = DataConfig(is_ssl=False,
**exp.task.train_data.as_dict())
exp.task.train_data.feature_shape = (32, 224, 224, 3)
exp.task.train_data.temporal_stride = 2
exp.task.validation_data = DataConfig(is_ssl=False,
**exp.task.validation_data.as_dict())
exp.task.validation_data.feature_shape = (32, 256, 256, 3)
exp.task.validation_data.temporal_stride = 2
exp.task.validation_data.min_image_size = 256
exp.task.validation_data.num_test_clips = 10
exp.task.validation_data.num_test_crops = 3
return exp
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=unused-import
from absl.testing import parameterized
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.vision import beta
from official.vision.beta.projects.video_ssl.configs import video_ssl as exp_cfg
class VideoClassificationConfigTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(('video_ssl_pretrain_kinetics400',),
('video_ssl_linear_eval_kinetics400',),
('video_ssl_pretrain_kinetics600',),
('video_ssl_linear_eval_kinetics600',))
def test_video_classification_configs(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.VideoClassificationTask)
self.assertIsInstance(config.task.model, exp_cfg.VideoClassificationModel)
self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig)
config.task.train_data.is_training = None
with self.assertRaises(KeyError):
config.validate()
if __name__ == '__main__':
tf.test.main()
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Parser for video and label datasets."""
from typing import Dict, Optional, Tuple
from absl import logging
import tensorflow as tf
from official.vision.beta.dataloaders import video_input
from official.vision.beta.ops import preprocess_ops_3d
from official.vision.beta.projects.video_ssl.configs import video_ssl as exp_cfg
from official.vision.beta.projects.video_ssl.ops import video_ssl_preprocess_ops
IMAGE_KEY = 'image/encoded'
LABEL_KEY = 'clip/label/index'
Decoder = video_input.Decoder
def _process_image(image: tf.Tensor,
is_training: bool = True,
is_ssl: bool = False,
num_frames: int = 32,
stride: int = 1,
num_test_clips: int = 1,
min_resize: int = 256,
crop_size: int = 224,
num_crops: int = 1,
zero_centering_image: bool = False,
seed: Optional[int] = None) -> tf.Tensor:
"""Processes a serialized image tensor.
Args:
image: Input Tensor of shape [timesteps] and type tf.string of serialized
frames.
is_training: Whether or not in training mode. If True, random sample, crop
and left right flip is used.
is_ssl: Whether or not in self-supervised pre-training mode.
num_frames: Number of frames per subclip.
stride: Temporal stride to sample frames.
num_test_clips: Number of test clips (1 by default). If more than 1, this
will sample multiple linearly spaced clips within each video at test time.
If 1, then a single clip in the middle of the video is sampled. The clips
are aggreagated in the batch dimension.
min_resize: Frames are resized so that min(height, width) is min_resize.
crop_size: Final size of the frame after cropping the resized frames. Both
height and width are the same.
num_crops: Number of crops to perform on the resized frames.
zero_centering_image: If True, frames are normalized to values in [-1, 1].
If False, values in [0, 1].
seed: A deterministic seed to use when sampling.
Returns:
Processed frames. Tensor of shape
[num_frames * num_test_clips, crop_size, crop_size, 3].
"""
# Validate parameters.
if is_training and num_test_clips != 1:
logging.warning(
'`num_test_clips` %d is ignored since `is_training` is `True`.',
num_test_clips)
# Temporal sampler.
if is_training:
# Sampler for training.
if is_ssl:
# Sample two clips from linear decreasing distribution.
image = video_ssl_preprocess_ops.sample_ssl_sequence(
image, num_frames, True, stride)
else:
# Sample random clip.
image = preprocess_ops_3d.sample_sequence(image, num_frames, True, stride)
else:
# Sampler for evaluation.
if num_test_clips > 1:
# Sample linspace clips.
image = preprocess_ops_3d.sample_linspace_sequence(image, num_test_clips,
num_frames, stride)
else:
# Sample middle clip.
image = preprocess_ops_3d.sample_sequence(image, num_frames, False,
stride)
# Decode JPEG string to tf.uint8.
image = preprocess_ops_3d.decode_jpeg(image, 3)
if is_training:
# Standard image data augmentation: random resized crop and random flip.
if is_ssl:
image_1, image_2 = tf.split(image, num_or_size_splits=2, axis=0)
image_1 = preprocess_ops_3d.random_crop_resize(
image_1, crop_size, crop_size, num_frames, 3, (0.5, 2), (0.3, 1))
image_1 = preprocess_ops_3d.random_flip_left_right(image_1, seed)
image_2 = preprocess_ops_3d.random_crop_resize(
image_2, crop_size, crop_size, num_frames, 3, (0.5, 2), (0.3, 1))
image_2 = preprocess_ops_3d.random_flip_left_right(image_2, seed)
else:
image = preprocess_ops_3d.random_crop_resize(
image, crop_size, crop_size, num_frames, 3, (0.5, 2), (0.3, 1))
image = preprocess_ops_3d.random_flip_left_right(image, seed)
else:
# Resize images (resize happens only if necessary to save compute).
image = preprocess_ops_3d.resize_smallest(image, min_resize)
# Three-crop of the frames.
image = preprocess_ops_3d.crop_image(image, crop_size, crop_size, False,
num_crops)
# Cast the frames in float32, normalizing according to zero_centering_image.
if is_training and is_ssl:
image_1 = preprocess_ops_3d.normalize_image(image_1, zero_centering_image)
image_2 = preprocess_ops_3d.normalize_image(image_2, zero_centering_image)
else:
image = preprocess_ops_3d.normalize_image(image, zero_centering_image)
# Self-supervised pre-training augmentations.
if is_training and is_ssl:
# Temporally consistent color jittering.
image_1 = video_ssl_preprocess_ops.random_color_jitter_3d(image_1)
image_2 = video_ssl_preprocess_ops.random_color_jitter_3d(image_2)
# Temporally consistent gaussian blurring.
image_1 = video_ssl_preprocess_ops.random_blur_3d(image_1, num_frames,
crop_size, crop_size)
image_2 = video_ssl_preprocess_ops.random_blur_3d(image_2, num_frames,
crop_size, crop_size)
image = tf.concat([image_1, image_2], axis=0)
return image
def _postprocess_image(image: tf.Tensor,
is_training: bool = True,
is_ssl: bool = False,
num_frames: int = 32,
num_test_clips: int = 1,
num_test_crops: int = 1) -> tf.Tensor:
"""Processes a batched Tensor of frames.
The same parameters used in process should be used here.
Args:
image: Input Tensor of shape [batch, timesteps, height, width, 3].
is_training: Whether or not in training mode. If True, random sample, crop
and left right flip is used.
is_ssl: Whether or not in self-supervised pre-training mode.
num_frames: Number of frames per subclip.
num_test_clips: Number of test clips (1 by default). If more than 1, this
will sample multiple linearly spaced clips within each video at test time.
If 1, then a single clip in the middle of the video is sampled. The clips
are aggreagated in the batch dimension.
num_test_crops: Number of test crops (1 by default). If more than 1, there
are multiple crops for each clip at test time. If 1, there is a single
central crop. The crops are aggreagated in the batch dimension.
Returns:
Processed frames. Tensor of shape
[batch * num_test_clips * num_test_crops, num_frames, height, width, 3].
"""
if is_ssl and is_training:
# In this case, two clips of self-supervised pre-training are merged
# together in batch dimenstion which will be 2 * batch.
image = tf.concat(tf.split(image, num_or_size_splits=2, axis=1), axis=0)
num_views = num_test_clips * num_test_crops
if num_views > 1 and not is_training:
# In this case, multiple views are merged together in batch dimenstion which
# will be batch * num_views.
image = tf.reshape(image, [-1, num_frames] + image.shape[2:].as_list())
return image
def _process_label(label: tf.Tensor,
one_hot_label: bool = True,
num_classes: Optional[int] = None) -> tf.Tensor:
"""Processes label Tensor."""
# Validate parameters.
if one_hot_label and not num_classes:
raise ValueError(
'`num_classes` should be given when requesting one hot label.')
# Cast to tf.int32.
label = tf.cast(label, dtype=tf.int32)
if one_hot_label:
# Replace label index by one hot representation.
label = tf.one_hot(label, num_classes)
if len(label.shape.as_list()) > 1:
label = tf.reduce_sum(label, axis=0)
if num_classes == 1:
# The trick for single label.
label = 1 - label
return label
class Parser(video_input.Parser):
"""Parses a video and label dataset."""
def __init__(self,
input_params: exp_cfg.DataConfig,
image_key: str = IMAGE_KEY,
label_key: str = LABEL_KEY):
super(Parser, self).__init__(input_params, image_key, label_key)
self._is_ssl = input_params.is_ssl
def _parse_train_data(
self, decoded_tensors: Dict[str, tf.Tensor]
) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]:
"""Parses data for training."""
# Process image and label.
image = decoded_tensors[self._image_key]
image = _process_image(
image=image,
is_training=True,
is_ssl=self._is_ssl,
num_frames=self._num_frames,
stride=self._stride,
num_test_clips=self._num_test_clips,
min_resize=self._min_resize,
crop_size=self._crop_size)
image = tf.cast(image, dtype=self._dtype)
features = {'image': image}
label = decoded_tensors[self._label_key]
label = _process_label(label, self._one_hot_label, self._num_classes)
return features, label
def _parse_eval_data(
self, decoded_tensors: Dict[str, tf.Tensor]
) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]:
"""Parses data for evaluation."""
image = decoded_tensors[self._image_key]
image = _process_image(
image=image,
is_training=False,
num_frames=self._num_frames,
stride=self._stride,
num_test_clips=self._num_test_clips,
min_resize=self._min_resize,
crop_size=self._crop_size,
num_crops=self._num_crops)
image = tf.cast(image, dtype=self._dtype)
features = {'image': image}
label = decoded_tensors[self._label_key]
label = _process_label(label, self._one_hot_label, self._num_classes)
if self._output_audio:
audio = decoded_tensors[self._audio_feature]
audio = tf.cast(audio, dtype=self._dtype)
audio = preprocess_ops_3d.sample_sequence(
audio, 20, random=False, stride=1)
audio = tf.ensure_shape(audio, [20, 2048])
features['audio'] = audio
return features, label
def parse_fn(self, is_training):
"""Returns a parse fn that reads and parses raw tensors from the decoder.
Args:
is_training: a `bool` to indicate whether it is in training mode.
Returns:
parse: a `callable` that takes the serialized examle and generate the
images, labels tuple where labels is a dict of Tensors that contains
labels.
"""
def parse(decoded_tensors):
"""Parses the serialized example data."""
if is_training:
return self._parse_train_data(decoded_tensors)
else:
return self._parse_eval_data(decoded_tensors)
return parse
class PostBatchProcessor(object):
"""Processes a video and label dataset which is batched."""
def __init__(self, input_params: exp_cfg.DataConfig):
self._is_training = input_params.is_training
self._is_ssl = input_params.is_ssl
self._num_frames = input_params.feature_shape[0]
self._num_test_clips = input_params.num_test_clips
self._num_test_crops = input_params.num_test_crops
def __call__(self, features: Dict[str, tf.Tensor],
label: tf.Tensor) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]:
"""Parses a single tf.Example into image and label tensors."""
for key in ['image', 'audio']:
if key in features:
features[key] = _postprocess_image(
image=features[key],
is_training=self._is_training,
is_ssl=self._is_ssl,
num_frames=self._num_frames,
num_test_clips=self._num_test_clips,
num_test_crops=self._num_test_crops)
return features, label
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import io
# Import libraries
import numpy as np
from PIL import Image
import tensorflow as tf
from official.vision.beta.projects.video_ssl.configs import video_ssl as exp_cfg
from official.vision.beta.projects.video_ssl.dataloaders import video_ssl_input
AUDIO_KEY = 'features/audio'
def fake_seq_example():
# Create fake data.
random_image = np.random.randint(0, 256, size=(263, 320, 3), dtype=np.uint8)
random_image = Image.fromarray(random_image)
label = 42
with io.BytesIO() as buffer:
random_image.save(buffer, format='JPEG')
raw_image_bytes = buffer.getvalue()
seq_example = tf.train.SequenceExample()
seq_example.feature_lists.feature_list.get_or_create(
video_ssl_input.IMAGE_KEY).feature.add().bytes_list.value[:] = [
raw_image_bytes
]
seq_example.feature_lists.feature_list.get_or_create(
video_ssl_input.IMAGE_KEY).feature.add().bytes_list.value[:] = [
raw_image_bytes
]
seq_example.context.feature[video_ssl_input.LABEL_KEY].int64_list.value[:] = [
label
]
random_audio = np.random.normal(size=(10, 256)).tolist()
for s in random_audio:
seq_example.feature_lists.feature_list.get_or_create(
AUDIO_KEY).feature.add().float_list.value[:] = s
return seq_example, label
class VideoAndLabelParserTest(tf.test.TestCase):
def test_video_ssl_input_pretrain(self):
params = exp_cfg.video_ssl_pretrain_kinetics600().task.train_data
decoder = video_ssl_input.Decoder()
parser = video_ssl_input.Parser(params).parse_fn(params.is_training)
seq_example, _ = fake_seq_example()
input_tensor = tf.constant(seq_example.SerializeToString())
decoded_tensors = decoder.decode(input_tensor)
output_tensor = parser(decoded_tensors)
image_features, _ = output_tensor
image = image_features['image']
self.assertAllEqual(image.shape, (32, 224, 224, 3))
def test_video_ssl_input_linear_train(self):
params = exp_cfg.video_ssl_linear_eval_kinetics600().task.train_data
decoder = video_ssl_input.Decoder()
parser = video_ssl_input.Parser(params).parse_fn(params.is_training)
seq_example, label = fake_seq_example()
input_tensor = tf.constant(seq_example.SerializeToString())
decoded_tensors = decoder.decode(input_tensor)
output_tensor = parser(decoded_tensors)
image_features, label = output_tensor
image = image_features['image']
self.assertAllEqual(image.shape, (32, 224, 224, 3))
self.assertAllEqual(label.shape, (600,))
def test_video_ssl_input_linear_eval(self):
params = exp_cfg.video_ssl_linear_eval_kinetics600().task.validation_data
print('!!!', params)
decoder = video_ssl_input.Decoder()
parser = video_ssl_input.Parser(params).parse_fn(params.is_training)
seq_example, label = fake_seq_example()
input_tensor = tf.constant(seq_example.SerializeToString())
decoded_tensors = decoder.decode(input_tensor)
output_tensor = parser(decoded_tensors)
image_features, label = output_tensor
image = image_features['image']
self.assertAllEqual(image.shape, (960, 256, 256, 3))
self.assertAllEqual(label.shape, (600,))
if __name__ == '__main__':
tf.test.main()
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils for customed ops for video ssl."""
import functools
from typing import Optional
import tensorflow as tf
def random_apply(func, p, x):
"""Randomly apply function func to x with probability p."""
return tf.cond(
tf.less(tf.random.uniform([], minval=0, maxval=1, dtype=tf.float32),
tf.cast(p, tf.float32)),
lambda: func(x),
lambda: x)
def random_brightness(image, max_delta):
"""Distort brightness of image (SimCLRv2 style)."""
factor = tf.random.uniform(
[], tf.maximum(1.0 - max_delta, 0), 1.0 + max_delta)
image = image * factor
return image
def to_grayscale(image, keep_channels=True):
"""Turn the input image to gray scale.
Args:
image: The input image tensor.
keep_channels: Whether maintaining the channel number for the image.
If true, the transformed image will repeat three times in channel.
If false, the transformed image will only have one channel.
Returns:
The distorted image tensor.
"""
image = tf.image.rgb_to_grayscale(image)
if keep_channels:
image = tf.tile(image, [1, 1, 3])
return image
def color_jitter(image, strength, random_order=True):
"""Distorts the color of the image (SimCLRv2 style).
Args:
image: The input image tensor.
strength: The floating number for the strength of the color augmentation.
random_order: A bool, specifying whether to randomize the jittering order.
Returns:
The distorted image tensor.
"""
brightness = 0.8 * strength
contrast = 0.8 * strength
saturation = 0.8 * strength
hue = 0.2 * strength
if random_order:
return color_jitter_rand(
image, brightness, contrast, saturation, hue)
else:
return color_jitter_nonrand(
image, brightness, contrast, saturation, hue)
def color_jitter_nonrand(image,
brightness=0,
contrast=0,
saturation=0,
hue=0):
"""Distorts the color of the image (jittering order is fixed, SimCLRv2 style).
Args:
image: The input image tensor.
brightness: A float, specifying the brightness for color jitter.
contrast: A float, specifying the contrast for color jitter.
saturation: A float, specifying the saturation for color jitter.
hue: A float, specifying the hue for color jitter.
Returns:
The distorted image tensor.
"""
with tf.name_scope('distort_color'):
def apply_transform(i, x, brightness, contrast, saturation, hue):
"""Apply the i-th transformation."""
if brightness != 0 and i == 0:
x = random_brightness(x, max_delta=brightness)
elif contrast != 0 and i == 1:
x = tf.image.random_contrast(
x, lower=1-contrast, upper=1+contrast)
elif saturation != 0 and i == 2:
x = tf.image.random_saturation(
x, lower=1-saturation, upper=1+saturation)
elif hue != 0:
x = tf.image.random_hue(x, max_delta=hue)
return x
for i in range(4):
image = apply_transform(i, image, brightness, contrast, saturation, hue)
image = tf.clip_by_value(image, 0., 1.)
return image
def color_jitter_rand(image,
brightness=0,
contrast=0,
saturation=0,
hue=0):
"""Distorts the color of the image (jittering order is random, SimCLRv2 style).
Args:
image: The input image tensor.
brightness: A float, specifying the brightness for color jitter.
contrast: A float, specifying the contrast for color jitter.
saturation: A float, specifying the saturation for color jitter.
hue: A float, specifying the hue for color jitter.
Returns:
The distorted image tensor.
"""
with tf.name_scope('distort_color'):
def apply_transform(i, x):
"""Apply the i-th transformation."""
def brightness_transform():
if brightness == 0:
return x
else:
return random_brightness(x, max_delta=brightness)
def contrast_transform():
if contrast == 0:
return x
else:
return tf.image.random_contrast(x, lower=1-contrast, upper=1+contrast)
def saturation_transform():
if saturation == 0:
return x
else:
return tf.image.random_saturation(
x, lower=1-saturation, upper=1+saturation)
def hue_transform():
if hue == 0:
return x
else:
return tf.image.random_hue(x, max_delta=hue)
# pylint:disable=g-long-lambda
x = tf.cond(
tf.less(i, 2), lambda: tf.cond(
tf.less(i, 1), brightness_transform, contrast_transform),
lambda: tf.cond(tf.less(i, 3), saturation_transform, hue_transform))
# pylint:disable=g-long-lambda
return x
perm = tf.random.shuffle(tf.range(4))
for i in range(4):
image = apply_transform(perm[i], image)
image = tf.clip_by_value(image, 0., 1.)
return image
def random_color_jitter_3d(frames):
"""Applies temporally consistent color jittering to one video clip.
Args:
frames: `Tensor` of shape [num_frames, height, width, channels].
Returns:
A Tensor of shape [num_frames, height, width, channels] being color jittered
with the same operation.
"""
def random_color_jitter(image, p=1.0):
def _transform(image):
color_jitter_t = functools.partial(
color_jitter, strength=1.0)
image = random_apply(color_jitter_t, p=0.8, x=image)
return random_apply(to_grayscale, p=0.2, x=image)
return random_apply(_transform, p=p, x=image)
num_frames, width, height, channels = tf.shape(frames)
big_image = tf.reshape(frames, [num_frames*width, height, channels])
big_image = random_color_jitter(big_image)
return tf.reshape(big_image, [num_frames, width, height, channels])
def gaussian_blur(image, kernel_size, sigma, padding='SAME'):
"""Blurs the given image with separable convolution.
Args:
image: Tensor of shape [height, width, channels] and dtype float to blur.
kernel_size: Integer Tensor for the size of the blur kernel. This is should
be an odd number. If it is an even number, the actual kernel size will be
size + 1.
sigma: Sigma value for gaussian operator.
padding: Padding to use for the convolution. Typically 'SAME' or 'VALID'.
Returns:
A Tensor representing the blurred image.
"""
radius = tf.cast(kernel_size / 2, dtype=tf.int32)
kernel_size = radius * 2 + 1
x = tf.cast(tf.range(-radius, radius + 1), dtype=tf.float32)
blur_filter = tf.exp(
-tf.pow(x, 2.0) / (2.0 * tf.pow(tf.cast(sigma, dtype=tf.float32), 2.0)))
blur_filter /= tf.reduce_sum(blur_filter)
# One vertical and one horizontal filter.
blur_v = tf.reshape(blur_filter, [kernel_size, 1, 1, 1])
blur_h = tf.reshape(blur_filter, [1, kernel_size, 1, 1])
num_channels = tf.shape(image)[-1]
blur_h = tf.tile(blur_h, [1, 1, num_channels, 1])
blur_v = tf.tile(blur_v, [1, 1, num_channels, 1])
expand_batch_dim = image.shape.ndims == 3
if expand_batch_dim:
# Tensorflow requires batched input to convolutions, which we can fake with
# an extra dimension.
image = tf.expand_dims(image, axis=0)
blurred = tf.nn.depthwise_conv2d(
image, blur_h, strides=[1, 1, 1, 1], padding=padding)
blurred = tf.nn.depthwise_conv2d(
blurred, blur_v, strides=[1, 1, 1, 1], padding=padding)
if expand_batch_dim:
blurred = tf.squeeze(blurred, axis=0)
return blurred
def random_blur(image, height, width, p=1.0):
"""Randomly blur an image.
Args:
image: `Tensor` representing an image of arbitrary size.
height: Height of output image.
width: Width of output image.
p: probability of applying this transformation.
Returns:
A preprocessed image `Tensor`.
"""
del width
def _transform(image):
sigma = tf.random.uniform([], 0.1, 2.0, dtype=tf.float32)
return gaussian_blur(
image, kernel_size=height//10, sigma=sigma, padding='SAME')
return random_apply(_transform, p=p, x=image)
def random_blur_3d(frames, height, width, blur_probability=0.5):
"""Apply efficient batch data transformations.
Args:
frames: `Tensor` of shape [timesteps, height, width, 3].
height: the height of image.
width: the width of image.
blur_probability: the probaility to apply the blur operator.
Returns:
Preprocessed feature list.
"""
def generate_selector(p, bsz):
shape = [bsz, 1, 1, 1]
selector = tf.cast(
tf.less(tf.random.uniform(shape, 0, 1, dtype=tf.float32), p),
tf.float32)
return selector
frames_new = random_blur(frames, height, width, p=1.)
selector = generate_selector(blur_probability, 1)
frames = frames_new * selector + frames * (1 - selector)
frames = tf.clip_by_value(frames, 0., 1.)
return frames
def _sample_or_pad_sequence_indices(sequence: tf.Tensor,
num_steps: int,
stride: int,
offset: tf.Tensor) -> tf.Tensor:
"""Returns indices to take for sampling or padding sequences to fixed size."""
sequence_length = tf.shape(sequence)[0]
sel_idx = tf.range(sequence_length)
# Repeats sequence until num_steps are available in total.
max_length = num_steps * stride + offset
num_repeats = tf.math.floordiv(
max_length + sequence_length - 1, sequence_length)
sel_idx = tf.tile(sel_idx, [num_repeats])
steps = tf.range(offset, offset + num_steps * stride, stride)
return tf.gather(sel_idx, steps)
def sample_ssl_sequence(sequence: tf.Tensor,
num_steps: int,
random: bool,
stride: int = 1,
num_windows: Optional[int] = 2) -> tf.Tensor:
"""Samples two segments of size num_steps randomly from a given sequence.
Currently it only supports images, and specically designed for video self-
supervised learning.
Args:
sequence: Any tensor where the first dimension is timesteps.
num_steps: Number of steps (e.g. frames) to take.
random: A boolean indicating whether to random sample the single window. If
True, the offset is randomized. Only True is supported.
stride: Distance to sample between timesteps.
num_windows: Number of sequence sampled.
Returns:
A single Tensor with first dimension num_steps with the sampled segment.
"""
sequence_length = tf.shape(sequence)[0]
sequence_length = tf.cast(sequence_length, tf.float32)
if random:
max_offset = tf.cond(
tf.greater(sequence_length, (num_steps - 1) * stride),
lambda: sequence_length - (num_steps - 1) * stride,
lambda: sequence_length)
max_offset = tf.cast(max_offset, dtype=tf.float32)
def cdf(k, power=1.0):
"""Cumulative distribution function for x^power."""
p = -tf.math.pow(k, power + 1) / (
power * tf.math.pow(max_offset, power + 1)) + k * (power + 1) / (
power * max_offset)
return p
u = tf.random.uniform(())
k_low = tf.constant(0, dtype=tf.float32)
k_up = max_offset
k = tf.math.floordiv(max_offset, 2.0)
c = lambda k_low, k_up, k: tf.greater(tf.math.abs(k_up - k_low), 1.0)
# pylint:disable=g-long-lambda
b = lambda k_low, k_up, k: tf.cond(
tf.greater(cdf(k), u),
lambda: [k_low, k, tf.math.floordiv(k + k_low, 2.0)],
lambda: [k, k_up, tf.math.floordiv(k_up + k, 2.0)])
_, _, k = tf.while_loop(c, b, [k_low, k_up, k])
delta = tf.cast(k, tf.int32)
max_offset = tf.cast(max_offset, tf.int32)
sequence_length = tf.cast(sequence_length, tf.int32)
choice_1 = tf.cond(
tf.equal(max_offset, sequence_length),
lambda: tf.random.uniform((),
maxval=tf.cast(max_offset, dtype=tf.int32),
dtype=tf.int32),
lambda: tf.random.uniform((),
maxval=tf.cast(max_offset - delta,
dtype=tf.int32),
dtype=tf.int32))
choice_2 = tf.cond(
tf.equal(max_offset, sequence_length),
lambda: tf.random.uniform((),
maxval=tf.cast(max_offset, dtype=tf.int32),
dtype=tf.int32),
lambda: choice_1 + delta)
# pylint:disable=g-long-lambda
shuffle_choice = tf.random.shuffle((choice_1, choice_2))
offset_1 = shuffle_choice[0]
offset_2 = shuffle_choice[1]
else:
raise NotImplementedError
indices_1 = _sample_or_pad_sequence_indices(
sequence=sequence,
num_steps=num_steps,
stride=stride,
offset=offset_1)
indices_2 = _sample_or_pad_sequence_indices(
sequence=sequence,
num_steps=num_steps,
stride=stride,
offset=offset_2)
indices = tf.concat([indices_1, indices_2], axis=0)
indices.set_shape((num_windows * num_steps,))
output = tf.gather(sequence, indices)
return output
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
from official.vision.beta.ops import preprocess_ops_3d
from official.vision.beta.projects.video_ssl.ops import video_ssl_preprocess_ops
class VideoSslPreprocessOpsTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self._raw_frames = tf.random.uniform((250, 256, 256, 3), minval=0,
maxval=255, dtype=tf.dtypes.int32)
self._sampled_frames = self._raw_frames[:16]
self._frames = preprocess_ops_3d.normalize_image(
self._sampled_frames, False, tf.float32)
def test_sample_ssl_sequence(self):
sampled_seq = video_ssl_preprocess_ops.sample_ssl_sequence(
self._raw_frames, 16, True, 2)
self.assertAllEqual(sampled_seq.shape, (32, 256, 256, 3))
def test_random_color_jitter_3d(self):
jittered_clip = video_ssl_preprocess_ops.random_color_jitter_3d(
self._frames)
self.assertAllEqual(jittered_clip.shape, (16, 256, 256, 3))
def test_random_blur_3d(self):
blurred_clip = video_ssl_preprocess_ops.random_blur_3d(
self._frames, 256, 256)
self.assertAllEqual(blurred_clip.shape, (16, 256, 256, 3))
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment