Commit bdca62cc authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 394355848
parent c5b6d8da
......@@ -32,6 +32,7 @@ class ResNet(hyperparams.Config):
stochastic_depth_drop_rate: float = 0.0
resnetd_shortcut: bool = False
replace_stem_max_pool: bool = False
bn_trainable: bool = True
@dataclasses.dataclass
......
......@@ -127,6 +127,7 @@ class ResNet(tf.keras.Model):
kernel_initializer: str = 'VarianceScaling',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bn_trainable: bool = True,
**kwargs):
"""Initializes a ResNet model.
......@@ -153,6 +154,8 @@ class ResNet(tf.keras.Model):
Conv2D. Default to None.
bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
Default to None.
bn_trainable: A `bool` that indicates whether batch norm layers should be
trainable. Default to True.
**kwargs: Additional keyword arguments to be passed.
"""
self._model_id = model_id
......@@ -174,6 +177,7 @@ class ResNet(tf.keras.Model):
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._bn_trainable = bn_trainable
if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = -1
......@@ -195,7 +199,10 @@ class ResNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)(
inputs)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
trainable=bn_trainable)(
x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
elif stem_type == 'v1':
......@@ -210,7 +217,10 @@ class ResNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)(
inputs)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
trainable=bn_trainable)(
x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
x = layers.Conv2D(
......@@ -224,7 +234,10 @@ class ResNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
trainable=bn_trainable)(
x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
x = layers.Conv2D(
......@@ -238,7 +251,10 @@ class ResNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
trainable=bn_trainable)(
x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
else:
......@@ -256,7 +272,10 @@ class ResNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
trainable=bn_trainable)(
x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
else:
......@@ -324,7 +343,8 @@ class ResNet(tf.keras.Model):
activation=self._activation,
use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)(
norm_epsilon=self._norm_epsilon,
bn_trainable=self._bn_trainable)(
inputs)
for _ in range(1, block_repeats):
......@@ -341,7 +361,8 @@ class ResNet(tf.keras.Model):
activation=self._activation,
use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)(
norm_epsilon=self._norm_epsilon,
bn_trainable=self._bn_trainable)(
x)
return tf.keras.layers.Activation('linear', name=name)(x)
......@@ -362,6 +383,7 @@ class ResNet(tf.keras.Model):
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'bn_trainable': self._bn_trainable
}
return config_dict
......@@ -400,4 +422,5 @@ def build_resnet(
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
kernel_regularizer=l2_regularizer,
bn_trainable=backbone_cfg.bn_trainable)
......@@ -135,6 +135,7 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase):
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
bn_trainable=True
)
network = resnet.ResNet(**kwargs)
......
......@@ -72,6 +72,7 @@ class ResidualBlock(tf.keras.layers.Layer):
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
bn_trainable=True,
**kwargs):
"""Initializes a residual block with BN after convolutions.
......@@ -99,6 +100,8 @@ class ResidualBlock(tf.keras.layers.Layer):
use_sync_bn: A `bool`. If True, use synchronized batch normalization.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero.
bn_trainable: A `bool` that indicates whether batch norm layers should be
trainable. Default to True.
**kwargs: Additional keyword arguments to be passed.
"""
super(ResidualBlock, self).__init__(**kwargs)
......@@ -126,6 +129,7 @@ class ResidualBlock(tf.keras.layers.Layer):
else:
self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation)
self._bn_trainable = bn_trainable
def build(self, input_shape):
if self._use_projection:
......@@ -140,7 +144,8 @@ class ResidualBlock(tf.keras.layers.Layer):
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
self._conv1 = tf.keras.layers.Conv2D(
filters=self._filters,
......@@ -154,7 +159,8 @@ class ResidualBlock(tf.keras.layers.Layer):
self._norm1 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
self._conv2 = tf.keras.layers.Conv2D(
filters=self._filters,
......@@ -168,7 +174,8 @@ class ResidualBlock(tf.keras.layers.Layer):
self._norm2 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
if self._se_ratio and self._se_ratio > 0 and self._se_ratio <= 1:
self._squeeze_excitation = nn_layers.SqueezeExcitation(
......@@ -203,7 +210,8 @@ class ResidualBlock(tf.keras.layers.Layer):
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
'norm_epsilon': self._norm_epsilon,
'bn_trainable': self._bn_trainable
}
base_config = super(ResidualBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
......@@ -249,6 +257,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
bn_trainable=True,
**kwargs):
"""Initializes a standard bottleneck block with BN after convolutions.
......@@ -277,6 +286,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
use_sync_bn: A `bool`. If True, use synchronized batch normalization.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero.
bn_trainable: A `bool` that indicates whether batch norm layers should be
trainable. Default to True.
**kwargs: Additional keyword arguments to be passed.
"""
super(BottleneckBlock, self).__init__(**kwargs)
......@@ -303,6 +314,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
self._bn_axis = -1
else:
self._bn_axis = 1
self._bn_trainable = bn_trainable
def build(self, input_shape):
if self._use_projection:
......@@ -330,7 +342,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
self._conv1 = tf.keras.layers.Conv2D(
filters=self._filters,
......@@ -343,7 +356,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
self._norm1 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
self._activation1 = tf_utils.get_activation(
self._activation, use_keras_layer=True)
......@@ -360,7 +374,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
self._norm2 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
self._activation2 = tf_utils.get_activation(
self._activation, use_keras_layer=True)
......@@ -375,7 +390,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
self._norm3 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
self._activation3 = tf_utils.get_activation(
self._activation, use_keras_layer=True)
......@@ -414,7 +430,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
'norm_epsilon': self._norm_epsilon,
'bn_trainable': self._bn_trainable
}
base_config = super(BottleneckBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment