"vscode:/vscode.git/clone" did not exist on "c0b60faaf0051e4cbd295bd829b1989b20431b50"
Commit 6ddd627a authored by Gunho Park's avatar Gunho Park
Browse files

Merge en_de to model

parent e570fda5
......@@ -17,7 +17,6 @@
# pylint: disable=unused-import
from official.vision import beta
from official.vision.beta.projects.basnet.configs import basnet
from official.vision.beta.projects.basnet.modeling import basnet_encoder
from official.vision.beta.projects.basnet.modeling import basnet_decoder
from official.vision.beta.projects.basnet.modeling import basnet_model
from official.vision.beta.projects.basnet.modeling import refunet
from official.vision.beta.projects.basnet.tasks import basnet
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Import libraries
from typing import Mapping
import tensorflow as tf
from official.modeling import tf_utils
from official.vision.beta.projects.basnet.modeling.layers import nn_blocks
# nf : num_filters, dr : dilation_rate
# (conv1_nf, conv1_dr, convm_nf, convm_dr, conv2_nf, conv2_dr, scale_factor)
BASNET_BRIDGE_SPECS = [
(512, 2, 512, 2, 512, 2, 32), #Sup0, Bridge
]
BASNET_DECODER_SPECS = [
(512, 1, 512, 2, 512, 2, 32), #Sup1, stage6d
(512, 1, 512, 1, 512, 1, 16), #Sup2, stage5d
(512, 1, 512, 1, 256, 1, 8), #Sup3, stage4d
(256, 1, 256, 1, 128, 1, 4), #Sup4, stage3d
(128, 1, 128, 1, 64, 1, 2), #Sup5, stage2d
(64, 1, 64, 1, 64, 1, 1) #Sup6, stage1d
]
@tf.keras.utils.register_keras_serializable(package='Vision')
class BASNet_Decoder(tf.keras.layers.Layer):
"""Decoder of BASNet.
Boundary-Awar network (BASNet) were proposed in:
[1] Qin, Xuebin, et al.
Basnet: Boundary-aware salient object detection.
"""
def __init__(self,
use_separable_conv=False,
activation='relu',
use_sync_bn=False,
use_bias=True,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
**kwargs):
"""BASNet Decoder initialization function.
Args:
use_separable_conv: `bool`, if True use separable convolution for
convolution in BASNet layers.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
use_bias: if True, use bias in convolution.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
**kwargs: keyword arguments to be passed.
"""
super(BASNet_Decoder, self).__init__(**kwargs)
self._config_dict = {
'use_separable_conv': use_separable_conv,
'activation': activation,
'use_sync_bn': use_sync_bn,
'use_bias': use_bias,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon,
'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
}
if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = -1
else:
bn_axis = 1
self._activation = tf_utils.get_activation(activation)
self._concat = tf.keras.layers.Concatenate(axis=-1)
self._sigmoid = tf.keras.layers.Activation(activation='sigmoid')
def build(self, input_shape):
"""Creates the variables of the BASNet decoder."""
if self._config_dict['use_separable_conv']:
conv_op = tf.keras.layers.SeparableConv2D
else:
conv_op = tf.keras.layers.Conv2D
conv_kwargs = {
'kernel_size': 3,
'strides': 1,
'use_bias': self._config_dict['use_bias'],
'kernel_initializer': self._config_dict['kernel_initializer'],
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
}
self._out_convs = []
self._out_usmps = []
# Bridge layers.
self._bdg_convs = []
for i, spec in enumerate(BASNET_BRIDGE_SPECS):
blocks = []
for j in range(3):
blocks.append(nn_blocks.ConvBlock(
filters=spec[2*j],
dilation_rate=spec[2*j+1],
activation='relu',
norm_momentum=0.99,
norm_epsilon=0.001,
**conv_kwargs))
self._bdg_convs.append(blocks)
self._out_convs.append(conv_op(
filters=1,
padding='same',
**conv_kwargs))
self._out_usmps.append(tf.keras.layers.UpSampling2D(
size=spec[6],
interpolation='bilinear'
))
# Decoder layers.
self._dec_convs = []
for i, spec in enumerate(BASNET_DECODER_SPECS):
blocks = []
for j in range(3):
blocks.append(nn_blocks.ConvBlock(
filters=spec[2*j],
dilation_rate=spec[2*j+1],
activation='relu',
norm_momentum=0.99,
norm_epsilon=0.001,
**conv_kwargs))
self._dec_convs.append(blocks)
self._out_convs.append(conv_op(
filters=1,
padding='same',
**conv_kwargs))
self._out_usmps.append(tf.keras.layers.UpSampling2D(
size=spec[6],
interpolation='bilinear'
))
def call(self, backbone_output: Mapping[str, tf.Tensor]):
levels = sorted(backbone_output.keys(), reverse=True)
sup = {}
x = backbone_output[levels[0]]
for blocks in self._bdg_convs:
for block in blocks:
x = block(x)
sup['0'] = x
for i, blocks in enumerate(self._dec_convs):
x = self._concat([x, backbone_output[levels[i]]])
for block in blocks:
x = block(x)
sup[str(i+1)] = x
x = tf.keras.layers.UpSampling2D(
size=2,
interpolation='bilinear'
)(x)
for i, (conv, usmp) in enumerate(zip(self._out_convs, self._out_usmps)):
sup[str(i)] = self._sigmoid(usmp(conv(sup[str(i)])))
self._output_specs = {
str(order): sup[str(order)].get_shape()
for order in range(0, len(BASNET_DECODER_SPECS))
}
return sup
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@property
def output_specs(self):
"""A dict of {order: TensorShape} pairs for the model output."""
return self._output_specs
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Import libraries
import tensorflow as tf
from official.modeling import tf_utils
from official.vision.beta.modeling.backbones import factory
from official.vision.beta.projects.basnet.modeling.layers import nn_blocks
# Specifications for BASNet encoder.
# Each element in the block configuration is in the following format:
# (num_filters, stride, block_repeats, maxpool)
BASNET_ENCODER_SPECS = [
(64, 1, 3, 0), #ResNet-34,
(128, 2, 4, 0), #ResNet-34,
(256, 2, 6, 0), #ResNet-34,
(512, 2, 3, 1), #ResNet-34,
(512, 1, 3, 1), #BASNet,
(512, 1, 3, 0), #BASNet,
]
@tf.keras.utils.register_keras_serializable(package='Vision')
class BASNet_Encoder(tf.keras.Model):
"""BASNet Encoder
Boundary-Awar network (BASNet) were proposed in:
[1] Qin, Xuebin, et al.
Basnet: Boundary-aware salient object detection.
"""
def __init__(self,
input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]),
activation='relu',
use_sync_bn=False,
use_bias=True,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
**kwargs):
"""BASNet_Encoder initialization function.
Args:
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
use_bias: if True, use bias in conv2d.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
**kwargs: keyword arguments to be passed.
"""
self._input_specs = input_specs
self._use_sync_bn = use_sync_bn
self._use_bias = use_bias
self._activation = activation
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = -1
else:
bn_axis = 1
# Build BASNet.
inputs = tf.keras.Input(shape=input_specs.shape[1:])
x = tf.keras.layers.Conv2D(
filters=64, kernel_size=3, strides=1,
use_bias=self._use_bias, padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
inputs)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
x)
x = tf_utils.get_activation(activation)(x)
endpoints = {}
for i, spec in enumerate(BASNET_ENCODER_SPECS):
x = self._block_group(
inputs=x,
filters=spec[0],
strides=spec[1],
block_repeats=spec[2],
name='block_group_l{}'.format(i + 2))
endpoints[str(i)] = x
if spec[3]:
x = tf.keras.layers.MaxPool2D(pool_size=2, strides=2, padding='same')(x)
self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}
super(BASNet_Encoder, self).__init__(inputs=inputs, outputs=endpoints, **kwargs)
def _block_group(self,
inputs,
filters,
strides,
block_repeats=1,
name='block_group'):
"""Creates one group of residual blocks for the BASNet encoder model.
Args:
inputs: `Tensor` of size `[batch, channels, height, width]`.
filters: `int` number of filters for the first convolution of the layer.
strides: `int` stride to use for the first convolution of the layer. If
greater than 1, this layer will downsample the input.
block_repeats: `int` number of blocks contained in the layer.
name: `str`name for the block.
Returns:
The output `Tensor` of the block layer.
"""
x = nn_blocks.ResBlock(
filters=filters,
strides=strides,
use_projection=True,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=self._activation,
use_sync_bn=self._use_sync_bn,
use_bias=self._use_bias,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)(
inputs)
for _ in range(1, block_repeats):
x = nn_block.ResBlock(
filters=filters,
strides=1,
use_projection=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=self._activation,
use_sync_bn=self._use_sync_bn,
use_bias=self._use_bias,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)(
x)
return tf.identity(x, name=name)
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@property
def output_specs(self):
"""A dict of {level: TensorShape} pairs for the model output."""
return self._output_specs
@factory.register_backbone_builder('basnet_encoder')
def build_basnet_encoder(
input_specs: tf.keras.layers.InputSpec,
model_config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds BASNet Encoder backbone from a config."""
backbone_type = model_config.backbone.type
backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
assert backbone_type == 'basnet_encoder', (f'Inconsistent backbone type '
f'{backbone_type}')
return BASNet_Encoder(
input_specs=input_specs,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
......@@ -15,13 +15,53 @@
"""Build BASNet models."""
# Import libraries
from typing import Mapping
import tensorflow as tf
from official.modeling import tf_utils
from official.vision.beta.modeling.backbones import factory
from official.vision.beta.projects.basnet.modeling.layers import nn_blocks
# Specifications for BASNet encoder.
# Each element in the block configuration is in the following format:
# (num_filters, stride, block_repeats, maxpool)
BASNET_ENCODER_SPECS = [
(64, 1, 3, 0), #ResNet-34,
(128, 2, 4, 0), #ResNet-34,
(256, 2, 6, 0), #ResNet-34,
(512, 2, 3, 1), #ResNet-34,
(512, 1, 3, 1), #BASNet,
(512, 1, 3, 0), #BASNet,
]
# Specifications for BASNet decoder.
# Each element in the block configuration is in the following format:
# (conv1_nf, conv1_dr, convm_nf, convm_dr, conv2_nf, conv2_dr, scale_factor)
# nf : num_filters, dr : dilation_rate
BASNET_BRIDGE_SPECS = [
(512, 2, 512, 2, 512, 2, 32), #Sup0, Bridge
]
BASNET_DECODER_SPECS = [
(512, 1, 512, 2, 512, 2, 32), #Sup1, stage6d
(512, 1, 512, 1, 512, 1, 16), #Sup2, stage5d
(512, 1, 512, 1, 256, 1, 8), #Sup3, stage4d
(256, 1, 256, 1, 128, 1, 4), #Sup4, stage3d
(128, 1, 128, 1, 64, 1, 2), #Sup5, stage2d
(64, 1, 64, 1, 64, 1, 1) #Sup6, stage1d
]
@tf.keras.utils.register_keras_serializable(package='Vision')
class BASNetModel(tf.keras.Model):
"""A BASNet model.
Boundary-Awar network (BASNet) were proposed in:
[1] Qin, Xuebin, et al.
Basnet: Boundary-aware salient object detection.
Input images are passed through backbone first. Decoder network is then
applied, and finally, refinement module is applied on the output of the
decoder network.
......@@ -80,3 +120,334 @@ class BASNetModel(tf.keras.Model):
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@tf.keras.utils.register_keras_serializable(package='Vision')
class BASNet_Encoder(tf.keras.Model):
"""BASNet encoder"""
def __init__(self,
input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]),
activation='relu',
use_sync_bn=False,
use_bias=True,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
**kwargs):
"""BASNet encoder initialization function.
Args:
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
use_bias: if True, use bias in conv2d.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
**kwargs: keyword arguments to be passed.
"""
self._input_specs = input_specs
self._use_sync_bn = use_sync_bn
self._use_bias = use_bias
self._activation = activation
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = -1
else:
bn_axis = 1
# Build BASNet.
inputs = tf.keras.Input(shape=input_specs.shape[1:])
x = tf.keras.layers.Conv2D(
filters=64, kernel_size=3, strides=1,
use_bias=self._use_bias, padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
inputs)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
x)
x = tf_utils.get_activation(activation)(x)
endpoints = {}
for i, spec in enumerate(BASNET_ENCODER_SPECS):
x = self._block_group(
inputs=x,
filters=spec[0],
strides=spec[1],
block_repeats=spec[2],
name='block_group_l{}'.format(i + 2))
endpoints[str(i)] = x
if spec[3]:
x = tf.keras.layers.MaxPool2D(pool_size=2, strides=2, padding='same')(x)
self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}
super(BASNet_Encoder, self).__init__(inputs=inputs, outputs=endpoints, **kwargs)
def _block_group(self,
inputs,
filters,
strides,
block_repeats=1,
name='block_group'):
"""Creates one group of residual blocks for the BASNet encoder model.
Args:
inputs: `Tensor` of size `[batch, channels, height, width]`.
filters: `int` number of filters for the first convolution of the layer.
strides: `int` stride to use for the first convolution of the layer. If
greater than 1, this layer will downsample the input.
block_repeats: `int` number of blocks contained in the layer.
name: `str`name for the block.
Returns:
The output `Tensor` of the block layer.
"""
x = nn_blocks.ResBlock(
filters=filters,
strides=strides,
use_projection=True,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=self._activation,
use_sync_bn=self._use_sync_bn,
use_bias=self._use_bias,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)(
inputs)
for _ in range(1, block_repeats):
x = nn_blocks.ResBlock(
filters=filters,
strides=1,
use_projection=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=self._activation,
use_sync_bn=self._use_sync_bn,
use_bias=self._use_bias,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)(
x)
return tf.identity(x, name=name)
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@property
def output_specs(self):
"""A dict of {level: TensorShape} pairs for the model output."""
return self._output_specs
@factory.register_backbone_builder('basnet_encoder')
def build_basnet_encoder(
input_specs: tf.keras.layers.InputSpec,
model_config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds BASNet Encoder backbone from a config."""
backbone_type = model_config.backbone.type
backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
assert backbone_type == 'basnet_encoder', (f'Inconsistent backbone type '
f'{backbone_type}')
return BASNet_Encoder(
input_specs=input_specs,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
@tf.keras.utils.register_keras_serializable(package='Vision')
class BASNet_Decoder(tf.keras.layers.Layer):
"""BASNet decoder."""
def __init__(self,
use_separable_conv=False,
activation='relu',
use_sync_bn=False,
use_bias=True,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
**kwargs):
"""BASNet decoder initialization function.
Args:
use_separable_conv: `bool`, if True use separable convolution for
convolution in BASNet layers.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
use_bias: if True, use bias in convolution.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
**kwargs: keyword arguments to be passed.
"""
super(BASNet_Decoder, self).__init__(**kwargs)
self._config_dict = {
'use_separable_conv': use_separable_conv,
'activation': activation,
'use_sync_bn': use_sync_bn,
'use_bias': use_bias,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon,
'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
}
if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = -1
else:
bn_axis = 1
self._activation = tf_utils.get_activation(activation)
self._concat = tf.keras.layers.Concatenate(axis=-1)
self._sigmoid = tf.keras.layers.Activation(activation='sigmoid')
def build(self, input_shape):
"""Creates the variables of the BASNet decoder."""
if self._config_dict['use_separable_conv']:
conv_op = tf.keras.layers.SeparableConv2D
else:
conv_op = tf.keras.layers.Conv2D
conv_kwargs = {
'kernel_size': 3,
'strides': 1,
'use_bias': self._config_dict['use_bias'],
'kernel_initializer': self._config_dict['kernel_initializer'],
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
}
self._out_convs = []
self._out_usmps = []
# Bridge layers.
self._bdg_convs = []
for i, spec in enumerate(BASNET_BRIDGE_SPECS):
blocks = []
for j in range(3):
blocks.append(nn_blocks.ConvBlock(
filters=spec[2*j],
dilation_rate=spec[2*j+1],
activation='relu',
norm_momentum=0.99,
norm_epsilon=0.001,
**conv_kwargs))
self._bdg_convs.append(blocks)
self._out_convs.append(conv_op(
filters=1,
padding='same',
**conv_kwargs))
self._out_usmps.append(tf.keras.layers.UpSampling2D(
size=spec[6],
interpolation='bilinear'
))
# Decoder layers.
self._dec_convs = []
for i, spec in enumerate(BASNET_DECODER_SPECS):
blocks = []
for j in range(3):
blocks.append(nn_blocks.ConvBlock(
filters=spec[2*j],
dilation_rate=spec[2*j+1],
activation='relu',
norm_momentum=0.99,
norm_epsilon=0.001,
**conv_kwargs))
self._dec_convs.append(blocks)
self._out_convs.append(conv_op(
filters=1,
padding='same',
**conv_kwargs))
self._out_usmps.append(tf.keras.layers.UpSampling2D(
size=spec[6],
interpolation='bilinear'
))
def call(self, backbone_output: Mapping[str, tf.Tensor]):
"""Forward pass of the BASNet decoder.
Args:
backbone_output: A `dict` of tensors
- key: A `str` of the level of the multilevel features.
- values: A `tf.Tensor` of the feature map tensors, whose shape is
[batch, height_l, width_l, channels].
Returns:
sup: A `dict` of tensors
- key: A `str` of the level of the multilevel features.
- values: A `tf.Tensor` of the feature map tensors, whose shape is
[batch, height_l, width_l, channels].
"""
levels = sorted(backbone_output.keys(), reverse=True)
sup = {}
x = backbone_output[levels[0]]
for blocks in self._bdg_convs:
for block in blocks:
x = block(x)
sup['0'] = x
for i, blocks in enumerate(self._dec_convs):
x = self._concat([x, backbone_output[levels[i]]])
for block in blocks:
x = block(x)
sup[str(i+1)] = x
x = tf.keras.layers.UpSampling2D(
size=2,
interpolation='bilinear'
)(x)
for i, (conv, usmp) in enumerate(zip(self._out_convs, self._out_usmps)):
sup[str(i)] = self._sigmoid(usmp(conv(sup[str(i)])))
self._output_specs = {
str(order): sup[str(order)].get_shape()
for order in range(0, len(BASNET_DECODER_SPECS))
}
return sup
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@property
def output_specs(self):
"""A dict of {order: TensorShape} pairs for the model output."""
return self._output_specs
......@@ -19,9 +19,7 @@ from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.vision.beta.projects.basnet.modeling import basnet_encoder
from official.vision.beta.projects.basnet.modeling import basnet_model
from official.vision.beta.projects.basnet.modeling import basnet_decoder
from official.vision.beta.projects.basnet.modeling import refunet
......@@ -38,9 +36,8 @@ class BASNetNetworkTest(parameterized.TestCase, tf.test.TestCase):
inputs = np.random.rand(2, input_size, input_size, 3)
tf.keras.backend.set_image_data_format('channels_last')
backbone = basnet_encoder.BASNet_Encoder()
decoder = basnet_decoder.BASNet_Decoder(
input_specs=backbone.output_specs)
backbone = basnet_model.BASNet_Encoder()
decoder = basnet_model.BASNet_Decoder()
refinement = refunet.RefUnet()
model = basnet_model.BASNetModel(
......@@ -50,16 +47,15 @@ class BASNetNetworkTest(parameterized.TestCase, tf.test.TestCase):
)
sigmoids = model(inputs)
#print(sigmoids['ref'].numpy().shape)
levels = sorted(sigmoids.keys())
self.assertAllEqual(
[2, input_size, input_size, 1],
sigmoids['ref'].numpy().shape)
sigmoids[levels[-1]].numpy().shape)
def test_serialize_deserialize(self):
"""Validate the network can be serialized and deserialized."""
backbone = basnet_encoder.BASNet_Encoder()
decoder = basnet_decoder.BASNet_Decoder(
input_specs=backbone.output_specs)
backbone = basnet_model.BASNet_Encoder()
decoder = basnet_model.BASNet_Decoder()
refinement = refunet.RefUnet()
model = basnet_model.BASNetModel(
......
......@@ -22,7 +22,7 @@ from official.vision.beta.projects.basnet.modeling.layers import nn_blocks
class RefUnet(tf.keras.layers.Layer):
"""Residual Refinement Module of BASNet.
Boundary-Awar network (BASNet) were proposed in:
Boundary-Aware network (BASNet) were proposed in:
[1] Qin, Xuebin, et al.
Basnet: Boundary-aware salient object detection.
"""
......@@ -141,6 +141,9 @@ class RefUnet(tf.keras.layers.Layer):
self._output_specs = output.get_shape()
return output
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
......
......@@ -30,9 +30,7 @@ from official.vision.beta.projects.basnet.evaluation import relax_f
from official.vision.beta.projects.basnet.evaluation import mae
from official.vision.beta.projects.basnet.losses import basnet_losses
from official.vision.beta.projects.basnet.modeling import basnet_encoder
from official.vision.beta.projects.basnet.modeling import basnet_model
from official.vision.beta.projects.basnet.modeling import basnet_decoder
from official.vision.beta.projects.basnet.modeling import refunet
def build_basnet_model(
......@@ -40,12 +38,12 @@ def build_basnet_model(
model_config: exp_cfg.BASNetModel,
l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds BASNet model."""
backbone = basnet_encoder.BASNet_Encoder(
backbone = basnet_model.BASNet_Encoder(
input_specs=input_specs)
norm_activation_config = model_config.norm_activation
decoder = basnet_decoder.BASNet_Decoder(
decoder = basnet_model.BASNet_Decoder(
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment