Unverified Commit 09d9656f authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'panoptic-segmentation' into panoptic-deeplab-modeling

parents ac671306 49a5706c
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Definitions of MobileDet Networks."""
import dataclasses
from typing import Any, Dict, Optional, Tuple, List
import tensorflow as tf
from official.modeling import hyperparams
from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.backbones import mobilenet
from official.vision.beta.modeling.layers import nn_blocks
from official.vision.beta.modeling.layers import nn_layers
layers = tf.keras.layers
# pylint: disable=pointless-string-statement
"""
Architecture: https://arxiv.org/abs/1704.04861.
"MobileDets: Searching for Object Detection Architectures for
Mobile Accelerators" Yunyang Xiong, Hanxiao Liu, Suyog Gupta, Berkin Akin,
Gabriel Bender, Yongzhe Wang, Pieter-Jan Kindermans, Mingxing Tan, Vikas Singh,
Bo Chen
Note that `round_down_protection` flag should be set to false for scaling
of the network.
"""
MD_CPU_BLOCK_SPECS = {
'spec_name': 'MobileDetCPU',
# [expand_ratio] is set to 1 and [use_residual] is set to false
# for inverted_bottleneck_no_expansion
# [se_ratio] is set to 0.25 for all inverted_bottleneck layers
# [activation] is set to 'hard_swish' for all applicable layers
'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters',
'activation', 'se_ratio', 'expand_ratio',
'use_residual', 'is_output'],
'block_specs': [
('convbn', 3, 2, 16, 'hard_swish', None, None, None, False),
# inverted_bottleneck_no_expansion
('invertedbottleneck', 3, 1, 8, 'hard_swish', 0.25, 1., False, True),
('invertedbottleneck', 3, 2, 16, 'hard_swish', 0.25, 4., False, True),
('invertedbottleneck', 3, 2, 32, 'hard_swish', 0.25, 8., False, False),
('invertedbottleneck', 3, 1, 32, 'hard_swish', 0.25, 4., True, False),
('invertedbottleneck', 3, 1, 32, 'hard_swish', 0.25, 4., True, False),
('invertedbottleneck', 3, 1, 32, 'hard_swish', 0.25, 4., True, True),
('invertedbottleneck', 5, 2, 72, 'hard_swish', 0.25, 8., False, False),
('invertedbottleneck', 3, 1, 72, 'hard_swish', 0.25, 8., True, False),
('invertedbottleneck', 5, 1, 72, 'hard_swish', 0.25, 4., True, False),
('invertedbottleneck', 3, 1, 72, 'hard_swish', 0.25, 4., True, False),
('invertedbottleneck', 3, 1, 72, 'hard_swish', 0.25, 8., False, False),
('invertedbottleneck', 3, 1, 72, 'hard_swish', 0.25, 8., True, False),
('invertedbottleneck', 3, 1, 72, 'hard_swish', 0.25, 8., True, False),
('invertedbottleneck', 3, 1, 72, 'hard_swish', 0.25, 8., True, True),
('invertedbottleneck', 5, 2, 104, 'hard_swish', 0.25, 8., False, False),
('invertedbottleneck', 5, 1, 104, 'hard_swish', 0.25, 4., True, False),
('invertedbottleneck', 5, 1, 104, 'hard_swish', 0.25, 4., True, False),
('invertedbottleneck', 3, 1, 104, 'hard_swish', 0.25, 4., True, False),
('invertedbottleneck', 3, 1, 144, 'hard_swish', 0.25, 8., False, True),
]
}
MD_DSP_BLOCK_SPECS = {
'spec_name': 'MobileDetDSP',
# [expand_ratio] is set to 1 and [use_residual] is set to false
# for inverted_bottleneck_no_expansion
# [use_depthwise] is set to False for fused_conv
# [se_ratio] is set to None for all inverted_bottleneck layers
# [activation] is set to 'relu6' for all applicable layers
'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters',
'activation', 'se_ratio', 'expand_ratio',
'input_compression_ratio', 'output_compression_ratio',
'use_depthwise', 'use_residual', 'is_output'],
'block_specs': [
('convbn', 3, 2, 32, 'relu6',
None, None, None, None, None, None, False),
# inverted_bottleneck_no_expansion
('invertedbottleneck', 3, 1, 24, 'relu6',
None, 1., None, None, True, False, True),
('invertedbottleneck', 3, 2, 32, 'relu6',
None, 4., None, None, False, False, False), # fused_conv
('invertedbottleneck', 3, 1, 32, 'relu6',
None, 4., None, None, False, True, False), # fused_conv
('invertedbottleneck', 3, 1, 32, 'relu6',
None, 4., None, None, True, True, False),
('tucker', 3, 1, 32, 'relu6',
None, None, 0.25, 0.75, None, True, True),
('invertedbottleneck', 3, 2, 64, 'relu6',
None, 8., None, None, False, False, False), # fused_conv
('invertedbottleneck', 3, 1, 64, 'relu6',
None, 4., None, None, True, True, False),
('invertedbottleneck', 3, 1, 64, 'relu6',
None, 4., None, None, False, True, False), # fused_conv
('invertedbottleneck', 3, 1, 64, 'relu6',
None, 4., None, None, False, True, True), # fused_conv
('invertedbottleneck', 3, 2, 120, 'relu6',
None, 8., None, None, False, False, False), # fused_conv
('invertedbottleneck', 3, 1, 120, 'relu6',
None, 4., None, None, True, True, False),
('invertedbottleneck', 3, 1, 120, 'relu6',
None, 8, None, None, True, True, False),
('invertedbottleneck', 3, 1, 120, 'relu6',
None, 8., None, None, True, True, False),
('invertedbottleneck', 3, 1, 144, 'relu6',
None, 8., None, None, False, False, False), # fused_conv
('invertedbottleneck', 3, 1, 144, 'relu6',
None, 8., None, None, True, True, False),
('invertedbottleneck', 3, 1, 144, 'relu6',
None, 8, None, None, True, True, False),
('invertedbottleneck', 3, 1, 144, 'relu6',
None, 8., None, None, True, True, True),
('invertedbottleneck', 3, 2, 160, 'relu6',
None, 4, None, None, True, False, False),
('invertedbottleneck', 3, 1, 160, 'relu6',
None, 4, None, None, True, True, False),
('invertedbottleneck', 3, 1, 160, 'relu6',
None, 4., None, None, False, False, False), # fused_conv
('tucker', 3, 1, 160, 'relu6',
None, None, 0.75, 0.75, None, True, False),
('invertedbottleneck', 3, 1, 240, 'relu6',
None, 8, None, None, True, False, True),
]
}
MD_EdgeTPU_BLOCK_SPECS = {
'spec_name': 'MobileDetEdgeTPU',
# [use_depthwise] is set to False for fused_conv
# [se_ratio] is set to None for all inverted_bottleneck layers
# [activation] is set to 'relu6' for all applicable layers
'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters',
'activation', 'se_ratio', 'expand_ratio',
'input_compression_ratio', 'output_compression_ratio',
'use_depthwise', 'use_residual', 'is_output'],
'block_specs': [
('convbn', 3, 2, 32, 'relu6',
None, None, None, None, None, None, False),
('tucker', 3, 1, 16, 'relu6',
None, None, 0.25, 0.75, None, False, True),
('invertedbottleneck', 3, 2, 16, 'relu6',
None, 8., None, None, False, False, False), # fused_conv
('invertedbottleneck', 3, 1, 16, 'relu6',
None, 4., None, None, False, True, False), # fused_conv
('invertedbottleneck', 3, 1, 16, 'relu6',
None, 8., None, None, False, True, False), # fused_conv
('invertedbottleneck', 3, 1, 16, 'relu6',
None, 4., None, None, False, True, True), # fused_conv
('invertedbottleneck', 5, 2, 40, 'relu6',
None, 8., None, None, False, False, False), # fused_conv
('invertedbottleneck', 3, 1, 40, 'relu6',
None, 4., None, None, False, True, False), # fused_conv
('invertedbottleneck', 3, 1, 40, 'relu6',
None, 4., None, None, False, True, False), # fused_conv
('invertedbottleneck', 3, 1, 40, 'relu6',
None, 4., None, None, False, True, True), # fused_conv
('invertedbottleneck', 3, 2, 72, 'relu6',
None, 8, None, None, True, False, False),
('invertedbottleneck', 3, 1, 72, 'relu6',
None, 8, None, None, True, True, False),
('invertedbottleneck', 3, 1, 72, 'relu6',
None, 4., None, None, False, True, False), # fused_conv
('invertedbottleneck', 3, 1, 72, 'relu6',
None, 4., None, None, False, True, False), # fused_conv
('invertedbottleneck', 5, 1, 96, 'relu6',
None, 8, None, None, True, False, False),
('invertedbottleneck', 5, 1, 96, 'relu6',
None, 8, None, None, True, True, False),
('invertedbottleneck', 3, 1, 96, 'relu6',
None, 8, None, None, True, True, False),
('invertedbottleneck', 3, 1, 96, 'relu6',
None, 8, None, None, True, True, True),
('invertedbottleneck', 5, 2, 120, 'relu6',
None, 8, None, None, True, False, False),
('invertedbottleneck', 3, 1, 120, 'relu6',
None, 8, None, None, True, True, False),
('invertedbottleneck', 5, 1, 120, 'relu6',
None, 4, None, None, True, True, False),
('invertedbottleneck', 3, 1, 120, 'relu6',
None, 8, None, None, True, True, False),
('invertedbottleneck', 5, 1, 384, 'relu6',
None, 8, None, None, True, False, True),
]
}
MD_GPU_BLOCK_SPECS = {
'spec_name': 'MobileDetGPU',
# [use_depthwise] is set to False for fused_conv
# [se_ratio] is set to None for all inverted_bottleneck layers
# [activation] is set to 'relu6' for all applicable layers
'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters',
'activation', 'se_ratio', 'expand_ratio',
'input_compression_ratio', 'output_compression_ratio',
'use_depthwise', 'use_residual', 'is_output'],
'block_specs': [
# block 0
('convbn', 3, 2, 32, 'relu6',
None, None, None, None, None, None, False),
# block 1
('tucker', 3, 1, 16, 'relu6',
None, None, 0.25, 0.25, None, False, True),
# block 2
('invertedbottleneck', 3, 2, 32, 'relu6',
None, 8., None, None, False, False, False), # fused_conv
('tucker', 3, 1, 32, 'relu6',
None, None, 0.25, 0.25, None, True, False),
('tucker', 3, 1, 32, 'relu6',
None, None, 0.25, 0.25, None, True, False),
('tucker', 3, 1, 32, 'relu6',
None, None, 0.25, 0.25, None, True, True),
# block 3
('invertedbottleneck', 3, 2, 64, 'relu6',
None, 8., None, None, False, False, False), # fused_conv
('invertedbottleneck', 3, 1, 64, 'relu6',
None, 8., None, None, False, True, False), # fused_conv
('invertedbottleneck', 3, 1, 64, 'relu6',
None, 8., None, None, False, True, False), # fused_conv
('invertedbottleneck', 3, 1, 64, 'relu6',
None, 4., None, None, False, True, True), # fused_conv
# block 4
('invertedbottleneck', 3, 2, 128, 'relu6',
None, 8., None, None, False, False, False), # fused_conv
('invertedbottleneck', 3, 1, 128, 'relu6',
None, 4., None, None, False, True, False), # fused_conv
('invertedbottleneck', 3, 1, 128, 'relu6',
None, 4., None, None, False, True, False), # fused_conv
('invertedbottleneck', 3, 1, 128, 'relu6',
None, 4., None, None, False, True, False), # fused_conv
# block 5
('invertedbottleneck', 3, 1, 128, 'relu6',
None, 8., None, None, False, False, False), # fused_conv
('invertedbottleneck', 3, 1, 128, 'relu6',
None, 8., None, None, False, True, False), # fused_conv
('invertedbottleneck', 3, 1, 128, 'relu6',
None, 8., None, None, False, True, False), # fused_conv
('invertedbottleneck', 3, 1, 128, 'relu6',
None, 8., None, None, False, True, True), # fused_conv
# block 6
('invertedbottleneck', 3, 2, 128, 'relu6',
None, 4., None, None, False, False, False), # fused_conv
('invertedbottleneck', 3, 1, 128, 'relu6',
None, 4., None, None, False, True, False), # fused_conv
('invertedbottleneck', 3, 1, 128, 'relu6',
None, 4., None, None, False, True, False), # fused_conv
('invertedbottleneck', 3, 1, 128, 'relu6',
None, 4., None, None, False, True, False), # fused_conv
# block 7
('invertedbottleneck', 3, 1, 384, 'relu6',
None, 8, None, None, True, False, True),
]
}
SUPPORTED_SPECS_MAP = {
'MobileDetCPU': MD_CPU_BLOCK_SPECS,
'MobileDetDSP': MD_DSP_BLOCK_SPECS,
'MobileDetEdgeTPU': MD_EdgeTPU_BLOCK_SPECS,
'MobileDetGPU': MD_GPU_BLOCK_SPECS,
}
@dataclasses.dataclass
class BlockSpec(hyperparams.Config):
"""A container class that specifies the block configuration for MobileDet."""
block_fn: str = 'convbn'
kernel_size: int = 3
strides: int = 1
filters: int = 32
use_bias: bool = False
use_normalization: bool = True
activation: str = 'relu6'
is_output: bool = True
# Used for block type InvertedResConv and TuckerConvBlock.
use_residual: bool = True
# Used for block type InvertedResConv only.
use_depthwise: bool = True
expand_ratio: Optional[float] = 8.
se_ratio: Optional[float] = None
# Used for block type TuckerConvBlock only.
input_compression_ratio: Optional[float] = None
output_compression_ratio: Optional[float] = None
def block_spec_decoder(
specs: Dict[Any, Any],
filter_size_scale: float,
divisible_by: int = 8) -> List[BlockSpec]:
"""Decodes specs for a block.
Args:
specs: A `dict` specification of block specs of a mobiledet version.
filter_size_scale: A `float` multiplier for the filter size for all
convolution ops. The value must be greater than zero. Typical usage will
be to set this value in (0, 1) to reduce the number of parameters or
computation cost of the model.
divisible_by: An `int` that ensures all inner dimensions are divisible by
this number.
Returns:
A list of `BlockSpec` that defines structure of the base network.
"""
spec_name = specs['spec_name']
block_spec_schema = specs['block_spec_schema']
block_specs = specs['block_specs']
if not block_specs:
raise ValueError(
'The block spec cannot be empty for {} !'.format(spec_name))
if len(block_specs[0]) != len(block_spec_schema):
raise ValueError('The block spec values {} do not match with '
'the schema {}'.format(block_specs[0], block_spec_schema))
decoded_specs = []
for s in block_specs:
kw_s = dict(zip(block_spec_schema, s))
decoded_specs.append(BlockSpec(**kw_s))
for ds in decoded_specs:
if ds.filters:
ds.filters = nn_layers.round_filters(filters=ds.filters,
multiplier=filter_size_scale,
divisor=divisible_by,
round_down_protect=False,
min_depth=8)
return decoded_specs
@tf.keras.utils.register_keras_serializable(package='Vision')
class MobileDet(tf.keras.Model):
"""Creates a MobileDet family model."""
def __init__(
self,
model_id: str = 'MobileDetCPU',
filter_size_scale: float = 1.0,
input_specs: tf.keras.layers.InputSpec = layers.InputSpec(
shape=[None, None, None, 3]),
# The followings are for hyper-parameter tuning.
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
kernel_initializer: str = 'VarianceScaling',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
# The followings should be kept the same most of the times.
min_depth: int = 8,
divisible_by: int = 8,
regularize_depthwise: bool = False,
use_sync_bn: bool = False,
**kwargs):
"""Initializes a MobileDet model.
Args:
model_id: A `str` of MobileDet version. The supported values are
`MobileDetCPU`, `MobileDetDSP`, `MobileDetEdgeTPU`, `MobileDetGPU`.
filter_size_scale: A `float` of multiplier for the filters (number of
channels) for all convolution ops. The value must be greater than zero.
Typical usage will be to set this value in (0, 1) to reduce the number
of parameters or computation cost of the model.
input_specs: A `tf.keras.layers.InputSpec` of specs of the input tensor.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero.
kernel_initializer: A `str` for kernel initializer of convolutional
layers.
kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
Conv2D. Default to None.
bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
Default to None.
min_depth: An `int` of minimum depth (number of channels) for all
convolution ops. Enforced when filter_size_scale < 1, and not an active
constraint when filter_size_scale >= 1.
divisible_by: An `int` that ensures all inner dimensions are divisible by
this number.
regularize_depthwise: If Ture, apply regularization on depthwise.
use_sync_bn: If True, use synchronized batch normalization.
**kwargs: Additional keyword arguments to be passed.
"""
if model_id not in SUPPORTED_SPECS_MAP:
raise ValueError('The MobileDet version {} '
'is not supported'.format(model_id))
if filter_size_scale <= 0:
raise ValueError('filter_size_scale is not greater than zero.')
self._model_id = model_id
self._input_specs = input_specs
self._filter_size_scale = filter_size_scale
self._min_depth = min_depth
self._divisible_by = divisible_by
self._regularize_depthwise = regularize_depthwise
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
inputs = tf.keras.Input(shape=input_specs.shape[1:])
block_specs = SUPPORTED_SPECS_MAP.get(model_id)
self._decoded_specs = block_spec_decoder(
specs=block_specs,
filter_size_scale=self._filter_size_scale,
divisible_by=self._get_divisible_by())
x, endpoints, next_endpoint_level = self._mobiledet_base(inputs=inputs)
self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}
super(MobileDet, self).__init__(
inputs=inputs, outputs=endpoints, **kwargs)
def _get_divisible_by(self):
return self._divisible_by
def _mobiledet_base(self,
inputs: tf.Tensor
) -> Tuple[tf.Tensor, Dict[str, tf.Tensor], int]:
"""Builds the base MobileDet architecture.
Args:
inputs: A `tf.Tensor` of shape `[batch_size, height, width, channels]`.
Returns:
A tuple of output Tensor and dictionary that collects endpoints.
"""
input_shape = inputs.get_shape().as_list()
if len(input_shape) != 4:
raise ValueError('Expected rank 4 input, was: %d' % len(input_shape))
net = inputs
endpoints = {}
endpoint_level = 1
for i, block_def in enumerate(self._decoded_specs):
block_name = 'block_group_{}_{}'.format(block_def.block_fn, i)
if block_def.block_fn == 'convbn':
net = mobilenet.Conv2DBNBlock(
filters=block_def.filters,
kernel_size=block_def.kernel_size,
strides=block_def.strides,
activation=block_def.activation,
use_bias=block_def.use_bias,
use_normalization=block_def.use_normalization,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon
)(net)
elif block_def.block_fn == 'invertedbottleneck':
in_filters = net.shape.as_list()[-1]
net = nn_blocks.InvertedBottleneckBlock(
in_filters=in_filters,
out_filters=block_def.filters,
kernel_size=block_def.kernel_size,
strides=block_def.strides,
expand_ratio=block_def.expand_ratio,
se_ratio=block_def.se_ratio,
se_inner_activation=block_def.activation,
se_gating_activation='sigmoid',
se_round_down_protect=False,
expand_se_in_filters=True,
activation=block_def.activation,
use_depthwise=block_def.use_depthwise,
use_residual=block_def.use_residual,
regularize_depthwise=self._regularize_depthwise,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon,
divisible_by=self._get_divisible_by()
)(net)
elif block_def.block_fn == 'tucker':
in_filters = net.shape.as_list()[-1]
net = nn_blocks.TuckerConvBlock(
in_filters=in_filters,
out_filters=block_def.filters,
kernel_size=block_def.kernel_size,
strides=block_def.strides,
input_compression_ratio=block_def.input_compression_ratio,
output_compression_ratio=block_def.output_compression_ratio,
activation=block_def.activation,
use_residual=block_def.use_residual,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon,
divisible_by=self._get_divisible_by()
)(net)
else:
raise ValueError('Unknown block type {} for layer {}'.format(
block_def.block_fn, i))
net = tf.keras.layers.Activation('linear', name=block_name)(net)
if block_def.is_output:
endpoints[str(endpoint_level)] = net
endpoint_level += 1
return net, endpoints, endpoint_level
def get_config(self):
config_dict = {
'model_id': self._model_id,
'filter_size_scale': self._filter_size_scale,
'min_depth': self._min_depth,
'divisible_by': self._divisible_by,
'regularize_depthwise': self._regularize_depthwise,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon,
}
return config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@property
def output_specs(self):
"""A dict of {level: TensorShape} pairs for the model output."""
return self._output_specs
@factory.register_backbone_builder('mobiledet')
def build_mobiledet(
input_specs: tf.keras.layers.InputSpec,
backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None
) -> tf.keras.Model:
"""Builds MobileDet backbone from a config."""
backbone_type = backbone_config.type
backbone_cfg = backbone_config.get()
assert backbone_type == 'mobiledet', (f'Inconsistent backbone type '
f'{backbone_type}')
return MobileDet(
model_id=backbone_cfg.model_id,
filter_size_scale=backbone_cfg.filter_size_scale,
input_specs=input_specs,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Mobiledet."""
import itertools
from absl.testing import parameterized
import tensorflow as tf
from official.vision.beta.modeling.backbones import mobiledet
class MobileDetTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
'MobileDetCPU',
'MobileDetDSP',
'MobileDetEdgeTPU',
'MobileDetGPU',
)
def test_serialize_deserialize(self, model_id):
# Create a network object that sets all of its config options.
kwargs = dict(
model_id=model_id,
filter_size_scale=1.0,
use_sync_bn=False,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
norm_momentum=0.99,
norm_epsilon=0.001,
min_depth=8,
divisible_by=8,
regularize_depthwise=False,
)
network = mobiledet.MobileDet(**kwargs)
expected_config = dict(kwargs)
self.assertEqual(network.get_config(), expected_config)
# Create another network object from the first object's config.
new_network = mobiledet.MobileDet.from_config(network.get_config())
# Validate that the config can be forced to JSON.
_ = new_network.to_json()
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(network.get_config(), new_network.get_config())
@parameterized.parameters(
itertools.product(
[1, 3],
[
'MobileDetCPU',
'MobileDetDSP',
'MobileDetEdgeTPU',
'MobileDetGPU',
],
))
def test_input_specs(self, input_dim, model_id):
"""Test different input feature dimensions."""
tf.keras.backend.set_image_data_format('channels_last')
input_specs = tf.keras.layers.InputSpec(shape=[None, None, None, input_dim])
network = mobiledet.MobileDet(model_id=model_id, input_specs=input_specs)
inputs = tf.keras.Input(shape=(128, 128, input_dim), batch_size=1)
_ = network(inputs)
@parameterized.parameters(
itertools.product(
[
'MobileDetCPU',
'MobileDetDSP',
'MobileDetEdgeTPU',
'MobileDetGPU',
],
[32, 224],
))
def test_mobiledet_creation(self, model_id, input_size):
"""Test creation of MobileDet family models."""
tf.keras.backend.set_image_data_format('channels_last')
mobiledet_layers = {
# The number of filters of layers having outputs been collected
# for filter_size_scale = 1.0
'MobileDetCPU': [8, 16, 32, 72, 144],
'MobileDetDSP': [24, 32, 64, 144, 240],
'MobileDetEdgeTPU': [16, 16, 40, 96, 384],
'MobileDetGPU': [16, 32, 64, 128, 384],
}
network = mobiledet.MobileDet(model_id=model_id,
filter_size_scale=1.0)
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
endpoints = network(inputs)
for idx, num_filter in enumerate(mobiledet_layers[model_id]):
self.assertAllEqual(
[1, input_size / 2 ** (idx+1), input_size / 2 ** (idx+1), num_filter],
endpoints[str(idx+1)].shape.as_list())
...@@ -420,7 +420,8 @@ MNMultiAVG_BLOCK_SPECS = { ...@@ -420,7 +420,8 @@ MNMultiAVG_BLOCK_SPECS = {
# Similar to MobileNetMultiAVG and used for segmentation task. # Similar to MobileNetMultiAVG and used for segmentation task.
# Reduced the filters by a factor of 2 in the last block. # Reduced the filters by a factor of 2 in the last block.
MNMultiAVG_SEG_BLOCK_SPECS = { MNMultiAVG_SEG_BLOCK_SPECS = {
'spec_name': 'MobileNetMultiAVGSeg', 'spec_name':
'MobileNetMultiAVGSeg',
'block_spec_schema': [ 'block_spec_schema': [
'block_fn', 'kernel_size', 'strides', 'filters', 'activation', 'block_fn', 'kernel_size', 'strides', 'filters', 'activation',
'expand_ratio', 'use_normalization', 'use_bias', 'is_output' 'expand_ratio', 'use_normalization', 'use_bias', 'is_output'
...@@ -443,7 +444,7 @@ MNMultiAVG_SEG_BLOCK_SPECS = { ...@@ -443,7 +444,7 @@ MNMultiAVG_SEG_BLOCK_SPECS = {
('invertedbottleneck', 5, 1, 96, 'relu', 2., True, False, False), ('invertedbottleneck', 5, 1, 96, 'relu', 2., True, False, False),
('invertedbottleneck', 5, 1, 96, 'relu', 4., True, False, False), ('invertedbottleneck', 5, 1, 96, 'relu', 4., True, False, False),
('invertedbottleneck', 5, 1, 96, 'relu', 4., True, False, True), ('invertedbottleneck', 5, 1, 96, 'relu', 4., True, False, True),
('convbn', 1, 1, 480, 'relu', None, True, False, False), ('convbn', 1, 1, 448, 'relu', None, True, False, True),
('gpooling', None, None, None, None, None, None, None, False), ('gpooling', None, None, None, None, None, None, None, False),
# Remove bias and add batch norm for the last layer to support QAT # Remove bias and add batch norm for the last layer to support QAT
# and achieve slightly better accuracy. # and achieve slightly better accuracy.
...@@ -451,6 +452,78 @@ MNMultiAVG_SEG_BLOCK_SPECS = { ...@@ -451,6 +452,78 @@ MNMultiAVG_SEG_BLOCK_SPECS = {
] ]
} }
# Similar to MobileNetMultiMax and used for segmentation task.
# Reduced the filters by a factor of 2 in the last block.
MNMultiMAX_SEG_BLOCK_SPECS = {
'spec_name':
'MobileNetMultiMAXSeg',
'block_spec_schema': [
'block_fn', 'kernel_size', 'strides', 'filters', 'activation',
'expand_ratio', 'use_normalization', 'use_bias', 'is_output'
],
'block_specs': [
('convbn', 3, 2, 32, 'relu', None, True, False, False),
('invertedbottleneck', 3, 2, 32, 'relu', 3., True, False, True),
('invertedbottleneck', 5, 2, 64, 'relu', 6., True, False, False),
('invertedbottleneck', 3, 1, 64, 'relu', 2., True, False, False),
('invertedbottleneck', 3, 1, 64, 'relu', 2., True, False, True),
('invertedbottleneck', 5, 2, 128, 'relu', 6., True, False, False),
('invertedbottleneck', 3, 1, 128, 'relu', 4., True, False, False),
('invertedbottleneck', 3, 1, 128, 'relu', 3., True, False, False),
('invertedbottleneck', 3, 1, 128, 'relu', 3., True, False, False),
('invertedbottleneck', 3, 1, 128, 'relu', 6., True, False, False),
('invertedbottleneck', 3, 1, 128, 'relu', 3., True, False, True),
('invertedbottleneck', 3, 2, 160, 'relu', 6., True, False, False),
('invertedbottleneck', 5, 1, 96, 'relu', 2., True, False, False),
('invertedbottleneck', 3, 1, 96, 'relu', 4., True, False, False),
('invertedbottleneck', 5, 1, 96, 'relu', 320.0 / 96, True, False, True),
('convbn', 1, 1, 448, 'relu', None, True, False, True),
('gpooling', None, None, None, None, None, None, None, False),
# Remove bias and add batch norm for the last layer to support QAT
# and achieve slightly better accuracy.
('convbn', 1, 1, 1280, 'relu', None, True, False, False),
]
}
# A smaller MNV3Small, with reduced filters for the last few layers
MNV3SmallReducedFilters = {
'spec_name':
'MobilenetV3SmallReducedFilters',
'block_spec_schema': [
'block_fn', 'kernel_size', 'strides', 'filters', 'activation',
'se_ratio', 'expand_ratio', 'use_normalization', 'use_bias', 'is_output'
],
'block_specs': [
('convbn', 3, 2, 16, 'hard_swish', None, None, True, False, False),
('invertedbottleneck', 3, 2, 16, 'relu', 0.25, 1, None, False, True),
('invertedbottleneck', 3, 2, 24, 'relu', None, 72. / 16, None, False,
False),
('invertedbottleneck', 3, 1, 24, 'relu', None, 88. / 24, None, False,
True),
('invertedbottleneck', 5, 2, 40, 'hard_swish', 0.25, 4, None, False,
False),
('invertedbottleneck', 5, 1, 40, 'hard_swish', 0.25, 6, None, False,
False),
('invertedbottleneck', 5, 1, 40, 'hard_swish', 0.25, 6, None, False,
False),
('invertedbottleneck', 5, 1, 48, 'hard_swish', 0.25, 3, None, False,
False),
('invertedbottleneck', 5, 1, 48, 'hard_swish', 0.25, 3, None, False,
True),
# Layers below are different from MobileNetV3Small and have
# half as many filters
('invertedbottleneck', 5, 2, 48, 'hard_swish', 0.25, 3, None, False,
False),
('invertedbottleneck', 5, 1, 48, 'hard_swish', 0.25, 6, None, False,
False),
('invertedbottleneck', 5, 1, 48, 'hard_swish', 0.25, 6, None, False,
True),
('convbn', 1, 1, 288, 'hard_swish', None, None, True, False, False),
('gpooling', None, None, None, None, None, None, None, None, False),
('convbn', 1, 1, 1024, 'hard_swish', None, None, False, True, False),
]
}
SUPPORTED_SPECS_MAP = { SUPPORTED_SPECS_MAP = {
'MobileNetV1': MNV1_BLOCK_SPECS, 'MobileNetV1': MNV1_BLOCK_SPECS,
'MobileNetV2': MNV2_BLOCK_SPECS, 'MobileNetV2': MNV2_BLOCK_SPECS,
...@@ -460,6 +533,8 @@ SUPPORTED_SPECS_MAP = { ...@@ -460,6 +533,8 @@ SUPPORTED_SPECS_MAP = {
'MobileNetMultiMAX': MNMultiMAX_BLOCK_SPECS, 'MobileNetMultiMAX': MNMultiMAX_BLOCK_SPECS,
'MobileNetMultiAVG': MNMultiAVG_BLOCK_SPECS, 'MobileNetMultiAVG': MNMultiAVG_BLOCK_SPECS,
'MobileNetMultiAVGSeg': MNMultiAVG_SEG_BLOCK_SPECS, 'MobileNetMultiAVGSeg': MNMultiAVG_SEG_BLOCK_SPECS,
'MobileNetMultiMAXSeg': MNMultiMAX_SEG_BLOCK_SPECS,
'MobileNetV3SmallReducedFilters': MNV3SmallReducedFilters,
} }
......
...@@ -37,6 +37,8 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -37,6 +37,8 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiAVG', 'MobileNetMultiAVG',
'MobileNetMultiMAX', 'MobileNetMultiMAX',
'MobileNetMultiAVGSeg', 'MobileNetMultiAVGSeg',
'MobileNetMultiMAXSeg',
'MobileNetV3SmallReducedFilters',
) )
def test_serialize_deserialize(self, model_id): def test_serialize_deserialize(self, model_id):
# Create a network object that sets all of its config options. # Create a network object that sets all of its config options.
...@@ -82,6 +84,8 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -82,6 +84,8 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiAVG', 'MobileNetMultiAVG',
'MobileNetMultiMAX', 'MobileNetMultiMAX',
'MobileNetMultiAVGSeg', 'MobileNetMultiAVGSeg',
'MobileNetMultiMAXSeg',
'MobileNetV3SmallReducedFilters',
], ],
)) ))
def test_input_specs(self, input_dim, model_id): def test_input_specs(self, input_dim, model_id):
...@@ -105,6 +109,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -105,6 +109,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiAVG', 'MobileNetMultiAVG',
'MobileNetMultiMAX', 'MobileNetMultiMAX',
'MobileNetMultiAVGSeg', 'MobileNetMultiAVGSeg',
'MobileNetV3SmallReducedFilters',
], ],
[32, 224], [32, 224],
)) ))
...@@ -124,6 +129,8 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -124,6 +129,8 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiMAX': [32, 64, 128, 160], 'MobileNetMultiMAX': [32, 64, 128, 160],
'MobileNetMultiAVG': [32, 64, 160, 192], 'MobileNetMultiAVG': [32, 64, 160, 192],
'MobileNetMultiAVGSeg': [32, 64, 160, 96], 'MobileNetMultiAVGSeg': [32, 64, 160, 96],
'MobileNetMultiMAXSeg': [32, 64, 128, 96],
'MobileNetV3SmallReducedFilters': [16, 24, 48, 48],
} }
network = mobilenet.MobileNet(model_id=model_id, network = mobilenet.MobileNet(model_id=model_id,
...@@ -148,6 +155,8 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -148,6 +155,8 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiAVG', 'MobileNetMultiAVG',
'MobileNetMultiMAX', 'MobileNetMultiMAX',
'MobileNetMultiAVGSeg', 'MobileNetMultiAVGSeg',
'MobileNetMultiMAXSeg',
'MobileNetV3SmallReducedFilters',
], ],
[32, 224], [32, 224],
)) ))
...@@ -167,6 +176,8 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -167,6 +176,8 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiMAX': [96, 128, 384, 640], 'MobileNetMultiMAX': [96, 128, 384, 640],
'MobileNetMultiAVG': [64, 192, 640, 768], 'MobileNetMultiAVG': [64, 192, 640, 768],
'MobileNetMultiAVGSeg': [64, 192, 640, 384], 'MobileNetMultiAVGSeg': [64, 192, 640, 384],
'MobileNetMultiMAXSeg': [96, 128, 384, 320],
'MobileNetV3SmallReducedFilters': [16, 88, 144, 288],
} }
network = mobilenet.MobileNet(model_id=model_id, network = mobilenet.MobileNet(model_id=model_id,
filter_size_scale=1.0, filter_size_scale=1.0,
...@@ -196,6 +207,8 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -196,6 +207,8 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiMAX', 'MobileNetMultiMAX',
'MobileNetMultiMAX', 'MobileNetMultiMAX',
'MobileNetMultiAVGSeg', 'MobileNetMultiAVGSeg',
'MobileNetMultiMAXSeg',
'MobileNetV3SmallReducedFilters',
], ],
[1.0, 0.75], [1.0, 0.75],
)) ))
...@@ -217,8 +230,12 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -217,8 +230,12 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
('MobileNetMultiAVG', 0.75): 2349704, ('MobileNetMultiAVG', 0.75): 2349704,
('MobileNetMultiMAX', 1.0): 3174560, ('MobileNetMultiMAX', 1.0): 3174560,
('MobileNetMultiMAX', 0.75): 2045816, ('MobileNetMultiMAX', 0.75): 2045816,
('MobileNetMultiAVGSeg', 1.0): 2284000, ('MobileNetMultiAVGSeg', 1.0): 2239840,
('MobileNetMultiAVGSeg', 0.75): 1427816, ('MobileNetMultiAVGSeg', 0.75): 1395272,
('MobileNetMultiMAXSeg', 1.0): 1929088,
('MobileNetMultiMAXSeg', 0.75): 1216544,
('MobileNetV3SmallReducedFilters', 1.0): 694880,
('MobileNetV3SmallReducedFilters', 0.75): 505960,
} }
input_size = 224 input_size = 224
...@@ -241,6 +258,8 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -241,6 +258,8 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiAVG', 'MobileNetMultiAVG',
'MobileNetMultiMAX', 'MobileNetMultiMAX',
'MobileNetMultiAVGSeg', 'MobileNetMultiAVGSeg',
'MobileNetMultiMAXSeg',
'MobileNetV3SmallReducedFilters',
], ],
[8, 16, 32], [8, 16, 32],
)) ))
...@@ -258,7 +277,9 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -258,7 +277,9 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetV3EdgeTPU': 192, 'MobileNetV3EdgeTPU': 192,
'MobileNetMultiMAX': 160, 'MobileNetMultiMAX': 160,
'MobileNetMultiAVG': 192, 'MobileNetMultiAVG': 192,
'MobileNetMultiAVGSeg': 96, 'MobileNetMultiAVGSeg': 448,
'MobileNetMultiMAXSeg': 448,
'MobileNetV3SmallReducedFilters': 48,
} }
network = mobilenet.MobileNet( network = mobilenet.MobileNet(
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains definitions of Atrous Spatial Pyramid Pooling (ASPP) decoder.""" """Contains definitions of Atrous Spatial Pyramid Pooling (ASPP) decoder."""
from typing import Any, List, Mapping, Optional from typing import Any, List, Mapping, Optional, Union
# Import libraries # Import libraries
...@@ -22,6 +22,9 @@ import tensorflow as tf ...@@ -22,6 +22,9 @@ import tensorflow as tf
from official.modeling import hyperparams from official.modeling import hyperparams
from official.vision.beta.modeling.decoders import factory from official.vision.beta.modeling.decoders import factory
from official.vision.beta.modeling.layers import deeplab from official.vision.beta.modeling.layers import deeplab
from official.vision.beta.modeling.layers import nn_layers
TensorMapUnion = Union[tf.Tensor, Mapping[str, tf.Tensor]]
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
...@@ -43,6 +46,8 @@ class ASPP(tf.keras.layers.Layer): ...@@ -43,6 +46,8 @@ class ASPP(tf.keras.layers.Layer):
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
interpolation: str = 'bilinear', interpolation: str = 'bilinear',
use_depthwise_convolution: bool = False, use_depthwise_convolution: bool = False,
spp_layer_version: str = 'v1',
output_tensor: bool = False,
**kwargs): **kwargs):
"""Initializes an Atrous Spatial Pyramid Pooling (ASPP) layer. """Initializes an Atrous Spatial Pyramid Pooling (ASPP) layer.
...@@ -67,9 +72,12 @@ class ASPP(tf.keras.layers.Layer): ...@@ -67,9 +72,12 @@ class ASPP(tf.keras.layers.Layer):
`gaussian`, or `mitchellcubic`. `gaussian`, or `mitchellcubic`.
use_depthwise_convolution: If True depthwise separable convolutions will use_depthwise_convolution: If True depthwise separable convolutions will
be added to the Atrous spatial pyramid pooling. be added to the Atrous spatial pyramid pooling.
spp_layer_version: A `str` of spatial pyramid pooling layer version.
output_tensor: Whether to output a single tensor or a dictionary of tensor.
Default is false.
**kwargs: Additional keyword arguments to be passed. **kwargs: Additional keyword arguments to be passed.
""" """
super(ASPP, self).__init__(**kwargs) super().__init__(**kwargs)
self._config_dict = { self._config_dict = {
'level': level, 'level': level,
'dilation_rates': dilation_rates, 'dilation_rates': dilation_rates,
...@@ -84,7 +92,11 @@ class ASPP(tf.keras.layers.Layer): ...@@ -84,7 +92,11 @@ class ASPP(tf.keras.layers.Layer):
'kernel_regularizer': kernel_regularizer, 'kernel_regularizer': kernel_regularizer,
'interpolation': interpolation, 'interpolation': interpolation,
'use_depthwise_convolution': use_depthwise_convolution, 'use_depthwise_convolution': use_depthwise_convolution,
'spp_layer_version': spp_layer_version,
'output_tensor': output_tensor
} }
self._aspp_layer = deeplab.SpatialPyramidPooling if self._config_dict[
'spp_layer_version'] == 'v1' else nn_layers.SpatialPyramidPooling
def build(self, input_shape): def build(self, input_shape):
pool_kernel_size = None pool_kernel_size = None
...@@ -93,7 +105,8 @@ class ASPP(tf.keras.layers.Layer): ...@@ -93,7 +105,8 @@ class ASPP(tf.keras.layers.Layer):
int(p_size // 2**self._config_dict['level']) int(p_size // 2**self._config_dict['level'])
for p_size in self._config_dict['pool_kernel_size'] for p_size in self._config_dict['pool_kernel_size']
] ]
self.aspp = deeplab.SpatialPyramidPooling(
self.aspp = self._aspp_layer(
output_channels=self._config_dict['num_filters'], output_channels=self._config_dict['num_filters'],
dilation_rates=self._config_dict['dilation_rates'], dilation_rates=self._config_dict['dilation_rates'],
pool_kernel_size=pool_kernel_size, pool_kernel_size=pool_kernel_size,
...@@ -108,31 +121,36 @@ class ASPP(tf.keras.layers.Layer): ...@@ -108,31 +121,36 @@ class ASPP(tf.keras.layers.Layer):
use_depthwise_convolution=self._config_dict['use_depthwise_convolution'] use_depthwise_convolution=self._config_dict['use_depthwise_convolution']
) )
def call(self, inputs: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]: def call(self, inputs: TensorMapUnion) -> TensorMapUnion:
"""Calls the Atrous Spatial Pyramid Pooling (ASPP) layer on an input. """Calls the Atrous Spatial Pyramid Pooling (ASPP) layer on an input.
The output of ASPP will be a dict of {`level`, `tf.Tensor`} even if only one The output of ASPP will be a dict of {`level`, `tf.Tensor`} even if only one
level is present. Hence, this will be compatible with the rest of the level is present, if output_tensor is false. Hence, this will be compatible
segmentation model interfaces. with the rest of the segmentation model interfaces.
If output_tensor is true, a single tensot is output.
Args: Args:
inputs: A `dict` of `tf.Tensor` where inputs: A `tf.Tensor` of shape [batch, height_l, width_l, filter_size] or
a `dict` of `tf.Tensor` where
- key: A `str` of the level of the multilevel feature maps. - key: A `str` of the level of the multilevel feature maps.
- values: A `tf.Tensor` of shape [batch, height_l, width_l, - values: A `tf.Tensor` of shape [batch, height_l, width_l,
filter_size]. filter_size].
Returns: Returns:
A `dict` of `tf.Tensor` where A `tf.Tensor` of shape [batch, height_l, width_l, filter_size] or a `dict`
of `tf.Tensor` where
- key: A `str` of the level of the multilevel feature maps. - key: A `str` of the level of the multilevel feature maps.
- values: A `tf.Tensor` of output of ASPP module. - values: A `tf.Tensor` of output of ASPP module.
""" """
outputs = {} outputs = {}
level = str(self._config_dict['level']) level = str(self._config_dict['level'])
outputs[level] = self.aspp(inputs[level]) backbone_output = inputs[level] if isinstance(inputs, dict) else inputs
return outputs outputs = self.aspp(backbone_output)
return outputs if self._config_dict['output_tensor'] else {level: outputs}
def get_config(self) -> Mapping[str, Any]: def get_config(self) -> Mapping[str, Any]:
return self._config_dict base_config = super().get_config()
return dict(list(base_config.items()) + list(self._config_dict.items()))
@classmethod @classmethod
def from_config(cls, config, custom_objects=None): def from_config(cls, config, custom_objects=None):
...@@ -180,4 +198,6 @@ def build_aspp_decoder( ...@@ -180,4 +198,6 @@ def build_aspp_decoder(
norm_momentum=norm_activation_config.norm_momentum, norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon, norm_epsilon=norm_activation_config.norm_epsilon,
activation=norm_activation_config.activation, activation=norm_activation_config.activation,
kernel_regularizer=l2_regularizer) kernel_regularizer=l2_regularizer,
spp_layer_version=decoder_cfg.spp_layer_version,
output_tensor=decoder_cfg.output_tensor)
...@@ -26,14 +26,15 @@ from official.vision.beta.modeling.decoders import aspp ...@@ -26,14 +26,15 @@ from official.vision.beta.modeling.decoders import aspp
class ASPPTest(parameterized.TestCase, tf.test.TestCase): class ASPPTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters( @parameterized.parameters(
(3, [6, 12, 18, 24], 128), (3, [6, 12, 18, 24], 128, 'v1'),
(3, [6, 12, 18], 128), (3, [6, 12, 18], 128, 'v1'),
(3, [6, 12], 256), (3, [6, 12], 256, 'v1'),
(4, [6, 12, 18, 24], 128), (4, [6, 12, 18, 24], 128, 'v2'),
(4, [6, 12, 18], 128), (4, [6, 12, 18], 128, 'v2'),
(4, [6, 12], 256), (4, [6, 12], 256, 'v2'),
) )
def test_network_creation(self, level, dilation_rates, num_filters): def test_network_creation(self, level, dilation_rates, num_filters,
spp_layer_version):
"""Test creation of ASPP.""" """Test creation of ASPP."""
input_size = 256 input_size = 256
...@@ -45,7 +46,8 @@ class ASPPTest(parameterized.TestCase, tf.test.TestCase): ...@@ -45,7 +46,8 @@ class ASPPTest(parameterized.TestCase, tf.test.TestCase):
network = aspp.ASPP( network = aspp.ASPP(
level=level, level=level,
dilation_rates=dilation_rates, dilation_rates=dilation_rates,
num_filters=num_filters) num_filters=num_filters,
spp_layer_version=spp_layer_version)
endpoints = backbone(inputs) endpoints = backbone(inputs)
feats = network(endpoints) feats = network(endpoints)
...@@ -71,7 +73,11 @@ class ASPPTest(parameterized.TestCase, tf.test.TestCase): ...@@ -71,7 +73,11 @@ class ASPPTest(parameterized.TestCase, tf.test.TestCase):
interpolation='bilinear', interpolation='bilinear',
dropout_rate=0.2, dropout_rate=0.2,
use_depthwise_convolution='false', use_depthwise_convolution='false',
) spp_layer_version='v1',
output_tensor=False,
dtype='float32',
name='aspp',
trainable=True)
network = aspp.ASPP(**kwargs) network = aspp.ASPP(**kwargs)
expected_config = dict(kwargs) expected_config = dict(kwargs)
......
...@@ -133,6 +133,10 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -133,6 +133,10 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
network_config = network.get_config() network_config = network.get_config()
factory_network_config = factory_network.get_config() factory_network_config = factory_network.get_config()
# Due to calling `super().get_config()` in aspp layer, everything but the
# the name of two layer instances are the same, so we force equal name so it
# will not give false alarm.
factory_network_config['name'] = network_config['name']
self.assertEqual(network_config, factory_network_config) self.assertEqual(network_config, factory_network_config)
......
...@@ -22,6 +22,7 @@ from absl import logging ...@@ -22,6 +22,7 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling import tf_utils
from official.vision.beta.modeling.decoders import factory from official.vision.beta.modeling.decoders import factory
from official.vision.beta.ops import spatial_transform_ops from official.vision.beta.ops import spatial_transform_ops
...@@ -165,12 +166,7 @@ class NASFPN(tf.keras.Model): ...@@ -165,12 +166,7 @@ class NASFPN(tf.keras.Model):
'momentum': self._config_dict['norm_momentum'], 'momentum': self._config_dict['norm_momentum'],
'epsilon': self._config_dict['norm_epsilon'], 'epsilon': self._config_dict['norm_epsilon'],
} }
if activation == 'relu': self._activation = tf_utils.get_activation(activation)
self._activation = tf.nn.relu
elif activation == 'swish':
self._activation = tf.nn.swish
else:
raise ValueError('Activation {} not implemented.'.format(activation))
# Gets input feature pyramid from backbone. # Gets input feature pyramid from backbone.
inputs = self._build_input_pyramid(input_specs, min_level) inputs = self._build_input_pyramid(input_specs, min_level)
...@@ -238,7 +234,11 @@ class NASFPN(tf.keras.Model): ...@@ -238,7 +234,11 @@ class NASFPN(tf.keras.Model):
# dtype mismatch when one input (by default float32 dtype) does not meet all # dtype mismatch when one input (by default float32 dtype) does not meet all
# the above conditions and is output unchanged, while other inputs are # the above conditions and is output unchanged, while other inputs are
# processed to have different dtype, e.g., using bfloat16 on TPU. # processed to have different dtype, e.g., using bfloat16 on TPU.
return tf.cast(x, dtype=tf.keras.layers.Layer().dtype_policy.compute_dtype) compute_dtype = tf.keras.layers.Layer().dtype_policy.compute_dtype
if (compute_dtype is not None) and (x.dtype != compute_dtype):
return tf.cast(x, dtype=compute_dtype)
else:
return x
def _global_attention(self, feat0, feat1): def _global_attention(self, feat0, feat1):
m = tf.math.reduce_max(feat0, axis=[1, 2], keepdims=True) m = tf.math.reduce_max(feat0, axis=[1, 2], keepdims=True)
......
...@@ -370,5 +370,17 @@ def build_segmentation_model( ...@@ -370,5 +370,17 @@ def build_segmentation_model(
norm_epsilon=norm_activation_config.norm_epsilon, norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer) kernel_regularizer=l2_regularizer)
model = segmentation_model.SegmentationModel(backbone, decoder, head) mask_scoring_head = None
if model_config.mask_scoring_head:
mask_scoring_head = segmentation_heads.MaskScoring(
num_classes=model_config.num_classes,
**model_config.mask_scoring_head.as_dict(),
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
model = segmentation_model.SegmentationModel(
backbone, decoder, head, mask_scoring_head=mask_scoring_head)
return model return model
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains definitions of segmentation heads.""" """Contains definitions of segmentation heads."""
from typing import List, Union, Optional, Mapping, Tuple from typing import List, Union, Optional, Mapping, Tuple, Any
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
...@@ -21,6 +21,176 @@ from official.vision.beta.modeling.layers import nn_layers ...@@ -21,6 +21,176 @@ from official.vision.beta.modeling.layers import nn_layers
from official.vision.beta.ops import spatial_transform_ops from official.vision.beta.ops import spatial_transform_ops
class MaskScoring(tf.keras.Model):
"""Creates a mask scoring layer.
This implements mask scoring layer from the paper:
Zhaojin Huang, Lichao Huang, Yongchao Gong, Chang Huang, Xinggang Wang.
Mask Scoring R-CNN.
(https://arxiv.org/pdf/1903.00241.pdf)
"""
def __init__(
self,
num_classes: int,
fc_input_size: List[int],
num_convs: int = 3,
num_filters: int = 256,
fc_dims: int = 1024,
num_fcs: int = 2,
activation: str = 'relu',
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
**kwargs):
"""Initializes mask scoring layer.
Args:
num_classes: An `int` for number of classes.
fc_input_size: A List of `int` for the input size of the
fully connected layers.
num_convs: An`int` for number of conv layers.
num_filters: An `int` for the number of filters for conv layers.
fc_dims: An `int` number of filters for each fully connected layers.
num_fcs: An `int` for number of fully connected layers.
activation: A `str` name of the activation function.
use_sync_bn: A bool, whether or not to use sync batch normalization.
norm_momentum: A float for the momentum in BatchNorm. Defaults to 0.99.
norm_epsilon: A float for the epsilon value in BatchNorm. Defaults to
0.001.
kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
Conv2D. Default is None.
bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
**kwargs: Additional keyword arguments to be passed.
"""
super(MaskScoring, self).__init__(**kwargs)
self._config_dict = {
'num_classes': num_classes,
'num_convs': num_convs,
'num_filters': num_filters,
'fc_input_size': fc_input_size,
'fc_dims': fc_dims,
'num_fcs': num_fcs,
'use_sync_bn': use_sync_bn,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon,
'activation': activation,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
}
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation = tf_utils.get_activation(activation)
def build(self, input_shape: Union[tf.TensorShape, List[tf.TensorShape]]):
"""Creates the variables of the mask scoring head."""
conv_op = tf.keras.layers.Conv2D
conv_kwargs = {
'filters': self._config_dict['num_filters'],
'kernel_size': 3,
'padding': 'same',
}
conv_kwargs.update({
'kernel_initializer': tf.keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
'bias_initializer': tf.zeros_initializer(),
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
})
bn_op = (tf.keras.layers.experimental.SyncBatchNormalization
if self._config_dict['use_sync_bn']
else tf.keras.layers.BatchNormalization)
bn_kwargs = {
'axis': self._bn_axis,
'momentum': self._config_dict['norm_momentum'],
'epsilon': self._config_dict['norm_epsilon'],
}
self._convs = []
self._conv_norms = []
for i in range(self._config_dict['num_convs']):
conv_name = 'mask-scoring_{}'.format(i)
self._convs.append(conv_op(name=conv_name, **conv_kwargs))
bn_name = 'mask-scoring-bn_{}'.format(i)
self._conv_norms.append(bn_op(name=bn_name, **bn_kwargs))
self._fcs = []
self._fc_norms = []
for i in range(self._config_dict['num_fcs']):
fc_name = 'mask-scoring-fc_{}'.format(i)
self._fcs.append(
tf.keras.layers.Dense(
units=self._config_dict['fc_dims'],
kernel_initializer=tf.keras.initializers.VarianceScaling(
scale=1 / 3.0, mode='fan_out', distribution='uniform'),
kernel_regularizer=self._config_dict['kernel_regularizer'],
bias_regularizer=self._config_dict['bias_regularizer'],
name=fc_name))
bn_name = 'mask-scoring-fc-bn_{}'.format(i)
self._fc_norms.append(bn_op(name=bn_name, **bn_kwargs))
self._classifier = tf.keras.layers.Dense(
units=self._config_dict['num_classes'],
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
bias_initializer=tf.zeros_initializer(),
kernel_regularizer=self._config_dict['kernel_regularizer'],
bias_regularizer=self._config_dict['bias_regularizer'],
name='iou-scores')
super(MaskScoring, self).build(input_shape)
def call(self, inputs: tf.Tensor, training: bool = None):
"""Forward pass mask scoring head.
Args:
inputs: A `tf.Tensor` of the shape [batch_size, width, size, num_classes],
representing the segmentation logits.
training: a `bool` indicating whether it is in `training` mode.
Returns:
mask_scores: A `tf.Tensor` of predicted mask scores
[batch_size, num_classes].
"""
x = tf.stop_gradient(inputs)
for conv, bn in zip(self._convs, self._conv_norms):
x = conv(x)
x = bn(x)
x = self._activation(x)
# Casts feat to float32 so the resize op can be run on TPU.
x = tf.cast(x, tf.float32)
x = tf.image.resize(x, size=self._config_dict['fc_input_size'],
method=tf.image.ResizeMethod.BILINEAR)
# Casts it back to be compatible with the rest opetations.
x = tf.cast(x, inputs.dtype)
_, h, w, filters = x.get_shape().as_list()
x = tf.reshape(x, [-1, h * w * filters])
for fc, bn in zip(self._fcs, self._fc_norms):
x = fc(x)
x = bn(x)
x = self._activation(x)
ious = self._classifier(x)
return ious
def get_config(self) -> Mapping[str, Any]:
return self._config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
class SegmentationHead(tf.keras.layers.Layer): class SegmentationHead(tf.keras.layers.Layer):
"""Creates a segmentation head.""" """Creates a segmentation head."""
...@@ -220,7 +390,7 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -220,7 +390,7 @@ class SegmentationHead(tf.keras.layers.Layer):
kernel_regularizer=self._config_dict['kernel_regularizer'], kernel_regularizer=self._config_dict['kernel_regularizer'],
bias_regularizer=self._config_dict['bias_regularizer']) bias_regularizer=self._config_dict['bias_regularizer'])
super(SegmentationHead, self).build(input_shape) super().build(input_shape)
def _fuse_features(self, inputs): def _fuse_features(self, inputs):
backbone_output = inputs[0] backbone_output = inputs[0]
...@@ -285,7 +455,8 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -285,7 +455,8 @@ class SegmentationHead(tf.keras.layers.Layer):
return self._prediction_conv(x) return self._prediction_conv(x)
def get_config(self): def get_config(self):
return self._config_dict base_config = super().get_config()
return dict(list(base_config.items()) + list(self._config_dict.items()))
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
......
...@@ -81,5 +81,36 @@ class SegmentationHeadTest(parameterized.TestCase, tf.test.TestCase): ...@@ -81,5 +81,36 @@ class SegmentationHeadTest(parameterized.TestCase, tf.test.TestCase):
new_head = segmentation_heads.SegmentationHead.from_config(config) new_head = segmentation_heads.SegmentationHead.from_config(config)
self.assertAllEqual(head.get_config(), new_head.get_config()) self.assertAllEqual(head.get_config(), new_head.get_config())
class MaskScoringHeadTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(1, 1, 64, [4, 4]),
(2, 1, 64, [4, 4]),
(3, 1, 64, [4, 4]),
(1, 2, 32, [8, 8]),
(2, 2, 32, [8, 8]),
(3, 2, 32, [8, 8]),)
def test_forward(self, num_convs, num_fcs,
num_filters, fc_input_size):
features = np.random.rand(2, 64, 64, 16)
head = segmentation_heads.MaskScoring(
num_classes=2,
num_convs=num_convs,
num_filters=num_filters,
fc_dims=128,
fc_input_size=fc_input_size)
scores = head(features)
self.assertAllEqual(scores.numpy().shape, [2, 2])
def test_serialize_deserialize(self):
head = segmentation_heads.MaskScoring(
num_classes=2, fc_input_size=[4, 4], fc_dims=128)
config = head.get_config()
new_head = segmentation_heads.MaskScoring.from_config(config)
self.assertAllEqual(head.get_config(), new_head.get_config())
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -368,7 +368,7 @@ def _generate_detections_v2(boxes: tf.Tensor, ...@@ -368,7 +368,7 @@ def _generate_detections_v2(boxes: tf.Tensor,
nmsed_boxes = tf.gather(nmsed_boxes, indices, batch_dims=1, axis=1) nmsed_boxes = tf.gather(nmsed_boxes, indices, batch_dims=1, axis=1)
nmsed_classes = tf.gather(nmsed_classes, indices, batch_dims=1) nmsed_classes = tf.gather(nmsed_classes, indices, batch_dims=1)
valid_detections = tf.reduce_sum( valid_detections = tf.reduce_sum(
input_tensor=tf.cast(tf.greater(nmsed_scores, -1), tf.int32), axis=1) input_tensor=tf.cast(tf.greater(nmsed_scores, 0.0), tf.int32), axis=1)
return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
......
...@@ -497,6 +497,7 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer): ...@@ -497,6 +497,7 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
activation='relu', activation='relu',
se_inner_activation='relu', se_inner_activation='relu',
se_gating_activation='sigmoid', se_gating_activation='sigmoid',
se_round_down_protect=True,
expand_se_in_filters=False, expand_se_in_filters=False,
depthwise_activation=None, depthwise_activation=None,
use_sync_bn=False, use_sync_bn=False,
...@@ -532,6 +533,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer): ...@@ -532,6 +533,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
se_inner_activation: A `str` name of squeeze-excitation inner activation. se_inner_activation: A `str` name of squeeze-excitation inner activation.
se_gating_activation: A `str` name of squeeze-excitation gating se_gating_activation: A `str` name of squeeze-excitation gating
activation. activation.
se_round_down_protect: A `bool` of whether round down more than 10%
will be allowed in SE layer.
expand_se_in_filters: A `bool` of whether or not to expand in_filter in expand_se_in_filters: A `bool` of whether or not to expand in_filter in
squeeze and excitation layer. squeeze and excitation layer.
depthwise_activation: A `str` name of the activation function for depthwise_activation: A `str` name of the activation function for
...@@ -573,6 +576,7 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer): ...@@ -573,6 +576,7 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
self._se_inner_activation = se_inner_activation self._se_inner_activation = se_inner_activation
self._se_gating_activation = se_gating_activation self._se_gating_activation = se_gating_activation
self._depthwise_activation = depthwise_activation self._depthwise_activation = depthwise_activation
self._se_round_down_protect = se_round_down_protect
self._kernel_initializer = kernel_initializer self._kernel_initializer = kernel_initializer
self._norm_momentum = norm_momentum self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon self._norm_epsilon = norm_epsilon
...@@ -652,6 +656,7 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer): ...@@ -652,6 +656,7 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
out_filters=expand_filters, out_filters=expand_filters,
se_ratio=self._se_ratio, se_ratio=self._se_ratio,
divisible_by=self._divisible_by, divisible_by=self._divisible_by,
round_down_protect=self._se_round_down_protect,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
...@@ -700,6 +705,7 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer): ...@@ -700,6 +705,7 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
'activation': self._activation, 'activation': self._activation,
'se_inner_activation': self._se_inner_activation, 'se_inner_activation': self._se_inner_activation,
'se_gating_activation': self._se_gating_activation, 'se_gating_activation': self._se_gating_activation,
'se_round_down_protect': self._se_round_down_protect,
'expand_se_in_filters': self._expand_se_in_filters, 'expand_se_in_filters': self._expand_se_in_filters,
'depthwise_activation': self._depthwise_activation, 'depthwise_activation': self._depthwise_activation,
'dilation_rate': self._dilation_rate, 'dilation_rate': self._dilation_rate,
...@@ -1310,3 +1316,196 @@ class DepthwiseSeparableConvBlock(tf.keras.layers.Layer): ...@@ -1310,3 +1316,196 @@ class DepthwiseSeparableConvBlock(tf.keras.layers.Layer):
x = self._conv1(x) x = self._conv1(x)
x = self._norm1(x) x = self._norm1(x)
return self._activation_fn(x) return self._activation_fn(x)
@tf.keras.utils.register_keras_serializable(package='Vision')
class TuckerConvBlock(tf.keras.layers.Layer):
"""An Tucker block (generalized bottleneck)."""
def __init__(self,
in_filters,
out_filters,
input_compression_ratio,
output_compression_ratio,
strides,
kernel_size=3,
stochastic_depth_drop_rate=None,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
activation='relu',
use_sync_bn=False,
divisible_by=1,
use_residual=True,
norm_momentum=0.99,
norm_epsilon=0.001,
**kwargs):
"""Initializes an inverted bottleneck block with BN after convolutions.
Args:
in_filters: An `int` number of filters of the input tensor.
out_filters: An `int` number of filters of the output tensor.
input_compression_ratio: An `float` of compression ratio for
input filters.
output_compression_ratio: An `float` of compression ratio for
output filters.
strides: An `int` block stride. If greater than 1, this block will
ultimately downsample the input.
kernel_size: An `int` kernel_size of the depthwise conv layer.
stochastic_depth_drop_rate: A `float` or None. if not None, drop rate for
the stochastic depth layer.
kernel_initializer: A `str` of kernel_initializer for convolutional
layers.
kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
Conv2D. Default to None.
bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2d.
Default to None.
activation: A `str` name of the activation function.
use_sync_bn: A `bool`. If True, use synchronized batch normalization.
divisible_by: An `int` that ensures all inner dimensions are divisible by
this number.
use_residual: A `bool` of whether to include residual connection between
input and output.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero.
**kwargs: Additional keyword arguments to be passed.
"""
super(TuckerConvBlock, self).__init__(**kwargs)
self._in_filters = in_filters
self._out_filters = out_filters
self._input_compression_ratio = input_compression_ratio
self._output_compression_ratio = output_compression_ratio
self._strides = strides
self._kernel_size = kernel_size
self._divisible_by = divisible_by
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
self._use_sync_bn = use_sync_bn
self._use_residual = use_residual
self._activation = activation
self._kernel_initializer = kernel_initializer
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
def build(self, input_shape):
input_compressed_filters = nn_layers.make_divisible(
value=self._in_filters * self._input_compression_ratio,
divisor=self._divisible_by,
round_down_protect=False)
self._conv0 = tf.keras.layers.Conv2D(
filters=input_compressed_filters,
kernel_size=1,
strides=1,
padding='same',
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._activation_layer0 = tf_utils.get_activation(
self._activation, use_keras_layer=True)
output_compressed_filters = nn_layers.make_divisible(
value=self._out_filters * self._output_compression_ratio,
divisor=self._divisible_by,
round_down_protect=False)
self._conv1 = tf.keras.layers.Conv2D(
filters=output_compressed_filters,
kernel_size=self._kernel_size,
strides=self._strides,
padding='same',
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm1 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._activation_layer1 = tf_utils.get_activation(
self._activation, use_keras_layer=True)
# Last 1x1 conv.
self._conv2 = tf.keras.layers.Conv2D(
filters=self._out_filters,
kernel_size=1,
strides=1,
padding='same',
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm2 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
if self._stochastic_depth_drop_rate:
self._stochastic_depth = nn_layers.StochasticDepth(
self._stochastic_depth_drop_rate)
else:
self._stochastic_depth = None
self._add = tf.keras.layers.Add()
super(TuckerConvBlock, self).build(input_shape)
def get_config(self):
config = {
'in_filters': self._in_filters,
'out_filters': self._out_filters,
'input_compression_ratio': self._input_compression_ratio,
'output_compression_ratio': self._output_compression_ratio,
'strides': self._strides,
'kernel_size': self._kernel_size,
'divisible_by': self._divisible_by,
'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'use_residual': self._use_residual,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
}
base_config = super(TuckerConvBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs, training=None):
shortcut = inputs
x = self._conv0(inputs)
x = self._norm0(x)
x = self._activation_layer0(x)
x = self._conv1(x)
x = self._norm1(x)
x = self._activation_layer1(x)
x = self._conv2(x)
x = self._norm2(x)
if (self._use_residual and
self._in_filters == self._out_filters and
self._strides == 1):
if self._stochastic_depth:
x = self._stochastic_depth(x, training=training)
x = self._add([x, shortcut])
return x
...@@ -113,6 +113,31 @@ class NNBlocksTest(parameterized.TestCase, tf.test.TestCase): ...@@ -113,6 +113,31 @@ class NNBlocksTest(parameterized.TestCase, tf.test.TestCase):
[1, input_size // strides, input_size // strides, out_filters], [1, input_size // strides, input_size // strides, out_filters],
features.shape.as_list()) features.shape.as_list())
@parameterized.parameters(
(nn_blocks.TuckerConvBlock, 1, 0.25, 0.25),
(nn_blocks.TuckerConvBlock, 2, 0.25, 0.25),
)
def test_tucker_conv_block(
self, block_fn, strides,
input_compression_ratio, output_compression_ratio):
input_size = 128
in_filters = 24
out_filters = 24
inputs = tf.keras.Input(
shape=(input_size, input_size, in_filters), batch_size=1)
block = block_fn(
in_filters=in_filters,
out_filters=out_filters,
input_compression_ratio=input_compression_ratio,
output_compression_ratio=output_compression_ratio,
strides=strides)
features = block(inputs)
self.assertAllEqual(
[1, input_size // strides, input_size // strides, out_filters],
features.shape.as_list())
class ResidualInnerTest(parameterized.TestCase, tf.test.TestCase): class ResidualInnerTest(parameterized.TestCase, tf.test.TestCase):
......
...@@ -30,7 +30,8 @@ Activation = Union[str, Callable] ...@@ -30,7 +30,8 @@ Activation = Union[str, Callable]
def make_divisible(value: float, def make_divisible(value: float,
divisor: int, divisor: int,
min_value: Optional[float] = None min_value: Optional[float] = None,
round_down_protect: bool = True,
) -> int: ) -> int:
"""This is to ensure that all layers have channels that are divisible by 8. """This is to ensure that all layers have channels that are divisible by 8.
...@@ -38,6 +39,8 @@ def make_divisible(value: float, ...@@ -38,6 +39,8 @@ def make_divisible(value: float,
value: A `float` of original value. value: A `float` of original value.
divisor: An `int` of the divisor that need to be checked upon. divisor: An `int` of the divisor that need to be checked upon.
min_value: A `float` of minimum value threshold. min_value: A `float` of minimum value threshold.
round_down_protect: A `bool` indicating whether round down more than 10%
will be allowed.
Returns: Returns:
The adjusted value in `int` that is divisible against divisor. The adjusted value in `int` that is divisible against divisor.
...@@ -46,7 +49,7 @@ def make_divisible(value: float, ...@@ -46,7 +49,7 @@ def make_divisible(value: float,
min_value = divisor min_value = divisor
new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%. # Make sure that round down does not go down by more than 10%.
if new_value < 0.9 * value: if round_down_protect and new_value < 0.9 * value:
new_value += divisor new_value += divisor
return int(new_value) return int(new_value)
...@@ -55,7 +58,8 @@ def round_filters(filters: int, ...@@ -55,7 +58,8 @@ def round_filters(filters: int,
multiplier: float, multiplier: float,
divisor: int = 8, divisor: int = 8,
min_depth: Optional[int] = None, min_depth: Optional[int] = None,
skip: bool = False): round_down_protect: bool = True,
skip: bool = False) -> int:
"""Rounds number of filters based on width multiplier.""" """Rounds number of filters based on width multiplier."""
orig_f = filters orig_f = filters
if skip or not multiplier: if skip or not multiplier:
...@@ -63,7 +67,8 @@ def round_filters(filters: int, ...@@ -63,7 +67,8 @@ def round_filters(filters: int,
new_filters = make_divisible(value=filters * multiplier, new_filters = make_divisible(value=filters * multiplier,
divisor=divisor, divisor=divisor,
min_value=min_depth) min_value=min_depth,
round_down_protect=round_down_protect)
logging.info('round_filter input=%s output=%s', orig_f, new_filters) logging.info('round_filter input=%s output=%s', orig_f, new_filters)
return int(new_filters) return int(new_filters)
...@@ -80,39 +85,6 @@ def get_padding_for_kernel_size(kernel_size): ...@@ -80,39 +85,6 @@ def get_padding_for_kernel_size(kernel_size):
kernel_size)) kernel_size))
def hard_swish(x: tf.Tensor) -> tf.Tensor:
"""A Swish6/H-Swish activation function.
Reference: Section 5.2 of Howard et al. "Searching for MobileNet V3."
https://arxiv.org/pdf/1905.02244.pdf
Args:
x: the input tensor.
Returns:
The activation output.
"""
return x * tf.nn.relu6(x + 3.) * (1. / 6.)
tf.keras.utils.get_custom_objects().update({'hard_swish': hard_swish})
def simple_swish(x: tf.Tensor) -> tf.Tensor:
"""A swish/silu activation function without custom gradients.
Useful for exporting to SavedModel to avoid custom gradient warnings.
Args:
x: the input tensor.
Returns:
The activation output.
"""
return x * tf.math.sigmoid(x)
tf.keras.utils.get_custom_objects().update({'simple_swish': simple_swish})
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
class SqueezeExcitation(tf.keras.layers.Layer): class SqueezeExcitation(tf.keras.layers.Layer):
"""Creates a squeeze and excitation layer.""" """Creates a squeeze and excitation layer."""
...@@ -128,6 +100,7 @@ class SqueezeExcitation(tf.keras.layers.Layer): ...@@ -128,6 +100,7 @@ class SqueezeExcitation(tf.keras.layers.Layer):
bias_regularizer=None, bias_regularizer=None,
activation='relu', activation='relu',
gating_activation='sigmoid', gating_activation='sigmoid',
round_down_protect=True,
**kwargs): **kwargs):
"""Initializes a squeeze and excitation layer. """Initializes a squeeze and excitation layer.
...@@ -148,6 +121,8 @@ class SqueezeExcitation(tf.keras.layers.Layer): ...@@ -148,6 +121,8 @@ class SqueezeExcitation(tf.keras.layers.Layer):
activation: A `str` name of the activation function. activation: A `str` name of the activation function.
gating_activation: A `str` name of the activation function for final gating_activation: A `str` name of the activation function for final
gating function. gating function.
round_down_protect: A `bool` of whether round down more than 10% will be
allowed.
**kwargs: Additional keyword arguments to be passed. **kwargs: Additional keyword arguments to be passed.
""" """
super(SqueezeExcitation, self).__init__(**kwargs) super(SqueezeExcitation, self).__init__(**kwargs)
...@@ -156,6 +131,7 @@ class SqueezeExcitation(tf.keras.layers.Layer): ...@@ -156,6 +131,7 @@ class SqueezeExcitation(tf.keras.layers.Layer):
self._out_filters = out_filters self._out_filters = out_filters
self._se_ratio = se_ratio self._se_ratio = se_ratio
self._divisible_by = divisible_by self._divisible_by = divisible_by
self._round_down_protect = round_down_protect
self._use_3d_input = use_3d_input self._use_3d_input = use_3d_input
self._activation = activation self._activation = activation
self._gating_activation = gating_activation self._gating_activation = gating_activation
...@@ -178,7 +154,8 @@ class SqueezeExcitation(tf.keras.layers.Layer): ...@@ -178,7 +154,8 @@ class SqueezeExcitation(tf.keras.layers.Layer):
def build(self, input_shape): def build(self, input_shape):
num_reduced_filters = make_divisible( num_reduced_filters = make_divisible(
max(1, int(self._in_filters * self._se_ratio)), max(1, int(self._in_filters * self._se_ratio)),
divisor=self._divisible_by) divisor=self._divisible_by,
round_down_protect=self._round_down_protect)
self._se_reduce = tf.keras.layers.Conv2D( self._se_reduce = tf.keras.layers.Conv2D(
filters=num_reduced_filters, filters=num_reduced_filters,
...@@ -214,6 +191,7 @@ class SqueezeExcitation(tf.keras.layers.Layer): ...@@ -214,6 +191,7 @@ class SqueezeExcitation(tf.keras.layers.Layer):
'bias_regularizer': self._bias_regularizer, 'bias_regularizer': self._bias_regularizer,
'activation': self._activation, 'activation': self._activation,
'gating_activation': self._gating_activation, 'gating_activation': self._gating_activation,
'round_down_protect': self._round_down_protect,
} }
base_config = super(SqueezeExcitation, self).get_config() base_config = super(SqueezeExcitation, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
...@@ -1369,7 +1347,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -1369,7 +1347,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
self.aspp_layers.append(pooling + [conv2, norm2]) self.aspp_layers.append(pooling + [conv2, norm2])
self._resize_layer = tf.keras.layers.Resizing( self._resizing_layer = tf.keras.layers.Resizing(
height, width, interpolation=self._interpolation, dtype=tf.float32) height, width, interpolation=self._interpolation, dtype=tf.float32)
self._projection = [ self._projection = [
...@@ -1402,7 +1380,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -1402,7 +1380,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
# Apply resize layer to the end of the last set of layers. # Apply resize layer to the end of the last set of layers.
if i == len(self.aspp_layers) - 1: if i == len(self.aspp_layers) - 1:
x = self._resize_layer(x) x = self._resizing_layer(x)
result.append(tf.cast(x, inputs.dtype)) result.append(tf.cast(x, inputs.dtype))
x = self._concat_layer(result) x = self._concat_layer(result)
......
...@@ -24,11 +24,6 @@ from official.vision.beta.modeling.layers import nn_layers ...@@ -24,11 +24,6 @@ from official.vision.beta.modeling.layers import nn_layers
class NNLayersTest(parameterized.TestCase, tf.test.TestCase): class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
def test_hard_swish(self):
activation = tf.keras.layers.Activation('hard_swish')
output = activation(tf.constant([-3, -1.5, 0, 3]))
self.assertAllEqual(output, [0., -0.375, 0., 3.])
def test_scale(self): def test_scale(self):
scale = nn_layers.Scale(initializer=tf.keras.initializers.constant(10.)) scale = nn_layers.Scale(initializer=tf.keras.initializers.constant(10.))
output = scale(3.) output = scale(3.)
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Build segmentation models.""" """Build segmentation models."""
from typing import Any, Mapping, Union from typing import Any, Mapping, Union, Optional, Dict
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
...@@ -35,13 +35,16 @@ class SegmentationModel(tf.keras.Model): ...@@ -35,13 +35,16 @@ class SegmentationModel(tf.keras.Model):
""" """
def __init__(self, backbone: tf.keras.Model, decoder: tf.keras.Model, def __init__(self, backbone: tf.keras.Model, decoder: tf.keras.Model,
head: tf.keras.layers.Layer, **kwargs): head: tf.keras.layers.Layer,
mask_scoring_head: Optional[tf.keras.layers.Layer] = None,
**kwargs):
"""Segmentation initialization function. """Segmentation initialization function.
Args: Args:
backbone: a backbone network. backbone: a backbone network.
decoder: a decoder network. E.g. FPN. decoder: a decoder network. E.g. FPN.
head: segmentation head. head: segmentation head.
mask_scoring_head: mask scoring head.
**kwargs: keyword arguments to be passed. **kwargs: keyword arguments to be passed.
""" """
super(SegmentationModel, self).__init__(**kwargs) super(SegmentationModel, self).__init__(**kwargs)
...@@ -49,12 +52,15 @@ class SegmentationModel(tf.keras.Model): ...@@ -49,12 +52,15 @@ class SegmentationModel(tf.keras.Model):
'backbone': backbone, 'backbone': backbone,
'decoder': decoder, 'decoder': decoder,
'head': head, 'head': head,
'mask_scoring_head': mask_scoring_head,
} }
self.backbone = backbone self.backbone = backbone
self.decoder = decoder self.decoder = decoder
self.head = head self.head = head
self.mask_scoring_head = mask_scoring_head
def call(self, inputs: tf.Tensor, training: bool = None) -> tf.Tensor: def call(self, inputs: tf.Tensor, training: bool = None
) -> Dict[str, tf.Tensor]:
backbone_features = self.backbone(inputs) backbone_features = self.backbone(inputs)
if self.decoder: if self.decoder:
...@@ -62,7 +68,12 @@ class SegmentationModel(tf.keras.Model): ...@@ -62,7 +68,12 @@ class SegmentationModel(tf.keras.Model):
else: else:
decoder_features = backbone_features decoder_features = backbone_features
return self.head((backbone_features, decoder_features)) logits = self.head((backbone_features, decoder_features))
outputs = {'logits': logits}
if self.mask_scoring_head:
mask_scores = self.mask_scoring_head(logits)
outputs.update({'mask_scores': mask_scores})
return outputs
@property @property
def checkpoint_items( def checkpoint_items(
...@@ -71,6 +82,8 @@ class SegmentationModel(tf.keras.Model): ...@@ -71,6 +82,8 @@ class SegmentationModel(tf.keras.Model):
items = dict(backbone=self.backbone, head=self.head) items = dict(backbone=self.backbone, head=self.head)
if self.decoder is not None: if self.decoder is not None:
items.update(decoder=self.decoder) items.update(decoder=self.decoder)
if self.mask_scoring_head is not None:
items.update(mask_scoring_head=self.mask_scoring_head)
return items return items
def get_config(self) -> Mapping[str, Any]: def get_config(self) -> Mapping[str, Any]:
......
...@@ -50,13 +50,14 @@ class SegmentationNetworkTest(parameterized.TestCase, tf.test.TestCase): ...@@ -50,13 +50,14 @@ class SegmentationNetworkTest(parameterized.TestCase, tf.test.TestCase):
model = segmentation_model.SegmentationModel( model = segmentation_model.SegmentationModel(
backbone=backbone, backbone=backbone,
decoder=decoder, decoder=decoder,
head=head head=head,
mask_scoring_head=None,
) )
logits = model(inputs) outputs = model(inputs)
self.assertAllEqual( self.assertAllEqual(
[2, input_size // (2**level), input_size // (2**level), num_classes], [2, input_size // (2**level), input_size // (2**level), num_classes],
logits.numpy().shape) outputs['logits'].numpy().shape)
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
"""Validate the network can be serialized and deserialized.""" """Validate the network can be serialized and deserialized."""
......
...@@ -14,7 +14,9 @@ ...@@ -14,7 +14,9 @@
"""Augmentation policies for enhanced image/video preprocessing. """Augmentation policies for enhanced image/video preprocessing.
AutoAugment Reference: https://arxiv.org/abs/1805.09501 AutoAugment Reference:
- AutoAugment Reference: https://arxiv.org/abs/1805.09501
- AutoAugment for Object Detection Reference: https://arxiv.org/abs/1906.11172
RandAugment Reference: https://arxiv.org/abs/1909.13719 RandAugment Reference: https://arxiv.org/abs/1909.13719
RandomErasing Reference: https://arxiv.org/abs/1708.04896 RandomErasing Reference: https://arxiv.org/abs/1708.04896
MixupAndCutmix: MixupAndCutmix:
...@@ -25,6 +27,7 @@ RandomErasing, Mixup and Cutmix are inspired by ...@@ -25,6 +27,7 @@ RandomErasing, Mixup and Cutmix are inspired by
https://github.com/rwightman/pytorch-image-models https://github.com/rwightman/pytorch-image-models
""" """
import inspect
import math import math
from typing import Any, List, Iterable, Optional, Text, Tuple from typing import Any, List, Iterable, Optional, Text, Tuple
...@@ -702,6 +705,572 @@ def unwrap(image: tf.Tensor, replace: int) -> tf.Tensor: ...@@ -702,6 +705,572 @@ def unwrap(image: tf.Tensor, replace: int) -> tf.Tensor:
return image return image
def _scale_bbox_only_op_probability(prob):
"""Reduce the probability of the bbox-only operation.
Probability is reduced so that we do not distort the content of too many
bounding boxes that are close to each other. The value of 3.0 was a chosen
hyper parameter when designing the autoaugment algorithm that we found
empirically to work well.
Args:
prob: Float that is the probability of applying the bbox-only operation.
Returns:
Reduced probability.
"""
return prob / 3.0
def _apply_bbox_augmentation(image, bbox, augmentation_func, *args):
"""Applies augmentation_func to the subsection of image indicated by bbox.
Args:
image: 3D uint8 Tensor.
bbox: 1D Tensor that has 4 elements (min_y, min_x, max_y, max_x)
of type float that represents the normalized coordinates between 0 and 1.
augmentation_func: Augmentation function that will be applied to the
subsection of image.
*args: Additional parameters that will be passed into augmentation_func
when it is called.
Returns:
A modified version of image, where the bbox location in the image will
have `ugmentation_func applied to it.
"""
image_height = tf.cast(tf.shape(image)[0], tf.float32)
image_width = tf.cast(tf.shape(image)[1], tf.float32)
min_y = tf.cast(image_height * bbox[0], tf.int32)
min_x = tf.cast(image_width * bbox[1], tf.int32)
max_y = tf.cast(image_height * bbox[2], tf.int32)
max_x = tf.cast(image_width * bbox[3], tf.int32)
image_height = tf.cast(image_height, tf.int32)
image_width = tf.cast(image_width, tf.int32)
# Clip to be sure the max values do not fall out of range.
max_y = tf.minimum(max_y, image_height - 1)
max_x = tf.minimum(max_x, image_width - 1)
# Get the sub-tensor that is the image within the bounding box region.
bbox_content = image[min_y:max_y + 1, min_x:max_x + 1, :]
# Apply the augmentation function to the bbox portion of the image.
augmented_bbox_content = augmentation_func(bbox_content, *args)
# Pad the augmented_bbox_content and the mask to match the shape of original
# image.
augmented_bbox_content = tf.pad(augmented_bbox_content,
[[min_y, (image_height - 1) - max_y],
[min_x, (image_width - 1) - max_x],
[0, 0]])
# Create a mask that will be used to zero out a part of the original image.
mask_tensor = tf.zeros_like(bbox_content)
mask_tensor = tf.pad(mask_tensor,
[[min_y, (image_height - 1) - max_y],
[min_x, (image_width - 1) - max_x],
[0, 0]],
constant_values=1)
# Replace the old bbox content with the new augmented content.
image = image * mask_tensor + augmented_bbox_content
return image
def _concat_bbox(bbox, bboxes):
"""Helper function that concates bbox to bboxes along the first dimension."""
# Note if all elements in bboxes are -1 (_INVALID_BOX), then this means
# we discard bboxes and start the bboxes Tensor with the current bbox.
bboxes_sum_check = tf.reduce_sum(bboxes)
bbox = tf.expand_dims(bbox, 0)
# This check will be true when it is an _INVALID_BOX
bboxes = tf.cond(tf.equal(bboxes_sum_check, -4.0),
lambda: bbox,
lambda: tf.concat([bboxes, bbox], 0))
return bboxes
def _apply_bbox_augmentation_wrapper(image, bbox, new_bboxes, prob,
augmentation_func, func_changes_bbox,
*args):
"""Applies _apply_bbox_augmentation with probability prob.
Args:
image: 3D uint8 Tensor.
bbox: 1D Tensor that has 4 elements (min_y, min_x, max_y, max_x)
of type float that represents the normalized coordinates between 0 and 1.
new_bboxes: 2D Tensor that is a list of the bboxes in the image after they
have been altered by aug_func. These will only be changed when
func_changes_bbox is set to true. Each bbox has 4 elements
(min_y, min_x, max_y, max_x) of type float that are the normalized
bbox coordinates between 0 and 1.
prob: Float that is the probability of applying _apply_bbox_augmentation.
augmentation_func: Augmentation function that will be applied to the
subsection of image.
func_changes_bbox: Boolean. Does augmentation_func return bbox in addition
to image.
*args: Additional parameters that will be passed into augmentation_func
when it is called.
Returns:
A tuple. Fist element is a modified version of image, where the bbox
location in the image will have augmentation_func applied to it if it is
chosen to be called with probability `prob`. The second element is a
Tensor of Tensors of length 4 that will contain the altered bbox after
applying augmentation_func.
"""
should_apply_op = tf.cast(
tf.floor(tf.random.uniform([], dtype=tf.float32) + prob), tf.bool)
if func_changes_bbox:
augmented_image, bbox = tf.cond(
should_apply_op,
lambda: augmentation_func(image, bbox, *args),
lambda: (image, bbox))
else:
augmented_image = tf.cond(
should_apply_op,
lambda: _apply_bbox_augmentation(image, bbox, augmentation_func, *args),
lambda: image)
new_bboxes = _concat_bbox(bbox, new_bboxes)
return augmented_image, new_bboxes
def _apply_multi_bbox_augmentation_wrapper(image, bboxes, prob, aug_func,
func_changes_bbox, *args):
"""Checks to be sure num bboxes > 0 before calling inner function."""
num_bboxes = tf.shape(bboxes)[0]
image, bboxes = tf.cond(
tf.equal(num_bboxes, 0),
lambda: (image, bboxes),
# pylint:disable=g-long-lambda
lambda: _apply_multi_bbox_augmentation(
image, bboxes, prob, aug_func, func_changes_bbox, *args))
# pylint:enable=g-long-lambda
return image, bboxes
# Represents an invalid bounding box that is used for checking for padding
# lists of bounding box coordinates for a few augmentation operations
_INVALID_BOX = [[-1.0, -1.0, -1.0, -1.0]]
def _apply_multi_bbox_augmentation(image, bboxes, prob, aug_func,
func_changes_bbox, *args):
"""Applies aug_func to the image for each bbox in bboxes.
Args:
image: 3D uint8 Tensor.
bboxes: 2D Tensor that is a list of the bboxes in the image. Each bbox
has 4 elements (min_y, min_x, max_y, max_x) of type float.
prob: Float that is the probability of applying aug_func to a specific
bounding box within the image.
aug_func: Augmentation function that will be applied to the
subsections of image indicated by the bbox values in bboxes.
func_changes_bbox: Boolean. Does augmentation_func return bbox in addition
to image.
*args: Additional parameters that will be passed into augmentation_func
when it is called.
Returns:
A modified version of image, where each bbox location in the image will
have augmentation_func applied to it if it is chosen to be called with
probability prob independently across all bboxes. Also the final
bboxes are returned that will be unchanged if func_changes_bbox is set to
false and if true, the new altered ones will be returned.
Raises:
ValueError if applied to video.
"""
if image.shape.rank == 4:
raise ValueError('Image rank 4 is not supported')
# Will keep track of the new altered bboxes after aug_func is repeatedly
# applied. The -1 values are a dummy value and this first Tensor will be
# removed upon appending the first real bbox.
new_bboxes = tf.constant(_INVALID_BOX)
# If the bboxes are empty, then just give it _INVALID_BOX. The result
# will be thrown away.
bboxes = tf.cond(tf.equal(tf.size(bboxes), 0),
lambda: tf.constant(_INVALID_BOX),
lambda: bboxes)
bboxes = tf.ensure_shape(bboxes, (None, 4))
# pylint:disable=g-long-lambda
wrapped_aug_func = (
lambda _image, bbox, _new_bboxes: _apply_bbox_augmentation_wrapper(
_image, bbox, _new_bboxes, prob, aug_func, func_changes_bbox, *args))
# pylint:enable=g-long-lambda
# Setup the while_loop.
num_bboxes = tf.shape(bboxes)[0] # We loop until we go over all bboxes.
idx = tf.constant(0) # Counter for the while loop.
# Conditional function when to end the loop once we go over all bboxes
# images_and_bboxes contain (_image, _new_bboxes)
cond = lambda _idx, _images_and_bboxes: tf.less(_idx, num_bboxes)
# Shuffle the bboxes so that the augmentation order is not deterministic if
# we are not changing the bboxes with aug_func.
if not func_changes_bbox:
loop_bboxes = tf.random.shuffle(bboxes)
else:
loop_bboxes = bboxes
# Main function of while_loop where we repeatedly apply augmentation on the
# bboxes in the image.
# pylint:disable=g-long-lambda
body = lambda _idx, _images_and_bboxes: [
_idx + 1, wrapped_aug_func(_images_and_bboxes[0],
loop_bboxes[_idx],
_images_and_bboxes[1])]
# pylint:enable=g-long-lambda
_, (image, new_bboxes) = tf.while_loop(
cond, body, [idx, (image, new_bboxes)],
shape_invariants=[idx.get_shape(),
(image.get_shape(), tf.TensorShape([None, 4]))])
# Either return the altered bboxes or the original ones depending on if
# we altered them in anyway.
if func_changes_bbox:
final_bboxes = new_bboxes
else:
final_bboxes = bboxes
return image, final_bboxes
def _clip_bbox(min_y, min_x, max_y, max_x):
"""Clip bounding box coordinates between 0 and 1.
Args:
min_y: Normalized bbox coordinate of type float between 0 and 1.
min_x: Normalized bbox coordinate of type float between 0 and 1.
max_y: Normalized bbox coordinate of type float between 0 and 1.
max_x: Normalized bbox coordinate of type float between 0 and 1.
Returns:
Clipped coordinate values between 0 and 1.
"""
min_y = tf.clip_by_value(min_y, 0.0, 1.0)
min_x = tf.clip_by_value(min_x, 0.0, 1.0)
max_y = tf.clip_by_value(max_y, 0.0, 1.0)
max_x = tf.clip_by_value(max_x, 0.0, 1.0)
return min_y, min_x, max_y, max_x
def _check_bbox_area(min_y, min_x, max_y, max_x, delta=0.05):
"""Adjusts bbox coordinates to make sure the area is > 0.
Args:
min_y: Normalized bbox coordinate of type float between 0 and 1.
min_x: Normalized bbox coordinate of type float between 0 and 1.
max_y: Normalized bbox coordinate of type float between 0 and 1.
max_x: Normalized bbox coordinate of type float between 0 and 1.
delta: Float, this is used to create a gap of size 2 * delta between
bbox min/max coordinates that are the same on the boundary.
This prevents the bbox from having an area of zero.
Returns:
Tuple of new bbox coordinates between 0 and 1 that will now have a
guaranteed area > 0.
"""
height = max_y - min_y
width = max_x - min_x
def _adjust_bbox_boundaries(min_coord, max_coord):
# Make sure max is never 0 and min is never 1.
max_coord = tf.maximum(max_coord, 0.0 + delta)
min_coord = tf.minimum(min_coord, 1.0 - delta)
return min_coord, max_coord
min_y, max_y = tf.cond(tf.equal(height, 0.0),
lambda: _adjust_bbox_boundaries(min_y, max_y),
lambda: (min_y, max_y))
min_x, max_x = tf.cond(tf.equal(width, 0.0),
lambda: _adjust_bbox_boundaries(min_x, max_x),
lambda: (min_x, max_x))
return min_y, min_x, max_y, max_x
def _rotate_bbox(bbox, image_height, image_width, degrees):
"""Rotates the bbox coordinated by degrees.
Args:
bbox: 1D Tensor that has 4 elements (min_y, min_x, max_y, max_x)
of type float that represents the normalized coordinates between 0 and 1.
image_height: Int, height of the image.
image_width: Int, height of the image.
degrees: Float, a scalar angle in degrees to rotate all images by. If
degrees is positive the image will be rotated clockwise otherwise it will
be rotated counterclockwise.
Returns:
A tensor of the same shape as bbox, but now with the rotated coordinates.
"""
image_height, image_width = (
tf.cast(image_height, tf.float32), tf.cast(image_width, tf.float32))
# Convert from degrees to radians.
degrees_to_radians = math.pi / 180.0
radians = degrees * degrees_to_radians
# Translate the bbox to the center of the image and turn the normalized 0-1
# coordinates to absolute pixel locations.
# Y coordinates are made negative as the y axis of images goes down with
# increasing pixel values, so we negate to make sure x axis and y axis points
# are in the traditionally positive direction.
min_y = -tf.cast(image_height * (bbox[0] - 0.5), tf.int32)
min_x = tf.cast(image_width * (bbox[1] - 0.5), tf.int32)
max_y = -tf.cast(image_height * (bbox[2] - 0.5), tf.int32)
max_x = tf.cast(image_width * (bbox[3] - 0.5), tf.int32)
coordinates = tf.stack(
[[min_y, min_x], [min_y, max_x], [max_y, min_x], [max_y, max_x]])
coordinates = tf.cast(coordinates, tf.float32)
# Rotate the coordinates according to the rotation matrix clockwise if
# radians is positive, else negative
rotation_matrix = tf.stack(
[[tf.cos(radians), tf.sin(radians)],
[-tf.sin(radians), tf.cos(radians)]])
new_coords = tf.cast(
tf.matmul(rotation_matrix, tf.transpose(coordinates)), tf.int32)
# Find min/max values and convert them back to normalized 0-1 floats.
min_y = -(
tf.cast(tf.reduce_max(new_coords[0, :]), tf.float32) / image_height - 0.5)
min_x = tf.cast(tf.reduce_min(new_coords[1, :]),
tf.float32) / image_width + 0.5
max_y = -(
tf.cast(tf.reduce_min(new_coords[0, :]), tf.float32) / image_height - 0.5)
max_x = tf.cast(tf.reduce_max(new_coords[1, :]),
tf.float32) / image_width + 0.5
# Clip the bboxes to be sure the fall between [0, 1].
min_y, min_x, max_y, max_x = _clip_bbox(min_y, min_x, max_y, max_x)
min_y, min_x, max_y, max_x = _check_bbox_area(min_y, min_x, max_y, max_x)
return tf.stack([min_y, min_x, max_y, max_x])
def rotate_with_bboxes(image, bboxes, degrees, replace):
"""Equivalent of PIL Rotate that rotates the image and bbox.
Args:
image: 3D uint8 Tensor.
bboxes: 2D Tensor that is a list of the bboxes in the image. Each bbox
has 4 elements (min_y, min_x, max_y, max_x) of type float.
degrees: Float, a scalar angle in degrees to rotate all images by. If
degrees is positive the image will be rotated clockwise otherwise it will
be rotated counterclockwise.
replace: A one or three value 1D tensor to fill empty pixels.
Returns:
A tuple containing a 3D uint8 Tensor that will be the result of rotating
image by degrees. The second element of the tuple is bboxes, where now
the coordinates will be shifted to reflect the rotated image.
Raises:
ValueError: If applied to video.
"""
if image.shape.rank == 4:
raise ValueError('Image rank 4 is not supported')
# Rotate the image.
image = wrapped_rotate(image, degrees, replace)
# Convert bbox coordinates to pixel values.
image_height = tf.shape(image)[0]
image_width = tf.shape(image)[1]
# pylint:disable=g-long-lambda
wrapped_rotate_bbox = lambda bbox: _rotate_bbox(
bbox, image_height, image_width, degrees)
# pylint:enable=g-long-lambda
bboxes = tf.map_fn(wrapped_rotate_bbox, bboxes)
return image, bboxes
def _shear_bbox(bbox, image_height, image_width, level, shear_horizontal):
"""Shifts the bbox according to how the image was sheared.
Args:
bbox: 1D Tensor that has 4 elements (min_y, min_x, max_y, max_x)
of type float that represents the normalized coordinates between 0 and 1.
image_height: Int, height of the image.
image_width: Int, height of the image.
level: Float. How much to shear the image.
shear_horizontal: If true then shear in X dimension else shear in
the Y dimension.
Returns:
A tensor of the same shape as bbox, but now with the shifted coordinates.
"""
image_height, image_width = (
tf.cast(image_height, tf.float32), tf.cast(image_width, tf.float32))
# Change bbox coordinates to be pixels.
min_y = tf.cast(image_height * bbox[0], tf.int32)
min_x = tf.cast(image_width * bbox[1], tf.int32)
max_y = tf.cast(image_height * bbox[2], tf.int32)
max_x = tf.cast(image_width * bbox[3], tf.int32)
coordinates = tf.stack(
[[min_y, min_x], [min_y, max_x], [max_y, min_x], [max_y, max_x]])
coordinates = tf.cast(coordinates, tf.float32)
# Shear the coordinates according to the translation matrix.
if shear_horizontal:
translation_matrix = tf.stack(
[[1, 0], [-level, 1]])
else:
translation_matrix = tf.stack(
[[1, -level], [0, 1]])
translation_matrix = tf.cast(translation_matrix, tf.float32)
new_coords = tf.cast(
tf.matmul(translation_matrix, tf.transpose(coordinates)), tf.int32)
# Find min/max values and convert them back to floats.
min_y = tf.cast(tf.reduce_min(new_coords[0, :]), tf.float32) / image_height
min_x = tf.cast(tf.reduce_min(new_coords[1, :]), tf.float32) / image_width
max_y = tf.cast(tf.reduce_max(new_coords[0, :]), tf.float32) / image_height
max_x = tf.cast(tf.reduce_max(new_coords[1, :]), tf.float32) / image_width
# Clip the bboxes to be sure the fall between [0, 1].
min_y, min_x, max_y, max_x = _clip_bbox(min_y, min_x, max_y, max_x)
min_y, min_x, max_y, max_x = _check_bbox_area(min_y, min_x, max_y, max_x)
return tf.stack([min_y, min_x, max_y, max_x])
def shear_with_bboxes(image, bboxes, level, replace, shear_horizontal):
"""Applies Shear Transformation to the image and shifts the bboxes.
Args:
image: 3D uint8 Tensor.
bboxes: 2D Tensor that is a list of the bboxes in the image. Each bbox
has 4 elements (min_y, min_x, max_y, max_x) of type float with values
between [0, 1].
level: Float. How much to shear the image. This value will be between
-0.3 to 0.3.
replace: A one or three value 1D tensor to fill empty pixels.
shear_horizontal: Boolean. If true then shear in X dimension else shear in
the Y dimension.
Returns:
A tuple containing a 3D uint8 Tensor that will be the result of shearing
image by level. The second element of the tuple is bboxes, where now
the coordinates will be shifted to reflect the sheared image.
Raises:
ValueError: If applied to video.
"""
if image.shape.rank == 4:
raise ValueError('Image rank 4 is not supported')
if shear_horizontal:
image = shear_x(image, level, replace)
else:
image = shear_y(image, level, replace)
# Convert bbox coordinates to pixel values.
image_height = tf.shape(image)[0]
image_width = tf.shape(image)[1]
# pylint:disable=g-long-lambda
wrapped_shear_bbox = lambda bbox: _shear_bbox(
bbox, image_height, image_width, level, shear_horizontal)
# pylint:enable=g-long-lambda
bboxes = tf.map_fn(wrapped_shear_bbox, bboxes)
return image, bboxes
def _shift_bbox(bbox, image_height, image_width, pixels, shift_horizontal):
"""Shifts the bbox coordinates by pixels.
Args:
bbox: 1D Tensor that has 4 elements (min_y, min_x, max_y, max_x)
of type float that represents the normalized coordinates between 0 and 1.
image_height: Int, height of the image.
image_width: Int, width of the image.
pixels: An int. How many pixels to shift the bbox.
shift_horizontal: Boolean. If true then shift in X dimension else shift in
Y dimension.
Returns:
A tensor of the same shape as bbox, but now with the shifted coordinates.
"""
pixels = tf.cast(pixels, tf.int32)
# Convert bbox to integer pixel locations.
min_y = tf.cast(tf.cast(image_height, tf.float32) * bbox[0], tf.int32)
min_x = tf.cast(tf.cast(image_width, tf.float32) * bbox[1], tf.int32)
max_y = tf.cast(tf.cast(image_height, tf.float32) * bbox[2], tf.int32)
max_x = tf.cast(tf.cast(image_width, tf.float32) * bbox[3], tf.int32)
if shift_horizontal:
min_x = tf.maximum(0, min_x - pixels)
max_x = tf.minimum(image_width, max_x - pixels)
else:
min_y = tf.maximum(0, min_y - pixels)
max_y = tf.minimum(image_height, max_y - pixels)
# Convert bbox back to floats.
min_y = tf.cast(min_y, tf.float32) / tf.cast(image_height, tf.float32)
min_x = tf.cast(min_x, tf.float32) / tf.cast(image_width, tf.float32)
max_y = tf.cast(max_y, tf.float32) / tf.cast(image_height, tf.float32)
max_x = tf.cast(max_x, tf.float32) / tf.cast(image_width, tf.float32)
# Clip the bboxes to be sure the fall between [0, 1].
min_y, min_x, max_y, max_x = _clip_bbox(min_y, min_x, max_y, max_x)
min_y, min_x, max_y, max_x = _check_bbox_area(min_y, min_x, max_y, max_x)
return tf.stack([min_y, min_x, max_y, max_x])
def translate_bbox(image, bboxes, pixels, replace, shift_horizontal):
"""Equivalent of PIL Translate in X/Y dimension that shifts image and bbox.
Args:
image: 3D uint8 Tensor.
bboxes: 2D Tensor that is a list of the bboxes in the image. Each bbox
has 4 elements (min_y, min_x, max_y, max_x) of type float with values
between [0, 1].
pixels: An int. How many pixels to shift the image and bboxes
replace: A one or three value 1D tensor to fill empty pixels.
shift_horizontal: Boolean. If true then shift in X dimension else shift in
Y dimension.
Returns:
A tuple containing a 3D uint8 Tensor that will be the result of translating
image by pixels. The second element of the tuple is bboxes, where now
the coordinates will be shifted to reflect the shifted image.
Raises:
ValueError if applied to video.
"""
if image.shape.rank == 4:
raise ValueError('Image rank 4 is not supported')
if shift_horizontal:
image = translate_x(image, pixels, replace)
else:
image = translate_y(image, pixels, replace)
# Convert bbox coordinates to pixel values.
image_height = tf.shape(image)[0]
image_width = tf.shape(image)[1]
# pylint:disable=g-long-lambda
wrapped_shift_bbox = lambda bbox: _shift_bbox(
bbox, image_height, image_width, pixels, shift_horizontal)
# pylint:enable=g-long-lambda
bboxes = tf.map_fn(wrapped_shift_bbox, bboxes)
return image, bboxes
def translate_y_only_bboxes(
image: tf.Tensor, bboxes: tf.Tensor, prob: float, pixels: int, replace):
"""Apply translate_y to each bbox in the image with probability prob."""
if bboxes.shape.rank == 4:
raise ValueError('translate_y_only_bboxes does not support rank 4 boxes')
func_changes_bbox = False
prob = _scale_bbox_only_op_probability(prob)
return _apply_multi_bbox_augmentation_wrapper(
image, bboxes, prob, translate_y, func_changes_bbox, pixels, replace)
def _randomly_negate_tensor(tensor): def _randomly_negate_tensor(tensor):
"""With 50% prob turn the tensor negative.""" """With 50% prob turn the tensor negative."""
should_flip = tf.cast(tf.floor(tf.random.uniform([]) + 0.5), tf.bool) should_flip = tf.cast(tf.floor(tf.random.uniform([]) + 0.5), tf.bool)
...@@ -746,29 +1315,35 @@ def _mult_to_arg(level: float, multiplier: float = 1.): ...@@ -746,29 +1315,35 @@ def _mult_to_arg(level: float, multiplier: float = 1.):
return (int((level / _MAX_LEVEL) * multiplier),) return (int((level / _MAX_LEVEL) * multiplier),)
def _apply_func_with_prob(func: Any, image: tf.Tensor, args: Any, prob: float): def _apply_func_with_prob(func: Any, image: tf.Tensor,
bboxes: Optional[tf.Tensor], args: Any, prob: float):
"""Apply `func` to image w/ `args` as input with probability `prob`.""" """Apply `func` to image w/ `args` as input with probability `prob`."""
assert isinstance(args, tuple) assert isinstance(args, tuple)
assert inspect.getfullargspec(func)[0][1] == 'bboxes'
# Apply the function with probability `prob`. # Apply the function with probability `prob`.
should_apply_op = tf.cast( should_apply_op = tf.cast(
tf.floor(tf.random.uniform([], dtype=tf.float32) + prob), tf.bool) tf.floor(tf.random.uniform([], dtype=tf.float32) + prob), tf.bool)
augmented_image = tf.cond(should_apply_op, lambda: func(image, *args), augmented_image, augmented_bboxes = tf.cond(
lambda: image) should_apply_op,
return augmented_image lambda: func(image, bboxes, *args),
lambda: (image, bboxes))
return augmented_image, augmented_bboxes
def select_and_apply_random_policy(policies: Any, image: tf.Tensor): def select_and_apply_random_policy(policies: Any,
image: tf.Tensor,
bboxes: Optional[tf.Tensor] = None):
"""Select a random policy from `policies` and apply it to `image`.""" """Select a random policy from `policies` and apply it to `image`."""
policy_to_select = tf.random.uniform([], maxval=len(policies), dtype=tf.int32) policy_to_select = tf.random.uniform([], maxval=len(policies), dtype=tf.int32)
# Note that using tf.case instead of tf.conds would result in significantly # Note that using tf.case instead of tf.conds would result in significantly
# larger graphs and would even break export for some larger policies. # larger graphs and would even break export for some larger policies.
for (i, policy) in enumerate(policies): for (i, policy) in enumerate(policies):
image = tf.cond( image, bboxes = tf.cond(
tf.equal(i, policy_to_select), tf.equal(i, policy_to_select),
lambda selected_policy=policy: selected_policy(image), lambda selected_policy=policy: selected_policy(image, bboxes),
lambda: image) lambda: (image, bboxes))
return image return image, bboxes
NAME_TO_FUNC = { NAME_TO_FUNC = {
...@@ -788,8 +1363,35 @@ NAME_TO_FUNC = { ...@@ -788,8 +1363,35 @@ NAME_TO_FUNC = {
'TranslateX': translate_x, 'TranslateX': translate_x,
'TranslateY': translate_y, 'TranslateY': translate_y,
'Cutout': cutout, 'Cutout': cutout,
'Rotate_BBox': rotate_with_bboxes,
# pylint:disable=g-long-lambda
'ShearX_BBox': lambda image, bboxes, level, replace: shear_with_bboxes(
image, bboxes, level, replace, shear_horizontal=True),
'ShearY_BBox': lambda image, bboxes, level, replace: shear_with_bboxes(
image, bboxes, level, replace, shear_horizontal=False),
'TranslateX_BBox': lambda image, bboxes, pixels, replace: translate_bbox(
image, bboxes, pixels, replace, shift_horizontal=True),
'TranslateY_BBox': lambda image, bboxes, pixels, replace: translate_bbox(
image, bboxes, pixels, replace, shift_horizontal=False),
# pylint:enable=g-long-lambda
'TranslateY_Only_BBoxes': translate_y_only_bboxes,
} }
# Functions that require a `bboxes` parameter.
REQUIRE_BOXES_FUNCS = frozenset({
'Rotate_BBox',
'ShearX_BBox',
'ShearY_BBox',
'TranslateX_BBox',
'TranslateY_BBox',
'TranslateY_Only_BBoxes',
})
# Functions that have a 'prob' parameter
PROB_FUNCS = frozenset({
'TranslateY_Only_BBoxes',
})
# Functions that have a 'replace' parameter # Functions that have a 'replace' parameter
REPLACE_FUNCS = frozenset({ REPLACE_FUNCS = frozenset({
'Rotate', 'Rotate',
...@@ -798,6 +1400,12 @@ REPLACE_FUNCS = frozenset({ ...@@ -798,6 +1400,12 @@ REPLACE_FUNCS = frozenset({
'ShearY', 'ShearY',
'TranslateY', 'TranslateY',
'Cutout', 'Cutout',
'Rotate_BBox',
'ShearX_BBox',
'ShearY_BBox',
'TranslateX_BBox',
'TranslateY_BBox',
'TranslateY_Only_BBoxes',
}) })
...@@ -810,6 +1418,7 @@ def level_to_arg(cutout_const: float, translate_const: float): ...@@ -810,6 +1418,7 @@ def level_to_arg(cutout_const: float, translate_const: float):
solarize_add_arg = lambda level: _mult_to_arg(level, 110) solarize_add_arg = lambda level: _mult_to_arg(level, 110)
cutout_arg = lambda level: _mult_to_arg(level, cutout_const) cutout_arg = lambda level: _mult_to_arg(level, cutout_const)
translate_arg = lambda level: _translate_level_to_arg(level, translate_const) translate_arg = lambda level: _translate_level_to_arg(level, translate_const)
translate_bbox_arg = lambda level: _translate_level_to_arg(level, 120)
args = { args = {
'AutoContrast': no_arg, 'AutoContrast': no_arg,
...@@ -828,10 +1437,27 @@ def level_to_arg(cutout_const: float, translate_const: float): ...@@ -828,10 +1437,27 @@ def level_to_arg(cutout_const: float, translate_const: float):
'Cutout': cutout_arg, 'Cutout': cutout_arg,
'TranslateX': translate_arg, 'TranslateX': translate_arg,
'TranslateY': translate_arg, 'TranslateY': translate_arg,
'Rotate_BBox': _rotate_level_to_arg,
'ShearX_BBox': _shear_level_to_arg,
'ShearY_BBox': _shear_level_to_arg,
# pylint:disable=g-long-lambda
'TranslateX_BBox': lambda level: _translate_level_to_arg(
level, translate_const),
'TranslateY_BBox': lambda level: _translate_level_to_arg(
level, translate_const),
# pylint:enable=g-long-lambda
'TranslateY_Only_BBoxes': translate_bbox_arg,
} }
return args return args
def bbox_wrapper(func):
"""Adds a bboxes function argument to func and returns unchanged bboxes."""
def wrapper(images, bboxes, *args, **kwargs):
return (func(images, *args, **kwargs), bboxes)
return wrapper
def _parse_policy_info(name: Text, def _parse_policy_info(name: Text,
prob: float, prob: float,
level: float, level: float,
...@@ -848,28 +1474,58 @@ def _parse_policy_info(name: Text, ...@@ -848,28 +1474,58 @@ def _parse_policy_info(name: Text,
args = level_to_arg(cutout_const, translate_const)[name](level) args = level_to_arg(cutout_const, translate_const)[name](level)
if name in PROB_FUNCS:
# Add in the prob arg if it is required for the function that is called.
args = tuple([prob] + list(args))
if name in REPLACE_FUNCS: if name in REPLACE_FUNCS:
# Add in replace arg if it is required for the function that is called. # Add in replace arg if it is required for the function that is called.
args = tuple(list(args) + [replace_value]) args = tuple(list(args) + [replace_value])
# Add bboxes as the second positional argument for the function if it does
# not already exist.
if 'bboxes' not in inspect.getfullargspec(func)[0]:
func = bbox_wrapper(func)
return func, prob, args return func, prob, args
class ImageAugment(object): class ImageAugment(object):
"""Image augmentation class for applying image distortions.""" """Image augmentation class for applying image distortions."""
def distort(self, image: tf.Tensor) -> tf.Tensor: def distort(
self,
image: tf.Tensor
) -> tf.Tensor:
"""Given an image tensor, returns a distorted image with the same shape. """Given an image tensor, returns a distorted image with the same shape.
Args: Args:
image: `Tensor` of shape [height, width, 3] or image: `Tensor` of shape [height, width, 3] or
[num_frames, height, width, 3] representing an image or image sequence. [num_frames, height, width, 3] representing an image or image sequence.
Returns: Returns:
The augmented version of `image`. The augmented version of `image`.
""" """
raise NotImplementedError() raise NotImplementedError()
def distort_with_boxes(
self,
image: tf.Tensor,
bboxes: tf.Tensor
) -> Tuple[tf.Tensor, tf.Tensor]:
"""Distorts the image and bounding boxes.
Args:
image: `Tensor` of shape [height, width, 3] or
[num_frames, height, width, 3] representing an image or image sequence.
bboxes: `Tensor` of shape [num_boxes, 4] or [num_frames, num_boxes, 4]
representing bounding boxes for an image or image sequence.
Returns:
The augmented version of `image` and `bboxes`.
"""
raise NotImplementedError
class AutoAugment(ImageAugment): class AutoAugment(ImageAugment):
"""Applies the AutoAugment policy to images. """Applies the AutoAugment policy to images.
...@@ -920,6 +1576,7 @@ class AutoAugment(ImageAugment): ...@@ -920,6 +1576,7 @@ class AutoAugment(ImageAugment):
self.cutout_const = float(cutout_const) self.cutout_const = float(cutout_const)
self.translate_const = float(translate_const) self.translate_const = float(translate_const)
self.available_policies = { self.available_policies = {
'detection_v0': self.detection_policy_v0(),
'v0': self.policy_v0(), 'v0': self.policy_v0(),
'test': self.policy_test(), 'test': self.policy_test(),
'simple': self.policy_simple(), 'simple': self.policy_simple(),
...@@ -954,24 +1611,8 @@ class AutoAugment(ImageAugment): ...@@ -954,24 +1611,8 @@ class AutoAugment(ImageAugment):
raise ValueError('Wrong shape detected for custom policy. Expected ' raise ValueError('Wrong shape detected for custom policy. Expected '
'(:, :, 3) but got {}.'.format(in_shape)) '(:, :, 3) but got {}.'.format(in_shape))
def distort(self, image: tf.Tensor) -> tf.Tensor: def _make_tf_policies(self):
"""Applies the AutoAugment policy to `image`. """Prepares the TF functions for augmentations based on the policies."""
AutoAugment is from the paper: https://arxiv.org/abs/1805.09501.
Args:
image: `Tensor` of shape [height, width, 3] representing an image.
Returns:
A version of image that now has data augmentation applied to it based on
the `policies` pass into the function.
"""
input_image_type = image.dtype
if input_image_type != tf.uint8:
image = tf.clip_by_value(image, 0.0, 255.0)
image = tf.cast(image, dtype=tf.uint8)
replace_value = [128] * 3 replace_value = [128] * 3
# func is the string name of the augmentation function, prob is the # func is the string name of the augmentation function, prob is the
...@@ -1000,20 +1641,64 @@ class AutoAugment(ImageAugment): ...@@ -1000,20 +1641,64 @@ class AutoAugment(ImageAugment):
# on image. # on image.
def make_final_policy(tf_policy_): def make_final_policy(tf_policy_):
def final_policy(image_): def final_policy(image_, bboxes_):
for func, prob, args in tf_policy_: for func, prob, args in tf_policy_:
image_ = _apply_func_with_prob(func, image_, args, prob) image_, bboxes_ = _apply_func_with_prob(func, image_, bboxes_, args,
return image_ prob)
return image_, bboxes_
return final_policy return final_policy
with tf.control_dependencies(assert_ranges): with tf.control_dependencies(assert_ranges):
tf_policies.append(make_final_policy(tf_policy)) tf_policies.append(make_final_policy(tf_policy))
image = select_and_apply_random_policy(tf_policies, image) return tf_policies
image = tf.cast(image, dtype=input_image_type)
def distort(self, image: tf.Tensor) -> tf.Tensor:
"""See base class."""
input_image_type = image.dtype
if input_image_type != tf.uint8:
image = tf.clip_by_value(image, 0.0, 255.0)
image = tf.cast(image, dtype=tf.uint8)
tf_policies = self._make_tf_policies()
image, _ = select_and_apply_random_policy(tf_policies, image, bboxes=None)
return image return image
def distort_with_boxes(self, image: tf.Tensor,
bboxes: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
"""See base class."""
input_image_type = image.dtype
if input_image_type != tf.uint8:
image = tf.clip_by_value(image, 0.0, 255.0)
image = tf.cast(image, dtype=tf.uint8)
tf_policies = self._make_tf_policies()
image, bboxes = select_and_apply_random_policy(tf_policies, image, bboxes)
return image, bboxes
@staticmethod
def detection_policy_v0():
"""Autoaugment policy that was used in AutoAugment Paper for Detection.
https://arxiv.org/pdf/1906.11172
Each tuple is an augmentation operation of the form
(operation, probability, magnitude). Each element in policy is a
sub-policy that will be applied sequentially on the image.
Returns:
the policy.
"""
policy = [
[('TranslateX_BBox', 0.6, 4), ('Equalize', 0.8, 10)],
[('TranslateY_Only_BBoxes', 0.2, 2), ('Cutout', 0.8, 8)],
[('Sharpness', 0.0, 8), ('ShearX_BBox', 0.4, 0)],
[('ShearY_BBox', 1.0, 2), ('TranslateY_Only_BBoxes', 0.6, 6)],
[('Rotate_BBox', 0.6, 10), ('Color', 1.0, 6)],
]
return policy
@staticmethod @staticmethod
def policy_v0(): def policy_v0():
"""Autoaugment policy that was used in AutoAugment Paper. """Autoaugment policy that was used in AutoAugment Paper.
...@@ -1211,6 +1896,10 @@ class AutoAugment(ImageAugment): ...@@ -1211,6 +1896,10 @@ class AutoAugment(ImageAugment):
return policy return policy
def _maybe_identity(x: Optional[tf.Tensor]) -> Optional[tf.Tensor]:
return tf.identity(x) if x is not None else None
class RandAugment(ImageAugment): class RandAugment(ImageAugment):
"""Applies the RandAugment policy to images. """Applies the RandAugment policy to images.
...@@ -1261,15 +1950,12 @@ class RandAugment(ImageAugment): ...@@ -1261,15 +1950,12 @@ class RandAugment(ImageAugment):
op for op in self.available_ops if op not in exclude_ops op for op in self.available_ops if op not in exclude_ops
] ]
def distort(self, image: tf.Tensor) -> tf.Tensor: def _distort_common(
"""Applies the RandAugment policy to `image`. self,
image: tf.Tensor,
Args: bboxes: Optional[tf.Tensor] = None
image: `Tensor` of shape [height, width, 3] representing an image. ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Distorts the image and optionally bounding boxes."""
Returns:
The augmented version of `image`.
"""
input_image_type = image.dtype input_image_type = image.dtype
if input_image_type != tf.uint8: if input_image_type != tf.uint8:
...@@ -1280,6 +1966,7 @@ class RandAugment(ImageAugment): ...@@ -1280,6 +1966,7 @@ class RandAugment(ImageAugment):
min_prob, max_prob = 0.2, 0.8 min_prob, max_prob = 0.2, 0.8
aug_image = image aug_image = image
aug_bboxes = bboxes
for _ in range(self.num_layers): for _ in range(self.num_layers):
op_to_select = tf.random.uniform([], op_to_select = tf.random.uniform([],
...@@ -1300,23 +1987,36 @@ class RandAugment(ImageAugment): ...@@ -1300,23 +1987,36 @@ class RandAugment(ImageAugment):
i, i,
# pylint:disable=g-long-lambda # pylint:disable=g-long-lambda
lambda selected_func=func, selected_args=args: selected_func( lambda selected_func=func, selected_args=args: selected_func(
image, *selected_args))) image, bboxes, *selected_args)))
# pylint:enable=g-long-lambda # pylint:enable=g-long-lambda
aug_image = tf.switch_case( aug_image, aug_bboxes = tf.switch_case(
branch_index=op_to_select, branch_index=op_to_select,
branch_fns=branch_fns, branch_fns=branch_fns,
default=lambda: tf.identity(image)) default=lambda: (tf.identity(image), _maybe_identity(bboxes)))
if self.prob_to_apply is not None: if self.prob_to_apply is not None:
aug_image = tf.cond( aug_image, aug_bboxes = tf.cond(
tf.random.uniform(shape=[], dtype=tf.float32) < self.prob_to_apply, tf.random.uniform(shape=[], dtype=tf.float32) < self.prob_to_apply,
lambda: tf.identity(aug_image), lambda: tf.identity(image)) lambda: (tf.identity(aug_image), _maybe_identity(aug_bboxes)),
lambda: (tf.identity(image), _maybe_identity(bboxes)))
image = aug_image image = aug_image
bboxes = aug_bboxes
image = tf.cast(image, dtype=input_image_type) image = tf.cast(image, dtype=input_image_type)
return image, bboxes
def distort(self, image: tf.Tensor) -> tf.Tensor:
"""See base class."""
image, _ = self._distort_common(image)
return image return image
def distort_with_boxes(self, image: tf.Tensor,
bboxes: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
"""See base class."""
image, bboxes = self._distort_common(image, bboxes)
return image, bboxes
class RandomErasing(ImageAugment): class RandomErasing(ImageAugment):
"""Applies RandomErasing to a single image. """Applies RandomErasing to a single image.
......
...@@ -95,15 +95,7 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase): ...@@ -95,15 +95,7 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
'reduced_cifar10', 'reduced_cifar10',
'svhn', 'svhn',
'reduced_imagenet', 'reduced_imagenet',
] 'detection_v0',
AVAILABLE_POLICIES = [
'v0',
'test',
'simple',
'reduced_cifar10',
'svhn',
'reduced_imagenet',
] ]
def test_autoaugment(self): def test_autoaugment(self):
...@@ -116,6 +108,18 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase): ...@@ -116,6 +108,18 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual((224, 224, 3), aug_image.shape) self.assertEqual((224, 224, 3), aug_image.shape)
def test_autoaugment_with_bboxes(self):
"""Smoke test to be sure there are no syntax errors with bboxes."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
bboxes = tf.ones((2, 4), dtype=tf.float32)
for policy in self.AVAILABLE_POLICIES:
augmenter = augment.AutoAugment(augmentation_name=policy)
aug_image, aug_bboxes = augmenter.distort_with_boxes(image, bboxes)
self.assertEqual((224, 224, 3), aug_image.shape)
self.assertEqual((2, 4), aug_bboxes.shape)
def test_randaug(self): def test_randaug(self):
"""Smoke test to be sure there are no syntax errors.""" """Smoke test to be sure there are no syntax errors."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8) image = tf.zeros((224, 224, 3), dtype=tf.uint8)
...@@ -125,6 +129,17 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase): ...@@ -125,6 +129,17 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual((224, 224, 3), aug_image.shape) self.assertEqual((224, 224, 3), aug_image.shape)
def test_randaug_with_bboxes(self):
"""Smoke test to be sure there are no syntax errors with bboxes."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
bboxes = tf.ones((2, 4), dtype=tf.float32)
augmenter = augment.RandAugment()
aug_image, aug_bboxes = augmenter.distort_with_boxes(image, bboxes)
self.assertEqual((224, 224, 3), aug_image.shape)
self.assertEqual((2, 4), aug_bboxes.shape)
def test_all_policy_ops(self): def test_all_policy_ops(self):
"""Smoke test to be sure all augmentation functions can execute.""" """Smoke test to be sure all augmentation functions can execute."""
...@@ -135,14 +150,37 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase): ...@@ -135,14 +150,37 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
translate_const = 250 translate_const = 250
image = tf.ones((224, 224, 3), dtype=tf.uint8) image = tf.ones((224, 224, 3), dtype=tf.uint8)
bboxes = None
for op_name in augment.NAME_TO_FUNC.keys() - augment.REQUIRE_BOXES_FUNCS:
func, _, args = augment._parse_policy_info(op_name, prob, magnitude,
replace_value, cutout_const,
translate_const)
image, bboxes = func(image, bboxes, *args)
self.assertEqual((224, 224, 3), image.shape)
self.assertIsNone(bboxes)
def test_all_policy_ops_with_bboxes(self):
"""Smoke test to be sure all augmentation functions can execute."""
prob = 1
magnitude = 10
replace_value = [128] * 3
cutout_const = 100
translate_const = 250
image = tf.ones((224, 224, 3), dtype=tf.uint8)
bboxes = tf.ones((2, 4), dtype=tf.float32)
for op_name in augment.NAME_TO_FUNC: for op_name in augment.NAME_TO_FUNC:
func, _, args = augment._parse_policy_info(op_name, prob, magnitude, func, _, args = augment._parse_policy_info(op_name, prob, magnitude,
replace_value, cutout_const, replace_value, cutout_const,
translate_const) translate_const)
image = func(image, *args) image, bboxes = func(image, bboxes, *args)
self.assertEqual((224, 224, 3), image.shape) self.assertEqual((224, 224, 3), image.shape)
self.assertEqual((2, 4), bboxes.shape)
def test_autoaugment_video(self): def test_autoaugment_video(self):
"""Smoke test with video to be sure there are no syntax errors.""" """Smoke test with video to be sure there are no syntax errors."""
...@@ -154,6 +192,18 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase): ...@@ -154,6 +192,18 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual((2, 224, 224, 3), aug_image.shape) self.assertEqual((2, 224, 224, 3), aug_image.shape)
def test_autoaugment_video_with_boxes(self):
"""Smoke test with video to be sure there are no syntax errors."""
image = tf.zeros((2, 224, 224, 3), dtype=tf.uint8)
bboxes = tf.ones((2, 2, 4), dtype=tf.float32)
for policy in self.AVAILABLE_POLICIES:
augmenter = augment.AutoAugment(augmentation_name=policy)
aug_image, aug_bboxes = augmenter.distort_with_boxes(image, bboxes)
self.assertEqual((2, 224, 224, 3), aug_image.shape)
self.assertEqual((2, 2, 4), aug_bboxes.shape)
def test_randaug_video(self): def test_randaug_video(self):
"""Smoke test with video to be sure there are no syntax errors.""" """Smoke test with video to be sure there are no syntax errors."""
image = tf.zeros((2, 224, 224, 3), dtype=tf.uint8) image = tf.zeros((2, 224, 224, 3), dtype=tf.uint8)
...@@ -173,14 +223,48 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase): ...@@ -173,14 +223,48 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
translate_const = 250 translate_const = 250
image = tf.ones((2, 224, 224, 3), dtype=tf.uint8) image = tf.ones((2, 224, 224, 3), dtype=tf.uint8)
bboxes = None
for op_name in augment.NAME_TO_FUNC.keys() - augment.REQUIRE_BOXES_FUNCS:
func, _, args = augment._parse_policy_info(op_name, prob, magnitude,
replace_value, cutout_const,
translate_const)
image, bboxes = func(image, bboxes, *args)
self.assertEqual((2, 224, 224, 3), image.shape)
self.assertIsNone(bboxes)
def test_all_policy_ops_video_with_bboxes(self):
"""Smoke test to be sure all video augmentation functions can execute."""
prob = 1
magnitude = 10
replace_value = [128] * 3
cutout_const = 100
translate_const = 250
image = tf.ones((2, 224, 224, 3), dtype=tf.uint8)
bboxes = tf.ones((2, 2, 4), dtype=tf.float32)
for op_name in augment.NAME_TO_FUNC: for op_name in augment.NAME_TO_FUNC:
func, _, args = augment._parse_policy_info(op_name, prob, magnitude, func, _, args = augment._parse_policy_info(op_name, prob, magnitude,
replace_value, cutout_const, replace_value, cutout_const,
translate_const) translate_const)
image = func(image, *args) if op_name in {
'Rotate_BBox',
'ShearX_BBox',
'ShearY_BBox',
'TranslateX_BBox',
'TranslateY_BBox',
'TranslateY_Only_BBoxes',
}:
with self.assertRaises(ValueError):
func(image, bboxes, *args)
else:
image, bboxes = func(image, bboxes, *args)
self.assertEqual((2, 224, 224, 3), image.shape) self.assertEqual((2, 224, 224, 3), image.shape)
self.assertEqual((2, 2, 4), bboxes.shape)
def _generate_test_policy(self): def _generate_test_policy(self):
"""Generate a test policy at random.""" """Generate a test policy at random."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment