refunet.py 5.29 KB
Newer Older
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Gunho Park's avatar
Gunho Park committed
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
"""RefUNet model."""
Gunho Park's avatar
Gunho Park committed
16
import tensorflow as tf
17
from official.projects.basnet.modeling import nn_blocks
Gunho Park's avatar
Gunho Park committed
18
19
20


@tf.keras.utils.register_keras_serializable(package='Vision')
Gunho Park's avatar
Gunho Park committed
21
22
class RefUnet(tf.keras.layers.Layer):
  """Residual Refinement Module of BASNet.
Gunho Park's avatar
Gunho Park committed
23

Gunho Park's avatar
Gunho Park committed
24
  Boundary-Aware network (BASNet) were proposed in:
25
  [1] Qin, Xuebin, et al.
Gunho Park's avatar
Gunho Park committed
26
27
      Basnet: Boundary-aware salient object detection.
  """
28

Gunho Park's avatar
Gunho Park committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
  def __init__(self,
               activation='relu',
               use_sync_bn=False,
               use_bias=True,
               norm_momentum=0.99,
               norm_epsilon=0.001,
               kernel_initializer='VarianceScaling',
               kernel_regularizer=None,
               bias_regularizer=None,
               **kwargs):
    """Residual Refinement Module of BASNet.

    Args:
      activation: `str` name of the activation function.
      use_sync_bn: if True, use synchronized batch normalization.
      use_bias: if True, use bias in conv2d.
      norm_momentum: `float` normalization omentum for the moving average.
      norm_epsilon: `float` small float added to variance to avoid dividing by
        zero.
      kernel_initializer: kernel_initializer for convolutional layers.
      kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
                          Default to None.
      bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
                        Default to None.
      **kwargs: keyword arguments to be passed.
    """
Gunho Park's avatar
Gunho Park committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    super(RefUnet, self).__init__(**kwargs)
    self._config_dict = {
        'activation': activation,
        'use_sync_bn': use_sync_bn,
        'use_bias': use_bias,
        'norm_momentum': norm_momentum,
        'norm_epsilon': norm_epsilon,
        'kernel_initializer': kernel_initializer,
        'kernel_regularizer': kernel_regularizer,
        'bias_regularizer': bias_regularizer,
    }
    self._concat = tf.keras.layers.Concatenate(axis=-1)
    self._sigmoid = tf.keras.layers.Activation(activation='sigmoid')
    self._maxpool = tf.keras.layers.MaxPool2D(
        pool_size=2,
        strides=2,
        padding='valid')
    self._upsample = tf.keras.layers.UpSampling2D(
        size=2,
        interpolation='bilinear')
Gunho Park's avatar
Gunho Park committed
75

Gunho Park's avatar
Gunho Park committed
76
77
  def build(self, input_shape):
    """Creates the variables of the BASNet decoder."""
Gunho Park's avatar
Gunho Park committed
78
    conv_op = tf.keras.layers.Conv2D
Gunho Park's avatar
Gunho Park committed
79
    conv_kwargs = {
80
81
82
83
84
85
        'kernel_size': 3,
        'strides': 1,
        'use_bias': self._config_dict['use_bias'],
        'kernel_initializer': self._config_dict['kernel_initializer'],
        'kernel_regularizer': self._config_dict['kernel_regularizer'],
        'bias_regularizer': self._config_dict['bias_regularizer'],
Gunho Park's avatar
Gunho Park committed
86
87
    }

Gunho Park's avatar
Gunho Park committed
88
89
90
91
    self._in_conv = conv_op(
        filters=64,
        padding='same',
        **conv_kwargs)
Gunho Park's avatar
Gunho Park committed
92
93
94

    self._en_convs = []
    for _ in range(4):
Gunho Park's avatar
Gunho Park committed
95
96
97
98
99
100
      self._en_convs.append(nn_blocks.ConvBlock(
          filters=64,
          use_sync_bn=self._config_dict['use_sync_bn'],
          norm_momentum=self._config_dict['norm_momentum'],
          norm_epsilon=self._config_dict['norm_epsilon'],
          **conv_kwargs))
Gunho Park's avatar
Gunho Park committed
101
102
103

    self._bridge_convs = []
    for _ in range(1):
Gunho Park's avatar
Gunho Park committed
104
105
106
107
108
109
110
      self._bridge_convs.append(nn_blocks.ConvBlock(
          filters=64,
          use_sync_bn=self._config_dict['use_sync_bn'],
          norm_momentum=self._config_dict['norm_momentum'],
          norm_epsilon=self._config_dict['norm_epsilon'],
          **conv_kwargs))

Gunho Park's avatar
Gunho Park committed
111
112
    self._de_convs = []
    for _ in range(4):
Gunho Park's avatar
Gunho Park committed
113
114
115
116
117
118
119
120
121
122
123
      self._de_convs.append(nn_blocks.ConvBlock(
          filters=64,
          use_sync_bn=self._config_dict['use_sync_bn'],
          norm_momentum=self._config_dict['norm_momentum'],
          norm_epsilon=self._config_dict['norm_epsilon'],
          **conv_kwargs))

    self._out_conv = conv_op(
        filters=1,
        padding='same',
        **conv_kwargs)
Gunho Park's avatar
Gunho Park committed
124
125
126

  def call(self, inputs):
    endpoints = {}
Gunho Park's avatar
Gunho Park committed
127
    residual = inputs
Gunho Park's avatar
Gunho Park committed
128
    x = self._in_conv(inputs)
Gunho Park's avatar
Gunho Park committed
129
130

    # Top-down
Gunho Park's avatar
Gunho Park committed
131
132
    for i, block in enumerate(self._en_convs):
      x = block(x)
Gunho Park's avatar
Gunho Park committed
133
      endpoints[str(i)] = x
Gunho Park's avatar
Gunho Park committed
134
      x = self._maxpool(x)
Gunho Park's avatar
Gunho Park committed
135
136

    # Bridge
Gunho Park's avatar
Gunho Park committed
137
138
    for i, block in enumerate(self._bridge_convs):
      x = block(x)
Gunho Park's avatar
Gunho Park committed
139
140

    # Bottom-up
Gunho Park's avatar
Gunho Park committed
141
    for i, block in enumerate(self._de_convs):
142
143
      dtype = x.dtype
      x = tf.cast(x, tf.float32)
Gunho Park's avatar
Gunho Park committed
144
      x = self._upsample(x)
145
      x = tf.cast(x, dtype)
Gunho Park's avatar
Gunho Park committed
146
147
      x = self._concat([endpoints[str(3-i)], x])
      x = block(x)
Gunho Park's avatar
Gunho Park committed
148

Gunho Park's avatar
Gunho Park committed
149
    x = self._out_conv(x)
Gunho Park's avatar
Gunho Park committed
150
    residual = tf.cast(residual, dtype=x.dtype)
Gunho Park's avatar
Gunho Park committed
151
    output = self._sigmoid(x + residual)
Gunho Park's avatar
Gunho Park committed
152
153

    self._output_specs = output.get_shape()
Gunho Park's avatar
Gunho Park committed
154
    return output
Gunho Park's avatar
Gunho Park committed
155

Gunho Park's avatar
Gunho Park committed
156
157
158
  def get_config(self):
    return self._config_dict

Gunho Park's avatar
Gunho Park committed
159
160
161
162
163
164
165
  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)

  @property
  def output_specs(self):
    return self._output_specs