Commit f58534eb authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Makes video id optional for input reader.

Adds unit tests for the dataloader.
Adds segment level training pipeline test.

PiperOrigin-RevId: 445277327
parent 2f9266ac
......@@ -35,13 +35,37 @@ YT8M_VAL_PATH = 'gs://youtube8m-ml/3/frame/validate/validate*.tfrecord'
@dataclasses.dataclass
class DataConfig(cfg.DataConfig):
"""The base configuration for building datasets."""
"""The base configuration for building datasets.
Attributes:
name: Dataset name.
split: dataset split, 'train' or 'valid'.
feature_sizes: shape(length) of each feature specified in the feature_names.
feature_names: names of the features in the tf.SequenceExample.
segment_size: Number of frames in each segment.
segment_labels: Use segment level label. Default: False, video level label.
include_video_id: `True` means include video id (string) in the input to
the model.
temporal_stride: Not used. Need to deprecated.
max_frames: Maxim Number of frames in a input example. It is used to crop
the input in the temporal dimension.
num_frames: Number of frames in a single input example.
num_classes: Number of classes to classify. Assuming it is a classification
task.
num_devices: Not used. To be deprecated.
input_path: The path to the input.
is_training: Whether this data is used for training or not.
num_examples: Number of examples in the dataset. It is used to compute the
steps for train or eval. set the value to `-1` to make the experiment run
until the end of dataset.
"""
name: Optional[str] = 'yt8m'
split: Optional[str] = None
feature_sizes: Tuple[int, ...] = (1024, 128)
feature_names: Tuple[str, ...] = ('rgb', 'audio')
segment_size: int = 1
segment_labels: bool = False
include_video_id: bool = False
temporal_stride: int = 1
max_frames: int = 300
num_frames: int = 300 # set smaller to allow random sample (Parser)
......@@ -49,7 +73,6 @@ class DataConfig(cfg.DataConfig):
num_devices: int = 1
input_path: str = ''
is_training: bool = True
random_seed: int = 123
num_examples: int = -1
......
......@@ -15,8 +15,9 @@
"""Contains a collection of util functions for training and evaluating."""
from absl import logging
import numpy
import numpy as np
import tensorflow as tf
from official.vision.dataloaders import tfexample_utils
def Dequantize(feat_vector, max_quantized_value=2, min_quantized_value=-2):
......@@ -113,7 +114,7 @@ def AddEpochSummary(summary_writer,
avg_loss = epoch_info_dict["avg_loss"]
aps = epoch_info_dict["aps"]
gap = epoch_info_dict["gap"]
mean_ap = numpy.mean(aps)
mean_ap = np.mean(aps)
summary_writer.add_summary(
MakeSummary("Epoch/" + summary_scope + "_Avg_Hit@1", avg_hit_at_one),
......@@ -213,3 +214,26 @@ def CombineGradients(tower_grads):
))
return final_grads
def MakeYt8mExample(num_segment: int = 5) -> tf.train.SequenceExample:
"""Generate fake data for unit tests."""
rgb = np.random.randint(low=256, size=1024, dtype=np.uint8)
audio = np.random.randint(low=256, size=128, dtype=np.uint8)
seq_example = tf.train.SequenceExample()
seq_example.context.feature["id"].bytes_list.value[:] = [b"id001"]
seq_example.context.feature["labels"].int64_list.value[:] = [1, 2, 3, 4]
seq_example.context.feature["segment_labels"].int64_list.value[:] = (
[4] * num_segment)
seq_example.context.feature["segment_start_times"].int64_list.value[:] = [
i * 5 for i in range(num_segment)
]
seq_example.context.feature["segment_scores"].float_list.value[:] = (
[0.] * num_segment)
tfexample_utils.put_bytes_list_to_feature(
seq_example, rgb.tobytes(), key="rgb", repeat_num=120)
tfexample_utils.put_bytes_list_to_feature(
seq_example, audio.tobytes(), key="audio", repeat_num=120)
return seq_example
......@@ -81,13 +81,14 @@ def _process_segment_and_label(video_matrix, num_frames, contexts,
num_frames: Number of frames per subclip.
contexts: context information extracted from decoder
segment_labels: if we read segment labels instead.
segment_size: the segment_size used for reading segments.
segment_size: the segment_size used for reading segments. Segment length.
num_classes: a positive integer for the number of classes.
Returns:
output: dictionary containing batch information
"""
# Partition frame-level feature matrix to segment-level feature matrix.
batch_video_ids = None
if segment_labels:
start_times = contexts["segment_start_times"].values
# Here we assume all the segments that started at the same start time has
......@@ -101,8 +102,9 @@ def _process_segment_and_label(video_matrix, num_frames, contexts,
batch_video_matrix = tf.gather_nd(video_matrix,
tf.expand_dims(range_mtx, axis=-1))
num_segment = tf.shape(batch_video_matrix)[0]
batch_video_ids = tf.reshape(
tf.tile([contexts["id"]], [num_segment]), (num_segment,))
if "id" in contexts:
batch_video_ids = tf.reshape(
tf.tile([contexts["id"]], [num_segment]), (num_segment,))
batch_frames = tf.reshape(
tf.tile([segment_size], [num_segment]), (num_segment,))
batch_frames = tf.cast(tf.expand_dims(batch_frames, 1), tf.float32)
......@@ -134,18 +136,20 @@ def _process_segment_and_label(video_matrix, num_frames, contexts,
sparse_labels, default_value=False, validate_indices=False)
# convert to batch format.
batch_video_ids = tf.expand_dims(contexts["id"], 0)
if "id" in contexts:
batch_video_ids = tf.expand_dims(contexts["id"], 0)
batch_video_matrix = tf.expand_dims(video_matrix, 0)
batch_labels = tf.expand_dims(labels, 0)
batch_frames = tf.expand_dims(num_frames, 0)
batch_label_weights = None
output_dict = {
"video_ids": batch_video_ids,
"video_matrix": batch_video_matrix,
"labels": batch_labels,
"num_frames": batch_frames,
}
if batch_video_ids is not None:
output_dict["video_ids"] = batch_video_ids
if batch_label_weights is not None:
output_dict["label_weights"] = batch_label_weights
......@@ -280,12 +284,10 @@ class Parser(parser.Parser):
self._num_classes = input_params.num_classes
self._segment_size = input_params.segment_size
self._segment_labels = input_params.segment_labels
self._include_video_id = input_params.include_video_id
self._feature_names = input_params.feature_names
self._feature_sizes = input_params.feature_sizes
self.stride = input_params.temporal_stride
self._max_frames = input_params.max_frames
self._num_frames = input_params.num_frames
self._seed = input_params.random_seed
self._max_quantized_value = max_quantized_value
self._min_quantized_value = min_quantized_value
......@@ -295,6 +297,8 @@ class Parser(parser.Parser):
self.video_matrix, self.num_frames = _concat_features(
decoded_tensors["features"], self._feature_names, self._feature_sizes,
self._max_frames, self._max_quantized_value, self._min_quantized_value)
if not self._include_video_id:
del decoded_tensors["contexts"]["id"]
output_dict = _process_segment_and_label(self.video_matrix, self.num_frames,
decoded_tensors["contexts"],
self._segment_labels,
......@@ -308,6 +312,8 @@ class Parser(parser.Parser):
self.video_matrix, self.num_frames = _concat_features(
decoded_tensors["features"], self._feature_names, self._feature_sizes,
self._max_frames, self._max_quantized_value, self._min_quantized_value)
if not self._include_video_id:
del decoded_tensors["contexts"]["id"]
output_dict = _process_segment_and_label(self.video_matrix, self.num_frames,
decoded_tensors["contexts"],
self._segment_labels,
......@@ -347,7 +353,7 @@ class PostBatchProcessor():
def post_fn(self, batched_tensors):
"""Processes batched Tensors."""
video_ids = batched_tensors["video_ids"]
video_ids = batched_tensors.get("video_ids", None)
video_matrix = batched_tensors["video_matrix"]
labels = batched_tensors["labels"]
num_frames = batched_tensors["num_frames"]
......@@ -356,7 +362,8 @@ class PostBatchProcessor():
if self.segment_labels:
# [batch x num_segment x segment_size x num_features]
# -> [batch * num_segment x segment_size x num_features]
video_ids = tf.reshape(video_ids, [-1])
if video_ids is not None:
video_ids = tf.reshape(video_ids, [-1])
video_matrix = tf.reshape(video_matrix, [-1, self.segment_size, 1152])
labels = tf.reshape(labels, [-1, self.num_classes])
num_frames = tf.reshape(num_frames, [-1, 1])
......@@ -369,11 +376,12 @@ class PostBatchProcessor():
labels = tf.squeeze(labels)
batched_tensors = {
"video_ids": video_ids,
"video_matrix": video_matrix,
"labels": labels,
"num_frames": num_frames,
}
if video_ids is not None:
batched_tensors["video_ids"] = video_ids
if label_weights is not None:
batched_tensors["label_weights"] = label_weights
......@@ -388,6 +396,7 @@ class TransformBatcher():
self._segment_labels = input_params.segment_labels
self._global_batch_size = input_params.global_batch_size
self._is_training = input_params.is_training
self._include_video_id = input_params.include_video_id
def batch_fn(self, dataset, input_context):
"""Add padding when segment_labels is true."""
......@@ -398,19 +407,20 @@ class TransformBatcher():
else:
# add padding
pad_shapes = {
"video_ids": [None],
"video_matrix": [None, None, None],
"labels": [None, None],
"num_frames": [None, None],
"label_weights": [None, None]
}
pad_values = {
"video_ids": None,
"video_matrix": 0.0,
"labels": -1.0,
"num_frames": 0.0,
"label_weights": 0.0
}
if self._include_video_id:
pad_shapes["video_ids"] = [None]
pad_values["video_ids"] = None
dataset = dataset.padded_batch(
per_replica_batch_size,
padded_shapes=pad_shapes,
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from absl import logging
from absl.testing import parameterized
import tensorflow as tf
from official.core import input_reader
from official.projects.yt8m.configs import yt8m as yt8m_configs
from official.projects.yt8m.dataloaders import utils
from official.projects.yt8m.dataloaders import yt8m_input
from official.vision.dataloaders import tfexample_utils
class Yt8mInputTest(parameterized.TestCase, tf.test.TestCase):
def setUp(self):
super().setUp()
self._model_dir = os.path.join(self.get_temp_dir(), 'model_dir')
tf.io.gfile.makedirs(self._model_dir)
data_dir = os.path.join(self.get_temp_dir(), 'data')
tf.io.gfile.makedirs(data_dir)
self.data_path = os.path.join(data_dir, 'data.tfrecord')
self.num_segment = 6
examples = [utils.MakeYt8mExample(self.num_segment) for _ in range(8)]
tfexample_utils.dump_to_tfrecord(self.data_path, tf_examples=examples)
def create_input_reader(self, params):
decoder = yt8m_input.Decoder(input_params=params)
decoder_fn = decoder.decode
parser = yt8m_input.Parser(input_params=params)
parser_fn = parser.parse_fn(params.is_training)
postprocess = yt8m_input.PostBatchProcessor(input_params=params)
postprocess_fn = postprocess.post_fn
transform_batch = yt8m_input.TransformBatcher(input_params=params)
batch_fn = transform_batch.batch_fn
return input_reader.InputReader(
params,
dataset_fn=tf.data.TFRecordDataset,
decoder_fn=decoder_fn,
parser_fn=parser_fn,
postprocess_fn=postprocess_fn,
transform_and_batch_fn=batch_fn)
@parameterized.parameters((True,), (False,))
def test_read_video_level_input(self, include_video_id):
params = yt8m_configs.yt8m(is_training=False)
params.global_batch_size = 4
params.segment_labels = False
params.input_path = self.data_path
params.include_video_id = include_video_id
reader = self.create_input_reader(params)
dataset = reader.read()
iterator = iter(dataset)
example = next(iterator)
for k, v in example.items():
logging.info('DEBUG read example %r %r %r', k, v.shape, type(v))
if include_video_id:
self.assertCountEqual(
['video_matrix', 'labels', 'num_frames', 'video_ids'], example.keys())
else:
self.assertCountEqual(['video_matrix', 'labels', 'num_frames'],
example.keys())
batch_size = params.global_batch_size
self.assertEqual(
example['video_matrix'].shape.as_list(),
[batch_size, params.max_frames, sum(params.feature_sizes)])
self.assertEqual(example['labels'].shape.as_list(),
[batch_size, params.num_classes])
self.assertEqual(example['num_frames'].shape.as_list(), [batch_size, 1])
if include_video_id:
self.assertEqual(example['video_ids'].shape.as_list(), [batch_size, 1])
@parameterized.parameters((True,), (False,))
def test_read_segement_level_input(self, include_video_id):
params = yt8m_configs.yt8m(is_training=False)
params.global_batch_size = 4
params.segment_labels = True
params.input_path = self.data_path
params.include_video_id = include_video_id
reader = self.create_input_reader(params)
dataset = reader.read()
iterator = iter(dataset)
example = next(iterator)
for k, v in example.items():
logging.info('DEBUG read example %r %r %r', k, v.shape, type(v))
if include_video_id:
self.assertCountEqual([
'video_matrix', 'labels', 'num_frames', 'label_weights', 'video_ids'
], example.keys())
else:
self.assertCountEqual(
['video_matrix', 'labels', 'num_frames', 'label_weights'],
example.keys())
batch_size = params.global_batch_size * self.num_segment
self.assertEqual(
example['video_matrix'].shape.as_list(),
[batch_size, params.segment_size, sum(params.feature_sizes)])
self.assertEqual(example['labels'].shape.as_list(),
[batch_size, params.num_classes])
self.assertEqual(example['num_frames'].shape.as_list(), [batch_size, 1])
self.assertEqual(example['label_weights'].shape.as_list(),
[batch_size, params.num_classes])
if include_video_id:
self.assertEqual(example['video_ids'].shape.as_list(), [batch_size])
if __name__ == '__main__':
tf.test.main()
......@@ -27,7 +27,6 @@ task:
num_devices: 1
input_path: 'gs://youtube8m-ml/2/frame/train/train*.tfrecord'
is_training: true
random_seed: 123
validation_data:
name: 'yt8m'
split: 'train'
......@@ -46,7 +45,6 @@ task:
num_devices: 1
input_path: 'gs://youtube8m-ml/3/frame/validate/validate*.tfrecord'
is_training: false
random_seed: 123
losses:
name: 'binary_crossentropy'
from_logits: false
......
......@@ -38,9 +38,10 @@ class DbofModel(tf.keras.Model):
def __init__(
self,
params: yt8m_cfg.DbofModel,
num_frames=30,
num_classes=3862,
input_specs=layers.InputSpec(shape=[None, None, 1152]),
num_frames: int = 30,
num_classes: int = 3862,
input_specs: layers.InputSpec = layers.InputSpec(
shape=[None, None, 1152]),
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
activation: str = "relu",
use_sync_bn: bool = False,
......
......@@ -12,49 +12,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from absl import flags
from absl.testing import flagsaver
import numpy as np
from absl.testing import parameterized
import tensorflow as tf
from official.projects.yt8m import train as train_lib
from official.projects.yt8m.dataloaders import utils
from official.vision.dataloaders import tfexample_utils
FLAGS = flags.FLAGS
def make_yt8m_example():
rgb = np.random.randint(low=256, size=1024, dtype=np.uint8)
audio = np.random.randint(low=256, size=128, dtype=np.uint8)
seq_example = tf.train.SequenceExample()
seq_example.context.feature['id'].bytes_list.value[:] = [b'id001']
seq_example.context.feature['labels'].int64_list.value[:] = [1, 2, 3, 4]
tfexample_utils.put_bytes_list_to_feature(
seq_example, rgb.tobytes(), key='rgb', repeat_num=120)
tfexample_utils.put_bytes_list_to_feature(
seq_example, audio.tobytes(), key='audio', repeat_num=120)
return seq_example
class TrainTest(tf.test.TestCase):
class TrainTest(parameterized.TestCase, tf.test.TestCase):
def setUp(self):
super(TrainTest, self).setUp()
super().setUp()
self._model_dir = os.path.join(self.get_temp_dir(), 'model_dir')
tf.io.gfile.makedirs(self._model_dir)
data_dir = os.path.join(self.get_temp_dir(), 'data')
tf.io.gfile.makedirs(data_dir)
self._data_path = os.path.join(data_dir, 'data.tfrecord')
examples = [make_yt8m_example() for _ in range(8)]
examples = [utils.MakeYt8mExample() for _ in range(8)]
tfexample_utils.dump_to_tfrecord(self._data_path, tf_examples=examples)
def test_run(self):
@parameterized.named_parameters(
dict(testcase_name='segment', use_segment_level_labels=True),
dict(testcase_name='video', use_segment_level_labels=False))
def test_train_and_eval(self, use_segment_level_labels):
saved_flag_values = flagsaver.save_flag_values()
train_lib.tfm_flags.define_flags()
FLAGS.mode = 'train'
......@@ -87,6 +75,7 @@ class TrainTest(tf.test.TestCase):
},
'validation_data': {
'input_path': self._data_path,
'segment_labels': use_segment_level_labels,
'global_batch_size': 4,
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment