Commit 5c7ec0df authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by TF Object Detection Team
Browse files

Register GIOU loss in losses.proto & losses_builder.

This CL register a pre-defined GIOU loss to the losses.proto and losses_builder so that the GIOU loss can be correctly called & built when GIOU loss is specified in the loss config file.

PiperOrigin-RevId: 365891022
parent 604abc08
...@@ -204,6 +204,9 @@ def _build_localization_loss(loss_config): ...@@ -204,6 +204,9 @@ def _build_localization_loss(loss_config):
if loss_type == 'l1_localization_loss': if loss_type == 'l1_localization_loss':
return losses.L1LocalizationLoss() return losses.L1LocalizationLoss()
if loss_type == 'weighted_giou':
return losses.WeightedGIOULocalizationLoss()
raise ValueError('Empty loss config.') raise ValueError('Empty loss config.')
......
...@@ -97,6 +97,23 @@ class LocalizationLossBuilderTest(tf.test.TestCase): ...@@ -97,6 +97,23 @@ class LocalizationLossBuilderTest(tf.test.TestCase):
self.assertIsInstance(localization_loss, self.assertIsInstance(localization_loss,
losses.WeightedIOULocalizationLoss) losses.WeightedIOULocalizationLoss)
def test_build_weighted_giou_localization_loss(self):
losses_text_proto = """
localization_loss {
weighted_giou {
}
}
classification_loss {
weighted_softmax {
}
}
"""
losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto)
_, localization_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertIsInstance(localization_loss,
losses.WeightedGIOULocalizationLoss)
def test_anchorwise_output(self): def test_anchorwise_output(self):
losses_text_proto = """ losses_text_proto = """
localization_loss { localization_loss {
......
...@@ -70,6 +70,7 @@ message LocalizationLoss { ...@@ -70,6 +70,7 @@ message LocalizationLoss {
WeightedSmoothL1LocalizationLoss weighted_smooth_l1 = 2; WeightedSmoothL1LocalizationLoss weighted_smooth_l1 = 2;
WeightedIOULocalizationLoss weighted_iou = 3; WeightedIOULocalizationLoss weighted_iou = 3;
L1LocalizationLoss l1_localization_loss = 4; L1LocalizationLoss l1_localization_loss = 4;
WeightedGIOULocalizationLoss weighted_giou = 5;
} }
} }
...@@ -101,6 +102,10 @@ message WeightedIOULocalizationLoss { ...@@ -101,6 +102,10 @@ message WeightedIOULocalizationLoss {
message L1LocalizationLoss { message L1LocalizationLoss {
} }
// Generalized intersection over union location loss: 1 - GIOU
message WeightedGIOULocalizationLoss {
}
// Configuration for class prediction loss function. // Configuration for class prediction loss function.
message ClassificationLoss { message ClassificationLoss {
oneof classification_loss { oneof classification_loss {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment