Commit 6d6e881a authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Add additional parameters for processing different image shape and label type.

PiperOrigin-RevId: 462275036
parent 62c74392
...@@ -41,6 +41,7 @@ class DataConfig(cfg.DataConfig): ...@@ -41,6 +41,7 @@ class DataConfig(cfg.DataConfig):
global_batch_size: int = 128 global_batch_size: int = 128
data_format: str = 'channels_last' data_format: str = 'channels_last'
dtype: str = 'float32' dtype: str = 'float32'
label_dtype: str = 'int32'
one_hot: bool = True one_hot: bool = True
shuffle_buffer_size: int = 64 shuffle_buffer_size: int = 64
cache: bool = False cache: bool = False
......
...@@ -36,7 +36,8 @@ def process_image(image: tf.Tensor, ...@@ -36,7 +36,8 @@ def process_image(image: tf.Tensor,
random_stride_range: int = 0, random_stride_range: int = 0,
num_test_clips: int = 1, num_test_clips: int = 1,
min_resize: int = 256, min_resize: int = 256,
crop_size: int = 224, crop_size: Union[int, Tuple[int, int]] = 224,
num_channels: int = 3,
num_crops: int = 1, num_crops: int = 1,
zero_centering_image: bool = False, zero_centering_image: bool = False,
min_aspect_ratio: float = 0.5, min_aspect_ratio: float = 0.5,
...@@ -64,8 +65,10 @@ def process_image(image: tf.Tensor, ...@@ -64,8 +65,10 @@ def process_image(image: tf.Tensor,
If 1, then a single clip in the middle of the video is sampled. The clips If 1, then a single clip in the middle of the video is sampled. The clips
are aggreagated in the batch dimension. are aggreagated in the batch dimension.
min_resize: Frames are resized so that min(height, width) is min_resize. min_resize: Frames are resized so that min(height, width) is min_resize.
crop_size: Final size of the frame after cropping the resized frames. Both crop_size: Final size of the frame after cropping the resized frames.
height and width are the same. Optionally, specify a tuple of (crop_height, crop_width) if
crop_height != crop_width.
num_channels: Number of channels of the clip.
num_crops: Number of crops to perform on the resized frames. num_crops: Number of crops to perform on the resized frames.
zero_centering_image: If True, frames are normalized to values in [-1, 1]. zero_centering_image: If True, frames are normalized to values in [-1, 1].
If False, values in [0, 1]. If False, values in [0, 1].
...@@ -78,7 +81,7 @@ def process_image(image: tf.Tensor, ...@@ -78,7 +81,7 @@ def process_image(image: tf.Tensor,
Returns: Returns:
Processed frames. Tensor of shape Processed frames. Tensor of shape
[num_frames * num_test_clips, crop_size, crop_size, 3]. [num_frames * num_test_clips, crop_height, crop_width, num_channels].
""" """
# Validate parameters. # Validate parameters.
if is_training and num_test_clips != 1: if is_training and num_test_clips != 1:
...@@ -90,6 +93,10 @@ def process_image(image: tf.Tensor, ...@@ -90,6 +93,10 @@ def process_image(image: tf.Tensor,
raise ValueError('Random stride range should be >= 0, got {}'.format( raise ValueError('Random stride range should be >= 0, got {}'.format(
random_stride_range)) random_stride_range))
if isinstance(crop_size, int):
crop_size = (crop_size, crop_size)
crop_height, crop_width = crop_size
# Temporal sampler. # Temporal sampler.
if is_training: if is_training:
if random_stride_range > 0: if random_stride_range > 0:
...@@ -113,12 +120,12 @@ def process_image(image: tf.Tensor, ...@@ -113,12 +120,12 @@ def process_image(image: tf.Tensor,
# Decode JPEG string to tf.uint8. # Decode JPEG string to tf.uint8.
if image.dtype == tf.string: if image.dtype == tf.string:
image = preprocess_ops_3d.decode_jpeg(image, 3) image = preprocess_ops_3d.decode_jpeg(image, num_channels)
if is_training: if is_training:
# Standard image data augmentation: random resized crop and random flip. # Standard image data augmentation: random resized crop and random flip.
image = preprocess_ops_3d.random_crop_resize( image = preprocess_ops_3d.random_crop_resize(
image, crop_size, crop_size, num_frames, 3, image, crop_height, crop_width, num_frames, num_channels,
(min_aspect_ratio, max_aspect_ratio), (min_aspect_ratio, max_aspect_ratio),
(min_area_ratio, max_area_ratio)) (min_area_ratio, max_area_ratio))
image = preprocess_ops_3d.random_flip_left_right(image, seed) image = preprocess_ops_3d.random_flip_left_right(image, seed)
...@@ -129,7 +136,7 @@ def process_image(image: tf.Tensor, ...@@ -129,7 +136,7 @@ def process_image(image: tf.Tensor,
# Resize images (resize happens only if necessary to save compute). # Resize images (resize happens only if necessary to save compute).
image = preprocess_ops_3d.resize_smallest(image, min_resize) image = preprocess_ops_3d.resize_smallest(image, min_resize)
# Crop of the frames. # Crop of the frames.
image = preprocess_ops_3d.crop_image(image, crop_size, crop_size, False, image = preprocess_ops_3d.crop_image(image, crop_height, crop_width, False,
num_crops) num_crops)
# Cast the frames in float32, normalizing according to zero_centering_image. # Cast the frames in float32, normalizing according to zero_centering_image.
...@@ -173,15 +180,16 @@ def postprocess_image(image: tf.Tensor, ...@@ -173,15 +180,16 @@ def postprocess_image(image: tf.Tensor,
def process_label(label: tf.Tensor, def process_label(label: tf.Tensor,
one_hot_label: bool = True, one_hot_label: bool = True,
num_classes: Optional[int] = None) -> tf.Tensor: num_classes: Optional[int] = None,
label_dtype: tf.DType = tf.int32) -> tf.Tensor:
"""Processes label Tensor.""" """Processes label Tensor."""
# Validate parameters. # Validate parameters.
if one_hot_label and not num_classes: if one_hot_label and not num_classes:
raise ValueError( raise ValueError(
'`num_classes` should be given when requesting one hot label.') '`num_classes` should be given when requesting one hot label.')
# Cast to tf.int32. # Cast to label_dtype (default = tf.int32).
label = tf.cast(label, dtype=tf.int32) label = tf.cast(label, dtype=label_dtype)
if one_hot_label: if one_hot_label:
# Replace label index by one hot representation. # Replace label index by one hot representation.
...@@ -269,7 +277,11 @@ class Parser(parser.Parser): ...@@ -269,7 +277,11 @@ class Parser(parser.Parser):
self._random_stride_range = input_params.random_stride_range self._random_stride_range = input_params.random_stride_range
self._num_test_clips = input_params.num_test_clips self._num_test_clips = input_params.num_test_clips
self._min_resize = input_params.min_image_size self._min_resize = input_params.min_image_size
self._crop_size = input_params.feature_shape[1] crop_height = input_params.feature_shape[1]
crop_width = input_params.feature_shape[2]
self._crop_size = crop_height if crop_height == crop_width else (
crop_height, crop_width)
self._num_channels = input_params.feature_shape[3]
self._num_crops = input_params.num_test_crops self._num_crops = input_params.num_test_crops
self._zero_centering_image = input_params.zero_centering_image self._zero_centering_image = input_params.zero_centering_image
self._one_hot_label = input_params.one_hot self._one_hot_label = input_params.one_hot
...@@ -277,6 +289,7 @@ class Parser(parser.Parser): ...@@ -277,6 +289,7 @@ class Parser(parser.Parser):
self._image_key = image_key self._image_key = image_key
self._label_key = label_key self._label_key = label_key
self._dtype = tf.dtypes.as_dtype(input_params.dtype) self._dtype = tf.dtypes.as_dtype(input_params.dtype)
self._label_dtype = tf.dtypes.as_dtype(input_params.label_dtype)
self._output_audio = input_params.output_audio self._output_audio = input_params.output_audio
self._min_aspect_ratio = input_params.aug_min_aspect_ratio self._min_aspect_ratio = input_params.aug_min_aspect_ratio
self._max_aspect_ratio = input_params.aug_max_aspect_ratio self._max_aspect_ratio = input_params.aug_max_aspect_ratio
...@@ -324,6 +337,7 @@ class Parser(parser.Parser): ...@@ -324,6 +337,7 @@ class Parser(parser.Parser):
num_test_clips=self._num_test_clips, num_test_clips=self._num_test_clips,
min_resize=self._min_resize, min_resize=self._min_resize,
crop_size=self._crop_size, crop_size=self._crop_size,
num_channels=self._num_channels,
min_aspect_ratio=self._min_aspect_ratio, min_aspect_ratio=self._min_aspect_ratio,
max_aspect_ratio=self._max_aspect_ratio, max_aspect_ratio=self._max_aspect_ratio,
min_area_ratio=self._min_area_ratio, min_area_ratio=self._min_area_ratio,
...@@ -335,7 +349,8 @@ class Parser(parser.Parser): ...@@ -335,7 +349,8 @@ class Parser(parser.Parser):
features = {'image': image} features = {'image': image}
label = decoded_tensors[self._label_key] label = decoded_tensors[self._label_key]
label = process_label(label, self._one_hot_label, self._num_classes) label = process_label(label, self._one_hot_label, self._num_classes,
self._label_dtype)
if self._output_audio: if self._output_audio:
audio = decoded_tensors[self._audio_feature] audio = decoded_tensors[self._audio_feature]
...@@ -361,13 +376,15 @@ class Parser(parser.Parser): ...@@ -361,13 +376,15 @@ class Parser(parser.Parser):
num_test_clips=self._num_test_clips, num_test_clips=self._num_test_clips,
min_resize=self._min_resize, min_resize=self._min_resize,
crop_size=self._crop_size, crop_size=self._crop_size,
num_channels=self._num_channels,
num_crops=self._num_crops, num_crops=self._num_crops,
zero_centering_image=self._zero_centering_image) zero_centering_image=self._zero_centering_image)
image = tf.cast(image, dtype=self._dtype) image = tf.cast(image, dtype=self._dtype)
features = {'image': image} features = {'image': image}
label = decoded_tensors[self._label_key] label = decoded_tensors[self._label_key]
label = process_label(label, self._one_hot_label, self._num_classes) label = process_label(label, self._one_hot_label, self._num_classes,
self._label_dtype)
if self._output_audio: if self._output_audio:
audio = decoded_tensors[self._audio_feature] audio = decoded_tensors[self._audio_feature]
......
...@@ -191,6 +191,28 @@ class VideoAndLabelParserTest(tf.test.TestCase): ...@@ -191,6 +191,28 @@ class VideoAndLabelParserTest(tf.test.TestCase):
self.assertAllEqual(image.shape, (2, 224, 224, 3)) self.assertAllEqual(image.shape, (2, 224, 224, 3))
self.assertAllEqual(label.shape, (600,)) self.assertAllEqual(label.shape, (600,))
def test_video_input_image_shape_label_type(self):
params = exp_cfg.kinetics600(is_training=True)
params.feature_shape = (2, 168, 224, 1)
params.min_image_size = 168
params.label_dtype = 'float32'
params.one_hot = False
decoder = video_input.Decoder()
parser = video_input.Parser(params).parse_fn(params.is_training)
seq_example, label = fake_seq_example()
input_tensor = tf.constant(seq_example.SerializeToString())
decoded_tensors = decoder.decode(input_tensor)
output_tensor = parser(decoded_tensors)
image_features, label = output_tensor
image = image_features['image']
self.assertAllEqual(image.shape, (2, 168, 224, 1))
self.assertAllEqual(label.shape, (1,))
self.assertDTypeEqual(label, tf.float32)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment