Commit 6941537b authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Supports reading floating point features.

Adds more options to configure reading features from tf.SequenceExample.

PiperOrigin-RevId: 445716499
parent 496a77e8
......@@ -42,6 +42,9 @@ class DataConfig(cfg.DataConfig):
split: dataset split, 'train' or 'valid'.
feature_sizes: shape(length) of each feature specified in the feature_names.
feature_names: names of the features in the tf.SequenceExample.
feature_sources: if the feature from 'context' or 'features'.
feature_dtypes: dtype of decoded feature.
feature_from_bytes: decode feature from bytes or as dtype list.
segment_size: Number of frames in each segment.
segment_labels: Use segment level label. Default: False, video level label.
include_video_id: `True` means include video id (string) in the input to
......@@ -58,11 +61,15 @@ class DataConfig(cfg.DataConfig):
num_examples: Number of examples in the dataset. It is used to compute the
steps for train or eval. set the value to `-1` to make the experiment run
until the end of dataset.
file_type: type of input files.
"""
name: Optional[str] = 'yt8m'
split: Optional[str] = None
feature_sizes: Tuple[int, ...] = (1024, 128)
feature_names: Tuple[str, ...] = ('rgb', 'audio')
feature_sources: Tuple[str, ...] = ('feature', 'feature')
feature_dtypes: Tuple[str, ...] = ('uint8', 'uint8')
feature_from_bytes: Tuple[bool, ...] = (True, True)
segment_size: int = 1
segment_labels: bool = False
include_video_id: bool = False
......@@ -74,6 +81,7 @@ class DataConfig(cfg.DataConfig):
input_path: str = ''
is_training: bool = True
num_examples: int = -1
file_type: str = 'tfrecord'
def yt8m(is_training):
......@@ -152,7 +160,7 @@ def add_trainer(
experiment.task.train_data.global_batch_size = train_batch_size
experiment.task.validation_data.global_batch_size = eval_batch_size
steps_per_epoch = YT8M_TRAIN_EXAMPLES // train_batch_size
steps_per_loop = 30
steps_per_loop = 500
experiment.trainer = cfg.TrainerConfig(
steps_per_loop=steps_per_loop,
summary_interval=steps_per_loop,
......@@ -199,14 +207,16 @@ def yt8m_experiment() -> cfg.ExperimentConfig:
'task.train_data.num_classes == task.validation_data.num_classes',
'task.train_data.feature_sizes != None',
'task.train_data.feature_names != None',
'task.train_data.feature_sources != None',
'task.train_data.feature_dtypes != None',
])
# Per TPUv3 Core batch size 16GB HBM. `factor` in range(1, 26)
factor = 1
num_cores = 32 # for TPU 4x4
num_cores = 32 # for TPUv3 4x4
train_per_core_bs = 32 * factor
train_bs = train_per_core_bs * num_cores
eval_per_core_bs = 32 * 50 # multiplier<=100
eval_per_core_bs = 4 * 50 # multiplier<=100
eval_bs = eval_per_core_bs * num_cores
# based lr=0.0001 for bs=512
return add_trainer(
......
......@@ -237,3 +237,30 @@ def MakeYt8mExample(num_segment: int = 5) -> tf.train.SequenceExample:
seq_example, audio.tobytes(), key="audio", repeat_num=120)
return seq_example
# TODO(yeqing): Move the test related functions to test_utils.
def MakeExampleWithFloatFeatures(
num_segment: int = 5) -> tf.train.SequenceExample:
"""Generate fake data for unit tests."""
rgb = np.random.rand(1, 2048).astype(np.float32)
audio = np.random.rand(256).astype(np.float32)
seq_example = tf.train.SequenceExample()
seq_example.context.feature["id"].bytes_list.value[:] = [b"id001"]
seq_example.context.feature["labels"].int64_list.value[:] = [1, 2, 3, 4]
seq_example.context.feature["segment_labels"].int64_list.value[:] = (
[4] * num_segment)
seq_example.context.feature["segment_start_times"].int64_list.value[:] = [
i * 5 for i in range(num_segment)
]
seq_example.context.feature["segment_scores"].float_list.value[:] = (
[0.] * num_segment)
seq_example.context.feature[
"VIDEO_EMBEDDING/context_feature/floats"].float_list.value[:] = (
audio.tolist())
tfexample_utils.put_float_list_to_feature(
seq_example, rgb.tolist(), key="FEATURE/feature/floats")
return seq_example
......@@ -156,14 +156,15 @@ def _process_segment_and_label(video_matrix, num_frames, contexts,
return output_dict
def _get_video_matrix(features, feature_size, max_frames, max_quantized_value,
min_quantized_value):
def _get_video_matrix(features, feature_size, dtype, max_frames,
max_quantized_value, min_quantized_value):
"""Decodes features from an input string and quantizes it.
Args:
features: raw feature values
feature_size: length of each frame feature vector
max_frames: number of frames (rows) in the output feature_matrix
features: raw feature values.
feature_size: length of each frame feature vector.
dtype: raw type of the feature.
max_frames: number of frames (rows) in the output feature_matrix.
max_quantized_value: the maximum of the quantized value.
min_quantized_value: the minimum of the quantized value.
......@@ -171,25 +172,27 @@ def _get_video_matrix(features, feature_size, max_frames, max_quantized_value,
feature_matrix: matrix of all frame-features
num_frames: number of frames in the sequence
"""
decoded_features = tf.reshape(
tf.cast(tf.io.decode_raw(features, tf.uint8), tf.float32),
[-1, feature_size])
decoded_features = tf.reshape(features, [-1, feature_size])
num_frames = tf.math.minimum(tf.shape(decoded_features)[0], max_frames)
if dtype.is_integer:
feature_matrix = utils.Dequantize(decoded_features, max_quantized_value,
min_quantized_value)
else:
feature_matrix = decoded_features
feature_matrix = resize_axis(feature_matrix, 0, max_frames)
return feature_matrix, num_frames
def _concat_features(features, feature_names, feature_sizes, max_frames,
max_quantized_value, min_quantized_value):
def _concat_features(features, feature_names, feature_sizes, feature_dtypes,
max_frames, max_quantized_value, min_quantized_value):
"""Loads (potentially) different types of features and concatenates them.
Args:
features: raw feature values
feature_names: list of feature names
feature_sizes: list of features sizes
feature_dtypes: dtype of the feature.
max_frames: number of frames in the sequence
max_quantized_value: the maximum of the quantized value.
min_quantized_value: the minimum of the quantized value.
......@@ -205,17 +208,24 @@ def _concat_features(features, feature_names, feature_sizes, max_frames,
assert len(feature_names) == len(feature_sizes), (
"length of feature_names (={}) != length of feature_sizes (={})".format(
len(feature_names), len(feature_sizes)))
assert len(feature_names) == len(feature_dtypes), (
"length of feature_names (={}) != length of feature_sizes (={})".format(
len(feature_names), len(feature_dtypes)))
num_frames = -1 # the number of frames in the video
feature_matrices = [None] * num_features # an array of different features
for feature_index in range(num_features):
for i in range(num_features):
feature_matrix, num_frames_in_this_feature = _get_video_matrix(
features[feature_names[feature_index]], feature_sizes[feature_index],
max_frames, max_quantized_value, min_quantized_value)
features[feature_names[i]],
feature_sizes[i],
tf.dtypes.as_dtype(feature_dtypes[i]),
max_frames,
max_quantized_value,
min_quantized_value)
if num_frames == -1:
num_frames = num_frames_in_this_feature
feature_matrices[feature_index] = feature_matrix
feature_matrices[i] = feature_matrix
# cap the number of frames at self.max_frames
num_frames = tf.minimum(num_frames, max_frames)
......@@ -236,9 +246,21 @@ class Decoder(decoder.Decoder):
self._segment_labels = input_params.segment_labels
self._feature_names = input_params.feature_names
self._context_features = {
"id": tf.io.FixedLenFeature([], tf.string),
}
self._feature_sources = input_params.feature_sources
self._feature_sizes = input_params.feature_sizes
self._feature_dtypes = input_params.feature_dtypes
self._feature_from_bytes = input_params.feature_from_bytes
self._include_video_id = input_params.include_video_id
assert len(self._feature_names) == len(self._feature_sources), (
"length of feature_names (={}) != length of feature_sizes (={})".format(
len(self._feature_names), len(self._feature_sources)))
self._context_features = {}
self._sequence_features = {}
if self._include_video_id:
self._context_features["id"] = tf.io.FixedLenFeature([], tf.string)
if self._segment_labels:
self._context_features.update({
# There is no need to read end-time given we always assume the segment
......@@ -250,10 +272,23 @@ class Decoder(decoder.Decoder):
else:
self._context_features.update({"labels": tf.io.VarLenFeature(tf.int64)})
self._sequence_features = {
feature_name: tf.io.FixedLenSequenceFeature([], dtype=tf.string)
for feature_name in self._feature_names
}
for i, name in enumerate(self._feature_names):
if self._feature_from_bytes[i]:
feature_type = tf.io.FixedLenSequenceFeature([], dtype=tf.string)
else:
dtype = tf.dtypes.as_dtype(self._feature_dtypes[i])
feature_shape = [self._feature_sizes[i]]
if self._feature_sources[i] == "feature":
feature_type = tf.io.FixedLenSequenceFeature(feature_shape, dtype)
else:
feature_type = tf.io.FixedLenFeature(feature_shape, dtype)
if self._feature_sources[i] == "feature":
self._sequence_features[name] = feature_type
elif self._feature_sources[i] == "context":
self._context_features[name] = feature_type
else:
raise ValueError(
f"Unknow feature source {self._feature_sources[i]} for {name}")
def decode(self, serialized_example):
"""Parses a single tf.Example into image and label tensors."""
......@@ -263,7 +298,17 @@ class Decoder(decoder.Decoder):
context_features=self._context_features,
sequence_features=self._sequence_features)
return {"contexts": contexts, "features": features}
decoded_tensor = {**contexts, **features}
for i, name in enumerate(self._feature_names):
# Convert the VarLen feature to dense tensor.
if self._feature_from_bytes[i]:
dtype = tf.dtypes.as_dtype(self._feature_dtypes[i])
decoded_tensor[name] = tf.cast(
tf.io.decode_raw(decoded_tensor[name], dtype), tf.float32),
else:
if isinstance(decoded_tensor[name], tf.SparseTensor):
decoded_tensor[name] = tf.sparse.to_dense(decoded_tensor[name])
return decoded_tensor
class Parser(parser.Parser):
......@@ -287,6 +332,7 @@ class Parser(parser.Parser):
self._include_video_id = input_params.include_video_id
self._feature_names = input_params.feature_names
self._feature_sizes = input_params.feature_sizes
self._feature_dtypes = input_params.feature_dtypes
self._max_frames = input_params.max_frames
self._max_quantized_value = max_quantized_value
self._min_quantized_value = min_quantized_value
......@@ -295,12 +341,13 @@ class Parser(parser.Parser):
"""Parses data for training."""
# loads (potentially) different types of features and concatenates them
self.video_matrix, self.num_frames = _concat_features(
decoded_tensors["features"], self._feature_names, self._feature_sizes,
self._max_frames, self._max_quantized_value, self._min_quantized_value)
if not self._include_video_id:
del decoded_tensors["contexts"]["id"]
decoded_tensors, self._feature_names, self._feature_sizes,
self._feature_dtypes, self._max_frames, self._max_quantized_value,
self._min_quantized_value)
if not self._include_video_id and "id" in decoded_tensors:
del decoded_tensors["id"]
output_dict = _process_segment_and_label(self.video_matrix, self.num_frames,
decoded_tensors["contexts"],
decoded_tensors,
self._segment_labels,
self._segment_size,
self._num_classes)
......@@ -310,12 +357,13 @@ class Parser(parser.Parser):
"""Parses data for evaluation."""
# loads (potentially) different types of features and concatenates them
self.video_matrix, self.num_frames = _concat_features(
decoded_tensors["features"], self._feature_names, self._feature_sizes,
self._max_frames, self._max_quantized_value, self._min_quantized_value)
if not self._include_video_id:
del decoded_tensors["contexts"]["id"]
decoded_tensors, self._feature_names, self._feature_sizes,
self._feature_dtypes, self._max_frames, self._max_quantized_value,
self._min_quantized_value)
if not self._include_video_id and "id" in decoded_tensors:
del decoded_tensors["id"]
output_dict = _process_segment_and_label(self.video_matrix, self.num_frames,
decoded_tensors["contexts"],
decoded_tensors,
self._segment_labels,
self._segment_size,
self._num_classes)
......
......@@ -123,6 +123,63 @@ class Yt8mInputTest(parameterized.TestCase, tf.test.TestCase):
if include_video_id:
self.assertEqual(example['video_ids'].shape.as_list(), [batch_size])
@parameterized.parameters((True,), (False,))
def test_read_video_level_float_input(self, include_video_id):
data_dir = os.path.join(self.get_temp_dir(), 'data2')
tf.io.gfile.makedirs(data_dir)
data_path = os.path.join(data_dir, 'data2.tfrecord')
examples = [
utils.MakeExampleWithFloatFeatures(self.num_segment) for _ in range(8)
]
tfexample_utils.dump_to_tfrecord(data_path, tf_examples=examples)
params = yt8m_configs.yt8m(is_training=False)
params.global_batch_size = 4
params.segment_labels = False
params.input_path = data_path
params.num_frames = 2
params.max_frames = 2
params.feature_names = ('VIDEO_EMBEDDING/context_feature/floats',
'FEATURE/feature/floats')
params.feature_sources = ('context', 'feature')
params.feature_dtypes = ('float32', 'float32')
params.feature_sizes = (256, 2048)
params.feature_from_bytes = (False, False)
params.include_video_id = include_video_id
reader = self.create_input_reader(params)
dataset = reader.read()
iterator = iter(dataset)
example = next(iterator)
for k, v in example.items():
logging.info('DEBUG read example %r %r %r', k, v.shape, type(v))
logging.info('DEBUG read example %r', example['video_matrix'][0, 0, :])
if include_video_id:
self.assertCountEqual(
['video_matrix', 'labels', 'num_frames', 'video_ids'], example.keys())
else:
self.assertCountEqual(['video_matrix', 'labels', 'num_frames'],
example.keys())
batch_size = params.global_batch_size
expected_context = examples[0].context.feature[
'VIDEO_EMBEDDING/context_feature/floats'].float_list.value
expected_feature = examples[0].feature_lists.feature_list[
'FEATURE/feature/floats'].feature[0].float_list.value
self.assertAllEqual(
expected_feature,
example['video_matrix'][0, 0, params.feature_sizes[0]:])
self.assertAllEqual(
expected_context,
example['video_matrix'][0, 0, :params.feature_sizes[0]])
self.assertEqual(
example['video_matrix'].shape.as_list(),
[batch_size, params.max_frames, sum(params.feature_sizes)])
self.assertEqual(example['labels'].shape.as_list(),
[batch_size, params.num_classes])
self.assertEqual(example['num_frames'].shape.as_list(), [batch_size, 1])
if include_video_id:
self.assertEqual(example['video_ids'].shape.as_list(), [batch_size, 1])
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment