"examples/git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "0fde7c57da6a8a73050bbf1919c614b9f6e55d58"
Commit e75624bc authored by Liangzhe Yuan's avatar Liangzhe Yuan Committed by A. Unique TensorFlower
Browse files

Propagate zero_centering_image flag in video_input.

PiperOrigin-RevId: 444993865
parent a00ddd57
......@@ -128,6 +128,9 @@ def _process_image(image: tf.Tensor,
# Self-supervised pre-training augmentations.
if is_training and is_ssl:
if zero_centering_image:
image_1 = 0.5 * (image_1 + 1.0)
image_2 = 0.5 * (image_2 + 1.0)
# Temporally consistent color jittering.
image_1 = video_ssl_preprocess_ops.random_color_jitter_3d(image_1)
image_2 = video_ssl_preprocess_ops.random_color_jitter_3d(image_2)
......@@ -139,6 +142,8 @@ def _process_image(image: tf.Tensor,
image_2 = video_ssl_preprocess_ops.random_solarization(image_2)
image = tf.concat([image_1, image_2], axis=0)
image = tf.clip_by_value(image, 0., 1.)
if zero_centering_image:
image = 2 * (image - 0.5)
return image
......@@ -233,7 +238,8 @@ class Parser(video_input.Parser):
stride=self._stride,
num_test_clips=self._num_test_clips,
min_resize=self._min_resize,
crop_size=self._crop_size)
crop_size=self._crop_size,
zero_centering_image=self._zero_centering_image)
image = tf.cast(image, dtype=self._dtype)
features = {'image': image}
......@@ -255,7 +261,8 @@ class Parser(video_input.Parser):
num_test_clips=self._num_test_clips,
min_resize=self._min_resize,
crop_size=self._crop_size,
num_crops=self._num_crops)
num_crops=self._num_crops,
zero_centering_image=self._zero_centering_image)
image = tf.cast(image, dtype=self._dtype)
features = {'image': image}
......
......@@ -49,6 +49,7 @@ class DataConfig(cfg.DataConfig):
cycle_length: int = 10
drop_remainder: bool = True
min_image_size: int = 256
zero_centering_image: bool = False
is_multilabel: bool = False
output_audio: bool = False
audio_feature: str = ''
......
......@@ -271,6 +271,7 @@ class Parser(parser.Parser):
self._min_resize = input_params.min_image_size
self._crop_size = input_params.feature_shape[1]
self._num_crops = input_params.num_test_crops
self._zero_centering_image = input_params.zero_centering_image
self._one_hot_label = input_params.one_hot
self._num_classes = input_params.num_classes
self._image_key = image_key
......@@ -317,7 +318,8 @@ class Parser(parser.Parser):
max_aspect_ratio=self._max_aspect_ratio,
min_area_ratio=self._min_area_ratio,
max_area_ratio=self._max_area_ratio,
augmenter=self._augmenter)
augmenter=self._augmenter,
zero_centering_image=self._zero_centering_image)
image = tf.cast(image, dtype=self._dtype)
features = {'image': image}
......@@ -349,7 +351,8 @@ class Parser(parser.Parser):
num_test_clips=self._num_test_clips,
min_resize=self._min_resize,
crop_size=self._crop_size,
num_crops=self._num_crops)
num_crops=self._num_crops,
zero_centering_image=self._zero_centering_image)
image = tf.cast(image, dtype=self._dtype)
features = {'image': image}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment