Commit b0c03e17 authored by Dan Kondratyuk's avatar Dan Kondratyuk Committed by A. Unique TensorFlower
Browse files

Add optional random stride to the video input pipeline.

PiperOrigin-RevId: 369747553
parent 9f3b338c
......@@ -33,6 +33,7 @@ class DataConfig(cfg.DataConfig):
split: str = 'train'
feature_shape: Tuple[int, ...] = (64, 224, 224, 3)
temporal_stride: int = 1
random_stride_range: int = 0
num_test_clips: int = 1
num_test_crops: int = 1
num_classes: int = -1
......
......@@ -33,6 +33,7 @@ def process_image(image: tf.Tensor,
is_training: bool = True,
num_frames: int = 32,
stride: int = 1,
random_stride_range: int = 0,
num_test_clips: int = 1,
min_resize: int = 256,
crop_size: int = 224,
......@@ -52,6 +53,11 @@ def process_image(image: tf.Tensor,
and left right flip is used.
num_frames: Number of frames per subclip.
stride: Temporal stride to sample frames.
random_stride_range: An int indicating the min and max bounds to uniformly
sample different strides from the video. E.g., a value of 1 with stride=2
will uniformly sample a stride in {1, 2, 3} for each video in a batch.
Only used enabled training for the purposes of frame-rate augmentation.
Defaults to 0, which disables random sampling.
num_test_clips: Number of test clips (1 by default). If more than 1, this
will sample multiple linearly spaced clips within each video at test time.
If 1, then a single clip in the middle of the video is sampled. The clips
......@@ -78,8 +84,20 @@ def process_image(image: tf.Tensor,
'`num_test_clips` %d is ignored since `is_training` is `True`.',
num_test_clips)
if random_stride_range < 0:
raise ValueError('Random stride range should be >= 0, got {}'.format(
random_stride_range))
# Temporal sampler.
if is_training:
if random_stride_range > 0:
# Uniformly sample different frame-rates
stride = tf.random.uniform(
[],
tf.maximum(stride - random_stride_range, 1),
stride + random_stride_range,
dtype=tf.int32)
# Sample random clip.
image = preprocess_ops_3d.sample_sequence(image, num_frames, True, stride,
seed)
......@@ -219,6 +237,7 @@ class Parser(parser.Parser):
label_key: str = LABEL_KEY):
self._num_frames = input_params.feature_shape[0]
self._stride = input_params.temporal_stride
self._random_stride_range = input_params.random_stride_range
self._num_test_clips = input_params.num_test_clips
self._min_resize = input_params.min_image_size
self._crop_size = input_params.feature_shape[1]
......@@ -248,6 +267,7 @@ class Parser(parser.Parser):
is_training=True,
num_frames=self._num_frames,
stride=self._stride,
random_stride_range=self._random_stride_range,
num_test_clips=self._num_test_clips,
min_resize=self._min_resize,
crop_size=self._crop_size,
......
......@@ -135,6 +135,28 @@ class VideoAndLabelParserTest(tf.test.TestCase):
self.assertAllEqual(label.shape, (600,))
self.assertEqual(audio.shape, (15, 256))
def test_video_input_random_stride(self):
params = exp_cfg.kinetics600(is_training=True)
params.feature_shape = (2, 224, 224, 3)
params.min_image_size = 224
params.temporal_stride = 2
params.random_stride_range = 1
decoder = video_input.Decoder()
parser = video_input.Parser(params).parse_fn(params.is_training)
seq_example, label = fake_seq_example()
input_tensor = tf.constant(seq_example.SerializeToString())
decoded_tensors = decoder.decode(input_tensor)
output_tensor = parser(decoded_tensors)
image_features, label = output_tensor
image = image_features['image']
self.assertAllEqual(image.shape, (2, 224, 224, 3))
self.assertAllEqual(label.shape, (600,))
if __name__ == '__main__':
tf.test.main()
......@@ -106,9 +106,10 @@ def sample_sequence(sequence: tf.Tensor,
if random:
sequence_length = tf.cast(sequence_length, tf.float32)
frame_stride = tf.cast(stride, tf.float32)
max_offset = tf.cond(
sequence_length > (num_steps - 1) * stride,
lambda: sequence_length - (num_steps - 1) * stride,
sequence_length > (num_steps - 1) * frame_stride,
lambda: sequence_length - (num_steps - 1) * frame_stride,
lambda: sequence_length)
offset = tf.random.uniform(
(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment