Commit 61961346 authored by Shixin Luo's avatar Shixin Luo
Browse files

move Conv2DBNBlock to mobilenet.py; change the reference name for...

move Conv2DBNBlock to mobilenet.py; change the reference name for InvertedBottleNeckBlock to align with paper
parent 602f22ed
......@@ -18,6 +18,7 @@ from typing import Text, Optional, Dict
# Import libraries
import tensorflow as tf
from official.modeling import tf_utils
from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_blocks
from official.vision.beta.modeling.layers import nn_layers
......@@ -25,16 +26,114 @@ from official.vision.beta.modeling.layers import nn_layers
layers = tf.keras.layers
regularizers = tf.keras.regularizers
class Conv2DBNBlock(tf.keras.layers.Layer):
"""A convolution block with batch normalization."""
class GlobalPoolingBlock(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(GlobalPoolingBlock, self).__init__(**kwargs)
def __init__(self,
filters: int,
kernel_size: int = 3,
strides: int = 1,
use_bias: bool = False,
activation: Text = 'relu6',
kernel_initializer: Text = 'VarianceScaling',
kernel_regularizer: Optional[
tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[
tf.keras.regularizers.Regularizer] = None,
use_normalization: bool = True,
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
**kwargs):
"""A convolution block with batch normalization.
def call(self, inputs, training=None):
x = layers.GlobalAveragePooling2D()(inputs)
outputs = layers.Reshape((1, 1, x.shape[1]))(x)
return outputs
Args:
filters: `int` number of filters for the first two convolutions. Note that
the third and final convolution will use 4 times as many filters.
kernel_size: `int` an integer specifying the height and width of the
2D convolution window.
strides: `int` block stride. If greater than 1, this block will ultimately
downsample the input.
use_bias: if True, use biase in the convolution layer.
activation: `str` name of the activation function.
kernel_size: `int` kernel_size of the conv layer.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
use_normalization: if True, use batch normalization.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization momentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
**kwargs: keyword arguments to be passed.
"""
super(Conv2DBNBlock, self).__init__(**kwargs)
self._filters = filters
self._kernel_size = kernel_size
self._strides = strides
self._activation = activation
self._use_bias = use_bias
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._use_normalization = use_normalization
self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation)
def get_config(self):
config = {
'filters': self._filters,
'strides': self._strides,
'kernel_size': self._kernel_size,
'use_bias': self._use_bias,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'use_normalization': self._use_normalization,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
}
base_config = super(Conv2DBNBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
self._conv0 = tf.keras.layers.Conv2D(
filters=self._filters,
kernel_size=self._kernel_size,
strides=self._strides,
padding='same',
use_bias=self._use_bias,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
if self._use_normalization:
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
super(Conv2DBNBlock, self).build(input_shape)
def call(self, inputs, training=None):
x = self._conv0(inputs)
if self._use_normalization:
x = self._norm0(x)
return self._activation_fn(x)
"""
Architecture: https://arxiv.org/abs/1704.04861.
......@@ -77,31 +176,23 @@ MNV2_BLOCK_SPECS = {
'expand_ratio'],
'block_specs': [
('convbn', 3, 2, 32, None),
('mbconv', 3, 1, 16, 1.),
('mbconv', 3, 2, 24, 6.),
('mbconv', 3, 1, 24, 6.),
('mbconv', 3, 2, 32, 6.),
('mbconv', 3, 1, 32, 6.),
('mbconv', 3, 1, 32, 6.),
('mbconv', 3, 2, 64, 6.),
('mbconv', 3, 1, 64, 6.),
('mbconv', 3, 1, 64, 6.),
('mbconv', 3, 1, 64, 6.),
('mbconv', 3, 1, 96, 6.),
('mbconv', 3, 1, 96, 6.),
('mbconv', 3, 1, 96, 6.),
('mbconv', 3, 2, 160, 6.),
('mbconv', 3, 1, 160, 6.),
('mbconv', 3, 1, 160, 6.),
('mbconv', 3, 1, 320, 6.),
('invertedresidual', 3, 1, 16, 1.),
('invertedresidual', 3, 2, 24, 6.),
('invertedresidual', 3, 1, 24, 6.),
('invertedresidual', 3, 2, 32, 6.),
('invertedresidual', 3, 1, 32, 6.),
('invertedresidual', 3, 1, 32, 6.),
('invertedresidual', 3, 2, 64, 6.),
('invertedresidual', 3, 1, 64, 6.),
('invertedresidual', 3, 1, 64, 6.),
('invertedresidual', 3, 1, 64, 6.),
('invertedresidual', 3, 1, 96, 6.),
('invertedresidual', 3, 1, 96, 6.),
('invertedresidual', 3, 1, 96, 6.),
('invertedresidual', 3, 2, 160, 6.),
('invertedresidual', 3, 1, 160, 6.),
('invertedresidual', 3, 1, 160, 6.),
('invertedresidual', 3, 1, 320, 6.),
('convbn', 1, 2, 1280, None),
]
}
......@@ -120,28 +211,21 @@ MNV3Large_BLOCK_SPECS = {
'use_normalization', 'use_bias'],
'block_specs': [
('convbn', 3, 2, 16, 'hard_swish', None, None, True, False),
('mbconv', 3, 1, 16, 'relu', None, 1., None, False),
('mbconv', 3, 2, 24, 'relu', None, 4., None, False),
('mbconv', 3, 1, 24, 'relu', None, 3., None, False),
('mbconv', 5, 2, 40, 'relu', 1. / 4, 3., None, False),
('mbconv', 5, 1, 40, 'relu', 1. / 4, 3., None, False),
('mbconv', 5, 1, 40, 'relu', 1. / 4, 3., None, False),
('mbconv', 3, 2, 80, 'hard_swish', None, 6., None, False),
('mbconv', 3, 1, 80, 'hard_swish', None, 2.5, None, False),
('mbconv', 3, 1, 80, 'hard_swish', None, 2.3, None, False),
('mbconv', 3, 1, 80, 'hard_swish', None, 2.3, None, False),
('mbconv', 3, 1, 112, 'hard_swish', 1. / 4, 6., None, False),
('mbconv', 3, 1, 112, 'hard_swish', 1. / 4, 6., None, False),
('mbconv', 5, 2, 160, 'hard_swish', 1. / 4, 6, None, False),
('mbconv', 5, 1, 160, 'hard_swish', 1. / 4, 6, None, False),
('mbconv', 5, 1, 160, 'hard_swish', 1. / 4, 6, None, False),
('invertedresidual', 3, 1, 16, 'relu', None, 1., None, False),
('invertedresidual', 3, 2, 24, 'relu', None, 4., None, False),
('invertedresidual', 3, 1, 24, 'relu', None, 3., None, False),
('invertedresidual', 5, 2, 40, 'relu', 1. / 4, 3., None, False),
('invertedresidual', 5, 1, 40, 'relu', 1. / 4, 3., None, False),
('invertedresidual', 5, 1, 40, 'relu', 1. / 4, 3., None, False),
('invertedresidual', 3, 2, 80, 'hard_swish', None, 6., None, False),
('invertedresidual', 3, 1, 80, 'hard_swish', None, 2.5, None, False),
('invertedresidual', 3, 1, 80, 'hard_swish', None, 2.3, None, False),
('invertedresidual', 3, 1, 80, 'hard_swish', None, 2.3, None, False),
('invertedresidual', 3, 1, 112, 'hard_swish', 1. / 4, 6., None, False),
('invertedresidual', 3, 1, 112, 'hard_swish', 1. / 4, 6., None, False),
('invertedresidual', 5, 2, 160, 'hard_swish', 1. / 4, 6., None, False),
('invertedresidual', 5, 1, 160, 'hard_swish', 1. / 4, 6., None, False),
('invertedresidual', 5, 1, 160, 'hard_swish', 1. / 4, 6., None, False),
('convbn', 1, 1, 960, 'hard_swish', None, None, True, False),
('gpooling', None, None, None, None, None, None, None, None),
('convbn', 1, 1, 1280, 'hard_swish', None, None, False, True),
......@@ -155,23 +239,17 @@ MNV3Small_BLOCK_SPECS = {
'use_normalization', 'use_bias'],
'block_specs': [
('convbn', 3, 2, 16, 'hard_swish', None, None, True, False),
('mbconv', 3, 2, 16, 'relu', 1. / 4, 1, None, False),
('mbconv', 3, 2, 24, 'relu', None, 72. / 16, None, False),
('mbconv', 3, 1, 24, 'relu', None, 88. / 24, None, False),
('mbconv', 5, 2, 40, 'hard_swish', 1. / 4, 4., None, False),
('mbconv', 5, 1, 40, 'hard_swish', 1. / 4, 6., None, False),
('mbconv', 5, 1, 40, 'hard_swish', 1. / 4, 6., None, False),
('mbconv', 5, 1, 48, 'hard_swish', 1. / 4, 3., None, False),
('mbconv', 5, 1, 48, 'hard_swish', 1. / 4, 3., None, False),
('mbconv', 5, 2, 96, 'hard_swish', 1. / 4, 6., None, False),
('mbconv', 5, 1, 96, 'hard_swish', 1. / 4, 6., None, False),
('mbconv', 5, 1, 96, 'hard_swish', 1. / 4, 6., None, False),
('invertedresidual', 3, 2, 16, 'relu', 1. / 4, 1, None, False),
('invertedresidual', 3, 2, 24, 'relu', None, 72. / 16, None, False),
('invertedresidual', 3, 1, 24, 'relu', None, 88. / 24, None, False),
('invertedresidual', 5, 2, 40, 'hard_swish', 1. / 4, 4., None, False),
('invertedresidual', 5, 1, 40, 'hard_swish', 1. / 4, 6., None, False),
('invertedresidual', 5, 1, 40, 'hard_swish', 1. / 4, 6., None, False),
('invertedresidual', 5, 1, 48, 'hard_swish', 1. / 4, 3., None, False),
('invertedresidual', 5, 1, 48, 'hard_swish', 1. / 4, 3., None, False),
('invertedresidual', 5, 2, 96, 'hard_swish', 1. / 4, 6., None, False),
('invertedresidual', 5, 1, 96, 'hard_swish', 1. / 4, 6., None, False),
('invertedresidual', 5, 1, 96, 'hard_swish', 1. / 4, 6., None, False),
('convbn', 1, 1, 576, 'hard_swish', None, None, True, False),
('gpooling', None, None, None, None, None, None, None, None),
('convbn', 1, 1, 1024, 'hard_swish', None, None, False, True),
......@@ -189,36 +267,28 @@ MNV3EdgeTPU_BLOCK_SPECS = {
'use_residual', 'use_depthwise'],
'block_specs': [
('convbn', 3, 2, 32, 'relu', None, None, None, None),
('mbconv', 3, 1, 16, 'relu', None, 1., True, False),
('mbconv', 3, 2, 32, 'relu', None, 8., True, False),
('mbconv', 3, 1, 32, 'relu', None, 4., True, False),
('mbconv', 3, 1, 32, 'relu', None, 4., True, False),
('mbconv', 3, 1, 32, 'relu', None, 4., True, False),
('mbconv', 3, 2, 48, 'relu', None, 8., True, False),
('mbconv', 3, 1, 48, 'relu', None, 4., True, False),
('mbconv', 3, 1, 48, 'relu', None, 4., True, False),
('mbconv', 3, 1, 48, 'relu', None, 4., True, False),
('mbconv', 3, 2, 96, 'relu', None, 8., True, True),
('mbconv', 3, 1, 96, 'relu', None, 4., True, True),
('mbconv', 3, 1, 96, 'relu', None, 4., True, True),
('mbconv', 3, 1, 96, 'relu', None, 4., True, True),
('mbconv', 3, 1, 96, 'relu', None, 8., False, True),
('mbconv', 3, 1, 96, 'relu', None, 4., True, True),
('mbconv', 3, 1, 96, 'relu', None, 4., True, True),
('mbconv', 3, 1, 96, 'relu', None, 4., True, True),
('mbconv', 5, 2, 160, 'relu', None, 8., True, True),
('mbconv', 5, 1, 160, 'relu', None, 4., True, True),
('mbconv', 5, 1, 160, 'relu', None, 4., True, True),
('mbconv', 5, 1, 160, 'relu', None, 4., True, True),
('mbconv', 3, 1, 192, 'relu', None, 8., True, True),
('invertedresidual', 3, 1, 16, 'relu', None, 1., True, False),
('invertedresidual', 3, 2, 32, 'relu', None, 8., True, False),
('invertedresidual', 3, 1, 32, 'relu', None, 4., True, False),
('invertedresidual', 3, 1, 32, 'relu', None, 4., True, False),
('invertedresidual', 3, 1, 32, 'relu', None, 4., True, False),
('invertedresidual', 3, 2, 48, 'relu', None, 8., True, False),
('invertedresidual', 3, 1, 48, 'relu', None, 4., True, False),
('invertedresidual', 3, 1, 48, 'relu', None, 4., True, False),
('invertedresidual', 3, 1, 48, 'relu', None, 4., True, False),
('invertedresidual', 3, 2, 96, 'relu', None, 8., True, True),
('invertedresidual', 3, 1, 96, 'relu', None, 4., True, True),
('invertedresidual', 3, 1, 96, 'relu', None, 4., True, True),
('invertedresidual', 3, 1, 96, 'relu', None, 4., True, True),
('invertedresidual', 3, 1, 96, 'relu', None, 8., False, True),
('invertedresidual', 3, 1, 96, 'relu', None, 4., True, True),
('invertedresidual', 3, 1, 96, 'relu', None, 4., True, True),
('invertedresidual', 3, 1, 96, 'relu', None, 4., True, True),
('invertedresidual', 5, 2, 160, 'relu', None, 8., True, True),
('invertedresidual', 5, 1, 160, 'relu', None, 4., True, True),
('invertedresidual', 5, 1, 160, 'relu', None, 4., True, True),
('invertedresidual', 5, 1, 160, 'relu', None, 4., True, True),
('invertedresidual', 3, 1, 192, 'relu', None, 8., True, True),
('convbn', 1, 1, 1280, 'relu', None, None, None, None),
]
}
......@@ -231,13 +301,6 @@ SUPPORTED_SPECS_MAP = {
'MobileNetV3EdgeTPU': MNV3EdgeTPU_BLOCK_SPECS,
}
BLOCK_FN_MAP = {
'convbn': nn_blocks.Conv2DBNBlock,
'depsepconv': nn_blocks.DepthwiseSeparableConvBlock,
'mbconv': nn_blocks.InvertedBottleneckBlock,
'gpooling': GlobalPoolingBlock,
}
class BlockSpec(object):
"""A container class that specifies the block configuration for MobileNet."""
......@@ -297,7 +360,8 @@ def block_spec_decoder(specs: Dict,
block_specs = specs['block_specs']
if len(block_specs) == 0:
raise ValueError('The block spec cannot be empty for {} !'.format(spec_name))
raise ValueError(
'The block spec cannot be empty for {} !'.format(spec_name))
if len(block_specs[0]) != len(block_spec_schema):
raise ValueError('The block spec values {} do not match with '
......@@ -489,7 +553,7 @@ class MobileNet(tf.keras.Model):
if block_def.block_fn == 'convbn':
net = nn_blocks.Conv2DBNBlock(
net = Conv2DBNBlock(
filters=block_def.filters,
kernel_size=block_def.kernel_size,
strides=block_def.strides,
......@@ -519,7 +583,7 @@ class MobileNet(tf.keras.Model):
norm_epsilon=self._norm_epsilon,
)(net)
elif block_def.block_fn == 'mbconv':
elif block_def.block_fn == 'invertedresidual':
use_rate = rate
if layer_rate > 1 and block_def.kernel_size != 1:
# We will apply atrous rate in the following cases:
......@@ -554,7 +618,8 @@ class MobileNet(tf.keras.Model):
)(net)
elif block_def.block_fn == 'gpooling':
net = GlobalPoolingBlock()(net)
net = layers.GlobalAveragePooling2D()(net)
net = layers.Reshape((1, 1, net.shape[1]))(net)
else:
raise ValueError('Unknown block type {} for layer {}'.format(
......@@ -604,7 +669,7 @@ def build_mobilenet(
backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
assert backbone_type == 'mobilenet', (f'Inconsistent backbone type '
f'{backbone_type}')
f'{backbone_type}')
return MobileNet(
model_id=backbone_cfg.model_id,
......
......@@ -1184,113 +1184,3 @@ class DepthwiseSeparableConvBlock(tf.keras.layers.Layer):
x = self._conv1(x)
x = self._norm1(x)
return self._activation_fn(x)
class Conv2DBNBlock(tf.keras.layers.Layer):
"""A convolution block with batch normalization."""
def __init__(self,
filters: int,
kernel_size: int = 3,
strides: int = 1,
use_bias: bool = False,
activation: Text = 'relu6',
kernel_initializer: Text = 'VarianceScaling',
kernel_regularizer: Optional[
tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[
tf.keras.regularizers.Regularizer] = None,
use_normalization: bool = True,
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
**kwargs):
"""A convolution block with batch normalization.
Args:
filters: `int` number of filters for the first two convolutions. Note that
the third and final convolution will use 4 times as many filters.
kernel_size: `int` an integer specifying the height and width of the
2D convolution window.
strides: `int` block stride. If greater than 1, this block will ultimately
downsample the input.
use_bias: if True, use biase in the convolution layer.
activation: `str` name of the activation function.
kernel_size: `int` kernel_size of the conv layer.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
use_normalization: if True, use batch normalization.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization momentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
**kwargs: keyword arguments to be passed.
"""
super(Conv2DBNBlock, self).__init__(**kwargs)
self._filters = filters
self._kernel_size = kernel_size
self._strides = strides
self._activation = activation
self._use_bias = use_bias
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._use_normalization = use_normalization
self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation)
def get_config(self):
config = {
'filters': self._filters,
'strides': self._strides,
'kernel_size': self._kernel_size,
'use_bias': self._use_bias,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'use_normalization': self._use_normalization,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
}
base_config = super(Conv2DBNBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
self._conv0 = tf.keras.layers.Conv2D(
filters=self._filters,
kernel_size=self._kernel_size,
strides=self._strides,
padding='same',
use_bias=self._use_bias,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
if self._use_normalization:
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
super(Conv2DBNBlock, self).build(input_shape)
def call(self, inputs, training=None):
x = self._conv0(inputs)
if self._use_normalization:
x = self._norm0(x)
return self._activation_fn(x)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment