Commit 40cd0e14 authored by Yin Cui's avatar Yin Cui Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 352087472
parent 2b949afd
# SlowOnly video classification on Kinetics-400. Expected performance to be updated. # SlowOnly video classification on Kinetics-400. Expected performance to be updated.
#
# --experiment_type=video_classification_kinetics400
# Expected accuracy: 71.5% top-1, 89.5% top-5.
runtime: runtime:
distribution_strategy: 'tpu' distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16' mixed_precision_dtype: 'bfloat16'
...@@ -61,8 +64,9 @@ task: ...@@ -61,8 +64,9 @@ task:
- 256 - 256
- 3 - 3
temporal_stride: 8 temporal_stride: 8
num_test_clips: 1 num_test_clips: 10
global_batch_size: 32 num_test_crops: 3
global_batch_size: 64
dtype: 'bfloat16' dtype: 'bfloat16'
drop_remainder: false drop_remainder: false
trainer: trainer:
......
...@@ -34,6 +34,7 @@ class DataConfig(cfg.DataConfig): ...@@ -34,6 +34,7 @@ class DataConfig(cfg.DataConfig):
feature_shape: Tuple[int, ...] = (64, 224, 224, 3) feature_shape: Tuple[int, ...] = (64, 224, 224, 3)
temporal_stride: int = 1 temporal_stride: int = 1
num_test_clips: int = 1 num_test_clips: int = 1
num_test_crops: int = 1
num_classes: int = -1 num_classes: int = -1
num_channels: int = 3 num_channels: int = 3
num_examples: int = -1 num_examples: int = -1
......
...@@ -34,8 +34,9 @@ def _process_image(image: tf.Tensor, ...@@ -34,8 +34,9 @@ def _process_image(image: tf.Tensor,
num_frames: int = 32, num_frames: int = 32,
stride: int = 1, stride: int = 1,
num_test_clips: int = 1, num_test_clips: int = 1,
min_resize: int = 224, min_resize: int = 256,
crop_size: int = 200, crop_size: int = 224,
num_crops: int = 1,
zero_centering_image: bool = False, zero_centering_image: bool = False,
seed: Optional[int] = None) -> tf.Tensor: seed: Optional[int] = None) -> tf.Tensor:
"""Processes a serialized image tensor. """Processes a serialized image tensor.
...@@ -54,6 +55,7 @@ def _process_image(image: tf.Tensor, ...@@ -54,6 +55,7 @@ def _process_image(image: tf.Tensor,
min_resize: Frames are resized so that min(height, width) is min_resize. min_resize: Frames are resized so that min(height, width) is min_resize.
crop_size: Final size of the frame after cropping the resized frames. Both crop_size: Final size of the frame after cropping the resized frames. Both
height and width are the same. height and width are the same.
num_crops: Number of crops to perform on the resized frames.
zero_centering_image: If True, frames are normalized to values in [-1, 1]. zero_centering_image: If True, frames are normalized to values in [-1, 1].
If False, values in [0, 1]. If False, values in [0, 1].
seed: A deterministic seed to use when sampling. seed: A deterministic seed to use when sampling.
...@@ -93,8 +95,9 @@ def _process_image(image: tf.Tensor, ...@@ -93,8 +95,9 @@ def _process_image(image: tf.Tensor,
seed) seed)
image = preprocess_ops_3d.random_flip_left_right(image, seed) image = preprocess_ops_3d.random_flip_left_right(image, seed)
else: else:
# Central crop of the frames. # Crop of the frames.
image = preprocess_ops_3d.crop_image(image, crop_size, crop_size, False) image = preprocess_ops_3d.crop_image(image, crop_size, crop_size, False,
num_crops)
# Cast the frames in float32, normalizing according to zero_centering_image. # Cast the frames in float32, normalizing according to zero_centering_image.
return preprocess_ops_3d.normalize_image(image, zero_centering_image) return preprocess_ops_3d.normalize_image(image, zero_centering_image)
...@@ -103,7 +106,8 @@ def _process_image(image: tf.Tensor, ...@@ -103,7 +106,8 @@ def _process_image(image: tf.Tensor,
def _postprocess_image(image: tf.Tensor, def _postprocess_image(image: tf.Tensor,
is_training: bool = True, is_training: bool = True,
num_frames: int = 32, num_frames: int = 32,
num_test_clips: int = 1) -> tf.Tensor: num_test_clips: int = 1,
num_test_crops: int = 1) -> tf.Tensor:
"""Processes a batched Tensor of frames. """Processes a batched Tensor of frames.
The same parameters used in process should be used here. The same parameters used in process should be used here.
...@@ -117,15 +121,19 @@ def _postprocess_image(image: tf.Tensor, ...@@ -117,15 +121,19 @@ def _postprocess_image(image: tf.Tensor,
will sample multiple linearly spaced clips within each video at test time. will sample multiple linearly spaced clips within each video at test time.
If 1, then a single clip in the middle of the video is sampled. The clips If 1, then a single clip in the middle of the video is sampled. The clips
are aggreagated in the batch dimension. are aggreagated in the batch dimension.
num_test_crops: Number of test crops (1 by default). If more than 1, there
are multiple crops for each clip at test time. If 1, there is a single
central crop. The crops are aggreagated in the batch dimension.
Returns: Returns:
Processed frames. Tensor of shape Processed frames. Tensor of shape
[batch * num_test_clips, num_frames, height, width, 3]. [batch * num_test_clips * num_test_crops, num_frames, height, width, 3].
""" """
if num_test_clips > 1 and not is_training: num_views = num_test_clips * num_test_crops
# In this case, multiple clips are merged together in batch dimenstion which if num_views > 1 and not is_training:
# will be B * num_test_clips. # In this case, multiple views are merged together in batch dimenstion which
image = tf.reshape(image, (-1, num_frames) + image.shape[2:]) # will be batch * num_views.
image = tf.reshape(image, [-1, num_frames] + image.shape[2:].as_list())
return image return image
...@@ -207,6 +215,7 @@ class Parser(parser.Parser): ...@@ -207,6 +215,7 @@ class Parser(parser.Parser):
self._num_test_clips = input_params.num_test_clips self._num_test_clips = input_params.num_test_clips
self._min_resize = input_params.min_image_size self._min_resize = input_params.min_image_size
self._crop_size = input_params.feature_shape[1] self._crop_size = input_params.feature_shape[1]
self._num_crops = input_params.num_test_crops
self._one_hot_label = input_params.one_hot self._one_hot_label = input_params.one_hot
self._num_classes = input_params.num_classes self._num_classes = input_params.num_classes
self._image_key = image_key self._image_key = image_key
...@@ -260,7 +269,8 @@ class Parser(parser.Parser): ...@@ -260,7 +269,8 @@ class Parser(parser.Parser):
stride=self._stride, stride=self._stride,
num_test_clips=self._num_test_clips, num_test_clips=self._num_test_clips,
min_resize=self._min_resize, min_resize=self._min_resize,
crop_size=self._crop_size) crop_size=self._crop_size,
num_crops=self._num_crops)
image = tf.cast(image, dtype=self._dtype) image = tf.cast(image, dtype=self._dtype)
features = {'image': image} features = {'image': image}
...@@ -286,6 +296,7 @@ class PostBatchProcessor(object): ...@@ -286,6 +296,7 @@ class PostBatchProcessor(object):
self._num_frames = input_params.feature_shape[0] self._num_frames = input_params.feature_shape[0]
self._num_test_clips = input_params.num_test_clips self._num_test_clips = input_params.num_test_clips
self._num_test_crops = input_params.num_test_crops
def __call__(self, features: Dict[str, tf.Tensor], def __call__(self, features: Dict[str, tf.Tensor],
label: tf.Tensor) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]: label: tf.Tensor) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]:
...@@ -296,6 +307,7 @@ class PostBatchProcessor(object): ...@@ -296,6 +307,7 @@ class PostBatchProcessor(object):
image=features[key], image=features[key],
is_training=self._is_training, is_training=self._is_training,
num_frames=self._num_frames, num_frames=self._num_frames,
num_test_clips=self._num_test_clips) num_test_clips=self._num_test_clips,
num_test_crops=self._num_test_crops)
return features, label return features, label
...@@ -151,19 +151,19 @@ def crop_image(frames: tf.Tensor, ...@@ -151,19 +151,19 @@ def crop_image(frames: tf.Tensor,
target_height: int, target_height: int,
target_width: int, target_width: int,
random: bool = False, random: bool = False,
num_views: int = 1, num_crops: int = 1,
seed: Optional[int] = None) -> tf.Tensor: seed: Optional[int] = None) -> tf.Tensor:
"""Crops the image sequence of images. """Crops the image sequence of images.
If requested size is bigger than image size, image is padded with 0. If not If requested size is bigger than image size, image is padded with 0. If not
random cropping, a central crop is performed. random cropping, a central crop is performed if num_crops is 1.
Args: Args:
frames: A Tensor of dimension [timesteps, in_height, in_width, channels]. frames: A Tensor of dimension [timesteps, in_height, in_width, channels].
target_height: Target cropped image height. target_height: Target cropped image height.
target_width: Target cropped image width. target_width: Target cropped image width.
random: A boolean indicating if crop should be randomized. random: A boolean indicating if crop should be randomized.
num_views: Number of views to crop in evaluation. num_crops: Number of crops (support 1 for central crop and 3 for 3-crop).
seed: A deterministic seed to use when random cropping. seed: A deterministic seed to use when random cropping.
Returns: Returns:
...@@ -181,13 +181,13 @@ def crop_image(frames: tf.Tensor, ...@@ -181,13 +181,13 @@ def crop_image(frames: tf.Tensor,
frames = tf.image.random_crop( frames = tf.image.random_crop(
frames, (seq_len, target_height, target_width, channels), seed) frames, (seq_len, target_height, target_width, channels), seed)
else: else:
if num_views == 1: if num_crops == 1:
# Central crop or pad. # Central crop or pad.
frames = tf.image.resize_with_crop_or_pad(frames, target_height, frames = tf.image.resize_with_crop_or_pad(frames, target_height,
target_width) target_width)
elif num_views == 3: elif num_crops == 3:
# Three-view evaluation. # Three-crop evaluation.
shape = tf.shape(frames) shape = tf.shape(frames)
static_shape = frames.shape.as_list() static_shape = frames.shape.as_list()
seq_len = shape[0] if static_shape[0] is None else static_shape[0] seq_len = shape[0] if static_shape[0] is None else static_shape[0]
...@@ -224,7 +224,7 @@ def crop_image(frames: tf.Tensor, ...@@ -224,7 +224,7 @@ def crop_image(frames: tf.Tensor,
else: else:
raise NotImplementedError( raise NotImplementedError(
f"Only 1 crop and 3 crop are supported. Found {num_views!r}.") f"Only 1-crop and 3-crop are supported. Found {num_crops!r}.")
return frames return frames
......
...@@ -275,4 +275,11 @@ class VideoClassificationTask(base_task.Task): ...@@ -275,4 +275,11 @@ class VideoClassificationTask(base_task.Task):
outputs = tf.math.sigmoid(outputs) outputs = tf.math.sigmoid(outputs)
else: else:
outputs = tf.math.softmax(outputs) outputs = tf.math.softmax(outputs)
num_test_clips = self.task_config.validation_data.num_test_clips
num_test_crops = self.task_config.validation_data.num_test_crops
num_test_views = num_test_clips * num_test_crops
if num_test_views > 1:
# Averaging output probabilities across multiples views.
outputs = tf.reshape(outputs, [-1, num_test_views, outputs.shape[-1]])
outputs = tf.reduce_mean(outputs, axis=1)
return outputs return outputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment