Commit bfcf684e authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 454634340
parent 4cd48ecc
......@@ -45,6 +45,8 @@ class DilatedResNet(hyperparams.Config):
last_stage_repeats: int = 1
se_ratio: float = 0.0
stochastic_depth_drop_rate: float = 0.0
resnetd_shortcut: bool = False
replace_stem_max_pool: bool = False
@dataclasses.dataclass
......
......@@ -75,6 +75,8 @@ class DilatedResNet(tf.keras.Model):
input_specs: tf.keras.layers.InputSpec = layers.InputSpec(
shape=[None, None, None, 3]),
stem_type: str = 'v0',
resnetd_shortcut: bool = False,
replace_stem_max_pool: bool = False,
se_ratio: Optional[float] = None,
init_stochastic_depth_rate: float = 0.0,
multigrid: Optional[Tuple[int]] = None,
......@@ -96,6 +98,10 @@ class DilatedResNet(tf.keras.Model):
input_specs: A `tf.keras.layers.InputSpec` of the input tensor.
stem_type: A `str` of stem type. Can be `v0` or `v1`. `v1` replaces 7x7
conv by 3 3x3 convs.
resnetd_shortcut: A `bool` of whether to use ResNet-D shortcut in
downsampling blocks.
replace_stem_max_pool: A `bool` of whether to replace the max pool in stem
with a stride-2 conv,
se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer.
init_stochastic_depth_rate: A `float` of initial stochastic depth rate.
multigrid: A tuple of the same length as the number of blocks in the last
......@@ -128,6 +134,8 @@ class DilatedResNet(tf.keras.Model):
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._stem_type = stem_type
self._resnetd_shortcut = resnetd_shortcut
self._replace_stem_max_pool = replace_stem_max_pool
self._se_ratio = se_ratio
self._init_stochastic_depth_rate = init_stochastic_depth_rate
......@@ -200,6 +208,22 @@ class DilatedResNet(tf.keras.Model):
else:
raise ValueError('Stem type {} not supported.'.format(stem_type))
if replace_stem_max_pool:
x = layers.Conv2D(
filters=64,
kernel_size=3,
strides=2,
use_bias=False,
padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
else:
x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)
normal_resnet_stage = int(np.math.log2(self._output_stride)) - 2
......@@ -296,6 +320,7 @@ class DilatedResNet(tf.keras.Model):
use_projection=True,
stochastic_depth_drop_rate=stochastic_depth_drop_rate,
se_ratio=self._se_ratio,
resnetd_shortcut=self._resnetd_shortcut,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
......@@ -311,6 +336,7 @@ class DilatedResNet(tf.keras.Model):
dilation_rate=dilation_rate * multigrid[i],
use_projection=False,
stochastic_depth_drop_rate=stochastic_depth_drop_rate,
resnetd_shortcut=self._resnetd_shortcut,
se_ratio=self._se_ratio,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
......@@ -328,6 +354,8 @@ class DilatedResNet(tf.keras.Model):
'model_id': self._model_id,
'output_stride': self._output_stride,
'stem_type': self._stem_type,
'resnetd_shortcut': self._resnetd_shortcut,
'replace_stem_max_pool': self._replace_stem_max_pool,
'se_ratio': self._se_ratio,
'init_stochastic_depth_rate': self._init_stochastic_depth_rate,
'activation': self._activation,
......@@ -367,6 +395,8 @@ def build_dilated_resnet(
output_stride=backbone_cfg.output_stride,
input_specs=input_specs,
stem_type=backbone_cfg.stem_type,
resnetd_shortcut=backbone_cfg.resnetd_shortcut,
replace_stem_max_pool=backbone_cfg.replace_stem_max_pool,
se_ratio=backbone_cfg.se_ratio,
init_stochastic_depth_rate=backbone_cfg.stochastic_depth_drop_rate,
multigrid=backbone_cfg.multigrid,
......
......@@ -52,13 +52,17 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase):
], endpoints[str(int(np.math.log2(output_stride)))].shape.as_list())
@parameterized.parameters(
('v0', None, 0.0),
('v1', None, 0.0),
('v1', 0.25, 0.0),
('v1', 0.25, 0.2),
('v0', None, 0.0, False, False),
('v1', None, 0.0, False, False),
('v1', 0.25, 0.0, False, False),
('v1', 0.25, 0.2, False, False),
('v1', 0.25, 0.0, True, False),
('v1', 0.25, 0.2, False, True),
('v1', None, 0.2, True, True),
)
def test_network_features(self, stem_type, se_ratio,
init_stochastic_depth_rate):
init_stochastic_depth_rate, resnetd_shortcut,
replace_stem_max_pool):
"""Test additional features of ResNet models."""
input_size = 128
model_id = 50
......@@ -71,6 +75,8 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase):
model_id=model_id,
output_stride=output_stride,
stem_type=stem_type,
resnetd_shortcut=resnetd_shortcut,
replace_stem_max_pool=replace_stem_max_pool,
se_ratio=se_ratio,
init_stochastic_depth_rate=init_stochastic_depth_rate)
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
......@@ -120,6 +126,8 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase):
stem_type='v0',
se_ratio=0.25,
init_stochastic_depth_rate=0.2,
resnetd_shortcut=False,
replace_stem_max_pool=False,
use_sync_bn=False,
activation='relu',
norm_momentum=0.99,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment