Commit d7e9ece3 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 334942484
parent 825cf9f1
...@@ -64,6 +64,7 @@ def kinetics600(is_training): ...@@ -64,6 +64,7 @@ def kinetics600(is_training):
@dataclasses.dataclass @dataclasses.dataclass
class VideoClassificationModel(hyperparams.Config): class VideoClassificationModel(hyperparams.Config):
"""The model config.""" """The model config."""
model_type: str = 'video_classification'
backbone: backbones_3d.Backbone3D = backbones_3d.Backbone3D( backbone: backbones_3d.Backbone3D = backbones_3d.Backbone3D(
type='resnet_3d', resnet_3d=backbones_3d.ResNet3D50()) type='resnet_3d', resnet_3d=backbones_3d.ResNet3D50())
norm_activation: common.NormActivation = common.NormActivation() norm_activation: common.NormActivation = common.NormActivation()
...@@ -142,6 +143,7 @@ def add_trainer(experiment: cfg.ExperimentConfig, ...@@ -142,6 +143,7 @@ def add_trainer(experiment: cfg.ExperimentConfig,
def video_classification() -> cfg.ExperimentConfig: def video_classification() -> cfg.ExperimentConfig:
"""Video classification general.""" """Video classification general."""
return cfg.ExperimentConfig( return cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'),
task=VideoClassificationTask(), task=VideoClassificationTask(),
trainer=cfg.TrainerConfig(), trainer=cfg.TrainerConfig(),
restrictions=[ restrictions=[
...@@ -166,6 +168,7 @@ def video_classification_kinetics600() -> cfg.ExperimentConfig: ...@@ -166,6 +168,7 @@ def video_classification_kinetics600() -> cfg.ExperimentConfig:
train_data=train_dataset, train_data=train_dataset,
validation_data=validation_dataset) validation_data=validation_dataset)
config = cfg.ExperimentConfig( config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'),
task=task, task=task,
restrictions=[ restrictions=[
'task.train_data.is_training != None', 'task.train_data.is_training != None',
......
...@@ -20,12 +20,10 @@ import tensorflow as tf ...@@ -20,12 +20,10 @@ import tensorflow as tf
from official.vision.beta.configs import image_classification as classification_cfg from official.vision.beta.configs import image_classification as classification_cfg
from official.vision.beta.configs import maskrcnn as maskrcnn_cfg from official.vision.beta.configs import maskrcnn as maskrcnn_cfg
from official.vision.beta.configs import retinanet as retinanet_cfg from official.vision.beta.configs import retinanet as retinanet_cfg
from official.vision.beta.configs import video_classification as video_classification_cfg
from official.vision.beta.modeling import backbones from official.vision.beta.modeling import backbones
from official.vision.beta.modeling import classification_model from official.vision.beta.modeling import classification_model
from official.vision.beta.modeling import maskrcnn_model from official.vision.beta.modeling import maskrcnn_model
from official.vision.beta.modeling import retinanet_model from official.vision.beta.modeling import retinanet_model
from official.vision.beta.modeling import video_classification_model
from official.vision.beta.modeling.decoders import factory as decoder_factory from official.vision.beta.modeling.decoders import factory as decoder_factory
from official.vision.beta.modeling.heads import dense_prediction_heads from official.vision.beta.modeling.heads import dense_prediction_heads
from official.vision.beta.modeling.heads import instance_heads from official.vision.beta.modeling.heads import instance_heads
...@@ -234,28 +232,3 @@ def build_retinanet(input_specs: tf.keras.layers.InputSpec, ...@@ -234,28 +232,3 @@ def build_retinanet(input_specs: tf.keras.layers.InputSpec,
model = retinanet_model.RetinaNetModel( model = retinanet_model.RetinaNetModel(
backbone, decoder, head, detection_generator_obj) backbone, decoder, head, detection_generator_obj)
return model return model
def build_video_classification_model(
input_specs: tf.keras.layers.InputSpec,
model_config: video_classification_cfg.VideoClassificationModel,
num_classes: int,
l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds the video classification model."""
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
norm_activation_config = model_config.norm_activation
model = video_classification_model.VideoClassificationModel(
backbone=backbone,
num_classes=num_classes,
input_specs=input_specs,
dropout_rate=model_config.dropout_rate,
kernel_regularizer=l2_regularizer,
add_head_batch_norm=model_config.add_head_batch_norm,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon)
return model
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Factory methods to build models."""
# Import libraries
import tensorflow as tf
from official.core import registry
from official.vision.beta.configs import video_classification as video_classification_cfg
from official.vision.beta.modeling import backbones
from official.vision.beta.modeling import video_classification_model
_REGISTERED_MODEL_CLS = {}
def register_model_builder(key: str):
"""Decorates a builder of model class.
The builder should be a Callable (a class or a function).
This decorator supports registration of backbone builder as follows:
```
class MyModel(tf.keras.Model):
pass
@register_backbone_builder('mybackbone')
def builder(input_specs, config, l2_reg):
return MyModel(...)
# Builds a MyModel object.
my_backbone = build_backbone_3d(input_specs, config, l2_reg)
```
Args:
key: the key to look up the builder.
Returns:
A callable for use as class decorator that registers the decorated class
for creation from an instance of model class.
"""
return registry.register(_REGISTERED_MODEL_CLS, key)
def build_model(model_type: str,
input_specs: tf.keras.layers.InputSpec,
model_config: video_classification_cfg.hyperparams.Config,
num_classes: int,
l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds backbone from a config.
Args:
model_type: string name of model type. It should be consistent with
ModelConfig.model_type.
input_specs: tf.keras.layers.InputSpec.
model_config: a OneOfConfig. Model config.
num_classes: number of classes.
l2_regularizer: tf.keras.regularizers.Regularizer instance. Default to None.
Returns:
tf.keras.Model instance of the backbone.
"""
model_builder = registry.lookup(_REGISTERED_MODEL_CLS, model_type)
return model_builder(input_specs, model_config, num_classes, l2_regularizer)
@register_model_builder('video_classification')
def build_video_classification_model(
input_specs: tf.keras.layers.InputSpec,
model_config: video_classification_cfg.VideoClassificationModel,
num_classes: int,
l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds the video classification model."""
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
norm_activation_config = model_config.norm_activation
model = video_classification_model.VideoClassificationModel(
backbone=backbone,
num_classes=num_classes,
input_specs=input_specs,
dropout_rate=model_config.dropout_rate,
kernel_regularizer=l2_regularizer,
add_head_batch_norm=model_config.add_head_batch_norm,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon)
return model
...@@ -26,6 +26,7 @@ from official.vision.beta.configs import maskrcnn as maskrcnn_cfg ...@@ -26,6 +26,7 @@ from official.vision.beta.configs import maskrcnn as maskrcnn_cfg
from official.vision.beta.configs import retinanet as retinanet_cfg from official.vision.beta.configs import retinanet as retinanet_cfg
from official.vision.beta.configs import video_classification as video_classification_cfg from official.vision.beta.configs import video_classification as video_classification_cfg
from official.vision.beta.modeling import factory from official.vision.beta.modeling import factory
from official.vision.beta.modeling import factory_3d
class ClassificationModelBuilderTest(parameterized.TestCase, tf.test.TestCase): class ClassificationModelBuilderTest(parameterized.TestCase, tf.test.TestCase):
...@@ -105,7 +106,7 @@ class VideoClassificationModelBuilderTest(parameterized.TestCase, ...@@ -105,7 +106,7 @@ class VideoClassificationModelBuilderTest(parameterized.TestCase,
backbone=backbones_3d.Backbone3D(type=backbone_type)) backbone=backbones_3d.Backbone3D(type=backbone_type))
l2_regularizer = ( l2_regularizer = (
tf.keras.regularizers.l2(weight_decay) if weight_decay else None) tf.keras.regularizers.l2(weight_decay) if weight_decay else None)
_ = factory.build_video_classification_model( _ = factory_3d.build_video_classification_model(
input_specs=input_specs, input_specs=input_specs,
model_config=model_config, model_config=model_config,
num_classes=2, num_classes=2,
......
...@@ -21,7 +21,7 @@ from official.core import task_factory ...@@ -21,7 +21,7 @@ from official.core import task_factory
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.beta.configs import video_classification as exp_cfg from official.vision.beta.configs import video_classification as exp_cfg
from official.vision.beta.dataloaders import video_input from official.vision.beta.dataloaders import video_input
from official.vision.beta.modeling import factory from official.vision.beta.modeling import factory_3d
@task_factory.register_task_cls(exp_cfg.VideoClassificationTask) @task_factory.register_task_cls(exp_cfg.VideoClassificationTask)
...@@ -39,7 +39,8 @@ class VideoClassificationTask(base_task.Task): ...@@ -39,7 +39,8 @@ class VideoClassificationTask(base_task.Task):
l2_regularizer = (tf.keras.regularizers.l2( l2_regularizer = (tf.keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None) l2_weight_decay / 2.0) if l2_weight_decay else None)
model = factory.build_video_classification_model( model = factory_3d.build_model(
self.task_config.model.model_type,
input_specs=input_specs, input_specs=input_specs,
model_config=self.task_config.model, model_config=self.task_config.model,
num_classes=self.task_config.train_data.num_classes, num_classes=self.task_config.train_data.num_classes,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment