Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
f0bc684c
Commit
f0bc684c
authored
Jul 30, 2020
by
Kaushik Shivakumar
Browse files
fix this pr
parent
dabfc27b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
38 deletions
+21
-38
research/object_detection/box_coders/detr_box_coder.py
research/object_detection/box_coders/detr_box_coder.py
+14
-32
research/object_detection/core/target_assigner_test.py
research/object_detection/core/target_assigner_test.py
+7
-6
No files found.
research/object_detection/box_coders/detr_box_coder.py
View file @
f0bc684c
...
@@ -13,13 +13,13 @@
...
@@ -13,13 +13,13 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""
Faster RCNN
box coder.
"""
DETR
box coder.
Faster RCNN
box coder follows the coding schema described below:
DETR
box coder follows the coding schema described below:
ty =
(y - ya) / ha
ty =
y
tx =
(x - xa) / wa
tx =
x
th =
log(h / ha)
th =
h
tw =
log(w / wa)
tw =
w
where x, y, w, h denote the box's center coordinates, width and height
where x, y, w, h denote the box's center coordinates, width and height
respectively. Similarly, xa, ya, wa, ha denote the anchor's center
respectively. Similarly, xa, ya, wa, ha denote the anchor's center
coordinates, width and height. tx, ty, tw and th denote the anchor-encoded
coordinates, width and height. tx, ty, tw and th denote the anchor-encoded
...
@@ -39,19 +39,15 @@ EPSILON = 1e-8
...
@@ -39,19 +39,15 @@ EPSILON = 1e-8
class
DETRBoxCoder
(
box_coder
.
BoxCoder
):
class
DETRBoxCoder
(
box_coder
.
BoxCoder
):
"""Faster RCNN box coder."""
"""Faster RCNN box coder."""
def
__init__
(
self
,
scale_factors
=
None
):
def
__init__
(
self
):
"""Constructor for
FasterRcnn
BoxCoder.
"""Constructor for
DETR
BoxCoder.
Args:
Args:
scale_factors: List of 4 positive scalars to scale ty, tx, th and tw.
scale_factors: List of 4 positive scalars to scale ty, tx, th and tw.
If set to None, does not perform scaling. For Faster RCNN,
If set to None, does not perform scaling. For Faster RCNN,
the open-source implementation recommends using [10.0, 10.0, 5.0, 5.0].
the open-source implementation recommends using [10.0, 10.0, 5.0, 5.0].
"""
"""
if
None
:
pass
assert
len
(
scale_factors
)
==
4
for
scalar
in
scale_factors
:
assert
scalar
>
0
self
.
_scale_factors
=
scale_factors
@
property
@
property
def
code_size
(
self
):
def
code_size
(
self
):
...
@@ -69,16 +65,7 @@ class DETRBoxCoder(box_coder.BoxCoder):
...
@@ -69,16 +65,7 @@ class DETRBoxCoder(box_coder.BoxCoder):
[ty, tx, th, tw].
[ty, tx, th, tw].
"""
"""
# Convert anchors to the center coordinate representation.
# Convert anchors to the center coordinate representation.
ycenter
,
xcenter
,
h
,
w
=
boxes
.
get_center_coordinates_and_sizes
()
ty
,
tx
,
th
,
tw
=
boxes
.
get_center_coordinates_and_sizes
()
# Avoid NaN in division and log below.
h
+=
EPSILON
w
+=
EPSILON
tx
=
xcenter
ty
=
ycenter
tw
=
w
#tf.log(w)
th
=
h
#tf.log(h)
return
tf
.
transpose
(
tf
.
stack
([
ty
,
tx
,
th
,
tw
]))
return
tf
.
transpose
(
tf
.
stack
([
ty
,
tx
,
th
,
tw
]))
def
_decode
(
self
,
rel_codes
,
anchors
):
def
_decode
(
self
,
rel_codes
,
anchors
):
...
@@ -92,14 +79,9 @@ class DETRBoxCoder(box_coder.BoxCoder):
...
@@ -92,14 +79,9 @@ class DETRBoxCoder(box_coder.BoxCoder):
boxes: BoxList holding N bounding boxes.
boxes: BoxList holding N bounding boxes.
"""
"""
ty
,
tx
,
th
,
tw
=
tf
.
unstack
(
tf
.
transpose
(
rel_codes
))
ty
,
tx
,
th
,
tw
=
tf
.
unstack
(
tf
.
transpose
(
rel_codes
))
ymin
=
ty
-
th
/
2.
w
=
tw
xmin
=
tx
-
tw
/
2.
h
=
th
ymax
=
ty
+
th
/
2.
ycenter
=
ty
xmax
=
tx
+
tw
/
2.
xcenter
=
tx
ymin
=
ycenter
-
h
/
2.
xmin
=
xcenter
-
w
/
2.
ymax
=
ycenter
+
h
/
2.
xmax
=
xcenter
+
w
/
2.
return
box_list
.
BoxList
(
tf
.
transpose
(
tf
.
stack
([
ymin
,
xmin
,
ymax
,
xmax
])))
return
box_list
.
BoxList
(
tf
.
transpose
(
tf
.
stack
([
ymin
,
xmin
,
ymax
,
xmax
])))
research/object_detection/core/target_assigner_test.py
View file @
f0bc684c
...
@@ -135,8 +135,8 @@ class TargetAssignerTest(test_case.TestCase):
...
@@ -135,8 +135,8 @@ class TargetAssignerTest(test_case.TestCase):
(
cls_targets
,
cls_weights
,
reg_targets
,
reg_weights
,
_
)
=
result
(
cls_targets
,
cls_weights
,
reg_targets
,
reg_weights
,
_
)
=
result
return
(
cls_targets
,
cls_weights
,
reg_targets
,
reg_weights
)
return
(
cls_targets
,
cls_weights
,
reg_targets
,
reg_weights
)
anchor_means
=
np
.
array
([[
0.
0
,
0.
0
,
0.4
,
0.2
],
anchor_means
=
np
.
array
([[
0.
25
,
0.
25
,
0.4
,
0.2
],
[
0.5
,
0.
5
,
1.0
,
0.8
],
[
0.5
,
0.
8
,
1.0
,
0.8
],
[
0.9
,
0.5
,
0.1
,
1.0
]],
dtype
=
np
.
float32
)
[
0.9
,
0.5
,
0.1
,
1.0
]],
dtype
=
np
.
float32
)
groundtruth_box_corners
=
np
.
array
([[
0.0
,
0.0
,
0.5
,
0.5
],
groundtruth_box_corners
=
np
.
array
([[
0.0
,
0.0
,
0.5
,
0.5
],
[
0.5
,
0.5
,
0.9
,
0.9
]],
[
0.5
,
0.5
,
0.9
,
0.9
]],
...
@@ -146,10 +146,10 @@ class TargetAssignerTest(test_case.TestCase):
...
@@ -146,10 +146,10 @@ class TargetAssignerTest(test_case.TestCase):
groundtruth_labels
=
np
.
array
([[
0.0
,
1.0
],
[
0.0
,
1.0
]],
groundtruth_labels
=
np
.
array
([[
0.0
,
1.0
],
[
0.0
,
1.0
]],
dtype
=
np
.
float32
)
dtype
=
np
.
float32
)
exp_cls_targets
=
[[
1
],
[
1
],
[
0
]]
exp_cls_targets
=
[[
0
,
1
],
[
0
,
1
],
[
1
,
0
]]
exp_cls_weights
=
[[
1
],
[
1
],
[
1
]]
exp_cls_weights
=
[[
1
,
1
],
[
1
,
1
],
[
1
,
1
]]
exp_reg_targets
=
[[
0.
0
,
0.
0
,
0.5
,
0.5
],
exp_reg_targets
=
[[
0.
25
,
0.
25
,
0.5
,
0.5
],
[
0.
5
,
0.
5
,
0.
9
,
0.
9
],
[
0.
7
,
0.
7
,
0.
4
,
0.
4
],
[
0
,
0
,
0
,
0
]]
[
0
,
0
,
0
,
0
]]
exp_reg_weights
=
[
1
,
1
,
0
]
exp_reg_weights
=
[
1
,
1
,
0
]
...
@@ -157,6 +157,7 @@ class TargetAssignerTest(test_case.TestCase):
...
@@ -157,6 +157,7 @@ class TargetAssignerTest(test_case.TestCase):
cls_weights_out
,
reg_targets_out
,
reg_weights_out
)
=
self
.
execute
(
cls_weights_out
,
reg_targets_out
,
reg_weights_out
)
=
self
.
execute
(
graph_fn
,
[
anchor_means
,
groundtruth_box_corners
,
graph_fn
,
[
anchor_means
,
groundtruth_box_corners
,
groundtruth_labels
,
predicted_labels
])
groundtruth_labels
,
predicted_labels
])
self
.
assertAllClose
(
cls_targets_out
,
exp_cls_targets
)
self
.
assertAllClose
(
cls_targets_out
,
exp_cls_targets
)
self
.
assertAllClose
(
cls_weights_out
,
exp_cls_weights
)
self
.
assertAllClose
(
cls_weights_out
,
exp_cls_weights
)
self
.
assertAllClose
(
reg_targets_out
,
exp_reg_targets
)
self
.
assertAllClose
(
reg_targets_out
,
exp_reg_targets
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment