Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
25bf4592
"git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "96236b79684a8b883fd71b95a85a9bfeec3c53aa"
Commit
25bf4592
authored
Jun 15, 2021
by
Gunho Park
Browse files
keras.Model to keras.layer
parent
13dffa31
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
191 additions
and
217 deletions
+191
-217
official/vision/beta/projects/basnet/modeling/basnet_decoder.py
...al/vision/beta/projects/basnet/modeling/basnet_decoder.py
+92
-75
official/vision/beta/projects/basnet/modeling/basnet_encoder.py
...al/vision/beta/projects/basnet/modeling/basnet_encoder.py
+22
-29
official/vision/beta/projects/basnet/modeling/refunet.py
official/vision/beta/projects/basnet/modeling/refunet.py
+77
-113
No files found.
official/vision/beta/projects/basnet/modeling/basnet_decoder.py
View file @
25bf4592
...
...
@@ -12,14 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Decoder of BASNet.
Boundary-Awar network (BASNet) were proposed in:
[1] Qin, Xuebin, et al.
Basnet: Boundary-aware salient object detection.
"""
# Import libraries
from
typing
import
Mapping
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
...
...
@@ -27,8 +22,11 @@ from official.vision.beta.projects.basnet.modeling.layers import nn_blocks
# nf : num_filters, dr : dilation_rate
# (conv1_nf, conv1_dr, convm_nf, convm_dr, conv2_nf, conv2_dr, scale_factor)
BASNET_BRIDGE_SPECS
=
[
(
512
,
2
,
512
,
2
,
512
,
2
,
32
),
#Sup0, Bridge
]
BASNET_DECODER_SPECS
=
[
(
512
,
2
,
512
,
2
,
512
,
2
,
32
),
#Bridge(Sup0)
(
512
,
1
,
512
,
2
,
512
,
2
,
32
),
#Sup1, stage6d
(
512
,
1
,
512
,
1
,
512
,
1
,
16
),
#Sup2, stage5d
(
512
,
1
,
512
,
1
,
256
,
1
,
8
),
#Sup3, stage4d
...
...
@@ -37,12 +35,17 @@ BASNET_DECODER_SPECS = [
(
64
,
1
,
64
,
1
,
64
,
1
,
1
)
#Sup6, stage1d
]
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
BASNet_Decoder
(
tf
.
keras
.
Model
):
"""BASNet Decoder."""
class
BASNet_Decoder
(
tf
.
keras
.
layers
.
Layer
):
"""Decoder of BASNet.
Boundary-Awar network (BASNet) were proposed in:
[1] Qin, Xuebin, et al.
Basnet: Boundary-aware salient object detection.
"""
def
__init__
(
self
,
input_specs
,
use_separable_conv
=
False
,
activation
=
'relu'
,
use_sync_bn
=
False
,
...
...
@@ -56,12 +59,11 @@ class BASNet_Decoder(tf.keras.Model):
"""BASNet Decoder initialization function.
Args:
input_specs: `dict` input specifications. A dictionary consists of
{level: TensorShape} from a backbone.
use_separable_conv: `bool`, if True use separable convolution for
convolution in BASNet layers.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
use_bias: if True, use bias in convolution.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
...
...
@@ -70,8 +72,8 @@ class BASNet_Decoder(tf.keras.Model):
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
**kwargs: keyword arguments to be passed.
"""
super
(
BASNet_Decoder
,
self
).
__init__
(
**
kwargs
)
self
.
_config_dict
=
{
'input_specs'
:
input_specs
,
'use_separable_conv'
:
use_separable_conv
,
'activation'
:
activation
,
'use_sync_bn'
:
use_sync_bn
,
...
...
@@ -82,89 +84,104 @@ class BASNet_Decoder(tf.keras.Model):
'kernel_regularizer'
:
kernel_regularizer
,
'bias_regularizer'
:
bias_regularizer
,
}
if
use_separable_conv
:
conv2d
=
tf
.
keras
.
layers
.
SeparableConv2D
else
:
conv2d
=
tf
.
keras
.
layers
.
Conv2D
if
use_sync_bn
:
norm
=
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
else
:
norm
=
tf
.
keras
.
layers
.
BatchNormalization
activation_fn
=
tf
.
keras
.
layers
.
Activation
(
tf_utils
.
get_activation
(
activation
))
# Build input feature pyramid.
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
bn_axis
=
-
1
else
:
bn_axis
=
1
self
.
_activation
=
tf_utils
.
get_activation
(
activation
)
self
.
_concat
=
tf
.
keras
.
layers
.
Concatenate
(
axis
=-
1
)
self
.
_sigmoid
=
tf
.
keras
.
layers
.
Activation
(
activation
=
'sigmoid'
)
def
build
(
self
,
input_shape
):
"""Creates the variables of the BASNet decoder."""
if
self
.
_config_dict
[
'use_separable_conv'
]:
conv_op
=
tf
.
keras
.
layers
.
SeparableConv2D
else
:
conv_op
=
tf
.
keras
.
layers
.
Conv2D
conv_kwargs
=
{
'kernel_size'
:
3
,
'strides'
:
1
,
'use_bias'
:
self
.
_config_dict
[
'use_bias'
],
'kernel_initializer'
:
self
.
_config_dict
[
'kernel_initializer'
],
'kernel_regularizer'
:
self
.
_config_dict
[
'kernel_regularizer'
],
'bias_regularizer'
:
self
.
_config_dict
[
'bias_regularizer'
],
}
# Get input feature pyramid from backbone.
inputs
=
self
.
_build_input_pyramid
(
input_specs
)
levels
=
sorted
(
inputs
.
keys
(),
reverse
=
True
)
self
.
_out_convs
=
[]
self
.
_out_usmps
=
[]
sup
=
{}
# Bridge layers.
self
.
_bdg_convs
=
[]
for
i
,
spec
in
enumerate
(
BASNET_BRIDGE_SPECS
):
blocks
=
[]
for
j
in
range
(
3
):
blocks
.
append
(
nn_blocks
.
ConvBlock
(
filters
=
spec
[
2
*
j
],
dilation_rate
=
spec
[
2
*
j
+
1
],
activation
=
'relu'
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
**
conv_kwargs
))
self
.
_bdg_convs
.
append
(
blocks
)
self
.
_out_convs
.
append
(
conv_op
(
filters
=
1
,
padding
=
'same'
,
**
conv_kwargs
))
self
.
_out_usmps
.
append
(
tf
.
keras
.
layers
.
UpSampling2D
(
size
=
spec
[
6
],
interpolation
=
'bilinear'
))
# Decoder layers.
self
.
_dec_convs
=
[]
for
i
,
spec
in
enumerate
(
BASNET_DECODER_SPECS
):
if
i
==
0
:
#x = inputs['5'] # Bridge input
x
=
inputs
[
levels
[
0
]]
# Bridge input
# str(levels[-1]) ??
else
:
x
=
tf
.
keras
.
layers
.
Concatenate
(
axis
=-
1
)([
x
,
inputs
[
levels
[
i
-
1
]]])
blocks
=
[]
for
j
in
range
(
3
):
x
=
nn_blocks
.
ConvBlock
(
blocks
.
append
(
nn_blocks
.
ConvBlock
(
filters
=
spec
[
2
*
j
],
kernel_size
=
3
,
strides
=
1
,
dilation_rate
=
spec
[
2
*
j
+
1
],
kernel_initializer
=
kernel_initializer
,
kernel_regularizer
=
kernel_regularizer
,
bias_regularizer
=
bias_regularizer
,
activation
=
'relu'
,
use_sync_bn
=
use_sync_bn
,
use_bias
=
use_bias
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
)(
x
)
output
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
1
,
kernel_size
=
3
,
strides
=
1
,
use_bias
=
use_bias
,
padding
=
'same'
,
kernel_initializer
=
kernel_initializer
,
kernel_regularizer
=
kernel_regularizer
,
bias_regularizer
=
bias_regularizer
)(
x
)
output
=
tf
.
keras
.
layers
.
UpSampling2D
(
norm_epsilon
=
0.001
,
**
conv_kwargs
))
self
.
_dec_convs
.
append
(
blocks
)
self
.
_out_convs
.
append
(
conv_op
(
filters
=
1
,
padding
=
'same'
,
**
conv_kwargs
))
self
.
_out_usmps
.
append
(
tf
.
keras
.
layers
.
UpSampling2D
(
size
=
spec
[
6
],
interpolation
=
'bilinear'
)(
output
)
output
=
tf
.
keras
.
layers
.
Activation
(
activation
=
'sigmoid'
)(
output
)
sup
[
str
(
i
)]
=
output
if
i
!=
0
:
x
=
tf
.
keras
.
layers
.
UpSampling2D
(
size
=
2
,
interpolation
=
'bilinear'
)(
x
)
))
def
call
(
self
,
backbone_output
:
Mapping
[
str
,
tf
.
Tensor
]):
levels
=
sorted
(
backbone_output
.
keys
(),
reverse
=
True
)
sup
=
{}
x
=
backbone_output
[
levels
[
0
]]
for
blocks
in
self
.
_bdg_convs
:
for
block
in
blocks
:
x
=
block
(
x
)
sup
[
'0'
]
=
x
for
i
,
blocks
in
enumerate
(
self
.
_dec_convs
):
x
=
self
.
_concat
([
x
,
backbone_output
[
levels
[
i
]]])
for
block
in
blocks
:
x
=
block
(
x
)
sup
[
str
(
i
+
1
)]
=
x
x
=
tf
.
keras
.
layers
.
UpSampling2D
(
size
=
2
,
interpolation
=
'bilinear'
)(
x
)
for
i
,
(
conv
,
usmp
)
in
enumerate
(
zip
(
self
.
_out_convs
,
self
.
_out_usmps
)):
sup
[
str
(
i
)]
=
self
.
_sigmoid
(
usmp
(
conv
(
sup
[
str
(
i
)])))
self
.
_output_specs
=
{
str
(
order
):
sup
[
str
(
order
)].
get_shape
()
for
order
in
range
(
0
,
len
(
BASNET_DECODER_SPECS
))
}
super
(
BASNet_Decoder
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
sup
,
**
kwargs
)
def
_build_input_pyramid
(
self
,
input_specs
):
assert
isinstance
(
input_specs
,
dict
)
inputs
=
{}
for
level
,
spec
in
input_specs
.
items
():
inputs
[
level
]
=
tf
.
keras
.
Input
(
shape
=
spec
[
1
:])
return
inputs
return
sup
def
get_config
(
self
):
return
self
.
_config_dict
...
...
official/vision/beta/projects/basnet/modeling/basnet_encoder.py
View file @
25bf4592
...
...
@@ -12,13 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BASNet Encoder
Boundary-Awar network (BASNet) were proposed in:
[1] Qin, Xuebin, et al.
Basnet: Boundary-aware salient object detection.
"""
# Import libraries
import
tensorflow
as
tf
...
...
@@ -29,19 +22,26 @@ from official.vision.beta.projects.basnet.modeling.layers import nn_blocks
# Specifications for BASNet encoder.
# Each element in the block configuration is in the following format:
# (
block_fn,
num_filters, stride, block_repeats, maxpool)
# (num_filters, stride, block_repeats, maxpool)
BASNET_ENCODER_SPECS
=
[
(
'residual'
,
64
,
1
,
3
,
0
),
#ResNet-34,
(
'residual'
,
128
,
2
,
4
,
0
),
#ResNet-34,
(
'residual'
,
256
,
2
,
6
,
0
),
#ResNet-34,
(
'residual'
,
512
,
2
,
3
,
1
),
#ResNet-34,
(
'residual'
,
512
,
1
,
3
,
1
),
#BASNet,
(
'residual'
,
512
,
1
,
3
,
0
),
#BASNet,
(
64
,
1
,
3
,
0
),
#ResNet-34,
(
128
,
2
,
4
,
0
),
#ResNet-34,
(
256
,
2
,
6
,
0
),
#ResNet-34,
(
512
,
2
,
3
,
1
),
#ResNet-34,
(
512
,
1
,
3
,
1
),
#BASNet,
(
512
,
1
,
3
,
0
),
#BASNet,
]
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
BASNet_Encoder
(
tf
.
keras
.
Model
):
"""BASNet Encoder
Boundary-Awar network (BASNet) were proposed in:
[1] Qin, Xuebin, et al.
Basnet: Boundary-aware salient object detection.
"""
def
__init__
(
self
,
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
]),
...
...
@@ -54,7 +54,7 @@ class BASNet_Encoder(tf.keras.Model):
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
**
kwargs
):
"""BASNet_En initialization function.
"""BASNet_En
coder
initialization function.
Args:
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
...
...
@@ -109,19 +109,14 @@ class BASNet_Encoder(tf.keras.Model):
endpoints
=
{}
for
i
,
spec
in
enumerate
(
BASNET_ENCODER_SPECS
):
if
spec
[
0
]
==
'residual'
:
block_fn
=
nn_blocks
.
ResBlock
else
:
raise
ValueError
(
'Block fn `{}` is not supported.'
.
format
(
spec
[
0
]))
x
=
self
.
_block_group
(
inputs
=
x
,
filters
=
spec
[
1
],
strides
=
spec
[
2
],
block_fn
=
block_fn
,
block_repeats
=
spec
[
3
],
filters
=
spec
[
0
],
strides
=
spec
[
1
],
block_repeats
=
spec
[
2
],
name
=
'block_group_l{}'
.
format
(
i
+
2
))
endpoints
[
str
(
i
)]
=
x
if
spec
[
4
]:
if
spec
[
3
]:
x
=
tf
.
keras
.
layers
.
MaxPool2D
(
pool_size
=
2
,
strides
=
2
,
padding
=
'same'
)(
x
)
self
.
_output_specs
=
{
l
:
endpoints
[
l
].
get_shape
()
for
l
in
endpoints
}
...
...
@@ -131,24 +126,22 @@ class BASNet_Encoder(tf.keras.Model):
inputs
,
filters
,
strides
,
block_fn
,
block_repeats
=
1
,
name
=
'block_group'
):
"""Creates one group of blocks for the
ResNet
model.
"""Creates one group of
residual
blocks for the
BASNet encoder
model.
Args:
inputs: `Tensor` of size `[batch, channels, height, width]`.
filters: `int` number of filters for the first convolution of the layer.
strides: `int` stride to use for the first convolution of the layer. If
greater than 1, this layer will downsample the input.
block_fn: Either `nn_blocks.ResidualBlock` or `nn_blocks.BottleneckBlock`.
block_repeats: `int` number of blocks contained in the layer.
name: `str`name for the block.
Returns:
The output `Tensor` of the block layer.
"""
x
=
block
_fn
(
x
=
nn_
block
s
.
ResBlock
(
filters
=
filters
,
strides
=
strides
,
use_projection
=
True
,
...
...
@@ -163,7 +156,7 @@ class BASNet_Encoder(tf.keras.Model):
inputs
)
for
_
in
range
(
1
,
block_repeats
):
x
=
block
_fn
(
x
=
nn_
block
.
ResBlock
(
filters
=
filters
,
strides
=
1
,
use_projection
=
False
,
...
...
official/vision/beta/projects/basnet/modeling/refunet.py
View file @
25bf4592
...
...
@@ -12,13 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Residual Refinement Module of BASNet.
Boundary-Awar network (BASNet) were proposed in:
[1] Qin, Xuebin, et al.
Basnet: Boundary-aware salient object detection.
"""
# Import libraries
import
tensorflow
as
tf
...
...
@@ -26,10 +19,15 @@ from official.vision.beta.projects.basnet.modeling.layers import nn_blocks
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
RefUnet
(
tf
.
keras
.
Model
):
class
RefUnet
(
tf
.
keras
.
layers
.
Layer
):
"""Residual Refinement Module of BASNet.
Boundary-Awar network (BASNet) were proposed in:
[1] Qin, Xuebin, et al.
Basnet: Boundary-aware salient object detection.
"""
def
__init__
(
self
,
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
1
])
,
use_separable_conv
=
False
,
activation
=
'relu'
,
use_sync_bn
=
False
,
use_bias
=
True
,
...
...
@@ -42,7 +40,8 @@ class RefUnet(tf.keras.Model):
"""Residual Refinement Module of BASNet.
Args:
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
use_separable_conv: `bool`, if True use separable convolution for
convolution in BASNet layers.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
use_bias: if True, use bias in conv2d.
...
...
@@ -56,126 +55,91 @@ class RefUnet(tf.keras.Model):
Default to None.
**kwargs: keyword arguments to be passed.
"""
self
.
_input_specs
=
input_specs
self
.
_use_sync_bn
=
use_sync_bn
self
.
_use_bias
=
use_bias
self
.
_activation
=
activation
self
.
_norm_momentum
=
norm_momentum
self
.
_norm_epsilon
=
norm_epsilon
if
use_sync_bn
:
self
.
_norm
=
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
else
:
self
.
_norm
=
tf
.
keras
.
layers
.
BatchNormalization
self
.
_kernel_initializer
=
kernel_initializer
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_bias_regularizer
=
bias_regularizer
super
(
RefUnet
,
self
).
__init__
(
**
kwargs
)
self
.
_config_dict
=
{
'use_separable_conv'
:
use_separable_conv
,
'activation'
:
activation
,
'use_sync_bn'
:
use_sync_bn
,
'use_bias'
:
use_bias
,
'norm_momentum'
:
norm_momentum
,
'norm_epsilon'
:
norm_epsilon
,
'kernel_initializer'
:
kernel_initializer
,
'kernel_regularizer'
:
kernel_regularizer
,
'bias_regularizer'
:
bias_regularizer
,
}
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
bn_axis
=
-
1
else
:
bn_axis
=
1
self
.
_concat
=
tf
.
keras
.
layers
.
Concatenate
(
axis
=-
1
)
self
.
_sigmoid
=
tf
.
keras
.
layers
.
Activation
(
activation
=
'sigmoid'
)
self
.
_maxpool
=
tf
.
keras
.
layers
.
MaxPool2D
(
pool_size
=
2
,
strides
=
2
,
padding
=
'valid'
)
self
.
_upsample
=
tf
.
keras
.
layers
.
UpSampling2D
(
size
=
2
,
interpolation
=
'bilinear'
)
# Build ResNet.
inputs
=
tf
.
keras
.
Input
(
shape
=
self
.
_input_specs
.
shape
[
1
:])
endpoints
=
{}
def
build
(
self
,
input_shape
):
"""Creates the variables of the BASNet decoder."""
if
self
.
_config_dict
[
'use_separable_conv'
]:
conv_op
=
tf
.
keras
.
layers
.
SeparableConv2D
else
:
conv_op
=
tf
.
keras
.
layers
.
Conv2D
conv_kwargs
=
{
'kernel_size'
:
3
,
'strides'
:
1
,
'use_bias'
:
self
.
_config_dict
[
'use_bias'
],
'kernel_initializer'
:
self
.
_config_dict
[
'kernel_initializer'
],
'kernel_regularizer'
:
self
.
_config_dict
[
'kernel_regularizer'
],
'bias_regularizer'
:
self
.
_config_dict
[
'bias_regularizer'
],
}
self
.
_in_conv
=
conv_op
(
filters
=
64
,
padding
=
'same'
,
**
conv_kwargs
)
self
.
_en_convs
=
[]
for
_
in
range
(
4
):
self
.
_en_convs
.
append
(
nn_blocks
.
ConvBlock
(
filters
=
64
,
**
conv_kwargs
))
self
.
_bridge_convs
=
[]
for
_
in
range
(
1
):
self
.
_bridge_convs
.
append
(
nn_blocks
.
ConvBlock
(
filters
=
64
,
**
conv_kwargs
))
self
.
_de_convs
=
[]
for
_
in
range
(
4
):
self
.
_de_convs
.
append
(
nn_blocks
.
ConvBlock
(
filters
=
64
,
**
conv_kwargs
))
self
.
_out_conv
=
conv_op
(
padding
=
'same'
,
filters
=
1
,
**
conv_kwargs
)
def
call
(
self
,
inputs
):
endpoints
=
{}
residual
=
inputs
x
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
64
,
kernel_size
=
3
,
strides
=
1
,
use_bias
=
self
.
_use_bias
,
padding
=
'same'
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)(
inputs
)
x
=
self
.
_in_conv
(
inputs
)
# Top-down
for
i
in
range
(
4
):
x
=
nn_blocks
.
ConvBlock
(
filters
=
64
,
kernel_size
=
3
,
strides
=
1
,
dilation_rate
=
1
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activation
=
'relu'
,
use_sync_bn
=
self
.
_use_sync_bn
,
use_bias
=
self
.
_use_bias
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
)(
x
)
for
i
,
block
in
enumerate
(
self
.
_en_convs
):
x
=
block
(
x
)
endpoints
[
str
(
i
)]
=
x
x
=
tf
.
keras
.
layers
.
MaxPool2D
(
pool_size
=
2
,
strides
=
2
,
padding
=
'valid'
)(
x
)
x
=
self
.
_maxpool
(
x
)
# Bridge
x
=
nn_blocks
.
ConvBlock
(
filters
=
64
,
kernel_size
=
3
,
strides
=
1
,
dilation_rate
=
1
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activation
=
'relu'
,
use_sync_bn
=
self
.
_use_sync_bn
,
use_bias
=
self
.
_use_bias
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
)(
x
)
x
=
tf
.
keras
.
layers
.
UpSampling2D
(
size
=
2
,
interpolation
=
'bilinear'
)(
x
)
for
i
,
block
in
enumerate
(
self
.
_bridge_convs
):
x
=
block
(
x
)
# Bottom-up
for
i
,
block
in
enumerate
(
self
.
_de_convs
):
x
=
self
.
_upsample
(
x
)
x
=
self
.
_concat
([
endpoints
[
str
(
3
-
i
)],
x
])
x
=
block
(
x
)
for
i
in
range
(
4
):
x
=
tf
.
keras
.
layers
.
Concatenate
(
axis
=-
1
)([
endpoints
[
str
(
3
-
i
)],
x
])
x
=
nn_blocks
.
ConvBlock
(
filters
=
64
,
kernel_size
=
3
,
strides
=
1
,
dilation_rate
=
1
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activation
=
'relu'
,
use_sync_bn
=
self
.
_use_sync_bn
,
use_bias
=
self
.
_use_bias
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
)(
x
)
if
i
==
3
:
x
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
1
,
kernel_size
=
3
,
strides
=
1
,
use_bias
=
self
.
_use_bias
,
padding
=
'same'
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)(
x
)
else
:
x
=
tf
.
keras
.
layers
.
UpSampling2D
(
size
=
2
,
interpolation
=
'bilinear'
)(
x
)
x
=
self
.
_out_conv
(
x
)
residual
=
tf
.
cast
(
residual
,
dtype
=
x
.
dtype
)
output
=
x
+
residual
output
=
tf
.
keras
.
layers
.
Activation
(
activation
=
'sigmoid'
)(
output
)
output
=
self
.
_sigmoid
(
x
+
residual
)
self
.
_output_specs
=
output
.
get_shape
()
super
(
RefUnet
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
output
,
**
kwargs
)
return
output
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment