Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
481cf8da
Commit
481cf8da
authored
Oct 08, 2020
by
Vivek Rathod
Committed by
TF Object Detection Team
Oct 08, 2020
Browse files
Output multiclass scores from post-process.
PiperOrigin-RevId: 336090589
parent
92752da2
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
2 deletions
+12
-2
research/object_detection/meta_architectures/center_net_meta_arch.py
...ject_detection/meta_architectures/center_net_meta_arch.py
+6
-1
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
...ction/meta_architectures/center_net_meta_arch_tf2_test.py
+6
-1
No files found.
research/object_detection/meta_architectures/center_net_meta_arch.py
View file @
481cf8da
...
@@ -2856,6 +2856,8 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2856,6 +2856,8 @@ class CenterNetMetaArch(model.DetectionModel):
feature extractor's final layer output.
feature extractor's final layer output.
detection_scores: A tensor of shape [batch, max_detections] holding
detection_scores: A tensor of shape [batch, max_detections] holding
the predicted score for each box.
the predicted score for each box.
detection_multiclass_scores: A tensor of shape [batch, max_detection,
num_classes] holding multiclass score for each box.
detection_classes: An integer tensor of shape [batch, max_detections]
detection_classes: An integer tensor of shape [batch, max_detections]
containing the detected class for each box.
containing the detected class for each box.
num_detections: An integer tensor of shape [batch] containing the
num_detections: An integer tensor of shape [batch] containing the
...
@@ -2883,7 +2885,8 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2883,7 +2885,8 @@ class CenterNetMetaArch(model.DetectionModel):
top_k_feature_map_locations
(
top_k_feature_map_locations
(
object_center_prob
,
max_pool_kernel_size
=
3
,
object_center_prob
,
max_pool_kernel_size
=
3
,
k
=
self
.
_center_params
.
max_box_predictions
))
k
=
self
.
_center_params
.
max_box_predictions
))
multiclass_scores
=
tf
.
gather_nd
(
object_center_prob
,
tf
.
stack
([
y_indices
,
x_indices
],
-
1
),
batch_dims
=
1
)
boxes_strided
,
classes
,
scores
,
num_detections
=
(
boxes_strided
,
classes
,
scores
,
num_detections
=
(
prediction_tensors_to_boxes
(
prediction_tensors_to_boxes
(
detection_scores
,
y_indices
,
x_indices
,
channel_indices
,
detection_scores
,
y_indices
,
x_indices
,
channel_indices
,
...
@@ -2895,6 +2898,8 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2895,6 +2898,8 @@ class CenterNetMetaArch(model.DetectionModel):
postprocess_dict
=
{
postprocess_dict
=
{
fields
.
DetectionResultFields
.
detection_boxes
:
boxes
,
fields
.
DetectionResultFields
.
detection_boxes
:
boxes
,
fields
.
DetectionResultFields
.
detection_scores
:
scores
,
fields
.
DetectionResultFields
.
detection_scores
:
scores
,
fields
.
DetectionResultFields
.
detection_multiclass_scores
:
multiclass_scores
,
fields
.
DetectionResultFields
.
detection_classes
:
classes
,
fields
.
DetectionResultFields
.
detection_classes
:
classes
,
fields
.
DetectionResultFields
.
num_detections
:
num_detections
,
fields
.
DetectionResultFields
.
num_detections
:
num_detections
,
'detection_boxes_strided'
:
boxes_strided
'detection_boxes_strided'
:
boxes_strided
...
...
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
View file @
481cf8da
...
@@ -1507,7 +1507,7 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
...
@@ -1507,7 +1507,7 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
keypoint_offsets
=
np
.
zeros
((
1
,
32
,
32
,
2
),
dtype
=
np
.
float32
)
keypoint_offsets
=
np
.
zeros
((
1
,
32
,
32
,
2
),
dtype
=
np
.
float32
)
keypoint_regression
=
np
.
random
.
randn
(
1
,
32
,
32
,
num_keypoints
*
2
)
keypoint_regression
=
np
.
random
.
randn
(
1
,
32
,
32
,
num_keypoints
*
2
)
class_probs
=
np
.
zero
s
(
10
)
class_probs
=
np
.
one
s
(
10
)
*
_logit
(
0.25
)
class_probs
[
target_class_id
]
=
_logit
(
0.75
)
class_probs
[
target_class_id
]
=
_logit
(
0.75
)
class_center
[
0
,
16
,
16
]
=
class_probs
class_center
[
0
,
16
,
16
]
=
class_probs
height_width
[
0
,
16
,
16
]
=
[
5
,
10
]
height_width
[
0
,
16
,
16
]
=
[
5
,
10
]
...
@@ -1582,6 +1582,11 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
...
@@ -1582,6 +1582,11 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
np
.
array
([
55
,
46
,
75
,
86
])
/
128.0
)
np
.
array
([
55
,
46
,
75
,
86
])
/
128.0
)
self
.
assertAllClose
(
detections
[
'detection_scores'
][
0
],
self
.
assertAllClose
(
detections
[
'detection_scores'
][
0
],
[.
75
,
.
5
,
.
5
,
.
5
,
.
5
])
[.
75
,
.
5
,
.
5
,
.
5
,
.
5
])
expected_multiclass_scores
=
[.
25
]
*
10
expected_multiclass_scores
[
target_class_id
]
=
.
75
self
.
assertAllClose
(
expected_multiclass_scores
,
detections
[
'detection_multiclass_scores'
][
0
][
0
])
# The output embedding extracted at the object center will be a 3-D array of
# The output embedding extracted at the object center will be a 3-D array of
# shape [batch, num_boxes, embedding_size]. The valid predicted embedding
# shape [batch, num_boxes, embedding_size]. The valid predicted embedding
# will be the first embedding in the first batch. It is a 1-D array of
# will be the first embedding in the first batch. It is a 1-D array of
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment