"docs/source/api/vscode:/vscode.git/clone" did not exist on "886aa32730618dbcac35edc8dbdbb69e826ef6bf"
Commit e02da657 authored by Yin Cui's avatar Yin Cui Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 352854870
parent fa15ed1e
# SlowOnly video classification on Kinetics-400. Expected performance to be updated. # SlowOnly video classification on Kinetics-400. Expected performance to be updated.
# #
# --experiment_type=video_classification_kinetics400 # --experiment_type=video_classification_kinetics400
# Expected accuracy: 71.5% top-1, 89.5% top-5. # Expected accuracy: 74.1% top-1, 91.4% top-5.
runtime: runtime:
distribution_strategy: 'tpu' distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16' mixed_precision_dtype: 'bfloat16'
......
...@@ -86,15 +86,14 @@ def _process_image(image: tf.Tensor, ...@@ -86,15 +86,14 @@ def _process_image(image: tf.Tensor,
# Decode JPEG string to tf.uint8. # Decode JPEG string to tf.uint8.
image = preprocess_ops_3d.decode_jpeg(image, 3) image = preprocess_ops_3d.decode_jpeg(image, 3)
# Resize images (resize happens only if necessary to save compute).
image = preprocess_ops_3d.resize_smallest(image, min_resize)
if is_training: if is_training:
# Standard image data augmentation: random crop and random flip. # Standard image data augmentation: random resized crop and random flip.
image = preprocess_ops_3d.crop_image(image, crop_size, crop_size, True, image = preprocess_ops_3d.random_crop_resize(
seed) image, crop_size, crop_size, num_frames, 3, (0.5, 2), (0.08, 1))
image = preprocess_ops_3d.random_flip_left_right(image, seed) image = preprocess_ops_3d.random_flip_left_right(image, seed)
else: else:
# Resize images (resize happens only if necessary to save compute).
image = preprocess_ops_3d.resize_smallest(image, min_resize)
# Crop of the frames. # Crop of the frames.
image = preprocess_ops_3d.crop_image(image, crop_size, crop_size, False, image = preprocess_ops_3d.crop_image(image, crop_size, crop_size, False,
num_crops) num_crops)
......
...@@ -57,6 +57,7 @@ def video_ssl_linear_eval_kinetics400() -> cfg.ExperimentConfig: ...@@ -57,6 +57,7 @@ def video_ssl_linear_eval_kinetics400() -> cfg.ExperimentConfig:
**exp.task.validation_data.as_dict()) **exp.task.validation_data.as_dict())
exp.task.validation_data.min_image_size = 256 exp.task.validation_data.min_image_size = 256
exp.task.validation_data.num_test_clips = 10 exp.task.validation_data.num_test_clips = 10
exp.task.validation_data.num_test_crops = 3
return exp return exp
...@@ -84,4 +85,5 @@ def video_ssl_linear_eval_kinetics600() -> cfg.ExperimentConfig: ...@@ -84,4 +85,5 @@ def video_ssl_linear_eval_kinetics600() -> cfg.ExperimentConfig:
exp.task.validation_data.temporal_stride = 2 exp.task.validation_data.temporal_stride = 2
exp.task.validation_data.min_image_size = 256 exp.task.validation_data.min_image_size = 256
exp.task.validation_data.num_test_clips = 10 exp.task.validation_data.num_test_clips = 10
exp.task.validation_data.num_test_crops = 3
return exp return exp
...@@ -36,8 +36,9 @@ def _process_image(image: tf.Tensor, ...@@ -36,8 +36,9 @@ def _process_image(image: tf.Tensor,
num_frames: int = 32, num_frames: int = 32,
stride: int = 1, stride: int = 1,
num_test_clips: int = 1, num_test_clips: int = 1,
min_resize: int = 224, min_resize: int = 256,
crop_size: int = 200, crop_size: int = 224,
num_crops: int = 1,
zero_centering_image: bool = False, zero_centering_image: bool = False,
seed: Optional[int] = None) -> tf.Tensor: seed: Optional[int] = None) -> tf.Tensor:
"""Processes a serialized image tensor. """Processes a serialized image tensor.
...@@ -57,6 +58,7 @@ def _process_image(image: tf.Tensor, ...@@ -57,6 +58,7 @@ def _process_image(image: tf.Tensor,
min_resize: Frames are resized so that min(height, width) is min_resize. min_resize: Frames are resized so that min(height, width) is min_resize.
crop_size: Final size of the frame after cropping the resized frames. Both crop_size: Final size of the frame after cropping the resized frames. Both
height and width are the same. height and width are the same.
num_crops: Number of crops to perform on the resized frames.
zero_centering_image: If True, frames are normalized to values in [-1, 1]. zero_centering_image: If True, frames are normalized to values in [-1, 1].
If False, values in [0, 1]. If False, values in [0, 1].
seed: A deterministic seed to use when sampling. seed: A deterministic seed to use when sampling.
...@@ -115,8 +117,8 @@ def _process_image(image: tf.Tensor, ...@@ -115,8 +117,8 @@ def _process_image(image: tf.Tensor,
# Resize images (resize happens only if necessary to save compute). # Resize images (resize happens only if necessary to save compute).
image = preprocess_ops_3d.resize_smallest(image, min_resize) image = preprocess_ops_3d.resize_smallest(image, min_resize)
# Three-crop of the frames. # Three-crop of the frames.
image = preprocess_ops_3d.crop_image(image, min_resize, min_resize, False, image = preprocess_ops_3d.crop_image(image, crop_size, crop_size, False,
3) num_crops)
# Cast the frames in float32, normalizing according to zero_centering_image. # Cast the frames in float32, normalizing according to zero_centering_image.
if is_training and is_ssl: if is_training and is_ssl:
...@@ -145,7 +147,8 @@ def _postprocess_image(image: tf.Tensor, ...@@ -145,7 +147,8 @@ def _postprocess_image(image: tf.Tensor,
is_training: bool = True, is_training: bool = True,
is_ssl: bool = False, is_ssl: bool = False,
num_frames: int = 32, num_frames: int = 32,
num_test_clips: int = 1) -> tf.Tensor: num_test_clips: int = 1,
num_test_crops: int = 1) -> tf.Tensor:
"""Processes a batched Tensor of frames. """Processes a batched Tensor of frames.
The same parameters used in process should be used here. The same parameters used in process should be used here.
...@@ -160,20 +163,24 @@ def _postprocess_image(image: tf.Tensor, ...@@ -160,20 +163,24 @@ def _postprocess_image(image: tf.Tensor,
will sample multiple linearly spaced clips within each video at test time. will sample multiple linearly spaced clips within each video at test time.
If 1, then a single clip in the middle of the video is sampled. The clips If 1, then a single clip in the middle of the video is sampled. The clips
are aggreagated in the batch dimension. are aggreagated in the batch dimension.
num_test_crops: Number of test crops (1 by default). If more than 1, there
are multiple crops for each clip at test time. If 1, there is a single
central crop. The crops are aggreagated in the batch dimension.
Returns: Returns:
Processed frames. Tensor of shape Processed frames. Tensor of shape
[batch * num_test_clips, num_frames, height, width, 3]. [batch * num_test_clips * num_test_crops, num_frames, height, width, 3].
""" """
if is_ssl and is_training: if is_ssl and is_training:
# In this case, two clips of self-supervised pre-training are merged # In this case, two clips of self-supervised pre-training are merged
# together in batch dimenstion which will be 2 * batch. # together in batch dimenstion which will be 2 * batch.
image = tf.concat(tf.split(image, num_or_size_splits=2, axis=1), axis=0) image = tf.concat(tf.split(image, num_or_size_splits=2, axis=1), axis=0)
if num_test_clips > 1 and not is_training: num_views = num_test_clips * num_test_crops
# In this case, multiple clips are merged together in batch dimenstion which if num_views > 1 and not is_training:
# will be B * num_test_clips. # In this case, multiple views are merged together in batch dimenstion which
image = tf.reshape(image, (-1, num_frames) + image.shape[2:]) # will be batch * num_views.
image = tf.reshape(image, [-1, num_frames] + image.shape[2:].as_list())
return image return image
...@@ -247,7 +254,8 @@ class Parser(video_input.Parser): ...@@ -247,7 +254,8 @@ class Parser(video_input.Parser):
stride=self._stride, stride=self._stride,
num_test_clips=self._num_test_clips, num_test_clips=self._num_test_clips,
min_resize=self._min_resize, min_resize=self._min_resize,
crop_size=self._crop_size) crop_size=self._crop_size,
num_crops=self._num_crops)
image = tf.cast(image, dtype=self._dtype) image = tf.cast(image, dtype=self._dtype)
features = {'image': image} features = {'image': image}
...@@ -293,6 +301,7 @@ class PostBatchProcessor(object): ...@@ -293,6 +301,7 @@ class PostBatchProcessor(object):
self._is_ssl = input_params.is_ssl self._is_ssl = input_params.is_ssl
self._num_frames = input_params.feature_shape[0] self._num_frames = input_params.feature_shape[0]
self._num_test_clips = input_params.num_test_clips self._num_test_clips = input_params.num_test_clips
self._num_test_crops = input_params.num_test_crops
def __call__(self, features: Dict[str, tf.Tensor], def __call__(self, features: Dict[str, tf.Tensor],
label: tf.Tensor) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]: label: tf.Tensor) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]:
...@@ -304,6 +313,7 @@ class PostBatchProcessor(object): ...@@ -304,6 +313,7 @@ class PostBatchProcessor(object):
is_training=self._is_training, is_training=self._is_training,
is_ssl=self._is_ssl, is_ssl=self._is_ssl,
num_frames=self._num_frames, num_frames=self._num_frames,
num_test_clips=self._num_test_clips) num_test_clips=self._num_test_clips,
num_test_crops=self._num_test_crops)
return features, label return features, label
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment