Commit 0326425d authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 357209234
parent 2a370fe0
......@@ -30,6 +30,8 @@ class ResNet(hyperparams.Config):
stem_type: str = 'v0'
se_ratio: float = 0.0
stochastic_depth_drop_rate: float = 0.0
resnetd_shortcut: bool = False
replace_stem_max_pool: bool = False
@dataclasses.dataclass
......
......@@ -162,6 +162,78 @@ def image_classification_imagenet() -> cfg.ExperimentConfig:
return config
@exp_factory.register_config_factory('resnet_rs_imagenet')
def image_classification_imagenet_resnetrs() -> cfg.ExperimentConfig:
"""Image classification on imagenet with resnet-rs."""
train_batch_size = 4096
eval_batch_size = 4096
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
config = cfg.ExperimentConfig(
task=ImageClassificationTask(
model=ImageClassificationModel(
num_classes=1001,
input_size=[160, 160, 3],
backbone=backbones.Backbone(
type='resnet',
resnet=backbones.ResNet(
model_id=50,
stem_type='v1',
resnetd_shortcut=True,
replace_stem_max_pool=True,
se_ratio=0.25,
stochastic_depth_drop_rate=0.0)),
dropout_rate=0.25,
norm_activation=common.NormActivation(
norm_momentum=0.0, norm_epsilon=1e-5, use_sync_bn=False)),
losses=Losses(l2_weight_decay=4e-5, label_smoothing=0.1),
train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size,
aug_policy='randaug'),
validation_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
is_training=False,
global_batch_size=eval_batch_size)),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=360 * steps_per_epoch,
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'sgd',
'sgd': {
'momentum': 0.9
}
},
'ema': {
'average_decay': 0.9999
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 0.1,
'decay_steps': 360 * steps_per_epoch
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 5 * steps_per_epoch,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('revnet_imagenet')
def image_classification_imagenet_revnet() -> cfg.ExperimentConfig:
"""Returns a revnet config for image classification on imagenet."""
......
......@@ -26,9 +26,12 @@ from official.vision.beta.configs import image_classification as exp_cfg
class ImageClassificationConfigTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(('resnet_imagenet',),
('revnet_imagenet',),
('mobilenet_imagenet'),)
@parameterized.parameters(
('resnet_imagenet',),
('resnet_rs_imagenet',),
('revnet_imagenet',),
('mobilenet_imagenet'),
)
def test_image_classification_configs(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
......
......@@ -69,10 +69,10 @@ RESNET_SPECS = {
('bottleneck', 256, 36),
('bottleneck', 512, 3),
],
300: [
270: [
('bottleneck', 64, 4),
('bottleneck', 128, 36),
('bottleneck', 256, 54),
('bottleneck', 128, 29),
('bottleneck', 256, 53),
('bottleneck', 512, 4),
],
350: [
......@@ -81,6 +81,12 @@ RESNET_SPECS = {
('bottleneck', 256, 72),
('bottleneck', 512, 4),
],
420: [
('bottleneck', 64, 4),
('bottleneck', 128, 44),
('bottleneck', 256, 87),
('bottleneck', 512, 4),
],
}
......@@ -93,6 +99,8 @@ class ResNet(tf.keras.Model):
input_specs=layers.InputSpec(shape=[None, None, None, 3]),
depth_multiplier=1.0,
stem_type='v0',
resnetd_shortcut=False,
replace_stem_max_pool=False,
se_ratio=None,
init_stochastic_depth_rate=0.0,
activation='relu',
......@@ -111,7 +119,11 @@ class ResNet(tf.keras.Model):
depth_multiplier: `float` a depth multiplier to uniformaly scale up all
layers in channel size in ResNet.
stem_type: `str` stem type of ResNet. Default to `v0`. If set to `v1`,
use ResNet-C type stem (https://arxiv.org/abs/1812.01187).
use ResNet-D type stem (https://arxiv.org/abs/1812.01187).
resnetd_shortcut: `bool` whether to use ResNet-D shortcut in downsampling
blocks.
replace_stem_max_pool: `bool` if True, replace the max pool in stem with
a stride-2 conv,
se_ratio: `float` or None. Ratio of the Squeeze-and-Excitation layer.
init_stochastic_depth_rate: `float` initial stochastic depth rate.
activation: `str` name of the activation function.
......@@ -130,6 +142,8 @@ class ResNet(tf.keras.Model):
self._input_specs = input_specs
self._depth_multiplier = depth_multiplier
self._stem_type = stem_type
self._resnetd_shortcut = resnetd_shortcut
self._replace_stem_max_pool = replace_stem_max_pool
self._se_ratio = se_ratio
self._init_stochastic_depth_rate = init_stochastic_depth_rate
self._use_sync_bn = use_sync_bn
......@@ -213,7 +227,23 @@ class ResNet(tf.keras.Model):
else:
raise ValueError('Stem type {} not supported.'.format(stem_type))
x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)
if replace_stem_max_pool:
x = layers.Conv2D(
filters=int(64 * self._depth_multiplier),
kernel_size=3,
strides=2,
use_bias=False,
padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
x)
x = tf_utils.get_activation(activation)(x)
else:
x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)
endpoints = {}
for i, spec in enumerate(RESNET_SPECS[model_id]):
......@@ -267,6 +297,7 @@ class ResNet(tf.keras.Model):
use_projection=True,
stochastic_depth_drop_rate=stochastic_depth_drop_rate,
se_ratio=self._se_ratio,
resnetd_shortcut=self._resnetd_shortcut,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
......@@ -283,6 +314,7 @@ class ResNet(tf.keras.Model):
use_projection=False,
stochastic_depth_drop_rate=stochastic_depth_drop_rate,
se_ratio=self._se_ratio,
resnetd_shortcut=self._resnetd_shortcut,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
......@@ -299,6 +331,8 @@ class ResNet(tf.keras.Model):
'model_id': self._model_id,
'depth_multiplier': self._depth_multiplier,
'stem_type': self._stem_type,
'resnetd_shortcut': self._resnetd_shortcut,
'replace_stem_max_pool': self._replace_stem_max_pool,
'activation': self._activation,
'se_ratio': self._se_ratio,
'init_stochastic_depth_rate': self._init_stochastic_depth_rate,
......@@ -338,6 +372,8 @@ def build_resnet(
input_specs=input_specs,
depth_multiplier=backbone_cfg.depth_multiplier,
stem_type=backbone_cfg.stem_type,
resnetd_shortcut=backbone_cfg.resnetd_shortcut,
replace_stem_max_pool=backbone_cfg.replace_stem_max_pool,
se_ratio=backbone_cfg.se_ratio,
init_stochastic_depth_rate=backbone_cfg.stochastic_depth_drop_rate,
activation=norm_activation_config.activation,
......
......@@ -84,20 +84,22 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase):
_ = network(inputs)
@parameterized.parameters(
(128, 34, 1, 'v0', None, 0.0, 1.0),
(128, 34, 1, 'v1', 0.25, 0.2, 1.25),
(128, 50, 4, 'v0', None, 0.0, 1.5),
(128, 50, 4, 'v1', 0.25, 0.2, 2.0),
(128, 34, 1, 'v0', None, 0.0, 1.0, False, False),
(128, 34, 1, 'v1', 0.25, 0.2, 1.25, True, True),
(128, 50, 4, 'v0', None, 0.0, 1.5, False, False),
(128, 50, 4, 'v1', 0.25, 0.2, 2.0, True, True),
)
def test_resnet_addons(self, input_size, model_id, endpoint_filter_scale,
stem_type, se_ratio, init_stochastic_depth_rate,
depth_multiplier):
def test_resnet_rs(self, input_size, model_id, endpoint_filter_scale,
stem_type, se_ratio, init_stochastic_depth_rate,
depth_multiplier, resnetd_shortcut, replace_stem_max_pool):
"""Test creation of ResNet family models."""
tf.keras.backend.set_image_data_format('channels_last')
network = resnet.ResNet(
model_id=model_id,
depth_multiplier=depth_multiplier,
stem_type=stem_type,
resnetd_shortcut=resnetd_shortcut,
replace_stem_max_pool=replace_stem_max_pool,
se_ratio=se_ratio,
init_stochastic_depth_rate=init_stochastic_depth_rate)
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
......@@ -121,6 +123,8 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase):
depth_multiplier=1.0,
stem_type='v0',
se_ratio=None,
resnetd_shortcut=False,
replace_stem_max_pool=False,
init_stochastic_depth_rate=0.0,
use_sync_bn=False,
activation='relu',
......
......@@ -63,6 +63,7 @@ class ResidualBlock(tf.keras.layers.Layer):
strides,
use_projection=False,
se_ratio=None,
resnetd_shortcut=False,
stochastic_depth_drop_rate=None,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
......@@ -84,6 +85,8 @@ class ResidualBlock(tf.keras.layers.Layer):
for the first block of a block group, which may change the number of
filters and the resolution.
se_ratio: `float` or None. Ratio of the Squeeze-and-Excitation layer.
resnetd_shortcut: `bool` if True, apply the resnetd style modification to
the shortcut connection. Not implemented in residual blocks.
stochastic_depth_drop_rate: `float` or None. if not None, drop rate for
the stochastic depth layer.
kernel_initializer: kernel_initializer for convolutional layers.
......@@ -104,6 +107,7 @@ class ResidualBlock(tf.keras.layers.Layer):
self._strides = strides
self._use_projection = use_projection
self._se_ratio = se_ratio
self._resnetd_shortcut = resnetd_shortcut
self._use_sync_bn = use_sync_bn
self._activation = activation
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
......@@ -191,6 +195,7 @@ class ResidualBlock(tf.keras.layers.Layer):
'strides': self._strides,
'use_projection': self._use_projection,
'se_ratio': self._se_ratio,
'resnetd_shortcut': self._resnetd_shortcut,
'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
......@@ -235,6 +240,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
dilation_rate=1,
use_projection=False,
se_ratio=None,
resnetd_shortcut=False,
stochastic_depth_drop_rate=None,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
......@@ -257,6 +263,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
for the first block of a block group, which may change the number of
filters and the resolution.
se_ratio: `float` or None. Ratio of the Squeeze-and-Excitation layer.
resnetd_shortcut: `bool` if True, apply the resnetd style modification to
the shortcut connection.
stochastic_depth_drop_rate: `float` or None. if not None, drop rate for
the stochastic depth layer.
kernel_initializer: kernel_initializer for convolutional layers.
......@@ -278,6 +286,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
self._dilation_rate = dilation_rate
self._use_projection = use_projection
self._se_ratio = se_ratio
self._resnetd_shortcut = resnetd_shortcut
self._use_sync_bn = use_sync_bn
self._activation = activation
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
......@@ -298,14 +307,27 @@ class BottleneckBlock(tf.keras.layers.Layer):
def build(self, input_shape):
if self._use_projection:
self._shortcut = tf.keras.layers.Conv2D(
filters=self._filters * 4,
kernel_size=1,
strides=self._strides,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
if self._resnetd_shortcut:
self._shortcut0 = tf.keras.layers.AveragePooling2D(
pool_size=2, strides=self._strides, padding='same')
self._shortcut1 = tf.keras.layers.Conv2D(
filters=self._filters * 4,
kernel_size=1,
strides=1,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
else:
self._shortcut = tf.keras.layers.Conv2D(
filters=self._filters * 4,
kernel_size=1,
strides=self._strides,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
......@@ -378,6 +400,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
'dilation_rate': self._dilation_rate,
'use_projection': self._use_projection,
'se_ratio': self._se_ratio,
'resnetd_shortcut': self._resnetd_shortcut,
'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
......@@ -393,7 +416,11 @@ class BottleneckBlock(tf.keras.layers.Layer):
def call(self, inputs, training=None):
shortcut = inputs
if self._use_projection:
shortcut = self._shortcut(shortcut)
if self._resnetd_shortcut:
shortcut = self._shortcut0(shortcut)
shortcut = self._shortcut1(shortcut)
else:
shortcut = self._shortcut(shortcut)
shortcut = self._norm0(shortcut)
x = self._conv1(inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment