"example/vscode:/vscode.git/clone" did not exist on "0c51a35ea8a60adb8feb2fc7da876ea45c272730"
Unverified Commit 1f8b5b27 authored by Simon Geisler's avatar Simon Geisler Committed by GitHub
Browse files

Merge branch 'master' into master

parents 0eeeaf98 8fcf177e
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Define losses."""
# Import libraries
import tensorflow as tf
from tensorflow.compiler.tf2xla.python import xla
def contrastive_loss(hidden,
num_replicas,
normalize_hidden,
temperature,
model,
weight_decay):
"""Computes contrastive loss.
Args:
hidden: embedding of video clips after projection head.
num_replicas: number of distributed replicas.
normalize_hidden: whether or not to l2 normalize the hidden vector.
temperature: temperature in the InfoNCE contrastive loss.
model: keras model for calculating weight decay.
weight_decay: weight decay parameter.
Returns:
A loss scalar.
The logits for contrastive prediction task.
The labels for contrastive prediction task.
"""
large_num = 1e9
hidden1, hidden2 = tf.split(hidden, num_or_size_splits=2, axis=0)
if normalize_hidden:
hidden1 = tf.math.l2_normalize(hidden1, -1)
hidden2 = tf.math.l2_normalize(hidden2, -1)
batch_size = tf.shape(hidden1)[0]
if num_replicas == 1:
# This is the local version
hidden1_large = hidden1
hidden2_large = hidden2
labels = tf.one_hot(tf.range(batch_size), batch_size * 2)
masks = tf.one_hot(tf.range(batch_size), batch_size)
else:
# This is the cross-tpu version.
hidden1_large = tpu_cross_replica_concat(hidden1, num_replicas)
hidden2_large = tpu_cross_replica_concat(hidden2, num_replicas)
enlarged_batch_size = tf.shape(hidden1_large)[0]
replica_id = tf.cast(tf.cast(xla.replica_id(), tf.uint32), tf.int32)
labels_idx = tf.range(batch_size) + replica_id * batch_size
labels = tf.one_hot(labels_idx, enlarged_batch_size * 2)
masks = tf.one_hot(labels_idx, enlarged_batch_size)
logits_aa = tf.matmul(hidden1, hidden1_large, transpose_b=True) / temperature
logits_aa = logits_aa - tf.cast(masks, logits_aa.dtype) * large_num
logits_bb = tf.matmul(hidden2, hidden2_large, transpose_b=True) / temperature
logits_bb = logits_bb - tf.cast(masks, logits_bb.dtype) * large_num
logits_ab = tf.matmul(hidden1, hidden2_large, transpose_b=True) / temperature
logits_ba = tf.matmul(hidden2, hidden1_large, transpose_b=True) / temperature
loss_a = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
labels, tf.concat([logits_ab, logits_aa], 1)))
loss_b = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
labels, tf.concat([logits_ba, logits_bb], 1)))
loss = loss_a + loss_b
l2_loss = weight_decay * tf.add_n([
tf.nn.l2_loss(v)
for v in model.trainable_variables
if 'kernel' in v.name
])
total_loss = loss + tf.cast(l2_loss, loss.dtype)
contrast_prob = tf.nn.softmax(logits_ab)
contrast_entropy = - tf.reduce_mean(
tf.reduce_sum(contrast_prob * tf.math.log(contrast_prob + 1e-8), -1))
contrast_acc = tf.equal(tf.argmax(labels, 1), tf.argmax(logits_ab, axis=1))
contrast_acc = tf.reduce_mean(tf.cast(contrast_acc, tf.float32))
return {
'total_loss': total_loss,
'contrastive_loss': loss,
'reg_loss': l2_loss,
'contrast_acc': contrast_acc,
'contrast_entropy': contrast_entropy,
}
def tpu_cross_replica_concat(tensor, num_replicas):
"""Reduce a concatenation of the `tensor` across TPU cores.
Args:
tensor: tensor to concatenate.
num_replicas: number of TPU device replicas.
Returns:
Tensor of the same rank as `tensor` with first dimension `num_replicas`
times larger.
"""
with tf.name_scope('tpu_cross_replica_concat'):
# This creates a tensor that is like the input tensor but has an added
# replica dimension as the outermost dimension. On each replica it will
# contain the local values and zeros for all other values that need to be
# fetched from other replicas.
ext_tensor = tf.scatter_nd(
indices=[[xla.replica_id()]],
updates=[tensor],
shape=[num_replicas] + tensor.shape.as_list())
# As every value is only present on one replica and 0 in all others, adding
# them all together will result in the full tensor on all replicas.
replica_context = tf.distribute.get_replica_context()
ext_tensor = replica_context.all_reduce(tf.distribute.ReduceOp.SUM,
ext_tensor)
# Flatten the replica dimension.
# The first dimension size will be: tensor.shape[0] * num_replicas
# Using [-1] trick to support also scalar input.
return tf.reshape(ext_tensor, [-1] + ext_tensor.shape.as_list()[2:])
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Build video classification models."""
from typing import Mapping, Optional
# Import libraries
import tensorflow as tf
from official.modeling import tf_utils
from official.vision.beta.modeling import backbones
from official.vision.beta.modeling import factory_3d as model_factory
from official.vision.beta.projects.video_ssl.configs import video_ssl as video_ssl_cfg
layers = tf.keras.layers
@tf.keras.utils.register_keras_serializable(package='Vision')
class VideoSSLModel(tf.keras.Model):
"""A video ssl model class builder."""
def __init__(self,
backbone,
normalize_feature,
hidden_dim,
hidden_layer_num,
hidden_norm_args,
projection_dim,
input_specs: Optional[Mapping[str,
tf.keras.layers.InputSpec]] = None,
dropout_rate: float = 0.0,
aggregate_endpoints: bool = False,
kernel_initializer='random_uniform',
kernel_regularizer=None,
bias_regularizer=None,
**kwargs):
"""Video Classification initialization function.
Args:
backbone: a 3d backbone network.
normalize_feature: whether normalize backbone feature.
hidden_dim: `int` number of hidden units in MLP.
hidden_layer_num: `int` number of hidden layers in MLP.
hidden_norm_args: `dict` for batchnorm arguments in MLP.
projection_dim: `int` number of ouput dimension for MLP.
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
dropout_rate: `float` rate for dropout regularization.
aggregate_endpoints: `bool` aggregate all end ponits or only use the
final end point.
kernel_initializer: kernel initializer for the dense layer.
kernel_regularizer: tf.keras.regularizers.Regularizer object. Default to
None.
bias_regularizer: tf.keras.regularizers.Regularizer object. Default to
None.
**kwargs: keyword arguments to be passed.
"""
if not input_specs:
input_specs = {
'image': layers.InputSpec(shape=[None, None, None, None, 3])
}
self._self_setattr_tracking = False
self._config_dict = {
'backbone': backbone,
'normalize_feature': normalize_feature,
'hidden_dim': hidden_dim,
'hidden_layer_num': hidden_layer_num,
'use_sync_bn': hidden_norm_args.use_sync_bn,
'norm_momentum': hidden_norm_args.norm_momentum,
'norm_epsilon': hidden_norm_args.norm_epsilon,
'activation': hidden_norm_args.activation,
'projection_dim': projection_dim,
'input_specs': input_specs,
'dropout_rate': dropout_rate,
'aggregate_endpoints': aggregate_endpoints,
'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
}
self._input_specs = input_specs
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._backbone = backbone
inputs = {
k: tf.keras.Input(shape=v.shape[1:]) for k, v in input_specs.items()
}
endpoints = backbone(inputs['image'])
if aggregate_endpoints:
pooled_feats = []
for endpoint in endpoints.values():
x_pool = tf.keras.layers.GlobalAveragePooling3D()(endpoint)
pooled_feats.append(x_pool)
x = tf.concat(pooled_feats, axis=1)
else:
x = endpoints[max(endpoints.keys())]
x = tf.keras.layers.GlobalAveragePooling3D()(x)
# L2 Normalize feature after backbone
if normalize_feature:
x = tf.nn.l2_normalize(x, axis=-1)
# MLP hidden layers
for _ in range(hidden_layer_num):
x = tf.keras.layers.Dense(hidden_dim)(x)
if self._config_dict['use_sync_bn']:
x = tf.keras.layers.experimental.SyncBatchNormalization(
momentum=self._config_dict['norm_momentum'],
epsilon=self._config_dict['norm_epsilon'])(x)
else:
x = tf.keras.layers.BatchNormalization(
momentum=self._config_dict['norm_momentum'],
epsilon=self._config_dict['norm_epsilon'])(x)
x = tf_utils.get_activation(self._config_dict['activation'])(x)
# Projection head
x = tf.keras.layers.Dense(projection_dim)(x)
super(VideoSSLModel, self).__init__(
inputs=inputs, outputs=x, **kwargs)
@property
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
return dict(backbone=self.backbone)
@property
def backbone(self):
return self._backbone
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@model_factory.register_model_builder('video_ssl_model')
def build_video_ssl_pretrain_model(
input_specs: tf.keras.layers.InputSpec,
model_config: video_ssl_cfg.VideoSSLModel,
num_classes: int,
l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None):
"""Builds the video classification model."""
del num_classes
input_specs_dict = {'image': input_specs}
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
backbone_config=model_config.backbone,
norm_activation_config=model_config.norm_activation,
l2_regularizer=l2_regularizer)
# Norm layer type in the MLP head should same with backbone
assert model_config.norm_activation.use_sync_bn == model_config.hidden_norm_activation.use_sync_bn
model = VideoSSLModel(
backbone=backbone,
normalize_feature=model_config.normalize_feature,
hidden_dim=model_config.hidden_dim,
hidden_layer_num=model_config.hidden_layer_num,
hidden_norm_args=model_config.hidden_norm_activation,
projection_dim=model_config.projection_dim,
input_specs=input_specs_dict,
dropout_rate=model_config.dropout_rate,
aggregate_endpoints=model_config.aggregate_endpoints,
kernel_regularizer=l2_regularizer)
return model
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from official.vision.beta.ops import preprocess_ops_3d
from official.vision.beta.projects.video_ssl.ops import video_ssl_preprocess_ops
class VideoSslPreprocessOpsTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self._raw_frames = tf.random.uniform((250, 256, 256, 3), minval=0,
maxval=255, dtype=tf.dtypes.int32)
self._sampled_frames = self._raw_frames[:16]
self._frames = preprocess_ops_3d.normalize_image(
self._sampled_frames, False, tf.float32)
def test_sample_ssl_sequence(self):
sampled_seq = video_ssl_preprocess_ops.sample_ssl_sequence(
self._raw_frames, 16, True, 2)
self.assertAllEqual(sampled_seq.shape, (32, 256, 256, 3))
def test_random_color_jitter_3d(self):
jittered_clip = video_ssl_preprocess_ops.random_color_jitter_3d(
self._frames)
self.assertAllEqual(jittered_clip.shape, (16, 256, 256, 3))
def test_random_blur_3d(self):
blurred_clip = video_ssl_preprocess_ops.random_blur_3d(
self._frames, 256, 256)
self.assertAllEqual(blurred_clip.shape, (16, 256, 256, 3))
if __name__ == '__main__':
tf.test.main()
...@@ -12,3 +12,7 @@ ...@@ -12,3 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tasks package definition."""
from official.vision.beta.projects.video_ssl.tasks import linear_eval
from official.vision.beta.projects.video_ssl.tasks import pretrain
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -79,8 +79,8 @@ class SemanticSegmentation3DTask(base_task.Task): ...@@ -79,8 +79,8 @@ class SemanticSegmentation3DTask(base_task.Task):
# Restoring checkpoint. # Restoring checkpoint.
if 'all' in self.task_config.init_checkpoint_modules: if 'all' in self.task_config.init_checkpoint_modules:
ckpt = tf.train.Checkpoint(**model.checkpoint_items) ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.assert_consumed() status.expect_partial().assert_existing_objects_matched()
else: else:
ckpt_items = {} ckpt_items = {}
if 'backbone' in self.task_config.init_checkpoint_modules: if 'backbone' in self.task_config.init_checkpoint_modules:
...@@ -89,7 +89,7 @@ class SemanticSegmentation3DTask(base_task.Task): ...@@ -89,7 +89,7 @@ class SemanticSegmentation3DTask(base_task.Task):
ckpt_items.update(decoder=model.decoder) ckpt_items.update(decoder=model.decoder)
ckpt = tf.train.Checkpoint(**ckpt_items) ckpt = tf.train.Checkpoint(**ckpt_items)
status = ckpt.restore(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s', logging.info('Finished loading pretrained checkpoint from %s',
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -63,11 +63,11 @@ class ImageClassificationTask(base_task.Task): ...@@ -63,11 +63,11 @@ class ImageClassificationTask(base_task.Task):
# Restoring checkpoint. # Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'all': if self.task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(**model.checkpoint_items) ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.assert_consumed() status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'backbone': elif self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone) ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.restore(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
else: else:
raise ValueError( raise ValueError(
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment