Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
fcb152bf
Commit
fcb152bf
authored
Sep 27, 2021
by
A. Unique TensorFlower
Committed by
TF Object Detection Team
Sep 27, 2021
Browse files
Limit k of top_k to the size of the dynamic input.
PiperOrigin-RevId: 399210729
parent
4b6cbef4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
1 deletion
+21
-1
research/object_detection/meta_architectures/center_net_meta_arch.py
...ject_detection/meta_architectures/center_net_meta_arch.py
+3
-1
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
...ction/meta_architectures/center_net_meta_arch_tf2_test.py
+18
-0
No files found.
research/object_detection/meta_architectures/center_net_meta_arch.py
View file @
fcb152bf
...
@@ -317,7 +317,9 @@ def top_k_feature_map_locations(feature_map, max_pool_kernel_size=3, k=100,
...
@@ -317,7 +317,9 @@ def top_k_feature_map_locations(feature_map, max_pool_kernel_size=3, k=100,
feature_map_peaks_flat
,
axis
=
1
,
output_type
=
tf
.
dtypes
.
int32
),
axis
=-
1
)
feature_map_peaks_flat
,
axis
=
1
,
output_type
=
tf
.
dtypes
.
int32
),
axis
=-
1
)
else
:
else
:
feature_map_peaks_flat
=
tf
.
reshape
(
feature_map_peaks
,
[
batch_size
,
-
1
])
feature_map_peaks_flat
=
tf
.
reshape
(
feature_map_peaks
,
[
batch_size
,
-
1
])
scores
,
peak_flat_indices
=
tf
.
math
.
top_k
(
feature_map_peaks_flat
,
k
=
k
)
safe_k
=
tf
.
minimum
(
k
,
tf
.
shape
(
feature_map_peaks_flat
)[
1
])
scores
,
peak_flat_indices
=
tf
.
math
.
top_k
(
feature_map_peaks_flat
,
k
=
safe_k
)
# Get x, y and channel indices corresponding to the top indices in the flat
# Get x, y and channel indices corresponding to the top indices in the flat
# array.
# array.
...
...
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
View file @
fcb152bf
...
@@ -469,6 +469,24 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
...
@@ -469,6 +469,24 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
np
.
testing
.
assert_array_equal
([
1
,
0
,
2
],
x_inds
[
1
])
np
.
testing
.
assert_array_equal
([
1
,
0
,
2
],
x_inds
[
1
])
np
.
testing
.
assert_array_equal
([
1
,
0
,
0
],
channel_inds
[
1
])
np
.
testing
.
assert_array_equal
([
1
,
0
,
0
],
channel_inds
[
1
])
def
test_top_k_feature_map_locations_very_large
(
self
):
feature_map_np
=
np
.
zeros
((
2
,
3
,
3
,
2
),
dtype
=
np
.
float32
)
feature_map_np
[
0
,
2
,
0
,
1
]
=
1.0
def
graph_fn
():
feature_map
=
tf
.
constant
(
feature_map_np
)
feature_map
.
set_shape
(
tf
.
TensorShape
([
2
,
3
,
None
,
2
]))
scores
,
y_inds
,
x_inds
,
channel_inds
=
(
cnma
.
top_k_feature_map_locations
(
feature_map
,
max_pool_kernel_size
=
1
,
k
=
3000
))
return
scores
,
y_inds
,
x_inds
,
channel_inds
# graph execution will fail if large k's are not handled.
scores
,
y_inds
,
x_inds
,
channel_inds
=
self
.
execute
(
graph_fn
,
[])
self
.
assertEqual
(
scores
.
shape
,
(
2
,
18
))
self
.
assertEqual
(
y_inds
.
shape
,
(
2
,
18
))
self
.
assertEqual
(
x_inds
.
shape
,
(
2
,
18
))
self
.
assertEqual
(
channel_inds
.
shape
,
(
2
,
18
))
def
test_top_k_feature_map_locations_per_channel
(
self
):
def
test_top_k_feature_map_locations_per_channel
(
self
):
feature_map_np
=
np
.
zeros
((
2
,
3
,
3
,
2
),
dtype
=
np
.
float32
)
feature_map_np
=
np
.
zeros
((
2
,
3
,
3
,
2
),
dtype
=
np
.
float32
)
feature_map_np
[
0
,
2
,
0
,
0
]
=
1.0
# Selected.
feature_map_np
[
0
,
2
,
0
,
0
]
=
1.0
# Selected.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment