Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
e0e440e1
Commit
e0e440e1
authored
Aug 04, 2020
by
Kaushik Shivakumar
Browse files
fix PR
parent
eb795bf7
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
208 additions
and
0 deletions
+208
-0
research/object_detection/core/losses.py
research/object_detection/core/losses.py
+34
-0
research/object_detection/core/losses_test.py
research/object_detection/core/losses_test.py
+40
-0
research/object_detection/utils/ops.py
research/object_detection/utils/ops.py
+57
-0
research/object_detection/utils/ops_test.py
research/object_detection/utils/ops_test.py
+77
-0
No files found.
research/object_detection/core/losses.py
View file @
e0e440e1
...
@@ -36,6 +36,7 @@ import tensorflow.compat.v1 as tf
...
@@ -36,6 +36,7 @@ import tensorflow.compat.v1 as tf
from
object_detection.core
import
box_list
from
object_detection.core
import
box_list
from
object_detection.core
import
box_list_ops
from
object_detection.core
import
box_list_ops
from
object_detection.utils
import
ops
from
object_detection.utils
import
ops
from
object_detection.utils
import
shape_utils
class
Loss
(
six
.
with_metaclass
(
abc
.
ABCMeta
,
object
)):
class
Loss
(
six
.
with_metaclass
(
abc
.
ABCMeta
,
object
)):
...
@@ -210,6 +211,39 @@ class WeightedIOULocalizationLoss(Loss):
...
@@ -210,6 +211,39 @@ class WeightedIOULocalizationLoss(Loss):
return
tf
.
reshape
(
weights
,
[
-
1
])
*
per_anchor_iou_loss
return
tf
.
reshape
(
weights
,
[
-
1
])
*
per_anchor_iou_loss
class
WeightedGIOULocalizationLoss
(
Loss
):
"""GIOU localization loss function.
Sums the GIOU loss for corresponding pairs of predicted/groundtruth boxes
and for each pair assign a loss of 1 - GIOU. We then compute a weighted
sum over all pairs which is returned as the total loss.
"""
def
_compute_loss
(
self
,
prediction_tensor
,
target_tensor
,
weights
):
"""Compute loss function.
Args:
prediction_tensor: A float tensor of shape [batch_size, num_anchors, 4]
representing the decoded predicted boxes
target_tensor: A float tensor of shape [batch_size, num_anchors, 4]
representing the decoded target boxes
weights: a float tensor of shape [batch_size, num_anchors]
Returns:
loss: a float tensor of shape [batch_size, num_anchors] tensor
representing the value of the loss function.
"""
batch_size
,
num_anchors
,
_
=
shape_utils
.
combined_static_and_dynamic_shape
(
prediction_tensor
)
predicted_boxes
=
tf
.
reshape
(
prediction_tensor
,
[
-
1
,
4
])
target_boxes
=
tf
.
reshape
(
target_tensor
,
[
-
1
,
4
])
per_anchor_iou_loss
=
1
-
ops
.
giou
(
predicted_boxes
,
target_boxes
)
return
tf
.
reshape
(
tf
.
reshape
(
weights
,
[
-
1
])
*
per_anchor_iou_loss
,
[
batch_size
,
num_anchors
])
class
WeightedSigmoidClassificationLoss
(
Loss
):
class
WeightedSigmoidClassificationLoss
(
Loss
):
"""Sigmoid cross entropy classification loss function."""
"""Sigmoid cross entropy classification loss function."""
...
...
research/object_detection/core/losses_test.py
View file @
e0e440e1
...
@@ -197,6 +197,46 @@ class WeightedIOULocalizationLossTest(test_case.TestCase):
...
@@ -197,6 +197,46 @@ class WeightedIOULocalizationLossTest(test_case.TestCase):
loss_output
=
self
.
execute
(
graph_fn
,
[])
loss_output
=
self
.
execute
(
graph_fn
,
[])
self
.
assertAllClose
(
loss_output
,
exp_loss
)
self
.
assertAllClose
(
loss_output
,
exp_loss
)
class
WeightedGIOULocalizationLossTest
(
test_case
.
TestCase
):
def
testReturnsCorrectLoss
(
self
):
def
graph_fn
():
prediction_tensor
=
tf
.
constant
([[[
1.5
,
0
,
2.4
,
1
],
[
0
,
0
,
1
,
1
],
[
0
,
0
,
0
,
0
]]])
target_tensor
=
tf
.
constant
([[[
1.5
,
0
,
2.4
,
1
],
[
0
,
0
,
1
,
1
],
[
5
,
5
,
10
,
10
]]])
weights
=
[[
1.0
,
.
5
,
2.0
]]
loss_op
=
losses
.
WeightedGIOULocalizationLoss
()
loss
=
loss_op
(
prediction_tensor
,
target_tensor
,
weights
=
weights
)
loss
=
tf
.
reduce_sum
(
loss
)
return
loss
exp_loss
=
3.5
loss_output
=
self
.
execute
(
graph_fn
,
[])
self
.
assertAllClose
(
loss_output
,
exp_loss
)
def
testReturnsCorrectLossWithNoLabels
(
self
):
def
graph_fn
():
prediction_tensor
=
tf
.
constant
([[[
1.5
,
0
,
2.4
,
1
],
[
0
,
0
,
1
,
1
],
[
0
,
0
,
.
5
,
.
25
]]])
target_tensor
=
tf
.
constant
([[[
1.5
,
0
,
2.4
,
1
],
[
0
,
0
,
1
,
1
],
[
50
,
50
,
500.5
,
100.25
]]])
weights
=
[[
1.0
,
.
5
,
2.0
]]
losses_mask
=
tf
.
constant
([
False
],
tf
.
bool
)
loss_op
=
losses
.
WeightedGIOULocalizationLoss
()
loss
=
loss_op
(
prediction_tensor
,
target_tensor
,
weights
=
weights
,
losses_mask
=
losses_mask
)
loss
=
tf
.
reduce_sum
(
loss
)
return
loss
exp_loss
=
0.0
loss_output
=
self
.
execute
(
graph_fn
,
[])
self
.
assertAllClose
(
loss_output
,
exp_loss
)
class
WeightedSigmoidClassificationLossTest
(
test_case
.
TestCase
):
class
WeightedSigmoidClassificationLossTest
(
test_case
.
TestCase
):
...
...
research/object_detection/utils/ops.py
View file @
e0e440e1
...
@@ -1134,3 +1134,60 @@ def decode_image(tensor_dict):
...
@@ -1134,3 +1134,60 @@ def decode_image(tensor_dict):
tensor_dict
[
fields
.
InputDataFields
.
image
],
channels
=
3
)
tensor_dict
[
fields
.
InputDataFields
.
image
],
channels
=
3
)
tensor_dict
[
fields
.
InputDataFields
.
image
].
set_shape
([
None
,
None
,
3
])
tensor_dict
[
fields
.
InputDataFields
.
image
].
set_shape
([
None
,
None
,
3
])
return
tensor_dict
return
tensor_dict
def
giou
(
boxes1
,
boxes2
):
"""
Computes generalized IOU between two tensors. Each box should be
represented as [ymin, xmin, ymax, xmax].
Args:
boxes1: a tensor with shape [num_boxes, 4]
boxes2: a tensor with shape [num_boxes, 4]
Returns:
a tensor of shape [num_boxes] containing GIoUs
"""
def
two_boxes_giou
(
boxes
):
boxes1
,
boxes2
=
boxes
pred_ymin
,
pred_xmin
,
pred_ymax
,
pred_xmax
=
tf
.
unstack
(
boxes1
)
gt_ymin
,
gt_xmin
,
gt_ymax
,
gt_xmax
=
tf
.
unstack
(
boxes2
)
gt_area
=
(
gt_ymax
-
gt_ymin
)
*
(
gt_xmax
-
gt_xmin
)
pred_area
=
(
pred_ymax
-
pred_ymin
)
*
(
pred_xmax
-
pred_xmin
)
x1I
=
tf
.
maximum
(
pred_xmin
,
gt_xmin
)
x2I
=
tf
.
minimum
(
pred_xmax
,
gt_xmax
)
y1I
=
tf
.
maximum
(
pred_ymin
,
gt_ymin
)
y2I
=
tf
.
minimum
(
pred_ymax
,
gt_ymax
)
intersection_area
=
tf
.
maximum
(
0.0
,
y2I
-
y1I
)
*
tf
.
maximum
(
0.0
,
x2I
-
x1I
)
x1C
=
tf
.
minimum
(
pred_xmin
,
gt_xmin
)
x2C
=
tf
.
maximum
(
pred_xmax
,
gt_xmax
)
y1C
=
tf
.
minimum
(
pred_ymin
,
gt_ymin
)
y2C
=
tf
.
maximum
(
pred_ymax
,
gt_ymax
)
hull_area
=
(
y2C
-
y1C
)
*
(
x2C
-
x1C
)
union_area
=
gt_area
+
pred_area
-
intersection_area
IoU
=
tf
.
where
(
tf
.
equal
(
union_area
,
0.0
),
0.0
,
intersection_area
/
union_area
)
gIoU
=
IoU
-
tf
.
where
(
hull_area
>
0.0
,
(
hull_area
-
union_area
)
/
hull_area
,
IoU
)
return
gIoU
return
shape_utils
.
static_or_dynamic_map_fn
(
two_boxes_giou
,
[
boxes1
,
boxes2
])
def
center_to_corner_coordinate
(
input_tensor
):
"""Converts input boxes from center to corner representation."""
reshaped_encodings
=
tf
.
reshape
(
input_tensor
,
[
-
1
,
4
])
ycenter
=
tf
.
gather
(
reshaped_encodings
,
[
0
],
axis
=
1
)
xcenter
=
tf
.
gather
(
reshaped_encodings
,
[
1
],
axis
=
1
)
h
=
tf
.
gather
(
reshaped_encodings
,
[
2
],
axis
=
1
)
w
=
tf
.
gather
(
reshaped_encodings
,
[
3
],
axis
=
1
)
ymin
=
ycenter
-
h
/
2.
xmin
=
xcenter
-
w
/
2.
ymax
=
ycenter
+
h
/
2.
xmax
=
xcenter
+
w
/
2.
return
tf
.
squeeze
(
tf
.
stack
([
ymin
,
xmin
,
ymax
,
xmax
],
axis
=
1
))
\ No newline at end of file
research/object_detection/utils/ops_test.py
View file @
e0e440e1
...
@@ -1630,8 +1630,85 @@ class TestGatherWithPaddingValues(test_case.TestCase):
...
@@ -1630,8 +1630,85 @@ class TestGatherWithPaddingValues(test_case.TestCase):
self
.
assertAllClose
(
expected_gathered_tensor
,
gathered_tensor_np
)
self
.
assertAllClose
(
expected_gathered_tensor
,
gathered_tensor_np
)
class
TestGIoU
(
test_case
.
TestCase
):
def
test_giou_general
(
self
):
expected_giou_tensor
=
[
0
,
-
1
/
3
,
1
/
25
,
-
3
/
4
,
0
,
-
98
/
100
]
def
graph_fn
():
boxes1
=
tf
.
constant
([[
3
,
4
,
5
,
6
],
[
3
,
3
,
5
,
5
],
[
2
,
1
,
7
,
6
],
[
0
,
0
,
0
,
0
],
[
3
,
3
,
5
,
5
],
[
9
,
9
,
10
,
10
]],
dtype
=
tf
.
float32
)
boxes2
=
tf
.
constant
([[
3
,
2
,
5
,
4
],
[
3
,
7
,
5
,
9
],
[
4
,
3
,
5
,
4
],
[
5
,
5
,
10
,
10
],
[
3
,
5
,
5
,
7
],
[
0
,
0
,
1
,
1
]],
dtype
=
tf
.
float32
)
giou
=
ops
.
giou
(
boxes1
,
boxes2
)
self
.
assertEqual
(
giou
.
dtype
,
tf
.
float32
)
return
giou
giou
=
self
.
execute
(
graph_fn
,
[])
self
.
assertAllClose
(
expected_giou_tensor
,
giou
)
def
test_giou_edge_cases
(
self
):
expected_giou_tensor
=
[
1
,
0
]
def
graph_fn
():
boxes1
=
tf
.
constant
([[
3
,
3
,
5
,
5
],
[
1
,
1
,
1
,
1
]],
dtype
=
tf
.
float32
)
boxes2
=
tf
.
constant
([[
3
,
3
,
5
,
5
],
[
1
,
1
,
1
,
1
]],
dtype
=
tf
.
float32
)
giou
=
ops
.
giou
(
boxes1
,
boxes2
)
self
.
assertEqual
(
giou
.
dtype
,
tf
.
float32
)
return
giou
giou
=
self
.
execute
(
graph_fn
,
[])
self
.
assertAllClose
(
expected_giou_tensor
,
giou
)
def
test_giou_l1_same
(
self
):
expected_giou_tensor
=
[
2
/
3
,
3
/
5
]
def
graph_fn
():
boxes1
=
tf
.
constant
([[
3
,
3
,
5
,
5
],
[
3
,
3
,
5
,
5
]],
dtype
=
tf
.
float32
)
boxes2
=
tf
.
constant
([[
3
,
2.5
,
5
,
5.5
],
[
3
,
2.5
,
5
,
4.5
]],
dtype
=
tf
.
float32
)
giou
=
ops
.
giou
(
boxes1
,
boxes2
)
self
.
assertEqual
(
giou
.
dtype
,
tf
.
float32
)
return
giou
giou
=
self
.
execute
(
graph_fn
,
[])
self
.
assertAllClose
(
expected_giou_tensor
,
giou
)
class
TestCoordinateConversion
(
test_case
.
TestCase
):
def
test_coord_conv
(
self
):
expected_box_tensor
=
[
[
0.5
,
0.5
,
5.5
,
5.5
],
[
2
,
1
,
4
,
7
],
[
0
,
0
,
0
,
0
]
]
def
graph_fn
():
boxes
=
tf
.
constant
([[
3
,
3
,
5
,
5
],
[
3
,
4
,
2
,
6
],
[
0
,
0
,
0
,
0
]],
dtype
=
tf
.
float32
)
converted
=
ops
.
center_to_corner_coordinate
(
boxes
)
self
.
assertEqual
(
converted
.
dtype
,
tf
.
float32
)
return
converted
converted
=
self
.
execute
(
graph_fn
,
[])
self
.
assertAllClose
(
expected_box_tensor
,
converted
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment