Commit 63e9d1cd authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 365678315
parent 76443347
...@@ -34,10 +34,13 @@ class ResNet3DBlock(hyperparams.Config): ...@@ -34,10 +34,13 @@ class ResNet3DBlock(hyperparams.Config):
class ResNet3D(hyperparams.Config): class ResNet3D(hyperparams.Config):
"""ResNet config.""" """ResNet config."""
model_id: int = 50 model_id: int = 50
stem_type: str = 'v0'
stem_conv_temporal_kernel_size: int = 5 stem_conv_temporal_kernel_size: int = 5
stem_conv_temporal_stride: int = 2 stem_conv_temporal_stride: int = 2
stem_pool_temporal_stride: int = 2 stem_pool_temporal_stride: int = 2
block_specs: Tuple[ResNet3DBlock, ...] = () block_specs: Tuple[ResNet3DBlock, ...] = ()
stochastic_depth_drop_rate: float = 0.0
se_ratio: float = 0.0
@dataclasses.dataclass @dataclasses.dataclass
...@@ -60,13 +63,45 @@ class ResNet3D50(ResNet3D): ...@@ -60,13 +63,45 @@ class ResNet3D50(ResNet3D):
use_self_gating=True)) use_self_gating=True))
@dataclasses.dataclass
class ResNet3DRS(ResNet3D):
"""Block specifications of the ResNet-RS (3D) model."""
model_id: int = 50
stem_type: str = 'v1'
stem_conv_temporal_kernel_size: int = 5
stem_conv_temporal_stride: int = 2
stem_pool_temporal_stride: int = 2
stochastic_depth_drop_rate: float = 0.1
se_ratio: float = 0.2
block_specs: Tuple[
ResNet3DBlock, ResNet3DBlock, ResNet3DBlock, ResNet3DBlock] = (
ResNet3DBlock(temporal_strides=1,
temporal_kernel_sizes=(1,),
use_self_gating=True),
ResNet3DBlock(temporal_strides=1,
temporal_kernel_sizes=(1,),
use_self_gating=True),
ResNet3DBlock(temporal_strides=1,
temporal_kernel_sizes=(3,),
use_self_gating=True),
ResNet3DBlock(temporal_strides=1,
temporal_kernel_sizes=(3,),
use_self_gating=True))
_RESNET3D50_DEFAULT_CFG = ResNet3D50()
_RESNET3DRS_DEFAULT_CFG = ResNet3DRS()
@dataclasses.dataclass @dataclasses.dataclass
class Backbone3D(hyperparams.OneOfConfig): class Backbone3D(hyperparams.OneOfConfig):
"""Configuration for backbones. """Configuration for backbones.
Attributes: Attributes:
type: 'str', type of backbone be used, one of the fields below. type: 'str', type of backbone be used, one of the fields below.
resnet: resnet3d backbone config. resnet_3d: resnet3d backbone config.
resnet_3d_rs: resnet3d-rs backbone config.
""" """
type: Optional[str] = None type: Optional[str] = None
resnet_3d: ResNet3D = ResNet3D50() resnet_3d: ResNet3D = _RESNET3D50_DEFAULT_CFG
resnet_3d_rs: ResNet3D = _RESNET3DRS_DEFAULT_CFG
# 3D ResNet-RS-50 video classification on Kinetics-400.
#
# --experiment_type=video_classification_kinetics400
# Expected accuracy: 78.2% top-1 accuracy.
runtime:
mixed_precision_dtype: bfloat16
task:
losses:
l2_weight_decay: 0.00004
label_smoothing: 0.1
one_hot: true
model:
aggregate_endpoints: false
backbone:
resnet_3d_rs:
model_id: 50
stem_type: 'v1'
stem_conv_temporal_kernel_size: 5
stem_conv_temporal_stride: 2
stem_pool_temporal_stride: 1
stochastic_depth_drop_rate: 0.1
se_ratio: 0.25
type: resnet_3d_rs
dropout_rate: 0.5
model_type: video_classification
norm_activation:
activation: relu
norm_epsilon: 1.0e-05
norm_momentum: 0.0
use_sync_bn: false
train_data:
data_format: channels_last
drop_remainder: true
dtype: bfloat16
feature_shape: !!python/tuple
- 32
- 224
- 224
- 3
file_type: sstable
global_batch_size: 1024
is_training: true
min_image_size: 256
name: kinetics400
num_channels: 3
num_classes: 400
num_examples: 215570
num_test_clips: 1
num_test_crops: 1
one_hot: true
temporal_stride: 2
aug_max_area_ratio: 1.0
aug_max_aspect_ratio: 2.0
aug_min_area_ratio: 0.08
aug_min_aspect_ratio: 0.5
validation_data:
data_format: channels_last
drop_remainder: false
dtype: bfloat16
feature_shape: !!python/tuple
- 32
- 256
- 256
- 3
file_type: sstable
global_batch_size: 64
is_training: false
min_image_size: 256
name: kinetics400
num_channels: 3
num_classes: 400
num_examples: 17706
num_test_clips: 10
num_test_crops: 3
one_hot: true
temporal_stride: 2
trainer:
checkpoint_interval: 210
max_to_keep: 3
optimizer_config:
ema:
average_decay: 0.9999
learning_rate:
cosine:
decay_steps: 73682
initial_learning_rate: 0.8
name: CosineDecay
type: cosine
warmup:
linear:
name: linear
warmup_learning_rate: 0
warmup_steps: 1050
type: linear
train_steps: 73682
steps_per_loop: 500
summary_interval: 500
validation_interval: 500
...@@ -20,6 +20,7 @@ import tensorflow as tf ...@@ -20,6 +20,7 @@ import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.beta.modeling.backbones import factory from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_blocks_3d from official.vision.beta.modeling.layers import nn_blocks_3d
from official.vision.beta.modeling.layers import nn_layers
layers = tf.keras.layers layers = tf.keras.layers
...@@ -36,6 +37,36 @@ RESNET_SPECS = { ...@@ -36,6 +37,36 @@ RESNET_SPECS = {
('bottleneck3d', 256, 23), ('bottleneck3d', 256, 23),
('bottleneck3d', 512, 3), ('bottleneck3d', 512, 3),
], ],
152: [
('bottleneck3d', 64, 3),
('bottleneck3d', 128, 8),
('bottleneck3d', 256, 36),
('bottleneck3d', 512, 3),
],
200: [
('bottleneck3d', 64, 3),
('bottleneck3d', 128, 24),
('bottleneck3d', 256, 36),
('bottleneck3d', 512, 3),
],
270: [
('bottleneck3d', 64, 4),
('bottleneck3d', 128, 29),
('bottleneck3d', 256, 53),
('bottleneck3d', 512, 4),
],
300: [
('bottleneck3d', 64, 4),
('bottleneck3d', 128, 36),
('bottleneck3d', 256, 54),
('bottleneck3d', 512, 4),
],
350: [
('bottleneck3d', 64, 4),
('bottleneck3d', 128, 36),
('bottleneck3d', 256, 72),
('bottleneck3d', 512, 4),
],
} }
...@@ -49,10 +80,13 @@ class ResNet3D(tf.keras.Model): ...@@ -49,10 +80,13 @@ class ResNet3D(tf.keras.Model):
temporal_kernel_sizes: List[Tuple[int]], temporal_kernel_sizes: List[Tuple[int]],
use_self_gating: List[int] = None, use_self_gating: List[int] = None,
input_specs=layers.InputSpec(shape=[None, None, None, None, 3]), input_specs=layers.InputSpec(shape=[None, None, None, None, 3]),
stem_type='v0',
stem_conv_temporal_kernel_size=5, stem_conv_temporal_kernel_size=5,
stem_conv_temporal_stride=2, stem_conv_temporal_stride=2,
stem_pool_temporal_stride=2, stem_pool_temporal_stride=2,
init_stochastic_depth_rate=0.0,
activation='relu', activation='relu',
se_ratio=None,
use_sync_bn=False, use_sync_bn=False,
norm_momentum=0.99, norm_momentum=0.99,
norm_epsilon=0.001, norm_epsilon=0.001,
...@@ -71,13 +105,17 @@ class ResNet3D(tf.keras.Model): ...@@ -71,13 +105,17 @@ class ResNet3D(tf.keras.Model):
use_self_gating: A list of booleans to specify applying self-gating module use_self_gating: A list of booleans to specify applying self-gating module
or not in each block group. If None, self-gating is not applied. or not in each block group. If None, self-gating is not applied.
input_specs: A `tf.keras.layers.InputSpec` of the input tensor. input_specs: A `tf.keras.layers.InputSpec` of the input tensor.
stem_type: A `str` of stem type of ResNet. Default to `v0`. If set to
`v1`, use ResNet-D type stem (https://arxiv.org/abs/1812.01187).
stem_conv_temporal_kernel_size: An `int` of temporal kernel size for the stem_conv_temporal_kernel_size: An `int` of temporal kernel size for the
first conv layer. first conv layer.
stem_conv_temporal_stride: An `int` of temporal stride for the first conv stem_conv_temporal_stride: An `int` of temporal stride for the first conv
layer. layer.
stem_pool_temporal_stride: An `int` of temporal stride for the first pool stem_pool_temporal_stride: An `int` of temporal stride for the first pool
layer. layer.
init_stochastic_depth_rate: A `float` of initial stochastic depth rate.
activation: A `str` of name of the activation function. activation: A `str` of name of the activation function.
se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer.
use_sync_bn: If True, use synchronized batch normalization. use_sync_bn: If True, use synchronized batch normalization.
norm_momentum: A `float` of normalization momentum for the moving average. norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero. norm_epsilon: A `float` added to variance to avoid dividing by zero.
...@@ -92,10 +130,13 @@ class ResNet3D(tf.keras.Model): ...@@ -92,10 +130,13 @@ class ResNet3D(tf.keras.Model):
self._temporal_strides = temporal_strides self._temporal_strides = temporal_strides
self._temporal_kernel_sizes = temporal_kernel_sizes self._temporal_kernel_sizes = temporal_kernel_sizes
self._input_specs = input_specs self._input_specs = input_specs
self._stem_type = stem_type
self._stem_conv_temporal_kernel_size = stem_conv_temporal_kernel_size self._stem_conv_temporal_kernel_size = stem_conv_temporal_kernel_size
self._stem_conv_temporal_stride = stem_conv_temporal_stride self._stem_conv_temporal_stride = stem_conv_temporal_stride
self._stem_pool_temporal_stride = stem_pool_temporal_stride self._stem_pool_temporal_stride = stem_pool_temporal_stride
self._use_self_gating = use_self_gating self._use_self_gating = use_self_gating
self._se_ratio = se_ratio
self._init_stochastic_depth_rate = init_stochastic_depth_rate
self._use_sync_bn = use_sync_bn self._use_sync_bn = use_sync_bn
self._activation = activation self._activation = activation
self._norm_momentum = norm_momentum self._norm_momentum = norm_momentum
...@@ -116,6 +157,7 @@ class ResNet3D(tf.keras.Model): ...@@ -116,6 +157,7 @@ class ResNet3D(tf.keras.Model):
inputs = tf.keras.Input(shape=input_specs.shape[1:]) inputs = tf.keras.Input(shape=input_specs.shape[1:])
# Build stem. # Build stem.
if stem_type == 'v0':
x = layers.Conv3D( x = layers.Conv3D(
filters=64, filters=64,
kernel_size=[stem_conv_temporal_kernel_size, 7, 7], kernel_size=[stem_conv_temporal_kernel_size, 7, 7],
...@@ -130,6 +172,51 @@ class ResNet3D(tf.keras.Model): ...@@ -130,6 +172,51 @@ class ResNet3D(tf.keras.Model):
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
x) x)
x = tf_utils.get_activation(activation)(x) x = tf_utils.get_activation(activation)(x)
elif stem_type == 'v1':
x = layers.Conv3D(
filters=32,
kernel_size=[stem_conv_temporal_kernel_size, 3, 3],
strides=[stem_conv_temporal_stride, 2, 2],
use_bias=False,
padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
inputs)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
x)
x = tf_utils.get_activation(activation)(x)
x = layers.Conv3D(
filters=32,
kernel_size=[1, 3, 3],
strides=[1, 1, 1],
use_bias=False,
padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
x)
x = tf_utils.get_activation(activation)(x)
x = layers.Conv3D(
filters=64,
kernel_size=[1, 3, 3],
strides=[1, 1, 1],
use_bias=False,
padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
x)
x = tf_utils.get_activation(activation)(x)
else:
raise ValueError(f'Stem type {stem_type} not supported.')
temporal_kernel_size = 1 if stem_pool_temporal_stride == 1 else 3 temporal_kernel_size = 1 if stem_pool_temporal_stride == 1 else 3
x = layers.MaxPool3D( x = layers.MaxPool3D(
...@@ -161,6 +248,8 @@ class ResNet3D(tf.keras.Model): ...@@ -161,6 +248,8 @@ class ResNet3D(tf.keras.Model):
spatial_strides=(1 if i == 0 else 2), spatial_strides=(1 if i == 0 else 2),
block_fn=block_fn, block_fn=block_fn,
block_repeats=resnet_spec[2], block_repeats=resnet_spec[2],
stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
self._init_stochastic_depth_rate, i + 2, 5),
use_self_gating=use_self_gating[i] if use_self_gating else False, use_self_gating=use_self_gating[i] if use_self_gating else False,
name='block_group_l{}'.format(i + 2)) name='block_group_l{}'.format(i + 2))
endpoints[str(i + 2)] = x endpoints[str(i + 2)] = x
...@@ -177,6 +266,7 @@ class ResNet3D(tf.keras.Model): ...@@ -177,6 +266,7 @@ class ResNet3D(tf.keras.Model):
spatial_strides, spatial_strides,
block_fn=nn_blocks_3d.BottleneckBlock3D, block_fn=nn_blocks_3d.BottleneckBlock3D,
block_repeats=1, block_repeats=1,
stochastic_depth_drop_rate=0.0,
use_self_gating=False, use_self_gating=False,
name='block_group'): name='block_group'):
"""Creates one group of blocks for the ResNet3D model. """Creates one group of blocks for the ResNet3D model.
...@@ -193,6 +283,8 @@ class ResNet3D(tf.keras.Model): ...@@ -193,6 +283,8 @@ class ResNet3D(tf.keras.Model):
layer. If greater than 1, this layer will downsample the input. layer. If greater than 1, this layer will downsample the input.
block_fn: Either `nn_blocks.ResidualBlock` or `nn_blocks.BottleneckBlock`. block_fn: Either `nn_blocks.ResidualBlock` or `nn_blocks.BottleneckBlock`.
block_repeats: An `int` of number of blocks contained in the layer. block_repeats: An `int` of number of blocks contained in the layer.
stochastic_depth_drop_rate: A `float` of drop rate of the current block
group.
use_self_gating: A `bool` that specifies whether to apply self-gating use_self_gating: A `bool` that specifies whether to apply self-gating
module or not. module or not.
name: A `str` name for the block. name: A `str` name for the block.
...@@ -213,7 +305,9 @@ class ResNet3D(tf.keras.Model): ...@@ -213,7 +305,9 @@ class ResNet3D(tf.keras.Model):
temporal_kernel_size=temporal_kernel_sizes[0], temporal_kernel_size=temporal_kernel_sizes[0],
temporal_strides=temporal_strides, temporal_strides=temporal_strides,
spatial_strides=spatial_strides, spatial_strides=spatial_strides,
stochastic_depth_drop_rate=stochastic_depth_drop_rate,
use_self_gating=use_self_gating_list[0], use_self_gating=use_self_gating_list[0],
se_ratio=self._se_ratio,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
...@@ -229,7 +323,9 @@ class ResNet3D(tf.keras.Model): ...@@ -229,7 +323,9 @@ class ResNet3D(tf.keras.Model):
temporal_kernel_size=temporal_kernel_sizes[i], temporal_kernel_size=temporal_kernel_sizes[i],
temporal_strides=1, temporal_strides=1,
spatial_strides=1, spatial_strides=1,
stochastic_depth_drop_rate=stochastic_depth_drop_rate,
use_self_gating=use_self_gating_list[i], use_self_gating=use_self_gating_list[i],
se_ratio=self._se_ratio,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
...@@ -246,10 +342,13 @@ class ResNet3D(tf.keras.Model): ...@@ -246,10 +342,13 @@ class ResNet3D(tf.keras.Model):
'model_id': self._model_id, 'model_id': self._model_id,
'temporal_strides': self._temporal_strides, 'temporal_strides': self._temporal_strides,
'temporal_kernel_sizes': self._temporal_kernel_sizes, 'temporal_kernel_sizes': self._temporal_kernel_sizes,
'stem_type': self._stem_type,
'stem_conv_temporal_kernel_size': self._stem_conv_temporal_kernel_size, 'stem_conv_temporal_kernel_size': self._stem_conv_temporal_kernel_size,
'stem_conv_temporal_stride': self._stem_conv_temporal_stride, 'stem_conv_temporal_stride': self._stem_conv_temporal_stride,
'stem_pool_temporal_stride': self._stem_pool_temporal_stride, 'stem_pool_temporal_stride': self._stem_pool_temporal_stride,
'use_self_gating': self._use_self_gating, 'use_self_gating': self._use_self_gating,
'se_ratio': self._se_ratio,
'init_stochastic_depth_rate': self._init_stochastic_depth_rate,
'activation': self._activation, 'activation': self._activation,
'use_sync_bn': self._use_sync_bn, 'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum, 'norm_momentum': self._norm_momentum,
...@@ -276,11 +375,8 @@ def build_resnet3d( ...@@ -276,11 +375,8 @@ def build_resnet3d(
model_config, model_config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds ResNet 3d backbone from a config.""" """Builds ResNet 3d backbone from a config."""
backbone_type = model_config.backbone.type
backbone_cfg = model_config.backbone.get() backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation norm_activation_config = model_config.norm_activation
assert backbone_type == 'resnet_3d', (f'Inconsistent backbone type '
f'{backbone_type}')
# Flatten configs before passing to the backbone. # Flatten configs before passing to the backbone.
temporal_strides = [] temporal_strides = []
...@@ -297,10 +393,52 @@ def build_resnet3d( ...@@ -297,10 +393,52 @@ def build_resnet3d(
temporal_kernel_sizes=temporal_kernel_sizes, temporal_kernel_sizes=temporal_kernel_sizes,
use_self_gating=use_self_gating, use_self_gating=use_self_gating,
input_specs=input_specs, input_specs=input_specs,
stem_type=backbone_cfg.stem_type,
stem_conv_temporal_kernel_size=backbone_cfg
.stem_conv_temporal_kernel_size,
stem_conv_temporal_stride=backbone_cfg.stem_conv_temporal_stride,
stem_pool_temporal_stride=backbone_cfg.stem_pool_temporal_stride,
init_stochastic_depth_rate=backbone_cfg.stochastic_depth_drop_rate,
se_ratio=backbone_cfg.se_ratio,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
@factory.register_backbone_builder('resnet_3d_rs')
def build_resnet3d_rs(
input_specs: tf.keras.layers.InputSpec,
model_config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds ResNet-3D-RS backbone from a config."""
backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
# Flatten configs before passing to the backbone.
temporal_strides = []
temporal_kernel_sizes = []
use_self_gating = []
for i, block_spec in enumerate(backbone_cfg.block_specs):
temporal_strides.append(block_spec.temporal_strides)
use_self_gating.append(block_spec.use_self_gating)
block_repeats_i = RESNET_SPECS[backbone_cfg.model_id][i][-1]
temporal_kernel_sizes.append(list(block_spec.temporal_kernel_sizes) *
block_repeats_i)
return ResNet3D(
model_id=backbone_cfg.model_id,
temporal_strides=temporal_strides,
temporal_kernel_sizes=temporal_kernel_sizes,
use_self_gating=use_self_gating,
input_specs=input_specs,
stem_type=backbone_cfg.stem_type,
stem_conv_temporal_kernel_size=backbone_cfg stem_conv_temporal_kernel_size=backbone_cfg
.stem_conv_temporal_kernel_size, .stem_conv_temporal_kernel_size,
stem_conv_temporal_stride=backbone_cfg.stem_conv_temporal_stride, stem_conv_temporal_stride=backbone_cfg.stem_conv_temporal_stride,
stem_pool_temporal_stride=backbone_cfg.stem_pool_temporal_stride, stem_pool_temporal_stride=backbone_cfg.stem_pool_temporal_stride,
init_stochastic_depth_rate=backbone_cfg.stochastic_depth_drop_rate,
se_ratio=backbone_cfg.se_ratio,
activation=norm_activation_config.activation, activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn, use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum, norm_momentum=norm_activation_config.norm_momentum,
......
...@@ -25,10 +25,12 @@ from official.vision.beta.modeling.backbones import resnet_3d ...@@ -25,10 +25,12 @@ from official.vision.beta.modeling.backbones import resnet_3d
class ResNet3DTest(parameterized.TestCase, tf.test.TestCase): class ResNet3DTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters( @parameterized.parameters(
(128, 50, 4), (128, 50, 4, 'v0', False, 0.0),
(128, 50, 4, 'v1', False, 0.2),
(256, 50, 4, 'v1', True, 0.2),
) )
def test_network_creation(self, input_size, model_id, def test_network_creation(self, input_size, model_id, endpoint_filter_scale,
endpoint_filter_scale): stem_type, se_ratio, init_stochastic_depth_rate):
"""Test creation of ResNet3D family models.""" """Test creation of ResNet3D family models."""
tf.keras.backend.set_image_data_format('channels_last') tf.keras.backend.set_image_data_format('channels_last')
temporal_strides = [1, 1, 1, 1] temporal_strides = [1, 1, 1, 1]
...@@ -41,7 +43,9 @@ class ResNet3DTest(parameterized.TestCase, tf.test.TestCase): ...@@ -41,7 +43,9 @@ class ResNet3DTest(parameterized.TestCase, tf.test.TestCase):
temporal_strides=temporal_strides, temporal_strides=temporal_strides,
temporal_kernel_sizes=temporal_kernel_sizes, temporal_kernel_sizes=temporal_kernel_sizes,
use_self_gating=use_self_gating, use_self_gating=use_self_gating,
) stem_type=stem_type,
se_ratio=se_ratio,
init_stochastic_depth_rate=init_stochastic_depth_rate)
inputs = tf.keras.Input(shape=(8, input_size, input_size, 3), batch_size=1) inputs = tf.keras.Input(shape=(8, input_size, input_size, 3), batch_size=1)
endpoints = network(inputs) endpoints = network(inputs)
...@@ -65,10 +69,13 @@ class ResNet3DTest(parameterized.TestCase, tf.test.TestCase): ...@@ -65,10 +69,13 @@ class ResNet3DTest(parameterized.TestCase, tf.test.TestCase):
temporal_strides=[1, 1, 1, 1], temporal_strides=[1, 1, 1, 1],
temporal_kernel_sizes=[(3, 3, 3), (3, 1, 3, 1), (3, 1, 3, 1, 3, 1), temporal_kernel_sizes=[(3, 3, 3), (3, 1, 3, 1), (3, 1, 3, 1, 3, 1),
(1, 3, 1)], (1, 3, 1)],
stem_type='v0',
stem_conv_temporal_kernel_size=5, stem_conv_temporal_kernel_size=5,
stem_conv_temporal_stride=2, stem_conv_temporal_stride=2,
stem_pool_temporal_stride=2, stem_pool_temporal_stride=2,
se_ratio=0.0,
use_self_gating=None, use_self_gating=None,
init_stochastic_depth_rate=0.0,
use_sync_bn=False, use_sync_bn=False,
activation='relu', activation='relu',
norm_momentum=0.99, norm_momentum=0.99,
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.beta.modeling.layers import nn_layers
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
...@@ -75,6 +76,8 @@ class BottleneckBlock3D(tf.keras.layers.Layer): ...@@ -75,6 +76,8 @@ class BottleneckBlock3D(tf.keras.layers.Layer):
temporal_kernel_size, temporal_kernel_size,
temporal_strides, temporal_strides,
spatial_strides, spatial_strides,
stochastic_depth_drop_rate=0.0,
se_ratio=None,
use_self_gating=False, use_self_gating=False,
kernel_initializer='VarianceScaling', kernel_initializer='VarianceScaling',
kernel_regularizer=None, kernel_regularizer=None,
...@@ -95,6 +98,9 @@ class BottleneckBlock3D(tf.keras.layers.Layer): ...@@ -95,6 +98,9 @@ class BottleneckBlock3D(tf.keras.layers.Layer):
convolutional layer. convolutional layer.
spatial_strides: An `int` of spatial stride for the spatial convolutional spatial_strides: An `int` of spatial stride for the spatial convolutional
layer. layer.
stochastic_depth_drop_rate: A `float` or None. If not None, drop rate for
the stochastic depth layer.
se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer.
use_self_gating: A `bool` of whether to apply self-gating module or not. use_self_gating: A `bool` of whether to apply self-gating module or not.
kernel_initializer: A `str` of kernel_initializer for convolutional kernel_initializer: A `str` of kernel_initializer for convolutional
layers. layers.
...@@ -114,7 +120,9 @@ class BottleneckBlock3D(tf.keras.layers.Layer): ...@@ -114,7 +120,9 @@ class BottleneckBlock3D(tf.keras.layers.Layer):
self._temporal_kernel_size = temporal_kernel_size self._temporal_kernel_size = temporal_kernel_size
self._spatial_strides = spatial_strides self._spatial_strides = spatial_strides
self._temporal_strides = temporal_strides self._temporal_strides = temporal_strides
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
self._use_self_gating = use_self_gating self._use_self_gating = use_self_gating
self._se_ratio = se_ratio
self._use_sync_bn = use_sync_bn self._use_sync_bn = use_sync_bn
self._activation = activation self._activation = activation
self._kernel_initializer = kernel_initializer self._kernel_initializer = kernel_initializer
...@@ -197,6 +205,24 @@ class BottleneckBlock3D(tf.keras.layers.Layer): ...@@ -197,6 +205,24 @@ class BottleneckBlock3D(tf.keras.layers.Layer):
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon)
if self._se_ratio and self._se_ratio > 0 and self._se_ratio <= 1:
self._squeeze_excitation = nn_layers.SqueezeExcitation(
in_filters=self._filters * 4,
out_filters=self._filters * 4,
se_ratio=self._se_ratio,
use_3d_input=True,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
else:
self._squeeze_excitation = None
if self._stochastic_depth_drop_rate:
self._stochastic_depth = nn_layers.StochasticDepth(
self._stochastic_depth_drop_rate)
else:
self._stochastic_depth = None
if self._use_self_gating: if self._use_self_gating:
self._self_gating = SelfGating(filters=4 * self._filters) self._self_gating = SelfGating(filters=4 * self._filters)
else: else:
...@@ -211,6 +237,8 @@ class BottleneckBlock3D(tf.keras.layers.Layer): ...@@ -211,6 +237,8 @@ class BottleneckBlock3D(tf.keras.layers.Layer):
'temporal_strides': self._temporal_strides, 'temporal_strides': self._temporal_strides,
'spatial_strides': self._spatial_strides, 'spatial_strides': self._spatial_strides,
'use_self_gating': self._use_self_gating, 'use_self_gating': self._use_self_gating,
'se_ratio': self._se_ratio,
'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
'kernel_initializer': self._kernel_initializer, 'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer, 'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer, 'bias_regularizer': self._bias_regularizer,
...@@ -243,10 +271,16 @@ class BottleneckBlock3D(tf.keras.layers.Layer): ...@@ -243,10 +271,16 @@ class BottleneckBlock3D(tf.keras.layers.Layer):
x = self._expand_conv(x) x = self._expand_conv(x)
x = self._norm3(x) x = self._norm3(x)
# Apply activation before additional modules.
x = self._activation_fn(x + shortcut)
# Apply self-gating, SE, stochastic depth.
if self._self_gating: if self._self_gating:
x = self._self_gating(x) x = self._self_gating(x)
if self._squeeze_excitation:
x = self._squeeze_excitation(x)
if self._stochastic_depth:
x = self._stochastic_depth(x, training=training)
# Apply activation before additional modules.
x = self._activation_fn(x + shortcut)
return x return x
...@@ -25,12 +25,13 @@ from official.vision.beta.modeling.layers import nn_blocks_3d ...@@ -25,12 +25,13 @@ from official.vision.beta.modeling.layers import nn_blocks_3d
class NNBlocksTest(parameterized.TestCase, tf.test.TestCase): class NNBlocksTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters( @parameterized.parameters(
(nn_blocks_3d.BottleneckBlock3D, 1, 1, 2, True), (nn_blocks_3d.BottleneckBlock3D, 1, 1, 2, True, 0.2, 0.1),
(nn_blocks_3d.BottleneckBlock3D, 3, 2, 1, False), (nn_blocks_3d.BottleneckBlock3D, 3, 2, 1, False, 0.0, 0.0),
) )
def test_bottleneck_block_creation(self, block_fn, temporal_kernel_size, def test_bottleneck_block_creation(self, block_fn, temporal_kernel_size,
temporal_strides, spatial_strides, temporal_strides, spatial_strides,
use_self_gating): use_self_gating, se_ratio,
stochastic_depth):
temporal_size = 16 temporal_size = 16
spatial_size = 128 spatial_size = 128
filters = 256 filters = 256
...@@ -42,7 +43,9 @@ class NNBlocksTest(parameterized.TestCase, tf.test.TestCase): ...@@ -42,7 +43,9 @@ class NNBlocksTest(parameterized.TestCase, tf.test.TestCase):
temporal_kernel_size=temporal_kernel_size, temporal_kernel_size=temporal_kernel_size,
temporal_strides=temporal_strides, temporal_strides=temporal_strides,
spatial_strides=spatial_strides, spatial_strides=spatial_strides,
use_self_gating=use_self_gating) use_self_gating=use_self_gating,
se_ratio=se_ratio,
stochastic_depth_drop_rate=stochastic_depth)
features = block(inputs) features = block(inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment