Commit d7e9ece3 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 334942484
parent 825cf9f1
......@@ -64,6 +64,7 @@ def kinetics600(is_training):
@dataclasses.dataclass
class VideoClassificationModel(hyperparams.Config):
"""The model config."""
model_type: str = 'video_classification'
backbone: backbones_3d.Backbone3D = backbones_3d.Backbone3D(
type='resnet_3d', resnet_3d=backbones_3d.ResNet3D50())
norm_activation: common.NormActivation = common.NormActivation()
......@@ -142,6 +143,7 @@ def add_trainer(experiment: cfg.ExperimentConfig,
def video_classification() -> cfg.ExperimentConfig:
"""Video classification general."""
return cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'),
task=VideoClassificationTask(),
trainer=cfg.TrainerConfig(),
restrictions=[
......@@ -166,6 +168,7 @@ def video_classification_kinetics600() -> cfg.ExperimentConfig:
train_data=train_dataset,
validation_data=validation_dataset)
config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'),
task=task,
restrictions=[
'task.train_data.is_training != None',
......
......@@ -20,12 +20,10 @@ import tensorflow as tf
from official.vision.beta.configs import image_classification as classification_cfg
from official.vision.beta.configs import maskrcnn as maskrcnn_cfg
from official.vision.beta.configs import retinanet as retinanet_cfg
from official.vision.beta.configs import video_classification as video_classification_cfg
from official.vision.beta.modeling import backbones
from official.vision.beta.modeling import classification_model
from official.vision.beta.modeling import maskrcnn_model
from official.vision.beta.modeling import retinanet_model
from official.vision.beta.modeling import video_classification_model
from official.vision.beta.modeling.decoders import factory as decoder_factory
from official.vision.beta.modeling.heads import dense_prediction_heads
from official.vision.beta.modeling.heads import instance_heads
......@@ -234,28 +232,3 @@ def build_retinanet(input_specs: tf.keras.layers.InputSpec,
model = retinanet_model.RetinaNetModel(
backbone, decoder, head, detection_generator_obj)
return model
def build_video_classification_model(
input_specs: tf.keras.layers.InputSpec,
model_config: video_classification_cfg.VideoClassificationModel,
num_classes: int,
l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds the video classification model."""
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
norm_activation_config = model_config.norm_activation
model = video_classification_model.VideoClassificationModel(
backbone=backbone,
num_classes=num_classes,
input_specs=input_specs,
dropout_rate=model_config.dropout_rate,
kernel_regularizer=l2_regularizer,
add_head_batch_norm=model_config.add_head_batch_norm,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon)
return model
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Factory methods to build models."""
# Import libraries
import tensorflow as tf
from official.core import registry
from official.vision.beta.configs import video_classification as video_classification_cfg
from official.vision.beta.modeling import backbones
from official.vision.beta.modeling import video_classification_model
_REGISTERED_MODEL_CLS = {}
def register_model_builder(key: str):
"""Decorates a builder of model class.
The builder should be a Callable (a class or a function).
This decorator supports registration of backbone builder as follows:
```
class MyModel(tf.keras.Model):
pass
@register_backbone_builder('mybackbone')
def builder(input_specs, config, l2_reg):
return MyModel(...)
# Builds a MyModel object.
my_backbone = build_backbone_3d(input_specs, config, l2_reg)
```
Args:
key: the key to look up the builder.
Returns:
A callable for use as class decorator that registers the decorated class
for creation from an instance of model class.
"""
return registry.register(_REGISTERED_MODEL_CLS, key)
def build_model(model_type: str,
input_specs: tf.keras.layers.InputSpec,
model_config: video_classification_cfg.hyperparams.Config,
num_classes: int,
l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds backbone from a config.
Args:
model_type: string name of model type. It should be consistent with
ModelConfig.model_type.
input_specs: tf.keras.layers.InputSpec.
model_config: a OneOfConfig. Model config.
num_classes: number of classes.
l2_regularizer: tf.keras.regularizers.Regularizer instance. Default to None.
Returns:
tf.keras.Model instance of the backbone.
"""
model_builder = registry.lookup(_REGISTERED_MODEL_CLS, model_type)
return model_builder(input_specs, model_config, num_classes, l2_regularizer)
@register_model_builder('video_classification')
def build_video_classification_model(
input_specs: tf.keras.layers.InputSpec,
model_config: video_classification_cfg.VideoClassificationModel,
num_classes: int,
l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds the video classification model."""
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
norm_activation_config = model_config.norm_activation
model = video_classification_model.VideoClassificationModel(
backbone=backbone,
num_classes=num_classes,
input_specs=input_specs,
dropout_rate=model_config.dropout_rate,
kernel_regularizer=l2_regularizer,
add_head_batch_norm=model_config.add_head_batch_norm,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon)
return model
......@@ -26,6 +26,7 @@ from official.vision.beta.configs import maskrcnn as maskrcnn_cfg
from official.vision.beta.configs import retinanet as retinanet_cfg
from official.vision.beta.configs import video_classification as video_classification_cfg
from official.vision.beta.modeling import factory
from official.vision.beta.modeling import factory_3d
class ClassificationModelBuilderTest(parameterized.TestCase, tf.test.TestCase):
......@@ -105,7 +106,7 @@ class VideoClassificationModelBuilderTest(parameterized.TestCase,
backbone=backbones_3d.Backbone3D(type=backbone_type))
l2_regularizer = (
tf.keras.regularizers.l2(weight_decay) if weight_decay else None)
_ = factory.build_video_classification_model(
_ = factory_3d.build_video_classification_model(
input_specs=input_specs,
model_config=model_config,
num_classes=2,
......
......@@ -21,7 +21,7 @@ from official.core import task_factory
from official.modeling import tf_utils
from official.vision.beta.configs import video_classification as exp_cfg
from official.vision.beta.dataloaders import video_input
from official.vision.beta.modeling import factory
from official.vision.beta.modeling import factory_3d
@task_factory.register_task_cls(exp_cfg.VideoClassificationTask)
......@@ -39,7 +39,8 @@ class VideoClassificationTask(base_task.Task):
l2_regularizer = (tf.keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None)
model = factory.build_video_classification_model(
model = factory_3d.build_model(
self.task_config.model.model_type,
input_specs=input_specs,
model_config=self.task_config.model,
num_classes=self.task_config.train_data.num_classes,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment