Commit babf9ad8 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 343373060
parent 69fb2164
...@@ -47,8 +47,12 @@ class DataConfig(cfg.DataConfig): ...@@ -47,8 +47,12 @@ class DataConfig(cfg.DataConfig):
input_path: str = '' input_path: str = ''
is_training: bool = True is_training: bool = True
cycle_length: int = 10 cycle_length: int = 10
drop_remainder: bool = True
min_image_size: int = 256 min_image_size: int = 256
is_multilabel: bool = False is_multilabel: bool = False
output_audio: bool = False
audio_feature: str = ''
audio_feature_shape: Tuple[int, ...] = (-1,)
def kinetics400(is_training): def kinetics400(is_training):
...@@ -58,6 +62,7 @@ def kinetics400(is_training): ...@@ -58,6 +62,7 @@ def kinetics400(is_training):
num_classes=400, num_classes=400,
is_training=is_training, is_training=is_training,
split='train' if is_training else 'valid', split='train' if is_training else 'valid',
drop_remainder=is_training,
num_examples=215570 if is_training else 17706, num_examples=215570 if is_training else 17706,
feature_shape=(64, 224, 224, 3) if is_training else (250, 224, 224, 3)) feature_shape=(64, 224, 224, 3) if is_training else (250, 224, 224, 3))
...@@ -69,6 +74,7 @@ def kinetics600(is_training): ...@@ -69,6 +74,7 @@ def kinetics600(is_training):
num_classes=600, num_classes=600,
is_training=is_training, is_training=is_training,
split='train' if is_training else 'valid', split='train' if is_training else 'valid',
drop_remainder=is_training,
num_examples=366016 if is_training else 27780, num_examples=366016 if is_training else 27780,
feature_shape=(64, 224, 224, 3) if is_training else (250, 224, 224, 3)) feature_shape=(64, 224, 224, 3) if is_training else (250, 224, 224, 3))
...@@ -95,8 +101,9 @@ class Losses(hyperparams.Config): ...@@ -95,8 +101,9 @@ class Losses(hyperparams.Config):
class VideoClassificationTask(cfg.TaskConfig): class VideoClassificationTask(cfg.TaskConfig):
"""The task config.""" """The task config."""
model: VideoClassificationModel = VideoClassificationModel() model: VideoClassificationModel = VideoClassificationModel()
train_data: DataConfig = DataConfig(is_training=True) train_data: DataConfig = DataConfig(is_training=True, drop_remainder=True)
validation_data: DataConfig = DataConfig(is_training=False) validation_data: DataConfig = DataConfig(
is_training=False, drop_remainder=False)
losses: Losses = Losses() losses: Losses = Losses()
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# ============================================================================== # ==============================================================================
"""Parser for video and label datasets.""" """Parser for video and label datasets."""
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple, Union
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
...@@ -125,8 +125,7 @@ def _postprocess_image(image: tf.Tensor, ...@@ -125,8 +125,7 @@ def _postprocess_image(image: tf.Tensor,
if num_test_clips > 1 and not is_training: if num_test_clips > 1 and not is_training:
# In this case, multiple clips are merged together in batch dimenstion which # In this case, multiple clips are merged together in batch dimenstion which
# will be B * num_test_clips. # will be B * num_test_clips.
image = tf.reshape( image = tf.reshape(image, (-1, num_frames) + image.shape[2:])
image, (-1, num_frames, image.shape[2], image.shape[3], image.shape[4]))
return image return image
...@@ -170,15 +169,30 @@ class Decoder(decoder.Decoder): ...@@ -170,15 +169,30 @@ class Decoder(decoder.Decoder):
self._image_key: tf.io.FixedLenSequenceFeature((), tf.string), self._image_key: tf.io.FixedLenSequenceFeature((), tf.string),
} }
def add_feature(self, feature_name: str,
feature_type: Union[tf.io.VarLenFeature,
tf.io.FixedLenFeature,
tf.io.FixedLenSequenceFeature]):
self._sequence_description[feature_name] = feature_type
def add_context(self, feature_name: str,
feature_type: Union[tf.io.VarLenFeature,
tf.io.FixedLenFeature,
tf.io.FixedLenSequenceFeature]):
self._context_description[feature_name] = feature_type
def decode(self, serialized_example): def decode(self, serialized_example):
"""Parses a single tf.Example into image and label tensors.""" """Parses a single tf.Example into image and label tensors."""
result = {}
context, sequences = tf.io.parse_single_sequence_example( context, sequences = tf.io.parse_single_sequence_example(
serialized_example, self._context_description, serialized_example, self._context_description,
self._sequence_description) self._sequence_description)
return { result.update(context)
self._image_key: sequences[self._image_key], result.update(sequences)
self._label_key: tf.sparse.to_dense(context[self._label_key]) for key, value in result.items():
} if isinstance(value, tf.SparseTensor):
result[key] = tf.sparse.to_dense(value)
return result
class Parser(parser.Parser): class Parser(parser.Parser):
...@@ -198,6 +212,10 @@ class Parser(parser.Parser): ...@@ -198,6 +212,10 @@ class Parser(parser.Parser):
self._image_key = image_key self._image_key = image_key
self._label_key = label_key self._label_key = label_key
self._dtype = tf.dtypes.as_dtype(input_params.dtype) self._dtype = tf.dtypes.as_dtype(input_params.dtype)
self._output_audio = input_params.output_audio
if self._output_audio:
self._audio_feature = input_params.audio_feature
self._audio_shape = input_params.audio_feature_shape
def _parse_train_data( def _parse_train_data(
self, decoded_tensors: Dict[str, tf.Tensor] self, decoded_tensors: Dict[str, tf.Tensor]
...@@ -214,11 +232,21 @@ class Parser(parser.Parser): ...@@ -214,11 +232,21 @@ class Parser(parser.Parser):
min_resize=self._min_resize, min_resize=self._min_resize,
crop_size=self._crop_size) crop_size=self._crop_size)
image = tf.cast(image, dtype=self._dtype) image = tf.cast(image, dtype=self._dtype)
features = {'image': image}
label = decoded_tensors[self._label_key] label = decoded_tensors[self._label_key]
label = _process_label(label, self._one_hot_label, self._num_classes) label = _process_label(label, self._one_hot_label, self._num_classes)
return {'image': image}, label if self._output_audio:
audio = decoded_tensors[self._audio_feature]
audio = tf.cast(audio, dtype=self._dtype)
# TODO(yeqing): synchronize audio/video sampling. Especially randomness.
audio = preprocess_ops_3d.sample_sequence(
audio, self._audio_shape[0], random=False, stride=1)
audio = tf.ensure_shape(audio, self._audio_shape)
features['audio'] = audio
return features, label
def _parse_eval_data( def _parse_eval_data(
self, decoded_tensors: Dict[str, tf.Tensor] self, decoded_tensors: Dict[str, tf.Tensor]
...@@ -234,11 +262,20 @@ class Parser(parser.Parser): ...@@ -234,11 +262,20 @@ class Parser(parser.Parser):
min_resize=self._min_resize, min_resize=self._min_resize,
crop_size=self._crop_size) crop_size=self._crop_size)
image = tf.cast(image, dtype=self._dtype) image = tf.cast(image, dtype=self._dtype)
features = {'image': image}
label = decoded_tensors[self._label_key] label = decoded_tensors[self._label_key]
label = _process_label(label, self._one_hot_label, self._num_classes) label = _process_label(label, self._one_hot_label, self._num_classes)
return {'image': image}, label if self._output_audio:
audio = decoded_tensors[self._audio_feature]
audio = tf.cast(audio, dtype=self._dtype)
audio = preprocess_ops_3d.sample_sequence(
audio, 20, random=False, stride=1)
audio = tf.ensure_shape(audio, [20, 2048])
features['audio'] = audio
return features, label
class PostBatchProcessor(object): class PostBatchProcessor(object):
...@@ -250,16 +287,15 @@ class PostBatchProcessor(object): ...@@ -250,16 +287,15 @@ class PostBatchProcessor(object):
self._num_frames = input_params.feature_shape[0] self._num_frames = input_params.feature_shape[0]
self._num_test_clips = input_params.num_test_clips self._num_test_clips = input_params.num_test_clips
def __call__( def __call__(self, features: Dict[str, tf.Tensor],
self,
image: Dict[str, tf.Tensor],
label: tf.Tensor) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]: label: tf.Tensor) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]:
"""Parses a single tf.Example into image and label tensors.""" """Parses a single tf.Example into image and label tensors."""
image = image['image'] for key in ['image', 'audio']:
image = _postprocess_image( if key in features:
image=image, features[key] = _postprocess_image(
image=features[key],
is_training=self._is_training, is_training=self._is_training,
num_frames=self._num_frames, num_frames=self._num_frames,
num_test_clips=self._num_test_clips) num_test_clips=self._num_test_clips)
return {'image': image}, label return features, label
...@@ -25,12 +25,10 @@ from official.vision.beta.configs import video_classification as exp_cfg ...@@ -25,12 +25,10 @@ from official.vision.beta.configs import video_classification as exp_cfg
from official.vision.beta.dataloaders import video_input from official.vision.beta.dataloaders import video_input
class DecoderTest(tf.test.TestCase): AUDIO_KEY = 'features/audio'
"""A tf.SequenceExample decoder for the video classification task."""
def test_decoder(self):
decoder = video_input.Decoder()
def fake_seq_example():
# Create fake data. # Create fake data.
random_image = np.random.randint(0, 256, size=(263, 320, 3), dtype=np.uint8) random_image = np.random.randint(0, 256, size=(263, 320, 3), dtype=np.uint8)
random_image = Image.fromarray(random_image) random_image = Image.fromarray(random_image)
...@@ -51,6 +49,21 @@ class DecoderTest(tf.test.TestCase): ...@@ -51,6 +49,21 @@ class DecoderTest(tf.test.TestCase):
seq_example.context.feature[video_input.LABEL_KEY].int64_list.value[:] = [ seq_example.context.feature[video_input.LABEL_KEY].int64_list.value[:] = [
label label
] ]
random_audio = np.random.normal(size=(10, 256)).tolist()
for s in random_audio:
seq_example.feature_lists.feature_list.get_or_create(
AUDIO_KEY).feature.add().float_list.value[:] = s
return seq_example, label
class DecoderTest(tf.test.TestCase):
"""A tf.SequenceExample decoder for the video classification task."""
def test_decoder(self):
decoder = video_input.Decoder()
seq_example, label = fake_seq_example()
serialized_example = seq_example.SerializeToString() serialized_example = seq_example.SerializeToString()
decoded_tensors = decoder.decode(tf.convert_to_tensor(serialized_example)) decoded_tensors = decoder.decode(tf.convert_to_tensor(serialized_example))
...@@ -59,6 +72,21 @@ class DecoderTest(tf.test.TestCase): ...@@ -59,6 +72,21 @@ class DecoderTest(tf.test.TestCase):
results.keys()) results.keys())
self.assertEqual(label, results[video_input.LABEL_KEY]) self.assertEqual(label, results[video_input.LABEL_KEY])
def test_decode_audio(self):
decoder = video_input.Decoder()
decoder.add_feature(AUDIO_KEY, tf.io.VarLenFeature(dtype=tf.float32))
seq_example, label = fake_seq_example()
serialized_example = seq_example.SerializeToString()
decoded_tensors = decoder.decode(tf.convert_to_tensor(serialized_example))
results = tf.nest.map_structure(lambda x: x.numpy(), decoded_tensors)
self.assertCountEqual(
[video_input.IMAGE_KEY, video_input.LABEL_KEY, AUDIO_KEY],
results.keys())
self.assertEqual(label, results[video_input.LABEL_KEY])
self.assertEqual(results[AUDIO_KEY].shape, (10, 256))
class VideoAndLabelParserTest(tf.test.TestCase): class VideoAndLabelParserTest(tf.test.TestCase):
...@@ -66,28 +94,11 @@ class VideoAndLabelParserTest(tf.test.TestCase): ...@@ -66,28 +94,11 @@ class VideoAndLabelParserTest(tf.test.TestCase):
params = exp_cfg.kinetics600(is_training=True) params = exp_cfg.kinetics600(is_training=True)
params.feature_shape = (2, 224, 224, 3) params.feature_shape = (2, 224, 224, 3)
params.min_image_size = 224 params.min_image_size = 224
decoder = video_input.Decoder() decoder = video_input.Decoder()
parser = video_input.Parser(params).parse_fn(params.is_training) parser = video_input.Parser(params).parse_fn(params.is_training)
# Create fake data. seq_example, label = fake_seq_example()
random_image = np.random.randint(0, 256, size=(263, 320, 3), dtype=np.uint8)
random_image = Image.fromarray(random_image)
with io.BytesIO() as buffer:
random_image.save(buffer, format='JPEG')
raw_image_bytes = buffer.getvalue()
seq_example = tf.train.SequenceExample()
seq_example.feature_lists.feature_list.get_or_create(
video_input.IMAGE_KEY).feature.add().bytes_list.value[:] = [
raw_image_bytes
]
seq_example.feature_lists.feature_list.get_or_create(
video_input.IMAGE_KEY).feature.add().bytes_list.value[:] = [
raw_image_bytes
]
seq_example.context.feature[video_input.LABEL_KEY].int64_list.value[:] = [
42
]
input_tensor = tf.constant(seq_example.SerializeToString()) input_tensor = tf.constant(seq_example.SerializeToString())
decoded_tensors = decoder.decode(input_tensor) decoded_tensors = decoder.decode(input_tensor)
...@@ -98,6 +109,32 @@ class VideoAndLabelParserTest(tf.test.TestCase): ...@@ -98,6 +109,32 @@ class VideoAndLabelParserTest(tf.test.TestCase):
self.assertAllEqual(image.shape, (2, 224, 224, 3)) self.assertAllEqual(image.shape, (2, 224, 224, 3))
self.assertAllEqual(label.shape, (600,)) self.assertAllEqual(label.shape, (600,))
def test_video_audio_input(self):
params = exp_cfg.kinetics600(is_training=True)
params.feature_shape = (2, 224, 224, 3)
params.min_image_size = 224
params.output_audio = True
params.audio_feature = AUDIO_KEY
params.audio_feature_shape = (15, 256)
decoder = video_input.Decoder()
decoder.add_feature(params.audio_feature,
tf.io.VarLenFeature(dtype=tf.float32))
parser = video_input.Parser(params).parse_fn(params.is_training)
seq_example, label = fake_seq_example()
input_tensor = tf.constant(seq_example.SerializeToString())
decoded_tensors = decoder.decode(input_tensor)
output_tensor = parser(decoded_tensors)
features, label = output_tensor
image = features['image']
audio = features['audio']
self.assertAllEqual(image.shape, (2, 224, 224, 3))
self.assertAllEqual(label.shape, (600,))
self.assertEqual(audio.shape, (15, 256))
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment