"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "eaa0c1dfb50dff7679c42c145a55bad7c0a728bf"
Commit 5670119e authored by Rui Qian's avatar Rui Qian Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 457627771
parent ae6f4977
...@@ -18,8 +18,7 @@ from typing import Optional, Tuple ...@@ -18,8 +18,7 @@ from typing import Optional, Tuple
import tensorflow as tf import tensorflow as tf
def _sample_or_pad_sequence_indices(sequence: tf.Tensor, def _sample_or_pad_sequence_indices(sequence: tf.Tensor, num_steps: int,
num_steps: int,
stride: int, stride: int,
offset: tf.Tensor) -> tf.Tensor: offset: tf.Tensor) -> tf.Tensor:
"""Returns indices to take for sampling or padding sequences to fixed size.""" """Returns indices to take for sampling or padding sequences to fixed size."""
...@@ -28,18 +27,16 @@ def _sample_or_pad_sequence_indices(sequence: tf.Tensor, ...@@ -28,18 +27,16 @@ def _sample_or_pad_sequence_indices(sequence: tf.Tensor,
# Repeats sequence until num_steps are available in total. # Repeats sequence until num_steps are available in total.
max_length = num_steps * stride + offset max_length = num_steps * stride + offset
num_repeats = tf.math.floordiv( num_repeats = tf.math.floordiv(max_length + sequence_length - 1,
max_length + sequence_length - 1, sequence_length) sequence_length)
sel_idx = tf.tile(sel_idx, [num_repeats]) sel_idx = tf.tile(sel_idx, [num_repeats])
steps = tf.range(offset, offset + num_steps * stride, stride) steps = tf.range(offset, offset + num_steps * stride, stride)
return tf.gather(sel_idx, steps) return tf.gather(sel_idx, steps)
def sample_linspace_sequence(sequence: tf.Tensor, def sample_linspace_sequence(sequence: tf.Tensor, num_windows: int,
num_windows: int, num_steps: int, stride: int) -> tf.Tensor:
num_steps: int,
stride: int) -> tf.Tensor:
"""Samples `num_windows` segments from sequence with linearly spaced offsets. """Samples `num_windows` segments from sequence with linearly spaced offsets.
The samples are concatenated in a single `tf.Tensor` in order to have the same The samples are concatenated in a single `tf.Tensor` in order to have the same
...@@ -66,11 +63,12 @@ def sample_linspace_sequence(sequence: tf.Tensor, ...@@ -66,11 +63,12 @@ def sample_linspace_sequence(sequence: tf.Tensor,
all_indices = [] all_indices = []
for i in range(num_windows): for i in range(num_windows):
all_indices.append(_sample_or_pad_sequence_indices( all_indices.append(
sequence=sequence, _sample_or_pad_sequence_indices(
num_steps=num_steps, sequence=sequence,
stride=stride, num_steps=num_steps,
offset=offsets[i])) stride=stride,
offset=offsets[i]))
indices = tf.concat(all_indices, axis=0) indices = tf.concat(all_indices, axis=0)
indices.set_shape((num_windows * num_steps,)) indices.set_shape((num_windows * num_steps,))
...@@ -110,25 +108,76 @@ def sample_sequence(sequence: tf.Tensor, ...@@ -110,25 +108,76 @@ def sample_sequence(sequence: tf.Tensor,
sequence_length > (num_steps - 1) * frame_stride, sequence_length > (num_steps - 1) * frame_stride,
lambda: sequence_length - (num_steps - 1) * frame_stride, lambda: sequence_length - (num_steps - 1) * frame_stride,
lambda: sequence_length) lambda: sequence_length)
offset = tf.random.uniform( offset = tf.random.uniform((),
(), maxval=tf.cast(max_offset, dtype=tf.int32),
maxval=tf.cast(max_offset, dtype=tf.int32), dtype=tf.int32,
dtype=tf.int32, seed=seed)
seed=seed)
else: else:
offset = (sequence_length - num_steps * stride) // 2 offset = (sequence_length - num_steps * stride) // 2
offset = tf.maximum(0, offset) offset = tf.maximum(0, offset)
indices = _sample_or_pad_sequence_indices( indices = _sample_or_pad_sequence_indices(
sequence=sequence, sequence=sequence, num_steps=num_steps, stride=stride, offset=offset)
num_steps=num_steps,
stride=stride,
offset=offset)
indices.set_shape((num_steps,)) indices.set_shape((num_steps,))
return tf.gather(sequence, indices) return tf.gather(sequence, indices)
def sample_segment_sequence(sequence: tf.Tensor,
num_frames: int,
is_training: bool,
seed: Optional[int] = None) -> tf.Tensor:
"""Samples a single segment of size `num_frames` from a given sequence.
This function follows the temporal segment network sampling style
(https://arxiv.org/abs/1608.00859). The video sequence would be divided into
`num_frames` non-overlapping segments with same length. If `is_training` is
`True`, we would randomly sampling one frame for each segment, and when
`is_training` is `False`, only the center frame of each segment is sampled.
Args:
sequence: Any tensor where the first dimension is timesteps.
num_frames: Number of frames to take.
is_training: A boolean indicating sampling in training or evaluation mode.
seed: A deterministic seed to use when sampling.
Returns:
A single `tf.Tensor` with first dimension `num_steps` with the sampled
segment.
"""
sequence_length = tf.shape(sequence)[0]
sequence_length = tf.cast(sequence_length, tf.float32)
segment_length = tf.cast(sequence_length // num_frames, tf.float32)
segment_indices = tf.linspace(0.0, sequence_length, num_frames + 1)
segment_indices = tf.cast(segment_indices, tf.int32)
if is_training:
segment_length = tf.cast(segment_length, tf.int32)
# pylint:disable=g-long-lambda
segment_offsets = tf.cond(
segment_length == 0,
lambda: tf.zeros(shape=(num_frames,), dtype=tf.int32),
lambda: tf.random.uniform(
shape=(num_frames,),
minval=0,
maxval=segment_length,
dtype=tf.int32,
seed=seed))
# pylint:disable=g-long-lambda
else:
# Only sampling central frame during inference for being deterministic.
segment_offsets = tf.ones(
shape=(num_frames,), dtype=tf.int32) * tf.cast(
segment_length // 2, dtype=tf.int32)
indices = segment_indices[:-1] + segment_offsets
indices.set_shape((num_frames,))
return tf.gather(sequence, indices)
def decode_jpeg(image_string: tf.Tensor, channels: int = 0) -> tf.Tensor: def decode_jpeg(image_string: tf.Tensor, channels: int = 0) -> tf.Tensor:
"""Decodes JPEG raw bytes string into a RGB uint8 Tensor. """Decodes JPEG raw bytes string into a RGB uint8 Tensor.
...@@ -144,7 +193,9 @@ def decode_jpeg(image_string: tf.Tensor, channels: int = 0) -> tf.Tensor: ...@@ -144,7 +193,9 @@ def decode_jpeg(image_string: tf.Tensor, channels: int = 0) -> tf.Tensor:
""" """
return tf.map_fn( return tf.map_fn(
lambda x: tf.image.decode_jpeg(x, channels=channels), lambda x: tf.image.decode_jpeg(x, channels=channels),
image_string, back_prop=False, dtype=tf.uint8) image_string,
back_prop=False,
dtype=tf.uint8)
def crop_image(frames: tf.Tensor, def crop_image(frames: tf.Tensor,
...@@ -229,8 +280,7 @@ def crop_image(frames: tf.Tensor, ...@@ -229,8 +280,7 @@ def crop_image(frames: tf.Tensor,
return frames return frames
def resize_smallest(frames: tf.Tensor, def resize_smallest(frames: tf.Tensor, min_resize: int) -> tf.Tensor:
min_resize: int) -> tf.Tensor:
"""Resizes frames so that min(`height`, `width`) is equal to `min_resize`. """Resizes frames so that min(`height`, `width`) is equal to `min_resize`.
This function will not do anything if the min(`height`, `width`) is already This function will not do anything if the min(`height`, `width`) is already
...@@ -255,18 +305,15 @@ def resize_smallest(frames: tf.Tensor, ...@@ -255,18 +305,15 @@ def resize_smallest(frames: tf.Tensor,
frames_resized = tf.image.resize(frames, (output_h, output_w)) frames_resized = tf.image.resize(frames, (output_h, output_w))
return tf.cast(frames_resized, frames.dtype) return tf.cast(frames_resized, frames.dtype)
should_resize = tf.math.logical_or(tf.not_equal(input_w, output_w), should_resize = tf.math.logical_or(
tf.not_equal(input_h, output_h)) tf.not_equal(input_w, output_w), tf.not_equal(input_h, output_h))
frames = tf.cond(should_resize, resize_fn, lambda: frames) frames = tf.cond(should_resize, resize_fn, lambda: frames)
return frames return frames
def random_crop_resize(frames: tf.Tensor, def random_crop_resize(frames: tf.Tensor, output_h: int, output_w: int,
output_h: int, num_frames: int, num_channels: int,
output_w: int,
num_frames: int,
num_channels: int,
aspect_ratio: Tuple[float, float], aspect_ratio: Tuple[float, float],
area_range: Tuple[float, float]) -> tf.Tensor: area_range: Tuple[float, float]) -> tf.Tensor:
"""First crops clip with jittering and then resizes to (output_h, output_w). """First crops clip with jittering and then resizes to (output_h, output_w).
...@@ -279,6 +326,7 @@ def random_crop_resize(frames: tf.Tensor, ...@@ -279,6 +326,7 @@ def random_crop_resize(frames: tf.Tensor,
num_channels: Number of channels of the clip. num_channels: Number of channels of the clip.
aspect_ratio: Float tuple with the aspect range for cropping. aspect_ratio: Float tuple with the aspect range for cropping.
area_range: Float tuple with the area range for cropping. area_range: Float tuple with the area range for cropping.
Returns: Returns:
A Tensor of shape [timesteps, output_h, output_w, channels] of type A Tensor of shape [timesteps, output_h, output_w, channels] of type
frames.dtype. frames.dtype.
...@@ -299,21 +347,16 @@ def random_crop_resize(frames: tf.Tensor, ...@@ -299,21 +347,16 @@ def random_crop_resize(frames: tf.Tensor,
bbox_begin, bbox_size, _ = sample_distorted_bbox bbox_begin, bbox_size, _ = sample_distorted_bbox
offset_y, offset_x, _ = tf.unstack(bbox_begin) offset_y, offset_x, _ = tf.unstack(bbox_begin)
target_height, target_width, _ = tf.unstack(bbox_size) target_height, target_width, _ = tf.unstack(bbox_size)
size = tf.convert_to_tensor(( size = tf.convert_to_tensor((seq_len, target_height, target_width, channels))
seq_len, target_height, target_width, channels)) offset = tf.convert_to_tensor((0, offset_y, offset_x, 0))
offset = tf.convert_to_tensor((
0, offset_y, offset_x, 0))
frames = tf.slice(frames, offset, size) frames = tf.slice(frames, offset, size)
frames = tf.cast( frames = tf.cast(tf.image.resize(frames, (output_h, output_w)), frames.dtype)
tf.image.resize(frames, (output_h, output_w)),
frames.dtype)
frames.set_shape((num_frames, output_h, output_w, num_channels)) frames.set_shape((num_frames, output_h, output_w, num_channels))
return frames return frames
def random_flip_left_right( def random_flip_left_right(frames: tf.Tensor,
frames: tf.Tensor, seed: Optional[int] = None) -> tf.Tensor:
seed: Optional[int] = None) -> tf.Tensor:
"""Flips all the frames with a probability of 50%. """Flips all the frames with a probability of 50%.
Args: Args:
...@@ -324,12 +367,16 @@ def random_flip_left_right( ...@@ -324,12 +367,16 @@ def random_flip_left_right(
A Tensor of shape [timesteps, output_h, output_w, channels] eventually A Tensor of shape [timesteps, output_h, output_w, channels] eventually
flipped left right. flipped left right.
""" """
is_flipped = tf.random.uniform( is_flipped = tf.random.uniform((),
(), minval=0, maxval=2, dtype=tf.int32, seed=seed) minval=0,
maxval=2,
frames = tf.cond(tf.equal(is_flipped, 1), dtype=tf.int32,
true_fn=lambda: tf.image.flip_left_right(frames), seed=seed)
false_fn=lambda: frames)
frames = tf.cond(
tf.equal(is_flipped, 1),
true_fn=lambda: tf.image.flip_left_right(frames),
false_fn=lambda: frames)
return frames return frames
......
...@@ -72,6 +72,16 @@ class ParserUtilsTest(tf.test.TestCase): ...@@ -72,6 +72,16 @@ class ParserUtilsTest(tf.test.TestCase):
self.assertBetween(offset_3, 0, 99) self.assertBetween(offset_3, 0, 99)
self.assertAllEqual(sampled_seq_3, range(offset_3, offset_3 + 10)) self.assertAllEqual(sampled_seq_3, range(offset_3, offset_3 + 10))
def test_sample_segment_sequence(self):
sequence = tf.range(100)
sampled_seq_1 = preprocess_ops_3d.sample_segment_sequence(
sequence, 10, False)
sampled_seq_2 = preprocess_ops_3d.sample_segment_sequence(
sequence, 10, True)
self.assertAllEqual(sampled_seq_1, [5 + i * 10 for i in range(10)])
for idx, v in enumerate(sampled_seq_2):
self.assertBetween(v - idx * 10, 0, 10)
def test_decode_jpeg(self): def test_decode_jpeg(self):
# Create a random RGB JPEG image. # Create a random RGB JPEG image.
random_image = np.random.randint(0, 256, size=(263, 320, 3), dtype=np.uint8) random_image = np.random.randint(0, 256, size=(263, 320, 3), dtype=np.uint8)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment