Commit d1c17371 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Adds UCF-101 experiment. Data is from tfds.

Adds checkpoint loading logic to the video classification task.
Adds tfds decoder for video data.

PiperOrigin-RevId: 372649811
parent a67c28c8
...@@ -144,6 +144,8 @@ class VideoClassificationTask(cfg.TaskConfig): ...@@ -144,6 +144,8 @@ class VideoClassificationTask(cfg.TaskConfig):
is_training=False, drop_remainder=False) is_training=False, drop_remainder=False)
losses: Losses = Losses() losses: Losses = Losses()
metrics: Metrics = Metrics() metrics: Metrics = Metrics()
init_checkpoint: Optional[str] = None
init_checkpoint_modules: str = 'all' # all or backbone
def add_trainer(experiment: cfg.ExperimentConfig, def add_trainer(experiment: cfg.ExperimentConfig,
...@@ -210,6 +212,52 @@ def video_classification() -> cfg.ExperimentConfig: ...@@ -210,6 +212,52 @@ def video_classification() -> cfg.ExperimentConfig:
]) ])
@exp_factory.register_config_factory('video_classification_ucf101')
def video_classification_ucf101() -> cfg.ExperimentConfig:
"""Video classification on UCF-101 with resnet."""
train_dataset = DataConfig(
name='ucf101',
num_classes=101,
is_training=True,
split='train',
drop_remainder=True,
num_examples=9537,
temporal_stride=2,
feature_shape=(32, 224, 224, 3))
train_dataset.tfds_name = 'ucf101'
train_dataset.tfds_split = 'train'
validation_dataset = DataConfig(
name='ucf101',
num_classes=101,
is_training=True,
split='test',
drop_remainder=False,
num_examples=3783,
temporal_stride=2,
feature_shape=(32, 224, 224, 3))
validation_dataset.tfds_name = 'ucf101'
validation_dataset.tfds_split = 'test'
task = VideoClassificationTask(
model=VideoClassificationModel(
backbone=backbones_3d.Backbone3D(
type='resnet_3d', resnet_3d=backbones_3d.ResNet3D50()),
norm_activation=common.NormActivation(
norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False)),
losses=Losses(l2_weight_decay=1e-4),
train_data=train_dataset,
validation_data=validation_dataset)
config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'),
task=task,
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None',
'task.train_data.num_classes == task.validation_data.num_classes',
])
add_trainer(config, train_batch_size=64, eval_batch_size=16, train_epochs=100)
return config
@exp_factory.register_config_factory('video_classification_kinetics400') @exp_factory.register_config_factory('video_classification_kinetics400')
def video_classification_kinetics400() -> cfg.ExperimentConfig: def video_classification_kinetics400() -> cfg.ExperimentConfig:
"""Video classification on Kinectics 400 with resnet.""" """Video classification on Kinectics 400 with resnet."""
......
...@@ -113,7 +113,8 @@ def process_image(image: tf.Tensor, ...@@ -113,7 +113,8 @@ def process_image(image: tf.Tensor,
image = preprocess_ops_3d.sample_sequence(image, num_frames, False, stride) image = preprocess_ops_3d.sample_sequence(image, num_frames, False, stride)
# Decode JPEG string to tf.uint8. # Decode JPEG string to tf.uint8.
image = preprocess_ops_3d.decode_jpeg(image, 3) if image.dtype == tf.string:
image = preprocess_ops_3d.decode_jpeg(image, 3)
if is_training: if is_training:
# Standard image data augmentation: random resized crop and random flip. # Standard image data augmentation: random resized crop and random flip.
...@@ -234,6 +235,29 @@ class Decoder(decoder.Decoder): ...@@ -234,6 +235,29 @@ class Decoder(decoder.Decoder):
return result return result
class VideoTfdsDecoder(decoder.Decoder):
"""A tf.SequenceExample decoder for tfds video classification datasets."""
def __init__(self, image_key: str = IMAGE_KEY, label_key: str = LABEL_KEY):
self._image_key = image_key
self._label_key = label_key
def decode(self, features):
"""Decode the TFDS FeatureDict.
Args:
features: features from TFDS video dataset.
See https://www.tensorflow.org/datasets/catalog/ucf101 for example.
Returns:
Dict of tensors.
"""
sample_dict = {
self._image_key: features['video'],
self._label_key: features['label'],
}
return sample_dict
class Parser(parser.Parser): class Parser(parser.Parser):
"""Parses a video and label dataset.""" """Parses a video and label dataset."""
......
...@@ -20,6 +20,7 @@ import io ...@@ -20,6 +20,7 @@ import io
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import tensorflow as tf import tensorflow as tf
import tensorflow_datasets as tfds
from official.vision.beta.configs import video_classification as exp_cfg from official.vision.beta.configs import video_classification as exp_cfg
from official.vision.beta.dataloaders import video_input from official.vision.beta.dataloaders import video_input
...@@ -87,6 +88,16 @@ class DecoderTest(tf.test.TestCase): ...@@ -87,6 +88,16 @@ class DecoderTest(tf.test.TestCase):
self.assertEqual(label, results[video_input.LABEL_KEY]) self.assertEqual(label, results[video_input.LABEL_KEY])
self.assertEqual(results[AUDIO_KEY].shape, (10, 256)) self.assertEqual(results[AUDIO_KEY].shape, (10, 256))
def test_tfds_decode(self):
with tfds.testing.mock_data(num_examples=1):
dataset = tfds.load('ucf101', split='train').take(1)
data = next(iter(dataset))
decoder = video_input.VideoTfdsDecoder()
decoded_tensors = decoder.decode(data)
self.assertContainsSubset([video_input.LABEL_KEY, video_input.IMAGE_KEY],
decoded_tensors.keys())
class VideoAndLabelParserTest(tf.test.TestCase): class VideoAndLabelParserTest(tf.test.TestCase):
......
...@@ -55,6 +55,31 @@ class VideoClassificationTask(base_task.Task): ...@@ -55,6 +55,31 @@ class VideoClassificationTask(base_task.Task):
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
return model return model
def initialize(self, model: tf.keras.Model):
"""Loads pretrained checkpoint."""
if not self.task_config.init_checkpoint:
return
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
# Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.assert_consumed()
elif self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
raise ValueError(
"Only 'all' or 'backbone' can be used to initialize the model.")
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def _get_dataset_fn(self, params): def _get_dataset_fn(self, params):
if params.file_type == 'tfrecord': if params.file_type == 'tfrecord':
return tf.data.TFRecordDataset return tf.data.TFRecordDataset
...@@ -62,8 +87,12 @@ class VideoClassificationTask(base_task.Task): ...@@ -62,8 +87,12 @@ class VideoClassificationTask(base_task.Task):
raise ValueError('Unknown input file type {!r}'.format(params.file_type)) raise ValueError('Unknown input file type {!r}'.format(params.file_type))
def _get_decoder_fn(self, params): def _get_decoder_fn(self, params):
decoder = video_input.Decoder( if params.tfds_name:
image_key=params.image_field_key, label_key=params.label_field_key) decoder = video_input.VideoTfdsDecoder(
image_key=params.image_field_key, label_key=params.label_field_key)
else:
decoder = video_input.Decoder(
image_key=params.image_field_key, label_key=params.label_field_key)
if self.task_config.train_data.output_audio: if self.task_config.train_data.output_audio:
assert self.task_config.train_data.audio_feature, 'audio feature is empty' assert self.task_config.train_data.audio_feature, 'audio feature is empty'
decoder.add_feature(self.task_config.train_data.audio_feature, decoder.add_feature(self.task_config.train_data.audio_feature,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment