Commit f05152f8 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 333559599
parent bb5f46bb
......@@ -19,6 +19,7 @@ import math
from absl import logging
import tensorflow as tf
from official.modeling import tf_utils
from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_blocks
layers = tf.keras.layers
......@@ -290,3 +291,27 @@ class EfficientNet(tf.keras.Model):
def output_specs(self):
"""A dict of {level: TensorShape} pairs for the model output."""
return self._output_specs
@factory.register_backbone_builder('efficientnet')
def build_efficientnet(
input_specs: tf.keras.layers.InputSpec,
model_config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds ResNet 3d backbone from a config."""
backbone_type = model_config.backbone.type
backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
assert backbone_type == 'efficientnet', (f'Inconsistent backbone type '
f'{backbone_type}')
return EfficientNet(
model_id=backbone_cfg.model_id,
input_specs=input_specs,
stochastic_depth_drop_rate=backbone_cfg.stochastic_depth_drop_rate,
se_ratio=backbone_cfg.se_ratio,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
......@@ -13,90 +13,76 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""factory method."""
"""Backbone registers and factory method.
One can regitered a new backbone model by the following two steps:
1 Import the factory and register the build in the backbone file.
2 Import the backbone class and add a build in __init__.py.
```
# my_backbone.py
from modeling.backbones import factory
class MyBackbone():
...
@factory.register_backbone_builder('my_backbone')
def build_my_backbone():
return MyBackbone()
# backbones/__init__.py adds import
from modeling.backbones.my_backbone import MyBackbone
```
If one wants the MyBackbone class to be used only by those binary
then don't imported the backbone module in backbones/__init__.py, but import it
in place that uses it.
"""
# Import libraries
import tensorflow as tf
from official.vision.beta.modeling import backbones
from official.vision.beta.modeling.backbones import spinenet
from official.core import registry
def build_backbone(input_specs: tf.keras.layers.InputSpec,
model_config,
l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds backbone from a config.
_REGISTERED_BACKBONE_CLS = {}
def register_backbone_builder(key: str):
"""Decorates a builder of backbone class.
The builder should be a Callable (a class or a function).
This decorator supports registration of backbone builder as follows:
```
class MyBackbone(tf.keras.Model):
pass
@register_backbone_builder('mybackbone')
def builder(input_specs, config, l2_reg):
return MyBackbone(...)
# Builds a MyBackbone object.
my_backbone = build_backbone_3d(input_specs, config, l2_reg)
```
Args:
input_specs: tf.keras.layers.InputSpec.
model_config: a OneOfConfig. Model config.
l2_regularizer: tf.keras.regularizers.Regularizer instance. Default to None.
key: the key to look up the builder.
Returns:
tf.keras.Model instance of the backbone.
A callable for use as class decorator that registers the decorated class
for creation from an instance of task_config_cls.
"""
backbone_type = model_config.backbone.type
backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
if backbone_type == 'resnet':
backbone = backbones.ResNet(
model_id=backbone_cfg.model_id,
input_specs=input_specs,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
elif backbone_type == 'efficientnet':
backbone = backbones.EfficientNet(
model_id=backbone_cfg.model_id,
input_specs=input_specs,
stochastic_depth_drop_rate=backbone_cfg.stochastic_depth_drop_rate,
se_ratio=backbone_cfg.se_ratio,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
elif backbone_type == 'spinenet':
model_id = backbone_cfg.model_id
if model_id not in spinenet.SCALING_MAP:
raise ValueError(
'SpineNet-{} is not a valid architecture.'.format(model_id))
scaling_params = spinenet.SCALING_MAP[model_id]
backbone = backbones.SpineNet(
input_specs=input_specs,
min_level=model_config.min_level,
max_level=model_config.max_level,
endpoints_num_filters=scaling_params['endpoints_num_filters'],
resample_alpha=scaling_params['resample_alpha'],
block_repeats=scaling_params['block_repeats'],
filter_size_scale=scaling_params['filter_size_scale'],
kernel_regularizer=l2_regularizer,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon)
elif backbone_type == 'revnet':
backbone = backbones.RevNet(
model_id=backbone_cfg.model_id,
input_specs=input_specs,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
else:
raise ValueError('Backbone {!r} not implement'.format(backbone_type))
return backbone
def build_backbone_3d(input_specs: tf.keras.layers.InputSpec,
model_config,
l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds 3d backbone from a config.
return registry.register(_REGISTERED_BACKBONE_CLS, key)
def build_backbone(input_specs: tf.keras.layers.InputSpec,
model_config,
l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds backbone from a config.
Args:
input_specs: tf.keras.layers.InputSpec.
......@@ -106,32 +92,7 @@ def build_backbone_3d(input_specs: tf.keras.layers.InputSpec,
Returns:
tf.keras.Model instance of the backbone.
"""
backbone_type = model_config.backbone.type
backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
# Flatten configs before passing to the backbone.
temporal_strides = []
temporal_kernel_sizes = []
use_self_gating = []
for block_spec in backbone_cfg.block_specs:
temporal_strides.append(block_spec.temporal_strides)
temporal_kernel_sizes.append(block_spec.temporal_kernel_sizes)
use_self_gating.append(block_spec.use_self_gating)
if backbone_type == 'resnet_3d':
backbone = backbones.ResNet3D(
model_id=backbone_cfg.model_id,
temporal_strides=temporal_strides,
temporal_kernel_sizes=temporal_kernel_sizes,
use_self_gating=use_self_gating,
input_specs=input_specs,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
else:
raise ValueError('Backbone {!r} not implement'.format(backbone_type))
return backbone
backbone_builder = registry.lookup(_REGISTERED_BACKBONE_CLS,
model_config.backbone.type)
return backbone_builder(input_specs, model_config, l2_regularizer)
......@@ -22,6 +22,7 @@ Residual networks (ResNets) were proposed in:
# Import libraries
import tensorflow as tf
from official.modeling import tf_utils
from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_blocks
layers = tf.keras.layers
......@@ -229,3 +230,25 @@ class ResNet(tf.keras.Model):
def output_specs(self):
"""A dict of {level: TensorShape} pairs for the model output."""
return self._output_specs
@factory.register_backbone_builder('resnet')
def build_resnet(
input_specs: tf.keras.layers.InputSpec,
model_config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds ResNet 3d backbone from a config."""
backbone_type = model_config.backbone.type
backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
assert backbone_type == 'resnet', (f'Inconsistent backbone type '
f'{backbone_type}')
return ResNet(
model_id=backbone_cfg.model_id,
input_specs=input_specs,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
......@@ -18,6 +18,7 @@ from typing import List, Tuple
# Import libraries
import tensorflow as tf
from official.modeling import tf_utils
from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_blocks_3d
layers = tf.keras.layers
......@@ -259,3 +260,37 @@ class ResNet3D(tf.keras.Model):
def output_specs(self):
"""A dict of {level: TensorShape} pairs for the model output."""
return self._output_specs
@factory.register_backbone_builder('resnet_3d')
def build_resnet3d(
input_specs: tf.keras.layers.InputSpec,
model_config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds ResNet 3d backbone from a config."""
backbone_type = model_config.backbone.type
backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
assert backbone_type == 'resnet_3d', (f'Inconsistent backbone type '
f'{backbone_type}')
# Flatten configs before passing to the backbone.
temporal_strides = []
temporal_kernel_sizes = []
use_self_gating = []
for block_spec in backbone_cfg.block_specs:
temporal_strides.append(block_spec.temporal_strides)
temporal_kernel_sizes.append(block_spec.temporal_kernel_sizes)
use_self_gating.append(block_spec.use_self_gating)
return ResNet3D(
model_id=backbone_cfg.model_id,
temporal_strides=temporal_strides,
temporal_kernel_sizes=temporal_kernel_sizes,
use_self_gating=use_self_gating,
input_specs=input_specs,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
......@@ -24,6 +24,7 @@ from typing import Any, Callable, Dict, Optional
# Import libraries
import tensorflow as tf
from official.modeling import tf_utils
from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_blocks
......@@ -202,3 +203,25 @@ class RevNet(tf.keras.Model):
def output_specs(self) -> Dict[int, tf.TensorShape]:
"""A dict of {level: TensorShape} pairs for the model output."""
return self._output_specs
@factory.register_backbone_builder('revnet')
def build_revnet(
input_specs: tf.keras.layers.InputSpec,
model_config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds ResNet 3d backbone from a config."""
backbone_type = model_config.backbone.type
backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
assert backbone_type == 'revnet', (f'Inconsistent backbone type '
f'{backbone_type}')
return RevNet(
model_id=backbone_cfg.model_id,
input_specs=input_specs,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
......@@ -25,6 +25,7 @@ import math
from absl import logging
import tensorflow as tf
from official.modeling import tf_utils
from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_blocks
from official.vision.beta.ops import spatial_transform_ops
......@@ -476,3 +477,36 @@ class SpineNet(tf.keras.Model):
def output_specs(self):
"""A dict of {level: TensorShape} pairs for the model output."""
return self._output_specs
@factory.register_backbone_builder('spinenet')
def build_spinenet(
input_specs: tf.keras.layers.InputSpec,
model_config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds ResNet 3d backbone from a config."""
backbone_type = model_config.backbone.type
backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
assert backbone_type == 'spinenet', (f'Inconsistent backbone type '
f'{backbone_type}')
model_id = backbone_cfg.model_id
if model_id not in SCALING_MAP:
raise ValueError(
'SpineNet-{} is not a valid architecture.'.format(model_id))
scaling_params = SCALING_MAP[model_id]
return SpineNet(
input_specs=input_specs,
min_level=model_config.min_level,
max_level=model_config.max_level,
endpoints_num_filters=scaling_params['endpoints_num_filters'],
resample_alpha=scaling_params['resample_alpha'],
block_repeats=scaling_params['block_repeats'],
filter_size_scale=scaling_params['filter_size_scale'],
kernel_regularizer=l2_regularizer,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon)
......@@ -21,11 +21,11 @@ from official.vision.beta.configs import image_classification as classification_
from official.vision.beta.configs import maskrcnn as maskrcnn_cfg
from official.vision.beta.configs import retinanet as retinanet_cfg
from official.vision.beta.configs import video_classification as video_classification_cfg
from official.vision.beta.modeling import backbones
from official.vision.beta.modeling import classification_model
from official.vision.beta.modeling import maskrcnn_model
from official.vision.beta.modeling import retinanet_model
from official.vision.beta.modeling import video_classification_model
from official.vision.beta.modeling.backbones import factory as backbone_factory
from official.vision.beta.modeling.decoders import factory as decoder_factory
from official.vision.beta.modeling.heads import dense_prediction_heads
from official.vision.beta.modeling.heads import instance_heads
......@@ -41,7 +41,7 @@ def build_classification_model(
model_config: classification_cfg.ImageClassificationModel,
l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds the classification model."""
backbone = backbone_factory.build_backbone(
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
......@@ -64,7 +64,7 @@ def build_maskrcnn(input_specs: tf.keras.layers.InputSpec,
model_config: maskrcnn_cfg.MaskRCNN,
l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds Mask R-CNN model."""
backbone = backbone_factory.build_backbone(
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
......@@ -193,7 +193,7 @@ def build_retinanet(input_specs: tf.keras.layers.InputSpec,
model_config: retinanet_cfg.RetinaNet,
l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds RetinaNet model."""
backbone = backbone_factory.build_backbone(
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
......@@ -242,7 +242,7 @@ def build_video_classification_model(
num_classes: int,
l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds the video classification model."""
backbone = backbone_factory.build_backbone_3d(
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment