"build/CMakeFiles/3.16.5/CMakeSystem.cmake" did not exist on "7fb5cae59364b435d73e0a66399c99b8cf476162"
basnet.py 9.67 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Gunho Park's avatar
Gunho Park committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

Gunho Park's avatar
Gunho Park committed
15
16
17
18
19
20
"""BASNet task definition."""
from typing import Optional

from absl import logging
import tensorflow as tf

21
from official.common import dataset_fn
Gunho Park's avatar
Gunho Park committed
22
23
24
from official.core import base_task
from official.core import input_reader
from official.core import task_factory
25
26
27
28
29
from official.projects.basnet.configs import basnet as exp_cfg
from official.projects.basnet.evaluation import metrics as basnet_metrics
from official.projects.basnet.losses import basnet_losses
from official.projects.basnet.modeling import basnet_model
from official.projects.basnet.modeling import refunet
Gunho Park's avatar
Gunho Park committed
30
31
32
33
34
35
36
37
38
from official.vision.beta.dataloaders import segmentation_input


def build_basnet_model(
    input_specs: tf.keras.layers.InputSpec,
    model_config: exp_cfg.BASNetModel,
    l2_regularizer: tf.keras.regularizers.Regularizer = None):
  """Builds BASNet model."""
  norm_activation_config = model_config.norm_activation
39
  backbone = basnet_model.BASNetEncoder(
Gunho Park's avatar
Gunho Park committed
40
41
42
43
44
45
46
      input_specs=input_specs,
      activation=norm_activation_config.activation,
      use_sync_bn=norm_activation_config.use_sync_bn,
      use_bias=model_config.use_bias,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      kernel_regularizer=l2_regularizer)
47
48

  decoder = basnet_model.BASNetDecoder(
Gunho Park's avatar
Gunho Park committed
49
50
51
52
53
54
      activation=norm_activation_config.activation,
      use_sync_bn=norm_activation_config.use_sync_bn,
      use_bias=model_config.use_bias,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      kernel_regularizer=l2_regularizer)
Gunho Park's avatar
Gunho Park committed
55

Gunho Park's avatar
Gunho Park committed
56
57
58
59
60
61
62
  refinement = refunet.RefUnet(
      activation=norm_activation_config.activation,
      use_sync_bn=norm_activation_config.use_sync_bn,
      use_bias=model_config.use_bias,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      kernel_regularizer=l2_regularizer)
63

Gunho Park's avatar
Gunho Park committed
64
65
66
  model = basnet_model.BASNetModel(backbone, decoder, refinement)
  return model

67

Gunho Park's avatar
Gunho Park committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
@task_factory.register_task_cls(exp_cfg.BASNetTask)
class BASNetTask(base_task.Task):
  """A task for basnet."""

  def build_model(self):
    """Builds basnet model."""
    input_specs = tf.keras.layers.InputSpec(
        shape=[None] + self.task_config.model.input_size)

    l2_weight_decay = self.task_config.losses.l2_weight_decay
    # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
    # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
    # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
    l2_regularizer = (tf.keras.regularizers.l2(
        l2_weight_decay / 2.0) if l2_weight_decay else None)

    model = build_basnet_model(
        input_specs=input_specs,
        model_config=self.task_config.model,
        l2_regularizer=l2_regularizer)
    return model

  def initialize(self, model: tf.keras.Model):
    """Loads pretrained checkpoint."""
    if not self.task_config.init_checkpoint:
      return

    ckpt_dir_or_file = self.task_config.init_checkpoint
    if tf.io.gfile.isdir(ckpt_dir_or_file):
      ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)

    # Restoring checkpoint.
    if 'all' in self.task_config.init_checkpoint_modules:
      ckpt = tf.train.Checkpoint(**model.checkpoint_items)
      status = ckpt.restore(ckpt_dir_or_file)
      status.assert_consumed()
    else:
      ckpt_items = {}
      if 'backbone' in self.task_config.init_checkpoint_modules:
        ckpt_items.update(backbone=model.backbone)
      if 'decoder' in self.task_config.init_checkpoint_modules:
        ckpt_items.update(decoder=model.decoder)

      ckpt = tf.train.Checkpoint(**ckpt_items)
      status = ckpt.restore(ckpt_dir_or_file)
      status.expect_partial().assert_existing_objects_matched()

    logging.info('Finished loading pretrained checkpoint from %s',
                 ckpt_dir_or_file)

  def build_inputs(self,
                   params: exp_cfg.DataConfig,
                   input_context: Optional[tf.distribute.InputContext] = None):
    """Builds BASNet input."""

    ignore_label = self.task_config.losses.ignore_label

    decoder = segmentation_input.Decoder()
    parser = segmentation_input.Parser(
        output_size=params.output_size,
        crop_size=params.crop_size,
        ignore_label=ignore_label,
        aug_rand_hflip=params.aug_rand_hflip,
        dtype=params.dtype)

    reader = input_reader.InputReader(
        params,
135
        dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
Gunho Park's avatar
Gunho Park committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        decoder_fn=decoder.decode,
        parser_fn=parser.parse_fn(params.is_training))

    dataset = reader.read(input_context=input_context)

    return dataset

  def build_losses(self, label, model_outputs, aux_losses=None):
    """Hybrid loss proposed in BASNet.

    Args:
      label: label.
      model_outputs: Output logits of the classifier.
      aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.

    Returns:
      The total loss tensor.
    """
    basnet_loss_fn = basnet_losses.BASNetLoss()
    total_loss = basnet_loss_fn(model_outputs, label['masks'])

    if aux_losses:
      total_loss += tf.add_n(aux_losses)

    return total_loss

  def build_metrics(self, training=False):
    """Gets streaming metrics for training/validation."""
164
    evaluations = []
165

Gunho Park's avatar
Gunho Park committed
166
    if training:
167
      evaluations = []
Gunho Park's avatar
Gunho Park committed
168
    else:
169
170
171
      self.mae_metric = basnet_metrics.MAE()
      self.maxf_metric = basnet_metrics.MaxFscore()
      self.relaxf_metric = basnet_metrics.RelaxedFscore()
Gunho Park's avatar
Gunho Park committed
172

173
    return evaluations
Gunho Park's avatar
Gunho Park committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205

  def train_step(self, inputs, model, optimizer, metrics=None):
    """Does forward and backward.

    Args:
      inputs: a dictionary of input tensors.
      model: the model, forward pass definition.
      optimizer: the optimizer for this training step.
      metrics: a nested structure of metrics objects.

    Returns:
      A dictionary of logs.
    """
    features, labels = inputs
    num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
    with tf.GradientTape() as tape:
      outputs = model(features, training=True)
      # Casting output layer as float32 is necessary when mixed_precision is
      # mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
      outputs = tf.nest.map_structure(
          lambda x: tf.cast(x, tf.float32), outputs)

      # Computes per-replica loss.
      loss = self.build_losses(
          model_outputs=outputs, label=labels, aux_losses=model.losses)

      # Scales loss as the default gradients allreduce performs sum inside the
      # optimizer.
      scaled_loss = loss / num_replicas

      # For mixed_precision policy, when LossScaleOptimizer is used, loss is
      # scaled for numerical stability.
Reed Wanderman-Milne's avatar
Reed Wanderman-Milne committed
206
      if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
Gunho Park's avatar
Gunho Park committed
207
208
209
210
211
212
213
        scaled_loss = optimizer.get_scaled_loss(scaled_loss)

    tvars = model.trainable_variables
    grads = tape.gradient(scaled_loss, tvars)

    # Scales back gradient before apply_gradients when LossScaleOptimizer is
    # used.
Reed Wanderman-Milne's avatar
Reed Wanderman-Milne committed
214
    if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
Gunho Park's avatar
Gunho Park committed
215
216
217
218
219
220
221
222
223
224
225
226
      grads = optimizer.get_unscaled_gradients(grads)

    # Apply gradient clipping.
    if self.task_config.gradient_clip_norm > 0:
      grads, _ = tf.clip_by_global_norm(
          grads, self.task_config.gradient_clip_norm)
    optimizer.apply_gradients(list(zip(grads, tvars)))
    logs = {self.loss: loss}
    return logs

  def validation_step(self, inputs, model, metrics=None):
    """Validatation step.
227

Gunho Park's avatar
Gunho Park committed
228
229
230
231
232
233
234
235
236
237
238
    Args:
      inputs: a dictionary of input tensors.
      model: the keras.Model.
      metrics: a nested structure of metrics objects.
    Returns:
      A dictionary of logs.
    """
    features, labels = inputs

    outputs = self.inference_step(features, model)
    outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
239

Gunho Park's avatar
Gunho Park committed
240
241
    loss = 0
    logs = {self.loss: loss}
242

Gunho Park's avatar
Gunho Park committed
243
    levels = sorted(outputs.keys())
244

Gunho Park's avatar
Gunho Park committed
245
246
247
248
249
250
    logs.update(
        {self.mae_metric.name: (labels['masks'], outputs[levels[-1]])})
    logs.update(
        {self.maxf_metric.name: (labels['masks'], outputs[levels[-1]])})
    logs.update(
        {self.relaxf_metric.name: (labels['masks'], outputs[levels[-1]])})
251
    return logs
Gunho Park's avatar
Gunho Park committed
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279

  def inference_step(self, inputs, model):
    """Performs the forward step."""
    return model(inputs, training=False)

  def aggregate_logs(self, state=None, step_outputs=None):
    if state is None:
      self.mae_metric.reset_states()
      self.maxf_metric.reset_states()
      self.relaxf_metric.reset_states()
      state = self.mae_metric
    self.mae_metric.update_state(
        step_outputs[self.mae_metric.name][0],
        step_outputs[self.mae_metric.name][1])
    self.maxf_metric.update_state(
        step_outputs[self.maxf_metric.name][0],
        step_outputs[self.maxf_metric.name][1])
    self.relaxf_metric.update_state(
        step_outputs[self.relaxf_metric.name][0],
        step_outputs[self.relaxf_metric.name][1])
    return state

  def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
    result = {}
    result['MAE'] = self.mae_metric.result()
    result['maxF'] = self.maxf_metric.result()
    result['relaxF'] = self.relaxf_metric.result()
    return result