"vscode:/vscode.git/clone" did not exist on "98ca260bc834ec94a8143e4b5cfe9516b0b951a2"
Commit ebac9847 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Small refactors of the video input interface and fixes bugs.

PiperOrigin-RevId: 361233109
parent 5c02f1ef
...@@ -29,20 +29,20 @@ IMAGE_KEY = 'image/encoded' ...@@ -29,20 +29,20 @@ IMAGE_KEY = 'image/encoded'
LABEL_KEY = 'clip/label/index' LABEL_KEY = 'clip/label/index'
def _process_image(image: tf.Tensor, def process_image(image: tf.Tensor,
is_training: bool = True, is_training: bool = True,
num_frames: int = 32, num_frames: int = 32,
stride: int = 1, stride: int = 1,
num_test_clips: int = 1, num_test_clips: int = 1,
min_resize: int = 256, min_resize: int = 256,
crop_size: int = 224, crop_size: int = 224,
num_crops: int = 1, num_crops: int = 1,
zero_centering_image: bool = False, zero_centering_image: bool = False,
min_aspect_ratio: float = 0.5, min_aspect_ratio: float = 0.5,
max_aspect_ratio: float = 2, max_aspect_ratio: float = 2,
min_area_ratio: float = 0.49, min_area_ratio: float = 0.49,
max_area_ratio: float = 1.0, max_area_ratio: float = 1.0,
seed: Optional[int] = None) -> tf.Tensor: seed: Optional[int] = None) -> tf.Tensor:
"""Processes a serialized image tensor. """Processes a serialized image tensor.
Args: Args:
...@@ -112,11 +112,11 @@ def _process_image(image: tf.Tensor, ...@@ -112,11 +112,11 @@ def _process_image(image: tf.Tensor,
return preprocess_ops_3d.normalize_image(image, zero_centering_image) return preprocess_ops_3d.normalize_image(image, zero_centering_image)
def _postprocess_image(image: tf.Tensor, def postprocess_image(image: tf.Tensor,
is_training: bool = True, is_training: bool = True,
num_frames: int = 32, num_frames: int = 32,
num_test_clips: int = 1, num_test_clips: int = 1,
num_test_crops: int = 1) -> tf.Tensor: num_test_crops: int = 1) -> tf.Tensor:
"""Processes a batched Tensor of frames. """Processes a batched Tensor of frames.
The same parameters used in process should be used here. The same parameters used in process should be used here.
...@@ -147,9 +147,9 @@ def _postprocess_image(image: tf.Tensor, ...@@ -147,9 +147,9 @@ def _postprocess_image(image: tf.Tensor,
return image return image
def _process_label(label: tf.Tensor, def process_label(label: tf.Tensor,
one_hot_label: bool = True, one_hot_label: bool = True,
num_classes: Optional[int] = None) -> tf.Tensor: num_classes: Optional[int] = None) -> tf.Tensor:
"""Processes label Tensor.""" """Processes label Tensor."""
# Validate parameters. # Validate parameters.
if one_hot_label and not num_classes: if one_hot_label and not num_classes:
...@@ -175,15 +175,13 @@ class Decoder(decoder.Decoder): ...@@ -175,15 +175,13 @@ class Decoder(decoder.Decoder):
"""A tf.Example decoder for classification task.""" """A tf.Example decoder for classification task."""
def __init__(self, image_key: str = IMAGE_KEY, label_key: str = LABEL_KEY): def __init__(self, image_key: str = IMAGE_KEY, label_key: str = LABEL_KEY):
self._image_key = image_key
self._label_key = label_key
self._context_description = { self._context_description = {
# One integer stored in context. # One integer stored in context.
self._label_key: tf.io.VarLenFeature(tf.int64), label_key: tf.io.VarLenFeature(tf.int64),
} }
self._sequence_description = { self._sequence_description = {
# Each image is a string encoding JPEG. # Each image is a string encoding JPEG.
self._image_key: tf.io.FixedLenSequenceFeature((), tf.string), image_key: tf.io.FixedLenSequenceFeature((), tf.string),
} }
def add_feature(self, feature_name: str, def add_feature(self, feature_name: str,
...@@ -245,7 +243,7 @@ class Parser(parser.Parser): ...@@ -245,7 +243,7 @@ class Parser(parser.Parser):
"""Parses data for training.""" """Parses data for training."""
# Process image and label. # Process image and label.
image = decoded_tensors[self._image_key] image = decoded_tensors[self._image_key]
image = _process_image( image = process_image(
image=image, image=image,
is_training=True, is_training=True,
num_frames=self._num_frames, num_frames=self._num_frames,
...@@ -261,7 +259,7 @@ class Parser(parser.Parser): ...@@ -261,7 +259,7 @@ class Parser(parser.Parser):
features = {'image': image} features = {'image': image}
label = decoded_tensors[self._label_key] label = decoded_tensors[self._label_key]
label = _process_label(label, self._one_hot_label, self._num_classes) label = process_label(label, self._one_hot_label, self._num_classes)
if self._output_audio: if self._output_audio:
audio = decoded_tensors[self._audio_feature] audio = decoded_tensors[self._audio_feature]
...@@ -279,7 +277,7 @@ class Parser(parser.Parser): ...@@ -279,7 +277,7 @@ class Parser(parser.Parser):
) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]: ) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]:
"""Parses data for evaluation.""" """Parses data for evaluation."""
image = decoded_tensors[self._image_key] image = decoded_tensors[self._image_key]
image = _process_image( image = process_image(
image=image, image=image,
is_training=False, is_training=False,
num_frames=self._num_frames, num_frames=self._num_frames,
...@@ -292,14 +290,14 @@ class Parser(parser.Parser): ...@@ -292,14 +290,14 @@ class Parser(parser.Parser):
features = {'image': image} features = {'image': image}
label = decoded_tensors[self._label_key] label = decoded_tensors[self._label_key]
label = _process_label(label, self._one_hot_label, self._num_classes) label = process_label(label, self._one_hot_label, self._num_classes)
if self._output_audio: if self._output_audio:
audio = decoded_tensors[self._audio_feature] audio = decoded_tensors[self._audio_feature]
audio = tf.cast(audio, dtype=self._dtype) audio = tf.cast(audio, dtype=self._dtype)
audio = preprocess_ops_3d.sample_sequence( audio = preprocess_ops_3d.sample_sequence(
audio, 20, random=False, stride=1) audio, 20, random=False, stride=1)
audio = tf.ensure_shape(audio, [20, 2048]) audio = tf.ensure_shape(audio, self._audio_shape)
features['audio'] = audio features['audio'] = audio
return features, label return features, label
...@@ -318,9 +316,9 @@ class PostBatchProcessor(object): ...@@ -318,9 +316,9 @@ class PostBatchProcessor(object):
def __call__(self, features: Dict[str, tf.Tensor], def __call__(self, features: Dict[str, tf.Tensor],
label: tf.Tensor) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]: label: tf.Tensor) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]:
"""Parses a single tf.Example into image and label tensors.""" """Parses a single tf.Example into image and label tensors."""
for key in ['image', 'audio']: for key in ['image']:
if key in features: if key in features:
features[key] = _postprocess_image( features[key] = postprocess_image(
image=features[key], image=features[key],
is_training=self._is_training, is_training=self._is_training,
num_frames=self._num_frames, num_frames=self._num_frames,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment