Commit 91a0e443 authored by Chaochao Yan's avatar Chaochao Yan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 476402446
parent 769fbfba
......@@ -22,7 +22,7 @@
back into a range between min_quantized_value and max_quantized_value.
link for details: https://research.google.com/youtube8m/download.html
"""
from typing import Dict
from typing import Any, Dict
import tensorflow as tf
from official.projects.yt8m.dataloaders import utils
......@@ -215,15 +215,11 @@ def _concat_features(features, feature_names, feature_sizes, feature_dtypes,
feature_matrices = [None] * num_features # an array of different features
for i in range(num_features):
feature_matrix, num_frames_in_this_feature = _get_video_matrix(
features[feature_names[i]],
feature_sizes[i],
tf.dtypes.as_dtype(feature_dtypes[i]),
max_frames,
max_quantized_value,
features[feature_names[i]], feature_sizes[i],
tf.dtypes.as_dtype(feature_dtypes[i]), max_frames, max_quantized_value,
min_quantized_value)
if num_frames == -1:
num_frames = num_frames_in_this_feature
feature_matrices[i] = feature_matrix
# cap the number of frames at self.max_frames
......@@ -236,7 +232,7 @@ def _concat_features(features, feature_names, feature_sizes, feature_dtypes,
class Decoder(decoder.Decoder):
"""A tf.Example decoder for classification task."""
"""A tf.train.SequeneExample decoder for classification task."""
def __init__(
self,
......@@ -270,8 +266,7 @@ class Decoder(decoder.Decoder):
"segment_scores": tf.io.VarLenFeature(tf.float32)
})
else:
self._context_features.update(
{self._label_field: tf.io.VarLenFeature(tf.int64)})
self._add_labels_specification()
for i, name in enumerate(self._feature_names):
if self._feature_from_bytes[i]:
......@@ -291,8 +286,15 @@ class Decoder(decoder.Decoder):
raise ValueError(
f"Unknow feature source {self._feature_sources[i]} for {name}")
def decode(self, serialized_example):
"""Parses a single tf.Example into image and label tensors."""
def _add_labels_specification(self):
if not self._label_field:
raise ValueError(f"Invalid label field: {self._label_field}!")
self._context_features.update(
{self._label_field: tf.io.VarLenFeature(tf.int64)})
def decode(self,
serialized_example: tf.train.SequenceExample) -> Dict[str, Any]:
"""Parses a single tf.train.SequenceExample into video and label tensors."""
contexts, features = tf.io.parse_single_sequence_example(
serialized_example,
......@@ -309,8 +311,6 @@ class Decoder(decoder.Decoder):
else:
if isinstance(decoded_tensor[name], tf.SparseTensor):
decoded_tensor[name] = tf.sparse.to_dense(decoded_tensor[name])
if not self._segment_labels:
decoded_tensor["labels"] = decoded_tensor[self._label_field]
return decoded_tensor
......@@ -349,12 +349,9 @@ class Parser(parser.Parser):
self._min_quantized_value)
if not self._include_video_id and "id" in decoded_tensors:
del decoded_tensors["id"]
output_dict = _process_segment_and_label(self.video_matrix, self.num_frames,
decoded_tensors,
self._segment_labels,
self._segment_size,
self._num_classes)
return output_dict
return self._process_label(self.video_matrix, self.num_frames,
decoded_tensors)
def _parse_eval_data(self, decoded_tensors):
"""Parses data for evaluation."""
......@@ -365,12 +362,26 @@ class Parser(parser.Parser):
self._min_quantized_value)
if not self._include_video_id and "id" in decoded_tensors:
del decoded_tensors["id"]
output_dict = _process_segment_and_label(self.video_matrix, self.num_frames,
decoded_tensors,
return self._process_label(self.video_matrix, self.num_frames,
decoded_tensors)
def _process_label(self, video_matrix, num_frames, contexts):
"""Processes a batched Tensor of frames.
Args:
video_matrix: video feature matric.
num_frames: number of frames in this video.
contexts: context information extracted from decoder.
Returns:
output: dictionary containing batch information
"""
output_dict = _process_segment_and_label(video_matrix, num_frames, contexts,
self._segment_labels,
self._segment_size,
self._num_classes)
return output_dict # batched
return output_dict
def parse_fn(self, is_training):
"""Returns a parse fn that reads and parses raw tensors from the decoder.
......@@ -394,53 +405,6 @@ class Parser(parser.Parser):
return parse
class PostBatchProcessor():
"""Processes a video and label dataset which is batched."""
def __init__(self, input_params: exp_cfg.DataConfig):
self.segment_labels = input_params.segment_labels
self.num_classes = input_params.num_classes
self.segment_size = input_params.segment_size
def post_fn(self, batched_tensors):
"""Processes batched Tensors."""
video_ids = batched_tensors.get("video_ids", None)
video_matrix = batched_tensors["video_matrix"]
labels = batched_tensors["labels"]
num_frames = batched_tensors["num_frames"]
label_weights = None
if self.segment_labels:
# [batch x num_segment x segment_size x num_features]
# -> [batch * num_segment x segment_size x num_features]
if video_ids is not None:
video_ids = tf.reshape(video_ids, [-1])
video_matrix = tf.reshape(video_matrix, [-1, self.segment_size, 1152])
labels = tf.reshape(labels, [-1, self.num_classes])
num_frames = tf.reshape(num_frames, [-1, 1])
label_weights = tf.reshape(batched_tensors["label_weights"],
[-1, self.num_classes])
else:
# NOTE(b/237445211): Must provide axis argument to tf.squeeze.
video_matrix = tf.squeeze(video_matrix, axis=1)
labels = tf.squeeze(labels, axis=1)
batched_tensors = {
"video_matrix": video_matrix,
"labels": labels,
"num_frames": num_frames,
}
if video_ids is not None:
batched_tensors["video_ids"] = video_ids
if label_weights is not None:
batched_tensors["label_weights"] = label_weights
return batched_tensors
class TransformBatcher():
"""Performs manual batching on input dataset."""
......@@ -481,3 +445,51 @@ class TransformBatcher():
drop_remainder=self._drop_remainder,
padding_values=pad_values)
return dataset
class PostBatchProcessor():
"""Processes a video and label dataset which is batched."""
def __init__(self, input_params: exp_cfg.DataConfig):
self.segment_labels = input_params.segment_labels
self.num_classes = input_params.num_classes
self.segment_size = input_params.segment_size
self.num_features = sum(input_params.feature_sizes)
def post_fn(self, batched_tensors: Dict[str,
tf.Tensor]) -> Dict[str, tf.Tensor]:
"""Processes batched Tensors."""
video_ids = batched_tensors.get("video_ids", None)
video_matrix = batched_tensors["video_matrix"]
labels = batched_tensors["labels"]
num_frames = batched_tensors["num_frames"]
if self.segment_labels:
# [batch x num_segment x segment_size x num_features]
# -> [batch * num_segment x segment_size x num_features]
if video_ids is not None:
video_ids = tf.reshape(video_ids, [-1])
video_matrix = tf.reshape(video_matrix,
[-1, self.segment_size, self.num_features])
labels = tf.reshape(labels, [-1, self.num_classes])
num_frames = tf.reshape(num_frames, [-1, 1])
batched_tensors["label_weights"] = tf.reshape(
batched_tensors["label_weights"], [-1, self.num_classes])
else:
# NOTE(b/237445211): Must provide axis argument to tf.squeeze.
video_matrix = tf.squeeze(video_matrix, axis=1)
labels = tf.squeeze(labels, axis=1)
num_frames = tf.reshape(num_frames, [-1, 1])
if "label_weights" in batched_tensors:
batched_tensors["label_weights"] = tf.squeeze(
batched_tensors["label_weights"], axis=1)
batched_tensors.update({
"video_matrix": video_matrix,
"labels": labels,
"num_frames": num_frames,
})
if video_ids is not None:
batched_tensors["video_ids"] = video_ids
return batched_tensors
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment