Commit e3774f8c authored by Dan Kondratyuk's avatar Dan Kondratyuk Committed by A. Unique TensorFlower
Browse files

Add optional random stride to the video input pipeline.

PiperOrigin-RevId: 369747553
parent 1b1fca6d
...@@ -33,6 +33,7 @@ class DataConfig(cfg.DataConfig): ...@@ -33,6 +33,7 @@ class DataConfig(cfg.DataConfig):
split: str = 'train' split: str = 'train'
feature_shape: Tuple[int, ...] = (64, 224, 224, 3) feature_shape: Tuple[int, ...] = (64, 224, 224, 3)
temporal_stride: int = 1 temporal_stride: int = 1
random_stride_range: int = 0
num_test_clips: int = 1 num_test_clips: int = 1
num_test_crops: int = 1 num_test_crops: int = 1
num_classes: int = -1 num_classes: int = -1
......
...@@ -33,6 +33,7 @@ def process_image(image: tf.Tensor, ...@@ -33,6 +33,7 @@ def process_image(image: tf.Tensor,
is_training: bool = True, is_training: bool = True,
num_frames: int = 32, num_frames: int = 32,
stride: int = 1, stride: int = 1,
random_stride_range: int = 0,
num_test_clips: int = 1, num_test_clips: int = 1,
min_resize: int = 256, min_resize: int = 256,
crop_size: int = 224, crop_size: int = 224,
...@@ -52,6 +53,11 @@ def process_image(image: tf.Tensor, ...@@ -52,6 +53,11 @@ def process_image(image: tf.Tensor,
and left right flip is used. and left right flip is used.
num_frames: Number of frames per subclip. num_frames: Number of frames per subclip.
stride: Temporal stride to sample frames. stride: Temporal stride to sample frames.
random_stride_range: An int indicating the min and max bounds to uniformly
sample different strides from the video. E.g., a value of 1 with stride=2
will uniformly sample a stride in {1, 2, 3} for each video in a batch.
Only used enabled training for the purposes of frame-rate augmentation.
Defaults to 0, which disables random sampling.
num_test_clips: Number of test clips (1 by default). If more than 1, this num_test_clips: Number of test clips (1 by default). If more than 1, this
will sample multiple linearly spaced clips within each video at test time. will sample multiple linearly spaced clips within each video at test time.
If 1, then a single clip in the middle of the video is sampled. The clips If 1, then a single clip in the middle of the video is sampled. The clips
...@@ -78,8 +84,20 @@ def process_image(image: tf.Tensor, ...@@ -78,8 +84,20 @@ def process_image(image: tf.Tensor,
'`num_test_clips` %d is ignored since `is_training` is `True`.', '`num_test_clips` %d is ignored since `is_training` is `True`.',
num_test_clips) num_test_clips)
if random_stride_range < 0:
raise ValueError('Random stride range should be >= 0, got {}'.format(
random_stride_range))
# Temporal sampler. # Temporal sampler.
if is_training: if is_training:
if random_stride_range > 0:
# Uniformly sample different frame-rates
stride = tf.random.uniform(
[],
tf.maximum(stride - random_stride_range, 1),
stride + random_stride_range,
dtype=tf.int32)
# Sample random clip. # Sample random clip.
image = preprocess_ops_3d.sample_sequence(image, num_frames, True, stride, image = preprocess_ops_3d.sample_sequence(image, num_frames, True, stride,
seed) seed)
...@@ -219,6 +237,7 @@ class Parser(parser.Parser): ...@@ -219,6 +237,7 @@ class Parser(parser.Parser):
label_key: str = LABEL_KEY): label_key: str = LABEL_KEY):
self._num_frames = input_params.feature_shape[0] self._num_frames = input_params.feature_shape[0]
self._stride = input_params.temporal_stride self._stride = input_params.temporal_stride
self._random_stride_range = input_params.random_stride_range
self._num_test_clips = input_params.num_test_clips self._num_test_clips = input_params.num_test_clips
self._min_resize = input_params.min_image_size self._min_resize = input_params.min_image_size
self._crop_size = input_params.feature_shape[1] self._crop_size = input_params.feature_shape[1]
...@@ -248,6 +267,7 @@ class Parser(parser.Parser): ...@@ -248,6 +267,7 @@ class Parser(parser.Parser):
is_training=True, is_training=True,
num_frames=self._num_frames, num_frames=self._num_frames,
stride=self._stride, stride=self._stride,
random_stride_range=self._random_stride_range,
num_test_clips=self._num_test_clips, num_test_clips=self._num_test_clips,
min_resize=self._min_resize, min_resize=self._min_resize,
crop_size=self._crop_size, crop_size=self._crop_size,
......
...@@ -135,6 +135,28 @@ class VideoAndLabelParserTest(tf.test.TestCase): ...@@ -135,6 +135,28 @@ class VideoAndLabelParserTest(tf.test.TestCase):
self.assertAllEqual(label.shape, (600,)) self.assertAllEqual(label.shape, (600,))
self.assertEqual(audio.shape, (15, 256)) self.assertEqual(audio.shape, (15, 256))
def test_video_input_random_stride(self):
params = exp_cfg.kinetics600(is_training=True)
params.feature_shape = (2, 224, 224, 3)
params.min_image_size = 224
params.temporal_stride = 2
params.random_stride_range = 1
decoder = video_input.Decoder()
parser = video_input.Parser(params).parse_fn(params.is_training)
seq_example, label = fake_seq_example()
input_tensor = tf.constant(seq_example.SerializeToString())
decoded_tensors = decoder.decode(input_tensor)
output_tensor = parser(decoded_tensors)
image_features, label = output_tensor
image = image_features['image']
self.assertAllEqual(image.shape, (2, 224, 224, 3))
self.assertAllEqual(label.shape, (600,))
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -106,9 +106,10 @@ def sample_sequence(sequence: tf.Tensor, ...@@ -106,9 +106,10 @@ def sample_sequence(sequence: tf.Tensor,
if random: if random:
sequence_length = tf.cast(sequence_length, tf.float32) sequence_length = tf.cast(sequence_length, tf.float32)
frame_stride = tf.cast(stride, tf.float32)
max_offset = tf.cond( max_offset = tf.cond(
sequence_length > (num_steps - 1) * stride, sequence_length > (num_steps - 1) * frame_stride,
lambda: sequence_length - (num_steps - 1) * stride, lambda: sequence_length - (num_steps - 1) * frame_stride,
lambda: sequence_length) lambda: sequence_length)
offset = tf.random.uniform( offset = tf.random.uniform(
(), (),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment