Unverified Commit e61588cd authored by Shixin's avatar Shixin Committed by GitHub
Browse files

[MobileNet] Add Mobilenet Backbone Implementation (#9303)



* factor make_divisible function and move round_filters to nn_layers

* modify SqueezeExcitation to add two additional parameter: divisible_by and gating_activation

* modify the InvertedBottleneckBlock to include 1. use_depthwise, 2. use_residual, 3. regularize_depthwise additional boolean flag; Add control for depthwise activation and regularizer; remove expand_ratio from SqueezeExcitation

* add Conv2DBNBlock definition

* add mobilenet v2, v3 implementation

* add mobilenet v1

* put mobilenet_base into class body

* fix a type hint error

* the invertedbottlenetblock is different for mobilenet and efficientnet. Made necessary changes to cope both.

* add target_backbone while call invertedbottleneckblock

* add relu6 and hard_sigmoid

* add test for mobilenet

* add mobilenet to factory

* fix some typo; link the reference to the architectures

* remove future import
Co-authored-by: default avatarShixin Luo <luoshixin@google.com>
parent 2f737e1e
...@@ -17,3 +17,5 @@ from official.modeling.activations.gelu import gelu ...@@ -17,3 +17,5 @@ from official.modeling.activations.gelu import gelu
from official.modeling.activations.swish import hard_swish from official.modeling.activations.swish import hard_swish
from official.modeling.activations.swish import identity from official.modeling.activations.swish import identity
from official.modeling.activations.swish import simple_swish from official.modeling.activations.swish import simple_swish
from official.modeling.activations.relu import relu6
from official.modeling.activations.sigmoid import hard_sigmoid
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Customized Relu activation."""
import tensorflow as tf
@tf.keras.utils.register_keras_serializable(package='Text')
def relu6(features):
"""Computes the Relu6 activation function.
Args:
features: A `Tensor` representing preactivation values.
Returns:
The activation value.
"""
features = tf.convert_to_tensor(features)
return tf.nn.relu6(features)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the customized Relu activation."""
import tensorflow as tf
from tensorflow.python.keras import \
keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.modeling import activations
@keras_parameterized.run_all_keras_modes
class CustomizedReluTest(keras_parameterized.TestCase):
def test_relu6(self):
features = [[.25, 0, -.25], [-1, -2, 3]]
customized_relu6_data = activations.relu6(features)
relu6_data = tf.nn.relu6(features)
self.assertAllClose(customized_relu6_data, relu6_data)
if __name__ == '__main__':
tf.test.main()
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Customized Sigmoid activation."""
import tensorflow as tf
@tf.keras.utils.register_keras_serializable(package='Text')
def hard_sigmoid(features):
"""Computes the hard sigmoid activation function.
Args:
features: A `Tensor` representing preactivation values.
Returns:
The activation value.
"""
features = tf.convert_to_tensor(features)
return tf.nn.relu6(features + tf.constant(3.)) * 0.16667
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the customized Sigmoid activation."""
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import \
keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.modeling import activations
@keras_parameterized.run_all_keras_modes
class CustomizedSigmoidTest(keras_parameterized.TestCase):
def _hard_sigmoid_nn(self, x):
x = np.float32(x)
return tf.nn.relu6(x + 3.) * 0.16667
def test_hard_sigmoid(self):
features = [[.25, 0, -.25], [-1, -2, 3]]
customized_hard_sigmoid_data = activations.hard_sigmoid(features)
sigmoid_data = self._hard_sigmoid_nn(features)
self.assertAllClose(customized_hard_sigmoid_data, sigmoid_data)
if __name__ == '__main__':
tf.test.main()
...@@ -104,6 +104,8 @@ def get_activation(identifier): ...@@ -104,6 +104,8 @@ def get_activation(identifier):
"gelu": activations.gelu, "gelu": activations.gelu,
"simple_swish": activations.simple_swish, "simple_swish": activations.simple_swish,
"hard_swish": activations.hard_swish, "hard_swish": activations.hard_swish,
"relu6": activations.relu6,
"hard_sigmoid": activations.hard_sigmoid,
"identity": activations.identity, "identity": activations.identity,
} }
identifier = str(identifier).lower() identifier = str(identifier).lower()
......
...@@ -36,6 +36,14 @@ class EfficientNet(hyperparams.Config): ...@@ -36,6 +36,14 @@ class EfficientNet(hyperparams.Config):
se_ratio: float = 0.0 se_ratio: float = 0.0
@dataclasses.dataclass
class MobileNet(hyperparams.Config):
"""Mobilenet config."""
model_id: str = 'MobileNetV2'
width_multiplier: float = 1.0
stochastic_depth_drop_rate: float = 0.0
@dataclasses.dataclass @dataclasses.dataclass
class SpineNet(hyperparams.Config): class SpineNet(hyperparams.Config):
"""SpineNet config.""" """SpineNet config."""
...@@ -59,9 +67,11 @@ class Backbone(hyperparams.OneOfConfig): ...@@ -59,9 +67,11 @@ class Backbone(hyperparams.OneOfConfig):
revnet: revnet backbone config. revnet: revnet backbone config.
efficientnet: efficientnet backbone config. efficientnet: efficientnet backbone config.
spinenet: spinenet backbone config. spinenet: spinenet backbone config.
mobilenet: mobilenet backbone config.
""" """
type: Optional[str] = None type: Optional[str] = None
resnet: ResNet = ResNet() resnet: ResNet = ResNet()
revnet: RevNet = RevNet() revnet: RevNet = RevNet()
efficientnet: EfficientNet = EfficientNet() efficientnet: EfficientNet = EfficientNet()
spinenet: SpineNet = SpineNet() spinenet: SpineNet = SpineNet()
mobilenet: MobileNet = MobileNet()
...@@ -20,3 +20,4 @@ from official.vision.beta.modeling.backbones.resnet import ResNet ...@@ -20,3 +20,4 @@ from official.vision.beta.modeling.backbones.resnet import ResNet
from official.vision.beta.modeling.backbones.resnet_3d import ResNet3D from official.vision.beta.modeling.backbones.resnet_3d import ResNet3D
from official.vision.beta.modeling.backbones.revnet import RevNet from official.vision.beta.modeling.backbones.revnet import RevNet
from official.vision.beta.modeling.backbones.spinenet import SpineNet from official.vision.beta.modeling.backbones.spinenet import SpineNet
from official.vision.beta.modeling.backbones.mobilenet import MobileNet
...@@ -20,6 +20,7 @@ from absl import logging ...@@ -20,6 +20,7 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.beta.modeling.layers import nn_blocks from official.vision.beta.modeling.layers import nn_blocks
from official.vision.beta.modeling.layers import nn_layers
layers = tf.keras.layers layers = tf.keras.layers
...@@ -49,22 +50,6 @@ SCALING_MAP = { ...@@ -49,22 +50,6 @@ SCALING_MAP = {
} }
def round_filters(filters, multiplier, divisor=8, min_depth=None, skip=False):
"""Round number of filters based on depth multiplier."""
orig_f = filters
if skip or not multiplier:
return filters
filters *= multiplier
min_depth = min_depth or divisor
new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_filters < 0.9 * filters:
new_filters += divisor
logging.info('round_filter input=%s output=%s', orig_f, new_filters)
return int(new_filters)
def round_repeats(repeats, multiplier, skip=False): def round_repeats(repeats, multiplier, skip=False):
"""Round number of filters based on depth multiplier.""" """Round number of filters based on depth multiplier."""
if skip or not multiplier: if skip or not multiplier:
...@@ -95,8 +80,8 @@ class BlockSpec(object): ...@@ -95,8 +80,8 @@ class BlockSpec(object):
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.strides = strides self.strides = strides
self.expand_ratio = expand_ratio self.expand_ratio = expand_ratio
self.in_filters = round_filters(in_filters, width_scale) self.in_filters = nn_layers.round_filters(in_filters, width_scale)
self.out_filters = round_filters(out_filters, width_scale) self.out_filters = nn_layers.round_filters(out_filters, width_scale)
self.is_output = is_output self.is_output = is_output
...@@ -165,7 +150,7 @@ class EfficientNet(tf.keras.Model): ...@@ -165,7 +150,7 @@ class EfficientNet(tf.keras.Model):
# Build stem. # Build stem.
x = layers.Conv2D( x = layers.Conv2D(
filters=round_filters(32, width_scale), filters=nn_layers.round_filters(32, width_scale),
kernel_size=3, kernel_size=3,
strides=2, strides=2,
use_bias=False, use_bias=False,
...@@ -197,7 +182,7 @@ class EfficientNet(tf.keras.Model): ...@@ -197,7 +182,7 @@ class EfficientNet(tf.keras.Model):
# Build the final conv for classification. # Build the final conv for classification.
x = layers.Conv2D( x = layers.Conv2D(
filters=round_filters(1280, width_scale), filters=nn_layers.round_filters(1280, width_scale),
kernel_size=1, kernel_size=1,
strides=1, strides=1,
use_bias=False, use_bias=False,
......
...@@ -87,6 +87,16 @@ def build_backbone(input_specs: tf.keras.layers.InputSpec, ...@@ -87,6 +87,16 @@ def build_backbone(input_specs: tf.keras.layers.InputSpec,
norm_momentum=norm_activation_config.norm_momentum, norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon, norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer) kernel_regularizer=l2_regularizer)
elif backbone_type == 'mobilenet':
backbone = backbones.MobileNet(
model_id=backbone_cfg.model_id,
width_multiplier=backbone_cfg.width_multiplier,
input_specs=input_specs,
stochastic_depth_drop_rate=backbone_cfg.stochastic_depth_drop_rate,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
else: else:
raise ValueError('Backbone {!r} not implement'.format(backbone_type)) raise ValueError('Backbone {!r} not implement'.format(backbone_type))
......
...@@ -86,7 +86,41 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -86,7 +86,41 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(network_config, factory_network_config) self.assertEqual(network_config, factory_network_config)
@combinations.generate(combinations.combine(model_id=['49'],)) @combinations.generate(
combinations.combine(
model_id=['MobileNetV1', 'MobileNetV2',
'MobileNetV3Large', 'MobileNetV3Small',
'MobileNetV3EdgeTPU'],
width_multiplier=[1.0, 0.75],
))
def test_mobilenet_creation(self, model_id, width_multiplier):
"""Test creation of Mobilenet models."""
network = backbones.MobileNet(
model_id=model_id,
width_multiplier=width_multiplier,
norm_momentum=0.99,
norm_epsilon=1e-5)
backbone_config = backbones_cfg.Backbone(
type='mobilenet',
mobilenet=backbones_cfg.MobileNet(
model_id=model_id, width_multiplier=width_multiplier))
norm_activation_config = common_cfg.NormActivation(
norm_momentum=0.99, norm_epsilon=1e-5)
model_config = retinanet_cfg.RetinaNet(
backbone=backbone_config, norm_activation=norm_activation_config)
factory_network = factory.build_backbone(
input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]),
model_config=model_config)
network_config = network.get_config()
factory_network_config = factory_network.get_config()
self.assertEqual(network_config, factory_network_config)
@combinations.generate(combinations.combine(model_id=['49'], ))
def test_spinenet_creation(self, model_id): def test_spinenet_creation(self, model_id):
"""Test creation of SpineNet models.""" """Test creation of SpineNet models."""
input_size = 128 input_size = 128
......
This diff is collapsed.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for MobileNet."""
# Import libraries
from absl.testing import parameterized
from itertools import product
import tensorflow as tf
from official.vision.beta.modeling.backbones import mobilenet
class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters('MobileNetV1', 'MobileNetV2',
'MobileNetV3Large', 'MobileNetV3Small',
'MobileNetV3EdgeTPU')
def test_serialize_deserialize(self, model_id):
# Create a network object that sets all of its config options.
kwargs = dict(
model_id=model_id,
width_multiplier=1.0,
stochastic_depth_drop_rate=None,
use_sync_bn=False,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
norm_momentum=0.99,
norm_epsilon=0.001,
output_stride=None,
min_depth=8,
divisible_by=8,
regularize_depthwise=False,
finegrain_classification_mode=True
)
network = mobilenet.MobileNet(**kwargs)
expected_config = dict(kwargs)
self.assertEqual(network.get_config(), expected_config)
# Create another network object from the first object's config.
new_network = mobilenet.MobileNet.from_config(network.get_config())
# Validate that the config can be forced to JSON.
_ = new_network.to_json()
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(network.get_config(), new_network.get_config())
@parameterized.parameters(
product((1, 3),
('MobileNetV1', 'MobileNetV2',
'MobileNetV3Large', 'MobileNetV3Small',
'MobileNetV3EdgeTPU'))
)
def test_input_specs(self, input_dim, model_id):
"""Test different input feature dimensions."""
tf.keras.backend.set_image_data_format('channels_last')
input_specs = tf.keras.layers.InputSpec(shape=[None, None, None, input_dim])
network = mobilenet.MobileNet(model_id=model_id, input_specs=input_specs)
inputs = tf.keras.Input(shape=(128, 128, input_dim), batch_size=1)
_ = network(inputs)
@parameterized.parameters(32, 224)
def test_mobilenet_v1_creation(self, input_size):
"""Test creation of EfficientNet family models."""
tf.keras.backend.set_image_data_format('channels_last')
network = mobilenet.MobileNet(model_id='MobileNetV1', width_multiplier=0.75)
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
endpoints = network(inputs)
self.assertAllEqual([1, input_size / 2 ** 1, input_size / 2 ** 1, 24],
endpoints[1].shape.as_list())
self.assertAllEqual([1, input_size / 2 ** 1, input_size / 2 ** 1, 48],
endpoints[2].shape.as_list())
self.assertAllEqual([1, input_size / 2 ** 2, input_size / 2 ** 2, 96],
endpoints[3].shape.as_list())
self.assertAllEqual([1, input_size / 2 ** 2, input_size / 2 ** 2, 96],
endpoints[4].shape.as_list())
@parameterized.parameters(32, 224)
def test_mobilenet_v2_creation(self, input_size):
"""Test creation of EfficientNet family models."""
tf.keras.backend.set_image_data_format('channels_last')
network = mobilenet.MobileNet(model_id='MobileNetV2', width_multiplier=1.0)
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
endpoints = network(inputs)
self.assertAllEqual([1, input_size / 2 ** 1, input_size / 2 ** 1, 32],
endpoints[1].shape.as_list())
self.assertAllEqual([1, input_size / 2 ** 1, input_size / 2 ** 1, 16],
endpoints[2].shape.as_list())
self.assertAllEqual([1, input_size / 2 ** 2, input_size / 2 ** 2, 24],
endpoints[3].shape.as_list())
self.assertAllEqual([1, input_size / 2 ** 2, input_size / 2 ** 2, 24],
endpoints[4].shape.as_list())
@parameterized.parameters(32, 224)
def test_mobilenet_v3_small_creation(self, input_size):
"""Test creation of EfficientNet family models."""
tf.keras.backend.set_image_data_format('channels_last')
network = mobilenet.MobileNet(model_id='MobileNetV3Small',
width_multiplier=0.75)
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
endpoints = network(inputs)
self.assertAllEqual([1, input_size / 2 ** 1, input_size / 2 ** 1, 16],
endpoints[1].shape.as_list())
self.assertAllEqual([1, input_size / 2 ** 2, input_size / 2 ** 2, 16],
endpoints[2].shape.as_list())
self.assertAllEqual([1, input_size / 2 ** 3, input_size / 2 ** 3, 24],
endpoints[3].shape.as_list())
self.assertAllEqual([1, input_size / 2 ** 3, input_size / 2 ** 3, 24],
endpoints[4].shape.as_list())
@parameterized.parameters(32, 224)
def test_mobilenet_v3_large_creation(self, input_size):
"""Test creation of EfficientNet family models."""
tf.keras.backend.set_image_data_format('channels_last')
network = mobilenet.MobileNet(model_id='MobileNetV3Large',
width_multiplier=0.75)
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
endpoints = network(inputs)
self.assertAllEqual([1, input_size / 2 ** 1, input_size / 2 ** 1, 16],
endpoints[1].shape.as_list())
self.assertAllEqual([1, input_size / 2 ** 1, input_size / 2 ** 1, 16],
endpoints[2].shape.as_list())
self.assertAllEqual([1, input_size / 2 ** 2, input_size / 2 ** 2, 24],
endpoints[3].shape.as_list())
self.assertAllEqual([1, input_size / 2 ** 2, input_size / 2 ** 2, 24],
endpoints[4].shape.as_list())
@parameterized.parameters(32, 224)
def test_mobilenet_v3_edgetpu_creation(self, input_size):
"""Test creation of EfficientNet family models."""
tf.keras.backend.set_image_data_format('channels_last')
network = mobilenet.MobileNet(model_id='MobileNetV3EdgeTPU',
width_multiplier=0.75)
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
endpoints = network(inputs)
self.assertAllEqual([1, input_size / 2 ** 1, input_size / 2 ** 1, 24],
endpoints[1].shape.as_list())
self.assertAllEqual([1, input_size / 2 ** 1, input_size / 2 ** 1, 16],
endpoints[2].shape.as_list())
self.assertAllEqual([1, input_size / 2 ** 2, input_size / 2 ** 2, 24],
endpoints[3].shape.as_list())
self.assertAllEqual([1, input_size / 2 ** 2, input_size / 2 ** 2, 24],
endpoints[4].shape.as_list())
@parameterized.parameters(1.0, 0.75)
def test_mobilenet_v1_scaling(self, width_multiplier):
mobilenet_v1_params = {
1.0: 3228864,
0.75: 1832976
}
input_size = 224
network = mobilenet.MobileNet(model_id='MobileNetV1',
width_multiplier=width_multiplier)
self.assertEqual(network.count_params(),
mobilenet_v1_params[width_multiplier])
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
_ = network(inputs)
@parameterized.parameters(1.0, 0.75)
def test_mobilenet_v2_scaling(self, width_multiplier):
mobilenet_v2_params = {
1.0: 2257984,
0.75: 1382064
}
input_size = 224
network = mobilenet.MobileNet(model_id='MobileNetV2',
width_multiplier=width_multiplier)
self.assertEqual(network.count_params(),
mobilenet_v2_params[width_multiplier])
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
_ = network(inputs)
@parameterized.parameters(1.0, 0.75)
def test_mobilenet_v3_large_scaling(self, width_multiplier):
mobilenet_v3_large_params = {
1.0: 4226432,
0.75: 2731616
}
input_size = 224
network = mobilenet.MobileNet(model_id='MobileNetV3Large',
width_multiplier=width_multiplier)
self.assertEqual(network.count_params(),
mobilenet_v3_large_params[width_multiplier])
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
_ = network(inputs)
@parameterized.parameters(1.0, 0.75)
def test_mobilenet_v3_small_scaling(self, width_multiplier):
mobilenet_v3_small_params = {
1.0: 1529968,
0.75: 1026552
}
input_size = 224
network = mobilenet.MobileNet(model_id='MobileNetV3Small',
width_multiplier=width_multiplier)
self.assertEqual(network.count_params(),
mobilenet_v3_small_params[width_multiplier])
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
_ = network(inputs)
@parameterized.parameters(1.0, 0.75)
def test_mobilenet_v3_edgetpu_scaling(self, width_multiplier):
mobilenet_v3_edgetpu_params = {
1.0: 2849312,
0.75: 1737288
}
input_size = 224
network = mobilenet.MobileNet(model_id='MobileNetV3EdgeTPU',
width_multiplier=width_multiplier)
self.assertEqual(network.count_params(),
mobilenet_v3_edgetpu_params[width_multiplier])
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
_ = network(inputs)
...@@ -15,45 +15,95 @@ ...@@ -15,45 +15,95 @@
"""Contains common building blocks for neural networks.""" """Contains common building blocks for neural networks."""
# Import libraries # Import libraries
from absl import logging
from typing import Optional
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
def make_divisible(value: float,
divisor: int,
min_value: Optional[float] = None
) -> int:
"""This utility function is to ensure that all layers have a channel number
that is divisible by 8.
Args:
value: `float` original value.
divisor: `int` the divisor that need to be checked upon.
min_value: `float` minimum value threshold.
Returns:
The adjusted value in `int` that divisible against divisor.
"""
if min_value is None:
min_value = divisor
new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_value < 0.9 * value:
new_value += divisor
return new_value
def round_filters(filters: int,
multiplier: float,
divisor: int = 8,
min_depth: Optional[int] = None,
skip: bool = False):
"""Round number of filters based on width multiplier."""
orig_f = filters
if skip or not multiplier:
return filters
new_filters = make_divisible(value=filters * multiplier,
divisor=divisor,
min_value=min_depth)
logging.info('round_filter input=%s output=%s', orig_f, new_filters)
return int(new_filters)
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
class SqueezeExcitation(tf.keras.layers.Layer): class SqueezeExcitation(tf.keras.layers.Layer):
"""Squeeze and excitation layer.""" """Squeeze and excitation layer."""
def __init__(self, def __init__(self,
in_filters, in_filters,
out_filters,
se_ratio, se_ratio,
expand_ratio, divisible_by=1,
kernel_initializer='VarianceScaling', kernel_initializer='VarianceScaling',
kernel_regularizer=None, kernel_regularizer=None,
bias_regularizer=None, bias_regularizer=None,
activation='relu', activation='relu',
gating_activation='sigmoid',
**kwargs): **kwargs):
"""Implementation for squeeze and excitation. """Implementation for squeeze and excitation.
Args: Args:
in_filters: `int` number of filters of the input tensor. in_filters: `int` number of filters of the input tensor.
out_filters: `int` number of filters of the output tensor.
se_ratio: `float` or None. If not None, se ratio for the squeeze and se_ratio: `float` or None. If not None, se ratio for the squeeze and
excitation layer. excitation layer.
expand_ratio: `int` expand_ratio for a MBConv block. divisible_by: `int` ensures all inner dimensions are divisible by this number.
kernel_initializer: kernel_initializer for convolutional layers. kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D. kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None. Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d. bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None. Default to None.
activation: `str` name of the activation function. activation: `str` name of the activation function.
gating_activation: `str` name of the activation function for final gating function.
**kwargs: keyword arguments to be passed. **kwargs: keyword arguments to be passed.
""" """
super(SqueezeExcitation, self).__init__(**kwargs) super(SqueezeExcitation, self).__init__(**kwargs)
self._in_filters = in_filters self._in_filters = in_filters
self._out_filters = out_filters
self._se_ratio = se_ratio self._se_ratio = se_ratio
self._expand_ratio = expand_ratio self._divisible_by = divisible_by
self._activation = activation self._activation = activation
self._gating_activation = gating_activation
self._kernel_initializer = kernel_initializer self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer self._bias_regularizer = bias_regularizer
...@@ -62,9 +112,12 @@ class SqueezeExcitation(tf.keras.layers.Layer): ...@@ -62,9 +112,12 @@ class SqueezeExcitation(tf.keras.layers.Layer):
else: else:
self._spatial_axis = [2, 3] self._spatial_axis = [2, 3]
self._activation_fn = tf_utils.get_activation(activation) self._activation_fn = tf_utils.get_activation(activation)
self._gating_activation_fn = tf_utils.get_activation(gating_activation)
def build(self, input_shape): def build(self, input_shape):
num_reduced_filters = max(1, int(self._in_filters * self._se_ratio)) num_reduced_filters = make_divisible(
max(1, int(self._in_filters * self._se_ratio)),
divisor=self._divisible_by)
self._se_reduce = tf.keras.layers.Conv2D( self._se_reduce = tf.keras.layers.Conv2D(
filters=num_reduced_filters, filters=num_reduced_filters,
...@@ -77,7 +130,7 @@ class SqueezeExcitation(tf.keras.layers.Layer): ...@@ -77,7 +130,7 @@ class SqueezeExcitation(tf.keras.layers.Layer):
bias_regularizer=self._bias_regularizer) bias_regularizer=self._bias_regularizer)
self._se_expand = tf.keras.layers.Conv2D( self._se_expand = tf.keras.layers.Conv2D(
filters=self._in_filters * self._expand_ratio, filters=self._out_filters,
kernel_size=1, kernel_size=1,
strides=1, strides=1,
padding='same', padding='same',
...@@ -91,22 +144,24 @@ class SqueezeExcitation(tf.keras.layers.Layer): ...@@ -91,22 +144,24 @@ class SqueezeExcitation(tf.keras.layers.Layer):
def get_config(self): def get_config(self):
config = { config = {
'in_filters': self._in_filters, 'in_filters': self._in_filters,
'out_filters': self._out_filters,
'se_ratio': self._se_ratio, 'se_ratio': self._se_ratio,
'expand_ratio': self._expand_ratio, 'divisible_by': self._divisible_by,
'strides': self._strides, 'strides': self._strides,
'kernel_initializer': self._kernel_initializer, 'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer, 'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer, 'bias_regularizer': self._bias_regularizer,
'activation': self._activation, 'activation': self._activation,
'gating_activation': self._gating_activation,
} }
base_config = super(SqueezeExcitation, self).get_config() base_config = super(SqueezeExcitation, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs): def call(self, inputs):
x = tf.reduce_mean(inputs, self._spatial_axis, keepdims=True) x = tf.reduce_mean(inputs, self._spatial_axis, keepdims=True)
x = self._se_expand(self._activation_fn(self._se_reduce(x))) x = self._activation_fn(self._se_reduce(x))
x = self._gating_activation_fn(self._se_expand(x))
return tf.sigmoid(x) * inputs return x * inputs
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment