Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
e09e0566
Commit
e09e0566
authored
Aug 12, 2020
by
Kaushik Shivakumar
Browse files
target assigner and similarity calculator fixes
parent
d54c86de
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
41 additions
and
55 deletions
+41
-55
research/object_detection/core/region_similarity_calculator.py
...rch/object_detection/core/region_similarity_calculator.py
+1
-1
research/object_detection/core/region_similarity_calculator_test.py
...bject_detection/core/region_similarity_calculator_test.py
+2
-1
research/object_detection/core/target_assigner.py
research/object_detection/core/target_assigner.py
+31
-40
research/object_detection/core/target_assigner_test.py
research/object_detection/core/target_assigner_test.py
+7
-13
No files found.
research/object_detection/core/region_similarity_calculator.py
View file @
e09e0566
...
...
@@ -53,7 +53,7 @@ class RegionSimilarityCalculator(six.with_metaclass(ABCMeta, object)):
a (float32) tensor of shape [N, M] with pairwise similarity score.
"""
with
tf
.
name_scope
(
scope
,
'Compare'
,
[
boxlist1
,
boxlist2
])
as
scope
:
return
self
.
_compare
(
boxlist1
,
boxlist2
,
groundtruth_labels
,
predicted_labels
)
return
self
.
_compare
(
boxlist1
,
boxlist2
)
@
abstractmethod
def
_compare
(
self
,
boxlist1
,
boxlist2
,
...
...
research/object_detection/core/region_similarity_calculator_test.py
View file @
e09e0566
...
...
@@ -103,7 +103,8 @@ class RegionSimilarityCalculatorTest(test_case.TestCase):
boxes2
=
box_list
.
BoxList
(
corners2
)
boxes1
.
add_field
(
fields
.
BoxListFields
.
classes
,
groundtruth_labels
)
boxes2
.
add_field
(
fields
.
BoxListFields
.
classes
,
predicted_labels
)
detr_similarity_calculator
=
region_similarity_calculator
.
DETRSimilarity
()
detr_similarity_calculator
=
\
region_similarity_calculator
.
DETRSimilarity
()
detr_similarity
=
detr_similarity_calculator
.
compare
(
boxes1
,
boxes2
,
None
)
return
detr_similarity
...
...
research/object_detection/core/target_assigner.py
View file @
e09e0566
...
...
@@ -437,9 +437,7 @@ def create_target_assigner(reference, stage=None,
box_coder_instance
=
faster_rcnn_box_coder
.
FasterRcnnBoxCoder
()
elif
reference
==
'DETR'
:
similarity_calc
=
sim_calc
.
DETRSimilarity
()
matcher
=
hungarian_matcher
.
HungarianBipartiteMatcher
()
return
DETRTargetAssigner
(
similarity_calc
,
matcher
)
return
DETRTargetAssigner
()
else
:
raise
ValueError
(
'No valid combination of reference and stage.'
)
...
...
@@ -1917,9 +1915,7 @@ class CenterNetCornerOffsetTargetAssigner(object):
class
DETRTargetAssigner
(
object
):
"""Target assigner to compute classification and regression targets."""
def
__init__
(
self
,
matcher
,
negative_class_weight
=
1.0
):
def
__init__
(
self
,
negative_class_weight
=
1.0
):
"""Construct Object Detection Target Assigner.
Args:
...
...
@@ -1931,8 +1927,6 @@ class DETRTargetAssigner(object):
boxes (default: 1.0). The weight must be in [0., 1.].
"""
if
not
isinstance
(
matcher
,
mat
.
Matcher
):
raise
ValueError
(
'matcher must be a Matcher'
)
self
.
_similarity_calc
=
sim_calc
.
DETRSimilarity
()
self
.
_matcher
=
hungarian_matcher
.
HungarianBipartiteMatcher
()
self
.
_negative_class_weight
=
negative_class_weight
...
...
@@ -2024,9 +2018,6 @@ class DETRTargetAssigner(object):
groundtruth_boxes
.
add_field
(
fields
.
BoxListFields
.
classes
,
groundtruth_labels
)
box_preds
.
add_field
(
fields
.
BoxListFields
.
classes
,
class_predictions
)
with
tf
.
control_dependencies
(
[
unmatched_shape_assert
,
labels_and_box_shapes_assert
]):
match_quality_matrix
=
self
.
_similarity_calc
.
compare
(
groundtruth_boxes
,
box_preds
)
...
...
research/object_detection/core/target_assigner_test.py
View file @
e09e0566
...
...
@@ -19,7 +19,6 @@ import tensorflow.compat.v1 as tf
from
object_detection.box_coders
import
keypoint_box_coder
from
object_detection.box_coders
import
mean_stddev_box_coder
from
object_detection.box_coders
import
detr_box_coder
from
object_detection.core
import
box_list
from
object_detection.core
import
region_similarity_calculator
from
object_detection.core
import
standard_fields
as
fields
...
...
@@ -2192,20 +2191,11 @@ class CornerOffsetTargetAssignerTest(test_case.TestCase):
self
.
assertAllClose
(
foreground
,
np
.
zeros
((
1
,
5
,
5
)))
if
__name__
==
'__main__'
:
tf
.
enable_v2_behavior
()
tf
.
test
.
main
()
class
DETRTargetAssignerTest
(
testcase
.
TestCase
):
class
DETRTargetAssignerTest
(
test_case
.
TestCase
):
def
test_assign_detr
(
self
):
def
graph_fn
(
anchor_means
,
groundtruth_box_corners
,
groundtruth_labels
,
predicted_labels
):
similarity_calc
=
region_similarity_calculator
.
DETRSimilarity
()
matcher
=
hungarian_matcher
.
HungarianBipartiteMatcher
()
box_coder
=
detr_box_coder
.
DETRBoxCoder
()
detr_target_assigner
=
target_assigner
.
DETRTargetAssigner
(
similarity_calc
,
matcher
,
box_coder
)
detr_target_assigner
=
targetassigner
.
DETRTargetAssigner
()
anchors_boxlist
=
box_list
.
BoxList
(
anchor_means
)
groundtruth_boxlist
=
box_list
.
BoxList
(
groundtruth_box_corners
)
result
=
detr_target_assigner
.
assign
(
...
...
@@ -2248,3 +2238,7 @@ class DETRTargetAssignerTest(testcase.TestCase):
self
.
assertEqual
(
cls_weights_out
.
dtype
,
np
.
float32
)
self
.
assertEqual
(
reg_targets_out
.
dtype
,
np
.
float32
)
self
.
assertEqual
(
reg_weights_out
.
dtype
,
np
.
float32
)
if
__name__
==
'__main__'
:
tf
.
enable_v2_behavior
()
tf
.
test
.
main
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment