Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
e09e0566
Commit
e09e0566
authored
Aug 12, 2020
by
Kaushik Shivakumar
Browse files
target assigner and similarity calculator fixes
parent
d54c86de
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
41 additions
and
55 deletions
+41
-55
research/object_detection/core/region_similarity_calculator.py
...rch/object_detection/core/region_similarity_calculator.py
+1
-1
research/object_detection/core/region_similarity_calculator_test.py
...bject_detection/core/region_similarity_calculator_test.py
+2
-1
research/object_detection/core/target_assigner.py
research/object_detection/core/target_assigner.py
+31
-40
research/object_detection/core/target_assigner_test.py
research/object_detection/core/target_assigner_test.py
+7
-13
No files found.
research/object_detection/core/region_similarity_calculator.py
View file @
e09e0566
...
@@ -53,7 +53,7 @@ class RegionSimilarityCalculator(six.with_metaclass(ABCMeta, object)):
...
@@ -53,7 +53,7 @@ class RegionSimilarityCalculator(six.with_metaclass(ABCMeta, object)):
a (float32) tensor of shape [N, M] with pairwise similarity score.
a (float32) tensor of shape [N, M] with pairwise similarity score.
"""
"""
with
tf
.
name_scope
(
scope
,
'Compare'
,
[
boxlist1
,
boxlist2
])
as
scope
:
with
tf
.
name_scope
(
scope
,
'Compare'
,
[
boxlist1
,
boxlist2
])
as
scope
:
return
self
.
_compare
(
boxlist1
,
boxlist2
,
groundtruth_labels
,
predicted_labels
)
return
self
.
_compare
(
boxlist1
,
boxlist2
)
@
abstractmethod
@
abstractmethod
def
_compare
(
self
,
boxlist1
,
boxlist2
,
def
_compare
(
self
,
boxlist1
,
boxlist2
,
...
...
research/object_detection/core/region_similarity_calculator_test.py
View file @
e09e0566
...
@@ -103,7 +103,8 @@ class RegionSimilarityCalculatorTest(test_case.TestCase):
...
@@ -103,7 +103,8 @@ class RegionSimilarityCalculatorTest(test_case.TestCase):
boxes2
=
box_list
.
BoxList
(
corners2
)
boxes2
=
box_list
.
BoxList
(
corners2
)
boxes1
.
add_field
(
fields
.
BoxListFields
.
classes
,
groundtruth_labels
)
boxes1
.
add_field
(
fields
.
BoxListFields
.
classes
,
groundtruth_labels
)
boxes2
.
add_field
(
fields
.
BoxListFields
.
classes
,
predicted_labels
)
boxes2
.
add_field
(
fields
.
BoxListFields
.
classes
,
predicted_labels
)
detr_similarity_calculator
=
region_similarity_calculator
.
DETRSimilarity
()
detr_similarity_calculator
=
\
region_similarity_calculator
.
DETRSimilarity
()
detr_similarity
=
detr_similarity_calculator
.
compare
(
detr_similarity
=
detr_similarity_calculator
.
compare
(
boxes1
,
boxes2
,
None
)
boxes1
,
boxes2
,
None
)
return
detr_similarity
return
detr_similarity
...
...
research/object_detection/core/target_assigner.py
View file @
e09e0566
...
@@ -437,9 +437,7 @@ def create_target_assigner(reference, stage=None,
...
@@ -437,9 +437,7 @@ def create_target_assigner(reference, stage=None,
box_coder_instance
=
faster_rcnn_box_coder
.
FasterRcnnBoxCoder
()
box_coder_instance
=
faster_rcnn_box_coder
.
FasterRcnnBoxCoder
()
elif
reference
==
'DETR'
:
elif
reference
==
'DETR'
:
similarity_calc
=
sim_calc
.
DETRSimilarity
()
return
DETRTargetAssigner
()
matcher
=
hungarian_matcher
.
HungarianBipartiteMatcher
()
return
DETRTargetAssigner
(
similarity_calc
,
matcher
)
else
:
else
:
raise
ValueError
(
'No valid combination of reference and stage.'
)
raise
ValueError
(
'No valid combination of reference and stage.'
)
...
@@ -1917,9 +1915,7 @@ class CenterNetCornerOffsetTargetAssigner(object):
...
@@ -1917,9 +1915,7 @@ class CenterNetCornerOffsetTargetAssigner(object):
class
DETRTargetAssigner
(
object
):
class
DETRTargetAssigner
(
object
):
"""Target assigner to compute classification and regression targets."""
"""Target assigner to compute classification and regression targets."""
def
__init__
(
self
,
def
__init__
(
self
,
negative_class_weight
=
1.0
):
matcher
,
negative_class_weight
=
1.0
):
"""Construct Object Detection Target Assigner.
"""Construct Object Detection Target Assigner.
Args:
Args:
...
@@ -1931,8 +1927,6 @@ class DETRTargetAssigner(object):
...
@@ -1931,8 +1927,6 @@ class DETRTargetAssigner(object):
boxes (default: 1.0). The weight must be in [0., 1.].
boxes (default: 1.0). The weight must be in [0., 1.].
"""
"""
if
not
isinstance
(
matcher
,
mat
.
Matcher
):
raise
ValueError
(
'matcher must be a Matcher'
)
self
.
_similarity_calc
=
sim_calc
.
DETRSimilarity
()
self
.
_similarity_calc
=
sim_calc
.
DETRSimilarity
()
self
.
_matcher
=
hungarian_matcher
.
HungarianBipartiteMatcher
()
self
.
_matcher
=
hungarian_matcher
.
HungarianBipartiteMatcher
()
self
.
_negative_class_weight
=
negative_class_weight
self
.
_negative_class_weight
=
negative_class_weight
...
@@ -2024,39 +2018,36 @@ class DETRTargetAssigner(object):
...
@@ -2024,39 +2018,36 @@ class DETRTargetAssigner(object):
groundtruth_boxes
.
add_field
(
fields
.
BoxListFields
.
classes
,
groundtruth_labels
)
groundtruth_boxes
.
add_field
(
fields
.
BoxListFields
.
classes
,
groundtruth_labels
)
box_preds
.
add_field
(
fields
.
BoxListFields
.
classes
,
class_predictions
)
box_preds
.
add_field
(
fields
.
BoxListFields
.
classes
,
class_predictions
)
with
tf
.
control_dependencies
(
match_quality_matrix
=
self
.
_similarity_calc
.
compare
(
[
unmatched_shape_assert
,
labels_and_box_shapes_assert
]):
groundtruth_boxes
,
box_preds
)
match_quality_matrix
=
self
.
_similarity_calc
.
compare
(
match
=
self
.
_matcher
.
match
(
match_quality_matrix
,
groundtruth_boxes
,
valid_rows
=
tf
.
greater
(
groundtruth_weights
,
0
))
box_preds
)
match
=
self
.
_matcher
.
match
(
match_quality_matrix
,
valid_rows
=
tf
.
greater
(
groundtruth_weights
,
0
))
reg_targets
=
self
.
_create_regression_targets
(
box_preds
,
reg_targets
=
self
.
_create_regression_targets
(
box_preds
,
groundtruth_boxes
,
groundtruth_boxes
,
match
)
match
)
cls_targets
=
match
.
gather_based_on_match
(
cls_targets
=
match
.
gather_based_on_match
(
groundtruth_labels
,
groundtruth_labels
,
unmatched_value
=
unmatched_class_label
,
unmatched_value
=
unmatched_class_label
,
ignored_value
=
unmatched_class_label
)
ignored_value
=
unmatched_class_label
)
reg_weights
=
match
.
gather_based_on_match
(
groundtruth_weights
,
reg_weights
=
match
.
gather_based_on_match
(
groundtruth_weights
,
ignored_value
=
0.
,
ignored_value
=
0.
,
unmatched_value
=
0.
)
unmatched_value
=
0.
)
cls_weights
=
match
.
gather_based_on_match
(
cls_weights
=
match
.
gather_based_on_match
(
groundtruth_weights
,
groundtruth_weights
,
ignored_value
=
0.
,
ignored_value
=
0.
,
unmatched_value
=
self
.
_negative_class_weight
)
unmatched_value
=
self
.
_negative_class_weight
)
# convert cls_weights from per-box_pred to per-class.
# convert cls_weights from per-box_pred to per-class.
class_label_shape
=
tf
.
shape
(
cls_targets
)[
1
:]
class_label_shape
=
tf
.
shape
(
cls_targets
)[
1
:]
weights_shape
=
tf
.
shape
(
cls_weights
)
weights_shape
=
tf
.
shape
(
cls_weights
)
weights_multiple
=
tf
.
concat
(
weights_multiple
=
tf
.
concat
(
[
tf
.
ones_like
(
weights_shape
),
class_label_shape
],
[
tf
.
ones_like
(
weights_shape
),
class_label_shape
],
axis
=
0
)
axis
=
0
)
for
_
in
range
(
len
(
cls_targets
.
get_shape
()[
1
:])):
for
_
in
range
(
len
(
cls_targets
.
get_shape
()[
1
:])):
cls_weights
=
tf
.
expand_dims
(
cls_weights
,
-
1
)
cls_weights
=
tf
.
expand_dims
(
cls_weights
,
-
1
)
cls_weights
=
tf
.
tile
(
cls_weights
,
weights_multiple
)
cls_weights
=
tf
.
tile
(
cls_weights
,
weights_multiple
)
num_box_preds
=
box_preds
.
num_boxes_static
()
num_box_preds
=
box_preds
.
num_boxes_static
()
if
num_box_preds
is
not
None
:
if
num_box_preds
is
not
None
:
...
...
research/object_detection/core/target_assigner_test.py
View file @
e09e0566
...
@@ -19,7 +19,6 @@ import tensorflow.compat.v1 as tf
...
@@ -19,7 +19,6 @@ import tensorflow.compat.v1 as tf
from
object_detection.box_coders
import
keypoint_box_coder
from
object_detection.box_coders
import
keypoint_box_coder
from
object_detection.box_coders
import
mean_stddev_box_coder
from
object_detection.box_coders
import
mean_stddev_box_coder
from
object_detection.box_coders
import
detr_box_coder
from
object_detection.core
import
box_list
from
object_detection.core
import
box_list
from
object_detection.core
import
region_similarity_calculator
from
object_detection.core
import
region_similarity_calculator
from
object_detection.core
import
standard_fields
as
fields
from
object_detection.core
import
standard_fields
as
fields
...
@@ -2192,20 +2191,11 @@ class CornerOffsetTargetAssignerTest(test_case.TestCase):
...
@@ -2192,20 +2191,11 @@ class CornerOffsetTargetAssignerTest(test_case.TestCase):
self
.
assertAllClose
(
foreground
,
np
.
zeros
((
1
,
5
,
5
)))
self
.
assertAllClose
(
foreground
,
np
.
zeros
((
1
,
5
,
5
)))
if
__name__
==
'__main__'
:
class
DETRTargetAssignerTest
(
test_case
.
TestCase
):
tf
.
enable_v2_behavior
()
tf
.
test
.
main
()
class
DETRTargetAssignerTest
(
testcase
.
TestCase
):
def
test_assign_detr
(
self
):
def
test_assign_detr
(
self
):
def
graph_fn
(
anchor_means
,
groundtruth_box_corners
,
def
graph_fn
(
anchor_means
,
groundtruth_box_corners
,
groundtruth_labels
,
predicted_labels
):
groundtruth_labels
,
predicted_labels
):
similarity_calc
=
region_similarity_calculator
.
DETRSimilarity
()
detr_target_assigner
=
targetassigner
.
DETRTargetAssigner
()
matcher
=
hungarian_matcher
.
HungarianBipartiteMatcher
()
box_coder
=
detr_box_coder
.
DETRBoxCoder
()
detr_target_assigner
=
target_assigner
.
DETRTargetAssigner
(
similarity_calc
,
matcher
,
box_coder
)
anchors_boxlist
=
box_list
.
BoxList
(
anchor_means
)
anchors_boxlist
=
box_list
.
BoxList
(
anchor_means
)
groundtruth_boxlist
=
box_list
.
BoxList
(
groundtruth_box_corners
)
groundtruth_boxlist
=
box_list
.
BoxList
(
groundtruth_box_corners
)
result
=
detr_target_assigner
.
assign
(
result
=
detr_target_assigner
.
assign
(
...
@@ -2247,4 +2237,8 @@ class DETRTargetAssignerTest(testcase.TestCase):
...
@@ -2247,4 +2237,8 @@ class DETRTargetAssignerTest(testcase.TestCase):
self
.
assertEqual
(
cls_targets_out
.
dtype
,
np
.
float32
)
self
.
assertEqual
(
cls_targets_out
.
dtype
,
np
.
float32
)
self
.
assertEqual
(
cls_weights_out
.
dtype
,
np
.
float32
)
self
.
assertEqual
(
cls_weights_out
.
dtype
,
np
.
float32
)
self
.
assertEqual
(
reg_targets_out
.
dtype
,
np
.
float32
)
self
.
assertEqual
(
reg_targets_out
.
dtype
,
np
.
float32
)
self
.
assertEqual
(
reg_weights_out
.
dtype
,
np
.
float32
)
self
.
assertEqual
(
reg_weights_out
.
dtype
,
np
.
float32
)
\ No newline at end of file
if
__name__
==
'__main__'
:
tf
.
enable_v2_behavior
()
tf
.
test
.
main
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment