"docs/en/_static/image" did not exist on "d7117b95ab120230bb7dc6e69c7c4c800397fcbf"
Commit 39774bc8 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 341852981
parent c061dace
...@@ -28,6 +28,7 @@ class ResNet(hyperparams.Config): ...@@ -28,6 +28,7 @@ class ResNet(hyperparams.Config):
model_id: int = 50 model_id: int = 50
stem_type: str = 'v0' stem_type: str = 'v0'
se_ratio: float = 0.0 se_ratio: float = 0.0
stochastic_depth_drop_rate: float = 0.0
@dataclasses.dataclass @dataclasses.dataclass
...@@ -41,8 +42,8 @@ class DilatedResNet(hyperparams.Config): ...@@ -41,8 +42,8 @@ class DilatedResNet(hyperparams.Config):
class EfficientNet(hyperparams.Config): class EfficientNet(hyperparams.Config):
"""EfficientNet config.""" """EfficientNet config."""
model_id: str = 'b0' model_id: str = 'b0'
stochastic_depth_drop_rate: float = 0.0
se_ratio: float = 0.0 se_ratio: float = 0.0
stochastic_depth_drop_rate: float = 0.0
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -24,6 +24,7 @@ import tensorflow as tf ...@@ -24,6 +24,7 @@ import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.beta.modeling.backbones import factory from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_blocks from official.vision.beta.modeling.layers import nn_blocks
from official.vision.beta.modeling.layers import nn_layers
layers = tf.keras.layers layers = tf.keras.layers
...@@ -80,6 +81,7 @@ class ResNet(tf.keras.Model): ...@@ -80,6 +81,7 @@ class ResNet(tf.keras.Model):
input_specs=layers.InputSpec(shape=[None, None, None, 3]), input_specs=layers.InputSpec(shape=[None, None, None, 3]),
stem_type='v0', stem_type='v0',
se_ratio=None, se_ratio=None,
init_stochastic_depth_rate=0.0,
activation='relu', activation='relu',
use_sync_bn=False, use_sync_bn=False,
norm_momentum=0.99, norm_momentum=0.99,
...@@ -96,6 +98,7 @@ class ResNet(tf.keras.Model): ...@@ -96,6 +98,7 @@ class ResNet(tf.keras.Model):
stem_type: `str` stem type of ResNet. Default to `v0`. If set to `v1`, stem_type: `str` stem type of ResNet. Default to `v0`. If set to `v1`,
use ResNet-C type stem (https://arxiv.org/abs/1812.01187). use ResNet-C type stem (https://arxiv.org/abs/1812.01187).
se_ratio: `float` or None. Ratio of the Squeeze-and-Excitation layer. se_ratio: `float` or None. Ratio of the Squeeze-and-Excitation layer.
init_stochastic_depth_rate: `float` initial stochastic depth rate.
activation: `str` name of the activation function. activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization. use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average. norm_momentum: `float` normalization omentum for the moving average.
...@@ -112,6 +115,7 @@ class ResNet(tf.keras.Model): ...@@ -112,6 +115,7 @@ class ResNet(tf.keras.Model):
self._input_specs = input_specs self._input_specs = input_specs
self._stem_type = stem_type self._stem_type = stem_type
self._se_ratio = se_ratio self._se_ratio = se_ratio
self._init_stochastic_depth_rate = init_stochastic_depth_rate
self._use_sync_bn = use_sync_bn self._use_sync_bn = use_sync_bn
self._activation = activation self._activation = activation
self._norm_momentum = norm_momentum self._norm_momentum = norm_momentum
...@@ -195,7 +199,6 @@ class ResNet(tf.keras.Model): ...@@ -195,7 +199,6 @@ class ResNet(tf.keras.Model):
x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x) x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)
# TODO(xianzhi): keep a list of blocks to make blocks accessible.
endpoints = {} endpoints = {}
for i, spec in enumerate(RESNET_SPECS[model_id]): for i, spec in enumerate(RESNET_SPECS[model_id]):
if spec[0] == 'residual': if spec[0] == 'residual':
...@@ -210,6 +213,8 @@ class ResNet(tf.keras.Model): ...@@ -210,6 +213,8 @@ class ResNet(tf.keras.Model):
strides=(1 if i == 0 else 2), strides=(1 if i == 0 else 2),
block_fn=block_fn, block_fn=block_fn,
block_repeats=spec[2], block_repeats=spec[2],
stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
self._init_stochastic_depth_rate, i + 2, 5),
name='block_group_l{}'.format(i + 2)) name='block_group_l{}'.format(i + 2))
endpoints[str(i + 2)] = x endpoints[str(i + 2)] = x
...@@ -223,6 +228,7 @@ class ResNet(tf.keras.Model): ...@@ -223,6 +228,7 @@ class ResNet(tf.keras.Model):
strides, strides,
block_fn, block_fn,
block_repeats=1, block_repeats=1,
stochastic_depth_drop_rate=0.0,
name='block_group'): name='block_group'):
"""Creates one group of blocks for the ResNet model. """Creates one group of blocks for the ResNet model.
...@@ -233,6 +239,7 @@ class ResNet(tf.keras.Model): ...@@ -233,6 +239,7 @@ class ResNet(tf.keras.Model):
greater than 1, this layer will downsample the input. greater than 1, this layer will downsample the input.
block_fn: Either `nn_blocks.ResidualBlock` or `nn_blocks.BottleneckBlock`. block_fn: Either `nn_blocks.ResidualBlock` or `nn_blocks.BottleneckBlock`.
block_repeats: `int` number of blocks contained in the layer. block_repeats: `int` number of blocks contained in the layer.
stochastic_depth_drop_rate: `float` drop rate of the current block group.
name: `str`name for the block. name: `str`name for the block.
Returns: Returns:
...@@ -242,6 +249,7 @@ class ResNet(tf.keras.Model): ...@@ -242,6 +249,7 @@ class ResNet(tf.keras.Model):
filters=filters, filters=filters,
strides=strides, strides=strides,
use_projection=True, use_projection=True,
stochastic_depth_drop_rate=stochastic_depth_drop_rate,
se_ratio=self._se_ratio, se_ratio=self._se_ratio,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
...@@ -257,6 +265,7 @@ class ResNet(tf.keras.Model): ...@@ -257,6 +265,7 @@ class ResNet(tf.keras.Model):
filters=filters, filters=filters,
strides=1, strides=1,
use_projection=False, use_projection=False,
stochastic_depth_drop_rate=stochastic_depth_drop_rate,
se_ratio=self._se_ratio, se_ratio=self._se_ratio,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
...@@ -275,6 +284,7 @@ class ResNet(tf.keras.Model): ...@@ -275,6 +284,7 @@ class ResNet(tf.keras.Model):
'stem_type': self._stem_type, 'stem_type': self._stem_type,
'activation': self._activation, 'activation': self._activation,
'se_ratio': self._se_ratio, 'se_ratio': self._se_ratio,
'init_stochastic_depth_rate': self._init_stochastic_depth_rate,
'use_sync_bn': self._use_sync_bn, 'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum, 'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon, 'norm_epsilon': self._norm_epsilon,
...@@ -311,6 +321,7 @@ def build_resnet( ...@@ -311,6 +321,7 @@ def build_resnet(
input_specs=input_specs, input_specs=input_specs,
stem_type=backbone_cfg.stem_type, stem_type=backbone_cfg.stem_type,
se_ratio=backbone_cfg.se_ratio, se_ratio=backbone_cfg.se_ratio,
init_stochastic_depth_rate=backbone_cfg.stochastic_depth_drop_rate,
activation=norm_activation_config.activation, activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn, use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum, norm_momentum=norm_activation_config.norm_momentum,
......
...@@ -84,17 +84,20 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -84,17 +84,20 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase):
_ = network(inputs) _ = network(inputs)
@parameterized.parameters( @parameterized.parameters(
(128, 34, 1, 'v0', None), (128, 34, 1, 'v0', None, 0.0),
(128, 34, 1, 'v1', 0.25), (128, 34, 1, 'v1', 0.25, 0.2),
(128, 50, 4, 'v0', None), (128, 50, 4, 'v0', None, 0.0),
(128, 50, 4, 'v1', 0.25), (128, 50, 4, 'v1', 0.25, 0.2),
) )
def test_resnet_addons(self, input_size, model_id, endpoint_filter_scale, def test_resnet_addons(self, input_size, model_id, endpoint_filter_scale,
stem_type, se_ratio): stem_type, se_ratio, init_stochastic_depth_rate):
"""Test creation of ResNet family models.""" """Test creation of ResNet family models."""
tf.keras.backend.set_image_data_format('channels_last') tf.keras.backend.set_image_data_format('channels_last')
network = resnet.ResNet( network = resnet.ResNet(
model_id=model_id, stem_type=stem_type, se_ratio=se_ratio) model_id=model_id,
stem_type=stem_type,
se_ratio=se_ratio,
init_stochastic_depth_rate=init_stochastic_depth_rate)
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1) inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
_ = network(inputs) _ = network(inputs)
...@@ -115,6 +118,7 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -115,6 +118,7 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase):
model_id=50, model_id=50,
stem_type='v0', stem_type='v0',
se_ratio=None, se_ratio=None,
init_stochastic_depth_rate=0.0,
use_sync_bn=False, use_sync_bn=False,
activation='relu', activation='relu',
norm_momentum=0.99, norm_momentum=0.99,
......
...@@ -27,6 +27,7 @@ import tensorflow as tf ...@@ -27,6 +27,7 @@ import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.beta.modeling.backbones import factory from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_blocks from official.vision.beta.modeling.layers import nn_blocks
from official.vision.beta.modeling.layers import nn_layers
from official.vision.beta.ops import spatial_transform_ops from official.vision.beta.ops import spatial_transform_ops
layers = tf.keras.layers layers = tf.keras.layers
...@@ -114,17 +115,6 @@ def build_block_specs(block_specs=None): ...@@ -114,17 +115,6 @@ def build_block_specs(block_specs=None):
return [BlockSpec(*b) for b in block_specs] return [BlockSpec(*b) for b in block_specs]
def get_stochastic_depth_rate(init_rate, i, n):
"""Get drop connect rate for the ith block."""
if init_rate is not None:
if init_rate < 0 or init_rate > 1:
raise ValueError('Initial drop rate must be within 0 and 1.')
dc_rate = init_rate * float(i + 1) / n
else:
dc_rate = None
return dc_rate
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
class SpineNet(tf.keras.Model): class SpineNet(tf.keras.Model):
"""Class to build SpineNet models.""" """Class to build SpineNet models."""
...@@ -350,8 +340,8 @@ class SpineNet(tf.keras.Model): ...@@ -350,8 +340,8 @@ class SpineNet(tf.keras.Model):
strides=1, strides=1,
block_fn_cand=target_block_fn, block_fn_cand=target_block_fn,
block_repeats=self._block_repeats, block_repeats=self._block_repeats,
stochastic_depth_drop_rate=get_stochastic_depth_rate( stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
self._init_stochastic_depth_rate, i, len(self._block_specs)), self._init_stochastic_depth_rate, i + 1, len(self._block_specs)),
name='scale_permuted_block_{}'.format(i + 1)) name='scale_permuted_block_{}'.format(i + 1))
net.append(x) net.append(x)
......
...@@ -167,6 +167,26 @@ class SqueezeExcitation(tf.keras.layers.Layer): ...@@ -167,6 +167,26 @@ class SqueezeExcitation(tf.keras.layers.Layer):
return x * inputs return x * inputs
def get_stochastic_depth_rate(init_rate, i, n):
"""Get drop connect rate for the ith block.
Args:
init_rate: `float` initial drop rate.
i: `int` order of the current block.
n: `int` total number of blocks.
Returns:
Drop rate of the ith block.
"""
if init_rate is not None:
if init_rate < 0 or init_rate > 1:
raise ValueError('Initial drop rate must be within 0 and 1.')
rate = init_rate * float(i) / n
else:
rate = None
return rate
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
class StochasticDepth(tf.keras.layers.Layer): class StochasticDepth(tf.keras.layers.Layer):
"""Stochastic depth layer.""" """Stochastic depth layer."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment