Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
ef6a4159
Commit
ef6a4159
authored
Jan 20, 2021
by
Yin Cui
Committed by
A. Unique TensorFlower
Jan 20, 2021
Browse files
Internal change
PiperOrigin-RevId: 352949436
parent
e0f818da
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
0 additions
and
1034 deletions
+0
-1034
official/vision/beta/projects/video_ssl/README.md
official/vision/beta/projects/video_ssl/README.md
+0
-8
official/vision/beta/projects/video_ssl/configs/__init__.py
official/vision/beta/projects/video_ssl/configs/__init__.py
+0
-18
official/vision/beta/projects/video_ssl/configs/video_ssl.py
official/vision/beta/projects/video_ssl/configs/video_ssl.py
+0
-89
official/vision/beta/projects/video_ssl/configs/video_ssl_test.py
.../vision/beta/projects/video_ssl/configs/video_ssl_test.py
+0
-45
official/vision/beta/projects/video_ssl/dataloaders/video_ssl_input.py
...on/beta/projects/video_ssl/dataloaders/video_ssl_input.py
+0
-319
official/vision/beta/projects/video_ssl/dataloaders/video_ssl_input_test.py
...ta/projects/video_ssl/dataloaders/video_ssl_input_test.py
+0
-111
official/vision/beta/projects/video_ssl/ops/video_ssl_preprocess_ops.py
...n/beta/projects/video_ssl/ops/video_ssl_preprocess_ops.py
+0
-397
official/vision/beta/projects/video_ssl/ops/video_ssl_preprocess_ops_test.py
...a/projects/video_ssl/ops/video_ssl_preprocess_ops_test.py
+0
-47
No files found.
official/vision/beta/projects/video_ssl/README.md
deleted
100644 → 0
View file @
e0f818da
# Video Self-supervised Learning
TF2 implementation of
[
CVRL
](
https://arxiv.org/abs/2008.03800
)
:
[1] Qian, Rui, Tianjian Meng, Boqing Gong, Ming-Hsuan Yang, Huisheng Wang,
Serge Belongie, and Yin Cui. "Spatiotemporal contrastive video
representation learning." arXiv preprint arXiv:2008.03800 (2020).
official/vision/beta/projects/video_ssl/configs/__init__.py
deleted
100644 → 0
View file @
e0f818da
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Configs package definition."""
from
official.vision.beta.projects.video_ssl.configs
import
video_ssl
official/vision/beta/projects/video_ssl/configs/video_ssl.py
deleted
100644 → 0
View file @
e0f818da
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Video classification configuration definition."""
import
dataclasses
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.vision.beta.configs
import
video_classification
Losses
=
video_classification
.
Losses
VideoClassificationModel
=
video_classification
.
VideoClassificationModel
VideoClassificationTask
=
video_classification
.
VideoClassificationTask
@
dataclasses
.
dataclass
class
DataConfig
(
video_classification
.
DataConfig
):
"""The base configuration for building datasets."""
is_ssl
:
bool
=
False
@
exp_factory
.
register_config_factory
(
'video_ssl_pretrain_kinetics400'
)
def
video_ssl_pretrain_kinetics400
()
->
cfg
.
ExperimentConfig
:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp
=
video_classification
.
video_classification_kinetics400
()
exp
.
task
.
train_data
=
DataConfig
(
is_ssl
=
True
,
**
exp
.
task
.
train_data
.
as_dict
())
exp
.
task
.
train_data
.
feature_shape
=
(
16
,
224
,
224
,
3
)
exp
.
task
.
train_data
.
temporal_stride
=
2
return
exp
@
exp_factory
.
register_config_factory
(
'video_ssl_linear_eval_kinetics400'
)
def
video_ssl_linear_eval_kinetics400
()
->
cfg
.
ExperimentConfig
:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp
=
video_classification
.
video_classification_kinetics400
()
exp
.
task
.
train_data
=
DataConfig
(
is_ssl
=
False
,
**
exp
.
task
.
train_data
.
as_dict
())
exp
.
task
.
train_data
.
feature_shape
=
(
32
,
224
,
224
,
3
)
exp
.
task
.
train_data
.
temporal_stride
=
2
exp
.
task
.
validation_data
.
feature_shape
=
(
32
,
256
,
256
,
3
)
exp
.
task
.
validation_data
.
temporal_stride
=
2
exp
.
task
.
validation_data
=
DataConfig
(
is_ssl
=
False
,
**
exp
.
task
.
validation_data
.
as_dict
())
exp
.
task
.
validation_data
.
min_image_size
=
256
exp
.
task
.
validation_data
.
num_test_clips
=
10
exp
.
task
.
validation_data
.
num_test_crops
=
3
return
exp
@
exp_factory
.
register_config_factory
(
'video_ssl_pretrain_kinetics600'
)
def
video_ssl_pretrain_kinetics600
()
->
cfg
.
ExperimentConfig
:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp
=
video_classification
.
video_classification_kinetics600
()
exp
.
task
.
train_data
=
DataConfig
(
is_ssl
=
True
,
**
exp
.
task
.
train_data
.
as_dict
())
exp
.
task
.
train_data
.
feature_shape
=
(
16
,
224
,
224
,
3
)
exp
.
task
.
train_data
.
temporal_stride
=
2
return
exp
@
exp_factory
.
register_config_factory
(
'video_ssl_linear_eval_kinetics600'
)
def
video_ssl_linear_eval_kinetics600
()
->
cfg
.
ExperimentConfig
:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp
=
video_classification
.
video_classification_kinetics600
()
exp
.
task
.
train_data
=
DataConfig
(
is_ssl
=
False
,
**
exp
.
task
.
train_data
.
as_dict
())
exp
.
task
.
train_data
.
feature_shape
=
(
32
,
224
,
224
,
3
)
exp
.
task
.
train_data
.
temporal_stride
=
2
exp
.
task
.
validation_data
=
DataConfig
(
is_ssl
=
False
,
**
exp
.
task
.
validation_data
.
as_dict
())
exp
.
task
.
validation_data
.
feature_shape
=
(
32
,
256
,
256
,
3
)
exp
.
task
.
validation_data
.
temporal_stride
=
2
exp
.
task
.
validation_data
.
min_image_size
=
256
exp
.
task
.
validation_data
.
num_test_clips
=
10
exp
.
task
.
validation_data
.
num_test_crops
=
3
return
exp
official/vision/beta/projects/video_ssl/configs/video_ssl_test.py
deleted
100644 → 0
View file @
e0f818da
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=unused-import
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.vision
import
beta
from
official.vision.beta.projects.video_ssl.configs
import
video_ssl
as
exp_cfg
class
VideoClassificationConfigTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
((
'video_ssl_pretrain_kinetics400'
,),
(
'video_ssl_linear_eval_kinetics400'
,),
(
'video_ssl_pretrain_kinetics600'
,),
(
'video_ssl_linear_eval_kinetics600'
,))
def
test_video_classification_configs
(
self
,
config_name
):
config
=
exp_factory
.
get_exp_config
(
config_name
)
self
.
assertIsInstance
(
config
,
cfg
.
ExperimentConfig
)
self
.
assertIsInstance
(
config
.
task
,
exp_cfg
.
VideoClassificationTask
)
self
.
assertIsInstance
(
config
.
task
.
model
,
exp_cfg
.
VideoClassificationModel
)
self
.
assertIsInstance
(
config
.
task
.
train_data
,
exp_cfg
.
DataConfig
)
config
.
task
.
train_data
.
is_training
=
None
with
self
.
assertRaises
(
KeyError
):
config
.
validate
()
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/video_ssl/dataloaders/video_ssl_input.py
deleted
100644 → 0
View file @
e0f818da
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Parser for video and label datasets."""
from
typing
import
Dict
,
Optional
,
Tuple
from
absl
import
logging
import
tensorflow
as
tf
from
official.vision.beta.dataloaders
import
video_input
from
official.vision.beta.ops
import
preprocess_ops_3d
from
official.vision.beta.projects.video_ssl.configs
import
video_ssl
as
exp_cfg
from
official.vision.beta.projects.video_ssl.ops
import
video_ssl_preprocess_ops
IMAGE_KEY
=
'image/encoded'
LABEL_KEY
=
'clip/label/index'
Decoder
=
video_input
.
Decoder
def
_process_image
(
image
:
tf
.
Tensor
,
is_training
:
bool
=
True
,
is_ssl
:
bool
=
False
,
num_frames
:
int
=
32
,
stride
:
int
=
1
,
num_test_clips
:
int
=
1
,
min_resize
:
int
=
256
,
crop_size
:
int
=
224
,
num_crops
:
int
=
1
,
zero_centering_image
:
bool
=
False
,
seed
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Processes a serialized image tensor.
Args:
image: Input Tensor of shape [timesteps] and type tf.string of serialized
frames.
is_training: Whether or not in training mode. If True, random sample, crop
and left right flip is used.
is_ssl: Whether or not in self-supervised pre-training mode.
num_frames: Number of frames per subclip.
stride: Temporal stride to sample frames.
num_test_clips: Number of test clips (1 by default). If more than 1, this
will sample multiple linearly spaced clips within each video at test time.
If 1, then a single clip in the middle of the video is sampled. The clips
are aggreagated in the batch dimension.
min_resize: Frames are resized so that min(height, width) is min_resize.
crop_size: Final size of the frame after cropping the resized frames. Both
height and width are the same.
num_crops: Number of crops to perform on the resized frames.
zero_centering_image: If True, frames are normalized to values in [-1, 1].
If False, values in [0, 1].
seed: A deterministic seed to use when sampling.
Returns:
Processed frames. Tensor of shape
[num_frames * num_test_clips, crop_size, crop_size, 3].
"""
# Validate parameters.
if
is_training
and
num_test_clips
!=
1
:
logging
.
warning
(
'`num_test_clips` %d is ignored since `is_training` is `True`.'
,
num_test_clips
)
# Temporal sampler.
if
is_training
:
# Sampler for training.
if
is_ssl
:
# Sample two clips from linear decreasing distribution.
image
=
video_ssl_preprocess_ops
.
sample_ssl_sequence
(
image
,
num_frames
,
True
,
stride
)
else
:
# Sample random clip.
image
=
preprocess_ops_3d
.
sample_sequence
(
image
,
num_frames
,
True
,
stride
)
else
:
# Sampler for evaluation.
if
num_test_clips
>
1
:
# Sample linspace clips.
image
=
preprocess_ops_3d
.
sample_linspace_sequence
(
image
,
num_test_clips
,
num_frames
,
stride
)
else
:
# Sample middle clip.
image
=
preprocess_ops_3d
.
sample_sequence
(
image
,
num_frames
,
False
,
stride
)
# Decode JPEG string to tf.uint8.
image
=
preprocess_ops_3d
.
decode_jpeg
(
image
,
3
)
if
is_training
:
# Standard image data augmentation: random resized crop and random flip.
if
is_ssl
:
image_1
,
image_2
=
tf
.
split
(
image
,
num_or_size_splits
=
2
,
axis
=
0
)
image_1
=
preprocess_ops_3d
.
random_crop_resize
(
image_1
,
crop_size
,
crop_size
,
num_frames
,
3
,
(
0.5
,
2
),
(
0.3
,
1
))
image_1
=
preprocess_ops_3d
.
random_flip_left_right
(
image_1
,
seed
)
image_2
=
preprocess_ops_3d
.
random_crop_resize
(
image_2
,
crop_size
,
crop_size
,
num_frames
,
3
,
(
0.5
,
2
),
(
0.3
,
1
))
image_2
=
preprocess_ops_3d
.
random_flip_left_right
(
image_2
,
seed
)
else
:
image
=
preprocess_ops_3d
.
random_crop_resize
(
image
,
crop_size
,
crop_size
,
num_frames
,
3
,
(
0.5
,
2
),
(
0.3
,
1
))
image
=
preprocess_ops_3d
.
random_flip_left_right
(
image
,
seed
)
else
:
# Resize images (resize happens only if necessary to save compute).
image
=
preprocess_ops_3d
.
resize_smallest
(
image
,
min_resize
)
# Three-crop of the frames.
image
=
preprocess_ops_3d
.
crop_image
(
image
,
crop_size
,
crop_size
,
False
,
num_crops
)
# Cast the frames in float32, normalizing according to zero_centering_image.
if
is_training
and
is_ssl
:
image_1
=
preprocess_ops_3d
.
normalize_image
(
image_1
,
zero_centering_image
)
image_2
=
preprocess_ops_3d
.
normalize_image
(
image_2
,
zero_centering_image
)
else
:
image
=
preprocess_ops_3d
.
normalize_image
(
image
,
zero_centering_image
)
# Self-supervised pre-training augmentations.
if
is_training
and
is_ssl
:
# Temporally consistent color jittering.
image_1
=
video_ssl_preprocess_ops
.
random_color_jitter_3d
(
image_1
)
image_2
=
video_ssl_preprocess_ops
.
random_color_jitter_3d
(
image_2
)
# Temporally consistent gaussian blurring.
image_1
=
video_ssl_preprocess_ops
.
random_blur_3d
(
image_1
,
num_frames
,
crop_size
,
crop_size
)
image_2
=
video_ssl_preprocess_ops
.
random_blur_3d
(
image_2
,
num_frames
,
crop_size
,
crop_size
)
image
=
tf
.
concat
([
image_1
,
image_2
],
axis
=
0
)
return
image
def
_postprocess_image
(
image
:
tf
.
Tensor
,
is_training
:
bool
=
True
,
is_ssl
:
bool
=
False
,
num_frames
:
int
=
32
,
num_test_clips
:
int
=
1
,
num_test_crops
:
int
=
1
)
->
tf
.
Tensor
:
"""Processes a batched Tensor of frames.
The same parameters used in process should be used here.
Args:
image: Input Tensor of shape [batch, timesteps, height, width, 3].
is_training: Whether or not in training mode. If True, random sample, crop
and left right flip is used.
is_ssl: Whether or not in self-supervised pre-training mode.
num_frames: Number of frames per subclip.
num_test_clips: Number of test clips (1 by default). If more than 1, this
will sample multiple linearly spaced clips within each video at test time.
If 1, then a single clip in the middle of the video is sampled. The clips
are aggreagated in the batch dimension.
num_test_crops: Number of test crops (1 by default). If more than 1, there
are multiple crops for each clip at test time. If 1, there is a single
central crop. The crops are aggreagated in the batch dimension.
Returns:
Processed frames. Tensor of shape
[batch * num_test_clips * num_test_crops, num_frames, height, width, 3].
"""
if
is_ssl
and
is_training
:
# In this case, two clips of self-supervised pre-training are merged
# together in batch dimenstion which will be 2 * batch.
image
=
tf
.
concat
(
tf
.
split
(
image
,
num_or_size_splits
=
2
,
axis
=
1
),
axis
=
0
)
num_views
=
num_test_clips
*
num_test_crops
if
num_views
>
1
and
not
is_training
:
# In this case, multiple views are merged together in batch dimenstion which
# will be batch * num_views.
image
=
tf
.
reshape
(
image
,
[
-
1
,
num_frames
]
+
image
.
shape
[
2
:].
as_list
())
return
image
def
_process_label
(
label
:
tf
.
Tensor
,
one_hot_label
:
bool
=
True
,
num_classes
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Processes label Tensor."""
# Validate parameters.
if
one_hot_label
and
not
num_classes
:
raise
ValueError
(
'`num_classes` should be given when requesting one hot label.'
)
# Cast to tf.int32.
label
=
tf
.
cast
(
label
,
dtype
=
tf
.
int32
)
if
one_hot_label
:
# Replace label index by one hot representation.
label
=
tf
.
one_hot
(
label
,
num_classes
)
if
len
(
label
.
shape
.
as_list
())
>
1
:
label
=
tf
.
reduce_sum
(
label
,
axis
=
0
)
if
num_classes
==
1
:
# The trick for single label.
label
=
1
-
label
return
label
class
Parser
(
video_input
.
Parser
):
"""Parses a video and label dataset."""
def
__init__
(
self
,
input_params
:
exp_cfg
.
DataConfig
,
image_key
:
str
=
IMAGE_KEY
,
label_key
:
str
=
LABEL_KEY
):
super
(
Parser
,
self
).
__init__
(
input_params
,
image_key
,
label_key
)
self
.
_is_ssl
=
input_params
.
is_ssl
def
_parse_train_data
(
self
,
decoded_tensors
:
Dict
[
str
,
tf
.
Tensor
]
)
->
Tuple
[
Dict
[
str
,
tf
.
Tensor
],
tf
.
Tensor
]:
"""Parses data for training."""
# Process image and label.
image
=
decoded_tensors
[
self
.
_image_key
]
image
=
_process_image
(
image
=
image
,
is_training
=
True
,
is_ssl
=
self
.
_is_ssl
,
num_frames
=
self
.
_num_frames
,
stride
=
self
.
_stride
,
num_test_clips
=
self
.
_num_test_clips
,
min_resize
=
self
.
_min_resize
,
crop_size
=
self
.
_crop_size
)
image
=
tf
.
cast
(
image
,
dtype
=
self
.
_dtype
)
features
=
{
'image'
:
image
}
label
=
decoded_tensors
[
self
.
_label_key
]
label
=
_process_label
(
label
,
self
.
_one_hot_label
,
self
.
_num_classes
)
return
features
,
label
def
_parse_eval_data
(
self
,
decoded_tensors
:
Dict
[
str
,
tf
.
Tensor
]
)
->
Tuple
[
Dict
[
str
,
tf
.
Tensor
],
tf
.
Tensor
]:
"""Parses data for evaluation."""
image
=
decoded_tensors
[
self
.
_image_key
]
image
=
_process_image
(
image
=
image
,
is_training
=
False
,
num_frames
=
self
.
_num_frames
,
stride
=
self
.
_stride
,
num_test_clips
=
self
.
_num_test_clips
,
min_resize
=
self
.
_min_resize
,
crop_size
=
self
.
_crop_size
,
num_crops
=
self
.
_num_crops
)
image
=
tf
.
cast
(
image
,
dtype
=
self
.
_dtype
)
features
=
{
'image'
:
image
}
label
=
decoded_tensors
[
self
.
_label_key
]
label
=
_process_label
(
label
,
self
.
_one_hot_label
,
self
.
_num_classes
)
if
self
.
_output_audio
:
audio
=
decoded_tensors
[
self
.
_audio_feature
]
audio
=
tf
.
cast
(
audio
,
dtype
=
self
.
_dtype
)
audio
=
preprocess_ops_3d
.
sample_sequence
(
audio
,
20
,
random
=
False
,
stride
=
1
)
audio
=
tf
.
ensure_shape
(
audio
,
[
20
,
2048
])
features
[
'audio'
]
=
audio
return
features
,
label
def
parse_fn
(
self
,
is_training
):
"""Returns a parse fn that reads and parses raw tensors from the decoder.
Args:
is_training: a `bool` to indicate whether it is in training mode.
Returns:
parse: a `callable` that takes the serialized examle and generate the
images, labels tuple where labels is a dict of Tensors that contains
labels.
"""
def
parse
(
decoded_tensors
):
"""Parses the serialized example data."""
if
is_training
:
return
self
.
_parse_train_data
(
decoded_tensors
)
else
:
return
self
.
_parse_eval_data
(
decoded_tensors
)
return
parse
class
PostBatchProcessor
(
object
):
"""Processes a video and label dataset which is batched."""
def
__init__
(
self
,
input_params
:
exp_cfg
.
DataConfig
):
self
.
_is_training
=
input_params
.
is_training
self
.
_is_ssl
=
input_params
.
is_ssl
self
.
_num_frames
=
input_params
.
feature_shape
[
0
]
self
.
_num_test_clips
=
input_params
.
num_test_clips
self
.
_num_test_crops
=
input_params
.
num_test_crops
def
__call__
(
self
,
features
:
Dict
[
str
,
tf
.
Tensor
],
label
:
tf
.
Tensor
)
->
Tuple
[
Dict
[
str
,
tf
.
Tensor
],
tf
.
Tensor
]:
"""Parses a single tf.Example into image and label tensors."""
for
key
in
[
'image'
,
'audio'
]:
if
key
in
features
:
features
[
key
]
=
_postprocess_image
(
image
=
features
[
key
],
is_training
=
self
.
_is_training
,
is_ssl
=
self
.
_is_ssl
,
num_frames
=
self
.
_num_frames
,
num_test_clips
=
self
.
_num_test_clips
,
num_test_crops
=
self
.
_num_test_crops
)
return
features
,
label
official/vision/beta/projects/video_ssl/dataloaders/video_ssl_input_test.py
deleted
100644 → 0
View file @
e0f818da
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
io
# Import libraries
import
numpy
as
np
from
PIL
import
Image
import
tensorflow
as
tf
from
official.vision.beta.projects.video_ssl.configs
import
video_ssl
as
exp_cfg
from
official.vision.beta.projects.video_ssl.dataloaders
import
video_ssl_input
AUDIO_KEY
=
'features/audio'
def
fake_seq_example
():
# Create fake data.
random_image
=
np
.
random
.
randint
(
0
,
256
,
size
=
(
263
,
320
,
3
),
dtype
=
np
.
uint8
)
random_image
=
Image
.
fromarray
(
random_image
)
label
=
42
with
io
.
BytesIO
()
as
buffer
:
random_image
.
save
(
buffer
,
format
=
'JPEG'
)
raw_image_bytes
=
buffer
.
getvalue
()
seq_example
=
tf
.
train
.
SequenceExample
()
seq_example
.
feature_lists
.
feature_list
.
get_or_create
(
video_ssl_input
.
IMAGE_KEY
).
feature
.
add
().
bytes_list
.
value
[:]
=
[
raw_image_bytes
]
seq_example
.
feature_lists
.
feature_list
.
get_or_create
(
video_ssl_input
.
IMAGE_KEY
).
feature
.
add
().
bytes_list
.
value
[:]
=
[
raw_image_bytes
]
seq_example
.
context
.
feature
[
video_ssl_input
.
LABEL_KEY
].
int64_list
.
value
[:]
=
[
label
]
random_audio
=
np
.
random
.
normal
(
size
=
(
10
,
256
)).
tolist
()
for
s
in
random_audio
:
seq_example
.
feature_lists
.
feature_list
.
get_or_create
(
AUDIO_KEY
).
feature
.
add
().
float_list
.
value
[:]
=
s
return
seq_example
,
label
class
VideoAndLabelParserTest
(
tf
.
test
.
TestCase
):
def
test_video_ssl_input_pretrain
(
self
):
params
=
exp_cfg
.
video_ssl_pretrain_kinetics600
().
task
.
train_data
decoder
=
video_ssl_input
.
Decoder
()
parser
=
video_ssl_input
.
Parser
(
params
).
parse_fn
(
params
.
is_training
)
seq_example
,
_
=
fake_seq_example
()
input_tensor
=
tf
.
constant
(
seq_example
.
SerializeToString
())
decoded_tensors
=
decoder
.
decode
(
input_tensor
)
output_tensor
=
parser
(
decoded_tensors
)
image_features
,
_
=
output_tensor
image
=
image_features
[
'image'
]
self
.
assertAllEqual
(
image
.
shape
,
(
32
,
224
,
224
,
3
))
def
test_video_ssl_input_linear_train
(
self
):
params
=
exp_cfg
.
video_ssl_linear_eval_kinetics600
().
task
.
train_data
decoder
=
video_ssl_input
.
Decoder
()
parser
=
video_ssl_input
.
Parser
(
params
).
parse_fn
(
params
.
is_training
)
seq_example
,
label
=
fake_seq_example
()
input_tensor
=
tf
.
constant
(
seq_example
.
SerializeToString
())
decoded_tensors
=
decoder
.
decode
(
input_tensor
)
output_tensor
=
parser
(
decoded_tensors
)
image_features
,
label
=
output_tensor
image
=
image_features
[
'image'
]
self
.
assertAllEqual
(
image
.
shape
,
(
32
,
224
,
224
,
3
))
self
.
assertAllEqual
(
label
.
shape
,
(
600
,))
def
test_video_ssl_input_linear_eval
(
self
):
params
=
exp_cfg
.
video_ssl_linear_eval_kinetics600
().
task
.
validation_data
print
(
'!!!'
,
params
)
decoder
=
video_ssl_input
.
Decoder
()
parser
=
video_ssl_input
.
Parser
(
params
).
parse_fn
(
params
.
is_training
)
seq_example
,
label
=
fake_seq_example
()
input_tensor
=
tf
.
constant
(
seq_example
.
SerializeToString
())
decoded_tensors
=
decoder
.
decode
(
input_tensor
)
output_tensor
=
parser
(
decoded_tensors
)
image_features
,
label
=
output_tensor
image
=
image_features
[
'image'
]
self
.
assertAllEqual
(
image
.
shape
,
(
960
,
256
,
256
,
3
))
self
.
assertAllEqual
(
label
.
shape
,
(
600
,))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/video_ssl/ops/video_ssl_preprocess_ops.py
deleted
100644 → 0
View file @
e0f818da
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils for customed ops for video ssl."""
import
functools
from
typing
import
Optional
import
tensorflow
as
tf
def
random_apply
(
func
,
p
,
x
):
"""Randomly apply function func to x with probability p."""
return
tf
.
cond
(
tf
.
less
(
tf
.
random
.
uniform
([],
minval
=
0
,
maxval
=
1
,
dtype
=
tf
.
float32
),
tf
.
cast
(
p
,
tf
.
float32
)),
lambda
:
func
(
x
),
lambda
:
x
)
def
random_brightness
(
image
,
max_delta
):
"""Distort brightness of image (SimCLRv2 style)."""
factor
=
tf
.
random
.
uniform
(
[],
tf
.
maximum
(
1.0
-
max_delta
,
0
),
1.0
+
max_delta
)
image
=
image
*
factor
return
image
def
to_grayscale
(
image
,
keep_channels
=
True
):
"""Turn the input image to gray scale.
Args:
image: The input image tensor.
keep_channels: Whether maintaining the channel number for the image.
If true, the transformed image will repeat three times in channel.
If false, the transformed image will only have one channel.
Returns:
The distorted image tensor.
"""
image
=
tf
.
image
.
rgb_to_grayscale
(
image
)
if
keep_channels
:
image
=
tf
.
tile
(
image
,
[
1
,
1
,
3
])
return
image
def
color_jitter
(
image
,
strength
,
random_order
=
True
):
"""Distorts the color of the image (SimCLRv2 style).
Args:
image: The input image tensor.
strength: The floating number for the strength of the color augmentation.
random_order: A bool, specifying whether to randomize the jittering order.
Returns:
The distorted image tensor.
"""
brightness
=
0.8
*
strength
contrast
=
0.8
*
strength
saturation
=
0.8
*
strength
hue
=
0.2
*
strength
if
random_order
:
return
color_jitter_rand
(
image
,
brightness
,
contrast
,
saturation
,
hue
)
else
:
return
color_jitter_nonrand
(
image
,
brightness
,
contrast
,
saturation
,
hue
)
def
color_jitter_nonrand
(
image
,
brightness
=
0
,
contrast
=
0
,
saturation
=
0
,
hue
=
0
):
"""Distorts the color of the image (jittering order is fixed, SimCLRv2 style).
Args:
image: The input image tensor.
brightness: A float, specifying the brightness for color jitter.
contrast: A float, specifying the contrast for color jitter.
saturation: A float, specifying the saturation for color jitter.
hue: A float, specifying the hue for color jitter.
Returns:
The distorted image tensor.
"""
with
tf
.
name_scope
(
'distort_color'
):
def
apply_transform
(
i
,
x
,
brightness
,
contrast
,
saturation
,
hue
):
"""Apply the i-th transformation."""
if
brightness
!=
0
and
i
==
0
:
x
=
random_brightness
(
x
,
max_delta
=
brightness
)
elif
contrast
!=
0
and
i
==
1
:
x
=
tf
.
image
.
random_contrast
(
x
,
lower
=
1
-
contrast
,
upper
=
1
+
contrast
)
elif
saturation
!=
0
and
i
==
2
:
x
=
tf
.
image
.
random_saturation
(
x
,
lower
=
1
-
saturation
,
upper
=
1
+
saturation
)
elif
hue
!=
0
:
x
=
tf
.
image
.
random_hue
(
x
,
max_delta
=
hue
)
return
x
for
i
in
range
(
4
):
image
=
apply_transform
(
i
,
image
,
brightness
,
contrast
,
saturation
,
hue
)
image
=
tf
.
clip_by_value
(
image
,
0.
,
1.
)
return
image
def
color_jitter_rand
(
image
,
brightness
=
0
,
contrast
=
0
,
saturation
=
0
,
hue
=
0
):
"""Distorts the color of the image (jittering order is random, SimCLRv2 style).
Args:
image: The input image tensor.
brightness: A float, specifying the brightness for color jitter.
contrast: A float, specifying the contrast for color jitter.
saturation: A float, specifying the saturation for color jitter.
hue: A float, specifying the hue for color jitter.
Returns:
The distorted image tensor.
"""
with
tf
.
name_scope
(
'distort_color'
):
def
apply_transform
(
i
,
x
):
"""Apply the i-th transformation."""
def
brightness_transform
():
if
brightness
==
0
:
return
x
else
:
return
random_brightness
(
x
,
max_delta
=
brightness
)
def
contrast_transform
():
if
contrast
==
0
:
return
x
else
:
return
tf
.
image
.
random_contrast
(
x
,
lower
=
1
-
contrast
,
upper
=
1
+
contrast
)
def
saturation_transform
():
if
saturation
==
0
:
return
x
else
:
return
tf
.
image
.
random_saturation
(
x
,
lower
=
1
-
saturation
,
upper
=
1
+
saturation
)
def
hue_transform
():
if
hue
==
0
:
return
x
else
:
return
tf
.
image
.
random_hue
(
x
,
max_delta
=
hue
)
# pylint:disable=g-long-lambda
x
=
tf
.
cond
(
tf
.
less
(
i
,
2
),
lambda
:
tf
.
cond
(
tf
.
less
(
i
,
1
),
brightness_transform
,
contrast_transform
),
lambda
:
tf
.
cond
(
tf
.
less
(
i
,
3
),
saturation_transform
,
hue_transform
))
# pylint:disable=g-long-lambda
return
x
perm
=
tf
.
random
.
shuffle
(
tf
.
range
(
4
))
for
i
in
range
(
4
):
image
=
apply_transform
(
perm
[
i
],
image
)
image
=
tf
.
clip_by_value
(
image
,
0.
,
1.
)
return
image
def
random_color_jitter_3d
(
frames
):
"""Applies temporally consistent color jittering to one video clip.
Args:
frames: `Tensor` of shape [num_frames, height, width, channels].
Returns:
A Tensor of shape [num_frames, height, width, channels] being color jittered
with the same operation.
"""
def
random_color_jitter
(
image
,
p
=
1.0
):
def
_transform
(
image
):
color_jitter_t
=
functools
.
partial
(
color_jitter
,
strength
=
1.0
)
image
=
random_apply
(
color_jitter_t
,
p
=
0.8
,
x
=
image
)
return
random_apply
(
to_grayscale
,
p
=
0.2
,
x
=
image
)
return
random_apply
(
_transform
,
p
=
p
,
x
=
image
)
num_frames
,
width
,
height
,
channels
=
tf
.
shape
(
frames
)
big_image
=
tf
.
reshape
(
frames
,
[
num_frames
*
width
,
height
,
channels
])
big_image
=
random_color_jitter
(
big_image
)
return
tf
.
reshape
(
big_image
,
[
num_frames
,
width
,
height
,
channels
])
def
gaussian_blur
(
image
,
kernel_size
,
sigma
,
padding
=
'SAME'
):
"""Blurs the given image with separable convolution.
Args:
image: Tensor of shape [height, width, channels] and dtype float to blur.
kernel_size: Integer Tensor for the size of the blur kernel. This is should
be an odd number. If it is an even number, the actual kernel size will be
size + 1.
sigma: Sigma value for gaussian operator.
padding: Padding to use for the convolution. Typically 'SAME' or 'VALID'.
Returns:
A Tensor representing the blurred image.
"""
radius
=
tf
.
cast
(
kernel_size
/
2
,
dtype
=
tf
.
int32
)
kernel_size
=
radius
*
2
+
1
x
=
tf
.
cast
(
tf
.
range
(
-
radius
,
radius
+
1
),
dtype
=
tf
.
float32
)
blur_filter
=
tf
.
exp
(
-
tf
.
pow
(
x
,
2.0
)
/
(
2.0
*
tf
.
pow
(
tf
.
cast
(
sigma
,
dtype
=
tf
.
float32
),
2.0
)))
blur_filter
/=
tf
.
reduce_sum
(
blur_filter
)
# One vertical and one horizontal filter.
blur_v
=
tf
.
reshape
(
blur_filter
,
[
kernel_size
,
1
,
1
,
1
])
blur_h
=
tf
.
reshape
(
blur_filter
,
[
1
,
kernel_size
,
1
,
1
])
num_channels
=
tf
.
shape
(
image
)[
-
1
]
blur_h
=
tf
.
tile
(
blur_h
,
[
1
,
1
,
num_channels
,
1
])
blur_v
=
tf
.
tile
(
blur_v
,
[
1
,
1
,
num_channels
,
1
])
expand_batch_dim
=
image
.
shape
.
ndims
==
3
if
expand_batch_dim
:
# Tensorflow requires batched input to convolutions, which we can fake with
# an extra dimension.
image
=
tf
.
expand_dims
(
image
,
axis
=
0
)
blurred
=
tf
.
nn
.
depthwise_conv2d
(
image
,
blur_h
,
strides
=
[
1
,
1
,
1
,
1
],
padding
=
padding
)
blurred
=
tf
.
nn
.
depthwise_conv2d
(
blurred
,
blur_v
,
strides
=
[
1
,
1
,
1
,
1
],
padding
=
padding
)
if
expand_batch_dim
:
blurred
=
tf
.
squeeze
(
blurred
,
axis
=
0
)
return
blurred
def
random_blur
(
image
,
height
,
width
,
p
=
1.0
):
"""Randomly blur an image.
Args:
image: `Tensor` representing an image of arbitrary size.
height: Height of output image.
width: Width of output image.
p: probability of applying this transformation.
Returns:
A preprocessed image `Tensor`.
"""
del
width
def
_transform
(
image
):
sigma
=
tf
.
random
.
uniform
([],
0.1
,
2.0
,
dtype
=
tf
.
float32
)
return
gaussian_blur
(
image
,
kernel_size
=
height
//
10
,
sigma
=
sigma
,
padding
=
'SAME'
)
return
random_apply
(
_transform
,
p
=
p
,
x
=
image
)
def
random_blur_3d
(
frames
,
height
,
width
,
blur_probability
=
0.5
):
"""Apply efficient batch data transformations.
Args:
frames: `Tensor` of shape [timesteps, height, width, 3].
height: the height of image.
width: the width of image.
blur_probability: the probaility to apply the blur operator.
Returns:
Preprocessed feature list.
"""
def
generate_selector
(
p
,
bsz
):
shape
=
[
bsz
,
1
,
1
,
1
]
selector
=
tf
.
cast
(
tf
.
less
(
tf
.
random
.
uniform
(
shape
,
0
,
1
,
dtype
=
tf
.
float32
),
p
),
tf
.
float32
)
return
selector
frames_new
=
random_blur
(
frames
,
height
,
width
,
p
=
1.
)
selector
=
generate_selector
(
blur_probability
,
1
)
frames
=
frames_new
*
selector
+
frames
*
(
1
-
selector
)
frames
=
tf
.
clip_by_value
(
frames
,
0.
,
1.
)
return
frames
def
_sample_or_pad_sequence_indices
(
sequence
:
tf
.
Tensor
,
num_steps
:
int
,
stride
:
int
,
offset
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Returns indices to take for sampling or padding sequences to fixed size."""
sequence_length
=
tf
.
shape
(
sequence
)[
0
]
sel_idx
=
tf
.
range
(
sequence_length
)
# Repeats sequence until num_steps are available in total.
max_length
=
num_steps
*
stride
+
offset
num_repeats
=
tf
.
math
.
floordiv
(
max_length
+
sequence_length
-
1
,
sequence_length
)
sel_idx
=
tf
.
tile
(
sel_idx
,
[
num_repeats
])
steps
=
tf
.
range
(
offset
,
offset
+
num_steps
*
stride
,
stride
)
return
tf
.
gather
(
sel_idx
,
steps
)
def
sample_ssl_sequence
(
sequence
:
tf
.
Tensor
,
num_steps
:
int
,
random
:
bool
,
stride
:
int
=
1
,
num_windows
:
Optional
[
int
]
=
2
)
->
tf
.
Tensor
:
"""Samples two segments of size num_steps randomly from a given sequence.
Currently it only supports images, and specically designed for video self-
supervised learning.
Args:
sequence: Any tensor where the first dimension is timesteps.
num_steps: Number of steps (e.g. frames) to take.
random: A boolean indicating whether to random sample the single window. If
True, the offset is randomized. Only True is supported.
stride: Distance to sample between timesteps.
num_windows: Number of sequence sampled.
Returns:
A single Tensor with first dimension num_steps with the sampled segment.
"""
sequence_length
=
tf
.
shape
(
sequence
)[
0
]
sequence_length
=
tf
.
cast
(
sequence_length
,
tf
.
float32
)
if
random
:
max_offset
=
tf
.
cond
(
tf
.
greater
(
sequence_length
,
(
num_steps
-
1
)
*
stride
),
lambda
:
sequence_length
-
(
num_steps
-
1
)
*
stride
,
lambda
:
sequence_length
)
max_offset
=
tf
.
cast
(
max_offset
,
dtype
=
tf
.
float32
)
def
cdf
(
k
,
power
=
1.0
):
"""Cumulative distribution function for x^power."""
p
=
-
tf
.
math
.
pow
(
k
,
power
+
1
)
/
(
power
*
tf
.
math
.
pow
(
max_offset
,
power
+
1
))
+
k
*
(
power
+
1
)
/
(
power
*
max_offset
)
return
p
u
=
tf
.
random
.
uniform
(())
k_low
=
tf
.
constant
(
0
,
dtype
=
tf
.
float32
)
k_up
=
max_offset
k
=
tf
.
math
.
floordiv
(
max_offset
,
2.0
)
c
=
lambda
k_low
,
k_up
,
k
:
tf
.
greater
(
tf
.
math
.
abs
(
k_up
-
k_low
),
1.0
)
# pylint:disable=g-long-lambda
b
=
lambda
k_low
,
k_up
,
k
:
tf
.
cond
(
tf
.
greater
(
cdf
(
k
),
u
),
lambda
:
[
k_low
,
k
,
tf
.
math
.
floordiv
(
k
+
k_low
,
2.0
)],
lambda
:
[
k
,
k_up
,
tf
.
math
.
floordiv
(
k_up
+
k
,
2.0
)])
_
,
_
,
k
=
tf
.
while_loop
(
c
,
b
,
[
k_low
,
k_up
,
k
])
delta
=
tf
.
cast
(
k
,
tf
.
int32
)
max_offset
=
tf
.
cast
(
max_offset
,
tf
.
int32
)
sequence_length
=
tf
.
cast
(
sequence_length
,
tf
.
int32
)
choice_1
=
tf
.
cond
(
tf
.
equal
(
max_offset
,
sequence_length
),
lambda
:
tf
.
random
.
uniform
((),
maxval
=
tf
.
cast
(
max_offset
,
dtype
=
tf
.
int32
),
dtype
=
tf
.
int32
),
lambda
:
tf
.
random
.
uniform
((),
maxval
=
tf
.
cast
(
max_offset
-
delta
,
dtype
=
tf
.
int32
),
dtype
=
tf
.
int32
))
choice_2
=
tf
.
cond
(
tf
.
equal
(
max_offset
,
sequence_length
),
lambda
:
tf
.
random
.
uniform
((),
maxval
=
tf
.
cast
(
max_offset
,
dtype
=
tf
.
int32
),
dtype
=
tf
.
int32
),
lambda
:
choice_1
+
delta
)
# pylint:disable=g-long-lambda
shuffle_choice
=
tf
.
random
.
shuffle
((
choice_1
,
choice_2
))
offset_1
=
shuffle_choice
[
0
]
offset_2
=
shuffle_choice
[
1
]
else
:
raise
NotImplementedError
indices_1
=
_sample_or_pad_sequence_indices
(
sequence
=
sequence
,
num_steps
=
num_steps
,
stride
=
stride
,
offset
=
offset_1
)
indices_2
=
_sample_or_pad_sequence_indices
(
sequence
=
sequence
,
num_steps
=
num_steps
,
stride
=
stride
,
offset
=
offset_2
)
indices
=
tf
.
concat
([
indices_1
,
indices_2
],
axis
=
0
)
indices
.
set_shape
((
num_windows
*
num_steps
,))
output
=
tf
.
gather
(
sequence
,
indices
)
return
output
official/vision/beta/projects/video_ssl/ops/video_ssl_preprocess_ops_test.py
deleted
100644 → 0
View file @
e0f818da
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
tensorflow
as
tf
from
official.vision.beta.ops
import
preprocess_ops_3d
from
official.vision.beta.projects.video_ssl.ops
import
video_ssl_preprocess_ops
class
VideoSslPreprocessOpsTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
().
setUp
()
self
.
_raw_frames
=
tf
.
random
.
uniform
((
250
,
256
,
256
,
3
),
minval
=
0
,
maxval
=
255
,
dtype
=
tf
.
dtypes
.
int32
)
self
.
_sampled_frames
=
self
.
_raw_frames
[:
16
]
self
.
_frames
=
preprocess_ops_3d
.
normalize_image
(
self
.
_sampled_frames
,
False
,
tf
.
float32
)
def
test_sample_ssl_sequence
(
self
):
sampled_seq
=
video_ssl_preprocess_ops
.
sample_ssl_sequence
(
self
.
_raw_frames
,
16
,
True
,
2
)
self
.
assertAllEqual
(
sampled_seq
.
shape
,
(
32
,
256
,
256
,
3
))
def
test_random_color_jitter_3d
(
self
):
jittered_clip
=
video_ssl_preprocess_ops
.
random_color_jitter_3d
(
self
.
_frames
)
self
.
assertAllEqual
(
jittered_clip
.
shape
,
(
16
,
256
,
256
,
3
))
def
test_random_blur_3d
(
self
):
blurred_clip
=
video_ssl_preprocess_ops
.
random_blur_3d
(
self
.
_frames
,
256
,
256
)
self
.
assertAllEqual
(
blurred_clip
.
shape
,
(
16
,
256
,
256
,
3
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment