Commit 48bc47ce authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #10072 from gunho1123:master

PiperOrigin-RevId: 394727655
parents f82ff3b3 2ab8af9f
# BASNet: Boundary-Aware Salient Object Detection
This repository is the unofficial implementation of the following paper. Please
see the paper
[BASNet: Boundary-Aware Salient Object Detection](https://openaccess.thecvf.com/content_CVPR_2019/html/Qin_BASNet_Boundary-Aware_Salient_Object_Detection_CVPR_2019_paper.html)
for more details.
## Requirements
[![TensorFlow 2.4](https://img.shields.io/badge/TensorFlow-2.4-FF6F00?logo=tensorflow)](https://github.com/tensorflow/tensorflow/releases/tag/v2.4.0)
[![Python 3.7](https://img.shields.io/badge/Python-3.7-3776AB)](https://www.python.org/downloads/release/python-379/)
## Train
```shell
$ python3 train.py \
--experiment=basnet_duts \
--mode=train \
--model_dir=$MODEL_DIR \
--config_file=./configs/experiments/basnet_dut_gpu.yaml
```
## Test
```shell
$ python3 train.py \
--experiment=basnet_duts \
--mode=eval \
--model_dir=$MODEL_DIR \
--config_file=./configs/experiments/basnet_dut_gpu.yaml
--params_override='runtime.num_gpus=1, runtime.distribution_strategy=one_device, task.model.input_size=[256, 256, 3]'
```
## Results
Dataset | maxF<sub>β</sub> | relaxF<sub>β</sub> | MAE
:--------- | :--------------- | :------------------- | -------:
DUTS-TE | 0.865 | 0.793 | 0.046
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BASNet configuration definition."""
import dataclasses
import os
from typing import List, Optional, Union
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling import optimization
from official.modeling.hyperparams import config_definitions as cfg
from official.vision.beta.configs import common
@dataclasses.dataclass
class DataConfig(cfg.DataConfig):
"""Input config for training."""
output_size: List[int] = dataclasses.field(default_factory=list)
# If crop_size is specified, image will be resized first to
# output_size, then crop of size crop_size will be cropped.
crop_size: List[int] = dataclasses.field(default_factory=list)
input_path: str = ''
global_batch_size: int = 0
is_training: bool = True
dtype: str = 'float32'
shuffle_buffer_size: int = 1000
cycle_length: int = 10
resize_eval_groundtruth: bool = True
groundtruth_padded_size: List[int] = dataclasses.field(default_factory=list)
aug_rand_hflip: bool = True
file_type: str = 'tfrecord'
@dataclasses.dataclass
class BASNetModel(hyperparams.Config):
"""BASNet model config."""
input_size: List[int] = dataclasses.field(default_factory=list)
use_bias: bool = False
norm_activation: common.NormActivation = common.NormActivation()
@dataclasses.dataclass
class Losses(hyperparams.Config):
label_smoothing: float = 0.1
ignore_label: int = 0 # will be treated as background
l2_weight_decay: float = 0.0
use_groundtruth_dimension: bool = True
@dataclasses.dataclass
class BASNetTask(cfg.TaskConfig):
"""The model config."""
model: BASNetModel = BASNetModel()
train_data: DataConfig = DataConfig(is_training=True)
validation_data: DataConfig = DataConfig(is_training=False)
losses: Losses = Losses()
gradient_clip_norm: float = 0.0
init_checkpoint: Optional[str] = None
init_checkpoint_modules: Union[
str, List[str]] = 'backbone' # all, backbone, and/or decoder
@exp_factory.register_config_factory('basnet')
def basnet() -> cfg.ExperimentConfig:
"""BASNet general."""
return cfg.ExperimentConfig(
task=BASNetModel(),
trainer=cfg.TrainerConfig(),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
# DUTS Dataset
DUTS_TRAIN_EXAMPLES = 10553
DUTS_VAL_EXAMPLES = 5019
DUTS_INPUT_PATH_BASE_TR = 'DUTS_DATASET'
DUTS_INPUT_PATH_BASE_VAL = 'DUTS_DATASET'
@exp_factory.register_config_factory('basnet_duts')
def basnet_duts() -> cfg.ExperimentConfig:
"""Image segmentation on duts with basnet."""
train_batch_size = 64
eval_batch_size = 16
steps_per_epoch = DUTS_TRAIN_EXAMPLES // train_batch_size
config = cfg.ExperimentConfig(
task=BASNetTask(
model=BASNetModel(
input_size=[None, None, 3],
use_bias=True,
norm_activation=common.NormActivation(
activation='relu',
norm_momentum=0.99,
norm_epsilon=1e-3,
use_sync_bn=True)),
losses=Losses(l2_weight_decay=0),
train_data=DataConfig(
input_path=os.path.join(DUTS_INPUT_PATH_BASE_TR,
'tf_record_train'),
file_type='tfrecord',
crop_size=[224, 224],
output_size=[256, 256],
is_training=True,
global_batch_size=train_batch_size,
),
validation_data=DataConfig(
input_path=os.path.join(DUTS_INPUT_PATH_BASE_VAL,
'tf_record_test'),
file_type='tfrecord',
output_size=[256, 256],
is_training=False,
global_batch_size=eval_batch_size,
),
init_checkpoint='gs://cloud-basnet-checkpoints/basnet_encoder_imagenet/ckpt-340306',
init_checkpoint_modules='backbone'
),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=300 * steps_per_epoch,
validation_steps=DUTS_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adam',
'adam': {
'beta_1': 0.9,
'beta_2': 0.999,
'epsilon': 1e-8,
}
},
'learning_rate': {
'type': 'constant',
'constant': {'learning_rate': 0.001}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for basnet configs."""
# pylint: disable=unused-import
from absl.testing import parameterized
import tensorflow as tf
from official.core import exp_factory
from official.modeling.hyperparams import config_definitions as cfg
from official.projects.basnet.configs import basnet as exp_cfg
class BASNetConfigTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(('basnet_duts',))
def test_basnet_configs(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.BASNetTask)
self.assertIsInstance(config.task.model,
exp_cfg.BASNetModel)
self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig)
config.task.train_data.is_training = None
with self.assertRaises(KeyError):
config.validate()
if __name__ == '__main__':
tf.test.main()
runtime:
distribution_strategy: 'mirrored'
mixed_precision_dtype: 'float32'
num_gpus: 8
task:
train_data:
dtype: 'float32'
validation_data:
resize_eval_groundtruth: true
dtype: 'float32'
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation metrics for BASNet.
The MAE and maxFscore implementations are a modified version of
https://github.com/xuebinqin/Binary-Segmentation-Evaluation-Tool
"""
import numpy as np
import scipy.signal
class MAE:
"""Mean Absolute Error(MAE) metric for basnet."""
def __init__(self):
"""Constructs MAE metric class."""
self.reset_states()
@property
def name(self):
return 'MAE'
def reset_states(self):
"""Resets internal states for a fresh run."""
self._predictions = []
self._groundtruths = []
def result(self):
"""Evaluates segmentation results, and reset_states."""
metric_result = self.evaluate()
# Cleans up the internal variables in order for a fresh eval next time.
self.reset_states()
return metric_result
def evaluate(self):
"""Evaluates with masks from all images.
Returns:
average_mae: average MAE with float numpy.
"""
mae_total = 0.0
for (true, pred) in zip(self._groundtruths, self._predictions):
# Computes MAE
mae = self._compute_mae(true, pred)
mae_total += mae
average_mae = mae_total / len(self._groundtruths)
return average_mae
def _mask_normalize(self, mask):
return mask/(np.amax(mask)+1e-8)
def _compute_mae(self, true, pred):
h, w = true.shape[0], true.shape[1]
mask1 = self._mask_normalize(true)
mask2 = self._mask_normalize(pred)
sum_error = np.sum(np.absolute((mask1.astype(float) - mask2.astype(float))))
mae_error = sum_error/(float(h)*float(w)+1e-8)
return mae_error
def _convert_to_numpy(self, groundtruths, predictions):
"""Converts tesnors to numpy arrays."""
numpy_groundtruths = groundtruths.numpy()
numpy_predictions = predictions.numpy()
return numpy_groundtruths, numpy_predictions
def update_state(self, groundtruths, predictions):
"""Update segmentation results and groundtruth data.
Args:
groundtruths : Tuple of single Tensor [batch, width, height, 1],
groundtruth masks. range [0, 1]
predictions : Tuple of single Tensor [batch, width, height, 1],
predicted masks. range [0, 1]
"""
groundtruths, predictions = self._convert_to_numpy(groundtruths[0],
predictions[0])
for (true, pred) in zip(groundtruths, predictions):
self._groundtruths.append(true)
self._predictions.append(pred)
class MaxFscore:
"""Maximum F-score metric for basnet."""
def __init__(self):
"""Constructs BASNet evaluation class."""
self.reset_states()
@property
def name(self):
return 'MaxFScore'
def reset_states(self):
"""Resets internal states for a fresh run."""
self._predictions = []
self._groundtruths = []
def result(self):
"""Evaluates segmentation results, and reset_states."""
metric_result = self.evaluate()
# Cleans up the internal variables in order for a fresh eval next time.
self.reset_states()
return metric_result
def evaluate(self):
"""Evaluates with masks from all images.
Returns:
f_max: maximum F-score value.
"""
mybins = np.arange(0, 256)
beta = 0.3
precisions = np.zeros((len(self._groundtruths), len(mybins)-1))
recalls = np.zeros((len(self._groundtruths), len(mybins)-1))
for i, (true, pred) in enumerate(zip(self._groundtruths,
self._predictions)):
# Compute F-score
true = self._mask_normalize(true) * 255.0
pred = self._mask_normalize(pred) * 255.0
pre, rec = self._compute_pre_rec(true, pred, mybins=np.arange(0, 256))
precisions[i, :] = pre
recalls[i, :] = rec
precisions = np.sum(precisions, 0) / (len(self._groundtruths) + 1e-8)
recalls = np.sum(recalls, 0) / (len(self._groundtruths) + 1e-8)
f = (1 + beta) * precisions * recalls / (beta * precisions + recalls + 1e-8)
f_max = np.max(f)
f_max = f_max.astype(np.float32)
return f_max
def _mask_normalize(self, mask):
return mask / (np.amax(mask) + 1e-8)
def _compute_pre_rec(self, true, pred, mybins=np.arange(0, 256)):
"""Computes relaxed precision and recall."""
# pixel number of ground truth foreground regions
gt_num = true[true > 128].size
# mask predicted pixel values in the ground truth foreground region
pp = pred[true > 128]
# mask predicted pixel values in the ground truth bacground region
nn = pred[true <= 128]
pp_hist, _ = np.histogram(pp, bins=mybins)
nn_hist, _ = np.histogram(nn, bins=mybins)
pp_hist_flip = np.flipud(pp_hist)
nn_hist_flip = np.flipud(nn_hist)
pp_hist_flip_cum = np.cumsum(pp_hist_flip)
nn_hist_flip_cum = np.cumsum(nn_hist_flip)
precision = pp_hist_flip_cum / (pp_hist_flip_cum + nn_hist_flip_cum + 1e-8
) # TP/(TP+FP)
recall = pp_hist_flip_cum / (gt_num + 1e-8) # TP/(TP+FN)
precision[np.isnan(precision)] = 0.0
recall[np.isnan(recall)] = 0.0
pre_len = len(precision)
rec_len = len(recall)
return np.reshape(precision, (pre_len)), np.reshape(recall, (rec_len))
def _convert_to_numpy(self, groundtruths, predictions):
"""Converts tesnors to numpy arrays."""
numpy_groundtruths = groundtruths.numpy()
numpy_predictions = predictions.numpy()
return numpy_groundtruths, numpy_predictions
def update_state(self, groundtruths, predictions):
"""Update segmentation results and groundtruth data.
Args:
groundtruths : Tuple of single Tensor [batch, width, height, 1],
groundtruth masks. range [0, 1]
predictions : Tuple of signle Tensor [batch, width, height, 1],
predicted masks. range [0, 1]
"""
groundtruths, predictions = self._convert_to_numpy(groundtruths[0],
predictions[0])
for (true, pred) in zip(groundtruths, predictions):
self._groundtruths.append(true)
self._predictions.append(pred)
class RelaxedFscore:
"""Relaxed F-score metric for basnet."""
def __init__(self):
"""Constructs BASNet evaluation class."""
self.reset_states()
@property
def name(self):
return 'RelaxFScore'
def reset_states(self):
"""Resets internal states for a fresh run."""
self._predictions = []
self._groundtruths = []
def result(self):
"""Evaluates segmentation results, and reset_states."""
metric_result = self.evaluate()
# Cleans up the internal variables in order for a fresh eval next time.
self.reset_states()
return metric_result
def evaluate(self):
"""Evaluates with masks from all images.
Returns:
relax_f: relaxed F-score value.
"""
beta = 0.3
rho = 3
relax_fs = np.zeros(len(self._groundtruths))
erode_kernel = np.ones((3, 3))
for i, (true,
pred) in enumerate(zip(self._groundtruths, self._predictions)):
true = self._mask_normalize(true)
pred = self._mask_normalize(pred)
true = np.squeeze(true, axis=-1)
pred = np.squeeze(pred, axis=-1)
# binary saliency mask (S_bw), threshold 0.5
pred[pred >= 0.5] = 1
pred[pred < 0.5] = 0
# compute eroded binary mask (S_erd) of S_bw
pred_erd = self._compute_erosion(pred, erode_kernel)
pred_xor = np.logical_xor(pred_erd, pred)
# convert True/False to 1/0
pred_xor = pred_xor * 1
# same method for ground truth
true[true >= 0.5] = 1
true[true < 0.5] = 0
true_erd = self._compute_erosion(true, erode_kernel)
true_xor = np.logical_xor(true_erd, true)
true_xor = true_xor * 1
pre, rec = self._compute_relax_pre_rec(true_xor, pred_xor, rho)
relax_fs[i] = (1 + beta) * pre * rec / (beta * pre + rec + 1e-8)
relax_f = np.sum(relax_fs, 0) / (len(self._groundtruths) + 1e-8)
relax_f = relax_f.astype(np.float32)
return relax_f
def _mask_normalize(self, mask):
return mask/(np.amax(mask)+1e-8)
def _compute_erosion(self, mask, kernel):
kernel_full = np.sum(kernel)
mask_erd = scipy.signal.convolve2d(mask, kernel, mode='same')
mask_erd[mask_erd < kernel_full] = 0
mask_erd[mask_erd == kernel_full] = 1
return mask_erd
def _compute_relax_pre_rec(self, true, pred, rho):
"""Computes relaxed precision and recall."""
kernel = np.ones((2 * rho - 1, 2 * rho - 1))
map_zeros = np.zeros_like(pred)
map_ones = np.ones_like(pred)
pred_filtered = scipy.signal.convolve2d(pred, kernel, mode='same')
# True positive for relaxed precision
relax_pre_tp = np.where((true == 1) & (pred_filtered > 0), map_ones,
map_zeros)
true_filtered = scipy.signal.convolve2d(true, kernel, mode='same')
# True positive for relaxed recall
relax_rec_tp = np.where((pred == 1) & (true_filtered > 0), map_ones,
map_zeros)
return np.sum(relax_pre_tp) / np.sum(pred), np.sum(relax_rec_tp) / np.sum(
true)
def _convert_to_numpy(self, groundtruths, predictions):
"""Converts tesnors to numpy arrays."""
numpy_groundtruths = groundtruths.numpy()
numpy_predictions = predictions.numpy()
return numpy_groundtruths, numpy_predictions
def update_state(self, groundtruths, predictions):
"""Update segmentation results and groundtruth data.
Args:
groundtruths : Tuple of single Tensor [batch, width, height, 1],
groundtruth masks. range [0, 1]
predictions : Tuple of single Tensor [batch, width, height, 1],
predicted masks. range [0, 1]
"""
groundtruths, predictions = self._convert_to_numpy(groundtruths[0],
predictions[0])
for (true, pred) in zip(groundtruths, predictions):
self._groundtruths.append(true)
self._predictions.append(pred)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for metrics.py."""
from absl.testing import parameterized
import tensorflow as tf
from official.projects.basnet.evaluation import metrics
class BASNetMetricTest(parameterized.TestCase, tf.test.TestCase):
def test_mae(self):
input_size = 224
inputs = (tf.random.uniform([2, input_size, input_size, 1]),)
labels = (tf.random.uniform([2, input_size, input_size, 1]),)
mae_obj = metrics.MAE()
mae_obj.reset_states()
mae_obj.update_state(labels, inputs)
output = mae_obj.result()
mae_tf = tf.keras.metrics.MeanAbsoluteError()
mae_tf.reset_state()
mae_tf.update_state(labels[0], inputs[0])
compare = mae_tf.result().numpy()
self.assertAlmostEqual(output, compare, places=4)
def test_max_f(self):
input_size = 224
beta = 0.3
inputs = (tf.random.uniform([2, input_size, input_size, 1]),)
labels = (tf.random.uniform([2, input_size, input_size, 1]),)
max_f_obj = metrics.MaxFscore()
max_f_obj.reset_states()
max_f_obj.update_state(labels, inputs)
output = max_f_obj.result()
pre_tf = tf.keras.metrics.Precision(thresholds=0.78)
rec_tf = tf.keras.metrics.Recall(thresholds=0.78)
pre_tf.reset_state()
rec_tf.reset_state()
pre_tf.update_state(labels[0], inputs[0])
rec_tf.update_state(labels[0], inputs[0])
pre_out_tf = pre_tf.result().numpy()
rec_out_tf = rec_tf.result().numpy()
compare = (1+beta)*pre_out_tf*rec_out_tf/(beta*pre_out_tf+rec_out_tf+1e-8)
self.assertAlmostEqual(output, compare, places=1)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Losses used for BASNet models."""
import tensorflow as tf
EPSILON = 1e-5
class BASNetLoss:
"""BASNet hybrid loss."""
def __init__(self):
self._binary_crossentropy = tf.keras.losses.BinaryCrossentropy(
reduction=tf.keras.losses.Reduction.SUM, from_logits=False)
self._ssim = tf.image.ssim
def __call__(self, sigmoids, labels):
levels = sorted(sigmoids.keys())
labels_bce = tf.squeeze(labels, axis=-1)
labels = tf.cast(labels, tf.float32)
bce_losses = []
ssim_losses = []
iou_losses = []
for level in levels:
bce_losses.append(
self._binary_crossentropy(labels_bce, sigmoids[level]))
ssim_losses.append(
1 - self._ssim(sigmoids[level], labels, max_val=1.0))
iou_losses.append(
self._iou_loss(sigmoids[level], labels))
total_bce_loss = tf.math.add_n(bce_losses)
total_ssim_loss = tf.math.add_n(ssim_losses)
total_iou_loss = tf.math.add_n(iou_losses)
total_loss = total_bce_loss + total_ssim_loss + total_iou_loss
total_loss = total_loss / len(levels)
return total_loss
def _iou_loss(self, sigmoids, labels):
total_iou_loss = 0
intersection = tf.reduce_sum(sigmoids[:, :, :, :] * labels[:, :, :, :])
union = tf.reduce_sum(sigmoids[:, :, :, :]) + tf.reduce_sum(
labels[:, :, :, :]) - intersection
iou = intersection / union
total_iou_loss += 1-iou
return total_iou_loss
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Build BASNet models."""
from typing import Mapping
import tensorflow as tf
from official.modeling import tf_utils
from official.projects.basnet.modeling import nn_blocks
from official.vision.beta.modeling.backbones import factory
# Specifications for BASNet encoder.
# Each element in the block configuration is in the following format:
# (num_filters, stride, block_repeats, maxpool)
BASNET_ENCODER_SPECS = [
(64, 1, 3, 0), # ResNet-34,
(128, 2, 4, 0), # ResNet-34,
(256, 2, 6, 0), # ResNet-34,
(512, 2, 3, 1), # ResNet-34,
(512, 1, 3, 1), # BASNet,
(512, 1, 3, 0), # BASNet,
]
# Specifications for BASNet decoder.
# Each element in the block configuration is in the following format:
# (conv1_nf, conv1_dr, convm_nf, convm_dr, conv2_nf, conv2_dr, scale_factor)
# nf : num_filters, dr : dilation_rate
BASNET_BRIDGE_SPECS = [
(512, 2, 512, 2, 512, 2, 32), # Sup0, Bridge
]
BASNET_DECODER_SPECS = [
(512, 1, 512, 2, 512, 2, 32), # Sup1, stage6d
(512, 1, 512, 1, 512, 1, 16), # Sup2, stage5d
(512, 1, 512, 1, 256, 1, 8), # Sup3, stage4d
(256, 1, 256, 1, 128, 1, 4), # Sup4, stage3d
(128, 1, 128, 1, 64, 1, 2), # Sup5, stage2d
(64, 1, 64, 1, 64, 1, 1) # Sup6, stage1d
]
@tf.keras.utils.register_keras_serializable(package='Vision')
class BASNetModel(tf.keras.Model):
"""A BASNet model.
Boundary-Awar network (BASNet) were proposed in:
[1] Qin, Xuebin, et al.
Basnet: Boundary-aware salient object detection.
Input images are passed through backbone first. Decoder network is then
applied, and finally, refinement module is applied on the output of the
decoder network.
"""
def __init__(self,
backbone,
decoder,
refinement=None,
**kwargs):
"""BASNet initialization function.
Args:
backbone: a backbone network. basnet_encoder.
decoder: a decoder network. basnet_decoder.
refinement: a module for salient map refinement.
**kwargs: keyword arguments to be passed.
"""
super(BASNetModel, self).__init__(**kwargs)
self._config_dict = {
'backbone': backbone,
'decoder': decoder,
'refinement': refinement,
}
self.backbone = backbone
self.decoder = decoder
self.refinement = refinement
def call(self, inputs, training=None):
features = self.backbone(inputs)
if self.decoder:
features = self.decoder(features)
levels = sorted(features.keys())
new_key = str(len(levels))
if self.refinement:
features[new_key] = self.refinement(features[levels[-1]])
return features
@property
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
items = dict(backbone=self.backbone)
if self.decoder is not None:
items.update(decoder=self.decoder)
if self.refinement is not None:
items.update(refinement=self.refinement)
return items
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@tf.keras.utils.register_keras_serializable(package='Vision')
class BASNetEncoder(tf.keras.Model):
"""BASNet encoder."""
def __init__(
self,
input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]),
activation='relu',
use_sync_bn=False,
use_bias=True,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
**kwargs):
"""BASNet encoder initialization function.
Args:
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
use_bias: if True, use bias in conv2d.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
**kwargs: keyword arguments to be passed.
"""
self._input_specs = input_specs
self._use_sync_bn = use_sync_bn
self._use_bias = use_bias
self._activation = activation
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = -1
else:
bn_axis = 1
# Build BASNet Encoder.
inputs = tf.keras.Input(shape=input_specs.shape[1:])
x = tf.keras.layers.Conv2D(
filters=64, kernel_size=3, strides=1,
use_bias=self._use_bias, padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
inputs)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
x)
x = tf_utils.get_activation(activation)(x)
endpoints = {}
for i, spec in enumerate(BASNET_ENCODER_SPECS):
x = self._block_group(
inputs=x,
filters=spec[0],
strides=spec[1],
block_repeats=spec[2],
name='block_group_l{}'.format(i + 2))
endpoints[str(i)] = x
if spec[3]:
x = tf.keras.layers.MaxPool2D(pool_size=2, strides=2, padding='same')(x)
self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}
super(BASNetEncoder, self).__init__(
inputs=inputs, outputs=endpoints, **kwargs)
def _block_group(self,
inputs,
filters,
strides,
block_repeats=1,
name='block_group'):
"""Creates one group of residual blocks for the BASNet encoder model.
Args:
inputs: `Tensor` of size `[batch, channels, height, width]`.
filters: `int` number of filters for the first convolution of the layer.
strides: `int` stride to use for the first convolution of the layer. If
greater than 1, this layer will downsample the input.
block_repeats: `int` number of blocks contained in the layer.
name: `str`name for the block.
Returns:
The output `Tensor` of the block layer.
"""
x = nn_blocks.ResBlock(
filters=filters,
strides=strides,
use_projection=True,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=self._activation,
use_sync_bn=self._use_sync_bn,
use_bias=self._use_bias,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)(
inputs)
for _ in range(1, block_repeats):
x = nn_blocks.ResBlock(
filters=filters,
strides=1,
use_projection=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=self._activation,
use_sync_bn=self._use_sync_bn,
use_bias=self._use_bias,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)(
x)
return tf.identity(x, name=name)
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@property
def output_specs(self):
"""A dict of {level: TensorShape} pairs for the model output."""
return self._output_specs
@factory.register_backbone_builder('basnet_encoder')
def build_basnet_encoder(
input_specs: tf.keras.layers.InputSpec,
model_config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds BASNet Encoder backbone from a config."""
backbone_type = model_config.backbone.type
norm_activation_config = model_config.norm_activation
assert backbone_type == 'basnet_encoder', (f'Inconsistent backbone type '
f'{backbone_type}')
return BASNetEncoder(
input_specs=input_specs,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
use_bias=norm_activation_config.use_bias,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
@tf.keras.utils.register_keras_serializable(package='Vision')
class BASNetDecoder(tf.keras.layers.Layer):
"""BASNet decoder."""
def __init__(self,
activation='relu',
use_sync_bn=False,
use_bias=True,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
**kwargs):
"""BASNet decoder initialization function.
Args:
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
use_bias: if True, use bias in convolution.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
**kwargs: keyword arguments to be passed.
"""
super(BASNetDecoder, self).__init__(**kwargs)
self._config_dict = {
'activation': activation,
'use_sync_bn': use_sync_bn,
'use_bias': use_bias,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon,
'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
}
self._activation = tf_utils.get_activation(activation)
self._concat = tf.keras.layers.Concatenate(axis=-1)
self._sigmoid = tf.keras.layers.Activation(activation='sigmoid')
def build(self, input_shape):
"""Creates the variables of the BASNet decoder."""
conv_op = tf.keras.layers.Conv2D
conv_kwargs = {
'kernel_size': 3,
'strides': 1,
'use_bias': self._config_dict['use_bias'],
'kernel_initializer': self._config_dict['kernel_initializer'],
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
}
self._out_convs = []
self._out_usmps = []
# Bridge layers.
self._bdg_convs = []
for spec in BASNET_BRIDGE_SPECS:
blocks = []
for j in range(3):
blocks.append(nn_blocks.ConvBlock(
filters=spec[2*j],
dilation_rate=spec[2*j+1],
activation='relu',
use_sync_bn=self._config_dict['use_sync_bn'],
norm_momentum=0.99,
norm_epsilon=0.001,
**conv_kwargs))
self._bdg_convs.append(blocks)
self._out_convs.append(conv_op(
filters=1,
padding='same',
**conv_kwargs))
self._out_usmps.append(tf.keras.layers.UpSampling2D(
size=spec[6],
interpolation='bilinear'
))
# Decoder layers.
self._dec_convs = []
for spec in BASNET_DECODER_SPECS:
blocks = []
for j in range(3):
blocks.append(nn_blocks.ConvBlock(
filters=spec[2*j],
dilation_rate=spec[2*j+1],
activation='relu',
use_sync_bn=self._config_dict['use_sync_bn'],
norm_momentum=0.99,
norm_epsilon=0.001,
**conv_kwargs))
self._dec_convs.append(blocks)
self._out_convs.append(conv_op(
filters=1,
padding='same',
**conv_kwargs))
self._out_usmps.append(tf.keras.layers.UpSampling2D(
size=spec[6],
interpolation='bilinear'
))
def call(self, backbone_output: Mapping[str, tf.Tensor]):
"""Forward pass of the BASNet decoder.
Args:
backbone_output: A `dict` of tensors
- key: A `str` of the level of the multilevel features.
- values: A `tf.Tensor` of the feature map tensors, whose shape is
[batch, height_l, width_l, channels].
Returns:
sup: A `dict` of tensors
- key: A `str` of the level of the multilevel features.
- values: A `tf.Tensor` of the feature map tensors, whose shape is
[batch, height_l, width_l, channels].
"""
levels = sorted(backbone_output.keys(), reverse=True)
sup = {}
x = backbone_output[levels[0]]
for blocks in self._bdg_convs:
for block in blocks:
x = block(x)
sup['0'] = x
for i, blocks in enumerate(self._dec_convs):
x = self._concat([x, backbone_output[levels[i]]])
for block in blocks:
x = block(x)
sup[str(i+1)] = x
x = tf.keras.layers.UpSampling2D(
size=2,
interpolation='bilinear'
)(x)
for i, (conv, usmp) in enumerate(zip(self._out_convs, self._out_usmps)):
sup[str(i)] = self._sigmoid(usmp(conv(sup[str(i)])))
self._output_specs = {
str(order): sup[str(order)].get_shape()
for order in range(0, len(BASNET_DECODER_SPECS))
}
return sup
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@property
def output_specs(self):
"""A dict of {order: TensorShape} pairs for the model output."""
return self._output_specs
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for basnet network."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.projects.basnet.modeling import basnet_model
from official.projects.basnet.modeling import refunet
class BASNetNetworkTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(256),
(512),
)
def test_basnet_network_creation(
self, input_size):
"""Test for creation of a segmentation network."""
inputs = np.random.rand(2, input_size, input_size, 3)
tf.keras.backend.set_image_data_format('channels_last')
backbone = basnet_model.BASNetEncoder()
decoder = basnet_model.BASNetDecoder()
refinement = refunet.RefUnet()
model = basnet_model.BASNetModel(
backbone=backbone,
decoder=decoder,
refinement=refinement
)
sigmoids = model(inputs)
levels = sorted(sigmoids.keys())
self.assertAllEqual(
[2, input_size, input_size, 1],
sigmoids[levels[-1]].numpy().shape)
def test_serialize_deserialize(self):
"""Validate the network can be serialized and deserialized."""
backbone = basnet_model.BASNetEncoder()
decoder = basnet_model.BASNetDecoder()
refinement = refunet.RefUnet()
model = basnet_model.BASNetModel(
backbone=backbone,
decoder=decoder,
refinement=refinement
)
config = model.get_config()
new_model = basnet_model.BASNetModel.from_config(config)
# Validate that the config can be forced to JSON.
_ = new_model.to_json()
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(model.get_config(), new_model.get_config())
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains common building blocks for BasNet model."""
import tensorflow as tf
from official.modeling import tf_utils
@tf.keras.utils.register_keras_serializable(package='Vision')
class ConvBlock(tf.keras.layers.Layer):
"""A (Conv+BN+Activation) block."""
def __init__(self,
filters,
strides,
dilation_rate=1,
kernel_size=3,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
activation='relu',
use_bias=False,
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
**kwargs):
"""A vgg block with BN after convolutions.
Args:
filters: `int` number of filters for the first two convolutions. Note that
the third and final convolution will use 4 times as many filters.
strides: `int` block stride. If greater than 1, this block will ultimately
downsample the input.
dilation_rate: `int`, dilation rate for conv layers.
kernel_size: `int`, kernel size of conv layers.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
activation: `str` name of the activation function.
use_bias: `bool`, whether or not use bias in conv layers.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
**kwargs: keyword arguments to be passed.
"""
super(ConvBlock, self).__init__(**kwargs)
self._config_dict = {
'filters': filters,
'kernel_size': kernel_size,
'strides': strides,
'dilation_rate': dilation_rate,
'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
'activation': activation,
'use_sync_bn': use_sync_bn,
'use_bias': use_bias,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon
}
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation)
def build(self, input_shape):
conv_kwargs = {
'padding': 'same',
'use_bias': self._config_dict['use_bias'],
'kernel_initializer': self._config_dict['kernel_initializer'],
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
}
self._conv0 = tf.keras.layers.Conv2D(
filters=self._config_dict['filters'],
kernel_size=self._config_dict['kernel_size'],
strides=self._config_dict['strides'],
dilation_rate=self._config_dict['dilation_rate'],
**conv_kwargs)
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._config_dict['norm_momentum'],
epsilon=self._config_dict['norm_epsilon'])
super(ConvBlock, self).build(input_shape)
def get_config(self):
return self._config_dict
def call(self, inputs, training=None):
x = self._conv0(inputs)
x = self._norm0(x)
x = self._activation_fn(x)
return x
@tf.keras.utils.register_keras_serializable(package='Vision')
class ResBlock(tf.keras.layers.Layer):
"""A residual block."""
def __init__(self,
filters,
strides,
use_projection=False,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
activation='relu',
use_sync_bn=False,
use_bias=False,
norm_momentum=0.99,
norm_epsilon=0.001,
**kwargs):
"""Initializes a residual block with BN after convolutions.
Args:
filters: An `int` number of filters for the first two convolutions. Note
that the third and final convolution will use 4 times as many filters.
strides: An `int` block stride. If greater than 1, this block will
ultimately downsample the input.
use_projection: A `bool` for whether this block should use a projection
shortcut (versus the default identity shortcut). This is usually `True`
for the first block of a block group, which may change the number of
filters and the resolution.
kernel_initializer: A `str` of kernel_initializer for convolutional
layers.
kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
Conv2D. Default to None.
bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2d.
Default to None.
activation: A `str` name of the activation function.
use_sync_bn: A `bool`. If True, use synchronized batch normalization.
use_bias: A `bool`. If True, use bias in conv2d.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero.
**kwargs: Additional keyword arguments to be passed.
"""
super(ResBlock, self).__init__(**kwargs)
self._config_dict = {
'filters': filters,
'strides': strides,
'use_projection': use_projection,
'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
'activation': activation,
'use_sync_bn': use_sync_bn,
'use_bias': use_bias,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon
}
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation)
def build(self, input_shape):
conv_kwargs = {
'filters': self._config_dict['filters'],
'padding': 'same',
'use_bias': self._config_dict['use_bias'],
'kernel_initializer': self._config_dict['kernel_initializer'],
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
}
if self._config_dict['use_projection']:
self._shortcut = tf.keras.layers.Conv2D(
filters=self._config_dict['filters'],
kernel_size=1,
strides=self._config_dict['strides'],
use_bias=self._config_dict['use_bias'],
kernel_initializer=self._config_dict['kernel_initializer'],
kernel_regularizer=self._config_dict['kernel_regularizer'],
bias_regularizer=self._config_dict['bias_regularizer'])
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._config_dict['norm_momentum'],
epsilon=self._config_dict['norm_epsilon'])
self._conv1 = tf.keras.layers.Conv2D(
kernel_size=3,
strides=self._config_dict['strides'],
**conv_kwargs)
self._norm1 = self._norm(
axis=self._bn_axis,
momentum=self._config_dict['norm_momentum'],
epsilon=self._config_dict['norm_epsilon'])
self._conv2 = tf.keras.layers.Conv2D(
kernel_size=3,
strides=1,
**conv_kwargs)
self._norm2 = self._norm(
axis=self._bn_axis,
momentum=self._config_dict['norm_momentum'],
epsilon=self._config_dict['norm_epsilon'])
super(ResBlock, self).build(input_shape)
def get_config(self):
return self._config_dict
def call(self, inputs, training=None):
shortcut = inputs
if self._config_dict['use_projection']:
shortcut = self._shortcut(shortcut)
shortcut = self._norm0(shortcut)
x = self._conv1(inputs)
x = self._norm1(x)
x = self._activation_fn(x)
x = self._conv2(x)
x = self._norm2(x)
return self._activation_fn(x + shortcut)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RefUNet model."""
import tensorflow as tf
from official.projects.basnet.modeling import nn_blocks
@tf.keras.utils.register_keras_serializable(package='Vision')
class RefUnet(tf.keras.layers.Layer):
"""Residual Refinement Module of BASNet.
Boundary-Aware network (BASNet) were proposed in:
[1] Qin, Xuebin, et al.
Basnet: Boundary-aware salient object detection.
"""
def __init__(self,
activation='relu',
use_sync_bn=False,
use_bias=True,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
**kwargs):
"""Residual Refinement Module of BASNet.
Args:
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
use_bias: if True, use bias in conv2d.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
**kwargs: keyword arguments to be passed.
"""
super(RefUnet, self).__init__(**kwargs)
self._config_dict = {
'activation': activation,
'use_sync_bn': use_sync_bn,
'use_bias': use_bias,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon,
'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
}
self._concat = tf.keras.layers.Concatenate(axis=-1)
self._sigmoid = tf.keras.layers.Activation(activation='sigmoid')
self._maxpool = tf.keras.layers.MaxPool2D(
pool_size=2,
strides=2,
padding='valid')
self._upsample = tf.keras.layers.UpSampling2D(
size=2,
interpolation='bilinear')
def build(self, input_shape):
"""Creates the variables of the BASNet decoder."""
conv_op = tf.keras.layers.Conv2D
conv_kwargs = {
'kernel_size': 3,
'strides': 1,
'use_bias': self._config_dict['use_bias'],
'kernel_initializer': self._config_dict['kernel_initializer'],
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
}
self._in_conv = conv_op(
filters=64,
padding='same',
**conv_kwargs)
self._en_convs = []
for _ in range(4):
self._en_convs.append(nn_blocks.ConvBlock(
filters=64,
use_sync_bn=self._config_dict['use_sync_bn'],
norm_momentum=self._config_dict['norm_momentum'],
norm_epsilon=self._config_dict['norm_epsilon'],
**conv_kwargs))
self._bridge_convs = []
for _ in range(1):
self._bridge_convs.append(nn_blocks.ConvBlock(
filters=64,
use_sync_bn=self._config_dict['use_sync_bn'],
norm_momentum=self._config_dict['norm_momentum'],
norm_epsilon=self._config_dict['norm_epsilon'],
**conv_kwargs))
self._de_convs = []
for _ in range(4):
self._de_convs.append(nn_blocks.ConvBlock(
filters=64,
use_sync_bn=self._config_dict['use_sync_bn'],
norm_momentum=self._config_dict['norm_momentum'],
norm_epsilon=self._config_dict['norm_epsilon'],
**conv_kwargs))
self._out_conv = conv_op(
filters=1,
padding='same',
**conv_kwargs)
def call(self, inputs):
endpoints = {}
residual = inputs
x = self._in_conv(inputs)
# Top-down
for i, block in enumerate(self._en_convs):
x = block(x)
endpoints[str(i)] = x
x = self._maxpool(x)
# Bridge
for i, block in enumerate(self._bridge_convs):
x = block(x)
# Bottom-up
for i, block in enumerate(self._de_convs):
dtype = x.dtype
x = tf.cast(x, tf.float32)
x = self._upsample(x)
x = tf.cast(x, dtype)
x = self._concat([endpoints[str(3-i)], x])
x = block(x)
x = self._out_conv(x)
residual = tf.cast(residual, dtype=x.dtype)
output = self._sigmoid(x + residual)
self._output_specs = output.get_shape()
return output
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@property
def output_specs(self):
return self._output_specs
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Export module for BASNet."""
import tensorflow as tf
from official.projects.basnet.tasks import basnet
from official.vision.beta.serving import semantic_segmentation
MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255)
STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
class BASNetModule(semantic_segmentation.SegmentationModule):
"""BASNet Module."""
def _build_model(self):
input_specs = tf.keras.layers.InputSpec(
shape=[self._batch_size] + self._input_image_size + [3])
return basnet.build_basnet_model(
input_specs=input_specs,
model_config=self.params.task.model,
l2_regularizer=None)
def serve(self, images):
"""Cast image to float and run inference.
Args:
images: uint8 Tensor of shape [batch_size, None, None, 3]
Returns:
Tensor holding classification output logits.
"""
with tf.device('cpu:0'):
images = tf.cast(images, dtype=tf.float32)
images = tf.nest.map_structure(
tf.identity,
tf.map_fn(
self._build_inputs, elems=images,
fn_output_signature=tf.TensorSpec(
shape=self._input_image_size + [3], dtype=tf.float32),
parallel_iterations=32
)
)
masks = self.inference_step(images)
keys = sorted(masks.keys())
output = tf.image.resize(
masks[keys[-1]],
self._input_image_size, method='bilinear')
return dict(predicted_masks=output)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Export binary for BASNet.
To export a trained checkpoint in saved_model format (shell script):
EXPERIMENT_TYPE = XX
CHECKPOINT_PATH = XX
EXPORT_DIR_PATH = XX
export_saved_model --experiment=${EXPERIMENT_TYPE} \
--export_dir=${EXPORT_DIR_PATH}/ \
--checkpoint_path=${CHECKPOINT_PATH} \
--batch_size=2 \
--input_image_size=224,224
To serve (python):
export_dir_path = XX
input_type = XX
input_images = XX
imported = tf.saved_model.load(export_dir_path)
model_fn = imported.signatures['serving_default']
output = model_fn(input_images)
"""
from absl import app
from absl import flags
from official.core import exp_factory
from official.modeling import hyperparams
from official.projects.basnet.serving import basnet
from official.vision.beta.serving import export_saved_model_lib
FLAGS = flags.FLAGS
flags.DEFINE_string(
'experiment', None, 'experiment type, e.g. retinanet_resnetfpn_coco')
flags.DEFINE_string('export_dir', None, 'The export directory.')
flags.DEFINE_string('checkpoint_path', None, 'Checkpoint path.')
flags.DEFINE_multi_string(
'config_file',
default=None,
help='YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.')
flags.DEFINE_string(
'params_override', '',
'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.')
flags.DEFINE_integer(
'batch_size', None, 'The batch size.')
flags.DEFINE_string(
'input_type', 'image_tensor',
'One of `image_tensor`, `image_bytes`, `tf_example`.')
flags.DEFINE_string(
'input_image_size', '224,224',
'The comma-separated string of two integers representing the height,width '
'of the input to the model.')
def main(_):
params = exp_factory.get_exp_config(FLAGS.experiment)
for config_file in FLAGS.config_file or []:
params = hyperparams.override_params_dict(
params, config_file, is_strict=True)
if FLAGS.params_override:
params = hyperparams.override_params_dict(
params, FLAGS.params_override, is_strict=True)
params.validate()
params.lock()
export_saved_model_lib.export_inference_graph(
input_type=FLAGS.input_type,
batch_size=FLAGS.batch_size,
input_image_size=[int(x) for x in FLAGS.input_image_size.split(',')],
params=params,
checkpoint_path=FLAGS.checkpoint_path,
export_dir=FLAGS.export_dir,
export_module=basnet.BASNetModule(
params=params,
batch_size=FLAGS.batch_size,
input_image_size=[int(x) for x in FLAGS.input_image_size.split(',')]),
export_checkpoint_subdir='checkpoint',
export_saved_model_subdir='saved_model')
if __name__ == '__main__':
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BASNet task definition."""
from typing import Optional
from absl import logging
import tensorflow as tf
from official.common import dataset_fn
from official.core import base_task
from official.core import input_reader
from official.core import task_factory
from official.projects.basnet.configs import basnet as exp_cfg
from official.projects.basnet.evaluation import metrics as basnet_metrics
from official.projects.basnet.losses import basnet_losses
from official.projects.basnet.modeling import basnet_model
from official.projects.basnet.modeling import refunet
from official.vision.beta.dataloaders import segmentation_input
def build_basnet_model(
input_specs: tf.keras.layers.InputSpec,
model_config: exp_cfg.BASNetModel,
l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds BASNet model."""
norm_activation_config = model_config.norm_activation
backbone = basnet_model.BASNetEncoder(
input_specs=input_specs,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
use_bias=model_config.use_bias,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
decoder = basnet_model.BASNetDecoder(
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
use_bias=model_config.use_bias,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
refinement = refunet.RefUnet(
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
use_bias=model_config.use_bias,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
model = basnet_model.BASNetModel(backbone, decoder, refinement)
return model
@task_factory.register_task_cls(exp_cfg.BASNetTask)
class BASNetTask(base_task.Task):
"""A task for basnet."""
def build_model(self):
"""Builds basnet model."""
input_specs = tf.keras.layers.InputSpec(
shape=[None] + self.task_config.model.input_size)
l2_weight_decay = self.task_config.losses.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf.keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None)
model = build_basnet_model(
input_specs=input_specs,
model_config=self.task_config.model,
l2_regularizer=l2_regularizer)
return model
def initialize(self, model: tf.keras.Model):
"""Loads pretrained checkpoint."""
if not self.task_config.init_checkpoint:
return
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
# Restoring checkpoint.
if 'all' in self.task_config.init_checkpoint_modules:
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.assert_consumed()
else:
ckpt_items = {}
if 'backbone' in self.task_config.init_checkpoint_modules:
ckpt_items.update(backbone=model.backbone)
if 'decoder' in self.task_config.init_checkpoint_modules:
ckpt_items.update(decoder=model.decoder)
ckpt = tf.train.Checkpoint(**ckpt_items)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def build_inputs(self,
params: exp_cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None):
"""Builds BASNet input."""
ignore_label = self.task_config.losses.ignore_label
decoder = segmentation_input.Decoder()
parser = segmentation_input.Parser(
output_size=params.output_size,
crop_size=params.crop_size,
ignore_label=ignore_label,
aug_rand_hflip=params.aug_rand_hflip,
dtype=params.dtype)
reader = input_reader.InputReader(
params,
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
decoder_fn=decoder.decode,
parser_fn=parser.parse_fn(params.is_training))
dataset = reader.read(input_context=input_context)
return dataset
def build_losses(self, label, model_outputs, aux_losses=None):
"""Hybrid loss proposed in BASNet.
Args:
label: label.
model_outputs: Output logits of the classifier.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
Returns:
The total loss tensor.
"""
basnet_loss_fn = basnet_losses.BASNetLoss()
total_loss = basnet_loss_fn(model_outputs, label['masks'])
if aux_losses:
total_loss += tf.add_n(aux_losses)
return total_loss
def build_metrics(self, training=False):
"""Gets streaming metrics for training/validation."""
evaluations = []
if training:
evaluations = []
else:
self.mae_metric = basnet_metrics.MAE()
self.maxf_metric = basnet_metrics.MaxFscore()
self.relaxf_metric = basnet_metrics.RelaxedFscore()
return evaluations
def train_step(self, inputs, model, optimizer, metrics=None):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape:
outputs = model(features, training=True)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure(
lambda x: tf.cast(x, tf.float32), outputs)
# Computes per-replica loss.
loss = self.build_losses(
model_outputs=outputs, label=labels, aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if isinstance(
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if isinstance(
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
# Apply gradient clipping.
if self.task_config.gradient_clip_norm > 0:
grads, _ = tf.clip_by_global_norm(
grads, self.task_config.gradient_clip_norm)
optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {self.loss: loss}
return logs
def validation_step(self, inputs, model, metrics=None):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
outputs = self.inference_step(features, model)
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
loss = 0
logs = {self.loss: loss}
levels = sorted(outputs.keys())
logs.update(
{self.mae_metric.name: (labels['masks'], outputs[levels[-1]])})
logs.update(
{self.maxf_metric.name: (labels['masks'], outputs[levels[-1]])})
logs.update(
{self.relaxf_metric.name: (labels['masks'], outputs[levels[-1]])})
return logs
def inference_step(self, inputs, model):
"""Performs the forward step."""
return model(inputs, training=False)
def aggregate_logs(self, state=None, step_outputs=None):
if state is None:
self.mae_metric.reset_states()
self.maxf_metric.reset_states()
self.relaxf_metric.reset_states()
state = self.mae_metric
self.mae_metric.update_state(
step_outputs[self.mae_metric.name][0],
step_outputs[self.mae_metric.name][1])
self.maxf_metric.update_state(
step_outputs[self.maxf_metric.name][0],
step_outputs[self.maxf_metric.name][1])
self.relaxf_metric.update_state(
step_outputs[self.relaxf_metric.name][0],
step_outputs[self.relaxf_metric.name][1])
return state
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
result = {}
result['MAE'] = self.mae_metric.result()
result['maxF'] = self.maxf_metric.result()
result['relaxF'] = self.relaxf_metric.result()
return result
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""TensorFlow Model Garden Vision training driver."""
from absl import app
# pylint: disable=unused-import
from official.common import flags as tfm_flags
from official.projects.basnet.configs import basnet as basnet_cfg
from official.projects.basnet.modeling import basnet_model
from official.projects.basnet.modeling import refunet
from official.projects.basnet.tasks import basnet as basenet_task
from official.vision.beta import train
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(train.main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment