Commit 609a5b5e authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 394355848
parent de1d7424
...@@ -32,6 +32,7 @@ class ResNet(hyperparams.Config): ...@@ -32,6 +32,7 @@ class ResNet(hyperparams.Config):
stochastic_depth_drop_rate: float = 0.0 stochastic_depth_drop_rate: float = 0.0
resnetd_shortcut: bool = False resnetd_shortcut: bool = False
replace_stem_max_pool: bool = False replace_stem_max_pool: bool = False
bn_trainable: bool = True
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -127,6 +127,7 @@ class ResNet(tf.keras.Model): ...@@ -127,6 +127,7 @@ class ResNet(tf.keras.Model):
kernel_initializer: str = 'VarianceScaling', kernel_initializer: str = 'VarianceScaling',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bn_trainable: bool = True,
**kwargs): **kwargs):
"""Initializes a ResNet model. """Initializes a ResNet model.
...@@ -153,6 +154,8 @@ class ResNet(tf.keras.Model): ...@@ -153,6 +154,8 @@ class ResNet(tf.keras.Model):
Conv2D. Default to None. Conv2D. Default to None.
bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D. bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
Default to None. Default to None.
bn_trainable: A `bool` that indicates whether batch norm layers should be
trainable. Default to True.
**kwargs: Additional keyword arguments to be passed. **kwargs: Additional keyword arguments to be passed.
""" """
self._model_id = model_id self._model_id = model_id
...@@ -174,6 +177,7 @@ class ResNet(tf.keras.Model): ...@@ -174,6 +177,7 @@ class ResNet(tf.keras.Model):
self._kernel_initializer = kernel_initializer self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer self._bias_regularizer = bias_regularizer
self._bn_trainable = bn_trainable
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = -1 bn_axis = -1
...@@ -195,7 +199,10 @@ class ResNet(tf.keras.Model): ...@@ -195,7 +199,10 @@ class ResNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)( bias_regularizer=self._bias_regularizer)(
inputs) inputs)
x = self._norm( x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
trainable=bn_trainable)(
x) x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x) x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
elif stem_type == 'v1': elif stem_type == 'v1':
...@@ -210,7 +217,10 @@ class ResNet(tf.keras.Model): ...@@ -210,7 +217,10 @@ class ResNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)( bias_regularizer=self._bias_regularizer)(
inputs) inputs)
x = self._norm( x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
trainable=bn_trainable)(
x) x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x) x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
x = layers.Conv2D( x = layers.Conv2D(
...@@ -224,7 +234,10 @@ class ResNet(tf.keras.Model): ...@@ -224,7 +234,10 @@ class ResNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)( bias_regularizer=self._bias_regularizer)(
x) x)
x = self._norm( x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
trainable=bn_trainable)(
x) x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x) x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
x = layers.Conv2D( x = layers.Conv2D(
...@@ -238,7 +251,10 @@ class ResNet(tf.keras.Model): ...@@ -238,7 +251,10 @@ class ResNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)( bias_regularizer=self._bias_regularizer)(
x) x)
x = self._norm( x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
trainable=bn_trainable)(
x) x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x) x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
else: else:
...@@ -256,7 +272,10 @@ class ResNet(tf.keras.Model): ...@@ -256,7 +272,10 @@ class ResNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)( bias_regularizer=self._bias_regularizer)(
x) x)
x = self._norm( x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
trainable=bn_trainable)(
x) x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x) x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
else: else:
...@@ -324,7 +343,8 @@ class ResNet(tf.keras.Model): ...@@ -324,7 +343,8 @@ class ResNet(tf.keras.Model):
activation=self._activation, activation=self._activation,
use_sync_bn=self._use_sync_bn, use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum, norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)( norm_epsilon=self._norm_epsilon,
bn_trainable=self._bn_trainable)(
inputs) inputs)
for _ in range(1, block_repeats): for _ in range(1, block_repeats):
...@@ -341,7 +361,8 @@ class ResNet(tf.keras.Model): ...@@ -341,7 +361,8 @@ class ResNet(tf.keras.Model):
activation=self._activation, activation=self._activation,
use_sync_bn=self._use_sync_bn, use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum, norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)( norm_epsilon=self._norm_epsilon,
bn_trainable=self._bn_trainable)(
x) x)
return tf.keras.layers.Activation('linear', name=name)(x) return tf.keras.layers.Activation('linear', name=name)(x)
...@@ -362,6 +383,7 @@ class ResNet(tf.keras.Model): ...@@ -362,6 +383,7 @@ class ResNet(tf.keras.Model):
'kernel_initializer': self._kernel_initializer, 'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer, 'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer, 'bias_regularizer': self._bias_regularizer,
'bn_trainable': self._bn_trainable
} }
return config_dict return config_dict
...@@ -400,4 +422,5 @@ def build_resnet( ...@@ -400,4 +422,5 @@ def build_resnet(
use_sync_bn=norm_activation_config.use_sync_bn, use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum, norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon, norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer) kernel_regularizer=l2_regularizer,
bn_trainable=backbone_cfg.bn_trainable)
...@@ -135,6 +135,7 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -135,6 +135,7 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase):
kernel_initializer='VarianceScaling', kernel_initializer='VarianceScaling',
kernel_regularizer=None, kernel_regularizer=None,
bias_regularizer=None, bias_regularizer=None,
bn_trainable=True
) )
network = resnet.ResNet(**kwargs) network = resnet.ResNet(**kwargs)
......
...@@ -72,6 +72,7 @@ class ResidualBlock(tf.keras.layers.Layer): ...@@ -72,6 +72,7 @@ class ResidualBlock(tf.keras.layers.Layer):
use_sync_bn=False, use_sync_bn=False,
norm_momentum=0.99, norm_momentum=0.99,
norm_epsilon=0.001, norm_epsilon=0.001,
bn_trainable=True,
**kwargs): **kwargs):
"""Initializes a residual block with BN after convolutions. """Initializes a residual block with BN after convolutions.
...@@ -99,6 +100,8 @@ class ResidualBlock(tf.keras.layers.Layer): ...@@ -99,6 +100,8 @@ class ResidualBlock(tf.keras.layers.Layer):
use_sync_bn: A `bool`. If True, use synchronized batch normalization. use_sync_bn: A `bool`. If True, use synchronized batch normalization.
norm_momentum: A `float` of normalization momentum for the moving average. norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero. norm_epsilon: A `float` added to variance to avoid dividing by zero.
bn_trainable: A `bool` that indicates whether batch norm layers should be
trainable. Default to True.
**kwargs: Additional keyword arguments to be passed. **kwargs: Additional keyword arguments to be passed.
""" """
super(ResidualBlock, self).__init__(**kwargs) super(ResidualBlock, self).__init__(**kwargs)
...@@ -126,6 +129,7 @@ class ResidualBlock(tf.keras.layers.Layer): ...@@ -126,6 +129,7 @@ class ResidualBlock(tf.keras.layers.Layer):
else: else:
self._bn_axis = 1 self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation) self._activation_fn = tf_utils.get_activation(activation)
self._bn_trainable = bn_trainable
def build(self, input_shape): def build(self, input_shape):
if self._use_projection: if self._use_projection:
...@@ -140,7 +144,8 @@ class ResidualBlock(tf.keras.layers.Layer): ...@@ -140,7 +144,8 @@ class ResidualBlock(tf.keras.layers.Layer):
self._norm0 = self._norm( self._norm0 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
self._conv1 = tf.keras.layers.Conv2D( self._conv1 = tf.keras.layers.Conv2D(
filters=self._filters, filters=self._filters,
...@@ -154,7 +159,8 @@ class ResidualBlock(tf.keras.layers.Layer): ...@@ -154,7 +159,8 @@ class ResidualBlock(tf.keras.layers.Layer):
self._norm1 = self._norm( self._norm1 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
self._conv2 = tf.keras.layers.Conv2D( self._conv2 = tf.keras.layers.Conv2D(
filters=self._filters, filters=self._filters,
...@@ -168,7 +174,8 @@ class ResidualBlock(tf.keras.layers.Layer): ...@@ -168,7 +174,8 @@ class ResidualBlock(tf.keras.layers.Layer):
self._norm2 = self._norm( self._norm2 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
if self._se_ratio and self._se_ratio > 0 and self._se_ratio <= 1: if self._se_ratio and self._se_ratio > 0 and self._se_ratio <= 1:
self._squeeze_excitation = nn_layers.SqueezeExcitation( self._squeeze_excitation = nn_layers.SqueezeExcitation(
...@@ -203,7 +210,8 @@ class ResidualBlock(tf.keras.layers.Layer): ...@@ -203,7 +210,8 @@ class ResidualBlock(tf.keras.layers.Layer):
'activation': self._activation, 'activation': self._activation,
'use_sync_bn': self._use_sync_bn, 'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum, 'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon 'norm_epsilon': self._norm_epsilon,
'bn_trainable': self._bn_trainable
} }
base_config = super(ResidualBlock, self).get_config() base_config = super(ResidualBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
...@@ -249,6 +257,7 @@ class BottleneckBlock(tf.keras.layers.Layer): ...@@ -249,6 +257,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
use_sync_bn=False, use_sync_bn=False,
norm_momentum=0.99, norm_momentum=0.99,
norm_epsilon=0.001, norm_epsilon=0.001,
bn_trainable=True,
**kwargs): **kwargs):
"""Initializes a standard bottleneck block with BN after convolutions. """Initializes a standard bottleneck block with BN after convolutions.
...@@ -277,6 +286,8 @@ class BottleneckBlock(tf.keras.layers.Layer): ...@@ -277,6 +286,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
use_sync_bn: A `bool`. If True, use synchronized batch normalization. use_sync_bn: A `bool`. If True, use synchronized batch normalization.
norm_momentum: A `float` of normalization momentum for the moving average. norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero. norm_epsilon: A `float` added to variance to avoid dividing by zero.
bn_trainable: A `bool` that indicates whether batch norm layers should be
trainable. Default to True.
**kwargs: Additional keyword arguments to be passed. **kwargs: Additional keyword arguments to be passed.
""" """
super(BottleneckBlock, self).__init__(**kwargs) super(BottleneckBlock, self).__init__(**kwargs)
...@@ -303,6 +314,7 @@ class BottleneckBlock(tf.keras.layers.Layer): ...@@ -303,6 +314,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
self._bn_axis = -1 self._bn_axis = -1
else: else:
self._bn_axis = 1 self._bn_axis = 1
self._bn_trainable = bn_trainable
def build(self, input_shape): def build(self, input_shape):
if self._use_projection: if self._use_projection:
...@@ -330,7 +342,8 @@ class BottleneckBlock(tf.keras.layers.Layer): ...@@ -330,7 +342,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
self._norm0 = self._norm( self._norm0 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
self._conv1 = tf.keras.layers.Conv2D( self._conv1 = tf.keras.layers.Conv2D(
filters=self._filters, filters=self._filters,
...@@ -343,7 +356,8 @@ class BottleneckBlock(tf.keras.layers.Layer): ...@@ -343,7 +356,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
self._norm1 = self._norm( self._norm1 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
self._activation1 = tf_utils.get_activation( self._activation1 = tf_utils.get_activation(
self._activation, use_keras_layer=True) self._activation, use_keras_layer=True)
...@@ -360,7 +374,8 @@ class BottleneckBlock(tf.keras.layers.Layer): ...@@ -360,7 +374,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
self._norm2 = self._norm( self._norm2 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
self._activation2 = tf_utils.get_activation( self._activation2 = tf_utils.get_activation(
self._activation, use_keras_layer=True) self._activation, use_keras_layer=True)
...@@ -375,7 +390,8 @@ class BottleneckBlock(tf.keras.layers.Layer): ...@@ -375,7 +390,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
self._norm3 = self._norm( self._norm3 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
self._activation3 = tf_utils.get_activation( self._activation3 = tf_utils.get_activation(
self._activation, use_keras_layer=True) self._activation, use_keras_layer=True)
...@@ -414,7 +430,8 @@ class BottleneckBlock(tf.keras.layers.Layer): ...@@ -414,7 +430,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
'activation': self._activation, 'activation': self._activation,
'use_sync_bn': self._use_sync_bn, 'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum, 'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon 'norm_epsilon': self._norm_epsilon,
'bn_trainable': self._bn_trainable
} }
base_config = super(BottleneckBlock, self).get_config() base_config = super(BottleneckBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment