Commit 3dce64dc authored by Yuqi Li's avatar Yuqi Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 398633511
parent 71ab0c31
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
# Lint as: python3 # Lint as: python3
"""Backbones configurations.""" """Backbones configurations."""
import dataclasses
from typing import Optional, List from typing import Optional, List
# Import libraries # Import libraries
import dataclasses
from official.modeling import hyperparams from official.modeling import hyperparams
...@@ -62,6 +62,8 @@ class MobileNet(hyperparams.Config): ...@@ -62,6 +62,8 @@ class MobileNet(hyperparams.Config):
model_id: str = 'MobileNetV2' model_id: str = 'MobileNetV2'
filter_size_scale: float = 1.0 filter_size_scale: float = 1.0
stochastic_depth_drop_rate: float = 0.0 stochastic_depth_drop_rate: float = 0.0
output_stride: Optional[int] = None
output_intermediate_endpoints: bool = False
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -67,7 +67,7 @@ class SegmentationHead(hyperparams.Config): ...@@ -67,7 +67,7 @@ class SegmentationHead(hyperparams.Config):
upsample_factor: int = 1 upsample_factor: int = 1
feature_fusion: Optional[str] = None # None, deeplabv3plus, or pyramid_fusion feature_fusion: Optional[str] = None # None, deeplabv3plus, or pyramid_fusion
# deeplabv3plus feature fusion params # deeplabv3plus feature fusion params
low_level: int = 2 low_level: Union[int, str] = 2
low_level_num_filters: int = 48 low_level_num_filters: int = 48
...@@ -137,7 +137,7 @@ PASCAL_INPUT_PATH_BASE = 'pascal_voc_seg' ...@@ -137,7 +137,7 @@ PASCAL_INPUT_PATH_BASE = 'pascal_voc_seg'
@exp_factory.register_config_factory('seg_deeplabv3_pascal') @exp_factory.register_config_factory('seg_deeplabv3_pascal')
def seg_deeplabv3_pascal() -> cfg.ExperimentConfig: def seg_deeplabv3_pascal() -> cfg.ExperimentConfig:
"""Image segmentation on imagenet with resnet deeplabv3.""" """Image segmentation on pascal voc with resnet deeplabv3."""
train_batch_size = 16 train_batch_size = 16
eval_batch_size = 8 eval_batch_size = 8
steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size
...@@ -225,7 +225,7 @@ def seg_deeplabv3_pascal() -> cfg.ExperimentConfig: ...@@ -225,7 +225,7 @@ def seg_deeplabv3_pascal() -> cfg.ExperimentConfig:
@exp_factory.register_config_factory('seg_deeplabv3plus_pascal') @exp_factory.register_config_factory('seg_deeplabv3plus_pascal')
def seg_deeplabv3plus_pascal() -> cfg.ExperimentConfig: def seg_deeplabv3plus_pascal() -> cfg.ExperimentConfig:
"""Image segmentation on imagenet with resnet deeplabv3+.""" """Image segmentation on pascal voc with resnet deeplabv3+."""
train_batch_size = 16 train_batch_size = 16
eval_batch_size = 8 eval_batch_size = 8
steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size
...@@ -318,7 +318,7 @@ def seg_deeplabv3plus_pascal() -> cfg.ExperimentConfig: ...@@ -318,7 +318,7 @@ def seg_deeplabv3plus_pascal() -> cfg.ExperimentConfig:
@exp_factory.register_config_factory('seg_resnetfpn_pascal') @exp_factory.register_config_factory('seg_resnetfpn_pascal')
def seg_resnetfpn_pascal() -> cfg.ExperimentConfig: def seg_resnetfpn_pascal() -> cfg.ExperimentConfig:
"""Image segmentation on imagenet with resnet-fpn.""" """Image segmentation on pascal voc with resnet-fpn."""
train_batch_size = 256 train_batch_size = 256
eval_batch_size = 32 eval_batch_size = 32
steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size
...@@ -390,6 +390,99 @@ def seg_resnetfpn_pascal() -> cfg.ExperimentConfig: ...@@ -390,6 +390,99 @@ def seg_resnetfpn_pascal() -> cfg.ExperimentConfig:
return config return config
@exp_factory.register_config_factory('mnv2_deeplabv3_pascal')
def mnv2_deeplabv3_pascal() -> cfg.ExperimentConfig:
"""Image segmentation on pascal with mobilenetv2 deeplabv3."""
train_batch_size = 16
eval_batch_size = 16
steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size
output_stride = 16
aspp_dilation_rates = []
level = int(np.math.log2(output_stride))
pool_kernel_size = []
config = cfg.ExperimentConfig(
task=SemanticSegmentationTask(
model=SemanticSegmentationModel(
num_classes=21,
input_size=[None, None, 3],
backbone=backbones.Backbone(
type='mobilenet',
mobilenet=backbones.MobileNet(
model_id='MobileNetV2', output_stride=output_stride)),
decoder=decoders.Decoder(
type='aspp',
aspp=decoders.ASPP(
level=level,
dilation_rates=aspp_dilation_rates,
pool_kernel_size=pool_kernel_size)),
head=SegmentationHead(level=level, num_convs=0),
norm_activation=common.NormActivation(
activation='relu',
norm_momentum=0.99,
norm_epsilon=1e-3,
use_sync_bn=True)),
losses=Losses(l2_weight_decay=4e-5),
train_data=DataConfig(
input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'train_aug*'),
output_size=[512, 512],
is_training=True,
global_batch_size=train_batch_size,
aug_scale_min=0.5,
aug_scale_max=2.0),
validation_data=DataConfig(
input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'val*'),
output_size=[512, 512],
is_training=False,
global_batch_size=eval_batch_size,
resize_eval_groundtruth=False,
groundtruth_padded_size=[512, 512],
drop_remainder=False),
# mobilenetv2
init_checkpoint='gs://tf_model_garden/cloud/vision-2.0/deeplab/deeplabv3_mobilenetv2_coco/best_ckpt-63',
init_checkpoint_modules=['backbone', 'decoder']),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=30000,
validation_steps=PASCAL_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
best_checkpoint_eval_metric='mean_iou',
best_checkpoint_export_subdir='best_ckpt',
best_checkpoint_metric_comp='higher',
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'sgd',
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'polynomial',
'polynomial': {
'initial_learning_rate': 0.007 * train_batch_size / 16,
'decay_steps': 30000,
'end_learning_rate': 0.0,
'power': 0.9
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 5 * steps_per_epoch,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
# Cityscapes Dataset (Download and process the dataset yourself) # Cityscapes Dataset (Download and process the dataset yourself)
CITYSCAPES_TRAIN_EXAMPLES = 2975 CITYSCAPES_TRAIN_EXAMPLES = 2975
CITYSCAPES_VAL_EXAMPLES = 500 CITYSCAPES_VAL_EXAMPLES = 500
...@@ -398,7 +491,7 @@ CITYSCAPES_INPUT_PATH_BASE = 'cityscapes' ...@@ -398,7 +491,7 @@ CITYSCAPES_INPUT_PATH_BASE = 'cityscapes'
@exp_factory.register_config_factory('seg_deeplabv3plus_cityscapes') @exp_factory.register_config_factory('seg_deeplabv3plus_cityscapes')
def seg_deeplabv3plus_cityscapes() -> cfg.ExperimentConfig: def seg_deeplabv3plus_cityscapes() -> cfg.ExperimentConfig:
"""Image segmentation on imagenet with resnet deeplabv3+.""" """Image segmentation on cityscapes with resnet deeplabv3+."""
train_batch_size = 16 train_batch_size = 16
eval_batch_size = 16 eval_batch_size = 16
steps_per_epoch = CITYSCAPES_TRAIN_EXAMPLES // train_batch_size steps_per_epoch = CITYSCAPES_TRAIN_EXAMPLES // train_batch_size
...@@ -491,3 +584,114 @@ def seg_deeplabv3plus_cityscapes() -> cfg.ExperimentConfig: ...@@ -491,3 +584,114 @@ def seg_deeplabv3plus_cityscapes() -> cfg.ExperimentConfig:
]) ])
return config return config
@exp_factory.register_config_factory('mnv2_deeplabv3_cityscapes')
def mnv2_deeplabv3_cityscapes() -> cfg.ExperimentConfig:
"""Image segmentation on cityscapes with mobilenetv2 deeplabv3."""
train_batch_size = 16
eval_batch_size = 16
steps_per_epoch = CITYSCAPES_TRAIN_EXAMPLES // train_batch_size
output_stride = 16
aspp_dilation_rates = []
pool_kernel_size = [512, 1024]
level = int(np.math.log2(output_stride))
config = cfg.ExperimentConfig(
task=SemanticSegmentationTask(
model=SemanticSegmentationModel(
# Cityscapes uses only 19 semantic classes for train/evaluation.
# The void (background) class is ignored in train and evaluation.
num_classes=19,
input_size=[None, None, 3],
backbone=backbones.Backbone(
type='mobilenet',
mobilenet=backbones.MobileNet(
model_id='MobileNetV2', output_stride=output_stride)),
decoder=decoders.Decoder(
type='aspp',
aspp=decoders.ASPP(
level=level,
dilation_rates=aspp_dilation_rates,
pool_kernel_size=pool_kernel_size)),
head=SegmentationHead(level=level, num_convs=0),
norm_activation=common.NormActivation(
activation='relu',
norm_momentum=0.99,
norm_epsilon=1e-3,
use_sync_bn=True)),
losses=Losses(l2_weight_decay=4e-5),
train_data=DataConfig(
input_path=os.path.join(CITYSCAPES_INPUT_PATH_BASE,
'train_fine**'),
crop_size=[512, 1024],
output_size=[1024, 2048],
is_training=True,
global_batch_size=train_batch_size,
aug_scale_min=0.5,
aug_scale_max=2.0),
validation_data=DataConfig(
input_path=os.path.join(CITYSCAPES_INPUT_PATH_BASE, 'val_fine*'),
output_size=[1024, 2048],
is_training=False,
global_batch_size=eval_batch_size,
resize_eval_groundtruth=True,
drop_remainder=False),
# Coco pre-trained mobilenetv2 checkpoint
init_checkpoint='gs://tf_model_garden/cloud/vision-2.0/deeplab/deeplabv3_mobilenetv2_coco/best_ckpt-63',
init_checkpoint_modules='backbone'),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=100000,
validation_steps=CITYSCAPES_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
best_checkpoint_eval_metric='mean_iou',
best_checkpoint_export_subdir='best_ckpt',
best_checkpoint_metric_comp='higher',
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'sgd',
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'polynomial',
'polynomial': {
'initial_learning_rate': 0.01,
'decay_steps': 100000,
'end_learning_rate': 0.0,
'power': 0.9
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 5 * steps_per_epoch,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('mnv2_deeplabv3plus_cityscapes')
def mnv2_deeplabv3plus_cityscapes() -> cfg.ExperimentConfig:
"""Image segmentation on cityscapes with mobilenetv2 deeplabv3plus."""
config = mnv2_deeplabv3_cityscapes()
config.task.model.head = SegmentationHead(
level=4,
num_convs=2,
feature_fusion='deeplabv3plus',
use_depthwise_convolution=True,
low_level='2/depthwise',
low_level_num_filters=48)
config.task.model.backbone.mobilenet.output_intermediate_endpoints = True
return config
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
"""Contains definitions of MobileNet Networks.""" """Contains definitions of MobileNet Networks."""
import dataclasses
from typing import Optional, Dict, Any, Tuple from typing import Optional, Dict, Any, Tuple
# Import libraries # Import libraries
import dataclasses
import tensorflow as tf import tensorflow as tf
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling import tf_utils from official.modeling import tf_utils
...@@ -517,6 +517,7 @@ class MobileNet(tf.keras.Model): ...@@ -517,6 +517,7 @@ class MobileNet(tf.keras.Model):
use_sync_bn: bool = False, use_sync_bn: bool = False,
# finegrain is not used in MobileNetV1. # finegrain is not used in MobileNetV1.
finegrain_classification_mode: bool = True, finegrain_classification_mode: bool = True,
output_intermediate_endpoints: bool = False,
**kwargs): **kwargs):
"""Initializes a MobileNet model. """Initializes a MobileNet model.
...@@ -554,6 +555,8 @@ class MobileNet(tf.keras.Model): ...@@ -554,6 +555,8 @@ class MobileNet(tf.keras.Model):
finegrain_classification_mode: If True, the model will keep the last layer finegrain_classification_mode: If True, the model will keep the last layer
large even for small multipliers, following large even for small multipliers, following
https://arxiv.org/abs/1801.04381. https://arxiv.org/abs/1801.04381.
output_intermediate_endpoints: A `bool` of whether or not output the
intermediate endpoints.
**kwargs: Additional keyword arguments to be passed. **kwargs: Additional keyword arguments to be passed.
""" """
if model_id not in SUPPORTED_SPECS_MAP: if model_id not in SUPPORTED_SPECS_MAP:
...@@ -586,6 +589,7 @@ class MobileNet(tf.keras.Model): ...@@ -586,6 +589,7 @@ class MobileNet(tf.keras.Model):
self._norm_momentum = norm_momentum self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon self._norm_epsilon = norm_epsilon
self._finegrain_classification_mode = finegrain_classification_mode self._finegrain_classification_mode = finegrain_classification_mode
self._output_intermediate_endpoints = output_intermediate_endpoints
inputs = tf.keras.Input(shape=input_specs.shape[1:]) inputs = tf.keras.Input(shape=input_specs.shape[1:])
...@@ -658,6 +662,7 @@ class MobileNet(tf.keras.Model): ...@@ -658,6 +662,7 @@ class MobileNet(tf.keras.Model):
layer_rate = 1 layer_rate = 1
current_stride *= block_def.strides current_stride *= block_def.strides
intermediate_endpoints = {}
if block_def.block_fn == 'convbn': if block_def.block_fn == 'convbn':
net = Conv2DBNBlock( net = Conv2DBNBlock(
...@@ -679,7 +684,7 @@ class MobileNet(tf.keras.Model): ...@@ -679,7 +684,7 @@ class MobileNet(tf.keras.Model):
net = nn_blocks.DepthwiseSeparableConvBlock( net = nn_blocks.DepthwiseSeparableConvBlock(
filters=block_def.filters, filters=block_def.filters,
kernel_size=block_def.kernel_size, kernel_size=block_def.kernel_size,
strides=block_def.strides, strides=layer_stride,
activation=block_def.activation, activation=block_def.activation,
dilation_rate=layer_rate, dilation_rate=layer_rate,
regularize_depthwise=self._regularize_depthwise, regularize_depthwise=self._regularize_depthwise,
...@@ -701,7 +706,7 @@ class MobileNet(tf.keras.Model): ...@@ -701,7 +706,7 @@ class MobileNet(tf.keras.Model):
# any 1x1 convolution). # any 1x1 convolution).
use_rate = layer_rate use_rate = layer_rate
in_filters = net.shape.as_list()[-1] in_filters = net.shape.as_list()[-1]
net = nn_blocks.InvertedBottleneckBlock( block = nn_blocks.InvertedBottleneckBlock(
in_filters=in_filters, in_filters=in_filters,
out_filters=block_def.filters, out_filters=block_def.filters,
kernel_size=block_def.kernel_size, kernel_size=block_def.kernel_size,
...@@ -722,8 +727,13 @@ class MobileNet(tf.keras.Model): ...@@ -722,8 +727,13 @@ class MobileNet(tf.keras.Model):
norm_momentum=self._norm_momentum, norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon, norm_epsilon=self._norm_epsilon,
stochastic_depth_drop_rate=self._stochastic_depth_drop_rate, stochastic_depth_drop_rate=self._stochastic_depth_drop_rate,
divisible_by=self._get_divisible_by() divisible_by=self._get_divisible_by(),
)(net) output_intermediate_endpoints=self._output_intermediate_endpoints,
)
if self._output_intermediate_endpoints:
net, intermediate_endpoints = block(net)
else:
net = block(net)
elif block_def.block_fn == 'gpooling': elif block_def.block_fn == 'gpooling':
net = layers.GlobalAveragePooling2D()(net) net = layers.GlobalAveragePooling2D()(net)
...@@ -737,8 +747,13 @@ class MobileNet(tf.keras.Model): ...@@ -737,8 +747,13 @@ class MobileNet(tf.keras.Model):
if block_def.is_output: if block_def.is_output:
endpoints[str(endpoint_level)] = net endpoints[str(endpoint_level)] = net
for key, tensor in intermediate_endpoints.items():
endpoints[str(endpoint_level) + '/' + key] = tensor
if current_stride != self._output_stride:
endpoint_level += 1 endpoint_level += 1
if str(endpoint_level) in endpoints:
endpoint_level += 1
return net, endpoints, endpoint_level return net, endpoints, endpoint_level
def get_config(self): def get_config(self):
...@@ -788,6 +803,8 @@ def build_mobilenet( ...@@ -788,6 +803,8 @@ def build_mobilenet(
filter_size_scale=backbone_cfg.filter_size_scale, filter_size_scale=backbone_cfg.filter_size_scale,
input_specs=input_specs, input_specs=input_specs,
stochastic_depth_drop_rate=backbone_cfg.stochastic_depth_drop_rate, stochastic_depth_drop_rate=backbone_cfg.stochastic_depth_drop_rate,
output_stride=backbone_cfg.output_stride,
output_intermediate_endpoints=backbone_cfg.output_intermediate_endpoints,
use_sync_bn=norm_activation_config.use_sync_bn, use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum, norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon, norm_epsilon=norm_activation_config.norm_epsilon,
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
"""Tests for MobileNet.""" """Tests for MobileNet."""
import itertools import itertools
import math
# Import libraries # Import libraries
from absl.testing import parameterized from absl.testing import parameterized
...@@ -131,6 +133,51 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -131,6 +133,51 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
[1, input_size / 2 ** (idx+2), input_size / 2 ** (idx+2), num_filter], [1, input_size / 2 ** (idx+2), input_size / 2 ** (idx+2), num_filter],
endpoints[str(idx+2)].shape.as_list()) endpoints[str(idx+2)].shape.as_list())
@parameterized.parameters(
itertools.product(
[
'MobileNetV1',
'MobileNetV2',
'MobileNetV3Large',
'MobileNetV3Small',
'MobileNetV3EdgeTPU',
'MobileNetMultiAVG',
'MobileNetMultiMAX',
],
[32, 224],
))
def test_mobilenet_intermediate_layers(self, model_id, input_size):
tf.keras.backend.set_image_data_format('channels_last')
# Tests the mobilenet intermediate depthwise layers.
mobilenet_depthwise_layers = {
# The number of filters of depthwise layers having outputs been
# collected for filter_size_scale = 1.0. Only tests the mobilenet
# model with inverted bottleneck block using depthwise which excludes
# MobileNetV1.
'MobileNetV1': [],
'MobileNetV2': [144, 192, 576, 960],
'MobileNetV3Small': [16, 88, 144, 576],
'MobileNetV3Large': [72, 120, 672, 960],
'MobileNetV3EdgeTPU': [None, None, 384, 1280],
'MobileNetMultiMAX': [96, 128, 384, 640],
'MobileNetMultiAVG': [64, 192, 640, 768],
}
network = mobilenet.MobileNet(model_id=model_id,
filter_size_scale=1.0,
output_intermediate_endpoints=True)
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
endpoints = network(inputs)
for idx, num_filter in enumerate(mobilenet_depthwise_layers[model_id]):
# Not using depthwise conv in this layer.
if num_filter is None:
continue
self.assertAllEqual(
[1, input_size / 2**(idx + 2), input_size / 2**(idx + 2), num_filter],
endpoints[str(idx + 2) + '/depthwise'].shape.as_list())
@parameterized.parameters( @parameterized.parameters(
itertools.product( itertools.product(
[ [
...@@ -173,5 +220,47 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -173,5 +220,47 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1) inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
_ = network(inputs) _ = network(inputs)
@parameterized.parameters(
itertools.product(
[
'MobileNetV1',
'MobileNetV2',
'MobileNetV3Large',
'MobileNetV3Small',
'MobileNetV3EdgeTPU',
'MobileNetMultiAVG',
'MobileNetMultiMAX',
],
[8, 16, 32],
))
def test_mobilenet_output_stride(self, model_id, output_stride):
"""Test for creation of a MobileNet with different output strides."""
tf.keras.backend.set_image_data_format('channels_last')
mobilenet_layers = {
# The number of filters of the layers outputs been collected
# for filter_size_scale = 1.0.
'MobileNetV1': 1024,
'MobileNetV2': 320,
'MobileNetV3Small': 96,
'MobileNetV3Large': 160,
'MobileNetV3EdgeTPU': 192,
'MobileNetMultiMAX': 160,
'MobileNetMultiAVG': 192,
}
network = mobilenet.MobileNet(
model_id=model_id, filter_size_scale=1.0, output_stride=output_stride)
level = int(math.log2(output_stride))
input_size = 224
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
endpoints = network(inputs)
num_filter = mobilenet_layers[model_id]
self.assertAllEqual(
[1, input_size / output_stride, input_size / output_stride, num_filter],
endpoints[str(level)].shape.as_list())
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -495,6 +495,7 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer): ...@@ -495,6 +495,7 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
use_residual=True, use_residual=True,
norm_momentum=0.99, norm_momentum=0.99,
norm_epsilon=0.001, norm_epsilon=0.001,
output_intermediate_endpoints=False,
**kwargs): **kwargs):
"""Initializes an inverted bottleneck block with BN after convolutions. """Initializes an inverted bottleneck block with BN after convolutions.
...@@ -537,6 +538,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer): ...@@ -537,6 +538,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
input and output. input and output.
norm_momentum: A `float` of normalization momentum for the moving average. norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero. norm_epsilon: A `float` added to variance to avoid dividing by zero.
output_intermediate_endpoints: A `bool` of whether or not output the
intermediate endpoints.
**kwargs: Additional keyword arguments to be passed. **kwargs: Additional keyword arguments to be passed.
""" """
super(InvertedBottleneckBlock, self).__init__(**kwargs) super(InvertedBottleneckBlock, self).__init__(**kwargs)
...@@ -564,6 +567,7 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer): ...@@ -564,6 +567,7 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer self._bias_regularizer = bias_regularizer
self._expand_se_in_filters = expand_se_in_filters self._expand_se_in_filters = expand_se_in_filters
self._output_intermediate_endpoints = output_intermediate_endpoints
if use_sync_bn: if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization self._norm = tf.keras.layers.experimental.SyncBatchNormalization
...@@ -698,6 +702,7 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer): ...@@ -698,6 +702,7 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs, training=None): def call(self, inputs, training=None):
endpoints = {}
shortcut = inputs shortcut = inputs
if self._expand_ratio > 1: if self._expand_ratio > 1:
x = self._conv0(inputs) x = self._conv0(inputs)
...@@ -710,6 +715,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer): ...@@ -710,6 +715,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
x = self._conv1(x) x = self._conv1(x)
x = self._norm1(x) x = self._norm1(x)
x = self._depthwise_activation_layer(x) x = self._depthwise_activation_layer(x)
if self._output_intermediate_endpoints:
endpoints['depthwise'] = x
if self._squeeze_excitation: if self._squeeze_excitation:
x = self._squeeze_excitation(x) x = self._squeeze_excitation(x)
...@@ -724,6 +731,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer): ...@@ -724,6 +731,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
x = self._stochastic_depth(x, training=training) x = self._stochastic_depth(x, training=training)
x = self._add([x, shortcut]) x = self._add([x, shortcut])
if self._output_intermediate_endpoints:
return x, endpoints
return x return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment