Commit 582bf927 authored by derekjchow's avatar derekjchow Committed by GitHub
Browse files

Merge pull request #2053 from derekjchow/master

object_detection exporter updates
parents ecf5edf1 a2cb67c2
......@@ -111,3 +111,26 @@ def pad_or_clip_tensor(t, length):
if not _is_tensor(length):
processed_t = _set_dim_0(processed_t, length)
return processed_t
def combined_static_and_dynamic_shape(tensor):
"""Returns a list containing static and dynamic values for the dimensions.
Returns a list of static and dynamic values for shape dimensions. This is
useful to preserve static shapes when available in reshape operation.
Args:
tensor: A tensor of any type.
Returns:
A list of size tensor.shape.ndims containing integers or a scalar tensor.
"""
static_shape = tensor.shape.as_list()
dynamic_shape = tf.shape(tensor)
combined_shape = []
for index, dim in enumerate(static_shape):
if dim is not None:
combined_shape.append(dim)
else:
combined_shape.append(dynamic_shape[index])
return combined_shape
......@@ -115,6 +115,13 @@ class UtilTest(tf.test.TestCase):
self.assertAllEqual([1, 2], tt3_result)
self.assertAllClose([[0.1, 0.2], [0.2, 0.4]], tt4_result)
def test_combines_static_dynamic_shape(self):
tensor = tf.placeholder(tf.float32, shape=(None, 2, 3))
combined_shape = shape_utils.combined_static_and_dynamic_shape(
tensor)
self.assertTrue(tf.contrib.framework.is_tensor(combined_shape[0]))
self.assertListEqual(combined_shape[1:], [2, 3])
if __name__ == '__main__':
tf.test.main()
......@@ -22,6 +22,7 @@ from object_detection.core import box_coder
from object_detection.core import box_list
from object_detection.core import box_predictor
from object_detection.core import matcher
from object_detection.utils import shape_utils
class MockBoxCoder(box_coder.BoxCoder):
......@@ -45,9 +46,10 @@ class MockBoxPredictor(box_predictor.BoxPredictor):
super(MockBoxPredictor, self).__init__(is_training, num_classes)
def _predict(self, image_features, num_predictions_per_location):
batch_size = image_features.get_shape().as_list()[0]
num_anchors = (image_features.get_shape().as_list()[1]
* image_features.get_shape().as_list()[2])
combined_feature_shape = shape_utils.combined_static_and_dynamic_shape(
image_features)
batch_size = combined_feature_shape[0]
num_anchors = (combined_feature_shape[1] * combined_feature_shape[2])
code_size = 4
zero = tf.reduce_sum(0 * image_features)
box_encodings = zero + tf.zeros(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment