Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
ab96cb33
Commit
ab96cb33
authored
Aug 11, 2020
by
Kaushik Shivakumar
Browse files
separate out DETR
parent
e0b082ed
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
357 additions
and
64 deletions
+357
-64
research/object_detection/core/target_assigner.py
research/object_detection/core/target_assigner.py
+304
-13
research/object_detection/core/target_assigner_test.py
research/object_detection/core/target_assigner_test.py
+53
-51
No files found.
research/object_detection/core/target_assigner.py
View file @
ab96cb33
...
@@ -57,8 +57,6 @@ from object_detection.utils import tf_version
...
@@ -57,8 +57,6 @@ from object_detection.utils import tf_version
if
tf_version
.
is_tf1
():
if
tf_version
.
is_tf1
():
from
object_detection.matchers
import
bipartite_matcher
# pylint: disable=g-import-not-at-top
from
object_detection.matchers
import
bipartite_matcher
# pylint: disable=g-import-not-at-top
elif
tf_version
.
is_tf2
():
from
object_detection.matchers
import
hungarian_matcher
ResizeMethod
=
tf2
.
image
.
ResizeMethod
ResizeMethod
=
tf2
.
image
.
ResizeMethod
...
@@ -142,8 +140,6 @@ class TargetAssigner(object):
...
@@ -142,8 +140,6 @@ class TargetAssigner(object):
aware of groundtruth weights. Additionally, `cls_weights` and
aware of groundtruth weights. Additionally, `cls_weights` and
`reg_weights` are calculated using groundtruth weights as an added
`reg_weights` are calculated using groundtruth weights as an added
safety.
safety.
class_predictions: A tensor with shape [max_num_boxes, d_1, d_2, ..., d_k]
to be used by certain similarity calculators.
Returns:
Returns:
cls_targets: a float32 tensor with shape [num_anchors, d_1, d_2 ... d_k],
cls_targets: a float32 tensor with shape [num_anchors, d_1, d_2 ... d_k],
...
@@ -199,15 +195,10 @@ class TargetAssigner(object):
...
@@ -199,15 +195,10 @@ class TargetAssigner(object):
with
tf
.
control_dependencies
(
with
tf
.
control_dependencies
(
[
unmatched_shape_assert
,
labels_and_box_shapes_assert
]):
[
unmatched_shape_assert
,
labels_and_box_shapes_assert
]):
match_quality_matrix
=
self
.
_similarity_calc
.
compare
(
groundtruth_boxes
,
match_quality_matrix
=
self
.
_similarity_calc
.
compare
(
anchors
)
groundtruth_boxes
,
anchors
,
groundtruth_labels
=
groundtruth_labels
,
predicted_labels
=
class_predictions
)
match
=
self
.
_matcher
.
match
(
match_quality_matrix
,
match
=
self
.
_matcher
.
match
(
match_quality_matrix
,
valid_rows
=
tf
.
greater
(
groundtruth_weights
,
0
))
valid_rows
=
tf
.
greater
(
groundtruth_weights
,
0
))
reg_targets
=
self
.
_create_regression_targets
(
anchors
,
reg_targets
=
self
.
_create_regression_targets
(
anchors
,
groundtruth_boxes
,
groundtruth_boxes
,
match
)
match
)
...
@@ -447,7 +438,7 @@ def create_target_assigner(reference, stage=None,
...
@@ -447,7 +438,7 @@ def create_target_assigner(reference, stage=None,
elif
reference
==
'DETR'
:
elif
reference
==
'DETR'
:
similarity_calc
=
sim_calc
.
DETRSimilarity
()
similarity_calc
=
sim_calc
.
DETRSimilarity
()
matcher
=
hungarian_matcher
.
HungarianBipartiteMatcher
()
matcher
=
hungarian_matcher
.
HungarianBipartiteMatcher
()
box_coder_instance
=
detr_box_coder
.
DETRBoxCod
er
(
)
return
DETRTargetAssigner
(
similarity_calc
,
match
er
)
else
:
else
:
raise
ValueError
(
'No valid combination of reference and stage.'
)
raise
ValueError
(
'No valid combination of reference and stage.'
)
...
@@ -1920,4 +1911,304 @@ class CenterNetCornerOffsetTargetAssigner(object):
...
@@ -1920,4 +1911,304 @@ class CenterNetCornerOffsetTargetAssigner(object):
corner_targets
.
append
(
corner_target
)
corner_targets
.
append
(
corner_target
)
return
(
tf
.
stack
(
corner_targets
,
axis
=
0
),
return
(
tf
.
stack
(
corner_targets
,
axis
=
0
),
tf
.
stack
(
foreground_targets
,
axis
=
0
))
tf
.
stack
(
foreground_targets
,
axis
=
0
))
\ No newline at end of file
class
DETRTargetAssigner
(
object
):
"""Target assigner to compute classification and regression targets."""
def
__init__
(
self
,
similarity_calc
,
matcher
,
negative_class_weight
=
1.0
):
"""Construct Object Detection Target Assigner.
Args:
similarity_calc: a RegionSimilarityCalculator
matcher: an object_detection.core.Matcher used to match groundtruth to
anchors.
box_coder_instance: an object_detection.core.BoxCoder used to encode
matching groundtruth boxes with respect to anchors.
negative_class_weight: classification weight to be associated to negative
anchors (default: 1.0). The weight must be in [0., 1.].
Raises:
ValueError: if similarity_calc is not a RegionSimilarityCalculator or
if matcher is not a Matcher or if box_coder is not a BoxCoder
"""
if
not
isinstance
(
similarity_calc
,
sim_calc
.
RegionSimilarityCalculator
):
raise
ValueError
(
'similarity_calc must be a RegionSimilarityCalculator'
)
if
not
isinstance
(
matcher
,
mat
.
Matcher
):
raise
ValueError
(
'matcher must be a Matcher'
)
self
.
_similarity_calc
=
similarity_calc
self
.
_matcher
=
matcher
self
.
_negative_class_weight
=
negative_class_weight
def
assign
(
self
,
anchors
,
groundtruth_boxes
,
groundtruth_labels
=
None
,
unmatched_class_label
=
None
,
groundtruth_weights
=
None
,
class_predictions
=
None
):
"""Assign classification and regression targets to each anchor.
For a given set of anchors and groundtruth detections, match anchors
to groundtruth_boxes and assign classification and regression targets to
each anchor as well as weights based on the resulting match (specifying,
e.g., which anchors should not contribute to training loss).
Anchors that are not matched to anything are given a classification target
of self._unmatched_cls_target which can be specified via the constructor.
Args:
anchors: a BoxList representing N anchors
groundtruth_boxes: a BoxList representing M groundtruth boxes
groundtruth_labels: a tensor of shape [M, d_1, ... d_k]
with labels for each of the ground_truth boxes. The subshape
[d_1, ... d_k] can be empty (corresponding to scalar inputs). When set
to None, groundtruth_labels assumes a binary problem where all
ground_truth boxes get a positive label (of 1).
unmatched_class_label: a float32 tensor with shape [d_1, d_2, ..., d_k]
which is consistent with the classification target for each
anchor (and can be empty for scalar targets). This shape must thus be
compatible with the groundtruth labels that are passed to the "assign"
function (which have shape [num_gt_boxes, d_1, d_2, ..., d_k]).
If set to None, unmatched_cls_target is set to be [0] for each anchor.
groundtruth_weights: a float tensor of shape [M] indicating the weight to
assign to all anchors match to a particular groundtruth box. The weights
must be in [0., 1.]. If None, all weights are set to 1. Generally no
groundtruth boxes with zero weight match to any anchors as matchers are
aware of groundtruth weights. Additionally, `cls_weights` and
`reg_weights` are calculated using groundtruth weights as an added
safety.
class_predictions: A tensor with shape [max_num_boxes, d_1, d_2, ..., d_k]
to be used by certain similarity calculators.
Returns:
cls_targets: a float32 tensor with shape [num_anchors, d_1, d_2 ... d_k],
where the subshape [d_1, ..., d_k] is compatible with groundtruth_labels
which has shape [num_gt_boxes, d_1, d_2, ... d_k].
cls_weights: a float32 tensor with shape [num_anchors, d_1, d_2 ... d_k],
representing weights for each element in cls_targets.
reg_targets: a float32 tensor with shape [num_anchors, box_code_dimension]
reg_weights: a float32 tensor with shape [num_anchors]
match: an int32 tensor of shape [num_anchors] containing result of anchor
groundtruth matching. Each position in the tensor indicates an anchor
and holds the following meaning:
(1) if match[i] >= 0, anchor i is matched with groundtruth match[i].
(2) if match[i]=-1, anchor i is marked to be background .
(3) if match[i]=-2, anchor i is ignored since it is not background and
does not have sufficient overlap to call it a foreground.
Raises:
ValueError: if anchors or groundtruth_boxes are not of type
box_list.BoxList
"""
if
not
isinstance
(
anchors
,
box_list
.
BoxList
):
raise
ValueError
(
'anchors must be an BoxList'
)
if
not
isinstance
(
groundtruth_boxes
,
box_list
.
BoxList
):
raise
ValueError
(
'groundtruth_boxes must be an BoxList'
)
if
unmatched_class_label
is
None
:
unmatched_class_label
=
tf
.
constant
([
0
],
tf
.
float32
)
if
groundtruth_labels
is
None
:
groundtruth_labels
=
tf
.
ones
(
tf
.
expand_dims
(
groundtruth_boxes
.
num_boxes
(),
0
))
groundtruth_labels
=
tf
.
expand_dims
(
groundtruth_labels
,
-
1
)
unmatched_shape_assert
=
shape_utils
.
assert_shape_equal
(
shape_utils
.
combined_static_and_dynamic_shape
(
groundtruth_labels
)[
1
:],
shape_utils
.
combined_static_and_dynamic_shape
(
unmatched_class_label
))
labels_and_box_shapes_assert
=
shape_utils
.
assert_shape_equal
(
shape_utils
.
combined_static_and_dynamic_shape
(
groundtruth_labels
)[:
1
],
shape_utils
.
combined_static_and_dynamic_shape
(
groundtruth_boxes
.
get
())[:
1
])
if
groundtruth_weights
is
None
:
num_gt_boxes
=
groundtruth_boxes
.
num_boxes_static
()
if
not
num_gt_boxes
:
num_gt_boxes
=
groundtruth_boxes
.
num_boxes
()
groundtruth_weights
=
tf
.
ones
([
num_gt_boxes
],
dtype
=
tf
.
float32
)
# set scores on the gt boxes
scores
=
1
-
groundtruth_labels
[:,
0
]
groundtruth_boxes
.
add_field
(
fields
.
BoxListFields
.
scores
,
scores
)
with
tf
.
control_dependencies
(
[
unmatched_shape_assert
,
labels_and_box_shapes_assert
]):
match_quality_matrix
=
self
.
_similarity_calc
.
compare
(
groundtruth_boxes
,
anchors
,
groundtruth_labels
=
groundtruth_labels
,
predicted_labels
=
class_predictions
)
match
=
self
.
_matcher
.
match
(
match_quality_matrix
,
valid_rows
=
tf
.
greater
(
groundtruth_weights
,
0
))
reg_targets
=
self
.
_create_regression_targets
(
anchors
,
groundtruth_boxes
,
match
)
cls_targets
=
self
.
_create_classification_targets
(
groundtruth_labels
,
unmatched_class_label
,
match
)
reg_weights
=
self
.
_create_regression_weights
(
match
,
groundtruth_weights
)
cls_weights
=
self
.
_create_classification_weights
(
match
,
groundtruth_weights
)
# convert cls_weights from per-anchor to per-class.
class_label_shape
=
tf
.
shape
(
cls_targets
)[
1
:]
weights_shape
=
tf
.
shape
(
cls_weights
)
weights_multiple
=
tf
.
concat
(
[
tf
.
ones_like
(
weights_shape
),
class_label_shape
],
axis
=
0
)
for
_
in
range
(
len
(
cls_targets
.
get_shape
()[
1
:])):
cls_weights
=
tf
.
expand_dims
(
cls_weights
,
-
1
)
cls_weights
=
tf
.
tile
(
cls_weights
,
weights_multiple
)
num_anchors
=
anchors
.
num_boxes_static
()
if
num_anchors
is
not
None
:
reg_targets
=
self
.
_reset_target_shape
(
reg_targets
,
num_anchors
)
cls_targets
=
self
.
_reset_target_shape
(
cls_targets
,
num_anchors
)
reg_weights
=
self
.
_reset_target_shape
(
reg_weights
,
num_anchors
)
cls_weights
=
self
.
_reset_target_shape
(
cls_weights
,
num_anchors
)
return
(
cls_targets
,
cls_weights
,
reg_targets
,
reg_weights
,
match
.
match_results
)
def
_reset_target_shape
(
self
,
target
,
num_anchors
):
"""Sets the static shape of the target.
Args:
target: the target tensor. Its first dimension will be overwritten.
num_anchors: the number of anchors, which is used to override the target's
first dimension.
Returns:
A tensor with the shape info filled in.
"""
target_shape
=
target
.
get_shape
().
as_list
()
target_shape
[
0
]
=
num_anchors
target
.
set_shape
(
target_shape
)
return
target
def
_create_regression_targets
(
self
,
anchors
,
groundtruth_boxes
,
match
):
"""Returns a regression target for each anchor.
Args:
anchors: a BoxList representing N anchors
groundtruth_boxes: a BoxList representing M groundtruth_boxes
match: a matcher.Match object
Returns:
reg_targets: a float32 tensor with shape [N, box_code_dimension]
"""
matched_gt_boxes
=
match
.
gather_based_on_match
(
groundtruth_boxes
.
get
(),
unmatched_value
=
tf
.
zeros
(
4
),
ignored_value
=
tf
.
zeros
(
4
))
matched_gt_boxlist
=
box_list
.
BoxList
(
matched_gt_boxes
)
ty
,
tx
,
th
,
tw
=
matched_gt_boxlist
.
get_center_coordinates_and_sizes
()
matched_reg_targets
=
tf
.
transpose
(
tf
.
stack
([
ty
,
tx
,
th
,
tw
]))
match_results_shape
=
shape_utils
.
combined_static_and_dynamic_shape
(
match
.
match_results
)
# Zero out the unmatched and ignored regression targets.
unmatched_ignored_reg_targets
=
tf
.
tile
(
self
.
_default_regression_target
(),
[
match_results_shape
[
0
],
1
])
matched_anchors_mask
=
match
.
matched_column_indicator
()
reg_targets
=
tf
.
where
(
matched_anchors_mask
,
matched_reg_targets
,
unmatched_ignored_reg_targets
)
return
reg_targets
def
_default_regression_target
(
self
):
"""Returns the default target for anchors to regress to.
Default regression targets are set to zero (though in
this implementation what these targets are set to should
not matter as the regression weight of any box set to
regress to the default target is zero).
Returns:
default_target: a float32 tensor with shape [1, box_code_dimension]
"""
return
tf
.
constant
([
4
*
[
0
]],
tf
.
float32
)
def
_create_classification_targets
(
self
,
groundtruth_labels
,
unmatched_class_label
,
match
):
"""Create classification targets for each anchor.
Assign a classification target of for each anchor to the matching
groundtruth label that is provided by match. Anchors that are not matched
to anything are given the target self._unmatched_cls_target
Args:
groundtruth_labels: a tensor of shape [num_gt_boxes, d_1, ... d_k]
with labels for each of the ground_truth boxes. The subshape
[d_1, ... d_k] can be empty (corresponding to scalar labels).
unmatched_class_label: a float32 tensor with shape [d_1, d_2, ..., d_k]
which is consistent with the classification target for each
anchor (and can be empty for scalar targets). This shape must thus be
compatible with the groundtruth labels that are passed to the "assign"
function (which have shape [num_gt_boxes, d_1, d_2, ..., d_k]).
match: a matcher.Match object that provides a matching between anchors
and groundtruth boxes.
Returns:
a float32 tensor with shape [num_anchors, d_1, d_2 ... d_k], where the
subshape [d_1, ..., d_k] is compatible with groundtruth_labels which has
shape [num_gt_boxes, d_1, d_2, ... d_k].
"""
return
match
.
gather_based_on_match
(
groundtruth_labels
,
unmatched_value
=
unmatched_class_label
,
ignored_value
=
unmatched_class_label
)
def
_create_regression_weights
(
self
,
match
,
groundtruth_weights
):
"""Set regression weight for each anchor.
Only positive anchors are set to contribute to the regression loss, so this
method returns a weight of 1 for every positive anchor and 0 for every
negative anchor.
Args:
match: a matcher.Match object that provides a matching between anchors
and groundtruth boxes.
groundtruth_weights: a float tensor of shape [M] indicating the weight to
assign to all anchors match to a particular groundtruth box.
Returns:
a float32 tensor with shape [num_anchors] representing regression weights.
"""
return
match
.
gather_based_on_match
(
groundtruth_weights
,
ignored_value
=
0.
,
unmatched_value
=
0.
)
def
_create_classification_weights
(
self
,
match
,
groundtruth_weights
):
"""Create classification weights for each anchor.
Positive (matched) anchors are associated with a weight of
positive_class_weight and negative (unmatched) anchors are associated with
a weight of negative_class_weight. When anchors are ignored, weights are set
to zero. By default, both positive/negative weights are set to 1.0,
but they can be adjusted to handle class imbalance (which is almost always
the case in object detection).
Args:
match: a matcher.Match object that provides a matching between anchors
and groundtruth boxes.
groundtruth_weights: a float tensor of shape [M] indicating the weight to
assign to all anchors match to a particular groundtruth box.
Returns:
a float32 tensor with shape [num_anchors] representing classification
weights.
"""
return
match
.
gather_based_on_match
(
groundtruth_weights
,
ignored_value
=
0.
,
unmatched_value
=
self
.
_negative_class_weight
)
research/object_detection/core/target_assigner_test.py
View file @
ab96cb33
...
@@ -116,57 +116,6 @@ class TargetAssignerTest(test_case.TestCase):
...
@@ -116,57 +116,6 @@ class TargetAssignerTest(test_case.TestCase):
self
.
assertEqual
(
reg_targets_out
.
dtype
,
np
.
float32
)
self
.
assertEqual
(
reg_targets_out
.
dtype
,
np
.
float32
)
self
.
assertEqual
(
reg_weights_out
.
dtype
,
np
.
float32
)
self
.
assertEqual
(
reg_weights_out
.
dtype
,
np
.
float32
)
def
test_assign_detr
(
self
):
def
graph_fn
(
anchor_means
,
groundtruth_box_corners
,
groundtruth_labels
,
predicted_labels
):
similarity_calc
=
region_similarity_calculator
.
DETRSimilarity
()
matcher
=
hungarian_matcher
.
HungarianBipartiteMatcher
()
box_coder
=
detr_box_coder
.
DETRBoxCoder
()
target_assigner
=
targetassigner
.
TargetAssigner
(
similarity_calc
,
matcher
,
box_coder
)
anchors_boxlist
=
box_list
.
BoxList
(
anchor_means
)
groundtruth_boxlist
=
box_list
.
BoxList
(
groundtruth_box_corners
)
result
=
target_assigner
.
assign
(
anchors_boxlist
,
groundtruth_boxlist
,
unmatched_class_label
=
tf
.
constant
(
[
1
,
0
],
dtype
=
tf
.
float32
),
groundtruth_labels
=
groundtruth_labels
,
class_predictions
=
predicted_labels
)
(
cls_targets
,
cls_weights
,
reg_targets
,
reg_weights
,
_
)
=
result
return
(
cls_targets
,
cls_weights
,
reg_targets
,
reg_weights
)
anchor_means
=
np
.
array
([[
0.25
,
0.25
,
0.4
,
0.2
],
[
0.5
,
0.8
,
1.0
,
0.8
],
[
0.9
,
0.5
,
0.1
,
1.0
]],
dtype
=
np
.
float32
)
groundtruth_box_corners
=
np
.
array
([[
0.0
,
0.0
,
0.5
,
0.5
],
[
0.5
,
0.5
,
0.9
,
0.9
]],
dtype
=
np
.
float32
)
predicted_labels
=
np
.
array
([[
-
3.0
,
3.0
],
[
2.0
,
9.4
],
[
5.0
,
1.0
]],
dtype
=
np
.
float32
)
groundtruth_labels
=
np
.
array
([[
0.0
,
1.0
],
[
0.0
,
1.0
]],
dtype
=
np
.
float32
)
exp_cls_targets
=
[[
0
,
1
],
[
0
,
1
],
[
1
,
0
]]
exp_cls_weights
=
[[
1
,
1
],
[
1
,
1
],
[
1
,
1
]]
exp_reg_targets
=
[[
0.25
,
0.25
,
0.5
,
0.5
],
[
0.7
,
0.7
,
0.4
,
0.4
],
[
0
,
0
,
0
,
0
]]
exp_reg_weights
=
[
1
,
1
,
0
]
(
cls_targets_out
,
cls_weights_out
,
reg_targets_out
,
reg_weights_out
)
=
self
.
execute
(
graph_fn
,
[
anchor_means
,
groundtruth_box_corners
,
groundtruth_labels
,
predicted_labels
])
self
.
assertAllClose
(
cls_targets_out
,
exp_cls_targets
)
self
.
assertAllClose
(
cls_weights_out
,
exp_cls_weights
)
self
.
assertAllClose
(
reg_targets_out
,
exp_reg_targets
)
self
.
assertAllClose
(
reg_weights_out
,
exp_reg_weights
)
self
.
assertEqual
(
cls_targets_out
.
dtype
,
np
.
float32
)
self
.
assertEqual
(
cls_weights_out
.
dtype
,
np
.
float32
)
self
.
assertEqual
(
reg_targets_out
.
dtype
,
np
.
float32
)
self
.
assertEqual
(
reg_weights_out
.
dtype
,
np
.
float32
)
def
test_assign_agnostic_with_keypoints
(
self
):
def
test_assign_agnostic_with_keypoints
(
self
):
def
graph_fn
(
anchor_means
,
groundtruth_box_corners
,
def
graph_fn
(
anchor_means
,
groundtruth_box_corners
,
...
@@ -2246,3 +2195,56 @@ class CornerOffsetTargetAssignerTest(test_case.TestCase):
...
@@ -2246,3 +2195,56 @@ class CornerOffsetTargetAssignerTest(test_case.TestCase):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
enable_v2_behavior
()
tf
.
enable_v2_behavior
()
tf
.
test
.
main
()
tf
.
test
.
main
()
class
DETRTargetAssignerTest
(
testcase
.
TestCase
):
def
test_assign_detr
(
self
):
def
graph_fn
(
anchor_means
,
groundtruth_box_corners
,
groundtruth_labels
,
predicted_labels
):
similarity_calc
=
region_similarity_calculator
.
DETRSimilarity
()
matcher
=
hungarian_matcher
.
HungarianBipartiteMatcher
()
box_coder
=
detr_box_coder
.
DETRBoxCoder
()
target_assigner
=
targetassigner
.
TargetAssigner
(
similarity_calc
,
matcher
,
box_coder
)
anchors_boxlist
=
box_list
.
BoxList
(
anchor_means
)
groundtruth_boxlist
=
box_list
.
BoxList
(
groundtruth_box_corners
)
result
=
target_assigner
.
assign
(
anchors_boxlist
,
groundtruth_boxlist
,
unmatched_class_label
=
tf
.
constant
(
[
1
,
0
],
dtype
=
tf
.
float32
),
groundtruth_labels
=
groundtruth_labels
,
class_predictions
=
predicted_labels
)
(
cls_targets
,
cls_weights
,
reg_targets
,
reg_weights
,
_
)
=
result
return
(
cls_targets
,
cls_weights
,
reg_targets
,
reg_weights
)
anchor_means
=
np
.
array
([[
0.25
,
0.25
,
0.4
,
0.2
],
[
0.5
,
0.8
,
1.0
,
0.8
],
[
0.9
,
0.5
,
0.1
,
1.0
]],
dtype
=
np
.
float32
)
groundtruth_box_corners
=
np
.
array
([[
0.0
,
0.0
,
0.5
,
0.5
],
[
0.5
,
0.5
,
0.9
,
0.9
]],
dtype
=
np
.
float32
)
predicted_labels
=
np
.
array
([[
-
3.0
,
3.0
],
[
2.0
,
9.4
],
[
5.0
,
1.0
]],
dtype
=
np
.
float32
)
groundtruth_labels
=
np
.
array
([[
0.0
,
1.0
],
[
0.0
,
1.0
]],
dtype
=
np
.
float32
)
exp_cls_targets
=
[[
0
,
1
],
[
0
,
1
],
[
1
,
0
]]
exp_cls_weights
=
[[
1
,
1
],
[
1
,
1
],
[
1
,
1
]]
exp_reg_targets
=
[[
0.25
,
0.25
,
0.5
,
0.5
],
[
0.7
,
0.7
,
0.4
,
0.4
],
[
0
,
0
,
0
,
0
]]
exp_reg_weights
=
[
1
,
1
,
0
]
(
cls_targets_out
,
cls_weights_out
,
reg_targets_out
,
reg_weights_out
)
=
self
.
execute
(
graph_fn
,
[
anchor_means
,
groundtruth_box_corners
,
groundtruth_labels
,
predicted_labels
])
self
.
assertAllClose
(
cls_targets_out
,
exp_cls_targets
)
self
.
assertAllClose
(
cls_weights_out
,
exp_cls_weights
)
self
.
assertAllClose
(
reg_targets_out
,
exp_reg_targets
)
self
.
assertAllClose
(
reg_weights_out
,
exp_reg_weights
)
self
.
assertEqual
(
cls_targets_out
.
dtype
,
np
.
float32
)
self
.
assertEqual
(
cls_weights_out
.
dtype
,
np
.
float32
)
self
.
assertEqual
(
reg_targets_out
.
dtype
,
np
.
float32
)
self
.
assertEqual
(
reg_weights_out
.
dtype
,
np
.
float32
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment