Commit 61961346 authored by Shixin Luo's avatar Shixin Luo
Browse files

move Conv2DBNBlock to mobilenet.py; change the reference name for...

move Conv2DBNBlock to mobilenet.py; change the reference name for InvertedBottleNeckBlock to align with paper
parent 602f22ed
...@@ -18,6 +18,7 @@ from typing import Text, Optional, Dict ...@@ -18,6 +18,7 @@ from typing import Text, Optional, Dict
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.vision.beta.modeling.backbones import factory from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_blocks from official.vision.beta.modeling.layers import nn_blocks
from official.vision.beta.modeling.layers import nn_layers from official.vision.beta.modeling.layers import nn_layers
...@@ -25,16 +26,114 @@ from official.vision.beta.modeling.layers import nn_layers ...@@ -25,16 +26,114 @@ from official.vision.beta.modeling.layers import nn_layers
layers = tf.keras.layers layers = tf.keras.layers
regularizers = tf.keras.regularizers regularizers = tf.keras.regularizers
class Conv2DBNBlock(tf.keras.layers.Layer):
"""A convolution block with batch normalization."""
class GlobalPoolingBlock(tf.keras.layers.Layer): def __init__(self,
def __init__(self, **kwargs): filters: int,
super(GlobalPoolingBlock, self).__init__(**kwargs) kernel_size: int = 3,
strides: int = 1,
use_bias: bool = False,
activation: Text = 'relu6',
kernel_initializer: Text = 'VarianceScaling',
kernel_regularizer: Optional[
tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[
tf.keras.regularizers.Regularizer] = None,
use_normalization: bool = True,
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
**kwargs):
"""A convolution block with batch normalization.
def call(self, inputs, training=None): Args:
x = layers.GlobalAveragePooling2D()(inputs) filters: `int` number of filters for the first two convolutions. Note that
outputs = layers.Reshape((1, 1, x.shape[1]))(x) the third and final convolution will use 4 times as many filters.
return outputs kernel_size: `int` an integer specifying the height and width of the
2D convolution window.
strides: `int` block stride. If greater than 1, this block will ultimately
downsample the input.
use_bias: if True, use biase in the convolution layer.
activation: `str` name of the activation function.
kernel_size: `int` kernel_size of the conv layer.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
use_normalization: if True, use batch normalization.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization momentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
**kwargs: keyword arguments to be passed.
"""
super(Conv2DBNBlock, self).__init__(**kwargs)
self._filters = filters
self._kernel_size = kernel_size
self._strides = strides
self._activation = activation
self._use_bias = use_bias
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._use_normalization = use_normalization
self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation)
def get_config(self):
config = {
'filters': self._filters,
'strides': self._strides,
'kernel_size': self._kernel_size,
'use_bias': self._use_bias,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'use_normalization': self._use_normalization,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
}
base_config = super(Conv2DBNBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
self._conv0 = tf.keras.layers.Conv2D(
filters=self._filters,
kernel_size=self._kernel_size,
strides=self._strides,
padding='same',
use_bias=self._use_bias,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
if self._use_normalization:
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
super(Conv2DBNBlock, self).build(input_shape)
def call(self, inputs, training=None):
x = self._conv0(inputs)
if self._use_normalization:
x = self._norm0(x)
return self._activation_fn(x)
""" """
Architecture: https://arxiv.org/abs/1704.04861. Architecture: https://arxiv.org/abs/1704.04861.
...@@ -77,31 +176,23 @@ MNV2_BLOCK_SPECS = { ...@@ -77,31 +176,23 @@ MNV2_BLOCK_SPECS = {
'expand_ratio'], 'expand_ratio'],
'block_specs': [ 'block_specs': [
('convbn', 3, 2, 32, None), ('convbn', 3, 2, 32, None),
('invertedresidual', 3, 1, 16, 1.),
('mbconv', 3, 1, 16, 1.), ('invertedresidual', 3, 2, 24, 6.),
('invertedresidual', 3, 1, 24, 6.),
('mbconv', 3, 2, 24, 6.), ('invertedresidual', 3, 2, 32, 6.),
('mbconv', 3, 1, 24, 6.), ('invertedresidual', 3, 1, 32, 6.),
('invertedresidual', 3, 1, 32, 6.),
('mbconv', 3, 2, 32, 6.), ('invertedresidual', 3, 2, 64, 6.),
('mbconv', 3, 1, 32, 6.), ('invertedresidual', 3, 1, 64, 6.),
('mbconv', 3, 1, 32, 6.), ('invertedresidual', 3, 1, 64, 6.),
('invertedresidual', 3, 1, 64, 6.),
('mbconv', 3, 2, 64, 6.), ('invertedresidual', 3, 1, 96, 6.),
('mbconv', 3, 1, 64, 6.), ('invertedresidual', 3, 1, 96, 6.),
('mbconv', 3, 1, 64, 6.), ('invertedresidual', 3, 1, 96, 6.),
('mbconv', 3, 1, 64, 6.), ('invertedresidual', 3, 2, 160, 6.),
('invertedresidual', 3, 1, 160, 6.),
('mbconv', 3, 1, 96, 6.), ('invertedresidual', 3, 1, 160, 6.),
('mbconv', 3, 1, 96, 6.), ('invertedresidual', 3, 1, 320, 6.),
('mbconv', 3, 1, 96, 6.),
('mbconv', 3, 2, 160, 6.),
('mbconv', 3, 1, 160, 6.),
('mbconv', 3, 1, 160, 6.),
('mbconv', 3, 1, 320, 6.),
('convbn', 1, 2, 1280, None), ('convbn', 1, 2, 1280, None),
] ]
} }
...@@ -120,28 +211,21 @@ MNV3Large_BLOCK_SPECS = { ...@@ -120,28 +211,21 @@ MNV3Large_BLOCK_SPECS = {
'use_normalization', 'use_bias'], 'use_normalization', 'use_bias'],
'block_specs': [ 'block_specs': [
('convbn', 3, 2, 16, 'hard_swish', None, None, True, False), ('convbn', 3, 2, 16, 'hard_swish', None, None, True, False),
('invertedresidual', 3, 1, 16, 'relu', None, 1., None, False),
('mbconv', 3, 1, 16, 'relu', None, 1., None, False), ('invertedresidual', 3, 2, 24, 'relu', None, 4., None, False),
('invertedresidual', 3, 1, 24, 'relu', None, 3., None, False),
('mbconv', 3, 2, 24, 'relu', None, 4., None, False), ('invertedresidual', 5, 2, 40, 'relu', 1. / 4, 3., None, False),
('mbconv', 3, 1, 24, 'relu', None, 3., None, False), ('invertedresidual', 5, 1, 40, 'relu', 1. / 4, 3., None, False),
('invertedresidual', 5, 1, 40, 'relu', 1. / 4, 3., None, False),
('mbconv', 5, 2, 40, 'relu', 1. / 4, 3., None, False), ('invertedresidual', 3, 2, 80, 'hard_swish', None, 6., None, False),
('mbconv', 5, 1, 40, 'relu', 1. / 4, 3., None, False), ('invertedresidual', 3, 1, 80, 'hard_swish', None, 2.5, None, False),
('mbconv', 5, 1, 40, 'relu', 1. / 4, 3., None, False), ('invertedresidual', 3, 1, 80, 'hard_swish', None, 2.3, None, False),
('invertedresidual', 3, 1, 80, 'hard_swish', None, 2.3, None, False),
('mbconv', 3, 2, 80, 'hard_swish', None, 6., None, False), ('invertedresidual', 3, 1, 112, 'hard_swish', 1. / 4, 6., None, False),
('mbconv', 3, 1, 80, 'hard_swish', None, 2.5, None, False), ('invertedresidual', 3, 1, 112, 'hard_swish', 1. / 4, 6., None, False),
('mbconv', 3, 1, 80, 'hard_swish', None, 2.3, None, False), ('invertedresidual', 5, 2, 160, 'hard_swish', 1. / 4, 6., None, False),
('mbconv', 3, 1, 80, 'hard_swish', None, 2.3, None, False), ('invertedresidual', 5, 1, 160, 'hard_swish', 1. / 4, 6., None, False),
('invertedresidual', 5, 1, 160, 'hard_swish', 1. / 4, 6., None, False),
('mbconv', 3, 1, 112, 'hard_swish', 1. / 4, 6., None, False),
('mbconv', 3, 1, 112, 'hard_swish', 1. / 4, 6., None, False),
('mbconv', 5, 2, 160, 'hard_swish', 1. / 4, 6, None, False),
('mbconv', 5, 1, 160, 'hard_swish', 1. / 4, 6, None, False),
('mbconv', 5, 1, 160, 'hard_swish', 1. / 4, 6, None, False),
('convbn', 1, 1, 960, 'hard_swish', None, None, True, False), ('convbn', 1, 1, 960, 'hard_swish', None, None, True, False),
('gpooling', None, None, None, None, None, None, None, None), ('gpooling', None, None, None, None, None, None, None, None),
('convbn', 1, 1, 1280, 'hard_swish', None, None, False, True), ('convbn', 1, 1, 1280, 'hard_swish', None, None, False, True),
...@@ -155,23 +239,17 @@ MNV3Small_BLOCK_SPECS = { ...@@ -155,23 +239,17 @@ MNV3Small_BLOCK_SPECS = {
'use_normalization', 'use_bias'], 'use_normalization', 'use_bias'],
'block_specs': [ 'block_specs': [
('convbn', 3, 2, 16, 'hard_swish', None, None, True, False), ('convbn', 3, 2, 16, 'hard_swish', None, None, True, False),
('invertedresidual', 3, 2, 16, 'relu', 1. / 4, 1, None, False),
('mbconv', 3, 2, 16, 'relu', 1. / 4, 1, None, False), ('invertedresidual', 3, 2, 24, 'relu', None, 72. / 16, None, False),
('invertedresidual', 3, 1, 24, 'relu', None, 88. / 24, None, False),
('mbconv', 3, 2, 24, 'relu', None, 72. / 16, None, False), ('invertedresidual', 5, 2, 40, 'hard_swish', 1. / 4, 4., None, False),
('mbconv', 3, 1, 24, 'relu', None, 88. / 24, None, False), ('invertedresidual', 5, 1, 40, 'hard_swish', 1. / 4, 6., None, False),
('invertedresidual', 5, 1, 40, 'hard_swish', 1. / 4, 6., None, False),
('mbconv', 5, 2, 40, 'hard_swish', 1. / 4, 4., None, False), ('invertedresidual', 5, 1, 48, 'hard_swish', 1. / 4, 3., None, False),
('mbconv', 5, 1, 40, 'hard_swish', 1. / 4, 6., None, False), ('invertedresidual', 5, 1, 48, 'hard_swish', 1. / 4, 3., None, False),
('mbconv', 5, 1, 40, 'hard_swish', 1. / 4, 6., None, False), ('invertedresidual', 5, 2, 96, 'hard_swish', 1. / 4, 6., None, False),
('invertedresidual', 5, 1, 96, 'hard_swish', 1. / 4, 6., None, False),
('mbconv', 5, 1, 48, 'hard_swish', 1. / 4, 3., None, False), ('invertedresidual', 5, 1, 96, 'hard_swish', 1. / 4, 6., None, False),
('mbconv', 5, 1, 48, 'hard_swish', 1. / 4, 3., None, False),
('mbconv', 5, 2, 96, 'hard_swish', 1. / 4, 6., None, False),
('mbconv', 5, 1, 96, 'hard_swish', 1. / 4, 6., None, False),
('mbconv', 5, 1, 96, 'hard_swish', 1. / 4, 6., None, False),
('convbn', 1, 1, 576, 'hard_swish', None, None, True, False), ('convbn', 1, 1, 576, 'hard_swish', None, None, True, False),
('gpooling', None, None, None, None, None, None, None, None), ('gpooling', None, None, None, None, None, None, None, None),
('convbn', 1, 1, 1024, 'hard_swish', None, None, False, True), ('convbn', 1, 1, 1024, 'hard_swish', None, None, False, True),
...@@ -189,36 +267,28 @@ MNV3EdgeTPU_BLOCK_SPECS = { ...@@ -189,36 +267,28 @@ MNV3EdgeTPU_BLOCK_SPECS = {
'use_residual', 'use_depthwise'], 'use_residual', 'use_depthwise'],
'block_specs': [ 'block_specs': [
('convbn', 3, 2, 32, 'relu', None, None, None, None), ('convbn', 3, 2, 32, 'relu', None, None, None, None),
('invertedresidual', 3, 1, 16, 'relu', None, 1., True, False),
('mbconv', 3, 1, 16, 'relu', None, 1., True, False), ('invertedresidual', 3, 2, 32, 'relu', None, 8., True, False),
('invertedresidual', 3, 1, 32, 'relu', None, 4., True, False),
('mbconv', 3, 2, 32, 'relu', None, 8., True, False), ('invertedresidual', 3, 1, 32, 'relu', None, 4., True, False),
('mbconv', 3, 1, 32, 'relu', None, 4., True, False), ('invertedresidual', 3, 1, 32, 'relu', None, 4., True, False),
('mbconv', 3, 1, 32, 'relu', None, 4., True, False), ('invertedresidual', 3, 2, 48, 'relu', None, 8., True, False),
('mbconv', 3, 1, 32, 'relu', None, 4., True, False), ('invertedresidual', 3, 1, 48, 'relu', None, 4., True, False),
('invertedresidual', 3, 1, 48, 'relu', None, 4., True, False),
('mbconv', 3, 2, 48, 'relu', None, 8., True, False), ('invertedresidual', 3, 1, 48, 'relu', None, 4., True, False),
('mbconv', 3, 1, 48, 'relu', None, 4., True, False), ('invertedresidual', 3, 2, 96, 'relu', None, 8., True, True),
('mbconv', 3, 1, 48, 'relu', None, 4., True, False), ('invertedresidual', 3, 1, 96, 'relu', None, 4., True, True),
('mbconv', 3, 1, 48, 'relu', None, 4., True, False), ('invertedresidual', 3, 1, 96, 'relu', None, 4., True, True),
('invertedresidual', 3, 1, 96, 'relu', None, 4., True, True),
('mbconv', 3, 2, 96, 'relu', None, 8., True, True), ('invertedresidual', 3, 1, 96, 'relu', None, 8., False, True),
('mbconv', 3, 1, 96, 'relu', None, 4., True, True), ('invertedresidual', 3, 1, 96, 'relu', None, 4., True, True),
('mbconv', 3, 1, 96, 'relu', None, 4., True, True), ('invertedresidual', 3, 1, 96, 'relu', None, 4., True, True),
('mbconv', 3, 1, 96, 'relu', None, 4., True, True), ('invertedresidual', 3, 1, 96, 'relu', None, 4., True, True),
('invertedresidual', 5, 2, 160, 'relu', None, 8., True, True),
('mbconv', 3, 1, 96, 'relu', None, 8., False, True), ('invertedresidual', 5, 1, 160, 'relu', None, 4., True, True),
('mbconv', 3, 1, 96, 'relu', None, 4., True, True), ('invertedresidual', 5, 1, 160, 'relu', None, 4., True, True),
('mbconv', 3, 1, 96, 'relu', None, 4., True, True), ('invertedresidual', 5, 1, 160, 'relu', None, 4., True, True),
('mbconv', 3, 1, 96, 'relu', None, 4., True, True), ('invertedresidual', 3, 1, 192, 'relu', None, 8., True, True),
('mbconv', 5, 2, 160, 'relu', None, 8., True, True),
('mbconv', 5, 1, 160, 'relu', None, 4., True, True),
('mbconv', 5, 1, 160, 'relu', None, 4., True, True),
('mbconv', 5, 1, 160, 'relu', None, 4., True, True),
('mbconv', 3, 1, 192, 'relu', None, 8., True, True),
('convbn', 1, 1, 1280, 'relu', None, None, None, None), ('convbn', 1, 1, 1280, 'relu', None, None, None, None),
] ]
} }
...@@ -231,13 +301,6 @@ SUPPORTED_SPECS_MAP = { ...@@ -231,13 +301,6 @@ SUPPORTED_SPECS_MAP = {
'MobileNetV3EdgeTPU': MNV3EdgeTPU_BLOCK_SPECS, 'MobileNetV3EdgeTPU': MNV3EdgeTPU_BLOCK_SPECS,
} }
BLOCK_FN_MAP = {
'convbn': nn_blocks.Conv2DBNBlock,
'depsepconv': nn_blocks.DepthwiseSeparableConvBlock,
'mbconv': nn_blocks.InvertedBottleneckBlock,
'gpooling': GlobalPoolingBlock,
}
class BlockSpec(object): class BlockSpec(object):
"""A container class that specifies the block configuration for MobileNet.""" """A container class that specifies the block configuration for MobileNet."""
...@@ -297,7 +360,8 @@ def block_spec_decoder(specs: Dict, ...@@ -297,7 +360,8 @@ def block_spec_decoder(specs: Dict,
block_specs = specs['block_specs'] block_specs = specs['block_specs']
if len(block_specs) == 0: if len(block_specs) == 0:
raise ValueError('The block spec cannot be empty for {} !'.format(spec_name)) raise ValueError(
'The block spec cannot be empty for {} !'.format(spec_name))
if len(block_specs[0]) != len(block_spec_schema): if len(block_specs[0]) != len(block_spec_schema):
raise ValueError('The block spec values {} do not match with ' raise ValueError('The block spec values {} do not match with '
...@@ -489,7 +553,7 @@ class MobileNet(tf.keras.Model): ...@@ -489,7 +553,7 @@ class MobileNet(tf.keras.Model):
if block_def.block_fn == 'convbn': if block_def.block_fn == 'convbn':
net = nn_blocks.Conv2DBNBlock( net = Conv2DBNBlock(
filters=block_def.filters, filters=block_def.filters,
kernel_size=block_def.kernel_size, kernel_size=block_def.kernel_size,
strides=block_def.strides, strides=block_def.strides,
...@@ -519,7 +583,7 @@ class MobileNet(tf.keras.Model): ...@@ -519,7 +583,7 @@ class MobileNet(tf.keras.Model):
norm_epsilon=self._norm_epsilon, norm_epsilon=self._norm_epsilon,
)(net) )(net)
elif block_def.block_fn == 'mbconv': elif block_def.block_fn == 'invertedresidual':
use_rate = rate use_rate = rate
if layer_rate > 1 and block_def.kernel_size != 1: if layer_rate > 1 and block_def.kernel_size != 1:
# We will apply atrous rate in the following cases: # We will apply atrous rate in the following cases:
...@@ -554,7 +618,8 @@ class MobileNet(tf.keras.Model): ...@@ -554,7 +618,8 @@ class MobileNet(tf.keras.Model):
)(net) )(net)
elif block_def.block_fn == 'gpooling': elif block_def.block_fn == 'gpooling':
net = GlobalPoolingBlock()(net) net = layers.GlobalAveragePooling2D()(net)
net = layers.Reshape((1, 1, net.shape[1]))(net)
else: else:
raise ValueError('Unknown block type {} for layer {}'.format( raise ValueError('Unknown block type {} for layer {}'.format(
......
...@@ -1184,113 +1184,3 @@ class DepthwiseSeparableConvBlock(tf.keras.layers.Layer): ...@@ -1184,113 +1184,3 @@ class DepthwiseSeparableConvBlock(tf.keras.layers.Layer):
x = self._conv1(x) x = self._conv1(x)
x = self._norm1(x) x = self._norm1(x)
return self._activation_fn(x) return self._activation_fn(x)
class Conv2DBNBlock(tf.keras.layers.Layer):
"""A convolution block with batch normalization."""
def __init__(self,
filters: int,
kernel_size: int = 3,
strides: int = 1,
use_bias: bool = False,
activation: Text = 'relu6',
kernel_initializer: Text = 'VarianceScaling',
kernel_regularizer: Optional[
tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[
tf.keras.regularizers.Regularizer] = None,
use_normalization: bool = True,
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
**kwargs):
"""A convolution block with batch normalization.
Args:
filters: `int` number of filters for the first two convolutions. Note that
the third and final convolution will use 4 times as many filters.
kernel_size: `int` an integer specifying the height and width of the
2D convolution window.
strides: `int` block stride. If greater than 1, this block will ultimately
downsample the input.
use_bias: if True, use biase in the convolution layer.
activation: `str` name of the activation function.
kernel_size: `int` kernel_size of the conv layer.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
use_normalization: if True, use batch normalization.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization momentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
**kwargs: keyword arguments to be passed.
"""
super(Conv2DBNBlock, self).__init__(**kwargs)
self._filters = filters
self._kernel_size = kernel_size
self._strides = strides
self._activation = activation
self._use_bias = use_bias
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._use_normalization = use_normalization
self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation)
def get_config(self):
config = {
'filters': self._filters,
'strides': self._strides,
'kernel_size': self._kernel_size,
'use_bias': self._use_bias,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'use_normalization': self._use_normalization,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
}
base_config = super(Conv2DBNBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
self._conv0 = tf.keras.layers.Conv2D(
filters=self._filters,
kernel_size=self._kernel_size,
strides=self._strides,
padding='same',
use_bias=self._use_bias,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
if self._use_normalization:
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
super(Conv2DBNBlock, self).build(input_shape)
def call(self, inputs, training=None):
x = self._conv0(inputs)
if self._use_normalization:
x = self._norm0(x)
return self._activation_fn(x)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment