unet_model.py 8.54 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model definition for the TF2 Keras UNet 3D Model."""

from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function

import tensorflow as tf


def create_optimizer(init_learning_rate, params):
  """Creates optimizer for training."""
  learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(
      initial_learning_rate=init_learning_rate,
      decay_steps=params.lr_decay_steps,
      decay_rate=params.lr_decay_rate)

  # TODO(hongjunchoi): Provide alternative optimizer options depending on model
  # config parameters.
  optimizer = tf.keras.optimizers.Adam(learning_rate)
  return optimizer


def create_convolution_block(input_layer,
                             n_filters,
                             batch_normalization=False,
                             kernel=(3, 3, 3),
                             activation=tf.nn.relu,
                             padding='SAME',
                             strides=(1, 1, 1),
                             data_format='channels_last',
                             instance_normalization=False):
  """UNet convolution block.

  Args:
    input_layer: tf.Tensor, the input tensor.
    n_filters: integer, the number of the output channels of the convolution.
    batch_normalization: boolean, use batch normalization after the convolution.
    kernel: kernel size of the convolution.
    activation: Tensorflow activation layer to use. (default is 'relu')
    padding: padding type of the convolution.
    strides: strides of the convolution.
    data_format: data format of the convolution. One of 'channels_first' or
      'channels_last'.
    instance_normalization: use Instance normalization. Exclusive with batch
      normalization.

  Returns:
    The Tensor after apply the convolution block to the input.
  """
  assert instance_normalization == 0, 'TF 2.0 does not support inst. norm.'
  layer = tf.keras.layers.Conv3D(
      filters=n_filters,
      kernel_size=kernel,
      strides=strides,
      padding=padding,
      data_format=data_format,
      activation=None,
  )(
      inputs=input_layer)
  if batch_normalization:
    layer = tf.keras.layers.BatchNormalization(axis=1)(inputs=layer)
  return activation(layer)


def apply_up_convolution(inputs,
                         num_filters,
                         pool_size,
                         kernel_size=(2, 2, 2),
                         strides=(2, 2, 2),
                         deconvolution=False):
  """Apply up convolution on inputs.

  Args:
    inputs: input feature tensor.
    num_filters: number of deconvolution output feature channels.
    pool_size: pool size of the up-scaling.
    kernel_size: kernel size of the deconvolution.
    strides: strides of the deconvolution.
    deconvolution: Use deconvolution or upsampling.

  Returns:
    The tensor of the up-scaled features.
  """
  if deconvolution:
    return tf.keras.layers.Conv3DTranspose(
        filters=num_filters, kernel_size=kernel_size, strides=strides)(
            inputs=inputs)
  else:
    return tf.keras.layers.UpSampling3D(size=pool_size)(inputs)


def unet3d_base(input_layer,
                pool_size=(2, 2, 2),
                n_labels=1,
                deconvolution=False,
                depth=4,
                n_base_filters=32,
                batch_normalization=False,
                data_format='channels_last'):
  """Builds the 3D UNet Tensorflow model and return the last layer logits.

  Args:
    input_layer: the input Tensor.
    pool_size: Pool size for the max pooling operations.
    n_labels: Number of binary labels that the model is learning.
    deconvolution: If set to True, will use transpose convolution(deconvolution)
      instead of up-sampling. This increases the amount memory required during
      training.
    depth: indicates the depth of the U-shape for the model. The greater the
      depth, the more max pooling layers will be added to the model. Lowering
      the depth may reduce the amount of memory required for training.
    n_base_filters: The number of filters that the first layer in the
      convolution network will have. Following layers will contain a multiple of
      this number. Lowering this number will likely reduce the amount of memory
      required to train the model.
    batch_normalization: boolean. True for use batch normalization after
      convolution and before activation.
    data_format: string, channel_last (default) or channel_first

  Returns:
    The last layer logits of 3D UNet.
  """
  levels = []
  current_layer = input_layer
  if data_format == 'channels_last':
    channel_dim = -1
  else:
    channel_dim = 1

  # add levels with max pooling
  for layer_depth in range(depth):
    layer1 = create_convolution_block(
        input_layer=current_layer,
        n_filters=n_base_filters * (2**layer_depth),
        batch_normalization=batch_normalization,
        kernel=(3, 3, 3),
        activation=tf.nn.relu,
        padding='SAME',
        strides=(1, 1, 1),
        data_format=data_format,
        instance_normalization=False)
    layer2 = create_convolution_block(
        input_layer=layer1,
        n_filters=n_base_filters * (2**layer_depth) * 2,
        batch_normalization=batch_normalization,
        kernel=(3, 3, 3),
        activation=tf.nn.relu,
        padding='SAME',
        strides=(1, 1, 1),
        data_format=data_format,
        instance_normalization=False)
    if layer_depth < depth - 1:
      current_layer = tf.keras.layers.MaxPool3D(
          pool_size=pool_size,
          strides=(2, 2, 2),
          padding='VALID',
          data_format=data_format)(
              inputs=layer2)
      levels.append([layer1, layer2, current_layer])
    else:
      current_layer = layer2
      levels.append([layer1, layer2])

  # add levels with up-convolution or up-sampling
  for layer_depth in range(depth - 2, -1, -1):
    up_convolution = apply_up_convolution(
        current_layer,
        pool_size=pool_size,
        deconvolution=deconvolution,
        num_filters=current_layer.get_shape().as_list()[channel_dim])
    concat = tf.concat([up_convolution, levels[layer_depth][1]],
                       axis=channel_dim)
    current_layer = create_convolution_block(
        n_filters=levels[layer_depth][1].get_shape().as_list()[channel_dim],
        input_layer=concat,
        batch_normalization=batch_normalization,
        kernel=(3, 3, 3),
        activation=tf.nn.relu,
        padding='SAME',
        strides=(1, 1, 1),
        data_format=data_format,
        instance_normalization=False)
    current_layer = create_convolution_block(
        n_filters=levels[layer_depth][1].get_shape().as_list()[channel_dim],
        input_layer=current_layer,
        batch_normalization=batch_normalization,
        kernel=(3, 3, 3),
        activation=tf.nn.relu,
        padding='SAME',
        strides=(1, 1, 1),
        data_format=data_format,
        instance_normalization=False)

  final_convolution = tf.keras.layers.Conv3D(
      filters=n_labels,
      kernel_size=(1, 1, 1),
      padding='VALID',
      data_format=data_format,
      activation=None)(
          current_layer)
  return final_convolution


def build_unet_model(params):
  """Builds the unet model, optimizer included."""
  input_shape = params.input_image_size + [1]
  input_layer = tf.keras.layers.Input(shape=input_shape)

  logits = unet3d_base(
      input_layer,
      pool_size=(2, 2, 2),
      n_labels=params.num_classes,
      deconvolution=params.deconvolution,
      depth=params.depth,
      n_base_filters=params.num_base_filters,
      batch_normalization=params.use_batch_norm,
      data_format=params.data_format)

  # Set output of softmax to float32 to avoid potential numerical overflow.
  predictions = tf.keras.layers.Softmax(dtype='float32')(logits)
  model = tf.keras.models.Model(inputs=input_layer, outputs=predictions)
  model.optimizer = create_optimizer(params.init_learning_rate, params)
  return model