Unverified Commit 8b641b13 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'tensorflow:master' into panoptic-deeplab

parents 7cffacfe 357fa547
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Configs package definition."""
from official.vision.beta.projects.video_ssl.configs import video_ssl
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
# Put the pretrained checkpoint here for linear evaluation
init_checkpoint: 'r3d_1x_k600_800ep_backbone-1'
init_checkpoint_modules: 'backbone'
model:
dropout_rate: 1.0
norm_activation:
use_sync_bn: false
backbone:
resnet_3d:
block_specs: !!python/tuple
- temporal_kernel_sizes: !!python/tuple
- 1
- 1
- 1
temporal_strides: 1
use_self_gating: false
- temporal_kernel_sizes: !!python/tuple
- 1
- 1
- 1
- 1
temporal_strides: 1
use_self_gating: false
- temporal_kernel_sizes: !!python/tuple
- 3
- 3
- 3
- 3
- 3
- 3
temporal_strides: 1
use_self_gating: false
- temporal_kernel_sizes: !!python/tuple
- 3
- 3
- 3
temporal_strides: 1
use_self_gating: false
model_id: 50
stem_conv_temporal_kernel_size: 5
stem_conv_temporal_stride: 2
stem_pool_temporal_stride: 1
train_data:
name: kinetics600
feature_shape: !!python/tuple
- 32
- 224
- 224
- 3
temporal_stride: 2
global_batch_size: 1024
dtype: 'bfloat16'
shuffle_buffer_size: 1024
aug_max_area_ratio: 1.0
aug_max_aspect_ratio: 2.0
aug_min_area_ratio: 0.3
aug_min_aspect_ratio: 0.5
validation_data:
name: kinetics600
feature_shape: !!python/tuple
- 32
- 256
- 256
- 3
temporal_stride: 2
num_test_clips: 10
num_test_crops: 3
global_batch_size: 64
dtype: 'bfloat16'
drop_remainder: false
losses:
l2_weight_decay: 0.0
trainer:
optimizer_config:
learning_rate:
cosine:
initial_learning_rate: 32.0
decay_steps: 35744
optimizer:
sgd:
nesterov: false
warmup:
linear:
warmup_steps: 1787
train_steps: 35744
steps_per_loop: 100
summary_interval: 100
validation_interval: 100
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
model:
dropout_rate: 1.0
norm_activation:
use_sync_bn: true
hidden_norm_activation:
use_sync_bn: true
backbone:
resnet_3d:
block_specs: !!python/tuple
- temporal_kernel_sizes: !!python/tuple
- 1
- 1
- 1
temporal_strides: 1
use_self_gating: false
- temporal_kernel_sizes: !!python/tuple
- 1
- 1
- 1
- 1
temporal_strides: 1
use_self_gating: false
- temporal_kernel_sizes: !!python/tuple
- 3
- 3
- 3
- 3
- 3
- 3
temporal_strides: 1
use_self_gating: false
- temporal_kernel_sizes: !!python/tuple
- 3
- 3
- 3
temporal_strides: 1
use_self_gating: false
model_id: 50
stem_conv_temporal_kernel_size: 5
stem_conv_temporal_stride: 2
stem_pool_temporal_stride: 1
train_data:
name: kinetics600
feature_shape: !!python/tuple
- 16
- 224
- 224
- 3
temporal_stride: 2
global_batch_size: 1024
dtype: 'bfloat16'
shuffle_buffer_size: 1024
losses:
l2_weight_decay: 0.000001
trainer:
optimizer_config:
learning_rate:
cosine:
initial_learning_rate: 0.32
decay_steps: 71488
optimizer:
sgd:
nesterov: false
warmup:
linear:
warmup_steps: 1787
train_steps: 71488
steps_per_loop: 100
summary_interval: 100
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Video classification configuration definition."""
import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.vision.beta.configs import common
from official.vision.beta.configs import video_classification
Losses = video_classification.Losses
VideoClassificationModel = video_classification.VideoClassificationModel
VideoClassificationTask = video_classification.VideoClassificationTask
@dataclasses.dataclass
class VideoSSLPretrainTask(VideoClassificationTask):
pass
@dataclasses.dataclass
class VideoSSLEvalTask(VideoClassificationTask):
pass
@dataclasses.dataclass
class DataConfig(video_classification.DataConfig):
"""The base configuration for building datasets."""
is_ssl: bool = False
@dataclasses.dataclass
class VideoSSLModel(VideoClassificationModel):
"""The model config."""
normalize_feature: bool = False
hidden_dim: int = 2048
hidden_layer_num: int = 3
projection_dim: int = 128
hidden_norm_activation: common.NormActivation = common.NormActivation(
use_sync_bn=False, norm_momentum=0.997, norm_epsilon=1.0e-05)
@dataclasses.dataclass
class SSLLosses(Losses):
normalize_hidden: bool = True
temperature: float = 0.1
@exp_factory.register_config_factory('video_ssl_pretrain_kinetics400')
def video_ssl_pretrain_kinetics400() -> cfg.ExperimentConfig:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp = video_classification.video_classification_kinetics400()
exp.task = VideoSSLPretrainTask(**exp.task.as_dict())
exp.task.train_data = DataConfig(is_ssl=True, **exp.task.train_data.as_dict())
exp.task.train_data.feature_shape = (16, 224, 224, 3)
exp.task.train_data.temporal_stride = 2
exp.task.model = VideoSSLModel(exp.task.model)
exp.task.model.model_type = 'video_ssl_model'
exp.task.losses = SSLLosses(exp.task.losses)
return exp
@exp_factory.register_config_factory('video_ssl_linear_eval_kinetics400')
def video_ssl_linear_eval_kinetics400() -> cfg.ExperimentConfig:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp = video_classification.video_classification_kinetics400()
exp.task = VideoSSLEvalTask(**exp.task.as_dict())
exp.task.train_data = DataConfig(is_ssl=False,
**exp.task.train_data.as_dict())
exp.task.train_data.feature_shape = (32, 224, 224, 3)
exp.task.train_data.temporal_stride = 2
exp.task.validation_data.feature_shape = (32, 256, 256, 3)
exp.task.validation_data.temporal_stride = 2
exp.task.validation_data = DataConfig(is_ssl=False,
**exp.task.validation_data.as_dict())
exp.task.validation_data.min_image_size = 256
exp.task.validation_data.num_test_clips = 10
exp.task.validation_data.num_test_crops = 3
exp.task.model = VideoSSLModel(exp.task.model)
exp.task.model.model_type = 'video_ssl_model'
exp.task.model.normalize_feature = True
exp.task.model.hidden_layer_num = 0
exp.task.model.projection_dim = 400
return exp
@exp_factory.register_config_factory('video_ssl_pretrain_kinetics600')
def video_ssl_pretrain_kinetics600() -> cfg.ExperimentConfig:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp = video_classification.video_classification_kinetics600()
exp.task = VideoSSLPretrainTask(**exp.task.as_dict())
exp.task.train_data = DataConfig(is_ssl=True, **exp.task.train_data.as_dict())
exp.task.train_data.feature_shape = (16, 224, 224, 3)
exp.task.train_data.temporal_stride = 2
exp.task.model = VideoSSLModel(exp.task.model)
exp.task.model.model_type = 'video_ssl_model'
exp.task.losses = SSLLosses(exp.task.losses)
return exp
@exp_factory.register_config_factory('video_ssl_linear_eval_kinetics600')
def video_ssl_linear_eval_kinetics600() -> cfg.ExperimentConfig:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp = video_classification.video_classification_kinetics600()
exp.task = VideoSSLEvalTask(**exp.task.as_dict())
exp.task.train_data = DataConfig(is_ssl=False,
**exp.task.train_data.as_dict())
exp.task.train_data.feature_shape = (32, 224, 224, 3)
exp.task.train_data.temporal_stride = 2
exp.task.validation_data = DataConfig(is_ssl=False,
**exp.task.validation_data.as_dict())
exp.task.validation_data.feature_shape = (32, 256, 256, 3)
exp.task.validation_data.temporal_stride = 2
exp.task.validation_data.min_image_size = 256
exp.task.validation_data.num_test_clips = 10
exp.task.validation_data.num_test_crops = 3
exp.task.model = VideoSSLModel(exp.task.model)
exp.task.model.model_type = 'video_ssl_model'
exp.task.model.normalize_feature = True
exp.task.model.hidden_layer_num = 0
exp.task.model.projection_dim = 600
return exp
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# pylint: disable=unused-import
from absl.testing import parameterized
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.vision import beta
from official.vision.beta.projects.video_ssl.configs import video_ssl as exp_cfg
class VideoClassificationConfigTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(('video_ssl_pretrain_kinetics400',),
('video_ssl_pretrain_kinetics600',))
def test_video_ssl_pretrain_configs(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.VideoSSLPretrainTask)
self.assertIsInstance(config.task.model, exp_cfg.VideoSSLModel)
self.assertIsInstance(config.task.losses, exp_cfg.SSLLosses)
self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig)
config.task.train_data.is_training = None
with self.assertRaises(KeyError):
config.validate()
@parameterized.parameters(('video_ssl_linear_eval_kinetics400',),
('video_ssl_linear_eval_kinetics600',))
def test_video_ssl_linear_eval_configs(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.VideoSSLEvalTask)
self.assertIsInstance(config.task.model, exp_cfg.VideoSSLModel)
self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig)
config.task.train_data.is_training = None
with self.assertRaises(KeyError):
config.validate()
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Parser for video and label datasets."""
from typing import Dict, Optional, Tuple
from absl import logging
import tensorflow as tf
from official.vision.beta.dataloaders import video_input
from official.vision.beta.ops import preprocess_ops_3d
from official.vision.beta.projects.video_ssl.configs import video_ssl as exp_cfg
from official.vision.beta.projects.video_ssl.ops import video_ssl_preprocess_ops
IMAGE_KEY = 'image/encoded'
LABEL_KEY = 'clip/label/index'
Decoder = video_input.Decoder
def _process_image(image: tf.Tensor,
is_training: bool = True,
is_ssl: bool = False,
num_frames: int = 32,
stride: int = 1,
num_test_clips: int = 1,
min_resize: int = 256,
crop_size: int = 224,
num_crops: int = 1,
zero_centering_image: bool = False,
seed: Optional[int] = None) -> tf.Tensor:
"""Processes a serialized image tensor.
Args:
image: Input Tensor of shape [timesteps] and type tf.string of serialized
frames.
is_training: Whether or not in training mode. If True, random sample, crop
and left right flip is used.
is_ssl: Whether or not in self-supervised pre-training mode.
num_frames: Number of frames per subclip.
stride: Temporal stride to sample frames.
num_test_clips: Number of test clips (1 by default). If more than 1, this
will sample multiple linearly spaced clips within each video at test time.
If 1, then a single clip in the middle of the video is sampled. The clips
are aggreagated in the batch dimension.
min_resize: Frames are resized so that min(height, width) is min_resize.
crop_size: Final size of the frame after cropping the resized frames. Both
height and width are the same.
num_crops: Number of crops to perform on the resized frames.
zero_centering_image: If True, frames are normalized to values in [-1, 1].
If False, values in [0, 1].
seed: A deterministic seed to use when sampling.
Returns:
Processed frames. Tensor of shape
[num_frames * num_test_clips, crop_size, crop_size, 3].
"""
# Validate parameters.
if is_training and num_test_clips != 1:
logging.warning(
'`num_test_clips` %d is ignored since `is_training` is `True`.',
num_test_clips)
# Temporal sampler.
if is_training:
# Sampler for training.
if is_ssl:
# Sample two clips from linear decreasing distribution.
image = video_ssl_preprocess_ops.sample_ssl_sequence(
image, num_frames, True, stride)
else:
# Sample random clip.
image = preprocess_ops_3d.sample_sequence(image, num_frames, True, stride)
else:
# Sampler for evaluation.
if num_test_clips > 1:
# Sample linspace clips.
image = preprocess_ops_3d.sample_linspace_sequence(image, num_test_clips,
num_frames, stride)
else:
# Sample middle clip.
image = preprocess_ops_3d.sample_sequence(image, num_frames, False,
stride)
# Decode JPEG string to tf.uint8.
image = preprocess_ops_3d.decode_jpeg(image, 3)
if is_training:
# Standard image data augmentation: random resized crop and random flip.
if is_ssl:
image_1, image_2 = tf.split(image, num_or_size_splits=2, axis=0)
image_1 = preprocess_ops_3d.random_crop_resize(
image_1, crop_size, crop_size, num_frames, 3, (0.5, 2), (0.3, 1))
image_1 = preprocess_ops_3d.random_flip_left_right(image_1, seed)
image_2 = preprocess_ops_3d.random_crop_resize(
image_2, crop_size, crop_size, num_frames, 3, (0.5, 2), (0.3, 1))
image_2 = preprocess_ops_3d.random_flip_left_right(image_2, seed)
else:
image = preprocess_ops_3d.random_crop_resize(
image, crop_size, crop_size, num_frames, 3, (0.5, 2), (0.3, 1))
image = preprocess_ops_3d.random_flip_left_right(image, seed)
else:
# Resize images (resize happens only if necessary to save compute).
image = preprocess_ops_3d.resize_smallest(image, min_resize)
# Three-crop of the frames.
image = preprocess_ops_3d.crop_image(image, crop_size, crop_size, False,
num_crops)
# Cast the frames in float32, normalizing according to zero_centering_image.
if is_training and is_ssl:
image_1 = preprocess_ops_3d.normalize_image(image_1, zero_centering_image)
image_2 = preprocess_ops_3d.normalize_image(image_2, zero_centering_image)
else:
image = preprocess_ops_3d.normalize_image(image, zero_centering_image)
# Self-supervised pre-training augmentations.
if is_training and is_ssl:
# Temporally consistent color jittering.
image_1 = video_ssl_preprocess_ops.random_color_jitter_3d(image_1)
image_2 = video_ssl_preprocess_ops.random_color_jitter_3d(image_2)
# Temporally consistent gaussian blurring.
image_1 = video_ssl_preprocess_ops.random_blur(image_1, crop_size,
crop_size, 1.0)
image_2 = video_ssl_preprocess_ops.random_blur(image_2, crop_size,
crop_size, 0.1)
image_2 = video_ssl_preprocess_ops.random_solarization(image_2)
image = tf.concat([image_1, image_2], axis=0)
image = tf.clip_by_value(image, 0., 1.)
return image
def _postprocess_image(image: tf.Tensor,
is_training: bool = True,
is_ssl: bool = False,
num_frames: int = 32,
num_test_clips: int = 1,
num_test_crops: int = 1) -> tf.Tensor:
"""Processes a batched Tensor of frames.
The same parameters used in process should be used here.
Args:
image: Input Tensor of shape [batch, timesteps, height, width, 3].
is_training: Whether or not in training mode. If True, random sample, crop
and left right flip is used.
is_ssl: Whether or not in self-supervised pre-training mode.
num_frames: Number of frames per subclip.
num_test_clips: Number of test clips (1 by default). If more than 1, this
will sample multiple linearly spaced clips within each video at test time.
If 1, then a single clip in the middle of the video is sampled. The clips
are aggreagated in the batch dimension.
num_test_crops: Number of test crops (1 by default). If more than 1, there
are multiple crops for each clip at test time. If 1, there is a single
central crop. The crops are aggreagated in the batch dimension.
Returns:
Processed frames. Tensor of shape
[batch * num_test_clips * num_test_crops, num_frames, height, width, 3].
"""
if is_ssl and is_training:
# In this case, two clips of self-supervised pre-training are merged
# together in batch dimenstion which will be 2 * batch.
image = tf.concat(tf.split(image, num_or_size_splits=2, axis=1), axis=0)
num_views = num_test_clips * num_test_crops
if num_views > 1 and not is_training:
# In this case, multiple views are merged together in batch dimenstion which
# will be batch * num_views.
image = tf.reshape(image, [-1, num_frames] + image.shape[2:].as_list())
return image
def _process_label(label: tf.Tensor,
one_hot_label: bool = True,
num_classes: Optional[int] = None) -> tf.Tensor:
"""Processes label Tensor."""
# Validate parameters.
if one_hot_label and not num_classes:
raise ValueError(
'`num_classes` should be given when requesting one hot label.')
# Cast to tf.int32.
label = tf.cast(label, dtype=tf.int32)
if one_hot_label:
# Replace label index by one hot representation.
label = tf.one_hot(label, num_classes)
if len(label.shape.as_list()) > 1:
label = tf.reduce_sum(label, axis=0)
if num_classes == 1:
# The trick for single label.
label = 1 - label
return label
class Parser(video_input.Parser):
"""Parses a video and label dataset."""
def __init__(self,
input_params: exp_cfg.DataConfig,
image_key: str = IMAGE_KEY,
label_key: str = LABEL_KEY):
super(Parser, self).__init__(input_params, image_key, label_key)
self._is_ssl = input_params.is_ssl
def _parse_train_data(
self, decoded_tensors: Dict[str, tf.Tensor]
) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]:
"""Parses data for training."""
# Process image and label.
image = decoded_tensors[self._image_key]
image = _process_image(
image=image,
is_training=True,
is_ssl=self._is_ssl,
num_frames=self._num_frames,
stride=self._stride,
num_test_clips=self._num_test_clips,
min_resize=self._min_resize,
crop_size=self._crop_size)
image = tf.cast(image, dtype=self._dtype)
features = {'image': image}
label = decoded_tensors[self._label_key]
label = _process_label(label, self._one_hot_label, self._num_classes)
return features, label
def _parse_eval_data(
self, decoded_tensors: Dict[str, tf.Tensor]
) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]:
"""Parses data for evaluation."""
image = decoded_tensors[self._image_key]
image = _process_image(
image=image,
is_training=False,
num_frames=self._num_frames,
stride=self._stride,
num_test_clips=self._num_test_clips,
min_resize=self._min_resize,
crop_size=self._crop_size,
num_crops=self._num_crops)
image = tf.cast(image, dtype=self._dtype)
features = {'image': image}
label = decoded_tensors[self._label_key]
label = _process_label(label, self._one_hot_label, self._num_classes)
if self._output_audio:
audio = decoded_tensors[self._audio_feature]
audio = tf.cast(audio, dtype=self._dtype)
audio = preprocess_ops_3d.sample_sequence(
audio, 20, random=False, stride=1)
audio = tf.ensure_shape(audio, [20, 2048])
features['audio'] = audio
return features, label
def parse_fn(self, is_training):
"""Returns a parse fn that reads and parses raw tensors from the decoder.
Args:
is_training: a `bool` to indicate whether it is in training mode.
Returns:
parse: a `callable` that takes the serialized examle and generate the
images, labels tuple where labels is a dict of Tensors that contains
labels.
"""
def parse(decoded_tensors):
"""Parses the serialized example data."""
if is_training:
return self._parse_train_data(decoded_tensors)
else:
return self._parse_eval_data(decoded_tensors)
return parse
class PostBatchProcessor(object):
"""Processes a video and label dataset which is batched."""
def __init__(self, input_params: exp_cfg.DataConfig):
self._is_training = input_params.is_training
self._is_ssl = input_params.is_ssl
self._num_frames = input_params.feature_shape[0]
self._num_test_clips = input_params.num_test_clips
self._num_test_crops = input_params.num_test_crops
def __call__(self, features: Dict[str, tf.Tensor],
label: tf.Tensor) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]:
"""Parses a single tf.Example into image and label tensors."""
for key in ['image', 'audio']:
if key in features:
features[key] = _postprocess_image(
image=features[key],
is_training=self._is_training,
is_ssl=self._is_ssl,
num_frames=self._num_frames,
num_test_clips=self._num_test_clips,
num_test_crops=self._num_test_crops)
return features, label
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
import io
# Import libraries
import numpy as np
from PIL import Image
import tensorflow as tf
from official.vision.beta.projects.video_ssl.configs import video_ssl as exp_cfg
from official.vision.beta.projects.video_ssl.dataloaders import video_ssl_input
AUDIO_KEY = 'features/audio'
def fake_seq_example():
# Create fake data.
random_image = np.random.randint(0, 256, size=(263, 320, 3), dtype=np.uint8)
random_image = Image.fromarray(random_image)
label = 42
with io.BytesIO() as buffer:
random_image.save(buffer, format='JPEG')
raw_image_bytes = buffer.getvalue()
seq_example = tf.train.SequenceExample()
seq_example.feature_lists.feature_list.get_or_create(
video_ssl_input.IMAGE_KEY).feature.add().bytes_list.value[:] = [
raw_image_bytes
]
seq_example.feature_lists.feature_list.get_or_create(
video_ssl_input.IMAGE_KEY).feature.add().bytes_list.value[:] = [
raw_image_bytes
]
seq_example.context.feature[video_ssl_input.LABEL_KEY].int64_list.value[:] = [
label
]
random_audio = np.random.normal(size=(10, 256)).tolist()
for s in random_audio:
seq_example.feature_lists.feature_list.get_or_create(
AUDIO_KEY).feature.add().float_list.value[:] = s
return seq_example, label
class VideoAndLabelParserTest(tf.test.TestCase):
def test_video_ssl_input_pretrain(self):
params = exp_cfg.video_ssl_pretrain_kinetics600().task.train_data
decoder = video_ssl_input.Decoder()
parser = video_ssl_input.Parser(params).parse_fn(params.is_training)
seq_example, _ = fake_seq_example()
input_tensor = tf.constant(seq_example.SerializeToString())
decoded_tensors = decoder.decode(input_tensor)
output_tensor = parser(decoded_tensors)
image_features, _ = output_tensor
image = image_features['image']
self.assertAllEqual(image.shape, (32, 224, 224, 3))
def test_video_ssl_input_linear_train(self):
params = exp_cfg.video_ssl_linear_eval_kinetics600().task.train_data
decoder = video_ssl_input.Decoder()
parser = video_ssl_input.Parser(params).parse_fn(params.is_training)
seq_example, label = fake_seq_example()
input_tensor = tf.constant(seq_example.SerializeToString())
decoded_tensors = decoder.decode(input_tensor)
output_tensor = parser(decoded_tensors)
image_features, label = output_tensor
image = image_features['image']
self.assertAllEqual(image.shape, (32, 224, 224, 3))
self.assertAllEqual(label.shape, (600,))
def test_video_ssl_input_linear_eval(self):
params = exp_cfg.video_ssl_linear_eval_kinetics600().task.validation_data
print('!!!', params)
decoder = video_ssl_input.Decoder()
parser = video_ssl_input.Parser(params).parse_fn(params.is_training)
seq_example, label = fake_seq_example()
input_tensor = tf.constant(seq_example.SerializeToString())
decoded_tensors = decoder.decode(input_tensor)
output_tensor = parser(decoded_tensors)
image_features, label = output_tensor
image = image_features['image']
self.assertAllEqual(image.shape, (960, 256, 256, 3))
self.assertAllEqual(label.shape, (600,))
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Define losses."""
# Import libraries
import tensorflow as tf
from tensorflow.compiler.tf2xla.python import xla
def contrastive_loss(hidden,
num_replicas,
normalize_hidden,
temperature,
model,
weight_decay):
"""Computes contrastive loss.
Args:
hidden: embedding of video clips after projection head.
num_replicas: number of distributed replicas.
normalize_hidden: whether or not to l2 normalize the hidden vector.
temperature: temperature in the InfoNCE contrastive loss.
model: keras model for calculating weight decay.
weight_decay: weight decay parameter.
Returns:
A loss scalar.
The logits for contrastive prediction task.
The labels for contrastive prediction task.
"""
large_num = 1e9
hidden1, hidden2 = tf.split(hidden, num_or_size_splits=2, axis=0)
if normalize_hidden:
hidden1 = tf.math.l2_normalize(hidden1, -1)
hidden2 = tf.math.l2_normalize(hidden2, -1)
batch_size = tf.shape(hidden1)[0]
if num_replicas == 1:
# This is the local version
hidden1_large = hidden1
hidden2_large = hidden2
labels = tf.one_hot(tf.range(batch_size), batch_size * 2)
masks = tf.one_hot(tf.range(batch_size), batch_size)
else:
# This is the cross-tpu version.
hidden1_large = tpu_cross_replica_concat(hidden1, num_replicas)
hidden2_large = tpu_cross_replica_concat(hidden2, num_replicas)
enlarged_batch_size = tf.shape(hidden1_large)[0]
replica_id = tf.cast(tf.cast(xla.replica_id(), tf.uint32), tf.int32)
labels_idx = tf.range(batch_size) + replica_id * batch_size
labels = tf.one_hot(labels_idx, enlarged_batch_size * 2)
masks = tf.one_hot(labels_idx, enlarged_batch_size)
logits_aa = tf.matmul(hidden1, hidden1_large, transpose_b=True) / temperature
logits_aa = logits_aa - tf.cast(masks, logits_aa.dtype) * large_num
logits_bb = tf.matmul(hidden2, hidden2_large, transpose_b=True) / temperature
logits_bb = logits_bb - tf.cast(masks, logits_bb.dtype) * large_num
logits_ab = tf.matmul(hidden1, hidden2_large, transpose_b=True) / temperature
logits_ba = tf.matmul(hidden2, hidden1_large, transpose_b=True) / temperature
loss_a = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
labels, tf.concat([logits_ab, logits_aa], 1)))
loss_b = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
labels, tf.concat([logits_ba, logits_bb], 1)))
loss = loss_a + loss_b
l2_loss = weight_decay * tf.add_n([
tf.nn.l2_loss(v)
for v in model.trainable_variables
if 'kernel' in v.name
])
total_loss = loss + tf.cast(l2_loss, loss.dtype)
contrast_prob = tf.nn.softmax(logits_ab)
contrast_entropy = - tf.reduce_mean(
tf.reduce_sum(contrast_prob * tf.math.log(contrast_prob + 1e-8), -1))
contrast_acc = tf.equal(tf.argmax(labels, 1), tf.argmax(logits_ab, axis=1))
contrast_acc = tf.reduce_mean(tf.cast(contrast_acc, tf.float32))
return {
'total_loss': total_loss,
'contrastive_loss': loss,
'reg_loss': l2_loss,
'contrast_acc': contrast_acc,
'contrast_entropy': contrast_entropy,
}
def tpu_cross_replica_concat(tensor, num_replicas):
"""Reduce a concatenation of the `tensor` across TPU cores.
Args:
tensor: tensor to concatenate.
num_replicas: number of TPU device replicas.
Returns:
Tensor of the same rank as `tensor` with first dimension `num_replicas`
times larger.
"""
with tf.name_scope('tpu_cross_replica_concat'):
# This creates a tensor that is like the input tensor but has an added
# replica dimension as the outermost dimension. On each replica it will
# contain the local values and zeros for all other values that need to be
# fetched from other replicas.
ext_tensor = tf.scatter_nd(
indices=[[xla.replica_id()]],
updates=[tensor],
shape=[num_replicas] + tensor.shape.as_list())
# As every value is only present on one replica and 0 in all others, adding
# them all together will result in the full tensor on all replicas.
replica_context = tf.distribute.get_replica_context()
ext_tensor = replica_context.all_reduce(tf.distribute.ReduceOp.SUM,
ext_tensor)
# Flatten the replica dimension.
# The first dimension size will be: tensor.shape[0] * num_replicas
# Using [-1] trick to support also scalar input.
return tf.reshape(ext_tensor, [-1] + ext_tensor.shape.as_list()[2:])
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Build video classification models."""
from typing import Mapping, Optional
# Import libraries
import tensorflow as tf
from official.modeling import tf_utils
from official.vision.beta.modeling import backbones
from official.vision.beta.modeling import factory_3d as model_factory
from official.vision.beta.projects.video_ssl.configs import video_ssl as video_ssl_cfg
layers = tf.keras.layers
class VideoSSLModel(tf.keras.Model):
"""A video ssl model class builder."""
def __init__(self,
backbone,
normalize_feature,
hidden_dim,
hidden_layer_num,
hidden_norm_args,
projection_dim,
input_specs: Optional[Mapping[str,
tf.keras.layers.InputSpec]] = None,
dropout_rate: float = 0.0,
aggregate_endpoints: bool = False,
kernel_initializer='random_uniform',
kernel_regularizer=None,
bias_regularizer=None,
**kwargs):
"""Video Classification initialization function.
Args:
backbone: a 3d backbone network.
normalize_feature: whether normalize backbone feature.
hidden_dim: `int` number of hidden units in MLP.
hidden_layer_num: `int` number of hidden layers in MLP.
hidden_norm_args: `dict` for batchnorm arguments in MLP.
projection_dim: `int` number of ouput dimension for MLP.
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
dropout_rate: `float` rate for dropout regularization.
aggregate_endpoints: `bool` aggregate all end ponits or only use the
final end point.
kernel_initializer: kernel initializer for the dense layer.
kernel_regularizer: tf.keras.regularizers.Regularizer object. Default to
None.
bias_regularizer: tf.keras.regularizers.Regularizer object. Default to
None.
**kwargs: keyword arguments to be passed.
"""
if not input_specs:
input_specs = {
'image': layers.InputSpec(shape=[None, None, None, None, 3])
}
self._self_setattr_tracking = False
self._config_dict = {
'backbone': backbone,
'normalize_feature': normalize_feature,
'hidden_dim': hidden_dim,
'hidden_layer_num': hidden_layer_num,
'use_sync_bn': hidden_norm_args.use_sync_bn,
'norm_momentum': hidden_norm_args.norm_momentum,
'norm_epsilon': hidden_norm_args.norm_epsilon,
'activation': hidden_norm_args.activation,
'projection_dim': projection_dim,
'input_specs': input_specs,
'dropout_rate': dropout_rate,
'aggregate_endpoints': aggregate_endpoints,
'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
}
self._input_specs = input_specs
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._backbone = backbone
inputs = {
k: tf.keras.Input(shape=v.shape[1:]) for k, v in input_specs.items()
}
endpoints = backbone(inputs['image'])
if aggregate_endpoints:
pooled_feats = []
for endpoint in endpoints.values():
x_pool = tf.keras.layers.GlobalAveragePooling3D()(endpoint)
pooled_feats.append(x_pool)
x = tf.concat(pooled_feats, axis=1)
else:
x = endpoints[max(endpoints.keys())]
x = tf.keras.layers.GlobalAveragePooling3D()(x)
# L2 Normalize feature after backbone
if normalize_feature:
x = tf.nn.l2_normalize(x, axis=-1)
# MLP hidden layers
for _ in range(hidden_layer_num):
x = tf.keras.layers.Dense(hidden_dim)(x)
if self._config_dict['use_sync_bn']:
x = tf.keras.layers.experimental.SyncBatchNormalization(
momentum=self._config_dict['norm_momentum'],
epsilon=self._config_dict['norm_epsilon'])(x)
else:
x = tf.keras.layers.BatchNormalization(
momentum=self._config_dict['norm_momentum'],
epsilon=self._config_dict['norm_epsilon'])(x)
x = tf_utils.get_activation(self._config_dict['activation'])(x)
# Projection head
x = tf.keras.layers.Dense(projection_dim)(x)
super(VideoSSLModel, self).__init__(
inputs=inputs, outputs=x, **kwargs)
@property
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
return dict(backbone=self.backbone)
@property
def backbone(self):
return self._backbone
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@model_factory.register_model_builder('video_ssl_model')
def build_video_ssl_pretrain_model(
input_specs: tf.keras.layers.InputSpec,
model_config: video_ssl_cfg.VideoSSLModel,
num_classes: int,
l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None):
"""Builds the video classification model."""
del num_classes
input_specs_dict = {'image': input_specs}
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
backbone_config=model_config.backbone,
norm_activation_config=model_config.norm_activation,
l2_regularizer=l2_regularizer)
# Norm layer type in the MLP head should same with backbone
assert model_config.norm_activation.use_sync_bn == model_config.hidden_norm_activation.use_sync_bn
model = VideoSSLModel(
backbone=backbone,
normalize_feature=model_config.normalize_feature,
hidden_dim=model_config.hidden_dim,
hidden_layer_num=model_config.hidden_layer_num,
hidden_norm_args=model_config.hidden_norm_activation,
projection_dim=model_config.projection_dim,
input_specs=input_specs_dict,
dropout_rate=model_config.dropout_rate,
aggregate_endpoints=model_config.aggregate_endpoints,
kernel_regularizer=l2_regularizer)
return model
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Utils for customed ops for video ssl."""
import functools
from typing import Optional
import tensorflow as tf
def random_apply(func, p, x):
"""Randomly apply function func to x with probability p."""
return tf.cond(
tf.less(tf.random.uniform([], minval=0, maxval=1, dtype=tf.float32),
tf.cast(p, tf.float32)),
lambda: func(x),
lambda: x)
def random_brightness(image, max_delta):
"""Distort brightness of image (SimCLRv2 style)."""
factor = tf.random.uniform(
[], tf.maximum(1.0 - max_delta, 0), 1.0 + max_delta)
image = image * factor
return image
def random_solarization(image, p=0.2):
"""Random solarize image."""
def _transform(image):
image = image * tf.cast(tf.less(image, 0.5), dtype=image.dtype) + (
1.0 - image) * tf.cast(tf.greater_equal(image, 0.5), dtype=image.dtype)
return image
return random_apply(_transform, p=p, x=image)
def to_grayscale(image, keep_channels=True):
"""Turn the input image to gray scale.
Args:
image: The input image tensor.
keep_channels: Whether maintaining the channel number for the image.
If true, the transformed image will repeat three times in channel.
If false, the transformed image will only have one channel.
Returns:
The distorted image tensor.
"""
image = tf.image.rgb_to_grayscale(image)
if keep_channels:
image = tf.tile(image, [1, 1, 3])
return image
def color_jitter(image, strength, random_order=True):
"""Distorts the color of the image (SimCLRv2 style).
Args:
image: The input image tensor.
strength: The floating number for the strength of the color augmentation.
random_order: A bool, specifying whether to randomize the jittering order.
Returns:
The distorted image tensor.
"""
brightness = 0.8 * strength
contrast = 0.8 * strength
saturation = 0.8 * strength
hue = 0.2 * strength
if random_order:
return color_jitter_rand(
image, brightness, contrast, saturation, hue)
else:
return color_jitter_nonrand(
image, brightness, contrast, saturation, hue)
def color_jitter_nonrand(image,
brightness=0,
contrast=0,
saturation=0,
hue=0):
"""Distorts the color of the image (jittering order is fixed, SimCLRv2 style).
Args:
image: The input image tensor.
brightness: A float, specifying the brightness for color jitter.
contrast: A float, specifying the contrast for color jitter.
saturation: A float, specifying the saturation for color jitter.
hue: A float, specifying the hue for color jitter.
Returns:
The distorted image tensor.
"""
with tf.name_scope('distort_color'):
def apply_transform(i, x, brightness, contrast, saturation, hue):
"""Apply the i-th transformation."""
if brightness != 0 and i == 0:
x = random_brightness(x, max_delta=brightness)
elif contrast != 0 and i == 1:
x = tf.image.random_contrast(
x, lower=1-contrast, upper=1+contrast)
elif saturation != 0 and i == 2:
x = tf.image.random_saturation(
x, lower=1-saturation, upper=1+saturation)
elif hue != 0:
x = tf.image.random_hue(x, max_delta=hue)
return x
for i in range(4):
image = apply_transform(i, image, brightness, contrast, saturation, hue)
image = tf.clip_by_value(image, 0., 1.)
return image
def color_jitter_rand(image,
brightness=0,
contrast=0,
saturation=0,
hue=0):
"""Distorts the color of the image (jittering order is random, SimCLRv2 style).
Args:
image: The input image tensor.
brightness: A float, specifying the brightness for color jitter.
contrast: A float, specifying the contrast for color jitter.
saturation: A float, specifying the saturation for color jitter.
hue: A float, specifying the hue for color jitter.
Returns:
The distorted image tensor.
"""
with tf.name_scope('distort_color'):
def apply_transform(i, x):
"""Apply the i-th transformation."""
def brightness_transform():
if brightness == 0:
return x
else:
return random_brightness(x, max_delta=brightness)
def contrast_transform():
if contrast == 0:
return x
else:
return tf.image.random_contrast(x, lower=1-contrast, upper=1+contrast)
def saturation_transform():
if saturation == 0:
return x
else:
return tf.image.random_saturation(
x, lower=1-saturation, upper=1+saturation)
def hue_transform():
if hue == 0:
return x
else:
return tf.image.random_hue(x, max_delta=hue)
# pylint:disable=g-long-lambda
x = tf.cond(
tf.less(i, 2), lambda: tf.cond(
tf.less(i, 1), brightness_transform, contrast_transform),
lambda: tf.cond(tf.less(i, 3), saturation_transform, hue_transform))
# pylint:disable=g-long-lambda
return x
perm = tf.random.shuffle(tf.range(4))
for i in range(4):
image = apply_transform(perm[i], image)
image = tf.clip_by_value(image, 0., 1.)
return image
def random_color_jitter_3d(frames):
"""Applies temporally consistent color jittering to one video clip.
Args:
frames: `Tensor` of shape [num_frames, height, width, channels].
Returns:
A Tensor of shape [num_frames, height, width, channels] being color jittered
with the same operation.
"""
def random_color_jitter(image, p=1.0):
def _transform(image):
color_jitter_t = functools.partial(
color_jitter, strength=1.0)
image = random_apply(color_jitter_t, p=0.8, x=image)
return random_apply(to_grayscale, p=0.2, x=image)
return random_apply(_transform, p=p, x=image)
num_frames, width, height, channels = frames.shape.as_list()
big_image = tf.reshape(frames, [num_frames*width, height, channels])
big_image = random_color_jitter(big_image)
return tf.reshape(big_image, [num_frames, width, height, channels])
def gaussian_blur(image, kernel_size, sigma, padding='SAME'):
"""Blurs the given image with separable convolution.
Args:
image: Tensor of shape [height, width, channels] and dtype float to blur.
kernel_size: Integer Tensor for the size of the blur kernel. This is should
be an odd number. If it is an even number, the actual kernel size will be
size + 1.
sigma: Sigma value for gaussian operator.
padding: Padding to use for the convolution. Typically 'SAME' or 'VALID'.
Returns:
A Tensor representing the blurred image.
"""
radius = tf.cast(kernel_size / 2, dtype=tf.int32)
kernel_size = radius * 2 + 1
x = tf.cast(tf.range(-radius, radius + 1), dtype=tf.float32)
blur_filter = tf.exp(
-tf.pow(x, 2.0) / (2.0 * tf.pow(tf.cast(sigma, dtype=tf.float32), 2.0)))
blur_filter /= tf.reduce_sum(blur_filter)
# One vertical and one horizontal filter.
blur_v = tf.reshape(blur_filter, [kernel_size, 1, 1, 1])
blur_h = tf.reshape(blur_filter, [1, kernel_size, 1, 1])
num_channels = tf.shape(image)[-1]
blur_h = tf.tile(blur_h, [1, 1, num_channels, 1])
blur_v = tf.tile(blur_v, [1, 1, num_channels, 1])
expand_batch_dim = image.shape.ndims == 3
if expand_batch_dim:
# Tensorflow requires batched input to convolutions, which we can fake with
# an extra dimension.
image = tf.expand_dims(image, axis=0)
blurred = tf.nn.depthwise_conv2d(
image, blur_h, strides=[1, 1, 1, 1], padding=padding)
blurred = tf.nn.depthwise_conv2d(
blurred, blur_v, strides=[1, 1, 1, 1], padding=padding)
if expand_batch_dim:
blurred = tf.squeeze(blurred, axis=0)
return blurred
def random_blur(image, height, width, p=1.0):
"""Randomly blur an image.
Args:
image: `Tensor` representing an image of arbitrary size.
height: Height of output image.
width: Width of output image.
p: probability of applying this transformation.
Returns:
A preprocessed image `Tensor`.
"""
del width
def _transform(image):
sigma = tf.random.uniform([], 0.1, 2.0, dtype=tf.float32)
return gaussian_blur(
image, kernel_size=height//10, sigma=sigma, padding='SAME')
return random_apply(_transform, p=p, x=image)
def random_blur_3d(frames, height, width, blur_probability=0.5):
"""Apply efficient batch data transformations.
Args:
frames: `Tensor` of shape [timesteps, height, width, 3].
height: the height of image.
width: the width of image.
blur_probability: the probaility to apply the blur operator.
Returns:
Preprocessed feature list.
"""
def generate_selector(p, bsz):
shape = [bsz, 1, 1, 1]
selector = tf.cast(
tf.less(tf.random.uniform(shape, 0, 1, dtype=tf.float32), p),
tf.float32)
return selector
frames_new = random_blur(frames, height, width, p=1.)
selector = generate_selector(blur_probability, 1)
frames = frames_new * selector + frames * (1 - selector)
frames = tf.clip_by_value(frames, 0., 1.)
return frames
def _sample_or_pad_sequence_indices(sequence: tf.Tensor,
num_steps: int,
stride: int,
offset: tf.Tensor) -> tf.Tensor:
"""Returns indices to take for sampling or padding sequences to fixed size."""
sequence_length = tf.shape(sequence)[0]
sel_idx = tf.range(sequence_length)
# Repeats sequence until num_steps are available in total.
max_length = num_steps * stride + offset
num_repeats = tf.math.floordiv(
max_length + sequence_length - 1, sequence_length)
sel_idx = tf.tile(sel_idx, [num_repeats])
steps = tf.range(offset, offset + num_steps * stride, stride)
return tf.gather(sel_idx, steps)
def sample_ssl_sequence(sequence: tf.Tensor,
num_steps: int,
random: bool,
stride: int = 1,
num_windows: Optional[int] = 2) -> tf.Tensor:
"""Samples two segments of size num_steps randomly from a given sequence.
Currently it only supports images, and specically designed for video self-
supervised learning.
Args:
sequence: Any tensor where the first dimension is timesteps.
num_steps: Number of steps (e.g. frames) to take.
random: A boolean indicating whether to random sample the single window. If
True, the offset is randomized. Only True is supported.
stride: Distance to sample between timesteps.
num_windows: Number of sequence sampled.
Returns:
A single Tensor with first dimension num_steps with the sampled segment.
"""
sequence_length = tf.shape(sequence)[0]
sequence_length = tf.cast(sequence_length, tf.float32)
if random:
max_offset = tf.cond(
tf.greater(sequence_length, (num_steps - 1) * stride),
lambda: sequence_length - (num_steps - 1) * stride,
lambda: sequence_length)
max_offset = tf.cast(max_offset, dtype=tf.float32)
def cdf(k, power=1.0):
"""Cumulative distribution function for x^power."""
p = -tf.math.pow(k, power + 1) / (
power * tf.math.pow(max_offset, power + 1)) + k * (power + 1) / (
power * max_offset)
return p
u = tf.random.uniform(())
k_low = tf.constant(0, dtype=tf.float32)
k_up = max_offset
k = tf.math.floordiv(max_offset, 2.0)
c = lambda k_low, k_up, k: tf.greater(tf.math.abs(k_up - k_low), 1.0)
# pylint:disable=g-long-lambda
b = lambda k_low, k_up, k: tf.cond(
tf.greater(cdf(k), u),
lambda: [k_low, k, tf.math.floordiv(k + k_low, 2.0)],
lambda: [k, k_up, tf.math.floordiv(k_up + k, 2.0)])
_, _, k = tf.while_loop(c, b, [k_low, k_up, k])
delta = tf.cast(k, tf.int32)
max_offset = tf.cast(max_offset, tf.int32)
sequence_length = tf.cast(sequence_length, tf.int32)
choice_1 = tf.cond(
tf.equal(max_offset, sequence_length),
lambda: tf.random.uniform((),
maxval=tf.cast(max_offset, dtype=tf.int32),
dtype=tf.int32),
lambda: tf.random.uniform((),
maxval=tf.cast(max_offset - delta,
dtype=tf.int32),
dtype=tf.int32))
choice_2 = tf.cond(
tf.equal(max_offset, sequence_length),
lambda: tf.random.uniform((),
maxval=tf.cast(max_offset, dtype=tf.int32),
dtype=tf.int32),
lambda: choice_1 + delta)
# pylint:disable=g-long-lambda
shuffle_choice = tf.random.shuffle((choice_1, choice_2))
offset_1 = shuffle_choice[0]
offset_2 = shuffle_choice[1]
else:
raise NotImplementedError
indices_1 = _sample_or_pad_sequence_indices(
sequence=sequence,
num_steps=num_steps,
stride=stride,
offset=offset_1)
indices_2 = _sample_or_pad_sequence_indices(
sequence=sequence,
num_steps=num_steps,
stride=stride,
offset=offset_2)
indices = tf.concat([indices_1, indices_2], axis=0)
indices.set_shape((num_windows * num_steps,))
output = tf.gather(sequence, indices)
return output
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from official.vision.beta.ops import preprocess_ops_3d
from official.vision.beta.projects.video_ssl.ops import video_ssl_preprocess_ops
class VideoSslPreprocessOpsTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self._raw_frames = tf.random.uniform((250, 256, 256, 3), minval=0,
maxval=255, dtype=tf.dtypes.int32)
self._sampled_frames = self._raw_frames[:16]
self._frames = preprocess_ops_3d.normalize_image(
self._sampled_frames, False, tf.float32)
def test_sample_ssl_sequence(self):
sampled_seq = video_ssl_preprocess_ops.sample_ssl_sequence(
self._raw_frames, 16, True, 2)
self.assertAllEqual(sampled_seq.shape, (32, 256, 256, 3))
def test_random_color_jitter_3d(self):
jittered_clip = video_ssl_preprocess_ops.random_color_jitter_3d(
self._frames)
self.assertAllEqual(jittered_clip.shape, (16, 256, 256, 3))
def test_random_blur_3d(self):
blurred_clip = video_ssl_preprocess_ops.random_blur_3d(
self._frames, 256, 256)
self.assertAllEqual(blurred_clip.shape, (16, 256, 256, 3))
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tasks package definition."""
from official.vision.beta.projects.video_ssl.tasks import linear_eval
from official.vision.beta.projects.video_ssl.tasks import pretrain
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Video ssl linear evaluation task definition."""
from typing import Any, Optional, List, Tuple
from absl import logging
import tensorflow as tf
# pylint: disable=unused-import
from official.core import task_factory
from official.vision.beta.projects.video_ssl.configs import video_ssl as exp_cfg
from official.vision.beta.projects.video_ssl.modeling import video_ssl_model
from official.vision.beta.tasks import video_classification
@task_factory.register_task_cls(exp_cfg.VideoSSLEvalTask)
class VideoSSLEvalTask(video_classification.VideoClassificationTask):
"""A task for video ssl linear evaluation."""
def initialize(self, model: tf.keras.Model):
"""Loading pretrained checkpoint."""
if not self.task_config.init_checkpoint:
return
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
# Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone)
ckpt.read(ckpt_dir_or_file)
else:
raise NotImplementedError
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def train_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics: Optional[List[Any]] = None):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
model.backbone.trainable = False
logging.info('Setting the backbone to non-trainable.')
return super(video_classification.VideoClassificationTask,
self).train_step(inputs, model, optimizer, metrics)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Video ssl pretrain task definition."""
from absl import logging
import tensorflow as tf
# pylint: disable=unused-import
from official.core import input_reader
from official.core import task_factory
from official.vision.beta.modeling import factory_3d
from official.vision.beta.projects.video_ssl.configs import video_ssl as exp_cfg
from official.vision.beta.projects.video_ssl.dataloaders import video_ssl_input
from official.vision.beta.projects.video_ssl.losses import losses
from official.vision.beta.projects.video_ssl.modeling import video_ssl_model
from official.vision.beta.tasks import video_classification
@task_factory.register_task_cls(exp_cfg.VideoSSLPretrainTask)
class VideoSSLPretrainTask(video_classification.VideoClassificationTask):
"""A task for video ssl pretraining."""
def build_model(self):
"""Builds video ssl pretraining model."""
common_input_shape = [
d1 if d1 == d2 else None
for d1, d2 in zip(self.task_config.train_data.feature_shape,
self.task_config.validation_data.feature_shape)
]
input_specs = tf.keras.layers.InputSpec(shape=[None] + common_input_shape)
logging.info('Build model input %r', common_input_shape)
model = factory_3d.build_model(
self.task_config.model.model_type,
input_specs=input_specs,
model_config=self.task_config.model,
num_classes=self.task_config.train_data.num_classes)
return model
def _get_decoder_fn(self, params):
decoder = video_ssl_input.Decoder()
return decoder.decode
def build_inputs(self, params: exp_cfg.DataConfig, input_context=None):
"""Builds classification input."""
parser = video_ssl_input.Parser(input_params=params)
postprocess_fn = video_ssl_input.PostBatchProcessor(params)
reader = input_reader.InputReader(
params,
dataset_fn=self._get_dataset_fn(params),
decoder_fn=self._get_decoder_fn(params),
parser_fn=parser.parse_fn(params.is_training),
postprocess_fn=postprocess_fn)
dataset = reader.read(input_context=input_context)
return dataset
def build_losses(self, model_outputs, num_replicas, model):
"""Sparse categorical cross entropy loss.
Args:
model_outputs: Output logits of the model.
num_replicas: distributed replica number.
model: keras model for calculating weight decay.
Returns:
The total loss tensor.
"""
all_losses = {}
contrastive_metrics = {}
losses_config = self.task_config.losses
total_loss = None
contrastive_loss_dict = losses.contrastive_loss(
model_outputs, num_replicas, losses_config.normalize_hidden,
losses_config.temperature, model,
self.task_config.losses.l2_weight_decay)
total_loss = contrastive_loss_dict['total_loss']
all_losses.update({
'total_loss': total_loss
})
all_losses[self.loss] = total_loss
contrastive_metrics.update({
'contrast_acc': contrastive_loss_dict['contrast_acc'],
'contrast_entropy': contrastive_loss_dict['contrast_entropy'],
'reg_loss': contrastive_loss_dict['reg_loss']
})
return all_losses, contrastive_metrics
def build_metrics(self, training=True):
"""Gets streaming metrics for training/validation."""
metrics = [
tf.keras.metrics.Mean(name='contrast_acc'),
tf.keras.metrics.Mean(name='contrast_entropy'),
tf.keras.metrics.Mean(name='reg_loss')
]
return metrics
def process_metrics(self, metrics, contrastive_metrics):
"""Process and update metrics."""
contrastive_metric_values = contrastive_metrics.values()
for metric, contrastive_metric_value in zip(metrics,
contrastive_metric_values):
metric.update_state(contrastive_metric_value)
def train_step(self, inputs, model, optimizer, metrics=None):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, _ = inputs
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape:
if self.task_config.train_data.output_audio:
outputs = model(features, training=True)
else:
outputs = model(features['image'], training=True)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure(
lambda x: tf.cast(x, tf.float32), outputs)
all_losses, contrastive_metrics = self.build_losses(
model_outputs=outputs, num_replicas=num_replicas,
model=model)
loss = all_losses[self.loss]
scaled_loss = loss
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if isinstance(
optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(list(zip(grads, tvars)))
logs = all_losses
if metrics:
self.process_metrics(metrics, contrastive_metrics)
logs.update({m.name: m.result() for m in metrics})
return logs
def validation_step(self, inputs, model, metrics=None):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
raise NotImplementedError
def inference_step(self, features, model):
"""Performs the forward step."""
raise NotImplementedError
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
import functools
import os
import random
import orbit
import tensorflow as tf
# pylint: disable=unused-import
from official.core import exp_factory
from official.core import task_factory
from official.modeling import optimization
from official.vision import beta
from official.vision.beta.dataloaders import tfexample_utils
from official.vision.beta.projects.video_ssl.tasks import pretrain
class VideoClassificationTaskTest(tf.test.TestCase):
def setUp(self):
super(VideoClassificationTaskTest, self).setUp()
data_dir = os.path.join(self.get_temp_dir(), 'data')
tf.io.gfile.makedirs(data_dir)
self._data_path = os.path.join(data_dir, 'data.tfrecord')
# pylint: disable=g-complex-comprehension
examples = [
tfexample_utils.make_video_test_example(
image_shape=(36, 36, 3),
audio_shape=(20, 128),
label=random.randint(0, 100)) for _ in range(2)
]
# pylint: enable=g-complex-comprehension
tfexample_utils.dump_to_tfrecord(self._data_path, tf_examples=examples)
def test_task(self):
config = exp_factory.get_exp_config('video_ssl_pretrain_kinetics600')
config.task.train_data.global_batch_size = 2
config.task.train_data.input_path = self._data_path
task = pretrain.VideoSSLPretrainTask(
config.task)
model = task.build_model()
metrics = task.build_metrics()
strategy = tf.distribute.get_strategy()
dataset = orbit.utils.make_distributed_dataset(
strategy,
functools.partial(task.build_inputs),
config.task.train_data)
iterator = iter(dataset)
opt_factory = optimization.OptimizerFactory(config.trainer.optimizer_config)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
logs = task.train_step(next(iterator), model, optimizer, metrics=metrics)
self.assertIn('total_loss', logs)
self.assertIn('reg_loss', logs)
self.assertIn('contrast_acc', logs)
self.assertIn('contrast_entropy', logs)
def test_task_factory(self):
config = exp_factory.get_exp_config('video_ssl_pretrain_kinetics600')
task = task_factory.get_task(config.task)
self.assertIs(type(task), pretrain.VideoSSLPretrainTask)
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Training driver."""
from absl import app
from absl import flags
import gin
# pylint: disable=unused-import
from official.common import registry_imports
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
from official.vision.beta.projects.video_ssl.modeling import video_ssl_model
from official.vision.beta.projects.video_ssl.tasks import linear_eval
from official.vision.beta.projects.video_ssl.tasks import pretrain
# pylint: disable=unused-import
FLAGS = flags.FLAGS
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
if 'train' in FLAGS.mode:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
if 'train_and_eval' in FLAGS.mode:
assert (params.task.train_data.feature_shape ==
params.task.validation_data.feature_shape), (
f'train {params.task.train_data.feature_shape} != validate '
f'{params.task.validation_data.feature_shape}')
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(main)
......@@ -16,7 +16,7 @@
# pylint: disable=unused-import
# pylint: disable=g-bad-import-order
from official.common import registry_imports
from official.vision import registry_imports
# import configs
from official.vision.beta.projects.yolo.configs import darknet_classification
......
......@@ -15,7 +15,7 @@
"""Backbones configurations."""
import dataclasses
from official.modeling import hyperparams
from official.vision.beta.configs import backbones
from official.vision.configs import backbones
@dataclasses.dataclass
......
......@@ -20,9 +20,9 @@ from typing import List, Optional
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
from official.vision.beta.configs import common
from official.vision.beta.configs import image_classification as imc
from official.vision.beta.projects.yolo.configs import backbones
from official.vision.configs import common
from official.vision.configs import image_classification as imc
@dataclasses.dataclass
......
......@@ -16,7 +16,7 @@
import dataclasses
from typing import Optional
from official.modeling import hyperparams
from official.vision.beta.configs import decoders
from official.vision.configs import decoders
@dataclasses.dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment