Commit f8f4845c authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Add more pytype checking.

PiperOrigin-RevId: 368129317
parent 4334a892
...@@ -15,8 +15,13 @@ ...@@ -15,8 +15,13 @@
"""Contains definitions of EfficientNet Networks.""" """Contains definitions of EfficientNet Networks."""
import math import math
from typing import Any, List, Tuple
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
from official.modeling import hyperparams
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.beta.modeling.backbones import factory from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_blocks from official.vision.beta.modeling.layers import nn_blocks
...@@ -50,14 +55,32 @@ SCALING_MAP = { ...@@ -50,14 +55,32 @@ SCALING_MAP = {
} }
def round_repeats(repeats, multiplier, skip=False): class BlockSpec():
"""A container class that specifies the block configuration for MnasNet."""
def __init__(self, block_fn: str, block_repeats: int, kernel_size: int,
strides: int, expand_ratio: float, in_filters: int,
out_filters: int, is_output: bool, width_scale: float,
depth_scale: float):
self.block_fn = block_fn
self.block_repeats = round_repeats(block_repeats, depth_scale)
self.kernel_size = kernel_size
self.strides = strides
self.expand_ratio = expand_ratio
self.in_filters = nn_layers.round_filters(in_filters, width_scale)
self.out_filters = nn_layers.round_filters(out_filters, width_scale)
self.is_output = is_output
def round_repeats(repeats: int, multiplier: float, skip: bool = False) -> int:
"""Returns rounded number of filters based on depth multiplier.""" """Returns rounded number of filters based on depth multiplier."""
if skip or not multiplier: if skip or not multiplier:
return repeats return repeats
return int(math.ceil(multiplier * repeats)) return int(math.ceil(multiplier * repeats))
def block_spec_decoder(specs, width_scale, depth_scale): def block_spec_decoder(specs: List[Tuple[Any, ...]], width_scale: float,
depth_scale: float) -> List[BlockSpec]:
"""Decodes and returns specs for a block.""" """Decodes and returns specs for a block."""
decoded_specs = [] decoded_specs = []
for s in specs: for s in specs:
...@@ -69,22 +92,6 @@ def block_spec_decoder(specs, width_scale, depth_scale): ...@@ -69,22 +92,6 @@ def block_spec_decoder(specs, width_scale, depth_scale):
return decoded_specs return decoded_specs
class BlockSpec(object):
"""A container class that specifies the block configuration for MnasNet."""
def __init__(self, block_fn, block_repeats, kernel_size, strides,
expand_ratio, in_filters, out_filters, is_output, width_scale,
depth_scale):
self.block_fn = block_fn
self.block_repeats = round_repeats(block_repeats, depth_scale)
self.kernel_size = kernel_size
self.strides = strides
self.expand_ratio = expand_ratio
self.in_filters = nn_layers.round_filters(in_filters, width_scale)
self.out_filters = nn_layers.round_filters(out_filters, width_scale)
self.is_output = is_output
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
class EfficientNet(tf.keras.Model): class EfficientNet(tf.keras.Model):
"""Creates an EfficientNet family model. """Creates an EfficientNet family model.
...@@ -96,17 +103,18 @@ class EfficientNet(tf.keras.Model): ...@@ -96,17 +103,18 @@ class EfficientNet(tf.keras.Model):
""" """
def __init__(self, def __init__(self,
model_id, model_id: str,
input_specs=layers.InputSpec(shape=[None, None, None, 3]), input_specs: tf.keras.layers.InputSpec = layers.InputSpec(
se_ratio=0.0, shape=[None, None, None, 3]),
stochastic_depth_drop_rate=0.0, se_ratio: float = 0.0,
kernel_initializer='VarianceScaling', stochastic_depth_drop_rate: float = 0.0,
kernel_regularizer=None, kernel_initializer: str = 'VarianceScaling',
bias_regularizer=None, kernel_regularizer: tf.keras.regularizers.Regularizer = None,
activation='relu', bias_regularizer: tf.keras.regularizers.Regularizer = None,
use_sync_bn=False, activation: str = 'relu',
norm_momentum=0.99, use_sync_bn: bool = False,
norm_epsilon=0.001, norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
**kwargs): **kwargs):
"""Initializes an EfficientNet model. """Initializes an EfficientNet model.
...@@ -205,7 +213,10 @@ class EfficientNet(tf.keras.Model): ...@@ -205,7 +213,10 @@ class EfficientNet(tf.keras.Model):
super(EfficientNet, self).__init__( super(EfficientNet, self).__init__(
inputs=inputs, outputs=endpoints, **kwargs) inputs=inputs, outputs=endpoints, **kwargs)
def _block_group(self, inputs, specs, name='block_group'): def _block_group(self,
inputs: tf.Tensor,
specs: BlockSpec,
name: str = 'block_group'):
"""Creates one group of blocks for the EfficientNet model. """Creates one group of blocks for the EfficientNet model.
Args: Args:
...@@ -286,7 +297,7 @@ class EfficientNet(tf.keras.Model): ...@@ -286,7 +297,7 @@ class EfficientNet(tf.keras.Model):
@factory.register_backbone_builder('efficientnet') @factory.register_backbone_builder('efficientnet')
def build_efficientnet( def build_efficientnet(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config, model_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds EfficientNet backbone from a config.""" """Builds EfficientNet backbone from a config."""
backbone_type = model_config.backbone.type backbone_type = model_config.backbone.type
......
...@@ -43,9 +43,11 @@ in place that uses it. ...@@ -43,9 +43,11 @@ in place that uses it.
""" """
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
from official.core import registry from official.core import registry
from official.modeling import hyperparams
_REGISTERED_BACKBONE_CLS = {} _REGISTERED_BACKBONE_CLS = {}
...@@ -79,9 +81,10 @@ def register_backbone_builder(key: str): ...@@ -79,9 +81,10 @@ def register_backbone_builder(key: str):
return registry.register(_REGISTERED_BACKBONE_CLS, key) return registry.register(_REGISTERED_BACKBONE_CLS, key)
def build_backbone(input_specs: tf.keras.layers.InputSpec, def build_backbone(
model_config, input_specs: tf.keras.layers.InputSpec,
l2_regularizer: tf.keras.regularizers.Regularizer = None): model_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds backbone from a config. """Builds backbone from a config.
Args: Args:
......
...@@ -26,7 +26,6 @@ from official.vision.beta.modeling.layers import nn_blocks ...@@ -26,7 +26,6 @@ from official.vision.beta.modeling.layers import nn_blocks
from official.vision.beta.modeling.layers import nn_layers from official.vision.beta.modeling.layers import nn_layers
layers = tf.keras.layers layers = tf.keras.layers
regularizers = tf.keras.regularizers
# pylint: disable=pointless-string-statement # pylint: disable=pointless-string-statement
...@@ -417,18 +416,19 @@ class BlockSpec(hyperparams.Config): ...@@ -417,18 +416,19 @@ class BlockSpec(hyperparams.Config):
use_bias: bool = False use_bias: bool = False
use_normalization: bool = True use_normalization: bool = True
activation: str = 'relu6' activation: str = 'relu6'
# used for block type InvertedResConv # Used for block type InvertedResConv.
expand_ratio: Optional[float] = 6. expand_ratio: Optional[float] = 6.
# used for block type InvertedResConv with SE # Used for block type InvertedResConv with SE.
se_ratio: Optional[float] = None se_ratio: Optional[float] = None
use_depthwise: bool = True use_depthwise: bool = True
use_residual: bool = True use_residual: bool = True
is_output: bool = True is_output: bool = True
def block_spec_decoder(specs: Dict[Any, Any], def block_spec_decoder(
specs: Dict[Any, Any],
filter_size_scale: float, filter_size_scale: float,
# set to 1 for mobilenetv1 # Set to 1 for mobilenetv1.
divisible_by: int = 8, divisible_by: int = 8,
finegrain_classification_mode: bool = True): finegrain_classification_mode: bool = True):
"""Decodes specs for a block. """Decodes specs for a block.
...@@ -491,23 +491,23 @@ class MobileNet(tf.keras.Model): ...@@ -491,23 +491,23 @@ class MobileNet(tf.keras.Model):
self, self,
model_id: str = 'MobileNetV2', model_id: str = 'MobileNetV2',
filter_size_scale: float = 1.0, filter_size_scale: float = 1.0,
input_specs: layers.InputSpec = layers.InputSpec( input_specs: tf.keras.layers.InputSpec = layers.InputSpec(
shape=[None, None, None, 3]), shape=[None, None, None, 3]),
# The followings are for hyper-parameter tuning # The followings are for hyper-parameter tuning.
norm_momentum: float = 0.99, norm_momentum: float = 0.99,
norm_epsilon: float = 0.001, norm_epsilon: float = 0.001,
kernel_initializer: str = 'VarianceScaling', kernel_initializer: str = 'VarianceScaling',
kernel_regularizer: Optional[regularizers.Regularizer] = None, kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[regularizers.Regularizer] = None, bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
# The followings should be kept the same most of the times # The followings should be kept the same most of the times.
output_stride: int = None, output_stride: int = None,
min_depth: int = 8, min_depth: int = 8,
# divisible is not used in MobileNetV1 # divisible is not used in MobileNetV1.
divisible_by: int = 8, divisible_by: int = 8,
stochastic_depth_drop_rate: float = 0.0, stochastic_depth_drop_rate: float = 0.0,
regularize_depthwise: bool = False, regularize_depthwise: bool = False,
use_sync_bn: bool = False, use_sync_bn: bool = False,
# finegrain is not used in MobileNetV1 # finegrain is not used in MobileNetV1.
finegrain_classification_mode: bool = True, finegrain_classification_mode: bool = True,
**kwargs): **kwargs):
"""Initializes a MobileNet model. """Initializes a MobileNet model.
...@@ -636,8 +636,8 @@ class MobileNet(tf.keras.Model): ...@@ -636,8 +636,8 @@ class MobileNet(tf.keras.Model):
# A small catch for gpooling block with None strides # A small catch for gpooling block with None strides
if not block_def.strides: if not block_def.strides:
block_def.strides = 1 block_def.strides = 1
if self._output_stride is not None \ if (self._output_stride is not None and
and current_stride == self._output_stride: current_stride == self._output_stride):
# If we have reached the target output_stride, then we need to employ # If we have reached the target output_stride, then we need to employ
# atrous convolution with stride=1 and multiply the atrous rate by the # atrous convolution with stride=1 and multiply the atrous rate by the
# current unit's stride for use in subsequent layers. # current unit's stride for use in subsequent layers.
...@@ -764,7 +764,7 @@ class MobileNet(tf.keras.Model): ...@@ -764,7 +764,7 @@ class MobileNet(tf.keras.Model):
@factory.register_backbone_builder('mobilenet') @factory.register_backbone_builder('mobilenet')
def build_mobilenet( def build_mobilenet(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config, model_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds MobileNet backbone from a config.""" """Builds MobileNet backbone from a config."""
backbone_type = model_config.backbone.type backbone_type = model_config.backbone.type
......
...@@ -14,8 +14,12 @@ ...@@ -14,8 +14,12 @@
"""Contains definitions of ResNet and ResNet-RS models.""" """Contains definitions of ResNet and ResNet-RS models."""
from typing import Callable, Optional
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
from official.modeling import hyperparams
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.beta.modeling.backbones import factory from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_blocks from official.vision.beta.modeling.layers import nn_blocks
...@@ -99,22 +103,24 @@ class ResNet(tf.keras.Model): ...@@ -99,22 +103,24 @@ class ResNet(tf.keras.Model):
(https://arxiv.org/abs/2103.07579). (https://arxiv.org/abs/2103.07579).
""" """
def __init__(self, def __init__(
model_id, self,
input_specs=layers.InputSpec(shape=[None, None, None, 3]), model_id: int,
depth_multiplier=1.0, input_specs: tf.keras.layers.InputSpec = layers.InputSpec(
stem_type='v0', shape=[None, None, None, 3]),
resnetd_shortcut=False, depth_multiplier: float = 1.0,
replace_stem_max_pool=False, stem_type: str = 'v0',
se_ratio=None, resnetd_shortcut: bool = False,
init_stochastic_depth_rate=0.0, replace_stem_max_pool: bool = False,
activation='relu', se_ratio: Optional[float] = None,
use_sync_bn=False, init_stochastic_depth_rate: float = 0.0,
norm_momentum=0.99, activation: str = 'relu',
norm_epsilon=0.001, use_sync_bn: bool = False,
kernel_initializer='VarianceScaling', norm_momentum: float = 0.99,
kernel_regularizer=None, norm_epsilon: float = 0.001,
bias_regularizer=None, kernel_initializer: str = 'VarianceScaling',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
**kwargs): **kwargs):
"""Initializes a ResNet model. """Initializes a ResNet model.
...@@ -274,13 +280,13 @@ class ResNet(tf.keras.Model): ...@@ -274,13 +280,13 @@ class ResNet(tf.keras.Model):
super(ResNet, self).__init__(inputs=inputs, outputs=endpoints, **kwargs) super(ResNet, self).__init__(inputs=inputs, outputs=endpoints, **kwargs)
def _block_group(self, def _block_group(self,
inputs, inputs: tf.Tensor,
filters, filters: int,
strides, strides: int,
block_fn, block_fn: Callable[..., tf.keras.layers.Layer],
block_repeats=1, block_repeats: int = 1,
stochastic_depth_drop_rate=0.0, stochastic_depth_drop_rate: float = 0.0,
name='block_group'): name: str = 'block_group'):
"""Creates one group of blocks for the ResNet model. """Creates one group of blocks for the ResNet model.
Args: Args:
...@@ -366,7 +372,7 @@ class ResNet(tf.keras.Model): ...@@ -366,7 +372,7 @@ class ResNet(tf.keras.Model):
@factory.register_backbone_builder('resnet') @factory.register_backbone_builder('resnet')
def build_resnet( def build_resnet(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config, model_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds ResNet backbone from a config.""" """Builds ResNet backbone from a config."""
backbone_type = model_config.backbone.type backbone_type = model_config.backbone.type
......
...@@ -13,10 +13,12 @@ ...@@ -13,10 +13,12 @@
# limitations under the License. # limitations under the License.
"""Contains definitions of 3D Residual Networks.""" """Contains definitions of 3D Residual Networks."""
from typing import List, Tuple from typing import Callable, List, Tuple, Optional
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
from official.modeling import hyperparams
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.beta.modeling.backbones import factory from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_blocks_3d from official.vision.beta.modeling.layers import nn_blocks_3d
...@@ -74,25 +76,27 @@ RESNET_SPECS = { ...@@ -74,25 +76,27 @@ RESNET_SPECS = {
class ResNet3D(tf.keras.Model): class ResNet3D(tf.keras.Model):
"""Creates a 3D ResNet family model.""" """Creates a 3D ResNet family model."""
def __init__(self, def __init__(
self,
model_id: int, model_id: int,
temporal_strides: List[int], temporal_strides: List[int],
temporal_kernel_sizes: List[Tuple[int]], temporal_kernel_sizes: List[Tuple[int]],
use_self_gating: List[int] = None, use_self_gating: List[int] = None,
input_specs=layers.InputSpec(shape=[None, None, None, None, 3]), input_specs: tf.keras.layers.InputSpec = layers.InputSpec(
stem_type='v0', shape=[None, None, None, None, 3]),
stem_conv_temporal_kernel_size=5, stem_type: str = 'v0',
stem_conv_temporal_stride=2, stem_conv_temporal_kernel_size: int = 5,
stem_pool_temporal_stride=2, stem_conv_temporal_stride: int = 2,
init_stochastic_depth_rate=0.0, stem_pool_temporal_stride: int = 2,
activation='relu', init_stochastic_depth_rate: float = 0.0,
se_ratio=None, activation: str = 'relu',
use_sync_bn=False, se_ratio: Optional[float] = None,
norm_momentum=0.99, use_sync_bn: bool = False,
norm_epsilon=0.001, norm_momentum: float = 0.99,
kernel_initializer='VarianceScaling', norm_epsilon: float = 0.001,
kernel_regularizer=None, kernel_initializer: str = 'VarianceScaling',
bias_regularizer=None, kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
**kwargs): **kwargs):
"""Initializes a 3D ResNet model. """Initializes a 3D ResNet model.
...@@ -259,16 +263,18 @@ class ResNet3D(tf.keras.Model): ...@@ -259,16 +263,18 @@ class ResNet3D(tf.keras.Model):
super(ResNet3D, self).__init__(inputs=inputs, outputs=endpoints, **kwargs) super(ResNet3D, self).__init__(inputs=inputs, outputs=endpoints, **kwargs)
def _block_group(self, def _block_group(self,
inputs, inputs: tf.Tensor,
filters, filters: int,
temporal_kernel_sizes, temporal_kernel_sizes: Tuple[int],
temporal_strides, temporal_strides: int,
spatial_strides, spatial_strides: int,
block_fn=nn_blocks_3d.BottleneckBlock3D, block_fn: Callable[
block_repeats=1, ...,
stochastic_depth_drop_rate=0.0, tf.keras.layers.Layer] = nn_blocks_3d.BottleneckBlock3D,
use_self_gating=False, block_repeats: int = 1,
name='block_group'): stochastic_depth_drop_rate: float = 0.0,
use_self_gating: bool = False,
name: str = 'block_group'):
"""Creates one group of blocks for the ResNet3D model. """Creates one group of blocks for the ResNet3D model.
Args: Args:
...@@ -410,7 +416,7 @@ def build_resnet3d( ...@@ -410,7 +416,7 @@ def build_resnet3d(
@factory.register_backbone_builder('resnet_3d_rs') @factory.register_backbone_builder('resnet_3d_rs')
def build_resnet3d_rs( def build_resnet3d_rs(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config, model_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds ResNet-3D-RS backbone from a config.""" """Builds ResNet-3D-RS backbone from a config."""
backbone_cfg = model_config.backbone.get() backbone_cfg = model_config.backbone.get()
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
"""Contains definitions of Residual Networks with Deeplab modifications.""" """Contains definitions of Residual Networks with Deeplab modifications."""
from typing import Callable, Optional, Tuple, List
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
...@@ -53,22 +55,24 @@ class DilatedResNet(tf.keras.Model): ...@@ -53,22 +55,24 @@ class DilatedResNet(tf.keras.Model):
(https://arxiv.org/pdf/1706.05587) (https://arxiv.org/pdf/1706.05587)
""" """
def __init__(self, def __init__(
model_id, self,
output_stride, model_id: int,
input_specs=layers.InputSpec(shape=[None, None, None, 3]), output_stride: int,
stem_type='v0', input_specs: tf.keras.layers.InputSpec = layers.InputSpec(
se_ratio=None, shape=[None, None, None, 3]),
init_stochastic_depth_rate=0.0, stem_type: str = 'v0',
multigrid=None, se_ratio: Optional[float] = None,
last_stage_repeats=1, init_stochastic_depth_rate: float = 0.0,
activation='relu', multigrid: Optional[Tuple[int]] = None,
use_sync_bn=False, last_stage_repeats: int = 1,
norm_momentum=0.99, activation: str = 'relu',
norm_epsilon=0.001, use_sync_bn: bool = False,
kernel_initializer='VarianceScaling', norm_momentum: float = 0.99,
kernel_regularizer=None, norm_epsilon: float = 0.001,
bias_regularizer=None, kernel_initializer: str = 'VarianceScaling',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
**kwargs): **kwargs):
"""Initializes a ResNet model with DeepLab modification. """Initializes a ResNet model with DeepLab modification.
...@@ -234,15 +238,15 @@ class DilatedResNet(tf.keras.Model): ...@@ -234,15 +238,15 @@ class DilatedResNet(tf.keras.Model):
inputs=inputs, outputs=endpoints, **kwargs) inputs=inputs, outputs=endpoints, **kwargs)
def _block_group(self, def _block_group(self,
inputs, inputs: tf.Tensor,
filters, filters: int,
strides, strides: int,
dilation_rate, dilation_rate: int,
block_fn, block_fn: Callable[..., tf.keras.layers.Layer],
block_repeats=1, block_repeats: int = 1,
stochastic_depth_drop_rate=0.0, stochastic_depth_drop_rate: float = 0.0,
multigrid=None, multigrid: Optional[List[int]] = None,
name='block_group'): name: str = 'block_group'):
"""Creates one group of blocks for the ResNet model. """Creates one group of blocks for the ResNet model.
Deeplab applies strides at the last block. Deeplab applies strides at the last block.
......
...@@ -59,16 +59,17 @@ class RevNet(tf.keras.Model): ...@@ -59,16 +59,17 @@ class RevNet(tf.keras.Model):
(https://arxiv.org/pdf/1707.04585.pdf) (https://arxiv.org/pdf/1707.04585.pdf)
""" """
def __init__(self, def __init__(
self,
model_id: int, model_id: int,
input_specs: tf.keras.layers.InputSpec input_specs: tf.keras.layers.InputSpec = tf.keras.layers.InputSpec(
= tf.keras.layers.InputSpec(shape=[None, None, None, 3]), shape=[None, None, None, 3]),
activation: str = 'relu', activation: str = 'relu',
use_sync_bn: bool = False, use_sync_bn: bool = False,
norm_momentum: float = 0.99, norm_momentum: float = 0.99,
norm_epsilon: float = 0.001, norm_epsilon: float = 0.001,
kernel_initializer: str = 'VarianceScaling', kernel_initializer: str = 'VarianceScaling',
kernel_regularizer: tf.keras.regularizers.Regularizer = None, kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
**kwargs): **kwargs):
"""Initializes a RevNet model. """Initializes a RevNet model.
......
...@@ -12,13 +12,16 @@ ...@@ -12,13 +12,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Contains definitions of SpineNet Networks.""" """Contains definitions of SpineNet Networks."""
import math import math
from typing import Any, List, Optional, Tuple
# Import libraries # Import libraries
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.beta.modeling.backbones import factory from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_blocks from official.vision.beta.modeling.layers import nn_blocks
...@@ -95,14 +98,16 @@ SCALING_MAP = { ...@@ -95,14 +98,16 @@ SCALING_MAP = {
class BlockSpec(object): class BlockSpec(object):
"""A container class that specifies the block configuration for SpineNet.""" """A container class that specifies the block configuration for SpineNet."""
def __init__(self, level, block_fn, input_offsets, is_output): def __init__(self, level: int, block_fn: str, input_offsets: Tuple[int, int],
is_output: bool):
self.level = level self.level = level
self.block_fn = block_fn self.block_fn = block_fn
self.input_offsets = input_offsets self.input_offsets = input_offsets
self.is_output = is_output self.is_output = is_output
def build_block_specs(block_specs=None): def build_block_specs(
block_specs: Optional[List[Tuple[Any, ...]]] = None) -> List[BlockSpec]:
"""Builds the list of BlockSpec objects for SpineNet.""" """Builds the list of BlockSpec objects for SpineNet."""
if not block_specs: if not block_specs:
block_specs = SPINENET_BLOCK_SPECS block_specs = SPINENET_BLOCK_SPECS
...@@ -121,23 +126,25 @@ class SpineNet(tf.keras.Model): ...@@ -121,23 +126,25 @@ class SpineNet(tf.keras.Model):
(https://arxiv.org/abs/1912.05027) (https://arxiv.org/abs/1912.05027)
""" """
def __init__(self, def __init__(
input_specs=tf.keras.layers.InputSpec(shape=[None, 640, 640, 3]), self,
min_level=3, input_specs: tf.keras.layers.InputSpec = tf.keras.layers.InputSpec(
max_level=7, shape=[None, 640, 640, 3]),
block_specs=build_block_specs(), min_level: int = 3,
endpoints_num_filters=256, max_level: int = 7,
resample_alpha=0.5, block_specs: List[BlockSpec] = build_block_specs(),
block_repeats=1, endpoints_num_filters: int = 256,
filter_size_scale=1.0, resample_alpha: float = 0.5,
init_stochastic_depth_rate=0.0, block_repeats: int = 1,
kernel_initializer='VarianceScaling', filter_size_scale: float = 1.0,
kernel_regularizer=None, init_stochastic_depth_rate: float = 0.0,
bias_regularizer=None, kernel_initializer: str = 'VarianceScaling',
activation='relu', kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
use_sync_bn=False, bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
norm_momentum=0.99, activation: str = 'relu',
norm_epsilon=0.001, use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
**kwargs): **kwargs):
"""Initializes a SpineNet model. """Initializes a SpineNet model.
...@@ -145,8 +152,8 @@ class SpineNet(tf.keras.Model): ...@@ -145,8 +152,8 @@ class SpineNet(tf.keras.Model):
input_specs: A `tf.keras.layers.InputSpec` of the input tensor. input_specs: A `tf.keras.layers.InputSpec` of the input tensor.
min_level: An `int` of min level for output mutiscale features. min_level: An `int` of min level for output mutiscale features.
max_level: An `int` of max level for output mutiscale features. max_level: An `int` of max level for output mutiscale features.
block_specs: The block specifications for the SpineNet model discovered by block_specs: A list of block specifications for the SpineNet model
NAS. discovered by NAS.
endpoints_num_filters: An `int` of feature dimension for the output endpoints_num_filters: An `int` of feature dimension for the output
endpoints. endpoints.
resample_alpha: A `float` of resampling factor in cross-scale connections. resample_alpha: A `float` of resampling factor in cross-scale connections.
...@@ -214,13 +221,13 @@ class SpineNet(tf.keras.Model): ...@@ -214,13 +221,13 @@ class SpineNet(tf.keras.Model):
super(SpineNet, self).__init__(inputs=inputs, outputs=endpoints) super(SpineNet, self).__init__(inputs=inputs, outputs=endpoints)
def _block_group(self, def _block_group(self,
inputs, inputs: tf.Tensor,
filters, filters: int,
strides, strides: int,
block_fn_cand, block_fn_cand: str,
block_repeats=1, block_repeats: int = 1,
stochastic_depth_drop_rate=None, stochastic_depth_drop_rate: Optional[float] = None,
name='block_group'): name: str = 'block_group'):
"""Creates one group of blocks for the SpineNet model.""" """Creates one group of blocks for the SpineNet model."""
block_fn_candidates = { block_fn_candidates = {
'bottleneck': nn_blocks.BottleneckBlock, 'bottleneck': nn_blocks.BottleneckBlock,
......
...@@ -29,10 +29,13 @@ ...@@ -29,10 +29,13 @@
# ============================================================================== # ==============================================================================
"""Contains definitions of Mobile SpineNet Networks.""" """Contains definitions of Mobile SpineNet Networks."""
import math import math
from typing import Any, List, Optional, Tuple
# Import libraries # Import libraries
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.beta.modeling.backbones import factory from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_blocks from official.vision.beta.modeling.layers import nn_blocks
...@@ -96,14 +99,16 @@ SCALING_MAP = { ...@@ -96,14 +99,16 @@ SCALING_MAP = {
class BlockSpec(object): class BlockSpec(object):
"""A container class that specifies the block configuration for SpineNet.""" """A container class that specifies the block configuration for SpineNet."""
def __init__(self, level, block_fn, input_offsets, is_output): def __init__(self, level: int, block_fn: str, input_offsets: Tuple[int, int],
is_output: bool):
self.level = level self.level = level
self.block_fn = block_fn self.block_fn = block_fn
self.input_offsets = input_offsets self.input_offsets = input_offsets
self.is_output = is_output self.is_output = is_output
def build_block_specs(block_specs=None): def build_block_specs(
block_specs: Optional[List[Tuple[Any, ...]]] = None) -> List[BlockSpec]:
"""Builds the list of BlockSpec objects for SpineNet.""" """Builds the list of BlockSpec objects for SpineNet."""
if not block_specs: if not block_specs:
block_specs = SPINENET_BLOCK_SPECS block_specs = SPINENET_BLOCK_SPECS
...@@ -126,24 +131,26 @@ class SpineNetMobile(tf.keras.Model): ...@@ -126,24 +131,26 @@ class SpineNetMobile(tf.keras.Model):
(https://arxiv.org/abs/2010.11426). (https://arxiv.org/abs/2010.11426).
""" """
def __init__(self, def __init__(
input_specs=tf.keras.layers.InputSpec(shape=[None, 512, 512, 3]), self,
min_level=3, input_specs: tf.keras.layers.InputSpec = tf.keras.layers.InputSpec(
max_level=7, shape=[None, 512, 512, 3]),
block_specs=build_block_specs(), min_level: int = 3,
endpoints_num_filters=48, max_level: int = 7,
se_ratio=0.2, block_specs: List[BlockSpec] = build_block_specs(),
block_repeats=1, endpoints_num_filters: int = 256,
filter_size_scale=1.0, se_ratio: float = 0.2,
expand_ratio=6, block_repeats: int = 1,
filter_size_scale: float = 1.0,
expand_ratio: int = 6,
init_stochastic_depth_rate=0.0, init_stochastic_depth_rate=0.0,
kernel_initializer='VarianceScaling', kernel_initializer: str = 'VarianceScaling',
kernel_regularizer=None, kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer=None, bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
activation='relu', activation: str = 'relu',
use_sync_bn=False, use_sync_bn: bool = False,
norm_momentum=0.99, norm_momentum: float = 0.99,
norm_epsilon=0.001, norm_epsilon: float = 0.001,
**kwargs): **kwargs):
"""Initializes a Mobile SpineNet model. """Initializes a Mobile SpineNet model.
...@@ -222,15 +229,15 @@ class SpineNetMobile(tf.keras.Model): ...@@ -222,15 +229,15 @@ class SpineNetMobile(tf.keras.Model):
super().__init__(inputs=inputs, outputs=endpoints) super().__init__(inputs=inputs, outputs=endpoints)
def _block_group(self, def _block_group(self,
inputs, inputs: tf.Tensor,
in_filters, in_filters: int,
out_filters, out_filters: int,
strides, strides: int,
expand_ratio=6, expand_ratio: int = 6,
block_repeats=1, block_repeats: int = 1,
se_ratio=0.2, se_ratio: float = 0.2,
stochastic_depth_drop_rate=None, stochastic_depth_drop_rate: Optional[float] = None,
name='block_group'): name: str = 'block_group'):
"""Creates one group of blocks for the SpineNet model.""" """Creates one group of blocks for the SpineNet model."""
x = nn_blocks.InvertedBottleneckBlock( x = nn_blocks.InvertedBottleneckBlock(
in_filters=in_filters, in_filters=in_filters,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
"""Build classification models.""" """Build classification models."""
from typing import Any, Mapping, Optional
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
...@@ -24,15 +25,17 @@ layers = tf.keras.layers ...@@ -24,15 +25,17 @@ layers = tf.keras.layers
class ClassificationModel(tf.keras.Model): class ClassificationModel(tf.keras.Model):
"""A classification class builder.""" """A classification class builder."""
def __init__(self, def __init__(
backbone, self,
num_classes, backbone: tf.keras.Model,
input_specs=layers.InputSpec(shape=[None, None, None, 3]), num_classes: int,
dropout_rate=0.0, input_specs: tf.keras.layers.InputSpec = layers.InputSpec(
kernel_initializer='random_uniform', shape=[None, None, None, 3]),
kernel_regularizer=None, dropout_rate: float = 0.0,
bias_regularizer=None, kernel_initializer: str = 'random_uniform',
add_head_batch_norm=False, kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
add_head_batch_norm: bool = False,
use_sync_bn: bool = False, use_sync_bn: bool = False,
norm_momentum: float = 0.99, norm_momentum: float = 0.99,
norm_epsilon: float = 0.001, norm_epsilon: float = 0.001,
...@@ -103,15 +106,15 @@ class ClassificationModel(tf.keras.Model): ...@@ -103,15 +106,15 @@ class ClassificationModel(tf.keras.Model):
self._norm = norm self._norm = norm
@property @property
def checkpoint_items(self): def checkpoint_items(self) -> Mapping[str, tf.keras.Model]:
"""Returns a dictionary of items to be additionally checkpointed.""" """Returns a dictionary of items to be additionally checkpointed."""
return dict(backbone=self.backbone) return dict(backbone=self.backbone)
@property @property
def backbone(self): def backbone(self) -> tf.keras.Model:
return self._backbone return self._backbone
def get_config(self): def get_config(self) -> Mapping[str, Any]:
return self._config_dict return self._config_dict
@classmethod @classmethod
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains definitions of Atrous Spatial Pyramid Pooling (ASPP) decoder.""" """Contains definitions of Atrous Spatial Pyramid Pooling (ASPP) decoder."""
from typing import Any, List, Optional, Mapping
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
...@@ -24,19 +25,20 @@ from official.vision import keras_cv ...@@ -24,19 +25,20 @@ from official.vision import keras_cv
class ASPP(tf.keras.layers.Layer): class ASPP(tf.keras.layers.Layer):
"""Creates an Atrous Spatial Pyramid Pooling (ASPP) layer.""" """Creates an Atrous Spatial Pyramid Pooling (ASPP) layer."""
def __init__(self, def __init__(
level, self,
dilation_rates, level: int,
num_filters=256, dilation_rates: List[int],
pool_kernel_size=None, num_filters: int = 256,
use_sync_bn=False, pool_kernel_size: Optional[int] = None,
norm_momentum=0.99, use_sync_bn: bool = False,
norm_epsilon=0.001, norm_momentum: float = 0.99,
activation='relu', norm_epsilon: float = 0.001,
dropout_rate=0.0, activation: str = 'relu',
kernel_initializer='VarianceScaling', dropout_rate: float = 0.0,
kernel_regularizer=None, kernel_initializer: str = 'VarianceScaling',
interpolation='bilinear', kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
interpolation: str = 'bilinear',
**kwargs): **kwargs):
"""Initializes an Atrous Spatial Pyramid Pooling (ASPP) layer. """Initializes an Atrous Spatial Pyramid Pooling (ASPP) layer.
...@@ -97,7 +99,7 @@ class ASPP(tf.keras.layers.Layer): ...@@ -97,7 +99,7 @@ class ASPP(tf.keras.layers.Layer):
kernel_regularizer=self._config_dict['kernel_regularizer'], kernel_regularizer=self._config_dict['kernel_regularizer'],
interpolation=self._config_dict['interpolation']) interpolation=self._config_dict['interpolation'])
def call(self, inputs): def call(self, inputs: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
"""Calls the Atrous Spatial Pyramid Pooling (ASPP) layer on an input. """Calls the Atrous Spatial Pyramid Pooling (ASPP) layer on an input.
The output of ASPP will be a dict of {`level`, `tf.Tensor`} even if only one The output of ASPP will be a dict of {`level`, `tf.Tensor`} even if only one
...@@ -120,7 +122,7 @@ class ASPP(tf.keras.layers.Layer): ...@@ -120,7 +122,7 @@ class ASPP(tf.keras.layers.Layer):
outputs[level] = self.aspp(inputs[level]) outputs[level] = self.aspp(inputs[level])
return outputs return outputs
def get_config(self): def get_config(self) -> Mapping[str, Any]:
return self._config_dict return self._config_dict
@classmethod @classmethod
......
...@@ -15,15 +15,21 @@ ...@@ -15,15 +15,21 @@
# Lint as: python3 # Lint as: python3
"""Contains the factory method to create decoders.""" """Contains the factory method to create decoders."""
from typing import Mapping, Optional
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
from official.modeling import hyperparams
from official.vision.beta.modeling import decoders from official.vision.beta.modeling import decoders
def build_decoder(input_specs, def build_decoder(
model_config, input_specs: Mapping[str, tf.TensorShape],
l2_regularizer: tf.keras.regularizers.Regularizer = None): model_config: hyperparams.Config,
l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None
) -> tf.keras.Model:
"""Builds decoder from a config. """Builds decoder from a config.
Args: Args:
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains the definitions of Feature Pyramid Networks (FPN).""" """Contains the definitions of Feature Pyramid Networks (FPN)."""
from typing import Any, Mapping, Optional
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
...@@ -32,19 +33,20 @@ class FPN(tf.keras.Model): ...@@ -32,19 +33,20 @@ class FPN(tf.keras.Model):
(https://arxiv.org/pdf/1612.03144) (https://arxiv.org/pdf/1612.03144)
""" """
def __init__(self, def __init__(
input_specs, self,
min_level=3, input_specs: Mapping[str, tf.TensorShape],
max_level=7, min_level: int = 3,
num_filters=256, max_level: int = 7,
use_separable_conv=False, num_filters: int = 256,
activation='relu', use_separable_conv: bool = False,
use_sync_bn=False, activation: str = 'relu',
norm_momentum=0.99, use_sync_bn: bool = False,
norm_epsilon=0.001, norm_momentum: float = 0.99,
kernel_initializer='VarianceScaling', norm_epsilon: float = 0.001,
kernel_regularizer=None, kernel_initializer: str = 'VarianceScaling',
bias_regularizer=None, kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
**kwargs): **kwargs):
"""Initializes a Feature Pyramid Network (FPN). """Initializes a Feature Pyramid Network (FPN).
...@@ -162,7 +164,8 @@ class FPN(tf.keras.Model): ...@@ -162,7 +164,8 @@ class FPN(tf.keras.Model):
super(FPN, self).__init__(inputs=inputs, outputs=feats, **kwargs) super(FPN, self).__init__(inputs=inputs, outputs=feats, **kwargs)
def _build_input_pyramid(self, input_specs, min_level): def _build_input_pyramid(self, input_specs: Mapping[str, tf.TensorShape],
min_level: int):
assert isinstance(input_specs, dict) assert isinstance(input_specs, dict)
if min(input_specs.keys()) > str(min_level): if min(input_specs.keys()) > str(min_level):
raise ValueError( raise ValueError(
...@@ -173,7 +176,7 @@ class FPN(tf.keras.Model): ...@@ -173,7 +176,7 @@ class FPN(tf.keras.Model):
inputs[level] = tf.keras.Input(shape=spec[1:]) inputs[level] = tf.keras.Input(shape=spec[1:])
return inputs return inputs
def get_config(self): def get_config(self) -> Mapping[str, Any]:
return self._config_dict return self._config_dict
@classmethod @classmethod
...@@ -181,6 +184,6 @@ class FPN(tf.keras.Model): ...@@ -181,6 +184,6 @@ class FPN(tf.keras.Model):
return cls(**config) return cls(**config)
@property @property
def output_specs(self): def output_specs(self) -> Mapping[str, tf.TensorShape]:
"""A dict of {level: TensorShape} pairs for the model output.""" """A dict of {level: TensorShape} pairs for the model output."""
return self._output_specs return self._output_specs
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains definitions of NAS-FPN.""" """Contains definitions of NAS-FPN."""
from typing import Any, Mapping, List, Tuple, Optional
# Import libraries # Import libraries
from absl import logging from absl import logging
...@@ -35,17 +36,19 @@ NASFPN_BLOCK_SPECS = [ ...@@ -35,17 +36,19 @@ NASFPN_BLOCK_SPECS = [
] ]
class BlockSpec(object): class BlockSpec():
"""A container class that specifies the block configuration for NAS-FPN.""" """A container class that specifies the block configuration for NAS-FPN."""
def __init__(self, level, combine_fn, input_offsets, is_output): def __init__(self, level: int, combine_fn: str,
input_offsets: Tuple[int, int], is_output: bool):
self.level = level self.level = level
self.combine_fn = combine_fn self.combine_fn = combine_fn
self.input_offsets = input_offsets self.input_offsets = input_offsets
self.is_output = is_output self.is_output = is_output
def build_block_specs(block_specs=None): def build_block_specs(
block_specs: Optional[List[Tuple[Any, ...]]] = None) -> List[BlockSpec]:
"""Builds the list of BlockSpec objects for NAS-FPN.""" """Builds the list of BlockSpec objects for NAS-FPN."""
if not block_specs: if not block_specs:
block_specs = NASFPN_BLOCK_SPECS block_specs = NASFPN_BLOCK_SPECS
...@@ -63,21 +66,22 @@ class NASFPN(tf.keras.Model): ...@@ -63,21 +66,22 @@ class NASFPN(tf.keras.Model):
(https://arxiv.org/abs/1904.07392) (https://arxiv.org/abs/1904.07392)
""" """
def __init__(self, def __init__(
input_specs, self,
min_level=3, input_specs: Mapping[str, tf.TensorShape],
max_level=7, min_level: int = 3,
block_specs=build_block_specs(), max_level: int = 7,
num_filters=256, block_specs: List[BlockSpec] = build_block_specs(),
num_repeats=5, num_filters: int = 256,
use_separable_conv=False, num_repeats: int = 5,
activation='relu', use_separable_conv: bool = False,
use_sync_bn=False, activation: str = 'relu',
norm_momentum=0.99, use_sync_bn: bool = False,
norm_epsilon=0.001, norm_momentum: float = 0.99,
kernel_initializer='VarianceScaling', norm_epsilon: float = 0.001,
kernel_regularizer=None, kernel_initializer: str = 'VarianceScaling',
bias_regularizer=None, kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
**kwargs): **kwargs):
"""Initializes a NAS-FPN model. """Initializes a NAS-FPN model.
...@@ -191,7 +195,8 @@ class NASFPN(tf.keras.Model): ...@@ -191,7 +195,8 @@ class NASFPN(tf.keras.Model):
for level in output_feats.keys()} for level in output_feats.keys()}
super(NASFPN, self).__init__(inputs=inputs, outputs=output_feats, **kwargs) super(NASFPN, self).__init__(inputs=inputs, outputs=output_feats, **kwargs)
def _build_input_pyramid(self, input_specs, min_level): def _build_input_pyramid(self, input_specs: Mapping[str, tf.TensorShape],
min_level: int):
assert isinstance(input_specs, dict) assert isinstance(input_specs, dict)
if min(input_specs.keys()) > str(min_level): if min(input_specs.keys()) > str(min_level):
raise ValueError( raise ValueError(
...@@ -300,7 +305,7 @@ class NASFPN(tf.keras.Model): ...@@ -300,7 +305,7 @@ class NASFPN(tf.keras.Model):
logging.info('Output feature pyramid: %s', output_feats) logging.info('Output feature pyramid: %s', output_feats)
return output_feats return output_feats
def get_config(self): def get_config(self) -> Mapping[str, Any]:
return self._config_dict return self._config_dict
@classmethod @classmethod
...@@ -308,6 +313,6 @@ class NASFPN(tf.keras.Model): ...@@ -308,6 +313,6 @@ class NASFPN(tf.keras.Model):
return cls(**config) return cls(**config)
@property @property
def output_specs(self): def output_specs(self) -> Mapping[str, tf.TensorShape]:
"""A dict of {level: TensorShape} pairs for the model output.""" """A dict of {level: TensorShape} pairs for the model output."""
return self._output_specs return self._output_specs
...@@ -42,7 +42,7 @@ def build_classification_model( ...@@ -42,7 +42,7 @@ def build_classification_model(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config: classification_cfg.ImageClassificationModel, model_config: classification_cfg.ImageClassificationModel,
l2_regularizer: tf.keras.regularizers.Regularizer = None, l2_regularizer: tf.keras.regularizers.Regularizer = None,
skip_logits_layer: bool = False): skip_logits_layer: bool = False) -> tf.keras.Model:
"""Builds the classification model.""" """Builds the classification model."""
backbone = backbones.factory.build_backbone( backbone = backbones.factory.build_backbone(
input_specs=input_specs, input_specs=input_specs,
...@@ -64,9 +64,10 @@ def build_classification_model( ...@@ -64,9 +64,10 @@ def build_classification_model(
return model return model
def build_maskrcnn(input_specs: tf.keras.layers.InputSpec, def build_maskrcnn(
input_specs: tf.keras.layers.InputSpec,
model_config: maskrcnn_cfg.MaskRCNN, model_config: maskrcnn_cfg.MaskRCNN,
l2_regularizer: tf.keras.regularizers.Regularizer = None): l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds Mask R-CNN model.""" """Builds Mask R-CNN model."""
backbone = backbones.factory.build_backbone( backbone = backbones.factory.build_backbone(
input_specs=input_specs, input_specs=input_specs,
...@@ -194,9 +195,10 @@ def build_maskrcnn(input_specs: tf.keras.layers.InputSpec, ...@@ -194,9 +195,10 @@ def build_maskrcnn(input_specs: tf.keras.layers.InputSpec,
return model return model
def build_retinanet(input_specs: tf.keras.layers.InputSpec, def build_retinanet(
input_specs: tf.keras.layers.InputSpec,
model_config: retinanet_cfg.RetinaNet, model_config: retinanet_cfg.RetinaNet,
l2_regularizer: tf.keras.regularizers.Regularizer = None): l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds RetinaNet model.""" """Builds RetinaNet model."""
backbone = backbones.factory.build_backbone( backbone = backbones.factory.build_backbone(
input_specs=input_specs, input_specs=input_specs,
...@@ -253,7 +255,7 @@ def build_retinanet(input_specs: tf.keras.layers.InputSpec, ...@@ -253,7 +255,7 @@ def build_retinanet(input_specs: tf.keras.layers.InputSpec,
def build_segmentation_model( def build_segmentation_model(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config: segmentation_cfg.SemanticSegmentationModel, model_config: segmentation_cfg.SemanticSegmentationModel,
l2_regularizer: tf.keras.regularizers.Regularizer = None): l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds Segmentation model.""" """Builds Segmentation model."""
backbone = backbones.factory.build_backbone( backbone = backbones.factory.build_backbone(
input_specs=input_specs, input_specs=input_specs,
......
...@@ -53,11 +53,12 @@ def register_model_builder(key: str): ...@@ -53,11 +53,12 @@ def register_model_builder(key: str):
return registry.register(_REGISTERED_MODEL_CLS, key) return registry.register(_REGISTERED_MODEL_CLS, key)
def build_model(model_type: str, def build_model(
model_type: str,
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config: video_classification_cfg.hyperparams.Config, model_config: video_classification_cfg.hyperparams.Config,
num_classes: int, num_classes: int,
l2_regularizer: tf.keras.regularizers.Regularizer = None): l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds backbone from a config. """Builds backbone from a config.
Args: Args:
...@@ -81,7 +82,7 @@ def build_video_classification_model( ...@@ -81,7 +82,7 @@ def build_video_classification_model(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config: video_classification_cfg.VideoClassificationModel, model_config: video_classification_cfg.VideoClassificationModel,
num_classes: int, num_classes: int,
l2_regularizer: tf.keras.regularizers.Regularizer = None): l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds the video classification model.""" """Builds the video classification model."""
input_specs_dict = {'image': input_specs} input_specs_dict = {'image': input_specs}
backbone = backbones.factory.build_backbone( backbone = backbones.factory.build_backbone(
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
"""Mask R-CNN model.""" """Mask R-CNN model."""
from typing import Any, Mapping, Optional, Union
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
...@@ -25,17 +27,17 @@ class MaskRCNNModel(tf.keras.Model): ...@@ -25,17 +27,17 @@ class MaskRCNNModel(tf.keras.Model):
"""The Mask R-CNN model.""" """The Mask R-CNN model."""
def __init__(self, def __init__(self,
backbone, backbone: tf.keras.Model,
decoder, decoder: tf.keras.Model,
rpn_head, rpn_head: tf.keras.layers.Layer,
detection_head, detection_head: tf.keras.layers.Layer,
roi_generator, roi_generator: tf.keras.layers.Layer,
roi_sampler, roi_sampler: tf.keras.layers.Layer,
roi_aligner, roi_aligner: tf.keras.layers.Layer,
detection_generator, detection_generator: tf.keras.layers.Layer,
mask_head=None, mask_head: Optional[tf.keras.layers.Layer] = None,
mask_sampler=None, mask_sampler: Optional[tf.keras.layers.Layer] = None,
mask_roi_aligner=None, mask_roi_aligner: Optional[tf.keras.layers.Layer] = None,
**kwargs): **kwargs):
"""Initializes the Mask R-CNN model. """Initializes the Mask R-CNN model.
...@@ -85,13 +87,13 @@ class MaskRCNNModel(tf.keras.Model): ...@@ -85,13 +87,13 @@ class MaskRCNNModel(tf.keras.Model):
self.mask_roi_aligner = mask_roi_aligner self.mask_roi_aligner = mask_roi_aligner
def call(self, def call(self,
images, images: tf.Tensor,
image_shape, image_shape: tf.Tensor,
anchor_boxes=None, anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None,
gt_boxes=None, gt_boxes: tf.Tensor = None,
gt_classes=None, gt_classes: tf.Tensor = None,
gt_masks=None, gt_masks: tf.Tensor = None,
training=None): training: bool = None) -> Mapping[str, tf.Tensor]:
model_outputs = {} model_outputs = {}
# Feature extraction. # Feature extraction.
...@@ -190,7 +192,8 @@ class MaskRCNNModel(tf.keras.Model): ...@@ -190,7 +192,8 @@ class MaskRCNNModel(tf.keras.Model):
return model_outputs return model_outputs
@property @property
def checkpoint_items(self): def checkpoint_items(
self) -> Mapping[str, Union[tf.keras.Model, tf.keras.layers.Layer]]:
"""Returns a dictionary of items to be additionally checkpointed.""" """Returns a dictionary of items to be additionally checkpointed."""
items = dict( items = dict(
backbone=self.backbone, backbone=self.backbone,
...@@ -203,7 +206,7 @@ class MaskRCNNModel(tf.keras.Model): ...@@ -203,7 +206,7 @@ class MaskRCNNModel(tf.keras.Model):
return items return items
def get_config(self): def get_config(self) -> Mapping[str, Any]:
return self._config_dict return self._config_dict
@classmethod @classmethod
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""RetinaNet.""" """RetinaNet."""
from typing import List, Optional from typing import Any, Mapping, List, Optional, Union
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
...@@ -26,10 +26,10 @@ class RetinaNetModel(tf.keras.Model): ...@@ -26,10 +26,10 @@ class RetinaNetModel(tf.keras.Model):
"""The RetinaNet model class.""" """The RetinaNet model class."""
def __init__(self, def __init__(self,
backbone, backbone: tf.keras.Model,
decoder, decoder: tf.keras.Model,
head, head: tf.keras.layers.Layer,
detection_generator, detection_generator: tf.keras.layers.Layer,
min_level: Optional[int] = None, min_level: Optional[int] = None,
max_level: Optional[int] = None, max_level: Optional[int] = None,
num_scales: Optional[int] = None, num_scales: Optional[int] = None,
...@@ -74,10 +74,10 @@ class RetinaNetModel(tf.keras.Model): ...@@ -74,10 +74,10 @@ class RetinaNetModel(tf.keras.Model):
self._detection_generator = detection_generator self._detection_generator = detection_generator
def call(self, def call(self,
images, images: tf.Tensor,
image_shape=None, image_shape: Optional[tf.Tensor] = None,
anchor_boxes=None, anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None,
training=None): training: bool = None) -> Mapping[str, tf.Tensor]:
"""Forward pass of the RetinaNet model. """Forward pass of the RetinaNet model.
Args: Args:
...@@ -163,7 +163,8 @@ class RetinaNetModel(tf.keras.Model): ...@@ -163,7 +163,8 @@ class RetinaNetModel(tf.keras.Model):
return outputs return outputs
@property @property
def checkpoint_items(self): def checkpoint_items(
self) -> Mapping[str, Union[tf.keras.Model, tf.keras.layers.Layer]]:
"""Returns a dictionary of items to be additionally checkpointed.""" """Returns a dictionary of items to be additionally checkpointed."""
items = dict(backbone=self.backbone, head=self.head) items = dict(backbone=self.backbone, head=self.head)
if self.decoder is not None: if self.decoder is not None:
...@@ -172,22 +173,22 @@ class RetinaNetModel(tf.keras.Model): ...@@ -172,22 +173,22 @@ class RetinaNetModel(tf.keras.Model):
return items return items
@property @property
def backbone(self): def backbone(self) -> tf.keras.Model:
return self._backbone return self._backbone
@property @property
def decoder(self): def decoder(self) -> tf.keras.Model:
return self._decoder return self._decoder
@property @property
def head(self): def head(self) -> tf.keras.layers.Layer:
return self._head return self._head
@property @property
def detection_generator(self): def detection_generator(self) -> tf.keras.layers.Layer:
return self._detection_generator return self._detection_generator
def get_config(self): def get_config(self) -> Mapping[str, Any]:
return self._config_dict return self._config_dict
@classmethod @classmethod
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Build segmentation models.""" """Build segmentation models."""
from typing import Any, Mapping, Union
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
...@@ -33,11 +34,8 @@ class SegmentationModel(tf.keras.Model): ...@@ -33,11 +34,8 @@ class SegmentationModel(tf.keras.Model):
different backbones, and decoders. different backbones, and decoders.
""" """
def __init__(self, def __init__(self, backbone: tf.keras.Model, decoder: tf.keras.Model,
backbone, head: tf.keras.layers.Layer, **kwargs):
decoder,
head,
**kwargs):
"""Segmentation initialization function. """Segmentation initialization function.
Args: Args:
...@@ -56,7 +54,7 @@ class SegmentationModel(tf.keras.Model): ...@@ -56,7 +54,7 @@ class SegmentationModel(tf.keras.Model):
self.decoder = decoder self.decoder = decoder
self.head = head self.head = head
def call(self, inputs, training=None): def call(self, inputs: tf.Tensor, training: bool = None) -> tf.Tensor:
backbone_features = self.backbone(inputs) backbone_features = self.backbone(inputs)
if self.decoder: if self.decoder:
...@@ -67,14 +65,15 @@ class SegmentationModel(tf.keras.Model): ...@@ -67,14 +65,15 @@ class SegmentationModel(tf.keras.Model):
return self.head(backbone_features, decoder_features) return self.head(backbone_features, decoder_features)
@property @property
def checkpoint_items(self): def checkpoint_items(
self) -> Mapping[str, Union[tf.keras.Model, tf.keras.layers.Layer]]:
"""Returns a dictionary of items to be additionally checkpointed.""" """Returns a dictionary of items to be additionally checkpointed."""
items = dict(backbone=self.backbone, head=self.head) items = dict(backbone=self.backbone, head=self.head)
if self.decoder is not None: if self.decoder is not None:
items.update(decoder=self.decoder) items.update(decoder=self.decoder)
return items return items
def get_config(self): def get_config(self) -> Mapping[str, Any]:
return self._config_dict return self._config_dict
@classmethod @classmethod
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Build video classification models.""" """Build video classification models."""
from typing import Mapping from typing import Any, Mapping, Optional, Union
import tensorflow as tf import tensorflow as tf
layers = tf.keras.layers layers = tf.keras.layers
...@@ -23,15 +23,16 @@ layers = tf.keras.layers ...@@ -23,15 +23,16 @@ layers = tf.keras.layers
class VideoClassificationModel(tf.keras.Model): class VideoClassificationModel(tf.keras.Model):
"""A video classification class builder.""" """A video classification class builder."""
def __init__(self, def __init__(
self,
backbone: tf.keras.Model, backbone: tf.keras.Model,
num_classes: int, num_classes: int,
input_specs: Mapping[str, tf.keras.layers.InputSpec] = None, input_specs: Mapping[str, tf.keras.layers.InputSpec] = None,
dropout_rate: float = 0.0, dropout_rate: float = 0.0,
aggregate_endpoints: bool = False, aggregate_endpoints: bool = False,
kernel_initializer='random_uniform', kernel_initializer: str = 'random_uniform',
kernel_regularizer=None, kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer=None, bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
**kwargs): **kwargs):
"""Video Classification initialization function. """Video Classification initialization function.
...@@ -95,15 +96,16 @@ class VideoClassificationModel(tf.keras.Model): ...@@ -95,15 +96,16 @@ class VideoClassificationModel(tf.keras.Model):
inputs=inputs, outputs=x, **kwargs) inputs=inputs, outputs=x, **kwargs)
@property @property
def checkpoint_items(self): def checkpoint_items(
self) -> Mapping[str, Union[tf.keras.Model, tf.keras.layers.Layer]]:
"""Returns a dictionary of items to be additionally checkpointed.""" """Returns a dictionary of items to be additionally checkpointed."""
return dict(backbone=self.backbone) return dict(backbone=self.backbone)
@property @property
def backbone(self): def backbone(self) -> tf.keras.Model:
return self._backbone return self._backbone
def get_config(self): def get_config(self) -> Mapping[str, Any]:
return self._config_dict return self._config_dict
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment