Commit 11ea5237 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Support random_crop_resize function in preprocess_ops_3d.

PiperOrigin-RevId: 348500556
parent 63bdcfb8
......@@ -15,7 +15,7 @@
# ==============================================================================
"""Utils for processing video dataset features."""
from typing import Optional
from typing import Optional, Tuple
import tensorflow as tf
......@@ -217,6 +217,55 @@ def resize_smallest(frames: tf.Tensor,
return frames
def random_crop_resize(frames: tf.Tensor,
output_h: int,
output_w: int,
num_frames: int,
num_channels: int,
aspect_ratio: Tuple[float, float],
area_range: Tuple[float, float]) -> tf.Tensor:
"""First crops clip with jittering and then resizes to (output_h, output_w).
Args:
frames: A Tensor of dimension [timesteps, input_h, input_w, channels].
output_h: Resized image height.
output_w: Resized image width.
num_frames: Number of input frames per clip.
num_channels: Number of channels of the clip.
aspect_ratio: Float tuple with the aspect range for cropping.
area_range: Float tuple with the area range for cropping.
Returns:
A Tensor of shape [timesteps, output_h, output_w, channels] of type
frames.dtype.
"""
shape = tf.shape(frames)
seq_len, _, _, channels = shape[0], shape[1], shape[2], shape[3]
bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
factor = output_w / output_h
aspect_ratio = (aspect_ratio[0] * factor, aspect_ratio[1] * factor)
sample_distorted_bbox = tf.image.sample_distorted_bounding_box(
shape[1:],
bounding_boxes=bbox,
min_object_covered=0.1,
aspect_ratio_range=aspect_ratio,
area_range=area_range,
max_attempts=100,
use_image_if_no_bounding_boxes=True)
bbox_begin, bbox_size, _ = sample_distorted_bbox
offset_y, offset_x, _ = tf.unstack(bbox_begin)
target_height, target_width, _ = tf.unstack(bbox_size)
size = tf.convert_to_tensor((
seq_len, target_height, target_width, channels))
offset = tf.convert_to_tensor((
0, offset_y, offset_x, 0))
frames = tf.slice(frames, offset, size)
frames = tf.cast(
tf.image.resize(frames, (output_h, output_w)),
frames.dtype)
frames.set_shape((num_frames, output_h, output_w, num_channels))
return frames
def random_flip_left_right(
frames: tf.Tensor,
seed: Optional[int] = None) -> tf.Tensor:
......
......@@ -119,6 +119,20 @@ class ParserUtilsTest(tf.test.TestCase):
self.assertAllEqual(resized_frames_3.shape, (6, 90, 120, 3))
self.assertAllEqual(resized_frames_4.shape, (6, 60, 45, 3))
def test_random_crop_resize(self):
resized_frames_1 = preprocess_ops_3d.random_crop_resize(
self._frames, 256, 256, 6, 3, (0.5, 2), (0.3, 1))
resized_frames_2 = preprocess_ops_3d.random_crop_resize(
self._frames, 224, 224, 6, 3, (0.5, 2), (0.3, 1))
resized_frames_3 = preprocess_ops_3d.random_crop_resize(
self._frames, 256, 256, 6, 3, (0.8, 1.2), (0.3, 1))
resized_frames_4 = preprocess_ops_3d.random_crop_resize(
self._frames, 256, 256, 6, 3, (0.5, 2), (0.1, 1))
self.assertAllEqual(resized_frames_1.shape, (6, 256, 256, 3))
self.assertAllEqual(resized_frames_2.shape, (6, 224, 224, 3))
self.assertAllEqual(resized_frames_3.shape, (6, 256, 256, 3))
self.assertAllEqual(resized_frames_4.shape, (6, 256, 256, 3))
def test_random_flip_left_right(self):
flipped_frames = preprocess_ops_3d.random_flip_left_right(self._frames)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment