Commit e75624bc authored by Liangzhe Yuan's avatar Liangzhe Yuan Committed by A. Unique TensorFlower
Browse files

Propagate zero_centering_image flag in video_input.

PiperOrigin-RevId: 444993865
parent a00ddd57
...@@ -128,6 +128,9 @@ def _process_image(image: tf.Tensor, ...@@ -128,6 +128,9 @@ def _process_image(image: tf.Tensor,
# Self-supervised pre-training augmentations. # Self-supervised pre-training augmentations.
if is_training and is_ssl: if is_training and is_ssl:
if zero_centering_image:
image_1 = 0.5 * (image_1 + 1.0)
image_2 = 0.5 * (image_2 + 1.0)
# Temporally consistent color jittering. # Temporally consistent color jittering.
image_1 = video_ssl_preprocess_ops.random_color_jitter_3d(image_1) image_1 = video_ssl_preprocess_ops.random_color_jitter_3d(image_1)
image_2 = video_ssl_preprocess_ops.random_color_jitter_3d(image_2) image_2 = video_ssl_preprocess_ops.random_color_jitter_3d(image_2)
...@@ -139,6 +142,8 @@ def _process_image(image: tf.Tensor, ...@@ -139,6 +142,8 @@ def _process_image(image: tf.Tensor,
image_2 = video_ssl_preprocess_ops.random_solarization(image_2) image_2 = video_ssl_preprocess_ops.random_solarization(image_2)
image = tf.concat([image_1, image_2], axis=0) image = tf.concat([image_1, image_2], axis=0)
image = tf.clip_by_value(image, 0., 1.) image = tf.clip_by_value(image, 0., 1.)
if zero_centering_image:
image = 2 * (image - 0.5)
return image return image
...@@ -233,7 +238,8 @@ class Parser(video_input.Parser): ...@@ -233,7 +238,8 @@ class Parser(video_input.Parser):
stride=self._stride, stride=self._stride,
num_test_clips=self._num_test_clips, num_test_clips=self._num_test_clips,
min_resize=self._min_resize, min_resize=self._min_resize,
crop_size=self._crop_size) crop_size=self._crop_size,
zero_centering_image=self._zero_centering_image)
image = tf.cast(image, dtype=self._dtype) image = tf.cast(image, dtype=self._dtype)
features = {'image': image} features = {'image': image}
...@@ -255,7 +261,8 @@ class Parser(video_input.Parser): ...@@ -255,7 +261,8 @@ class Parser(video_input.Parser):
num_test_clips=self._num_test_clips, num_test_clips=self._num_test_clips,
min_resize=self._min_resize, min_resize=self._min_resize,
crop_size=self._crop_size, crop_size=self._crop_size,
num_crops=self._num_crops) num_crops=self._num_crops,
zero_centering_image=self._zero_centering_image)
image = tf.cast(image, dtype=self._dtype) image = tf.cast(image, dtype=self._dtype)
features = {'image': image} features = {'image': image}
......
...@@ -49,6 +49,7 @@ class DataConfig(cfg.DataConfig): ...@@ -49,6 +49,7 @@ class DataConfig(cfg.DataConfig):
cycle_length: int = 10 cycle_length: int = 10
drop_remainder: bool = True drop_remainder: bool = True
min_image_size: int = 256 min_image_size: int = 256
zero_centering_image: bool = False
is_multilabel: bool = False is_multilabel: bool = False
output_audio: bool = False output_audio: bool = False
audio_feature: str = '' audio_feature: str = ''
......
...@@ -271,6 +271,7 @@ class Parser(parser.Parser): ...@@ -271,6 +271,7 @@ class Parser(parser.Parser):
self._min_resize = input_params.min_image_size self._min_resize = input_params.min_image_size
self._crop_size = input_params.feature_shape[1] self._crop_size = input_params.feature_shape[1]
self._num_crops = input_params.num_test_crops self._num_crops = input_params.num_test_crops
self._zero_centering_image = input_params.zero_centering_image
self._one_hot_label = input_params.one_hot self._one_hot_label = input_params.one_hot
self._num_classes = input_params.num_classes self._num_classes = input_params.num_classes
self._image_key = image_key self._image_key = image_key
...@@ -317,7 +318,8 @@ class Parser(parser.Parser): ...@@ -317,7 +318,8 @@ class Parser(parser.Parser):
max_aspect_ratio=self._max_aspect_ratio, max_aspect_ratio=self._max_aspect_ratio,
min_area_ratio=self._min_area_ratio, min_area_ratio=self._min_area_ratio,
max_area_ratio=self._max_area_ratio, max_area_ratio=self._max_area_ratio,
augmenter=self._augmenter) augmenter=self._augmenter,
zero_centering_image=self._zero_centering_image)
image = tf.cast(image, dtype=self._dtype) image = tf.cast(image, dtype=self._dtype)
features = {'image': image} features = {'image': image}
...@@ -349,7 +351,8 @@ class Parser(parser.Parser): ...@@ -349,7 +351,8 @@ class Parser(parser.Parser):
num_test_clips=self._num_test_clips, num_test_clips=self._num_test_clips,
min_resize=self._min_resize, min_resize=self._min_resize,
crop_size=self._crop_size, crop_size=self._crop_size,
num_crops=self._num_crops) num_crops=self._num_crops,
zero_centering_image=self._zero_centering_image)
image = tf.cast(image, dtype=self._dtype) image = tf.cast(image, dtype=self._dtype)
features = {'image': image} features = {'image': image}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment