losses.proto 8.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
syntax = "proto2";

package object_detection.protos;

// Message for configuring the localization loss, classification loss and hard
// example miner used for training object detection models. See core/losses.py
// for details
message Loss {
  // Localization loss to use.
  optional LocalizationLoss localization_loss = 1;

  // Classification loss to use.
  optional ClassificationLoss classification_loss = 2;

  // If not left to default, applies hard example mining.
  optional HardExampleMiner hard_example_miner = 3;

  // Classification loss weight.
  optional float classification_weight = 4 [default=1.0];

  // Localization loss weight.
  optional float localization_weight = 5 [default=1.0];
23
24
25

  // If not left to default, applies random example sampling.
  optional RandomExampleSampler random_example_sampler = 6;
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

  // Equalization loss.
  message EqualizationLoss {
    // Weight equalization loss strength.
    optional float weight = 1 [default=0.0];

    // When computing equalization loss, ops that start with
    // equalization_exclude_prefixes will be ignored. Only used when
    // equalization_weight > 0.
    repeated string exclude_prefixes = 2;
  }

  optional EqualizationLoss equalization_loss = 7;

  enum ExpectedLossWeights {
    NONE = 0;
    // Use expected_classification_loss_by_expected_sampling
    // from third_party/tensorflow_models/object_detection/utils/ops.py
    EXPECTED_SAMPLING = 1;
    // Use expected_classification_loss_by_reweighting_unmatched_anchors
    // from third_party/tensorflow_models/object_detection/utils/ops.py
    REWEIGHTING_UNMATCHED_ANCHORS = 2;
  }

  // Method to compute expected loss weights with respect to balanced
  // positive/negative sampling scheme. If NONE, use explicit sampling.
  // TODO(birdbrain): Move under ExpectedLossWeights.
  optional ExpectedLossWeights expected_loss_weights = 18 [default = NONE];

  // Minimum number of effective negative samples.
  // Only applies if expected_loss_weights is not NONE.
  // TODO(birdbrain): Move under ExpectedLossWeights.
  optional float min_num_negative_samples = 19 [default=0];

  // Desired number of effective negative samples per positive sample.
  // Only applies if expected_loss_weights is not NONE.
  // TODO(birdbrain): Move under ExpectedLossWeights.
  optional float desired_negative_sampling_ratio = 20 [default=3];
64
65
66
67
68
69
70
71
}

// Configuration for bounding box localization loss function.
message LocalizationLoss {
  oneof localization_loss {
    WeightedL2LocalizationLoss weighted_l2 = 1;
    WeightedSmoothL1LocalizationLoss weighted_smooth_l1 = 2;
    WeightedIOULocalizationLoss weighted_iou = 3;
72
    L1LocalizationLoss l1_localization_loss = 4;
73
    WeightedGIOULocalizationLoss weighted_giou = 5;
74
75
76
77
78
  }
}

// L2 location loss: 0.5 * ||weight * (a - b)|| ^ 2
message WeightedL2LocalizationLoss {
79
  // DEPRECATED, do not use.
80
81
82
83
  // Output loss per anchor.
  optional bool anchorwise_output = 1 [default=false];
}

84
85
// SmoothL1 (Huber) location loss.
// The smooth L1_loss is defined elementwise as .5 x^2 if |x| <= delta and
86
// delta * (|x|-0.5*delta) otherwise, where x is the difference between
87
// predictions and target.
88
message WeightedSmoothL1LocalizationLoss {
89
  // DEPRECATED, do not use.
90
91
  // Output loss per anchor.
  optional bool anchorwise_output = 1 [default=false];
92
93
94

  // Delta value for huber loss.
  optional float delta = 2 [default=1.0];
95
96
97
98
99
100
}

// Intersection over union location loss: 1 - IOU
message WeightedIOULocalizationLoss {
}

101
102
103
104
// L1 Localization Loss.
message L1LocalizationLoss {
}

105
106
107
108
// Generalized intersection over union location loss: 1 - GIOU
message WeightedGIOULocalizationLoss {
}

109
110
111
112
113
// Configuration for class prediction loss function.
message ClassificationLoss {
  oneof classification_loss {
    WeightedSigmoidClassificationLoss weighted_sigmoid = 1;
    WeightedSoftmaxClassificationLoss weighted_softmax = 2;
114
    WeightedSoftmaxClassificationAgainstLogitsLoss weighted_logits_softmax = 5;
115
    BootstrappedSigmoidClassificationLoss bootstrapped_sigmoid = 3;
Vivek Rathod's avatar
Vivek Rathod committed
116
    SigmoidFocalClassificationLoss weighted_sigmoid_focal = 4;
117
    PenaltyReducedLogisticFocalLoss penalty_reduced_logistic_focal_loss = 6;
Vighnesh Birodkar's avatar
Vighnesh Birodkar committed
118
    WeightedDiceClassificationLoss weighted_dice_classification_loss = 7;
119
120
121
122
123
  }
}

// Classification loss using a sigmoid function over class predictions.
message WeightedSigmoidClassificationLoss {
124
  // DEPRECATED, do not use.
125
126
127
128
  // Output loss per anchor.
  optional bool anchorwise_output = 1 [default=false];
}

Vivek Rathod's avatar
Vivek Rathod committed
129
130
131
// Sigmoid Focal cross entropy loss as described in
// https://arxiv.org/abs/1708.02002
message SigmoidFocalClassificationLoss {
132
  // DEPRECATED, do not use.
Vivek Rathod's avatar
Vivek Rathod committed
133
134
135
136
137
138
139
  optional bool anchorwise_output = 1 [default = false];
  // modulating factor for the loss.
  optional float gamma = 2 [default = 2.0];
  // alpha weighting factor for the loss.
  optional float alpha = 3;
}

140
141
// Classification loss using a softmax function over class predictions.
message WeightedSoftmaxClassificationLoss {
142
  // DEPRECATED, do not use.
143
144
  // Output loss per anchor.
  optional bool anchorwise_output = 1 [default=false];
Vivek Rathod's avatar
Vivek Rathod committed
145
146
147
  // Scale logit (input) value before calculating softmax classification loss.
  // Typically used for softmax distillation.
  optional float logit_scale = 2 [default = 1.0];
148
149
}

150
151
152
153
154
155
156
157
158
159
160
// Classification loss using a softmax function over class predictions and
// a softmax function over the groundtruth labels (assumed to be logits).
message WeightedSoftmaxClassificationAgainstLogitsLoss {
  // DEPRECATED, do not use.
  optional bool anchorwise_output = 1 [default = false];
  // Scale and softmax groundtruth logits before calculating softmax
  // classification loss. Typically used for softmax distillation with teacher
  // annotations stored as logits.
  optional float logit_scale = 2 [default = 1.0];
}

161
162
163
164
165
166
167
168
169
170
171
// Classification loss using a sigmoid function over the class prediction with
// the highest prediction score.
message BootstrappedSigmoidClassificationLoss {
  // Interpolation weight between 0 and 1.
  optional float alpha = 1;

  // Whether hard boot strapping should be used or not. If true, will only use
  // one class favored by model. Othewise, will use all predicted class
  // probabilities.
  optional bool hard_bootstrap = 2 [default=false];

172
  // DEPRECATED, do not use.
173
174
175
176
  // Output loss per anchor.
  optional bool anchorwise_output = 3 [default=false];
}

177
178
179
180
181
182
183
184
185
186
187
// Pixelwise logistic focal loss with pixels near the target having a reduced
// penalty.
message PenaltyReducedLogisticFocalLoss {

  // Focussing parameter of the focal loss.
  optional float alpha = 1;

  // Penalty reduction factor.
  optional float beta = 2;
}

188
// Configuration for hard example miner.
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
message HardExampleMiner {
  // Maximum number of hard examples to be selected per image (prior to
  // enforcing max negative to positive ratio constraint).  If set to 0,
  // all examples obtained after NMS are considered.
  optional int32 num_hard_examples = 1 [default=64];

  // Minimum intersection over union for an example to be discarded during NMS.
  optional float iou_threshold = 2 [default=0.7];

  // Whether to use classification losses ('cls', default), localization losses
  // ('loc') or both losses ('both'). In the case of 'both', cls_loss_weight and
  // loc_loss_weight are used to compute weighted sum of the two losses.
  enum LossType {
    BOTH = 0;
    CLASSIFICATION = 1;
    LOCALIZATION = 2;
  }
  optional LossType loss_type = 3 [default=BOTH];

  // Maximum number of negatives to retain for each positive anchor. If
  // num_negatives_per_positive is 0 no prespecified negative:positive ratio is
  // enforced.
  optional int32 max_negatives_per_positive = 4 [default=0];

  // Minimum number of negative anchors to sample for a given image. Setting
  // this to a positive number samples negatives in an image without any
  // positive anchors and thus not bias the model towards having at least one
  // detection per image.
  optional int32 min_negatives_per_image = 5 [default=0];
}
219
220
221
222
223
224
225

// Configuration for random example sampler.
message RandomExampleSampler {
  // The desired fraction of positive samples in batch when applying random
  // example sampling.
  optional float positive_sample_fraction = 1 [default = 0.01];
}
Vighnesh Birodkar's avatar
Vighnesh Birodkar committed
226
227
228
229
230
231
232
233
234
235
236

// Dice loss for training instance masks[1][2].
// [1]: https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
// [2]: https://arxiv.org/abs/1606.04797
message WeightedDiceClassificationLoss {
  // If set, we square the probabilities in the denominator term used for
  // normalization.
  optional bool squared_normalization = 1 [default=false];
}