losses_builder.py 7.72 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""A function to build localization and classification losses from config."""

18
from object_detection.core import balanced_positive_negative_sampler as sampler
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from object_detection.core import losses
from object_detection.protos import losses_pb2


def build(loss_config):
  """Build losses based on the config.

  Builds classification, localization losses and optionally a hard example miner
  based on the config.

  Args:
    loss_config: A losses_pb2.Loss object.

  Returns:
    classification_loss: Classification loss object.
    localization_loss: Localization loss object.
    classification_weight: Classification loss weight.
    localization_weight: Localization loss weight.
    hard_example_miner: Hard example miner object.
38
    random_example_sampler: BalancedPositiveNegativeSampler object.
Vivek Rathod's avatar
Vivek Rathod committed
39
40
41

  Raises:
    ValueError: If hard_example_miner is used with sigmoid_focal_loss.
42
43
    ValueError: If random_example_sampler is getting non-positive value as
      desired positive example fraction.
44
45
46
47
48
49
50
51
52
  """
  classification_loss = _build_classification_loss(
      loss_config.classification_loss)
  localization_loss = _build_localization_loss(
      loss_config.localization_loss)
  classification_weight = loss_config.classification_weight
  localization_weight = loss_config.localization_weight
  hard_example_miner = None
  if loss_config.HasField('hard_example_miner'):
Vivek Rathod's avatar
Vivek Rathod committed
53
54
55
56
    if (loss_config.classification_loss.WhichOneof('classification_loss') ==
        'weighted_sigmoid_focal'):
      raise ValueError('HardExampleMiner should not be used with sigmoid focal '
                       'loss')
57
58
59
60
    hard_example_miner = build_hard_example_miner(
        loss_config.hard_example_miner,
        classification_weight,
        localization_weight)
61
62
63
64
65
66
67
68
69
70
  random_example_sampler = None
  if loss_config.HasField('random_example_sampler'):
    if loss_config.random_example_sampler.positive_sample_fraction <= 0:
      raise ValueError('RandomExampleSampler should not use non-positive'
                       'value as positive sample fraction.')
    random_example_sampler = sampler.BalancedPositiveNegativeSampler(
        positive_fraction=loss_config.random_example_sampler.
        positive_sample_fraction)
  return (classification_loss, localization_loss, classification_weight,
          localization_weight, hard_example_miner, random_example_sampler)
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111


def build_hard_example_miner(config,
                             classification_weight,
                             localization_weight):
  """Builds hard example miner based on the config.

  Args:
    config: A losses_pb2.HardExampleMiner object.
    classification_weight: Classification loss weight.
    localization_weight: Localization loss weight.

  Returns:
    Hard example miner.

  """
  loss_type = None
  if config.loss_type == losses_pb2.HardExampleMiner.BOTH:
    loss_type = 'both'
  if config.loss_type == losses_pb2.HardExampleMiner.CLASSIFICATION:
    loss_type = 'cls'
  if config.loss_type == losses_pb2.HardExampleMiner.LOCALIZATION:
    loss_type = 'loc'

  max_negatives_per_positive = None
  num_hard_examples = None
  if config.max_negatives_per_positive > 0:
    max_negatives_per_positive = config.max_negatives_per_positive
  if config.num_hard_examples > 0:
    num_hard_examples = config.num_hard_examples
  hard_example_miner = losses.HardExampleMiner(
      num_hard_examples=num_hard_examples,
      iou_threshold=config.iou_threshold,
      loss_type=loss_type,
      cls_loss_weight=classification_weight,
      loc_loss_weight=localization_weight,
      max_negatives_per_positive=max_negatives_per_positive,
      min_negatives_per_image=config.min_negatives_per_image)
  return hard_example_miner


Vivek Rathod's avatar
Vivek Rathod committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def build_faster_rcnn_classification_loss(loss_config):
  """Builds a classification loss for Faster RCNN based on the loss config.

  Args:
    loss_config: A losses_pb2.ClassificationLoss object.

  Returns:
    Loss based on the config.

  Raises:
    ValueError: On invalid loss_config.
  """
  if not isinstance(loss_config, losses_pb2.ClassificationLoss):
    raise ValueError('loss_config not of type losses_pb2.ClassificationLoss.')

  loss_type = loss_config.WhichOneof('classification_loss')

  if loss_type == 'weighted_sigmoid':
130
    return losses.WeightedSigmoidClassificationLoss()
Vivek Rathod's avatar
Vivek Rathod committed
131
132
133
  if loss_type == 'weighted_softmax':
    config = loss_config.weighted_softmax
    return losses.WeightedSoftmaxClassificationLoss(
134
        logit_scale=config.logit_scale)
135
136
137
138
  if loss_type == 'weighted_logits_softmax':
    config = loss_config.weighted_logits_softmax
    return losses.WeightedSoftmaxClassificationAgainstLogitsLoss(
        logit_scale=config.logit_scale)
Vivek Rathod's avatar
Vivek Rathod committed
139
140
141

  # By default, Faster RCNN second stage classifier uses Softmax loss
  # with anchor-wise outputs.
142
  config = loss_config.weighted_softmax
Vivek Rathod's avatar
Vivek Rathod committed
143
  return losses.WeightedSoftmaxClassificationLoss(
144
      logit_scale=config.logit_scale)
Vivek Rathod's avatar
Vivek Rathod committed
145
146


147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def _build_localization_loss(loss_config):
  """Builds a localization loss based on the loss config.

  Args:
    loss_config: A losses_pb2.LocalizationLoss object.

  Returns:
    Loss based on the config.

  Raises:
    ValueError: On invalid loss_config.
  """
  if not isinstance(loss_config, losses_pb2.LocalizationLoss):
    raise ValueError('loss_config not of type losses_pb2.LocalizationLoss.')

  loss_type = loss_config.WhichOneof('localization_loss')

  if loss_type == 'weighted_l2':
165
    return losses.WeightedL2LocalizationLoss()
166
167

  if loss_type == 'weighted_smooth_l1':
168
169
    return losses.WeightedSmoothL1LocalizationLoss(
        loss_config.weighted_smooth_l1.delta)
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194

  if loss_type == 'weighted_iou':
    return losses.WeightedIOULocalizationLoss()

  raise ValueError('Empty loss config.')


def _build_classification_loss(loss_config):
  """Builds a classification loss based on the loss config.

  Args:
    loss_config: A losses_pb2.ClassificationLoss object.

  Returns:
    Loss based on the config.

  Raises:
    ValueError: On invalid loss_config.
  """
  if not isinstance(loss_config, losses_pb2.ClassificationLoss):
    raise ValueError('loss_config not of type losses_pb2.ClassificationLoss.')

  loss_type = loss_config.WhichOneof('classification_loss')

  if loss_type == 'weighted_sigmoid':
195
    return losses.WeightedSigmoidClassificationLoss()
196

Vivek Rathod's avatar
Vivek Rathod committed
197
198
199
200
201
202
203
204
205
  if loss_type == 'weighted_sigmoid_focal':
    config = loss_config.weighted_sigmoid_focal
    alpha = None
    if config.HasField('alpha'):
      alpha = config.alpha
    return losses.SigmoidFocalClassificationLoss(
        gamma=config.gamma,
        alpha=alpha)

206
207
208
  if loss_type == 'weighted_softmax':
    config = loss_config.weighted_softmax
    return losses.WeightedSoftmaxClassificationLoss(
Vivek Rathod's avatar
Vivek Rathod committed
209
        logit_scale=config.logit_scale)
210

211
212
213
214
215
  if loss_type == 'weighted_logits_softmax':
    config = loss_config.weighted_logits_softmax
    return losses.WeightedSoftmaxClassificationAgainstLogitsLoss(
        logit_scale=config.logit_scale)

216
217
218
219
  if loss_type == 'bootstrapped_sigmoid':
    config = loss_config.bootstrapped_sigmoid
    return losses.BootstrappedSigmoidClassificationLoss(
        alpha=config.alpha,
220
        bootstrap_type=('hard' if config.hard_bootstrap else 'soft'))
221
222

  raise ValueError('Empty loss config.')