basnet_model.py 14.7 KB
Newer Older
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Gunho Park's avatar
Gunho Park committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

Gunho Park's avatar
Gunho Park committed
15
16
"""Build BASNet models."""

Gunho Park's avatar
Gunho Park committed
17
from typing import Mapping
18

Gunho Park's avatar
Gunho Park committed
19
20
import tensorflow as tf

Gunho Park's avatar
Gunho Park committed
21
from official.modeling import tf_utils
22
from official.projects.basnet.modeling import nn_blocks
Gunho Park's avatar
Gunho Park committed
23
24
25
26
27
28
from official.vision.beta.modeling.backbones import factory

# Specifications for BASNet encoder.
# Each element in the block configuration is in the following format:
# (num_filters, stride, block_repeats, maxpool)
BASNET_ENCODER_SPECS = [
29
30
31
32
33
34
35
    (64, 1, 3, 0),  # ResNet-34,
    (128, 2, 4, 0),  # ResNet-34,
    (256, 2, 6, 0),  # ResNet-34,
    (512, 2, 3, 1),  # ResNet-34,
    (512, 1, 3, 1),  # BASNet,
    (512, 1, 3, 0),  # BASNet,
]
Gunho Park's avatar
Gunho Park committed
36
37
38
39
40
41

# Specifications for BASNet decoder.
# Each element in the block configuration is in the following format:
# (conv1_nf, conv1_dr, convm_nf, convm_dr, conv2_nf, conv2_dr, scale_factor)
# nf : num_filters, dr : dilation_rate
BASNET_BRIDGE_SPECS = [
42
43
    (512, 2, 512, 2, 512, 2, 32),  # Sup0, Bridge
]
Gunho Park's avatar
Gunho Park committed
44
45

BASNET_DECODER_SPECS = [
46
47
48
49
50
51
52
    (512, 1, 512, 2, 512, 2, 32),  # Sup1, stage6d
    (512, 1, 512, 1, 512, 1, 16),  # Sup2, stage5d
    (512, 1, 512, 1, 256, 1, 8),  # Sup3, stage4d
    (256, 1, 256, 1, 128, 1, 4),  # Sup4, stage3d
    (128, 1, 128, 1, 64, 1, 2),  # Sup5, stage2d
    (64, 1, 64, 1, 64, 1, 1)  # Sup6, stage1d
]
Gunho Park's avatar
Gunho Park committed
53

Gunho Park's avatar
Gunho Park committed
54
55
56
57
58

@tf.keras.utils.register_keras_serializable(package='Vision')
class BASNetModel(tf.keras.Model):
  """A BASNet model.

Gunho Park's avatar
Gunho Park committed
59
  Boundary-Awar network (BASNet) were proposed in:
60
  [1] Qin, Xuebin, et al.
Gunho Park's avatar
Gunho Park committed
61
62
      Basnet: Boundary-aware salient object detection.

Gunho Park's avatar
Gunho Park committed
63
64
65
66
67
68
69
70
  Input images are passed through backbone first. Decoder network is then
  applied, and finally, refinement module is applied on the output of the
  decoder network.
  """

  def __init__(self,
               backbone,
               decoder,
Gunho Park's avatar
Gunho Park committed
71
               refinement=None,
Gunho Park's avatar
Gunho Park committed
72
73
74
75
               **kwargs):
    """BASNet initialization function.

    Args:
76
77
78
      backbone: a backbone network. basnet_encoder.
      decoder: a decoder network. basnet_decoder.
      refinement: a module for salient map refinement.
Gunho Park's avatar
Gunho Park committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
      **kwargs: keyword arguments to be passed.
    """
    super(BASNetModel, self).__init__(**kwargs)
    self._config_dict = {
        'backbone': backbone,
        'decoder': decoder,
        'refinement': refinement,
    }
    self.backbone = backbone
    self.decoder = decoder
    self.refinement = refinement

  def call(self, inputs, training=None):
    features = self.backbone(inputs)

    if self.decoder:
      features = self.decoder(features)
Gunho Park's avatar
Gunho Park committed
96
97
98

    levels = sorted(features.keys())
    new_key = str(len(levels))
Gunho Park's avatar
Gunho Park committed
99
    if self.refinement:
Gunho Park's avatar
Gunho Park committed
100
      features[new_key] = self.refinement(features[levels[-1]])
Gunho Park's avatar
Gunho Park committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119

    return features

  @property
  def checkpoint_items(self):
    """Returns a dictionary of items to be additionally checkpointed."""
    items = dict(backbone=self.backbone)
    if self.decoder is not None:
      items.update(decoder=self.decoder)
    if self.refinement is not None:
      items.update(refinement=self.refinement)
    return items

  def get_config(self):
    return self._config_dict

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)
Gunho Park's avatar
Gunho Park committed
120
121
122


@tf.keras.utils.register_keras_serializable(package='Vision')
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
class BASNetEncoder(tf.keras.Model):
  """BASNet encoder."""

  def __init__(
      self,
      input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]),
      activation='relu',
      use_sync_bn=False,
      use_bias=True,
      norm_momentum=0.99,
      norm_epsilon=0.001,
      kernel_initializer='VarianceScaling',
      kernel_regularizer=None,
      bias_regularizer=None,
      **kwargs):
Gunho Park's avatar
Gunho Park committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    """BASNet encoder initialization function.

    Args:
      input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
      activation: `str` name of the activation function.
      use_sync_bn: if True, use synchronized batch normalization.
      use_bias: if True, use bias in conv2d.
      norm_momentum: `float` normalization omentum for the moving average.
      norm_epsilon: `float` small float added to variance to avoid dividing by
        zero.
      kernel_initializer: kernel_initializer for convolutional layers.
      kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
                          Default to None.
      bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
                        Default to None.
      **kwargs: keyword arguments to be passed.
    """
    self._input_specs = input_specs
    self._use_sync_bn = use_sync_bn
    self._use_bias = use_bias
    self._activation = activation
    self._norm_momentum = norm_momentum
    self._norm_epsilon = norm_epsilon
    if use_sync_bn:
      self._norm = tf.keras.layers.experimental.SyncBatchNormalization
    else:
      self._norm = tf.keras.layers.BatchNormalization
    self._kernel_initializer = kernel_initializer
    self._kernel_regularizer = kernel_regularizer
    self._bias_regularizer = bias_regularizer

    if tf.keras.backend.image_data_format() == 'channels_last':
      bn_axis = -1
    else:
      bn_axis = 1

174
    # Build BASNet Encoder.
Gunho Park's avatar
Gunho Park committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    inputs = tf.keras.Input(shape=input_specs.shape[1:])

    x = tf.keras.layers.Conv2D(
        filters=64, kernel_size=3, strides=1,
        use_bias=self._use_bias, padding='same',
        kernel_initializer=self._kernel_initializer,
        kernel_regularizer=self._kernel_regularizer,
        bias_regularizer=self._bias_regularizer)(
            inputs)
    x = self._norm(
        axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
            x)
    x = tf_utils.get_activation(activation)(x)

    endpoints = {}
    for i, spec in enumerate(BASNET_ENCODER_SPECS):
      x = self._block_group(
          inputs=x,
          filters=spec[0],
          strides=spec[1],
          block_repeats=spec[2],
          name='block_group_l{}'.format(i + 2))
      endpoints[str(i)] = x
      if spec[3]:
        x = tf.keras.layers.MaxPool2D(pool_size=2, strides=2, padding='same')(x)
    self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}
201
202
    super(BASNetEncoder, self).__init__(
        inputs=inputs, outputs=endpoints, **kwargs)
Gunho Park's avatar
Gunho Park committed
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267

  def _block_group(self,
                   inputs,
                   filters,
                   strides,
                   block_repeats=1,
                   name='block_group'):
    """Creates one group of residual blocks for the BASNet encoder model.

    Args:
      inputs: `Tensor` of size `[batch, channels, height, width]`.
      filters: `int` number of filters for the first convolution of the layer.
      strides: `int` stride to use for the first convolution of the layer. If
        greater than 1, this layer will downsample the input.
      block_repeats: `int` number of blocks contained in the layer.
      name: `str`name for the block.

    Returns:
      The output `Tensor` of the block layer.
    """
    x = nn_blocks.ResBlock(
        filters=filters,
        strides=strides,
        use_projection=True,
        kernel_initializer=self._kernel_initializer,
        kernel_regularizer=self._kernel_regularizer,
        bias_regularizer=self._bias_regularizer,
        activation=self._activation,
        use_sync_bn=self._use_sync_bn,
        use_bias=self._use_bias,
        norm_momentum=self._norm_momentum,
        norm_epsilon=self._norm_epsilon)(
            inputs)

    for _ in range(1, block_repeats):
      x = nn_blocks.ResBlock(
          filters=filters,
          strides=1,
          use_projection=False,
          kernel_initializer=self._kernel_initializer,
          kernel_regularizer=self._kernel_regularizer,
          bias_regularizer=self._bias_regularizer,
          activation=self._activation,
          use_sync_bn=self._use_sync_bn,
          use_bias=self._use_bias,
          norm_momentum=self._norm_momentum,
          norm_epsilon=self._norm_epsilon)(
              x)

    return tf.identity(x, name=name)

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)

  @property
  def output_specs(self):
    """A dict of {level: TensorShape} pairs for the model output."""
    return self._output_specs


@factory.register_backbone_builder('basnet_encoder')
def build_basnet_encoder(
    input_specs: tf.keras.layers.InputSpec,
    model_config,
Rebecca Chen's avatar
Rebecca Chen committed
268
    l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:  # pytype: disable=annotation-type-mismatch  # typed-keras
Gunho Park's avatar
Gunho Park committed
269
270
271
272
273
  """Builds BASNet Encoder backbone from a config."""
  backbone_type = model_config.backbone.type
  norm_activation_config = model_config.norm_activation
  assert backbone_type == 'basnet_encoder', (f'Inconsistent backbone type '
                                             f'{backbone_type}')
274
  return BASNetEncoder(
Gunho Park's avatar
Gunho Park committed
275
276
277
      input_specs=input_specs,
      activation=norm_activation_config.activation,
      use_sync_bn=norm_activation_config.use_sync_bn,
Gunho Park's avatar
Gunho Park committed
278
      use_bias=norm_activation_config.use_bias,
Gunho Park's avatar
Gunho Park committed
279
280
281
282
283
284
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      kernel_regularizer=l2_regularizer)


@tf.keras.utils.register_keras_serializable(package='Vision')
285
class BASNetDecoder(tf.keras.layers.Layer):
Gunho Park's avatar
Gunho Park committed
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
  """BASNet decoder."""

  def __init__(self,
               activation='relu',
               use_sync_bn=False,
               use_bias=True,
               norm_momentum=0.99,
               norm_epsilon=0.001,
               kernel_initializer='VarianceScaling',
               kernel_regularizer=None,
               bias_regularizer=None,
               **kwargs):
    """BASNet decoder initialization function.

    Args:
      activation: `str` name of the activation function.
      use_sync_bn: if True, use synchronized batch normalization.
      use_bias: if True, use bias in convolution.
      norm_momentum: `float` normalization omentum for the moving average.
      norm_epsilon: `float` small float added to variance to avoid dividing by
        zero.
      kernel_initializer: kernel_initializer for convolutional layers.
      kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
      bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
      **kwargs: keyword arguments to be passed.
    """
312
    super(BASNetDecoder, self).__init__(**kwargs)
Gunho Park's avatar
Gunho Park committed
313
314
315
316
317
318
319
320
321
322
    self._config_dict = {
        'activation': activation,
        'use_sync_bn': use_sync_bn,
        'use_bias': use_bias,
        'norm_momentum': norm_momentum,
        'norm_epsilon': norm_epsilon,
        'kernel_initializer': kernel_initializer,
        'kernel_regularizer': kernel_regularizer,
        'bias_regularizer': bias_regularizer,
    }
323

Gunho Park's avatar
Gunho Park committed
324
325
326
327
328
329
    self._activation = tf_utils.get_activation(activation)
    self._concat = tf.keras.layers.Concatenate(axis=-1)
    self._sigmoid = tf.keras.layers.Activation(activation='sigmoid')

  def build(self, input_shape):
    """Creates the variables of the BASNet decoder."""
Gunho Park's avatar
Gunho Park committed
330
    conv_op = tf.keras.layers.Conv2D
Gunho Park's avatar
Gunho Park committed
331
    conv_kwargs = {
332
333
334
335
336
337
        'kernel_size': 3,
        'strides': 1,
        'use_bias': self._config_dict['use_bias'],
        'kernel_initializer': self._config_dict['kernel_initializer'],
        'kernel_regularizer': self._config_dict['kernel_regularizer'],
        'bias_regularizer': self._config_dict['bias_regularizer'],
Gunho Park's avatar
Gunho Park committed
338
339
340
341
342
343
344
    }

    self._out_convs = []
    self._out_usmps = []

    # Bridge layers.
    self._bdg_convs = []
345
    for spec in BASNET_BRIDGE_SPECS:
Gunho Park's avatar
Gunho Park committed
346
347
348
349
350
351
      blocks = []
      for j in range(3):
        blocks.append(nn_blocks.ConvBlock(
            filters=spec[2*j],
            dilation_rate=spec[2*j+1],
            activation='relu',
Gunho Park's avatar
Gunho Park committed
352
            use_sync_bn=self._config_dict['use_sync_bn'],
Gunho Park's avatar
Gunho Park committed
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
            norm_momentum=0.99,
            norm_epsilon=0.001,
            **conv_kwargs))
      self._bdg_convs.append(blocks)
      self._out_convs.append(conv_op(
          filters=1,
          padding='same',
          **conv_kwargs))
      self._out_usmps.append(tf.keras.layers.UpSampling2D(
          size=spec[6],
          interpolation='bilinear'
          ))

    # Decoder layers.
    self._dec_convs = []
368
    for spec in BASNET_DECODER_SPECS:
Gunho Park's avatar
Gunho Park committed
369
370
371
372
373
374
      blocks = []
      for j in range(3):
        blocks.append(nn_blocks.ConvBlock(
            filters=spec[2*j],
            dilation_rate=spec[2*j+1],
            activation='relu',
Gunho Park's avatar
Gunho Park committed
375
            use_sync_bn=self._config_dict['use_sync_bn'],
Gunho Park's avatar
Gunho Park committed
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
            norm_momentum=0.99,
            norm_epsilon=0.001,
            **conv_kwargs))
      self._dec_convs.append(blocks)
      self._out_convs.append(conv_op(
          filters=1,
          padding='same',
          **conv_kwargs))
      self._out_usmps.append(tf.keras.layers.UpSampling2D(
          size=spec[6],
          interpolation='bilinear'
          ))

  def call(self, backbone_output: Mapping[str, tf.Tensor]):
    """Forward pass of the BASNet decoder.
391

Gunho Park's avatar
Gunho Park committed
392
393
394
395
396
    Args:
      backbone_output: A `dict` of tensors
        - key: A `str` of the level of the multilevel features.
        - values: A `tf.Tensor` of the feature map tensors, whose shape is
            [batch, height_l, width_l, channels].
397

Gunho Park's avatar
Gunho Park committed
398
399
400
401
402
403
404
405
406
407
408
409
410
411
    Returns:
      sup: A `dict` of tensors
        - key: A `str` of the level of the multilevel features.
        - values: A `tf.Tensor` of the feature map tensors, whose shape is
            [batch, height_l, width_l, channels].
    """
    levels = sorted(backbone_output.keys(), reverse=True)
    sup = {}
    x = backbone_output[levels[0]]

    for blocks in self._bdg_convs:
      for block in blocks:
        x = block(x)
    sup['0'] = x
412

Gunho Park's avatar
Gunho Park committed
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
    for i, blocks in enumerate(self._dec_convs):
      x = self._concat([x, backbone_output[levels[i]]])
      for block in blocks:
        x = block(x)
      sup[str(i+1)] = x
      x = tf.keras.layers.UpSampling2D(
          size=2,
          interpolation='bilinear'
          )(x)
    for i, (conv, usmp) in enumerate(zip(self._out_convs, self._out_usmps)):
      sup[str(i)] = self._sigmoid(usmp(conv(sup[str(i)])))

    self._output_specs = {
        str(order): sup[str(order)].get_shape()
        for order in range(0, len(BASNET_DECODER_SPECS))
    }

    return sup

  def get_config(self):
    return self._config_dict

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)

  @property
  def output_specs(self):
    """A dict of {order: TensorShape} pairs for the model output."""
    return self._output_specs