Commit 09b3f5ab authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 481936708
parent 54a70bac
# MOSAIC: Mobile Segmentation via decoding Aggregated Information and encoded Context
[![Paper](http://img.shields.io/badge/Paper-arXiv.2112.11623-B3181B?logo=arXiv)](https://arxiv.org/abs/2112.11623)
This repository is the official implementation of the following
paper.
* [MOSAIC: Mobile Segmentation via decoding Aggregated Information and encoded Context](https://arxiv.org/abs/2112.11623)
## Description
MOSAIC is a neural network architecture for efficient and accurate semantic
image segmentation on mobile devices. MOSAIC is designed using commonly
supported neural operations by diverse mobile hardware platforms for flexible
deployment across various mobile platforms. With a simple asymmetric
encoder-decoder structure which consists of an efficient multi-scale context
encoder and a light-weight hybrid decoder to recover spatial details from
aggregated information, MOSAIC achieves better balanced performance while
considering accuracy and computational cost. Deployed on top of a tailored
feature extraction backbone based on a searched classification network, MOSAIC
achieves a 5% absolute accuracy gain on ADE20K with similar or lower latency
compared to the current industry standard MLPerf mobile v1.0 models and
state-of-the-art architectures.
[MLPerf Mobile v2.0]((https://mlcommons.org/en/inference-mobile-20/)) included
MOSAIC as a new industry standard benchmark model for image segmentation.
Please see details [here](https://mlcommons.org/en/news/mlperf-inference-1q2022/).
You can also refer to the [MLCommons GitHub repository](https://github.com/mlcommons/mobile_open/tree/main/vision/mosaic).
## History
### Oct 13, 2022
* First release of MOSAIC in TensorFlow 2 including checkpoints that have been
pretrained on Cityscapes.
## Maintainers
* Weijun Wang ([weijunw-g](https://github.com/weijunw-g))
* Fang Yang ([fyangf](https://github.com/fyangf))
* Shixin Luo ([luotigerlsx](https://github.com/luotigerlsx))
## Requirements
[![Python](https://img.shields.io/pypi/pyversions/tensorflow.svg?style=plastic)](https://badge.fury.io/py/tensorflow)
[![tf-models-official PyPI](https://badge.fury.io/py/tf-models-official.svg)](https://badge.fury.io/py/tf-models-official)
## Results
The following table shows the mIoU measured on the `cityscapes` dataset.
| Config | Backbone | Resolution | branch_filter_depths | pyramid_pool_bin_nums | mIoU | Download |
|-------------------------|:--------------------:|:----------:|:--------------------:|:---------------------:|:-----:|:--------:|
| Paper reference config | MobileNetMultiAVGSeg | 1024x2048 | [32, 32] | [4, 8, 16] | 75.98 | [ckpt](https://storage.googleapis.com/tf_model_garden/vision/mosaic/MobileNetMultiAVGSeg-r1024-ebf32-nogp.tar.gz)<br>[tensorboard](https://tensorboard.dev/experiment/okEog90bSwupajFgJwGEIw//#scalars) |
| Current best config | MobileNetMultiAVGSeg | 1024x2048 | [64, 64] | [1, 4, 8, 16] | 77.24 | [ckpt](https://storage.googleapis.com/tf_model_garden/vision/mosaic/MobileNetMultiAVGSeg-r1024-ebf64-gp.tar.gz)<br>[tensorboard](https://tensorboard.dev/experiment/l5hkV7JaQM23EXeOBT6oJg/#scalars) |
* `branch_filter_depths`: the number of convolution channels in each branch at
a pyramid level after `Spatial Pyramid Pooling`
* `pyramid_pool_bin_nums`: the number of bins at each level of the `Spatial
Pyramid Pooling`
## Training
It can run on Google Cloud Platform using Cloud TPU.
[Here](https://cloud.google.com/tpu/docs/how-to) is the instruction of using
Cloud TPU. Following the instructions to set up Cloud TPU and
launch training by:
```shell
EXP_TYPE=mosaic_mnv35_cityscapes
EXP_NAME="<experiment-name>" # You can give any name to the experiment.
TPU_NAME="<tpu-name>" # The name assigned while creating a Cloud TPU
MODEL_DIR="gs://<path-to-model-directory>"
# Now launch the experiment.
python3 -m official.projects.mosaic.train \
--experiment=$EXP_TYPE \
--mode=train \
--tpu=$TPU_NAME \
--model_dir=$MODEL_DIR \
--config_file=official/projects/mosaic/configs/experiments/mosaic_mnv35_cityscapes_tdfs_tpu.yaml
```
## Evaluation
Please run this command line for evaluation.
```shell
EXP_TYPE=mosaic_mnv35_cityscapes
EXP_NAME="<experiment-name>" # You can give any name to the experiment.
TPU_NAME="<tpu-name>" # The name assigned while creating a Cloud TPU
MODEL_DIR="gs://<path-to-model-directory>"
# Now launch the experiment.
python3 -m official.projects.mosaic.train \
--experiment=$EXP_TYPE \
--mode=eval \
--tpu=$TPU_NAME \
--model_dir=$MODEL_DIR \
--config_file=official/projects/mosaic/configs/experiments/mosaic_mnv35_cityscapes_tdfs_tpu.yaml
```
## License
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
This project is licensed under the terms of the **Apache License 2.0**.
## Citation
If you want to cite this repository in your work, please consider citing the
paper.
```
@inproceedings{weijun2021mosaic,
title={MOSAIC: Mobile Segmentation via decoding Aggregated Information and
encoded Context},
author={Weijun Wang, Andrew Howard},
journal={arXiv preprint arXiv:2112.11623},
year={2021},
}
```
# Using Tensorflow datasets: 'cityscapes/semantic_segmentation'
# Some expected flags to use with xmanager launcher:
# --experiment_type=mosaic_mnv35_cityscapes
# --tpu_topology=4x4
# mIoU: 77.24%
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'float32'
task:
model:
num_classes: 19
input_size: [null, null, 3]
backbone:
type: 'mobilenet'
mobilenet:
model_id: 'MobileNetMultiAVGSeg'
output_intermediate_endpoints: true
output_stride: 16
neck:
branch_filter_depths: [64, 64]
conv_kernel_sizes: [3, 5]
pyramid_pool_bin_nums: [1, 4, 8, 16]
dropout_rate: 0.0
head:
num_classes: 19
decoder_input_levels: ['3/depthwise', '2/depthwise']
decoder_stage_merge_styles: ['concat_merge', 'sum_merge']
decoder_filters: [64, 64]
decoder_projected_filters: [19, 19]
norm_activation:
activation: relu
norm_epsilon: 0.001
norm_momentum: 0.99
use_sync_bn: true
init_checkpoint: 'gs://tf_model_garden/vision/mobilenet/v3.5multiavg_seg_float/'
init_checkpoint_modules: 'backbone'
losses:
l2_weight_decay: 1.0e-04
train_data:
output_size: [1024, 2048]
crop_size: [1024, 2048]
input_path: ''
tfds_name: 'cityscapes/semantic_segmentation'
tfds_split: 'train'
is_training: true
global_batch_size: 32
dtype: 'float32'
aug_rand_hflip: true
aug_scale_max: 2.0
aug_scale_min: 0.5
validation_data:
output_size: [1024, 2048]
input_path: ''
tfds_name: 'cityscapes/semantic_segmentation'
tfds_split: 'validation'
is_training: false
global_batch_size: 32
dtype: 'float32'
drop_remainder: false
resize_eval_groundtruth: true
trainer:
optimizer_config:
learning_rate:
polynomial:
decay_steps: 100000
initial_learning_rate: 0.1
power: 0.9
type: polynomial
optimizer:
sgd:
momentum: 0.9
type: sgd
warmup:
linear:
name: linear
warmup_learning_rate: 0
warmup_steps: 925
type: linear
steps_per_loop: 92 # 2975 / 32 = 92
summary_interval: 92
train_steps: 100000
validation_interval: 92
validation_steps: 16 # 500 / 32 = 16
checkpoint_interval: 92
best_checkpoint_export_subdir: 'best_ckpt'
best_checkpoint_eval_metric: 'mean_iou'
best_checkpoint_metric_comp: 'higher'
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration definition for Semantic Segmentation with MOSAIC."""
import dataclasses
import os
from typing import List, Optional, Union
import numpy as np
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling import optimization
from official.vision.configs import common
from official.vision.configs import semantic_segmentation as seg_cfg
from official.vision.configs.google import backbones
@dataclasses.dataclass
class MosaicDecoderHead(hyperparams.Config):
"""MOSAIC decoder head config for Segmentation."""
num_classes: int = 19
decoder_input_levels: List[str] = dataclasses.field(default_factory=list)
decoder_stage_merge_styles: List[str] = dataclasses.field(
default_factory=list)
decoder_filters: List[int] = dataclasses.field(default_factory=list)
decoder_projected_filters: List[int] = dataclasses.field(default_factory=list)
encoder_end_level: int = 4
use_additional_classifier_layer: bool = False
classifier_kernel_size: int = 1
activation: str = 'relu'
kernel_initializer: str = 'glorot_uniform'
interpolation: str = 'bilinear'
@dataclasses.dataclass
class MosaicEncoderNeck(hyperparams.Config):
"""MOSAIC encoder neck config for segmentation."""
encoder_input_level: Union[str, int] = '4'
branch_filter_depths: List[int] = dataclasses.field(default_factory=list)
conv_kernel_sizes: List[int] = dataclasses.field(default_factory=list)
pyramid_pool_bin_nums: List[int] = dataclasses.field(default_factory=list)
activation: str = 'relu'
dropout_rate: float = 0.1
kernel_initializer: str = 'glorot_uniform'
interpolation: str = 'bilinear'
use_depthwise_convolution: bool = True
@dataclasses.dataclass
class MosaicSemanticSegmentationModel(hyperparams.Config):
"""MOSAIC semantic segmentation model config."""
num_classes: int = 19
input_size: List[int] = dataclasses.field(default_factory=list)
head: MosaicDecoderHead = MosaicDecoderHead()
backbone: backbones.Backbone = backbones.Backbone(
type='mobilenet', mobilenet=backbones.MobileNet())
neck: MosaicEncoderNeck = MosaicEncoderNeck()
norm_activation: common.NormActivation = common.NormActivation(
use_sync_bn=True, norm_momentum=0.99, norm_epsilon=0.001)
@dataclasses.dataclass
class MosaicSemanticSegmentationTask(seg_cfg.SemanticSegmentationTask):
"""The config for MOSAIC segmentation task."""
model: MosaicSemanticSegmentationModel = MosaicSemanticSegmentationModel()
train_data: seg_cfg.DataConfig = seg_cfg.DataConfig(is_training=True)
validation_data: seg_cfg.DataConfig = seg_cfg.DataConfig(is_training=False)
losses: seg_cfg.Losses = seg_cfg.Losses()
evaluation: seg_cfg.Evaluation = seg_cfg.Evaluation()
train_input_partition_dims: List[int] = dataclasses.field(
default_factory=list)
eval_input_partition_dims: List[int] = dataclasses.field(
default_factory=list)
init_checkpoint: Optional[str] = None
init_checkpoint_modules: Union[
str, List[str]] = 'all' # all, backbone, and/or neck.
export_config: seg_cfg.ExportConfig = seg_cfg.ExportConfig()
# Cityscapes Dataset (Download and process the dataset yourself)
CITYSCAPES_TRAIN_EXAMPLES = 2975
CITYSCAPES_VAL_EXAMPLES = 500
CITYSCAPES_INPUT_PATH_BASE = 'cityscapes/tfrecord'
@exp_factory.register_config_factory('mosaic_mnv35_cityscapes')
def mosaic_mnv35_cityscapes() -> cfg.ExperimentConfig:
"""Instantiates an experiment configuration of image segmentation task.
This image segmentation experiment is conducted on Cityscapes dataset. The
model architecture is a MOSAIC encoder-decoer. The default backbone network is
a mobilenet variant called Mobilenet_v3.5-MultiAvg on top of which the MOSAIC
encoder-decoder can be deployed. All detailed configurations can be overridden
by a .yaml file provided by the user to launch the experiments. Please refer
to .yaml examples in the path of ../configs/experiments/.
Returns:
A particular instance of cfg.ExperimentConfig for MOSAIC model based
image semantic segmentation task.
"""
train_batch_size = 16
eval_batch_size = 16
steps_per_epoch = CITYSCAPES_TRAIN_EXAMPLES // train_batch_size
output_stride = 16
backbone_output_level = int(np.math.log2(output_stride))
config = cfg.ExperimentConfig(
task=MosaicSemanticSegmentationTask(
model=MosaicSemanticSegmentationModel(
# Cityscapes uses only 19 semantic classes for train/evaluation.
# The void (background) class is ignored in train and evaluation.
num_classes=19,
input_size=[None, None, 3],
backbone=backbones.Backbone(
type='mobilenet',
mobilenet=backbones.MobileNet(
model_id='MobileNetMultiAVGSeg',
output_intermediate_endpoints=True,
output_stride=output_stride)),
neck=MosaicEncoderNeck(
encoder_input_level=backbone_output_level,
branch_filter_depths=[64, 64],
conv_kernel_sizes=[3, 5],
pyramid_pool_bin_nums=[1, 4, 8, 16], # paper default
activation='relu',
dropout_rate=0.1,
kernel_initializer='glorot_uniform',
interpolation='bilinear',
use_depthwise_convolution=True),
head=MosaicDecoderHead(
num_classes=19,
decoder_input_levels=['3/depthwise', '2/depthwise'],
decoder_stage_merge_styles=['concat_merge', 'sum_merge'],
decoder_filters=[64, 64],
decoder_projected_filters=[19, 19],
encoder_end_level=backbone_output_level,
use_additional_classifier_layer=False,
classifier_kernel_size=1,
activation='relu',
kernel_initializer='glorot_uniform',
interpolation='bilinear'),
norm_activation=common.NormActivation(
activation='relu',
norm_momentum=0.99,
norm_epsilon=1e-3,
use_sync_bn=True)),
losses=seg_cfg.Losses(l2_weight_decay=4e-5),
train_data=seg_cfg.DataConfig(
input_path=os.path.join(CITYSCAPES_INPUT_PATH_BASE,
'train_fine**'),
crop_size=[1024, 2048],
output_size=[1024, 2048],
is_training=True,
global_batch_size=train_batch_size,
aug_scale_min=0.5,
aug_scale_max=2.0),
validation_data=seg_cfg.DataConfig(
input_path=os.path.join(CITYSCAPES_INPUT_PATH_BASE, 'val_fine*'),
output_size=[1024, 2048],
is_training=False,
global_batch_size=eval_batch_size,
resize_eval_groundtruth=True,
drop_remainder=False),
# Imagenet pre-trained Mobilenet_v3.5-MultiAvg checkpoint.
init_checkpoint='gs://tf_model_garden/vision/mobilenet/v3.5multiavg_seg_float/',
init_checkpoint_modules='backbone'),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=100000,
validation_steps=CITYSCAPES_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
best_checkpoint_eval_metric='mean_iou',
best_checkpoint_export_subdir='best_ckpt',
best_checkpoint_metric_comp='higher',
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'sgd',
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'polynomial',
'polynomial': {
'initial_learning_rate': 0.1,
'decay_steps': 100000,
'end_learning_rate': 0.0,
'power': 0.9
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 5 * steps_per_epoch,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
This diff is collapsed.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for mosaic_blocks."""
# Import libraries
from absl.testing import parameterized
import tensorflow as tf
from official.projects.mosaic.modeling import mosaic_blocks
class MosaicBlocksTest(parameterized.TestCase, tf.test.TestCase):
def test_multi_kernel_group_conv_block(self):
block = mosaic_blocks.MultiKernelGroupConvBlock([64, 64], [3, 5])
inputs = tf.ones([1, 4, 4, 448])
outputs = block(inputs)
self.assertAllEqual(outputs.shape, [1, 4, 4, 128])
def test_mosaic_encoder_block(self):
block = mosaic_blocks.MosaicEncoderBlock(
encoder_input_level=4,
branch_filter_depths=[64, 64],
conv_kernel_sizes=[3, 5],
pyramid_pool_bin_nums=[1, 4, 8, 16])
inputs = tf.ones([1, 32, 32, 448])
outputs = block(inputs)
self.assertAllEqual(outputs.shape, [1, 32, 32, 128])
def test_mosaic_encoder_block_odd_input_overlap_pool(self):
block = mosaic_blocks.MosaicEncoderBlock(
encoder_input_level=4,
branch_filter_depths=[64, 64],
conv_kernel_sizes=[3, 5],
pyramid_pool_bin_nums=[1, 4, 8, 16])
inputs = tf.ones([1, 31, 31, 448])
outputs = block(inputs)
self.assertAllEqual(outputs.shape, [1, 31, 31, 128])
def test_mosaic_encoder_non_separable_block(self):
block = mosaic_blocks.MosaicEncoderBlock(
encoder_input_level=4,
branch_filter_depths=[64, 64],
conv_kernel_sizes=[3, 5],
pyramid_pool_bin_nums=[1, 4, 8, 16],
use_depthwise_convolution=False)
inputs = tf.ones([1, 32, 32, 448])
outputs = block(inputs)
self.assertAllEqual(outputs.shape, [1, 32, 32, 128])
def test_mosaic_decoder_concat_merge_block(self):
concat_merge_block = mosaic_blocks.DecoderConcatMergeBlock(64, 32, [64, 64])
inputs = [tf.ones([1, 32, 32, 128]), tf.ones([1, 64, 64, 192])]
outputs = concat_merge_block(inputs)
self.assertAllEqual(outputs.shape, [1, 64, 64, 32])
def test_mosaic_decoder_concat_merge_block_default_output_size(self):
concat_merge_block = mosaic_blocks.DecoderConcatMergeBlock(64, 32)
inputs = [tf.ones([1, 32, 32, 128]), tf.ones([1, 64, 64, 192])]
outputs = concat_merge_block(inputs)
self.assertAllEqual(outputs.shape, [1, 64, 64, 32])
def test_mosaic_decoder_concat_merge_block_default_output_size_4x(self):
concat_merge_block = mosaic_blocks.DecoderConcatMergeBlock(64, 32)
inputs = [tf.ones([1, 32, 32, 128]), tf.ones([1, 128, 128, 192])]
outputs = concat_merge_block(inputs)
self.assertAllEqual(outputs.shape, [1, 128, 128, 32])
def test_mosaic_decoder_concat_merge_block_default_output_size_4x_rec(self):
concat_merge_block = mosaic_blocks.DecoderConcatMergeBlock(64, 32)
inputs = [tf.ones([1, 32, 64, 128]), tf.ones([1, 128, 256, 64])]
outputs = concat_merge_block(inputs)
self.assertAllEqual(outputs.shape, [1, 128, 256, 32])
def test_mosaic_decoder_sum_merge_block(self):
concat_merge_block = mosaic_blocks.DecoderSumMergeBlock(32, [128, 128])
inputs = [tf.ones([1, 64, 64, 32]), tf.ones([1, 128, 128, 64])]
outputs = concat_merge_block(inputs)
self.assertAllEqual(outputs.shape, [1, 128, 128, 32])
def test_mosaic_decoder_sum_merge_block_default_output_size(self):
concat_merge_block = mosaic_blocks.DecoderSumMergeBlock(32)
inputs = [tf.ones([1, 64, 64, 32]), tf.ones([1, 128, 128, 64])]
outputs = concat_merge_block(inputs)
self.assertAllEqual(outputs.shape, [1, 128, 128, 32])
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains definitions of segmentation head of the MOSAIC model."""
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
import tensorflow as tf
from official.modeling import tf_utils
from official.projects.mosaic.modeling import mosaic_blocks
@tf.keras.utils.register_keras_serializable(package='Vision')
class MosaicDecoderHead(tf.keras.layers.Layer):
"""Creates a MOSAIC decoder in segmentation head.
Reference:
[MOSAIC: Mobile Segmentation via decoding Aggregated Information and encoded
Context](https://arxiv.org/pdf/2112.11623.pdf)
"""
def __init__(
self,
num_classes: int,
decoder_input_levels: Optional[List[str]] = None,
decoder_stage_merge_styles: Optional[List[str]] = None,
decoder_filters: Optional[List[int]] = None,
decoder_projected_filters: Optional[List[int]] = None,
encoder_end_level: Optional[int] = 4,
use_additional_classifier_layer: bool = False,
classifier_kernel_size: int = 1,
activation: str = 'relu',
use_sync_bn: bool = False,
batchnorm_momentum: float = 0.99,
batchnorm_epsilon: float = 0.001,
kernel_initializer: str = 'GlorotUniform',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
interpolation: str = 'bilinear',
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
**kwargs):
"""Initializes a MOSAIC segmentation head.
Args:
num_classes: An `int` number of mask classification categories. The number
of classes does not include background class.
decoder_input_levels: A list of `str` specifying additional
input levels from the backbone outputs for mask refinement in decoder.
decoder_stage_merge_styles: A list of `str` specifying the merge style at
each stage of the decoder, merge styles can be 'concat_merge' or
'sum_merge'.
decoder_filters: A list of integers specifying the number of channels used
at each decoder stage. Note: this only has affects if the decoder merge
style is 'concat_merge'.
decoder_projected_filters: A list of integers specifying the number of
projected channels at the end of each decoder stage.
encoder_end_level: An optional integer specifying the output level of the
encoder stage, which is used if the input from the encoder to the
decoder head is a dictionary.
use_additional_classifier_layer: A `bool` specifying whether to use an
additional classifier layer or not. It must be True if the final decoder
projected filters does not match the `num_classes`.
classifier_kernel_size: An `int` number to specify the kernel size of the
classifier layer.
activation: A `str` that indicates which activation is used, e.g. 'relu',
'swish', etc.
use_sync_bn: A `bool` that indicates whether to use synchronized batch
normalization across different replicas.
batchnorm_momentum: A `float` of normalization momentum for the moving
average.
batchnorm_epsilon: A `float` added to variance to avoid dividing by zero.
kernel_initializer: Kernel initializer for conv layers. Defaults to
`glorot_uniform`.
kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
Conv2D. Default is None.
interpolation: The interpolation method for upsampling. Defaults to
`bilinear`.
bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
**kwargs: Additional keyword arguments to be passed.
"""
super(MosaicDecoderHead, self).__init__(**kwargs)
# Assuming 'decoder_input_levels' are sorted in descending order and the
# other setting are listed in the order according to 'decoder_input_levels'.
if decoder_input_levels is None:
decoder_input_levels = ['3', '2']
if decoder_stage_merge_styles is None:
decoder_stage_merge_styles = ['concat_merge', 'sum_merge']
if decoder_filters is None:
decoder_filters = [64, 64]
if decoder_projected_filters is None:
decoder_projected_filters = [32, 32]
self._decoder_input_levels = decoder_input_levels
self._decoder_stage_merge_styles = decoder_stage_merge_styles
self._decoder_filters = decoder_filters
self._decoder_projected_filters = decoder_projected_filters
if (len(decoder_input_levels) != len(decoder_stage_merge_styles) or
len(decoder_input_levels) != len(decoder_filters) or
len(decoder_input_levels) != len(decoder_projected_filters)):
raise ValueError('The number of Decoder inputs and settings must match.')
self._merge_stages = []
for (stage_merge_style, decoder_filter,
decoder_projected_filter) in zip(decoder_stage_merge_styles,
decoder_filters,
decoder_projected_filters):
if stage_merge_style == 'concat_merge':
concat_merge_stage = mosaic_blocks.DecoderConcatMergeBlock(
decoder_internal_depth=decoder_filter,
decoder_projected_depth=decoder_projected_filter,
output_size=(0, 0),
use_sync_bn=use_sync_bn,
batchnorm_momentum=batchnorm_momentum,
batchnorm_epsilon=batchnorm_epsilon,
activation=activation,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
interpolation=interpolation)
self._merge_stages.append(concat_merge_stage)
elif stage_merge_style == 'sum_merge':
sum_merge_stage = mosaic_blocks.DecoderSumMergeBlock(
decoder_projected_depth=decoder_projected_filter,
output_size=(0, 0),
use_sync_bn=use_sync_bn,
batchnorm_momentum=batchnorm_momentum,
batchnorm_epsilon=batchnorm_epsilon,
activation=activation,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
interpolation=interpolation)
self._merge_stages.append(sum_merge_stage)
else:
raise ValueError(
'A stage merge style in MOSAIC Decoder can only be concat_merge '
'or sum_merge.')
# Concat merge or sum merge does not require an additional classifer layer
# unless the final decoder projected filter does not match num_classes.
final_decoder_projected_filter = decoder_projected_filters[-1]
if (final_decoder_projected_filter != num_classes and
not use_additional_classifier_layer):
raise ValueError('Additional classifier layer is needed if final decoder '
'projected filters does not match num_classes!')
self._use_additional_classifier_layer = use_additional_classifier_layer
if use_additional_classifier_layer:
# This additional classification layer uses different kernel
# initializers and bias compared to earlier blocks.
self._pixelwise_classifier = tf.keras.layers.Conv2D(
name='pixelwise_classifier',
filters=num_classes,
kernel_size=classifier_kernel_size,
padding='same',
bias_initializer=tf.zeros_initializer(),
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
use_bias=True)
self._activation_fn = tf_utils.get_activation(activation)
self._config_dict = {
'num_classes': num_classes,
'decoder_input_levels': decoder_input_levels,
'decoder_stage_merge_styles': decoder_stage_merge_styles,
'decoder_filters': decoder_filters,
'decoder_projected_filters': decoder_projected_filters,
'encoder_end_level': encoder_end_level,
'use_additional_classifier_layer': use_additional_classifier_layer,
'classifier_kernel_size': classifier_kernel_size,
'activation': activation,
'use_sync_bn': use_sync_bn,
'batchnorm_momentum': batchnorm_momentum,
'batchnorm_epsilon': batchnorm_epsilon,
'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer,
'interpolation': interpolation,
'bias_regularizer': bias_regularizer
}
def call(self,
inputs: Tuple[Union[tf.Tensor, Mapping[str, tf.Tensor]],
Union[tf.Tensor, Mapping[str, tf.Tensor]]],
training: Optional[bool] = None) -> tf.Tensor:
"""Forward pass of the segmentation head.
It supports a tuple of 2 elements. Each element is a tensor or a tensor
dictionary. The first one is the final (low-resolution) encoder endpoints,
and the second one is higher-resolution backbone endpoints.
When inputs are tensors, they are from a single level of feature maps.
When inputs are dictionaries, they contain multiple levels of feature maps,
where the key is the level/index of feature map.
Note: 'level' denotes the number of 2x downsampling, defined in backbone.
Args:
inputs: A tuple of 2 elements, each element can either be a tensor
representing feature maps or 1 dictionary of tensors:
- key: A `str` of the level of the multilevel features.
- values: A `tf.Tensor` of the feature map tensors.
The first is encoder endpoints, and the second is backbone endpoints.
training: a `Boolean` indicating whether it is in `training` mode.
Returns:
segmentation mask prediction logits: A `tf.Tensor` representing the
output logits before the final segmentation mask.
"""
encoder_outputs = inputs[0]
backbone_outputs = inputs[1]
y = encoder_outputs[str(
self._config_dict['encoder_end_level'])] if isinstance(
encoder_outputs, dict) else encoder_outputs
if isinstance(backbone_outputs, dict):
for level, merge_stage in zip(
self._decoder_input_levels, self._merge_stages):
x = backbone_outputs[str(level)]
y = merge_stage([y, x], training=training)
else:
x = backbone_outputs
y = self._merge_stages[0]([y, x], training=training)
if self._use_additional_classifier_layer:
y = self._pixelwise_classifier(y)
y = self._activation_fn(y)
return y
def get_config(self) -> Dict[str, Any]:
"""Returns a config dictionary for initialization from serialization."""
base_config = super().get_config()
base_config.update(self._config_dict)
return base_config
@classmethod
def from_config(cls, config: Dict[str, Any]):
return cls(**config)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for mosaic_head."""
# Import libraries
from absl.testing import parameterized
import tensorflow as tf
from official.projects.mosaic.modeling import mosaic_head
class MosaicBlocksTest(parameterized.TestCase, tf.test.TestCase):
def test_mosaic_head(self):
decoder_head = mosaic_head.MosaicDecoderHead(
num_classes=32,
decoder_input_levels=['3', '2'],
decoder_stage_merge_styles=['concat_merge', 'sum_merge'],
decoder_filters=[64, 64],
decoder_projected_filters=[32, 32])
inputs = [
tf.ones([1, 32, 32, 128]), {
'2': tf.ones([1, 128, 128, 64]),
'3': tf.ones([1, 64, 64, 192])
}
]
outputs = decoder_head(inputs)
self.assertAllEqual(outputs.shape, [1, 128, 128, 32])
def test_mosaic_head_3laterals(self):
decoder_head = mosaic_head.MosaicDecoderHead(
num_classes=32,
decoder_input_levels=[3, 2, 1],
decoder_stage_merge_styles=[
'concat_merge', 'concat_merge', 'sum_merge'
],
decoder_filters=[64, 64, 64],
decoder_projected_filters=[32, 32, 32])
inputs = [
tf.ones([1, 32, 32, 128]), {
'1': tf.ones([1, 256, 256, 64]),
'2': tf.ones([1, 128, 128, 64]),
'3': tf.ones([1, 64, 64, 192])
}
]
outputs = decoder_head(inputs)
self.assertAllEqual(outputs.shape, [1, 256, 256, 32])
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Builds the overall MOSAIC segmentation models."""
from typing import Any, Dict, Optional, Union
import tensorflow as tf
from official.projects.mosaic.configs import mosaic_config
from official.projects.mosaic.modeling import mosaic_blocks
from official.projects.mosaic.modeling import mosaic_head
from official.vision.modeling import backbones
@tf.keras.utils.register_keras_serializable(package='Vision')
class MosaicSegmentationModel(tf.keras.Model):
"""A model class for segmentation using MOSAIC.
Input images are passed through a backbone first. A MOSAIC neck encoder
network is then applied, and finally a MOSAIC segmentation head is applied on
the outputs of the backbone and neck encoder network. Feature fusion and
decoding is done in the segmentation head.
Reference:
[MOSAIC: Mobile Segmentation via decoding Aggregated Information and encoded
Context](https://arxiv.org/pdf/2112.11623.pdf)
"""
def __init__(self,
backbone: tf.keras.Model,
head: tf.keras.layers.Layer,
neck: Optional[tf.keras.layers.Layer] = None,
**kwargs):
"""Segmentation initialization function.
Args:
backbone: A backbone network.
head: A segmentation head, e.g. MOSAIC decoder.
neck: An optional neck encoder network, e.g. MOSAIC encoder. If it is not
provided, the decoder head will be connected directly with the backbone.
**kwargs: keyword arguments to be passed.
"""
super(MosaicSegmentationModel, self).__init__(**kwargs)
self._config_dict = {
'backbone': backbone,
'neck': neck,
'head': head,
}
self.backbone = backbone
self.neck = neck
self.head = head
def call(self,
inputs: tf.Tensor,
training: bool = None) -> Dict[str, tf.Tensor]:
backbone_features = self.backbone(inputs)
if self.neck is not None:
neck_features = self.neck(backbone_features, training=training)
else:
neck_features = backbone_features
logits = self.head([neck_features, backbone_features], training=training)
outputs = {'logits': logits}
return outputs
@property
def checkpoint_items(
self) -> Dict[str, Union[tf.keras.Model, tf.keras.layers.Layer]]:
"""Returns a dictionary of items to be additionally checkpointed."""
items = dict(backbone=self.backbone, head=self.head)
if self.neck is not None:
items.update(neck=self.neck)
return items
def get_config(self) -> Dict[str, Any]:
"""Returns a config dictionary for initialization from serialization."""
base_config = super().get_config()
model_config = base_config
model_config.update(self._config_dict)
return model_config
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
def build_mosaic_segmentation_model(
input_specs: tf.keras.layers.InputSpec,
model_config: mosaic_config.MosaicSemanticSegmentationModel,
l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
backbone: Optional[tf.keras.Model] = None,
neck: Optional[tf.keras.layers.Layer] = None
) -> tf.keras.Model:
"""Builds MOSAIC Segmentation model."""
norm_activation_config = model_config.norm_activation
if backbone is None:
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
backbone_config=model_config.backbone,
norm_activation_config=norm_activation_config,
l2_regularizer=l2_regularizer)
if neck is None:
neck_config = model_config.neck
neck = mosaic_blocks.MosaicEncoderBlock(
encoder_input_level=neck_config.encoder_input_level,
branch_filter_depths=neck_config.branch_filter_depths,
conv_kernel_sizes=neck_config.conv_kernel_sizes,
pyramid_pool_bin_nums=neck_config.pyramid_pool_bin_nums,
use_sync_bn=norm_activation_config.use_sync_bn,
batchnorm_momentum=norm_activation_config.norm_momentum,
batchnorm_epsilon=norm_activation_config.norm_epsilon,
activation=neck_config.activation,
dropout_rate=neck_config.dropout_rate,
kernel_initializer=neck_config.kernel_initializer,
kernel_regularizer=l2_regularizer,
interpolation=neck_config.interpolation,
use_depthwise_convolution=neck_config.use_depthwise_convolution)
head_config = model_config.head
head = mosaic_head.MosaicDecoderHead(
num_classes=model_config.num_classes,
decoder_input_levels=head_config.decoder_input_levels,
decoder_stage_merge_styles=head_config.decoder_stage_merge_styles,
decoder_filters=head_config.decoder_filters,
decoder_projected_filters=head_config.decoder_projected_filters,
encoder_end_level=head_config.encoder_end_level,
use_additional_classifier_layer=head_config
.use_additional_classifier_layer,
classifier_kernel_size=head_config.classifier_kernel_size,
activation=head_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
batchnorm_momentum=norm_activation_config.norm_momentum,
batchnorm_epsilon=norm_activation_config.norm_epsilon,
kernel_initializer=head_config.kernel_initializer,
kernel_regularizer=l2_regularizer,
interpolation=head_config.interpolation)
model = MosaicSegmentationModel(
backbone=backbone, neck=neck, head=head)
return model
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the overall MOSAIC segmentation network modeling."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.projects.mosaic.modeling import mosaic_blocks
from official.projects.mosaic.modeling import mosaic_head
from official.projects.mosaic.modeling import mosaic_model
from official.vision.modeling import backbones
class SegmentationNetworkTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(128, [4, 8], [3, 2], ['concat_merge', 'sum_merge']),
(128, [1, 4, 8], [3, 2], ['concat_merge', 'sum_merge']),
(128, [1, 4, 8], [3, 2], ['sum_merge', 'sum_merge']),
(128, [1, 4, 8], [3, 2], ['concat_merge', 'concat_merge']),
(512, [1, 4, 8, 16], [3, 2], ['concat_merge', 'sum_merge']),
(256, [4, 8], [3, 2], ['concat_merge', 'sum_merge']),
(256, [1, 4, 8], [3, 2], ['concat_merge', 'sum_merge']),
(256, [1, 4, 8, 16], [3, 2], ['concat_merge', 'sum_merge']),
)
def test_mosaic_segmentation_model(self,
input_size,
pyramid_pool_bin_nums,
decoder_input_levels,
decoder_stage_merge_styles):
"""Test for building and calling of a MOSAIC segmentation network."""
num_classes = 32
inputs = np.random.rand(2, input_size, input_size, 3)
tf.keras.backend.set_image_data_format('channels_last')
backbone = backbones.MobileNet(model_id='MobileNetMultiAVGSeg')
encoder_input_level = 4
neck = mosaic_blocks.MosaicEncoderBlock(
encoder_input_level=encoder_input_level,
branch_filter_depths=[64, 64],
conv_kernel_sizes=[3, 5],
pyramid_pool_bin_nums=pyramid_pool_bin_nums)
head = mosaic_head.MosaicDecoderHead(
num_classes=num_classes,
decoder_input_levels=decoder_input_levels,
decoder_stage_merge_styles=decoder_stage_merge_styles,
decoder_filters=[64, 64],
decoder_projected_filters=[32, 32])
model = mosaic_model.MosaicSegmentationModel(
backbone=backbone,
head=head,
neck=neck,
)
# Calls the MOSAIC model.
outputs = model(inputs)
level = min(decoder_input_levels)
self.assertAllEqual(
[2, input_size // (2**level), input_size // (2**level), num_classes],
outputs['logits'].numpy().shape)
def test_serialize_deserialize(self):
"""Validate the mosaic network can be serialized and deserialized."""
num_classes = 8
backbone = backbones.ResNet(model_id=50)
neck = mosaic_blocks.MosaicEncoderBlock(
encoder_input_level=4,
branch_filter_depths=[64, 64],
conv_kernel_sizes=[3, 5],
pyramid_pool_bin_nums=[1, 4, 8, 16])
head = mosaic_head.MosaicDecoderHead(
num_classes=num_classes,
decoder_input_levels=[3, 2],
decoder_stage_merge_styles=['concat_merge', 'sum_merge'],
decoder_filters=[64, 64],
decoder_projected_filters=[32, 8])
model = mosaic_model.MosaicSegmentationModel(
backbone=backbone,
head=head,
neck=neck,
)
config = model.get_config()
new_model = mosaic_model.MosaicSegmentationModel.from_config(config)
# Validate that the config can be forced to JSON.
_ = new_model.to_json()
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(model.get_config(), new_model.get_config())
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task definition for image semantic segmentation with MOSAIC models."""
from absl import logging
import tensorflow as tf
from official.core import task_factory
from official.projects.mosaic.configs import mosaic_config
from official.projects.mosaic.modeling import mosaic_model
from official.vision.tasks import semantic_segmentation as seg_tasks
@task_factory.register_task_cls(mosaic_config.MosaicSemanticSegmentationTask)
class MosaicSemanticSegmentationTask(seg_tasks.SemanticSegmentationTask):
"""A task for semantic segmentation using MOSAIC model."""
# Note: the `build_model` is overrided to add an additional `train` flag
# for the purpose of indicating the model is built for performing `training`
# or `eval`. This is to make sure the model is initialized with proper
# `input_shape` if the model will be trained and evaluated in different
# `input_shape`. For example, the model is trained with cropping but
# evaluated with original shape.
def build_model(self, training: bool = True) -> tf.keras.Model:
"""Builds MOSAIC segmentation model."""
input_specs = tf.keras.layers.InputSpec(
shape=[None] + self.task_config.model.input_size)
l2_weight_decay = self.task_config.losses.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf.keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None)
model = mosaic_model.build_mosaic_segmentation_model(
input_specs=input_specs,
model_config=self.task_config.model,
l2_regularizer=l2_regularizer)
# Note: Create a dummy input and call model instance to initialize.
# This ensures all the layers are built; otherwise some layers may be
# missing from the model and cannot be associated with variables from
# a loaded checkpoint. The input size is determined by whether the model
# is built for performing training or eval.
if training:
input_size = self.task_config.train_data.output_size
crop_size = self.task_config.train_data.crop_size
if crop_size:
input_size = crop_size
else:
input_size = self.task_config.validation_data.output_size
dummy_input = tf.ones(shape=[1] + input_size + [3])
model(dummy_input)
return model
def initialize(self, model: tf.keras.Model):
"""Loads pretrained checkpoint."""
if not self.task_config.init_checkpoint:
return
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
# Restoring checkpoint.
if 'all' in self.task_config.init_checkpoint_modules:
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
ckpt_items = {}
if 'backbone' in self.task_config.init_checkpoint_modules:
ckpt_items.update(backbone=model.backbone)
if 'neck' in self.task_config.init_checkpoint_modules:
ckpt_items.update(neck=model.neck)
ckpt = tf.train.Checkpoint(**ckpt_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for mosaic task."""
# pylint: disable=unused-import
import os
from absl.testing import parameterized
import orbit
import tensorflow as tf
from official import vision
from official.core import exp_factory
from official.modeling import optimization
from official.projects.mosaic import mosaic_tasks
from official.projects.mosaic.configs import mosaic_config as exp_cfg
from official.vision.dataloaders import tfexample_utils
class MosaicTaskTest(parameterized.TestCase, tf.test.TestCase):
def _create_test_tfrecord(self, tfrecord_file, example, num_samples):
examples = [example] * num_samples
tfexample_utils.dump_to_tfrecord(
record_file=tfrecord_file, tf_examples=examples)
@parameterized.parameters(
('mosaic_mnv35_cityscapes', True),
('mosaic_mnv35_cityscapes', False),
)
def test_semantic_segmentation_task(self, test_config, is_training):
"""Tests mosaic task for training and eval using toy configs."""
input_image_size = [1024, 2048]
test_tfrecord_file = os.path.join(self.get_temp_dir(), 'seg_test.tfrecord')
example = tfexample_utils.create_segmentation_test_example(
image_height=input_image_size[0],
image_width=input_image_size[1],
image_channel=3)
self._create_test_tfrecord(
tfrecord_file=test_tfrecord_file, example=example, num_samples=10)
config = exp_factory.get_exp_config(test_config)
# Modify config to suit local testing
config.task.model.input_size = [None, None, 3]
config.trainer.steps_per_loop = 1
config.task.train_data.global_batch_size = 1
config.task.validation_data.global_batch_size = 1
config.task.train_data.output_size = [1024, 2048]
config.task.validation_data.output_size = [1024, 2048]
config.task.train_data.crop_size = [512, 512]
config.task.train_data.shuffle_buffer_size = 2
config.task.validation_data.shuffle_buffer_size = 2
config.task.validation_data.input_path = test_tfrecord_file
config.task.train_data.input_path = test_tfrecord_file
config.train_steps = 1
config.task.model.num_classes = 256
config.task.model.head.num_classes = 256
config.task.model.head.decoder_projected_filters = [256, 256]
task = mosaic_tasks.MosaicSemanticSegmentationTask(config.task)
model = task.build_model(training=is_training)
metrics = task.build_metrics(training=is_training)
strategy = tf.distribute.get_strategy()
data_config = config.task.train_data if is_training else config.task.validation_data
dataset = orbit.utils.make_distributed_dataset(strategy, task.build_inputs,
data_config)
iterator = iter(dataset)
opt_factory = optimization.OptimizerFactory(config.trainer.optimizer_config)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
if is_training:
logs = task.train_step(next(iterator), model, optimizer, metrics=metrics)
else:
logs = task.validation_step(next(iterator), model, metrics=metrics)
self.assertIn('loss', logs)
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training driver for MOSAIC models."""
from absl import app
from absl import flags
import gin
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import base_trainer
from official.core import config_definitions
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
# Import MOSAIC libraries to register the model into tf.vision
# model garden factory.
# pylint: disable=unused-import
from official.projects.mosaic import mosaic_tasks
from official.projects.mosaic.modeling import mosaic_model
from official.vision import registry_imports
# pylint: enable=unused-import
FLAGS = flags.FLAGS
# Note: we overrided the `build_trainer` due to the customized `build_model`
# methods in `MosaicSemanticSegmentationTask.
def _build_mosaic_trainer(params: config_definitions.ExperimentConfig,
task: mosaic_tasks.MosaicSemanticSegmentationTask,
model_dir: str, train: bool,
evaluate: bool) -> base_trainer.Trainer:
"""Creates custom trainer."""
checkpoint_exporter = train_lib.maybe_create_best_ckpt_exporter(
params, model_dir)
model = task.build_model(train)
optimizer = train_utils.create_optimizer(task, params)
trainer = base_trainer.Trainer(
params,
task,
model=model,
optimizer=optimizer,
train=train,
evaluate=evaluate,
checkpoint_exporter=checkpoint_exporter)
return trainer
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
if 'train' in FLAGS.mode:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
mosaic_trainer = _build_mosaic_trainer(
task=task,
params=params,
model_dir=model_dir,
train='train' in FLAGS.mode,
evaluate='eval' in FLAGS.mode)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir,
trainer=mosaic_trainer)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment