Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9c069a70
Commit
9c069a70
authored
Sep 03, 2021
by
A. Unique TensorFlower
Browse files
Merge pull request #10072 from gunho1123:master
PiperOrigin-RevId: 394727655
parents
bacf03e3
2ab8af9f
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
2117 additions
and
0 deletions
+2117
-0
official/projects/basnet/README.md
official/projects/basnet/README.md
+35
-0
official/projects/basnet/configs/basnet.py
official/projects/basnet/configs/basnet.py
+156
-0
official/projects/basnet/configs/basnet_test.py
official/projects/basnet/configs/basnet_test.py
+42
-0
official/projects/basnet/configs/experiments/basnet_dut_gpu.yaml
...l/projects/basnet/configs/experiments/basnet_dut_gpu.yaml
+10
-0
official/projects/basnet/evaluation/metrics.py
official/projects/basnet/evaluation/metrics.py
+329
-0
official/projects/basnet/evaluation/metrics_test.py
official/projects/basnet/evaluation/metrics_test.py
+68
-0
official/projects/basnet/losses/basnet_losses.py
official/projects/basnet/losses/basnet_losses.py
+65
-0
official/projects/basnet/modeling/basnet_model.py
official/projects/basnet/modeling/basnet_model.py
+442
-0
official/projects/basnet/modeling/basnet_model_test.py
official/projects/basnet/modeling/basnet_model_test.py
+76
-0
official/projects/basnet/modeling/nn_blocks.py
official/projects/basnet/modeling/nn_blocks.py
+245
-0
official/projects/basnet/modeling/refunet.py
official/projects/basnet/modeling/refunet.py
+165
-0
official/projects/basnet/serving/basnet.py
official/projects/basnet/serving/basnet.py
+66
-0
official/projects/basnet/serving/export_saved_model.py
official/projects/basnet/serving/export_saved_model.py
+106
-0
official/projects/basnet/tasks/basnet.py
official/projects/basnet/tasks/basnet.py
+281
-0
official/projects/basnet/train.py
official/projects/basnet/train.py
+31
-0
No files found.
official/projects/basnet/README.md
0 → 100644
View file @
9c069a70
# BASNet: Boundary-Aware Salient Object Detection
This repository is the unofficial implementation of the following paper. Please
see the paper
[
BASNet: Boundary-Aware Salient Object Detection
](
https://openaccess.thecvf.com/content_CVPR_2019/html/Qin_BASNet_Boundary-Aware_Salient_Object_Detection_CVPR_2019_paper.html
)
for more details.
## Requirements
[

](https://github.com/tensorflow/tensorflow/releases/tag/v2.4.0)
[

](https://www.python.org/downloads/release/python-379/)
## Train
```
shell
$
python3 train.py
\
--experiment
=
basnet_duts
\
--mode
=
train
\
--model_dir
=
$MODEL_DIR
\
--config_file
=
./configs/experiments/basnet_dut_gpu.yaml
```
## Test
```
shell
$
python3 train.py
\
--experiment
=
basnet_duts
\
--mode
=
eval
\
--model_dir
=
$MODEL_DIR
\
--config_file
=
./configs/experiments/basnet_dut_gpu.yaml
--params_override
=
'runtime.num_gpus=1, runtime.distribution_strategy=one_device, task.model.input_size=[256, 256, 3]'
```
## Results
Dataset | maxF
<sub>
β
</sub>
| relaxF
<sub>
β
</sub>
| MAE
:--------- | :--------------- | :------------------- | -------:
DUTS-TE | 0.865 | 0.793 | 0.046
official/projects/basnet/configs/basnet.py
0 → 100644
View file @
9c069a70
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BASNet configuration definition."""
import
dataclasses
import
os
from
typing
import
List
,
Optional
,
Union
from
official.core
import
exp_factory
from
official.modeling
import
hyperparams
from
official.modeling
import
optimization
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.vision.beta.configs
import
common
@
dataclasses
.
dataclass
class
DataConfig
(
cfg
.
DataConfig
):
"""Input config for training."""
output_size
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
# If crop_size is specified, image will be resized first to
# output_size, then crop of size crop_size will be cropped.
crop_size
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
input_path
:
str
=
''
global_batch_size
:
int
=
0
is_training
:
bool
=
True
dtype
:
str
=
'float32'
shuffle_buffer_size
:
int
=
1000
cycle_length
:
int
=
10
resize_eval_groundtruth
:
bool
=
True
groundtruth_padded_size
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
aug_rand_hflip
:
bool
=
True
file_type
:
str
=
'tfrecord'
@
dataclasses
.
dataclass
class
BASNetModel
(
hyperparams
.
Config
):
"""BASNet model config."""
input_size
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
use_bias
:
bool
=
False
norm_activation
:
common
.
NormActivation
=
common
.
NormActivation
()
@
dataclasses
.
dataclass
class
Losses
(
hyperparams
.
Config
):
label_smoothing
:
float
=
0.1
ignore_label
:
int
=
0
# will be treated as background
l2_weight_decay
:
float
=
0.0
use_groundtruth_dimension
:
bool
=
True
@
dataclasses
.
dataclass
class
BASNetTask
(
cfg
.
TaskConfig
):
"""The model config."""
model
:
BASNetModel
=
BASNetModel
()
train_data
:
DataConfig
=
DataConfig
(
is_training
=
True
)
validation_data
:
DataConfig
=
DataConfig
(
is_training
=
False
)
losses
:
Losses
=
Losses
()
gradient_clip_norm
:
float
=
0.0
init_checkpoint
:
Optional
[
str
]
=
None
init_checkpoint_modules
:
Union
[
str
,
List
[
str
]]
=
'backbone'
# all, backbone, and/or decoder
@
exp_factory
.
register_config_factory
(
'basnet'
)
def
basnet
()
->
cfg
.
ExperimentConfig
:
"""BASNet general."""
return
cfg
.
ExperimentConfig
(
task
=
BASNetModel
(),
trainer
=
cfg
.
TrainerConfig
(),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
# DUTS Dataset
DUTS_TRAIN_EXAMPLES
=
10553
DUTS_VAL_EXAMPLES
=
5019
DUTS_INPUT_PATH_BASE_TR
=
'DUTS_DATASET'
DUTS_INPUT_PATH_BASE_VAL
=
'DUTS_DATASET'
@
exp_factory
.
register_config_factory
(
'basnet_duts'
)
def
basnet_duts
()
->
cfg
.
ExperimentConfig
:
"""Image segmentation on duts with basnet."""
train_batch_size
=
64
eval_batch_size
=
16
steps_per_epoch
=
DUTS_TRAIN_EXAMPLES
//
train_batch_size
config
=
cfg
.
ExperimentConfig
(
task
=
BASNetTask
(
model
=
BASNetModel
(
input_size
=
[
None
,
None
,
3
],
use_bias
=
True
,
norm_activation
=
common
.
NormActivation
(
activation
=
'relu'
,
norm_momentum
=
0.99
,
norm_epsilon
=
1e-3
,
use_sync_bn
=
True
)),
losses
=
Losses
(
l2_weight_decay
=
0
),
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
DUTS_INPUT_PATH_BASE_TR
,
'tf_record_train'
),
file_type
=
'tfrecord'
,
crop_size
=
[
224
,
224
],
output_size
=
[
256
,
256
],
is_training
=
True
,
global_batch_size
=
train_batch_size
,
),
validation_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
DUTS_INPUT_PATH_BASE_VAL
,
'tf_record_test'
),
file_type
=
'tfrecord'
,
output_size
=
[
256
,
256
],
is_training
=
False
,
global_batch_size
=
eval_batch_size
,
),
init_checkpoint
=
'gs://cloud-basnet-checkpoints/basnet_encoder_imagenet/ckpt-340306'
,
init_checkpoint_modules
=
'backbone'
),
trainer
=
cfg
.
TrainerConfig
(
steps_per_loop
=
steps_per_epoch
,
summary_interval
=
steps_per_epoch
,
checkpoint_interval
=
steps_per_epoch
,
train_steps
=
300
*
steps_per_epoch
,
validation_steps
=
DUTS_VAL_EXAMPLES
//
eval_batch_size
,
validation_interval
=
steps_per_epoch
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adam'
,
'adam'
:
{
'beta_1'
:
0.9
,
'beta_2'
:
0.999
,
'epsilon'
:
1e-8
,
}
},
'learning_rate'
:
{
'type'
:
'constant'
,
'constant'
:
{
'learning_rate'
:
0.001
}
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
official/projects/basnet/configs/basnet_test.py
0 → 100644
View file @
9c069a70
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for basnet configs."""
# pylint: disable=unused-import
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.core
import
exp_factory
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.projects.basnet.configs
import
basnet
as
exp_cfg
class
BASNetConfigTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
((
'basnet_duts'
,))
def
test_basnet_configs
(
self
,
config_name
):
config
=
exp_factory
.
get_exp_config
(
config_name
)
self
.
assertIsInstance
(
config
,
cfg
.
ExperimentConfig
)
self
.
assertIsInstance
(
config
.
task
,
exp_cfg
.
BASNetTask
)
self
.
assertIsInstance
(
config
.
task
.
model
,
exp_cfg
.
BASNetModel
)
self
.
assertIsInstance
(
config
.
task
.
train_data
,
exp_cfg
.
DataConfig
)
config
.
task
.
train_data
.
is_training
=
None
with
self
.
assertRaises
(
KeyError
):
config
.
validate
()
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/projects/basnet/configs/experiments/basnet_dut_gpu.yaml
0 → 100644
View file @
9c069a70
runtime
:
distribution_strategy
:
'
mirrored'
mixed_precision_dtype
:
'
float32'
num_gpus
:
8
task
:
train_data
:
dtype
:
'
float32'
validation_data
:
resize_eval_groundtruth
:
true
dtype
:
'
float32'
official/projects/basnet/evaluation/metrics.py
0 → 100644
View file @
9c069a70
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation metrics for BASNet.
The MAE and maxFscore implementations are a modified version of
https://github.com/xuebinqin/Binary-Segmentation-Evaluation-Tool
"""
import
numpy
as
np
import
scipy.signal
class
MAE
:
"""Mean Absolute Error(MAE) metric for basnet."""
def
__init__
(
self
):
"""Constructs MAE metric class."""
self
.
reset_states
()
@
property
def
name
(
self
):
return
'MAE'
def
reset_states
(
self
):
"""Resets internal states for a fresh run."""
self
.
_predictions
=
[]
self
.
_groundtruths
=
[]
def
result
(
self
):
"""Evaluates segmentation results, and reset_states."""
metric_result
=
self
.
evaluate
()
# Cleans up the internal variables in order for a fresh eval next time.
self
.
reset_states
()
return
metric_result
def
evaluate
(
self
):
"""Evaluates with masks from all images.
Returns:
average_mae: average MAE with float numpy.
"""
mae_total
=
0.0
for
(
true
,
pred
)
in
zip
(
self
.
_groundtruths
,
self
.
_predictions
):
# Computes MAE
mae
=
self
.
_compute_mae
(
true
,
pred
)
mae_total
+=
mae
average_mae
=
mae_total
/
len
(
self
.
_groundtruths
)
return
average_mae
def
_mask_normalize
(
self
,
mask
):
return
mask
/
(
np
.
amax
(
mask
)
+
1e-8
)
def
_compute_mae
(
self
,
true
,
pred
):
h
,
w
=
true
.
shape
[
0
],
true
.
shape
[
1
]
mask1
=
self
.
_mask_normalize
(
true
)
mask2
=
self
.
_mask_normalize
(
pred
)
sum_error
=
np
.
sum
(
np
.
absolute
((
mask1
.
astype
(
float
)
-
mask2
.
astype
(
float
))))
mae_error
=
sum_error
/
(
float
(
h
)
*
float
(
w
)
+
1e-8
)
return
mae_error
def
_convert_to_numpy
(
self
,
groundtruths
,
predictions
):
"""Converts tesnors to numpy arrays."""
numpy_groundtruths
=
groundtruths
.
numpy
()
numpy_predictions
=
predictions
.
numpy
()
return
numpy_groundtruths
,
numpy_predictions
def
update_state
(
self
,
groundtruths
,
predictions
):
"""Update segmentation results and groundtruth data.
Args:
groundtruths : Tuple of single Tensor [batch, width, height, 1],
groundtruth masks. range [0, 1]
predictions : Tuple of single Tensor [batch, width, height, 1],
predicted masks. range [0, 1]
"""
groundtruths
,
predictions
=
self
.
_convert_to_numpy
(
groundtruths
[
0
],
predictions
[
0
])
for
(
true
,
pred
)
in
zip
(
groundtruths
,
predictions
):
self
.
_groundtruths
.
append
(
true
)
self
.
_predictions
.
append
(
pred
)
class
MaxFscore
:
"""Maximum F-score metric for basnet."""
def
__init__
(
self
):
"""Constructs BASNet evaluation class."""
self
.
reset_states
()
@
property
def
name
(
self
):
return
'MaxFScore'
def
reset_states
(
self
):
"""Resets internal states for a fresh run."""
self
.
_predictions
=
[]
self
.
_groundtruths
=
[]
def
result
(
self
):
"""Evaluates segmentation results, and reset_states."""
metric_result
=
self
.
evaluate
()
# Cleans up the internal variables in order for a fresh eval next time.
self
.
reset_states
()
return
metric_result
def
evaluate
(
self
):
"""Evaluates with masks from all images.
Returns:
f_max: maximum F-score value.
"""
mybins
=
np
.
arange
(
0
,
256
)
beta
=
0.3
precisions
=
np
.
zeros
((
len
(
self
.
_groundtruths
),
len
(
mybins
)
-
1
))
recalls
=
np
.
zeros
((
len
(
self
.
_groundtruths
),
len
(
mybins
)
-
1
))
for
i
,
(
true
,
pred
)
in
enumerate
(
zip
(
self
.
_groundtruths
,
self
.
_predictions
)):
# Compute F-score
true
=
self
.
_mask_normalize
(
true
)
*
255.0
pred
=
self
.
_mask_normalize
(
pred
)
*
255.0
pre
,
rec
=
self
.
_compute_pre_rec
(
true
,
pred
,
mybins
=
np
.
arange
(
0
,
256
))
precisions
[
i
,
:]
=
pre
recalls
[
i
,
:]
=
rec
precisions
=
np
.
sum
(
precisions
,
0
)
/
(
len
(
self
.
_groundtruths
)
+
1e-8
)
recalls
=
np
.
sum
(
recalls
,
0
)
/
(
len
(
self
.
_groundtruths
)
+
1e-8
)
f
=
(
1
+
beta
)
*
precisions
*
recalls
/
(
beta
*
precisions
+
recalls
+
1e-8
)
f_max
=
np
.
max
(
f
)
f_max
=
f_max
.
astype
(
np
.
float32
)
return
f_max
def
_mask_normalize
(
self
,
mask
):
return
mask
/
(
np
.
amax
(
mask
)
+
1e-8
)
def
_compute_pre_rec
(
self
,
true
,
pred
,
mybins
=
np
.
arange
(
0
,
256
)):
"""Computes relaxed precision and recall."""
# pixel number of ground truth foreground regions
gt_num
=
true
[
true
>
128
].
size
# mask predicted pixel values in the ground truth foreground region
pp
=
pred
[
true
>
128
]
# mask predicted pixel values in the ground truth bacground region
nn
=
pred
[
true
<=
128
]
pp_hist
,
_
=
np
.
histogram
(
pp
,
bins
=
mybins
)
nn_hist
,
_
=
np
.
histogram
(
nn
,
bins
=
mybins
)
pp_hist_flip
=
np
.
flipud
(
pp_hist
)
nn_hist_flip
=
np
.
flipud
(
nn_hist
)
pp_hist_flip_cum
=
np
.
cumsum
(
pp_hist_flip
)
nn_hist_flip_cum
=
np
.
cumsum
(
nn_hist_flip
)
precision
=
pp_hist_flip_cum
/
(
pp_hist_flip_cum
+
nn_hist_flip_cum
+
1e-8
)
# TP/(TP+FP)
recall
=
pp_hist_flip_cum
/
(
gt_num
+
1e-8
)
# TP/(TP+FN)
precision
[
np
.
isnan
(
precision
)]
=
0.0
recall
[
np
.
isnan
(
recall
)]
=
0.0
pre_len
=
len
(
precision
)
rec_len
=
len
(
recall
)
return
np
.
reshape
(
precision
,
(
pre_len
)),
np
.
reshape
(
recall
,
(
rec_len
))
def
_convert_to_numpy
(
self
,
groundtruths
,
predictions
):
"""Converts tesnors to numpy arrays."""
numpy_groundtruths
=
groundtruths
.
numpy
()
numpy_predictions
=
predictions
.
numpy
()
return
numpy_groundtruths
,
numpy_predictions
def
update_state
(
self
,
groundtruths
,
predictions
):
"""Update segmentation results and groundtruth data.
Args:
groundtruths : Tuple of single Tensor [batch, width, height, 1],
groundtruth masks. range [0, 1]
predictions : Tuple of signle Tensor [batch, width, height, 1],
predicted masks. range [0, 1]
"""
groundtruths
,
predictions
=
self
.
_convert_to_numpy
(
groundtruths
[
0
],
predictions
[
0
])
for
(
true
,
pred
)
in
zip
(
groundtruths
,
predictions
):
self
.
_groundtruths
.
append
(
true
)
self
.
_predictions
.
append
(
pred
)
class
RelaxedFscore
:
"""Relaxed F-score metric for basnet."""
def
__init__
(
self
):
"""Constructs BASNet evaluation class."""
self
.
reset_states
()
@
property
def
name
(
self
):
return
'RelaxFScore'
def
reset_states
(
self
):
"""Resets internal states for a fresh run."""
self
.
_predictions
=
[]
self
.
_groundtruths
=
[]
def
result
(
self
):
"""Evaluates segmentation results, and reset_states."""
metric_result
=
self
.
evaluate
()
# Cleans up the internal variables in order for a fresh eval next time.
self
.
reset_states
()
return
metric_result
def
evaluate
(
self
):
"""Evaluates with masks from all images.
Returns:
relax_f: relaxed F-score value.
"""
beta
=
0.3
rho
=
3
relax_fs
=
np
.
zeros
(
len
(
self
.
_groundtruths
))
erode_kernel
=
np
.
ones
((
3
,
3
))
for
i
,
(
true
,
pred
)
in
enumerate
(
zip
(
self
.
_groundtruths
,
self
.
_predictions
)):
true
=
self
.
_mask_normalize
(
true
)
pred
=
self
.
_mask_normalize
(
pred
)
true
=
np
.
squeeze
(
true
,
axis
=-
1
)
pred
=
np
.
squeeze
(
pred
,
axis
=-
1
)
# binary saliency mask (S_bw), threshold 0.5
pred
[
pred
>=
0.5
]
=
1
pred
[
pred
<
0.5
]
=
0
# compute eroded binary mask (S_erd) of S_bw
pred_erd
=
self
.
_compute_erosion
(
pred
,
erode_kernel
)
pred_xor
=
np
.
logical_xor
(
pred_erd
,
pred
)
# convert True/False to 1/0
pred_xor
=
pred_xor
*
1
# same method for ground truth
true
[
true
>=
0.5
]
=
1
true
[
true
<
0.5
]
=
0
true_erd
=
self
.
_compute_erosion
(
true
,
erode_kernel
)
true_xor
=
np
.
logical_xor
(
true_erd
,
true
)
true_xor
=
true_xor
*
1
pre
,
rec
=
self
.
_compute_relax_pre_rec
(
true_xor
,
pred_xor
,
rho
)
relax_fs
[
i
]
=
(
1
+
beta
)
*
pre
*
rec
/
(
beta
*
pre
+
rec
+
1e-8
)
relax_f
=
np
.
sum
(
relax_fs
,
0
)
/
(
len
(
self
.
_groundtruths
)
+
1e-8
)
relax_f
=
relax_f
.
astype
(
np
.
float32
)
return
relax_f
def
_mask_normalize
(
self
,
mask
):
return
mask
/
(
np
.
amax
(
mask
)
+
1e-8
)
def
_compute_erosion
(
self
,
mask
,
kernel
):
kernel_full
=
np
.
sum
(
kernel
)
mask_erd
=
scipy
.
signal
.
convolve2d
(
mask
,
kernel
,
mode
=
'same'
)
mask_erd
[
mask_erd
<
kernel_full
]
=
0
mask_erd
[
mask_erd
==
kernel_full
]
=
1
return
mask_erd
def
_compute_relax_pre_rec
(
self
,
true
,
pred
,
rho
):
"""Computes relaxed precision and recall."""
kernel
=
np
.
ones
((
2
*
rho
-
1
,
2
*
rho
-
1
))
map_zeros
=
np
.
zeros_like
(
pred
)
map_ones
=
np
.
ones_like
(
pred
)
pred_filtered
=
scipy
.
signal
.
convolve2d
(
pred
,
kernel
,
mode
=
'same'
)
# True positive for relaxed precision
relax_pre_tp
=
np
.
where
((
true
==
1
)
&
(
pred_filtered
>
0
),
map_ones
,
map_zeros
)
true_filtered
=
scipy
.
signal
.
convolve2d
(
true
,
kernel
,
mode
=
'same'
)
# True positive for relaxed recall
relax_rec_tp
=
np
.
where
((
pred
==
1
)
&
(
true_filtered
>
0
),
map_ones
,
map_zeros
)
return
np
.
sum
(
relax_pre_tp
)
/
np
.
sum
(
pred
),
np
.
sum
(
relax_rec_tp
)
/
np
.
sum
(
true
)
def
_convert_to_numpy
(
self
,
groundtruths
,
predictions
):
"""Converts tesnors to numpy arrays."""
numpy_groundtruths
=
groundtruths
.
numpy
()
numpy_predictions
=
predictions
.
numpy
()
return
numpy_groundtruths
,
numpy_predictions
def
update_state
(
self
,
groundtruths
,
predictions
):
"""Update segmentation results and groundtruth data.
Args:
groundtruths : Tuple of single Tensor [batch, width, height, 1],
groundtruth masks. range [0, 1]
predictions : Tuple of single Tensor [batch, width, height, 1],
predicted masks. range [0, 1]
"""
groundtruths
,
predictions
=
self
.
_convert_to_numpy
(
groundtruths
[
0
],
predictions
[
0
])
for
(
true
,
pred
)
in
zip
(
groundtruths
,
predictions
):
self
.
_groundtruths
.
append
(
true
)
self
.
_predictions
.
append
(
pred
)
official/projects/basnet/evaluation/metrics_test.py
0 → 100644
View file @
9c069a70
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for metrics.py."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.projects.basnet.evaluation
import
metrics
class
BASNetMetricTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
def
test_mae
(
self
):
input_size
=
224
inputs
=
(
tf
.
random
.
uniform
([
2
,
input_size
,
input_size
,
1
]),)
labels
=
(
tf
.
random
.
uniform
([
2
,
input_size
,
input_size
,
1
]),)
mae_obj
=
metrics
.
MAE
()
mae_obj
.
reset_states
()
mae_obj
.
update_state
(
labels
,
inputs
)
output
=
mae_obj
.
result
()
mae_tf
=
tf
.
keras
.
metrics
.
MeanAbsoluteError
()
mae_tf
.
reset_state
()
mae_tf
.
update_state
(
labels
[
0
],
inputs
[
0
])
compare
=
mae_tf
.
result
().
numpy
()
self
.
assertAlmostEqual
(
output
,
compare
,
places
=
4
)
def
test_max_f
(
self
):
input_size
=
224
beta
=
0.3
inputs
=
(
tf
.
random
.
uniform
([
2
,
input_size
,
input_size
,
1
]),)
labels
=
(
tf
.
random
.
uniform
([
2
,
input_size
,
input_size
,
1
]),)
max_f_obj
=
metrics
.
MaxFscore
()
max_f_obj
.
reset_states
()
max_f_obj
.
update_state
(
labels
,
inputs
)
output
=
max_f_obj
.
result
()
pre_tf
=
tf
.
keras
.
metrics
.
Precision
(
thresholds
=
0.78
)
rec_tf
=
tf
.
keras
.
metrics
.
Recall
(
thresholds
=
0.78
)
pre_tf
.
reset_state
()
rec_tf
.
reset_state
()
pre_tf
.
update_state
(
labels
[
0
],
inputs
[
0
])
rec_tf
.
update_state
(
labels
[
0
],
inputs
[
0
])
pre_out_tf
=
pre_tf
.
result
().
numpy
()
rec_out_tf
=
rec_tf
.
result
().
numpy
()
compare
=
(
1
+
beta
)
*
pre_out_tf
*
rec_out_tf
/
(
beta
*
pre_out_tf
+
rec_out_tf
+
1e-8
)
self
.
assertAlmostEqual
(
output
,
compare
,
places
=
1
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/projects/basnet/losses/basnet_losses.py
0 → 100644
View file @
9c069a70
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Losses used for BASNet models."""
import
tensorflow
as
tf
EPSILON
=
1e-5
class
BASNetLoss
:
"""BASNet hybrid loss."""
def
__init__
(
self
):
self
.
_binary_crossentropy
=
tf
.
keras
.
losses
.
BinaryCrossentropy
(
reduction
=
tf
.
keras
.
losses
.
Reduction
.
SUM
,
from_logits
=
False
)
self
.
_ssim
=
tf
.
image
.
ssim
def
__call__
(
self
,
sigmoids
,
labels
):
levels
=
sorted
(
sigmoids
.
keys
())
labels_bce
=
tf
.
squeeze
(
labels
,
axis
=-
1
)
labels
=
tf
.
cast
(
labels
,
tf
.
float32
)
bce_losses
=
[]
ssim_losses
=
[]
iou_losses
=
[]
for
level
in
levels
:
bce_losses
.
append
(
self
.
_binary_crossentropy
(
labels_bce
,
sigmoids
[
level
]))
ssim_losses
.
append
(
1
-
self
.
_ssim
(
sigmoids
[
level
],
labels
,
max_val
=
1.0
))
iou_losses
.
append
(
self
.
_iou_loss
(
sigmoids
[
level
],
labels
))
total_bce_loss
=
tf
.
math
.
add_n
(
bce_losses
)
total_ssim_loss
=
tf
.
math
.
add_n
(
ssim_losses
)
total_iou_loss
=
tf
.
math
.
add_n
(
iou_losses
)
total_loss
=
total_bce_loss
+
total_ssim_loss
+
total_iou_loss
total_loss
=
total_loss
/
len
(
levels
)
return
total_loss
def
_iou_loss
(
self
,
sigmoids
,
labels
):
total_iou_loss
=
0
intersection
=
tf
.
reduce_sum
(
sigmoids
[:,
:,
:,
:]
*
labels
[:,
:,
:,
:])
union
=
tf
.
reduce_sum
(
sigmoids
[:,
:,
:,
:])
+
tf
.
reduce_sum
(
labels
[:,
:,
:,
:])
-
intersection
iou
=
intersection
/
union
total_iou_loss
+=
1
-
iou
return
total_iou_loss
official/projects/basnet/modeling/basnet_model.py
0 → 100644
View file @
9c069a70
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Build BASNet models."""
from
typing
import
Mapping
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.projects.basnet.modeling
import
nn_blocks
from
official.vision.beta.modeling.backbones
import
factory
# Specifications for BASNet encoder.
# Each element in the block configuration is in the following format:
# (num_filters, stride, block_repeats, maxpool)
BASNET_ENCODER_SPECS
=
[
(
64
,
1
,
3
,
0
),
# ResNet-34,
(
128
,
2
,
4
,
0
),
# ResNet-34,
(
256
,
2
,
6
,
0
),
# ResNet-34,
(
512
,
2
,
3
,
1
),
# ResNet-34,
(
512
,
1
,
3
,
1
),
# BASNet,
(
512
,
1
,
3
,
0
),
# BASNet,
]
# Specifications for BASNet decoder.
# Each element in the block configuration is in the following format:
# (conv1_nf, conv1_dr, convm_nf, convm_dr, conv2_nf, conv2_dr, scale_factor)
# nf : num_filters, dr : dilation_rate
BASNET_BRIDGE_SPECS
=
[
(
512
,
2
,
512
,
2
,
512
,
2
,
32
),
# Sup0, Bridge
]
BASNET_DECODER_SPECS
=
[
(
512
,
1
,
512
,
2
,
512
,
2
,
32
),
# Sup1, stage6d
(
512
,
1
,
512
,
1
,
512
,
1
,
16
),
# Sup2, stage5d
(
512
,
1
,
512
,
1
,
256
,
1
,
8
),
# Sup3, stage4d
(
256
,
1
,
256
,
1
,
128
,
1
,
4
),
# Sup4, stage3d
(
128
,
1
,
128
,
1
,
64
,
1
,
2
),
# Sup5, stage2d
(
64
,
1
,
64
,
1
,
64
,
1
,
1
)
# Sup6, stage1d
]
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
BASNetModel
(
tf
.
keras
.
Model
):
"""A BASNet model.
Boundary-Awar network (BASNet) were proposed in:
[1] Qin, Xuebin, et al.
Basnet: Boundary-aware salient object detection.
Input images are passed through backbone first. Decoder network is then
applied, and finally, refinement module is applied on the output of the
decoder network.
"""
def
__init__
(
self
,
backbone
,
decoder
,
refinement
=
None
,
**
kwargs
):
"""BASNet initialization function.
Args:
backbone: a backbone network. basnet_encoder.
decoder: a decoder network. basnet_decoder.
refinement: a module for salient map refinement.
**kwargs: keyword arguments to be passed.
"""
super
(
BASNetModel
,
self
).
__init__
(
**
kwargs
)
self
.
_config_dict
=
{
'backbone'
:
backbone
,
'decoder'
:
decoder
,
'refinement'
:
refinement
,
}
self
.
backbone
=
backbone
self
.
decoder
=
decoder
self
.
refinement
=
refinement
def
call
(
self
,
inputs
,
training
=
None
):
features
=
self
.
backbone
(
inputs
)
if
self
.
decoder
:
features
=
self
.
decoder
(
features
)
levels
=
sorted
(
features
.
keys
())
new_key
=
str
(
len
(
levels
))
if
self
.
refinement
:
features
[
new_key
]
=
self
.
refinement
(
features
[
levels
[
-
1
]])
return
features
@
property
def
checkpoint_items
(
self
):
"""Returns a dictionary of items to be additionally checkpointed."""
items
=
dict
(
backbone
=
self
.
backbone
)
if
self
.
decoder
is
not
None
:
items
.
update
(
decoder
=
self
.
decoder
)
if
self
.
refinement
is
not
None
:
items
.
update
(
refinement
=
self
.
refinement
)
return
items
def
get_config
(
self
):
return
self
.
_config_dict
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
BASNetEncoder
(
tf
.
keras
.
Model
):
"""BASNet encoder."""
def
__init__
(
self
,
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
]),
activation
=
'relu'
,
use_sync_bn
=
False
,
use_bias
=
True
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
kernel_initializer
=
'VarianceScaling'
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
**
kwargs
):
"""BASNet encoder initialization function.
Args:
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
use_bias: if True, use bias in conv2d.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
**kwargs: keyword arguments to be passed.
"""
self
.
_input_specs
=
input_specs
self
.
_use_sync_bn
=
use_sync_bn
self
.
_use_bias
=
use_bias
self
.
_activation
=
activation
self
.
_norm_momentum
=
norm_momentum
self
.
_norm_epsilon
=
norm_epsilon
if
use_sync_bn
:
self
.
_norm
=
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
else
:
self
.
_norm
=
tf
.
keras
.
layers
.
BatchNormalization
self
.
_kernel_initializer
=
kernel_initializer
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_bias_regularizer
=
bias_regularizer
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
bn_axis
=
-
1
else
:
bn_axis
=
1
# Build BASNet Encoder.
inputs
=
tf
.
keras
.
Input
(
shape
=
input_specs
.
shape
[
1
:])
x
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
64
,
kernel_size
=
3
,
strides
=
1
,
use_bias
=
self
.
_use_bias
,
padding
=
'same'
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)(
inputs
)
x
=
self
.
_norm
(
axis
=
bn_axis
,
momentum
=
norm_momentum
,
epsilon
=
norm_epsilon
)(
x
)
x
=
tf_utils
.
get_activation
(
activation
)(
x
)
endpoints
=
{}
for
i
,
spec
in
enumerate
(
BASNET_ENCODER_SPECS
):
x
=
self
.
_block_group
(
inputs
=
x
,
filters
=
spec
[
0
],
strides
=
spec
[
1
],
block_repeats
=
spec
[
2
],
name
=
'block_group_l{}'
.
format
(
i
+
2
))
endpoints
[
str
(
i
)]
=
x
if
spec
[
3
]:
x
=
tf
.
keras
.
layers
.
MaxPool2D
(
pool_size
=
2
,
strides
=
2
,
padding
=
'same'
)(
x
)
self
.
_output_specs
=
{
l
:
endpoints
[
l
].
get_shape
()
for
l
in
endpoints
}
super
(
BASNetEncoder
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
endpoints
,
**
kwargs
)
def
_block_group
(
self
,
inputs
,
filters
,
strides
,
block_repeats
=
1
,
name
=
'block_group'
):
"""Creates one group of residual blocks for the BASNet encoder model.
Args:
inputs: `Tensor` of size `[batch, channels, height, width]`.
filters: `int` number of filters for the first convolution of the layer.
strides: `int` stride to use for the first convolution of the layer. If
greater than 1, this layer will downsample the input.
block_repeats: `int` number of blocks contained in the layer.
name: `str`name for the block.
Returns:
The output `Tensor` of the block layer.
"""
x
=
nn_blocks
.
ResBlock
(
filters
=
filters
,
strides
=
strides
,
use_projection
=
True
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activation
=
self
.
_activation
,
use_sync_bn
=
self
.
_use_sync_bn
,
use_bias
=
self
.
_use_bias
,
norm_momentum
=
self
.
_norm_momentum
,
norm_epsilon
=
self
.
_norm_epsilon
)(
inputs
)
for
_
in
range
(
1
,
block_repeats
):
x
=
nn_blocks
.
ResBlock
(
filters
=
filters
,
strides
=
1
,
use_projection
=
False
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activation
=
self
.
_activation
,
use_sync_bn
=
self
.
_use_sync_bn
,
use_bias
=
self
.
_use_bias
,
norm_momentum
=
self
.
_norm_momentum
,
norm_epsilon
=
self
.
_norm_epsilon
)(
x
)
return
tf
.
identity
(
x
,
name
=
name
)
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
@
property
def
output_specs
(
self
):
"""A dict of {level: TensorShape} pairs for the model output."""
return
self
.
_output_specs
@
factory
.
register_backbone_builder
(
'basnet_encoder'
)
def
build_basnet_encoder
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
model_config
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
"""Builds BASNet Encoder backbone from a config."""
backbone_type
=
model_config
.
backbone
.
type
norm_activation_config
=
model_config
.
norm_activation
assert
backbone_type
==
'basnet_encoder'
,
(
f
'Inconsistent backbone type '
f
'
{
backbone_type
}
'
)
return
BASNetEncoder
(
input_specs
=
input_specs
,
activation
=
norm_activation_config
.
activation
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
use_bias
=
norm_activation_config
.
use_bias
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
kernel_regularizer
=
l2_regularizer
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
BASNetDecoder
(
tf
.
keras
.
layers
.
Layer
):
"""BASNet decoder."""
def
__init__
(
self
,
activation
=
'relu'
,
use_sync_bn
=
False
,
use_bias
=
True
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
kernel_initializer
=
'VarianceScaling'
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
**
kwargs
):
"""BASNet decoder initialization function.
Args:
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
use_bias: if True, use bias in convolution.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
**kwargs: keyword arguments to be passed.
"""
super
(
BASNetDecoder
,
self
).
__init__
(
**
kwargs
)
self
.
_config_dict
=
{
'activation'
:
activation
,
'use_sync_bn'
:
use_sync_bn
,
'use_bias'
:
use_bias
,
'norm_momentum'
:
norm_momentum
,
'norm_epsilon'
:
norm_epsilon
,
'kernel_initializer'
:
kernel_initializer
,
'kernel_regularizer'
:
kernel_regularizer
,
'bias_regularizer'
:
bias_regularizer
,
}
self
.
_activation
=
tf_utils
.
get_activation
(
activation
)
self
.
_concat
=
tf
.
keras
.
layers
.
Concatenate
(
axis
=-
1
)
self
.
_sigmoid
=
tf
.
keras
.
layers
.
Activation
(
activation
=
'sigmoid'
)
def
build
(
self
,
input_shape
):
"""Creates the variables of the BASNet decoder."""
conv_op
=
tf
.
keras
.
layers
.
Conv2D
conv_kwargs
=
{
'kernel_size'
:
3
,
'strides'
:
1
,
'use_bias'
:
self
.
_config_dict
[
'use_bias'
],
'kernel_initializer'
:
self
.
_config_dict
[
'kernel_initializer'
],
'kernel_regularizer'
:
self
.
_config_dict
[
'kernel_regularizer'
],
'bias_regularizer'
:
self
.
_config_dict
[
'bias_regularizer'
],
}
self
.
_out_convs
=
[]
self
.
_out_usmps
=
[]
# Bridge layers.
self
.
_bdg_convs
=
[]
for
spec
in
BASNET_BRIDGE_SPECS
:
blocks
=
[]
for
j
in
range
(
3
):
blocks
.
append
(
nn_blocks
.
ConvBlock
(
filters
=
spec
[
2
*
j
],
dilation_rate
=
spec
[
2
*
j
+
1
],
activation
=
'relu'
,
use_sync_bn
=
self
.
_config_dict
[
'use_sync_bn'
],
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
**
conv_kwargs
))
self
.
_bdg_convs
.
append
(
blocks
)
self
.
_out_convs
.
append
(
conv_op
(
filters
=
1
,
padding
=
'same'
,
**
conv_kwargs
))
self
.
_out_usmps
.
append
(
tf
.
keras
.
layers
.
UpSampling2D
(
size
=
spec
[
6
],
interpolation
=
'bilinear'
))
# Decoder layers.
self
.
_dec_convs
=
[]
for
spec
in
BASNET_DECODER_SPECS
:
blocks
=
[]
for
j
in
range
(
3
):
blocks
.
append
(
nn_blocks
.
ConvBlock
(
filters
=
spec
[
2
*
j
],
dilation_rate
=
spec
[
2
*
j
+
1
],
activation
=
'relu'
,
use_sync_bn
=
self
.
_config_dict
[
'use_sync_bn'
],
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
**
conv_kwargs
))
self
.
_dec_convs
.
append
(
blocks
)
self
.
_out_convs
.
append
(
conv_op
(
filters
=
1
,
padding
=
'same'
,
**
conv_kwargs
))
self
.
_out_usmps
.
append
(
tf
.
keras
.
layers
.
UpSampling2D
(
size
=
spec
[
6
],
interpolation
=
'bilinear'
))
def
call
(
self
,
backbone_output
:
Mapping
[
str
,
tf
.
Tensor
]):
"""Forward pass of the BASNet decoder.
Args:
backbone_output: A `dict` of tensors
- key: A `str` of the level of the multilevel features.
- values: A `tf.Tensor` of the feature map tensors, whose shape is
[batch, height_l, width_l, channels].
Returns:
sup: A `dict` of tensors
- key: A `str` of the level of the multilevel features.
- values: A `tf.Tensor` of the feature map tensors, whose shape is
[batch, height_l, width_l, channels].
"""
levels
=
sorted
(
backbone_output
.
keys
(),
reverse
=
True
)
sup
=
{}
x
=
backbone_output
[
levels
[
0
]]
for
blocks
in
self
.
_bdg_convs
:
for
block
in
blocks
:
x
=
block
(
x
)
sup
[
'0'
]
=
x
for
i
,
blocks
in
enumerate
(
self
.
_dec_convs
):
x
=
self
.
_concat
([
x
,
backbone_output
[
levels
[
i
]]])
for
block
in
blocks
:
x
=
block
(
x
)
sup
[
str
(
i
+
1
)]
=
x
x
=
tf
.
keras
.
layers
.
UpSampling2D
(
size
=
2
,
interpolation
=
'bilinear'
)(
x
)
for
i
,
(
conv
,
usmp
)
in
enumerate
(
zip
(
self
.
_out_convs
,
self
.
_out_usmps
)):
sup
[
str
(
i
)]
=
self
.
_sigmoid
(
usmp
(
conv
(
sup
[
str
(
i
)])))
self
.
_output_specs
=
{
str
(
order
):
sup
[
str
(
order
)].
get_shape
()
for
order
in
range
(
0
,
len
(
BASNET_DECODER_SPECS
))
}
return
sup
def
get_config
(
self
):
return
self
.
_config_dict
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
@
property
def
output_specs
(
self
):
"""A dict of {order: TensorShape} pairs for the model output."""
return
self
.
_output_specs
official/projects/basnet/modeling/basnet_model_test.py
0 → 100644
View file @
9c069a70
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for basnet network."""
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.projects.basnet.modeling
import
basnet_model
from
official.projects.basnet.modeling
import
refunet
class
BASNetNetworkTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
(
(
256
),
(
512
),
)
def
test_basnet_network_creation
(
self
,
input_size
):
"""Test for creation of a segmentation network."""
inputs
=
np
.
random
.
rand
(
2
,
input_size
,
input_size
,
3
)
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
backbone
=
basnet_model
.
BASNetEncoder
()
decoder
=
basnet_model
.
BASNetDecoder
()
refinement
=
refunet
.
RefUnet
()
model
=
basnet_model
.
BASNetModel
(
backbone
=
backbone
,
decoder
=
decoder
,
refinement
=
refinement
)
sigmoids
=
model
(
inputs
)
levels
=
sorted
(
sigmoids
.
keys
())
self
.
assertAllEqual
(
[
2
,
input_size
,
input_size
,
1
],
sigmoids
[
levels
[
-
1
]].
numpy
().
shape
)
def
test_serialize_deserialize
(
self
):
"""Validate the network can be serialized and deserialized."""
backbone
=
basnet_model
.
BASNetEncoder
()
decoder
=
basnet_model
.
BASNetDecoder
()
refinement
=
refunet
.
RefUnet
()
model
=
basnet_model
.
BASNetModel
(
backbone
=
backbone
,
decoder
=
decoder
,
refinement
=
refinement
)
config
=
model
.
get_config
()
new_model
=
basnet_model
.
BASNetModel
.
from_config
(
config
)
# Validate that the config can be forced to JSON.
_
=
new_model
.
to_json
()
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
model
.
get_config
(),
new_model
.
get_config
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/projects/basnet/modeling/nn_blocks.py
0 → 100644
View file @
9c069a70
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains common building blocks for BasNet model."""
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
ConvBlock
(
tf
.
keras
.
layers
.
Layer
):
"""A (Conv+BN+Activation) block."""
def
__init__
(
self
,
filters
,
strides
,
dilation_rate
=
1
,
kernel_size
=
3
,
kernel_initializer
=
'VarianceScaling'
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
activation
=
'relu'
,
use_bias
=
False
,
use_sync_bn
=
False
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
**
kwargs
):
"""A vgg block with BN after convolutions.
Args:
filters: `int` number of filters for the first two convolutions. Note that
the third and final convolution will use 4 times as many filters.
strides: `int` block stride. If greater than 1, this block will ultimately
downsample the input.
dilation_rate: `int`, dilation rate for conv layers.
kernel_size: `int`, kernel size of conv layers.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
activation: `str` name of the activation function.
use_bias: `bool`, whether or not use bias in conv layers.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
**kwargs: keyword arguments to be passed.
"""
super
(
ConvBlock
,
self
).
__init__
(
**
kwargs
)
self
.
_config_dict
=
{
'filters'
:
filters
,
'kernel_size'
:
kernel_size
,
'strides'
:
strides
,
'dilation_rate'
:
dilation_rate
,
'kernel_initializer'
:
kernel_initializer
,
'kernel_regularizer'
:
kernel_regularizer
,
'bias_regularizer'
:
bias_regularizer
,
'activation'
:
activation
,
'use_sync_bn'
:
use_sync_bn
,
'use_bias'
:
use_bias
,
'norm_momentum'
:
norm_momentum
,
'norm_epsilon'
:
norm_epsilon
}
if
use_sync_bn
:
self
.
_norm
=
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
else
:
self
.
_norm
=
tf
.
keras
.
layers
.
BatchNormalization
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
self
.
_bn_axis
=
-
1
else
:
self
.
_bn_axis
=
1
self
.
_activation_fn
=
tf_utils
.
get_activation
(
activation
)
def
build
(
self
,
input_shape
):
conv_kwargs
=
{
'padding'
:
'same'
,
'use_bias'
:
self
.
_config_dict
[
'use_bias'
],
'kernel_initializer'
:
self
.
_config_dict
[
'kernel_initializer'
],
'kernel_regularizer'
:
self
.
_config_dict
[
'kernel_regularizer'
],
'bias_regularizer'
:
self
.
_config_dict
[
'bias_regularizer'
],
}
self
.
_conv0
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
self
.
_config_dict
[
'filters'
],
kernel_size
=
self
.
_config_dict
[
'kernel_size'
],
strides
=
self
.
_config_dict
[
'strides'
],
dilation_rate
=
self
.
_config_dict
[
'dilation_rate'
],
**
conv_kwargs
)
self
.
_norm0
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_config_dict
[
'norm_momentum'
],
epsilon
=
self
.
_config_dict
[
'norm_epsilon'
])
super
(
ConvBlock
,
self
).
build
(
input_shape
)
def
get_config
(
self
):
return
self
.
_config_dict
def
call
(
self
,
inputs
,
training
=
None
):
x
=
self
.
_conv0
(
inputs
)
x
=
self
.
_norm0
(
x
)
x
=
self
.
_activation_fn
(
x
)
return
x
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
ResBlock
(
tf
.
keras
.
layers
.
Layer
):
"""A residual block."""
def
__init__
(
self
,
filters
,
strides
,
use_projection
=
False
,
kernel_initializer
=
'VarianceScaling'
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
activation
=
'relu'
,
use_sync_bn
=
False
,
use_bias
=
False
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
**
kwargs
):
"""Initializes a residual block with BN after convolutions.
Args:
filters: An `int` number of filters for the first two convolutions. Note
that the third and final convolution will use 4 times as many filters.
strides: An `int` block stride. If greater than 1, this block will
ultimately downsample the input.
use_projection: A `bool` for whether this block should use a projection
shortcut (versus the default identity shortcut). This is usually `True`
for the first block of a block group, which may change the number of
filters and the resolution.
kernel_initializer: A `str` of kernel_initializer for convolutional
layers.
kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
Conv2D. Default to None.
bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2d.
Default to None.
activation: A `str` name of the activation function.
use_sync_bn: A `bool`. If True, use synchronized batch normalization.
use_bias: A `bool`. If True, use bias in conv2d.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero.
**kwargs: Additional keyword arguments to be passed.
"""
super
(
ResBlock
,
self
).
__init__
(
**
kwargs
)
self
.
_config_dict
=
{
'filters'
:
filters
,
'strides'
:
strides
,
'use_projection'
:
use_projection
,
'kernel_initializer'
:
kernel_initializer
,
'kernel_regularizer'
:
kernel_regularizer
,
'bias_regularizer'
:
bias_regularizer
,
'activation'
:
activation
,
'use_sync_bn'
:
use_sync_bn
,
'use_bias'
:
use_bias
,
'norm_momentum'
:
norm_momentum
,
'norm_epsilon'
:
norm_epsilon
}
if
use_sync_bn
:
self
.
_norm
=
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
else
:
self
.
_norm
=
tf
.
keras
.
layers
.
BatchNormalization
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
self
.
_bn_axis
=
-
1
else
:
self
.
_bn_axis
=
1
self
.
_activation_fn
=
tf_utils
.
get_activation
(
activation
)
def
build
(
self
,
input_shape
):
conv_kwargs
=
{
'filters'
:
self
.
_config_dict
[
'filters'
],
'padding'
:
'same'
,
'use_bias'
:
self
.
_config_dict
[
'use_bias'
],
'kernel_initializer'
:
self
.
_config_dict
[
'kernel_initializer'
],
'kernel_regularizer'
:
self
.
_config_dict
[
'kernel_regularizer'
],
'bias_regularizer'
:
self
.
_config_dict
[
'bias_regularizer'
],
}
if
self
.
_config_dict
[
'use_projection'
]:
self
.
_shortcut
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
self
.
_config_dict
[
'filters'
],
kernel_size
=
1
,
strides
=
self
.
_config_dict
[
'strides'
],
use_bias
=
self
.
_config_dict
[
'use_bias'
],
kernel_initializer
=
self
.
_config_dict
[
'kernel_initializer'
],
kernel_regularizer
=
self
.
_config_dict
[
'kernel_regularizer'
],
bias_regularizer
=
self
.
_config_dict
[
'bias_regularizer'
])
self
.
_norm0
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_config_dict
[
'norm_momentum'
],
epsilon
=
self
.
_config_dict
[
'norm_epsilon'
])
self
.
_conv1
=
tf
.
keras
.
layers
.
Conv2D
(
kernel_size
=
3
,
strides
=
self
.
_config_dict
[
'strides'
],
**
conv_kwargs
)
self
.
_norm1
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_config_dict
[
'norm_momentum'
],
epsilon
=
self
.
_config_dict
[
'norm_epsilon'
])
self
.
_conv2
=
tf
.
keras
.
layers
.
Conv2D
(
kernel_size
=
3
,
strides
=
1
,
**
conv_kwargs
)
self
.
_norm2
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_config_dict
[
'norm_momentum'
],
epsilon
=
self
.
_config_dict
[
'norm_epsilon'
])
super
(
ResBlock
,
self
).
build
(
input_shape
)
def
get_config
(
self
):
return
self
.
_config_dict
def
call
(
self
,
inputs
,
training
=
None
):
shortcut
=
inputs
if
self
.
_config_dict
[
'use_projection'
]:
shortcut
=
self
.
_shortcut
(
shortcut
)
shortcut
=
self
.
_norm0
(
shortcut
)
x
=
self
.
_conv1
(
inputs
)
x
=
self
.
_norm1
(
x
)
x
=
self
.
_activation_fn
(
x
)
x
=
self
.
_conv2
(
x
)
x
=
self
.
_norm2
(
x
)
return
self
.
_activation_fn
(
x
+
shortcut
)
official/projects/basnet/modeling/refunet.py
0 → 100644
View file @
9c069a70
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RefUNet model."""
import
tensorflow
as
tf
from
official.projects.basnet.modeling
import
nn_blocks
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
RefUnet
(
tf
.
keras
.
layers
.
Layer
):
"""Residual Refinement Module of BASNet.
Boundary-Aware network (BASNet) were proposed in:
[1] Qin, Xuebin, et al.
Basnet: Boundary-aware salient object detection.
"""
def
__init__
(
self
,
activation
=
'relu'
,
use_sync_bn
=
False
,
use_bias
=
True
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
kernel_initializer
=
'VarianceScaling'
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
**
kwargs
):
"""Residual Refinement Module of BASNet.
Args:
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
use_bias: if True, use bias in conv2d.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
**kwargs: keyword arguments to be passed.
"""
super
(
RefUnet
,
self
).
__init__
(
**
kwargs
)
self
.
_config_dict
=
{
'activation'
:
activation
,
'use_sync_bn'
:
use_sync_bn
,
'use_bias'
:
use_bias
,
'norm_momentum'
:
norm_momentum
,
'norm_epsilon'
:
norm_epsilon
,
'kernel_initializer'
:
kernel_initializer
,
'kernel_regularizer'
:
kernel_regularizer
,
'bias_regularizer'
:
bias_regularizer
,
}
self
.
_concat
=
tf
.
keras
.
layers
.
Concatenate
(
axis
=-
1
)
self
.
_sigmoid
=
tf
.
keras
.
layers
.
Activation
(
activation
=
'sigmoid'
)
self
.
_maxpool
=
tf
.
keras
.
layers
.
MaxPool2D
(
pool_size
=
2
,
strides
=
2
,
padding
=
'valid'
)
self
.
_upsample
=
tf
.
keras
.
layers
.
UpSampling2D
(
size
=
2
,
interpolation
=
'bilinear'
)
def
build
(
self
,
input_shape
):
"""Creates the variables of the BASNet decoder."""
conv_op
=
tf
.
keras
.
layers
.
Conv2D
conv_kwargs
=
{
'kernel_size'
:
3
,
'strides'
:
1
,
'use_bias'
:
self
.
_config_dict
[
'use_bias'
],
'kernel_initializer'
:
self
.
_config_dict
[
'kernel_initializer'
],
'kernel_regularizer'
:
self
.
_config_dict
[
'kernel_regularizer'
],
'bias_regularizer'
:
self
.
_config_dict
[
'bias_regularizer'
],
}
self
.
_in_conv
=
conv_op
(
filters
=
64
,
padding
=
'same'
,
**
conv_kwargs
)
self
.
_en_convs
=
[]
for
_
in
range
(
4
):
self
.
_en_convs
.
append
(
nn_blocks
.
ConvBlock
(
filters
=
64
,
use_sync_bn
=
self
.
_config_dict
[
'use_sync_bn'
],
norm_momentum
=
self
.
_config_dict
[
'norm_momentum'
],
norm_epsilon
=
self
.
_config_dict
[
'norm_epsilon'
],
**
conv_kwargs
))
self
.
_bridge_convs
=
[]
for
_
in
range
(
1
):
self
.
_bridge_convs
.
append
(
nn_blocks
.
ConvBlock
(
filters
=
64
,
use_sync_bn
=
self
.
_config_dict
[
'use_sync_bn'
],
norm_momentum
=
self
.
_config_dict
[
'norm_momentum'
],
norm_epsilon
=
self
.
_config_dict
[
'norm_epsilon'
],
**
conv_kwargs
))
self
.
_de_convs
=
[]
for
_
in
range
(
4
):
self
.
_de_convs
.
append
(
nn_blocks
.
ConvBlock
(
filters
=
64
,
use_sync_bn
=
self
.
_config_dict
[
'use_sync_bn'
],
norm_momentum
=
self
.
_config_dict
[
'norm_momentum'
],
norm_epsilon
=
self
.
_config_dict
[
'norm_epsilon'
],
**
conv_kwargs
))
self
.
_out_conv
=
conv_op
(
filters
=
1
,
padding
=
'same'
,
**
conv_kwargs
)
def
call
(
self
,
inputs
):
endpoints
=
{}
residual
=
inputs
x
=
self
.
_in_conv
(
inputs
)
# Top-down
for
i
,
block
in
enumerate
(
self
.
_en_convs
):
x
=
block
(
x
)
endpoints
[
str
(
i
)]
=
x
x
=
self
.
_maxpool
(
x
)
# Bridge
for
i
,
block
in
enumerate
(
self
.
_bridge_convs
):
x
=
block
(
x
)
# Bottom-up
for
i
,
block
in
enumerate
(
self
.
_de_convs
):
dtype
=
x
.
dtype
x
=
tf
.
cast
(
x
,
tf
.
float32
)
x
=
self
.
_upsample
(
x
)
x
=
tf
.
cast
(
x
,
dtype
)
x
=
self
.
_concat
([
endpoints
[
str
(
3
-
i
)],
x
])
x
=
block
(
x
)
x
=
self
.
_out_conv
(
x
)
residual
=
tf
.
cast
(
residual
,
dtype
=
x
.
dtype
)
output
=
self
.
_sigmoid
(
x
+
residual
)
self
.
_output_specs
=
output
.
get_shape
()
return
output
def
get_config
(
self
):
return
self
.
_config_dict
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
@
property
def
output_specs
(
self
):
return
self
.
_output_specs
official/projects/basnet/serving/basnet.py
0 → 100644
View file @
9c069a70
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Export module for BASNet."""
import
tensorflow
as
tf
from
official.projects.basnet.tasks
import
basnet
from
official.vision.beta.serving
import
semantic_segmentation
MEAN_RGB
=
(
0.485
*
255
,
0.456
*
255
,
0.406
*
255
)
STDDEV_RGB
=
(
0.229
*
255
,
0.224
*
255
,
0.225
*
255
)
class
BASNetModule
(
semantic_segmentation
.
SegmentationModule
):
"""BASNet Module."""
def
_build_model
(
self
):
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
self
.
_batch_size
]
+
self
.
_input_image_size
+
[
3
])
return
basnet
.
build_basnet_model
(
input_specs
=
input_specs
,
model_config
=
self
.
params
.
task
.
model
,
l2_regularizer
=
None
)
def
serve
(
self
,
images
):
"""Cast image to float and run inference.
Args:
images: uint8 Tensor of shape [batch_size, None, None, 3]
Returns:
Tensor holding classification output logits.
"""
with
tf
.
device
(
'cpu:0'
):
images
=
tf
.
cast
(
images
,
dtype
=
tf
.
float32
)
images
=
tf
.
nest
.
map_structure
(
tf
.
identity
,
tf
.
map_fn
(
self
.
_build_inputs
,
elems
=
images
,
fn_output_signature
=
tf
.
TensorSpec
(
shape
=
self
.
_input_image_size
+
[
3
],
dtype
=
tf
.
float32
),
parallel_iterations
=
32
)
)
masks
=
self
.
inference_step
(
images
)
keys
=
sorted
(
masks
.
keys
())
output
=
tf
.
image
.
resize
(
masks
[
keys
[
-
1
]],
self
.
_input_image_size
,
method
=
'bilinear'
)
return
dict
(
predicted_masks
=
output
)
official/projects/basnet/serving/export_saved_model.py
0 → 100644
View file @
9c069a70
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r
"""Export binary for BASNet.
To export a trained checkpoint in saved_model format (shell script):
EXPERIMENT_TYPE = XX
CHECKPOINT_PATH = XX
EXPORT_DIR_PATH = XX
export_saved_model --experiment=${EXPERIMENT_TYPE} \
--export_dir=${EXPORT_DIR_PATH}/ \
--checkpoint_path=${CHECKPOINT_PATH} \
--batch_size=2 \
--input_image_size=224,224
To serve (python):
export_dir_path = XX
input_type = XX
input_images = XX
imported = tf.saved_model.load(export_dir_path)
model_fn = imported.signatures['serving_default']
output = model_fn(input_images)
"""
from
absl
import
app
from
absl
import
flags
from
official.core
import
exp_factory
from
official.modeling
import
hyperparams
from
official.projects.basnet.serving
import
basnet
from
official.vision.beta.serving
import
export_saved_model_lib
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
'experiment'
,
None
,
'experiment type, e.g. retinanet_resnetfpn_coco'
)
flags
.
DEFINE_string
(
'export_dir'
,
None
,
'The export directory.'
)
flags
.
DEFINE_string
(
'checkpoint_path'
,
None
,
'Checkpoint path.'
)
flags
.
DEFINE_multi_string
(
'config_file'
,
default
=
None
,
help
=
'YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.'
)
flags
.
DEFINE_string
(
'params_override'
,
''
,
'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.'
)
flags
.
DEFINE_integer
(
'batch_size'
,
None
,
'The batch size.'
)
flags
.
DEFINE_string
(
'input_type'
,
'image_tensor'
,
'One of `image_tensor`, `image_bytes`, `tf_example`.'
)
flags
.
DEFINE_string
(
'input_image_size'
,
'224,224'
,
'The comma-separated string of two integers representing the height,width '
'of the input to the model.'
)
def
main
(
_
):
params
=
exp_factory
.
get_exp_config
(
FLAGS
.
experiment
)
for
config_file
in
FLAGS
.
config_file
or
[]:
params
=
hyperparams
.
override_params_dict
(
params
,
config_file
,
is_strict
=
True
)
if
FLAGS
.
params_override
:
params
=
hyperparams
.
override_params_dict
(
params
,
FLAGS
.
params_override
,
is_strict
=
True
)
params
.
validate
()
params
.
lock
()
export_saved_model_lib
.
export_inference_graph
(
input_type
=
FLAGS
.
input_type
,
batch_size
=
FLAGS
.
batch_size
,
input_image_size
=
[
int
(
x
)
for
x
in
FLAGS
.
input_image_size
.
split
(
','
)],
params
=
params
,
checkpoint_path
=
FLAGS
.
checkpoint_path
,
export_dir
=
FLAGS
.
export_dir
,
export_module
=
basnet
.
BASNetModule
(
params
=
params
,
batch_size
=
FLAGS
.
batch_size
,
input_image_size
=
[
int
(
x
)
for
x
in
FLAGS
.
input_image_size
.
split
(
','
)]),
export_checkpoint_subdir
=
'checkpoint'
,
export_saved_model_subdir
=
'saved_model'
)
if
__name__
==
'__main__'
:
app
.
run
(
main
)
official/projects/basnet/tasks/basnet.py
0 → 100644
View file @
9c069a70
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BASNet task definition."""
from
typing
import
Optional
from
absl
import
logging
import
tensorflow
as
tf
from
official.common
import
dataset_fn
from
official.core
import
base_task
from
official.core
import
input_reader
from
official.core
import
task_factory
from
official.projects.basnet.configs
import
basnet
as
exp_cfg
from
official.projects.basnet.evaluation
import
metrics
as
basnet_metrics
from
official.projects.basnet.losses
import
basnet_losses
from
official.projects.basnet.modeling
import
basnet_model
from
official.projects.basnet.modeling
import
refunet
from
official.vision.beta.dataloaders
import
segmentation_input
def
build_basnet_model
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
model_config
:
exp_cfg
.
BASNetModel
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
):
"""Builds BASNet model."""
norm_activation_config
=
model_config
.
norm_activation
backbone
=
basnet_model
.
BASNetEncoder
(
input_specs
=
input_specs
,
activation
=
norm_activation_config
.
activation
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
use_bias
=
model_config
.
use_bias
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
kernel_regularizer
=
l2_regularizer
)
decoder
=
basnet_model
.
BASNetDecoder
(
activation
=
norm_activation_config
.
activation
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
use_bias
=
model_config
.
use_bias
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
kernel_regularizer
=
l2_regularizer
)
refinement
=
refunet
.
RefUnet
(
activation
=
norm_activation_config
.
activation
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
use_bias
=
model_config
.
use_bias
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
kernel_regularizer
=
l2_regularizer
)
model
=
basnet_model
.
BASNetModel
(
backbone
,
decoder
,
refinement
)
return
model
@
task_factory
.
register_task_cls
(
exp_cfg
.
BASNetTask
)
class
BASNetTask
(
base_task
.
Task
):
"""A task for basnet."""
def
build_model
(
self
):
"""Builds basnet model."""
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
]
+
self
.
task_config
.
model
.
input_size
)
l2_weight_decay
=
self
.
task_config
.
losses
.
l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer
=
(
tf
.
keras
.
regularizers
.
l2
(
l2_weight_decay
/
2.0
)
if
l2_weight_decay
else
None
)
model
=
build_basnet_model
(
input_specs
=
input_specs
,
model_config
=
self
.
task_config
.
model
,
l2_regularizer
=
l2_regularizer
)
return
model
def
initialize
(
self
,
model
:
tf
.
keras
.
Model
):
"""Loads pretrained checkpoint."""
if
not
self
.
task_config
.
init_checkpoint
:
return
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
# Restoring checkpoint.
if
'all'
in
self
.
task_config
.
init_checkpoint_modules
:
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
.
assert_consumed
()
else
:
ckpt_items
=
{}
if
'backbone'
in
self
.
task_config
.
init_checkpoint_modules
:
ckpt_items
.
update
(
backbone
=
model
.
backbone
)
if
'decoder'
in
self
.
task_config
.
init_checkpoint_modules
:
ckpt_items
.
update
(
decoder
=
model
.
decoder
)
ckpt
=
tf
.
train
.
Checkpoint
(
**
ckpt_items
)
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
def
build_inputs
(
self
,
params
:
exp_cfg
.
DataConfig
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Builds BASNet input."""
ignore_label
=
self
.
task_config
.
losses
.
ignore_label
decoder
=
segmentation_input
.
Decoder
()
parser
=
segmentation_input
.
Parser
(
output_size
=
params
.
output_size
,
crop_size
=
params
.
crop_size
,
ignore_label
=
ignore_label
,
aug_rand_hflip
=
params
.
aug_rand_hflip
,
dtype
=
params
.
dtype
)
reader
=
input_reader
.
InputReader
(
params
,
dataset_fn
=
dataset_fn
.
pick_dataset_fn
(
params
.
file_type
),
decoder_fn
=
decoder
.
decode
,
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
))
dataset
=
reader
.
read
(
input_context
=
input_context
)
return
dataset
def
build_losses
(
self
,
label
,
model_outputs
,
aux_losses
=
None
):
"""Hybrid loss proposed in BASNet.
Args:
label: label.
model_outputs: Output logits of the classifier.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
Returns:
The total loss tensor.
"""
basnet_loss_fn
=
basnet_losses
.
BASNetLoss
()
total_loss
=
basnet_loss_fn
(
model_outputs
,
label
[
'masks'
])
if
aux_losses
:
total_loss
+=
tf
.
add_n
(
aux_losses
)
return
total_loss
def
build_metrics
(
self
,
training
=
False
):
"""Gets streaming metrics for training/validation."""
evaluations
=
[]
if
training
:
evaluations
=
[]
else
:
self
.
mae_metric
=
basnet_metrics
.
MAE
()
self
.
maxf_metric
=
basnet_metrics
.
MaxFscore
()
self
.
relaxf_metric
=
basnet_metrics
.
RelaxedFscore
()
return
evaluations
def
train_step
(
self
,
inputs
,
model
,
optimizer
,
metrics
=
None
):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features
,
labels
=
inputs
num_replicas
=
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
with
tf
.
GradientTape
()
as
tape
:
outputs
=
model
(
features
,
training
=
True
)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
# Computes per-replica loss.
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
label
=
labels
,
aux_losses
=
model
.
losses
)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss
=
loss
/
num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
experimental
.
LossScaleOptimizer
):
scaled_loss
=
optimizer
.
get_scaled_loss
(
scaled_loss
)
tvars
=
model
.
trainable_variables
grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
experimental
.
LossScaleOptimizer
):
grads
=
optimizer
.
get_unscaled_gradients
(
grads
)
# Apply gradient clipping.
if
self
.
task_config
.
gradient_clip_norm
>
0
:
grads
,
_
=
tf
.
clip_by_global_norm
(
grads
,
self
.
task_config
.
gradient_clip_norm
)
optimizer
.
apply_gradients
(
list
(
zip
(
grads
,
tvars
)))
logs
=
{
self
.
loss
:
loss
}
return
logs
def
validation_step
(
self
,
inputs
,
model
,
metrics
=
None
):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features
,
labels
=
inputs
outputs
=
self
.
inference_step
(
features
,
model
)
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
loss
=
0
logs
=
{
self
.
loss
:
loss
}
levels
=
sorted
(
outputs
.
keys
())
logs
.
update
(
{
self
.
mae_metric
.
name
:
(
labels
[
'masks'
],
outputs
[
levels
[
-
1
]])})
logs
.
update
(
{
self
.
maxf_metric
.
name
:
(
labels
[
'masks'
],
outputs
[
levels
[
-
1
]])})
logs
.
update
(
{
self
.
relaxf_metric
.
name
:
(
labels
[
'masks'
],
outputs
[
levels
[
-
1
]])})
return
logs
def
inference_step
(
self
,
inputs
,
model
):
"""Performs the forward step."""
return
model
(
inputs
,
training
=
False
)
def
aggregate_logs
(
self
,
state
=
None
,
step_outputs
=
None
):
if
state
is
None
:
self
.
mae_metric
.
reset_states
()
self
.
maxf_metric
.
reset_states
()
self
.
relaxf_metric
.
reset_states
()
state
=
self
.
mae_metric
self
.
mae_metric
.
update_state
(
step_outputs
[
self
.
mae_metric
.
name
][
0
],
step_outputs
[
self
.
mae_metric
.
name
][
1
])
self
.
maxf_metric
.
update_state
(
step_outputs
[
self
.
maxf_metric
.
name
][
0
],
step_outputs
[
self
.
maxf_metric
.
name
][
1
])
self
.
relaxf_metric
.
update_state
(
step_outputs
[
self
.
relaxf_metric
.
name
][
0
],
step_outputs
[
self
.
relaxf_metric
.
name
][
1
])
return
state
def
reduce_aggregated_logs
(
self
,
aggregated_logs
,
global_step
=
None
):
result
=
{}
result
[
'MAE'
]
=
self
.
mae_metric
.
result
()
result
[
'maxF'
]
=
self
.
maxf_metric
.
result
()
result
[
'relaxF'
]
=
self
.
relaxf_metric
.
result
()
return
result
official/projects/basnet/train.py
0 → 100644
View file @
9c069a70
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""TensorFlow Model Garden Vision training driver."""
from
absl
import
app
# pylint: disable=unused-import
from
official.common
import
flags
as
tfm_flags
from
official.projects.basnet.configs
import
basnet
as
basnet_cfg
from
official.projects.basnet.modeling
import
basnet_model
from
official.projects.basnet.modeling
import
refunet
from
official.projects.basnet.tasks
import
basnet
as
basenet_task
from
official.vision.beta
import
train
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
app
.
run
(
train
.
main
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment