Commit 584b5f29 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Support three-crop evaluation function in preprocess_ops_3d.

PiperOrigin-RevId: 348816819
parent 5b0e647b
...@@ -148,9 +148,10 @@ def decode_jpeg(image_string: tf.Tensor, channels: int = 0) -> tf.Tensor: ...@@ -148,9 +148,10 @@ def decode_jpeg(image_string: tf.Tensor, channels: int = 0) -> tf.Tensor:
def crop_image(frames: tf.Tensor, def crop_image(frames: tf.Tensor,
height: int, target_height: int,
width: int, target_width: int,
random: bool = False, random: bool = False,
num_views: int = 1,
seed: Optional[int] = None) -> tf.Tensor: seed: Optional[int] = None) -> tf.Tensor:
"""Crops the image sequence of images. """Crops the image sequence of images.
...@@ -159,9 +160,10 @@ def crop_image(frames: tf.Tensor, ...@@ -159,9 +160,10 @@ def crop_image(frames: tf.Tensor,
Args: Args:
frames: A Tensor of dimension [timesteps, in_height, in_width, channels]. frames: A Tensor of dimension [timesteps, in_height, in_width, channels].
height: Cropped image height. target_height: Target cropped image height.
width: Cropped image width. target_width: Target cropped image width.
random: A boolean indicating if crop should be randomized. random: A boolean indicating if crop should be randomized.
num_views: Number of views to crop in evaluation.
seed: A deterministic seed to use when random cropping. seed: A deterministic seed to use when random cropping.
Returns: Returns:
...@@ -176,11 +178,54 @@ def crop_image(frames: tf.Tensor, ...@@ -176,11 +178,54 @@ def crop_image(frames: tf.Tensor,
static_shape = frames.shape.as_list() static_shape = frames.shape.as_list()
seq_len = shape[0] if static_shape[0] is None else static_shape[0] seq_len = shape[0] if static_shape[0] is None else static_shape[0]
channels = shape[3] if static_shape[3] is None else static_shape[3] channels = shape[3] if static_shape[3] is None else static_shape[3]
frames = tf.image.random_crop(frames, (seq_len, height, width, channels), frames = tf.image.random_crop(
seed) frames, (seq_len, target_height, target_width, channels), seed)
else: else:
# Central crop or pad. if num_views == 1:
frames = tf.image.resize_with_crop_or_pad(frames, height, width) # Central crop or pad.
frames = tf.image.resize_with_crop_or_pad(frames, target_height,
target_width)
elif num_views == 3:
# Three-view evaluation.
shape = tf.shape(frames)
static_shape = frames.shape.as_list()
seq_len = shape[0] if static_shape[0] is None else static_shape[0]
height = shape[1] if static_shape[1] is None else static_shape[1]
width = shape[2] if static_shape[2] is None else static_shape[2]
channels = shape[3] if static_shape[3] is None else static_shape[3]
size = tf.convert_to_tensor(
(seq_len, target_height, target_width, channels))
offset_1 = tf.broadcast_to([0, 0, 0, 0], [4])
# pylint:disable=g-long-lambda
offset_2 = tf.cond(
tf.greater_equal(height, width),
true_fn=lambda: tf.broadcast_to([
0, tf.cast(height, tf.float32) / 2 - target_height // 2, 0, 0
], [4]),
false_fn=lambda: tf.broadcast_to([
0, 0, tf.cast(width, tf.float32) / 2 - target_width // 2, 0
], [4]))
offset_3 = tf.cond(
tf.greater_equal(height, width),
true_fn=lambda: tf.broadcast_to(
[0, tf.cast(height, tf.float32) - target_height, 0, 0], [4]),
false_fn=lambda: tf.broadcast_to(
[0, 0, tf.cast(width, tf.float32) - target_width, 0], [4]))
# pylint:disable=g-long-lambda
crops = []
for offset in [offset_1, offset_2, offset_3]:
offset = tf.cast(tf.math.round(offset), tf.int32)
crops.append(tf.slice(frames, offset, size))
frames = tf.concat(crops, axis=0)
else:
raise NotImplementedError(
f"Only 1 crop and 3 crop are supported. Found {num_views!r}.")
return frames return frames
......
...@@ -91,6 +91,8 @@ class ParserUtilsTest(tf.test.TestCase): ...@@ -91,6 +91,8 @@ class ParserUtilsTest(tf.test.TestCase):
cropped_image_1 = preprocess_ops_3d.crop_image(self._frames, 50, 70) cropped_image_1 = preprocess_ops_3d.crop_image(self._frames, 50, 70)
cropped_image_2 = preprocess_ops_3d.crop_image(self._frames, 200, 200) cropped_image_2 = preprocess_ops_3d.crop_image(self._frames, 200, 200)
cropped_image_3 = preprocess_ops_3d.crop_image(self._frames, 50, 70, True) cropped_image_3 = preprocess_ops_3d.crop_image(self._frames, 50, 70, True)
cropped_image_4 = preprocess_ops_3d.crop_image(
self._frames, 90, 90, False, 3)
self.assertAllEqual(cropped_image_1.shape, (6, 50, 70, 3)) self.assertAllEqual(cropped_image_1.shape, (6, 50, 70, 3))
self.assertAllEqual(cropped_image_1, self._np_frames[:, 20:70, 25:95, :]) self.assertAllEqual(cropped_image_1, self._np_frames[:, 20:70, 25:95, :])
...@@ -106,6 +108,7 @@ class ParserUtilsTest(tf.test.TestCase): ...@@ -106,6 +108,7 @@ class ParserUtilsTest(tf.test.TestCase):
expected = expected[np.newaxis, :, :, np.newaxis] expected = expected[np.newaxis, :, :, np.newaxis]
expected = np.broadcast_to(expected, (6, 50, 70, 3)) expected = np.broadcast_to(expected, (6, 50, 70, 3))
self.assertAllEqual(cropped_image_3, expected) self.assertAllEqual(cropped_image_3, expected)
self.assertAllEqual(cropped_image_4.shape, (18, 90, 90, 3))
def test_resize_smallest(self): def test_resize_smallest(self):
resized_frames_1 = preprocess_ops_3d.resize_smallest(self._frames, 180) resized_frames_1 = preprocess_ops_3d.resize_smallest(self._frames, 180)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment