Commit 53c3f653 authored by Gunho Park's avatar Gunho Park
Browse files

Internal change

parent d4f401e1
...@@ -46,6 +46,7 @@ class DataConfig(cfg.DataConfig): ...@@ -46,6 +46,7 @@ class DataConfig(cfg.DataConfig):
class BASNetModel(hyperparams.Config): class BASNetModel(hyperparams.Config):
"""BASNet model config.""" """BASNet model config."""
input_size: List[int] = dataclasses.field(default_factory=list) input_size: List[int] = dataclasses.field(default_factory=list)
use_bias: bool = False
norm_activation: common.NormActivation = common.NormActivation() norm_activation: common.NormActivation = common.NormActivation()
...@@ -99,6 +100,7 @@ def basnet_duts() -> cfg.ExperimentConfig: ...@@ -99,6 +100,7 @@ def basnet_duts() -> cfg.ExperimentConfig:
task=BASNetTask( task=BASNetTask(
model=BASNetModel( model=BASNetModel(
input_size=[None, None, 3], input_size=[None, None, 3],
use_bias=True,
norm_activation=common.NormActivation( norm_activation=common.NormActivation(
activation='relu', activation='relu',
norm_momentum=0.99, norm_momentum=0.99,
......
...@@ -274,11 +274,11 @@ def build_basnet_encoder( ...@@ -274,11 +274,11 @@ def build_basnet_encoder(
norm_activation_config = model_config.norm_activation norm_activation_config = model_config.norm_activation
assert backbone_type == 'basnet_encoder', (f'Inconsistent backbone type ' assert backbone_type == 'basnet_encoder', (f'Inconsistent backbone type '
f'{backbone_type}') f'{backbone_type}')
return BASNet_Encoder( return BASNet_Encoder(
input_specs=input_specs, input_specs=input_specs,
activation=norm_activation_config.activation, activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn, use_sync_bn=norm_activation_config.use_sync_bn,
use_bias=norm_activation_config.use_bias,
norm_momentum=norm_activation_config.norm_momentum, norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon, norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer) kernel_regularizer=l2_regularizer)
...@@ -289,7 +289,6 @@ class BASNet_Decoder(tf.keras.layers.Layer): ...@@ -289,7 +289,6 @@ class BASNet_Decoder(tf.keras.layers.Layer):
"""BASNet decoder.""" """BASNet decoder."""
def __init__(self, def __init__(self,
use_separable_conv=False,
activation='relu', activation='relu',
use_sync_bn=False, use_sync_bn=False,
use_bias=True, use_bias=True,
...@@ -302,11 +301,11 @@ class BASNet_Decoder(tf.keras.layers.Layer): ...@@ -302,11 +301,11 @@ class BASNet_Decoder(tf.keras.layers.Layer):
"""BASNet decoder initialization function. """BASNet decoder initialization function.
Args: Args:
use_separable_conv: `bool`, if True use separable convolution for
convolution in BASNet layers.
activation: `str` name of the activation function. activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization. use_sync_bn: if True, use synchronized batch normalization.
use_bias: if True, use bias in convolution. use_bias: if True, use bias in convolution.
use_separable_conv: `bool`, if True use separable convolution for
convolution in BASNet layers.
norm_momentum: `float` normalization omentum for the moving average. norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by norm_epsilon: `float` small float added to variance to avoid dividing by
zero. zero.
...@@ -317,7 +316,6 @@ class BASNet_Decoder(tf.keras.layers.Layer): ...@@ -317,7 +316,6 @@ class BASNet_Decoder(tf.keras.layers.Layer):
""" """
super(BASNet_Decoder, self).__init__(**kwargs) super(BASNet_Decoder, self).__init__(**kwargs)
self._config_dict = { self._config_dict = {
'use_separable_conv': use_separable_conv,
'activation': activation, 'activation': activation,
'use_sync_bn': use_sync_bn, 'use_sync_bn': use_sync_bn,
'use_bias': use_bias, 'use_bias': use_bias,
...@@ -337,10 +335,7 @@ class BASNet_Decoder(tf.keras.layers.Layer): ...@@ -337,10 +335,7 @@ class BASNet_Decoder(tf.keras.layers.Layer):
def build(self, input_shape): def build(self, input_shape):
"""Creates the variables of the BASNet decoder.""" """Creates the variables of the BASNet decoder."""
if self._config_dict['use_separable_conv']: conv_op = tf.keras.layers.Conv2D
conv_op = tf.keras.layers.SeparableConv2D
else:
conv_op = tf.keras.layers.Conv2D
conv_kwargs = { conv_kwargs = {
'kernel_size': 3, 'kernel_size': 3,
'strides': 1, 'strides': 1,
...@@ -362,6 +357,7 @@ class BASNet_Decoder(tf.keras.layers.Layer): ...@@ -362,6 +357,7 @@ class BASNet_Decoder(tf.keras.layers.Layer):
filters=spec[2*j], filters=spec[2*j],
dilation_rate=spec[2*j+1], dilation_rate=spec[2*j+1],
activation='relu', activation='relu',
use_sync_bn=self._config_dict['use_sync_bn'],
norm_momentum=0.99, norm_momentum=0.99,
norm_epsilon=0.001, norm_epsilon=0.001,
**conv_kwargs)) **conv_kwargs))
...@@ -384,6 +380,7 @@ class BASNet_Decoder(tf.keras.layers.Layer): ...@@ -384,6 +380,7 @@ class BASNet_Decoder(tf.keras.layers.Layer):
filters=spec[2*j], filters=spec[2*j],
dilation_rate=spec[2*j+1], dilation_rate=spec[2*j+1],
activation='relu', activation='relu',
use_sync_bn=self._config_dict['use_sync_bn'],
norm_momentum=0.99, norm_momentum=0.99,
norm_epsilon=0.001, norm_epsilon=0.001,
**conv_kwargs)) **conv_kwargs))
......
...@@ -58,20 +58,20 @@ class ConvBlock(tf.keras.layers.Layer): ...@@ -58,20 +58,20 @@ class ConvBlock(tf.keras.layers.Layer):
**kwargs: keyword arguments to be passed. **kwargs: keyword arguments to be passed.
""" """
super(ConvBlock, self).__init__(**kwargs) super(ConvBlock, self).__init__(**kwargs)
self._config_dict = {
self._filters = filters 'filters': filters,
self._kernel_size = kernel_size 'kernel_size': kernel_size,
self._strides = strides 'strides': strides,
self._dilation_rate = dilation_rate 'dilation_rate': dilation_rate,
self._kernel_initializer = kernel_initializer 'kernel_initializer': kernel_initializer,
self._kernel_regularizer = kernel_regularizer 'kernel_regularizer': kernel_regularizer,
self._bias_regularizer = bias_regularizer 'bias_regularizer': bias_regularizer,
self._activation = activation 'activation': activation,
self._use_bias = use_bias 'use_sync_bn': use_sync_bn,
self._use_sync_bn = use_sync_bn 'use_bias': use_bias,
self._norm_momentum = norm_momentum 'norm_momentum': norm_momentum,
self._norm_epsilon = norm_epsilon 'norm_epsilon': norm_epsilon
}
if use_sync_bn: if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else: else:
...@@ -83,40 +83,29 @@ class ConvBlock(tf.keras.layers.Layer): ...@@ -83,40 +83,29 @@ class ConvBlock(tf.keras.layers.Layer):
self._activation_fn = tf_utils.get_activation(activation) self._activation_fn = tf_utils.get_activation(activation)
def build(self, input_shape): def build(self, input_shape):
conv_kwargs = {
'padding': 'same',
'use_bias': self._config_dict['use_bias'],
'kernel_initializer': self._config_dict['kernel_initializer'],
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
}
self._conv0 = tf.keras.layers.Conv2D( self._conv0 = tf.keras.layers.Conv2D(
filters=self._filters, filters=self._config_dict['filters'],
kernel_size=self._kernel_size, kernel_size=self._config_dict['kernel_size'],
strides=self._strides, strides=self._config_dict['strides'],
dilation_rate=self._dilation_rate, dilation_rate=self._config_dict['dilation_rate'],
padding='same', **conv_kwargs)
use_bias=self._use_bias,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm0 = self._norm( self._norm0 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._config_dict['norm_momentum'],
epsilon=self._norm_epsilon) epsilon=self._config_dict['norm_epsilon'])
super(ConvBlock, self).build(input_shape) super(ConvBlock, self).build(input_shape)
def get_config(self): def get_config(self):
config = { return self._config_dict
'filters': self._filters,
'kernel_size': self._kernel_size,
'strides': self._strides,
'dilation_rate': self._dilation_rate,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'use_bias': self._use_bias,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
}
base_config = super(ConvBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs, training=None): def call(self, inputs, training=None):
x = self._conv0(inputs) x = self._conv0(inputs)
...@@ -168,19 +157,19 @@ class ResBlock(tf.keras.layers.Layer): ...@@ -168,19 +157,19 @@ class ResBlock(tf.keras.layers.Layer):
**kwargs: Additional keyword arguments to be passed. **kwargs: Additional keyword arguments to be passed.
""" """
super(ResBlock, self).__init__(**kwargs) super(ResBlock, self).__init__(**kwargs)
self._config_dict = {
self._filters = filters 'filters': filters,
self._strides = strides 'strides': strides,
self._use_projection = use_projection 'use_projection': use_projection,
self._use_sync_bn = use_sync_bn 'kernel_initializer': kernel_initializer,
self._use_bias = use_bias 'kernel_regularizer': kernel_regularizer,
self._activation = activation 'bias_regularizer': bias_regularizer,
self._kernel_initializer = kernel_initializer 'activation': activation,
self._norm_momentum = norm_momentum 'use_sync_bn': use_sync_bn,
self._norm_epsilon = norm_epsilon 'use_bias': use_bias,
self._kernel_regularizer = kernel_regularizer 'norm_momentum': norm_momentum,
self._bias_regularizer = bias_regularizer 'norm_epsilon': norm_epsilon
}
if use_sync_bn: if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else: else:
...@@ -192,70 +181,55 @@ class ResBlock(tf.keras.layers.Layer): ...@@ -192,70 +181,55 @@ class ResBlock(tf.keras.layers.Layer):
self._activation_fn = tf_utils.get_activation(activation) self._activation_fn = tf_utils.get_activation(activation)
def build(self, input_shape): def build(self, input_shape):
if self._use_projection: conv_kwargs = {
'filters': self._config_dict['filters'],
'padding': 'same',
'use_bias': self._config_dict['use_bias'],
'kernel_initializer': self._config_dict['kernel_initializer'],
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
}
if self._config_dict['use_projection']:
self._shortcut = tf.keras.layers.Conv2D( self._shortcut = tf.keras.layers.Conv2D(
filters=self._filters, filters=self._config_dict['filters'],
kernel_size=1, kernel_size=1,
strides=self._strides, strides=self._config_dict['strides'],
use_bias=self._use_bias, use_bias=self._config_dict['use_bias'],
kernel_initializer=self._kernel_initializer, kernel_initializer=self._config_dict['kernel_initializer'],
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._config_dict['kernel_regularizer'],
bias_regularizer=self._bias_regularizer) bias_regularizer=self._config_dict['bias_regularizer'])
self._norm0 = self._norm( self._norm0 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._config_dict['norm_momentum'],
epsilon=self._norm_epsilon) epsilon=self._config_dict['norm_epsilon'])
self._conv1 = tf.keras.layers.Conv2D( self._conv1 = tf.keras.layers.Conv2D(
filters=self._filters,
kernel_size=3, kernel_size=3,
strides=self._strides, strides=self._config_dict['strides'],
padding='same', **conv_kwargs)
use_bias=self._use_bias,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm1 = self._norm( self._norm1 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._config_dict['norm_momentum'],
epsilon=self._norm_epsilon) epsilon=self._config_dict['norm_epsilon'])
self._conv2 = tf.keras.layers.Conv2D( self._conv2 = tf.keras.layers.Conv2D(
filters=self._filters,
kernel_size=3, kernel_size=3,
strides=1, strides=1,
padding='same', **conv_kwargs)
use_bias=self._use_bias,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm2 = self._norm( self._norm2 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._config_dict['norm_momentum'],
epsilon=self._norm_epsilon) epsilon=self._config_dict['norm_epsilon'])
super(ResBlock, self).build(input_shape) super(ResBlock, self).build(input_shape)
def get_config(self): def get_config(self):
config = { return self._config_dict
'filters': self._filters,
'strides': self._strides,
'use_projection': self._use_projection,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'use_bias': self._use_bias,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
}
base_config = super(ResBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs, training=None): def call(self, inputs, training=None):
shortcut = inputs shortcut = inputs
if self._use_projection: if self._config_dict['use_projection']:
shortcut = self._shortcut(shortcut) shortcut = self._shortcut(shortcut)
shortcut = self._norm0(shortcut) shortcut = self._norm0(shortcut)
......
...@@ -27,7 +27,6 @@ class RefUnet(tf.keras.layers.Layer): ...@@ -27,7 +27,6 @@ class RefUnet(tf.keras.layers.Layer):
Basnet: Boundary-aware salient object detection. Basnet: Boundary-aware salient object detection.
""" """
def __init__(self, def __init__(self,
use_separable_conv=False,
activation='relu', activation='relu',
use_sync_bn=False, use_sync_bn=False,
use_bias=True, use_bias=True,
...@@ -40,8 +39,6 @@ class RefUnet(tf.keras.layers.Layer): ...@@ -40,8 +39,6 @@ class RefUnet(tf.keras.layers.Layer):
"""Residual Refinement Module of BASNet. """Residual Refinement Module of BASNet.
Args: Args:
use_separable_conv: `bool`, if True use separable convolution for
convolution in BASNet layers.
activation: `str` name of the activation function. activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization. use_sync_bn: if True, use synchronized batch normalization.
use_bias: if True, use bias in conv2d. use_bias: if True, use bias in conv2d.
...@@ -57,7 +54,6 @@ class RefUnet(tf.keras.layers.Layer): ...@@ -57,7 +54,6 @@ class RefUnet(tf.keras.layers.Layer):
""" """
super(RefUnet, self).__init__(**kwargs) super(RefUnet, self).__init__(**kwargs)
self._config_dict = { self._config_dict = {
'use_separable_conv': use_separable_conv,
'activation': activation, 'activation': activation,
'use_sync_bn': use_sync_bn, 'use_sync_bn': use_sync_bn,
'use_bias': use_bias, 'use_bias': use_bias,
...@@ -83,11 +79,10 @@ class RefUnet(tf.keras.layers.Layer): ...@@ -83,11 +79,10 @@ class RefUnet(tf.keras.layers.Layer):
def build(self, input_shape): def build(self, input_shape):
"""Creates the variables of the BASNet decoder.""" """Creates the variables of the BASNet decoder."""
if self._config_dict['use_separable_conv']: conv_op = tf.keras.layers.Conv2D
conv_op = tf.keras.layers.SeparableConv2D
else:
conv_op = tf.keras.layers.Conv2D
conv_kwargs = { conv_kwargs = {
'dilation_rate': 1,
'activation': self._config_dict['activation'],
'kernel_size': 3, 'kernel_size': 3,
'strides': 1, 'strides': 1,
'use_bias': self._config_dict['use_bias'], 'use_bias': self._config_dict['use_bias'],
...@@ -96,21 +91,44 @@ class RefUnet(tf.keras.layers.Layer): ...@@ -96,21 +91,44 @@ class RefUnet(tf.keras.layers.Layer):
'bias_regularizer': self._config_dict['bias_regularizer'], 'bias_regularizer': self._config_dict['bias_regularizer'],
} }
self._in_conv = conv_op(filters=64, padding='same',**conv_kwargs) self._in_conv = conv_op(
filters=64,
padding='same',
**conv_kwargs)
self._en_convs = [] self._en_convs = []
for _ in range(4): for _ in range(4):
self._en_convs.append(nn_blocks.ConvBlock(filters=64, **conv_kwargs)) self._en_convs.append(nn_blocks.ConvBlock(
filters=64,
use_sync_bn=self._config_dict['use_sync_bn'],
norm_momentum=self._config_dict['norm_momentum'],
norm_epsilon=self._config_dict['norm_epsilon'],
**conv_kwargs))
self._bridge_convs = [] self._bridge_convs = []
for _ in range(1): for _ in range(1):
self._bridge_convs.append(nn_blocks.ConvBlock(filters=64, **conv_kwargs)) self._bridge_convs.append(nn_blocks.ConvBlock(
filters=64,
use_sync_bn=self._config_dict['use_sync_bn'],
norm_momentum=self._config_dict['norm_momentum'],
norm_epsilon=self._config_dict['norm_epsilon'],
**conv_kwargs))
self._de_convs = [] self._de_convs = []
for _ in range(4): for _ in range(4):
self._de_convs.append(nn_blocks.ConvBlock(filters=64, **conv_kwargs)) self._de_convs.append(nn_blocks.ConvBlock(
filters=64,
self._out_conv = conv_op(padding='same', filters=1, **conv_kwargs) use_sync_bn=self._config_dict['use_sync_bn'],
norm_momentum=self._config_dict['norm_momentum'],
norm_epsilon=self._config_dict['norm_epsilon'],
**conv_kwargs))
self._out_conv = conv_op(
filters=1,
padding='same',
**conv_kwargs)
def call(self, inputs): def call(self, inputs):
endpoints = {} endpoints = {}
......
...@@ -36,21 +36,31 @@ def build_basnet_model( ...@@ -36,21 +36,31 @@ def build_basnet_model(
model_config: exp_cfg.BASNetModel, model_config: exp_cfg.BASNetModel,
l2_regularizer: tf.keras.regularizers.Regularizer = None): l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds BASNet model.""" """Builds BASNet model."""
backbone = basnet_model.BASNet_Encoder(
input_specs=input_specs)
norm_activation_config = model_config.norm_activation norm_activation_config = model_config.norm_activation
backbone = basnet_model.BASNet_Encoder(
input_specs=input_specs,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
use_bias=model_config.use_bias,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
decoder = basnet_model.BASNet_Decoder( decoder = basnet_model.BASNet_Decoder(
use_sync_bn=norm_activation_config.use_sync_bn, activation=norm_activation_config.activation,
norm_momentum=norm_activation_config.norm_momentum, use_sync_bn=norm_activation_config.use_sync_bn,
norm_epsilon=norm_activation_config.norm_epsilon, use_bias=model_config.use_bias,
activation=norm_activation_config.activation, norm_momentum=norm_activation_config.norm_momentum,
kernel_regularizer=l2_regularizer) norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
refinement = refunet.RefUnet()
refinement = refunet.RefUnet(
norm_activation_config = model_config.norm_activation activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
use_bias=model_config.use_bias,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
model = basnet_model.BASNetModel(backbone, decoder, refinement) model = basnet_model.BASNetModel(backbone, decoder, refinement)
return model return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment