"vscode:/vscode.git/clone" did not exist on "bafeed46fb76fa337a771ebb41d65bb95039565a"
Commit 06000e73 authored by Gunho Park's avatar Gunho Park
Browse files

From basnet branch

parent dcdd2e40
# BASNet: Boundary-Aware Salient Object Detection
This repository is the unofficial implementation of the following paper. Please see the paper [BASNet: Boundary-Aware Salient Object Detection](https://openaccess.thecvf.com/content_CVPR_2019/html/Qin_BASNet_Boundary-Aware_Salient_Object_Detection_CVPR_2019_paper.html) for more details.
## Requirements
[![TensorFlow 2.4](https://img.shields.io/badge/TensorFlow-2.4-FF6F00?logo=tensorflow)](https://github.com/tensorflow/tensorflow/releases/tag/v2.4.0)
[![Python 3.7](https://img.shields.io/badge/Python-3.7-3776AB)](https://www.python.org/downloads/release/python-379/)
## Train
```shell
$ python3 train.py \
--experiment=basnet_duts \
--mode=train \
--model_dir=$MODEL_DIR \
--config_file=./configs/experiments/basnet_dut_gpu.yaml
```
## Test
```shell
$ python3 train.py \
--experiment=basnet_duts \
--mode=eval \
--model_dir=$MODEL_DIR \
--config_file=./configs/experiments/basnet_dut_gpu.yaml
--params_override='runtime.num_gpus=1, runtime.distribution_strategy=one_device, task.model.input_size=[256, 256, 3]'
```
## Results
Dataset | maxF<sub>β</sub> | relaxF<sub>β</sub> | MAE
:--------- | :--------------- | :------------------- | -------:
DUTS-TE | 0.865 | 0.793 | 0.046
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""All necessary imports for registration."""
# pylint: disable=unused-import
from official.vision import beta
from official.vision.beta.projects.basnet.configs import basnet
from official.vision.beta.projects.basnet.modeling import basnet_encoder
from official.vision.beta.projects.basnet.modeling import basnet_decoder
from official.vision.beta.projects.basnet.modeling import refunet
from official.vision.beta.projects.basnet.tasks import basnet
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BASNet configuration definition."""
import os
from typing import List, Union, Optional
import dataclasses
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling import optimization
from official.modeling.hyperparams import config_definitions as cfg
from official.vision.beta.configs import common
@dataclasses.dataclass
class DataConfig(cfg.DataConfig):
"""Input config for training."""
output_size: List[int] = dataclasses.field(default_factory=list)
# If crop_size is specified, image will be resized first to
# output_size, then crop of size crop_size will be cropped.
crop_size: List[int] = dataclasses.field(default_factory=list)
input_path: str = ''
global_batch_size: int = 0
is_training: bool = True
dtype: str = 'float32'
shuffle_buffer_size: int = 1000
cycle_length: int = 10
resize_eval_groundtruth: bool = True
groundtruth_padded_size: List[int] = dataclasses.field(default_factory=list)
aug_rand_hflip: bool = True
@dataclasses.dataclass
class BASNetModel(hyperparams.Config):
"""BASNet model config."""
input_size: List[int] = dataclasses.field(default_factory=list)
norm_activation: common.NormActivation = common.NormActivation()
@dataclasses.dataclass
class Losses(hyperparams.Config):
label_smoothing: float = 0.1
ignore_label: int = 0 # set 0 (background)
l2_weight_decay: float = 0.0
use_groundtruth_dimension: bool = True
@dataclasses.dataclass
class BASNetTask(cfg.TaskConfig):
"""The model config."""
model: BASNetModel = BASNetModel()
train_data: DataConfig = DataConfig(is_training=True)
validation_data: DataConfig = DataConfig(is_training=False)
losses: Losses = Losses()
gradient_clip_norm: float = 0.0
init_checkpoint: Optional[str] = None
init_checkpoint_modules: Union[
str, List[str]] = 'backbone' # all, backbone, and/or decoder
@exp_factory.register_config_factory('basnet')
def basnet() -> cfg.ExperimentConfig:
"""BASNet general."""
return cfg.ExperimentConfig(
task=BASNetModel(),
trainer=cfg.TrainerConfig(),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
# DUTS Dataset
DUTS_TRAIN_EXAMPLES = 10553
DUTS_VAL_EXAMPLES = 5019
DUTS_INPUT_PATH_BASE_TR = '/home/datasets/DUTS/DUTS_TR_TFRecords/'
DUTS_INPUT_PATH_BASE_VAL = '/home/datasets/DUTS/DUTS_TE_TFRecords/'
@exp_factory.register_config_factory('basnet_duts')
def basnet_duts() -> cfg.ExperimentConfig:
"""Image segmentation on duts with basnet."""
train_batch_size = 16
eval_batch_size = 16
steps_per_epoch = DUTS_TRAIN_EXAMPLES // train_batch_size
config = cfg.ExperimentConfig(
task=BASNetTask(
model=BASNetModel(
input_size=[None, None, 3],
norm_activation=common.NormActivation(
activation='relu',
norm_momentum=0.99,
norm_epsilon=1e-3,
use_sync_bn=True)),
losses=Losses(l2_weight_decay=0),
train_data=DataConfig(
input_path=os.path.join(DUTS_INPUT_PATH_BASE_TR, 'DUTS-TR-*'),
crop_size=[224,224],
output_size=[256,256],
is_training=True,
global_batch_size=train_batch_size,
),
validation_data=DataConfig(
input_path=os.path.join(DUTS_INPUT_PATH_BASE_VAL, 'DUTS-TE-*'),
output_size=[256,256],
is_training=False,
global_batch_size=eval_batch_size,
),
init_checkpoint='',
init_checkpoint_modules='backbone'
),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=300 * steps_per_epoch,
validation_steps=DUTS_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adam',
'adam': {
'beta_1': 0.9,
'beta_2': 0.999,
'epsilon': 1e-8,
}
},
'learning_rate': {
'type': 'constant',
'constant': {'learning_rate': 0.001}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for basnet."""
# pylint: disable=unused-import
from absl.testing import parameterized
import tensorflow as tf
from official.core import exp_factory
from official.modeling.hyperparams import config_definitions as cfg
from official.vision import beta
from official.vision.beta.projects.basnet.configs import basnet as exp_cfg
class BASNetConfigTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(('basnet_duts',))
def test_basnet_configs(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.BASNetTask)
self.assertIsInstance(config.task.model,
exp_cfg.BASNetModel)
self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig)
config.task.train_data.is_training = None
with self.assertRaises(KeyError):
config.validate()
if __name__ == '__main__':
tf.test.main()
runtime:
distribution_strategy: 'mirrored'
mixed_precision_dtype: 'float32'
num_gpus: 8
task:
init_checkpoint: '/home/gunho1123/ckpt/ckpt_encoder_wbias01'
init_checkpoint_modules: 'backbone'
train_data:
dtype: 'float32'
validation_data:
resize_eval_groundtruth : True
dtype: 'float32'
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for maskrcnn_model.py."""
# Import libraries
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.vision.beta.projects.basnet.evaluation import mae
from official.vision.beta.projects.basnet.evaluation import max_f
class BASNetMetricTest(parameterized.TestCase, tf.test.TestCase):
def test_mae(self):
input_size = 224
inputs = (tf.random.uniform([2, input_size, input_size, 1]),)
labels = (tf.random.uniform([2, input_size, input_size, 1]),)
mae_obj = mae.MAE()
mae_obj.reset_states()
mae_obj.update_state(labels, inputs)
output = mae_obj.result()
mae_tf = tf.keras.metrics.MeanAbsoluteError()
mae_tf.reset_state()
mae_tf.update_state(labels[0], inputs[0])
compare = mae_tf.result().numpy()
self.assertAlmostEqual(output, compare, places=4)
def test_max_f(self):
input_size = 224
beta = 0.3
inputs = (tf.random.uniform([2, input_size, input_size, 1]),)
labels = (tf.random.uniform([2, input_size, input_size, 1]),)
max_f_obj = max_f.maxFscore()
max_f_obj.reset_states()
max_f_obj.update_state(labels, inputs)
output = max_f_obj.result()
pre_tf = tf.keras.metrics.Precision(thresholds=0.78)
rec_tf = tf.keras.metrics.Recall(thresholds=0.78)
pre_tf.reset_state()
rec_tf.reset_state()
pre_tf.update_state(labels[0], inputs[0])
rec_tf.update_state(labels[0], inputs[0])
pre_out_tf = pre_tf.result().numpy()
rec_out_tf = rec_tf.result().numpy()
compare = (1+beta)*pre_out_tf*rec_out_tf/(beta*pre_out_tf+rec_out_tf+1e-8)
self.assertAlmostEqual(output, compare, places=1)
if __name__ == '__main__':
tf.test.main()
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
This source code is a modified version of
https://github.com/xuebinqin/Binary-Segmentation-Evaluation-Tool
"""
# Import libraries
import numpy as np
class MAE(object):
"""Mean Absolute Error(MAE) metric for basnet."""
def __init__(self):
"""Constructs MAE metric class.
Args:
"""
self.reset_states()
@property
def name(self):
return 'MAE'
def reset_states(self):
"""Resets internal states for a fresh run."""
self._predictions = []
self._groundtruths = []
def result(self):
"""Evaluates segmentation results, and reset_states."""
metric_result = self.evaluate()
# Cleans up the internal variables in order for a fresh eval next time.
self.reset_states()
return metric_result
def evaluate(self):
"""Evaluates with masks from all images.
Returns:
average_mae: average MAE with float numpy.
"""
mae_total = 0.0
for i, (true, pred) in enumerate(zip(self._groundtruths,
self._predictions)):
# Compute MAE
mae = self._compute_mae(true, pred)
mae_total += mae
average_mae = mae_total/len(self._groundtruths)
average_mae = average_mae.astype(np.float32)
return average_mae
def _mask_normalize(self, mask):
return mask/(np.amax(mask)+1e-8)
def _compute_mae(self, true, pred):
h, w = true.shape[0], true.shape[1]
mask1 = self._mask_normalize(true)
mask2 = self._mask_normalize(pred)
sum_error = np.sum(np.absolute((mask1.astype(float) - mask2.astype(float))))
mae_error = sum_error/(float(h)*float(w)+1e-8)
return mae_error
def _convert_to_numpy(self, groundtruths, predictions):
"""Converts tesnors to numpy arrays."""
numpy_groundtruths = groundtruths.numpy()
numpy_predictions = predictions.numpy()
return numpy_groundtruths, numpy_predictions
def update_state(self, groundtruths, predictions):
"""Update segmentation results and groundtruth data.
Args:
groundtruths : Tuple of single Tensor [batch, width, height, 1],
groundtruth masks. range [0, 1]
predictions : Tuple of single Tensor [batch, width, height, 1],
predicted masks. range [0, 1]
"""
groundtruths, predictions = self._convert_to_numpy(groundtruths[0],
predictions[0])
for (true, pred) in zip(groundtruths, predictions):
self._groundtruths.append(true)
self._predictions.append(pred)
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
This source code is a modified version of
https://github.com/xuebinqin/Binary-Segmentation-Evaluation-Tool
"""
# Import libraries
import numpy as np
class maxFscore(object):
"""Maximum F-score metric for basnet."""
def __init__(self):
"""Constructs BASNet evaluation class.
Args:
"""
self.reset_states()
@property
def name(self):
return 'maxF'
def reset_states(self):
"""Resets internal states for a fresh run."""
self._predictions = []
self._groundtruths = []
def result(self):
"""Evaluates segmentation results, and reset_states."""
metric_result = self.evaluate()
# Cleans up the internal variables in order for a fresh eval next time.
self.reset_states()
return metric_result
def evaluate(self):
"""Evaluates with masks from all images.
Returns:
f_max: maximum F-score value.
"""
mybins = np.arange(0, 256)
beta = 0.3
precisions = np.zeros((len(self._groundtruths), len(mybins)-1))
recalls = np.zeros((len(self._groundtruths), len(mybins)-1))
for i, (true, pred) in enumerate(zip(self._groundtruths,
self._predictions)):
# Compute F-score
true = self._mask_normalize(true)*255.0
pred = self._mask_normalize(pred)*255.0
pre, rec = self._compute_pre_rec(true, pred, mybins=np.arange(0,256))
precisions[i,:] = pre
recalls[i,:] = rec
precisions = np.sum(precisions,0)/(len(self._groundtruths)+1e-8)
recalls = np.sum(recalls,0)/(len(self._groundtruths)+1e-8)
f = (1+beta)*precisions*recalls/(beta*precisions+recalls+1e-8)
f_max = np.max(f)
f_max = f_max.astype(np.float32)
return f_max
def _mask_normalize(self, mask):
return mask/(np.amax(mask)+1e-8)
def _compute_pre_rec(self, true, pred, mybins=np.arange(0,256)):
# pixel number of ground truth foreground regions
gt_num = true[true>128].size
# mask predicted pixel values in the ground truth foreground region
pp = pred[true>128]
# mask predicted pixel values in the ground truth bacground region
nn = pred[true<=128]
pp_hist,pp_edges = np.histogram(pp,bins=mybins)
nn_hist,nn_edges = np.histogram(nn,bins=mybins)
pp_hist_flip = np.flipud(pp_hist)
nn_hist_flip = np.flipud(nn_hist)
pp_hist_flip_cum = np.cumsum(pp_hist_flip)
nn_hist_flip_cum = np.cumsum(nn_hist_flip)
precision = pp_hist_flip_cum/(pp_hist_flip_cum + nn_hist_flip_cum+1e-8) #TP/(TP+FP)
recall = pp_hist_flip_cum/(gt_num+1e-8) #TP/(TP+FN)
precision[np.isnan(precision)]= 0.0
recall[np.isnan(recall)] = 0.0
pre_len = len(precision)
rec_len = len(recall)
return np.reshape(precision,(pre_len)), np.reshape(recall,(rec_len))
def _convert_to_numpy(self, groundtruths, predictions):
"""Converts tesnors to numpy arrays."""
numpy_groundtruths = groundtruths.numpy()
numpy_predictions = predictions.numpy()
return numpy_groundtruths, numpy_predictions
def update_state(self, groundtruths, predictions):
"""Update segmentation results and groundtruth data.
Args:
groundtruths : Tuple of single Tensor [batch, width, height, 1],
groundtruth masks. range [0, 1]
predictions : Tuple of signle Tensor [batch, width, height, 1],
predicted masks. range [0, 1]
"""
groundtruths, predictions = self._convert_to_numpy(groundtruths[0],
predictions[0])
for (true, pred) in zip(groundtruths, predictions):
self._groundtruths.append(true)
self._predictions.append(pred)
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Metrics for basnet"""
# Import libraries
import numpy as np
from scipy import signal
class relaxedFscore(object):
"""Relaxed F-score metric for basnet."""
def __init__(self):
"""Constructs BASNet evaluation class.
Args:
"""
self.reset_states()
@property
def name(self):
return 'relaxF'
def reset_states(self):
"""Resets internal states for a fresh run."""
self._predictions = []
self._groundtruths = []
def result(self):
"""Evaluates segmentation results, and reset_states."""
metric_result = self.evaluate()
# Cleans up the internal variables in order for a fresh eval next time.
self.reset_states()
return metric_result
def evaluate(self):
"""Evaluates with masks from all images.
Returns:
relax_f: relaxed F-score value.
"""
beta = 0.3
rho = 3
relax_fs = np.zeros(len(self._groundtruths))
erode_kernel = np.ones((3,3))
for i, (true, pred) in enumerate(zip(self._groundtruths,
self._predictions)):
true = self._mask_normalize(true)
pred = self._mask_normalize(pred)
true = np.squeeze(true, axis=-1)
pred = np.squeeze(pred, axis=-1)
# binary saliency mask (S_bw), threshold 0.5
pred[pred>=0.5]= 1
pred[pred<0.5]= 0
# compute eroded binary mask (S_erd) of S_bw
pred_erd = self._compute_erosion(pred, erode_kernel)
pred_xor = np.logical_xor(pred_erd, pred)
# convert True/False to 1/0
pred_xor = pred_xor * 1
# same method for ground truth
true[true>=0.5]= 1
true[true<0.5]= 0
true_erd = self._compute_erosion(true, erode_kernel)
true_xor = np.logical_xor(true_erd, true)
true_xor = true_xor * 1
pre, rec = self._compute_relax_pre_rec(true_xor, pred_xor, rho)
relax_fs[i] = (1+beta)*pre*rec/(beta*pre+rec+1e-8)
relax_f = np.sum(relax_fs,0)/(len(self._groundtruths)+1e-8)
relax_f = relax_f.astype(np.float32)
return relax_f
def _mask_normalize(self, mask):
return mask/(np.amax(mask)+1e-8)
def _compute_erosion(self, mask, kernel):
kernel_full = np.sum(kernel)
mask_erd = signal.convolve2d(mask, kernel, mode='same')
mask_erd[mask_erd<kernel_full] = 0
mask_erd[mask_erd==kernel_full] = 1
return mask_erd
def _compute_relax_pre_rec(self, true, pred, rho):
kernel = np.ones((2*rho-1,2*rho-1))
map_zeros = np.zeros_like(pred)
map_ones = np.ones_like(pred)
pred_filtered = signal.convolve2d(pred, kernel, mode='same')
# True positive for relaxed precision
relax_pre_tp = np.where((true==1) & (pred_filtered>0), map_ones, map_zeros)
true_filtered = signal.convolve2d(true, kernel, mode='same')
# True positive for relaxed recall
relax_rec_tp = np.where((pred==1) & (true_filtered>0), map_ones, map_zeros)
return np.sum(relax_pre_tp)/np.sum(pred), np.sum(relax_rec_tp)/np.sum(true)
def _convert_to_numpy(self, groundtruths, predictions):
"""Converts tesnors to numpy arrays."""
numpy_groundtruths = groundtruths.numpy()
numpy_predictions = predictions.numpy()
return numpy_groundtruths, numpy_predictions
def update_state(self, groundtruths, predictions):
"""Update segmentation results and groundtruth data.
Args:
groundtruths : Tuple of single Tensor [batch, width, height, 1],
groundtruth masks. range [0, 1]
predictions : Tuple of single Tensor [batch, width, height, 1],
predicted masks. range [0, 1]
"""
groundtruths, predictions = self._convert_to_numpy(groundtruths[0],
predictions[0])
for (true, pred) in zip(groundtruths, predictions):
self._groundtruths.append(true)
self._predictions.append(pred)
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Losses used for BASNet models."""
# Import libraries
import tensorflow as tf
EPSILON = 1e-5
class BASNetLoss:
"""BASNet hybrid loss."""
def __init__(self):
self._binary_crossentropy = tf.keras.losses.BinaryCrossentropy(
reduction=tf.keras.losses.Reduction.SUM, from_logits=False)
self._ssim = tf.image.ssim
def __call__(self, sigmoids, labels):
levels = sorted(sigmoids.keys())
labels = labels/255
labels_bce = tf.squeeze(labels, axis=-1)
labels = tf.cast(labels, tf.float32)
bce_losses = []
ssim_losses = []
iou_losses = []
for level in levels:
bce_losses.append(
self._binary_crossentropy(labels_bce, sigmoids[level]))
ssim_losses.append(
1 - self._ssim(sigmoids[level], labels, max_val=1.0))
iou_losses.append(
self._iou_loss(sigmoids[level], labels))
total_bce_loss = tf.math.add_n(bce_losses)
total_ssim_loss = tf.math.add_n(ssim_losses)
total_iou_loss = tf.math.add_n(iou_losses)
total_loss = total_bce_loss + total_ssim_loss + total_iou_loss
return total_loss
def _iou_loss(self, sigmoids, labels):
total_iou_loss = 0
Iand1 = tf.reduce_sum(sigmoids[:,:,:,:]*labels[:,:,:,:])
Ior1 = tf.reduce_sum(sigmoids[:,:,:,:])+tf.reduce_sum(labels[:,:,:,:])-Iand1
IoU = Iand1/Ior1
total_iou_loss += 1-IoU
return total_iou_loss
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Decoder of BASNet.
Boundary-Awar network (BASNet) were proposed in:
[1] Qin, Xuebin, et al.
Basnet: Boundary-aware salient object detection.
"""
# Import libraries
import tensorflow as tf
from official.modeling import tf_utils
from official.vision.beta.projects.basnet.modeling.layers import nn_blocks
# nf : num_filters, dr : dilation_rate
# (conv1_nf, conv1_dr, convm_nf, convm_dr, conv2_nf, conv2_dr, scale_factor)
BASNET_DECODER_SPECS = [
(512, 2, 512, 2, 512, 2, 32), #Bridge(Sup0)
(512, 1, 512, 2, 512, 2, 32), #Sup1, stage6d
(512, 1, 512, 1, 512, 1, 16), #Sup2, stage5d
(512, 1, 512, 1, 256, 1, 8), #Sup3, stage4d
(256, 1, 256, 1, 128, 1, 4), #Sup4, stage3d
(128, 1, 128, 1, 64, 1, 2), #Sup5, stage2d
(64, 1, 64, 1, 64, 1, 1) #Sup6, stage1d
]
@tf.keras.utils.register_keras_serializable(package='Vision')
class BASNet_Decoder(tf.keras.Model):
"""BASNet Decoder."""
def __init__(self,
input_specs,
use_separable_conv=False,
activation='relu',
use_sync_bn=False,
use_bias=True,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
**kwargs):
"""BASNet Decoder initialization function.
Args:
input_specs: `dict` input specifications. A dictionary consists of
{level: TensorShape} from a backbone.
use_separable_conv: `bool`, if True use separable convolution for
convolution in BASNet layers.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
**kwargs: keyword arguments to be passed.
"""
self._config_dict = {
'input_specs': input_specs,
'use_separable_conv': use_separable_conv,
'activation': activation,
'use_sync_bn': use_sync_bn,
'use_bias': use_bias,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon,
'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
}
if use_separable_conv:
conv2d = tf.keras.layers.SeparableConv2D
else:
conv2d = tf.keras.layers.Conv2D
if use_sync_bn:
norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
norm = tf.keras.layers.BatchNormalization
activation_fn = tf.keras.layers.Activation(
tf_utils.get_activation(activation))
# Build input feature pyramid.
if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = -1
else:
bn_axis = 1
# Get input feature pyramid from backbone.
inputs = self._build_input_pyramid(input_specs)
sup = {}
for i, spec in enumerate(BASNET_DECODER_SPECS):
if i == 0:
x = inputs['5'] # Bridge input
else:
x = tf.keras.layers.Concatenate(axis=-1)([x, inputs[str(6-i)]])
for j in range(3):
x = nn_blocks.ConvBlock(
filters=spec[2*j],
kernel_size=3,
strides=1,
dilation_rate=spec[2*j+1],
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activation='relu',
use_sync_bn=use_sync_bn,
use_bias=use_bias,
norm_momentum=0.99,
norm_epsilon=0.001
)(x)
output = tf.keras.layers.Conv2D(
filters=1, kernel_size=3, strides=1,
use_bias=use_bias, padding='same',
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer
)(x)
output = tf.keras.layers.UpSampling2D(
size=spec[6],
interpolation='bilinear'
)(output)
output = tf.keras.layers.Activation(
activation='sigmoid'
)(output)
sup[str(i+1)] = output
if i != 0:
x = tf.keras.layers.UpSampling2D(
size=2,
interpolation='bilinear'
)(x)
self._output_specs = {
str(order): sup[str(order)].get_shape()
for order in range(1, 7)
}
super(BASNet_Decoder, self).__init__(inputs=inputs, outputs=sup, **kwargs)
def _build_input_pyramid(self, input_specs):
assert isinstance(input_specs, dict)
inputs = {}
for level, spec in input_specs.items():
inputs[level] = tf.keras.Input(shape=spec[1:])
return inputs
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@property
def output_specs(self):
"""A dict of {order: TensorShape} pairs for the model output."""
return self._output_specs
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BASNet Encoder
Boundary-Awar network (BASNet) were proposed in:
[1] Qin, Xuebin, et al.
Basnet: Boundary-aware salient object detection.
"""
# Import libraries
import tensorflow as tf
from official.modeling import tf_utils
from official.vision.beta.modeling.backbones import factory
from official.vision.beta.projects.basnet.modeling.layers import nn_blocks
# Specifications for BASNet encoder.
# Each element in the block configuration is in the following format:
# (block_fn, num_filters, stride, block_repeats, maxpool)
BASNET_ENCODER_SPECS = [
('residual', 64, 1, 3, 0), #ResNet-34,
('residual', 128, 2, 4, 0), #ResNet-34,
('residual', 256, 2, 6, 0), #ResNet-34,
('residual', 512, 2, 3, 1), #ResNet-34,
('residual', 512, 1, 3, 1), #BASNet,
('residual', 512, 1, 3, 0), #BASNet,
]
@tf.keras.utils.register_keras_serializable(package='Vision')
class BASNet_Encoder(tf.keras.Model):
def __init__(self,
input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]),
activation='relu',
use_sync_bn=False,
use_bias=True,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
**kwargs):
"""BASNet_En initialization function.
Args:
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
use_bias: if True, use bias in conv2d.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
**kwargs: keyword arguments to be passed.
"""
self._input_specs = input_specs
self._use_sync_bn = use_sync_bn
self._use_bias = use_bias
self._activation = activation
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = -1
else:
bn_axis = 1
# Build BASNet.
inputs = tf.keras.Input(shape=input_specs.shape[1:])
x = tf.keras.layers.Conv2D(
filters=64, kernel_size=3, strides=1,
use_bias=self._use_bias, padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
inputs)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
x)
x = tf_utils.get_activation(activation)(x)
endpoints = {}
for i, spec in enumerate(BASNET_ENCODER_SPECS):
if spec[0] == 'residual':
block_fn = nn_blocks.ResBlock
else:
raise ValueError('Block fn `{}` is not supported.'.format(spec[0]))
x = self._block_group(
inputs=x,
filters=spec[1],
strides=spec[2],
block_fn=block_fn,
block_repeats=spec[3],
name='block_group_l{}'.format(i + 2))
endpoints[str(i)] = x
if spec[4]:
x = tf.keras.layers.MaxPool2D(pool_size=2, strides=2, padding='same')(x)
self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}
super(BASNet_Encoder, self).__init__(inputs=inputs, outputs=endpoints, **kwargs)
def _block_group(self,
inputs,
filters,
strides,
block_fn,
block_repeats=1,
name='block_group'):
"""Creates one group of blocks for the ResNet model.
Args:
inputs: `Tensor` of size `[batch, channels, height, width]`.
filters: `int` number of filters for the first convolution of the layer.
strides: `int` stride to use for the first convolution of the layer. If
greater than 1, this layer will downsample the input.
block_fn: Either `nn_blocks.ResidualBlock` or `nn_blocks.BottleneckBlock`.
block_repeats: `int` number of blocks contained in the layer.
name: `str`name for the block.
Returns:
The output `Tensor` of the block layer.
"""
x = block_fn(
filters=filters,
strides=strides,
use_projection=True,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=self._activation,
use_sync_bn=self._use_sync_bn,
use_bias=self._use_bias,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)(
inputs)
for _ in range(1, block_repeats):
x = block_fn(
filters=filters,
strides=1,
use_projection=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=self._activation,
use_sync_bn=self._use_sync_bn,
use_bias=self._use_bias,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)(
x)
return tf.identity(x, name=name)
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@property
def output_specs(self):
"""A dict of {level: TensorShape} pairs for the model output."""
return self._output_specs
@factory.register_backbone_builder('basnet_encoder')
def build_basnet_encoder(
input_specs: tf.keras.layers.InputSpec,
model_config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds BASNet Encoder backbone from a config."""
backbone_type = model_config.backbone.type
backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
assert backbone_type == 'basnet_encoder', (f'Inconsistent backbone type '
f'{backbone_type}')
return BASNet_Encoder(
input_specs=input_specs,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Build BASNet models."""
# Import libraries
import tensorflow as tf
@tf.keras.utils.register_keras_serializable(package='Vision')
class BASNetModel(tf.keras.Model):
"""A BASNet model.
Input images are passed through backbone first. Decoder network is then
applied, and finally, refinement module is applied on the output of the
decoder network.
"""
def __init__(self,
backbone,
decoder,
refinement,
**kwargs):
"""BASNet initialization function.
Args:
backbone: a backbone network. basnet_encoder
decoder: a decoder network. basnet_decoder
refinement: a module for salient map refinement
**kwargs: keyword arguments to be passed.
"""
super(BASNetModel, self).__init__(**kwargs)
self._config_dict = {
'backbone': backbone,
'decoder': decoder,
'refinement': refinement,
}
self.backbone = backbone
self.decoder = decoder
self.refinement = refinement
def call(self, inputs, training=None):
features = self.backbone(inputs)
if self.decoder:
features = self.decoder(features)
if self.refinement:
features['ref'] = self.refinement(features['7'])
return features
@property
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
items = dict(backbone=self.backbone)
if self.decoder is not None:
items.update(decoder=self.decoder)
if self.refinement is not None:
items.update(refinement=self.refinement)
return items
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for basnet network."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.vision.beta.projects.basnet.modeling import basnet_encoder
from official.vision.beta.projects.basnet.modeling import basnet_model
from official.vision.beta.projects.basnet.modeling import basnet_decoder
from official.vision.beta.projects.basnet.modeling import refunet
class BASNetNetworkTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(256),
(512),
)
def test_basnet_network_creation(
self, input_size):
"""Test for creation of a segmentation network."""
inputs = np.random.rand(2, input_size, input_size, 3)
tf.keras.backend.set_image_data_format('channels_last')
backbone = basnet_encoder.BASNet_Encoder()
decoder = basnet_decoder.BASNet_Decoder(
input_specs=backbone.output_specs)
refinement = refunet.RefUnet()
model = basnet_model.BASNetModel(
backbone=backbone,
decoder=decoder,
refinement=refinement
)
sigmoids = model(inputs)
#print(sigmoids['ref'].numpy().shape)
self.assertAllEqual(
[2, input_size, input_size, 1],
sigmoids['ref'].numpy().shape)
def test_serialize_deserialize(self):
"""Validate the network can be serialized and deserialized."""
backbone = basnet_encoder.BASNet_Encoder()
decoder = basnet_decoder.BASNet_Decoder(
input_specs=backbone.output_specs)
refinement = refunet.RefUnet()
model = basnet_model.BASNetModel(
backbone=backbone,
decoder=decoder,
refinement=refinement
)
config = model.get_config()
new_model = basnet_model.BASNetModel.from_config(config)
# Validate that the config can be forced to JSON.
_ = new_model.to_json()
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(model.get_config(), new_model.get_config())
if __name__ == '__main__':
tf.test.main()
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains common building blocks for neural networks."""
# Import libraries
import tensorflow as tf
from official.modeling import tf_utils
@tf.keras.utils.register_keras_serializable(package='Vision')
class ConvBlock(tf.keras.layers.Layer):
"""A (Conv+BN+Activation) block."""
def __init__(self,
filters,
strides,
dilation_rate,
kernel_size=3,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
activation='relu',
use_bias=False,
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
**kwargs):
"""A vgg block with BN after convolutions.
Args:
filters: `int` number of filters for the first two convolutions. Note that
the third and final convolution will use 4 times as many filters.
strides: `int` block stride. If greater than 1, this block will ultimately
downsample the input.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
**kwargs: keyword arguments to be passed.
"""
super(ConvBlock, self).__init__(**kwargs)
self._filters = filters
self._kernel_size = kernel_size
self._strides = strides
self._dilation_rate = dilation_rate
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._activation = activation
self._use_bias = use_bias
self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation)
def build(self, input_shape):
self._conv0 = tf.keras.layers.Conv2D(
filters=self._filters,
kernel_size=self._kernel_size,
strides=self._strides,
dilation_rate=self._dilation_rate,
padding='same',
use_bias=self._use_bias,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
super(ConvBlock, self).build(input_shape)
def get_config(self):
config = {
'filters': self._filters,
'kernel_size': self._kernel_size,
'strides': self._strides,
'dilation_rate': self._dilation_rate,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'use_bias': self._use_bias,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
}
base_config = super(ConvBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs, training=None):
x = self._conv0(inputs)
x = self._norm0(x)
x = self._activation_fn(x)
return x
@tf.keras.utils.register_keras_serializable(package='Vision')
class ResBlock(tf.keras.layers.Layer):
"""A residual block."""
def __init__(self,
filters,
strides,
use_projection=False,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
activation='relu',
use_sync_bn=False,
use_bias=False,
norm_momentum=0.99,
norm_epsilon=0.001,
**kwargs):
"""Initializes a residual block with BN after convolutions.
Args:
filters: An `int` number of filters for the first two convolutions. Note
that the third and final convolution will use 4 times as many filters.
strides: An `int` block stride. If greater than 1, this block will
ultimately downsample the input.
use_projection: A `bool` for whether this block should use a projection
shortcut (versus the default identity shortcut). This is usually `True`
for the first block of a block group, which may change the number of
filters and the resolution.
kernel_initializer: A `str` of kernel_initializer for convolutional
layers.
kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
Conv2D. Default to None.
bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2d.
Default to None.
activation: A `str` name of the activation function.
use_sync_bn: A `bool`. If True, use synchronized batch normalization.
use_bias: A `bool`. If True, use bias in conv2d.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero.
**kwargs: Additional keyword arguments to be passed.
"""
super(ResBlock, self).__init__(**kwargs)
self._filters = filters
self._strides = strides
self._use_projection = use_projection
self._use_sync_bn = use_sync_bn
self._use_bias = use_bias
self._activation = activation
self._kernel_initializer = kernel_initializer
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation)
def build(self, input_shape):
if self._use_projection:
self._shortcut = tf.keras.layers.Conv2D(
filters=self._filters,
kernel_size=1,
strides=self._strides,
use_bias=self._use_bias,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._conv1 = tf.keras.layers.Conv2D(
filters=self._filters,
kernel_size=3,
strides=self._strides,
padding='same',
use_bias=self._use_bias,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm1 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._conv2 = tf.keras.layers.Conv2D(
filters=self._filters,
kernel_size=3,
strides=1,
padding='same',
use_bias=self._use_bias,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm2 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
super(ResBlock, self).build(input_shape)
def get_config(self):
config = {
'filters': self._filters,
'strides': self._strides,
'use_projection': self._use_projection,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'use_bias': self._use_bias,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
}
base_config = super(ResBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs, training=None):
shortcut = inputs
if self._use_projection:
shortcut = self._shortcut(shortcut)
shortcut = self._norm0(shortcut)
x = self._conv1(inputs)
x = self._norm1(x)
x = self._activation_fn(x)
x = self._conv2(x)
x = self._norm2(x)
return self._activation_fn(x + shortcut)
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Residual Refinement Module of BASNet.
Boundary-Awar network (BASNet) were proposed in:
[1] Qin, Xuebin, et al.
Basnet: Boundary-aware salient object detection.
"""
# Import libraries
import tensorflow as tf
from official.vision.beta.projects.basnet.modeling.layers import nn_blocks
@tf.keras.utils.register_keras_serializable(package='Vision')
class RefUnet(tf.keras.Model):
def __init__(self,
input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 1]),
activation='relu',
use_sync_bn=False,
use_bias=True,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
**kwargs):
"""Residual Refinement Module of BASNet.
Args:
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
use_bias: if True, use bias in conv2d.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
**kwargs: keyword arguments to be passed.
"""
self._input_specs = input_specs
self._use_sync_bn = use_sync_bn
self._use_bias = use_bias
self._activation = activation
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = -1
else:
bn_axis = 1
# Build ResNet.
inputs = tf.keras.Input(shape=self._input_specs.shape[1:])
endpoints = {}
residual = inputs
x = tf.keras.layers.Conv2D(
filters=64, kernel_size=3, strides=1,
use_bias=self._use_bias, padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
inputs)
# Top-down
for i in range(4):
x = nn_blocks.ConvBlock(
filters=64,
kernel_size=3,
strides=1,
dilation_rate=1,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation='relu',
use_sync_bn=self._use_sync_bn,
use_bias=self._use_bias,
norm_momentum=0.99,
norm_epsilon=0.001
)(x)
endpoints[str(i)] = x
x = tf.keras.layers.MaxPool2D(pool_size=2, strides=2, padding='valid')(x)
# Bridge
x = nn_blocks.ConvBlock(
filters=64,
kernel_size=3,
strides=1,
dilation_rate=1,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation='relu',
use_sync_bn=self._use_sync_bn,
use_bias=self._use_bias,
norm_momentum=0.99,
norm_epsilon=0.001
)(x)
x = tf.keras.layers.UpSampling2D(
size=2,
interpolation='bilinear'
)(x)
# Bottom-up
for i in range(4):
x = tf.keras.layers.Concatenate(axis=-1)([endpoints[str(3-i)], x])
x = nn_blocks.ConvBlock(
filters=64,
kernel_size=3,
strides=1,
dilation_rate=1,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation='relu',
use_sync_bn=self._use_sync_bn,
use_bias=self._use_bias,
norm_momentum=0.99,
norm_epsilon=0.001
)(x)
if i == 3:
x = tf.keras.layers.Conv2D(
filters=1, kernel_size=3, strides=1,
use_bias=self._use_bias, padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer
)(x)
else:
x = tf.keras.layers.UpSampling2D(
size=2,
interpolation='bilinear'
)(x)
residual = tf.cast(residual, dtype=x.dtype)
output = x + residual
output = tf.keras.layers.Activation(
activation='sigmoid'
)(output)
self._output_specs = output.get_shape()
super(RefUnet, self).__init__(inputs=inputs, outputs=output, **kwargs)
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@property
def output_specs(self):
return self._output_specs
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Detection input and model functions for serving/inference."""
import tensorflow as tf
from official.vision.beta.projects.basnet.tasks import basnet
from official.vision.beta.ops import preprocess_ops
from official.vision.beta.serving import semantic_segmentation
MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255)
STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
class BASNetModule(semantic_segmentation.SegmentationModule):
"""BASNet Module."""
def _build_model(self):
input_specs = tf.keras.layers.InputSpec(
shape=[self._batch_size] + self._input_image_size + [3])
return basnet.build_basnet_model(
input_specs=input_specs,
model_config=self.params.task.model,
l2_regularizer=None)
def serve(self, images):
"""Cast image to float and run inference.
Args:
images: uint8 Tensor of shape [batch_size, None, None, 3]
Returns:
Tensor holding classification output logits.
"""
with tf.device('cpu:0'):
images = tf.cast(images, dtype=tf.float32)
images = tf.nest.map_structure(
tf.identity,
tf.map_fn(
self._build_inputs, elems=images,
fn_output_signature=tf.TensorSpec(
shape=self._input_image_size + [3], dtype=tf.float32),
parallel_iterations=32
)
)
masks = self.inference_step(images)
output = tf.image.resize(masks['ref'], self._input_image_size, method='bilinear')
return dict(predicted_masks=output)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
r"""Vision models export binary for serving/inference.
To export a trained checkpoint in saved_model format (shell script):
EXPERIMENT_TYPE = XX
CHECKPOINT_PATH = XX
EXPORT_DIR_PATH = XX
export_saved_model --experiment=${EXPERIMENT_TYPE} \
--export_dir=${EXPORT_DIR_PATH}/ \
--checkpoint_path=${CHECKPOINT_PATH} \
--batch_size=2 \
--input_image_size=224,224
To serve (python):
export_dir_path = XX
input_type = XX
input_images = XX
imported = tf.saved_model.load(export_dir_path)
model_fn = imported.signatures['serving_default']
output = model_fn(input_images)
"""
from absl import app
from absl import flags
from official.vision import beta
from official.core import exp_factory
from official.modeling import hyperparams
from official.vision.beta.serving import export_saved_model_lib
from official.vision.beta.projects.basnet.serving import basnet
FLAGS = flags.FLAGS
flags.DEFINE_string(
'experiment', None, 'experiment type, e.g. retinanet_resnetfpn_coco')
flags.DEFINE_string('export_dir', None, 'The export directory.')
flags.DEFINE_string('checkpoint_path', None, 'Checkpoint path.')
flags.DEFINE_multi_string(
'config_file',
default=None,
help='YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.')
flags.DEFINE_string(
'params_override', '',
'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.')
flags.DEFINE_integer(
'batch_size', None, 'The batch size.')
flags.DEFINE_string(
'input_type', 'image_tensor',
'One of `image_tensor`, `image_bytes`, `tf_example`.')
flags.DEFINE_string(
'input_image_size', '224,224',
'The comma-separated string of two integers representing the height,width '
'of the input to the model.')
def main(_):
params = exp_factory.get_exp_config(FLAGS.experiment)
for config_file in FLAGS.config_file or []:
params = hyperparams.override_params_dict(
params, config_file, is_strict=True)
if FLAGS.params_override:
params = hyperparams.override_params_dict(
params, FLAGS.params_override, is_strict=True)
params.validate()
params.lock()
export_saved_model_lib.export_inference_graph(
input_type=FLAGS.input_type,
batch_size=FLAGS.batch_size,
input_image_size=[int(x) for x in FLAGS.input_image_size.split(',')],
params=params,
checkpoint_path=FLAGS.checkpoint_path,
export_dir=FLAGS.export_dir,
export_module=basnet.BASNetModule(
params=params,
batch_size=FLAGS.batch_size,
input_image_size=[int(x) for x in FLAGS.input_image_size.split(',')]),
export_checkpoint_subdir='checkpoint',
export_saved_model_subdir='saved_model')
if __name__ == '__main__':
app.run(main)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BASNet task definition."""
from typing import Optional
from absl import logging
import tensorflow as tf
from official.core import base_task
from official.core import input_reader
from official.core import task_factory
from official.vision.beta.projects.basnet.configs import basnet as exp_cfg
from official.vision.beta.dataloaders import segmentation_input
from official.vision.beta.projects.basnet.evaluation import max_f
from official.vision.beta.projects.basnet.evaluation import relax_f
from official.vision.beta.projects.basnet.evaluation import mae
from official.vision.beta.projects.basnet.losses import basnet_losses
from official.vision.beta.projects.basnet.modeling import basnet_encoder
from official.vision.beta.projects.basnet.modeling import basnet_model
from official.vision.beta.projects.basnet.modeling import basnet_decoder
from official.vision.beta.projects.basnet.modeling import refunet
def build_basnet_model(
input_specs: tf.keras.layers.InputSpec,
model_config: exp_cfg.BASNetModel,
l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds BASNet model."""
backbone = basnet_encoder.BASNet_Encoder(
input_specs=input_specs)
norm_activation_config = model_config.norm_activation
decoder = basnet_decoder.BASNet_Decoder(
input_specs=backbone.output_specs,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
activation=norm_activation_config.activation,
kernel_regularizer=l2_regularizer)
refinement = refunet.RefUnet()
norm_activation_config = model_config.norm_activation
model = basnet_model.BASNetModel(backbone, decoder, refinement)
return model
@task_factory.register_task_cls(exp_cfg.BASNetTask)
class BASNetTask(base_task.Task):
"""A task for basnet."""
def build_model(self):
"""Builds basnet model."""
input_specs = tf.keras.layers.InputSpec(
shape=[None] + self.task_config.model.input_size)
l2_weight_decay = self.task_config.losses.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf.keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None)
model = build_basnet_model(
input_specs=input_specs,
model_config=self.task_config.model,
l2_regularizer=l2_regularizer)
return model
def initialize(self, model: tf.keras.Model):
"""Loads pretrained checkpoint."""
if not self.task_config.init_checkpoint:
return
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
# Restoring checkpoint.
if 'all' in self.task_config.init_checkpoint_modules:
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.assert_consumed()
else:
ckpt_items = {}
if 'backbone' in self.task_config.init_checkpoint_modules:
ckpt_items.update(backbone=model.backbone)
if 'decoder' in self.task_config.init_checkpoint_modules:
ckpt_items.update(decoder=model.decoder)
ckpt = tf.train.Checkpoint(**ckpt_items)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def build_inputs(self,
params: exp_cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None):
"""Builds BASNet input."""
ignore_label = self.task_config.losses.ignore_label
decoder = segmentation_input.Decoder()
parser = segmentation_input.Parser(
output_size=params.output_size,
crop_size=params.crop_size,
ignore_label=ignore_label,
aug_rand_hflip=params.aug_rand_hflip,
dtype=params.dtype)
reader = input_reader.InputReader(
params,
dataset_fn=tf.data.TFRecordDataset,
decoder_fn=decoder.decode,
parser_fn=parser.parse_fn(params.is_training))
dataset = reader.read(input_context=input_context)
return dataset
def build_losses(self, label, model_outputs, aux_losses=None):
"""Hybrid loss proposed in BASNet.
Args:
label: label.
model_outputs: Output logits of the classifier.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
Returns:
The total loss tensor.
"""
loss_params = self._task_config.losses
basnet_loss_fn = basnet_losses.BASNetLoss()
total_loss = basnet_loss_fn(model_outputs, label['masks'])
if aux_losses:
total_loss += tf.add_n(aux_losses)
return total_loss
def build_metrics(self, training=False):
"""Gets streaming metrics for training/validation."""
metrics = []
if training:
metrics = []
else:
self.mae_metric = mae.MAE()
self.maxf_metric = max_f.maxFscore()
self.relaxf_metric = relax_f.relaxedFscore()
return metrics
def train_step(self, inputs, model, optimizer, metrics=None):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape:
outputs = model(features, training=True)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure(
lambda x: tf.cast(x, tf.float32), outputs)
# Computes per-replica loss.
loss = self.build_losses(
model_outputs=outputs, label=labels, aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if isinstance(
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if isinstance(
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
# Apply gradient clipping.
if self.task_config.gradient_clip_norm > 0:
grads, _ = tf.clip_by_global_norm(
grads, self.task_config.gradient_clip_norm)
optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {self.loss: loss}
return logs
def validation_step(self, inputs, model, metrics=None):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
outputs = self.inference_step(features, model)
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
loss = 0
logs = {self.loss: loss}
temp = labels['masks']
logs.update({self.mae_metric.name: (labels['masks'], outputs['ref'])})
logs.update({self.maxf_metric.name: (labels['masks'], outputs['ref'])})
logs.update({self.relaxf_metric.name: (labels['masks'], outputs['ref'])})
return logs
def inference_step(self, inputs, model):
"""Performs the forward step."""
return model(inputs, training=False)
def aggregate_logs(self, state=None, step_outputs=None):
if state is None:
self.mae_metric.reset_states()
self.maxf_metric.reset_states()
self.relaxf_metric.reset_states()
state = self.mae_metric
self.mae_metric.update_state(
step_outputs[self.mae_metric.name][0],
step_outputs[self.mae_metric.name][1])
self.maxf_metric.update_state(
step_outputs[self.maxf_metric.name][0],
step_outputs[self.maxf_metric.name][1])
self.relaxf_metric.update_state(
step_outputs[self.relaxf_metric.name][0],
step_outputs[self.relaxf_metric.name][1])
return state
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
result = {}
result['MAE'] = self.mae_metric.result()
result['maxF'] = self.maxf_metric.result()
result['relaxF'] = self.relaxf_metric.result()
return result
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow Model Garden Vision training driver."""
from absl import app
from absl import flags
import gin
# pylint: disable=unused-import
#from official.common import registry_imports
from official.vision.beta.projects.basnet.common import registry_imports
# pylint: enable=unused-import
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
#import os
#os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
#os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3"
FLAGS = flags.FLAGS
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
if 'train' in FLAGS.mode:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale,
use_experimental_api=True)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment