revnet.py 8.64 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Yeqing Li's avatar
Yeqing Li committed
14

Fan Yang's avatar
Fan Yang committed
15
"""Contains definitions of RevNet."""
Abdullah Rashwan's avatar
Abdullah Rashwan committed
16
17
18
19

from typing import Any, Callable, Dict, Optional
# Import libraries
import tensorflow as tf
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
20
from official.modeling import hyperparams
Abdullah Rashwan's avatar
Abdullah Rashwan committed
21
from official.modeling import tf_utils
Yeqing Li's avatar
Yeqing Li committed
22
from official.vision.beta.modeling.backbones import factory
Abdullah Rashwan's avatar
Abdullah Rashwan committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from official.vision.beta.modeling.layers import nn_blocks


# Specifications for different RevNet variants.
# Each entry specifies block configurations of the particular RevNet variant.
# Each element in the block configuration is in the following format:
# (block_fn, num_filters, block_repeats)
REVNET_SPECS = {
    38: [
        ('residual', 32, 3),
        ('residual', 64, 3),
        ('residual', 112, 3),
    ],
    56: [
        ('bottleneck', 128, 2),
        ('bottleneck', 256, 2),
        ('bottleneck', 512, 3),
        ('bottleneck', 832, 2),
    ],
    104: [
        ('bottleneck', 128, 2),
        ('bottleneck', 256, 2),
        ('bottleneck', 512, 11),
        ('bottleneck', 832, 2),
    ],
}


51
@tf.keras.utils.register_keras_serializable(package='Beta')
Abdullah Rashwan's avatar
Abdullah Rashwan committed
52
class RevNet(tf.keras.Model):
Fan Yang's avatar
Fan Yang committed
53
54
55
56
57
58
59
60
  """Creates a Reversible ResNet (RevNet) family model.

  This implements:
    Aidan N. Gomez, Mengye Ren, Raquel Urtasun, Roger B. Grosse.
    The Reversible Residual Network: Backpropagation Without Storing
    Activations.
    (https://arxiv.org/pdf/1707.04585.pdf)
  """
Abdullah Rashwan's avatar
Abdullah Rashwan committed
61

Fan Yang's avatar
Fan Yang committed
62
63
64
65
66
67
68
69
70
71
72
73
  def __init__(
      self,
      model_id: int,
      input_specs: tf.keras.layers.InputSpec = tf.keras.layers.InputSpec(
          shape=[None, None, None, 3]),
      activation: str = 'relu',
      use_sync_bn: bool = False,
      norm_momentum: float = 0.99,
      norm_epsilon: float = 0.001,
      kernel_initializer: str = 'VarianceScaling',
      kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
      **kwargs):
Fan Yang's avatar
Fan Yang committed
74
    """Initializes a RevNet model.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
75
76

    Args:
Fan Yang's avatar
Fan Yang committed
77
78
79
80
81
82
83
84
85
86
      model_id: An `int` of depth/id of ResNet backbone model.
      input_specs: A `tf.keras.layers.InputSpec` of the input tensor.
      activation: A `str` name of the activation function.
      use_sync_bn: If True, use synchronized batch normalization.
      norm_momentum: A `float` of normalization momentum for the moving average.
      norm_epsilon: A `float` added to variance to avoid dividing by zero.
      kernel_initializer: A str for kernel initializer of convolutional layers.
      kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
        Conv2D. Default to None.
      **kwargs: Additional keyword arguments to be passed.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    """
    self._model_id = model_id
    self._input_specs = input_specs
    self._use_sync_bn = use_sync_bn
    self._activation = activation
    self._norm_momentum = norm_momentum
    self._norm_epsilon = norm_epsilon
    self._kernel_initializer = kernel_initializer
    self._kernel_regularizer = kernel_regularizer
    if use_sync_bn:
      self._norm = tf.keras.layers.experimental.SyncBatchNormalization
    else:
      self._norm = tf.keras.layers.BatchNormalization

    axis = -1 if tf.keras.backend.image_data_format() == 'channels_last' else 1

    # Build RevNet.
    inputs = tf.keras.Input(shape=input_specs.shape[1:])

    x = tf.keras.layers.Conv2D(
        filters=REVNET_SPECS[model_id][0][1],
        kernel_size=7, strides=2, use_bias=False, padding='same',
        kernel_initializer=self._kernel_initializer,
        kernel_regularizer=self._kernel_regularizer)(inputs)
    x = self._norm(
        axis=axis, momentum=norm_momentum, epsilon=norm_epsilon)(x)
    x = tf_utils.get_activation(activation)(x)
    x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)

    endpoints = {}
    for i, spec in enumerate(REVNET_SPECS[model_id]):
      if spec[0] == 'residual':
        inner_block_fn = nn_blocks.ResidualInner
      elif spec[0] == 'bottleneck':
        inner_block_fn = nn_blocks.BottleneckResidualInner
      else:
        raise ValueError('Block fn `{}` is not supported.'.format(spec[0]))

      if spec[1] % 2 != 0:
        raise ValueError('Number of output filters must be even to ensure '
                         'splitting in channel dimension for reversible blocks')

      x = self._block_group(
          inputs=x,
          filters=spec[1],
          strides=(1 if i == 0 else 2),
          inner_block_fn=inner_block_fn,
          block_repeats=spec[2],
          batch_norm_first=(i != 0),  # Only skip on first block
          name='revblock_group_{}'.format(i + 2))
Abdullah Rashwan's avatar
Abdullah Rashwan committed
137
      endpoints[str(i + 2)] = x
Abdullah Rashwan's avatar
Abdullah Rashwan committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

    self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}

    super(RevNet, self).__init__(inputs=inputs, outputs=endpoints, **kwargs)

  def _block_group(self,
                   inputs: tf.Tensor,
                   filters: int,
                   strides: int,
                   inner_block_fn: Callable[..., tf.keras.layers.Layer],
                   block_repeats: int,
                   batch_norm_first: bool,
                   name: str = 'revblock_group') -> tf.Tensor:
    """Creates one reversible block for RevNet model.

    Args:
Fan Yang's avatar
Fan Yang committed
154
155
156
157
      inputs: A `tf.Tensor` of size `[batch, channels, height, width]`.
      filters: An `int` number of filters for the first convolution of the
        layer.
      strides: An `int` stride to use for the first convolution of the layer. If
Abdullah Rashwan's avatar
Abdullah Rashwan committed
158
159
160
        greater than 1, this block group will downsample the input.
      inner_block_fn: Either `nn_blocks.ResidualInner` or
        `nn_blocks.BottleneckResidualInner`.
Fan Yang's avatar
Fan Yang committed
161
162
163
164
165
      block_repeats: An `int` number of blocks contained in this block group.
      batch_norm_first: A `bool` that specifies whether to apply
        BatchNormalization and activation layer before feeding into convolution
        layers.
      name: A `str` name for the block.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
166
167

    Returns:
Fan Yang's avatar
Fan Yang committed
168
      The output `tf.Tensor` of the block layer.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    """
    x = inputs
    for i in range(block_repeats):
      is_first_block = i == 0
      # Only first residual layer in block gets downsampled
      curr_strides = strides if is_first_block else 1
      f = inner_block_fn(
          filters=filters // 2,
          strides=curr_strides,
          batch_norm_first=batch_norm_first and is_first_block,
          kernel_regularizer=self._kernel_regularizer)
      g = inner_block_fn(
          filters=filters // 2,
          strides=1,
          batch_norm_first=batch_norm_first and is_first_block,
          kernel_regularizer=self._kernel_regularizer)
      x = nn_blocks.ReversibleLayer(f, g)(x)

    return tf.identity(x, name=name)

  def get_config(self) -> Dict[str, Any]:
    config_dict = {
        'model_id': self._model_id,
        'activation': self._activation,
        'use_sync_bn': self._use_sync_bn,
        'norm_momentum': self._norm_momentum,
        'norm_epsilon': self._norm_epsilon,
        'kernel_initializer': self._kernel_initializer,
        'kernel_regularizer': self._kernel_regularizer,
    }
    return config_dict

  @classmethod
  def from_config(cls,
                  config: Dict[str, Any],
                  custom_objects: Optional[Any] = None) -> tf.keras.Model:
    return cls(**config)

  @property
  def output_specs(self) -> Dict[int, tf.TensorShape]:
    """A dict of {level: TensorShape} pairs for the model output."""
Rebecca Chen's avatar
Rebecca Chen committed
210
    return self._output_specs  # pytype: disable=bad-return-type  # trace-all-classes
Yeqing Li's avatar
Yeqing Li committed
211
212
213
214
215


@factory.register_backbone_builder('revnet')
def build_revnet(
    input_specs: tf.keras.layers.InputSpec,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
216
217
    backbone_config: hyperparams.Config,
    norm_activation_config: hyperparams.Config,
Rebecca Chen's avatar
Rebecca Chen committed
218
    l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:  # pytype: disable=annotation-type-mismatch  # typed-keras
Fan Yang's avatar
Fan Yang committed
219
  """Builds RevNet backbone from a config."""
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
220
221
  backbone_type = backbone_config.type
  backbone_cfg = backbone_config.get()
Yeqing Li's avatar
Yeqing Li committed
222
223
224
225
226
227
228
229
230
231
232
  assert backbone_type == 'revnet', (f'Inconsistent backbone type '
                                     f'{backbone_type}')

  return RevNet(
      model_id=backbone_cfg.model_id,
      input_specs=input_specs,
      activation=norm_activation_config.activation,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      kernel_regularizer=l2_regularizer)