Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
fe31beae
Commit
fe31beae
authored
Feb 09, 2018
by
Zhichao Lu
Committed by
lzc5123016
Feb 13, 2018
Browse files
Ensure that `ops.matmul_gather_on_zeroth_axis` works with dynamic shapes.
PiperOrigin-RevId: 185218732
parent
1efe98bb
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
7 deletions
+23
-7
research/object_detection/utils/ops.py
research/object_detection/utils/ops.py
+5
-4
research/object_detection/utils/ops_test.py
research/object_detection/utils/ops_test.py
+18
-3
No files found.
research/object_detection/utils/ops.py
View file @
fe31beae
...
...
@@ -813,9 +813,10 @@ def matmul_gather_on_zeroth_axis(params, indices, scope=None):
from indices given by indices, with shape indices.shape + params.shape[1:].
"""
with
tf
.
name_scope
(
scope
,
'MatMulGather'
):
index_range
=
params
.
shape
[
0
]
params2d
=
tf
.
reshape
(
params
,
[
index_range
,
-
1
])
indicator_matrix
=
tf
.
one_hot
(
indices
,
index_range
)
params_shape
=
shape_utils
.
combined_static_and_dynamic_shape
(
params
)
indices_shape
=
shape_utils
.
combined_static_and_dynamic_shape
(
indices
)
params2d
=
tf
.
reshape
(
params
,
[
params_shape
[
0
],
-
1
])
indicator_matrix
=
tf
.
one_hot
(
indices
,
params_shape
[
0
])
gathered_result_flattened
=
tf
.
matmul
(
indicator_matrix
,
params2d
)
return
tf
.
reshape
(
gathered_result_flattened
,
indices
.
shape
.
concatenate
(
params
.
shape
[
1
:]))
tf
.
stack
(
indices
_
shape
+
params
_
shape
[
1
:]))
research/object_detection/utils/ops_test.py
View file @
fe31beae
...
...
@@ -840,7 +840,7 @@ class OpsTestPositionSensitiveCropRegions(tf.test.TestCase):
# All channels are equal so position-sensitive crop and resize should
# work as the usual crop and resize for just one channel.
crop
=
tf
.
image
.
crop_and_resize
(
image
,
boxes
,
box_ind
,
crop_size
)
crop_and_pool
=
tf
.
reduce_mean
(
crop
,
[
1
,
2
],
keep
_
dims
=
True
)
crop_and_pool
=
tf
.
reduce_mean
(
crop
,
[
1
,
2
],
keepdims
=
True
)
ps_crop_and_pool
=
ops
.
position_sensitive_crop_regions
(
tiled_image
,
...
...
@@ -866,7 +866,7 @@ class OpsTestPositionSensitiveCropRegions(tf.test.TestCase):
# When a single bin is used, position-sensitive crop and pool should be
# the same as non-position sensitive crop and pool.
crop
=
tf
.
image
.
crop_and_resize
(
image
,
boxes
,
box_ind
,
crop_size
)
crop_and_pool
=
tf
.
reduce_mean
(
crop
,
[
1
,
2
],
keep
_
dims
=
True
)
crop_and_pool
=
tf
.
reduce_mean
(
crop
,
[
1
,
2
],
keepdims
=
True
)
ps_crop_and_pool
=
ops
.
position_sensitive_crop_regions
(
image
,
boxes
,
box_ind
,
crop_size
,
num_spatial_bins
,
global_pool
=
True
)
...
...
@@ -1054,7 +1054,7 @@ class OpsTestPositionSensitiveCropRegions(tf.test.TestCase):
ps_crop
=
ops
.
position_sensitive_crop_regions
(
image
,
boxes
,
box_ind
,
crop_size
,
num_spatial_bins
,
global_pool
=
False
)
ps_crop_and_pool
=
tf
.
reduce_mean
(
ps_crop
,
reduction_indices
=
(
1
,
2
),
keep
_
dims
=
True
)
ps_crop
,
reduction_indices
=
(
1
,
2
),
keepdims
=
True
)
with
self
.
test_session
()
as
sess
:
output
=
sess
.
run
(
ps_crop_and_pool
)
...
...
@@ -1225,6 +1225,21 @@ class MatmulGatherOnZerothAxis(test_case.TestCase):
gather_output
=
self
.
execute
(
graph_fn
,
[
params
,
indices
])
self
.
assertAllClose
(
gather_output
,
expected_output
)
def
test_gather_with_dynamic_shape_input
(
self
):
params_placeholder
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
[
None
,
4
])
indices_placeholder
=
tf
.
placeholder
(
tf
.
int32
,
shape
=
[
None
])
gather_result
=
ops
.
matmul_gather_on_zeroth_axis
(
params_placeholder
,
indices_placeholder
)
params
=
np
.
array
([[
1
,
2
,
3
,
4
],
[
5
,
6
,
7
,
8
],
[
9
,
10
,
11
,
12
],
[
0
,
1
,
0
,
0
]],
dtype
=
np
.
float32
)
indices
=
np
.
array
([
0
,
0
,
0
,
0
,
0
,
0
])
expected_output
=
np
.
array
(
6
*
[[
1
,
2
,
3
,
4
]])
with
self
.
test_session
()
as
sess
:
gather_output
=
sess
.
run
(
gather_result
,
feed_dict
=
{
params_placeholder
:
params
,
indices_placeholder
:
indices
})
self
.
assertAllClose
(
gather_output
,
expected_output
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment