Unverified Commit ca552843 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'panoptic-segmentation' into panoptic-segmentation

parents 7e2f7a35 6b90e134
......@@ -12,22 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains common building blocks for simclr neural networks."""
from typing import Text, Optional
......
......@@ -12,22 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from absl.testing import parameterized
import tensorflow as tf
......
......@@ -14,6 +14,7 @@
"""Multi-task image multi-taskSimCLR model definition."""
from typing import Dict, Text
from absl import logging
import tensorflow as tf
......@@ -52,15 +53,10 @@ class SimCLRMTModel(base_model.MultiTaskBaseModel):
norm_activation_config=config.norm_activation,
l2_regularizer=self._l2_regularizer)
super().__init__(**kwargs)
def _instantiate_sub_tasks(self) -> Dict[Text, tf.keras.Model]:
tasks = {}
# Build the shared projection head
norm_activation_config = self._config.norm_activation
projection_head_config = self._config.projection_head
projection_head = simclr_head.ProjectionHead(
self._projection_head = simclr_head.ProjectionHead(
proj_output_dim=projection_head_config.proj_output_dim,
num_proj_layers=projection_head_config.num_proj_layers,
ft_proj_idx=projection_head_config.ft_proj_idx,
......@@ -69,6 +65,11 @@ class SimCLRMTModel(base_model.MultiTaskBaseModel):
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon)
super().__init__(**kwargs)
def _instantiate_sub_tasks(self) -> Dict[Text, tf.keras.Model]:
tasks = {}
for model_config in self._config.heads:
# Build supervised head
supervised_head_config = model_config.supervised_head
......@@ -84,16 +85,41 @@ class SimCLRMTModel(base_model.MultiTaskBaseModel):
else:
supervised_head = None
tasks[model_config.mode] = simclr_model.SimCLRModel(
tasks[model_config.task_name] = simclr_model.SimCLRModel(
input_specs=self._input_specs,
backbone=self._backbone,
projection_head=projection_head,
projection_head=self._projection_head,
supervised_head=supervised_head,
mode=model_config.mode,
backbone_trainable=self._config.backbone_trainable)
return tasks
# TODO(huythong): Implement initialize function to load the pretrained
# checkpoint of backbone.
# def initialize(self):
def initialize(self):
"""Loads the multi-task SimCLR model with a pretrained checkpoint."""
ckpt_dir_or_file = self._config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file:
return
logging.info('Loading pretrained %s', self._config.init_checkpoint_modules)
if self._config.init_checkpoint_modules == 'backbone':
pretrained_items = dict(backbone=self._backbone)
elif self._config.init_checkpoint_modules == 'backbone_projection':
pretrained_items = dict(
backbone=self._backbone, projection_head=self._projection_head)
else:
assert ("Only 'backbone_projection' or 'backbone' can be used to "
'initialize the model.')
ckpt = tf.train.Checkpoint(**pretrained_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
@property
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
return dict(backbone=self._backbone, projection_head=self._projection_head)
......@@ -29,11 +29,13 @@ class MultitaskModelTest(tf.test.TestCase):
ckpt_dir = self.get_temp_dir()
config = multitask_config.SimCLRMTModelConfig(
input_size=[64, 64, 3],
heads=(multitask_config.SimCLRMTHeadConfig(mode=simclr_model.PRETRAIN),
multitask_config.SimCLRMTHeadConfig(mode=simclr_model.FINETUNE)))
heads=(multitask_config.SimCLRMTHeadConfig(
mode=simclr_model.PRETRAIN, task_name='pretrain_simclr'),
multitask_config.SimCLRMTHeadConfig(
mode=simclr_model.FINETUNE, task_name='finetune_simclr')))
model = multitask_model.SimCLRMTModel(config)
self.assertIn(simclr_model.PRETRAIN, model.sub_tasks)
self.assertIn(simclr_model.FINETUNE, model.sub_tasks)
self.assertIn('pretrain_simclr', model.sub_tasks)
self.assertIn('finetune_simclr', model.sub_tasks)
ckpt = tf.train.Checkpoint(backbone=model._backbone)
ckpt.save(os.path.join(ckpt_dir, 'ckpt'))
model.initialize()
......
......@@ -12,22 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Build simclr models."""
from typing import Optional
from absl import logging
......@@ -133,12 +118,12 @@ class SimCLRModel(tf.keras.Model):
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
if self._supervised_head is not None:
items = dict(backbone=self.backbone,
projection_head=self.projection_head,
supervised_head=self.supervised_head)
items = dict(
backbone=self.backbone,
projection_head=self.projection_head,
supervised_head=self.supervised_head)
else:
items = dict(backbone=self.backbone,
projection_head=self.projection_head)
items = dict(backbone=self.backbone, projection_head=self.projection_head)
return items
@property
......
......@@ -12,22 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test for SimCLR model."""
from absl.testing import parameterized
import numpy as np
......
......@@ -12,21 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Image SimCLR task definition.
SimCLR training two different modes:
......@@ -39,7 +24,6 @@ the task definition:
- training loss
- projection_head and/or supervised_head
"""
from typing import Dict, Optional
from absl import logging
......@@ -67,7 +51,8 @@ RuntimeConfig = config_definitions.RuntimeConfig
class SimCLRPretrainTask(base_task.Task):
"""A task for image classification."""
def create_optimizer(self, optimizer_config: OptimizationConfig,
def create_optimizer(self,
optimizer_config: OptimizationConfig,
runtime_config: Optional[RuntimeConfig] = None):
"""Creates an TF optimizer from configurations.
......@@ -78,8 +63,8 @@ class SimCLRPretrainTask(base_task.Task):
Returns:
A tf.optimizers.Optimizer object.
"""
if (optimizer_config.optimizer.type == 'lars'
and self.task_config.loss.l2_weight_decay > 0.0):
if (optimizer_config.optimizer.type == 'lars' and
self.task_config.loss.l2_weight_decay > 0.0):
raise ValueError('The l2_weight_decay cannot be used together with lars '
'optimizer. Please set it to 0.')
......@@ -97,15 +82,16 @@ class SimCLRPretrainTask(base_task.Task):
def build_model(self):
model_config = self.task_config.model
input_specs = tf.keras.layers.InputSpec(
shape=[None] + model_config.input_size)
input_specs = tf.keras.layers.InputSpec(shape=[None] +
model_config.input_size)
l2_weight_decay = self.task_config.loss.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf.keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None)
l2_regularizer = (
tf.keras.regularizers.l2(l2_weight_decay /
2.0) if l2_weight_decay else None)
# Build backbone
backbone = backbones.factory.build_backbone(
......@@ -164,11 +150,11 @@ class SimCLRPretrainTask(base_task.Task):
# Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.assert_consumed()
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.restore(ckpt_dir_or_file)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
assert "Only 'all' or 'backbone' can be used to initialize the model."
......@@ -220,8 +206,7 @@ class SimCLRPretrainTask(base_task.Task):
projection_outputs = model_outputs[simclr_model.PROJECTION_OUTPUT_KEY]
projection1, projection2 = tf.split(projection_outputs, 2, 0)
contrast_loss, (contrast_logits, contrast_labels) = con_losses_obj(
projection1=projection1,
projection2=projection2)
projection1=projection1, projection2=projection2)
contrast_accuracy = tf.equal(
tf.argmax(contrast_labels, axis=1), tf.argmax(contrast_logits, axis=1))
......@@ -253,8 +238,8 @@ class SimCLRPretrainTask(base_task.Task):
outputs)
sup_loss = tf.reduce_mean(sup_loss)
label_acc = tf.equal(tf.argmax(labels, axis=1),
tf.argmax(outputs, axis=1))
label_acc = tf.equal(
tf.argmax(labels, axis=1), tf.argmax(outputs, axis=1))
label_acc = tf.reduce_mean(tf.cast(label_acc, tf.float32))
model_loss = contrast_loss + sup_loss
......@@ -278,10 +263,7 @@ class SimCLRPretrainTask(base_task.Task):
if training:
metrics = []
metric_names = [
'total_loss',
'contrast_loss',
'contrast_accuracy',
'contrast_entropy'
'total_loss', 'contrast_loss', 'contrast_accuracy', 'contrast_entropy'
]
if self.task_config.model.supervised_head:
metric_names.extend(['supervised_loss', 'accuracy'])
......@@ -293,18 +275,26 @@ class SimCLRPretrainTask(base_task.Task):
metrics = [
tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
tf.keras.metrics.TopKCategoricalAccuracy(
k=k, name='top_{}_accuracy'.format(k))]
k=k, name='top_{}_accuracy'.format(k))
]
else:
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
tf.keras.metrics.SparseTopKCategoricalAccuracy(
k=k, name='top_{}_accuracy'.format(k))]
k=k, name='top_{}_accuracy'.format(k))
]
return metrics
def train_step(self, inputs, model, optimizer, metrics=None):
features, labels = inputs
if (self.task_config.model.supervised_head is not None
and self.task_config.evaluation.one_hot):
# To do a sanity check that we absolutely use no labels when pretraining, we
# can set the labels here to zero.
if self.task_config.train_data.input_set_label_to_zero:
labels *= 0
if (self.task_config.model.supervised_head is not None and
self.task_config.evaluation.one_hot):
num_classes = self.task_config.model.supervised_head.num_classes
labels = tf.one_hot(labels, num_classes)
......@@ -313,8 +303,7 @@ class SimCLRPretrainTask(base_task.Task):
outputs = model(features, training=True)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure(
lambda x: tf.cast(x, tf.float32), outputs)
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
# Computes per-replica loss.
losses = self.build_losses(
......@@ -373,7 +362,8 @@ class SimCLRPretrainTask(base_task.Task):
class SimCLRFinetuneTask(base_task.Task):
"""A task for image classification."""
def create_optimizer(self, optimizer_config: OptimizationConfig,
def create_optimizer(self,
optimizer_config: OptimizationConfig,
runtime_config: Optional[RuntimeConfig] = None):
"""Creates an TF optimizer from configurations.
......@@ -384,8 +374,8 @@ class SimCLRFinetuneTask(base_task.Task):
Returns:
A tf.optimizers.Optimizer object.
"""
if (optimizer_config.optimizer.type == 'lars'
and self.task_config.loss.l2_weight_decay > 0.0):
if (optimizer_config.optimizer.type == 'lars' and
self.task_config.loss.l2_weight_decay > 0.0):
raise ValueError('The l2_weight_decay cannot be used together with lars '
'optimizer. Please set it to 0.')
......@@ -403,15 +393,16 @@ class SimCLRFinetuneTask(base_task.Task):
def build_model(self):
model_config = self.task_config.model
input_specs = tf.keras.layers.InputSpec(
shape=[None] + model_config.input_size)
input_specs = tf.keras.layers.InputSpec(shape=[None] +
model_config.input_size)
l2_weight_decay = self.task_config.loss.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf.keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None)
l2_regularizer = (
tf.keras.regularizers.l2(l2_weight_decay /
2.0) if l2_weight_decay else None)
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
......@@ -464,16 +455,16 @@ class SimCLRFinetuneTask(base_task.Task):
# Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.assert_consumed()
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'backbone_projection':
ckpt = tf.train.Checkpoint(backbone=model.backbone,
projection_head=model.projection_head)
status = ckpt.restore(ckpt_dir_or_file)
ckpt = tf.train.Checkpoint(
backbone=model.backbone, projection_head=model.projection_head)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.restore(ckpt_dir_or_file)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
assert "Only 'all' or 'backbone' can be used to initialize the model."
......@@ -542,12 +533,14 @@ class SimCLRFinetuneTask(base_task.Task):
metrics = [
tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
tf.keras.metrics.TopKCategoricalAccuracy(
k=k, name='top_{}_accuracy'.format(k))]
k=k, name='top_{}_accuracy'.format(k))
]
else:
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
tf.keras.metrics.SparseTopKCategoricalAccuracy(
k=k, name='top_{}_accuracy'.format(k))]
k=k, name='top_{}_accuracy'.format(k))
]
return metrics
def train_step(self, inputs, model, optimizer, metrics=None):
......@@ -577,16 +570,14 @@ class SimCLRFinetuneTask(base_task.Task):
# Computes per-replica loss.
loss = self.build_losses(
model_outputs=outputs,
labels=labels, aux_losses=model.losses)
model_outputs=outputs, labels=labels, aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if isinstance(
optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables
......@@ -596,8 +587,7 @@ class SimCLRFinetuneTask(base_task.Task):
grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if isinstance(
optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(list(zip(grads, tvars)))
......@@ -626,11 +616,11 @@ class SimCLRFinetuneTask(base_task.Task):
num_classes = self.task_config.model.supervised_head.num_classes
labels = tf.one_hot(labels, num_classes)
outputs = self.inference_step(
features, model)[simclr_model.SUPERVISED_OUTPUT_KEY]
outputs = self.inference_step(features,
model)[simclr_model.SUPERVISED_OUTPUT_KEY]
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
loss = self.build_losses(model_outputs=outputs,
labels=labels, aux_losses=model.losses)
loss = self.build_losses(
model_outputs=outputs, labels=labels, aux_losses=model.losses)
logs = {self.loss: loss}
if metrics:
......
......@@ -12,22 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow Model Garden Vision SimCLR training driver."""
"""TensorFlow Model Garden Vision SimCLR trainer."""
from absl import app
from absl import flags
import gin
......
# Spatiotemporal Contrastive Video Representation Learning
[![Paper](http://img.shields.io/badge/Paper-arXiv.2008.03800-B3181B?logo=arXiv)](https://arxiv.org/abs/2008.03800)
This repository is the official TF2 implementation of [Spatiotemporal Contrastive Video Representation Learning](https://arxiv.org/abs/2008.03800).
<p align="left">
<img src="https://storage.googleapis.com/tf_model_garden/vision/cvrl/artifacts/cvrl_overview.png" height=350>
</p>
## Description
We present a self-supervised Contrastive Video Representation Learning (CVRL)
method to learn spatiotemporal visual representations from unlabeled videos. Our
representations are learned using a contrastive loss, where two augmented clips
from the same short video are pulled together in the embedding space, while
clips from different videos are pushed away. CVRL significantly closes the gap
between unsupervised and supervised video representation learning.
We release the code and pre-trained models.
More pre-trained model checkpoints and a detailed instruction about the code
will be updated.
## Experimental Results
### Kinetics-600 top-1 linear classification accuracy
<p align="left">
<img src="https://storage.googleapis.com/tf_model_garden/vision/cvrl/artifacts/cvrl_results.png" height=350>
</p>
## Pre-trained Model Checkpoints
We provide model checkpoints pre-trained on unlabeled RGB videos from
Kinetics-400 and Kinetics-600. All models are trained scratch with random
initialization.
We also provide a baseline model checkpoint of "ImageNet inflated" we used in
the paper. The model has the same architecture as 3D-ResNet-50 (R3D-50), with
model weights inflated from a 2D ResNet-50 pre-trained on ImageNet.
| Model | Parameters | Dataset | Epochs | K400 Linear Eval. | K600 Linear Eval. | Checkpoint |
| :--------------: | :----: | :--: | :--: |:-----------: | :----------: | :----------: |
| R3D-50 (1x) | 31.7M | ImageNet | - | 53.5% | 54.7% | [ckpt (127 MB)](https://storage.googleapis.com/tf_model_garden/vision/cvrl/imagenet.tar.gz) |
| R3D-50 (1x) | 31.7M | Kinetics-400 | 200 | 63.8% | - | [ckpt (127 MB)](https://storage.googleapis.com/tf_model_garden/vision/cvrl/r3d_1x_k400_200ep.tar.gz) |
| R3D-50 (1x) | 31.7M | Kinetics-400 | 800 | 66.1% | - | [ckpt (127 MB)](https://storage.googleapis.com/tf_model_garden/vision/cvrl/r3d_1x_k400_800ep.tar.gz) |
| R3D-50 (1x) | 31.7M | Kinetics-600 | 800 | 68.5% | 70.4% | [ckpt (127 MB)](https://storage.googleapis.com/tf_model_garden/vision/cvrl/r3d_1x_k600_800ep.tar.gz) |
## Citation
```
@inproceedings{qian2021spatiotemporal,
title={Spatiotemporal contrastive video representation learning},
author={Qian, Rui and Meng, Tianjian and Gong, Boqing and Yang, Ming-Hsuan and Wang, Huisheng and Belongie, Serge and Cui, Yin},
booktitle={CVPR},
year={2021}
}
```
......@@ -12,3 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Configs package definition."""
from official.vision.beta.projects.video_ssl.configs import video_ssl
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
# Put the pretrained checkpoint here for linear evaluation
init_checkpoint: 'r3d_1x_k600_800ep_backbone-1'
init_checkpoint_modules: 'backbone'
model:
dropout_rate: 1.0
norm_activation:
use_sync_bn: false
backbone:
resnet_3d:
block_specs: !!python/tuple
- temporal_kernel_sizes: !!python/tuple
- 1
- 1
- 1
temporal_strides: 1
use_self_gating: false
- temporal_kernel_sizes: !!python/tuple
- 1
- 1
- 1
- 1
temporal_strides: 1
use_self_gating: false
- temporal_kernel_sizes: !!python/tuple
- 3
- 3
- 3
- 3
- 3
- 3
temporal_strides: 1
use_self_gating: false
- temporal_kernel_sizes: !!python/tuple
- 3
- 3
- 3
temporal_strides: 1
use_self_gating: false
model_id: 50
stem_conv_temporal_kernel_size: 5
stem_conv_temporal_stride: 2
stem_pool_temporal_stride: 1
train_data:
name: kinetics600
feature_shape: !!python/tuple
- 32
- 224
- 224
- 3
temporal_stride: 2
global_batch_size: 1024
dtype: 'bfloat16'
shuffle_buffer_size: 1024
aug_max_area_ratio: 1.0
aug_max_aspect_ratio: 2.0
aug_min_area_ratio: 0.3
aug_min_aspect_ratio: 0.5
validation_data:
name: kinetics600
feature_shape: !!python/tuple
- 32
- 256
- 256
- 3
temporal_stride: 2
num_test_clips: 10
num_test_crops: 3
global_batch_size: 64
dtype: 'bfloat16'
drop_remainder: false
losses:
l2_weight_decay: 0.0
trainer:
optimizer_config:
learning_rate:
cosine:
initial_learning_rate: 32.0
decay_steps: 35744
optimizer:
sgd:
nesterov: false
warmup:
linear:
warmup_steps: 1787
train_steps: 35744
steps_per_loop: 100
summary_interval: 100
validation_interval: 100
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
model:
dropout_rate: 1.0
norm_activation:
use_sync_bn: true
hidden_norm_activation:
use_sync_bn: true
backbone:
resnet_3d:
block_specs: !!python/tuple
- temporal_kernel_sizes: !!python/tuple
- 1
- 1
- 1
temporal_strides: 1
use_self_gating: false
- temporal_kernel_sizes: !!python/tuple
- 1
- 1
- 1
- 1
temporal_strides: 1
use_self_gating: false
- temporal_kernel_sizes: !!python/tuple
- 3
- 3
- 3
- 3
- 3
- 3
temporal_strides: 1
use_self_gating: false
- temporal_kernel_sizes: !!python/tuple
- 3
- 3
- 3
temporal_strides: 1
use_self_gating: false
model_id: 50
stem_conv_temporal_kernel_size: 5
stem_conv_temporal_stride: 2
stem_pool_temporal_stride: 1
train_data:
name: kinetics600
feature_shape: !!python/tuple
- 16
- 224
- 224
- 3
temporal_stride: 2
global_batch_size: 1024
dtype: 'bfloat16'
shuffle_buffer_size: 1024
losses:
l2_weight_decay: 0.000001
trainer:
optimizer_config:
learning_rate:
cosine:
initial_learning_rate: 0.32
decay_steps: 71488
optimizer:
sgd:
nesterov: false
warmup:
linear:
warmup_steps: 1787
train_steps: 71488
steps_per_loop: 100
summary_interval: 100
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Video classification configuration definition."""
import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.vision.beta.configs import common
from official.vision.beta.configs import video_classification
Losses = video_classification.Losses
VideoClassificationModel = video_classification.VideoClassificationModel
VideoClassificationTask = video_classification.VideoClassificationTask
@dataclasses.dataclass
class VideoSSLPretrainTask(VideoClassificationTask):
pass
@dataclasses.dataclass
class VideoSSLEvalTask(VideoClassificationTask):
pass
@dataclasses.dataclass
class DataConfig(video_classification.DataConfig):
"""The base configuration for building datasets."""
is_ssl: bool = False
@dataclasses.dataclass
class VideoSSLModel(VideoClassificationModel):
"""The model config."""
normalize_feature: bool = False
hidden_dim: int = 2048
hidden_layer_num: int = 3
projection_dim: int = 128
hidden_norm_activation: common.NormActivation = common.NormActivation(
use_sync_bn=False, norm_momentum=0.997, norm_epsilon=1.0e-05)
@dataclasses.dataclass
class SSLLosses(Losses):
normalize_hidden: bool = True
temperature: float = 0.1
@exp_factory.register_config_factory('video_ssl_pretrain_kinetics400')
def video_ssl_pretrain_kinetics400() -> cfg.ExperimentConfig:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp = video_classification.video_classification_kinetics400()
exp.task = VideoSSLPretrainTask(**exp.task.as_dict())
exp.task.train_data = DataConfig(is_ssl=True, **exp.task.train_data.as_dict())
exp.task.train_data.feature_shape = (16, 224, 224, 3)
exp.task.train_data.temporal_stride = 2
exp.task.model = VideoSSLModel(exp.task.model)
exp.task.model.model_type = 'video_ssl_model'
exp.task.losses = SSLLosses(exp.task.losses)
return exp
@exp_factory.register_config_factory('video_ssl_linear_eval_kinetics400')
def video_ssl_linear_eval_kinetics400() -> cfg.ExperimentConfig:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp = video_classification.video_classification_kinetics400()
exp.task = VideoSSLEvalTask(**exp.task.as_dict())
exp.task.train_data = DataConfig(is_ssl=False,
**exp.task.train_data.as_dict())
exp.task.train_data.feature_shape = (32, 224, 224, 3)
exp.task.train_data.temporal_stride = 2
exp.task.validation_data.feature_shape = (32, 256, 256, 3)
exp.task.validation_data.temporal_stride = 2
exp.task.validation_data = DataConfig(is_ssl=False,
**exp.task.validation_data.as_dict())
exp.task.validation_data.min_image_size = 256
exp.task.validation_data.num_test_clips = 10
exp.task.validation_data.num_test_crops = 3
exp.task.model = VideoSSLModel(exp.task.model)
exp.task.model.model_type = 'video_ssl_model'
exp.task.model.normalize_feature = True
exp.task.model.hidden_layer_num = 0
exp.task.model.projection_dim = 400
return exp
@exp_factory.register_config_factory('video_ssl_pretrain_kinetics600')
def video_ssl_pretrain_kinetics600() -> cfg.ExperimentConfig:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp = video_classification.video_classification_kinetics600()
exp.task = VideoSSLPretrainTask(**exp.task.as_dict())
exp.task.train_data = DataConfig(is_ssl=True, **exp.task.train_data.as_dict())
exp.task.train_data.feature_shape = (16, 224, 224, 3)
exp.task.train_data.temporal_stride = 2
exp.task.model = VideoSSLModel(exp.task.model)
exp.task.model.model_type = 'video_ssl_model'
exp.task.losses = SSLLosses(exp.task.losses)
return exp
@exp_factory.register_config_factory('video_ssl_linear_eval_kinetics600')
def video_ssl_linear_eval_kinetics600() -> cfg.ExperimentConfig:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp = video_classification.video_classification_kinetics600()
exp.task = VideoSSLEvalTask(**exp.task.as_dict())
exp.task.train_data = DataConfig(is_ssl=False,
**exp.task.train_data.as_dict())
exp.task.train_data.feature_shape = (32, 224, 224, 3)
exp.task.train_data.temporal_stride = 2
exp.task.validation_data = DataConfig(is_ssl=False,
**exp.task.validation_data.as_dict())
exp.task.validation_data.feature_shape = (32, 256, 256, 3)
exp.task.validation_data.temporal_stride = 2
exp.task.validation_data.min_image_size = 256
exp.task.validation_data.num_test_clips = 10
exp.task.validation_data.num_test_crops = 3
exp.task.model = VideoSSLModel(exp.task.model)
exp.task.model.model_type = 'video_ssl_model'
exp.task.model.normalize_feature = True
exp.task.model.hidden_layer_num = 0
exp.task.model.projection_dim = 600
return exp
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# pylint: disable=unused-import
from absl.testing import parameterized
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.vision import beta
from official.vision.beta.projects.video_ssl.configs import video_ssl as exp_cfg
class VideoClassificationConfigTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(('video_ssl_pretrain_kinetics400',),
('video_ssl_pretrain_kinetics600',))
def test_video_ssl_pretrain_configs(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.VideoSSLPretrainTask)
self.assertIsInstance(config.task.model, exp_cfg.VideoSSLModel)
self.assertIsInstance(config.task.losses, exp_cfg.SSLLosses)
self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig)
config.task.train_data.is_training = None
with self.assertRaises(KeyError):
config.validate()
@parameterized.parameters(('video_ssl_linear_eval_kinetics400',),
('video_ssl_linear_eval_kinetics600',))
def test_video_ssl_linear_eval_configs(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.VideoSSLEvalTask)
self.assertIsInstance(config.task.model, exp_cfg.VideoSSLModel)
self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig)
config.task.train_data.is_training = None
with self.assertRaises(KeyError):
config.validate()
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Parser for video and label datasets."""
from typing import Dict, Optional, Tuple
from absl import logging
import tensorflow as tf
from official.vision.beta.dataloaders import video_input
from official.vision.beta.ops import preprocess_ops_3d
from official.vision.beta.projects.video_ssl.configs import video_ssl as exp_cfg
from official.vision.beta.projects.video_ssl.ops import video_ssl_preprocess_ops
IMAGE_KEY = 'image/encoded'
LABEL_KEY = 'clip/label/index'
Decoder = video_input.Decoder
def _process_image(image: tf.Tensor,
is_training: bool = True,
is_ssl: bool = False,
num_frames: int = 32,
stride: int = 1,
num_test_clips: int = 1,
min_resize: int = 256,
crop_size: int = 224,
num_crops: int = 1,
zero_centering_image: bool = False,
seed: Optional[int] = None) -> tf.Tensor:
"""Processes a serialized image tensor.
Args:
image: Input Tensor of shape [timesteps] and type tf.string of serialized
frames.
is_training: Whether or not in training mode. If True, random sample, crop
and left right flip is used.
is_ssl: Whether or not in self-supervised pre-training mode.
num_frames: Number of frames per subclip.
stride: Temporal stride to sample frames.
num_test_clips: Number of test clips (1 by default). If more than 1, this
will sample multiple linearly spaced clips within each video at test time.
If 1, then a single clip in the middle of the video is sampled. The clips
are aggreagated in the batch dimension.
min_resize: Frames are resized so that min(height, width) is min_resize.
crop_size: Final size of the frame after cropping the resized frames. Both
height and width are the same.
num_crops: Number of crops to perform on the resized frames.
zero_centering_image: If True, frames are normalized to values in [-1, 1].
If False, values in [0, 1].
seed: A deterministic seed to use when sampling.
Returns:
Processed frames. Tensor of shape
[num_frames * num_test_clips, crop_size, crop_size, 3].
"""
# Validate parameters.
if is_training and num_test_clips != 1:
logging.warning(
'`num_test_clips` %d is ignored since `is_training` is `True`.',
num_test_clips)
# Temporal sampler.
if is_training:
# Sampler for training.
if is_ssl:
# Sample two clips from linear decreasing distribution.
image = video_ssl_preprocess_ops.sample_ssl_sequence(
image, num_frames, True, stride)
else:
# Sample random clip.
image = preprocess_ops_3d.sample_sequence(image, num_frames, True, stride)
else:
# Sampler for evaluation.
if num_test_clips > 1:
# Sample linspace clips.
image = preprocess_ops_3d.sample_linspace_sequence(image, num_test_clips,
num_frames, stride)
else:
# Sample middle clip.
image = preprocess_ops_3d.sample_sequence(image, num_frames, False,
stride)
# Decode JPEG string to tf.uint8.
image = preprocess_ops_3d.decode_jpeg(image, 3)
if is_training:
# Standard image data augmentation: random resized crop and random flip.
if is_ssl:
image_1, image_2 = tf.split(image, num_or_size_splits=2, axis=0)
image_1 = preprocess_ops_3d.random_crop_resize(
image_1, crop_size, crop_size, num_frames, 3, (0.5, 2), (0.3, 1))
image_1 = preprocess_ops_3d.random_flip_left_right(image_1, seed)
image_2 = preprocess_ops_3d.random_crop_resize(
image_2, crop_size, crop_size, num_frames, 3, (0.5, 2), (0.3, 1))
image_2 = preprocess_ops_3d.random_flip_left_right(image_2, seed)
else:
image = preprocess_ops_3d.random_crop_resize(
image, crop_size, crop_size, num_frames, 3, (0.5, 2), (0.3, 1))
image = preprocess_ops_3d.random_flip_left_right(image, seed)
else:
# Resize images (resize happens only if necessary to save compute).
image = preprocess_ops_3d.resize_smallest(image, min_resize)
# Three-crop of the frames.
image = preprocess_ops_3d.crop_image(image, crop_size, crop_size, False,
num_crops)
# Cast the frames in float32, normalizing according to zero_centering_image.
if is_training and is_ssl:
image_1 = preprocess_ops_3d.normalize_image(image_1, zero_centering_image)
image_2 = preprocess_ops_3d.normalize_image(image_2, zero_centering_image)
else:
image = preprocess_ops_3d.normalize_image(image, zero_centering_image)
# Self-supervised pre-training augmentations.
if is_training and is_ssl:
# Temporally consistent color jittering.
image_1 = video_ssl_preprocess_ops.random_color_jitter_3d(image_1)
image_2 = video_ssl_preprocess_ops.random_color_jitter_3d(image_2)
# Temporally consistent gaussian blurring.
image_1 = video_ssl_preprocess_ops.random_blur(image_1, crop_size,
crop_size, 1.0)
image_2 = video_ssl_preprocess_ops.random_blur(image_2, crop_size,
crop_size, 0.1)
image_2 = video_ssl_preprocess_ops.random_solarization(image_2)
image = tf.concat([image_1, image_2], axis=0)
image = tf.clip_by_value(image, 0., 1.)
return image
def _postprocess_image(image: tf.Tensor,
is_training: bool = True,
is_ssl: bool = False,
num_frames: int = 32,
num_test_clips: int = 1,
num_test_crops: int = 1) -> tf.Tensor:
"""Processes a batched Tensor of frames.
The same parameters used in process should be used here.
Args:
image: Input Tensor of shape [batch, timesteps, height, width, 3].
is_training: Whether or not in training mode. If True, random sample, crop
and left right flip is used.
is_ssl: Whether or not in self-supervised pre-training mode.
num_frames: Number of frames per subclip.
num_test_clips: Number of test clips (1 by default). If more than 1, this
will sample multiple linearly spaced clips within each video at test time.
If 1, then a single clip in the middle of the video is sampled. The clips
are aggreagated in the batch dimension.
num_test_crops: Number of test crops (1 by default). If more than 1, there
are multiple crops for each clip at test time. If 1, there is a single
central crop. The crops are aggreagated in the batch dimension.
Returns:
Processed frames. Tensor of shape
[batch * num_test_clips * num_test_crops, num_frames, height, width, 3].
"""
if is_ssl and is_training:
# In this case, two clips of self-supervised pre-training are merged
# together in batch dimenstion which will be 2 * batch.
image = tf.concat(tf.split(image, num_or_size_splits=2, axis=1), axis=0)
num_views = num_test_clips * num_test_crops
if num_views > 1 and not is_training:
# In this case, multiple views are merged together in batch dimenstion which
# will be batch * num_views.
image = tf.reshape(image, [-1, num_frames] + image.shape[2:].as_list())
return image
def _process_label(label: tf.Tensor,
one_hot_label: bool = True,
num_classes: Optional[int] = None) -> tf.Tensor:
"""Processes label Tensor."""
# Validate parameters.
if one_hot_label and not num_classes:
raise ValueError(
'`num_classes` should be given when requesting one hot label.')
# Cast to tf.int32.
label = tf.cast(label, dtype=tf.int32)
if one_hot_label:
# Replace label index by one hot representation.
label = tf.one_hot(label, num_classes)
if len(label.shape.as_list()) > 1:
label = tf.reduce_sum(label, axis=0)
if num_classes == 1:
# The trick for single label.
label = 1 - label
return label
class Parser(video_input.Parser):
"""Parses a video and label dataset."""
def __init__(self,
input_params: exp_cfg.DataConfig,
image_key: str = IMAGE_KEY,
label_key: str = LABEL_KEY):
super(Parser, self).__init__(input_params, image_key, label_key)
self._is_ssl = input_params.is_ssl
def _parse_train_data(
self, decoded_tensors: Dict[str, tf.Tensor]
) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]:
"""Parses data for training."""
# Process image and label.
image = decoded_tensors[self._image_key]
image = _process_image(
image=image,
is_training=True,
is_ssl=self._is_ssl,
num_frames=self._num_frames,
stride=self._stride,
num_test_clips=self._num_test_clips,
min_resize=self._min_resize,
crop_size=self._crop_size)
image = tf.cast(image, dtype=self._dtype)
features = {'image': image}
label = decoded_tensors[self._label_key]
label = _process_label(label, self._one_hot_label, self._num_classes)
return features, label
def _parse_eval_data(
self, decoded_tensors: Dict[str, tf.Tensor]
) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]:
"""Parses data for evaluation."""
image = decoded_tensors[self._image_key]
image = _process_image(
image=image,
is_training=False,
num_frames=self._num_frames,
stride=self._stride,
num_test_clips=self._num_test_clips,
min_resize=self._min_resize,
crop_size=self._crop_size,
num_crops=self._num_crops)
image = tf.cast(image, dtype=self._dtype)
features = {'image': image}
label = decoded_tensors[self._label_key]
label = _process_label(label, self._one_hot_label, self._num_classes)
if self._output_audio:
audio = decoded_tensors[self._audio_feature]
audio = tf.cast(audio, dtype=self._dtype)
audio = preprocess_ops_3d.sample_sequence(
audio, 20, random=False, stride=1)
audio = tf.ensure_shape(audio, [20, 2048])
features['audio'] = audio
return features, label
def parse_fn(self, is_training):
"""Returns a parse fn that reads and parses raw tensors from the decoder.
Args:
is_training: a `bool` to indicate whether it is in training mode.
Returns:
parse: a `callable` that takes the serialized examle and generate the
images, labels tuple where labels is a dict of Tensors that contains
labels.
"""
def parse(decoded_tensors):
"""Parses the serialized example data."""
if is_training:
return self._parse_train_data(decoded_tensors)
else:
return self._parse_eval_data(decoded_tensors)
return parse
class PostBatchProcessor(object):
"""Processes a video and label dataset which is batched."""
def __init__(self, input_params: exp_cfg.DataConfig):
self._is_training = input_params.is_training
self._is_ssl = input_params.is_ssl
self._num_frames = input_params.feature_shape[0]
self._num_test_clips = input_params.num_test_clips
self._num_test_crops = input_params.num_test_crops
def __call__(self, features: Dict[str, tf.Tensor],
label: tf.Tensor) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]:
"""Parses a single tf.Example into image and label tensors."""
for key in ['image', 'audio']:
if key in features:
features[key] = _postprocess_image(
image=features[key],
is_training=self._is_training,
is_ssl=self._is_ssl,
num_frames=self._num_frames,
num_test_clips=self._num_test_clips,
num_test_crops=self._num_test_crops)
return features, label
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
import io
# Import libraries
import numpy as np
from PIL import Image
import tensorflow as tf
from official.vision.beta.projects.video_ssl.configs import video_ssl as exp_cfg
from official.vision.beta.projects.video_ssl.dataloaders import video_ssl_input
AUDIO_KEY = 'features/audio'
def fake_seq_example():
# Create fake data.
random_image = np.random.randint(0, 256, size=(263, 320, 3), dtype=np.uint8)
random_image = Image.fromarray(random_image)
label = 42
with io.BytesIO() as buffer:
random_image.save(buffer, format='JPEG')
raw_image_bytes = buffer.getvalue()
seq_example = tf.train.SequenceExample()
seq_example.feature_lists.feature_list.get_or_create(
video_ssl_input.IMAGE_KEY).feature.add().bytes_list.value[:] = [
raw_image_bytes
]
seq_example.feature_lists.feature_list.get_or_create(
video_ssl_input.IMAGE_KEY).feature.add().bytes_list.value[:] = [
raw_image_bytes
]
seq_example.context.feature[video_ssl_input.LABEL_KEY].int64_list.value[:] = [
label
]
random_audio = np.random.normal(size=(10, 256)).tolist()
for s in random_audio:
seq_example.feature_lists.feature_list.get_or_create(
AUDIO_KEY).feature.add().float_list.value[:] = s
return seq_example, label
class VideoAndLabelParserTest(tf.test.TestCase):
def test_video_ssl_input_pretrain(self):
params = exp_cfg.video_ssl_pretrain_kinetics600().task.train_data
decoder = video_ssl_input.Decoder()
parser = video_ssl_input.Parser(params).parse_fn(params.is_training)
seq_example, _ = fake_seq_example()
input_tensor = tf.constant(seq_example.SerializeToString())
decoded_tensors = decoder.decode(input_tensor)
output_tensor = parser(decoded_tensors)
image_features, _ = output_tensor
image = image_features['image']
self.assertAllEqual(image.shape, (32, 224, 224, 3))
def test_video_ssl_input_linear_train(self):
params = exp_cfg.video_ssl_linear_eval_kinetics600().task.train_data
decoder = video_ssl_input.Decoder()
parser = video_ssl_input.Parser(params).parse_fn(params.is_training)
seq_example, label = fake_seq_example()
input_tensor = tf.constant(seq_example.SerializeToString())
decoded_tensors = decoder.decode(input_tensor)
output_tensor = parser(decoded_tensors)
image_features, label = output_tensor
image = image_features['image']
self.assertAllEqual(image.shape, (32, 224, 224, 3))
self.assertAllEqual(label.shape, (600,))
def test_video_ssl_input_linear_eval(self):
params = exp_cfg.video_ssl_linear_eval_kinetics600().task.validation_data
print('!!!', params)
decoder = video_ssl_input.Decoder()
parser = video_ssl_input.Parser(params).parse_fn(params.is_training)
seq_example, label = fake_seq_example()
input_tensor = tf.constant(seq_example.SerializeToString())
decoded_tensors = decoder.decode(input_tensor)
output_tensor = parser(decoded_tensors)
image_features, label = output_tensor
image = image_features['image']
self.assertAllEqual(image.shape, (960, 256, 256, 3))
self.assertAllEqual(label.shape, (600,))
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Define losses."""
# Import libraries
import tensorflow as tf
from tensorflow.compiler.tf2xla.python import xla
def contrastive_loss(hidden,
num_replicas,
normalize_hidden,
temperature,
model,
weight_decay):
"""Computes contrastive loss.
Args:
hidden: embedding of video clips after projection head.
num_replicas: number of distributed replicas.
normalize_hidden: whether or not to l2 normalize the hidden vector.
temperature: temperature in the InfoNCE contrastive loss.
model: keras model for calculating weight decay.
weight_decay: weight decay parameter.
Returns:
A loss scalar.
The logits for contrastive prediction task.
The labels for contrastive prediction task.
"""
large_num = 1e9
hidden1, hidden2 = tf.split(hidden, num_or_size_splits=2, axis=0)
if normalize_hidden:
hidden1 = tf.math.l2_normalize(hidden1, -1)
hidden2 = tf.math.l2_normalize(hidden2, -1)
batch_size = tf.shape(hidden1)[0]
if num_replicas == 1:
# This is the local version
hidden1_large = hidden1
hidden2_large = hidden2
labels = tf.one_hot(tf.range(batch_size), batch_size * 2)
masks = tf.one_hot(tf.range(batch_size), batch_size)
else:
# This is the cross-tpu version.
hidden1_large = tpu_cross_replica_concat(hidden1, num_replicas)
hidden2_large = tpu_cross_replica_concat(hidden2, num_replicas)
enlarged_batch_size = tf.shape(hidden1_large)[0]
replica_id = tf.cast(tf.cast(xla.replica_id(), tf.uint32), tf.int32)
labels_idx = tf.range(batch_size) + replica_id * batch_size
labels = tf.one_hot(labels_idx, enlarged_batch_size * 2)
masks = tf.one_hot(labels_idx, enlarged_batch_size)
logits_aa = tf.matmul(hidden1, hidden1_large, transpose_b=True) / temperature
logits_aa = logits_aa - tf.cast(masks, logits_aa.dtype) * large_num
logits_bb = tf.matmul(hidden2, hidden2_large, transpose_b=True) / temperature
logits_bb = logits_bb - tf.cast(masks, logits_bb.dtype) * large_num
logits_ab = tf.matmul(hidden1, hidden2_large, transpose_b=True) / temperature
logits_ba = tf.matmul(hidden2, hidden1_large, transpose_b=True) / temperature
loss_a = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
labels, tf.concat([logits_ab, logits_aa], 1)))
loss_b = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
labels, tf.concat([logits_ba, logits_bb], 1)))
loss = loss_a + loss_b
l2_loss = weight_decay * tf.add_n([
tf.nn.l2_loss(v)
for v in model.trainable_variables
if 'kernel' in v.name
])
total_loss = loss + tf.cast(l2_loss, loss.dtype)
contrast_prob = tf.nn.softmax(logits_ab)
contrast_entropy = - tf.reduce_mean(
tf.reduce_sum(contrast_prob * tf.math.log(contrast_prob + 1e-8), -1))
contrast_acc = tf.equal(tf.argmax(labels, 1), tf.argmax(logits_ab, axis=1))
contrast_acc = tf.reduce_mean(tf.cast(contrast_acc, tf.float32))
return {
'total_loss': total_loss,
'contrastive_loss': loss,
'reg_loss': l2_loss,
'contrast_acc': contrast_acc,
'contrast_entropy': contrast_entropy,
}
def tpu_cross_replica_concat(tensor, num_replicas):
"""Reduce a concatenation of the `tensor` across TPU cores.
Args:
tensor: tensor to concatenate.
num_replicas: number of TPU device replicas.
Returns:
Tensor of the same rank as `tensor` with first dimension `num_replicas`
times larger.
"""
with tf.name_scope('tpu_cross_replica_concat'):
# This creates a tensor that is like the input tensor but has an added
# replica dimension as the outermost dimension. On each replica it will
# contain the local values and zeros for all other values that need to be
# fetched from other replicas.
ext_tensor = tf.scatter_nd(
indices=[[xla.replica_id()]],
updates=[tensor],
shape=[num_replicas] + tensor.shape.as_list())
# As every value is only present on one replica and 0 in all others, adding
# them all together will result in the full tensor on all replicas.
replica_context = tf.distribute.get_replica_context()
ext_tensor = replica_context.all_reduce(tf.distribute.ReduceOp.SUM,
ext_tensor)
# Flatten the replica dimension.
# The first dimension size will be: tensor.shape[0] * num_replicas
# Using [-1] trick to support also scalar input.
return tf.reshape(ext_tensor, [-1] + ext_tensor.shape.as_list()[2:])
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Build video classification models."""
from typing import Mapping, Optional
# Import libraries
import tensorflow as tf
from official.modeling import tf_utils
from official.vision.beta.modeling import backbones
from official.vision.beta.modeling import factory_3d as model_factory
from official.vision.beta.projects.video_ssl.configs import video_ssl as video_ssl_cfg
layers = tf.keras.layers
@tf.keras.utils.register_keras_serializable(package='Vision')
class VideoSSLModel(tf.keras.Model):
"""A video ssl model class builder."""
def __init__(self,
backbone,
normalize_feature,
hidden_dim,
hidden_layer_num,
hidden_norm_args,
projection_dim,
input_specs: Optional[Mapping[str,
tf.keras.layers.InputSpec]] = None,
dropout_rate: float = 0.0,
aggregate_endpoints: bool = False,
kernel_initializer='random_uniform',
kernel_regularizer=None,
bias_regularizer=None,
**kwargs):
"""Video Classification initialization function.
Args:
backbone: a 3d backbone network.
normalize_feature: whether normalize backbone feature.
hidden_dim: `int` number of hidden units in MLP.
hidden_layer_num: `int` number of hidden layers in MLP.
hidden_norm_args: `dict` for batchnorm arguments in MLP.
projection_dim: `int` number of ouput dimension for MLP.
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
dropout_rate: `float` rate for dropout regularization.
aggregate_endpoints: `bool` aggregate all end ponits or only use the
final end point.
kernel_initializer: kernel initializer for the dense layer.
kernel_regularizer: tf.keras.regularizers.Regularizer object. Default to
None.
bias_regularizer: tf.keras.regularizers.Regularizer object. Default to
None.
**kwargs: keyword arguments to be passed.
"""
if not input_specs:
input_specs = {
'image': layers.InputSpec(shape=[None, None, None, None, 3])
}
self._self_setattr_tracking = False
self._config_dict = {
'backbone': backbone,
'normalize_feature': normalize_feature,
'hidden_dim': hidden_dim,
'hidden_layer_num': hidden_layer_num,
'use_sync_bn': hidden_norm_args.use_sync_bn,
'norm_momentum': hidden_norm_args.norm_momentum,
'norm_epsilon': hidden_norm_args.norm_epsilon,
'activation': hidden_norm_args.activation,
'projection_dim': projection_dim,
'input_specs': input_specs,
'dropout_rate': dropout_rate,
'aggregate_endpoints': aggregate_endpoints,
'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
}
self._input_specs = input_specs
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._backbone = backbone
inputs = {
k: tf.keras.Input(shape=v.shape[1:]) for k, v in input_specs.items()
}
endpoints = backbone(inputs['image'])
if aggregate_endpoints:
pooled_feats = []
for endpoint in endpoints.values():
x_pool = tf.keras.layers.GlobalAveragePooling3D()(endpoint)
pooled_feats.append(x_pool)
x = tf.concat(pooled_feats, axis=1)
else:
x = endpoints[max(endpoints.keys())]
x = tf.keras.layers.GlobalAveragePooling3D()(x)
# L2 Normalize feature after backbone
if normalize_feature:
x = tf.nn.l2_normalize(x, axis=-1)
# MLP hidden layers
for _ in range(hidden_layer_num):
x = tf.keras.layers.Dense(hidden_dim)(x)
if self._config_dict['use_sync_bn']:
x = tf.keras.layers.experimental.SyncBatchNormalization(
momentum=self._config_dict['norm_momentum'],
epsilon=self._config_dict['norm_epsilon'])(x)
else:
x = tf.keras.layers.BatchNormalization(
momentum=self._config_dict['norm_momentum'],
epsilon=self._config_dict['norm_epsilon'])(x)
x = tf_utils.get_activation(self._config_dict['activation'])(x)
# Projection head
x = tf.keras.layers.Dense(projection_dim)(x)
super(VideoSSLModel, self).__init__(
inputs=inputs, outputs=x, **kwargs)
@property
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
return dict(backbone=self.backbone)
@property
def backbone(self):
return self._backbone
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@model_factory.register_model_builder('video_ssl_model')
def build_video_ssl_pretrain_model(
input_specs: tf.keras.layers.InputSpec,
model_config: video_ssl_cfg.VideoSSLModel,
num_classes: int,
l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None):
"""Builds the video classification model."""
del num_classes
input_specs_dict = {'image': input_specs}
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
backbone_config=model_config.backbone,
norm_activation_config=model_config.norm_activation,
l2_regularizer=l2_regularizer)
# Norm layer type in the MLP head should same with backbone
assert model_config.norm_activation.use_sync_bn == model_config.hidden_norm_activation.use_sync_bn
model = VideoSSLModel(
backbone=backbone,
normalize_feature=model_config.normalize_feature,
hidden_dim=model_config.hidden_dim,
hidden_layer_num=model_config.hidden_layer_num,
hidden_norm_args=model_config.hidden_norm_activation,
projection_dim=model_config.projection_dim,
input_specs=input_specs_dict,
dropout_rate=model_config.dropout_rate,
aggregate_endpoints=model_config.aggregate_endpoints,
kernel_regularizer=l2_regularizer)
return model
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Utils for customed ops for video ssl."""
import functools
from typing import Optional
import tensorflow as tf
def random_apply(func, p, x):
"""Randomly apply function func to x with probability p."""
return tf.cond(
tf.less(tf.random.uniform([], minval=0, maxval=1, dtype=tf.float32),
tf.cast(p, tf.float32)),
lambda: func(x),
lambda: x)
def random_brightness(image, max_delta):
"""Distort brightness of image (SimCLRv2 style)."""
factor = tf.random.uniform(
[], tf.maximum(1.0 - max_delta, 0), 1.0 + max_delta)
image = image * factor
return image
def random_solarization(image, p=0.2):
"""Random solarize image."""
def _transform(image):
image = image * tf.cast(tf.less(image, 0.5), dtype=image.dtype) + (
1.0 - image) * tf.cast(tf.greater_equal(image, 0.5), dtype=image.dtype)
return image
return random_apply(_transform, p=p, x=image)
def to_grayscale(image, keep_channels=True):
"""Turn the input image to gray scale.
Args:
image: The input image tensor.
keep_channels: Whether maintaining the channel number for the image.
If true, the transformed image will repeat three times in channel.
If false, the transformed image will only have one channel.
Returns:
The distorted image tensor.
"""
image = tf.image.rgb_to_grayscale(image)
if keep_channels:
image = tf.tile(image, [1, 1, 3])
return image
def color_jitter(image, strength, random_order=True):
"""Distorts the color of the image (SimCLRv2 style).
Args:
image: The input image tensor.
strength: The floating number for the strength of the color augmentation.
random_order: A bool, specifying whether to randomize the jittering order.
Returns:
The distorted image tensor.
"""
brightness = 0.8 * strength
contrast = 0.8 * strength
saturation = 0.8 * strength
hue = 0.2 * strength
if random_order:
return color_jitter_rand(
image, brightness, contrast, saturation, hue)
else:
return color_jitter_nonrand(
image, brightness, contrast, saturation, hue)
def color_jitter_nonrand(image,
brightness=0,
contrast=0,
saturation=0,
hue=0):
"""Distorts the color of the image (jittering order is fixed, SimCLRv2 style).
Args:
image: The input image tensor.
brightness: A float, specifying the brightness for color jitter.
contrast: A float, specifying the contrast for color jitter.
saturation: A float, specifying the saturation for color jitter.
hue: A float, specifying the hue for color jitter.
Returns:
The distorted image tensor.
"""
with tf.name_scope('distort_color'):
def apply_transform(i, x, brightness, contrast, saturation, hue):
"""Apply the i-th transformation."""
if brightness != 0 and i == 0:
x = random_brightness(x, max_delta=brightness)
elif contrast != 0 and i == 1:
x = tf.image.random_contrast(
x, lower=1-contrast, upper=1+contrast)
elif saturation != 0 and i == 2:
x = tf.image.random_saturation(
x, lower=1-saturation, upper=1+saturation)
elif hue != 0:
x = tf.image.random_hue(x, max_delta=hue)
return x
for i in range(4):
image = apply_transform(i, image, brightness, contrast, saturation, hue)
image = tf.clip_by_value(image, 0., 1.)
return image
def color_jitter_rand(image,
brightness=0,
contrast=0,
saturation=0,
hue=0):
"""Distorts the color of the image (jittering order is random, SimCLRv2 style).
Args:
image: The input image tensor.
brightness: A float, specifying the brightness for color jitter.
contrast: A float, specifying the contrast for color jitter.
saturation: A float, specifying the saturation for color jitter.
hue: A float, specifying the hue for color jitter.
Returns:
The distorted image tensor.
"""
with tf.name_scope('distort_color'):
def apply_transform(i, x):
"""Apply the i-th transformation."""
def brightness_transform():
if brightness == 0:
return x
else:
return random_brightness(x, max_delta=brightness)
def contrast_transform():
if contrast == 0:
return x
else:
return tf.image.random_contrast(x, lower=1-contrast, upper=1+contrast)
def saturation_transform():
if saturation == 0:
return x
else:
return tf.image.random_saturation(
x, lower=1-saturation, upper=1+saturation)
def hue_transform():
if hue == 0:
return x
else:
return tf.image.random_hue(x, max_delta=hue)
# pylint:disable=g-long-lambda
x = tf.cond(
tf.less(i, 2), lambda: tf.cond(
tf.less(i, 1), brightness_transform, contrast_transform),
lambda: tf.cond(tf.less(i, 3), saturation_transform, hue_transform))
# pylint:disable=g-long-lambda
return x
perm = tf.random.shuffle(tf.range(4))
for i in range(4):
image = apply_transform(perm[i], image)
image = tf.clip_by_value(image, 0., 1.)
return image
def random_color_jitter_3d(frames):
"""Applies temporally consistent color jittering to one video clip.
Args:
frames: `Tensor` of shape [num_frames, height, width, channels].
Returns:
A Tensor of shape [num_frames, height, width, channels] being color jittered
with the same operation.
"""
def random_color_jitter(image, p=1.0):
def _transform(image):
color_jitter_t = functools.partial(
color_jitter, strength=1.0)
image = random_apply(color_jitter_t, p=0.8, x=image)
return random_apply(to_grayscale, p=0.2, x=image)
return random_apply(_transform, p=p, x=image)
num_frames, width, height, channels = frames.shape.as_list()
big_image = tf.reshape(frames, [num_frames*width, height, channels])
big_image = random_color_jitter(big_image)
return tf.reshape(big_image, [num_frames, width, height, channels])
def gaussian_blur(image, kernel_size, sigma, padding='SAME'):
"""Blurs the given image with separable convolution.
Args:
image: Tensor of shape [height, width, channels] and dtype float to blur.
kernel_size: Integer Tensor for the size of the blur kernel. This is should
be an odd number. If it is an even number, the actual kernel size will be
size + 1.
sigma: Sigma value for gaussian operator.
padding: Padding to use for the convolution. Typically 'SAME' or 'VALID'.
Returns:
A Tensor representing the blurred image.
"""
radius = tf.cast(kernel_size / 2, dtype=tf.int32)
kernel_size = radius * 2 + 1
x = tf.cast(tf.range(-radius, radius + 1), dtype=tf.float32)
blur_filter = tf.exp(
-tf.pow(x, 2.0) / (2.0 * tf.pow(tf.cast(sigma, dtype=tf.float32), 2.0)))
blur_filter /= tf.reduce_sum(blur_filter)
# One vertical and one horizontal filter.
blur_v = tf.reshape(blur_filter, [kernel_size, 1, 1, 1])
blur_h = tf.reshape(blur_filter, [1, kernel_size, 1, 1])
num_channels = tf.shape(image)[-1]
blur_h = tf.tile(blur_h, [1, 1, num_channels, 1])
blur_v = tf.tile(blur_v, [1, 1, num_channels, 1])
expand_batch_dim = image.shape.ndims == 3
if expand_batch_dim:
# Tensorflow requires batched input to convolutions, which we can fake with
# an extra dimension.
image = tf.expand_dims(image, axis=0)
blurred = tf.nn.depthwise_conv2d(
image, blur_h, strides=[1, 1, 1, 1], padding=padding)
blurred = tf.nn.depthwise_conv2d(
blurred, blur_v, strides=[1, 1, 1, 1], padding=padding)
if expand_batch_dim:
blurred = tf.squeeze(blurred, axis=0)
return blurred
def random_blur(image, height, width, p=1.0):
"""Randomly blur an image.
Args:
image: `Tensor` representing an image of arbitrary size.
height: Height of output image.
width: Width of output image.
p: probability of applying this transformation.
Returns:
A preprocessed image `Tensor`.
"""
del width
def _transform(image):
sigma = tf.random.uniform([], 0.1, 2.0, dtype=tf.float32)
return gaussian_blur(
image, kernel_size=height//10, sigma=sigma, padding='SAME')
return random_apply(_transform, p=p, x=image)
def random_blur_3d(frames, height, width, blur_probability=0.5):
"""Apply efficient batch data transformations.
Args:
frames: `Tensor` of shape [timesteps, height, width, 3].
height: the height of image.
width: the width of image.
blur_probability: the probaility to apply the blur operator.
Returns:
Preprocessed feature list.
"""
def generate_selector(p, bsz):
shape = [bsz, 1, 1, 1]
selector = tf.cast(
tf.less(tf.random.uniform(shape, 0, 1, dtype=tf.float32), p),
tf.float32)
return selector
frames_new = random_blur(frames, height, width, p=1.)
selector = generate_selector(blur_probability, 1)
frames = frames_new * selector + frames * (1 - selector)
frames = tf.clip_by_value(frames, 0., 1.)
return frames
def _sample_or_pad_sequence_indices(sequence: tf.Tensor,
num_steps: int,
stride: int,
offset: tf.Tensor) -> tf.Tensor:
"""Returns indices to take for sampling or padding sequences to fixed size."""
sequence_length = tf.shape(sequence)[0]
sel_idx = tf.range(sequence_length)
# Repeats sequence until num_steps are available in total.
max_length = num_steps * stride + offset
num_repeats = tf.math.floordiv(
max_length + sequence_length - 1, sequence_length)
sel_idx = tf.tile(sel_idx, [num_repeats])
steps = tf.range(offset, offset + num_steps * stride, stride)
return tf.gather(sel_idx, steps)
def sample_ssl_sequence(sequence: tf.Tensor,
num_steps: int,
random: bool,
stride: int = 1,
num_windows: Optional[int] = 2) -> tf.Tensor:
"""Samples two segments of size num_steps randomly from a given sequence.
Currently it only supports images, and specically designed for video self-
supervised learning.
Args:
sequence: Any tensor where the first dimension is timesteps.
num_steps: Number of steps (e.g. frames) to take.
random: A boolean indicating whether to random sample the single window. If
True, the offset is randomized. Only True is supported.
stride: Distance to sample between timesteps.
num_windows: Number of sequence sampled.
Returns:
A single Tensor with first dimension num_steps with the sampled segment.
"""
sequence_length = tf.shape(sequence)[0]
sequence_length = tf.cast(sequence_length, tf.float32)
if random:
max_offset = tf.cond(
tf.greater(sequence_length, (num_steps - 1) * stride),
lambda: sequence_length - (num_steps - 1) * stride,
lambda: sequence_length)
max_offset = tf.cast(max_offset, dtype=tf.float32)
def cdf(k, power=1.0):
"""Cumulative distribution function for x^power."""
p = -tf.math.pow(k, power + 1) / (
power * tf.math.pow(max_offset, power + 1)) + k * (power + 1) / (
power * max_offset)
return p
u = tf.random.uniform(())
k_low = tf.constant(0, dtype=tf.float32)
k_up = max_offset
k = tf.math.floordiv(max_offset, 2.0)
c = lambda k_low, k_up, k: tf.greater(tf.math.abs(k_up - k_low), 1.0)
# pylint:disable=g-long-lambda
b = lambda k_low, k_up, k: tf.cond(
tf.greater(cdf(k), u),
lambda: [k_low, k, tf.math.floordiv(k + k_low, 2.0)],
lambda: [k, k_up, tf.math.floordiv(k_up + k, 2.0)])
_, _, k = tf.while_loop(c, b, [k_low, k_up, k])
delta = tf.cast(k, tf.int32)
max_offset = tf.cast(max_offset, tf.int32)
sequence_length = tf.cast(sequence_length, tf.int32)
choice_1 = tf.cond(
tf.equal(max_offset, sequence_length),
lambda: tf.random.uniform((),
maxval=tf.cast(max_offset, dtype=tf.int32),
dtype=tf.int32),
lambda: tf.random.uniform((),
maxval=tf.cast(max_offset - delta,
dtype=tf.int32),
dtype=tf.int32))
choice_2 = tf.cond(
tf.equal(max_offset, sequence_length),
lambda: tf.random.uniform((),
maxval=tf.cast(max_offset, dtype=tf.int32),
dtype=tf.int32),
lambda: choice_1 + delta)
# pylint:disable=g-long-lambda
shuffle_choice = tf.random.shuffle((choice_1, choice_2))
offset_1 = shuffle_choice[0]
offset_2 = shuffle_choice[1]
else:
raise NotImplementedError
indices_1 = _sample_or_pad_sequence_indices(
sequence=sequence,
num_steps=num_steps,
stride=stride,
offset=offset_1)
indices_2 = _sample_or_pad_sequence_indices(
sequence=sequence,
num_steps=num_steps,
stride=stride,
offset=offset_2)
indices = tf.concat([indices_1, indices_2], axis=0)
indices.set_shape((num_windows * num_steps,))
output = tf.gather(sequence, indices)
return output
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from official.vision.beta.ops import preprocess_ops_3d
from official.vision.beta.projects.video_ssl.ops import video_ssl_preprocess_ops
class VideoSslPreprocessOpsTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self._raw_frames = tf.random.uniform((250, 256, 256, 3), minval=0,
maxval=255, dtype=tf.dtypes.int32)
self._sampled_frames = self._raw_frames[:16]
self._frames = preprocess_ops_3d.normalize_image(
self._sampled_frames, False, tf.float32)
def test_sample_ssl_sequence(self):
sampled_seq = video_ssl_preprocess_ops.sample_ssl_sequence(
self._raw_frames, 16, True, 2)
self.assertAllEqual(sampled_seq.shape, (32, 256, 256, 3))
def test_random_color_jitter_3d(self):
jittered_clip = video_ssl_preprocess_ops.random_color_jitter_3d(
self._frames)
self.assertAllEqual(jittered_clip.shape, (16, 256, 256, 3))
def test_random_blur_3d(self):
blurred_clip = video_ssl_preprocess_ops.random_blur_3d(
self._frames, 256, 256)
self.assertAllEqual(blurred_clip.shape, (16, 256, 256, 3))
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment