Commit 6685fb8b authored by derekjchow's avatar derekjchow Committed by GitHub
Browse files

Merge pull request #1593 from justanhduc/patch-2

Change key of type 'tuple' to 'str'
parents 1b7eb90f f906646c
......@@ -20,6 +20,8 @@ import tensorflow as tf
from object_detection.core import prefetcher
rt_shape_str = '_runtime_shapes'
class BatchQueue(object):
"""BatchQueue class.
......@@ -81,8 +83,9 @@ class BatchQueue(object):
{key: tensor.get_shape() for key, tensor in tensor_dict.items()})
# Remember runtime shapes to unpad tensors after batching.
runtime_shapes = collections.OrderedDict(
{(key, 'runtime_shapes'): tf.shape(tensor)
for key, tensor in tensor_dict.items()})
{(key + rt_shape_str): tf.shape(tensor)
for key, tensor in tensor_dict.iteritems()})
all_tensors = tensor_dict
all_tensors.update(runtime_shapes)
batched_tensors = tf.train.batch(
......@@ -112,8 +115,8 @@ class BatchQueue(object):
for key, batched_tensor in batched_tensors.items():
unbatched_tensor_list = tf.unstack(batched_tensor)
for i, unbatched_tensor in enumerate(unbatched_tensor_list):
if isinstance(key, tuple) and key[1] == 'runtime_shapes':
shapes[(key[0], i)] = unbatched_tensor
if rt_shape_str in key:
shapes[(key[:-len(rt_shape_str)], i)] = unbatched_tensor
else:
tensors[(key, i)] = unbatched_tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment