Commit 40cd0e14 authored by Yin Cui's avatar Yin Cui Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 352087472
parent 2b949afd
# SlowOnly video classification on Kinetics-400. Expected performance to be updated.
#
# --experiment_type=video_classification_kinetics400
# Expected accuracy: 71.5% top-1, 89.5% top-5.
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
......@@ -61,8 +64,9 @@ task:
- 256
- 3
temporal_stride: 8
num_test_clips: 1
global_batch_size: 32
num_test_clips: 10
num_test_crops: 3
global_batch_size: 64
dtype: 'bfloat16'
drop_remainder: false
trainer:
......
......@@ -34,6 +34,7 @@ class DataConfig(cfg.DataConfig):
feature_shape: Tuple[int, ...] = (64, 224, 224, 3)
temporal_stride: int = 1
num_test_clips: int = 1
num_test_crops: int = 1
num_classes: int = -1
num_channels: int = 3
num_examples: int = -1
......
......@@ -34,8 +34,9 @@ def _process_image(image: tf.Tensor,
num_frames: int = 32,
stride: int = 1,
num_test_clips: int = 1,
min_resize: int = 224,
crop_size: int = 200,
min_resize: int = 256,
crop_size: int = 224,
num_crops: int = 1,
zero_centering_image: bool = False,
seed: Optional[int] = None) -> tf.Tensor:
"""Processes a serialized image tensor.
......@@ -54,6 +55,7 @@ def _process_image(image: tf.Tensor,
min_resize: Frames are resized so that min(height, width) is min_resize.
crop_size: Final size of the frame after cropping the resized frames. Both
height and width are the same.
num_crops: Number of crops to perform on the resized frames.
zero_centering_image: If True, frames are normalized to values in [-1, 1].
If False, values in [0, 1].
seed: A deterministic seed to use when sampling.
......@@ -93,8 +95,9 @@ def _process_image(image: tf.Tensor,
seed)
image = preprocess_ops_3d.random_flip_left_right(image, seed)
else:
# Central crop of the frames.
image = preprocess_ops_3d.crop_image(image, crop_size, crop_size, False)
# Crop of the frames.
image = preprocess_ops_3d.crop_image(image, crop_size, crop_size, False,
num_crops)
# Cast the frames in float32, normalizing according to zero_centering_image.
return preprocess_ops_3d.normalize_image(image, zero_centering_image)
......@@ -103,7 +106,8 @@ def _process_image(image: tf.Tensor,
def _postprocess_image(image: tf.Tensor,
is_training: bool = True,
num_frames: int = 32,
num_test_clips: int = 1) -> tf.Tensor:
num_test_clips: int = 1,
num_test_crops: int = 1) -> tf.Tensor:
"""Processes a batched Tensor of frames.
The same parameters used in process should be used here.
......@@ -117,15 +121,19 @@ def _postprocess_image(image: tf.Tensor,
will sample multiple linearly spaced clips within each video at test time.
If 1, then a single clip in the middle of the video is sampled. The clips
are aggreagated in the batch dimension.
num_test_crops: Number of test crops (1 by default). If more than 1, there
are multiple crops for each clip at test time. If 1, there is a single
central crop. The crops are aggreagated in the batch dimension.
Returns:
Processed frames. Tensor of shape
[batch * num_test_clips, num_frames, height, width, 3].
[batch * num_test_clips * num_test_crops, num_frames, height, width, 3].
"""
if num_test_clips > 1 and not is_training:
# In this case, multiple clips are merged together in batch dimenstion which
# will be B * num_test_clips.
image = tf.reshape(image, (-1, num_frames) + image.shape[2:])
num_views = num_test_clips * num_test_crops
if num_views > 1 and not is_training:
# In this case, multiple views are merged together in batch dimenstion which
# will be batch * num_views.
image = tf.reshape(image, [-1, num_frames] + image.shape[2:].as_list())
return image
......@@ -207,6 +215,7 @@ class Parser(parser.Parser):
self._num_test_clips = input_params.num_test_clips
self._min_resize = input_params.min_image_size
self._crop_size = input_params.feature_shape[1]
self._num_crops = input_params.num_test_crops
self._one_hot_label = input_params.one_hot
self._num_classes = input_params.num_classes
self._image_key = image_key
......@@ -260,7 +269,8 @@ class Parser(parser.Parser):
stride=self._stride,
num_test_clips=self._num_test_clips,
min_resize=self._min_resize,
crop_size=self._crop_size)
crop_size=self._crop_size,
num_crops=self._num_crops)
image = tf.cast(image, dtype=self._dtype)
features = {'image': image}
......@@ -286,6 +296,7 @@ class PostBatchProcessor(object):
self._num_frames = input_params.feature_shape[0]
self._num_test_clips = input_params.num_test_clips
self._num_test_crops = input_params.num_test_crops
def __call__(self, features: Dict[str, tf.Tensor],
label: tf.Tensor) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]:
......@@ -296,6 +307,7 @@ class PostBatchProcessor(object):
image=features[key],
is_training=self._is_training,
num_frames=self._num_frames,
num_test_clips=self._num_test_clips)
num_test_clips=self._num_test_clips,
num_test_crops=self._num_test_crops)
return features, label
......@@ -151,19 +151,19 @@ def crop_image(frames: tf.Tensor,
target_height: int,
target_width: int,
random: bool = False,
num_views: int = 1,
num_crops: int = 1,
seed: Optional[int] = None) -> tf.Tensor:
"""Crops the image sequence of images.
If requested size is bigger than image size, image is padded with 0. If not
random cropping, a central crop is performed.
random cropping, a central crop is performed if num_crops is 1.
Args:
frames: A Tensor of dimension [timesteps, in_height, in_width, channels].
target_height: Target cropped image height.
target_width: Target cropped image width.
random: A boolean indicating if crop should be randomized.
num_views: Number of views to crop in evaluation.
num_crops: Number of crops (support 1 for central crop and 3 for 3-crop).
seed: A deterministic seed to use when random cropping.
Returns:
......@@ -181,13 +181,13 @@ def crop_image(frames: tf.Tensor,
frames = tf.image.random_crop(
frames, (seq_len, target_height, target_width, channels), seed)
else:
if num_views == 1:
if num_crops == 1:
# Central crop or pad.
frames = tf.image.resize_with_crop_or_pad(frames, target_height,
target_width)
elif num_views == 3:
# Three-view evaluation.
elif num_crops == 3:
# Three-crop evaluation.
shape = tf.shape(frames)
static_shape = frames.shape.as_list()
seq_len = shape[0] if static_shape[0] is None else static_shape[0]
......@@ -224,7 +224,7 @@ def crop_image(frames: tf.Tensor,
else:
raise NotImplementedError(
f"Only 1 crop and 3 crop are supported. Found {num_views!r}.")
f"Only 1-crop and 3-crop are supported. Found {num_crops!r}.")
return frames
......
......@@ -275,4 +275,11 @@ class VideoClassificationTask(base_task.Task):
outputs = tf.math.sigmoid(outputs)
else:
outputs = tf.math.softmax(outputs)
num_test_clips = self.task_config.validation_data.num_test_clips
num_test_crops = self.task_config.validation_data.num_test_crops
num_test_views = num_test_clips * num_test_crops
if num_test_views > 1:
# Averaging output probabilities across multiples views.
outputs = tf.reshape(outputs, [-1, num_test_views, outputs.shape[-1]])
outputs = tf.reduce_mean(outputs, axis=1)
return outputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment