semantic_segmentation.py 10.7 KB
Newer Older
Yeqing Li's avatar
Yeqing Li committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Yeqing Li's avatar
Yeqing Li committed
14

Abdullah Rashwan's avatar
Abdullah Rashwan committed
15
"""Image segmentation task definition."""
Fan Yang's avatar
Fan Yang committed
16
from typing import Any, Optional, List, Tuple, Mapping, Union
Abdullah Rashwan's avatar
Abdullah Rashwan committed
17
18
19

from absl import logging
import tensorflow as tf
20
from official.common import dataset_fn
Abdullah Rashwan's avatar
Abdullah Rashwan committed
21
22
23
from official.core import base_task
from official.core import task_factory
from official.vision.beta.configs import semantic_segmentation as exp_cfg
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
24
from official.vision.beta.dataloaders import input_reader_factory
Abdullah Rashwan's avatar
Abdullah Rashwan committed
25
from official.vision.beta.dataloaders import segmentation_input
Abdullah Rashwan's avatar
Abdullah Rashwan committed
26
from official.vision.beta.dataloaders import tfds_factory
Abdullah Rashwan's avatar
Abdullah Rashwan committed
27
28
29
30
31
from official.vision.beta.evaluation import segmentation_metrics
from official.vision.beta.losses import segmentation_losses
from official.vision.beta.modeling import factory


Abdullah Rashwan's avatar
Abdullah Rashwan committed
32
33
@task_factory.register_task_cls(exp_cfg.SemanticSegmentationTask)
class SemanticSegmentationTask(base_task.Task):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
34
  """A task for semantic segmentation."""
Abdullah Rashwan's avatar
Abdullah Rashwan committed
35
36

  def build_model(self):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
37
    """Builds segmentation model."""
Abdullah Rashwan's avatar
Abdullah Rashwan committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    input_specs = tf.keras.layers.InputSpec(
        shape=[None] + self.task_config.model.input_size)

    l2_weight_decay = self.task_config.losses.l2_weight_decay
    # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
    # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
    # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
    l2_regularizer = (tf.keras.regularizers.l2(
        l2_weight_decay / 2.0) if l2_weight_decay else None)

    model = factory.build_segmentation_model(
        input_specs=input_specs,
        model_config=self.task_config.model,
        l2_regularizer=l2_regularizer)
    return model

  def initialize(self, model: tf.keras.Model):
    """Loads pretrained checkpoint."""
    if not self.task_config.init_checkpoint:
      return

    ckpt_dir_or_file = self.task_config.init_checkpoint
    if tf.io.gfile.isdir(ckpt_dir_or_file):
      ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)

    # Restoring checkpoint.
    if 'all' in self.task_config.init_checkpoint_modules:
      ckpt = tf.train.Checkpoint(**model.checkpoint_items)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
66
67
      status = ckpt.read(ckpt_dir_or_file)
      status.expect_partial().assert_existing_objects_matched()
Abdullah Rashwan's avatar
Abdullah Rashwan committed
68
69
70
71
72
73
74
75
    else:
      ckpt_items = {}
      if 'backbone' in self.task_config.init_checkpoint_modules:
        ckpt_items.update(backbone=model.backbone)
      if 'decoder' in self.task_config.init_checkpoint_modules:
        ckpt_items.update(decoder=model.decoder)

      ckpt = tf.train.Checkpoint(**ckpt_items)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
76
      status = ckpt.read(ckpt_dir_or_file)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
77
78
79
80
81
      status.expect_partial().assert_existing_objects_matched()

    logging.info('Finished loading pretrained checkpoint from %s',
                 ckpt_dir_or_file)

Fan Yang's avatar
Fan Yang committed
82
83
84
  def build_inputs(self,
                   params: exp_cfg.DataConfig,
                   input_context: Optional[tf.distribute.InputContext] = None):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
85
86
87
88
    """Builds classification input."""

    ignore_label = self.task_config.losses.ignore_label

Abdullah Rashwan's avatar
Abdullah Rashwan committed
89
    if params.tfds_name:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
90
      decoder = tfds_factory.get_segmentation_decoder(params.tfds_name)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
91
92
93
    else:
      decoder = segmentation_input.Decoder()

Abdullah Rashwan's avatar
Abdullah Rashwan committed
94
    parser = segmentation_input.Parser(
Abdullah Rashwan's avatar
Abdullah Rashwan committed
95
        output_size=params.output_size,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
96
        crop_size=params.crop_size,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
97
98
99
100
101
        ignore_label=ignore_label,
        resize_eval_groundtruth=params.resize_eval_groundtruth,
        groundtruth_padded_size=params.groundtruth_padded_size,
        aug_scale_min=params.aug_scale_min,
        aug_scale_max=params.aug_scale_max,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
102
        aug_rand_hflip=params.aug_rand_hflip,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
103
        preserve_aspect_ratio=params.preserve_aspect_ratio,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
104
105
        dtype=params.dtype)

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
106
    reader = input_reader_factory.input_reader_generator(
Abdullah Rashwan's avatar
Abdullah Rashwan committed
107
        params,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
108
        dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
Abdullah Rashwan's avatar
Abdullah Rashwan committed
109
110
111
112
113
114
115
        decoder_fn=decoder.decode,
        parser_fn=parser.parse_fn(params.is_training))

    dataset = reader.read(input_context=input_context)

    return dataset

Fan Yang's avatar
Fan Yang committed
116
117
118
119
  def build_losses(self,
                   labels: Mapping[str, tf.Tensor],
                   model_outputs: Union[Mapping[str, tf.Tensor], tf.Tensor],
                   aux_losses: Optional[Any] = None):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
120
    """Segmentation loss.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134

    Args:
      labels: labels.
      model_outputs: Output logits of the classifier.
      aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.

    Returns:
      The total loss tensor.
    """
    loss_params = self._task_config.losses
    segmentation_loss_fn = segmentation_losses.SegmentationLoss(
        loss_params.label_smoothing,
        loss_params.class_weights,
        loss_params.ignore_label,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
135
136
        use_groundtruth_dimension=loss_params.use_groundtruth_dimension,
        top_k_percent_pixels=loss_params.top_k_percent_pixels)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
137
138
139
140
141
142

    total_loss = segmentation_loss_fn(model_outputs, labels['masks'])

    if aux_losses:
      total_loss += tf.add_n(aux_losses)

Abdullah Rashwan's avatar
Abdullah Rashwan committed
143
144
    total_loss = loss_params.loss_weight * total_loss

Abdullah Rashwan's avatar
Abdullah Rashwan committed
145
146
    return total_loss

Fan Yang's avatar
Fan Yang committed
147
  def build_metrics(self, training: bool = True):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
148
149
    """Gets streaming metrics for training/validation."""
    metrics = []
150
    if training and self.task_config.evaluation.report_train_mean_iou:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
151
152
153
154
155
      metrics.append(segmentation_metrics.MeanIoU(
          name='mean_iou',
          num_classes=self.task_config.model.num_classes,
          rescale_predictions=False,
          dtype=tf.float32))
Abdullah Rashwan's avatar
Abdullah Rashwan committed
156
    else:
157
158
      self.iou_metric = segmentation_metrics.PerClassIoU(
          name='per_class_iou',
Abdullah Rashwan's avatar
Abdullah Rashwan committed
159
160
          num_classes=self.task_config.model.num_classes,
          rescale_predictions=not self.task_config.validation_data
Abdullah Rashwan's avatar
Abdullah Rashwan committed
161
162
          .resize_eval_groundtruth,
          dtype=tf.float32)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
163

Abdullah Rashwan's avatar
Abdullah Rashwan committed
164
165
166
167
      # Update state on CPU if TPUStrategy due to dynamic resizing.
      self._process_iou_metric_on_cpu = isinstance(
          tf.distribute.get_strategy(), tf.distribute.TPUStrategy)

Abdullah Rashwan's avatar
Abdullah Rashwan committed
168
169
    return metrics

Fan Yang's avatar
Fan Yang committed
170
171
172
173
174
  def train_step(self,
                 inputs: Tuple[Any, Any],
                 model: tf.keras.Model,
                 optimizer: tf.keras.optimizers.Optimizer,
                 metrics: Optional[List[Any]] = None):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
175
176
177
178
179
180
181
182
183
184
185
186
    """Does forward and backward.

    Args:
      inputs: a dictionary of input tensors.
      model: the model, forward pass definition.
      optimizer: the optimizer for this training step.
      metrics: a nested structure of metrics objects.

    Returns:
      A dictionary of logs.
    """
    features, labels = inputs
Abdullah Rashwan's avatar
Abdullah Rashwan committed
187
188
189
190
191
192
193

    input_partition_dims = self.task_config.train_input_partition_dims
    if input_partition_dims:
      strategy = tf.distribute.get_strategy()
      features = strategy.experimental_split_to_logical_devices(
          features, input_partition_dims)

Abdullah Rashwan's avatar
Abdullah Rashwan committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
    with tf.GradientTape() as tape:
      outputs = model(features, training=True)
      # Casting output layer as float32 is necessary when mixed_precision is
      # mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
      outputs = tf.nest.map_structure(
          lambda x: tf.cast(x, tf.float32), outputs)

      # Computes per-replica loss.
      loss = self.build_losses(
          model_outputs=outputs, labels=labels, aux_losses=model.losses)
      # Scales loss as the default gradients allreduce performs sum inside the
      # optimizer.
      scaled_loss = loss / num_replicas

      # For mixed_precision policy, when LossScaleOptimizer is used, loss is
      # scaled for numerical stability.
Pankaj Kanwar's avatar
Pankaj Kanwar committed
211
      if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
212
213
214
215
216
217
        scaled_loss = optimizer.get_scaled_loss(scaled_loss)

    tvars = model.trainable_variables
    grads = tape.gradient(scaled_loss, tvars)
    # Scales back gradient before apply_gradients when LossScaleOptimizer is
    # used.
Pankaj Kanwar's avatar
Pankaj Kanwar committed
218
    if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
219
220
221
222
223
224
225
226
227
228
      grads = optimizer.get_unscaled_gradients(grads)
    optimizer.apply_gradients(list(zip(grads, tvars)))

    logs = {self.loss: loss}
    if metrics:
      self.process_metrics(metrics, labels, outputs)
      logs.update({m.name: m.result() for m in metrics})

    return logs

Fan Yang's avatar
Fan Yang committed
229
230
231
232
  def validation_step(self,
                      inputs: Tuple[Any, Any],
                      model: tf.keras.Model,
                      metrics: Optional[List[Any]] = None):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
233
234
235
236
237
238
239
240
241
242
243
244
    """Validatation step.

    Args:
      inputs: a dictionary of input tensors.
      model: the keras.Model.
      metrics: a nested structure of metrics objects.

    Returns:
      A dictionary of logs.
    """
    features, labels = inputs

Abdullah Rashwan's avatar
Abdullah Rashwan committed
245
246
247
248
249
250
    input_partition_dims = self.task_config.eval_input_partition_dims
    if input_partition_dims:
      strategy = tf.distribute.get_strategy()
      features = strategy.experimental_split_to_logical_devices(
          features, input_partition_dims)

Abdullah Rashwan's avatar
Abdullah Rashwan committed
251
252
    outputs = self.inference_step(features, model)
    outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
253
254
255
256
257
258

    if self.task_config.validation_data.resize_eval_groundtruth:
      loss = self.build_losses(model_outputs=outputs, labels=labels,
                               aux_losses=model.losses)
    else:
      loss = 0
Abdullah Rashwan's avatar
Abdullah Rashwan committed
259
260

    logs = {self.loss: loss}
Abdullah Rashwan's avatar
Abdullah Rashwan committed
261
262
263
264
265

    if self._process_iou_metric_on_cpu:
      logs.update({self.iou_metric.name: (labels, outputs)})
    else:
      self.iou_metric.update_state(labels, outputs)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
266
267
268
269
270
271
272

    if metrics:
      self.process_metrics(metrics, labels, outputs)
      logs.update({m.name: m.result() for m in metrics})

    return logs

Fan Yang's avatar
Fan Yang committed
273
  def inference_step(self, inputs: tf.Tensor, model: tf.keras.Model):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
274
275
276
277
278
    """Performs the forward step."""
    return model(inputs, training=False)

  def aggregate_logs(self, state=None, step_outputs=None):
    if state is None:
279
280
      self.iou_metric.reset_states()
      state = self.iou_metric
Abdullah Rashwan's avatar
Abdullah Rashwan committed
281
282
283
    if self._process_iou_metric_on_cpu:
      self.iou_metric.update_state(step_outputs[self.iou_metric.name][0],
                                   step_outputs[self.iou_metric.name][1])
Abdullah Rashwan's avatar
Abdullah Rashwan committed
284
285
    return state

286
  def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
287
288
289
290
291
292
293
294
295
    result = {}
    ious = self.iou_metric.result()
    # TODO(arashwan): support loading class name from a label map file.
    if self.task_config.evaluation.report_per_class_iou:
      for i, value in enumerate(ious.numpy()):
        result.update({'iou/{}'.format(i): value})
    # Computes mean IoU
    result.update({'mean_iou': tf.reduce_mean(ious).numpy()})
    return result