Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
53c3f653
Commit
53c3f653
authored
Jun 28, 2021
by
Gunho Park
Browse files
Internal change
parent
d4f401e1
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
133 additions
and
132 deletions
+133
-132
official/vision/beta/projects/basnet/configs/basnet.py
official/vision/beta/projects/basnet/configs/basnet.py
+2
-0
official/vision/beta/projects/basnet/modeling/basnet_model.py
...cial/vision/beta/projects/basnet/modeling/basnet_model.py
+6
-9
official/vision/beta/projects/basnet/modeling/layers/nn_blocks.py
.../vision/beta/projects/basnet/modeling/layers/nn_blocks.py
+70
-96
official/vision/beta/projects/basnet/modeling/refunet.py
official/vision/beta/projects/basnet/modeling/refunet.py
+32
-14
official/vision/beta/projects/basnet/tasks/basnet.py
official/vision/beta/projects/basnet/tasks/basnet.py
+23
-13
No files found.
official/vision/beta/projects/basnet/configs/basnet.py
View file @
53c3f653
...
@@ -46,6 +46,7 @@ class DataConfig(cfg.DataConfig):
...
@@ -46,6 +46,7 @@ class DataConfig(cfg.DataConfig):
class
BASNetModel
(
hyperparams
.
Config
):
class
BASNetModel
(
hyperparams
.
Config
):
"""BASNet model config."""
"""BASNet model config."""
input_size
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
input_size
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
use_bias
:
bool
=
False
norm_activation
:
common
.
NormActivation
=
common
.
NormActivation
()
norm_activation
:
common
.
NormActivation
=
common
.
NormActivation
()
...
@@ -99,6 +100,7 @@ def basnet_duts() -> cfg.ExperimentConfig:
...
@@ -99,6 +100,7 @@ def basnet_duts() -> cfg.ExperimentConfig:
task
=
BASNetTask
(
task
=
BASNetTask
(
model
=
BASNetModel
(
model
=
BASNetModel
(
input_size
=
[
None
,
None
,
3
],
input_size
=
[
None
,
None
,
3
],
use_bias
=
True
,
norm_activation
=
common
.
NormActivation
(
norm_activation
=
common
.
NormActivation
(
activation
=
'relu'
,
activation
=
'relu'
,
norm_momentum
=
0.99
,
norm_momentum
=
0.99
,
...
...
official/vision/beta/projects/basnet/modeling/basnet_model.py
View file @
53c3f653
...
@@ -274,11 +274,11 @@ def build_basnet_encoder(
...
@@ -274,11 +274,11 @@ def build_basnet_encoder(
norm_activation_config
=
model_config
.
norm_activation
norm_activation_config
=
model_config
.
norm_activation
assert
backbone_type
==
'basnet_encoder'
,
(
f
'Inconsistent backbone type '
assert
backbone_type
==
'basnet_encoder'
,
(
f
'Inconsistent backbone type '
f
'
{
backbone_type
}
'
)
f
'
{
backbone_type
}
'
)
return
BASNet_Encoder
(
return
BASNet_Encoder
(
input_specs
=
input_specs
,
input_specs
=
input_specs
,
activation
=
norm_activation_config
.
activation
,
activation
=
norm_activation_config
.
activation
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
use_bias
=
norm_activation_config
.
use_bias
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
kernel_regularizer
=
l2_regularizer
)
kernel_regularizer
=
l2_regularizer
)
...
@@ -289,7 +289,6 @@ class BASNet_Decoder(tf.keras.layers.Layer):
...
@@ -289,7 +289,6 @@ class BASNet_Decoder(tf.keras.layers.Layer):
"""BASNet decoder."""
"""BASNet decoder."""
def
__init__
(
self
,
def
__init__
(
self
,
use_separable_conv
=
False
,
activation
=
'relu'
,
activation
=
'relu'
,
use_sync_bn
=
False
,
use_sync_bn
=
False
,
use_bias
=
True
,
use_bias
=
True
,
...
@@ -302,11 +301,11 @@ class BASNet_Decoder(tf.keras.layers.Layer):
...
@@ -302,11 +301,11 @@ class BASNet_Decoder(tf.keras.layers.Layer):
"""BASNet decoder initialization function.
"""BASNet decoder initialization function.
Args:
Args:
use_separable_conv: `bool`, if True use separable convolution for
convolution in BASNet layers.
activation: `str` name of the activation function.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
use_sync_bn: if True, use synchronized batch normalization.
use_bias: if True, use bias in convolution.
use_bias: if True, use bias in convolution.
use_separable_conv: `bool`, if True use separable convolution for
convolution in BASNet layers.
norm_momentum: `float` normalization omentum for the moving average.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
zero.
...
@@ -317,7 +316,6 @@ class BASNet_Decoder(tf.keras.layers.Layer):
...
@@ -317,7 +316,6 @@ class BASNet_Decoder(tf.keras.layers.Layer):
"""
"""
super
(
BASNet_Decoder
,
self
).
__init__
(
**
kwargs
)
super
(
BASNet_Decoder
,
self
).
__init__
(
**
kwargs
)
self
.
_config_dict
=
{
self
.
_config_dict
=
{
'use_separable_conv'
:
use_separable_conv
,
'activation'
:
activation
,
'activation'
:
activation
,
'use_sync_bn'
:
use_sync_bn
,
'use_sync_bn'
:
use_sync_bn
,
'use_bias'
:
use_bias
,
'use_bias'
:
use_bias
,
...
@@ -337,9 +335,6 @@ class BASNet_Decoder(tf.keras.layers.Layer):
...
@@ -337,9 +335,6 @@ class BASNet_Decoder(tf.keras.layers.Layer):
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
):
"""Creates the variables of the BASNet decoder."""
"""Creates the variables of the BASNet decoder."""
if
self
.
_config_dict
[
'use_separable_conv'
]:
conv_op
=
tf
.
keras
.
layers
.
SeparableConv2D
else
:
conv_op
=
tf
.
keras
.
layers
.
Conv2D
conv_op
=
tf
.
keras
.
layers
.
Conv2D
conv_kwargs
=
{
conv_kwargs
=
{
'kernel_size'
:
3
,
'kernel_size'
:
3
,
...
@@ -362,6 +357,7 @@ class BASNet_Decoder(tf.keras.layers.Layer):
...
@@ -362,6 +357,7 @@ class BASNet_Decoder(tf.keras.layers.Layer):
filters
=
spec
[
2
*
j
],
filters
=
spec
[
2
*
j
],
dilation_rate
=
spec
[
2
*
j
+
1
],
dilation_rate
=
spec
[
2
*
j
+
1
],
activation
=
'relu'
,
activation
=
'relu'
,
use_sync_bn
=
self
.
_config_dict
[
'use_sync_bn'
],
norm_momentum
=
0.99
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
norm_epsilon
=
0.001
,
**
conv_kwargs
))
**
conv_kwargs
))
...
@@ -384,6 +380,7 @@ class BASNet_Decoder(tf.keras.layers.Layer):
...
@@ -384,6 +380,7 @@ class BASNet_Decoder(tf.keras.layers.Layer):
filters
=
spec
[
2
*
j
],
filters
=
spec
[
2
*
j
],
dilation_rate
=
spec
[
2
*
j
+
1
],
dilation_rate
=
spec
[
2
*
j
+
1
],
activation
=
'relu'
,
activation
=
'relu'
,
use_sync_bn
=
self
.
_config_dict
[
'use_sync_bn'
],
norm_momentum
=
0.99
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
norm_epsilon
=
0.001
,
**
conv_kwargs
))
**
conv_kwargs
))
...
...
official/vision/beta/projects/basnet/modeling/layers/nn_blocks.py
View file @
53c3f653
...
@@ -58,20 +58,20 @@ class ConvBlock(tf.keras.layers.Layer):
...
@@ -58,20 +58,20 @@ class ConvBlock(tf.keras.layers.Layer):
**kwargs: keyword arguments to be passed.
**kwargs: keyword arguments to be passed.
"""
"""
super
(
ConvBlock
,
self
).
__init__
(
**
kwargs
)
super
(
ConvBlock
,
self
).
__init__
(
**
kwargs
)
self
.
_config_dict
=
{
self
.
_
filters
=
filters
'
filters
'
:
filters
,
self
.
_
kernel_size
=
kernel_size
'
kernel_size
'
:
kernel_size
,
self
.
_
strides
=
strides
'
strides
'
:
strides
,
self
.
_
dilation_rate
=
dilation_rate
'
dilation_rate
'
:
dilation_rate
,
self
.
_
kernel_initializer
=
kernel_initializer
'
kernel_initializer
'
:
kernel_initializer
,
self
.
_
kernel_regularizer
=
kernel_regularizer
'
kernel_regularizer
'
:
kernel_regularizer
,
self
.
_
bias_regularizer
=
bias_regularizer
'
bias_regularizer
'
:
bias_regularizer
,
self
.
_
activation
=
activation
'
activation
'
:
activation
,
self
.
_use_bias
=
use_bias
'use_sync_bn'
:
use_sync_bn
,
self
.
_use_sync_bn
=
use_sync_bn
'use_bias'
:
use_bias
,
self
.
_
norm_momentum
=
norm_momentum
'
norm_momentum
'
:
norm_momentum
,
self
.
_
norm_epsilon
=
norm_epsilon
'
norm_epsilon
'
:
norm_epsilon
}
if
use_sync_bn
:
if
use_sync_bn
:
self
.
_norm
=
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
self
.
_norm
=
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
else
:
else
:
...
@@ -83,40 +83,29 @@ class ConvBlock(tf.keras.layers.Layer):
...
@@ -83,40 +83,29 @@ class ConvBlock(tf.keras.layers.Layer):
self
.
_activation_fn
=
tf_utils
.
get_activation
(
activation
)
self
.
_activation_fn
=
tf_utils
.
get_activation
(
activation
)
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
):
conv_kwargs
=
{
'padding'
:
'same'
,
'use_bias'
:
self
.
_config_dict
[
'use_bias'
],
'kernel_initializer'
:
self
.
_config_dict
[
'kernel_initializer'
],
'kernel_regularizer'
:
self
.
_config_dict
[
'kernel_regularizer'
],
'bias_regularizer'
:
self
.
_config_dict
[
'bias_regularizer'
],
}
self
.
_conv0
=
tf
.
keras
.
layers
.
Conv2D
(
self
.
_conv0
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
self
.
_filters
,
filters
=
self
.
_config_dict
[
'filters'
],
kernel_size
=
self
.
_kernel_size
,
kernel_size
=
self
.
_config_dict
[
'kernel_size'
],
strides
=
self
.
_strides
,
strides
=
self
.
_config_dict
[
'strides'
],
dilation_rate
=
self
.
_dilation_rate
,
dilation_rate
=
self
.
_config_dict
[
'dilation_rate'
],
padding
=
'same'
,
**
conv_kwargs
)
use_bias
=
self
.
_use_bias
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)
self
.
_norm0
=
self
.
_norm
(
self
.
_norm0
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
momentum
=
self
.
_
config_dict
[
'
norm_momentum
'
]
,
epsilon
=
self
.
_norm_epsilon
)
epsilon
=
self
.
_
config_dict
[
'
norm_epsilon
'
]
)
super
(
ConvBlock
,
self
).
build
(
input_shape
)
super
(
ConvBlock
,
self
).
build
(
input_shape
)
def
get_config
(
self
):
def
get_config
(
self
):
config
=
{
return
self
.
_config_dict
'filters'
:
self
.
_filters
,
'kernel_size'
:
self
.
_kernel_size
,
'strides'
:
self
.
_strides
,
'dilation_rate'
:
self
.
_dilation_rate
,
'kernel_initializer'
:
self
.
_kernel_initializer
,
'kernel_regularizer'
:
self
.
_kernel_regularizer
,
'bias_regularizer'
:
self
.
_bias_regularizer
,
'activation'
:
self
.
_activation
,
'use_sync_bn'
:
self
.
_use_sync_bn
,
'use_bias'
:
self
.
_use_bias
,
'norm_momentum'
:
self
.
_norm_momentum
,
'norm_epsilon'
:
self
.
_norm_epsilon
}
base_config
=
super
(
ConvBlock
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
inputs
,
training
=
None
):
def
call
(
self
,
inputs
,
training
=
None
):
x
=
self
.
_conv0
(
inputs
)
x
=
self
.
_conv0
(
inputs
)
...
@@ -168,19 +157,19 @@ class ResBlock(tf.keras.layers.Layer):
...
@@ -168,19 +157,19 @@ class ResBlock(tf.keras.layers.Layer):
**kwargs: Additional keyword arguments to be passed.
**kwargs: Additional keyword arguments to be passed.
"""
"""
super
(
ResBlock
,
self
).
__init__
(
**
kwargs
)
super
(
ResBlock
,
self
).
__init__
(
**
kwargs
)
self
.
_config_dict
=
{
self
.
_
filters
=
filters
'
filters
'
:
filters
,
self
.
_
strides
=
strides
'
strides
'
:
strides
,
self
.
_
use_projection
=
use_projection
'
use_projection
'
:
use_projection
,
self
.
_use_sync_bn
=
use_sync_bn
'kernel_initializer'
:
kernel_initializer
,
self
.
_use_bias
=
use_bias
'kernel_regularizer'
:
kernel_regularizer
,
self
.
_activation
=
activation
'bias_regularizer'
:
bias_regularizer
,
self
.
_kernel_initializer
=
kernel_initializer
'activation'
:
activation
,
self
.
_norm_momentum
=
norm_momentum
'use_sync_bn'
:
use_sync_bn
,
self
.
_norm_epsilon
=
norm_epsilon
'use_bias'
:
use_bias
,
self
.
_kernel_regularizer
=
kernel_regularizer
'norm_momentum'
:
norm_momentum
,
self
.
_bias_regularizer
=
bias_regularizer
'norm_epsilon'
:
norm_epsilon
}
if
use_sync_bn
:
if
use_sync_bn
:
self
.
_norm
=
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
self
.
_norm
=
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
else
:
else
:
...
@@ -192,70 +181,55 @@ class ResBlock(tf.keras.layers.Layer):
...
@@ -192,70 +181,55 @@ class ResBlock(tf.keras.layers.Layer):
self
.
_activation_fn
=
tf_utils
.
get_activation
(
activation
)
self
.
_activation_fn
=
tf_utils
.
get_activation
(
activation
)
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
):
if
self
.
_use_projection
:
conv_kwargs
=
{
'filters'
:
self
.
_config_dict
[
'filters'
],
'padding'
:
'same'
,
'use_bias'
:
self
.
_config_dict
[
'use_bias'
],
'kernel_initializer'
:
self
.
_config_dict
[
'kernel_initializer'
],
'kernel_regularizer'
:
self
.
_config_dict
[
'kernel_regularizer'
],
'bias_regularizer'
:
self
.
_config_dict
[
'bias_regularizer'
],
}
if
self
.
_config_dict
[
'use_projection'
]:
self
.
_shortcut
=
tf
.
keras
.
layers
.
Conv2D
(
self
.
_shortcut
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
self
.
_filters
,
filters
=
self
.
_
config_dict
[
'
filters
'
]
,
kernel_size
=
1
,
kernel_size
=
1
,
strides
=
self
.
_strides
,
strides
=
self
.
_
config_dict
[
'
strides
'
]
,
use_bias
=
self
.
_use_bias
,
use_bias
=
self
.
_
config_dict
[
'
use_bias
'
]
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_
config_dict
[
'
kernel_initializer
'
]
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_
config_dict
[
'
kernel_regularizer
'
]
,
bias_regularizer
=
self
.
_bias_regularizer
)
bias_regularizer
=
self
.
_
config_dict
[
'
bias_regularizer
'
]
)
self
.
_norm0
=
self
.
_norm
(
self
.
_norm0
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
momentum
=
self
.
_
config_dict
[
'
norm_momentum
'
]
,
epsilon
=
self
.
_norm_epsilon
)
epsilon
=
self
.
_
config_dict
[
'
norm_epsilon
'
]
)
self
.
_conv1
=
tf
.
keras
.
layers
.
Conv2D
(
self
.
_conv1
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
self
.
_filters
,
kernel_size
=
3
,
kernel_size
=
3
,
strides
=
self
.
_strides
,
strides
=
self
.
_config_dict
[
'strides'
],
padding
=
'same'
,
**
conv_kwargs
)
use_bias
=
self
.
_use_bias
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)
self
.
_norm1
=
self
.
_norm
(
self
.
_norm1
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
momentum
=
self
.
_
config_dict
[
'
norm_momentum
'
]
,
epsilon
=
self
.
_norm_epsilon
)
epsilon
=
self
.
_
config_dict
[
'
norm_epsilon
'
]
)
self
.
_conv2
=
tf
.
keras
.
layers
.
Conv2D
(
self
.
_conv2
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
self
.
_filters
,
kernel_size
=
3
,
kernel_size
=
3
,
strides
=
1
,
strides
=
1
,
padding
=
'same'
,
**
conv_kwargs
)
use_bias
=
self
.
_use_bias
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)
self
.
_norm2
=
self
.
_norm
(
self
.
_norm2
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
momentum
=
self
.
_
config_dict
[
'
norm_momentum
'
]
,
epsilon
=
self
.
_norm_epsilon
)
epsilon
=
self
.
_
config_dict
[
'
norm_epsilon
'
]
)
super
(
ResBlock
,
self
).
build
(
input_shape
)
super
(
ResBlock
,
self
).
build
(
input_shape
)
def
get_config
(
self
):
def
get_config
(
self
):
config
=
{
return
self
.
_config_dict
'filters'
:
self
.
_filters
,
'strides'
:
self
.
_strides
,
'use_projection'
:
self
.
_use_projection
,
'kernel_initializer'
:
self
.
_kernel_initializer
,
'kernel_regularizer'
:
self
.
_kernel_regularizer
,
'bias_regularizer'
:
self
.
_bias_regularizer
,
'activation'
:
self
.
_activation
,
'use_sync_bn'
:
self
.
_use_sync_bn
,
'use_bias'
:
self
.
_use_bias
,
'norm_momentum'
:
self
.
_norm_momentum
,
'norm_epsilon'
:
self
.
_norm_epsilon
}
base_config
=
super
(
ResBlock
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
inputs
,
training
=
None
):
def
call
(
self
,
inputs
,
training
=
None
):
shortcut
=
inputs
shortcut
=
inputs
if
self
.
_use_projection
:
if
self
.
_
config_dict
[
'
use_projection
'
]
:
shortcut
=
self
.
_shortcut
(
shortcut
)
shortcut
=
self
.
_shortcut
(
shortcut
)
shortcut
=
self
.
_norm0
(
shortcut
)
shortcut
=
self
.
_norm0
(
shortcut
)
...
...
official/vision/beta/projects/basnet/modeling/refunet.py
View file @
53c3f653
...
@@ -27,7 +27,6 @@ class RefUnet(tf.keras.layers.Layer):
...
@@ -27,7 +27,6 @@ class RefUnet(tf.keras.layers.Layer):
Basnet: Boundary-aware salient object detection.
Basnet: Boundary-aware salient object detection.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
use_separable_conv
=
False
,
activation
=
'relu'
,
activation
=
'relu'
,
use_sync_bn
=
False
,
use_sync_bn
=
False
,
use_bias
=
True
,
use_bias
=
True
,
...
@@ -40,8 +39,6 @@ class RefUnet(tf.keras.layers.Layer):
...
@@ -40,8 +39,6 @@ class RefUnet(tf.keras.layers.Layer):
"""Residual Refinement Module of BASNet.
"""Residual Refinement Module of BASNet.
Args:
Args:
use_separable_conv: `bool`, if True use separable convolution for
convolution in BASNet layers.
activation: `str` name of the activation function.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
use_sync_bn: if True, use synchronized batch normalization.
use_bias: if True, use bias in conv2d.
use_bias: if True, use bias in conv2d.
...
@@ -57,7 +54,6 @@ class RefUnet(tf.keras.layers.Layer):
...
@@ -57,7 +54,6 @@ class RefUnet(tf.keras.layers.Layer):
"""
"""
super
(
RefUnet
,
self
).
__init__
(
**
kwargs
)
super
(
RefUnet
,
self
).
__init__
(
**
kwargs
)
self
.
_config_dict
=
{
self
.
_config_dict
=
{
'use_separable_conv'
:
use_separable_conv
,
'activation'
:
activation
,
'activation'
:
activation
,
'use_sync_bn'
:
use_sync_bn
,
'use_sync_bn'
:
use_sync_bn
,
'use_bias'
:
use_bias
,
'use_bias'
:
use_bias
,
...
@@ -83,11 +79,10 @@ class RefUnet(tf.keras.layers.Layer):
...
@@ -83,11 +79,10 @@ class RefUnet(tf.keras.layers.Layer):
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
):
"""Creates the variables of the BASNet decoder."""
"""Creates the variables of the BASNet decoder."""
if
self
.
_config_dict
[
'use_separable_conv'
]:
conv_op
=
tf
.
keras
.
layers
.
SeparableConv2D
else
:
conv_op
=
tf
.
keras
.
layers
.
Conv2D
conv_op
=
tf
.
keras
.
layers
.
Conv2D
conv_kwargs
=
{
conv_kwargs
=
{
'dilation_rate'
:
1
,
'activation'
:
self
.
_config_dict
[
'activation'
],
'kernel_size'
:
3
,
'kernel_size'
:
3
,
'strides'
:
1
,
'strides'
:
1
,
'use_bias'
:
self
.
_config_dict
[
'use_bias'
],
'use_bias'
:
self
.
_config_dict
[
'use_bias'
],
...
@@ -96,21 +91,44 @@ class RefUnet(tf.keras.layers.Layer):
...
@@ -96,21 +91,44 @@ class RefUnet(tf.keras.layers.Layer):
'bias_regularizer'
:
self
.
_config_dict
[
'bias_regularizer'
],
'bias_regularizer'
:
self
.
_config_dict
[
'bias_regularizer'
],
}
}
self
.
_in_conv
=
conv_op
(
filters
=
64
,
padding
=
'same'
,
**
conv_kwargs
)
self
.
_in_conv
=
conv_op
(
filters
=
64
,
padding
=
'same'
,
**
conv_kwargs
)
self
.
_en_convs
=
[]
self
.
_en_convs
=
[]
for
_
in
range
(
4
):
for
_
in
range
(
4
):
self
.
_en_convs
.
append
(
nn_blocks
.
ConvBlock
(
filters
=
64
,
**
conv_kwargs
))
self
.
_en_convs
.
append
(
nn_blocks
.
ConvBlock
(
filters
=
64
,
use_sync_bn
=
self
.
_config_dict
[
'use_sync_bn'
],
norm_momentum
=
self
.
_config_dict
[
'norm_momentum'
],
norm_epsilon
=
self
.
_config_dict
[
'norm_epsilon'
],
**
conv_kwargs
))
self
.
_bridge_convs
=
[]
self
.
_bridge_convs
=
[]
for
_
in
range
(
1
):
for
_
in
range
(
1
):
self
.
_bridge_convs
.
append
(
nn_blocks
.
ConvBlock
(
filters
=
64
,
**
conv_kwargs
))
self
.
_bridge_convs
.
append
(
nn_blocks
.
ConvBlock
(
filters
=
64
,
use_sync_bn
=
self
.
_config_dict
[
'use_sync_bn'
],
norm_momentum
=
self
.
_config_dict
[
'norm_momentum'
],
norm_epsilon
=
self
.
_config_dict
[
'norm_epsilon'
],
**
conv_kwargs
))
self
.
_de_convs
=
[]
self
.
_de_convs
=
[]
for
_
in
range
(
4
):
for
_
in
range
(
4
):
self
.
_de_convs
.
append
(
nn_blocks
.
ConvBlock
(
filters
=
64
,
**
conv_kwargs
))
self
.
_de_convs
.
append
(
nn_blocks
.
ConvBlock
(
filters
=
64
,
self
.
_out_conv
=
conv_op
(
padding
=
'same'
,
filters
=
1
,
**
conv_kwargs
)
use_sync_bn
=
self
.
_config_dict
[
'use_sync_bn'
],
norm_momentum
=
self
.
_config_dict
[
'norm_momentum'
],
norm_epsilon
=
self
.
_config_dict
[
'norm_epsilon'
],
**
conv_kwargs
))
self
.
_out_conv
=
conv_op
(
filters
=
1
,
padding
=
'same'
,
**
conv_kwargs
)
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
):
endpoints
=
{}
endpoints
=
{}
...
...
official/vision/beta/projects/basnet/tasks/basnet.py
View file @
53c3f653
...
@@ -36,21 +36,31 @@ def build_basnet_model(
...
@@ -36,21 +36,31 @@ def build_basnet_model(
model_config
:
exp_cfg
.
BASNetModel
,
model_config
:
exp_cfg
.
BASNetModel
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
):
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
):
"""Builds BASNet model."""
"""Builds BASNet model."""
backbone
=
basnet_model
.
BASNet_Encoder
(
input_specs
=
input_specs
)
norm_activation_config
=
model_config
.
norm_activation
norm_activation_config
=
model_config
.
norm_activation
backbone
=
basnet_model
.
BASNet_Encoder
(
input_specs
=
input_specs
,
activation
=
norm_activation_config
.
activation
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
use_bias
=
model_config
.
use_bias
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
kernel_regularizer
=
l2_regularizer
)
decoder
=
basnet_model
.
BASNet_Decoder
(
decoder
=
basnet_model
.
BASNet_Decoder
(
activation
=
norm_activation_config
.
activation
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
use_bias
=
model_config
.
use_bias
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
activation
=
norm_activation_config
.
activation
,
kernel_regularizer
=
l2_regularizer
)
kernel_regularizer
=
l2_regularizer
)
refinement
=
refunet
.
RefUnet
()
refinement
=
refunet
.
RefUnet
(
activation
=
norm_activation_config
.
activation
,
norm_activation_config
=
model_config
.
norm_activation
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
use_bias
=
model_config
.
use_bias
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
kernel_regularizer
=
l2_regularizer
)
model
=
basnet_model
.
BASNetModel
(
backbone
,
decoder
,
refinement
)
model
=
basnet_model
.
BASNetModel
(
backbone
,
decoder
,
refinement
)
return
model
return
model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment