Commit babf9ad8 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 343373060
parent 69fb2164
......@@ -47,8 +47,12 @@ class DataConfig(cfg.DataConfig):
input_path: str = ''
is_training: bool = True
cycle_length: int = 10
drop_remainder: bool = True
min_image_size: int = 256
is_multilabel: bool = False
output_audio: bool = False
audio_feature: str = ''
audio_feature_shape: Tuple[int, ...] = (-1,)
def kinetics400(is_training):
......@@ -58,6 +62,7 @@ def kinetics400(is_training):
num_classes=400,
is_training=is_training,
split='train' if is_training else 'valid',
drop_remainder=is_training,
num_examples=215570 if is_training else 17706,
feature_shape=(64, 224, 224, 3) if is_training else (250, 224, 224, 3))
......@@ -69,6 +74,7 @@ def kinetics600(is_training):
num_classes=600,
is_training=is_training,
split='train' if is_training else 'valid',
drop_remainder=is_training,
num_examples=366016 if is_training else 27780,
feature_shape=(64, 224, 224, 3) if is_training else (250, 224, 224, 3))
......@@ -95,8 +101,9 @@ class Losses(hyperparams.Config):
class VideoClassificationTask(cfg.TaskConfig):
"""The task config."""
model: VideoClassificationModel = VideoClassificationModel()
train_data: DataConfig = DataConfig(is_training=True)
validation_data: DataConfig = DataConfig(is_training=False)
train_data: DataConfig = DataConfig(is_training=True, drop_remainder=True)
validation_data: DataConfig = DataConfig(
is_training=False, drop_remainder=False)
losses: Losses = Losses()
......
......@@ -15,7 +15,7 @@
# ==============================================================================
"""Parser for video and label datasets."""
from typing import Dict, Optional, Tuple
from typing import Dict, Optional, Tuple, Union
from absl import logging
import tensorflow as tf
......@@ -125,8 +125,7 @@ def _postprocess_image(image: tf.Tensor,
if num_test_clips > 1 and not is_training:
# In this case, multiple clips are merged together in batch dimenstion which
# will be B * num_test_clips.
image = tf.reshape(
image, (-1, num_frames, image.shape[2], image.shape[3], image.shape[4]))
image = tf.reshape(image, (-1, num_frames) + image.shape[2:])
return image
......@@ -170,15 +169,30 @@ class Decoder(decoder.Decoder):
self._image_key: tf.io.FixedLenSequenceFeature((), tf.string),
}
def add_feature(self, feature_name: str,
feature_type: Union[tf.io.VarLenFeature,
tf.io.FixedLenFeature,
tf.io.FixedLenSequenceFeature]):
self._sequence_description[feature_name] = feature_type
def add_context(self, feature_name: str,
feature_type: Union[tf.io.VarLenFeature,
tf.io.FixedLenFeature,
tf.io.FixedLenSequenceFeature]):
self._context_description[feature_name] = feature_type
def decode(self, serialized_example):
"""Parses a single tf.Example into image and label tensors."""
result = {}
context, sequences = tf.io.parse_single_sequence_example(
serialized_example, self._context_description,
self._sequence_description)
return {
self._image_key: sequences[self._image_key],
self._label_key: tf.sparse.to_dense(context[self._label_key])
}
result.update(context)
result.update(sequences)
for key, value in result.items():
if isinstance(value, tf.SparseTensor):
result[key] = tf.sparse.to_dense(value)
return result
class Parser(parser.Parser):
......@@ -198,6 +212,10 @@ class Parser(parser.Parser):
self._image_key = image_key
self._label_key = label_key
self._dtype = tf.dtypes.as_dtype(input_params.dtype)
self._output_audio = input_params.output_audio
if self._output_audio:
self._audio_feature = input_params.audio_feature
self._audio_shape = input_params.audio_feature_shape
def _parse_train_data(
self, decoded_tensors: Dict[str, tf.Tensor]
......@@ -214,11 +232,21 @@ class Parser(parser.Parser):
min_resize=self._min_resize,
crop_size=self._crop_size)
image = tf.cast(image, dtype=self._dtype)
features = {'image': image}
label = decoded_tensors[self._label_key]
label = _process_label(label, self._one_hot_label, self._num_classes)
return {'image': image}, label
if self._output_audio:
audio = decoded_tensors[self._audio_feature]
audio = tf.cast(audio, dtype=self._dtype)
# TODO(yeqing): synchronize audio/video sampling. Especially randomness.
audio = preprocess_ops_3d.sample_sequence(
audio, self._audio_shape[0], random=False, stride=1)
audio = tf.ensure_shape(audio, self._audio_shape)
features['audio'] = audio
return features, label
def _parse_eval_data(
self, decoded_tensors: Dict[str, tf.Tensor]
......@@ -234,11 +262,20 @@ class Parser(parser.Parser):
min_resize=self._min_resize,
crop_size=self._crop_size)
image = tf.cast(image, dtype=self._dtype)
features = {'image': image}
label = decoded_tensors[self._label_key]
label = _process_label(label, self._one_hot_label, self._num_classes)
return {'image': image}, label
if self._output_audio:
audio = decoded_tensors[self._audio_feature]
audio = tf.cast(audio, dtype=self._dtype)
audio = preprocess_ops_3d.sample_sequence(
audio, 20, random=False, stride=1)
audio = tf.ensure_shape(audio, [20, 2048])
features['audio'] = audio
return features, label
class PostBatchProcessor(object):
......@@ -250,16 +287,15 @@ class PostBatchProcessor(object):
self._num_frames = input_params.feature_shape[0]
self._num_test_clips = input_params.num_test_clips
def __call__(
self,
image: Dict[str, tf.Tensor],
label: tf.Tensor) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]:
def __call__(self, features: Dict[str, tf.Tensor],
label: tf.Tensor) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]:
"""Parses a single tf.Example into image and label tensors."""
image = image['image']
image = _postprocess_image(
image=image,
is_training=self._is_training,
num_frames=self._num_frames,
num_test_clips=self._num_test_clips)
return {'image': image}, label
for key in ['image', 'audio']:
if key in features:
features[key] = _postprocess_image(
image=features[key],
is_training=self._is_training,
num_frames=self._num_frames,
num_test_clips=self._num_test_clips)
return features, label
......@@ -25,32 +25,45 @@ from official.vision.beta.configs import video_classification as exp_cfg
from official.vision.beta.dataloaders import video_input
AUDIO_KEY = 'features/audio'
def fake_seq_example():
# Create fake data.
random_image = np.random.randint(0, 256, size=(263, 320, 3), dtype=np.uint8)
random_image = Image.fromarray(random_image)
label = 42
with io.BytesIO() as buffer:
random_image.save(buffer, format='JPEG')
raw_image_bytes = buffer.getvalue()
seq_example = tf.train.SequenceExample()
seq_example.feature_lists.feature_list.get_or_create(
video_input.IMAGE_KEY).feature.add().bytes_list.value[:] = [
raw_image_bytes
]
seq_example.feature_lists.feature_list.get_or_create(
video_input.IMAGE_KEY).feature.add().bytes_list.value[:] = [
raw_image_bytes
]
seq_example.context.feature[video_input.LABEL_KEY].int64_list.value[:] = [
label
]
random_audio = np.random.normal(size=(10, 256)).tolist()
for s in random_audio:
seq_example.feature_lists.feature_list.get_or_create(
AUDIO_KEY).feature.add().float_list.value[:] = s
return seq_example, label
class DecoderTest(tf.test.TestCase):
"""A tf.SequenceExample decoder for the video classification task."""
def test_decoder(self):
decoder = video_input.Decoder()
# Create fake data.
random_image = np.random.randint(0, 256, size=(263, 320, 3), dtype=np.uint8)
random_image = Image.fromarray(random_image)
label = 42
with io.BytesIO() as buffer:
random_image.save(buffer, format='JPEG')
raw_image_bytes = buffer.getvalue()
seq_example = tf.train.SequenceExample()
seq_example.feature_lists.feature_list.get_or_create(
video_input.IMAGE_KEY).feature.add().bytes_list.value[:] = [
raw_image_bytes
]
seq_example.feature_lists.feature_list.get_or_create(
video_input.IMAGE_KEY).feature.add().bytes_list.value[:] = [
raw_image_bytes
]
seq_example.context.feature[video_input.LABEL_KEY].int64_list.value[:] = [
label
]
seq_example, label = fake_seq_example()
serialized_example = seq_example.SerializeToString()
decoded_tensors = decoder.decode(tf.convert_to_tensor(serialized_example))
......@@ -59,6 +72,21 @@ class DecoderTest(tf.test.TestCase):
results.keys())
self.assertEqual(label, results[video_input.LABEL_KEY])
def test_decode_audio(self):
decoder = video_input.Decoder()
decoder.add_feature(AUDIO_KEY, tf.io.VarLenFeature(dtype=tf.float32))
seq_example, label = fake_seq_example()
serialized_example = seq_example.SerializeToString()
decoded_tensors = decoder.decode(tf.convert_to_tensor(serialized_example))
results = tf.nest.map_structure(lambda x: x.numpy(), decoded_tensors)
self.assertCountEqual(
[video_input.IMAGE_KEY, video_input.LABEL_KEY, AUDIO_KEY],
results.keys())
self.assertEqual(label, results[video_input.LABEL_KEY])
self.assertEqual(results[AUDIO_KEY].shape, (10, 256))
class VideoAndLabelParserTest(tf.test.TestCase):
......@@ -66,28 +94,11 @@ class VideoAndLabelParserTest(tf.test.TestCase):
params = exp_cfg.kinetics600(is_training=True)
params.feature_shape = (2, 224, 224, 3)
params.min_image_size = 224
decoder = video_input.Decoder()
parser = video_input.Parser(params).parse_fn(params.is_training)
# Create fake data.
random_image = np.random.randint(0, 256, size=(263, 320, 3), dtype=np.uint8)
random_image = Image.fromarray(random_image)
with io.BytesIO() as buffer:
random_image.save(buffer, format='JPEG')
raw_image_bytes = buffer.getvalue()
seq_example = tf.train.SequenceExample()
seq_example.feature_lists.feature_list.get_or_create(
video_input.IMAGE_KEY).feature.add().bytes_list.value[:] = [
raw_image_bytes
]
seq_example.feature_lists.feature_list.get_or_create(
video_input.IMAGE_KEY).feature.add().bytes_list.value[:] = [
raw_image_bytes
]
seq_example.context.feature[video_input.LABEL_KEY].int64_list.value[:] = [
42
]
seq_example, label = fake_seq_example()
input_tensor = tf.constant(seq_example.SerializeToString())
decoded_tensors = decoder.decode(input_tensor)
......@@ -98,6 +109,32 @@ class VideoAndLabelParserTest(tf.test.TestCase):
self.assertAllEqual(image.shape, (2, 224, 224, 3))
self.assertAllEqual(label.shape, (600,))
def test_video_audio_input(self):
params = exp_cfg.kinetics600(is_training=True)
params.feature_shape = (2, 224, 224, 3)
params.min_image_size = 224
params.output_audio = True
params.audio_feature = AUDIO_KEY
params.audio_feature_shape = (15, 256)
decoder = video_input.Decoder()
decoder.add_feature(params.audio_feature,
tf.io.VarLenFeature(dtype=tf.float32))
parser = video_input.Parser(params).parse_fn(params.is_training)
seq_example, label = fake_seq_example()
input_tensor = tf.constant(seq_example.SerializeToString())
decoded_tensors = decoder.decode(input_tensor)
output_tensor = parser(decoded_tensors)
features, label = output_tensor
image = features['image']
audio = features['audio']
self.assertAllEqual(image.shape, (2, 224, 224, 3))
self.assertAllEqual(label.shape, (600,))
self.assertEqual(audio.shape, (15, 256))
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment