"vscode:/vscode.git/clone" did not exist on "53e6552fed4595edb67f13b88772e6d2eab8dff9"
Commit 09b3f5ab authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 481936708
parent 54a70bac
# MOSAIC: Mobile Segmentation via decoding Aggregated Information and encoded Context
[![Paper](http://img.shields.io/badge/Paper-arXiv.2112.11623-B3181B?logo=arXiv)](https://arxiv.org/abs/2112.11623)
This repository is the official implementation of the following
paper.
* [MOSAIC: Mobile Segmentation via decoding Aggregated Information and encoded Context](https://arxiv.org/abs/2112.11623)
## Description
MOSAIC is a neural network architecture for efficient and accurate semantic
image segmentation on mobile devices. MOSAIC is designed using commonly
supported neural operations by diverse mobile hardware platforms for flexible
deployment across various mobile platforms. With a simple asymmetric
encoder-decoder structure which consists of an efficient multi-scale context
encoder and a light-weight hybrid decoder to recover spatial details from
aggregated information, MOSAIC achieves better balanced performance while
considering accuracy and computational cost. Deployed on top of a tailored
feature extraction backbone based on a searched classification network, MOSAIC
achieves a 5% absolute accuracy gain on ADE20K with similar or lower latency
compared to the current industry standard MLPerf mobile v1.0 models and
state-of-the-art architectures.
[MLPerf Mobile v2.0]((https://mlcommons.org/en/inference-mobile-20/)) included
MOSAIC as a new industry standard benchmark model for image segmentation.
Please see details [here](https://mlcommons.org/en/news/mlperf-inference-1q2022/).
You can also refer to the [MLCommons GitHub repository](https://github.com/mlcommons/mobile_open/tree/main/vision/mosaic).
## History
### Oct 13, 2022
* First release of MOSAIC in TensorFlow 2 including checkpoints that have been
pretrained on Cityscapes.
## Maintainers
* Weijun Wang ([weijunw-g](https://github.com/weijunw-g))
* Fang Yang ([fyangf](https://github.com/fyangf))
* Shixin Luo ([luotigerlsx](https://github.com/luotigerlsx))
## Requirements
[![Python](https://img.shields.io/pypi/pyversions/tensorflow.svg?style=plastic)](https://badge.fury.io/py/tensorflow)
[![tf-models-official PyPI](https://badge.fury.io/py/tf-models-official.svg)](https://badge.fury.io/py/tf-models-official)
## Results
The following table shows the mIoU measured on the `cityscapes` dataset.
| Config | Backbone | Resolution | branch_filter_depths | pyramid_pool_bin_nums | mIoU | Download |
|-------------------------|:--------------------:|:----------:|:--------------------:|:---------------------:|:-----:|:--------:|
| Paper reference config | MobileNetMultiAVGSeg | 1024x2048 | [32, 32] | [4, 8, 16] | 75.98 | [ckpt](https://storage.googleapis.com/tf_model_garden/vision/mosaic/MobileNetMultiAVGSeg-r1024-ebf32-nogp.tar.gz)<br>[tensorboard](https://tensorboard.dev/experiment/okEog90bSwupajFgJwGEIw//#scalars) |
| Current best config | MobileNetMultiAVGSeg | 1024x2048 | [64, 64] | [1, 4, 8, 16] | 77.24 | [ckpt](https://storage.googleapis.com/tf_model_garden/vision/mosaic/MobileNetMultiAVGSeg-r1024-ebf64-gp.tar.gz)<br>[tensorboard](https://tensorboard.dev/experiment/l5hkV7JaQM23EXeOBT6oJg/#scalars) |
* `branch_filter_depths`: the number of convolution channels in each branch at
a pyramid level after `Spatial Pyramid Pooling`
* `pyramid_pool_bin_nums`: the number of bins at each level of the `Spatial
Pyramid Pooling`
## Training
It can run on Google Cloud Platform using Cloud TPU.
[Here](https://cloud.google.com/tpu/docs/how-to) is the instruction of using
Cloud TPU. Following the instructions to set up Cloud TPU and
launch training by:
```shell
EXP_TYPE=mosaic_mnv35_cityscapes
EXP_NAME="<experiment-name>" # You can give any name to the experiment.
TPU_NAME="<tpu-name>" # The name assigned while creating a Cloud TPU
MODEL_DIR="gs://<path-to-model-directory>"
# Now launch the experiment.
python3 -m official.projects.mosaic.train \
--experiment=$EXP_TYPE \
--mode=train \
--tpu=$TPU_NAME \
--model_dir=$MODEL_DIR \
--config_file=official/projects/mosaic/configs/experiments/mosaic_mnv35_cityscapes_tdfs_tpu.yaml
```
## Evaluation
Please run this command line for evaluation.
```shell
EXP_TYPE=mosaic_mnv35_cityscapes
EXP_NAME="<experiment-name>" # You can give any name to the experiment.
TPU_NAME="<tpu-name>" # The name assigned while creating a Cloud TPU
MODEL_DIR="gs://<path-to-model-directory>"
# Now launch the experiment.
python3 -m official.projects.mosaic.train \
--experiment=$EXP_TYPE \
--mode=eval \
--tpu=$TPU_NAME \
--model_dir=$MODEL_DIR \
--config_file=official/projects/mosaic/configs/experiments/mosaic_mnv35_cityscapes_tdfs_tpu.yaml
```
## License
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
This project is licensed under the terms of the **Apache License 2.0**.
## Citation
If you want to cite this repository in your work, please consider citing the
paper.
```
@inproceedings{weijun2021mosaic,
title={MOSAIC: Mobile Segmentation via decoding Aggregated Information and
encoded Context},
author={Weijun Wang, Andrew Howard},
journal={arXiv preprint arXiv:2112.11623},
year={2021},
}
```
# Using Tensorflow datasets: 'cityscapes/semantic_segmentation'
# Some expected flags to use with xmanager launcher:
# --experiment_type=mosaic_mnv35_cityscapes
# --tpu_topology=4x4
# mIoU: 77.24%
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'float32'
task:
model:
num_classes: 19
input_size: [null, null, 3]
backbone:
type: 'mobilenet'
mobilenet:
model_id: 'MobileNetMultiAVGSeg'
output_intermediate_endpoints: true
output_stride: 16
neck:
branch_filter_depths: [64, 64]
conv_kernel_sizes: [3, 5]
pyramid_pool_bin_nums: [1, 4, 8, 16]
dropout_rate: 0.0
head:
num_classes: 19
decoder_input_levels: ['3/depthwise', '2/depthwise']
decoder_stage_merge_styles: ['concat_merge', 'sum_merge']
decoder_filters: [64, 64]
decoder_projected_filters: [19, 19]
norm_activation:
activation: relu
norm_epsilon: 0.001
norm_momentum: 0.99
use_sync_bn: true
init_checkpoint: 'gs://tf_model_garden/vision/mobilenet/v3.5multiavg_seg_float/'
init_checkpoint_modules: 'backbone'
losses:
l2_weight_decay: 1.0e-04
train_data:
output_size: [1024, 2048]
crop_size: [1024, 2048]
input_path: ''
tfds_name: 'cityscapes/semantic_segmentation'
tfds_split: 'train'
is_training: true
global_batch_size: 32
dtype: 'float32'
aug_rand_hflip: true
aug_scale_max: 2.0
aug_scale_min: 0.5
validation_data:
output_size: [1024, 2048]
input_path: ''
tfds_name: 'cityscapes/semantic_segmentation'
tfds_split: 'validation'
is_training: false
global_batch_size: 32
dtype: 'float32'
drop_remainder: false
resize_eval_groundtruth: true
trainer:
optimizer_config:
learning_rate:
polynomial:
decay_steps: 100000
initial_learning_rate: 0.1
power: 0.9
type: polynomial
optimizer:
sgd:
momentum: 0.9
type: sgd
warmup:
linear:
name: linear
warmup_learning_rate: 0
warmup_steps: 925
type: linear
steps_per_loop: 92 # 2975 / 32 = 92
summary_interval: 92
train_steps: 100000
validation_interval: 92
validation_steps: 16 # 500 / 32 = 16
checkpoint_interval: 92
best_checkpoint_export_subdir: 'best_ckpt'
best_checkpoint_eval_metric: 'mean_iou'
best_checkpoint_metric_comp: 'higher'
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration definition for Semantic Segmentation with MOSAIC."""
import dataclasses
import os
from typing import List, Optional, Union
import numpy as np
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling import optimization
from official.vision.configs import common
from official.vision.configs import semantic_segmentation as seg_cfg
from official.vision.configs.google import backbones
@dataclasses.dataclass
class MosaicDecoderHead(hyperparams.Config):
"""MOSAIC decoder head config for Segmentation."""
num_classes: int = 19
decoder_input_levels: List[str] = dataclasses.field(default_factory=list)
decoder_stage_merge_styles: List[str] = dataclasses.field(
default_factory=list)
decoder_filters: List[int] = dataclasses.field(default_factory=list)
decoder_projected_filters: List[int] = dataclasses.field(default_factory=list)
encoder_end_level: int = 4
use_additional_classifier_layer: bool = False
classifier_kernel_size: int = 1
activation: str = 'relu'
kernel_initializer: str = 'glorot_uniform'
interpolation: str = 'bilinear'
@dataclasses.dataclass
class MosaicEncoderNeck(hyperparams.Config):
"""MOSAIC encoder neck config for segmentation."""
encoder_input_level: Union[str, int] = '4'
branch_filter_depths: List[int] = dataclasses.field(default_factory=list)
conv_kernel_sizes: List[int] = dataclasses.field(default_factory=list)
pyramid_pool_bin_nums: List[int] = dataclasses.field(default_factory=list)
activation: str = 'relu'
dropout_rate: float = 0.1
kernel_initializer: str = 'glorot_uniform'
interpolation: str = 'bilinear'
use_depthwise_convolution: bool = True
@dataclasses.dataclass
class MosaicSemanticSegmentationModel(hyperparams.Config):
"""MOSAIC semantic segmentation model config."""
num_classes: int = 19
input_size: List[int] = dataclasses.field(default_factory=list)
head: MosaicDecoderHead = MosaicDecoderHead()
backbone: backbones.Backbone = backbones.Backbone(
type='mobilenet', mobilenet=backbones.MobileNet())
neck: MosaicEncoderNeck = MosaicEncoderNeck()
norm_activation: common.NormActivation = common.NormActivation(
use_sync_bn=True, norm_momentum=0.99, norm_epsilon=0.001)
@dataclasses.dataclass
class MosaicSemanticSegmentationTask(seg_cfg.SemanticSegmentationTask):
"""The config for MOSAIC segmentation task."""
model: MosaicSemanticSegmentationModel = MosaicSemanticSegmentationModel()
train_data: seg_cfg.DataConfig = seg_cfg.DataConfig(is_training=True)
validation_data: seg_cfg.DataConfig = seg_cfg.DataConfig(is_training=False)
losses: seg_cfg.Losses = seg_cfg.Losses()
evaluation: seg_cfg.Evaluation = seg_cfg.Evaluation()
train_input_partition_dims: List[int] = dataclasses.field(
default_factory=list)
eval_input_partition_dims: List[int] = dataclasses.field(
default_factory=list)
init_checkpoint: Optional[str] = None
init_checkpoint_modules: Union[
str, List[str]] = 'all' # all, backbone, and/or neck.
export_config: seg_cfg.ExportConfig = seg_cfg.ExportConfig()
# Cityscapes Dataset (Download and process the dataset yourself)
CITYSCAPES_TRAIN_EXAMPLES = 2975
CITYSCAPES_VAL_EXAMPLES = 500
CITYSCAPES_INPUT_PATH_BASE = 'cityscapes/tfrecord'
@exp_factory.register_config_factory('mosaic_mnv35_cityscapes')
def mosaic_mnv35_cityscapes() -> cfg.ExperimentConfig:
"""Instantiates an experiment configuration of image segmentation task.
This image segmentation experiment is conducted on Cityscapes dataset. The
model architecture is a MOSAIC encoder-decoer. The default backbone network is
a mobilenet variant called Mobilenet_v3.5-MultiAvg on top of which the MOSAIC
encoder-decoder can be deployed. All detailed configurations can be overridden
by a .yaml file provided by the user to launch the experiments. Please refer
to .yaml examples in the path of ../configs/experiments/.
Returns:
A particular instance of cfg.ExperimentConfig for MOSAIC model based
image semantic segmentation task.
"""
train_batch_size = 16
eval_batch_size = 16
steps_per_epoch = CITYSCAPES_TRAIN_EXAMPLES // train_batch_size
output_stride = 16
backbone_output_level = int(np.math.log2(output_stride))
config = cfg.ExperimentConfig(
task=MosaicSemanticSegmentationTask(
model=MosaicSemanticSegmentationModel(
# Cityscapes uses only 19 semantic classes for train/evaluation.
# The void (background) class is ignored in train and evaluation.
num_classes=19,
input_size=[None, None, 3],
backbone=backbones.Backbone(
type='mobilenet',
mobilenet=backbones.MobileNet(
model_id='MobileNetMultiAVGSeg',
output_intermediate_endpoints=True,
output_stride=output_stride)),
neck=MosaicEncoderNeck(
encoder_input_level=backbone_output_level,
branch_filter_depths=[64, 64],
conv_kernel_sizes=[3, 5],
pyramid_pool_bin_nums=[1, 4, 8, 16], # paper default
activation='relu',
dropout_rate=0.1,
kernel_initializer='glorot_uniform',
interpolation='bilinear',
use_depthwise_convolution=True),
head=MosaicDecoderHead(
num_classes=19,
decoder_input_levels=['3/depthwise', '2/depthwise'],
decoder_stage_merge_styles=['concat_merge', 'sum_merge'],
decoder_filters=[64, 64],
decoder_projected_filters=[19, 19],
encoder_end_level=backbone_output_level,
use_additional_classifier_layer=False,
classifier_kernel_size=1,
activation='relu',
kernel_initializer='glorot_uniform',
interpolation='bilinear'),
norm_activation=common.NormActivation(
activation='relu',
norm_momentum=0.99,
norm_epsilon=1e-3,
use_sync_bn=True)),
losses=seg_cfg.Losses(l2_weight_decay=4e-5),
train_data=seg_cfg.DataConfig(
input_path=os.path.join(CITYSCAPES_INPUT_PATH_BASE,
'train_fine**'),
crop_size=[1024, 2048],
output_size=[1024, 2048],
is_training=True,
global_batch_size=train_batch_size,
aug_scale_min=0.5,
aug_scale_max=2.0),
validation_data=seg_cfg.DataConfig(
input_path=os.path.join(CITYSCAPES_INPUT_PATH_BASE, 'val_fine*'),
output_size=[1024, 2048],
is_training=False,
global_batch_size=eval_batch_size,
resize_eval_groundtruth=True,
drop_remainder=False),
# Imagenet pre-trained Mobilenet_v3.5-MultiAvg checkpoint.
init_checkpoint='gs://tf_model_garden/vision/mobilenet/v3.5multiavg_seg_float/',
init_checkpoint_modules='backbone'),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=100000,
validation_steps=CITYSCAPES_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
best_checkpoint_eval_metric='mean_iou',
best_checkpoint_export_subdir='best_ckpt',
best_checkpoint_metric_comp='higher',
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'sgd',
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'polynomial',
'polynomial': {
'initial_learning_rate': 0.1,
'decay_steps': 100000,
'end_learning_rate': 0.0,
'power': 0.9
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 5 * steps_per_epoch,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Definitions of building blocks for MOSAIC model.
Reference:
[MOSAIC: Mobile Segmentation via decoding Aggregated Information and encoded
Context](https://arxiv.org/pdf/2112.11623.pdf)
"""
from typing import Any, Dict, List, Optional, Tuple, Union
import tensorflow as tf
from official.modeling import tf_utils
@tf.keras.utils.register_keras_serializable(package='Vision')
class MultiKernelGroupConvBlock(tf.keras.layers.Layer):
"""A multi-kernel grouped convolution block.
This block is used in the segmentation neck introduced in MOSAIC.
Reference:
[MOSAIC: Mobile Segmentation via decoding Aggregated Information and encoded
Context](https://arxiv.org/pdf/2112.11623.pdf)
"""
def __init__(
self,
output_filter_depths: Optional[List[int]] = None,
kernel_sizes: Optional[List[int]] = None,
use_sync_bn: bool = False,
batchnorm_momentum: float = 0.99,
batchnorm_epsilon: float = 0.001,
activation: str = 'relu',
kernel_initializer: str = 'GlorotUniform',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
use_depthwise_convolution: bool = True,
**kwargs):
"""Initializes a Multi-kernel Grouped Convolution Block.
Args:
output_filter_depths: A list of integers representing the numbers of
output channels or filter depths of convolution groups.
kernel_sizes: A list of integers denoting the convolution kernel sizes in
each convolution group.
use_sync_bn: A bool, whether or not to use sync batch normalization.
batchnorm_momentum: A float for the momentum in BatchNorm. Defaults to
0.99.
batchnorm_epsilon: A float for the epsilon value in BatchNorm. Defaults to
0.001.
activation: A `str` for the activation fuction type. Defaults to 'relu'.
kernel_initializer: Kernel initializer for conv layers. Defaults to
`glorot_uniform`.
kernel_regularizer: Kernel regularizer for conv layers. Defaults to None.
use_depthwise_convolution: Allows spatial pooling to be separable
depthwise convolusions.
**kwargs: Other keyword arguments for the layer.
"""
super(MultiKernelGroupConvBlock, self).__init__(**kwargs)
if output_filter_depths is None:
output_filter_depths = [64, 64]
if kernel_sizes is None:
kernel_sizes = [3, 5]
if len(output_filter_depths) != len(kernel_sizes):
raise ValueError('The number of output groups must match #kernels.')
self._output_filter_depths = output_filter_depths
self._kernel_sizes = kernel_sizes
self._num_groups = len(self._kernel_sizes)
self._use_sync_bn = use_sync_bn
self._batchnorm_momentum = batchnorm_momentum
self._batchnorm_epsilon = batchnorm_epsilon
self._activation = activation
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._use_depthwise_convolution = use_depthwise_convolution
# To apply BN before activation. Putting BN between conv and activation also
# helps quantization where conv+bn+activation are fused into a single op.
self._activation_fn = tf_utils.get_activation(activation)
if self._use_sync_bn:
self._bn_op = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._bn_op = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
self._group_split_axis = -1
else:
self._bn_axis = 1
self._group_split_axis = 1
def build(self, input_shape: tf.TensorShape) -> None:
"""Builds the block with the given input shape."""
input_channels = input_shape[self._group_split_axis]
if input_channels % self._num_groups != 0:
raise ValueError('The number of input channels must be divisible by '
'the number of groups for evenly group split.')
self._conv_branches = []
if self._use_depthwise_convolution:
for i, conv_kernel_size in enumerate(self._kernel_sizes):
depthwise_conv = tf.keras.layers.DepthwiseConv2D(
kernel_size=(conv_kernel_size, conv_kernel_size),
depth_multiplier=1,
padding='same',
depthwise_regularizer=self._kernel_regularizer,
depthwise_initializer=self._kernel_initializer,
use_bias=False)
# Add BN->RELU after depthwise convolution.
batchnorm_op_depthwise = self._bn_op(
axis=self._bn_axis,
momentum=self._batchnorm_momentum,
epsilon=self._batchnorm_epsilon)
activation_depthwise = self._activation_fn
feature_conv = tf.keras.layers.Conv2D(
filters=self._output_filter_depths[i],
kernel_size=(1, 1),
padding='same',
kernel_regularizer=self._kernel_regularizer,
kernel_initializer=self._kernel_initializer,
activation=None,
use_bias=False)
batchnorm_op = self._bn_op(
axis=self._bn_axis,
momentum=self._batchnorm_momentum,
epsilon=self._batchnorm_epsilon)
# Use list manually as current QAT API does not support sequential model
# within a tf.keras.Sequential block, e.g. conv_branch =
# tf.keras.Sequential([depthwise_conv, feature_conv, batchnorm_op,])
conv_branch = [
depthwise_conv,
batchnorm_op_depthwise,
activation_depthwise,
feature_conv,
batchnorm_op,
]
self._conv_branches.append(conv_branch)
else:
for i, conv_kernel_size in enumerate(self._kernel_sizes):
norm_conv = tf.keras.layers.Conv2D(
filters=self._output_filter_depths[i],
kernel_size=(conv_kernel_size, conv_kernel_size),
padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
activation=None,
use_bias=False)
batchnorm_op = self._bn_op(
axis=self._bn_axis,
momentum=self._batchnorm_momentum,
epsilon=self._batchnorm_epsilon)
conv_branch = [norm_conv, batchnorm_op]
self._conv_branches.append(conv_branch)
self._concat_groups = tf.keras.layers.Concatenate(
axis=self._group_split_axis)
def call(self,
inputs: tf.Tensor,
training: Optional[bool] = None) -> tf.Tensor:
"""Calls this group convolution block with the given inputs."""
inputs_splits = tf.split(inputs,
num_or_size_splits=self._num_groups,
axis=self._group_split_axis)
output_branches = []
for i, x in enumerate(inputs_splits):
conv_branch = self._conv_branches[i]
# Apply layers sequentially and manually.
for layer in conv_branch:
if isinstance(layer, tf.keras.layers.Layer):
x = layer(x, training=training)
else:
x = layer(x)
# Apply activation function after BN, which also helps quantization
# where conv+bn+activation are fused into a single op.
x = self._activation_fn(x)
output_branches.append(x)
x = self._concat_groups(output_branches)
return x
def get_config(self) -> Dict[str, Any]:
"""Returns a config dictionary for initialization from serialization."""
config = {
'output_filter_depths': self._output_filter_depths,
'kernel_sizes': self._kernel_sizes,
'num_groups': self._num_groups,
'use_sync_bn': self._use_sync_bn,
'batchnorm_momentum': self._batchnorm_momentum,
'batchnorm_epsilon': self._batchnorm_epsilon,
'activation': self._activation,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'use_depthwise_convolution': self._use_depthwise_convolution,
}
base_config = super(MultiKernelGroupConvBlock, self).get_config()
base_config.update(config)
return base_config
@tf.keras.utils.register_keras_serializable(package='Vision')
class MosaicEncoderBlock(tf.keras.layers.Layer):
"""Implements the encoder module/block of MOSAIC model.
Spatial Pyramid Pooling and Multi-kernel Conv layer
SpatialPyramidPoolingMultiKernelConv
References:
[MOSAIC: Mobile Segmentation via decoding Aggregated Information and encoded
context](https://arxiv.org/pdf/2112.11623.pdf)
"""
def __init__(
self,
encoder_input_level: Optional[Union[str, int]] = '4',
branch_filter_depths: Optional[List[int]] = None,
conv_kernel_sizes: Optional[List[int]] = None,
pyramid_pool_bin_nums: Optional[List[int]] = None,
use_sync_bn: bool = False,
batchnorm_momentum: float = 0.99,
batchnorm_epsilon: float = 0.001,
activation: str = 'relu',
dropout_rate: float = 0.1,
kernel_initializer: str = 'glorot_uniform',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
interpolation: str = 'bilinear',
use_depthwise_convolution: bool = True,
**kwargs):
"""Initializes a MOSAIC encoder block which is deployed after a backbone.
Args:
encoder_input_level: An optional `str` or integer specifying the level of
backbone outputs as the input to the encoder.
branch_filter_depths: A list of integers for the number of convolution
channels in each branch at a pyramid level after SpatialPyramidPooling.
conv_kernel_sizes: A list of integers representing the convolution kernel
sizes in the Multi-kernel Convolution blocks in the encoder.
pyramid_pool_bin_nums: A list of integers for the number of bins at each
level of the Spatial Pyramid Pooling.
use_sync_bn: A bool, whether or not to use sync batch normalization.
batchnorm_momentum: A float for the momentum in BatchNorm. Defaults to
0.99.
batchnorm_epsilon: A float for the epsilon value in BatchNorm. Defaults to
0.001.
activation: A `str` for the activation function type. Defaults to 'relu'.
dropout_rate: A float between 0 and 1. Fraction of the input units to drop
out, which will be used directly as the `rate` of the Dropout layer at
the end of the encoder. Defaults to 0.1.
kernel_initializer: Kernel initializer for conv layers. Defaults to
`glorot_uniform`.
kernel_regularizer: Kernel regularizer for conv layers. Defaults to None.
interpolation: The interpolation method for upsampling. Defaults to
`bilinear`.
use_depthwise_convolution: Use depthwise separable convolusions in the
Multi-kernel Convolution blocks in the encoder.
**kwargs: Other keyword arguments for the layer.
"""
super().__init__(**kwargs)
self._encoder_input_level = str(encoder_input_level)
if branch_filter_depths is None:
branch_filter_depths = [64, 64]
self._branch_filter_depths = branch_filter_depths
if conv_kernel_sizes is None:
conv_kernel_sizes = [3, 5]
self._conv_kernel_sizes = conv_kernel_sizes
if pyramid_pool_bin_nums is None:
pyramid_pool_bin_nums = [1, 4, 8, 16]
self._pyramid_pool_bin_nums = pyramid_pool_bin_nums
self._use_sync_bn = use_sync_bn
self._batchnorm_momentum = batchnorm_momentum
self._batchnorm_epsilon = batchnorm_epsilon
self._activation = activation
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._interpolation = interpolation
self._use_depthwise_convolution = use_depthwise_convolution
self._activation_fn = tf_utils.get_activation(activation)
if self._use_sync_bn:
self._bn_op = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._bn_op = tf.keras.layers.BatchNormalization
self._dropout_rate = dropout_rate
if dropout_rate:
self._encoder_end_dropout_layer = tf.keras.layers.Dropout(
rate=dropout_rate)
else:
self._encoder_end_dropout_layer = None
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
self._channel_axis = -1
else:
self._bn_axis = 1
self._channel_axis = 1
def _get_bin_pool_kernel_and_stride(
self,
input_size: int,
num_of_bin: int) -> Tuple[int, int]:
"""Calculates the kernel size and stride for spatial bin pooling.
Args:
input_size: Input dimension (a scalar).
num_of_bin: The number of bins used for spatial bin pooling.
Returns:
The Kernel and Stride for spatial bin pooling (a scalar).
"""
bin_overlap = int(input_size % num_of_bin)
pooling_stride = int(input_size // num_of_bin)
pooling_kernel = pooling_stride + bin_overlap
return pooling_kernel, pooling_stride
def build(
self, input_shape: Union[tf.TensorShape, Dict[str,
tf.TensorShape]]) -> None:
"""Builds this MOSAIC encoder block with the given single input shape."""
input_shape = (
input_shape[self._encoder_input_level]
if isinstance(input_shape, dict) else input_shape)
self._data_format = tf.keras.backend.image_data_format()
if self._data_format == 'channels_last':
height = input_shape[1]
width = input_shape[2]
else:
height = input_shape[2]
width = input_shape[3]
self._global_pool_branch = None
self._spatial_pyramid = []
for pyramid_pool_bin_num in self._pyramid_pool_bin_nums:
if pyramid_pool_bin_num == 1:
global_pool = tf.keras.layers.GlobalAveragePooling2D(
data_format=self._data_format, keepdims=True)
global_projection = tf.keras.layers.Conv2D(
filters=max(self._branch_filter_depths),
kernel_size=(1, 1),
padding='same',
activation=None,
kernel_regularizer=self._kernel_regularizer,
kernel_initializer=self._kernel_initializer,
use_bias=False)
batch_norm_global_branch = self._bn_op(
axis=self._bn_axis,
momentum=self._batchnorm_momentum,
epsilon=self._batchnorm_epsilon)
# Use list manually instead of tf.keras.Sequential([])
self._global_pool_branch = [
global_pool,
global_projection,
batch_norm_global_branch,
]
else:
if height < pyramid_pool_bin_num or width < pyramid_pool_bin_num:
raise ValueError('The number of pooling bins must be smaller than '
'input sizes.')
assert pyramid_pool_bin_num >= 2, (
'Except for the gloabl pooling, the number of bins in pyramid '
'pooling must be at least two.')
pool_height, stride_height = self._get_bin_pool_kernel_and_stride(
height, pyramid_pool_bin_num)
pool_width, stride_width = self._get_bin_pool_kernel_and_stride(
width, pyramid_pool_bin_num)
bin_pool_level = tf.keras.layers.AveragePooling2D(
pool_size=(pool_height, pool_width),
strides=(stride_height, stride_width),
padding='valid',
data_format=self._data_format)
self._spatial_pyramid.append(bin_pool_level)
# Grouped multi-kernel Convolution.
self._multi_kernel_group_conv = MultiKernelGroupConvBlock(
output_filter_depths=self._branch_filter_depths,
kernel_sizes=self._conv_kernel_sizes,
use_sync_bn=self._use_sync_bn,
batchnorm_momentum=self._batchnorm_momentum,
batchnorm_epsilon=self._batchnorm_epsilon,
activation=self._activation,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
use_depthwise_convolution=self._use_depthwise_convolution)
# Encoder's final 1x1 feature projection.
# Considering the relatively large #channels merged before projection,
# enlarge the projection #channels to the sum of the filter depths of
# branches.
self._output_channels = sum(self._branch_filter_depths)
# Use list manually instead of tf.keras.Sequential([]).
self._encoder_projection = [
tf.keras.layers.Conv2D(
filters=self._output_channels,
kernel_size=(1, 1),
padding='same',
activation=None,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
use_bias=False),
self._bn_op(
axis=self._bn_axis,
momentum=self._batchnorm_momentum,
epsilon=self._batchnorm_epsilon),
]
# Use the TF2 default feature alignment rule for bilinear resizing.
self._upsample = tf.keras.layers.Resizing(
height,
width,
interpolation=self._interpolation,
crop_to_aspect_ratio=False)
self._concat_layer = tf.keras.layers.Concatenate(axis=self._channel_axis)
def call(self,
inputs: Union[tf.Tensor, Dict[str, tf.Tensor]],
training: Optional[bool] = None) -> tf.Tensor:
"""Calls this MOSAIC encoder block with the given input."""
if training is None:
training = tf.keras.backend.learning_phase()
input_from_backbone_output = (
inputs[self._encoder_input_level]
if isinstance(inputs, dict) else inputs)
branches = []
# Original features from the final output of the backbone.
branches.append(input_from_backbone_output)
if self._spatial_pyramid:
for bin_pool_level in self._spatial_pyramid:
x = input_from_backbone_output
x = bin_pool_level(x)
x = self._multi_kernel_group_conv(x, training=training)
x = self._upsample(x)
branches.append(x)
if self._global_pool_branch is not None:
x = input_from_backbone_output
for layer in self._global_pool_branch:
x = layer(x, training=training)
x = self._activation_fn(x)
x = self._upsample(x)
branches.append(x)
x = self._concat_layer(branches)
for layer in self._encoder_projection:
x = layer(x, training=training)
x = self._activation_fn(x)
if self._encoder_end_dropout_layer is not None:
x = self._encoder_end_dropout_layer(x, training=training)
return x
def get_config(self) -> Dict[str, Any]:
"""Returns a config dictionary for initialization from serialization."""
config = {
'encoder_input_level': self._encoder_input_level,
'branch_filter_depths': self._branch_filter_depths,
'conv_kernel_sizes': self._conv_kernel_sizes,
'pyramid_pool_bin_nums': self._pyramid_pool_bin_nums,
'use_sync_bn': self._use_sync_bn,
'batchnorm_momentum': self._batchnorm_momentum,
'batchnorm_epsilon': self._batchnorm_epsilon,
'activation': self._activation,
'dropout_rate': self._dropout_rate,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'interpolation': self._interpolation,
'use_depthwise_convolution': self._use_depthwise_convolution,
}
base_config = super().get_config()
base_config.update(config)
return base_config
@tf.keras.utils.register_keras_serializable(package='Vision')
class DecoderSumMergeBlock(tf.keras.layers.Layer):
"""Implements the decoder feature sum merge block of MOSAIC model.
This block is used in the decoder of segmentation head introduced in MOSAIC.
It essentially merges a high-resolution feature map of a low semantic level
and a low-resolution feature map of a higher semantic level by 'Sum-Merge'.
"""
def __init__(
self,
decoder_projected_depth: int,
output_size: Tuple[int, int] = (0, 0),
use_sync_bn: bool = False,
batchnorm_momentum: float = 0.99,
batchnorm_epsilon: float = 0.001,
activation: str = 'relu',
kernel_initializer: str = 'GlorotUniform',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
interpolation: str = 'bilinear',
**kwargs):
"""Initialize a sum-merge block for one decoder stage.
Args:
decoder_projected_depth: An integer representing the number of output
channels of this sum-merge block in the decoder.
output_size: A Tuple of integers representing the output height and width
of the feature maps from this sum-merge block. Defaults to (0, 0),
where the output size is set the same as the high-resolution branch.
use_sync_bn: A bool, whether or not to use sync batch normalization.
batchnorm_momentum: A float for the momentum in BatchNorm. Defaults to
0.99.
batchnorm_epsilon: A float for the epsilon value in BatchNorm. Defaults to
0.001.
activation: A `str` for the activation function type. Defaults to 'relu'.
kernel_initializer: Kernel initializer for conv layers. Defaults to
`glorot_uniform`.
kernel_regularizer: Kernel regularizer for conv layers. Defaults to None.
interpolation: The interpolation method for upsampling. Defaults to
`bilinear`.
**kwargs: Other keyword arguments for the layer.
"""
super(DecoderSumMergeBlock, self).__init__(**kwargs)
self._decoder_projected_depth = decoder_projected_depth
self._output_size = output_size
self._low_res_branch = []
self._upsample_low_res = None
self._high_res_branch = []
self._upsample_high_res = None
self._use_sync_bn = use_sync_bn
self._batchnorm_momentum = batchnorm_momentum
self._batchnorm_epsilon = batchnorm_epsilon
self._activation = activation
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._interpolation = interpolation
# Apply BN before activation. Putting BN between conv and activation also
# helps quantization where conv+bn+activation are fused into a single op.
self._activation_fn = tf_utils.get_activation(activation)
if self._use_sync_bn:
self._bn_op = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._bn_op = tf.keras.layers.BatchNormalization
self._bn_axis = (
-1
if tf.keras.backend.image_data_format() == 'channels_last' else 1)
self._channel_axis = (
-1
if tf.keras.backend.image_data_format() == 'channels_last' else 1)
self._add_layer = tf.keras.layers.Add()
def build(
self,
input_shape: Tuple[tf.TensorShape, tf.TensorShape]) -> None:
"""Builds the block with the given input shape."""
# Assume backbone features of the same level are concated before input.
low_res_input_shape = input_shape[0]
high_res_input_shape = input_shape[1]
low_res_channels = low_res_input_shape[self._channel_axis]
high_res_channels = high_res_input_shape[self._channel_axis]
if low_res_channels != self._decoder_projected_depth:
low_res_feature_conv = tf.keras.layers.Conv2D(
filters=self._decoder_projected_depth,
kernel_size=(1, 1),
padding='same',
kernel_regularizer=self._kernel_regularizer,
kernel_initializer=self._kernel_initializer,
activation=None,
use_bias=False)
batchnorm_op = self._bn_op(
axis=self._bn_axis,
momentum=self._batchnorm_momentum,
epsilon=self._batchnorm_epsilon)
self._low_res_branch.extend([
low_res_feature_conv,
batchnorm_op,
])
if high_res_channels != self._decoder_projected_depth:
high_res_feature_conv = tf.keras.layers.Conv2D(
filters=self._decoder_projected_depth,
kernel_size=(1, 1),
padding='same',
kernel_regularizer=self._kernel_regularizer,
kernel_initializer=self._kernel_initializer,
activation=None,
use_bias=False)
batchnorm_op_high = self._bn_op(
axis=self._bn_axis,
momentum=self._batchnorm_momentum,
epsilon=self._batchnorm_epsilon)
self._high_res_branch.extend([
high_res_feature_conv,
batchnorm_op_high,
])
# Resize feature maps.
if tf.keras.backend.image_data_format() == 'channels_last':
low_res_height = low_res_input_shape[1]
low_res_width = low_res_input_shape[2]
high_res_height = high_res_input_shape[1]
high_res_width = high_res_input_shape[2]
else:
low_res_height = low_res_input_shape[2]
low_res_width = low_res_input_shape[3]
high_res_height = high_res_input_shape[2]
high_res_width = high_res_input_shape[3]
if (self._output_size[0] == 0 or self._output_size[1] == 0):
self._output_size = (high_res_height, high_res_width)
if (low_res_height != self._output_size[0] or
low_res_width != self._output_size[1]):
self._upsample_low_res = tf.keras.layers.Resizing(
self._output_size[0],
self._output_size[1],
interpolation=self._interpolation,
crop_to_aspect_ratio=False)
if (high_res_height != self._output_size[0] or
high_res_width != self._output_size[1]):
self._upsample_high_res = tf.keras.layers.Resizing(
self._output_size[0],
self._output_size[1],
interpolation=self._interpolation,
crop_to_aspect_ratio=False)
def call(self,
inputs: Tuple[tf.Tensor, tf.Tensor],
training: Optional[bool] = None) -> tf.Tensor:
"""Calls this decoder sum-merge block with the given input.
Args:
inputs: A Tuple of tensors consisting of a low-resolution higher-semantic
level feature map from the encoder as the first item and a higher
resolution lower-level feature map from the backbone as the second item.
training: a `bool` indicating whether it is in `training` mode.
Note: the first item of the input Tuple takes a lower-resolution feature map
and the second item of the input Tuple takes a higher-resolution branch.
Returns:
A tensor representing the sum-merged decoder feature map.
"""
if training is None:
training = tf.keras.backend.learning_phase()
x_low_res = inputs[0]
x_high_res = inputs[1]
if self._low_res_branch:
for layer in self._low_res_branch:
x_low_res = layer(x_low_res, training=training)
x_low_res = self._activation_fn(x_low_res)
if self._high_res_branch:
for layer in self._high_res_branch:
x_high_res = layer(x_high_res, training=training)
x_high_res = self._activation_fn(x_high_res)
if self._upsample_low_res is not None:
x_low_res = self._upsample_low_res(x_low_res)
if self._upsample_high_res is not None:
x_high_res = self._upsample_high_res(x_high_res)
output = self._add_layer([x_low_res, x_high_res])
return output
def get_config(self) -> Dict[str, Any]:
"""Returns a config dictionary for initialization from serialization."""
config = {
'decoder_projected_depth': self._decoder_projected_depth,
'output_size': self._output_size,
'use_sync_bn': self._use_sync_bn,
'batchnorm_momentum': self._batchnorm_momentum,
'batchnorm_epsilon': self._batchnorm_epsilon,
'activation': self._activation,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'interpolation': self._interpolation,
}
base_config = super(DecoderSumMergeBlock, self).get_config()
base_config.update(config)
return base_config
@tf.keras.utils.register_keras_serializable(package='Vision')
class DecoderConcatMergeBlock(tf.keras.layers.Layer):
"""Implements the decoder feature concat merge block of MOSAIC model.
This block is used in the decoder of segmentation head introduced in MOSAIC.
It essentially merges a high-resolution feature map of a low semantic level
and a low-resolution feature of a higher semantic level by 'Concat-Merge'.
"""
def __init__(
self,
decoder_internal_depth: int,
decoder_projected_depth: int,
output_size: Tuple[int, int] = (0, 0),
use_sync_bn: bool = False,
batchnorm_momentum: float = 0.99,
batchnorm_epsilon: float = 0.001,
activation: str = 'relu',
kernel_initializer: str = 'GlorotUniform',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
interpolation: str = 'bilinear',
**kwargs):
"""Initializes a concat-merge block for one decoder stage.
Args:
decoder_internal_depth: An integer representing the number of internal
channels of this concat-merge block in the decoder.
decoder_projected_depth: An integer representing the number of output
channels of this concat-merge block in the decoder.
output_size: A Tuple of integers representing the output height and width
of the feature maps from this concat-merge block. Defaults to (0, 0),
where the output size is set the same as the high-resolution branch.
use_sync_bn: A bool, whether or not to use sync batch normalization.
batchnorm_momentum: A float for the momentum in BatchNorm. Defaults to
0.99.
batchnorm_epsilon: A float for the epsilon value in BatchNorm. Defaults to
0.001.
activation: A `str` for the activation function type. Defaults to 'relu'.
kernel_initializer: Kernel initializer for conv layers. Defaults to
`glorot_uniform`.
kernel_regularizer: Kernel regularizer for conv layers. Defaults to None.
interpolation: The interpolation method for upsampling. Defaults to
`bilinear`.
**kwargs: Other keyword arguments for the layer.
"""
super(DecoderConcatMergeBlock, self).__init__(**kwargs)
self._decoder_internal_depth = decoder_internal_depth
self._decoder_projected_depth = decoder_projected_depth
self._output_size = output_size
self._upsample_low_res = None
self._upsample_high_res = None
self._use_sync_bn = use_sync_bn
self._batchnorm_momentum = batchnorm_momentum
self._batchnorm_epsilon = batchnorm_epsilon
self._activation = activation
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._interpolation = interpolation
# Apply BN before activation. Putting BN between conv and activation also
# helps quantization where conv+bn+activation are fused into a single op.
self._activation_fn = tf_utils.get_activation(activation)
if self._use_sync_bn:
self._bn_op = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._bn_op = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
self._channel_axis = -1
else:
self._bn_axis = 1
self._channel_axis = 1
def build(
self,
input_shape: Tuple[tf.TensorShape, tf.TensorShape]) -> None:
"""Builds this block with the given input shape."""
# Assume backbone features of the same level are concated before input.
low_res_input_shape = input_shape[0]
high_res_input_shape = input_shape[1]
# Set up resizing feature maps before concat.
if tf.keras.backend.image_data_format() == 'channels_last':
low_res_height = low_res_input_shape[1]
low_res_width = low_res_input_shape[2]
high_res_height = high_res_input_shape[1]
high_res_width = high_res_input_shape[2]
else:
low_res_height = low_res_input_shape[2]
low_res_width = low_res_input_shape[3]
high_res_height = high_res_input_shape[2]
high_res_width = high_res_input_shape[3]
if (self._output_size[0] == 0 or self._output_size[1] == 0):
self._output_size = (high_res_height, high_res_width)
if (low_res_height != self._output_size[0] or
low_res_width != self._output_size[1]):
self._upsample_low_res = tf.keras.layers.Resizing(
self._output_size[0],
self._output_size[1],
interpolation=self._interpolation,
crop_to_aspect_ratio=False)
if (high_res_height != self._output_size[0] or
high_res_width != self._output_size[1]):
self._upsample_high_res = tf.keras.layers.Resizing(
self._output_size[0],
self._output_size[1],
interpolation=self._interpolation,
crop_to_aspect_ratio=False)
# Set up a 3-layer separable convolution blocks, i.e.
# 1x1->BN->RELU + Depthwise->BN->RELU + 1x1->BN->RELU.
initial_feature_conv = tf.keras.layers.Conv2D(
filters=self._decoder_internal_depth,
kernel_size=(1, 1),
padding='same',
kernel_regularizer=self._kernel_regularizer,
kernel_initializer=self._kernel_initializer,
activation=None,
use_bias=False)
batchnorm_op1 = self._bn_op(
axis=self._bn_axis,
momentum=self._batchnorm_momentum,
epsilon=self._batchnorm_epsilon)
activation1 = self._activation_fn
depthwise_conv = tf.keras.layers.DepthwiseConv2D(
kernel_size=(3, 3),
depth_multiplier=1,
padding='same',
depthwise_regularizer=self._kernel_regularizer,
depthwise_initializer=self._kernel_initializer,
use_bias=False)
batchnorm_op2 = self._bn_op(
axis=self._bn_axis,
momentum=self._batchnorm_momentum,
epsilon=self._batchnorm_epsilon)
activation2 = self._activation_fn
project_feature_conv = tf.keras.layers.Conv2D(
filters=self._decoder_projected_depth,
kernel_size=(1, 1),
padding='same',
kernel_regularizer=self._kernel_regularizer,
kernel_initializer=self._kernel_initializer,
activation=None,
use_bias=False)
batchnorm_op3 = self._bn_op(
axis=self._bn_axis,
momentum=self._batchnorm_momentum,
epsilon=self._batchnorm_epsilon)
activation3 = self._activation_fn
self._feature_fusion_block = [
initial_feature_conv,
batchnorm_op1,
activation1,
depthwise_conv,
batchnorm_op2,
activation2,
project_feature_conv,
batchnorm_op3,
activation3,
]
self._concat_layer = tf.keras.layers.Concatenate(axis=self._channel_axis)
def call(self,
inputs: Tuple[tf.Tensor, tf.Tensor],
training: Optional[bool] = None) -> tf.Tensor:
"""Calls this concat-merge block with the given inputs.
Args:
inputs: A Tuple of tensors consisting of a lower-level higher-resolution
feature map from the backbone as the first item and a higher-level
lower-resolution feature map from the encoder as the second item.
training: a `Boolean` indicating whether it is in `training` mode.
Returns:
A tensor representing the concat-merged decoder feature map.
"""
low_res_input = inputs[0]
high_res_input = inputs[1]
if self._upsample_low_res is not None:
low_res_input = self._upsample_low_res(low_res_input)
if self._upsample_high_res is not None:
high_res_input = self._upsample_high_res(high_res_input)
decoder_feature_list = [low_res_input, high_res_input]
x = self._concat_layer(decoder_feature_list)
for layer in self._feature_fusion_block:
if isinstance(layer, tf.keras.layers.Layer):
x = layer(x, training=training)
else:
x = layer(x)
return x
def get_config(self) -> Dict[str, Any]:
"""Returns a config dictionary for initialization from serialization."""
config = {
'decoder_internal_depth': self._decoder_internal_depth,
'decoder_projected_depth': self._decoder_projected_depth,
'output_size': self._output_size,
'use_sync_bn': self._use_sync_bn,
'batchnorm_momentum': self._batchnorm_momentum,
'batchnorm_epsilon': self._batchnorm_epsilon,
'activation': self._activation,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'interpolation': self._interpolation,
}
base_config = super(DecoderConcatMergeBlock, self).get_config()
base_config.update(config)
return base_config
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for mosaic_blocks."""
# Import libraries
from absl.testing import parameterized
import tensorflow as tf
from official.projects.mosaic.modeling import mosaic_blocks
class MosaicBlocksTest(parameterized.TestCase, tf.test.TestCase):
def test_multi_kernel_group_conv_block(self):
block = mosaic_blocks.MultiKernelGroupConvBlock([64, 64], [3, 5])
inputs = tf.ones([1, 4, 4, 448])
outputs = block(inputs)
self.assertAllEqual(outputs.shape, [1, 4, 4, 128])
def test_mosaic_encoder_block(self):
block = mosaic_blocks.MosaicEncoderBlock(
encoder_input_level=4,
branch_filter_depths=[64, 64],
conv_kernel_sizes=[3, 5],
pyramid_pool_bin_nums=[1, 4, 8, 16])
inputs = tf.ones([1, 32, 32, 448])
outputs = block(inputs)
self.assertAllEqual(outputs.shape, [1, 32, 32, 128])
def test_mosaic_encoder_block_odd_input_overlap_pool(self):
block = mosaic_blocks.MosaicEncoderBlock(
encoder_input_level=4,
branch_filter_depths=[64, 64],
conv_kernel_sizes=[3, 5],
pyramid_pool_bin_nums=[1, 4, 8, 16])
inputs = tf.ones([1, 31, 31, 448])
outputs = block(inputs)
self.assertAllEqual(outputs.shape, [1, 31, 31, 128])
def test_mosaic_encoder_non_separable_block(self):
block = mosaic_blocks.MosaicEncoderBlock(
encoder_input_level=4,
branch_filter_depths=[64, 64],
conv_kernel_sizes=[3, 5],
pyramid_pool_bin_nums=[1, 4, 8, 16],
use_depthwise_convolution=False)
inputs = tf.ones([1, 32, 32, 448])
outputs = block(inputs)
self.assertAllEqual(outputs.shape, [1, 32, 32, 128])
def test_mosaic_decoder_concat_merge_block(self):
concat_merge_block = mosaic_blocks.DecoderConcatMergeBlock(64, 32, [64, 64])
inputs = [tf.ones([1, 32, 32, 128]), tf.ones([1, 64, 64, 192])]
outputs = concat_merge_block(inputs)
self.assertAllEqual(outputs.shape, [1, 64, 64, 32])
def test_mosaic_decoder_concat_merge_block_default_output_size(self):
concat_merge_block = mosaic_blocks.DecoderConcatMergeBlock(64, 32)
inputs = [tf.ones([1, 32, 32, 128]), tf.ones([1, 64, 64, 192])]
outputs = concat_merge_block(inputs)
self.assertAllEqual(outputs.shape, [1, 64, 64, 32])
def test_mosaic_decoder_concat_merge_block_default_output_size_4x(self):
concat_merge_block = mosaic_blocks.DecoderConcatMergeBlock(64, 32)
inputs = [tf.ones([1, 32, 32, 128]), tf.ones([1, 128, 128, 192])]
outputs = concat_merge_block(inputs)
self.assertAllEqual(outputs.shape, [1, 128, 128, 32])
def test_mosaic_decoder_concat_merge_block_default_output_size_4x_rec(self):
concat_merge_block = mosaic_blocks.DecoderConcatMergeBlock(64, 32)
inputs = [tf.ones([1, 32, 64, 128]), tf.ones([1, 128, 256, 64])]
outputs = concat_merge_block(inputs)
self.assertAllEqual(outputs.shape, [1, 128, 256, 32])
def test_mosaic_decoder_sum_merge_block(self):
concat_merge_block = mosaic_blocks.DecoderSumMergeBlock(32, [128, 128])
inputs = [tf.ones([1, 64, 64, 32]), tf.ones([1, 128, 128, 64])]
outputs = concat_merge_block(inputs)
self.assertAllEqual(outputs.shape, [1, 128, 128, 32])
def test_mosaic_decoder_sum_merge_block_default_output_size(self):
concat_merge_block = mosaic_blocks.DecoderSumMergeBlock(32)
inputs = [tf.ones([1, 64, 64, 32]), tf.ones([1, 128, 128, 64])]
outputs = concat_merge_block(inputs)
self.assertAllEqual(outputs.shape, [1, 128, 128, 32])
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains definitions of segmentation head of the MOSAIC model."""
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
import tensorflow as tf
from official.modeling import tf_utils
from official.projects.mosaic.modeling import mosaic_blocks
@tf.keras.utils.register_keras_serializable(package='Vision')
class MosaicDecoderHead(tf.keras.layers.Layer):
"""Creates a MOSAIC decoder in segmentation head.
Reference:
[MOSAIC: Mobile Segmentation via decoding Aggregated Information and encoded
Context](https://arxiv.org/pdf/2112.11623.pdf)
"""
def __init__(
self,
num_classes: int,
decoder_input_levels: Optional[List[str]] = None,
decoder_stage_merge_styles: Optional[List[str]] = None,
decoder_filters: Optional[List[int]] = None,
decoder_projected_filters: Optional[List[int]] = None,
encoder_end_level: Optional[int] = 4,
use_additional_classifier_layer: bool = False,
classifier_kernel_size: int = 1,
activation: str = 'relu',
use_sync_bn: bool = False,
batchnorm_momentum: float = 0.99,
batchnorm_epsilon: float = 0.001,
kernel_initializer: str = 'GlorotUniform',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
interpolation: str = 'bilinear',
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
**kwargs):
"""Initializes a MOSAIC segmentation head.
Args:
num_classes: An `int` number of mask classification categories. The number
of classes does not include background class.
decoder_input_levels: A list of `str` specifying additional
input levels from the backbone outputs for mask refinement in decoder.
decoder_stage_merge_styles: A list of `str` specifying the merge style at
each stage of the decoder, merge styles can be 'concat_merge' or
'sum_merge'.
decoder_filters: A list of integers specifying the number of channels used
at each decoder stage. Note: this only has affects if the decoder merge
style is 'concat_merge'.
decoder_projected_filters: A list of integers specifying the number of
projected channels at the end of each decoder stage.
encoder_end_level: An optional integer specifying the output level of the
encoder stage, which is used if the input from the encoder to the
decoder head is a dictionary.
use_additional_classifier_layer: A `bool` specifying whether to use an
additional classifier layer or not. It must be True if the final decoder
projected filters does not match the `num_classes`.
classifier_kernel_size: An `int` number to specify the kernel size of the
classifier layer.
activation: A `str` that indicates which activation is used, e.g. 'relu',
'swish', etc.
use_sync_bn: A `bool` that indicates whether to use synchronized batch
normalization across different replicas.
batchnorm_momentum: A `float` of normalization momentum for the moving
average.
batchnorm_epsilon: A `float` added to variance to avoid dividing by zero.
kernel_initializer: Kernel initializer for conv layers. Defaults to
`glorot_uniform`.
kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
Conv2D. Default is None.
interpolation: The interpolation method for upsampling. Defaults to
`bilinear`.
bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
**kwargs: Additional keyword arguments to be passed.
"""
super(MosaicDecoderHead, self).__init__(**kwargs)
# Assuming 'decoder_input_levels' are sorted in descending order and the
# other setting are listed in the order according to 'decoder_input_levels'.
if decoder_input_levels is None:
decoder_input_levels = ['3', '2']
if decoder_stage_merge_styles is None:
decoder_stage_merge_styles = ['concat_merge', 'sum_merge']
if decoder_filters is None:
decoder_filters = [64, 64]
if decoder_projected_filters is None:
decoder_projected_filters = [32, 32]
self._decoder_input_levels = decoder_input_levels
self._decoder_stage_merge_styles = decoder_stage_merge_styles
self._decoder_filters = decoder_filters
self._decoder_projected_filters = decoder_projected_filters
if (len(decoder_input_levels) != len(decoder_stage_merge_styles) or
len(decoder_input_levels) != len(decoder_filters) or
len(decoder_input_levels) != len(decoder_projected_filters)):
raise ValueError('The number of Decoder inputs and settings must match.')
self._merge_stages = []
for (stage_merge_style, decoder_filter,
decoder_projected_filter) in zip(decoder_stage_merge_styles,
decoder_filters,
decoder_projected_filters):
if stage_merge_style == 'concat_merge':
concat_merge_stage = mosaic_blocks.DecoderConcatMergeBlock(
decoder_internal_depth=decoder_filter,
decoder_projected_depth=decoder_projected_filter,
output_size=(0, 0),
use_sync_bn=use_sync_bn,
batchnorm_momentum=batchnorm_momentum,
batchnorm_epsilon=batchnorm_epsilon,
activation=activation,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
interpolation=interpolation)
self._merge_stages.append(concat_merge_stage)
elif stage_merge_style == 'sum_merge':
sum_merge_stage = mosaic_blocks.DecoderSumMergeBlock(
decoder_projected_depth=decoder_projected_filter,
output_size=(0, 0),
use_sync_bn=use_sync_bn,
batchnorm_momentum=batchnorm_momentum,
batchnorm_epsilon=batchnorm_epsilon,
activation=activation,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
interpolation=interpolation)
self._merge_stages.append(sum_merge_stage)
else:
raise ValueError(
'A stage merge style in MOSAIC Decoder can only be concat_merge '
'or sum_merge.')
# Concat merge or sum merge does not require an additional classifer layer
# unless the final decoder projected filter does not match num_classes.
final_decoder_projected_filter = decoder_projected_filters[-1]
if (final_decoder_projected_filter != num_classes and
not use_additional_classifier_layer):
raise ValueError('Additional classifier layer is needed if final decoder '
'projected filters does not match num_classes!')
self._use_additional_classifier_layer = use_additional_classifier_layer
if use_additional_classifier_layer:
# This additional classification layer uses different kernel
# initializers and bias compared to earlier blocks.
self._pixelwise_classifier = tf.keras.layers.Conv2D(
name='pixelwise_classifier',
filters=num_classes,
kernel_size=classifier_kernel_size,
padding='same',
bias_initializer=tf.zeros_initializer(),
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
use_bias=True)
self._activation_fn = tf_utils.get_activation(activation)
self._config_dict = {
'num_classes': num_classes,
'decoder_input_levels': decoder_input_levels,
'decoder_stage_merge_styles': decoder_stage_merge_styles,
'decoder_filters': decoder_filters,
'decoder_projected_filters': decoder_projected_filters,
'encoder_end_level': encoder_end_level,
'use_additional_classifier_layer': use_additional_classifier_layer,
'classifier_kernel_size': classifier_kernel_size,
'activation': activation,
'use_sync_bn': use_sync_bn,
'batchnorm_momentum': batchnorm_momentum,
'batchnorm_epsilon': batchnorm_epsilon,
'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer,
'interpolation': interpolation,
'bias_regularizer': bias_regularizer
}
def call(self,
inputs: Tuple[Union[tf.Tensor, Mapping[str, tf.Tensor]],
Union[tf.Tensor, Mapping[str, tf.Tensor]]],
training: Optional[bool] = None) -> tf.Tensor:
"""Forward pass of the segmentation head.
It supports a tuple of 2 elements. Each element is a tensor or a tensor
dictionary. The first one is the final (low-resolution) encoder endpoints,
and the second one is higher-resolution backbone endpoints.
When inputs are tensors, they are from a single level of feature maps.
When inputs are dictionaries, they contain multiple levels of feature maps,
where the key is the level/index of feature map.
Note: 'level' denotes the number of 2x downsampling, defined in backbone.
Args:
inputs: A tuple of 2 elements, each element can either be a tensor
representing feature maps or 1 dictionary of tensors:
- key: A `str` of the level of the multilevel features.
- values: A `tf.Tensor` of the feature map tensors.
The first is encoder endpoints, and the second is backbone endpoints.
training: a `Boolean` indicating whether it is in `training` mode.
Returns:
segmentation mask prediction logits: A `tf.Tensor` representing the
output logits before the final segmentation mask.
"""
encoder_outputs = inputs[0]
backbone_outputs = inputs[1]
y = encoder_outputs[str(
self._config_dict['encoder_end_level'])] if isinstance(
encoder_outputs, dict) else encoder_outputs
if isinstance(backbone_outputs, dict):
for level, merge_stage in zip(
self._decoder_input_levels, self._merge_stages):
x = backbone_outputs[str(level)]
y = merge_stage([y, x], training=training)
else:
x = backbone_outputs
y = self._merge_stages[0]([y, x], training=training)
if self._use_additional_classifier_layer:
y = self._pixelwise_classifier(y)
y = self._activation_fn(y)
return y
def get_config(self) -> Dict[str, Any]:
"""Returns a config dictionary for initialization from serialization."""
base_config = super().get_config()
base_config.update(self._config_dict)
return base_config
@classmethod
def from_config(cls, config: Dict[str, Any]):
return cls(**config)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for mosaic_head."""
# Import libraries
from absl.testing import parameterized
import tensorflow as tf
from official.projects.mosaic.modeling import mosaic_head
class MosaicBlocksTest(parameterized.TestCase, tf.test.TestCase):
def test_mosaic_head(self):
decoder_head = mosaic_head.MosaicDecoderHead(
num_classes=32,
decoder_input_levels=['3', '2'],
decoder_stage_merge_styles=['concat_merge', 'sum_merge'],
decoder_filters=[64, 64],
decoder_projected_filters=[32, 32])
inputs = [
tf.ones([1, 32, 32, 128]), {
'2': tf.ones([1, 128, 128, 64]),
'3': tf.ones([1, 64, 64, 192])
}
]
outputs = decoder_head(inputs)
self.assertAllEqual(outputs.shape, [1, 128, 128, 32])
def test_mosaic_head_3laterals(self):
decoder_head = mosaic_head.MosaicDecoderHead(
num_classes=32,
decoder_input_levels=[3, 2, 1],
decoder_stage_merge_styles=[
'concat_merge', 'concat_merge', 'sum_merge'
],
decoder_filters=[64, 64, 64],
decoder_projected_filters=[32, 32, 32])
inputs = [
tf.ones([1, 32, 32, 128]), {
'1': tf.ones([1, 256, 256, 64]),
'2': tf.ones([1, 128, 128, 64]),
'3': tf.ones([1, 64, 64, 192])
}
]
outputs = decoder_head(inputs)
self.assertAllEqual(outputs.shape, [1, 256, 256, 32])
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Builds the overall MOSAIC segmentation models."""
from typing import Any, Dict, Optional, Union
import tensorflow as tf
from official.projects.mosaic.configs import mosaic_config
from official.projects.mosaic.modeling import mosaic_blocks
from official.projects.mosaic.modeling import mosaic_head
from official.vision.modeling import backbones
@tf.keras.utils.register_keras_serializable(package='Vision')
class MosaicSegmentationModel(tf.keras.Model):
"""A model class for segmentation using MOSAIC.
Input images are passed through a backbone first. A MOSAIC neck encoder
network is then applied, and finally a MOSAIC segmentation head is applied on
the outputs of the backbone and neck encoder network. Feature fusion and
decoding is done in the segmentation head.
Reference:
[MOSAIC: Mobile Segmentation via decoding Aggregated Information and encoded
Context](https://arxiv.org/pdf/2112.11623.pdf)
"""
def __init__(self,
backbone: tf.keras.Model,
head: tf.keras.layers.Layer,
neck: Optional[tf.keras.layers.Layer] = None,
**kwargs):
"""Segmentation initialization function.
Args:
backbone: A backbone network.
head: A segmentation head, e.g. MOSAIC decoder.
neck: An optional neck encoder network, e.g. MOSAIC encoder. If it is not
provided, the decoder head will be connected directly with the backbone.
**kwargs: keyword arguments to be passed.
"""
super(MosaicSegmentationModel, self).__init__(**kwargs)
self._config_dict = {
'backbone': backbone,
'neck': neck,
'head': head,
}
self.backbone = backbone
self.neck = neck
self.head = head
def call(self,
inputs: tf.Tensor,
training: bool = None) -> Dict[str, tf.Tensor]:
backbone_features = self.backbone(inputs)
if self.neck is not None:
neck_features = self.neck(backbone_features, training=training)
else:
neck_features = backbone_features
logits = self.head([neck_features, backbone_features], training=training)
outputs = {'logits': logits}
return outputs
@property
def checkpoint_items(
self) -> Dict[str, Union[tf.keras.Model, tf.keras.layers.Layer]]:
"""Returns a dictionary of items to be additionally checkpointed."""
items = dict(backbone=self.backbone, head=self.head)
if self.neck is not None:
items.update(neck=self.neck)
return items
def get_config(self) -> Dict[str, Any]:
"""Returns a config dictionary for initialization from serialization."""
base_config = super().get_config()
model_config = base_config
model_config.update(self._config_dict)
return model_config
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
def build_mosaic_segmentation_model(
input_specs: tf.keras.layers.InputSpec,
model_config: mosaic_config.MosaicSemanticSegmentationModel,
l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
backbone: Optional[tf.keras.Model] = None,
neck: Optional[tf.keras.layers.Layer] = None
) -> tf.keras.Model:
"""Builds MOSAIC Segmentation model."""
norm_activation_config = model_config.norm_activation
if backbone is None:
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
backbone_config=model_config.backbone,
norm_activation_config=norm_activation_config,
l2_regularizer=l2_regularizer)
if neck is None:
neck_config = model_config.neck
neck = mosaic_blocks.MosaicEncoderBlock(
encoder_input_level=neck_config.encoder_input_level,
branch_filter_depths=neck_config.branch_filter_depths,
conv_kernel_sizes=neck_config.conv_kernel_sizes,
pyramid_pool_bin_nums=neck_config.pyramid_pool_bin_nums,
use_sync_bn=norm_activation_config.use_sync_bn,
batchnorm_momentum=norm_activation_config.norm_momentum,
batchnorm_epsilon=norm_activation_config.norm_epsilon,
activation=neck_config.activation,
dropout_rate=neck_config.dropout_rate,
kernel_initializer=neck_config.kernel_initializer,
kernel_regularizer=l2_regularizer,
interpolation=neck_config.interpolation,
use_depthwise_convolution=neck_config.use_depthwise_convolution)
head_config = model_config.head
head = mosaic_head.MosaicDecoderHead(
num_classes=model_config.num_classes,
decoder_input_levels=head_config.decoder_input_levels,
decoder_stage_merge_styles=head_config.decoder_stage_merge_styles,
decoder_filters=head_config.decoder_filters,
decoder_projected_filters=head_config.decoder_projected_filters,
encoder_end_level=head_config.encoder_end_level,
use_additional_classifier_layer=head_config
.use_additional_classifier_layer,
classifier_kernel_size=head_config.classifier_kernel_size,
activation=head_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
batchnorm_momentum=norm_activation_config.norm_momentum,
batchnorm_epsilon=norm_activation_config.norm_epsilon,
kernel_initializer=head_config.kernel_initializer,
kernel_regularizer=l2_regularizer,
interpolation=head_config.interpolation)
model = MosaicSegmentationModel(
backbone=backbone, neck=neck, head=head)
return model
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the overall MOSAIC segmentation network modeling."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.projects.mosaic.modeling import mosaic_blocks
from official.projects.mosaic.modeling import mosaic_head
from official.projects.mosaic.modeling import mosaic_model
from official.vision.modeling import backbones
class SegmentationNetworkTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(128, [4, 8], [3, 2], ['concat_merge', 'sum_merge']),
(128, [1, 4, 8], [3, 2], ['concat_merge', 'sum_merge']),
(128, [1, 4, 8], [3, 2], ['sum_merge', 'sum_merge']),
(128, [1, 4, 8], [3, 2], ['concat_merge', 'concat_merge']),
(512, [1, 4, 8, 16], [3, 2], ['concat_merge', 'sum_merge']),
(256, [4, 8], [3, 2], ['concat_merge', 'sum_merge']),
(256, [1, 4, 8], [3, 2], ['concat_merge', 'sum_merge']),
(256, [1, 4, 8, 16], [3, 2], ['concat_merge', 'sum_merge']),
)
def test_mosaic_segmentation_model(self,
input_size,
pyramid_pool_bin_nums,
decoder_input_levels,
decoder_stage_merge_styles):
"""Test for building and calling of a MOSAIC segmentation network."""
num_classes = 32
inputs = np.random.rand(2, input_size, input_size, 3)
tf.keras.backend.set_image_data_format('channels_last')
backbone = backbones.MobileNet(model_id='MobileNetMultiAVGSeg')
encoder_input_level = 4
neck = mosaic_blocks.MosaicEncoderBlock(
encoder_input_level=encoder_input_level,
branch_filter_depths=[64, 64],
conv_kernel_sizes=[3, 5],
pyramid_pool_bin_nums=pyramid_pool_bin_nums)
head = mosaic_head.MosaicDecoderHead(
num_classes=num_classes,
decoder_input_levels=decoder_input_levels,
decoder_stage_merge_styles=decoder_stage_merge_styles,
decoder_filters=[64, 64],
decoder_projected_filters=[32, 32])
model = mosaic_model.MosaicSegmentationModel(
backbone=backbone,
head=head,
neck=neck,
)
# Calls the MOSAIC model.
outputs = model(inputs)
level = min(decoder_input_levels)
self.assertAllEqual(
[2, input_size // (2**level), input_size // (2**level), num_classes],
outputs['logits'].numpy().shape)
def test_serialize_deserialize(self):
"""Validate the mosaic network can be serialized and deserialized."""
num_classes = 8
backbone = backbones.ResNet(model_id=50)
neck = mosaic_blocks.MosaicEncoderBlock(
encoder_input_level=4,
branch_filter_depths=[64, 64],
conv_kernel_sizes=[3, 5],
pyramid_pool_bin_nums=[1, 4, 8, 16])
head = mosaic_head.MosaicDecoderHead(
num_classes=num_classes,
decoder_input_levels=[3, 2],
decoder_stage_merge_styles=['concat_merge', 'sum_merge'],
decoder_filters=[64, 64],
decoder_projected_filters=[32, 8])
model = mosaic_model.MosaicSegmentationModel(
backbone=backbone,
head=head,
neck=neck,
)
config = model.get_config()
new_model = mosaic_model.MosaicSegmentationModel.from_config(config)
# Validate that the config can be forced to JSON.
_ = new_model.to_json()
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(model.get_config(), new_model.get_config())
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task definition for image semantic segmentation with MOSAIC models."""
from absl import logging
import tensorflow as tf
from official.core import task_factory
from official.projects.mosaic.configs import mosaic_config
from official.projects.mosaic.modeling import mosaic_model
from official.vision.tasks import semantic_segmentation as seg_tasks
@task_factory.register_task_cls(mosaic_config.MosaicSemanticSegmentationTask)
class MosaicSemanticSegmentationTask(seg_tasks.SemanticSegmentationTask):
"""A task for semantic segmentation using MOSAIC model."""
# Note: the `build_model` is overrided to add an additional `train` flag
# for the purpose of indicating the model is built for performing `training`
# or `eval`. This is to make sure the model is initialized with proper
# `input_shape` if the model will be trained and evaluated in different
# `input_shape`. For example, the model is trained with cropping but
# evaluated with original shape.
def build_model(self, training: bool = True) -> tf.keras.Model:
"""Builds MOSAIC segmentation model."""
input_specs = tf.keras.layers.InputSpec(
shape=[None] + self.task_config.model.input_size)
l2_weight_decay = self.task_config.losses.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf.keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None)
model = mosaic_model.build_mosaic_segmentation_model(
input_specs=input_specs,
model_config=self.task_config.model,
l2_regularizer=l2_regularizer)
# Note: Create a dummy input and call model instance to initialize.
# This ensures all the layers are built; otherwise some layers may be
# missing from the model and cannot be associated with variables from
# a loaded checkpoint. The input size is determined by whether the model
# is built for performing training or eval.
if training:
input_size = self.task_config.train_data.output_size
crop_size = self.task_config.train_data.crop_size
if crop_size:
input_size = crop_size
else:
input_size = self.task_config.validation_data.output_size
dummy_input = tf.ones(shape=[1] + input_size + [3])
model(dummy_input)
return model
def initialize(self, model: tf.keras.Model):
"""Loads pretrained checkpoint."""
if not self.task_config.init_checkpoint:
return
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
# Restoring checkpoint.
if 'all' in self.task_config.init_checkpoint_modules:
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
ckpt_items = {}
if 'backbone' in self.task_config.init_checkpoint_modules:
ckpt_items.update(backbone=model.backbone)
if 'neck' in self.task_config.init_checkpoint_modules:
ckpt_items.update(neck=model.neck)
ckpt = tf.train.Checkpoint(**ckpt_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for mosaic task."""
# pylint: disable=unused-import
import os
from absl.testing import parameterized
import orbit
import tensorflow as tf
from official import vision
from official.core import exp_factory
from official.modeling import optimization
from official.projects.mosaic import mosaic_tasks
from official.projects.mosaic.configs import mosaic_config as exp_cfg
from official.vision.dataloaders import tfexample_utils
class MosaicTaskTest(parameterized.TestCase, tf.test.TestCase):
def _create_test_tfrecord(self, tfrecord_file, example, num_samples):
examples = [example] * num_samples
tfexample_utils.dump_to_tfrecord(
record_file=tfrecord_file, tf_examples=examples)
@parameterized.parameters(
('mosaic_mnv35_cityscapes', True),
('mosaic_mnv35_cityscapes', False),
)
def test_semantic_segmentation_task(self, test_config, is_training):
"""Tests mosaic task for training and eval using toy configs."""
input_image_size = [1024, 2048]
test_tfrecord_file = os.path.join(self.get_temp_dir(), 'seg_test.tfrecord')
example = tfexample_utils.create_segmentation_test_example(
image_height=input_image_size[0],
image_width=input_image_size[1],
image_channel=3)
self._create_test_tfrecord(
tfrecord_file=test_tfrecord_file, example=example, num_samples=10)
config = exp_factory.get_exp_config(test_config)
# Modify config to suit local testing
config.task.model.input_size = [None, None, 3]
config.trainer.steps_per_loop = 1
config.task.train_data.global_batch_size = 1
config.task.validation_data.global_batch_size = 1
config.task.train_data.output_size = [1024, 2048]
config.task.validation_data.output_size = [1024, 2048]
config.task.train_data.crop_size = [512, 512]
config.task.train_data.shuffle_buffer_size = 2
config.task.validation_data.shuffle_buffer_size = 2
config.task.validation_data.input_path = test_tfrecord_file
config.task.train_data.input_path = test_tfrecord_file
config.train_steps = 1
config.task.model.num_classes = 256
config.task.model.head.num_classes = 256
config.task.model.head.decoder_projected_filters = [256, 256]
task = mosaic_tasks.MosaicSemanticSegmentationTask(config.task)
model = task.build_model(training=is_training)
metrics = task.build_metrics(training=is_training)
strategy = tf.distribute.get_strategy()
data_config = config.task.train_data if is_training else config.task.validation_data
dataset = orbit.utils.make_distributed_dataset(strategy, task.build_inputs,
data_config)
iterator = iter(dataset)
opt_factory = optimization.OptimizerFactory(config.trainer.optimizer_config)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
if is_training:
logs = task.train_step(next(iterator), model, optimizer, metrics=metrics)
else:
logs = task.validation_step(next(iterator), model, metrics=metrics)
self.assertIn('loss', logs)
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training driver for MOSAIC models."""
from absl import app
from absl import flags
import gin
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import base_trainer
from official.core import config_definitions
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
# Import MOSAIC libraries to register the model into tf.vision
# model garden factory.
# pylint: disable=unused-import
from official.projects.mosaic import mosaic_tasks
from official.projects.mosaic.modeling import mosaic_model
from official.vision import registry_imports
# pylint: enable=unused-import
FLAGS = flags.FLAGS
# Note: we overrided the `build_trainer` due to the customized `build_model`
# methods in `MosaicSemanticSegmentationTask.
def _build_mosaic_trainer(params: config_definitions.ExperimentConfig,
task: mosaic_tasks.MosaicSemanticSegmentationTask,
model_dir: str, train: bool,
evaluate: bool) -> base_trainer.Trainer:
"""Creates custom trainer."""
checkpoint_exporter = train_lib.maybe_create_best_ckpt_exporter(
params, model_dir)
model = task.build_model(train)
optimizer = train_utils.create_optimizer(task, params)
trainer = base_trainer.Trainer(
params,
task,
model=model,
optimizer=optimizer,
train=train,
evaluate=evaluate,
checkpoint_exporter=checkpoint_exporter)
return trainer
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
if 'train' in FLAGS.mode:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
mosaic_trainer = _build_mosaic_trainer(
task=task,
params=params,
model_dir=model_dir,
train='train' in FLAGS.mode,
evaluate='eval' in FLAGS.mode)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir,
trainer=mosaic_trainer)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment