Commit 25bf4592 authored by Gunho Park's avatar Gunho Park
Browse files

keras.Model to keras.layer

parent 13dffa31
......@@ -12,14 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Decoder of BASNet.
Boundary-Awar network (BASNet) were proposed in:
[1] Qin, Xuebin, et al.
Basnet: Boundary-aware salient object detection.
"""
# Import libraries
from typing import Mapping
import tensorflow as tf
from official.modeling import tf_utils
......@@ -27,8 +22,11 @@ from official.vision.beta.projects.basnet.modeling.layers import nn_blocks
# nf : num_filters, dr : dilation_rate
# (conv1_nf, conv1_dr, convm_nf, convm_dr, conv2_nf, conv2_dr, scale_factor)
BASNET_BRIDGE_SPECS = [
(512, 2, 512, 2, 512, 2, 32), #Sup0, Bridge
]
BASNET_DECODER_SPECS = [
(512, 2, 512, 2, 512, 2, 32), #Bridge(Sup0)
(512, 1, 512, 2, 512, 2, 32), #Sup1, stage6d
(512, 1, 512, 1, 512, 1, 16), #Sup2, stage5d
(512, 1, 512, 1, 256, 1, 8), #Sup3, stage4d
......@@ -37,12 +35,17 @@ BASNET_DECODER_SPECS = [
(64, 1, 64, 1, 64, 1, 1) #Sup6, stage1d
]
@tf.keras.utils.register_keras_serializable(package='Vision')
class BASNet_Decoder(tf.keras.Model):
"""BASNet Decoder."""
class BASNet_Decoder(tf.keras.layers.Layer):
"""Decoder of BASNet.
Boundary-Awar network (BASNet) were proposed in:
[1] Qin, Xuebin, et al.
Basnet: Boundary-aware salient object detection.
"""
def __init__(self,
input_specs,
use_separable_conv=False,
activation='relu',
use_sync_bn=False,
......@@ -56,12 +59,11 @@ class BASNet_Decoder(tf.keras.Model):
"""BASNet Decoder initialization function.
Args:
input_specs: `dict` input specifications. A dictionary consists of
{level: TensorShape} from a backbone.
use_separable_conv: `bool`, if True use separable convolution for
convolution in BASNet layers.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
use_bias: if True, use bias in convolution.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
......@@ -70,8 +72,8 @@ class BASNet_Decoder(tf.keras.Model):
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
**kwargs: keyword arguments to be passed.
"""
super(BASNet_Decoder, self).__init__(**kwargs)
self._config_dict = {
'input_specs': input_specs,
'use_separable_conv': use_separable_conv,
'activation': activation,
'use_sync_bn': use_sync_bn,
......@@ -82,89 +84,104 @@ class BASNet_Decoder(tf.keras.Model):
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
}
if use_separable_conv:
conv2d = tf.keras.layers.SeparableConv2D
else:
conv2d = tf.keras.layers.Conv2D
if use_sync_bn:
norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
norm = tf.keras.layers.BatchNormalization
activation_fn = tf.keras.layers.Activation(
tf_utils.get_activation(activation))
# Build input feature pyramid.
if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = -1
else:
bn_axis = 1
self._activation = tf_utils.get_activation(activation)
self._concat = tf.keras.layers.Concatenate(axis=-1)
self._sigmoid = tf.keras.layers.Activation(activation='sigmoid')
def build(self, input_shape):
"""Creates the variables of the BASNet decoder."""
if self._config_dict['use_separable_conv']:
conv_op = tf.keras.layers.SeparableConv2D
else:
conv_op = tf.keras.layers.Conv2D
conv_kwargs = {
'kernel_size': 3,
'strides': 1,
'use_bias': self._config_dict['use_bias'],
'kernel_initializer': self._config_dict['kernel_initializer'],
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
}
# Get input feature pyramid from backbone.
inputs = self._build_input_pyramid(input_specs)
levels = sorted(inputs.keys(), reverse=True)
self._out_convs = []
self._out_usmps = []
sup = {}
# Bridge layers.
self._bdg_convs = []
for i, spec in enumerate(BASNET_BRIDGE_SPECS):
blocks = []
for j in range(3):
blocks.append(nn_blocks.ConvBlock(
filters=spec[2*j],
dilation_rate=spec[2*j+1],
activation='relu',
norm_momentum=0.99,
norm_epsilon=0.001,
**conv_kwargs))
self._bdg_convs.append(blocks)
self._out_convs.append(conv_op(
filters=1,
padding='same',
**conv_kwargs))
self._out_usmps.append(tf.keras.layers.UpSampling2D(
size=spec[6],
interpolation='bilinear'
))
# Decoder layers.
self._dec_convs = []
for i, spec in enumerate(BASNET_DECODER_SPECS):
if i == 0:
#x = inputs['5'] # Bridge input
x = inputs[levels[0]] # Bridge input
# str(levels[-1]) ??
else:
x = tf.keras.layers.Concatenate(axis=-1)([x, inputs[levels[i-1]]])
blocks = []
for j in range(3):
x = nn_blocks.ConvBlock(
blocks.append(nn_blocks.ConvBlock(
filters=spec[2*j],
kernel_size=3,
strides=1,
dilation_rate=spec[2*j+1],
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activation='relu',
use_sync_bn=use_sync_bn,
use_bias=use_bias,
norm_momentum=0.99,
norm_epsilon=0.001
)(x)
output = tf.keras.layers.Conv2D(
filters=1, kernel_size=3, strides=1,
use_bias=use_bias, padding='same',
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer
)(x)
output = tf.keras.layers.UpSampling2D(
norm_epsilon=0.001,
**conv_kwargs))
self._dec_convs.append(blocks)
self._out_convs.append(conv_op(
filters=1,
padding='same',
**conv_kwargs))
self._out_usmps.append(tf.keras.layers.UpSampling2D(
size=spec[6],
interpolation='bilinear'
)(output)
output = tf.keras.layers.Activation(
activation='sigmoid'
)(output)
sup[str(i)] = output
if i != 0:
))
def call(self, backbone_output: Mapping[str, tf.Tensor]):
levels = sorted(backbone_output.keys(), reverse=True)
sup = {}
x = backbone_output[levels[0]]
for blocks in self._bdg_convs:
for block in blocks:
x = block(x)
sup['0'] = x
for i, blocks in enumerate(self._dec_convs):
x = self._concat([x, backbone_output[levels[i]]])
for block in blocks:
x = block(x)
sup[str(i+1)] = x
x = tf.keras.layers.UpSampling2D(
size=2,
interpolation='bilinear'
)(x)
for i, (conv, usmp) in enumerate(zip(self._out_convs, self._out_usmps)):
sup[str(i)] = self._sigmoid(usmp(conv(sup[str(i)])))
self._output_specs = {
str(order): sup[str(order)].get_shape()
for order in range(0, len(BASNET_DECODER_SPECS))
}
super(BASNet_Decoder, self).__init__(inputs=inputs, outputs=sup, **kwargs)
def _build_input_pyramid(self, input_specs):
assert isinstance(input_specs, dict)
inputs = {}
for level, spec in input_specs.items():
inputs[level] = tf.keras.Input(shape=spec[1:])
return inputs
return sup
def get_config(self):
return self._config_dict
......
......@@ -12,13 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BASNet Encoder
Boundary-Awar network (BASNet) were proposed in:
[1] Qin, Xuebin, et al.
Basnet: Boundary-aware salient object detection.
"""
# Import libraries
import tensorflow as tf
......@@ -29,19 +22,26 @@ from official.vision.beta.projects.basnet.modeling.layers import nn_blocks
# Specifications for BASNet encoder.
# Each element in the block configuration is in the following format:
# (block_fn, num_filters, stride, block_repeats, maxpool)
# (num_filters, stride, block_repeats, maxpool)
BASNET_ENCODER_SPECS = [
('residual', 64, 1, 3, 0), #ResNet-34,
('residual', 128, 2, 4, 0), #ResNet-34,
('residual', 256, 2, 6, 0), #ResNet-34,
('residual', 512, 2, 3, 1), #ResNet-34,
('residual', 512, 1, 3, 1), #BASNet,
('residual', 512, 1, 3, 0), #BASNet,
(64, 1, 3, 0), #ResNet-34,
(128, 2, 4, 0), #ResNet-34,
(256, 2, 6, 0), #ResNet-34,
(512, 2, 3, 1), #ResNet-34,
(512, 1, 3, 1), #BASNet,
(512, 1, 3, 0), #BASNet,
]
@tf.keras.utils.register_keras_serializable(package='Vision')
class BASNet_Encoder(tf.keras.Model):
"""BASNet Encoder
Boundary-Awar network (BASNet) were proposed in:
[1] Qin, Xuebin, et al.
Basnet: Boundary-aware salient object detection.
"""
def __init__(self,
input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]),
......@@ -54,7 +54,7 @@ class BASNet_Encoder(tf.keras.Model):
kernel_regularizer=None,
bias_regularizer=None,
**kwargs):
"""BASNet_En initialization function.
"""BASNet_Encoder initialization function.
Args:
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
......@@ -109,19 +109,14 @@ class BASNet_Encoder(tf.keras.Model):
endpoints = {}
for i, spec in enumerate(BASNET_ENCODER_SPECS):
if spec[0] == 'residual':
block_fn = nn_blocks.ResBlock
else:
raise ValueError('Block fn `{}` is not supported.'.format(spec[0]))
x = self._block_group(
inputs=x,
filters=spec[1],
strides=spec[2],
block_fn=block_fn,
block_repeats=spec[3],
filters=spec[0],
strides=spec[1],
block_repeats=spec[2],
name='block_group_l{}'.format(i + 2))
endpoints[str(i)] = x
if spec[4]:
if spec[3]:
x = tf.keras.layers.MaxPool2D(pool_size=2, strides=2, padding='same')(x)
self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}
......@@ -131,24 +126,22 @@ class BASNet_Encoder(tf.keras.Model):
inputs,
filters,
strides,
block_fn,
block_repeats=1,
name='block_group'):
"""Creates one group of blocks for the ResNet model.
"""Creates one group of residual blocks for the BASNet encoder model.
Args:
inputs: `Tensor` of size `[batch, channels, height, width]`.
filters: `int` number of filters for the first convolution of the layer.
strides: `int` stride to use for the first convolution of the layer. If
greater than 1, this layer will downsample the input.
block_fn: Either `nn_blocks.ResidualBlock` or `nn_blocks.BottleneckBlock`.
block_repeats: `int` number of blocks contained in the layer.
name: `str`name for the block.
Returns:
The output `Tensor` of the block layer.
"""
x = block_fn(
x = nn_blocks.ResBlock(
filters=filters,
strides=strides,
use_projection=True,
......@@ -163,7 +156,7 @@ class BASNet_Encoder(tf.keras.Model):
inputs)
for _ in range(1, block_repeats):
x = block_fn(
x = nn_block.ResBlock(
filters=filters,
strides=1,
use_projection=False,
......
......@@ -12,13 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Residual Refinement Module of BASNet.
Boundary-Awar network (BASNet) were proposed in:
[1] Qin, Xuebin, et al.
Basnet: Boundary-aware salient object detection.
"""
# Import libraries
import tensorflow as tf
......@@ -26,10 +19,15 @@ from official.vision.beta.projects.basnet.modeling.layers import nn_blocks
@tf.keras.utils.register_keras_serializable(package='Vision')
class RefUnet(tf.keras.Model):
class RefUnet(tf.keras.layers.Layer):
"""Residual Refinement Module of BASNet.
Boundary-Awar network (BASNet) were proposed in:
[1] Qin, Xuebin, et al.
Basnet: Boundary-aware salient object detection.
"""
def __init__(self,
input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 1]),
use_separable_conv=False,
activation='relu',
use_sync_bn=False,
use_bias=True,
......@@ -42,7 +40,8 @@ class RefUnet(tf.keras.Model):
"""Residual Refinement Module of BASNet.
Args:
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
use_separable_conv: `bool`, if True use separable convolution for
convolution in BASNet layers.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
use_bias: if True, use bias in conv2d.
......@@ -56,126 +55,91 @@ class RefUnet(tf.keras.Model):
Default to None.
**kwargs: keyword arguments to be passed.
"""
self._input_specs = input_specs
self._use_sync_bn = use_sync_bn
self._use_bias = use_bias
self._activation = activation
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
super(RefUnet, self).__init__(**kwargs)
self._config_dict = {
'use_separable_conv': use_separable_conv,
'activation': activation,
'use_sync_bn': use_sync_bn,
'use_bias': use_bias,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon,
'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
}
if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = -1
else:
bn_axis = 1
self._concat = tf.keras.layers.Concatenate(axis=-1)
self._sigmoid = tf.keras.layers.Activation(activation='sigmoid')
self._maxpool = tf.keras.layers.MaxPool2D(
pool_size=2,
strides=2,
padding='valid')
self._upsample = tf.keras.layers.UpSampling2D(
size=2,
interpolation='bilinear')
# Build ResNet.
inputs = tf.keras.Input(shape=self._input_specs.shape[1:])
def build(self, input_shape):
"""Creates the variables of the BASNet decoder."""
if self._config_dict['use_separable_conv']:
conv_op = tf.keras.layers.SeparableConv2D
else:
conv_op = tf.keras.layers.Conv2D
conv_kwargs = {
'kernel_size': 3,
'strides': 1,
'use_bias': self._config_dict['use_bias'],
'kernel_initializer': self._config_dict['kernel_initializer'],
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
}
endpoints = {}
residual = inputs
self._in_conv = conv_op(filters=64, padding='same',**conv_kwargs)
x = tf.keras.layers.Conv2D(
filters=64, kernel_size=3, strides=1,
use_bias=self._use_bias, padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
inputs)
self._en_convs = []
for _ in range(4):
self._en_convs.append(nn_blocks.ConvBlock(filters=64, **conv_kwargs))
self._bridge_convs = []
for _ in range(1):
self._bridge_convs.append(nn_blocks.ConvBlock(filters=64, **conv_kwargs))
# Top-down
for i in range(4):
x = nn_blocks.ConvBlock(
filters=64,
kernel_size=3,
strides=1,
dilation_rate=1,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation='relu',
use_sync_bn=self._use_sync_bn,
use_bias=self._use_bias,
norm_momentum=0.99,
norm_epsilon=0.001
)(x)
self._de_convs = []
for _ in range(4):
self._de_convs.append(nn_blocks.ConvBlock(filters=64, **conv_kwargs))
endpoints[str(i)] = x
self._out_conv = conv_op(padding='same', filters=1, **conv_kwargs)
def call(self, inputs):
endpoints = {}
x = tf.keras.layers.MaxPool2D(pool_size=2, strides=2, padding='valid')(x)
residual = inputs
x = self._in_conv(inputs)
# Bridge
x = nn_blocks.ConvBlock(
filters=64,
kernel_size=3,
strides=1,
dilation_rate=1,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation='relu',
use_sync_bn=self._use_sync_bn,
use_bias=self._use_bias,
norm_momentum=0.99,
norm_epsilon=0.001
)(x)
# Top-down
for i, block in enumerate(self._en_convs):
x = block(x)
endpoints[str(i)] = x
x = self._maxpool(x)
x = tf.keras.layers.UpSampling2D(
size=2,
interpolation='bilinear'
)(x)
# Bridge
for i, block in enumerate(self._bridge_convs):
x = block(x)
# Bottom-up
for i, block in enumerate(self._de_convs):
x = self._upsample(x)
x = self._concat([endpoints[str(3-i)], x])
x = block(x)
for i in range(4):
x = tf.keras.layers.Concatenate(axis=-1)([endpoints[str(3-i)], x])
x = nn_blocks.ConvBlock(
filters=64,
kernel_size=3,
strides=1,
dilation_rate=1,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation='relu',
use_sync_bn=self._use_sync_bn,
use_bias=self._use_bias,
norm_momentum=0.99,
norm_epsilon=0.001
)(x)
if i == 3:
x = tf.keras.layers.Conv2D(
filters=1, kernel_size=3, strides=1,
use_bias=self._use_bias, padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer
)(x)
else:
x = tf.keras.layers.UpSampling2D(
size=2,
interpolation='bilinear'
)(x)
x = self._out_conv(x)
residual = tf.cast(residual, dtype=x.dtype)
output = x + residual
output = tf.keras.layers.Activation(
activation='sigmoid'
)(output)
output = self._sigmoid(x + residual)
self._output_specs = output.get_shape()
super(RefUnet, self).__init__(inputs=inputs, outputs=output, **kwargs)
return output
@classmethod
def from_config(cls, config, custom_objects=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment