Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d54c86de
Commit
d54c86de
authored
Aug 11, 2020
by
Kaushik Shivakumar
Browse files
make suggested fixes to target assigner and similarity calculator
parent
4f135c70
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
69 additions
and
105 deletions
+69
-105
research/object_detection/core/region_similarity_calculator.py
...rch/object_detection/core/region_similarity_calculator.py
+15
-39
research/object_detection/core/region_similarity_calculator_test.py
...bject_detection/core/region_similarity_calculator_test.py
+3
-1
research/object_detection/core/target_assigner.py
research/object_detection/core/target_assigner.py
+49
-63
research/object_detection/core/target_assigner_test.py
research/object_detection/core/target_assigner_test.py
+2
-2
No files found.
research/object_detection/core/region_similarity_calculator.py
View file @
d54c86de
...
@@ -35,8 +35,7 @@ from object_detection.core import standard_fields as fields
...
@@ -35,8 +35,7 @@ from object_detection.core import standard_fields as fields
class
RegionSimilarityCalculator
(
six
.
with_metaclass
(
ABCMeta
,
object
)):
class
RegionSimilarityCalculator
(
six
.
with_metaclass
(
ABCMeta
,
object
)):
"""Abstract base class for region similarity calculator."""
"""Abstract base class for region similarity calculator."""
def
compare
(
self
,
boxlist1
,
boxlist2
,
scope
=
None
,
def
compare
(
self
,
boxlist1
,
boxlist2
,
scope
=
None
):
groundtruth_labels
=
None
,
predicted_labels
=
None
):
"""Computes matrix of pairwise similarity between BoxLists.
"""Computes matrix of pairwise similarity between BoxLists.
This op (to be overridden) computes a measure of pairwise similarity between
This op (to be overridden) computes a measure of pairwise similarity between
...
@@ -49,10 +48,6 @@ class RegionSimilarityCalculator(six.with_metaclass(ABCMeta, object)):
...
@@ -49,10 +48,6 @@ class RegionSimilarityCalculator(six.with_metaclass(ABCMeta, object)):
boxlist1: BoxList holding N boxes.
boxlist1: BoxList holding N boxes.
boxlist2: BoxList holding M boxes.
boxlist2: BoxList holding M boxes.
scope: Op scope name. Defaults to 'Compare' if None.
scope: Op scope name. Defaults to 'Compare' if None.
groundtruth_labels: a Tensor of shape [num_boxes, num_classes]
containing groundtruth labels.
predicted_labels: a Tensor of shape [num_boxes, num_classes]
containing predicted labels.
Returns:
Returns:
a (float32) tensor of shape [N, M] with pairwise similarity score.
a (float32) tensor of shape [N, M] with pairwise similarity score.
...
@@ -72,17 +67,12 @@ class IouSimilarity(RegionSimilarityCalculator):
...
@@ -72,17 +67,12 @@ class IouSimilarity(RegionSimilarityCalculator):
This class computes pairwise similarity between two BoxLists based on IOU.
This class computes pairwise similarity between two BoxLists based on IOU.
"""
"""
def
_compare
(
self
,
boxlist1
,
boxlist2
,
def
_compare
(
self
,
boxlist1
,
boxlist2
):
groundtruth_labels
=
None
,
predicted_labels
=
None
):
"""Compute pairwise IOU similarity between the two BoxLists.
"""Compute pairwise IOU similarity between the two BoxLists.
Args:
Args:
boxlist1: BoxList holding N boxes.
boxlist1: BoxList holding N boxes.
boxlist2: BoxList holding M boxes.
boxlist2: BoxList holding M boxes.
groundtruth_labels: a Tensor of shape [num_boxes, num_classes]
containing groundtruth labels.
predicted_labels: a Tensor of shape [num_boxes, num_classes]
containing predicted labels.
Returns:
Returns:
A tensor with shape [N, M] representing pairwise iou scores.
A tensor with shape [N, M] representing pairwise iou scores.
...
@@ -95,25 +85,26 @@ class DETRSimilarity(RegionSimilarityCalculator):
...
@@ -95,25 +85,26 @@ class DETRSimilarity(RegionSimilarityCalculator):
This class computes pairwise similarity between two BoxLists using a weighted
This class computes pairwise similarity between two BoxLists using a weighted
combination of IOU, classification scores, and the L1 loss.
combination of IOU, classification scores, and the L1 loss.
"""
"""
def
__init__
(
self
,
l1_weight
=
5
,
giou_weight
=
2
):
self
.
l1_weight
=
l1_weight
self
.
giou_weight
=
giou_weight
def
_compare
(
self
,
boxlist1
,
boxlist2
,
def
_compare
(
self
,
boxlist1
,
boxlist2
):
groundtruth_labels
=
None
,
predicted_labels
=
None
):
"""Compute pairwise IOU similarity between the two BoxLists.
"""Compute pairwise IOU similarity between the two BoxLists.
Args:
Args:
boxlist1: BoxList holding N boxes.
boxlist1: BoxList holding N groundtruth boxes.
boxlist2: BoxList holding M boxes.
boxlist2: BoxList holding M predicted boxes.
groundtruth_labels: a Tensor of shape [num_boxes, num_classes]
containing groundtruth labels.
predicted_labels: a Tensor of shape [num_boxes, num_classes]
containing predicted labels.
Returns:
Returns:
A tensor with shape [N, M] representing pairwise iou scores.
A tensor with shape [N, M] representing pairwise iou scores.
"""
"""
groundtruth_labels
=
boxlist1
.
get_field
(
fields
.
BoxListFields
.
classes
)
predicted_labels
=
boxlist2
.
get_field
(
fields
.
BoxListFields
.
classes
)
classification_scores
=
tf
.
matmul
(
groundtruth_labels
,
classification_scores
=
tf
.
matmul
(
groundtruth_labels
,
tf
.
nn
.
softmax
(
predicted_labels
),
transpose_b
=
True
)
tf
.
nn
.
softmax
(
predicted_labels
),
transpose_b
=
True
)
return
-
5
*
box_list_ops
.
l1
(
boxlist1
,
boxlist2
)
+
2
*
box_list_ops
.
giou
(
return
-
self
.
l1_weight
*
box_list_ops
.
l1
(
boxlist1
,
boxlist2
)
+
self
.
giou_weight
*
box_list_ops
.
giou
(
boxlist1
,
boxlist2
)
+
classification_scores
boxlist1
,
boxlist2
)
+
classification_scores
class
NegSqDistSimilarity
(
RegionSimilarityCalculator
):
class
NegSqDistSimilarity
(
RegionSimilarityCalculator
):
...
@@ -123,17 +114,12 @@ class NegSqDistSimilarity(RegionSimilarityCalculator):
...
@@ -123,17 +114,12 @@ class NegSqDistSimilarity(RegionSimilarityCalculator):
negative squared distance metric.
negative squared distance metric.
"""
"""
def
_compare
(
self
,
boxlist1
,
boxlist2
,
def
_compare
(
self
,
boxlist1
,
boxlist2
):
groundtruth_labels
=
None
,
predicted_labels
=
None
):
"""Compute matrix of (negated) sq distances.
"""Compute matrix of (negated) sq distances.
Args:
Args:
boxlist1: BoxList holding N boxes.
boxlist1: BoxList holding N boxes.
boxlist2: BoxList holding M boxes.
boxlist2: BoxList holding M boxes.
groundtruth_labels: a Tensor of shape [num_boxes, num_classes]
containing groundtruth labels.
predicted_labels: a Tensor of shape [num_boxes, num_classes]
containing predicted labels.
Returns:
Returns:
A tensor with shape [N, M] representing negated pairwise squared distance.
A tensor with shape [N, M] representing negated pairwise squared distance.
...
@@ -147,17 +133,12 @@ class IoaSimilarity(RegionSimilarityCalculator):
...
@@ -147,17 +133,12 @@ class IoaSimilarity(RegionSimilarityCalculator):
pairwise intersections divided by the areas of second BoxLists.
pairwise intersections divided by the areas of second BoxLists.
"""
"""
def
_compare
(
self
,
boxlist1
,
boxlist2
,
def
_compare
(
self
,
boxlist1
,
boxlist2
):
groundtruth_labels
=
None
,
predicted_labels
=
None
):
"""Compute pairwise IOA similarity between the two BoxLists.
"""Compute pairwise IOA similarity between the two BoxLists.
Args:
Args:
boxlist1: BoxList holding N boxes.
boxlist1: BoxList holding N boxes.
boxlist2: BoxList holding M boxes.
boxlist2: BoxList holding M boxes.
groundtruth_labels: a Tensor of shape [num_boxes, num_classes]
containing groundtruth labels.
predicted_labels: a Tensor of shape [num_boxes, num_classes]
containing predicted labels.
Returns:
Returns:
A tensor with shape [N, M] representing pairwise IOA scores.
A tensor with shape [N, M] representing pairwise IOA scores.
...
@@ -184,17 +165,12 @@ class ThresholdedIouSimilarity(RegionSimilarityCalculator):
...
@@ -184,17 +165,12 @@ class ThresholdedIouSimilarity(RegionSimilarityCalculator):
super
(
ThresholdedIouSimilarity
,
self
).
__init__
()
super
(
ThresholdedIouSimilarity
,
self
).
__init__
()
self
.
_iou_threshold
=
iou_threshold
self
.
_iou_threshold
=
iou_threshold
def
_compare
(
self
,
boxlist1
,
boxlist2
,
def
_compare
(
self
,
boxlist1
,
boxlist2
):
groundtruth_labels
=
None
,
predicted_labels
=
None
):
"""Compute pairwise IOU similarity between the two BoxLists and score.
"""Compute pairwise IOU similarity between the two BoxLists and score.
Args:
Args:
boxlist1: BoxList holding N boxes. Must have a score field.
boxlist1: BoxList holding N boxes. Must have a score field.
boxlist2: BoxList holding M boxes.
boxlist2: BoxList holding M boxes.
groundtruth_labels: a Tensor of shape [num_boxes, num_classes]
containing groundtruth labels.
predicted_labels: a Tensor of shape [num_boxes, num_classes]
containing predicted labels.
Returns:
Returns:
A tensor with shape [N, M] representing scores threholded by pairwise
A tensor with shape [N, M] representing scores threholded by pairwise
...
...
research/object_detection/core/region_similarity_calculator_test.py
View file @
d54c86de
...
@@ -101,9 +101,11 @@ class RegionSimilarityCalculatorTest(test_case.TestCase):
...
@@ -101,9 +101,11 @@ class RegionSimilarityCalculatorTest(test_case.TestCase):
predicted_labels
=
tf
.
constant
([[
0.0
,
1000.0
],
[
1000.0
,
0.0
]])
predicted_labels
=
tf
.
constant
([[
0.0
,
1000.0
],
[
1000.0
,
0.0
]])
boxes1
=
box_list
.
BoxList
(
corners1
)
boxes1
=
box_list
.
BoxList
(
corners1
)
boxes2
=
box_list
.
BoxList
(
corners2
)
boxes2
=
box_list
.
BoxList
(
corners2
)
boxes1
.
add_field
(
fields
.
BoxListFields
.
classes
,
groundtruth_labels
)
boxes2
.
add_field
(
fields
.
BoxListFields
.
classes
,
predicted_labels
)
detr_similarity_calculator
=
region_similarity_calculator
.
DETRSimilarity
()
detr_similarity_calculator
=
region_similarity_calculator
.
DETRSimilarity
()
detr_similarity
=
detr_similarity_calculator
.
compare
(
detr_similarity
=
detr_similarity_calculator
.
compare
(
boxes1
,
boxes2
,
None
,
groundtruth_labels
,
predicted_labels
)
boxes1
,
boxes2
,
None
)
return
detr_similarity
return
detr_similarity
exp_output
=
[[
2.0
,
-
2.0
/
3.0
+
1.0
-
20.0
]]
exp_output
=
[[
2.0
,
-
2.0
/
3.0
+
1.0
-
20.0
]]
sim_output
=
self
.
execute
(
graph_fn
,
[])
sim_output
=
self
.
execute
(
graph_fn
,
[])
...
...
research/object_detection/core/target_assigner.py
View file @
d54c86de
...
@@ -51,6 +51,7 @@ from object_detection.core import matcher as mat
...
@@ -51,6 +51,7 @@ from object_detection.core import matcher as mat
from
object_detection.core
import
region_similarity_calculator
as
sim_calc
from
object_detection.core
import
region_similarity_calculator
as
sim_calc
from
object_detection.core
import
standard_fields
as
fields
from
object_detection.core
import
standard_fields
as
fields
from
object_detection.matchers
import
argmax_matcher
from
object_detection.matchers
import
argmax_matcher
from
object_detection.matchers
import
hungarian_matcher
from
object_detection.utils
import
shape_utils
from
object_detection.utils
import
shape_utils
from
object_detection.utils
import
target_assigner_utils
as
ta_utils
from
object_detection.utils
import
target_assigner_utils
as
ta_utils
from
object_detection.utils
import
tf_version
from
object_detection.utils
import
tf_version
...
@@ -1917,51 +1918,44 @@ class DETRTargetAssigner(object):
...
@@ -1917,51 +1918,44 @@ class DETRTargetAssigner(object):
"""Target assigner to compute classification and regression targets."""
"""Target assigner to compute classification and regression targets."""
def
__init__
(
self
,
def
__init__
(
self
,
similarity_calc
,
matcher
,
matcher
,
negative_class_weight
=
1.0
):
negative_class_weight
=
1.0
):
"""Construct Object Detection Target Assigner.
"""Construct Object Detection Target Assigner.
Args:
Args:
similarity_calc: a RegionSimilarityCalculator
matcher: an object_detection.core.Matcher used to match groundtruth to
matcher: an object_detection.core.Matcher used to match groundtruth to
anchor
s.
predicted boxe
s.
box_coder_instance: an object_detection.core.BoxCoder used to encode
box_coder_instance: an object_detection.core.BoxCoder used to encode
matching groundtruth boxes with respect to
anchor
s.
matching groundtruth boxes with respect to
predicted boxe
s.
negative_class_weight: classification weight to be associated to negative
negative_class_weight: classification weight to be associated to negative
anchor
s (default: 1.0). The weight must be in [0., 1.].
boxe
s (default: 1.0). The weight must be in [0., 1.].
Raises:
ValueError: if similarity_calc is not a RegionSimilarityCalculator or
if matcher is not a Matcher or if box_coder is not a BoxCoder
"""
"""
if
not
isinstance
(
similarity_calc
,
sim_calc
.
RegionSimilarityCalculator
):
raise
ValueError
(
'similarity_calc must be a RegionSimilarityCalculator'
)
if
not
isinstance
(
matcher
,
mat
.
Matcher
):
if
not
isinstance
(
matcher
,
mat
.
Matcher
):
raise
ValueError
(
'matcher must be a Matcher'
)
raise
ValueError
(
'matcher must be a Matcher'
)
self
.
_similarity_calc
=
similarity
_calc
self
.
_similarity_calc
=
sim
_calc
.
DETRSim
ilarity
()
self
.
_matcher
=
matcher
self
.
_matcher
=
hungarian_matcher
.
HungarianBipartiteMatcher
()
self
.
_negative_class_weight
=
negative_class_weight
self
.
_negative_class_weight
=
negative_class_weight
def
assign
(
self
,
def
assign
(
self
,
anchor
s
,
box_pred
s
,
groundtruth_boxes
,
groundtruth_boxes
,
groundtruth_labels
=
None
,
groundtruth_labels
=
None
,
unmatched_class_label
=
None
,
unmatched_class_label
=
None
,
groundtruth_weights
=
None
,
groundtruth_weights
=
None
,
class_predictions
=
None
):
class_predictions
=
None
):
"""Assign classification and regression targets to each
anchor
.
"""Assign classification and regression targets to each
box_pred
.
For a given set of
anchor
s and groundtruth detections, match
anchor
s
For a given set of
box_pred
s and groundtruth detections, match
box_pred
s
to groundtruth_boxes and assign classification and regression targets to
to groundtruth_boxes and assign classification and regression targets to
each
anchor
as well as weights based on the resulting match (specifying,
each
box_pred
as well as weights based on the resulting match (specifying,
e.g., which
anchor
s should not contribute to training loss).
e.g., which
box_pred
s should not contribute to training loss).
Anchor
s that are not matched to anything are given a classification target
box_pred
s that are not matched to anything are given a classification target
of self._unmatched_cls_target which can be specified via the constructor.
of self._unmatched_cls_target which can be specified via the constructor.
Args:
Args:
anchor
s: a BoxList representing N
anchor
s
box_pred
s: a BoxList representing N
box_pred
s
groundtruth_boxes: a BoxList representing M groundtruth boxes
groundtruth_boxes: a BoxList representing M groundtruth boxes
groundtruth_labels: a tensor of shape [M, d_1, ... d_k]
groundtruth_labels: a tensor of shape [M, d_1, ... d_k]
with labels for each of the ground_truth boxes. The subshape
with labels for each of the ground_truth boxes. The subshape
...
@@ -1970,14 +1964,14 @@ class DETRTargetAssigner(object):
...
@@ -1970,14 +1964,14 @@ class DETRTargetAssigner(object):
ground_truth boxes get a positive label (of 1).
ground_truth boxes get a positive label (of 1).
unmatched_class_label: a float32 tensor with shape [d_1, d_2, ..., d_k]
unmatched_class_label: a float32 tensor with shape [d_1, d_2, ..., d_k]
which is consistent with the classification target for each
which is consistent with the classification target for each
anchor
(and can be empty for scalar targets). This shape must thus be
box_pred
(and can be empty for scalar targets). This shape must thus be
compatible with the groundtruth labels that are passed to the "assign"
compatible with the groundtruth labels that are passed to the "assign"
function (which have shape [num_gt_boxes, d_1, d_2, ..., d_k]).
function (which have shape [num_gt_boxes, d_1, d_2, ..., d_k]).
If set to None, unmatched_cls_target is set to be [0] for each
anchor
.
If set to None, unmatched_cls_target is set to be [0] for each
box_pred
.
groundtruth_weights: a float tensor of shape [M] indicating the weight to
groundtruth_weights: a float tensor of shape [M] indicating the weight to
assign to all
anchor
s match to a particular groundtruth box. The weights
assign to all
box_pred
s match to a particular groundtruth box. The weights
must be in [0., 1.]. If None, all weights are set to 1. Generally no
must be in [0., 1.]. If None, all weights are set to 1. Generally no
groundtruth boxes with zero weight match to any
anchor
s as matchers are
groundtruth boxes with zero weight match to any
box_pred
s as matchers are
aware of groundtruth weights. Additionally, `cls_weights` and
aware of groundtruth weights. Additionally, `cls_weights` and
`reg_weights` are calculated using groundtruth weights as an added
`reg_weights` are calculated using groundtruth weights as an added
safety.
safety.
...
@@ -1985,27 +1979,27 @@ class DETRTargetAssigner(object):
...
@@ -1985,27 +1979,27 @@ class DETRTargetAssigner(object):
to be used by certain similarity calculators.
to be used by certain similarity calculators.
Returns:
Returns:
cls_targets: a float32 tensor with shape [num_
anchor
s, d_1, d_2 ... d_k],
cls_targets: a float32 tensor with shape [num_
box_pred
s, d_1, d_2 ... d_k],
where the subshape [d_1, ..., d_k] is compatible with groundtruth_labels
where the subshape [d_1, ..., d_k] is compatible with groundtruth_labels
which has shape [num_gt_boxes, d_1, d_2, ... d_k].
which has shape [num_gt_boxes, d_1, d_2, ... d_k].
cls_weights: a float32 tensor with shape [num_
anchor
s, d_1, d_2 ... d_k],
cls_weights: a float32 tensor with shape [num_
box_pred
s, d_1, d_2 ... d_k],
representing weights for each element in cls_targets.
representing weights for each element in cls_targets.
reg_targets: a float32 tensor with shape [num_
anchor
s, box_code_dimension]
reg_targets: a float32 tensor with shape [num_
box_pred
s, box_code_dimension]
reg_weights: a float32 tensor with shape [num_
anchor
s]
reg_weights: a float32 tensor with shape [num_
box_pred
s]
match: an int32 tensor of shape [num_
anchor
s] containing result of
anchor
match: an int32 tensor of shape [num_
box_pred
s] containing result of
box_pred
groundtruth matching. Each position in the tensor indicates an
anchor
groundtruth matching. Each position in the tensor indicates an
box_pred
and holds the following meaning:
and holds the following meaning:
(1) if match[i] >= 0,
anchor
i is matched with groundtruth match[i].
(1) if match[i] >= 0,
box_pred
i is matched with groundtruth match[i].
(2) if match[i]=-1,
anchor
i is marked to be background .
(2) if match[i]=-1,
box_pred
i is marked to be background .
(3) if match[i]=-2,
anchor
i is ignored since it is not background and
(3) if match[i]=-2,
box_pred
i is ignored since it is not background and
does not have sufficient overlap to call it a foreground.
does not have sufficient overlap to call it a foreground.
Raises:
Raises:
ValueError: if
anchor
s or groundtruth_boxes are not of type
ValueError: if
box_pred
s or groundtruth_boxes are not of type
box_list.BoxList
box_list.BoxList
"""
"""
if
not
isinstance
(
anchor
s
,
box_list
.
BoxList
):
if
not
isinstance
(
box_pred
s
,
box_list
.
BoxList
):
raise
ValueError
(
'
anchor
s must be an BoxList'
)
raise
ValueError
(
'
box_pred
s must be an BoxList'
)
if
not
isinstance
(
groundtruth_boxes
,
box_list
.
BoxList
):
if
not
isinstance
(
groundtruth_boxes
,
box_list
.
BoxList
):
raise
ValueError
(
'groundtruth_boxes must be an BoxList'
)
raise
ValueError
(
'groundtruth_boxes must be an BoxList'
)
...
@@ -2017,15 +2011,6 @@ class DETRTargetAssigner(object):
...
@@ -2017,15 +2011,6 @@ class DETRTargetAssigner(object):
0
))
0
))
groundtruth_labels
=
tf
.
expand_dims
(
groundtruth_labels
,
-
1
)
groundtruth_labels
=
tf
.
expand_dims
(
groundtruth_labels
,
-
1
)
unmatched_shape_assert
=
shape_utils
.
assert_shape_equal
(
shape_utils
.
combined_static_and_dynamic_shape
(
groundtruth_labels
)[
1
:],
shape_utils
.
combined_static_and_dynamic_shape
(
unmatched_class_label
))
labels_and_box_shapes_assert
=
shape_utils
.
assert_shape_equal
(
shape_utils
.
combined_static_and_dynamic_shape
(
groundtruth_labels
)[:
1
],
shape_utils
.
combined_static_and_dynamic_shape
(
groundtruth_boxes
.
get
())[:
1
])
if
groundtruth_weights
is
None
:
if
groundtruth_weights
is
None
:
num_gt_boxes
=
groundtruth_boxes
.
num_boxes_static
()
num_gt_boxes
=
groundtruth_boxes
.
num_boxes_static
()
if
not
num_gt_boxes
:
if
not
num_gt_boxes
:
...
@@ -2036,18 +2021,19 @@ class DETRTargetAssigner(object):
...
@@ -2036,18 +2021,19 @@ class DETRTargetAssigner(object):
scores
=
1
-
groundtruth_labels
[:,
0
]
scores
=
1
-
groundtruth_labels
[:,
0
]
groundtruth_boxes
.
add_field
(
fields
.
BoxListFields
.
scores
,
scores
)
groundtruth_boxes
.
add_field
(
fields
.
BoxListFields
.
scores
,
scores
)
groundtruth_boxes
.
add_field
(
fields
.
BoxListFields
.
classes
,
groundtruth_labels
)
box_preds
.
add_field
(
fields
.
BoxListFields
.
classes
,
class_predictions
)
with
tf
.
control_dependencies
(
with
tf
.
control_dependencies
(
[
unmatched_shape_assert
,
labels_and_box_shapes_assert
]):
[
unmatched_shape_assert
,
labels_and_box_shapes_assert
]):
match_quality_matrix
=
self
.
_similarity_calc
.
compare
(
match_quality_matrix
=
self
.
_similarity_calc
.
compare
(
groundtruth_boxes
,
groundtruth_boxes
,
anchors
,
box_preds
)
groundtruth_labels
=
groundtruth_labels
,
predicted_labels
=
class_predictions
)
match
=
self
.
_matcher
.
match
(
match_quality_matrix
,
match
=
self
.
_matcher
.
match
(
match_quality_matrix
,
valid_rows
=
tf
.
greater
(
groundtruth_weights
,
0
))
valid_rows
=
tf
.
greater
(
groundtruth_weights
,
0
))
reg_targets
=
self
.
_create_regression_targets
(
anchor
s
,
reg_targets
=
self
.
_create_regression_targets
(
box_pred
s
,
groundtruth_boxes
,
groundtruth_boxes
,
match
)
match
)
cls_targets
=
match
.
gather_based_on_match
(
cls_targets
=
match
.
gather_based_on_match
(
...
@@ -2062,7 +2048,7 @@ class DETRTargetAssigner(object):
...
@@ -2062,7 +2048,7 @@ class DETRTargetAssigner(object):
ignored_value
=
0.
,
ignored_value
=
0.
,
unmatched_value
=
self
.
_negative_class_weight
)
unmatched_value
=
self
.
_negative_class_weight
)
# convert cls_weights from per-
anchor
to per-class.
# convert cls_weights from per-
box_pred
to per-class.
class_label_shape
=
tf
.
shape
(
cls_targets
)[
1
:]
class_label_shape
=
tf
.
shape
(
cls_targets
)[
1
:]
weights_shape
=
tf
.
shape
(
cls_weights
)
weights_shape
=
tf
.
shape
(
cls_weights
)
weights_multiple
=
tf
.
concat
(
weights_multiple
=
tf
.
concat
(
...
@@ -2072,37 +2058,37 @@ class DETRTargetAssigner(object):
...
@@ -2072,37 +2058,37 @@ class DETRTargetAssigner(object):
cls_weights
=
tf
.
expand_dims
(
cls_weights
,
-
1
)
cls_weights
=
tf
.
expand_dims
(
cls_weights
,
-
1
)
cls_weights
=
tf
.
tile
(
cls_weights
,
weights_multiple
)
cls_weights
=
tf
.
tile
(
cls_weights
,
weights_multiple
)
num_
anchors
=
anchor
s
.
num_boxes_static
()
num_
box_preds
=
box_pred
s
.
num_boxes_static
()
if
num_
anchor
s
is
not
None
:
if
num_
box_pred
s
is
not
None
:
reg_targets
=
self
.
_reset_target_shape
(
reg_targets
,
num_
anchor
s
)
reg_targets
=
self
.
_reset_target_shape
(
reg_targets
,
num_
box_pred
s
)
cls_targets
=
self
.
_reset_target_shape
(
cls_targets
,
num_
anchor
s
)
cls_targets
=
self
.
_reset_target_shape
(
cls_targets
,
num_
box_pred
s
)
reg_weights
=
self
.
_reset_target_shape
(
reg_weights
,
num_
anchor
s
)
reg_weights
=
self
.
_reset_target_shape
(
reg_weights
,
num_
box_pred
s
)
cls_weights
=
self
.
_reset_target_shape
(
cls_weights
,
num_
anchor
s
)
cls_weights
=
self
.
_reset_target_shape
(
cls_weights
,
num_
box_pred
s
)
return
(
cls_targets
,
cls_weights
,
reg_targets
,
reg_weights
,
return
(
cls_targets
,
cls_weights
,
reg_targets
,
reg_weights
,
match
.
match_results
)
match
.
match_results
)
def
_reset_target_shape
(
self
,
target
,
num_
anchor
s
):
def
_reset_target_shape
(
self
,
target
,
num_
box_pred
s
):
"""Sets the static shape of the target.
"""Sets the static shape of the target.
Args:
Args:
target: the target tensor. Its first dimension will be overwritten.
target: the target tensor. Its first dimension will be overwritten.
num_
anchor
s: the number of
anchor
s, which is used to override the target's
num_
box_pred
s: the number of
box_pred
s, which is used to override the target's
first dimension.
first dimension.
Returns:
Returns:
A tensor with the shape info filled in.
A tensor with the shape info filled in.
"""
"""
target_shape
=
target
.
get_shape
().
as_list
()
target_shape
=
target
.
get_shape
().
as_list
()
target_shape
[
0
]
=
num_
anchor
s
target_shape
[
0
]
=
num_
box_pred
s
target
.
set_shape
(
target_shape
)
target
.
set_shape
(
target_shape
)
return
target
return
target
def
_create_regression_targets
(
self
,
anchor
s
,
groundtruth_boxes
,
match
):
def
_create_regression_targets
(
self
,
box_pred
s
,
groundtruth_boxes
,
match
):
"""Returns a regression target for each
anchor
.
"""Returns a regression target for each
box_pred
.
Args:
Args:
anchor
s: a BoxList representing N
anchor
s
box_pred
s: a BoxList representing N
box_pred
s
groundtruth_boxes: a BoxList representing M groundtruth_boxes
groundtruth_boxes: a BoxList representing M groundtruth_boxes
match: a matcher.Match object
match: a matcher.Match object
...
@@ -2123,8 +2109,8 @@ class DETRTargetAssigner(object):
...
@@ -2123,8 +2109,8 @@ class DETRTargetAssigner(object):
# Zero out the unmatched and ignored regression targets.
# Zero out the unmatched and ignored regression targets.
unmatched_ignored_reg_targets
=
tf
.
tile
(
unmatched_ignored_reg_targets
=
tf
.
tile
(
tf
.
constant
([
4
*
[
0
]],
tf
.
float32
),
[
match_results_shape
[
0
],
1
])
tf
.
constant
([
4
*
[
0
]],
tf
.
float32
),
[
match_results_shape
[
0
],
1
])
matched_
anchor
s_mask
=
match
.
matched_column_indicator
()
matched_
box_pred
s_mask
=
match
.
matched_column_indicator
()
reg_targets
=
tf
.
where
(
matched_
anchor
s_mask
,
reg_targets
=
tf
.
where
(
matched_
box_pred
s_mask
,
matched_reg_targets
,
matched_reg_targets
,
unmatched_ignored_reg_targets
)
unmatched_ignored_reg_targets
)
return
reg_targets
return
reg_targets
research/object_detection/core/target_assigner_test.py
View file @
d54c86de
...
@@ -2204,11 +2204,11 @@ class DETRTargetAssignerTest(testcase.TestCase):
...
@@ -2204,11 +2204,11 @@ class DETRTargetAssignerTest(testcase.TestCase):
similarity_calc
=
region_similarity_calculator
.
DETRSimilarity
()
similarity_calc
=
region_similarity_calculator
.
DETRSimilarity
()
matcher
=
hungarian_matcher
.
HungarianBipartiteMatcher
()
matcher
=
hungarian_matcher
.
HungarianBipartiteMatcher
()
box_coder
=
detr_box_coder
.
DETRBoxCoder
()
box_coder
=
detr_box_coder
.
DETRBoxCoder
()
target_assigner
=
targetassigner
.
TargetAssigner
(
detr_
target_assigner
=
target
_
assigner
.
DETR
TargetAssigner
(
similarity_calc
,
matcher
,
box_coder
)
similarity_calc
,
matcher
,
box_coder
)
anchors_boxlist
=
box_list
.
BoxList
(
anchor_means
)
anchors_boxlist
=
box_list
.
BoxList
(
anchor_means
)
groundtruth_boxlist
=
box_list
.
BoxList
(
groundtruth_box_corners
)
groundtruth_boxlist
=
box_list
.
BoxList
(
groundtruth_box_corners
)
result
=
target_assigner
.
assign
(
result
=
detr_
target_assigner
.
assign
(
anchors_boxlist
,
groundtruth_boxlist
,
anchors_boxlist
,
groundtruth_boxlist
,
unmatched_class_label
=
tf
.
constant
(
unmatched_class_label
=
tf
.
constant
(
[
1
,
0
],
dtype
=
tf
.
float32
),
[
1
,
0
],
dtype
=
tf
.
float32
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment