Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
44e7092c
Commit
44e7092c
authored
Feb 01, 2021
by
stephenwu
Browse files
Merge branch 'master' of
https://github.com/tensorflow/models
into AXg
parents
431a9ca3
59434199
Changes
113
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
109 additions
and
1110 deletions
+109
-1110
official/vision/beta/configs/video_classification.py
official/vision/beta/configs/video_classification.py
+5
-0
official/vision/beta/data/tfrecord_lib.py
official/vision/beta/data/tfrecord_lib.py
+24
-6
official/vision/beta/dataloaders/dataset_fn.py
official/vision/beta/dataloaders/dataset_fn.py
+1
-1
official/vision/beta/dataloaders/tf_example_decoder.py
official/vision/beta/dataloaders/tf_example_decoder.py
+6
-44
official/vision/beta/dataloaders/tf_example_label_map_decoder.py
...l/vision/beta/dataloaders/tf_example_label_map_decoder.py
+4
-2
official/vision/beta/dataloaders/video_input.py
official/vision/beta/dataloaders/video_input.py
+48
-19
official/vision/beta/modeling/layers/nn_layers.py
official/vision/beta/modeling/layers/nn_layers.py
+0
-1
official/vision/beta/ops/anchor.py
official/vision/beta/ops/anchor.py
+2
-1
official/vision/beta/ops/preprocess_ops_3d.py
official/vision/beta/ops/preprocess_ops_3d.py
+7
-7
official/vision/beta/projects/video_ssl/configs/__init__.py
official/vision/beta/projects/video_ssl/configs/__init__.py
+0
-18
official/vision/beta/projects/video_ssl/configs/video_ssl.py
official/vision/beta/projects/video_ssl/configs/video_ssl.py
+0
-87
official/vision/beta/projects/video_ssl/configs/video_ssl_test.py
.../vision/beta/projects/video_ssl/configs/video_ssl_test.py
+0
-45
official/vision/beta/projects/video_ssl/dataloaders/video_ssl_input.py
...on/beta/projects/video_ssl/dataloaders/video_ssl_input.py
+0
-309
official/vision/beta/projects/video_ssl/dataloaders/video_ssl_input_test.py
...ta/projects/video_ssl/dataloaders/video_ssl_input_test.py
+0
-111
official/vision/beta/projects/video_ssl/ops/video_ssl_preprocess_ops.py
...n/beta/projects/video_ssl/ops/video_ssl_preprocess_ops.py
+0
-397
official/vision/beta/projects/video_ssl/ops/video_ssl_preprocess_ops_test.py
...a/projects/video_ssl/ops/video_ssl_preprocess_ops_test.py
+0
-47
official/vision/beta/projects/yolo/train.py
official/vision/beta/projects/yolo/train.py
+2
-0
official/vision/beta/serving/export_saved_model_lib.py
official/vision/beta/serving/export_saved_model_lib.py
+1
-1
official/vision/beta/tasks/maskrcnn.py
official/vision/beta/tasks/maskrcnn.py
+6
-3
official/vision/beta/tasks/retinanet.py
official/vision/beta/tasks/retinanet.py
+3
-11
No files found.
official/vision/beta/configs/video_classification.py
View file @
44e7092c
...
@@ -34,6 +34,7 @@ class DataConfig(cfg.DataConfig):
...
@@ -34,6 +34,7 @@ class DataConfig(cfg.DataConfig):
feature_shape
:
Tuple
[
int
,
...]
=
(
64
,
224
,
224
,
3
)
feature_shape
:
Tuple
[
int
,
...]
=
(
64
,
224
,
224
,
3
)
temporal_stride
:
int
=
1
temporal_stride
:
int
=
1
num_test_clips
:
int
=
1
num_test_clips
:
int
=
1
num_test_crops
:
int
=
1
num_classes
:
int
=
-
1
num_classes
:
int
=
-
1
num_channels
:
int
=
3
num_channels
:
int
=
3
num_examples
:
int
=
-
1
num_examples
:
int
=
-
1
...
@@ -53,6 +54,10 @@ class DataConfig(cfg.DataConfig):
...
@@ -53,6 +54,10 @@ class DataConfig(cfg.DataConfig):
output_audio
:
bool
=
False
output_audio
:
bool
=
False
audio_feature
:
str
=
''
audio_feature
:
str
=
''
audio_feature_shape
:
Tuple
[
int
,
...]
=
(
-
1
,)
audio_feature_shape
:
Tuple
[
int
,
...]
=
(
-
1
,)
aug_min_aspect_ratio
:
float
=
0.5
aug_max_aspect_ratio
:
float
=
2.0
aug_min_area_ratio
:
float
=
0.49
aug_max_area_ratio
:
float
=
1.0
def
kinetics400
(
is_training
):
def
kinetics400
(
is_training
):
...
...
official/vision/beta/data/tfrecord_lib.py
View file @
44e7092c
...
@@ -19,6 +19,7 @@ import io
...
@@ -19,6 +19,7 @@ import io
import
itertools
import
itertools
from
absl
import
logging
from
absl
import
logging
import
numpy
as
np
from
PIL
import
Image
from
PIL
import
Image
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -45,10 +46,10 @@ def convert_to_feature(value, value_type=None):
...
@@ -45,10 +46,10 @@ def convert_to_feature(value, value_type=None):
if
isinstance
(
element
,
bytes
):
if
isinstance
(
element
,
bytes
):
value_type
=
'bytes'
value_type
=
'bytes'
elif
isinstance
(
element
,
int
):
elif
isinstance
(
element
,
(
int
,
np
.
integer
)
):
value_type
=
'int64'
value_type
=
'int64'
elif
isinstance
(
element
,
float
):
elif
isinstance
(
element
,
(
float
,
np
.
floating
)
):
value_type
=
'float'
value_type
=
'float'
else
:
else
:
...
@@ -104,8 +105,9 @@ def encode_binary_mask_as_png(binary_mask):
...
@@ -104,8 +105,9 @@ def encode_binary_mask_as_png(binary_mask):
return
output_io
.
getvalue
()
return
output_io
.
getvalue
()
def
write_tf_record_dataset
(
output_path
,
annotation_iterator
,
process_func
,
def
write_tf_record_dataset
(
output_path
,
annotation_iterator
,
num_shards
,
use_multiprocessing
=
True
):
process_func
,
num_shards
,
use_multiprocessing
=
True
,
unpack_arguments
=
True
):
"""Iterates over annotations, processes them and writes into TFRecords.
"""Iterates over annotations, processes them and writes into TFRecords.
Args:
Args:
...
@@ -118,6 +120,9 @@ def write_tf_record_dataset(output_path, annotation_iterator, process_func,
...
@@ -118,6 +120,9 @@ def write_tf_record_dataset(output_path, annotation_iterator, process_func,
num_shards: int, the number of shards to write for the dataset.
num_shards: int, the number of shards to write for the dataset.
use_multiprocessing:
use_multiprocessing:
Whether or not to use multiple processes to write TF Records.
Whether or not to use multiple processes to write TF Records.
unpack_arguments:
Whether to unpack the tuples from annotation_iterator as individual
arguments to the process func or to pass the returned value as it is.
Returns:
Returns:
num_skipped: The total number of skipped annotations.
num_skipped: The total number of skipped annotations.
...
@@ -133,9 +138,15 @@ def write_tf_record_dataset(output_path, annotation_iterator, process_func,
...
@@ -133,9 +138,15 @@ def write_tf_record_dataset(output_path, annotation_iterator, process_func,
if
use_multiprocessing
:
if
use_multiprocessing
:
pool
=
mp
.
Pool
()
pool
=
mp
.
Pool
()
tf_example_iterator
=
pool
.
starmap
(
process_func
,
annotation_iterator
)
if
unpack_arguments
:
tf_example_iterator
=
pool
.
starmap
(
process_func
,
annotation_iterator
)
else
:
tf_example_iterator
=
pool
.
imap
(
process_func
,
annotation_iterator
)
else
:
else
:
tf_example_iterator
=
itertools
.
starmap
(
process_func
,
annotation_iterator
)
if
unpack_arguments
:
tf_example_iterator
=
itertools
.
starmap
(
process_func
,
annotation_iterator
)
else
:
tf_example_iterator
=
map
(
process_func
,
annotation_iterator
)
for
idx
,
(
tf_example
,
num_annotations_skipped
)
in
enumerate
(
for
idx
,
(
tf_example
,
num_annotations_skipped
)
in
enumerate
(
tf_example_iterator
):
tf_example_iterator
):
...
@@ -155,3 +166,10 @@ def write_tf_record_dataset(output_path, annotation_iterator, process_func,
...
@@ -155,3 +166,10 @@ def write_tf_record_dataset(output_path, annotation_iterator, process_func,
logging
.
info
(
'Finished writing, skipped %d annotations.'
,
logging
.
info
(
'Finished writing, skipped %d annotations.'
,
total_num_annotations_skipped
)
total_num_annotations_skipped
)
return
total_num_annotations_skipped
return
total_num_annotations_skipped
def
check_and_make_dir
(
directory
):
"""Creates the directory if it doesn't exist."""
if
not
tf
.
io
.
gfile
.
isdir
(
directory
):
tf
.
io
.
gfile
.
makedirs
(
directory
)
official/vision/beta/dataloaders/dataset_fn.py
View file @
44e7092c
...
@@ -22,7 +22,7 @@ PossibleDatasetType = Union[Type[tf.data.Dataset], Callable[[tf.Tensor], Any]]
...
@@ -22,7 +22,7 @@ PossibleDatasetType = Union[Type[tf.data.Dataset], Callable[[tf.Tensor], Any]]
def
pick_dataset_fn
(
file_type
:
str
)
->
PossibleDatasetType
:
def
pick_dataset_fn
(
file_type
:
str
)
->
PossibleDatasetType
:
if
file_type
==
'tf
_
record'
:
if
file_type
==
'tfrecord'
:
return
tf
.
data
.
TFRecordDataset
return
tf
.
data
.
TFRecordDataset
raise
ValueError
(
'Unrecognized file_type: {}'
.
format
(
file_type
))
raise
ValueError
(
'Unrecognized file_type: {}'
.
format
(
file_type
))
official/vision/beta/dataloaders/tf_example_decoder.py
View file @
44e7092c
...
@@ -17,8 +17,6 @@
...
@@ -17,8 +17,6 @@
A decoder to decode string tensors containing serialized tensorflow.Example
A decoder to decode string tensors containing serialized tensorflow.Example
protos for object detection.
protos for object detection.
"""
"""
import
csv
# Import libraries
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.vision.beta.dataloaders
import
decoder
from
official.vision.beta.dataloaders
import
decoder
...
@@ -34,7 +32,8 @@ class TfExampleDecoder(decoder.Decoder):
...
@@ -34,7 +32,8 @@ class TfExampleDecoder(decoder.Decoder):
def
__init__
(
self
,
def
__init__
(
self
,
include_mask
=
False
,
include_mask
=
False
,
regenerate_source_id
=
False
):
regenerate_source_id
=
False
,
mask_binarize_threshold
=
None
):
self
.
_include_mask
=
include_mask
self
.
_include_mask
=
include_mask
self
.
_regenerate_source_id
=
regenerate_source_id
self
.
_regenerate_source_id
=
regenerate_source_id
self
.
_keys_to_features
=
{
self
.
_keys_to_features
=
{
...
@@ -50,6 +49,7 @@ class TfExampleDecoder(decoder.Decoder):
...
@@ -50,6 +49,7 @@ class TfExampleDecoder(decoder.Decoder):
'image/object/area'
:
tf
.
io
.
VarLenFeature
(
tf
.
float32
),
'image/object/area'
:
tf
.
io
.
VarLenFeature
(
tf
.
float32
),
'image/object/is_crowd'
:
tf
.
io
.
VarLenFeature
(
tf
.
int64
),
'image/object/is_crowd'
:
tf
.
io
.
VarLenFeature
(
tf
.
int64
),
}
}
self
.
_mask_binarize_threshold
=
mask_binarize_threshold
if
include_mask
:
if
include_mask
:
self
.
_keys_to_features
.
update
({
self
.
_keys_to_features
.
update
({
'image/object/mask'
:
tf
.
io
.
VarLenFeature
(
tf
.
string
),
'image/object/mask'
:
tf
.
io
.
VarLenFeature
(
tf
.
string
),
...
@@ -151,6 +151,9 @@ class TfExampleDecoder(decoder.Decoder):
...
@@ -151,6 +151,9 @@ class TfExampleDecoder(decoder.Decoder):
if
self
.
_include_mask
:
if
self
.
_include_mask
:
masks
=
self
.
_decode_masks
(
parsed_tensors
)
masks
=
self
.
_decode_masks
(
parsed_tensors
)
if
self
.
_mask_binarize_threshold
is
not
None
:
masks
=
tf
.
cast
(
masks
>
self
.
_mask_binarize_threshold
,
tf
.
float32
)
decoded_tensors
=
{
decoded_tensors
=
{
'source_id'
:
source_id
,
'source_id'
:
source_id
,
'image'
:
image
,
'image'
:
image
,
...
@@ -167,44 +170,3 @@ class TfExampleDecoder(decoder.Decoder):
...
@@ -167,44 +170,3 @@ class TfExampleDecoder(decoder.Decoder):
'groundtruth_instance_masks_png'
:
parsed_tensors
[
'image/object/mask'
],
'groundtruth_instance_masks_png'
:
parsed_tensors
[
'image/object/mask'
],
})
})
return
decoded_tensors
return
decoded_tensors
class
TfExampleDecoderLabelMap
(
TfExampleDecoder
):
"""Tensorflow Example proto decoder."""
def
__init__
(
self
,
label_map
,
include_mask
=
False
,
regenerate_source_id
=
False
):
super
(
TfExampleDecoderLabelMap
,
self
).
__init__
(
include_mask
=
include_mask
,
regenerate_source_id
=
regenerate_source_id
)
self
.
_keys_to_features
.
update
({
'image/object/class/text'
:
tf
.
io
.
VarLenFeature
(
tf
.
string
),
})
name_to_id
=
self
.
_process_label_map
(
label_map
)
self
.
_name_to_id_table
=
tf
.
lookup
.
StaticHashTable
(
tf
.
lookup
.
KeyValueTensorInitializer
(
keys
=
tf
.
constant
(
list
(
name_to_id
.
keys
()),
dtype
=
tf
.
string
),
values
=
tf
.
constant
(
list
(
name_to_id
.
values
()),
dtype
=
tf
.
int64
)),
default_value
=-
1
)
def
_process_label_map
(
self
,
label_map
):
if
label_map
.
endswith
(
'.csv'
):
name_to_id
=
self
.
_process_csv
(
label_map
)
else
:
raise
ValueError
(
'The label map file is in incorrect format.'
)
return
name_to_id
def
_process_csv
(
self
,
label_map
):
name_to_id
=
{}
with
tf
.
io
.
gfile
.
GFile
(
label_map
,
'r'
)
as
f
:
reader
=
csv
.
reader
(
f
,
delimiter
=
','
)
for
row
in
reader
:
if
len
(
row
)
!=
2
:
raise
ValueError
(
'Each row of the csv label map file must be in '
'`id,name` format. length = {}'
.
format
(
len
(
row
)))
id_index
=
int
(
row
[
0
])
name
=
row
[
1
]
name_to_id
[
name
]
=
id_index
return
name_to_id
def
_decode_classes
(
self
,
parsed_tensors
):
return
self
.
_name_to_id_table
.
lookup
(
parsed_tensors
[
'image/object/class/text'
])
official/vision/beta/dataloaders/tf_example_label_map_decoder.py
View file @
44e7092c
...
@@ -27,9 +27,11 @@ from official.vision.beta.dataloaders import tf_example_decoder
...
@@ -27,9 +27,11 @@ from official.vision.beta.dataloaders import tf_example_decoder
class
TfExampleDecoderLabelMap
(
tf_example_decoder
.
TfExampleDecoder
):
class
TfExampleDecoderLabelMap
(
tf_example_decoder
.
TfExampleDecoder
):
"""Tensorflow Example proto decoder."""
"""Tensorflow Example proto decoder."""
def
__init__
(
self
,
label_map
,
include_mask
=
False
,
regenerate_source_id
=
False
):
def
__init__
(
self
,
label_map
,
include_mask
=
False
,
regenerate_source_id
=
False
,
mask_binarize_threshold
=
None
):
super
(
TfExampleDecoderLabelMap
,
self
).
__init__
(
super
(
TfExampleDecoderLabelMap
,
self
).
__init__
(
include_mask
=
include_mask
,
regenerate_source_id
=
regenerate_source_id
)
include_mask
=
include_mask
,
regenerate_source_id
=
regenerate_source_id
,
mask_binarize_threshold
=
mask_binarize_threshold
)
self
.
_keys_to_features
.
update
({
self
.
_keys_to_features
.
update
({
'image/object/class/text'
:
tf
.
io
.
VarLenFeature
(
tf
.
string
),
'image/object/class/text'
:
tf
.
io
.
VarLenFeature
(
tf
.
string
),
})
})
...
...
official/vision/beta/dataloaders/video_input.py
View file @
44e7092c
...
@@ -34,9 +34,14 @@ def _process_image(image: tf.Tensor,
...
@@ -34,9 +34,14 @@ def _process_image(image: tf.Tensor,
num_frames
:
int
=
32
,
num_frames
:
int
=
32
,
stride
:
int
=
1
,
stride
:
int
=
1
,
num_test_clips
:
int
=
1
,
num_test_clips
:
int
=
1
,
min_resize
:
int
=
224
,
min_resize
:
int
=
256
,
crop_size
:
int
=
200
,
crop_size
:
int
=
224
,
num_crops
:
int
=
1
,
zero_centering_image
:
bool
=
False
,
zero_centering_image
:
bool
=
False
,
min_aspect_ratio
:
float
=
0.5
,
max_aspect_ratio
:
float
=
2
,
min_area_ratio
:
float
=
0.49
,
max_area_ratio
:
float
=
1.0
,
seed
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
seed
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Processes a serialized image tensor.
"""Processes a serialized image tensor.
...
@@ -54,8 +59,13 @@ def _process_image(image: tf.Tensor,
...
@@ -54,8 +59,13 @@ def _process_image(image: tf.Tensor,
min_resize: Frames are resized so that min(height, width) is min_resize.
min_resize: Frames are resized so that min(height, width) is min_resize.
crop_size: Final size of the frame after cropping the resized frames. Both
crop_size: Final size of the frame after cropping the resized frames. Both
height and width are the same.
height and width are the same.
num_crops: Number of crops to perform on the resized frames.
zero_centering_image: If True, frames are normalized to values in [-1, 1].
zero_centering_image: If True, frames are normalized to values in [-1, 1].
If False, values in [0, 1].
If False, values in [0, 1].
min_aspect_ratio: The minimum aspect range for cropping.
max_aspect_ratio: The maximum aspect range for cropping.
min_area_ratio: The minimum area range for cropping.
max_area_ratio: The maximum area range for cropping.
seed: A deterministic seed to use when sampling.
seed: A deterministic seed to use when sampling.
Returns:
Returns:
...
@@ -84,17 +94,19 @@ def _process_image(image: tf.Tensor,
...
@@ -84,17 +94,19 @@ def _process_image(image: tf.Tensor,
# Decode JPEG string to tf.uint8.
# Decode JPEG string to tf.uint8.
image
=
preprocess_ops_3d
.
decode_jpeg
(
image
,
3
)
image
=
preprocess_ops_3d
.
decode_jpeg
(
image
,
3
)
# Resize images (resize happens only if necessary to save compute).
image
=
preprocess_ops_3d
.
resize_smallest
(
image
,
min_resize
)
if
is_training
:
if
is_training
:
# Standard image data augmentation: random crop and random flip.
# Standard image data augmentation: random resized crop and random flip.
image
=
preprocess_ops_3d
.
crop_image
(
image
,
crop_size
,
crop_size
,
True
,
image
=
preprocess_ops_3d
.
random_crop_resize
(
seed
)
image
,
crop_size
,
crop_size
,
num_frames
,
3
,
(
min_aspect_ratio
,
max_aspect_ratio
),
(
min_area_ratio
,
max_area_ratio
))
image
=
preprocess_ops_3d
.
random_flip_left_right
(
image
,
seed
)
image
=
preprocess_ops_3d
.
random_flip_left_right
(
image
,
seed
)
else
:
else
:
# Central crop of the frames.
# Resize images (resize happens only if necessary to save compute).
image
=
preprocess_ops_3d
.
crop_image
(
image
,
crop_size
,
crop_size
,
False
)
image
=
preprocess_ops_3d
.
resize_smallest
(
image
,
min_resize
)
# Crop of the frames.
image
=
preprocess_ops_3d
.
crop_image
(
image
,
crop_size
,
crop_size
,
False
,
num_crops
)
# Cast the frames in float32, normalizing according to zero_centering_image.
# Cast the frames in float32, normalizing according to zero_centering_image.
return
preprocess_ops_3d
.
normalize_image
(
image
,
zero_centering_image
)
return
preprocess_ops_3d
.
normalize_image
(
image
,
zero_centering_image
)
...
@@ -103,7 +115,8 @@ def _process_image(image: tf.Tensor,
...
@@ -103,7 +115,8 @@ def _process_image(image: tf.Tensor,
def
_postprocess_image
(
image
:
tf
.
Tensor
,
def
_postprocess_image
(
image
:
tf
.
Tensor
,
is_training
:
bool
=
True
,
is_training
:
bool
=
True
,
num_frames
:
int
=
32
,
num_frames
:
int
=
32
,
num_test_clips
:
int
=
1
)
->
tf
.
Tensor
:
num_test_clips
:
int
=
1
,
num_test_crops
:
int
=
1
)
->
tf
.
Tensor
:
"""Processes a batched Tensor of frames.
"""Processes a batched Tensor of frames.
The same parameters used in process should be used here.
The same parameters used in process should be used here.
...
@@ -117,15 +130,19 @@ def _postprocess_image(image: tf.Tensor,
...
@@ -117,15 +130,19 @@ def _postprocess_image(image: tf.Tensor,
will sample multiple linearly spaced clips within each video at test time.
will sample multiple linearly spaced clips within each video at test time.
If 1, then a single clip in the middle of the video is sampled. The clips
If 1, then a single clip in the middle of the video is sampled. The clips
are aggreagated in the batch dimension.
are aggreagated in the batch dimension.
num_test_crops: Number of test crops (1 by default). If more than 1, there
are multiple crops for each clip at test time. If 1, there is a single
central crop. The crops are aggreagated in the batch dimension.
Returns:
Returns:
Processed frames. Tensor of shape
Processed frames. Tensor of shape
[batch * num_test_clips, num_frames, height, width, 3].
[batch * num_test_clips
* num_test_crops
, num_frames, height, width, 3].
"""
"""
if
num_test_clips
>
1
and
not
is_training
:
num_views
=
num_test_clips
*
num_test_crops
# In this case, multiple clips are merged together in batch dimenstion which
if
num_views
>
1
and
not
is_training
:
# will be B * num_test_clips.
# In this case, multiple views are merged together in batch dimenstion which
image
=
tf
.
reshape
(
image
,
(
-
1
,
num_frames
)
+
image
.
shape
[
2
:])
# will be batch * num_views.
image
=
tf
.
reshape
(
image
,
[
-
1
,
num_frames
]
+
image
.
shape
[
2
:].
as_list
())
return
image
return
image
...
@@ -207,12 +224,17 @@ class Parser(parser.Parser):
...
@@ -207,12 +224,17 @@ class Parser(parser.Parser):
self
.
_num_test_clips
=
input_params
.
num_test_clips
self
.
_num_test_clips
=
input_params
.
num_test_clips
self
.
_min_resize
=
input_params
.
min_image_size
self
.
_min_resize
=
input_params
.
min_image_size
self
.
_crop_size
=
input_params
.
feature_shape
[
1
]
self
.
_crop_size
=
input_params
.
feature_shape
[
1
]
self
.
_num_crops
=
input_params
.
num_test_crops
self
.
_one_hot_label
=
input_params
.
one_hot
self
.
_one_hot_label
=
input_params
.
one_hot
self
.
_num_classes
=
input_params
.
num_classes
self
.
_num_classes
=
input_params
.
num_classes
self
.
_image_key
=
image_key
self
.
_image_key
=
image_key
self
.
_label_key
=
label_key
self
.
_label_key
=
label_key
self
.
_dtype
=
tf
.
dtypes
.
as_dtype
(
input_params
.
dtype
)
self
.
_dtype
=
tf
.
dtypes
.
as_dtype
(
input_params
.
dtype
)
self
.
_output_audio
=
input_params
.
output_audio
self
.
_output_audio
=
input_params
.
output_audio
self
.
_min_aspect_ratio
=
input_params
.
aug_min_aspect_ratio
self
.
_max_aspect_ratio
=
input_params
.
aug_max_aspect_ratio
self
.
_min_area_ratio
=
input_params
.
aug_min_area_ratio
self
.
_max_area_ratio
=
input_params
.
aug_max_area_ratio
if
self
.
_output_audio
:
if
self
.
_output_audio
:
self
.
_audio_feature
=
input_params
.
audio_feature
self
.
_audio_feature
=
input_params
.
audio_feature
self
.
_audio_shape
=
input_params
.
audio_feature_shape
self
.
_audio_shape
=
input_params
.
audio_feature_shape
...
@@ -230,7 +252,11 @@ class Parser(parser.Parser):
...
@@ -230,7 +252,11 @@ class Parser(parser.Parser):
stride
=
self
.
_stride
,
stride
=
self
.
_stride
,
num_test_clips
=
self
.
_num_test_clips
,
num_test_clips
=
self
.
_num_test_clips
,
min_resize
=
self
.
_min_resize
,
min_resize
=
self
.
_min_resize
,
crop_size
=
self
.
_crop_size
)
crop_size
=
self
.
_crop_size
,
min_aspect_ratio
=
self
.
_min_aspect_ratio
,
max_aspect_ratio
=
self
.
_max_aspect_ratio
,
min_area_ratio
=
self
.
_min_area_ratio
,
max_area_ratio
=
self
.
_max_area_ratio
)
image
=
tf
.
cast
(
image
,
dtype
=
self
.
_dtype
)
image
=
tf
.
cast
(
image
,
dtype
=
self
.
_dtype
)
features
=
{
'image'
:
image
}
features
=
{
'image'
:
image
}
...
@@ -260,7 +286,8 @@ class Parser(parser.Parser):
...
@@ -260,7 +286,8 @@ class Parser(parser.Parser):
stride
=
self
.
_stride
,
stride
=
self
.
_stride
,
num_test_clips
=
self
.
_num_test_clips
,
num_test_clips
=
self
.
_num_test_clips
,
min_resize
=
self
.
_min_resize
,
min_resize
=
self
.
_min_resize
,
crop_size
=
self
.
_crop_size
)
crop_size
=
self
.
_crop_size
,
num_crops
=
self
.
_num_crops
)
image
=
tf
.
cast
(
image
,
dtype
=
self
.
_dtype
)
image
=
tf
.
cast
(
image
,
dtype
=
self
.
_dtype
)
features
=
{
'image'
:
image
}
features
=
{
'image'
:
image
}
...
@@ -286,6 +313,7 @@ class PostBatchProcessor(object):
...
@@ -286,6 +313,7 @@ class PostBatchProcessor(object):
self
.
_num_frames
=
input_params
.
feature_shape
[
0
]
self
.
_num_frames
=
input_params
.
feature_shape
[
0
]
self
.
_num_test_clips
=
input_params
.
num_test_clips
self
.
_num_test_clips
=
input_params
.
num_test_clips
self
.
_num_test_crops
=
input_params
.
num_test_crops
def
__call__
(
self
,
features
:
Dict
[
str
,
tf
.
Tensor
],
def
__call__
(
self
,
features
:
Dict
[
str
,
tf
.
Tensor
],
label
:
tf
.
Tensor
)
->
Tuple
[
Dict
[
str
,
tf
.
Tensor
],
tf
.
Tensor
]:
label
:
tf
.
Tensor
)
->
Tuple
[
Dict
[
str
,
tf
.
Tensor
],
tf
.
Tensor
]:
...
@@ -296,6 +324,7 @@ class PostBatchProcessor(object):
...
@@ -296,6 +324,7 @@ class PostBatchProcessor(object):
image
=
features
[
key
],
image
=
features
[
key
],
is_training
=
self
.
_is_training
,
is_training
=
self
.
_is_training
,
num_frames
=
self
.
_num_frames
,
num_frames
=
self
.
_num_frames
,
num_test_clips
=
self
.
_num_test_clips
)
num_test_clips
=
self
.
_num_test_clips
,
num_test_crops
=
self
.
_num_test_crops
)
return
features
,
label
return
features
,
label
official/vision/beta/modeling/layers/nn_layers.py
View file @
44e7092c
...
@@ -150,7 +150,6 @@ class SqueezeExcitation(tf.keras.layers.Layer):
...
@@ -150,7 +150,6 @@ class SqueezeExcitation(tf.keras.layers.Layer):
'out_filters'
:
self
.
_out_filters
,
'out_filters'
:
self
.
_out_filters
,
'se_ratio'
:
self
.
_se_ratio
,
'se_ratio'
:
self
.
_se_ratio
,
'divisible_by'
:
self
.
_divisible_by
,
'divisible_by'
:
self
.
_divisible_by
,
'strides'
:
self
.
_strides
,
'kernel_initializer'
:
self
.
_kernel_initializer
,
'kernel_initializer'
:
self
.
_kernel_initializer
,
'kernel_regularizer'
:
self
.
_kernel_regularizer
,
'kernel_regularizer'
:
self
.
_kernel_regularizer
,
'bias_regularizer'
:
self
.
_bias_regularizer
,
'bias_regularizer'
:
self
.
_bias_regularizer
,
...
...
official/vision/beta/ops/anchor.py
View file @
44e7092c
...
@@ -203,7 +203,8 @@ class RpnAnchorLabeler(AnchorLabeler):
...
@@ -203,7 +203,8 @@ class RpnAnchorLabeler(AnchorLabeler):
unmatched_threshold
=
0.3
,
unmatched_threshold
=
0.3
,
rpn_batch_size_per_im
=
256
,
rpn_batch_size_per_im
=
256
,
rpn_fg_fraction
=
0.5
):
rpn_fg_fraction
=
0.5
):
AnchorLabeler
.
__init__
(
self
,
match_threshold
=
0.7
,
unmatched_threshold
=
0.3
)
AnchorLabeler
.
__init__
(
self
,
match_threshold
=
match_threshold
,
unmatched_threshold
=
unmatched_threshold
)
self
.
_rpn_batch_size_per_im
=
rpn_batch_size_per_im
self
.
_rpn_batch_size_per_im
=
rpn_batch_size_per_im
self
.
_rpn_fg_fraction
=
rpn_fg_fraction
self
.
_rpn_fg_fraction
=
rpn_fg_fraction
...
...
official/vision/beta/ops/preprocess_ops_3d.py
View file @
44e7092c
...
@@ -151,19 +151,19 @@ def crop_image(frames: tf.Tensor,
...
@@ -151,19 +151,19 @@ def crop_image(frames: tf.Tensor,
target_height
:
int
,
target_height
:
int
,
target_width
:
int
,
target_width
:
int
,
random
:
bool
=
False
,
random
:
bool
=
False
,
num_
view
s
:
int
=
1
,
num_
crop
s
:
int
=
1
,
seed
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
seed
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Crops the image sequence of images.
"""Crops the image sequence of images.
If requested size is bigger than image size, image is padded with 0. If not
If requested size is bigger than image size, image is padded with 0. If not
random cropping, a central crop is performed.
random cropping, a central crop is performed
if num_crops is 1
.
Args:
Args:
frames: A Tensor of dimension [timesteps, in_height, in_width, channels].
frames: A Tensor of dimension [timesteps, in_height, in_width, channels].
target_height: Target cropped image height.
target_height: Target cropped image height.
target_width: Target cropped image width.
target_width: Target cropped image width.
random: A boolean indicating if crop should be randomized.
random: A boolean indicating if crop should be randomized.
num_
view
s: Number of
views to crop in evaluation
.
num_
crop
s: Number of
crops (support 1 for central crop and 3 for 3-crop)
.
seed: A deterministic seed to use when random cropping.
seed: A deterministic seed to use when random cropping.
Returns:
Returns:
...
@@ -181,13 +181,13 @@ def crop_image(frames: tf.Tensor,
...
@@ -181,13 +181,13 @@ def crop_image(frames: tf.Tensor,
frames
=
tf
.
image
.
random_crop
(
frames
=
tf
.
image
.
random_crop
(
frames
,
(
seq_len
,
target_height
,
target_width
,
channels
),
seed
)
frames
,
(
seq_len
,
target_height
,
target_width
,
channels
),
seed
)
else
:
else
:
if
num_
view
s
==
1
:
if
num_
crop
s
==
1
:
# Central crop or pad.
# Central crop or pad.
frames
=
tf
.
image
.
resize_with_crop_or_pad
(
frames
,
target_height
,
frames
=
tf
.
image
.
resize_with_crop_or_pad
(
frames
,
target_height
,
target_width
)
target_width
)
elif
num_
view
s
==
3
:
elif
num_
crop
s
==
3
:
# Three-
view
evaluation.
# Three-
crop
evaluation.
shape
=
tf
.
shape
(
frames
)
shape
=
tf
.
shape
(
frames
)
static_shape
=
frames
.
shape
.
as_list
()
static_shape
=
frames
.
shape
.
as_list
()
seq_len
=
shape
[
0
]
if
static_shape
[
0
]
is
None
else
static_shape
[
0
]
seq_len
=
shape
[
0
]
if
static_shape
[
0
]
is
None
else
static_shape
[
0
]
...
@@ -224,7 +224,7 @@ def crop_image(frames: tf.Tensor,
...
@@ -224,7 +224,7 @@ def crop_image(frames: tf.Tensor,
else
:
else
:
raise
NotImplementedError
(
raise
NotImplementedError
(
f
"Only 1
crop and 3
crop are supported. Found
{
num_
view
s
!
r
}
."
)
f
"Only 1
-
crop and 3
-
crop are supported. Found
{
num_
crop
s
!
r
}
."
)
return
frames
return
frames
...
...
official/vision/beta/projects/video_ssl/configs/__init__.py
deleted
100644 → 0
View file @
431a9ca3
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Configs package definition."""
from
official.vision.beta.projects.video_ssl.configs
import
video_ssl
official/vision/beta/projects/video_ssl/configs/video_ssl.py
deleted
100644 → 0
View file @
431a9ca3
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Video classification configuration definition."""
import
dataclasses
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.vision.beta.configs
import
video_classification
Losses
=
video_classification
.
Losses
VideoClassificationModel
=
video_classification
.
VideoClassificationModel
VideoClassificationTask
=
video_classification
.
VideoClassificationTask
@
dataclasses
.
dataclass
class
DataConfig
(
video_classification
.
DataConfig
):
"""The base configuration for building datasets."""
is_ssl
:
bool
=
False
@
exp_factory
.
register_config_factory
(
'video_ssl_pretrain_kinetics400'
)
def
video_ssl_pretrain_kinetics400
()
->
cfg
.
ExperimentConfig
:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp
=
video_classification
.
video_classification_kinetics400
()
exp
.
task
.
train_data
=
DataConfig
(
is_ssl
=
True
,
**
exp
.
task
.
train_data
.
as_dict
())
exp
.
task
.
train_data
.
feature_shape
=
(
16
,
224
,
224
,
3
)
exp
.
task
.
train_data
.
temporal_stride
=
2
return
exp
@
exp_factory
.
register_config_factory
(
'video_ssl_linear_eval_kinetics400'
)
def
video_ssl_linear_eval_kinetics400
()
->
cfg
.
ExperimentConfig
:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp
=
video_classification
.
video_classification_kinetics400
()
exp
.
task
.
train_data
=
DataConfig
(
is_ssl
=
False
,
**
exp
.
task
.
train_data
.
as_dict
())
exp
.
task
.
train_data
.
feature_shape
=
(
32
,
224
,
224
,
3
)
exp
.
task
.
train_data
.
temporal_stride
=
2
exp
.
task
.
validation_data
.
feature_shape
=
(
32
,
256
,
256
,
3
)
exp
.
task
.
validation_data
.
temporal_stride
=
2
exp
.
task
.
validation_data
=
DataConfig
(
is_ssl
=
False
,
**
exp
.
task
.
validation_data
.
as_dict
())
exp
.
task
.
validation_data
.
min_image_size
=
256
exp
.
task
.
validation_data
.
num_test_clips
=
10
return
exp
@
exp_factory
.
register_config_factory
(
'video_ssl_pretrain_kinetics600'
)
def
video_ssl_pretrain_kinetics600
()
->
cfg
.
ExperimentConfig
:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp
=
video_classification
.
video_classification_kinetics600
()
exp
.
task
.
train_data
=
DataConfig
(
is_ssl
=
True
,
**
exp
.
task
.
train_data
.
as_dict
())
exp
.
task
.
train_data
.
feature_shape
=
(
16
,
224
,
224
,
3
)
exp
.
task
.
train_data
.
temporal_stride
=
2
return
exp
@
exp_factory
.
register_config_factory
(
'video_ssl_linear_eval_kinetics600'
)
def
video_ssl_linear_eval_kinetics600
()
->
cfg
.
ExperimentConfig
:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp
=
video_classification
.
video_classification_kinetics600
()
exp
.
task
.
train_data
=
DataConfig
(
is_ssl
=
False
,
**
exp
.
task
.
train_data
.
as_dict
())
exp
.
task
.
train_data
.
feature_shape
=
(
32
,
224
,
224
,
3
)
exp
.
task
.
train_data
.
temporal_stride
=
2
exp
.
task
.
validation_data
=
DataConfig
(
is_ssl
=
False
,
**
exp
.
task
.
validation_data
.
as_dict
())
exp
.
task
.
validation_data
.
feature_shape
=
(
32
,
256
,
256
,
3
)
exp
.
task
.
validation_data
.
temporal_stride
=
2
exp
.
task
.
validation_data
.
min_image_size
=
256
exp
.
task
.
validation_data
.
num_test_clips
=
10
return
exp
official/vision/beta/projects/video_ssl/configs/video_ssl_test.py
deleted
100644 → 0
View file @
431a9ca3
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=unused-import
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.vision
import
beta
from
official.vision.beta.projects.video_ssl.configs
import
video_ssl
as
exp_cfg
class
VideoClassificationConfigTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
((
'video_ssl_pretrain_kinetics400'
,),
(
'video_ssl_linear_eval_kinetics400'
,),
(
'video_ssl_pretrain_kinetics600'
,),
(
'video_ssl_linear_eval_kinetics600'
,))
def
test_video_classification_configs
(
self
,
config_name
):
config
=
exp_factory
.
get_exp_config
(
config_name
)
self
.
assertIsInstance
(
config
,
cfg
.
ExperimentConfig
)
self
.
assertIsInstance
(
config
.
task
,
exp_cfg
.
VideoClassificationTask
)
self
.
assertIsInstance
(
config
.
task
.
model
,
exp_cfg
.
VideoClassificationModel
)
self
.
assertIsInstance
(
config
.
task
.
train_data
,
exp_cfg
.
DataConfig
)
config
.
task
.
train_data
.
is_training
=
None
with
self
.
assertRaises
(
KeyError
):
config
.
validate
()
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/video_ssl/dataloaders/video_ssl_input.py
deleted
100644 → 0
View file @
431a9ca3
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Parser for video and label datasets."""
from
typing
import
Dict
,
Optional
,
Tuple
from
absl
import
logging
import
tensorflow
as
tf
from
official.vision.beta.dataloaders
import
video_input
from
official.vision.beta.ops
import
preprocess_ops_3d
from
official.vision.beta.projects.video_ssl.configs
import
video_ssl
as
exp_cfg
from
official.vision.beta.projects.video_ssl.ops
import
video_ssl_preprocess_ops
IMAGE_KEY
=
'image/encoded'
LABEL_KEY
=
'clip/label/index'
Decoder
=
video_input
.
Decoder
def
_process_image
(
image
:
tf
.
Tensor
,
is_training
:
bool
=
True
,
is_ssl
:
bool
=
False
,
num_frames
:
int
=
32
,
stride
:
int
=
1
,
num_test_clips
:
int
=
1
,
min_resize
:
int
=
224
,
crop_size
:
int
=
200
,
zero_centering_image
:
bool
=
False
,
seed
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Processes a serialized image tensor.
Args:
image: Input Tensor of shape [timesteps] and type tf.string of serialized
frames.
is_training: Whether or not in training mode. If True, random sample, crop
and left right flip is used.
is_ssl: Whether or not in self-supervised pre-training mode.
num_frames: Number of frames per subclip.
stride: Temporal stride to sample frames.
num_test_clips: Number of test clips (1 by default). If more than 1, this
will sample multiple linearly spaced clips within each video at test time.
If 1, then a single clip in the middle of the video is sampled. The clips
are aggreagated in the batch dimension.
min_resize: Frames are resized so that min(height, width) is min_resize.
crop_size: Final size of the frame after cropping the resized frames. Both
height and width are the same.
zero_centering_image: If True, frames are normalized to values in [-1, 1].
If False, values in [0, 1].
seed: A deterministic seed to use when sampling.
Returns:
Processed frames. Tensor of shape
[num_frames * num_test_clips, crop_size, crop_size, 3].
"""
# Validate parameters.
if
is_training
and
num_test_clips
!=
1
:
logging
.
warning
(
'`num_test_clips` %d is ignored since `is_training` is `True`.'
,
num_test_clips
)
# Temporal sampler.
if
is_training
:
# Sampler for training.
if
is_ssl
:
# Sample two clips from linear decreasing distribution.
image
=
video_ssl_preprocess_ops
.
sample_ssl_sequence
(
image
,
num_frames
,
True
,
stride
)
else
:
# Sample random clip.
image
=
preprocess_ops_3d
.
sample_sequence
(
image
,
num_frames
,
True
,
stride
)
else
:
# Sampler for evaluation.
if
num_test_clips
>
1
:
# Sample linspace clips.
image
=
preprocess_ops_3d
.
sample_linspace_sequence
(
image
,
num_test_clips
,
num_frames
,
stride
)
else
:
# Sample middle clip.
image
=
preprocess_ops_3d
.
sample_sequence
(
image
,
num_frames
,
False
,
stride
)
# Decode JPEG string to tf.uint8.
image
=
preprocess_ops_3d
.
decode_jpeg
(
image
,
3
)
if
is_training
:
# Standard image data augmentation: random resized crop and random flip.
if
is_ssl
:
image_1
,
image_2
=
tf
.
split
(
image
,
num_or_size_splits
=
2
,
axis
=
0
)
image_1
=
preprocess_ops_3d
.
random_crop_resize
(
image_1
,
crop_size
,
crop_size
,
num_frames
,
3
,
(
0.5
,
2
),
(
0.3
,
1
))
image_1
=
preprocess_ops_3d
.
random_flip_left_right
(
image_1
,
seed
)
image_2
=
preprocess_ops_3d
.
random_crop_resize
(
image_2
,
crop_size
,
crop_size
,
num_frames
,
3
,
(
0.5
,
2
),
(
0.3
,
1
))
image_2
=
preprocess_ops_3d
.
random_flip_left_right
(
image_2
,
seed
)
else
:
image
=
preprocess_ops_3d
.
random_crop_resize
(
image
,
crop_size
,
crop_size
,
num_frames
,
3
,
(
0.5
,
2
),
(
0.3
,
1
))
image
=
preprocess_ops_3d
.
random_flip_left_right
(
image
,
seed
)
else
:
# Resize images (resize happens only if necessary to save compute).
image
=
preprocess_ops_3d
.
resize_smallest
(
image
,
min_resize
)
# Three-crop of the frames.
image
=
preprocess_ops_3d
.
crop_image
(
image
,
min_resize
,
min_resize
,
False
,
3
)
# Cast the frames in float32, normalizing according to zero_centering_image.
if
is_training
and
is_ssl
:
image_1
=
preprocess_ops_3d
.
normalize_image
(
image_1
,
zero_centering_image
)
image_2
=
preprocess_ops_3d
.
normalize_image
(
image_2
,
zero_centering_image
)
else
:
image
=
preprocess_ops_3d
.
normalize_image
(
image
,
zero_centering_image
)
# Self-supervised pre-training augmentations.
if
is_training
and
is_ssl
:
# Temporally consistent color jittering.
image_1
=
video_ssl_preprocess_ops
.
random_color_jitter_3d
(
image_1
)
image_2
=
video_ssl_preprocess_ops
.
random_color_jitter_3d
(
image_2
)
# Temporally consistent gaussian blurring.
image_1
=
video_ssl_preprocess_ops
.
random_blur_3d
(
image_1
,
num_frames
,
crop_size
,
crop_size
)
image_2
=
video_ssl_preprocess_ops
.
random_blur_3d
(
image_2
,
num_frames
,
crop_size
,
crop_size
)
image
=
tf
.
concat
([
image_1
,
image_2
],
axis
=
0
)
return
image
def
_postprocess_image
(
image
:
tf
.
Tensor
,
is_training
:
bool
=
True
,
is_ssl
:
bool
=
False
,
num_frames
:
int
=
32
,
num_test_clips
:
int
=
1
)
->
tf
.
Tensor
:
"""Processes a batched Tensor of frames.
The same parameters used in process should be used here.
Args:
image: Input Tensor of shape [batch, timesteps, height, width, 3].
is_training: Whether or not in training mode. If True, random sample, crop
and left right flip is used.
is_ssl: Whether or not in self-supervised pre-training mode.
num_frames: Number of frames per subclip.
num_test_clips: Number of test clips (1 by default). If more than 1, this
will sample multiple linearly spaced clips within each video at test time.
If 1, then a single clip in the middle of the video is sampled. The clips
are aggreagated in the batch dimension.
Returns:
Processed frames. Tensor of shape
[batch * num_test_clips, num_frames, height, width, 3].
"""
if
is_ssl
and
is_training
:
# In this case, two clips of self-supervised pre-training are merged
# together in batch dimenstion which will be 2 * batch.
image
=
tf
.
concat
(
tf
.
split
(
image
,
num_or_size_splits
=
2
,
axis
=
1
),
axis
=
0
)
if
num_test_clips
>
1
and
not
is_training
:
# In this case, multiple clips are merged together in batch dimenstion which
# will be B * num_test_clips.
image
=
tf
.
reshape
(
image
,
(
-
1
,
num_frames
)
+
image
.
shape
[
2
:])
return
image
def
_process_label
(
label
:
tf
.
Tensor
,
one_hot_label
:
bool
=
True
,
num_classes
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Processes label Tensor."""
# Validate parameters.
if
one_hot_label
and
not
num_classes
:
raise
ValueError
(
'`num_classes` should be given when requesting one hot label.'
)
# Cast to tf.int32.
label
=
tf
.
cast
(
label
,
dtype
=
tf
.
int32
)
if
one_hot_label
:
# Replace label index by one hot representation.
label
=
tf
.
one_hot
(
label
,
num_classes
)
if
len
(
label
.
shape
.
as_list
())
>
1
:
label
=
tf
.
reduce_sum
(
label
,
axis
=
0
)
if
num_classes
==
1
:
# The trick for single label.
label
=
1
-
label
return
label
class
Parser
(
video_input
.
Parser
):
"""Parses a video and label dataset."""
def
__init__
(
self
,
input_params
:
exp_cfg
.
DataConfig
,
image_key
:
str
=
IMAGE_KEY
,
label_key
:
str
=
LABEL_KEY
):
super
(
Parser
,
self
).
__init__
(
input_params
,
image_key
,
label_key
)
self
.
_is_ssl
=
input_params
.
is_ssl
def
_parse_train_data
(
self
,
decoded_tensors
:
Dict
[
str
,
tf
.
Tensor
]
)
->
Tuple
[
Dict
[
str
,
tf
.
Tensor
],
tf
.
Tensor
]:
"""Parses data for training."""
# Process image and label.
image
=
decoded_tensors
[
self
.
_image_key
]
image
=
_process_image
(
image
=
image
,
is_training
=
True
,
is_ssl
=
self
.
_is_ssl
,
num_frames
=
self
.
_num_frames
,
stride
=
self
.
_stride
,
num_test_clips
=
self
.
_num_test_clips
,
min_resize
=
self
.
_min_resize
,
crop_size
=
self
.
_crop_size
)
image
=
tf
.
cast
(
image
,
dtype
=
self
.
_dtype
)
features
=
{
'image'
:
image
}
label
=
decoded_tensors
[
self
.
_label_key
]
label
=
_process_label
(
label
,
self
.
_one_hot_label
,
self
.
_num_classes
)
return
features
,
label
def
_parse_eval_data
(
self
,
decoded_tensors
:
Dict
[
str
,
tf
.
Tensor
]
)
->
Tuple
[
Dict
[
str
,
tf
.
Tensor
],
tf
.
Tensor
]:
"""Parses data for evaluation."""
image
=
decoded_tensors
[
self
.
_image_key
]
image
=
_process_image
(
image
=
image
,
is_training
=
False
,
num_frames
=
self
.
_num_frames
,
stride
=
self
.
_stride
,
num_test_clips
=
self
.
_num_test_clips
,
min_resize
=
self
.
_min_resize
,
crop_size
=
self
.
_crop_size
)
image
=
tf
.
cast
(
image
,
dtype
=
self
.
_dtype
)
features
=
{
'image'
:
image
}
label
=
decoded_tensors
[
self
.
_label_key
]
label
=
_process_label
(
label
,
self
.
_one_hot_label
,
self
.
_num_classes
)
if
self
.
_output_audio
:
audio
=
decoded_tensors
[
self
.
_audio_feature
]
audio
=
tf
.
cast
(
audio
,
dtype
=
self
.
_dtype
)
audio
=
preprocess_ops_3d
.
sample_sequence
(
audio
,
20
,
random
=
False
,
stride
=
1
)
audio
=
tf
.
ensure_shape
(
audio
,
[
20
,
2048
])
features
[
'audio'
]
=
audio
return
features
,
label
def
parse_fn
(
self
,
is_training
):
"""Returns a parse fn that reads and parses raw tensors from the decoder.
Args:
is_training: a `bool` to indicate whether it is in training mode.
Returns:
parse: a `callable` that takes the serialized examle and generate the
images, labels tuple where labels is a dict of Tensors that contains
labels.
"""
def
parse
(
decoded_tensors
):
"""Parses the serialized example data."""
if
is_training
:
return
self
.
_parse_train_data
(
decoded_tensors
)
else
:
return
self
.
_parse_eval_data
(
decoded_tensors
)
return
parse
class
PostBatchProcessor
(
object
):
"""Processes a video and label dataset which is batched."""
def
__init__
(
self
,
input_params
:
exp_cfg
.
DataConfig
):
self
.
_is_training
=
input_params
.
is_training
self
.
_is_ssl
=
input_params
.
is_ssl
self
.
_num_frames
=
input_params
.
feature_shape
[
0
]
self
.
_num_test_clips
=
input_params
.
num_test_clips
def
__call__
(
self
,
features
:
Dict
[
str
,
tf
.
Tensor
],
label
:
tf
.
Tensor
)
->
Tuple
[
Dict
[
str
,
tf
.
Tensor
],
tf
.
Tensor
]:
"""Parses a single tf.Example into image and label tensors."""
for
key
in
[
'image'
,
'audio'
]:
if
key
in
features
:
features
[
key
]
=
_postprocess_image
(
image
=
features
[
key
],
is_training
=
self
.
_is_training
,
is_ssl
=
self
.
_is_ssl
,
num_frames
=
self
.
_num_frames
,
num_test_clips
=
self
.
_num_test_clips
)
return
features
,
label
official/vision/beta/projects/video_ssl/dataloaders/video_ssl_input_test.py
deleted
100644 → 0
View file @
431a9ca3
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
io
# Import libraries
import
numpy
as
np
from
PIL
import
Image
import
tensorflow
as
tf
from
official.vision.beta.projects.video_ssl.configs
import
video_ssl
as
exp_cfg
from
official.vision.beta.projects.video_ssl.dataloaders
import
video_ssl_input
AUDIO_KEY
=
'features/audio'
def
fake_seq_example
():
# Create fake data.
random_image
=
np
.
random
.
randint
(
0
,
256
,
size
=
(
263
,
320
,
3
),
dtype
=
np
.
uint8
)
random_image
=
Image
.
fromarray
(
random_image
)
label
=
42
with
io
.
BytesIO
()
as
buffer
:
random_image
.
save
(
buffer
,
format
=
'JPEG'
)
raw_image_bytes
=
buffer
.
getvalue
()
seq_example
=
tf
.
train
.
SequenceExample
()
seq_example
.
feature_lists
.
feature_list
.
get_or_create
(
video_ssl_input
.
IMAGE_KEY
).
feature
.
add
().
bytes_list
.
value
[:]
=
[
raw_image_bytes
]
seq_example
.
feature_lists
.
feature_list
.
get_or_create
(
video_ssl_input
.
IMAGE_KEY
).
feature
.
add
().
bytes_list
.
value
[:]
=
[
raw_image_bytes
]
seq_example
.
context
.
feature
[
video_ssl_input
.
LABEL_KEY
].
int64_list
.
value
[:]
=
[
label
]
random_audio
=
np
.
random
.
normal
(
size
=
(
10
,
256
)).
tolist
()
for
s
in
random_audio
:
seq_example
.
feature_lists
.
feature_list
.
get_or_create
(
AUDIO_KEY
).
feature
.
add
().
float_list
.
value
[:]
=
s
return
seq_example
,
label
class
VideoAndLabelParserTest
(
tf
.
test
.
TestCase
):
def
test_video_ssl_input_pretrain
(
self
):
params
=
exp_cfg
.
video_ssl_pretrain_kinetics600
().
task
.
train_data
decoder
=
video_ssl_input
.
Decoder
()
parser
=
video_ssl_input
.
Parser
(
params
).
parse_fn
(
params
.
is_training
)
seq_example
,
_
=
fake_seq_example
()
input_tensor
=
tf
.
constant
(
seq_example
.
SerializeToString
())
decoded_tensors
=
decoder
.
decode
(
input_tensor
)
output_tensor
=
parser
(
decoded_tensors
)
image_features
,
_
=
output_tensor
image
=
image_features
[
'image'
]
self
.
assertAllEqual
(
image
.
shape
,
(
32
,
224
,
224
,
3
))
def
test_video_ssl_input_linear_train
(
self
):
params
=
exp_cfg
.
video_ssl_linear_eval_kinetics600
().
task
.
train_data
decoder
=
video_ssl_input
.
Decoder
()
parser
=
video_ssl_input
.
Parser
(
params
).
parse_fn
(
params
.
is_training
)
seq_example
,
label
=
fake_seq_example
()
input_tensor
=
tf
.
constant
(
seq_example
.
SerializeToString
())
decoded_tensors
=
decoder
.
decode
(
input_tensor
)
output_tensor
=
parser
(
decoded_tensors
)
image_features
,
label
=
output_tensor
image
=
image_features
[
'image'
]
self
.
assertAllEqual
(
image
.
shape
,
(
32
,
224
,
224
,
3
))
self
.
assertAllEqual
(
label
.
shape
,
(
600
,))
def
test_video_ssl_input_linear_eval
(
self
):
params
=
exp_cfg
.
video_ssl_linear_eval_kinetics600
().
task
.
validation_data
print
(
'!!!'
,
params
)
decoder
=
video_ssl_input
.
Decoder
()
parser
=
video_ssl_input
.
Parser
(
params
).
parse_fn
(
params
.
is_training
)
seq_example
,
label
=
fake_seq_example
()
input_tensor
=
tf
.
constant
(
seq_example
.
SerializeToString
())
decoded_tensors
=
decoder
.
decode
(
input_tensor
)
output_tensor
=
parser
(
decoded_tensors
)
image_features
,
label
=
output_tensor
image
=
image_features
[
'image'
]
self
.
assertAllEqual
(
image
.
shape
,
(
960
,
256
,
256
,
3
))
self
.
assertAllEqual
(
label
.
shape
,
(
600
,))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/video_ssl/ops/video_ssl_preprocess_ops.py
deleted
100644 → 0
View file @
431a9ca3
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils for customed ops for video ssl."""
import
functools
from
typing
import
Optional
import
tensorflow
as
tf
def
random_apply
(
func
,
p
,
x
):
"""Randomly apply function func to x with probability p."""
return
tf
.
cond
(
tf
.
less
(
tf
.
random
.
uniform
([],
minval
=
0
,
maxval
=
1
,
dtype
=
tf
.
float32
),
tf
.
cast
(
p
,
tf
.
float32
)),
lambda
:
func
(
x
),
lambda
:
x
)
def
random_brightness
(
image
,
max_delta
):
"""Distort brightness of image (SimCLRv2 style)."""
factor
=
tf
.
random
.
uniform
(
[],
tf
.
maximum
(
1.0
-
max_delta
,
0
),
1.0
+
max_delta
)
image
=
image
*
factor
return
image
def
to_grayscale
(
image
,
keep_channels
=
True
):
"""Turn the input image to gray scale.
Args:
image: The input image tensor.
keep_channels: Whether maintaining the channel number for the image.
If true, the transformed image will repeat three times in channel.
If false, the transformed image will only have one channel.
Returns:
The distorted image tensor.
"""
image
=
tf
.
image
.
rgb_to_grayscale
(
image
)
if
keep_channels
:
image
=
tf
.
tile
(
image
,
[
1
,
1
,
3
])
return
image
def
color_jitter
(
image
,
strength
,
random_order
=
True
):
"""Distorts the color of the image (SimCLRv2 style).
Args:
image: The input image tensor.
strength: The floating number for the strength of the color augmentation.
random_order: A bool, specifying whether to randomize the jittering order.
Returns:
The distorted image tensor.
"""
brightness
=
0.8
*
strength
contrast
=
0.8
*
strength
saturation
=
0.8
*
strength
hue
=
0.2
*
strength
if
random_order
:
return
color_jitter_rand
(
image
,
brightness
,
contrast
,
saturation
,
hue
)
else
:
return
color_jitter_nonrand
(
image
,
brightness
,
contrast
,
saturation
,
hue
)
def
color_jitter_nonrand
(
image
,
brightness
=
0
,
contrast
=
0
,
saturation
=
0
,
hue
=
0
):
"""Distorts the color of the image (jittering order is fixed, SimCLRv2 style).
Args:
image: The input image tensor.
brightness: A float, specifying the brightness for color jitter.
contrast: A float, specifying the contrast for color jitter.
saturation: A float, specifying the saturation for color jitter.
hue: A float, specifying the hue for color jitter.
Returns:
The distorted image tensor.
"""
with
tf
.
name_scope
(
'distort_color'
):
def
apply_transform
(
i
,
x
,
brightness
,
contrast
,
saturation
,
hue
):
"""Apply the i-th transformation."""
if
brightness
!=
0
and
i
==
0
:
x
=
random_brightness
(
x
,
max_delta
=
brightness
)
elif
contrast
!=
0
and
i
==
1
:
x
=
tf
.
image
.
random_contrast
(
x
,
lower
=
1
-
contrast
,
upper
=
1
+
contrast
)
elif
saturation
!=
0
and
i
==
2
:
x
=
tf
.
image
.
random_saturation
(
x
,
lower
=
1
-
saturation
,
upper
=
1
+
saturation
)
elif
hue
!=
0
:
x
=
tf
.
image
.
random_hue
(
x
,
max_delta
=
hue
)
return
x
for
i
in
range
(
4
):
image
=
apply_transform
(
i
,
image
,
brightness
,
contrast
,
saturation
,
hue
)
image
=
tf
.
clip_by_value
(
image
,
0.
,
1.
)
return
image
def
color_jitter_rand
(
image
,
brightness
=
0
,
contrast
=
0
,
saturation
=
0
,
hue
=
0
):
"""Distorts the color of the image (jittering order is random, SimCLRv2 style).
Args:
image: The input image tensor.
brightness: A float, specifying the brightness for color jitter.
contrast: A float, specifying the contrast for color jitter.
saturation: A float, specifying the saturation for color jitter.
hue: A float, specifying the hue for color jitter.
Returns:
The distorted image tensor.
"""
with
tf
.
name_scope
(
'distort_color'
):
def
apply_transform
(
i
,
x
):
"""Apply the i-th transformation."""
def
brightness_transform
():
if
brightness
==
0
:
return
x
else
:
return
random_brightness
(
x
,
max_delta
=
brightness
)
def
contrast_transform
():
if
contrast
==
0
:
return
x
else
:
return
tf
.
image
.
random_contrast
(
x
,
lower
=
1
-
contrast
,
upper
=
1
+
contrast
)
def
saturation_transform
():
if
saturation
==
0
:
return
x
else
:
return
tf
.
image
.
random_saturation
(
x
,
lower
=
1
-
saturation
,
upper
=
1
+
saturation
)
def
hue_transform
():
if
hue
==
0
:
return
x
else
:
return
tf
.
image
.
random_hue
(
x
,
max_delta
=
hue
)
# pylint:disable=g-long-lambda
x
=
tf
.
cond
(
tf
.
less
(
i
,
2
),
lambda
:
tf
.
cond
(
tf
.
less
(
i
,
1
),
brightness_transform
,
contrast_transform
),
lambda
:
tf
.
cond
(
tf
.
less
(
i
,
3
),
saturation_transform
,
hue_transform
))
# pylint:disable=g-long-lambda
return
x
perm
=
tf
.
random
.
shuffle
(
tf
.
range
(
4
))
for
i
in
range
(
4
):
image
=
apply_transform
(
perm
[
i
],
image
)
image
=
tf
.
clip_by_value
(
image
,
0.
,
1.
)
return
image
def
random_color_jitter_3d
(
frames
):
"""Applies temporally consistent color jittering to one video clip.
Args:
frames: `Tensor` of shape [num_frames, height, width, channels].
Returns:
A Tensor of shape [num_frames, height, width, channels] being color jittered
with the same operation.
"""
def
random_color_jitter
(
image
,
p
=
1.0
):
def
_transform
(
image
):
color_jitter_t
=
functools
.
partial
(
color_jitter
,
strength
=
1.0
)
image
=
random_apply
(
color_jitter_t
,
p
=
0.8
,
x
=
image
)
return
random_apply
(
to_grayscale
,
p
=
0.2
,
x
=
image
)
return
random_apply
(
_transform
,
p
=
p
,
x
=
image
)
num_frames
,
width
,
height
,
channels
=
tf
.
shape
(
frames
)
big_image
=
tf
.
reshape
(
frames
,
[
num_frames
*
width
,
height
,
channels
])
big_image
=
random_color_jitter
(
big_image
)
return
tf
.
reshape
(
big_image
,
[
num_frames
,
width
,
height
,
channels
])
def
gaussian_blur
(
image
,
kernel_size
,
sigma
,
padding
=
'SAME'
):
"""Blurs the given image with separable convolution.
Args:
image: Tensor of shape [height, width, channels] and dtype float to blur.
kernel_size: Integer Tensor for the size of the blur kernel. This is should
be an odd number. If it is an even number, the actual kernel size will be
size + 1.
sigma: Sigma value for gaussian operator.
padding: Padding to use for the convolution. Typically 'SAME' or 'VALID'.
Returns:
A Tensor representing the blurred image.
"""
radius
=
tf
.
cast
(
kernel_size
/
2
,
dtype
=
tf
.
int32
)
kernel_size
=
radius
*
2
+
1
x
=
tf
.
cast
(
tf
.
range
(
-
radius
,
radius
+
1
),
dtype
=
tf
.
float32
)
blur_filter
=
tf
.
exp
(
-
tf
.
pow
(
x
,
2.0
)
/
(
2.0
*
tf
.
pow
(
tf
.
cast
(
sigma
,
dtype
=
tf
.
float32
),
2.0
)))
blur_filter
/=
tf
.
reduce_sum
(
blur_filter
)
# One vertical and one horizontal filter.
blur_v
=
tf
.
reshape
(
blur_filter
,
[
kernel_size
,
1
,
1
,
1
])
blur_h
=
tf
.
reshape
(
blur_filter
,
[
1
,
kernel_size
,
1
,
1
])
num_channels
=
tf
.
shape
(
image
)[
-
1
]
blur_h
=
tf
.
tile
(
blur_h
,
[
1
,
1
,
num_channels
,
1
])
blur_v
=
tf
.
tile
(
blur_v
,
[
1
,
1
,
num_channels
,
1
])
expand_batch_dim
=
image
.
shape
.
ndims
==
3
if
expand_batch_dim
:
# Tensorflow requires batched input to convolutions, which we can fake with
# an extra dimension.
image
=
tf
.
expand_dims
(
image
,
axis
=
0
)
blurred
=
tf
.
nn
.
depthwise_conv2d
(
image
,
blur_h
,
strides
=
[
1
,
1
,
1
,
1
],
padding
=
padding
)
blurred
=
tf
.
nn
.
depthwise_conv2d
(
blurred
,
blur_v
,
strides
=
[
1
,
1
,
1
,
1
],
padding
=
padding
)
if
expand_batch_dim
:
blurred
=
tf
.
squeeze
(
blurred
,
axis
=
0
)
return
blurred
def
random_blur
(
image
,
height
,
width
,
p
=
1.0
):
"""Randomly blur an image.
Args:
image: `Tensor` representing an image of arbitrary size.
height: Height of output image.
width: Width of output image.
p: probability of applying this transformation.
Returns:
A preprocessed image `Tensor`.
"""
del
width
def
_transform
(
image
):
sigma
=
tf
.
random
.
uniform
([],
0.1
,
2.0
,
dtype
=
tf
.
float32
)
return
gaussian_blur
(
image
,
kernel_size
=
height
//
10
,
sigma
=
sigma
,
padding
=
'SAME'
)
return
random_apply
(
_transform
,
p
=
p
,
x
=
image
)
def
random_blur_3d
(
frames
,
height
,
width
,
blur_probability
=
0.5
):
"""Apply efficient batch data transformations.
Args:
frames: `Tensor` of shape [timesteps, height, width, 3].
height: the height of image.
width: the width of image.
blur_probability: the probaility to apply the blur operator.
Returns:
Preprocessed feature list.
"""
def
generate_selector
(
p
,
bsz
):
shape
=
[
bsz
,
1
,
1
,
1
]
selector
=
tf
.
cast
(
tf
.
less
(
tf
.
random
.
uniform
(
shape
,
0
,
1
,
dtype
=
tf
.
float32
),
p
),
tf
.
float32
)
return
selector
frames_new
=
random_blur
(
frames
,
height
,
width
,
p
=
1.
)
selector
=
generate_selector
(
blur_probability
,
1
)
frames
=
frames_new
*
selector
+
frames
*
(
1
-
selector
)
frames
=
tf
.
clip_by_value
(
frames
,
0.
,
1.
)
return
frames
def
_sample_or_pad_sequence_indices
(
sequence
:
tf
.
Tensor
,
num_steps
:
int
,
stride
:
int
,
offset
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Returns indices to take for sampling or padding sequences to fixed size."""
sequence_length
=
tf
.
shape
(
sequence
)[
0
]
sel_idx
=
tf
.
range
(
sequence_length
)
# Repeats sequence until num_steps are available in total.
max_length
=
num_steps
*
stride
+
offset
num_repeats
=
tf
.
math
.
floordiv
(
max_length
+
sequence_length
-
1
,
sequence_length
)
sel_idx
=
tf
.
tile
(
sel_idx
,
[
num_repeats
])
steps
=
tf
.
range
(
offset
,
offset
+
num_steps
*
stride
,
stride
)
return
tf
.
gather
(
sel_idx
,
steps
)
def
sample_ssl_sequence
(
sequence
:
tf
.
Tensor
,
num_steps
:
int
,
random
:
bool
,
stride
:
int
=
1
,
num_windows
:
Optional
[
int
]
=
2
)
->
tf
.
Tensor
:
"""Samples two segments of size num_steps randomly from a given sequence.
Currently it only supports images, and specically designed for video self-
supervised learning.
Args:
sequence: Any tensor where the first dimension is timesteps.
num_steps: Number of steps (e.g. frames) to take.
random: A boolean indicating whether to random sample the single window. If
True, the offset is randomized. Only True is supported.
stride: Distance to sample between timesteps.
num_windows: Number of sequence sampled.
Returns:
A single Tensor with first dimension num_steps with the sampled segment.
"""
sequence_length
=
tf
.
shape
(
sequence
)[
0
]
sequence_length
=
tf
.
cast
(
sequence_length
,
tf
.
float32
)
if
random
:
max_offset
=
tf
.
cond
(
tf
.
greater
(
sequence_length
,
(
num_steps
-
1
)
*
stride
),
lambda
:
sequence_length
-
(
num_steps
-
1
)
*
stride
,
lambda
:
sequence_length
)
max_offset
=
tf
.
cast
(
max_offset
,
dtype
=
tf
.
float32
)
def
cdf
(
k
,
power
=
1.0
):
"""Cumulative distribution function for x^power."""
p
=
-
tf
.
math
.
pow
(
k
,
power
+
1
)
/
(
power
*
tf
.
math
.
pow
(
max_offset
,
power
+
1
))
+
k
*
(
power
+
1
)
/
(
power
*
max_offset
)
return
p
u
=
tf
.
random
.
uniform
(())
k_low
=
tf
.
constant
(
0
,
dtype
=
tf
.
float32
)
k_up
=
max_offset
k
=
tf
.
math
.
floordiv
(
max_offset
,
2.0
)
c
=
lambda
k_low
,
k_up
,
k
:
tf
.
greater
(
tf
.
math
.
abs
(
k_up
-
k_low
),
1.0
)
# pylint:disable=g-long-lambda
b
=
lambda
k_low
,
k_up
,
k
:
tf
.
cond
(
tf
.
greater
(
cdf
(
k
),
u
),
lambda
:
[
k_low
,
k
,
tf
.
math
.
floordiv
(
k
+
k_low
,
2.0
)],
lambda
:
[
k
,
k_up
,
tf
.
math
.
floordiv
(
k_up
+
k
,
2.0
)])
_
,
_
,
k
=
tf
.
while_loop
(
c
,
b
,
[
k_low
,
k_up
,
k
])
delta
=
tf
.
cast
(
k
,
tf
.
int32
)
max_offset
=
tf
.
cast
(
max_offset
,
tf
.
int32
)
sequence_length
=
tf
.
cast
(
sequence_length
,
tf
.
int32
)
choice_1
=
tf
.
cond
(
tf
.
equal
(
max_offset
,
sequence_length
),
lambda
:
tf
.
random
.
uniform
((),
maxval
=
tf
.
cast
(
max_offset
,
dtype
=
tf
.
int32
),
dtype
=
tf
.
int32
),
lambda
:
tf
.
random
.
uniform
((),
maxval
=
tf
.
cast
(
max_offset
-
delta
,
dtype
=
tf
.
int32
),
dtype
=
tf
.
int32
))
choice_2
=
tf
.
cond
(
tf
.
equal
(
max_offset
,
sequence_length
),
lambda
:
tf
.
random
.
uniform
((),
maxval
=
tf
.
cast
(
max_offset
,
dtype
=
tf
.
int32
),
dtype
=
tf
.
int32
),
lambda
:
choice_1
+
delta
)
# pylint:disable=g-long-lambda
shuffle_choice
=
tf
.
random
.
shuffle
((
choice_1
,
choice_2
))
offset_1
=
shuffle_choice
[
0
]
offset_2
=
shuffle_choice
[
1
]
else
:
raise
NotImplementedError
indices_1
=
_sample_or_pad_sequence_indices
(
sequence
=
sequence
,
num_steps
=
num_steps
,
stride
=
stride
,
offset
=
offset_1
)
indices_2
=
_sample_or_pad_sequence_indices
(
sequence
=
sequence
,
num_steps
=
num_steps
,
stride
=
stride
,
offset
=
offset_2
)
indices
=
tf
.
concat
([
indices_1
,
indices_2
],
axis
=
0
)
indices
.
set_shape
((
num_windows
*
num_steps
,))
output
=
tf
.
gather
(
sequence
,
indices
)
return
output
official/vision/beta/projects/video_ssl/ops/video_ssl_preprocess_ops_test.py
deleted
100644 → 0
View file @
431a9ca3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
tensorflow
as
tf
from
official.vision.beta.ops
import
preprocess_ops_3d
from
official.vision.beta.projects.video_ssl.ops
import
video_ssl_preprocess_ops
class
VideoSslPreprocessOpsTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
().
setUp
()
self
.
_raw_frames
=
tf
.
random
.
uniform
((
250
,
256
,
256
,
3
),
minval
=
0
,
maxval
=
255
,
dtype
=
tf
.
dtypes
.
int32
)
self
.
_sampled_frames
=
self
.
_raw_frames
[:
16
]
self
.
_frames
=
preprocess_ops_3d
.
normalize_image
(
self
.
_sampled_frames
,
False
,
tf
.
float32
)
def
test_sample_ssl_sequence
(
self
):
sampled_seq
=
video_ssl_preprocess_ops
.
sample_ssl_sequence
(
self
.
_raw_frames
,
16
,
True
,
2
)
self
.
assertAllEqual
(
sampled_seq
.
shape
,
(
32
,
256
,
256
,
3
))
def
test_random_color_jitter_3d
(
self
):
jittered_clip
=
video_ssl_preprocess_ops
.
random_color_jitter_3d
(
self
.
_frames
)
self
.
assertAllEqual
(
jittered_clip
.
shape
,
(
16
,
256
,
256
,
3
))
def
test_random_blur_3d
(
self
):
blurred_clip
=
video_ssl_preprocess_ops
.
random_blur_3d
(
self
.
_frames
,
256
,
256
)
self
.
assertAllEqual
(
blurred_clip
.
shape
,
(
16
,
256
,
256
,
3
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/yolo/train.py
View file @
44e7092c
...
@@ -67,6 +67,8 @@ def main(_):
...
@@ -67,6 +67,8 @@ def main(_):
params
=
params
,
params
=
params
,
model_dir
=
model_dir
)
model_dir
=
model_dir
)
train_utils
.
save_gin_config
(
FLAGS
.
mode
,
model_dir
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
tfm_flags
.
define_flags
()
app
.
run
(
main
)
app
.
run
(
main
)
official/vision/beta/serving/export_saved_model_lib.py
View file @
44e7092c
...
@@ -17,7 +17,7 @@ r"""Vision models export utility function for serving/inference."""
...
@@ -17,7 +17,7 @@ r"""Vision models export utility function for serving/inference."""
import
os
import
os
import
tensorflow
.compat.v2
as
tf
import
tensorflow
as
tf
from
official.core
import
train_utils
from
official.core
import
train_utils
from
official.vision.beta
import
configs
from
official.vision.beta
import
configs
...
...
official/vision/beta/tasks/maskrcnn.py
View file @
44e7092c
...
@@ -23,6 +23,7 @@ from official.core import task_factory
...
@@ -23,6 +23,7 @@ from official.core import task_factory
from
official.vision.beta.configs
import
maskrcnn
as
exp_cfg
from
official.vision.beta.configs
import
maskrcnn
as
exp_cfg
from
official.vision.beta.dataloaders
import
maskrcnn_input
from
official.vision.beta.dataloaders
import
maskrcnn_input
from
official.vision.beta.dataloaders
import
tf_example_decoder
from
official.vision.beta.dataloaders
import
tf_example_decoder
from
official.vision.beta.dataloaders
import
dataset_fn
from
official.vision.beta.dataloaders
import
tf_example_label_map_decoder
from
official.vision.beta.dataloaders
import
tf_example_label_map_decoder
from
official.vision.beta.evaluation
import
coco_evaluator
from
official.vision.beta.evaluation
import
coco_evaluator
from
official.vision.beta.losses
import
maskrcnn_losses
from
official.vision.beta.losses
import
maskrcnn_losses
...
@@ -110,12 +111,14 @@ class MaskRCNNTask(base_task.Task):
...
@@ -110,12 +111,14 @@ class MaskRCNNTask(base_task.Task):
if
params
.
decoder
.
type
==
'simple_decoder'
:
if
params
.
decoder
.
type
==
'simple_decoder'
:
decoder
=
tf_example_decoder
.
TfExampleDecoder
(
decoder
=
tf_example_decoder
.
TfExampleDecoder
(
include_mask
=
self
.
_task_config
.
model
.
include_mask
,
include_mask
=
self
.
_task_config
.
model
.
include_mask
,
regenerate_source_id
=
decoder_cfg
.
regenerate_source_id
)
regenerate_source_id
=
decoder_cfg
.
regenerate_source_id
,
mask_binarize_threshold
=
decoder_cfg
.
mask_binarize_threshold
)
elif
params
.
decoder
.
type
==
'label_map_decoder'
:
elif
params
.
decoder
.
type
==
'label_map_decoder'
:
decoder
=
tf_example_label_map_decoder
.
TfExampleDecoderLabelMap
(
decoder
=
tf_example_label_map_decoder
.
TfExampleDecoderLabelMap
(
label_map
=
decoder_cfg
.
label_map
,
label_map
=
decoder_cfg
.
label_map
,
include_mask
=
self
.
_task_config
.
model
.
include_mask
,
include_mask
=
self
.
_task_config
.
model
.
include_mask
,
regenerate_source_id
=
decoder_cfg
.
regenerate_source_id
)
regenerate_source_id
=
decoder_cfg
.
regenerate_source_id
,
mask_binarize_threshold
=
decoder_cfg
.
mask_binarize_threshold
)
else
:
else
:
raise
ValueError
(
'Unknown decoder type: {}!'
.
format
(
params
.
decoder
.
type
))
raise
ValueError
(
'Unknown decoder type: {}!'
.
format
(
params
.
decoder
.
type
))
...
@@ -141,7 +144,7 @@ class MaskRCNNTask(base_task.Task):
...
@@ -141,7 +144,7 @@ class MaskRCNNTask(base_task.Task):
reader
=
input_reader
.
InputReader
(
reader
=
input_reader
.
InputReader
(
params
,
params
,
dataset_fn
=
tf
.
data
.
TFRecordDataset
,
dataset_fn
=
dataset_fn
.
pick_dataset_fn
(
params
.
file_type
)
,
decoder_fn
=
decoder
.
decode
,
decoder_fn
=
decoder
.
decode
,
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
))
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
))
dataset
=
reader
.
read
(
input_context
=
input_context
)
dataset
=
reader
.
read
(
input_context
=
input_context
)
...
...
official/vision/beta/tasks/retinanet.py
View file @
44e7092c
...
@@ -24,6 +24,7 @@ from official.vision import keras_cv
...
@@ -24,6 +24,7 @@ from official.vision import keras_cv
from
official.vision.beta.configs
import
retinanet
as
exp_cfg
from
official.vision.beta.configs
import
retinanet
as
exp_cfg
from
official.vision.beta.dataloaders
import
retinanet_input
from
official.vision.beta.dataloaders
import
retinanet_input
from
official.vision.beta.dataloaders
import
tf_example_decoder
from
official.vision.beta.dataloaders
import
tf_example_decoder
from
official.vision.beta.dataloaders
import
dataset_fn
from
official.vision.beta.dataloaders
import
tf_example_label_map_decoder
from
official.vision.beta.dataloaders
import
tf_example_label_map_decoder
from
official.vision.beta.evaluation
import
coco_evaluator
from
official.vision.beta.evaluation
import
coco_evaluator
from
official.vision.beta.modeling
import
factory
from
official.vision.beta.modeling
import
factory
...
@@ -93,16 +94,7 @@ class RetinaNetTask(base_task.Task):
...
@@ -93,16 +94,7 @@ class RetinaNetTask(base_task.Task):
regenerate_source_id
=
decoder_cfg
.
regenerate_source_id
)
regenerate_source_id
=
decoder_cfg
.
regenerate_source_id
)
else
:
else
:
raise
ValueError
(
'Unknown decoder type: {}!'
.
format
(
params
.
decoder
.
type
))
raise
ValueError
(
'Unknown decoder type: {}!'
.
format
(
params
.
decoder
.
type
))
decoder_cfg
=
params
.
decoder
.
get
()
if
params
.
decoder
.
type
==
'simple_decoder'
:
decoder
=
tf_example_decoder
.
TfExampleDecoder
(
regenerate_source_id
=
decoder_cfg
.
regenerate_source_id
)
elif
params
.
decoder
.
type
==
'label_map_decoder'
:
decoder
=
tf_example_decoder
.
TfExampleDecoderLabelMap
(
label_map
=
decoder_cfg
.
label_map
,
regenerate_source_id
=
decoder_cfg
.
regenerate_source_id
)
else
:
raise
ValueError
(
'Unknown decoder type: {}!'
.
format
(
params
.
decoder
.
type
))
parser
=
retinanet_input
.
Parser
(
parser
=
retinanet_input
.
Parser
(
output_size
=
self
.
task_config
.
model
.
input_size
[:
2
],
output_size
=
self
.
task_config
.
model
.
input_size
[:
2
],
min_level
=
self
.
task_config
.
model
.
min_level
,
min_level
=
self
.
task_config
.
model
.
min_level
,
...
@@ -121,7 +113,7 @@ class RetinaNetTask(base_task.Task):
...
@@ -121,7 +113,7 @@ class RetinaNetTask(base_task.Task):
reader
=
input_reader
.
InputReader
(
reader
=
input_reader
.
InputReader
(
params
,
params
,
dataset_fn
=
tf
.
data
.
TFRecordDataset
,
dataset_fn
=
dataset_fn
.
pick_dataset_fn
(
params
.
file_type
)
,
decoder_fn
=
decoder
.
decode
,
decoder_fn
=
decoder
.
decode
,
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
))
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
))
dataset
=
reader
.
read
(
input_context
=
input_context
)
dataset
=
reader
.
read
(
input_context
=
input_context
)
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment