Commit 091da63d authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 373084407
parent 3091fb64
...@@ -67,6 +67,8 @@ class SpineNet(hyperparams.Config): ...@@ -67,6 +67,8 @@ class SpineNet(hyperparams.Config):
"""SpineNet config.""" """SpineNet config."""
model_id: str = '49' model_id: str = '49'
stochastic_depth_drop_rate: float = 0.0 stochastic_depth_drop_rate: float = 0.0
min_level: int = 3
max_level: int = 7
@dataclasses.dataclass @dataclasses.dataclass
...@@ -76,6 +78,8 @@ class SpineNetMobile(hyperparams.Config): ...@@ -76,6 +78,8 @@ class SpineNetMobile(hyperparams.Config):
stochastic_depth_drop_rate: float = 0.0 stochastic_depth_drop_rate: float = 0.0
se_ratio: float = 0.2 se_ratio: float = 0.2
expand_ratio: int = 6 expand_ratio: int = 6
min_level: int = 3
max_level: int = 7
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -437,7 +437,12 @@ def maskrcnn_spinenet_coco() -> cfg.ExperimentConfig: ...@@ -437,7 +437,12 @@ def maskrcnn_spinenet_coco() -> cfg.ExperimentConfig:
'instances_val2017.json'), 'instances_val2017.json'),
model=MaskRCNN( model=MaskRCNN(
backbone=backbones.Backbone( backbone=backbones.Backbone(
type='spinenet', spinenet=backbones.SpineNet(model_id='49')), type='spinenet',
spinenet=backbones.SpineNet(
model_id='49',
min_level=3,
max_level=7,
)),
decoder=decoders.Decoder( decoder=decoders.Decoder(
type='identity', identity=decoders.Identity()), type='identity', identity=decoders.Identity()),
anchor=Anchor(anchor_size=3), anchor=Anchor(anchor_size=3),
...@@ -491,6 +496,8 @@ def maskrcnn_spinenet_coco() -> cfg.ExperimentConfig: ...@@ -491,6 +496,8 @@ def maskrcnn_spinenet_coco() -> cfg.ExperimentConfig:
})), })),
restrictions=[ restrictions=[
'task.train_data.is_training != None', 'task.train_data.is_training != None',
'task.validation_data.is_training != None' 'task.validation_data.is_training != None',
'task.model.min_level == task,model.backbone.spinenet.min_level',
'task.model.max_level == task,model.backbone.spinenet.max_level',
]) ])
return config return config
...@@ -248,7 +248,10 @@ def retinanet_spinenet_coco() -> cfg.ExperimentConfig: ...@@ -248,7 +248,10 @@ def retinanet_spinenet_coco() -> cfg.ExperimentConfig:
backbone=backbones.Backbone( backbone=backbones.Backbone(
type='spinenet', type='spinenet',
spinenet=backbones.SpineNet( spinenet=backbones.SpineNet(
model_id='49', stochastic_depth_drop_rate=0.2)), model_id='49',
stochastic_depth_drop_rate=0.2,
min_level=3,
max_level=7)),
decoder=decoders.Decoder( decoder=decoders.Decoder(
type='identity', identity=decoders.Identity()), type='identity', identity=decoders.Identity()),
anchor=Anchor(anchor_size=3), anchor=Anchor(anchor_size=3),
...@@ -306,7 +309,9 @@ def retinanet_spinenet_coco() -> cfg.ExperimentConfig: ...@@ -306,7 +309,9 @@ def retinanet_spinenet_coco() -> cfg.ExperimentConfig:
})), })),
restrictions=[ restrictions=[
'task.train_data.is_training != None', 'task.train_data.is_training != None',
'task.validation_data.is_training != None' 'task.validation_data.is_training != None',
'task.model.min_level == task,model.backbone.spinenet.min_level',
'task.model.max_level == task,model.backbone.spinenet.max_level',
]) ])
return config return config
...@@ -329,7 +334,10 @@ def retinanet_spinenet_mobile_coco() -> cfg.ExperimentConfig: ...@@ -329,7 +334,10 @@ def retinanet_spinenet_mobile_coco() -> cfg.ExperimentConfig:
backbone=backbones.Backbone( backbone=backbones.Backbone(
type='spinenet_mobile', type='spinenet_mobile',
spinenet_mobile=backbones.SpineNetMobile( spinenet_mobile=backbones.SpineNetMobile(
model_id='49', stochastic_depth_drop_rate=0.2)), model_id='49',
stochastic_depth_drop_rate=0.2,
min_level=3,
max_level=7)),
decoder=decoders.Decoder( decoder=decoders.Decoder(
type='identity', identity=decoders.Identity()), type='identity', identity=decoders.Identity()),
head=RetinaNetHead(num_filters=48, use_separable_conv=True), head=RetinaNetHead(num_filters=48, use_separable_conv=True),
...@@ -388,7 +396,9 @@ def retinanet_spinenet_mobile_coco() -> cfg.ExperimentConfig: ...@@ -388,7 +396,9 @@ def retinanet_spinenet_mobile_coco() -> cfg.ExperimentConfig:
})), })),
restrictions=[ restrictions=[
'task.train_data.is_training != None', 'task.train_data.is_training != None',
'task.validation_data.is_training != None' 'task.validation_data.is_training != None',
'task.model.min_level == task,model.backbone.spinenet_mobile.min_level',
'task.model.max_level == task,model.backbone.spinenet_mobile.max_level',
]) ])
return config return config
...@@ -297,12 +297,12 @@ class EfficientNet(tf.keras.Model): ...@@ -297,12 +297,12 @@ class EfficientNet(tf.keras.Model):
@factory.register_backbone_builder('efficientnet') @factory.register_backbone_builder('efficientnet')
def build_efficientnet( def build_efficientnet(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config: hyperparams.Config, backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds EfficientNet backbone from a config.""" """Builds EfficientNet backbone from a config."""
backbone_type = model_config.backbone.type backbone_type = backbone_config.type
backbone_cfg = model_config.backbone.get() backbone_cfg = backbone_config.get()
norm_activation_config = model_config.norm_activation
assert backbone_type == 'efficientnet', (f'Inconsistent backbone type ' assert backbone_type == 'efficientnet', (f'Inconsistent backbone type '
f'{backbone_type}') f'{backbone_type}')
......
...@@ -42,6 +42,8 @@ in place that uses it. ...@@ -42,6 +42,8 @@ in place that uses it.
""" """
from typing import Sequence, Union
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
...@@ -81,22 +83,31 @@ def register_backbone_builder(key: str): ...@@ -81,22 +83,31 @@ def register_backbone_builder(key: str):
return registry.register(_REGISTERED_BACKBONE_CLS, key) return registry.register(_REGISTERED_BACKBONE_CLS, key)
def build_backbone( def build_backbone(input_specs: Union[tf.keras.layers.InputSpec,
input_specs: tf.keras.layers.InputSpec, Sequence[tf.keras.layers.InputSpec]],
model_config: hyperparams.Config, backbone_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None,
**kwargs) -> tf.keras.Model:
"""Builds backbone from a config. """Builds backbone from a config.
Args: Args:
input_specs: A `tf.keras.layers.InputSpec` of input. input_specs: A (sequence of) `tf.keras.layers.InputSpec` of input.
model_config: A `OneOfConfig` of model config. backbone_config: A `OneOfConfig` of backbone config.
norm_activation_config: A config for normalization/activation layer.
l2_regularizer: A `tf.keras.regularizers.Regularizer` object. Default to l2_regularizer: A `tf.keras.regularizers.Regularizer` object. Default to
None. None.
**kwargs: Additional keyword args to be passed to backbone builder.
Returns: Returns:
A `tf.keras.Model` instance of the backbone. A `tf.keras.Model` instance of the backbone.
""" """
backbone_builder = registry.lookup(_REGISTERED_BACKBONE_CLS, backbone_builder = registry.lookup(_REGISTERED_BACKBONE_CLS,
model_config.backbone.type) backbone_config.type)
return backbone_builder(input_specs, model_config, l2_regularizer) return backbone_builder(
input_specs=input_specs,
backbone_config=backbone_config,
norm_activation_config=norm_activation_config,
l2_regularizer=l2_regularizer,
**kwargs)
...@@ -22,7 +22,6 @@ from tensorflow.python.distribute import combinations ...@@ -22,7 +22,6 @@ from tensorflow.python.distribute import combinations
from official.vision.beta.configs import backbones as backbones_cfg from official.vision.beta.configs import backbones as backbones_cfg
from official.vision.beta.configs import backbones_3d as backbones_3d_cfg from official.vision.beta.configs import backbones_3d as backbones_3d_cfg
from official.vision.beta.configs import common as common_cfg from official.vision.beta.configs import common as common_cfg
from official.vision.beta.configs import retinanet as retinanet_cfg
from official.vision.beta.modeling import backbones from official.vision.beta.modeling import backbones
from official.vision.beta.modeling.backbones import factory from official.vision.beta.modeling.backbones import factory
...@@ -42,12 +41,11 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -42,12 +41,11 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
resnet=backbones_cfg.ResNet(model_id=model_id, se_ratio=0.0)) resnet=backbones_cfg.ResNet(model_id=model_id, se_ratio=0.0))
norm_activation_config = common_cfg.NormActivation( norm_activation_config = common_cfg.NormActivation(
norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False) norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False)
model_config = retinanet_cfg.RetinaNet(
backbone=backbone_config, norm_activation=norm_activation_config)
factory_network = factory.build_backbone( factory_network = factory.build_backbone(
input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]), input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]),
model_config=model_config) backbone_config=backbone_config,
norm_activation_config=norm_activation_config)
network_config = network.get_config() network_config = network.get_config()
factory_network_config = factory_network.get_config() factory_network_config = factory_network.get_config()
...@@ -74,12 +72,11 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -74,12 +72,11 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
model_id=model_id, se_ratio=se_ratio)) model_id=model_id, se_ratio=se_ratio))
norm_activation_config = common_cfg.NormActivation( norm_activation_config = common_cfg.NormActivation(
norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False) norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False)
model_config = retinanet_cfg.RetinaNet(
backbone=backbone_config, norm_activation=norm_activation_config)
factory_network = factory.build_backbone( factory_network = factory.build_backbone(
input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]), input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]),
model_config=model_config) backbone_config=backbone_config,
norm_activation_config=norm_activation_config)
network_config = network.get_config() network_config = network.get_config()
factory_network_config = factory_network.get_config() factory_network_config = factory_network.get_config()
...@@ -108,12 +105,11 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -108,12 +105,11 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
model_id=model_id, filter_size_scale=filter_size_scale)) model_id=model_id, filter_size_scale=filter_size_scale))
norm_activation_config = common_cfg.NormActivation( norm_activation_config = common_cfg.NormActivation(
norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False) norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False)
model_config = retinanet_cfg.RetinaNet(
backbone=backbone_config, norm_activation=norm_activation_config)
factory_network = factory.build_backbone( factory_network = factory.build_backbone(
input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]), input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]),
model_config=model_config) backbone_config=backbone_config,
norm_activation_config=norm_activation_config)
network_config = network.get_config() network_config = network.get_config()
factory_network_config = factory_network.get_config() factory_network_config = factory_network.get_config()
...@@ -141,13 +137,12 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -141,13 +137,12 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
spinenet=backbones_cfg.SpineNet(model_id=model_id)) spinenet=backbones_cfg.SpineNet(model_id=model_id))
norm_activation_config = common_cfg.NormActivation( norm_activation_config = common_cfg.NormActivation(
norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False) norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False)
model_config = retinanet_cfg.RetinaNet(
backbone=backbone_config, norm_activation=norm_activation_config)
factory_network = factory.build_backbone( factory_network = factory.build_backbone(
input_specs=tf.keras.layers.InputSpec( input_specs=tf.keras.layers.InputSpec(
shape=[None, input_size, input_size, 3]), shape=[None, input_size, input_size, 3]),
model_config=model_config) backbone_config=backbone_config,
norm_activation_config=norm_activation_config)
network_config = network.get_config() network_config = network.get_config()
factory_network_config = factory_network.get_config() factory_network_config = factory_network.get_config()
...@@ -166,12 +161,11 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -166,12 +161,11 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
revnet=backbones_cfg.RevNet(model_id=model_id)) revnet=backbones_cfg.RevNet(model_id=model_id))
norm_activation_config = common_cfg.NormActivation( norm_activation_config = common_cfg.NormActivation(
norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False) norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False)
model_config = retinanet_cfg.RetinaNet(
backbone=backbone_config, norm_activation=norm_activation_config)
factory_network = factory.build_backbone( factory_network = factory.build_backbone(
input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]), input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]),
model_config=model_config) backbone_config=backbone_config,
norm_activation_config=norm_activation_config)
network_config = network.get_config() network_config = network.get_config()
factory_network_config = factory_network.get_config() factory_network_config = factory_network.get_config()
......
...@@ -766,12 +766,12 @@ class MobileNet(tf.keras.Model): ...@@ -766,12 +766,12 @@ class MobileNet(tf.keras.Model):
@factory.register_backbone_builder('mobilenet') @factory.register_backbone_builder('mobilenet')
def build_mobilenet( def build_mobilenet(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config: hyperparams.Config, backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds MobileNet backbone from a config.""" """Builds MobileNet backbone from a config."""
backbone_type = model_config.backbone.type backbone_type = backbone_config.type
backbone_cfg = model_config.backbone.get() backbone_cfg = backbone_config.get()
norm_activation_config = model_config.norm_activation
assert backbone_type == 'mobilenet', (f'Inconsistent backbone type ' assert backbone_type == 'mobilenet', (f'Inconsistent backbone type '
f'{backbone_type}') f'{backbone_type}')
......
...@@ -372,12 +372,12 @@ class ResNet(tf.keras.Model): ...@@ -372,12 +372,12 @@ class ResNet(tf.keras.Model):
@factory.register_backbone_builder('resnet') @factory.register_backbone_builder('resnet')
def build_resnet( def build_resnet(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config: hyperparams.Config, backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds ResNet backbone from a config.""" """Builds ResNet backbone from a config."""
backbone_type = model_config.backbone.type backbone_type = backbone_config.type
backbone_cfg = model_config.backbone.get() backbone_cfg = backbone_config.get()
norm_activation_config = model_config.norm_activation
assert backbone_type == 'resnet', (f'Inconsistent backbone type ' assert backbone_type == 'resnet', (f'Inconsistent backbone type '
f'{backbone_type}') f'{backbone_type}')
......
...@@ -378,11 +378,11 @@ class ResNet3D(tf.keras.Model): ...@@ -378,11 +378,11 @@ class ResNet3D(tf.keras.Model):
@factory.register_backbone_builder('resnet_3d') @factory.register_backbone_builder('resnet_3d')
def build_resnet3d( def build_resnet3d(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config, backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds ResNet 3d backbone from a config.""" """Builds ResNet 3d backbone from a config."""
backbone_cfg = model_config.backbone.get() backbone_cfg = backbone_config.get()
norm_activation_config = model_config.norm_activation
# Flatten configs before passing to the backbone. # Flatten configs before passing to the backbone.
temporal_strides = [] temporal_strides = []
...@@ -416,11 +416,11 @@ def build_resnet3d( ...@@ -416,11 +416,11 @@ def build_resnet3d(
@factory.register_backbone_builder('resnet_3d_rs') @factory.register_backbone_builder('resnet_3d_rs')
def build_resnet3d_rs( def build_resnet3d_rs(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config: hyperparams.Config, backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds ResNet-3D-RS backbone from a config.""" """Builds ResNet-3D-RS backbone from a config."""
backbone_cfg = model_config.backbone.get() backbone_cfg = backbone_config.get()
norm_activation_config = model_config.norm_activation
# Flatten configs before passing to the backbone. # Flatten configs before passing to the backbone.
temporal_strides = [] temporal_strides = []
......
...@@ -18,6 +18,7 @@ from typing import Callable, Optional, Tuple, List ...@@ -18,6 +18,7 @@ from typing import Callable, Optional, Tuple, List
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.modeling import hyperparams
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.beta.modeling.backbones import factory from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_blocks from official.vision.beta.modeling.layers import nn_blocks
...@@ -340,12 +341,12 @@ class DilatedResNet(tf.keras.Model): ...@@ -340,12 +341,12 @@ class DilatedResNet(tf.keras.Model):
@factory.register_backbone_builder('dilated_resnet') @factory.register_backbone_builder('dilated_resnet')
def build_dilated_resnet( def build_dilated_resnet(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config, backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds ResNet backbone from a config.""" """Builds ResNet backbone from a config."""
backbone_type = model_config.backbone.type backbone_type = backbone_config.type
backbone_cfg = model_config.backbone.get() backbone_cfg = backbone_config.get()
norm_activation_config = model_config.norm_activation
assert backbone_type == 'dilated_resnet', (f'Inconsistent backbone type ' assert backbone_type == 'dilated_resnet', (f'Inconsistent backbone type '
f'{backbone_type}') f'{backbone_type}')
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
from typing import Any, Callable, Dict, Optional from typing import Any, Callable, Dict, Optional
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
from official.modeling import hyperparams
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.beta.modeling.backbones import factory from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_blocks from official.vision.beta.modeling.layers import nn_blocks
...@@ -213,12 +214,12 @@ class RevNet(tf.keras.Model): ...@@ -213,12 +214,12 @@ class RevNet(tf.keras.Model):
@factory.register_backbone_builder('revnet') @factory.register_backbone_builder('revnet')
def build_revnet( def build_revnet(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config, backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds RevNet backbone from a config.""" """Builds RevNet backbone from a config."""
backbone_type = model_config.backbone.type backbone_type = backbone_config.type
backbone_cfg = model_config.backbone.get() backbone_cfg = backbone_config.get()
norm_activation_config = model_config.norm_activation
assert backbone_type == 'revnet', (f'Inconsistent backbone type ' assert backbone_type == 'revnet', (f'Inconsistent backbone type '
f'{backbone_type}') f'{backbone_type}')
......
...@@ -22,6 +22,7 @@ from typing import Any, List, Optional, Tuple ...@@ -22,6 +22,7 @@ from typing import Any, List, Optional, Tuple
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import hyperparams
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.beta.modeling.backbones import factory from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_blocks from official.vision.beta.modeling.layers import nn_blocks
...@@ -527,12 +528,12 @@ class SpineNet(tf.keras.Model): ...@@ -527,12 +528,12 @@ class SpineNet(tf.keras.Model):
@factory.register_backbone_builder('spinenet') @factory.register_backbone_builder('spinenet')
def build_spinenet( def build_spinenet(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config, backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds SpineNet backbone from a config.""" """Builds SpineNet backbone from a config."""
backbone_type = model_config.backbone.type backbone_type = backbone_config.type
backbone_cfg = model_config.backbone.get() backbone_cfg = backbone_config.get()
norm_activation_config = model_config.norm_activation
assert backbone_type == 'spinenet', (f'Inconsistent backbone type ' assert backbone_type == 'spinenet', (f'Inconsistent backbone type '
f'{backbone_type}') f'{backbone_type}')
...@@ -544,8 +545,8 @@ def build_spinenet( ...@@ -544,8 +545,8 @@ def build_spinenet(
return SpineNet( return SpineNet(
input_specs=input_specs, input_specs=input_specs,
min_level=model_config.min_level, min_level=backbone_cfg.min_level,
max_level=model_config.max_level, max_level=backbone_cfg.max_level,
endpoints_num_filters=scaling_params['endpoints_num_filters'], endpoints_num_filters=scaling_params['endpoints_num_filters'],
resample_alpha=scaling_params['resample_alpha'], resample_alpha=scaling_params['resample_alpha'],
block_repeats=scaling_params['block_repeats'], block_repeats=scaling_params['block_repeats'],
......
...@@ -36,6 +36,7 @@ from typing import Any, List, Optional, Tuple ...@@ -36,6 +36,7 @@ from typing import Any, List, Optional, Tuple
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import hyperparams
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.beta.modeling.backbones import factory from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_blocks from official.vision.beta.modeling.layers import nn_blocks
...@@ -501,12 +502,12 @@ class SpineNetMobile(tf.keras.Model): ...@@ -501,12 +502,12 @@ class SpineNetMobile(tf.keras.Model):
@factory.register_backbone_builder('spinenet_mobile') @factory.register_backbone_builder('spinenet_mobile')
def build_spinenet_mobile( def build_spinenet_mobile(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config, backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds Mobile SpineNet backbone from a config.""" """Builds Mobile SpineNet backbone from a config."""
backbone_type = model_config.backbone.type backbone_type = backbone_config.type
backbone_cfg = model_config.backbone.get() backbone_cfg = backbone_config.get()
norm_activation_config = model_config.norm_activation
assert backbone_type == 'spinenet_mobile', (f'Inconsistent backbone type ' assert backbone_type == 'spinenet_mobile', (f'Inconsistent backbone type '
f'{backbone_type}') f'{backbone_type}')
...@@ -518,8 +519,8 @@ def build_spinenet_mobile( ...@@ -518,8 +519,8 @@ def build_spinenet_mobile(
return SpineNetMobile( return SpineNetMobile(
input_specs=input_specs, input_specs=input_specs,
min_level=model_config.min_level, min_level=backbone_cfg.min_level,
max_level=model_config.max_level, max_level=backbone_cfg.max_level,
endpoints_num_filters=scaling_params['endpoints_num_filters'], endpoints_num_filters=scaling_params['endpoints_num_filters'],
block_repeats=scaling_params['block_repeats'], block_repeats=scaling_params['block_repeats'],
filter_size_scale=scaling_params['filter_size_scale'], filter_size_scale=scaling_params['filter_size_scale'],
......
...@@ -44,12 +44,13 @@ def build_classification_model( ...@@ -44,12 +44,13 @@ def build_classification_model(
l2_regularizer: tf.keras.regularizers.Regularizer = None, l2_regularizer: tf.keras.regularizers.Regularizer = None,
skip_logits_layer: bool = False) -> tf.keras.Model: skip_logits_layer: bool = False) -> tf.keras.Model:
"""Builds the classification model.""" """Builds the classification model."""
norm_activation_config = model_config.norm_activation
backbone = backbones.factory.build_backbone( backbone = backbones.factory.build_backbone(
input_specs=input_specs, input_specs=input_specs,
model_config=model_config, backbone_config=model_config.backbone,
norm_activation_config=norm_activation_config,
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
norm_activation_config = model_config.norm_activation
model = classification_model.ClassificationModel( model = classification_model.ClassificationModel(
backbone=backbone, backbone=backbone,
num_classes=model_config.num_classes, num_classes=model_config.num_classes,
...@@ -69,9 +70,11 @@ def build_maskrcnn( ...@@ -69,9 +70,11 @@ def build_maskrcnn(
model_config: maskrcnn_cfg.MaskRCNN, model_config: maskrcnn_cfg.MaskRCNN,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds Mask R-CNN model.""" """Builds Mask R-CNN model."""
norm_activation_config = model_config.norm_activation
backbone = backbones.factory.build_backbone( backbone = backbones.factory.build_backbone(
input_specs=input_specs, input_specs=input_specs,
model_config=model_config, backbone_config=model_config.backbone,
norm_activation_config=norm_activation_config,
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
decoder = decoder_factory.build_decoder( decoder = decoder_factory.build_decoder(
...@@ -85,7 +88,6 @@ def build_maskrcnn( ...@@ -85,7 +88,6 @@ def build_maskrcnn(
roi_aligner_config = model_config.roi_aligner roi_aligner_config = model_config.roi_aligner
detection_head_config = model_config.detection_head detection_head_config = model_config.detection_head
generator_config = model_config.detection_generator generator_config = model_config.detection_generator
norm_activation_config = model_config.norm_activation
num_anchors_per_location = ( num_anchors_per_location = (
len(model_config.anchor.aspect_ratios) * model_config.anchor.num_scales) len(model_config.anchor.aspect_ratios) * model_config.anchor.num_scales)
...@@ -242,9 +244,11 @@ def build_retinanet( ...@@ -242,9 +244,11 @@ def build_retinanet(
model_config: retinanet_cfg.RetinaNet, model_config: retinanet_cfg.RetinaNet,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds RetinaNet model.""" """Builds RetinaNet model."""
norm_activation_config = model_config.norm_activation
backbone = backbones.factory.build_backbone( backbone = backbones.factory.build_backbone(
input_specs=input_specs, input_specs=input_specs,
model_config=model_config, backbone_config=model_config.backbone,
norm_activation_config=norm_activation_config,
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
decoder = decoder_factory.build_decoder( decoder = decoder_factory.build_decoder(
...@@ -254,7 +258,6 @@ def build_retinanet( ...@@ -254,7 +258,6 @@ def build_retinanet(
head_config = model_config.head head_config = model_config.head
generator_config = model_config.detection_generator generator_config = model_config.detection_generator
norm_activation_config = model_config.norm_activation
num_anchors_per_location = ( num_anchors_per_location = (
len(model_config.anchor.aspect_ratios) * model_config.anchor.num_scales) len(model_config.anchor.aspect_ratios) * model_config.anchor.num_scales)
...@@ -301,9 +304,11 @@ def build_segmentation_model( ...@@ -301,9 +304,11 @@ def build_segmentation_model(
model_config: segmentation_cfg.SemanticSegmentationModel, model_config: segmentation_cfg.SemanticSegmentationModel,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds Segmentation model.""" """Builds Segmentation model."""
norm_activation_config = model_config.norm_activation
backbone = backbones.factory.build_backbone( backbone = backbones.factory.build_backbone(
input_specs=input_specs, input_specs=input_specs,
model_config=model_config, backbone_config=model_config.backbone,
norm_activation_config=norm_activation_config,
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
decoder = decoder_factory.build_decoder( decoder = decoder_factory.build_decoder(
...@@ -312,7 +317,6 @@ def build_segmentation_model( ...@@ -312,7 +317,6 @@ def build_segmentation_model(
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
head_config = model_config.head head_config = model_config.head
norm_activation_config = model_config.norm_activation
head = segmentation_heads.SegmentationHead( head = segmentation_heads.SegmentationHead(
num_classes=model_config.num_classes, num_classes=model_config.num_classes,
......
...@@ -85,9 +85,11 @@ def build_video_classification_model( ...@@ -85,9 +85,11 @@ def build_video_classification_model(
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds the video classification model.""" """Builds the video classification model."""
input_specs_dict = {'image': input_specs} input_specs_dict = {'image': input_specs}
norm_activation_config = model_config.norm_activation
backbone = backbones.factory.build_backbone( backbone = backbones.factory.build_backbone(
input_specs=input_specs, input_specs=input_specs,
model_config=model_config, backbone_config=model_config.backbone,
norm_activation_config=norm_activation_config,
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
model = video_classification_model.VideoClassificationModel( model = video_classification_model.VideoClassificationModel(
......
...@@ -54,6 +54,7 @@ from absl import logging ...@@ -54,6 +54,7 @@ from absl import logging
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.modeling import hyperparams
from official.vision.beta.modeling import factory_3d as model_factory from official.vision.beta.modeling import factory_3d as model_factory
from official.vision.beta.modeling.backbones import factory as backbone_factory from official.vision.beta.modeling.backbones import factory as backbone_factory
from official.vision.beta.projects.assemblenet.configs import assemblenet as cfg from official.vision.beta.projects.assemblenet.configs import assemblenet as cfg
...@@ -1015,14 +1016,14 @@ def assemblenet_v1(assemblenet_depth: int, ...@@ -1015,14 +1016,14 @@ def assemblenet_v1(assemblenet_depth: int,
@backbone_factory.register_backbone_builder('assemblenet') @backbone_factory.register_backbone_builder('assemblenet')
def build_assemblenet_v1( def build_assemblenet_v1(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config: cfg.Backbone3D, backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds assemblenet backbone.""" """Builds assemblenet backbone."""
del l2_regularizer del l2_regularizer
backbone_type = model_config.backbone.type backbone_type = backbone_config.type
backbone_cfg = model_config.backbone.get() backbone_cfg = backbone_config.get()
norm_activation_config = model_config.norm_activation
assert backbone_type == 'assemblenet' assert backbone_type == 'assemblenet'
assemblenet_depth = int(backbone_cfg.model_id) assemblenet_depth = int(backbone_cfg.model_id)
...@@ -1060,7 +1061,8 @@ def build_assemblenet_model( ...@@ -1060,7 +1061,8 @@ def build_assemblenet_model(
l2_regularizer: tf.keras.regularizers.Regularizer = None): l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds assemblenet model.""" """Builds assemblenet model."""
input_specs_dict = {'image': input_specs} input_specs_dict = {'image': input_specs}
backbone = build_assemblenet_v1(input_specs, model_config, l2_regularizer) backbone = build_assemblenet_v1(input_specs, model_config.backbone,
model_config.norm_activation, l2_regularizer)
backbone_cfg = model_config.backbone.get() backbone_cfg = model_config.backbone.get()
model_structure, _ = cfg.blocks_to_flat_lists(backbone_cfg.blocks) model_structure, _ = cfg.blocks_to_flat_lists(backbone_cfg.blocks)
model = AssembleNetModel( model = AssembleNetModel(
......
...@@ -37,9 +37,11 @@ def build_maskrcnn(input_specs: tf.keras.layers.InputSpec, ...@@ -37,9 +37,11 @@ def build_maskrcnn(input_specs: tf.keras.layers.InputSpec,
model_config: deep_mask_head_rcnn_config.DeepMaskHeadRCNN, model_config: deep_mask_head_rcnn_config.DeepMaskHeadRCNN,
l2_regularizer: tf.keras.regularizers.Regularizer = None): l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds Mask R-CNN model.""" """Builds Mask R-CNN model."""
norm_activation_config = model_config.norm_activation
backbone = backbones.factory.build_backbone( backbone = backbones.factory.build_backbone(
input_specs=input_specs, input_specs=input_specs,
model_config=model_config, backbone_config=model_config.backbone,
norm_activation_config=norm_activation_config,
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
decoder = decoder_factory.build_decoder( decoder = decoder_factory.build_decoder(
...@@ -53,7 +55,6 @@ def build_maskrcnn(input_specs: tf.keras.layers.InputSpec, ...@@ -53,7 +55,6 @@ def build_maskrcnn(input_specs: tf.keras.layers.InputSpec,
roi_aligner_config = model_config.roi_aligner roi_aligner_config = model_config.roi_aligner
detection_head_config = model_config.detection_head detection_head_config = model_config.detection_head
generator_config = model_config.detection_generator generator_config = model_config.detection_generator
norm_activation_config = model_config.norm_activation
num_anchors_per_location = ( num_anchors_per_location = (
len(model_config.anchor.aspect_ratios) * model_config.anchor.num_scales) len(model_config.anchor.aspect_ratios) * model_config.anchor.num_scales)
......
...@@ -110,7 +110,8 @@ class SimCLRPretrainTask(base_task.Task): ...@@ -110,7 +110,8 @@ class SimCLRPretrainTask(base_task.Task):
# Build backbone # Build backbone
backbone = backbones.factory.build_backbone( backbone = backbones.factory.build_backbone(
input_specs=input_specs, input_specs=input_specs,
model_config=model_config, backbone_config=model_config.backbone,
norm_activation_config=model_config.norm_activation,
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
# Build projection head # Build projection head
......
...@@ -40,6 +40,7 @@ import collections ...@@ -40,6 +40,7 @@ import collections
import tensorflow as tf import tensorflow as tf
from official.modeling import hyperparams
from official.vision.beta.modeling.backbones import factory from official.vision.beta.modeling.backbones import factory
from official.vision.beta.projects.yolo.modeling.layers import nn_blocks from official.vision.beta.projects.yolo.modeling.layers import nn_blocks
...@@ -428,12 +429,12 @@ class Darknet(tf.keras.Model): ...@@ -428,12 +429,12 @@ class Darknet(tf.keras.Model):
@factory.register_backbone_builder("darknet") @factory.register_backbone_builder("darknet")
def build_darknet( def build_darknet(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config, backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds darknet backbone.""" """Builds darknet backbone."""
backbone_cfg = model_config.backbone.get() backbone_cfg = backbone_config.get()
norm_activation_config = model_config.norm_activation
model = Darknet( model = Darknet(
model_id=backbone_cfg.model_id, model_id=backbone_cfg.model_id,
input_shape=input_specs, input_shape=input_specs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment