Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
565fbe88
Commit
565fbe88
authored
Jan 08, 2021
by
A. Unique TensorFlower
Browse files
Support input pipeline for video ssl.
PiperOrigin-RevId: 350874149
parent
f79b1875
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
570 additions
and
0 deletions
+570
-0
official/vision/beta/projects/video_ssl/configs/__init__.py
official/vision/beta/projects/video_ssl/configs/__init__.py
+18
-0
official/vision/beta/projects/video_ssl/configs/video_ssl.py
official/vision/beta/projects/video_ssl/configs/video_ssl.py
+87
-0
official/vision/beta/projects/video_ssl/configs/video_ssl_test.py
.../vision/beta/projects/video_ssl/configs/video_ssl_test.py
+45
-0
official/vision/beta/projects/video_ssl/dataloaders/video_ssl_input.py
...on/beta/projects/video_ssl/dataloaders/video_ssl_input.py
+309
-0
official/vision/beta/projects/video_ssl/dataloaders/video_ssl_input_test.py
...ta/projects/video_ssl/dataloaders/video_ssl_input_test.py
+111
-0
No files found.
official/vision/beta/projects/video_ssl/configs/__init__.py
0 → 100644
View file @
565fbe88
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Configs package definition."""
from
official.vision.beta.projects.video_ssl.configs
import
video_ssl
official/vision/beta/projects/video_ssl/configs/video_ssl.py
0 → 100644
View file @
565fbe88
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Video classification configuration definition."""
import
dataclasses
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.vision.beta.configs
import
video_classification
Losses
=
video_classification
.
Losses
VideoClassificationModel
=
video_classification
.
VideoClassificationModel
VideoClassificationTask
=
video_classification
.
VideoClassificationTask
@
dataclasses
.
dataclass
class
DataConfig
(
video_classification
.
DataConfig
):
"""The base configuration for building datasets."""
is_ssl
:
bool
=
False
@
exp_factory
.
register_config_factory
(
'video_ssl_pretrain_kinetics400'
)
def
video_ssl_pretrain_kinetics400
()
->
cfg
.
ExperimentConfig
:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp
=
video_classification
.
video_classification_kinetics400
()
exp
.
task
.
train_data
=
DataConfig
(
is_ssl
=
True
,
**
exp
.
task
.
train_data
.
as_dict
())
exp
.
task
.
train_data
.
feature_shape
=
(
16
,
224
,
224
,
3
)
exp
.
task
.
train_data
.
temporal_stride
=
2
return
exp
@
exp_factory
.
register_config_factory
(
'video_ssl_linear_eval_kinetics400'
)
def
video_ssl_linear_eval_kinetics400
()
->
cfg
.
ExperimentConfig
:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp
=
video_classification
.
video_classification_kinetics400
()
exp
.
task
.
train_data
=
DataConfig
(
is_ssl
=
False
,
**
exp
.
task
.
train_data
.
as_dict
())
exp
.
task
.
train_data
.
feature_shape
=
(
32
,
224
,
224
,
3
)
exp
.
task
.
train_data
.
temporal_stride
=
2
exp
.
task
.
validation_data
.
feature_shape
=
(
32
,
256
,
256
,
3
)
exp
.
task
.
validation_data
.
temporal_stride
=
2
exp
.
task
.
validation_data
=
DataConfig
(
is_ssl
=
False
,
**
exp
.
task
.
validation_data
.
as_dict
())
exp
.
task
.
validation_data
.
min_image_size
=
256
exp
.
task
.
validation_data
.
num_test_clips
=
10
return
exp
@
exp_factory
.
register_config_factory
(
'video_ssl_pretrain_kinetics600'
)
def
video_ssl_pretrain_kinetics600
()
->
cfg
.
ExperimentConfig
:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp
=
video_classification
.
video_classification_kinetics600
()
exp
.
task
.
train_data
=
DataConfig
(
is_ssl
=
True
,
**
exp
.
task
.
train_data
.
as_dict
())
exp
.
task
.
train_data
.
feature_shape
=
(
16
,
224
,
224
,
3
)
exp
.
task
.
train_data
.
temporal_stride
=
2
return
exp
@
exp_factory
.
register_config_factory
(
'video_ssl_linear_eval_kinetics600'
)
def
video_ssl_linear_eval_kinetics600
()
->
cfg
.
ExperimentConfig
:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp
=
video_classification
.
video_classification_kinetics600
()
exp
.
task
.
train_data
=
DataConfig
(
is_ssl
=
False
,
**
exp
.
task
.
train_data
.
as_dict
())
exp
.
task
.
train_data
.
feature_shape
=
(
32
,
224
,
224
,
3
)
exp
.
task
.
train_data
.
temporal_stride
=
2
exp
.
task
.
validation_data
=
DataConfig
(
is_ssl
=
False
,
**
exp
.
task
.
validation_data
.
as_dict
())
exp
.
task
.
validation_data
.
feature_shape
=
(
32
,
256
,
256
,
3
)
exp
.
task
.
validation_data
.
temporal_stride
=
2
exp
.
task
.
validation_data
.
min_image_size
=
256
exp
.
task
.
validation_data
.
num_test_clips
=
10
return
exp
official/vision/beta/projects/video_ssl/configs/video_ssl_test.py
0 → 100644
View file @
565fbe88
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=unused-import
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.vision
import
beta
from
official.vision.beta.projects.video_ssl.configs
import
video_ssl
as
exp_cfg
class
VideoClassificationConfigTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
((
'video_ssl_pretrain_kinetics400'
,),
(
'video_ssl_linear_eval_kinetics400'
,),
(
'video_ssl_pretrain_kinetics600'
,),
(
'video_ssl_linear_eval_kinetics600'
,))
def
test_video_classification_configs
(
self
,
config_name
):
config
=
exp_factory
.
get_exp_config
(
config_name
)
self
.
assertIsInstance
(
config
,
cfg
.
ExperimentConfig
)
self
.
assertIsInstance
(
config
.
task
,
exp_cfg
.
VideoClassificationTask
)
self
.
assertIsInstance
(
config
.
task
.
model
,
exp_cfg
.
VideoClassificationModel
)
self
.
assertIsInstance
(
config
.
task
.
train_data
,
exp_cfg
.
DataConfig
)
config
.
task
.
train_data
.
is_training
=
None
with
self
.
assertRaises
(
KeyError
):
config
.
validate
()
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/video_ssl/dataloaders/video_ssl_input.py
0 → 100644
View file @
565fbe88
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Parser for video and label datasets."""
from
typing
import
Dict
,
Optional
,
Tuple
from
absl
import
logging
import
tensorflow
as
tf
from
official.vision.beta.dataloaders
import
video_input
from
official.vision.beta.ops
import
preprocess_ops_3d
from
official.vision.beta.projects.video_ssl.configs
import
video_ssl
as
exp_cfg
from
official.vision.beta.projects.video_ssl.ops
import
video_ssl_preprocess_ops
IMAGE_KEY
=
'image/encoded'
LABEL_KEY
=
'clip/label/index'
Decoder
=
video_input
.
Decoder
def
_process_image
(
image
:
tf
.
Tensor
,
is_training
:
bool
=
True
,
is_ssl
:
bool
=
False
,
num_frames
:
int
=
32
,
stride
:
int
=
1
,
num_test_clips
:
int
=
1
,
min_resize
:
int
=
224
,
crop_size
:
int
=
200
,
zero_centering_image
:
bool
=
False
,
seed
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Processes a serialized image tensor.
Args:
image: Input Tensor of shape [timesteps] and type tf.string of serialized
frames.
is_training: Whether or not in training mode. If True, random sample, crop
and left right flip is used.
is_ssl: Whether or not in self-supervised pre-training mode.
num_frames: Number of frames per subclip.
stride: Temporal stride to sample frames.
num_test_clips: Number of test clips (1 by default). If more than 1, this
will sample multiple linearly spaced clips within each video at test time.
If 1, then a single clip in the middle of the video is sampled. The clips
are aggreagated in the batch dimension.
min_resize: Frames are resized so that min(height, width) is min_resize.
crop_size: Final size of the frame after cropping the resized frames. Both
height and width are the same.
zero_centering_image: If True, frames are normalized to values in [-1, 1].
If False, values in [0, 1].
seed: A deterministic seed to use when sampling.
Returns:
Processed frames. Tensor of shape
[num_frames * num_test_clips, crop_size, crop_size, 3].
"""
# Validate parameters.
if
is_training
and
num_test_clips
!=
1
:
logging
.
warning
(
'`num_test_clips` %d is ignored since `is_training` is `True`.'
,
num_test_clips
)
# Temporal sampler.
if
is_training
:
# Sampler for training.
if
is_ssl
:
# Sample two clips from linear decreasing distribution.
image
=
video_ssl_preprocess_ops
.
sample_ssl_sequence
(
image
,
num_frames
,
True
,
stride
)
else
:
# Sample random clip.
image
=
preprocess_ops_3d
.
sample_sequence
(
image
,
num_frames
,
True
,
stride
)
else
:
# Sampler for evaluation.
if
num_test_clips
>
1
:
# Sample linspace clips.
image
=
preprocess_ops_3d
.
sample_linspace_sequence
(
image
,
num_test_clips
,
num_frames
,
stride
)
else
:
# Sample middle clip.
image
=
preprocess_ops_3d
.
sample_sequence
(
image
,
num_frames
,
False
,
stride
)
# Decode JPEG string to tf.uint8.
image
=
preprocess_ops_3d
.
decode_jpeg
(
image
,
3
)
if
is_training
:
# Standard image data augmentation: random resized crop and random flip.
if
is_ssl
:
image_1
,
image_2
=
tf
.
split
(
image
,
num_or_size_splits
=
2
,
axis
=
0
)
image_1
=
preprocess_ops_3d
.
random_crop_resize
(
image_1
,
crop_size
,
crop_size
,
num_frames
,
3
,
(
0.5
,
2
),
(
0.3
,
1
))
image_1
=
preprocess_ops_3d
.
random_flip_left_right
(
image_1
,
seed
)
image_2
=
preprocess_ops_3d
.
random_crop_resize
(
image_2
,
crop_size
,
crop_size
,
num_frames
,
3
,
(
0.5
,
2
),
(
0.3
,
1
))
image_2
=
preprocess_ops_3d
.
random_flip_left_right
(
image_2
,
seed
)
else
:
image
=
preprocess_ops_3d
.
random_crop_resize
(
image
,
crop_size
,
crop_size
,
num_frames
,
3
,
(
0.5
,
2
),
(
0.3
,
1
))
image
=
preprocess_ops_3d
.
random_flip_left_right
(
image
,
seed
)
else
:
# Resize images (resize happens only if necessary to save compute).
image
=
preprocess_ops_3d
.
resize_smallest
(
image
,
min_resize
)
# Three-crop of the frames.
image
=
preprocess_ops_3d
.
crop_image
(
image
,
min_resize
,
min_resize
,
False
,
3
)
# Cast the frames in float32, normalizing according to zero_centering_image.
if
is_training
and
is_ssl
:
image_1
=
preprocess_ops_3d
.
normalize_image
(
image_1
,
zero_centering_image
)
image_2
=
preprocess_ops_3d
.
normalize_image
(
image_2
,
zero_centering_image
)
else
:
image
=
preprocess_ops_3d
.
normalize_image
(
image
,
zero_centering_image
)
# Self-supervised pre-training augmentations.
if
is_training
and
is_ssl
:
# Temporally consistent color jittering.
image_1
=
video_ssl_preprocess_ops
.
random_color_jitter_3d
(
image_1
)
image_2
=
video_ssl_preprocess_ops
.
random_color_jitter_3d
(
image_2
)
# Temporally consistent gaussian blurring.
image_1
=
video_ssl_preprocess_ops
.
random_blur_3d
(
image_1
,
num_frames
,
crop_size
,
crop_size
)
image_2
=
video_ssl_preprocess_ops
.
random_blur_3d
(
image_2
,
num_frames
,
crop_size
,
crop_size
)
image
=
tf
.
concat
([
image_1
,
image_2
],
axis
=
0
)
return
image
def
_postprocess_image
(
image
:
tf
.
Tensor
,
is_training
:
bool
=
True
,
is_ssl
:
bool
=
False
,
num_frames
:
int
=
32
,
num_test_clips
:
int
=
1
)
->
tf
.
Tensor
:
"""Processes a batched Tensor of frames.
The same parameters used in process should be used here.
Args:
image: Input Tensor of shape [batch, timesteps, height, width, 3].
is_training: Whether or not in training mode. If True, random sample, crop
and left right flip is used.
is_ssl: Whether or not in self-supervised pre-training mode.
num_frames: Number of frames per subclip.
num_test_clips: Number of test clips (1 by default). If more than 1, this
will sample multiple linearly spaced clips within each video at test time.
If 1, then a single clip in the middle of the video is sampled. The clips
are aggreagated in the batch dimension.
Returns:
Processed frames. Tensor of shape
[batch * num_test_clips, num_frames, height, width, 3].
"""
if
is_ssl
and
is_training
:
# In this case, two clips of self-supervised pre-training are merged
# together in batch dimenstion which will be 2 * batch.
image
=
tf
.
concat
(
tf
.
split
(
image
,
num_or_size_splits
=
2
,
axis
=
1
),
axis
=
0
)
if
num_test_clips
>
1
and
not
is_training
:
# In this case, multiple clips are merged together in batch dimenstion which
# will be B * num_test_clips.
image
=
tf
.
reshape
(
image
,
(
-
1
,
num_frames
)
+
image
.
shape
[
2
:])
return
image
def
_process_label
(
label
:
tf
.
Tensor
,
one_hot_label
:
bool
=
True
,
num_classes
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Processes label Tensor."""
# Validate parameters.
if
one_hot_label
and
not
num_classes
:
raise
ValueError
(
'`num_classes` should be given when requesting one hot label.'
)
# Cast to tf.int32.
label
=
tf
.
cast
(
label
,
dtype
=
tf
.
int32
)
if
one_hot_label
:
# Replace label index by one hot representation.
label
=
tf
.
one_hot
(
label
,
num_classes
)
if
len
(
label
.
shape
.
as_list
())
>
1
:
label
=
tf
.
reduce_sum
(
label
,
axis
=
0
)
if
num_classes
==
1
:
# The trick for single label.
label
=
1
-
label
return
label
class
Parser
(
video_input
.
Parser
):
"""Parses a video and label dataset."""
def
__init__
(
self
,
input_params
:
exp_cfg
.
DataConfig
,
image_key
:
str
=
IMAGE_KEY
,
label_key
:
str
=
LABEL_KEY
):
super
(
Parser
,
self
).
__init__
(
input_params
,
image_key
,
label_key
)
self
.
_is_ssl
=
input_params
.
is_ssl
def
_parse_train_data
(
self
,
decoded_tensors
:
Dict
[
str
,
tf
.
Tensor
]
)
->
Tuple
[
Dict
[
str
,
tf
.
Tensor
],
tf
.
Tensor
]:
"""Parses data for training."""
# Process image and label.
image
=
decoded_tensors
[
self
.
_image_key
]
image
=
_process_image
(
image
=
image
,
is_training
=
True
,
is_ssl
=
self
.
_is_ssl
,
num_frames
=
self
.
_num_frames
,
stride
=
self
.
_stride
,
num_test_clips
=
self
.
_num_test_clips
,
min_resize
=
self
.
_min_resize
,
crop_size
=
self
.
_crop_size
)
image
=
tf
.
cast
(
image
,
dtype
=
self
.
_dtype
)
features
=
{
'image'
:
image
}
label
=
decoded_tensors
[
self
.
_label_key
]
label
=
_process_label
(
label
,
self
.
_one_hot_label
,
self
.
_num_classes
)
return
features
,
label
def
_parse_eval_data
(
self
,
decoded_tensors
:
Dict
[
str
,
tf
.
Tensor
]
)
->
Tuple
[
Dict
[
str
,
tf
.
Tensor
],
tf
.
Tensor
]:
"""Parses data for evaluation."""
image
=
decoded_tensors
[
self
.
_image_key
]
image
=
_process_image
(
image
=
image
,
is_training
=
False
,
num_frames
=
self
.
_num_frames
,
stride
=
self
.
_stride
,
num_test_clips
=
self
.
_num_test_clips
,
min_resize
=
self
.
_min_resize
,
crop_size
=
self
.
_crop_size
)
image
=
tf
.
cast
(
image
,
dtype
=
self
.
_dtype
)
features
=
{
'image'
:
image
}
label
=
decoded_tensors
[
self
.
_label_key
]
label
=
_process_label
(
label
,
self
.
_one_hot_label
,
self
.
_num_classes
)
if
self
.
_output_audio
:
audio
=
decoded_tensors
[
self
.
_audio_feature
]
audio
=
tf
.
cast
(
audio
,
dtype
=
self
.
_dtype
)
audio
=
preprocess_ops_3d
.
sample_sequence
(
audio
,
20
,
random
=
False
,
stride
=
1
)
audio
=
tf
.
ensure_shape
(
audio
,
[
20
,
2048
])
features
[
'audio'
]
=
audio
return
features
,
label
def
parse_fn
(
self
,
is_training
):
"""Returns a parse fn that reads and parses raw tensors from the decoder.
Args:
is_training: a `bool` to indicate whether it is in training mode.
Returns:
parse: a `callable` that takes the serialized examle and generate the
images, labels tuple where labels is a dict of Tensors that contains
labels.
"""
def
parse
(
decoded_tensors
):
"""Parses the serialized example data."""
if
is_training
:
return
self
.
_parse_train_data
(
decoded_tensors
)
else
:
return
self
.
_parse_eval_data
(
decoded_tensors
)
return
parse
class
PostBatchProcessor
(
object
):
"""Processes a video and label dataset which is batched."""
def
__init__
(
self
,
input_params
:
exp_cfg
.
DataConfig
):
self
.
_is_training
=
input_params
.
is_training
self
.
_is_ssl
=
input_params
.
is_ssl
self
.
_num_frames
=
input_params
.
feature_shape
[
0
]
self
.
_num_test_clips
=
input_params
.
num_test_clips
def
__call__
(
self
,
features
:
Dict
[
str
,
tf
.
Tensor
],
label
:
tf
.
Tensor
)
->
Tuple
[
Dict
[
str
,
tf
.
Tensor
],
tf
.
Tensor
]:
"""Parses a single tf.Example into image and label tensors."""
for
key
in
[
'image'
,
'audio'
]:
if
key
in
features
:
features
[
key
]
=
_postprocess_image
(
image
=
features
[
key
],
is_training
=
self
.
_is_training
,
is_ssl
=
self
.
_is_ssl
,
num_frames
=
self
.
_num_frames
,
num_test_clips
=
self
.
_num_test_clips
)
return
features
,
label
official/vision/beta/projects/video_ssl/dataloaders/video_ssl_input_test.py
0 → 100644
View file @
565fbe88
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
io
# Import libraries
import
numpy
as
np
from
PIL
import
Image
import
tensorflow
as
tf
from
official.vision.beta.projects.video_ssl.configs
import
video_ssl
as
exp_cfg
from
official.vision.beta.projects.video_ssl.dataloaders
import
video_ssl_input
AUDIO_KEY
=
'features/audio'
def
fake_seq_example
():
# Create fake data.
random_image
=
np
.
random
.
randint
(
0
,
256
,
size
=
(
263
,
320
,
3
),
dtype
=
np
.
uint8
)
random_image
=
Image
.
fromarray
(
random_image
)
label
=
42
with
io
.
BytesIO
()
as
buffer
:
random_image
.
save
(
buffer
,
format
=
'JPEG'
)
raw_image_bytes
=
buffer
.
getvalue
()
seq_example
=
tf
.
train
.
SequenceExample
()
seq_example
.
feature_lists
.
feature_list
.
get_or_create
(
video_ssl_input
.
IMAGE_KEY
).
feature
.
add
().
bytes_list
.
value
[:]
=
[
raw_image_bytes
]
seq_example
.
feature_lists
.
feature_list
.
get_or_create
(
video_ssl_input
.
IMAGE_KEY
).
feature
.
add
().
bytes_list
.
value
[:]
=
[
raw_image_bytes
]
seq_example
.
context
.
feature
[
video_ssl_input
.
LABEL_KEY
].
int64_list
.
value
[:]
=
[
label
]
random_audio
=
np
.
random
.
normal
(
size
=
(
10
,
256
)).
tolist
()
for
s
in
random_audio
:
seq_example
.
feature_lists
.
feature_list
.
get_or_create
(
AUDIO_KEY
).
feature
.
add
().
float_list
.
value
[:]
=
s
return
seq_example
,
label
class
VideoAndLabelParserTest
(
tf
.
test
.
TestCase
):
def
test_video_ssl_input_pretrain
(
self
):
params
=
exp_cfg
.
video_ssl_pretrain_kinetics600
().
task
.
train_data
decoder
=
video_ssl_input
.
Decoder
()
parser
=
video_ssl_input
.
Parser
(
params
).
parse_fn
(
params
.
is_training
)
seq_example
,
_
=
fake_seq_example
()
input_tensor
=
tf
.
constant
(
seq_example
.
SerializeToString
())
decoded_tensors
=
decoder
.
decode
(
input_tensor
)
output_tensor
=
parser
(
decoded_tensors
)
image_features
,
_
=
output_tensor
image
=
image_features
[
'image'
]
self
.
assertAllEqual
(
image
.
shape
,
(
32
,
224
,
224
,
3
))
def
test_video_ssl_input_linear_train
(
self
):
params
=
exp_cfg
.
video_ssl_linear_eval_kinetics600
().
task
.
train_data
decoder
=
video_ssl_input
.
Decoder
()
parser
=
video_ssl_input
.
Parser
(
params
).
parse_fn
(
params
.
is_training
)
seq_example
,
label
=
fake_seq_example
()
input_tensor
=
tf
.
constant
(
seq_example
.
SerializeToString
())
decoded_tensors
=
decoder
.
decode
(
input_tensor
)
output_tensor
=
parser
(
decoded_tensors
)
image_features
,
label
=
output_tensor
image
=
image_features
[
'image'
]
self
.
assertAllEqual
(
image
.
shape
,
(
32
,
224
,
224
,
3
))
self
.
assertAllEqual
(
label
.
shape
,
(
600
,))
def
test_video_ssl_input_linear_eval
(
self
):
params
=
exp_cfg
.
video_ssl_linear_eval_kinetics600
().
task
.
validation_data
print
(
'!!!'
,
params
)
decoder
=
video_ssl_input
.
Decoder
()
parser
=
video_ssl_input
.
Parser
(
params
).
parse_fn
(
params
.
is_training
)
seq_example
,
label
=
fake_seq_example
()
input_tensor
=
tf
.
constant
(
seq_example
.
SerializeToString
())
decoded_tensors
=
decoder
.
decode
(
input_tensor
)
output_tensor
=
parser
(
decoded_tensors
)
image_features
,
label
=
output_tensor
image
=
image_features
[
'image'
]
self
.
assertAllEqual
(
image
.
shape
,
(
960
,
256
,
256
,
3
))
self
.
assertAllEqual
(
label
.
shape
,
(
600
,))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment