Commit 9c279893 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Small refactors of the video input interface and fixes bugs.

PiperOrigin-RevId: 361233109
parent f6fcea21
...@@ -29,20 +29,20 @@ IMAGE_KEY = 'image/encoded' ...@@ -29,20 +29,20 @@ IMAGE_KEY = 'image/encoded'
LABEL_KEY = 'clip/label/index' LABEL_KEY = 'clip/label/index'
def _process_image(image: tf.Tensor, def process_image(image: tf.Tensor,
is_training: bool = True, is_training: bool = True,
num_frames: int = 32, num_frames: int = 32,
stride: int = 1, stride: int = 1,
num_test_clips: int = 1, num_test_clips: int = 1,
min_resize: int = 256, min_resize: int = 256,
crop_size: int = 224, crop_size: int = 224,
num_crops: int = 1, num_crops: int = 1,
zero_centering_image: bool = False, zero_centering_image: bool = False,
min_aspect_ratio: float = 0.5, min_aspect_ratio: float = 0.5,
max_aspect_ratio: float = 2, max_aspect_ratio: float = 2,
min_area_ratio: float = 0.49, min_area_ratio: float = 0.49,
max_area_ratio: float = 1.0, max_area_ratio: float = 1.0,
seed: Optional[int] = None) -> tf.Tensor: seed: Optional[int] = None) -> tf.Tensor:
"""Processes a serialized image tensor. """Processes a serialized image tensor.
Args: Args:
...@@ -112,11 +112,11 @@ def _process_image(image: tf.Tensor, ...@@ -112,11 +112,11 @@ def _process_image(image: tf.Tensor,
return preprocess_ops_3d.normalize_image(image, zero_centering_image) return preprocess_ops_3d.normalize_image(image, zero_centering_image)
def _postprocess_image(image: tf.Tensor, def postprocess_image(image: tf.Tensor,
is_training: bool = True, is_training: bool = True,
num_frames: int = 32, num_frames: int = 32,
num_test_clips: int = 1, num_test_clips: int = 1,
num_test_crops: int = 1) -> tf.Tensor: num_test_crops: int = 1) -> tf.Tensor:
"""Processes a batched Tensor of frames. """Processes a batched Tensor of frames.
The same parameters used in process should be used here. The same parameters used in process should be used here.
...@@ -147,9 +147,9 @@ def _postprocess_image(image: tf.Tensor, ...@@ -147,9 +147,9 @@ def _postprocess_image(image: tf.Tensor,
return image return image
def _process_label(label: tf.Tensor, def process_label(label: tf.Tensor,
one_hot_label: bool = True, one_hot_label: bool = True,
num_classes: Optional[int] = None) -> tf.Tensor: num_classes: Optional[int] = None) -> tf.Tensor:
"""Processes label Tensor.""" """Processes label Tensor."""
# Validate parameters. # Validate parameters.
if one_hot_label and not num_classes: if one_hot_label and not num_classes:
...@@ -175,15 +175,13 @@ class Decoder(decoder.Decoder): ...@@ -175,15 +175,13 @@ class Decoder(decoder.Decoder):
"""A tf.Example decoder for classification task.""" """A tf.Example decoder for classification task."""
def __init__(self, image_key: str = IMAGE_KEY, label_key: str = LABEL_KEY): def __init__(self, image_key: str = IMAGE_KEY, label_key: str = LABEL_KEY):
self._image_key = image_key
self._label_key = label_key
self._context_description = { self._context_description = {
# One integer stored in context. # One integer stored in context.
self._label_key: tf.io.VarLenFeature(tf.int64), label_key: tf.io.VarLenFeature(tf.int64),
} }
self._sequence_description = { self._sequence_description = {
# Each image is a string encoding JPEG. # Each image is a string encoding JPEG.
self._image_key: tf.io.FixedLenSequenceFeature((), tf.string), image_key: tf.io.FixedLenSequenceFeature((), tf.string),
} }
def add_feature(self, feature_name: str, def add_feature(self, feature_name: str,
...@@ -245,7 +243,7 @@ class Parser(parser.Parser): ...@@ -245,7 +243,7 @@ class Parser(parser.Parser):
"""Parses data for training.""" """Parses data for training."""
# Process image and label. # Process image and label.
image = decoded_tensors[self._image_key] image = decoded_tensors[self._image_key]
image = _process_image( image = process_image(
image=image, image=image,
is_training=True, is_training=True,
num_frames=self._num_frames, num_frames=self._num_frames,
...@@ -261,7 +259,7 @@ class Parser(parser.Parser): ...@@ -261,7 +259,7 @@ class Parser(parser.Parser):
features = {'image': image} features = {'image': image}
label = decoded_tensors[self._label_key] label = decoded_tensors[self._label_key]
label = _process_label(label, self._one_hot_label, self._num_classes) label = process_label(label, self._one_hot_label, self._num_classes)
if self._output_audio: if self._output_audio:
audio = decoded_tensors[self._audio_feature] audio = decoded_tensors[self._audio_feature]
...@@ -279,7 +277,7 @@ class Parser(parser.Parser): ...@@ -279,7 +277,7 @@ class Parser(parser.Parser):
) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]: ) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]:
"""Parses data for evaluation.""" """Parses data for evaluation."""
image = decoded_tensors[self._image_key] image = decoded_tensors[self._image_key]
image = _process_image( image = process_image(
image=image, image=image,
is_training=False, is_training=False,
num_frames=self._num_frames, num_frames=self._num_frames,
...@@ -292,14 +290,14 @@ class Parser(parser.Parser): ...@@ -292,14 +290,14 @@ class Parser(parser.Parser):
features = {'image': image} features = {'image': image}
label = decoded_tensors[self._label_key] label = decoded_tensors[self._label_key]
label = _process_label(label, self._one_hot_label, self._num_classes) label = process_label(label, self._one_hot_label, self._num_classes)
if self._output_audio: if self._output_audio:
audio = decoded_tensors[self._audio_feature] audio = decoded_tensors[self._audio_feature]
audio = tf.cast(audio, dtype=self._dtype) audio = tf.cast(audio, dtype=self._dtype)
audio = preprocess_ops_3d.sample_sequence( audio = preprocess_ops_3d.sample_sequence(
audio, 20, random=False, stride=1) audio, 20, random=False, stride=1)
audio = tf.ensure_shape(audio, [20, 2048]) audio = tf.ensure_shape(audio, self._audio_shape)
features['audio'] = audio features['audio'] = audio
return features, label return features, label
...@@ -318,9 +316,9 @@ class PostBatchProcessor(object): ...@@ -318,9 +316,9 @@ class PostBatchProcessor(object):
def __call__(self, features: Dict[str, tf.Tensor], def __call__(self, features: Dict[str, tf.Tensor],
label: tf.Tensor) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]: label: tf.Tensor) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]:
"""Parses a single tf.Example into image and label tensors.""" """Parses a single tf.Example into image and label tensors."""
for key in ['image', 'audio']: for key in ['image']:
if key in features: if key in features:
features[key] = _postprocess_image( features[key] = postprocess_image(
image=features[key], image=features[key],
is_training=self._is_training, is_training=self._is_training,
num_frames=self._num_frames, num_frames=self._num_frames,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment