losses_builder.py 6.56 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""A function to build localization and classification losses from config."""

from object_detection.core import losses
from object_detection.protos import losses_pb2


def build(loss_config):
  """Build losses based on the config.

  Builds classification, localization losses and optionally a hard example miner
  based on the config.

  Args:
    loss_config: A losses_pb2.Loss object.

  Returns:
    classification_loss: Classification loss object.
    localization_loss: Localization loss object.
    classification_weight: Classification loss weight.
    localization_weight: Localization loss weight.
    hard_example_miner: Hard example miner object.
Vivek Rathod's avatar
Vivek Rathod committed
37
38
39

  Raises:
    ValueError: If hard_example_miner is used with sigmoid_focal_loss.
40
41
42
43
44
45
46
47
48
  """
  classification_loss = _build_classification_loss(
      loss_config.classification_loss)
  localization_loss = _build_localization_loss(
      loss_config.localization_loss)
  classification_weight = loss_config.classification_weight
  localization_weight = loss_config.localization_weight
  hard_example_miner = None
  if loss_config.HasField('hard_example_miner'):
Vivek Rathod's avatar
Vivek Rathod committed
49
50
51
52
    if (loss_config.classification_loss.WhichOneof('classification_loss') ==
        'weighted_sigmoid_focal'):
      raise ValueError('HardExampleMiner should not be used with sigmoid focal '
                       'loss')
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    hard_example_miner = build_hard_example_miner(
        loss_config.hard_example_miner,
        classification_weight,
        localization_weight)
  return (classification_loss, localization_loss,
          classification_weight,
          localization_weight, hard_example_miner)


def build_hard_example_miner(config,
                             classification_weight,
                             localization_weight):
  """Builds hard example miner based on the config.

  Args:
    config: A losses_pb2.HardExampleMiner object.
    classification_weight: Classification loss weight.
    localization_weight: Localization loss weight.

  Returns:
    Hard example miner.

  """
  loss_type = None
  if config.loss_type == losses_pb2.HardExampleMiner.BOTH:
    loss_type = 'both'
  if config.loss_type == losses_pb2.HardExampleMiner.CLASSIFICATION:
    loss_type = 'cls'
  if config.loss_type == losses_pb2.HardExampleMiner.LOCALIZATION:
    loss_type = 'loc'

  max_negatives_per_positive = None
  num_hard_examples = None
  if config.max_negatives_per_positive > 0:
    max_negatives_per_positive = config.max_negatives_per_positive
  if config.num_hard_examples > 0:
    num_hard_examples = config.num_hard_examples
  hard_example_miner = losses.HardExampleMiner(
      num_hard_examples=num_hard_examples,
      iou_threshold=config.iou_threshold,
      loss_type=loss_type,
      cls_loss_weight=classification_weight,
      loc_loss_weight=localization_weight,
      max_negatives_per_positive=max_negatives_per_positive,
      min_negatives_per_image=config.min_negatives_per_image)
  return hard_example_miner


Vivek Rathod's avatar
Vivek Rathod committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def build_faster_rcnn_classification_loss(loss_config):
  """Builds a classification loss for Faster RCNN based on the loss config.

  Args:
    loss_config: A losses_pb2.ClassificationLoss object.

  Returns:
    Loss based on the config.

  Raises:
    ValueError: On invalid loss_config.
  """
  if not isinstance(loss_config, losses_pb2.ClassificationLoss):
    raise ValueError('loss_config not of type losses_pb2.ClassificationLoss.')

  loss_type = loss_config.WhichOneof('classification_loss')

  if loss_type == 'weighted_sigmoid':
119
    return losses.WeightedSigmoidClassificationLoss()
Vivek Rathod's avatar
Vivek Rathod committed
120
121
122
  if loss_type == 'weighted_softmax':
    config = loss_config.weighted_softmax
    return losses.WeightedSoftmaxClassificationLoss(
123
        logit_scale=config.logit_scale)
Vivek Rathod's avatar
Vivek Rathod committed
124
125
126

  # By default, Faster RCNN second stage classifier uses Softmax loss
  # with anchor-wise outputs.
127
  config = loss_config.weighted_softmax
Vivek Rathod's avatar
Vivek Rathod committed
128
  return losses.WeightedSoftmaxClassificationLoss(
129
      logit_scale=config.logit_scale)
Vivek Rathod's avatar
Vivek Rathod committed
130
131


132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def _build_localization_loss(loss_config):
  """Builds a localization loss based on the loss config.

  Args:
    loss_config: A losses_pb2.LocalizationLoss object.

  Returns:
    Loss based on the config.

  Raises:
    ValueError: On invalid loss_config.
  """
  if not isinstance(loss_config, losses_pb2.LocalizationLoss):
    raise ValueError('loss_config not of type losses_pb2.LocalizationLoss.')

  loss_type = loss_config.WhichOneof('localization_loss')

  if loss_type == 'weighted_l2':
150
    return losses.WeightedL2LocalizationLoss()
151
152

  if loss_type == 'weighted_smooth_l1':
153
    return losses.WeightedSmoothL1LocalizationLoss()
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178

  if loss_type == 'weighted_iou':
    return losses.WeightedIOULocalizationLoss()

  raise ValueError('Empty loss config.')


def _build_classification_loss(loss_config):
  """Builds a classification loss based on the loss config.

  Args:
    loss_config: A losses_pb2.ClassificationLoss object.

  Returns:
    Loss based on the config.

  Raises:
    ValueError: On invalid loss_config.
  """
  if not isinstance(loss_config, losses_pb2.ClassificationLoss):
    raise ValueError('loss_config not of type losses_pb2.ClassificationLoss.')

  loss_type = loss_config.WhichOneof('classification_loss')

  if loss_type == 'weighted_sigmoid':
179
    return losses.WeightedSigmoidClassificationLoss()
180

Vivek Rathod's avatar
Vivek Rathod committed
181
182
183
184
185
186
187
188
189
  if loss_type == 'weighted_sigmoid_focal':
    config = loss_config.weighted_sigmoid_focal
    alpha = None
    if config.HasField('alpha'):
      alpha = config.alpha
    return losses.SigmoidFocalClassificationLoss(
        gamma=config.gamma,
        alpha=alpha)

190
191
192
  if loss_type == 'weighted_softmax':
    config = loss_config.weighted_softmax
    return losses.WeightedSoftmaxClassificationLoss(
Vivek Rathod's avatar
Vivek Rathod committed
193
        logit_scale=config.logit_scale)
194
195
196
197
198

  if loss_type == 'bootstrapped_sigmoid':
    config = loss_config.bootstrapped_sigmoid
    return losses.BootstrappedSigmoidClassificationLoss(
        alpha=config.alpha,
199
        bootstrap_type=('hard' if config.hard_bootstrap else 'soft'))
200
201

  raise ValueError('Empty loss config.')