Commit 53c3f653 authored by Gunho Park's avatar Gunho Park
Browse files

Internal change

parent d4f401e1
......@@ -46,6 +46,7 @@ class DataConfig(cfg.DataConfig):
class BASNetModel(hyperparams.Config):
"""BASNet model config."""
input_size: List[int] = dataclasses.field(default_factory=list)
use_bias: bool = False
norm_activation: common.NormActivation = common.NormActivation()
......@@ -99,6 +100,7 @@ def basnet_duts() -> cfg.ExperimentConfig:
task=BASNetTask(
model=BASNetModel(
input_size=[None, None, 3],
use_bias=True,
norm_activation=common.NormActivation(
activation='relu',
norm_momentum=0.99,
......
......@@ -274,11 +274,11 @@ def build_basnet_encoder(
norm_activation_config = model_config.norm_activation
assert backbone_type == 'basnet_encoder', (f'Inconsistent backbone type '
f'{backbone_type}')
return BASNet_Encoder(
input_specs=input_specs,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
use_bias=norm_activation_config.use_bias,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
......@@ -289,7 +289,6 @@ class BASNet_Decoder(tf.keras.layers.Layer):
"""BASNet decoder."""
def __init__(self,
use_separable_conv=False,
activation='relu',
use_sync_bn=False,
use_bias=True,
......@@ -302,11 +301,11 @@ class BASNet_Decoder(tf.keras.layers.Layer):
"""BASNet decoder initialization function.
Args:
use_separable_conv: `bool`, if True use separable convolution for
convolution in BASNet layers.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
use_bias: if True, use bias in convolution.
use_separable_conv: `bool`, if True use separable convolution for
convolution in BASNet layers.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
......@@ -317,7 +316,6 @@ class BASNet_Decoder(tf.keras.layers.Layer):
"""
super(BASNet_Decoder, self).__init__(**kwargs)
self._config_dict = {
'use_separable_conv': use_separable_conv,
'activation': activation,
'use_sync_bn': use_sync_bn,
'use_bias': use_bias,
......@@ -337,9 +335,6 @@ class BASNet_Decoder(tf.keras.layers.Layer):
def build(self, input_shape):
"""Creates the variables of the BASNet decoder."""
if self._config_dict['use_separable_conv']:
conv_op = tf.keras.layers.SeparableConv2D
else:
conv_op = tf.keras.layers.Conv2D
conv_kwargs = {
'kernel_size': 3,
......@@ -362,6 +357,7 @@ class BASNet_Decoder(tf.keras.layers.Layer):
filters=spec[2*j],
dilation_rate=spec[2*j+1],
activation='relu',
use_sync_bn=self._config_dict['use_sync_bn'],
norm_momentum=0.99,
norm_epsilon=0.001,
**conv_kwargs))
......@@ -384,6 +380,7 @@ class BASNet_Decoder(tf.keras.layers.Layer):
filters=spec[2*j],
dilation_rate=spec[2*j+1],
activation='relu',
use_sync_bn=self._config_dict['use_sync_bn'],
norm_momentum=0.99,
norm_epsilon=0.001,
**conv_kwargs))
......
......@@ -58,20 +58,20 @@ class ConvBlock(tf.keras.layers.Layer):
**kwargs: keyword arguments to be passed.
"""
super(ConvBlock, self).__init__(**kwargs)
self._filters = filters
self._kernel_size = kernel_size
self._strides = strides
self._dilation_rate = dilation_rate
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._activation = activation
self._use_bias = use_bias
self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._config_dict = {
'filters': filters,
'kernel_size': kernel_size,
'strides': strides,
'dilation_rate': dilation_rate,
'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
'activation': activation,
'use_sync_bn': use_sync_bn,
'use_bias': use_bias,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon
}
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
......@@ -83,40 +83,29 @@ class ConvBlock(tf.keras.layers.Layer):
self._activation_fn = tf_utils.get_activation(activation)
def build(self, input_shape):
conv_kwargs = {
'padding': 'same',
'use_bias': self._config_dict['use_bias'],
'kernel_initializer': self._config_dict['kernel_initializer'],
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
}
self._conv0 = tf.keras.layers.Conv2D(
filters=self._filters,
kernel_size=self._kernel_size,
strides=self._strides,
dilation_rate=self._dilation_rate,
padding='same',
use_bias=self._use_bias,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
filters=self._config_dict['filters'],
kernel_size=self._config_dict['kernel_size'],
strides=self._config_dict['strides'],
dilation_rate=self._config_dict['dilation_rate'],
**conv_kwargs)
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
momentum=self._config_dict['norm_momentum'],
epsilon=self._config_dict['norm_epsilon'])
super(ConvBlock, self).build(input_shape)
def get_config(self):
config = {
'filters': self._filters,
'kernel_size': self._kernel_size,
'strides': self._strides,
'dilation_rate': self._dilation_rate,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'use_bias': self._use_bias,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
}
base_config = super(ConvBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
return self._config_dict
def call(self, inputs, training=None):
x = self._conv0(inputs)
......@@ -168,19 +157,19 @@ class ResBlock(tf.keras.layers.Layer):
**kwargs: Additional keyword arguments to be passed.
"""
super(ResBlock, self).__init__(**kwargs)
self._filters = filters
self._strides = strides
self._use_projection = use_projection
self._use_sync_bn = use_sync_bn
self._use_bias = use_bias
self._activation = activation
self._kernel_initializer = kernel_initializer
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._config_dict = {
'filters': filters,
'strides': strides,
'use_projection': use_projection,
'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
'activation': activation,
'use_sync_bn': use_sync_bn,
'use_bias': use_bias,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon
}
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
......@@ -192,70 +181,55 @@ class ResBlock(tf.keras.layers.Layer):
self._activation_fn = tf_utils.get_activation(activation)
def build(self, input_shape):
if self._use_projection:
conv_kwargs = {
'filters': self._config_dict['filters'],
'padding': 'same',
'use_bias': self._config_dict['use_bias'],
'kernel_initializer': self._config_dict['kernel_initializer'],
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
}
if self._config_dict['use_projection']:
self._shortcut = tf.keras.layers.Conv2D(
filters=self._filters,
filters=self._config_dict['filters'],
kernel_size=1,
strides=self._strides,
use_bias=self._use_bias,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
strides=self._config_dict['strides'],
use_bias=self._config_dict['use_bias'],
kernel_initializer=self._config_dict['kernel_initializer'],
kernel_regularizer=self._config_dict['kernel_regularizer'],
bias_regularizer=self._config_dict['bias_regularizer'])
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
momentum=self._config_dict['norm_momentum'],
epsilon=self._config_dict['norm_epsilon'])
self._conv1 = tf.keras.layers.Conv2D(
filters=self._filters,
kernel_size=3,
strides=self._strides,
padding='same',
use_bias=self._use_bias,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
strides=self._config_dict['strides'],
**conv_kwargs)
self._norm1 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
momentum=self._config_dict['norm_momentum'],
epsilon=self._config_dict['norm_epsilon'])
self._conv2 = tf.keras.layers.Conv2D(
filters=self._filters,
kernel_size=3,
strides=1,
padding='same',
use_bias=self._use_bias,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
**conv_kwargs)
self._norm2 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
momentum=self._config_dict['norm_momentum'],
epsilon=self._config_dict['norm_epsilon'])
super(ResBlock, self).build(input_shape)
def get_config(self):
config = {
'filters': self._filters,
'strides': self._strides,
'use_projection': self._use_projection,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'use_bias': self._use_bias,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
}
base_config = super(ResBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
return self._config_dict
def call(self, inputs, training=None):
shortcut = inputs
if self._use_projection:
if self._config_dict['use_projection']:
shortcut = self._shortcut(shortcut)
shortcut = self._norm0(shortcut)
......
......@@ -27,7 +27,6 @@ class RefUnet(tf.keras.layers.Layer):
Basnet: Boundary-aware salient object detection.
"""
def __init__(self,
use_separable_conv=False,
activation='relu',
use_sync_bn=False,
use_bias=True,
......@@ -40,8 +39,6 @@ class RefUnet(tf.keras.layers.Layer):
"""Residual Refinement Module of BASNet.
Args:
use_separable_conv: `bool`, if True use separable convolution for
convolution in BASNet layers.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
use_bias: if True, use bias in conv2d.
......@@ -57,7 +54,6 @@ class RefUnet(tf.keras.layers.Layer):
"""
super(RefUnet, self).__init__(**kwargs)
self._config_dict = {
'use_separable_conv': use_separable_conv,
'activation': activation,
'use_sync_bn': use_sync_bn,
'use_bias': use_bias,
......@@ -83,11 +79,10 @@ class RefUnet(tf.keras.layers.Layer):
def build(self, input_shape):
"""Creates the variables of the BASNet decoder."""
if self._config_dict['use_separable_conv']:
conv_op = tf.keras.layers.SeparableConv2D
else:
conv_op = tf.keras.layers.Conv2D
conv_kwargs = {
'dilation_rate': 1,
'activation': self._config_dict['activation'],
'kernel_size': 3,
'strides': 1,
'use_bias': self._config_dict['use_bias'],
......@@ -96,21 +91,44 @@ class RefUnet(tf.keras.layers.Layer):
'bias_regularizer': self._config_dict['bias_regularizer'],
}
self._in_conv = conv_op(filters=64, padding='same',**conv_kwargs)
self._in_conv = conv_op(
filters=64,
padding='same',
**conv_kwargs)
self._en_convs = []
for _ in range(4):
self._en_convs.append(nn_blocks.ConvBlock(filters=64, **conv_kwargs))
self._en_convs.append(nn_blocks.ConvBlock(
filters=64,
use_sync_bn=self._config_dict['use_sync_bn'],
norm_momentum=self._config_dict['norm_momentum'],
norm_epsilon=self._config_dict['norm_epsilon'],
**conv_kwargs))
self._bridge_convs = []
for _ in range(1):
self._bridge_convs.append(nn_blocks.ConvBlock(filters=64, **conv_kwargs))
self._bridge_convs.append(nn_blocks.ConvBlock(
filters=64,
use_sync_bn=self._config_dict['use_sync_bn'],
norm_momentum=self._config_dict['norm_momentum'],
norm_epsilon=self._config_dict['norm_epsilon'],
**conv_kwargs))
self._de_convs = []
for _ in range(4):
self._de_convs.append(nn_blocks.ConvBlock(filters=64, **conv_kwargs))
self._out_conv = conv_op(padding='same', filters=1, **conv_kwargs)
self._de_convs.append(nn_blocks.ConvBlock(
filters=64,
use_sync_bn=self._config_dict['use_sync_bn'],
norm_momentum=self._config_dict['norm_momentum'],
norm_epsilon=self._config_dict['norm_epsilon'],
**conv_kwargs))
self._out_conv = conv_op(
filters=1,
padding='same',
**conv_kwargs)
def call(self, inputs):
endpoints = {}
......
......@@ -36,21 +36,31 @@ def build_basnet_model(
model_config: exp_cfg.BASNetModel,
l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds BASNet model."""
backbone = basnet_model.BASNet_Encoder(
input_specs=input_specs)
norm_activation_config = model_config.norm_activation
backbone = basnet_model.BASNet_Encoder(
input_specs=input_specs,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
use_bias=model_config.use_bias,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
decoder = basnet_model.BASNet_Decoder(
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
use_bias=model_config.use_bias,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
activation=norm_activation_config.activation,
kernel_regularizer=l2_regularizer)
refinement = refunet.RefUnet()
norm_activation_config = model_config.norm_activation
refinement = refunet.RefUnet(
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
use_bias=model_config.use_bias,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
model = basnet_model.BASNetModel(backbone, decoder, refinement)
return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment