Commit 4c5e327e authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 370198074
parent 7e370ba7
......@@ -57,6 +57,8 @@ class DataConfig(cfg.DataConfig):
aug_max_aspect_ratio: float = 2.0
aug_min_area_ratio: float = 0.49
aug_max_area_ratio: float = 1.0
image_field_key: str = 'image/encoded'
label_field_key: str = 'clip/label/index'
def kinetics400(is_training):
......@@ -83,6 +85,30 @@ def kinetics600(is_training):
feature_shape=(64, 224, 224, 3) if is_training else (250, 224, 224, 3))
def kinetics700(is_training):
"""Generated Kinectics 600 dataset configs."""
return DataConfig(
name='kinetics700',
num_classes=700,
is_training=is_training,
split='train' if is_training else 'valid',
drop_remainder=is_training,
num_examples=522883 if is_training else 33441,
feature_shape=(64, 224, 224, 3) if is_training else (250, 224, 224, 3))
def kinetics700_2020(is_training):
"""Generated Kinectics 600 dataset configs."""
return DataConfig(
name='kinetics700',
num_classes=700,
is_training=is_training,
split='train' if is_training else 'valid',
drop_remainder=is_training,
num_examples=535982 if is_training else 33640,
feature_shape=(64, 224, 224, 3) if is_training else (250, 224, 224, 3))
@dataclasses.dataclass
class VideoClassificationModel(hyperparams.Config):
"""The model config."""
......@@ -232,3 +258,55 @@ def video_classification_kinetics600() -> cfg.ExperimentConfig:
])
add_trainer(config, train_batch_size=1024, eval_batch_size=64)
return config
@exp_factory.register_config_factory('video_classification_kinetics700')
def video_classification_kinetics700() -> cfg.ExperimentConfig:
"""Video classification on Kinectics 700 with resnet."""
train_dataset = kinetics700(is_training=True)
validation_dataset = kinetics700(is_training=False)
task = VideoClassificationTask(
model=VideoClassificationModel(
backbone=backbones_3d.Backbone3D(
type='resnet_3d', resnet_3d=backbones_3d.ResNet3D50()),
norm_activation=common.NormActivation(
norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False)),
losses=Losses(l2_weight_decay=1e-4),
train_data=train_dataset,
validation_data=validation_dataset)
config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'),
task=task,
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None',
'task.train_data.num_classes == task.validation_data.num_classes',
])
add_trainer(config, train_batch_size=1024, eval_batch_size=64)
return config
@exp_factory.register_config_factory('video_classification_kinetics700_2020')
def video_classification_kinetics700_2020() -> cfg.ExperimentConfig:
"""Video classification on Kinectics 700 2020 with resnet."""
train_dataset = kinetics700_2020(is_training=True)
validation_dataset = kinetics700_2020(is_training=False)
task = VideoClassificationTask(
model=VideoClassificationModel(
backbone=backbones_3d.Backbone3D(
type='resnet_3d', resnet_3d=backbones_3d.ResNet3D50()),
norm_activation=common.NormActivation(
norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False)),
losses=Losses(l2_weight_decay=1e-4),
train_data=train_dataset,
validation_data=validation_dataset)
config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'),
task=task,
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None',
'task.train_data.num_classes == task.validation_data.num_classes',
])
add_trainer(config, train_batch_size=1024, eval_batch_size=64)
return config
......@@ -62,7 +62,8 @@ class VideoClassificationTask(base_task.Task):
raise ValueError('Unknown input file type {!r}'.format(params.file_type))
def _get_decoder_fn(self, params):
decoder = video_input.Decoder()
decoder = video_input.Decoder(
image_key=params.image_field_key, label_key=params.label_field_key)
if self.task_config.train_data.output_audio:
assert self.task_config.train_data.audio_feature, 'audio feature is empty'
decoder.add_feature(self.task_config.train_data.audio_feature,
......@@ -74,7 +75,10 @@ class VideoClassificationTask(base_task.Task):
input_context: Optional[tf.distribute.InputContext] = None):
"""Builds classification input."""
parser = video_input.Parser(input_params=params)
parser = video_input.Parser(
input_params=params,
image_key=params.image_field_key,
label_key=params.label_field_key)
postprocess_fn = video_input.PostBatchProcessor(params)
reader = input_reader_factory.input_reader_generator(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment