Commit 29461a63 authored by Xianzhi Du's avatar Xianzhi Du Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 394775644
parent 2f43cff2
...@@ -30,6 +30,7 @@ class ResNet(hyperparams.Config): ...@@ -30,6 +30,7 @@ class ResNet(hyperparams.Config):
stem_type: str = 'v0' stem_type: str = 'v0'
se_ratio: float = 0.0 se_ratio: float = 0.0
stochastic_depth_drop_rate: float = 0.0 stochastic_depth_drop_rate: float = 0.0
scale_stem: bool = True
resnetd_shortcut: bool = False resnetd_shortcut: bool = False
replace_stem_max_pool: bool = False replace_stem_max_pool: bool = False
bn_trainable: bool = True bn_trainable: bool = True
......
...@@ -120,6 +120,7 @@ class ResNet(tf.keras.Model): ...@@ -120,6 +120,7 @@ class ResNet(tf.keras.Model):
replace_stem_max_pool: bool = False, replace_stem_max_pool: bool = False,
se_ratio: Optional[float] = None, se_ratio: Optional[float] = None,
init_stochastic_depth_rate: float = 0.0, init_stochastic_depth_rate: float = 0.0,
scale_stem: bool = True,
activation: str = 'relu', activation: str = 'relu',
use_sync_bn: bool = False, use_sync_bn: bool = False,
norm_momentum: float = 0.99, norm_momentum: float = 0.99,
...@@ -145,6 +146,7 @@ class ResNet(tf.keras.Model): ...@@ -145,6 +146,7 @@ class ResNet(tf.keras.Model):
with a stride-2 conv, with a stride-2 conv,
se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer. se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer.
init_stochastic_depth_rate: A `float` of initial stochastic depth rate. init_stochastic_depth_rate: A `float` of initial stochastic depth rate.
scale_stem: A `bool` of whether to scale stem layers.
activation: A `str` name of the activation function. activation: A `str` name of the activation function.
use_sync_bn: If True, use synchronized batch normalization. use_sync_bn: If True, use synchronized batch normalization.
norm_momentum: A `float` of normalization momentum for the moving average. norm_momentum: A `float` of normalization momentum for the moving average.
...@@ -166,6 +168,7 @@ class ResNet(tf.keras.Model): ...@@ -166,6 +168,7 @@ class ResNet(tf.keras.Model):
self._replace_stem_max_pool = replace_stem_max_pool self._replace_stem_max_pool = replace_stem_max_pool
self._se_ratio = se_ratio self._se_ratio = se_ratio
self._init_stochastic_depth_rate = init_stochastic_depth_rate self._init_stochastic_depth_rate = init_stochastic_depth_rate
self._scale_stem = scale_stem
self._use_sync_bn = use_sync_bn self._use_sync_bn = use_sync_bn
self._activation = activation self._activation = activation
self._norm_momentum = norm_momentum self._norm_momentum = norm_momentum
...@@ -187,9 +190,10 @@ class ResNet(tf.keras.Model): ...@@ -187,9 +190,10 @@ class ResNet(tf.keras.Model):
# Build ResNet. # Build ResNet.
inputs = tf.keras.Input(shape=input_specs.shape[1:]) inputs = tf.keras.Input(shape=input_specs.shape[1:])
stem_depth_multiplier = self._depth_multiplier if scale_stem else 1.0
if stem_type == 'v0': if stem_type == 'v0':
x = layers.Conv2D( x = layers.Conv2D(
filters=int(64 * self._depth_multiplier), filters=int(64 * stem_depth_multiplier),
kernel_size=7, kernel_size=7,
strides=2, strides=2,
use_bias=False, use_bias=False,
...@@ -207,7 +211,7 @@ class ResNet(tf.keras.Model): ...@@ -207,7 +211,7 @@ class ResNet(tf.keras.Model):
x = tf_utils.get_activation(activation, use_keras_layer=True)(x) x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
elif stem_type == 'v1': elif stem_type == 'v1':
x = layers.Conv2D( x = layers.Conv2D(
filters=int(32 * self._depth_multiplier), filters=int(32 * stem_depth_multiplier),
kernel_size=3, kernel_size=3,
strides=2, strides=2,
use_bias=False, use_bias=False,
...@@ -224,7 +228,7 @@ class ResNet(tf.keras.Model): ...@@ -224,7 +228,7 @@ class ResNet(tf.keras.Model):
x) x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x) x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
x = layers.Conv2D( x = layers.Conv2D(
filters=int(32 * self._depth_multiplier), filters=int(32 * stem_depth_multiplier),
kernel_size=3, kernel_size=3,
strides=1, strides=1,
use_bias=False, use_bias=False,
...@@ -241,7 +245,7 @@ class ResNet(tf.keras.Model): ...@@ -241,7 +245,7 @@ class ResNet(tf.keras.Model):
x) x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x) x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
x = layers.Conv2D( x = layers.Conv2D(
filters=int(64 * self._depth_multiplier), filters=int(64 * stem_depth_multiplier),
kernel_size=3, kernel_size=3,
strides=1, strides=1,
use_bias=False, use_bias=False,
...@@ -377,6 +381,7 @@ class ResNet(tf.keras.Model): ...@@ -377,6 +381,7 @@ class ResNet(tf.keras.Model):
'activation': self._activation, 'activation': self._activation,
'se_ratio': self._se_ratio, 'se_ratio': self._se_ratio,
'init_stochastic_depth_rate': self._init_stochastic_depth_rate, 'init_stochastic_depth_rate': self._init_stochastic_depth_rate,
'scale_stem': self._scale_stem,
'use_sync_bn': self._use_sync_bn, 'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum, 'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon, 'norm_epsilon': self._norm_epsilon,
...@@ -418,6 +423,7 @@ def build_resnet( ...@@ -418,6 +423,7 @@ def build_resnet(
replace_stem_max_pool=backbone_cfg.replace_stem_max_pool, replace_stem_max_pool=backbone_cfg.replace_stem_max_pool,
se_ratio=backbone_cfg.se_ratio, se_ratio=backbone_cfg.se_ratio,
init_stochastic_depth_rate=backbone_cfg.stochastic_depth_drop_rate, init_stochastic_depth_rate=backbone_cfg.stochastic_depth_drop_rate,
scale_stem=backbone_cfg.scale_stem,
activation=norm_activation_config.activation, activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn, use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum, norm_momentum=norm_activation_config.norm_momentum,
......
...@@ -128,6 +128,7 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -128,6 +128,7 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase):
resnetd_shortcut=False, resnetd_shortcut=False,
replace_stem_max_pool=False, replace_stem_max_pool=False,
init_stochastic_depth_rate=0.0, init_stochastic_depth_rate=0.0,
scale_stem=True,
use_sync_bn=False, use_sync_bn=False,
activation='relu', activation='relu',
norm_momentum=0.99, norm_momentum=0.99,
...@@ -135,8 +136,7 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -135,8 +136,7 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase):
kernel_initializer='VarianceScaling', kernel_initializer='VarianceScaling',
kernel_regularizer=None, kernel_regularizer=None,
bias_regularizer=None, bias_regularizer=None,
bn_trainable=True bn_trainable=True)
)
network = resnet.ResNet(**kwargs) network = resnet.ResNet(**kwargs)
expected_config = dict(kwargs) expected_config = dict(kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment