Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
6ddd627a
"doc/git@developer.sourcefind.cn:ox696c/ktransformers.git" did not exist on "8d5ebe49abb32626c0c857e0e3a4cb5fad063c24"
Commit
6ddd627a
authored
Jun 16, 2021
by
Gunho Park
Browse files
Merge en_de to model
parent
e570fda5
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
384 additions
and
418 deletions
+384
-418
official/vision/beta/projects/basnet/common/registry_imports.py
...al/vision/beta/projects/basnet/common/registry_imports.py
+1
-2
official/vision/beta/projects/basnet/modeling/basnet_decoder.py
...al/vision/beta/projects/basnet/modeling/basnet_decoder.py
+0
-196
official/vision/beta/projects/basnet/modeling/basnet_encoder.py
...al/vision/beta/projects/basnet/modeling/basnet_encoder.py
+0
-205
official/vision/beta/projects/basnet/modeling/basnet_model.py
...cial/vision/beta/projects/basnet/modeling/basnet_model.py
+371
-0
official/vision/beta/projects/basnet/modeling/basnet_model_test.py
...vision/beta/projects/basnet/modeling/basnet_model_test.py
+6
-10
official/vision/beta/projects/basnet/modeling/refunet.py
official/vision/beta/projects/basnet/modeling/refunet.py
+4
-1
official/vision/beta/projects/basnet/tasks/basnet.py
official/vision/beta/projects/basnet/tasks/basnet.py
+2
-4
No files found.
official/vision/beta/projects/basnet/common/registry_imports.py
View file @
6ddd627a
...
@@ -17,7 +17,6 @@
...
@@ -17,7 +17,6 @@
# pylint: disable=unused-import
# pylint: disable=unused-import
from
official.vision
import
beta
from
official.vision
import
beta
from
official.vision.beta.projects.basnet.configs
import
basnet
from
official.vision.beta.projects.basnet.configs
import
basnet
from
official.vision.beta.projects.basnet.modeling
import
basnet_encoder
from
official.vision.beta.projects.basnet.modeling
import
basnet_model
from
official.vision.beta.projects.basnet.modeling
import
basnet_decoder
from
official.vision.beta.projects.basnet.modeling
import
refunet
from
official.vision.beta.projects.basnet.modeling
import
refunet
from
official.vision.beta.projects.basnet.tasks
import
basnet
from
official.vision.beta.projects.basnet.tasks
import
basnet
official/vision/beta/projects/basnet/modeling/basnet_decoder.py
deleted
100644 → 0
View file @
e570fda5
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Import libraries
from
typing
import
Mapping
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.vision.beta.projects.basnet.modeling.layers
import
nn_blocks
# nf : num_filters, dr : dilation_rate
# (conv1_nf, conv1_dr, convm_nf, convm_dr, conv2_nf, conv2_dr, scale_factor)
BASNET_BRIDGE_SPECS
=
[
(
512
,
2
,
512
,
2
,
512
,
2
,
32
),
#Sup0, Bridge
]
BASNET_DECODER_SPECS
=
[
(
512
,
1
,
512
,
2
,
512
,
2
,
32
),
#Sup1, stage6d
(
512
,
1
,
512
,
1
,
512
,
1
,
16
),
#Sup2, stage5d
(
512
,
1
,
512
,
1
,
256
,
1
,
8
),
#Sup3, stage4d
(
256
,
1
,
256
,
1
,
128
,
1
,
4
),
#Sup4, stage3d
(
128
,
1
,
128
,
1
,
64
,
1
,
2
),
#Sup5, stage2d
(
64
,
1
,
64
,
1
,
64
,
1
,
1
)
#Sup6, stage1d
]
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
BASNet_Decoder
(
tf
.
keras
.
layers
.
Layer
):
"""Decoder of BASNet.
Boundary-Awar network (BASNet) were proposed in:
[1] Qin, Xuebin, et al.
Basnet: Boundary-aware salient object detection.
"""
def
__init__
(
self
,
use_separable_conv
=
False
,
activation
=
'relu'
,
use_sync_bn
=
False
,
use_bias
=
True
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
kernel_initializer
=
'VarianceScaling'
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
**
kwargs
):
"""BASNet Decoder initialization function.
Args:
use_separable_conv: `bool`, if True use separable convolution for
convolution in BASNet layers.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
use_bias: if True, use bias in convolution.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
**kwargs: keyword arguments to be passed.
"""
super
(
BASNet_Decoder
,
self
).
__init__
(
**
kwargs
)
self
.
_config_dict
=
{
'use_separable_conv'
:
use_separable_conv
,
'activation'
:
activation
,
'use_sync_bn'
:
use_sync_bn
,
'use_bias'
:
use_bias
,
'norm_momentum'
:
norm_momentum
,
'norm_epsilon'
:
norm_epsilon
,
'kernel_initializer'
:
kernel_initializer
,
'kernel_regularizer'
:
kernel_regularizer
,
'bias_regularizer'
:
bias_regularizer
,
}
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
bn_axis
=
-
1
else
:
bn_axis
=
1
self
.
_activation
=
tf_utils
.
get_activation
(
activation
)
self
.
_concat
=
tf
.
keras
.
layers
.
Concatenate
(
axis
=-
1
)
self
.
_sigmoid
=
tf
.
keras
.
layers
.
Activation
(
activation
=
'sigmoid'
)
def
build
(
self
,
input_shape
):
"""Creates the variables of the BASNet decoder."""
if
self
.
_config_dict
[
'use_separable_conv'
]:
conv_op
=
tf
.
keras
.
layers
.
SeparableConv2D
else
:
conv_op
=
tf
.
keras
.
layers
.
Conv2D
conv_kwargs
=
{
'kernel_size'
:
3
,
'strides'
:
1
,
'use_bias'
:
self
.
_config_dict
[
'use_bias'
],
'kernel_initializer'
:
self
.
_config_dict
[
'kernel_initializer'
],
'kernel_regularizer'
:
self
.
_config_dict
[
'kernel_regularizer'
],
'bias_regularizer'
:
self
.
_config_dict
[
'bias_regularizer'
],
}
self
.
_out_convs
=
[]
self
.
_out_usmps
=
[]
# Bridge layers.
self
.
_bdg_convs
=
[]
for
i
,
spec
in
enumerate
(
BASNET_BRIDGE_SPECS
):
blocks
=
[]
for
j
in
range
(
3
):
blocks
.
append
(
nn_blocks
.
ConvBlock
(
filters
=
spec
[
2
*
j
],
dilation_rate
=
spec
[
2
*
j
+
1
],
activation
=
'relu'
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
**
conv_kwargs
))
self
.
_bdg_convs
.
append
(
blocks
)
self
.
_out_convs
.
append
(
conv_op
(
filters
=
1
,
padding
=
'same'
,
**
conv_kwargs
))
self
.
_out_usmps
.
append
(
tf
.
keras
.
layers
.
UpSampling2D
(
size
=
spec
[
6
],
interpolation
=
'bilinear'
))
# Decoder layers.
self
.
_dec_convs
=
[]
for
i
,
spec
in
enumerate
(
BASNET_DECODER_SPECS
):
blocks
=
[]
for
j
in
range
(
3
):
blocks
.
append
(
nn_blocks
.
ConvBlock
(
filters
=
spec
[
2
*
j
],
dilation_rate
=
spec
[
2
*
j
+
1
],
activation
=
'relu'
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
**
conv_kwargs
))
self
.
_dec_convs
.
append
(
blocks
)
self
.
_out_convs
.
append
(
conv_op
(
filters
=
1
,
padding
=
'same'
,
**
conv_kwargs
))
self
.
_out_usmps
.
append
(
tf
.
keras
.
layers
.
UpSampling2D
(
size
=
spec
[
6
],
interpolation
=
'bilinear'
))
def
call
(
self
,
backbone_output
:
Mapping
[
str
,
tf
.
Tensor
]):
levels
=
sorted
(
backbone_output
.
keys
(),
reverse
=
True
)
sup
=
{}
x
=
backbone_output
[
levels
[
0
]]
for
blocks
in
self
.
_bdg_convs
:
for
block
in
blocks
:
x
=
block
(
x
)
sup
[
'0'
]
=
x
for
i
,
blocks
in
enumerate
(
self
.
_dec_convs
):
x
=
self
.
_concat
([
x
,
backbone_output
[
levels
[
i
]]])
for
block
in
blocks
:
x
=
block
(
x
)
sup
[
str
(
i
+
1
)]
=
x
x
=
tf
.
keras
.
layers
.
UpSampling2D
(
size
=
2
,
interpolation
=
'bilinear'
)(
x
)
for
i
,
(
conv
,
usmp
)
in
enumerate
(
zip
(
self
.
_out_convs
,
self
.
_out_usmps
)):
sup
[
str
(
i
)]
=
self
.
_sigmoid
(
usmp
(
conv
(
sup
[
str
(
i
)])))
self
.
_output_specs
=
{
str
(
order
):
sup
[
str
(
order
)].
get_shape
()
for
order
in
range
(
0
,
len
(
BASNET_DECODER_SPECS
))
}
return
sup
def
get_config
(
self
):
return
self
.
_config_dict
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
@
property
def
output_specs
(
self
):
"""A dict of {order: TensorShape} pairs for the model output."""
return
self
.
_output_specs
official/vision/beta/projects/basnet/modeling/basnet_encoder.py
deleted
100644 → 0
View file @
e570fda5
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Import libraries
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.vision.beta.modeling.backbones
import
factory
from
official.vision.beta.projects.basnet.modeling.layers
import
nn_blocks
# Specifications for BASNet encoder.
# Each element in the block configuration is in the following format:
# (num_filters, stride, block_repeats, maxpool)
BASNET_ENCODER_SPECS
=
[
(
64
,
1
,
3
,
0
),
#ResNet-34,
(
128
,
2
,
4
,
0
),
#ResNet-34,
(
256
,
2
,
6
,
0
),
#ResNet-34,
(
512
,
2
,
3
,
1
),
#ResNet-34,
(
512
,
1
,
3
,
1
),
#BASNet,
(
512
,
1
,
3
,
0
),
#BASNet,
]
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
BASNet_Encoder
(
tf
.
keras
.
Model
):
"""BASNet Encoder
Boundary-Awar network (BASNet) were proposed in:
[1] Qin, Xuebin, et al.
Basnet: Boundary-aware salient object detection.
"""
def
__init__
(
self
,
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
]),
activation
=
'relu'
,
use_sync_bn
=
False
,
use_bias
=
True
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
kernel_initializer
=
'VarianceScaling'
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
**
kwargs
):
"""BASNet_Encoder initialization function.
Args:
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
use_bias: if True, use bias in conv2d.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
**kwargs: keyword arguments to be passed.
"""
self
.
_input_specs
=
input_specs
self
.
_use_sync_bn
=
use_sync_bn
self
.
_use_bias
=
use_bias
self
.
_activation
=
activation
self
.
_norm_momentum
=
norm_momentum
self
.
_norm_epsilon
=
norm_epsilon
if
use_sync_bn
:
self
.
_norm
=
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
else
:
self
.
_norm
=
tf
.
keras
.
layers
.
BatchNormalization
self
.
_kernel_initializer
=
kernel_initializer
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_bias_regularizer
=
bias_regularizer
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
bn_axis
=
-
1
else
:
bn_axis
=
1
# Build BASNet.
inputs
=
tf
.
keras
.
Input
(
shape
=
input_specs
.
shape
[
1
:])
x
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
64
,
kernel_size
=
3
,
strides
=
1
,
use_bias
=
self
.
_use_bias
,
padding
=
'same'
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)(
inputs
)
x
=
self
.
_norm
(
axis
=
bn_axis
,
momentum
=
norm_momentum
,
epsilon
=
norm_epsilon
)(
x
)
x
=
tf_utils
.
get_activation
(
activation
)(
x
)
endpoints
=
{}
for
i
,
spec
in
enumerate
(
BASNET_ENCODER_SPECS
):
x
=
self
.
_block_group
(
inputs
=
x
,
filters
=
spec
[
0
],
strides
=
spec
[
1
],
block_repeats
=
spec
[
2
],
name
=
'block_group_l{}'
.
format
(
i
+
2
))
endpoints
[
str
(
i
)]
=
x
if
spec
[
3
]:
x
=
tf
.
keras
.
layers
.
MaxPool2D
(
pool_size
=
2
,
strides
=
2
,
padding
=
'same'
)(
x
)
self
.
_output_specs
=
{
l
:
endpoints
[
l
].
get_shape
()
for
l
in
endpoints
}
super
(
BASNet_Encoder
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
endpoints
,
**
kwargs
)
def
_block_group
(
self
,
inputs
,
filters
,
strides
,
block_repeats
=
1
,
name
=
'block_group'
):
"""Creates one group of residual blocks for the BASNet encoder model.
Args:
inputs: `Tensor` of size `[batch, channels, height, width]`.
filters: `int` number of filters for the first convolution of the layer.
strides: `int` stride to use for the first convolution of the layer. If
greater than 1, this layer will downsample the input.
block_repeats: `int` number of blocks contained in the layer.
name: `str`name for the block.
Returns:
The output `Tensor` of the block layer.
"""
x
=
nn_blocks
.
ResBlock
(
filters
=
filters
,
strides
=
strides
,
use_projection
=
True
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activation
=
self
.
_activation
,
use_sync_bn
=
self
.
_use_sync_bn
,
use_bias
=
self
.
_use_bias
,
norm_momentum
=
self
.
_norm_momentum
,
norm_epsilon
=
self
.
_norm_epsilon
)(
inputs
)
for
_
in
range
(
1
,
block_repeats
):
x
=
nn_block
.
ResBlock
(
filters
=
filters
,
strides
=
1
,
use_projection
=
False
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activation
=
self
.
_activation
,
use_sync_bn
=
self
.
_use_sync_bn
,
use_bias
=
self
.
_use_bias
,
norm_momentum
=
self
.
_norm_momentum
,
norm_epsilon
=
self
.
_norm_epsilon
)(
x
)
return
tf
.
identity
(
x
,
name
=
name
)
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
@
property
def
output_specs
(
self
):
"""A dict of {level: TensorShape} pairs for the model output."""
return
self
.
_output_specs
@
factory
.
register_backbone_builder
(
'basnet_encoder'
)
def
build_basnet_encoder
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
model_config
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
"""Builds BASNet Encoder backbone from a config."""
backbone_type
=
model_config
.
backbone
.
type
backbone_cfg
=
model_config
.
backbone
.
get
()
norm_activation_config
=
model_config
.
norm_activation
assert
backbone_type
==
'basnet_encoder'
,
(
f
'Inconsistent backbone type '
f
'
{
backbone_type
}
'
)
return
BASNet_Encoder
(
input_specs
=
input_specs
,
activation
=
norm_activation_config
.
activation
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
kernel_regularizer
=
l2_regularizer
)
official/vision/beta/projects/basnet/modeling/basnet_model.py
View file @
6ddd627a
...
@@ -15,13 +15,53 @@
...
@@ -15,13 +15,53 @@
"""Build BASNet models."""
"""Build BASNet models."""
# Import libraries
# Import libraries
from
typing
import
Mapping
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.vision.beta.modeling.backbones
import
factory
from
official.vision.beta.projects.basnet.modeling.layers
import
nn_blocks
# Specifications for BASNet encoder.
# Each element in the block configuration is in the following format:
# (num_filters, stride, block_repeats, maxpool)
BASNET_ENCODER_SPECS
=
[
(
64
,
1
,
3
,
0
),
#ResNet-34,
(
128
,
2
,
4
,
0
),
#ResNet-34,
(
256
,
2
,
6
,
0
),
#ResNet-34,
(
512
,
2
,
3
,
1
),
#ResNet-34,
(
512
,
1
,
3
,
1
),
#BASNet,
(
512
,
1
,
3
,
0
),
#BASNet,
]
# Specifications for BASNet decoder.
# Each element in the block configuration is in the following format:
# (conv1_nf, conv1_dr, convm_nf, convm_dr, conv2_nf, conv2_dr, scale_factor)
# nf : num_filters, dr : dilation_rate
BASNET_BRIDGE_SPECS
=
[
(
512
,
2
,
512
,
2
,
512
,
2
,
32
),
#Sup0, Bridge
]
BASNET_DECODER_SPECS
=
[
(
512
,
1
,
512
,
2
,
512
,
2
,
32
),
#Sup1, stage6d
(
512
,
1
,
512
,
1
,
512
,
1
,
16
),
#Sup2, stage5d
(
512
,
1
,
512
,
1
,
256
,
1
,
8
),
#Sup3, stage4d
(
256
,
1
,
256
,
1
,
128
,
1
,
4
),
#Sup4, stage3d
(
128
,
1
,
128
,
1
,
64
,
1
,
2
),
#Sup5, stage2d
(
64
,
1
,
64
,
1
,
64
,
1
,
1
)
#Sup6, stage1d
]
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
BASNetModel
(
tf
.
keras
.
Model
):
class
BASNetModel
(
tf
.
keras
.
Model
):
"""A BASNet model.
"""A BASNet model.
Boundary-Awar network (BASNet) were proposed in:
[1] Qin, Xuebin, et al.
Basnet: Boundary-aware salient object detection.
Input images are passed through backbone first. Decoder network is then
Input images are passed through backbone first. Decoder network is then
applied, and finally, refinement module is applied on the output of the
applied, and finally, refinement module is applied on the output of the
decoder network.
decoder network.
...
@@ -80,3 +120,334 @@ class BASNetModel(tf.keras.Model):
...
@@ -80,3 +120,334 @@ class BASNetModel(tf.keras.Model):
@
classmethod
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
return
cls
(
**
config
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
BASNet_Encoder
(
tf
.
keras
.
Model
):
"""BASNet encoder"""
def
__init__
(
self
,
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
]),
activation
=
'relu'
,
use_sync_bn
=
False
,
use_bias
=
True
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
kernel_initializer
=
'VarianceScaling'
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
**
kwargs
):
"""BASNet encoder initialization function.
Args:
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
use_bias: if True, use bias in conv2d.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
**kwargs: keyword arguments to be passed.
"""
self
.
_input_specs
=
input_specs
self
.
_use_sync_bn
=
use_sync_bn
self
.
_use_bias
=
use_bias
self
.
_activation
=
activation
self
.
_norm_momentum
=
norm_momentum
self
.
_norm_epsilon
=
norm_epsilon
if
use_sync_bn
:
self
.
_norm
=
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
else
:
self
.
_norm
=
tf
.
keras
.
layers
.
BatchNormalization
self
.
_kernel_initializer
=
kernel_initializer
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_bias_regularizer
=
bias_regularizer
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
bn_axis
=
-
1
else
:
bn_axis
=
1
# Build BASNet.
inputs
=
tf
.
keras
.
Input
(
shape
=
input_specs
.
shape
[
1
:])
x
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
64
,
kernel_size
=
3
,
strides
=
1
,
use_bias
=
self
.
_use_bias
,
padding
=
'same'
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)(
inputs
)
x
=
self
.
_norm
(
axis
=
bn_axis
,
momentum
=
norm_momentum
,
epsilon
=
norm_epsilon
)(
x
)
x
=
tf_utils
.
get_activation
(
activation
)(
x
)
endpoints
=
{}
for
i
,
spec
in
enumerate
(
BASNET_ENCODER_SPECS
):
x
=
self
.
_block_group
(
inputs
=
x
,
filters
=
spec
[
0
],
strides
=
spec
[
1
],
block_repeats
=
spec
[
2
],
name
=
'block_group_l{}'
.
format
(
i
+
2
))
endpoints
[
str
(
i
)]
=
x
if
spec
[
3
]:
x
=
tf
.
keras
.
layers
.
MaxPool2D
(
pool_size
=
2
,
strides
=
2
,
padding
=
'same'
)(
x
)
self
.
_output_specs
=
{
l
:
endpoints
[
l
].
get_shape
()
for
l
in
endpoints
}
super
(
BASNet_Encoder
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
endpoints
,
**
kwargs
)
def
_block_group
(
self
,
inputs
,
filters
,
strides
,
block_repeats
=
1
,
name
=
'block_group'
):
"""Creates one group of residual blocks for the BASNet encoder model.
Args:
inputs: `Tensor` of size `[batch, channels, height, width]`.
filters: `int` number of filters for the first convolution of the layer.
strides: `int` stride to use for the first convolution of the layer. If
greater than 1, this layer will downsample the input.
block_repeats: `int` number of blocks contained in the layer.
name: `str`name for the block.
Returns:
The output `Tensor` of the block layer.
"""
x
=
nn_blocks
.
ResBlock
(
filters
=
filters
,
strides
=
strides
,
use_projection
=
True
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activation
=
self
.
_activation
,
use_sync_bn
=
self
.
_use_sync_bn
,
use_bias
=
self
.
_use_bias
,
norm_momentum
=
self
.
_norm_momentum
,
norm_epsilon
=
self
.
_norm_epsilon
)(
inputs
)
for
_
in
range
(
1
,
block_repeats
):
x
=
nn_blocks
.
ResBlock
(
filters
=
filters
,
strides
=
1
,
use_projection
=
False
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activation
=
self
.
_activation
,
use_sync_bn
=
self
.
_use_sync_bn
,
use_bias
=
self
.
_use_bias
,
norm_momentum
=
self
.
_norm_momentum
,
norm_epsilon
=
self
.
_norm_epsilon
)(
x
)
return
tf
.
identity
(
x
,
name
=
name
)
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
@
property
def
output_specs
(
self
):
"""A dict of {level: TensorShape} pairs for the model output."""
return
self
.
_output_specs
@
factory
.
register_backbone_builder
(
'basnet_encoder'
)
def
build_basnet_encoder
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
model_config
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
"""Builds BASNet Encoder backbone from a config."""
backbone_type
=
model_config
.
backbone
.
type
backbone_cfg
=
model_config
.
backbone
.
get
()
norm_activation_config
=
model_config
.
norm_activation
assert
backbone_type
==
'basnet_encoder'
,
(
f
'Inconsistent backbone type '
f
'
{
backbone_type
}
'
)
return
BASNet_Encoder
(
input_specs
=
input_specs
,
activation
=
norm_activation_config
.
activation
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
kernel_regularizer
=
l2_regularizer
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
BASNet_Decoder
(
tf
.
keras
.
layers
.
Layer
):
"""BASNet decoder."""
def
__init__
(
self
,
use_separable_conv
=
False
,
activation
=
'relu'
,
use_sync_bn
=
False
,
use_bias
=
True
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
kernel_initializer
=
'VarianceScaling'
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
**
kwargs
):
"""BASNet decoder initialization function.
Args:
use_separable_conv: `bool`, if True use separable convolution for
convolution in BASNet layers.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
use_bias: if True, use bias in convolution.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
**kwargs: keyword arguments to be passed.
"""
super
(
BASNet_Decoder
,
self
).
__init__
(
**
kwargs
)
self
.
_config_dict
=
{
'use_separable_conv'
:
use_separable_conv
,
'activation'
:
activation
,
'use_sync_bn'
:
use_sync_bn
,
'use_bias'
:
use_bias
,
'norm_momentum'
:
norm_momentum
,
'norm_epsilon'
:
norm_epsilon
,
'kernel_initializer'
:
kernel_initializer
,
'kernel_regularizer'
:
kernel_regularizer
,
'bias_regularizer'
:
bias_regularizer
,
}
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
bn_axis
=
-
1
else
:
bn_axis
=
1
self
.
_activation
=
tf_utils
.
get_activation
(
activation
)
self
.
_concat
=
tf
.
keras
.
layers
.
Concatenate
(
axis
=-
1
)
self
.
_sigmoid
=
tf
.
keras
.
layers
.
Activation
(
activation
=
'sigmoid'
)
def
build
(
self
,
input_shape
):
"""Creates the variables of the BASNet decoder."""
if
self
.
_config_dict
[
'use_separable_conv'
]:
conv_op
=
tf
.
keras
.
layers
.
SeparableConv2D
else
:
conv_op
=
tf
.
keras
.
layers
.
Conv2D
conv_kwargs
=
{
'kernel_size'
:
3
,
'strides'
:
1
,
'use_bias'
:
self
.
_config_dict
[
'use_bias'
],
'kernel_initializer'
:
self
.
_config_dict
[
'kernel_initializer'
],
'kernel_regularizer'
:
self
.
_config_dict
[
'kernel_regularizer'
],
'bias_regularizer'
:
self
.
_config_dict
[
'bias_regularizer'
],
}
self
.
_out_convs
=
[]
self
.
_out_usmps
=
[]
# Bridge layers.
self
.
_bdg_convs
=
[]
for
i
,
spec
in
enumerate
(
BASNET_BRIDGE_SPECS
):
blocks
=
[]
for
j
in
range
(
3
):
blocks
.
append
(
nn_blocks
.
ConvBlock
(
filters
=
spec
[
2
*
j
],
dilation_rate
=
spec
[
2
*
j
+
1
],
activation
=
'relu'
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
**
conv_kwargs
))
self
.
_bdg_convs
.
append
(
blocks
)
self
.
_out_convs
.
append
(
conv_op
(
filters
=
1
,
padding
=
'same'
,
**
conv_kwargs
))
self
.
_out_usmps
.
append
(
tf
.
keras
.
layers
.
UpSampling2D
(
size
=
spec
[
6
],
interpolation
=
'bilinear'
))
# Decoder layers.
self
.
_dec_convs
=
[]
for
i
,
spec
in
enumerate
(
BASNET_DECODER_SPECS
):
blocks
=
[]
for
j
in
range
(
3
):
blocks
.
append
(
nn_blocks
.
ConvBlock
(
filters
=
spec
[
2
*
j
],
dilation_rate
=
spec
[
2
*
j
+
1
],
activation
=
'relu'
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
**
conv_kwargs
))
self
.
_dec_convs
.
append
(
blocks
)
self
.
_out_convs
.
append
(
conv_op
(
filters
=
1
,
padding
=
'same'
,
**
conv_kwargs
))
self
.
_out_usmps
.
append
(
tf
.
keras
.
layers
.
UpSampling2D
(
size
=
spec
[
6
],
interpolation
=
'bilinear'
))
def
call
(
self
,
backbone_output
:
Mapping
[
str
,
tf
.
Tensor
]):
"""Forward pass of the BASNet decoder.
Args:
backbone_output: A `dict` of tensors
- key: A `str` of the level of the multilevel features.
- values: A `tf.Tensor` of the feature map tensors, whose shape is
[batch, height_l, width_l, channels].
Returns:
sup: A `dict` of tensors
- key: A `str` of the level of the multilevel features.
- values: A `tf.Tensor` of the feature map tensors, whose shape is
[batch, height_l, width_l, channels].
"""
levels
=
sorted
(
backbone_output
.
keys
(),
reverse
=
True
)
sup
=
{}
x
=
backbone_output
[
levels
[
0
]]
for
blocks
in
self
.
_bdg_convs
:
for
block
in
blocks
:
x
=
block
(
x
)
sup
[
'0'
]
=
x
for
i
,
blocks
in
enumerate
(
self
.
_dec_convs
):
x
=
self
.
_concat
([
x
,
backbone_output
[
levels
[
i
]]])
for
block
in
blocks
:
x
=
block
(
x
)
sup
[
str
(
i
+
1
)]
=
x
x
=
tf
.
keras
.
layers
.
UpSampling2D
(
size
=
2
,
interpolation
=
'bilinear'
)(
x
)
for
i
,
(
conv
,
usmp
)
in
enumerate
(
zip
(
self
.
_out_convs
,
self
.
_out_usmps
)):
sup
[
str
(
i
)]
=
self
.
_sigmoid
(
usmp
(
conv
(
sup
[
str
(
i
)])))
self
.
_output_specs
=
{
str
(
order
):
sup
[
str
(
order
)].
get_shape
()
for
order
in
range
(
0
,
len
(
BASNET_DECODER_SPECS
))
}
return
sup
def
get_config
(
self
):
return
self
.
_config_dict
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
@
property
def
output_specs
(
self
):
"""A dict of {order: TensorShape} pairs for the model output."""
return
self
.
_output_specs
official/vision/beta/projects/basnet/modeling/basnet_model_test.py
View file @
6ddd627a
...
@@ -19,9 +19,7 @@ from absl.testing import parameterized
...
@@ -19,9 +19,7 @@ from absl.testing import parameterized
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.vision.beta.projects.basnet.modeling
import
basnet_encoder
from
official.vision.beta.projects.basnet.modeling
import
basnet_model
from
official.vision.beta.projects.basnet.modeling
import
basnet_model
from
official.vision.beta.projects.basnet.modeling
import
basnet_decoder
from
official.vision.beta.projects.basnet.modeling
import
refunet
from
official.vision.beta.projects.basnet.modeling
import
refunet
...
@@ -38,9 +36,8 @@ class BASNetNetworkTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -38,9 +36,8 @@ class BASNetNetworkTest(parameterized.TestCase, tf.test.TestCase):
inputs
=
np
.
random
.
rand
(
2
,
input_size
,
input_size
,
3
)
inputs
=
np
.
random
.
rand
(
2
,
input_size
,
input_size
,
3
)
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
backbone
=
basnet_encoder
.
BASNet_Encoder
()
backbone
=
basnet_model
.
BASNet_Encoder
()
decoder
=
basnet_decoder
.
BASNet_Decoder
(
decoder
=
basnet_model
.
BASNet_Decoder
()
input_specs
=
backbone
.
output_specs
)
refinement
=
refunet
.
RefUnet
()
refinement
=
refunet
.
RefUnet
()
model
=
basnet_model
.
BASNetModel
(
model
=
basnet_model
.
BASNetModel
(
...
@@ -50,16 +47,15 @@ class BASNetNetworkTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -50,16 +47,15 @@ class BASNetNetworkTest(parameterized.TestCase, tf.test.TestCase):
)
)
sigmoids
=
model
(
inputs
)
sigmoids
=
model
(
inputs
)
#print(sigmoids['ref'].numpy().shape
)
levels
=
sorted
(
sigmoids
.
keys
()
)
self
.
assertAllEqual
(
self
.
assertAllEqual
(
[
2
,
input_size
,
input_size
,
1
],
[
2
,
input_size
,
input_size
,
1
],
sigmoids
[
'ref'
].
numpy
().
shape
)
sigmoids
[
levels
[
-
1
]
].
numpy
().
shape
)
def
test_serialize_deserialize
(
self
):
def
test_serialize_deserialize
(
self
):
"""Validate the network can be serialized and deserialized."""
"""Validate the network can be serialized and deserialized."""
backbone
=
basnet_encoder
.
BASNet_Encoder
()
backbone
=
basnet_model
.
BASNet_Encoder
()
decoder
=
basnet_decoder
.
BASNet_Decoder
(
decoder
=
basnet_model
.
BASNet_Decoder
()
input_specs
=
backbone
.
output_specs
)
refinement
=
refunet
.
RefUnet
()
refinement
=
refunet
.
RefUnet
()
model
=
basnet_model
.
BASNetModel
(
model
=
basnet_model
.
BASNetModel
(
...
...
official/vision/beta/projects/basnet/modeling/refunet.py
View file @
6ddd627a
...
@@ -22,7 +22,7 @@ from official.vision.beta.projects.basnet.modeling.layers import nn_blocks
...
@@ -22,7 +22,7 @@ from official.vision.beta.projects.basnet.modeling.layers import nn_blocks
class
RefUnet
(
tf
.
keras
.
layers
.
Layer
):
class
RefUnet
(
tf
.
keras
.
layers
.
Layer
):
"""Residual Refinement Module of BASNet.
"""Residual Refinement Module of BASNet.
Boundary-Awar network (BASNet) were proposed in:
Boundary-Awar
e
network (BASNet) were proposed in:
[1] Qin, Xuebin, et al.
[1] Qin, Xuebin, et al.
Basnet: Boundary-aware salient object detection.
Basnet: Boundary-aware salient object detection.
"""
"""
...
@@ -141,6 +141,9 @@ class RefUnet(tf.keras.layers.Layer):
...
@@ -141,6 +141,9 @@ class RefUnet(tf.keras.layers.Layer):
self
.
_output_specs
=
output
.
get_shape
()
self
.
_output_specs
=
output
.
get_shape
()
return
output
return
output
def
get_config
(
self
):
return
self
.
_config_dict
@
classmethod
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
return
cls
(
**
config
)
...
...
official/vision/beta/projects/basnet/tasks/basnet.py
View file @
6ddd627a
...
@@ -30,9 +30,7 @@ from official.vision.beta.projects.basnet.evaluation import relax_f
...
@@ -30,9 +30,7 @@ from official.vision.beta.projects.basnet.evaluation import relax_f
from
official.vision.beta.projects.basnet.evaluation
import
mae
from
official.vision.beta.projects.basnet.evaluation
import
mae
from
official.vision.beta.projects.basnet.losses
import
basnet_losses
from
official.vision.beta.projects.basnet.losses
import
basnet_losses
from
official.vision.beta.projects.basnet.modeling
import
basnet_encoder
from
official.vision.beta.projects.basnet.modeling
import
basnet_model
from
official.vision.beta.projects.basnet.modeling
import
basnet_model
from
official.vision.beta.projects.basnet.modeling
import
basnet_decoder
from
official.vision.beta.projects.basnet.modeling
import
refunet
from
official.vision.beta.projects.basnet.modeling
import
refunet
def
build_basnet_model
(
def
build_basnet_model
(
...
@@ -40,12 +38,12 @@ def build_basnet_model(
...
@@ -40,12 +38,12 @@ def build_basnet_model(
model_config
:
exp_cfg
.
BASNetModel
,
model_config
:
exp_cfg
.
BASNetModel
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
):
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
):
"""Builds BASNet model."""
"""Builds BASNet model."""
backbone
=
basnet_
enc
ode
r
.
BASNet_Encoder
(
backbone
=
basnet_
m
ode
l
.
BASNet_Encoder
(
input_specs
=
input_specs
)
input_specs
=
input_specs
)
norm_activation_config
=
model_config
.
norm_activation
norm_activation_config
=
model_config
.
norm_activation
decoder
=
basnet_
dec
ode
r
.
BASNet_Decoder
(
decoder
=
basnet_
m
ode
l
.
BASNet_Decoder
(
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment