"Ruyi-Models/__init__.py" did not exist on "1a6b26f1127c207418d23861ce84bad7d5710ab5"
segmentation_losses.py 9.3 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Losses used for segmentation models."""

import tensorflow as tf

from official.modeling import tf_utils
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
20
from official.vision.dataloaders import utils
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
21
22
23
24
25
26
27

EPSILON = 1e-5


class SegmentationLoss:
  """Semantic segmentation loss."""

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
28
29
30
31
32
  def __init__(self,
               label_smoothing,
               class_weights,
               ignore_label,
               use_groundtruth_dimension,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
33
34
35
               top_k_percent_pixels=1.0,
               gt_is_matting_map=False
               ):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
36
37
38
39
40
41
42
    """Initializes `SegmentationLoss`.

    Args:
      label_smoothing: A float, if > 0., smooth out one-hot probability by
        spreading the amount of probability to all other label classes.
      class_weights: A float list containing the weight of each class.
      ignore_label: An integer specifying the ignore label.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
43

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
44
45
46
47
48
      use_groundtruth_dimension: A boolean, whether to resize the output to
        match the dimension of the ground truth.
      top_k_percent_pixels: A float, the value lies in [0.0, 1.0]. When its
        value < 1., only compute the loss for the top k percent pixels. This is
        useful for hard pixel mining.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
49
50
      gt_is_matting_map: If or not the groundtruth mask is a matting map. Note
        that the matting map is only supported for 2 class segmentation.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
51
52
    """
    self._label_smoothing = label_smoothing
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
53
54
55
    self._class_weights = class_weights
    self._ignore_label = ignore_label
    self._use_groundtruth_dimension = use_groundtruth_dimension
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
56
    self._top_k_percent_pixels = top_k_percent_pixels
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
57
    self._gt_is_matting_map = gt_is_matting_map
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
58

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
59
  def __call__(self, logits, labels, **kwargs):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
60
61
62
63
64
65
66
67
68
69
70
71
72
    """Computes `SegmentationLoss`.

    Args:
      logits: A float tensor in shape (batch_size, height, width, num_classes)
        which is the output of the network.
      labels: A tensor in shape (batch_size, height, width, 1), which is the
        label mask of the ground truth.
      **kwargs: additional keyword arguments.

    Returns:
       A 0-D float which stores the overall loss of the batch.
    """
    _, height, width, _ = logits.get_shape().as_list()
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
73
74
75
76

    if self._use_groundtruth_dimension:
      # TODO(arashwan): Test using align corners to match deeplab alignment.
      logits = tf.image.resize(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
77
          logits, tf.shape(labels)[1:3], method=tf.image.ResizeMethod.BILINEAR)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
78
79
80
81
82
    else:
      labels = tf.image.resize(
          labels, (height, width),
          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
83
84
85
86
    # Do not need to cast into int32 if it is a matting map
    if not self._gt_is_matting_map:
      labels = tf.cast(labels, tf.int32)

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
87
    valid_mask = tf.not_equal(labels, self._ignore_label)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
88

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    cross_entropy_loss = self.compute_pixelwise_loss(labels, logits, valid_mask,
                                                     **kwargs)

    if self._top_k_percent_pixels < 1.0:
      return self.aggregate_loss_top_k(cross_entropy_loss)
    else:
      return self.aggregate_loss(cross_entropy_loss, valid_mask)

  def compute_pixelwise_loss(self, labels, logits, valid_mask, **kwargs):
    """Computes the loss for each pixel.

    Args:
      labels: An int32 tensor in shape (batch_size, height, width, 1), which is
        the label mask of the ground truth.
      logits: A float tensor in shape (batch_size, height, width, num_classes)
        which is the output of the network.
      valid_mask: A bool tensor in shape (batch_size, height, width, 1) which
        masks out ignored pixels.
      **kwargs: additional keyword arguments.

    Returns:
       A float tensor in shape (batch_size, height, width) which stores the loss
       value for each pixel.
    """
    num_classes = logits.get_shape().as_list()[-1]

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
115
116
117
118
119
    # Assign pixel with ignore label to class 0 (background). The loss on the
    # pixel will later be masked out.
    labels = tf.where(valid_mask, labels, tf.zeros_like(labels))

    cross_entropy_loss = tf.nn.softmax_cross_entropy_with_logits(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
120
121
        labels=self.get_labels_with_prob(labels, logits, **kwargs),
        logits=logits)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
122
123
124
125
126
127
128
129
130
131

    if not self._class_weights:
      class_weights = [1] * num_classes
    else:
      class_weights = self._class_weights

    if num_classes != len(class_weights):
      raise ValueError(
          'Length of class_weights should be {}'.format(num_classes))

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
132
    valid_mask = tf.squeeze(tf.cast(valid_mask, tf.float32), axis=-1)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
133
134
135
136
137
138

    # If groundtruth is matting map, binarize the value to create the weight
    # mask
    if self._gt_is_matting_map:
      labels = tf.cast(utils.binarize_matting_map(labels), tf.int32)

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
139
140
141
142
143
    weight_mask = tf.einsum(
        '...y,y->...',
        tf.one_hot(tf.squeeze(labels, axis=-1), num_classes, dtype=tf.float32),
        tf.constant(class_weights, tf.float32))
    return cross_entropy_loss * valid_mask * weight_mask
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
144

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
145
146
147
148
149
150
  def get_labels_with_prob(self, labels, logits, **unused_kwargs):
    """Get a tensor representing the probability of each class for each pixel.

    This method can be overridden in subclasses for customizing loss function.

    Args:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
151
152
153
      labels: If groundtruth mask is not matting map, an int32 tensor which is
      the label map of the groundtruth. If groundtruth mask is matting map,
      an float32 tensor. The shape is always (batch_size, height, width, 1).
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
154
155
156
157
158
159
160
161
      logits: A float tensor in shape (batch_size, height, width, num_classes)
        which is the output of the network.
      **unused_kwargs: Unused keyword arguments.

    Returns:
       A float tensor in shape (batch_size, height, width, num_classes).
    """
    num_classes = logits.get_shape().as_list()[-1]
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
162
163
164
165
166
167
168

    if self._gt_is_matting_map:
      train_labels = tf.concat([1 - labels, labels], axis=-1)
    else:
      labels = tf.squeeze(labels, axis=-1)
      train_labels = tf.one_hot(labels, num_classes)
    return train_labels * (
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
169
170
        1 - self._label_smoothing) + self._label_smoothing / num_classes

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
  def aggregate_loss(self, pixelwise_loss, valid_mask):
    """Aggregate the pixelwise loss.

    Args:
      pixelwise_loss: A float tensor in shape (batch_size, height, width) which
        stores the loss of each pixel.
      valid_mask: A bool tensor in shape (batch_size, height, width, 1) which
        masks out ignored pixels.

    Returns:
       A 0-D float which stores the overall loss of the batch.
    """
    normalizer = tf.reduce_sum(tf.cast(valid_mask, tf.float32)) + EPSILON
    return tf.reduce_sum(pixelwise_loss) / normalizer

  def aggregate_loss_top_k(self, pixelwise_loss):
    """Aggregate the top-k greatest pixelwise loss.

    Args:
      pixelwise_loss: A float tensor in shape (batch_size, height, width) which
        stores the loss of each pixel.

    Returns:
       A 0-D float which stores the overall loss of the batch.
    """
    pixelwise_loss = tf.reshape(pixelwise_loss, shape=[-1])
    top_k_pixels = tf.cast(
        self._top_k_percent_pixels *
        tf.cast(tf.size(pixelwise_loss), tf.float32), tf.int32)
    top_k_losses, _ = tf.math.top_k(pixelwise_loss, k=top_k_pixels, sorted=True)
    normalizer = tf.reduce_sum(
        tf.cast(tf.not_equal(top_k_losses, 0.0), tf.float32)) + EPSILON
    return tf.reduce_sum(top_k_losses) / normalizer

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
205
206
207
208
209
210
211

def get_actual_mask_scores(logits, labels, ignore_label):
  """Gets actual mask scores."""
  _, height, width, num_classes = logits.get_shape().as_list()
  batch_size = tf.shape(logits)[0]
  logits = tf.stop_gradient(logits)
  labels = tf.image.resize(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
212
      labels, (height, width), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
  predicted_labels = tf.argmax(logits, -1, output_type=tf.int32)
  flat_predictions = tf.reshape(predicted_labels, [batch_size, -1])
  flat_labels = tf.cast(tf.reshape(labels, [batch_size, -1]), tf.int32)

  one_hot_predictions = tf.one_hot(
      flat_predictions, num_classes, on_value=True, off_value=False)
  one_hot_labels = tf.one_hot(
      flat_labels, num_classes, on_value=True, off_value=False)
  keep_mask = tf.not_equal(flat_labels, ignore_label)
  keep_mask = tf.expand_dims(keep_mask, 2)

  overlap = tf.logical_and(one_hot_predictions, one_hot_labels)
  overlap = tf.logical_and(overlap, keep_mask)
  overlap = tf.reduce_sum(tf.cast(overlap, tf.float32), axis=1)
  union = tf.logical_or(one_hot_predictions, one_hot_labels)
  union = tf.logical_and(union, keep_mask)
  union = tf.reduce_sum(tf.cast(union, tf.float32), axis=1)
  actual_scores = tf.divide(overlap, tf.maximum(union, EPSILON))
  return actual_scores


class MaskScoringLoss:
  """Mask Scoring loss."""

  def __init__(self, ignore_label):
    self._ignore_label = ignore_label
    self._mse_loss = tf.keras.losses.MeanSquaredError(
        reduction=tf.keras.losses.Reduction.NONE)

  def __call__(self, predicted_scores, logits, labels):
    actual_scores = get_actual_mask_scores(logits, labels, self._ignore_label)
    loss = tf_utils.safe_mean(self._mse_loss(actual_scores, predicted_scores))
    return loss