Unverified Commit 32671be9 authored by Xavier Gibert's avatar Xavier Gibert Committed by GitHub
Browse files

attention_ocr: added export for SavedModel format. (#8757)

* Added export for SavedModel format.

* Fixed some pylint errors.
parent b548c7fd
......@@ -166,6 +166,14 @@ implement one in Python or C++.
The recommended way is to use the [Serving infrastructure][serving].
To export to SavedModel format:
```
python model_export.py \
--checkpoint=model.ckpt-399731 \
--export_dir=/tmp/attention_ocr_export
```
Alternatively you can:
1. define a placeholder for images (or use directly an numpy array)
2. [create a graph ](https://github.com/tensorflow/models/blob/master/research/attention_ocr/python/eval.py#L60)
......@@ -188,7 +196,7 @@ other than a one time experiment please use the [TensorFlow Serving][serving].
[1]: https://github.com/tensorflow/tensorflow/blob/aaf7adc/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py
[2]: https://www.tensorflow.org/api_docs/python/tf/contrib/framework/assign_from_checkpoint_fn
[serving]: https://tensorflow.github.io/serving/serving_basic
[serving]: https://www.tensorflow.org/tfx/serving/serving_basic
## Disclaimer
......
......@@ -14,10 +14,10 @@
# ==============================================================================
"""Define flags are common for both train.py and eval.py scripts."""
import logging
import sys
from tensorflow.python.platform import flags
import logging
import datasets
import model
......@@ -35,9 +35,17 @@ logging.basicConfig(
datefmt='%Y-%m-%d %H:%M:%S')
_common_flags_defined = False
def define():
"""Define common flags."""
# yapf: disable
# common_flags.define() may be called multiple times in unit tests.
global _common_flags_defined
if _common_flags_defined:
return
_common_flags_defined = True
flags.DEFINE_integer('batch_size', 32,
'Batch size.')
......@@ -74,7 +82,7 @@ def define():
'the optimizer to use')
flags.DEFINE_float('momentum', 0.9,
'momentum value for the momentum optimizer if used')
'momentum value for the momentum optimizer if used')
flags.DEFINE_bool('use_augment_input', True,
'If True will use image augmentation')
......
......@@ -144,9 +144,6 @@ def preprocess_image(image, augment=False, central_crop_size=None,
images = [augment_image(img) for img in images]
image = tf.concat(images, 1)
image = tf.subtract(image, 0.5)
image = tf.multiply(image, 2.5)
return image
......
......@@ -177,6 +177,8 @@ def get_split(split_name, dataset_dir=None, config=None):
items_to_descriptions=config['items_to_descriptions'],
# additional parameters for convenience.
charset=charset,
charset_file=charset_file,
image_shape=config['image_shape'],
num_char_classes=len(charset),
num_of_views=config['num_of_views'],
max_sequence_length=config['max_sequence_length'],
......
This diff is collapsed.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Converts existing checkpoint into a SavedModel.
Usage example:
python model_export.py \
--logtostderr --checkpoint=model.ckpt-399731 \
--export_dir=/tmp/attention_ocr_export
"""
import os
import tensorflow as tf
from tensorflow import app
from tensorflow.contrib import slim
from tensorflow.python.platform import flags
import common_flags
import model_export_lib
FLAGS = flags.FLAGS
common_flags.define()
flags.DEFINE_string('export_dir', None, 'Directory to export model files to.')
flags.DEFINE_integer(
'image_width', None,
'Image width used during training (or crop width if used)'
' If not set, the dataset default is used instead.')
flags.DEFINE_integer(
'image_height', None,
'Image height used during training(or crop height if used)'
' If not set, the dataset default is used instead.')
flags.DEFINE_string('work_dir', '/tmp', 'A directory to store temporary files.')
flags.DEFINE_integer('version_number', 1, 'Version number of the model')
flags.DEFINE_bool(
'export_for_serving', True,
'Whether the exported model accepts serialized tf.Example '
'protos as input')
def get_checkpoint_path():
"""Returns a path to a checkpoint based on specified commandline flags.
In order to specify a full path to a checkpoint use --checkpoint flag.
Alternatively, if --train_log_dir was specified it will return a path to the
most recent checkpoint.
Raises:
ValueError: in case it can't find a checkpoint.
Returns:
A string.
"""
if FLAGS.checkpoint:
return FLAGS.checkpoint
else:
model_save_path = tf.train.latest_checkpoint(FLAGS.train_log_dir)
if not model_save_path:
raise ValueError('Can\'t find a checkpoint in: %s' % FLAGS.train_log_dir)
return model_save_path
def export_model(export_dir,
export_for_serving,
batch_size=None,
crop_image_width=None,
crop_image_height=None):
"""Exports a model to the named directory.
Note that --datatset_name and --checkpoint are required and parsed by the
underlying module common_flags.
Args:
export_dir: The output dir where model is exported to.
export_for_serving: If True, expects a serialized image as input and attach
image normalization as part of exported graph.
batch_size: For non-serving export, the input batch_size needs to be
specified.
crop_image_width: Width of the input image. Uses the dataset default if
None.
crop_image_height: Height of the input image. Uses the dataset default if
None.
Returns:
Returns the model signature_def.
"""
# Dataset object used only to get all parameters for the model.
dataset = common_flags.create_dataset(split_name='test')
model = common_flags.create_model(
dataset.num_char_classes,
dataset.max_sequence_length,
dataset.num_of_views,
dataset.null_code,
charset=dataset.charset)
dataset_image_height, dataset_image_width, image_depth = dataset.image_shape
# Add check for charmap file
if not os.path.exists(dataset.charset_file):
raise ValueError('No charset defined at {}: export will fail'.format(
dataset.charset))
# Default to dataset dimensions, otherwise use provided dimensions.
image_width = crop_image_width or dataset_image_width
image_height = crop_image_height or dataset_image_height
if export_for_serving:
images_orig = tf.placeholder(
tf.string, shape=[batch_size], name='tf_example')
images_orig_float = model_export_lib.generate_tfexample_image(
images_orig,
image_height,
image_width,
image_depth,
name='float_images')
else:
images_shape = (batch_size, image_height, image_width, image_depth)
images_orig = tf.placeholder(
tf.uint8, shape=images_shape, name='original_image')
images_orig_float = tf.image.convert_image_dtype(
images_orig, dtype=tf.float32, name='float_images')
endpoints = model.create_base(images_orig_float, labels_one_hot=None)
sess = tf.Session()
saver = tf.train.Saver(slim.get_variables_to_restore(), sharded=True)
saver.restore(sess, get_checkpoint_path())
tf.logging.info('Model restored successfully.')
# Create model signature.
if export_for_serving:
input_tensors = {
tf.saved_model.signature_constants.CLASSIFY_INPUTS: images_orig
}
else:
input_tensors = {'images': images_orig}
signature_inputs = model_export_lib.build_tensor_info(input_tensors)
# NOTE: Tensors 'image_float' and 'chars_logit' are used by the inference
# or to compute saliency maps.
output_tensors = {
'images_float': images_orig_float,
'predictions': endpoints.predicted_chars,
'scores': endpoints.predicted_scores,
'chars_logit': endpoints.chars_logit,
'predicted_length': endpoints.predicted_length,
'predicted_text': endpoints.predicted_text,
'predicted_conf': endpoints.predicted_conf,
'normalized_seq_conf': endpoints.normalized_seq_conf
}
for i, t in enumerate(
model_export_lib.attention_ocr_attention_masks(
dataset.max_sequence_length)):
output_tensors['attention_mask_%d' % i] = t
signature_outputs = model_export_lib.build_tensor_info(output_tensors)
signature_def = tf.saved_model.signature_def_utils.build_signature_def(
signature_inputs, signature_outputs,
tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME)
# Save model.
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
signature_def
},
main_op=tf.tables_initializer(),
strip_default_attrs=True)
builder.save()
tf.logging.info('Model has been exported to %s' % export_dir)
return signature_def
def main(unused_argv):
if os.path.exists(FLAGS.export_dir):
raise ValueError('export_dir already exists: exporting will fail')
export_model(FLAGS.export_dir, FLAGS.export_for_serving, FLAGS.batch_size,
FLAGS.image_width, FLAGS.image_height)
if __name__ == '__main__':
flags.mark_flag_as_required('dataset_name')
flags.mark_flag_as_required('export_dir')
app.run(main)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions for exporting Attention OCR model."""
import tensorflow as tf
# Function borrowed from research/object_detection/core/preprocessor.py
def normalize_image(image, original_minval, original_maxval, target_minval,
target_maxval):
"""Normalizes pixel values in the image.
Moves the pixel values from the current [original_minval, original_maxval]
range to a the [target_minval, target_maxval] range.
Args:
image: rank 3 float32 tensor containing 1 image -> [height, width,
channels].
original_minval: current image minimum value.
original_maxval: current image maximum value.
target_minval: target image minimum value.
target_maxval: target image maximum value.
Returns:
image: image which is the same shape as input image.
"""
with tf.name_scope('NormalizeImage', values=[image]):
original_minval = float(original_minval)
original_maxval = float(original_maxval)
target_minval = float(target_minval)
target_maxval = float(target_maxval)
image = tf.cast(image, dtype=tf.float32)
image = tf.subtract(image, original_minval)
image = tf.multiply(image, (target_maxval - target_minval) /
(original_maxval - original_minval))
image = tf.add(image, target_minval)
return image
def generate_tfexample_image(input_example_strings,
image_height,
image_width,
image_channels,
name=None):
"""Parses a 1D tensor of serialized tf.Example protos and returns image batch.
Args:
input_example_strings: A 1-Dimensional tensor of size [batch_size] and type
tf.string containing a serialized Example proto per image.
image_height: First image dimension.
image_width: Second image dimension.
image_channels: Third image dimension.
name: optional tensor name.
Returns:
A tensor with shape [batch_size, height, width, channels] of type float32
with values in the range [0..1]
"""
batch_size = tf.shape(input_example_strings)[0]
images_shape = tf.stack(
[batch_size, image_height, image_width, image_channels])
tf_example_image_key = 'image/encoded'
feature_configs = {
tf_example_image_key:
tf.FixedLenFeature(
image_height * image_width * image_channels, dtype=tf.float32)
}
feature_tensors = tf.parse_example(input_example_strings, feature_configs)
float_images = tf.reshape(
normalize_image(
feature_tensors[tf_example_image_key],
original_minval=0.0,
original_maxval=255.0,
target_minval=0.0,
target_maxval=1.0),
images_shape,
name=name)
return float_images
def attention_ocr_attention_masks(num_characters):
# TODO(gorban): use tensors directly after replacing LSTM unroll methods.
prefix = ('AttentionOcr_v1/'
'sequence_logit_fn/SQLR/LSTM/attention_decoder/Attention_0')
names = ['%s/Softmax:0' % (prefix)]
for i in range(1, num_characters):
names += ['%s_%d/Softmax:0' % (prefix, i)]
return [tf.get_default_graph().get_tensor_by_name(n) for n in names]
def build_tensor_info(tensor_dict):
return {
k: tf.saved_model.utils.build_tensor_info(t)
for k, t in tensor_dict.items()
}
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for model_export."""
import os
import numpy as np
from absl.testing import flagsaver
import tensorflow as tf
import common_flags
import model_export
_CHECKPOINT = 'model.ckpt-399731'
_CHECKPOINT_URL = (
'http://download.tensorflow.org/models/attention_ocr_2017_08_09.tar.gz')
def _clean_up():
tf.gfile.DeleteRecursively(tf.test.get_temp_dir())
def _create_tf_example_string(image):
"""Create a serialized tf.Example proto for feeding the model."""
example = tf.train.Example()
example.features.feature['image/encoded'].float_list.value.extend(
list(np.reshape(image, (-1))))
return example.SerializeToString()
class AttentionOcrExportTest(tf.test.TestCase):
"""Tests for model_export.export_model."""
def setUp(self):
for suffix in ['.meta', '.index', '.data-00000-of-00001']:
filename = _CHECKPOINT + suffix
self.assertTrue(
tf.gfile.Exists(filename),
msg='Missing checkpoint file %s. '
'Please download and extract it from %s' %
(filename, _CHECKPOINT_URL))
tf.flags.FLAGS.dataset_name = 'fsns'
tf.flags.FLAGS.checkpoint = _CHECKPOINT
tf.flags.FLAGS.dataset_dir = os.path.join(
os.path.dirname(__file__), 'datasets/testdata/fsns')
tf.test.TestCase.setUp(self)
_clean_up()
self.export_dir = os.path.join(tf.test.get_temp_dir(), 'exported_model')
self.minimal_output_signature = {
'predictions': 'AttentionOcr_v1/predicted_chars:0',
'scores': 'AttentionOcr_v1/predicted_scores:0',
'predicted_length': 'AttentionOcr_v1/predicted_length:0',
'predicted_text': 'AttentionOcr_v1/predicted_text:0',
'predicted_conf': 'AttentionOcr_v1/predicted_conf:0',
'normalized_seq_conf': 'AttentionOcr_v1/normalized_seq_conf:0'
}
def create_input_feed(self, graph_def, serving):
"""Returns the input feed for the model.
Creates random images, according to the size specified by dataset_name,
format it in the correct way depending on whether the model was exported
for serving, and return the correctly keyed feed_dict for inference.
Args:
graph_def: Graph definition of the loaded model.
serving: Whether the model was exported for Serving.
Returns:
The feed_dict suitable for model inference.
"""
# Creates a dataset based on FLAGS.dataset_name.
self.dataset = common_flags.create_dataset('test')
# Create some random images to test inference for any dataset.
self.images = {
'img1':
np.random.uniform(low=64, high=192,
size=self.dataset.image_shape).astype('uint8'),
'img2':
np.random.uniform(low=32, high=224,
size=self.dataset.image_shape).astype('uint8'),
}
signature_def = graph_def.signature_def[
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
if serving:
input_name = signature_def.inputs[
tf.saved_model.signature_constants.CLASSIFY_INPUTS].name
# Model for serving takes input: inputs['inputs'] = 'tf_example:0'
feed_dict = {
input_name: [
_create_tf_example_string(self.images['img1']),
_create_tf_example_string(self.images['img2'])
]
}
else:
input_name = signature_def.inputs['images'].name
# Model for direct use takes input: inputs['images'] = 'original_image:0'
feed_dict = {
input_name: np.stack([self.images['img1'], self.images['img2']])
}
return feed_dict
def verify_export_load_and_inference(self, export_for_serving=False):
"""Verify exported model can be loaded and inference can run successfully.
This function will load the exported model in self.export_dir, then create
some fake images according to the specification of FLAGS.dataset_name.
It then feeds the input through the model, and verify the minimal set of
output signatures are present.
Note: Model and dataset creation in the underlying library depends on the
following commandline flags:
FLAGS.dataset_name
Args:
export_for_serving: True if the model was exported for Serving. This
affects how input is fed into the model.
"""
tf.reset_default_graph()
sess = tf.Session()
graph_def = tf.saved_model.loader.load(
sess=sess,
tags=[tf.saved_model.tag_constants.SERVING],
export_dir=self.export_dir)
feed_dict = self.create_input_feed(graph_def, export_for_serving)
results = sess.run(self.minimal_output_signature, feed_dict=feed_dict)
out_shape = (2,)
self.assertEqual(np.shape(results['predicted_conf']), out_shape)
self.assertEqual(np.shape(results['predicted_text']), out_shape)
self.assertEqual(np.shape(results['predicted_length']), out_shape)
self.assertEqual(np.shape(results['normalized_seq_conf']), out_shape)
out_shape = (2, self.dataset.max_sequence_length)
self.assertEqual(np.shape(results['scores']), out_shape)
self.assertEqual(np.shape(results['predictions']), out_shape)
@flagsaver.flagsaver
def test_fsns_export_for_serving_and_load_inference(self):
model_export.export_model(self.export_dir, True)
self.verify_export_load_and_inference(True)
@flagsaver.flagsaver
def test_fsns_export_and_load_inference(self):
model_export.export_model(self.export_dir, False, batch_size=2)
self.verify_export_load_and_inference(False)
if __name__ == '__main__':
tf.test.main()
......@@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the model."""
import string
import numpy as np
import string
import tensorflow as tf
from tensorflow.contrib import slim
......@@ -32,6 +31,7 @@ def create_fake_charset(num_char_classes):
class ModelTest(tf.test.TestCase):
def setUp(self):
tf.test.TestCase.setUp(self)
......@@ -51,18 +51,21 @@ class ModelTest(tf.test.TestCase):
self.chars_logit_shape = (self.batch_size, self.seq_length,
self.num_char_classes)
self.length_logit_shape = (self.batch_size, self.seq_length + 1)
# Placeholder knows image dimensions, but not batch size.
self.input_images = tf.placeholder(
tf.float32,
shape=(None, self.image_height, self.image_width, 3),
name='input_node')
self.initialize_fakes()
def initialize_fakes(self):
self.images_shape = (self.batch_size, self.image_height, self.image_width,
3)
self.fake_images = tf.constant(
self.rng.randint(low=0, high=255,
size=self.images_shape).astype('float32'),
name='input_node')
self.fake_conv_tower_np = self.rng.randn(
*self.conv_tower_shape).astype('float32')
self.fake_images = self.rng.randint(
low=0, high=255, size=self.images_shape).astype('float32')
self.fake_conv_tower_np = self.rng.randn(*self.conv_tower_shape).astype(
'float32')
self.fake_conv_tower = tf.constant(self.fake_conv_tower_np)
self.fake_logits = tf.constant(
self.rng.randn(*self.chars_logit_shape).astype('float32'))
......@@ -74,33 +77,44 @@ class ModelTest(tf.test.TestCase):
def create_model(self, charset=None):
return model.Model(
self.num_char_classes, self.seq_length, num_views=4, null_code=62,
self.num_char_classes,
self.seq_length,
num_views=4,
null_code=62,
charset=charset)
def test_char_related_shapes(self):
ocr_model = self.create_model()
charset = create_fake_charset(self.num_char_classes)
ocr_model = self.create_model(charset=charset)
with self.test_session() as sess:
endpoints_tf = ocr_model.create_base(
images=self.fake_images, labels_one_hot=None)
images=self.input_images, labels_one_hot=None)
sess.run(tf.global_variables_initializer())
endpoints = sess.run(endpoints_tf)
self.assertEqual((self.batch_size, self.seq_length,
self.num_char_classes), endpoints.chars_logit.shape)
self.assertEqual((self.batch_size, self.seq_length,
self.num_char_classes), endpoints.chars_log_prob.shape)
tf.tables_initializer().run()
endpoints = sess.run(
endpoints_tf, feed_dict={self.input_images: self.fake_images})
self.assertEqual(
(self.batch_size, self.seq_length, self.num_char_classes),
endpoints.chars_logit.shape)
self.assertEqual(
(self.batch_size, self.seq_length, self.num_char_classes),
endpoints.chars_log_prob.shape)
self.assertEqual((self.batch_size, self.seq_length),
endpoints.predicted_chars.shape)
self.assertEqual((self.batch_size, self.seq_length),
endpoints.predicted_scores.shape)
self.assertEqual((self.batch_size,), endpoints.predicted_text.shape)
self.assertEqual((self.batch_size,), endpoints.predicted_conf.shape)
self.assertEqual((self.batch_size,), endpoints.normalized_seq_conf.shape)
def test_predicted_scores_are_within_range(self):
ocr_model = self.create_model()
_, _, scores = ocr_model.char_predictions(self.fake_logits)
with self.test_session() as sess:
scores_np = sess.run(scores)
scores_np = sess.run(
scores, feed_dict={self.input_images: self.fake_images})
values_in_range = (scores_np >= 0.0) & (scores_np <= 1.0)
self.assertTrue(
......@@ -111,10 +125,11 @@ class ModelTest(tf.test.TestCase):
def test_conv_tower_shape(self):
with self.test_session() as sess:
ocr_model = self.create_model()
conv_tower = ocr_model.conv_tower_fn(self.fake_images)
conv_tower = ocr_model.conv_tower_fn(self.input_images)
sess.run(tf.global_variables_initializer())
conv_tower_np = sess.run(conv_tower)
conv_tower_np = sess.run(
conv_tower, feed_dict={self.input_images: self.fake_images})
self.assertEqual(self.conv_tower_shape, conv_tower_np.shape)
......@@ -124,11 +139,12 @@ class ModelTest(tf.test.TestCase):
# updates, gradients and variances. It also depends on the type of used
# optimizer.
ocr_model = self.create_model()
ocr_model.create_base(images=self.fake_images, labels_one_hot=None)
ocr_model.create_base(images=self.input_images, labels_one_hot=None)
with self.test_session() as sess:
tfprof_root = tf.profiler.profile(
sess.graph,
options=tf.profiler.ProfileOptionBuilder.trainable_variables_parameter())
options=tf.profiler.ProfileOptionBuilder
.trainable_variables_parameter())
model_size_bytes = 4 * tfprof_root.total_parameters
self.assertLess(model_size_bytes, 1 * 2**30)
......@@ -158,7 +174,7 @@ class ModelTest(tf.test.TestCase):
loss = model.sequence_loss_fn(self.fake_logits, self.fake_labels)
with self.test_session() as sess:
loss_np = sess.run(loss)
loss_np = sess.run(loss, feed_dict={self.input_images: self.fake_images})
# This test checks that the loss function is 'runnable'.
self.assertEqual(loss_np.shape, tuple())
......@@ -172,19 +188,20 @@ class ModelTest(tf.test.TestCase):
Returns:
a list of tensors with encoded image coordinates in them.
"""
batch_size, h, w, _ = net.shape.as_list()
batch_size = tf.shape(net)[0]
_, h, w, _ = net.shape.as_list()
h_loc = [
tf.tile(
tf.reshape(
tf.contrib.layers.one_hot_encoding(
tf.constant([i]), num_classes=h), [h, 1]), [1, w])
for i in range(h)
tf.tile(
tf.reshape(
tf.contrib.layers.one_hot_encoding(
tf.constant([i]), num_classes=h), [h, 1]), [1, w])
for i in range(h)
]
h_loc = tf.concat([tf.expand_dims(t, 2) for t in h_loc], 2)
w_loc = [
tf.tile(
tf.contrib.layers.one_hot_encoding(tf.constant([i]), num_classes=w),
[h, 1]) for i in range(w)
tf.tile(
tf.contrib.layers.one_hot_encoding(tf.constant([i]), num_classes=w),
[h, 1]) for i in range(w)
]
w_loc = tf.concat([tf.expand_dims(t, 2) for t in w_loc], 2)
loc = tf.concat([h_loc, w_loc], 2)
......@@ -197,11 +214,12 @@ class ModelTest(tf.test.TestCase):
conv_w_coords_tf = model.encode_coordinates_fn(self.fake_conv_tower)
with self.test_session() as sess:
conv_w_coords = sess.run(conv_w_coords_tf)
conv_w_coords = sess.run(
conv_w_coords_tf, feed_dict={self.input_images: self.fake_images})
batch_size, height, width, feature_size = self.conv_tower_shape
self.assertEqual(conv_w_coords.shape, (batch_size, height, width,
feature_size + height + width))
self.assertEqual(conv_w_coords.shape,
(batch_size, height, width, feature_size + height + width))
def test_disabled_coordinate_encoding_returns_features_unchanged(self):
model = self.create_model()
......@@ -209,7 +227,8 @@ class ModelTest(tf.test.TestCase):
conv_w_coords_tf = model.encode_coordinates_fn(self.fake_conv_tower)
with self.test_session() as sess:
conv_w_coords = sess.run(conv_w_coords_tf)
conv_w_coords = sess.run(
conv_w_coords_tf, feed_dict={self.input_images: self.fake_images})
self.assertAllEqual(conv_w_coords, self.fake_conv_tower_np)
......@@ -221,7 +240,8 @@ class ModelTest(tf.test.TestCase):
conv_w_coords_tf = model.encode_coordinates_fn(fake_conv_tower)
with self.test_session() as sess:
conv_w_coords = sess.run(conv_w_coords_tf)
conv_w_coords = sess.run(
conv_w_coords_tf, feed_dict={self.input_images: self.fake_images})
# Original features
self.assertAllEqual(conv_w_coords[0, :, :, :4],
......@@ -261,10 +281,11 @@ class ModelTest(tf.test.TestCase):
class CharsetMapperTest(tf.test.TestCase):
def test_text_corresponds_to_ids(self):
charset = create_fake_charset(36)
ids = tf.constant(
[[17, 14, 21, 21, 24], [32, 24, 27, 21, 13]], dtype=tf.int64)
ids = tf.constant([[17, 14, 21, 21, 24], [32, 24, 27, 21, 13]],
dtype=tf.int64)
charset_mapper = model.CharsetMapper(charset)
with self.test_session() as sess:
......
......@@ -111,7 +111,7 @@ class SequenceLayerBase(object):
self._mparams = method_params
self._net = net
self._labels_one_hot = labels_one_hot
self._batch_size = net.get_shape().dims[0].value
self._batch_size = tf.shape(net)[0]
# Initialize parameters for char logits which will be computed on the fly
# inside an LSTM decoder.
......@@ -275,7 +275,7 @@ class NetSlice(SequenceLayerBase):
def __init__(self, *args, **kwargs):
super(NetSlice, self).__init__(*args, **kwargs)
self._zero_label = tf.zeros(
[self._batch_size, self._params.num_char_classes])
tf.stack([self._batch_size, self._params.num_char_classes]))
def get_image_feature(self, char_index):
"""Returns a subset of image features for a character.
......@@ -352,7 +352,7 @@ class Attention(SequenceLayerBase):
def __init__(self, *args, **kwargs):
super(Attention, self).__init__(*args, **kwargs)
self._zero_label = tf.zeros(
[self._batch_size, self._params.num_char_classes])
tf.stack([self._batch_size, self._params.num_char_classes]))
def get_eval_input(self, prev, i):
"""See SequenceLayerBase.get_eval_input for details."""
......
......@@ -78,3 +78,20 @@ def variables_to_restore(scope=None, strip_scope=False):
return variable_map
else:
return {v.op.name: v for v in slim.get_variables_to_restore()}
def ConvertAllInputsToTensors(func):
"""A decorator to convert all function's inputs into tensors.
Args:
func: a function to decorate.
Returns:
A decorated function.
"""
def FuncWrapper(*args):
tensors = [tf.convert_to_tensor(a) for a in args]
return func(*tensors)
return FuncWrapper
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment