Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
f8b0f1dd
Commit
f8b0f1dd
authored
Feb 05, 2021
by
Vighnesh Birodkar
Committed by
TF Object Detection Team
Feb 05, 2021
Browse files
Add support for dice loss.
PiperOrigin-RevId: 355908695
parent
8a064338
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
242 additions
and
6 deletions
+242
-6
research/object_detection/builders/losses_builder.py
research/object_detection/builders/losses_builder.py
+12
-6
research/object_detection/builders/losses_builder_test.py
research/object_detection/builders/losses_builder_test.py
+39
-0
research/object_detection/core/losses.py
research/object_detection/core/losses.py
+73
-0
research/object_detection/core/losses_test.py
research/object_detection/core/losses_test.py
+106
-0
research/object_detection/protos/losses.proto
research/object_detection/protos/losses.proto
+12
-0
No files found.
research/object_detection/builders/losses_builder.py
View file @
f8b0f1dd
...
...
@@ -227,7 +227,7 @@ def _build_classification_loss(loss_config):
if
loss_type
==
'weighted_sigmoid'
:
return
losses
.
WeightedSigmoidClassificationLoss
()
if
loss_type
==
'weighted_sigmoid_focal'
:
el
if
loss_type
==
'weighted_sigmoid_focal'
:
config
=
loss_config
.
weighted_sigmoid_focal
alpha
=
None
if
config
.
HasField
(
'alpha'
):
...
...
@@ -236,25 +236,31 @@ def _build_classification_loss(loss_config):
gamma
=
config
.
gamma
,
alpha
=
alpha
)
if
loss_type
==
'weighted_softmax'
:
el
if
loss_type
==
'weighted_softmax'
:
config
=
loss_config
.
weighted_softmax
return
losses
.
WeightedSoftmaxClassificationLoss
(
logit_scale
=
config
.
logit_scale
)
if
loss_type
==
'weighted_logits_softmax'
:
el
if
loss_type
==
'weighted_logits_softmax'
:
config
=
loss_config
.
weighted_logits_softmax
return
losses
.
WeightedSoftmaxClassificationAgainstLogitsLoss
(
logit_scale
=
config
.
logit_scale
)
if
loss_type
==
'bootstrapped_sigmoid'
:
el
if
loss_type
==
'bootstrapped_sigmoid'
:
config
=
loss_config
.
bootstrapped_sigmoid
return
losses
.
BootstrappedSigmoidClassificationLoss
(
alpha
=
config
.
alpha
,
bootstrap_type
=
(
'hard'
if
config
.
hard_bootstrap
else
'soft'
))
if
loss_type
==
'penalty_reduced_logistic_focal_loss'
:
el
if
loss_type
==
'penalty_reduced_logistic_focal_loss'
:
config
=
loss_config
.
penalty_reduced_logistic_focal_loss
return
losses
.
PenaltyReducedLogisticFocalLoss
(
alpha
=
config
.
alpha
,
beta
=
config
.
beta
)
raise
ValueError
(
'Empty loss config.'
)
elif
loss_type
==
'weighted_dice_classification_loss'
:
config
=
loss_config
.
weighted_dice_classification_loss
return
losses
.
WeightedDiceClassificationLoss
(
squared_normalization
=
config
.
squared_normalization
)
else
:
raise
ValueError
(
'Empty loss config.'
)
research/object_detection/builders/losses_builder_test.py
View file @
f8b0f1dd
...
...
@@ -298,6 +298,45 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
with
self
.
assertRaises
(
ValueError
):
losses_builder
.
build
(
losses_proto
)
def
test_build_penalty_reduced_logistic_focal_loss
(
self
):
losses_text_proto
=
"""
classification_loss {
penalty_reduced_logistic_focal_loss {
alpha: 2.0
beta: 4.0
}
}
localization_loss {
l1_localization_loss {
}
}
"""
losses_proto
=
losses_pb2
.
Loss
()
text_format
.
Merge
(
losses_text_proto
,
losses_proto
)
classification_loss
,
_
,
_
,
_
,
_
,
_
,
_
=
losses_builder
.
build
(
losses_proto
)
self
.
assertIsInstance
(
classification_loss
,
losses
.
PenaltyReducedLogisticFocalLoss
)
self
.
assertAlmostEqual
(
classification_loss
.
_alpha
,
2.0
)
self
.
assertAlmostEqual
(
classification_loss
.
_beta
,
4.0
)
def
test_build_dice_loss
(
self
):
losses_text_proto
=
"""
classification_loss {
weighted_dice_classification_loss {
squared_normalization: true
}
}
localization_loss {
l1_localization_loss {
}
}
"""
losses_proto
=
losses_pb2
.
Loss
()
text_format
.
Merge
(
losses_text_proto
,
losses_proto
)
classification_loss
,
_
,
_
,
_
,
_
,
_
,
_
=
losses_builder
.
build
(
losses_proto
)
self
.
assertIsInstance
(
classification_loss
,
losses
.
WeightedDiceClassificationLoss
)
assert
classification_loss
.
_squared_normalization
class
HardExampleMinerBuilderTest
(
tf
.
test
.
TestCase
):
...
...
research/object_detection/core/losses.py
View file @
f8b0f1dd
...
...
@@ -278,6 +278,79 @@ class WeightedSigmoidClassificationLoss(Loss):
return
per_entry_cross_ent
*
weights
class
WeightedDiceClassificationLoss
(
Loss
):
"""Dice loss for classification [1][2].
[1]: https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
[2]: https://arxiv.org/abs/1606.04797
"""
def
__init__
(
self
,
squared_normalization
):
"""Initializes the loss object.
Args:
squared_normalization: boolean, if set, we square the probabilities in the
denominator term used for normalization.
"""
self
.
_squared_normalization
=
squared_normalization
super
(
WeightedDiceClassificationLoss
,
self
).
__init__
()
def
_compute_loss
(
self
,
prediction_tensor
,
target_tensor
,
weights
,
class_indices
=
None
):
"""Computes the loss value.
Dice loss uses the area of the ground truth and prediction tensors for
normalization. We compute area by summing along the anchors (2nd) dimension.
Args:
prediction_tensor: A float tensor of shape [batch_size, num_pixels,
num_classes] representing the predicted logits for each class.
num_pixels denotes the total number of pixels in the spatial dimensions
of the mask after flattening.
target_tensor: A float tensor of shape [batch_size, num_pixels,
num_classes] representing one-hot encoded classification targets.
num_pixels denotes the total number of pixels in the spatial dimensions
of the mask after flattening.
weights: a float tensor of shape, either [batch_size, num_anchors,
num_classes] or [batch_size, num_anchors, 1]. If the shape is
[batch_size, num_anchors, 1], all the classses are equally weighted.
class_indices: (Optional) A 1-D integer tensor of class indices.
If provided, computes loss only for the specified class indices.
Returns:
loss: a float tensor of shape [batch_size, num_classes]
representing the value of the loss function.
"""
if
class_indices
is
not
None
:
weights
*=
tf
.
reshape
(
ops
.
indices_to_dense_vector
(
class_indices
,
tf
.
shape
(
prediction_tensor
)[
2
]),
[
1
,
1
,
-
1
])
prob_tensor
=
tf
.
nn
.
sigmoid
(
prediction_tensor
)
if
self
.
_squared_normalization
:
prob_tensor
=
tf
.
pow
(
prob_tensor
,
2
)
target_tensor
=
tf
.
pow
(
target_tensor
,
2
)
prob_tensor
*=
weights
target_tensor
*=
weights
prediction_area
=
tf
.
reduce_sum
(
prob_tensor
,
axis
=
1
)
gt_area
=
tf
.
reduce_sum
(
target_tensor
,
axis
=
1
)
intersection
=
tf
.
reduce_sum
(
prob_tensor
*
target_tensor
,
axis
=
1
)
dice_coeff
=
2
*
intersection
/
tf
.
maximum
(
gt_area
+
prediction_area
,
1.0
)
dice_loss
=
1
-
dice_coeff
return
dice_loss
class
SigmoidFocalClassificationLoss
(
Loss
):
"""Sigmoid focal cross entropy loss.
...
...
research/object_detection/core/losses_test.py
View file @
f8b0f1dd
...
...
@@ -1447,5 +1447,111 @@ class L1LocalizationLossTest(test_case.TestCase):
self
.
assertAllClose
(
computed_value
,
[[
0.8
,
0.0
],
[
0.6
,
0.1
]],
rtol
=
1e-6
)
class
WeightedDiceClassificationLoss
(
test_case
.
TestCase
):
def
test_compute_weights_1
(
self
):
def
graph_fn
():
loss
=
losses
.
WeightedDiceClassificationLoss
(
squared_normalization
=
False
)
pred
=
np
.
zeros
((
2
,
3
,
4
),
dtype
=
np
.
float32
)
target
=
np
.
zeros
((
2
,
3
,
4
),
dtype
=
np
.
float32
)
pred
[
0
,
1
,
0
]
=
_logit
(
0.9
)
pred
[
0
,
2
,
0
]
=
_logit
(
0.1
)
pred
[
0
,
2
,
2
]
=
_logit
(
0.5
)
pred
[
0
,
1
,
3
]
=
_logit
(
0.1
)
pred
[
1
,
2
,
3
]
=
_logit
(
0.2
)
pred
[
1
,
1
,
1
]
=
_logit
(
0.3
)
pred
[
1
,
0
,
2
]
=
_logit
(
0.1
)
target
[
0
,
1
,
0
]
=
1.0
target
[
0
,
2
,
2
]
=
1.0
target
[
0
,
1
,
3
]
=
1.0
target
[
1
,
2
,
3
]
=
1.0
target
[
1
,
1
,
1
]
=
0.0
target
[
1
,
0
,
2
]
=
0.0
weights
=
np
.
ones_like
(
target
)
return
loss
.
_compute_loss
(
pred
,
target
,
weights
)
dice_coeff
=
np
.
zeros
((
2
,
4
))
dice_coeff
[
0
,
0
]
=
2
*
0.9
/
2.5
dice_coeff
[
0
,
2
]
=
2
*
0.5
/
2.5
dice_coeff
[
0
,
3
]
=
2
*
0.1
/
2.1
dice_coeff
[
1
,
3
]
=
2
*
0.2
/
2.2
computed_value
=
self
.
execute
(
graph_fn
,
[])
self
.
assertAllClose
(
computed_value
,
1
-
dice_coeff
,
rtol
=
1e-6
)
def
test_compute_weights_set
(
self
):
def
graph_fn
():
loss
=
losses
.
WeightedDiceClassificationLoss
(
squared_normalization
=
False
)
pred
=
np
.
zeros
((
2
,
3
,
4
),
dtype
=
np
.
float32
)
target
=
np
.
zeros
((
2
,
3
,
4
),
dtype
=
np
.
float32
)
pred
[
0
,
1
,
0
]
=
_logit
(
0.9
)
pred
[
0
,
2
,
0
]
=
_logit
(
0.1
)
pred
[
0
,
2
,
2
]
=
_logit
(
0.5
)
pred
[
0
,
1
,
3
]
=
_logit
(
0.1
)
pred
[
1
,
2
,
3
]
=
_logit
(
0.2
)
pred
[
1
,
1
,
1
]
=
_logit
(
0.3
)
pred
[
1
,
0
,
2
]
=
_logit
(
0.1
)
target
[
0
,
1
,
0
]
=
1.0
target
[
0
,
2
,
2
]
=
1.0
target
[
0
,
1
,
3
]
=
1.0
target
[
1
,
2
,
3
]
=
1.0
target
[
1
,
1
,
1
]
=
0.0
target
[
1
,
0
,
2
]
=
0.0
weights
=
np
.
ones_like
(
target
)
weights
[:,
:,
0
]
=
0.0
return
loss
.
_compute_loss
(
pred
,
target
,
weights
)
dice_coeff
=
np
.
zeros
((
2
,
4
))
dice_coeff
[
0
,
2
]
=
2
*
0.5
/
2.5
dice_coeff
[
0
,
3
]
=
2
*
0.1
/
2.1
dice_coeff
[
1
,
3
]
=
2
*
0.2
/
2.2
computed_value
=
self
.
execute
(
graph_fn
,
[])
self
.
assertAllClose
(
computed_value
,
1
-
dice_coeff
,
rtol
=
1e-6
)
def
test_class_indices
(
self
):
def
graph_fn
():
loss
=
losses
.
WeightedDiceClassificationLoss
(
squared_normalization
=
False
)
pred
=
np
.
zeros
((
2
,
3
,
4
),
dtype
=
np
.
float32
)
target
=
np
.
zeros
((
2
,
3
,
4
),
dtype
=
np
.
float32
)
pred
[
0
,
1
,
0
]
=
_logit
(
0.9
)
pred
[
0
,
2
,
0
]
=
_logit
(
0.1
)
pred
[
0
,
2
,
2
]
=
_logit
(
0.5
)
pred
[
0
,
1
,
3
]
=
_logit
(
0.1
)
pred
[
1
,
2
,
3
]
=
_logit
(
0.2
)
pred
[
1
,
1
,
1
]
=
_logit
(
0.3
)
pred
[
1
,
0
,
2
]
=
_logit
(
0.1
)
target
[
0
,
1
,
0
]
=
1.0
target
[
0
,
2
,
2
]
=
1.0
target
[
0
,
1
,
3
]
=
1.0
target
[
1
,
2
,
3
]
=
1.0
target
[
1
,
1
,
1
]
=
0.0
target
[
1
,
0
,
2
]
=
0.0
weights
=
np
.
ones_like
(
target
)
return
loss
.
_compute_loss
(
pred
,
target
,
weights
,
class_indices
=
[
0
])
dice_coeff
=
np
.
zeros
((
2
,
4
))
dice_coeff
[
0
,
0
]
=
2
*
0.9
/
2.5
computed_value
=
self
.
execute
(
graph_fn
,
[])
self
.
assertAllClose
(
computed_value
,
1
-
dice_coeff
,
rtol
=
1e-6
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/object_detection/protos/losses.proto
View file @
f8b0f1dd
...
...
@@ -110,6 +110,7 @@ message ClassificationLoss {
BootstrappedSigmoidClassificationLoss
bootstrapped_sigmoid
=
3
;
SigmoidFocalClassificationLoss
weighted_sigmoid_focal
=
4
;
PenaltyReducedLogisticFocalLoss
penalty_reduced_logistic_focal_loss
=
6
;
WeightedDiceClassificationLoss
weighted_dice_classification_loss
=
7
;
}
}
...
...
@@ -217,3 +218,14 @@ message RandomExampleSampler {
// example sampling.
optional
float
positive_sample_fraction
=
1
[
default
=
0.01
];
}
// Dice loss for training instance masks[1][2].
// [1]: https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
// [2]: https://arxiv.org/abs/1606.04797
message
WeightedDiceClassificationLoss
{
// If set, we square the probabilities in the denominator term used for
// normalization.
optional
bool
squared_normalization
=
1
[
default
=
false
];
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment