Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8b641b13
Unverified
Commit
8b641b13
authored
Mar 26, 2022
by
Srihari Humbarwadi
Committed by
GitHub
Mar 26, 2022
Browse files
Merge branch 'tensorflow:master' into panoptic-deeplab
parents
7cffacfe
357fa547
Changes
411
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5 additions
and
2017 deletions
+5
-2017
official/vision/beta/projects/video_ssl/configs/__init__.py
official/vision/beta/projects/video_ssl/configs/__init__.py
+0
-18
official/vision/beta/projects/video_ssl/configs/experiments/cvrl_linear_eval_k600.yaml
.../video_ssl/configs/experiments/cvrl_linear_eval_k600.yaml
+0
-92
official/vision/beta/projects/video_ssl/configs/experiments/cvrl_pretrain_k600_200ep.yaml
...deo_ssl/configs/experiments/cvrl_pretrain_k600_200ep.yaml
+0
-73
official/vision/beta/projects/video_ssl/configs/video_ssl.py
official/vision/beta/projects/video_ssl/configs/video_ssl.py
+0
-138
official/vision/beta/projects/video_ssl/configs/video_ssl_test.py
.../vision/beta/projects/video_ssl/configs/video_ssl_test.py
+0
-56
official/vision/beta/projects/video_ssl/dataloaders/video_ssl_input.py
...on/beta/projects/video_ssl/dataloaders/video_ssl_input.py
+0
-321
official/vision/beta/projects/video_ssl/dataloaders/video_ssl_input_test.py
...ta/projects/video_ssl/dataloaders/video_ssl_input_test.py
+0
-111
official/vision/beta/projects/video_ssl/losses/losses.py
official/vision/beta/projects/video_ssl/losses/losses.py
+0
-136
official/vision/beta/projects/video_ssl/modeling/video_ssl_model.py
...ision/beta/projects/video_ssl/modeling/video_ssl_model.py
+0
-179
official/vision/beta/projects/video_ssl/ops/video_ssl_preprocess_ops.py
...n/beta/projects/video_ssl/ops/video_ssl_preprocess_ops.py
+0
-406
official/vision/beta/projects/video_ssl/ops/video_ssl_preprocess_ops_test.py
...a/projects/video_ssl/ops/video_ssl_preprocess_ops_test.py
+0
-47
official/vision/beta/projects/video_ssl/tasks/__init__.py
official/vision/beta/projects/video_ssl/tasks/__init__.py
+0
-18
official/vision/beta/projects/video_ssl/tasks/linear_eval.py
official/vision/beta/projects/video_ssl/tasks/linear_eval.py
+0
-71
official/vision/beta/projects/video_ssl/tasks/pretrain.py
official/vision/beta/projects/video_ssl/tasks/pretrain.py
+0
-186
official/vision/beta/projects/video_ssl/tasks/pretrain_test.py
...ial/vision/beta/projects/video_ssl/tasks/pretrain_test.py
+0
-82
official/vision/beta/projects/video_ssl/train.py
official/vision/beta/projects/video_ssl/train.py
+0
-78
official/vision/beta/projects/yolo/common/registry_imports.py
...cial/vision/beta/projects/yolo/common/registry_imports.py
+1
-1
official/vision/beta/projects/yolo/configs/backbones.py
official/vision/beta/projects/yolo/configs/backbones.py
+1
-1
official/vision/beta/projects/yolo/configs/darknet_classification.py
...sion/beta/projects/yolo/configs/darknet_classification.py
+2
-2
official/vision/beta/projects/yolo/configs/decoders.py
official/vision/beta/projects/yolo/configs/decoders.py
+1
-1
No files found.
Too many changes to show.
To preserve performance only
411 of 411+
files are displayed.
Plain diff
Email patch
official/vision/beta/projects/video_ssl/configs/__init__.py
deleted
100644 → 0
View file @
7cffacfe
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Configs package definition."""
from
official.vision.beta.projects.video_ssl.configs
import
video_ssl
official/vision/beta/projects/video_ssl/configs/experiments/cvrl_linear_eval_k600.yaml
deleted
100644 → 0
View file @
7cffacfe
runtime
:
distribution_strategy
:
'
tpu'
mixed_precision_dtype
:
'
bfloat16'
task
:
# Put the pretrained checkpoint here for linear evaluation
init_checkpoint
:
'
r3d_1x_k600_800ep_backbone-1'
init_checkpoint_modules
:
'
backbone'
model
:
dropout_rate
:
1.0
norm_activation
:
use_sync_bn
:
false
backbone
:
resnet_3d
:
block_specs
:
!!python/tuple
-
temporal_kernel_sizes
:
!!python/tuple
-
1
-
1
-
1
temporal_strides
:
1
use_self_gating
:
false
-
temporal_kernel_sizes
:
!!python/tuple
-
1
-
1
-
1
-
1
temporal_strides
:
1
use_self_gating
:
false
-
temporal_kernel_sizes
:
!!python/tuple
-
3
-
3
-
3
-
3
-
3
-
3
temporal_strides
:
1
use_self_gating
:
false
-
temporal_kernel_sizes
:
!!python/tuple
-
3
-
3
-
3
temporal_strides
:
1
use_self_gating
:
false
model_id
:
50
stem_conv_temporal_kernel_size
:
5
stem_conv_temporal_stride
:
2
stem_pool_temporal_stride
:
1
train_data
:
name
:
kinetics600
feature_shape
:
!!python/tuple
-
32
-
224
-
224
-
3
temporal_stride
:
2
global_batch_size
:
1024
dtype
:
'
bfloat16'
shuffle_buffer_size
:
1024
aug_max_area_ratio
:
1.0
aug_max_aspect_ratio
:
2.0
aug_min_area_ratio
:
0.3
aug_min_aspect_ratio
:
0.5
validation_data
:
name
:
kinetics600
feature_shape
:
!!python/tuple
-
32
-
256
-
256
-
3
temporal_stride
:
2
num_test_clips
:
10
num_test_crops
:
3
global_batch_size
:
64
dtype
:
'
bfloat16'
drop_remainder
:
false
losses
:
l2_weight_decay
:
0.0
trainer
:
optimizer_config
:
learning_rate
:
cosine
:
initial_learning_rate
:
32.0
decay_steps
:
35744
optimizer
:
sgd
:
nesterov
:
false
warmup
:
linear
:
warmup_steps
:
1787
train_steps
:
35744
steps_per_loop
:
100
summary_interval
:
100
validation_interval
:
100
official/vision/beta/projects/video_ssl/configs/experiments/cvrl_pretrain_k600_200ep.yaml
deleted
100644 → 0
View file @
7cffacfe
runtime
:
distribution_strategy
:
'
tpu'
mixed_precision_dtype
:
'
bfloat16'
task
:
model
:
dropout_rate
:
1.0
norm_activation
:
use_sync_bn
:
true
hidden_norm_activation
:
use_sync_bn
:
true
backbone
:
resnet_3d
:
block_specs
:
!!python/tuple
-
temporal_kernel_sizes
:
!!python/tuple
-
1
-
1
-
1
temporal_strides
:
1
use_self_gating
:
false
-
temporal_kernel_sizes
:
!!python/tuple
-
1
-
1
-
1
-
1
temporal_strides
:
1
use_self_gating
:
false
-
temporal_kernel_sizes
:
!!python/tuple
-
3
-
3
-
3
-
3
-
3
-
3
temporal_strides
:
1
use_self_gating
:
false
-
temporal_kernel_sizes
:
!!python/tuple
-
3
-
3
-
3
temporal_strides
:
1
use_self_gating
:
false
model_id
:
50
stem_conv_temporal_kernel_size
:
5
stem_conv_temporal_stride
:
2
stem_pool_temporal_stride
:
1
train_data
:
name
:
kinetics600
feature_shape
:
!!python/tuple
-
16
-
224
-
224
-
3
temporal_stride
:
2
global_batch_size
:
1024
dtype
:
'
bfloat16'
shuffle_buffer_size
:
1024
losses
:
l2_weight_decay
:
0.000001
trainer
:
optimizer_config
:
learning_rate
:
cosine
:
initial_learning_rate
:
0.32
decay_steps
:
71488
optimizer
:
sgd
:
nesterov
:
false
warmup
:
linear
:
warmup_steps
:
1787
train_steps
:
71488
steps_per_loop
:
100
summary_interval
:
100
official/vision/beta/projects/video_ssl/configs/video_ssl.py
deleted
100644 → 0
View file @
7cffacfe
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Video classification configuration definition."""
import
dataclasses
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.vision.beta.configs
import
common
from
official.vision.beta.configs
import
video_classification
Losses
=
video_classification
.
Losses
VideoClassificationModel
=
video_classification
.
VideoClassificationModel
VideoClassificationTask
=
video_classification
.
VideoClassificationTask
@
dataclasses
.
dataclass
class
VideoSSLPretrainTask
(
VideoClassificationTask
):
pass
@
dataclasses
.
dataclass
class
VideoSSLEvalTask
(
VideoClassificationTask
):
pass
@
dataclasses
.
dataclass
class
DataConfig
(
video_classification
.
DataConfig
):
"""The base configuration for building datasets."""
is_ssl
:
bool
=
False
@
dataclasses
.
dataclass
class
VideoSSLModel
(
VideoClassificationModel
):
"""The model config."""
normalize_feature
:
bool
=
False
hidden_dim
:
int
=
2048
hidden_layer_num
:
int
=
3
projection_dim
:
int
=
128
hidden_norm_activation
:
common
.
NormActivation
=
common
.
NormActivation
(
use_sync_bn
=
False
,
norm_momentum
=
0.997
,
norm_epsilon
=
1.0e-05
)
@
dataclasses
.
dataclass
class
SSLLosses
(
Losses
):
normalize_hidden
:
bool
=
True
temperature
:
float
=
0.1
@
exp_factory
.
register_config_factory
(
'video_ssl_pretrain_kinetics400'
)
def
video_ssl_pretrain_kinetics400
()
->
cfg
.
ExperimentConfig
:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp
=
video_classification
.
video_classification_kinetics400
()
exp
.
task
=
VideoSSLPretrainTask
(
**
exp
.
task
.
as_dict
())
exp
.
task
.
train_data
=
DataConfig
(
is_ssl
=
True
,
**
exp
.
task
.
train_data
.
as_dict
())
exp
.
task
.
train_data
.
feature_shape
=
(
16
,
224
,
224
,
3
)
exp
.
task
.
train_data
.
temporal_stride
=
2
exp
.
task
.
model
=
VideoSSLModel
(
exp
.
task
.
model
)
exp
.
task
.
model
.
model_type
=
'video_ssl_model'
exp
.
task
.
losses
=
SSLLosses
(
exp
.
task
.
losses
)
return
exp
@
exp_factory
.
register_config_factory
(
'video_ssl_linear_eval_kinetics400'
)
def
video_ssl_linear_eval_kinetics400
()
->
cfg
.
ExperimentConfig
:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp
=
video_classification
.
video_classification_kinetics400
()
exp
.
task
=
VideoSSLEvalTask
(
**
exp
.
task
.
as_dict
())
exp
.
task
.
train_data
=
DataConfig
(
is_ssl
=
False
,
**
exp
.
task
.
train_data
.
as_dict
())
exp
.
task
.
train_data
.
feature_shape
=
(
32
,
224
,
224
,
3
)
exp
.
task
.
train_data
.
temporal_stride
=
2
exp
.
task
.
validation_data
.
feature_shape
=
(
32
,
256
,
256
,
3
)
exp
.
task
.
validation_data
.
temporal_stride
=
2
exp
.
task
.
validation_data
=
DataConfig
(
is_ssl
=
False
,
**
exp
.
task
.
validation_data
.
as_dict
())
exp
.
task
.
validation_data
.
min_image_size
=
256
exp
.
task
.
validation_data
.
num_test_clips
=
10
exp
.
task
.
validation_data
.
num_test_crops
=
3
exp
.
task
.
model
=
VideoSSLModel
(
exp
.
task
.
model
)
exp
.
task
.
model
.
model_type
=
'video_ssl_model'
exp
.
task
.
model
.
normalize_feature
=
True
exp
.
task
.
model
.
hidden_layer_num
=
0
exp
.
task
.
model
.
projection_dim
=
400
return
exp
@
exp_factory
.
register_config_factory
(
'video_ssl_pretrain_kinetics600'
)
def
video_ssl_pretrain_kinetics600
()
->
cfg
.
ExperimentConfig
:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp
=
video_classification
.
video_classification_kinetics600
()
exp
.
task
=
VideoSSLPretrainTask
(
**
exp
.
task
.
as_dict
())
exp
.
task
.
train_data
=
DataConfig
(
is_ssl
=
True
,
**
exp
.
task
.
train_data
.
as_dict
())
exp
.
task
.
train_data
.
feature_shape
=
(
16
,
224
,
224
,
3
)
exp
.
task
.
train_data
.
temporal_stride
=
2
exp
.
task
.
model
=
VideoSSLModel
(
exp
.
task
.
model
)
exp
.
task
.
model
.
model_type
=
'video_ssl_model'
exp
.
task
.
losses
=
SSLLosses
(
exp
.
task
.
losses
)
return
exp
@
exp_factory
.
register_config_factory
(
'video_ssl_linear_eval_kinetics600'
)
def
video_ssl_linear_eval_kinetics600
()
->
cfg
.
ExperimentConfig
:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp
=
video_classification
.
video_classification_kinetics600
()
exp
.
task
=
VideoSSLEvalTask
(
**
exp
.
task
.
as_dict
())
exp
.
task
.
train_data
=
DataConfig
(
is_ssl
=
False
,
**
exp
.
task
.
train_data
.
as_dict
())
exp
.
task
.
train_data
.
feature_shape
=
(
32
,
224
,
224
,
3
)
exp
.
task
.
train_data
.
temporal_stride
=
2
exp
.
task
.
validation_data
=
DataConfig
(
is_ssl
=
False
,
**
exp
.
task
.
validation_data
.
as_dict
())
exp
.
task
.
validation_data
.
feature_shape
=
(
32
,
256
,
256
,
3
)
exp
.
task
.
validation_data
.
temporal_stride
=
2
exp
.
task
.
validation_data
.
min_image_size
=
256
exp
.
task
.
validation_data
.
num_test_clips
=
10
exp
.
task
.
validation_data
.
num_test_crops
=
3
exp
.
task
.
model
=
VideoSSLModel
(
exp
.
task
.
model
)
exp
.
task
.
model
.
model_type
=
'video_ssl_model'
exp
.
task
.
model
.
normalize_feature
=
True
exp
.
task
.
model
.
hidden_layer_num
=
0
exp
.
task
.
model
.
projection_dim
=
600
return
exp
official/vision/beta/projects/video_ssl/configs/video_ssl_test.py
deleted
100644 → 0
View file @
7cffacfe
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# pylint: disable=unused-import
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.vision
import
beta
from
official.vision.beta.projects.video_ssl.configs
import
video_ssl
as
exp_cfg
class
VideoClassificationConfigTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
((
'video_ssl_pretrain_kinetics400'
,),
(
'video_ssl_pretrain_kinetics600'
,))
def
test_video_ssl_pretrain_configs
(
self
,
config_name
):
config
=
exp_factory
.
get_exp_config
(
config_name
)
self
.
assertIsInstance
(
config
,
cfg
.
ExperimentConfig
)
self
.
assertIsInstance
(
config
.
task
,
exp_cfg
.
VideoSSLPretrainTask
)
self
.
assertIsInstance
(
config
.
task
.
model
,
exp_cfg
.
VideoSSLModel
)
self
.
assertIsInstance
(
config
.
task
.
losses
,
exp_cfg
.
SSLLosses
)
self
.
assertIsInstance
(
config
.
task
.
train_data
,
exp_cfg
.
DataConfig
)
config
.
task
.
train_data
.
is_training
=
None
with
self
.
assertRaises
(
KeyError
):
config
.
validate
()
@
parameterized
.
parameters
((
'video_ssl_linear_eval_kinetics400'
,),
(
'video_ssl_linear_eval_kinetics600'
,))
def
test_video_ssl_linear_eval_configs
(
self
,
config_name
):
config
=
exp_factory
.
get_exp_config
(
config_name
)
self
.
assertIsInstance
(
config
,
cfg
.
ExperimentConfig
)
self
.
assertIsInstance
(
config
.
task
,
exp_cfg
.
VideoSSLEvalTask
)
self
.
assertIsInstance
(
config
.
task
.
model
,
exp_cfg
.
VideoSSLModel
)
self
.
assertIsInstance
(
config
.
task
.
train_data
,
exp_cfg
.
DataConfig
)
config
.
task
.
train_data
.
is_training
=
None
with
self
.
assertRaises
(
KeyError
):
config
.
validate
()
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/video_ssl/dataloaders/video_ssl_input.py
deleted
100644 → 0
View file @
7cffacfe
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Parser for video and label datasets."""
from
typing
import
Dict
,
Optional
,
Tuple
from
absl
import
logging
import
tensorflow
as
tf
from
official.vision.beta.dataloaders
import
video_input
from
official.vision.beta.ops
import
preprocess_ops_3d
from
official.vision.beta.projects.video_ssl.configs
import
video_ssl
as
exp_cfg
from
official.vision.beta.projects.video_ssl.ops
import
video_ssl_preprocess_ops
IMAGE_KEY
=
'image/encoded'
LABEL_KEY
=
'clip/label/index'
Decoder
=
video_input
.
Decoder
def
_process_image
(
image
:
tf
.
Tensor
,
is_training
:
bool
=
True
,
is_ssl
:
bool
=
False
,
num_frames
:
int
=
32
,
stride
:
int
=
1
,
num_test_clips
:
int
=
1
,
min_resize
:
int
=
256
,
crop_size
:
int
=
224
,
num_crops
:
int
=
1
,
zero_centering_image
:
bool
=
False
,
seed
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Processes a serialized image tensor.
Args:
image: Input Tensor of shape [timesteps] and type tf.string of serialized
frames.
is_training: Whether or not in training mode. If True, random sample, crop
and left right flip is used.
is_ssl: Whether or not in self-supervised pre-training mode.
num_frames: Number of frames per subclip.
stride: Temporal stride to sample frames.
num_test_clips: Number of test clips (1 by default). If more than 1, this
will sample multiple linearly spaced clips within each video at test time.
If 1, then a single clip in the middle of the video is sampled. The clips
are aggreagated in the batch dimension.
min_resize: Frames are resized so that min(height, width) is min_resize.
crop_size: Final size of the frame after cropping the resized frames. Both
height and width are the same.
num_crops: Number of crops to perform on the resized frames.
zero_centering_image: If True, frames are normalized to values in [-1, 1].
If False, values in [0, 1].
seed: A deterministic seed to use when sampling.
Returns:
Processed frames. Tensor of shape
[num_frames * num_test_clips, crop_size, crop_size, 3].
"""
# Validate parameters.
if
is_training
and
num_test_clips
!=
1
:
logging
.
warning
(
'`num_test_clips` %d is ignored since `is_training` is `True`.'
,
num_test_clips
)
# Temporal sampler.
if
is_training
:
# Sampler for training.
if
is_ssl
:
# Sample two clips from linear decreasing distribution.
image
=
video_ssl_preprocess_ops
.
sample_ssl_sequence
(
image
,
num_frames
,
True
,
stride
)
else
:
# Sample random clip.
image
=
preprocess_ops_3d
.
sample_sequence
(
image
,
num_frames
,
True
,
stride
)
else
:
# Sampler for evaluation.
if
num_test_clips
>
1
:
# Sample linspace clips.
image
=
preprocess_ops_3d
.
sample_linspace_sequence
(
image
,
num_test_clips
,
num_frames
,
stride
)
else
:
# Sample middle clip.
image
=
preprocess_ops_3d
.
sample_sequence
(
image
,
num_frames
,
False
,
stride
)
# Decode JPEG string to tf.uint8.
image
=
preprocess_ops_3d
.
decode_jpeg
(
image
,
3
)
if
is_training
:
# Standard image data augmentation: random resized crop and random flip.
if
is_ssl
:
image_1
,
image_2
=
tf
.
split
(
image
,
num_or_size_splits
=
2
,
axis
=
0
)
image_1
=
preprocess_ops_3d
.
random_crop_resize
(
image_1
,
crop_size
,
crop_size
,
num_frames
,
3
,
(
0.5
,
2
),
(
0.3
,
1
))
image_1
=
preprocess_ops_3d
.
random_flip_left_right
(
image_1
,
seed
)
image_2
=
preprocess_ops_3d
.
random_crop_resize
(
image_2
,
crop_size
,
crop_size
,
num_frames
,
3
,
(
0.5
,
2
),
(
0.3
,
1
))
image_2
=
preprocess_ops_3d
.
random_flip_left_right
(
image_2
,
seed
)
else
:
image
=
preprocess_ops_3d
.
random_crop_resize
(
image
,
crop_size
,
crop_size
,
num_frames
,
3
,
(
0.5
,
2
),
(
0.3
,
1
))
image
=
preprocess_ops_3d
.
random_flip_left_right
(
image
,
seed
)
else
:
# Resize images (resize happens only if necessary to save compute).
image
=
preprocess_ops_3d
.
resize_smallest
(
image
,
min_resize
)
# Three-crop of the frames.
image
=
preprocess_ops_3d
.
crop_image
(
image
,
crop_size
,
crop_size
,
False
,
num_crops
)
# Cast the frames in float32, normalizing according to zero_centering_image.
if
is_training
and
is_ssl
:
image_1
=
preprocess_ops_3d
.
normalize_image
(
image_1
,
zero_centering_image
)
image_2
=
preprocess_ops_3d
.
normalize_image
(
image_2
,
zero_centering_image
)
else
:
image
=
preprocess_ops_3d
.
normalize_image
(
image
,
zero_centering_image
)
# Self-supervised pre-training augmentations.
if
is_training
and
is_ssl
:
# Temporally consistent color jittering.
image_1
=
video_ssl_preprocess_ops
.
random_color_jitter_3d
(
image_1
)
image_2
=
video_ssl_preprocess_ops
.
random_color_jitter_3d
(
image_2
)
# Temporally consistent gaussian blurring.
image_1
=
video_ssl_preprocess_ops
.
random_blur
(
image_1
,
crop_size
,
crop_size
,
1.0
)
image_2
=
video_ssl_preprocess_ops
.
random_blur
(
image_2
,
crop_size
,
crop_size
,
0.1
)
image_2
=
video_ssl_preprocess_ops
.
random_solarization
(
image_2
)
image
=
tf
.
concat
([
image_1
,
image_2
],
axis
=
0
)
image
=
tf
.
clip_by_value
(
image
,
0.
,
1.
)
return
image
def
_postprocess_image
(
image
:
tf
.
Tensor
,
is_training
:
bool
=
True
,
is_ssl
:
bool
=
False
,
num_frames
:
int
=
32
,
num_test_clips
:
int
=
1
,
num_test_crops
:
int
=
1
)
->
tf
.
Tensor
:
"""Processes a batched Tensor of frames.
The same parameters used in process should be used here.
Args:
image: Input Tensor of shape [batch, timesteps, height, width, 3].
is_training: Whether or not in training mode. If True, random sample, crop
and left right flip is used.
is_ssl: Whether or not in self-supervised pre-training mode.
num_frames: Number of frames per subclip.
num_test_clips: Number of test clips (1 by default). If more than 1, this
will sample multiple linearly spaced clips within each video at test time.
If 1, then a single clip in the middle of the video is sampled. The clips
are aggreagated in the batch dimension.
num_test_crops: Number of test crops (1 by default). If more than 1, there
are multiple crops for each clip at test time. If 1, there is a single
central crop. The crops are aggreagated in the batch dimension.
Returns:
Processed frames. Tensor of shape
[batch * num_test_clips * num_test_crops, num_frames, height, width, 3].
"""
if
is_ssl
and
is_training
:
# In this case, two clips of self-supervised pre-training are merged
# together in batch dimenstion which will be 2 * batch.
image
=
tf
.
concat
(
tf
.
split
(
image
,
num_or_size_splits
=
2
,
axis
=
1
),
axis
=
0
)
num_views
=
num_test_clips
*
num_test_crops
if
num_views
>
1
and
not
is_training
:
# In this case, multiple views are merged together in batch dimenstion which
# will be batch * num_views.
image
=
tf
.
reshape
(
image
,
[
-
1
,
num_frames
]
+
image
.
shape
[
2
:].
as_list
())
return
image
def
_process_label
(
label
:
tf
.
Tensor
,
one_hot_label
:
bool
=
True
,
num_classes
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Processes label Tensor."""
# Validate parameters.
if
one_hot_label
and
not
num_classes
:
raise
ValueError
(
'`num_classes` should be given when requesting one hot label.'
)
# Cast to tf.int32.
label
=
tf
.
cast
(
label
,
dtype
=
tf
.
int32
)
if
one_hot_label
:
# Replace label index by one hot representation.
label
=
tf
.
one_hot
(
label
,
num_classes
)
if
len
(
label
.
shape
.
as_list
())
>
1
:
label
=
tf
.
reduce_sum
(
label
,
axis
=
0
)
if
num_classes
==
1
:
# The trick for single label.
label
=
1
-
label
return
label
class
Parser
(
video_input
.
Parser
):
"""Parses a video and label dataset."""
def
__init__
(
self
,
input_params
:
exp_cfg
.
DataConfig
,
image_key
:
str
=
IMAGE_KEY
,
label_key
:
str
=
LABEL_KEY
):
super
(
Parser
,
self
).
__init__
(
input_params
,
image_key
,
label_key
)
self
.
_is_ssl
=
input_params
.
is_ssl
def
_parse_train_data
(
self
,
decoded_tensors
:
Dict
[
str
,
tf
.
Tensor
]
)
->
Tuple
[
Dict
[
str
,
tf
.
Tensor
],
tf
.
Tensor
]:
"""Parses data for training."""
# Process image and label.
image
=
decoded_tensors
[
self
.
_image_key
]
image
=
_process_image
(
image
=
image
,
is_training
=
True
,
is_ssl
=
self
.
_is_ssl
,
num_frames
=
self
.
_num_frames
,
stride
=
self
.
_stride
,
num_test_clips
=
self
.
_num_test_clips
,
min_resize
=
self
.
_min_resize
,
crop_size
=
self
.
_crop_size
)
image
=
tf
.
cast
(
image
,
dtype
=
self
.
_dtype
)
features
=
{
'image'
:
image
}
label
=
decoded_tensors
[
self
.
_label_key
]
label
=
_process_label
(
label
,
self
.
_one_hot_label
,
self
.
_num_classes
)
return
features
,
label
def
_parse_eval_data
(
self
,
decoded_tensors
:
Dict
[
str
,
tf
.
Tensor
]
)
->
Tuple
[
Dict
[
str
,
tf
.
Tensor
],
tf
.
Tensor
]:
"""Parses data for evaluation."""
image
=
decoded_tensors
[
self
.
_image_key
]
image
=
_process_image
(
image
=
image
,
is_training
=
False
,
num_frames
=
self
.
_num_frames
,
stride
=
self
.
_stride
,
num_test_clips
=
self
.
_num_test_clips
,
min_resize
=
self
.
_min_resize
,
crop_size
=
self
.
_crop_size
,
num_crops
=
self
.
_num_crops
)
image
=
tf
.
cast
(
image
,
dtype
=
self
.
_dtype
)
features
=
{
'image'
:
image
}
label
=
decoded_tensors
[
self
.
_label_key
]
label
=
_process_label
(
label
,
self
.
_one_hot_label
,
self
.
_num_classes
)
if
self
.
_output_audio
:
audio
=
decoded_tensors
[
self
.
_audio_feature
]
audio
=
tf
.
cast
(
audio
,
dtype
=
self
.
_dtype
)
audio
=
preprocess_ops_3d
.
sample_sequence
(
audio
,
20
,
random
=
False
,
stride
=
1
)
audio
=
tf
.
ensure_shape
(
audio
,
[
20
,
2048
])
features
[
'audio'
]
=
audio
return
features
,
label
def
parse_fn
(
self
,
is_training
):
"""Returns a parse fn that reads and parses raw tensors from the decoder.
Args:
is_training: a `bool` to indicate whether it is in training mode.
Returns:
parse: a `callable` that takes the serialized examle and generate the
images, labels tuple where labels is a dict of Tensors that contains
labels.
"""
def
parse
(
decoded_tensors
):
"""Parses the serialized example data."""
if
is_training
:
return
self
.
_parse_train_data
(
decoded_tensors
)
else
:
return
self
.
_parse_eval_data
(
decoded_tensors
)
return
parse
class
PostBatchProcessor
(
object
):
"""Processes a video and label dataset which is batched."""
def
__init__
(
self
,
input_params
:
exp_cfg
.
DataConfig
):
self
.
_is_training
=
input_params
.
is_training
self
.
_is_ssl
=
input_params
.
is_ssl
self
.
_num_frames
=
input_params
.
feature_shape
[
0
]
self
.
_num_test_clips
=
input_params
.
num_test_clips
self
.
_num_test_crops
=
input_params
.
num_test_crops
def
__call__
(
self
,
features
:
Dict
[
str
,
tf
.
Tensor
],
label
:
tf
.
Tensor
)
->
Tuple
[
Dict
[
str
,
tf
.
Tensor
],
tf
.
Tensor
]:
"""Parses a single tf.Example into image and label tensors."""
for
key
in
[
'image'
,
'audio'
]:
if
key
in
features
:
features
[
key
]
=
_postprocess_image
(
image
=
features
[
key
],
is_training
=
self
.
_is_training
,
is_ssl
=
self
.
_is_ssl
,
num_frames
=
self
.
_num_frames
,
num_test_clips
=
self
.
_num_test_clips
,
num_test_crops
=
self
.
_num_test_crops
)
return
features
,
label
official/vision/beta/projects/video_ssl/dataloaders/video_ssl_input_test.py
deleted
100644 → 0
View file @
7cffacfe
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
import
io
# Import libraries
import
numpy
as
np
from
PIL
import
Image
import
tensorflow
as
tf
from
official.vision.beta.projects.video_ssl.configs
import
video_ssl
as
exp_cfg
from
official.vision.beta.projects.video_ssl.dataloaders
import
video_ssl_input
AUDIO_KEY
=
'features/audio'
def
fake_seq_example
():
# Create fake data.
random_image
=
np
.
random
.
randint
(
0
,
256
,
size
=
(
263
,
320
,
3
),
dtype
=
np
.
uint8
)
random_image
=
Image
.
fromarray
(
random_image
)
label
=
42
with
io
.
BytesIO
()
as
buffer
:
random_image
.
save
(
buffer
,
format
=
'JPEG'
)
raw_image_bytes
=
buffer
.
getvalue
()
seq_example
=
tf
.
train
.
SequenceExample
()
seq_example
.
feature_lists
.
feature_list
.
get_or_create
(
video_ssl_input
.
IMAGE_KEY
).
feature
.
add
().
bytes_list
.
value
[:]
=
[
raw_image_bytes
]
seq_example
.
feature_lists
.
feature_list
.
get_or_create
(
video_ssl_input
.
IMAGE_KEY
).
feature
.
add
().
bytes_list
.
value
[:]
=
[
raw_image_bytes
]
seq_example
.
context
.
feature
[
video_ssl_input
.
LABEL_KEY
].
int64_list
.
value
[:]
=
[
label
]
random_audio
=
np
.
random
.
normal
(
size
=
(
10
,
256
)).
tolist
()
for
s
in
random_audio
:
seq_example
.
feature_lists
.
feature_list
.
get_or_create
(
AUDIO_KEY
).
feature
.
add
().
float_list
.
value
[:]
=
s
return
seq_example
,
label
class
VideoAndLabelParserTest
(
tf
.
test
.
TestCase
):
def
test_video_ssl_input_pretrain
(
self
):
params
=
exp_cfg
.
video_ssl_pretrain_kinetics600
().
task
.
train_data
decoder
=
video_ssl_input
.
Decoder
()
parser
=
video_ssl_input
.
Parser
(
params
).
parse_fn
(
params
.
is_training
)
seq_example
,
_
=
fake_seq_example
()
input_tensor
=
tf
.
constant
(
seq_example
.
SerializeToString
())
decoded_tensors
=
decoder
.
decode
(
input_tensor
)
output_tensor
=
parser
(
decoded_tensors
)
image_features
,
_
=
output_tensor
image
=
image_features
[
'image'
]
self
.
assertAllEqual
(
image
.
shape
,
(
32
,
224
,
224
,
3
))
def
test_video_ssl_input_linear_train
(
self
):
params
=
exp_cfg
.
video_ssl_linear_eval_kinetics600
().
task
.
train_data
decoder
=
video_ssl_input
.
Decoder
()
parser
=
video_ssl_input
.
Parser
(
params
).
parse_fn
(
params
.
is_training
)
seq_example
,
label
=
fake_seq_example
()
input_tensor
=
tf
.
constant
(
seq_example
.
SerializeToString
())
decoded_tensors
=
decoder
.
decode
(
input_tensor
)
output_tensor
=
parser
(
decoded_tensors
)
image_features
,
label
=
output_tensor
image
=
image_features
[
'image'
]
self
.
assertAllEqual
(
image
.
shape
,
(
32
,
224
,
224
,
3
))
self
.
assertAllEqual
(
label
.
shape
,
(
600
,))
def
test_video_ssl_input_linear_eval
(
self
):
params
=
exp_cfg
.
video_ssl_linear_eval_kinetics600
().
task
.
validation_data
print
(
'!!!'
,
params
)
decoder
=
video_ssl_input
.
Decoder
()
parser
=
video_ssl_input
.
Parser
(
params
).
parse_fn
(
params
.
is_training
)
seq_example
,
label
=
fake_seq_example
()
input_tensor
=
tf
.
constant
(
seq_example
.
SerializeToString
())
decoded_tensors
=
decoder
.
decode
(
input_tensor
)
output_tensor
=
parser
(
decoded_tensors
)
image_features
,
label
=
output_tensor
image
=
image_features
[
'image'
]
self
.
assertAllEqual
(
image
.
shape
,
(
960
,
256
,
256
,
3
))
self
.
assertAllEqual
(
label
.
shape
,
(
600
,))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/video_ssl/losses/losses.py
deleted
100644 → 0
View file @
7cffacfe
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Define losses."""
# Import libraries
import
tensorflow
as
tf
from
tensorflow.compiler.tf2xla.python
import
xla
def
contrastive_loss
(
hidden
,
num_replicas
,
normalize_hidden
,
temperature
,
model
,
weight_decay
):
"""Computes contrastive loss.
Args:
hidden: embedding of video clips after projection head.
num_replicas: number of distributed replicas.
normalize_hidden: whether or not to l2 normalize the hidden vector.
temperature: temperature in the InfoNCE contrastive loss.
model: keras model for calculating weight decay.
weight_decay: weight decay parameter.
Returns:
A loss scalar.
The logits for contrastive prediction task.
The labels for contrastive prediction task.
"""
large_num
=
1e9
hidden1
,
hidden2
=
tf
.
split
(
hidden
,
num_or_size_splits
=
2
,
axis
=
0
)
if
normalize_hidden
:
hidden1
=
tf
.
math
.
l2_normalize
(
hidden1
,
-
1
)
hidden2
=
tf
.
math
.
l2_normalize
(
hidden2
,
-
1
)
batch_size
=
tf
.
shape
(
hidden1
)[
0
]
if
num_replicas
==
1
:
# This is the local version
hidden1_large
=
hidden1
hidden2_large
=
hidden2
labels
=
tf
.
one_hot
(
tf
.
range
(
batch_size
),
batch_size
*
2
)
masks
=
tf
.
one_hot
(
tf
.
range
(
batch_size
),
batch_size
)
else
:
# This is the cross-tpu version.
hidden1_large
=
tpu_cross_replica_concat
(
hidden1
,
num_replicas
)
hidden2_large
=
tpu_cross_replica_concat
(
hidden2
,
num_replicas
)
enlarged_batch_size
=
tf
.
shape
(
hidden1_large
)[
0
]
replica_id
=
tf
.
cast
(
tf
.
cast
(
xla
.
replica_id
(),
tf
.
uint32
),
tf
.
int32
)
labels_idx
=
tf
.
range
(
batch_size
)
+
replica_id
*
batch_size
labels
=
tf
.
one_hot
(
labels_idx
,
enlarged_batch_size
*
2
)
masks
=
tf
.
one_hot
(
labels_idx
,
enlarged_batch_size
)
logits_aa
=
tf
.
matmul
(
hidden1
,
hidden1_large
,
transpose_b
=
True
)
/
temperature
logits_aa
=
logits_aa
-
tf
.
cast
(
masks
,
logits_aa
.
dtype
)
*
large_num
logits_bb
=
tf
.
matmul
(
hidden2
,
hidden2_large
,
transpose_b
=
True
)
/
temperature
logits_bb
=
logits_bb
-
tf
.
cast
(
masks
,
logits_bb
.
dtype
)
*
large_num
logits_ab
=
tf
.
matmul
(
hidden1
,
hidden2_large
,
transpose_b
=
True
)
/
temperature
logits_ba
=
tf
.
matmul
(
hidden2
,
hidden1_large
,
transpose_b
=
True
)
/
temperature
loss_a
=
tf
.
reduce_mean
(
tf
.
nn
.
softmax_cross_entropy_with_logits
(
labels
,
tf
.
concat
([
logits_ab
,
logits_aa
],
1
)))
loss_b
=
tf
.
reduce_mean
(
tf
.
nn
.
softmax_cross_entropy_with_logits
(
labels
,
tf
.
concat
([
logits_ba
,
logits_bb
],
1
)))
loss
=
loss_a
+
loss_b
l2_loss
=
weight_decay
*
tf
.
add_n
([
tf
.
nn
.
l2_loss
(
v
)
for
v
in
model
.
trainable_variables
if
'kernel'
in
v
.
name
])
total_loss
=
loss
+
tf
.
cast
(
l2_loss
,
loss
.
dtype
)
contrast_prob
=
tf
.
nn
.
softmax
(
logits_ab
)
contrast_entropy
=
-
tf
.
reduce_mean
(
tf
.
reduce_sum
(
contrast_prob
*
tf
.
math
.
log
(
contrast_prob
+
1e-8
),
-
1
))
contrast_acc
=
tf
.
equal
(
tf
.
argmax
(
labels
,
1
),
tf
.
argmax
(
logits_ab
,
axis
=
1
))
contrast_acc
=
tf
.
reduce_mean
(
tf
.
cast
(
contrast_acc
,
tf
.
float32
))
return
{
'total_loss'
:
total_loss
,
'contrastive_loss'
:
loss
,
'reg_loss'
:
l2_loss
,
'contrast_acc'
:
contrast_acc
,
'contrast_entropy'
:
contrast_entropy
,
}
def
tpu_cross_replica_concat
(
tensor
,
num_replicas
):
"""Reduce a concatenation of the `tensor` across TPU cores.
Args:
tensor: tensor to concatenate.
num_replicas: number of TPU device replicas.
Returns:
Tensor of the same rank as `tensor` with first dimension `num_replicas`
times larger.
"""
with
tf
.
name_scope
(
'tpu_cross_replica_concat'
):
# This creates a tensor that is like the input tensor but has an added
# replica dimension as the outermost dimension. On each replica it will
# contain the local values and zeros for all other values that need to be
# fetched from other replicas.
ext_tensor
=
tf
.
scatter_nd
(
indices
=
[[
xla
.
replica_id
()]],
updates
=
[
tensor
],
shape
=
[
num_replicas
]
+
tensor
.
shape
.
as_list
())
# As every value is only present on one replica and 0 in all others, adding
# them all together will result in the full tensor on all replicas.
replica_context
=
tf
.
distribute
.
get_replica_context
()
ext_tensor
=
replica_context
.
all_reduce
(
tf
.
distribute
.
ReduceOp
.
SUM
,
ext_tensor
)
# Flatten the replica dimension.
# The first dimension size will be: tensor.shape[0] * num_replicas
# Using [-1] trick to support also scalar input.
return
tf
.
reshape
(
ext_tensor
,
[
-
1
]
+
ext_tensor
.
shape
.
as_list
()[
2
:])
official/vision/beta/projects/video_ssl/modeling/video_ssl_model.py
deleted
100644 → 0
View file @
7cffacfe
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Build video classification models."""
from
typing
import
Mapping
,
Optional
# Import libraries
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.vision.beta.modeling
import
backbones
from
official.vision.beta.modeling
import
factory_3d
as
model_factory
from
official.vision.beta.projects.video_ssl.configs
import
video_ssl
as
video_ssl_cfg
layers
=
tf
.
keras
.
layers
class
VideoSSLModel
(
tf
.
keras
.
Model
):
"""A video ssl model class builder."""
def
__init__
(
self
,
backbone
,
normalize_feature
,
hidden_dim
,
hidden_layer_num
,
hidden_norm_args
,
projection_dim
,
input_specs
:
Optional
[
Mapping
[
str
,
tf
.
keras
.
layers
.
InputSpec
]]
=
None
,
dropout_rate
:
float
=
0.0
,
aggregate_endpoints
:
bool
=
False
,
kernel_initializer
=
'random_uniform'
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
**
kwargs
):
"""Video Classification initialization function.
Args:
backbone: a 3d backbone network.
normalize_feature: whether normalize backbone feature.
hidden_dim: `int` number of hidden units in MLP.
hidden_layer_num: `int` number of hidden layers in MLP.
hidden_norm_args: `dict` for batchnorm arguments in MLP.
projection_dim: `int` number of ouput dimension for MLP.
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
dropout_rate: `float` rate for dropout regularization.
aggregate_endpoints: `bool` aggregate all end ponits or only use the
final end point.
kernel_initializer: kernel initializer for the dense layer.
kernel_regularizer: tf.keras.regularizers.Regularizer object. Default to
None.
bias_regularizer: tf.keras.regularizers.Regularizer object. Default to
None.
**kwargs: keyword arguments to be passed.
"""
if
not
input_specs
:
input_specs
=
{
'image'
:
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
None
,
3
])
}
self
.
_self_setattr_tracking
=
False
self
.
_config_dict
=
{
'backbone'
:
backbone
,
'normalize_feature'
:
normalize_feature
,
'hidden_dim'
:
hidden_dim
,
'hidden_layer_num'
:
hidden_layer_num
,
'use_sync_bn'
:
hidden_norm_args
.
use_sync_bn
,
'norm_momentum'
:
hidden_norm_args
.
norm_momentum
,
'norm_epsilon'
:
hidden_norm_args
.
norm_epsilon
,
'activation'
:
hidden_norm_args
.
activation
,
'projection_dim'
:
projection_dim
,
'input_specs'
:
input_specs
,
'dropout_rate'
:
dropout_rate
,
'aggregate_endpoints'
:
aggregate_endpoints
,
'kernel_initializer'
:
kernel_initializer
,
'kernel_regularizer'
:
kernel_regularizer
,
'bias_regularizer'
:
bias_regularizer
,
}
self
.
_input_specs
=
input_specs
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_bias_regularizer
=
bias_regularizer
self
.
_backbone
=
backbone
inputs
=
{
k
:
tf
.
keras
.
Input
(
shape
=
v
.
shape
[
1
:])
for
k
,
v
in
input_specs
.
items
()
}
endpoints
=
backbone
(
inputs
[
'image'
])
if
aggregate_endpoints
:
pooled_feats
=
[]
for
endpoint
in
endpoints
.
values
():
x_pool
=
tf
.
keras
.
layers
.
GlobalAveragePooling3D
()(
endpoint
)
pooled_feats
.
append
(
x_pool
)
x
=
tf
.
concat
(
pooled_feats
,
axis
=
1
)
else
:
x
=
endpoints
[
max
(
endpoints
.
keys
())]
x
=
tf
.
keras
.
layers
.
GlobalAveragePooling3D
()(
x
)
# L2 Normalize feature after backbone
if
normalize_feature
:
x
=
tf
.
nn
.
l2_normalize
(
x
,
axis
=-
1
)
# MLP hidden layers
for
_
in
range
(
hidden_layer_num
):
x
=
tf
.
keras
.
layers
.
Dense
(
hidden_dim
)(
x
)
if
self
.
_config_dict
[
'use_sync_bn'
]:
x
=
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
(
momentum
=
self
.
_config_dict
[
'norm_momentum'
],
epsilon
=
self
.
_config_dict
[
'norm_epsilon'
])(
x
)
else
:
x
=
tf
.
keras
.
layers
.
BatchNormalization
(
momentum
=
self
.
_config_dict
[
'norm_momentum'
],
epsilon
=
self
.
_config_dict
[
'norm_epsilon'
])(
x
)
x
=
tf_utils
.
get_activation
(
self
.
_config_dict
[
'activation'
])(
x
)
# Projection head
x
=
tf
.
keras
.
layers
.
Dense
(
projection_dim
)(
x
)
super
(
VideoSSLModel
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
x
,
**
kwargs
)
@
property
def
checkpoint_items
(
self
):
"""Returns a dictionary of items to be additionally checkpointed."""
return
dict
(
backbone
=
self
.
backbone
)
@
property
def
backbone
(
self
):
return
self
.
_backbone
def
get_config
(
self
):
return
self
.
_config_dict
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
@
model_factory
.
register_model_builder
(
'video_ssl_model'
)
def
build_video_ssl_pretrain_model
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
model_config
:
video_ssl_cfg
.
VideoSSLModel
,
num_classes
:
int
,
l2_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
):
"""Builds the video classification model."""
del
num_classes
input_specs_dict
=
{
'image'
:
input_specs
}
backbone
=
backbones
.
factory
.
build_backbone
(
input_specs
=
input_specs
,
backbone_config
=
model_config
.
backbone
,
norm_activation_config
=
model_config
.
norm_activation
,
l2_regularizer
=
l2_regularizer
)
# Norm layer type in the MLP head should same with backbone
assert
model_config
.
norm_activation
.
use_sync_bn
==
model_config
.
hidden_norm_activation
.
use_sync_bn
model
=
VideoSSLModel
(
backbone
=
backbone
,
normalize_feature
=
model_config
.
normalize_feature
,
hidden_dim
=
model_config
.
hidden_dim
,
hidden_layer_num
=
model_config
.
hidden_layer_num
,
hidden_norm_args
=
model_config
.
hidden_norm_activation
,
projection_dim
=
model_config
.
projection_dim
,
input_specs
=
input_specs_dict
,
dropout_rate
=
model_config
.
dropout_rate
,
aggregate_endpoints
=
model_config
.
aggregate_endpoints
,
kernel_regularizer
=
l2_regularizer
)
return
model
official/vision/beta/projects/video_ssl/ops/video_ssl_preprocess_ops.py
deleted
100644 → 0
View file @
7cffacfe
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Utils for customed ops for video ssl."""
import
functools
from
typing
import
Optional
import
tensorflow
as
tf
def
random_apply
(
func
,
p
,
x
):
"""Randomly apply function func to x with probability p."""
return
tf
.
cond
(
tf
.
less
(
tf
.
random
.
uniform
([],
minval
=
0
,
maxval
=
1
,
dtype
=
tf
.
float32
),
tf
.
cast
(
p
,
tf
.
float32
)),
lambda
:
func
(
x
),
lambda
:
x
)
def
random_brightness
(
image
,
max_delta
):
"""Distort brightness of image (SimCLRv2 style)."""
factor
=
tf
.
random
.
uniform
(
[],
tf
.
maximum
(
1.0
-
max_delta
,
0
),
1.0
+
max_delta
)
image
=
image
*
factor
return
image
def
random_solarization
(
image
,
p
=
0.2
):
"""Random solarize image."""
def
_transform
(
image
):
image
=
image
*
tf
.
cast
(
tf
.
less
(
image
,
0.5
),
dtype
=
image
.
dtype
)
+
(
1.0
-
image
)
*
tf
.
cast
(
tf
.
greater_equal
(
image
,
0.5
),
dtype
=
image
.
dtype
)
return
image
return
random_apply
(
_transform
,
p
=
p
,
x
=
image
)
def
to_grayscale
(
image
,
keep_channels
=
True
):
"""Turn the input image to gray scale.
Args:
image: The input image tensor.
keep_channels: Whether maintaining the channel number for the image.
If true, the transformed image will repeat three times in channel.
If false, the transformed image will only have one channel.
Returns:
The distorted image tensor.
"""
image
=
tf
.
image
.
rgb_to_grayscale
(
image
)
if
keep_channels
:
image
=
tf
.
tile
(
image
,
[
1
,
1
,
3
])
return
image
def
color_jitter
(
image
,
strength
,
random_order
=
True
):
"""Distorts the color of the image (SimCLRv2 style).
Args:
image: The input image tensor.
strength: The floating number for the strength of the color augmentation.
random_order: A bool, specifying whether to randomize the jittering order.
Returns:
The distorted image tensor.
"""
brightness
=
0.8
*
strength
contrast
=
0.8
*
strength
saturation
=
0.8
*
strength
hue
=
0.2
*
strength
if
random_order
:
return
color_jitter_rand
(
image
,
brightness
,
contrast
,
saturation
,
hue
)
else
:
return
color_jitter_nonrand
(
image
,
brightness
,
contrast
,
saturation
,
hue
)
def
color_jitter_nonrand
(
image
,
brightness
=
0
,
contrast
=
0
,
saturation
=
0
,
hue
=
0
):
"""Distorts the color of the image (jittering order is fixed, SimCLRv2 style).
Args:
image: The input image tensor.
brightness: A float, specifying the brightness for color jitter.
contrast: A float, specifying the contrast for color jitter.
saturation: A float, specifying the saturation for color jitter.
hue: A float, specifying the hue for color jitter.
Returns:
The distorted image tensor.
"""
with
tf
.
name_scope
(
'distort_color'
):
def
apply_transform
(
i
,
x
,
brightness
,
contrast
,
saturation
,
hue
):
"""Apply the i-th transformation."""
if
brightness
!=
0
and
i
==
0
:
x
=
random_brightness
(
x
,
max_delta
=
brightness
)
elif
contrast
!=
0
and
i
==
1
:
x
=
tf
.
image
.
random_contrast
(
x
,
lower
=
1
-
contrast
,
upper
=
1
+
contrast
)
elif
saturation
!=
0
and
i
==
2
:
x
=
tf
.
image
.
random_saturation
(
x
,
lower
=
1
-
saturation
,
upper
=
1
+
saturation
)
elif
hue
!=
0
:
x
=
tf
.
image
.
random_hue
(
x
,
max_delta
=
hue
)
return
x
for
i
in
range
(
4
):
image
=
apply_transform
(
i
,
image
,
brightness
,
contrast
,
saturation
,
hue
)
image
=
tf
.
clip_by_value
(
image
,
0.
,
1.
)
return
image
def
color_jitter_rand
(
image
,
brightness
=
0
,
contrast
=
0
,
saturation
=
0
,
hue
=
0
):
"""Distorts the color of the image (jittering order is random, SimCLRv2 style).
Args:
image: The input image tensor.
brightness: A float, specifying the brightness for color jitter.
contrast: A float, specifying the contrast for color jitter.
saturation: A float, specifying the saturation for color jitter.
hue: A float, specifying the hue for color jitter.
Returns:
The distorted image tensor.
"""
with
tf
.
name_scope
(
'distort_color'
):
def
apply_transform
(
i
,
x
):
"""Apply the i-th transformation."""
def
brightness_transform
():
if
brightness
==
0
:
return
x
else
:
return
random_brightness
(
x
,
max_delta
=
brightness
)
def
contrast_transform
():
if
contrast
==
0
:
return
x
else
:
return
tf
.
image
.
random_contrast
(
x
,
lower
=
1
-
contrast
,
upper
=
1
+
contrast
)
def
saturation_transform
():
if
saturation
==
0
:
return
x
else
:
return
tf
.
image
.
random_saturation
(
x
,
lower
=
1
-
saturation
,
upper
=
1
+
saturation
)
def
hue_transform
():
if
hue
==
0
:
return
x
else
:
return
tf
.
image
.
random_hue
(
x
,
max_delta
=
hue
)
# pylint:disable=g-long-lambda
x
=
tf
.
cond
(
tf
.
less
(
i
,
2
),
lambda
:
tf
.
cond
(
tf
.
less
(
i
,
1
),
brightness_transform
,
contrast_transform
),
lambda
:
tf
.
cond
(
tf
.
less
(
i
,
3
),
saturation_transform
,
hue_transform
))
# pylint:disable=g-long-lambda
return
x
perm
=
tf
.
random
.
shuffle
(
tf
.
range
(
4
))
for
i
in
range
(
4
):
image
=
apply_transform
(
perm
[
i
],
image
)
image
=
tf
.
clip_by_value
(
image
,
0.
,
1.
)
return
image
def
random_color_jitter_3d
(
frames
):
"""Applies temporally consistent color jittering to one video clip.
Args:
frames: `Tensor` of shape [num_frames, height, width, channels].
Returns:
A Tensor of shape [num_frames, height, width, channels] being color jittered
with the same operation.
"""
def
random_color_jitter
(
image
,
p
=
1.0
):
def
_transform
(
image
):
color_jitter_t
=
functools
.
partial
(
color_jitter
,
strength
=
1.0
)
image
=
random_apply
(
color_jitter_t
,
p
=
0.8
,
x
=
image
)
return
random_apply
(
to_grayscale
,
p
=
0.2
,
x
=
image
)
return
random_apply
(
_transform
,
p
=
p
,
x
=
image
)
num_frames
,
width
,
height
,
channels
=
frames
.
shape
.
as_list
()
big_image
=
tf
.
reshape
(
frames
,
[
num_frames
*
width
,
height
,
channels
])
big_image
=
random_color_jitter
(
big_image
)
return
tf
.
reshape
(
big_image
,
[
num_frames
,
width
,
height
,
channels
])
def
gaussian_blur
(
image
,
kernel_size
,
sigma
,
padding
=
'SAME'
):
"""Blurs the given image with separable convolution.
Args:
image: Tensor of shape [height, width, channels] and dtype float to blur.
kernel_size: Integer Tensor for the size of the blur kernel. This is should
be an odd number. If it is an even number, the actual kernel size will be
size + 1.
sigma: Sigma value for gaussian operator.
padding: Padding to use for the convolution. Typically 'SAME' or 'VALID'.
Returns:
A Tensor representing the blurred image.
"""
radius
=
tf
.
cast
(
kernel_size
/
2
,
dtype
=
tf
.
int32
)
kernel_size
=
radius
*
2
+
1
x
=
tf
.
cast
(
tf
.
range
(
-
radius
,
radius
+
1
),
dtype
=
tf
.
float32
)
blur_filter
=
tf
.
exp
(
-
tf
.
pow
(
x
,
2.0
)
/
(
2.0
*
tf
.
pow
(
tf
.
cast
(
sigma
,
dtype
=
tf
.
float32
),
2.0
)))
blur_filter
/=
tf
.
reduce_sum
(
blur_filter
)
# One vertical and one horizontal filter.
blur_v
=
tf
.
reshape
(
blur_filter
,
[
kernel_size
,
1
,
1
,
1
])
blur_h
=
tf
.
reshape
(
blur_filter
,
[
1
,
kernel_size
,
1
,
1
])
num_channels
=
tf
.
shape
(
image
)[
-
1
]
blur_h
=
tf
.
tile
(
blur_h
,
[
1
,
1
,
num_channels
,
1
])
blur_v
=
tf
.
tile
(
blur_v
,
[
1
,
1
,
num_channels
,
1
])
expand_batch_dim
=
image
.
shape
.
ndims
==
3
if
expand_batch_dim
:
# Tensorflow requires batched input to convolutions, which we can fake with
# an extra dimension.
image
=
tf
.
expand_dims
(
image
,
axis
=
0
)
blurred
=
tf
.
nn
.
depthwise_conv2d
(
image
,
blur_h
,
strides
=
[
1
,
1
,
1
,
1
],
padding
=
padding
)
blurred
=
tf
.
nn
.
depthwise_conv2d
(
blurred
,
blur_v
,
strides
=
[
1
,
1
,
1
,
1
],
padding
=
padding
)
if
expand_batch_dim
:
blurred
=
tf
.
squeeze
(
blurred
,
axis
=
0
)
return
blurred
def
random_blur
(
image
,
height
,
width
,
p
=
1.0
):
"""Randomly blur an image.
Args:
image: `Tensor` representing an image of arbitrary size.
height: Height of output image.
width: Width of output image.
p: probability of applying this transformation.
Returns:
A preprocessed image `Tensor`.
"""
del
width
def
_transform
(
image
):
sigma
=
tf
.
random
.
uniform
([],
0.1
,
2.0
,
dtype
=
tf
.
float32
)
return
gaussian_blur
(
image
,
kernel_size
=
height
//
10
,
sigma
=
sigma
,
padding
=
'SAME'
)
return
random_apply
(
_transform
,
p
=
p
,
x
=
image
)
def
random_blur_3d
(
frames
,
height
,
width
,
blur_probability
=
0.5
):
"""Apply efficient batch data transformations.
Args:
frames: `Tensor` of shape [timesteps, height, width, 3].
height: the height of image.
width: the width of image.
blur_probability: the probaility to apply the blur operator.
Returns:
Preprocessed feature list.
"""
def
generate_selector
(
p
,
bsz
):
shape
=
[
bsz
,
1
,
1
,
1
]
selector
=
tf
.
cast
(
tf
.
less
(
tf
.
random
.
uniform
(
shape
,
0
,
1
,
dtype
=
tf
.
float32
),
p
),
tf
.
float32
)
return
selector
frames_new
=
random_blur
(
frames
,
height
,
width
,
p
=
1.
)
selector
=
generate_selector
(
blur_probability
,
1
)
frames
=
frames_new
*
selector
+
frames
*
(
1
-
selector
)
frames
=
tf
.
clip_by_value
(
frames
,
0.
,
1.
)
return
frames
def
_sample_or_pad_sequence_indices
(
sequence
:
tf
.
Tensor
,
num_steps
:
int
,
stride
:
int
,
offset
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Returns indices to take for sampling or padding sequences to fixed size."""
sequence_length
=
tf
.
shape
(
sequence
)[
0
]
sel_idx
=
tf
.
range
(
sequence_length
)
# Repeats sequence until num_steps are available in total.
max_length
=
num_steps
*
stride
+
offset
num_repeats
=
tf
.
math
.
floordiv
(
max_length
+
sequence_length
-
1
,
sequence_length
)
sel_idx
=
tf
.
tile
(
sel_idx
,
[
num_repeats
])
steps
=
tf
.
range
(
offset
,
offset
+
num_steps
*
stride
,
stride
)
return
tf
.
gather
(
sel_idx
,
steps
)
def
sample_ssl_sequence
(
sequence
:
tf
.
Tensor
,
num_steps
:
int
,
random
:
bool
,
stride
:
int
=
1
,
num_windows
:
Optional
[
int
]
=
2
)
->
tf
.
Tensor
:
"""Samples two segments of size num_steps randomly from a given sequence.
Currently it only supports images, and specically designed for video self-
supervised learning.
Args:
sequence: Any tensor where the first dimension is timesteps.
num_steps: Number of steps (e.g. frames) to take.
random: A boolean indicating whether to random sample the single window. If
True, the offset is randomized. Only True is supported.
stride: Distance to sample between timesteps.
num_windows: Number of sequence sampled.
Returns:
A single Tensor with first dimension num_steps with the sampled segment.
"""
sequence_length
=
tf
.
shape
(
sequence
)[
0
]
sequence_length
=
tf
.
cast
(
sequence_length
,
tf
.
float32
)
if
random
:
max_offset
=
tf
.
cond
(
tf
.
greater
(
sequence_length
,
(
num_steps
-
1
)
*
stride
),
lambda
:
sequence_length
-
(
num_steps
-
1
)
*
stride
,
lambda
:
sequence_length
)
max_offset
=
tf
.
cast
(
max_offset
,
dtype
=
tf
.
float32
)
def
cdf
(
k
,
power
=
1.0
):
"""Cumulative distribution function for x^power."""
p
=
-
tf
.
math
.
pow
(
k
,
power
+
1
)
/
(
power
*
tf
.
math
.
pow
(
max_offset
,
power
+
1
))
+
k
*
(
power
+
1
)
/
(
power
*
max_offset
)
return
p
u
=
tf
.
random
.
uniform
(())
k_low
=
tf
.
constant
(
0
,
dtype
=
tf
.
float32
)
k_up
=
max_offset
k
=
tf
.
math
.
floordiv
(
max_offset
,
2.0
)
c
=
lambda
k_low
,
k_up
,
k
:
tf
.
greater
(
tf
.
math
.
abs
(
k_up
-
k_low
),
1.0
)
# pylint:disable=g-long-lambda
b
=
lambda
k_low
,
k_up
,
k
:
tf
.
cond
(
tf
.
greater
(
cdf
(
k
),
u
),
lambda
:
[
k_low
,
k
,
tf
.
math
.
floordiv
(
k
+
k_low
,
2.0
)],
lambda
:
[
k
,
k_up
,
tf
.
math
.
floordiv
(
k_up
+
k
,
2.0
)])
_
,
_
,
k
=
tf
.
while_loop
(
c
,
b
,
[
k_low
,
k_up
,
k
])
delta
=
tf
.
cast
(
k
,
tf
.
int32
)
max_offset
=
tf
.
cast
(
max_offset
,
tf
.
int32
)
sequence_length
=
tf
.
cast
(
sequence_length
,
tf
.
int32
)
choice_1
=
tf
.
cond
(
tf
.
equal
(
max_offset
,
sequence_length
),
lambda
:
tf
.
random
.
uniform
((),
maxval
=
tf
.
cast
(
max_offset
,
dtype
=
tf
.
int32
),
dtype
=
tf
.
int32
),
lambda
:
tf
.
random
.
uniform
((),
maxval
=
tf
.
cast
(
max_offset
-
delta
,
dtype
=
tf
.
int32
),
dtype
=
tf
.
int32
))
choice_2
=
tf
.
cond
(
tf
.
equal
(
max_offset
,
sequence_length
),
lambda
:
tf
.
random
.
uniform
((),
maxval
=
tf
.
cast
(
max_offset
,
dtype
=
tf
.
int32
),
dtype
=
tf
.
int32
),
lambda
:
choice_1
+
delta
)
# pylint:disable=g-long-lambda
shuffle_choice
=
tf
.
random
.
shuffle
((
choice_1
,
choice_2
))
offset_1
=
shuffle_choice
[
0
]
offset_2
=
shuffle_choice
[
1
]
else
:
raise
NotImplementedError
indices_1
=
_sample_or_pad_sequence_indices
(
sequence
=
sequence
,
num_steps
=
num_steps
,
stride
=
stride
,
offset
=
offset_1
)
indices_2
=
_sample_or_pad_sequence_indices
(
sequence
=
sequence
,
num_steps
=
num_steps
,
stride
=
stride
,
offset
=
offset_2
)
indices
=
tf
.
concat
([
indices_1
,
indices_2
],
axis
=
0
)
indices
.
set_shape
((
num_windows
*
num_steps
,))
output
=
tf
.
gather
(
sequence
,
indices
)
return
output
official/vision/beta/projects/video_ssl/ops/video_ssl_preprocess_ops_test.py
deleted
100644 → 0
View file @
7cffacfe
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
tensorflow
as
tf
from
official.vision.beta.ops
import
preprocess_ops_3d
from
official.vision.beta.projects.video_ssl.ops
import
video_ssl_preprocess_ops
class
VideoSslPreprocessOpsTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
().
setUp
()
self
.
_raw_frames
=
tf
.
random
.
uniform
((
250
,
256
,
256
,
3
),
minval
=
0
,
maxval
=
255
,
dtype
=
tf
.
dtypes
.
int32
)
self
.
_sampled_frames
=
self
.
_raw_frames
[:
16
]
self
.
_frames
=
preprocess_ops_3d
.
normalize_image
(
self
.
_sampled_frames
,
False
,
tf
.
float32
)
def
test_sample_ssl_sequence
(
self
):
sampled_seq
=
video_ssl_preprocess_ops
.
sample_ssl_sequence
(
self
.
_raw_frames
,
16
,
True
,
2
)
self
.
assertAllEqual
(
sampled_seq
.
shape
,
(
32
,
256
,
256
,
3
))
def
test_random_color_jitter_3d
(
self
):
jittered_clip
=
video_ssl_preprocess_ops
.
random_color_jitter_3d
(
self
.
_frames
)
self
.
assertAllEqual
(
jittered_clip
.
shape
,
(
16
,
256
,
256
,
3
))
def
test_random_blur_3d
(
self
):
blurred_clip
=
video_ssl_preprocess_ops
.
random_blur_3d
(
self
.
_frames
,
256
,
256
)
self
.
assertAllEqual
(
blurred_clip
.
shape
,
(
16
,
256
,
256
,
3
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/video_ssl/tasks/__init__.py
deleted
100644 → 0
View file @
7cffacfe
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tasks package definition."""
from
official.vision.beta.projects.video_ssl.tasks
import
linear_eval
from
official.vision.beta.projects.video_ssl.tasks
import
pretrain
official/vision/beta/projects/video_ssl/tasks/linear_eval.py
deleted
100644 → 0
View file @
7cffacfe
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Video ssl linear evaluation task definition."""
from
typing
import
Any
,
Optional
,
List
,
Tuple
from
absl
import
logging
import
tensorflow
as
tf
# pylint: disable=unused-import
from
official.core
import
task_factory
from
official.vision.beta.projects.video_ssl.configs
import
video_ssl
as
exp_cfg
from
official.vision.beta.projects.video_ssl.modeling
import
video_ssl_model
from
official.vision.beta.tasks
import
video_classification
@
task_factory
.
register_task_cls
(
exp_cfg
.
VideoSSLEvalTask
)
class
VideoSSLEvalTask
(
video_classification
.
VideoClassificationTask
):
"""A task for video ssl linear evaluation."""
def
initialize
(
self
,
model
:
tf
.
keras
.
Model
):
"""Loading pretrained checkpoint."""
if
not
self
.
task_config
.
init_checkpoint
:
return
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
# Restoring checkpoint.
if
self
.
task_config
.
init_checkpoint_modules
==
'backbone'
:
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
backbone
)
ckpt
.
read
(
ckpt_dir_or_file
)
else
:
raise
NotImplementedError
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
def
train_step
(
self
,
inputs
:
Tuple
[
Any
,
Any
],
model
:
tf
.
keras
.
Model
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
metrics
:
Optional
[
List
[
Any
]]
=
None
):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
model
.
backbone
.
trainable
=
False
logging
.
info
(
'Setting the backbone to non-trainable.'
)
return
super
(
video_classification
.
VideoClassificationTask
,
self
).
train_step
(
inputs
,
model
,
optimizer
,
metrics
)
official/vision/beta/projects/video_ssl/tasks/pretrain.py
deleted
100644 → 0
View file @
7cffacfe
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Video ssl pretrain task definition."""
from
absl
import
logging
import
tensorflow
as
tf
# pylint: disable=unused-import
from
official.core
import
input_reader
from
official.core
import
task_factory
from
official.vision.beta.modeling
import
factory_3d
from
official.vision.beta.projects.video_ssl.configs
import
video_ssl
as
exp_cfg
from
official.vision.beta.projects.video_ssl.dataloaders
import
video_ssl_input
from
official.vision.beta.projects.video_ssl.losses
import
losses
from
official.vision.beta.projects.video_ssl.modeling
import
video_ssl_model
from
official.vision.beta.tasks
import
video_classification
@
task_factory
.
register_task_cls
(
exp_cfg
.
VideoSSLPretrainTask
)
class
VideoSSLPretrainTask
(
video_classification
.
VideoClassificationTask
):
"""A task for video ssl pretraining."""
def
build_model
(
self
):
"""Builds video ssl pretraining model."""
common_input_shape
=
[
d1
if
d1
==
d2
else
None
for
d1
,
d2
in
zip
(
self
.
task_config
.
train_data
.
feature_shape
,
self
.
task_config
.
validation_data
.
feature_shape
)
]
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
]
+
common_input_shape
)
logging
.
info
(
'Build model input %r'
,
common_input_shape
)
model
=
factory_3d
.
build_model
(
self
.
task_config
.
model
.
model_type
,
input_specs
=
input_specs
,
model_config
=
self
.
task_config
.
model
,
num_classes
=
self
.
task_config
.
train_data
.
num_classes
)
return
model
def
_get_decoder_fn
(
self
,
params
):
decoder
=
video_ssl_input
.
Decoder
()
return
decoder
.
decode
def
build_inputs
(
self
,
params
:
exp_cfg
.
DataConfig
,
input_context
=
None
):
"""Builds classification input."""
parser
=
video_ssl_input
.
Parser
(
input_params
=
params
)
postprocess_fn
=
video_ssl_input
.
PostBatchProcessor
(
params
)
reader
=
input_reader
.
InputReader
(
params
,
dataset_fn
=
self
.
_get_dataset_fn
(
params
),
decoder_fn
=
self
.
_get_decoder_fn
(
params
),
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
),
postprocess_fn
=
postprocess_fn
)
dataset
=
reader
.
read
(
input_context
=
input_context
)
return
dataset
def
build_losses
(
self
,
model_outputs
,
num_replicas
,
model
):
"""Sparse categorical cross entropy loss.
Args:
model_outputs: Output logits of the model.
num_replicas: distributed replica number.
model: keras model for calculating weight decay.
Returns:
The total loss tensor.
"""
all_losses
=
{}
contrastive_metrics
=
{}
losses_config
=
self
.
task_config
.
losses
total_loss
=
None
contrastive_loss_dict
=
losses
.
contrastive_loss
(
model_outputs
,
num_replicas
,
losses_config
.
normalize_hidden
,
losses_config
.
temperature
,
model
,
self
.
task_config
.
losses
.
l2_weight_decay
)
total_loss
=
contrastive_loss_dict
[
'total_loss'
]
all_losses
.
update
({
'total_loss'
:
total_loss
})
all_losses
[
self
.
loss
]
=
total_loss
contrastive_metrics
.
update
({
'contrast_acc'
:
contrastive_loss_dict
[
'contrast_acc'
],
'contrast_entropy'
:
contrastive_loss_dict
[
'contrast_entropy'
],
'reg_loss'
:
contrastive_loss_dict
[
'reg_loss'
]
})
return
all_losses
,
contrastive_metrics
def
build_metrics
(
self
,
training
=
True
):
"""Gets streaming metrics for training/validation."""
metrics
=
[
tf
.
keras
.
metrics
.
Mean
(
name
=
'contrast_acc'
),
tf
.
keras
.
metrics
.
Mean
(
name
=
'contrast_entropy'
),
tf
.
keras
.
metrics
.
Mean
(
name
=
'reg_loss'
)
]
return
metrics
def
process_metrics
(
self
,
metrics
,
contrastive_metrics
):
"""Process and update metrics."""
contrastive_metric_values
=
contrastive_metrics
.
values
()
for
metric
,
contrastive_metric_value
in
zip
(
metrics
,
contrastive_metric_values
):
metric
.
update_state
(
contrastive_metric_value
)
def
train_step
(
self
,
inputs
,
model
,
optimizer
,
metrics
=
None
):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features
,
_
=
inputs
num_replicas
=
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
with
tf
.
GradientTape
()
as
tape
:
if
self
.
task_config
.
train_data
.
output_audio
:
outputs
=
model
(
features
,
training
=
True
)
else
:
outputs
=
model
(
features
[
'image'
],
training
=
True
)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
all_losses
,
contrastive_metrics
=
self
.
build_losses
(
model_outputs
=
outputs
,
num_replicas
=
num_replicas
,
model
=
model
)
loss
=
all_losses
[
self
.
loss
]
scaled_loss
=
loss
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
scaled_loss
=
optimizer
.
get_scaled_loss
(
scaled_loss
)
tvars
=
model
.
trainable_variables
grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
grads
=
optimizer
.
get_unscaled_gradients
(
grads
)
optimizer
.
apply_gradients
(
list
(
zip
(
grads
,
tvars
)))
logs
=
all_losses
if
metrics
:
self
.
process_metrics
(
metrics
,
contrastive_metrics
)
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
metrics
})
return
logs
def
validation_step
(
self
,
inputs
,
model
,
metrics
=
None
):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
raise
NotImplementedError
def
inference_step
(
self
,
features
,
model
):
"""Performs the forward step."""
raise
NotImplementedError
official/vision/beta/projects/video_ssl/tasks/pretrain_test.py
deleted
100644 → 0
View file @
7cffacfe
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
import
functools
import
os
import
random
import
orbit
import
tensorflow
as
tf
# pylint: disable=unused-import
from
official.core
import
exp_factory
from
official.core
import
task_factory
from
official.modeling
import
optimization
from
official.vision
import
beta
from
official.vision.beta.dataloaders
import
tfexample_utils
from
official.vision.beta.projects.video_ssl.tasks
import
pretrain
class
VideoClassificationTaskTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
(
VideoClassificationTaskTest
,
self
).
setUp
()
data_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'data'
)
tf
.
io
.
gfile
.
makedirs
(
data_dir
)
self
.
_data_path
=
os
.
path
.
join
(
data_dir
,
'data.tfrecord'
)
# pylint: disable=g-complex-comprehension
examples
=
[
tfexample_utils
.
make_video_test_example
(
image_shape
=
(
36
,
36
,
3
),
audio_shape
=
(
20
,
128
),
label
=
random
.
randint
(
0
,
100
))
for
_
in
range
(
2
)
]
# pylint: enable=g-complex-comprehension
tfexample_utils
.
dump_to_tfrecord
(
self
.
_data_path
,
tf_examples
=
examples
)
def
test_task
(
self
):
config
=
exp_factory
.
get_exp_config
(
'video_ssl_pretrain_kinetics600'
)
config
.
task
.
train_data
.
global_batch_size
=
2
config
.
task
.
train_data
.
input_path
=
self
.
_data_path
task
=
pretrain
.
VideoSSLPretrainTask
(
config
.
task
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
strategy
=
tf
.
distribute
.
get_strategy
()
dataset
=
orbit
.
utils
.
make_distributed_dataset
(
strategy
,
functools
.
partial
(
task
.
build_inputs
),
config
.
task
.
train_data
)
iterator
=
iter
(
dataset
)
opt_factory
=
optimization
.
OptimizerFactory
(
config
.
trainer
.
optimizer_config
)
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
logs
=
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
self
.
assertIn
(
'total_loss'
,
logs
)
self
.
assertIn
(
'reg_loss'
,
logs
)
self
.
assertIn
(
'contrast_acc'
,
logs
)
self
.
assertIn
(
'contrast_entropy'
,
logs
)
def
test_task_factory
(
self
):
config
=
exp_factory
.
get_exp_config
(
'video_ssl_pretrain_kinetics600'
)
task
=
task_factory
.
get_task
(
config
.
task
)
self
.
assertIs
(
type
(
task
),
pretrain
.
VideoSSLPretrainTask
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/video_ssl/train.py
deleted
100644 → 0
View file @
7cffacfe
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Training driver."""
from
absl
import
app
from
absl
import
flags
import
gin
# pylint: disable=unused-import
from
official.common
import
registry_imports
from
official.common
import
distribute_utils
from
official.common
import
flags
as
tfm_flags
from
official.core
import
task_factory
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.modeling
import
performance
from
official.vision.beta.projects.video_ssl.modeling
import
video_ssl_model
from
official.vision.beta.projects.video_ssl.tasks
import
linear_eval
from
official.vision.beta.projects.video_ssl.tasks
import
pretrain
# pylint: disable=unused-import
FLAGS
=
flags
.
FLAGS
def
main
(
_
):
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_params
)
params
=
train_utils
.
parse_configuration
(
FLAGS
)
model_dir
=
FLAGS
.
model_dir
if
'train'
in
FLAGS
.
mode
:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils
.
serialize_config
(
params
,
model_dir
)
if
'train_and_eval'
in
FLAGS
.
mode
:
assert
(
params
.
task
.
train_data
.
feature_shape
==
params
.
task
.
validation_data
.
feature_shape
),
(
f
'train
{
params
.
task
.
train_data
.
feature_shape
}
!= validate '
f
'
{
params
.
task
.
validation_data
.
feature_shape
}
'
)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if
params
.
runtime
.
mixed_precision_dtype
:
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
)
distribution_strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
num_gpus
=
params
.
runtime
.
num_gpus
,
tpu_address
=
params
.
runtime
.
tpu
)
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
FLAGS
.
mode
,
params
=
params
,
model_dir
=
model_dir
)
train_utils
.
save_gin_config
(
FLAGS
.
mode
,
model_dir
)
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
app
.
run
(
main
)
official/vision/beta/projects/yolo/common/registry_imports.py
View file @
8b641b13
...
...
@@ -16,7 +16,7 @@
# pylint: disable=unused-import
# pylint: disable=g-bad-import-order
from
official.
comm
on
import
registry_imports
from
official.
visi
on
import
registry_imports
# import configs
from
official.vision.beta.projects.yolo.configs
import
darknet_classification
...
...
official/vision/beta/projects/yolo/configs/backbones.py
View file @
8b641b13
...
...
@@ -15,7 +15,7 @@
"""Backbones configurations."""
import
dataclasses
from
official.modeling
import
hyperparams
from
official.vision.
beta.
configs
import
backbones
from
official.vision.configs
import
backbones
@
dataclasses
.
dataclass
...
...
official/vision/beta/projects/yolo/configs/darknet_classification.py
View file @
8b641b13
...
...
@@ -20,9 +20,9 @@ from typing import List, Optional
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.modeling
import
hyperparams
from
official.vision.beta.configs
import
common
from
official.vision.beta.configs
import
image_classification
as
imc
from
official.vision.beta.projects.yolo.configs
import
backbones
from
official.vision.configs
import
common
from
official.vision.configs
import
image_classification
as
imc
@
dataclasses
.
dataclass
...
...
official/vision/beta/projects/yolo/configs/decoders.py
View file @
8b641b13
...
...
@@ -16,7 +16,7 @@
import
dataclasses
from
typing
import
Optional
from
official.modeling
import
hyperparams
from
official.vision.
beta.
configs
import
decoders
from
official.vision.configs
import
decoders
@
dataclasses
.
dataclass
...
...
Prev
1
…
15
16
17
18
19
20
21
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment