Unverified Commit 32671be9 authored by Xavier Gibert's avatar Xavier Gibert Committed by GitHub
Browse files

attention_ocr: added export for SavedModel format. (#8757)

* Added export for SavedModel format.

* Fixed some pylint errors.
parent b548c7fd
...@@ -166,6 +166,14 @@ implement one in Python or C++. ...@@ -166,6 +166,14 @@ implement one in Python or C++.
The recommended way is to use the [Serving infrastructure][serving]. The recommended way is to use the [Serving infrastructure][serving].
To export to SavedModel format:
```
python model_export.py \
--checkpoint=model.ckpt-399731 \
--export_dir=/tmp/attention_ocr_export
```
Alternatively you can: Alternatively you can:
1. define a placeholder for images (or use directly an numpy array) 1. define a placeholder for images (or use directly an numpy array)
2. [create a graph ](https://github.com/tensorflow/models/blob/master/research/attention_ocr/python/eval.py#L60) 2. [create a graph ](https://github.com/tensorflow/models/blob/master/research/attention_ocr/python/eval.py#L60)
...@@ -188,7 +196,7 @@ other than a one time experiment please use the [TensorFlow Serving][serving]. ...@@ -188,7 +196,7 @@ other than a one time experiment please use the [TensorFlow Serving][serving].
[1]: https://github.com/tensorflow/tensorflow/blob/aaf7adc/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py [1]: https://github.com/tensorflow/tensorflow/blob/aaf7adc/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py
[2]: https://www.tensorflow.org/api_docs/python/tf/contrib/framework/assign_from_checkpoint_fn [2]: https://www.tensorflow.org/api_docs/python/tf/contrib/framework/assign_from_checkpoint_fn
[serving]: https://tensorflow.github.io/serving/serving_basic [serving]: https://www.tensorflow.org/tfx/serving/serving_basic
## Disclaimer ## Disclaimer
......
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
# ============================================================================== # ==============================================================================
"""Define flags are common for both train.py and eval.py scripts.""" """Define flags are common for both train.py and eval.py scripts."""
import logging
import sys import sys
from tensorflow.python.platform import flags from tensorflow.python.platform import flags
import logging
import datasets import datasets
import model import model
...@@ -35,9 +35,17 @@ logging.basicConfig( ...@@ -35,9 +35,17 @@ logging.basicConfig(
datefmt='%Y-%m-%d %H:%M:%S') datefmt='%Y-%m-%d %H:%M:%S')
_common_flags_defined = False
def define(): def define():
"""Define common flags.""" """Define common flags."""
# yapf: disable # yapf: disable
# common_flags.define() may be called multiple times in unit tests.
global _common_flags_defined
if _common_flags_defined:
return
_common_flags_defined = True
flags.DEFINE_integer('batch_size', 32, flags.DEFINE_integer('batch_size', 32,
'Batch size.') 'Batch size.')
...@@ -74,7 +82,7 @@ def define(): ...@@ -74,7 +82,7 @@ def define():
'the optimizer to use') 'the optimizer to use')
flags.DEFINE_float('momentum', 0.9, flags.DEFINE_float('momentum', 0.9,
'momentum value for the momentum optimizer if used') 'momentum value for the momentum optimizer if used')
flags.DEFINE_bool('use_augment_input', True, flags.DEFINE_bool('use_augment_input', True,
'If True will use image augmentation') 'If True will use image augmentation')
......
...@@ -144,9 +144,6 @@ def preprocess_image(image, augment=False, central_crop_size=None, ...@@ -144,9 +144,6 @@ def preprocess_image(image, augment=False, central_crop_size=None,
images = [augment_image(img) for img in images] images = [augment_image(img) for img in images]
image = tf.concat(images, 1) image = tf.concat(images, 1)
image = tf.subtract(image, 0.5)
image = tf.multiply(image, 2.5)
return image return image
......
...@@ -177,6 +177,8 @@ def get_split(split_name, dataset_dir=None, config=None): ...@@ -177,6 +177,8 @@ def get_split(split_name, dataset_dir=None, config=None):
items_to_descriptions=config['items_to_descriptions'], items_to_descriptions=config['items_to_descriptions'],
# additional parameters for convenience. # additional parameters for convenience.
charset=charset, charset=charset,
charset_file=charset_file,
image_shape=config['image_shape'],
num_char_classes=len(charset), num_char_classes=len(charset),
num_of_views=config['num_of_views'], num_of_views=config['num_of_views'],
max_sequence_length=config['max_sequence_length'], max_sequence_length=config['max_sequence_length'],
......
This diff is collapsed.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Converts existing checkpoint into a SavedModel.
Usage example:
python model_export.py \
--logtostderr --checkpoint=model.ckpt-399731 \
--export_dir=/tmp/attention_ocr_export
"""
import os
import tensorflow as tf
from tensorflow import app
from tensorflow.contrib import slim
from tensorflow.python.platform import flags
import common_flags
import model_export_lib
FLAGS = flags.FLAGS
common_flags.define()
flags.DEFINE_string('export_dir', None, 'Directory to export model files to.')
flags.DEFINE_integer(
'image_width', None,
'Image width used during training (or crop width if used)'
' If not set, the dataset default is used instead.')
flags.DEFINE_integer(
'image_height', None,
'Image height used during training(or crop height if used)'
' If not set, the dataset default is used instead.')
flags.DEFINE_string('work_dir', '/tmp', 'A directory to store temporary files.')
flags.DEFINE_integer('version_number', 1, 'Version number of the model')
flags.DEFINE_bool(
'export_for_serving', True,
'Whether the exported model accepts serialized tf.Example '
'protos as input')
def get_checkpoint_path():
"""Returns a path to a checkpoint based on specified commandline flags.
In order to specify a full path to a checkpoint use --checkpoint flag.
Alternatively, if --train_log_dir was specified it will return a path to the
most recent checkpoint.
Raises:
ValueError: in case it can't find a checkpoint.
Returns:
A string.
"""
if FLAGS.checkpoint:
return FLAGS.checkpoint
else:
model_save_path = tf.train.latest_checkpoint(FLAGS.train_log_dir)
if not model_save_path:
raise ValueError('Can\'t find a checkpoint in: %s' % FLAGS.train_log_dir)
return model_save_path
def export_model(export_dir,
export_for_serving,
batch_size=None,
crop_image_width=None,
crop_image_height=None):
"""Exports a model to the named directory.
Note that --datatset_name and --checkpoint are required and parsed by the
underlying module common_flags.
Args:
export_dir: The output dir where model is exported to.
export_for_serving: If True, expects a serialized image as input and attach
image normalization as part of exported graph.
batch_size: For non-serving export, the input batch_size needs to be
specified.
crop_image_width: Width of the input image. Uses the dataset default if
None.
crop_image_height: Height of the input image. Uses the dataset default if
None.
Returns:
Returns the model signature_def.
"""
# Dataset object used only to get all parameters for the model.
dataset = common_flags.create_dataset(split_name='test')
model = common_flags.create_model(
dataset.num_char_classes,
dataset.max_sequence_length,
dataset.num_of_views,
dataset.null_code,
charset=dataset.charset)
dataset_image_height, dataset_image_width, image_depth = dataset.image_shape
# Add check for charmap file
if not os.path.exists(dataset.charset_file):
raise ValueError('No charset defined at {}: export will fail'.format(
dataset.charset))
# Default to dataset dimensions, otherwise use provided dimensions.
image_width = crop_image_width or dataset_image_width
image_height = crop_image_height or dataset_image_height
if export_for_serving:
images_orig = tf.placeholder(
tf.string, shape=[batch_size], name='tf_example')
images_orig_float = model_export_lib.generate_tfexample_image(
images_orig,
image_height,
image_width,
image_depth,
name='float_images')
else:
images_shape = (batch_size, image_height, image_width, image_depth)
images_orig = tf.placeholder(
tf.uint8, shape=images_shape, name='original_image')
images_orig_float = tf.image.convert_image_dtype(
images_orig, dtype=tf.float32, name='float_images')
endpoints = model.create_base(images_orig_float, labels_one_hot=None)
sess = tf.Session()
saver = tf.train.Saver(slim.get_variables_to_restore(), sharded=True)
saver.restore(sess, get_checkpoint_path())
tf.logging.info('Model restored successfully.')
# Create model signature.
if export_for_serving:
input_tensors = {
tf.saved_model.signature_constants.CLASSIFY_INPUTS: images_orig
}
else:
input_tensors = {'images': images_orig}
signature_inputs = model_export_lib.build_tensor_info(input_tensors)
# NOTE: Tensors 'image_float' and 'chars_logit' are used by the inference
# or to compute saliency maps.
output_tensors = {
'images_float': images_orig_float,
'predictions': endpoints.predicted_chars,
'scores': endpoints.predicted_scores,
'chars_logit': endpoints.chars_logit,
'predicted_length': endpoints.predicted_length,
'predicted_text': endpoints.predicted_text,
'predicted_conf': endpoints.predicted_conf,
'normalized_seq_conf': endpoints.normalized_seq_conf
}
for i, t in enumerate(
model_export_lib.attention_ocr_attention_masks(
dataset.max_sequence_length)):
output_tensors['attention_mask_%d' % i] = t
signature_outputs = model_export_lib.build_tensor_info(output_tensors)
signature_def = tf.saved_model.signature_def_utils.build_signature_def(
signature_inputs, signature_outputs,
tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME)
# Save model.
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
signature_def
},
main_op=tf.tables_initializer(),
strip_default_attrs=True)
builder.save()
tf.logging.info('Model has been exported to %s' % export_dir)
return signature_def
def main(unused_argv):
if os.path.exists(FLAGS.export_dir):
raise ValueError('export_dir already exists: exporting will fail')
export_model(FLAGS.export_dir, FLAGS.export_for_serving, FLAGS.batch_size,
FLAGS.image_width, FLAGS.image_height)
if __name__ == '__main__':
flags.mark_flag_as_required('dataset_name')
flags.mark_flag_as_required('export_dir')
app.run(main)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions for exporting Attention OCR model."""
import tensorflow as tf
# Function borrowed from research/object_detection/core/preprocessor.py
def normalize_image(image, original_minval, original_maxval, target_minval,
target_maxval):
"""Normalizes pixel values in the image.
Moves the pixel values from the current [original_minval, original_maxval]
range to a the [target_minval, target_maxval] range.
Args:
image: rank 3 float32 tensor containing 1 image -> [height, width,
channels].
original_minval: current image minimum value.
original_maxval: current image maximum value.
target_minval: target image minimum value.
target_maxval: target image maximum value.
Returns:
image: image which is the same shape as input image.
"""
with tf.name_scope('NormalizeImage', values=[image]):
original_minval = float(original_minval)
original_maxval = float(original_maxval)
target_minval = float(target_minval)
target_maxval = float(target_maxval)
image = tf.cast(image, dtype=tf.float32)
image = tf.subtract(image, original_minval)
image = tf.multiply(image, (target_maxval - target_minval) /
(original_maxval - original_minval))
image = tf.add(image, target_minval)
return image
def generate_tfexample_image(input_example_strings,
image_height,
image_width,
image_channels,
name=None):
"""Parses a 1D tensor of serialized tf.Example protos and returns image batch.
Args:
input_example_strings: A 1-Dimensional tensor of size [batch_size] and type
tf.string containing a serialized Example proto per image.
image_height: First image dimension.
image_width: Second image dimension.
image_channels: Third image dimension.
name: optional tensor name.
Returns:
A tensor with shape [batch_size, height, width, channels] of type float32
with values in the range [0..1]
"""
batch_size = tf.shape(input_example_strings)[0]
images_shape = tf.stack(
[batch_size, image_height, image_width, image_channels])
tf_example_image_key = 'image/encoded'
feature_configs = {
tf_example_image_key:
tf.FixedLenFeature(
image_height * image_width * image_channels, dtype=tf.float32)
}
feature_tensors = tf.parse_example(input_example_strings, feature_configs)
float_images = tf.reshape(
normalize_image(
feature_tensors[tf_example_image_key],
original_minval=0.0,
original_maxval=255.0,
target_minval=0.0,
target_maxval=1.0),
images_shape,
name=name)
return float_images
def attention_ocr_attention_masks(num_characters):
# TODO(gorban): use tensors directly after replacing LSTM unroll methods.
prefix = ('AttentionOcr_v1/'
'sequence_logit_fn/SQLR/LSTM/attention_decoder/Attention_0')
names = ['%s/Softmax:0' % (prefix)]
for i in range(1, num_characters):
names += ['%s_%d/Softmax:0' % (prefix, i)]
return [tf.get_default_graph().get_tensor_by_name(n) for n in names]
def build_tensor_info(tensor_dict):
return {
k: tf.saved_model.utils.build_tensor_info(t)
for k, t in tensor_dict.items()
}
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for model_export."""
import os
import numpy as np
from absl.testing import flagsaver
import tensorflow as tf
import common_flags
import model_export
_CHECKPOINT = 'model.ckpt-399731'
_CHECKPOINT_URL = (
'http://download.tensorflow.org/models/attention_ocr_2017_08_09.tar.gz')
def _clean_up():
tf.gfile.DeleteRecursively(tf.test.get_temp_dir())
def _create_tf_example_string(image):
"""Create a serialized tf.Example proto for feeding the model."""
example = tf.train.Example()
example.features.feature['image/encoded'].float_list.value.extend(
list(np.reshape(image, (-1))))
return example.SerializeToString()
class AttentionOcrExportTest(tf.test.TestCase):
"""Tests for model_export.export_model."""
def setUp(self):
for suffix in ['.meta', '.index', '.data-00000-of-00001']:
filename = _CHECKPOINT + suffix
self.assertTrue(
tf.gfile.Exists(filename),
msg='Missing checkpoint file %s. '
'Please download and extract it from %s' %
(filename, _CHECKPOINT_URL))
tf.flags.FLAGS.dataset_name = 'fsns'
tf.flags.FLAGS.checkpoint = _CHECKPOINT
tf.flags.FLAGS.dataset_dir = os.path.join(
os.path.dirname(__file__), 'datasets/testdata/fsns')
tf.test.TestCase.setUp(self)
_clean_up()
self.export_dir = os.path.join(tf.test.get_temp_dir(), 'exported_model')
self.minimal_output_signature = {
'predictions': 'AttentionOcr_v1/predicted_chars:0',
'scores': 'AttentionOcr_v1/predicted_scores:0',
'predicted_length': 'AttentionOcr_v1/predicted_length:0',
'predicted_text': 'AttentionOcr_v1/predicted_text:0',
'predicted_conf': 'AttentionOcr_v1/predicted_conf:0',
'normalized_seq_conf': 'AttentionOcr_v1/normalized_seq_conf:0'
}
def create_input_feed(self, graph_def, serving):
"""Returns the input feed for the model.
Creates random images, according to the size specified by dataset_name,
format it in the correct way depending on whether the model was exported
for serving, and return the correctly keyed feed_dict for inference.
Args:
graph_def: Graph definition of the loaded model.
serving: Whether the model was exported for Serving.
Returns:
The feed_dict suitable for model inference.
"""
# Creates a dataset based on FLAGS.dataset_name.
self.dataset = common_flags.create_dataset('test')
# Create some random images to test inference for any dataset.
self.images = {
'img1':
np.random.uniform(low=64, high=192,
size=self.dataset.image_shape).astype('uint8'),
'img2':
np.random.uniform(low=32, high=224,
size=self.dataset.image_shape).astype('uint8'),
}
signature_def = graph_def.signature_def[
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
if serving:
input_name = signature_def.inputs[
tf.saved_model.signature_constants.CLASSIFY_INPUTS].name
# Model for serving takes input: inputs['inputs'] = 'tf_example:0'
feed_dict = {
input_name: [
_create_tf_example_string(self.images['img1']),
_create_tf_example_string(self.images['img2'])
]
}
else:
input_name = signature_def.inputs['images'].name
# Model for direct use takes input: inputs['images'] = 'original_image:0'
feed_dict = {
input_name: np.stack([self.images['img1'], self.images['img2']])
}
return feed_dict
def verify_export_load_and_inference(self, export_for_serving=False):
"""Verify exported model can be loaded and inference can run successfully.
This function will load the exported model in self.export_dir, then create
some fake images according to the specification of FLAGS.dataset_name.
It then feeds the input through the model, and verify the minimal set of
output signatures are present.
Note: Model and dataset creation in the underlying library depends on the
following commandline flags:
FLAGS.dataset_name
Args:
export_for_serving: True if the model was exported for Serving. This
affects how input is fed into the model.
"""
tf.reset_default_graph()
sess = tf.Session()
graph_def = tf.saved_model.loader.load(
sess=sess,
tags=[tf.saved_model.tag_constants.SERVING],
export_dir=self.export_dir)
feed_dict = self.create_input_feed(graph_def, export_for_serving)
results = sess.run(self.minimal_output_signature, feed_dict=feed_dict)
out_shape = (2,)
self.assertEqual(np.shape(results['predicted_conf']), out_shape)
self.assertEqual(np.shape(results['predicted_text']), out_shape)
self.assertEqual(np.shape(results['predicted_length']), out_shape)
self.assertEqual(np.shape(results['normalized_seq_conf']), out_shape)
out_shape = (2, self.dataset.max_sequence_length)
self.assertEqual(np.shape(results['scores']), out_shape)
self.assertEqual(np.shape(results['predictions']), out_shape)
@flagsaver.flagsaver
def test_fsns_export_for_serving_and_load_inference(self):
model_export.export_model(self.export_dir, True)
self.verify_export_load_and_inference(True)
@flagsaver.flagsaver
def test_fsns_export_and_load_inference(self):
model_export.export_model(self.export_dir, False, batch_size=2)
self.verify_export_load_and_inference(False)
if __name__ == '__main__':
tf.test.main()
...@@ -12,11 +12,10 @@ ...@@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for the model.""" """Tests for the model."""
import string
import numpy as np import numpy as np
import string
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib import slim from tensorflow.contrib import slim
...@@ -32,6 +31,7 @@ def create_fake_charset(num_char_classes): ...@@ -32,6 +31,7 @@ def create_fake_charset(num_char_classes):
class ModelTest(tf.test.TestCase): class ModelTest(tf.test.TestCase):
def setUp(self): def setUp(self):
tf.test.TestCase.setUp(self) tf.test.TestCase.setUp(self)
...@@ -51,18 +51,21 @@ class ModelTest(tf.test.TestCase): ...@@ -51,18 +51,21 @@ class ModelTest(tf.test.TestCase):
self.chars_logit_shape = (self.batch_size, self.seq_length, self.chars_logit_shape = (self.batch_size, self.seq_length,
self.num_char_classes) self.num_char_classes)
self.length_logit_shape = (self.batch_size, self.seq_length + 1) self.length_logit_shape = (self.batch_size, self.seq_length + 1)
# Placeholder knows image dimensions, but not batch size.
self.input_images = tf.placeholder(
tf.float32,
shape=(None, self.image_height, self.image_width, 3),
name='input_node')
self.initialize_fakes() self.initialize_fakes()
def initialize_fakes(self): def initialize_fakes(self):
self.images_shape = (self.batch_size, self.image_height, self.image_width, self.images_shape = (self.batch_size, self.image_height, self.image_width,
3) 3)
self.fake_images = tf.constant( self.fake_images = self.rng.randint(
self.rng.randint(low=0, high=255, low=0, high=255, size=self.images_shape).astype('float32')
size=self.images_shape).astype('float32'), self.fake_conv_tower_np = self.rng.randn(*self.conv_tower_shape).astype(
name='input_node') 'float32')
self.fake_conv_tower_np = self.rng.randn(
*self.conv_tower_shape).astype('float32')
self.fake_conv_tower = tf.constant(self.fake_conv_tower_np) self.fake_conv_tower = tf.constant(self.fake_conv_tower_np)
self.fake_logits = tf.constant( self.fake_logits = tf.constant(
self.rng.randn(*self.chars_logit_shape).astype('float32')) self.rng.randn(*self.chars_logit_shape).astype('float32'))
...@@ -74,33 +77,44 @@ class ModelTest(tf.test.TestCase): ...@@ -74,33 +77,44 @@ class ModelTest(tf.test.TestCase):
def create_model(self, charset=None): def create_model(self, charset=None):
return model.Model( return model.Model(
self.num_char_classes, self.seq_length, num_views=4, null_code=62, self.num_char_classes,
self.seq_length,
num_views=4,
null_code=62,
charset=charset) charset=charset)
def test_char_related_shapes(self): def test_char_related_shapes(self):
ocr_model = self.create_model() charset = create_fake_charset(self.num_char_classes)
ocr_model = self.create_model(charset=charset)
with self.test_session() as sess: with self.test_session() as sess:
endpoints_tf = ocr_model.create_base( endpoints_tf = ocr_model.create_base(
images=self.fake_images, labels_one_hot=None) images=self.input_images, labels_one_hot=None)
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
endpoints = sess.run(endpoints_tf) tf.tables_initializer().run()
endpoints = sess.run(
self.assertEqual((self.batch_size, self.seq_length, endpoints_tf, feed_dict={self.input_images: self.fake_images})
self.num_char_classes), endpoints.chars_logit.shape)
self.assertEqual((self.batch_size, self.seq_length, self.assertEqual(
self.num_char_classes), endpoints.chars_log_prob.shape) (self.batch_size, self.seq_length, self.num_char_classes),
endpoints.chars_logit.shape)
self.assertEqual(
(self.batch_size, self.seq_length, self.num_char_classes),
endpoints.chars_log_prob.shape)
self.assertEqual((self.batch_size, self.seq_length), self.assertEqual((self.batch_size, self.seq_length),
endpoints.predicted_chars.shape) endpoints.predicted_chars.shape)
self.assertEqual((self.batch_size, self.seq_length), self.assertEqual((self.batch_size, self.seq_length),
endpoints.predicted_scores.shape) endpoints.predicted_scores.shape)
self.assertEqual((self.batch_size,), endpoints.predicted_text.shape)
self.assertEqual((self.batch_size,), endpoints.predicted_conf.shape)
self.assertEqual((self.batch_size,), endpoints.normalized_seq_conf.shape)
def test_predicted_scores_are_within_range(self): def test_predicted_scores_are_within_range(self):
ocr_model = self.create_model() ocr_model = self.create_model()
_, _, scores = ocr_model.char_predictions(self.fake_logits) _, _, scores = ocr_model.char_predictions(self.fake_logits)
with self.test_session() as sess: with self.test_session() as sess:
scores_np = sess.run(scores) scores_np = sess.run(
scores, feed_dict={self.input_images: self.fake_images})
values_in_range = (scores_np >= 0.0) & (scores_np <= 1.0) values_in_range = (scores_np >= 0.0) & (scores_np <= 1.0)
self.assertTrue( self.assertTrue(
...@@ -111,10 +125,11 @@ class ModelTest(tf.test.TestCase): ...@@ -111,10 +125,11 @@ class ModelTest(tf.test.TestCase):
def test_conv_tower_shape(self): def test_conv_tower_shape(self):
with self.test_session() as sess: with self.test_session() as sess:
ocr_model = self.create_model() ocr_model = self.create_model()
conv_tower = ocr_model.conv_tower_fn(self.fake_images) conv_tower = ocr_model.conv_tower_fn(self.input_images)
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
conv_tower_np = sess.run(conv_tower) conv_tower_np = sess.run(
conv_tower, feed_dict={self.input_images: self.fake_images})
self.assertEqual(self.conv_tower_shape, conv_tower_np.shape) self.assertEqual(self.conv_tower_shape, conv_tower_np.shape)
...@@ -124,11 +139,12 @@ class ModelTest(tf.test.TestCase): ...@@ -124,11 +139,12 @@ class ModelTest(tf.test.TestCase):
# updates, gradients and variances. It also depends on the type of used # updates, gradients and variances. It also depends on the type of used
# optimizer. # optimizer.
ocr_model = self.create_model() ocr_model = self.create_model()
ocr_model.create_base(images=self.fake_images, labels_one_hot=None) ocr_model.create_base(images=self.input_images, labels_one_hot=None)
with self.test_session() as sess: with self.test_session() as sess:
tfprof_root = tf.profiler.profile( tfprof_root = tf.profiler.profile(
sess.graph, sess.graph,
options=tf.profiler.ProfileOptionBuilder.trainable_variables_parameter()) options=tf.profiler.ProfileOptionBuilder
.trainable_variables_parameter())
model_size_bytes = 4 * tfprof_root.total_parameters model_size_bytes = 4 * tfprof_root.total_parameters
self.assertLess(model_size_bytes, 1 * 2**30) self.assertLess(model_size_bytes, 1 * 2**30)
...@@ -158,7 +174,7 @@ class ModelTest(tf.test.TestCase): ...@@ -158,7 +174,7 @@ class ModelTest(tf.test.TestCase):
loss = model.sequence_loss_fn(self.fake_logits, self.fake_labels) loss = model.sequence_loss_fn(self.fake_logits, self.fake_labels)
with self.test_session() as sess: with self.test_session() as sess:
loss_np = sess.run(loss) loss_np = sess.run(loss, feed_dict={self.input_images: self.fake_images})
# This test checks that the loss function is 'runnable'. # This test checks that the loss function is 'runnable'.
self.assertEqual(loss_np.shape, tuple()) self.assertEqual(loss_np.shape, tuple())
...@@ -172,19 +188,20 @@ class ModelTest(tf.test.TestCase): ...@@ -172,19 +188,20 @@ class ModelTest(tf.test.TestCase):
Returns: Returns:
a list of tensors with encoded image coordinates in them. a list of tensors with encoded image coordinates in them.
""" """
batch_size, h, w, _ = net.shape.as_list() batch_size = tf.shape(net)[0]
_, h, w, _ = net.shape.as_list()
h_loc = [ h_loc = [
tf.tile( tf.tile(
tf.reshape( tf.reshape(
tf.contrib.layers.one_hot_encoding( tf.contrib.layers.one_hot_encoding(
tf.constant([i]), num_classes=h), [h, 1]), [1, w]) tf.constant([i]), num_classes=h), [h, 1]), [1, w])
for i in range(h) for i in range(h)
] ]
h_loc = tf.concat([tf.expand_dims(t, 2) for t in h_loc], 2) h_loc = tf.concat([tf.expand_dims(t, 2) for t in h_loc], 2)
w_loc = [ w_loc = [
tf.tile( tf.tile(
tf.contrib.layers.one_hot_encoding(tf.constant([i]), num_classes=w), tf.contrib.layers.one_hot_encoding(tf.constant([i]), num_classes=w),
[h, 1]) for i in range(w) [h, 1]) for i in range(w)
] ]
w_loc = tf.concat([tf.expand_dims(t, 2) for t in w_loc], 2) w_loc = tf.concat([tf.expand_dims(t, 2) for t in w_loc], 2)
loc = tf.concat([h_loc, w_loc], 2) loc = tf.concat([h_loc, w_loc], 2)
...@@ -197,11 +214,12 @@ class ModelTest(tf.test.TestCase): ...@@ -197,11 +214,12 @@ class ModelTest(tf.test.TestCase):
conv_w_coords_tf = model.encode_coordinates_fn(self.fake_conv_tower) conv_w_coords_tf = model.encode_coordinates_fn(self.fake_conv_tower)
with self.test_session() as sess: with self.test_session() as sess:
conv_w_coords = sess.run(conv_w_coords_tf) conv_w_coords = sess.run(
conv_w_coords_tf, feed_dict={self.input_images: self.fake_images})
batch_size, height, width, feature_size = self.conv_tower_shape batch_size, height, width, feature_size = self.conv_tower_shape
self.assertEqual(conv_w_coords.shape, (batch_size, height, width, self.assertEqual(conv_w_coords.shape,
feature_size + height + width)) (batch_size, height, width, feature_size + height + width))
def test_disabled_coordinate_encoding_returns_features_unchanged(self): def test_disabled_coordinate_encoding_returns_features_unchanged(self):
model = self.create_model() model = self.create_model()
...@@ -209,7 +227,8 @@ class ModelTest(tf.test.TestCase): ...@@ -209,7 +227,8 @@ class ModelTest(tf.test.TestCase):
conv_w_coords_tf = model.encode_coordinates_fn(self.fake_conv_tower) conv_w_coords_tf = model.encode_coordinates_fn(self.fake_conv_tower)
with self.test_session() as sess: with self.test_session() as sess:
conv_w_coords = sess.run(conv_w_coords_tf) conv_w_coords = sess.run(
conv_w_coords_tf, feed_dict={self.input_images: self.fake_images})
self.assertAllEqual(conv_w_coords, self.fake_conv_tower_np) self.assertAllEqual(conv_w_coords, self.fake_conv_tower_np)
...@@ -221,7 +240,8 @@ class ModelTest(tf.test.TestCase): ...@@ -221,7 +240,8 @@ class ModelTest(tf.test.TestCase):
conv_w_coords_tf = model.encode_coordinates_fn(fake_conv_tower) conv_w_coords_tf = model.encode_coordinates_fn(fake_conv_tower)
with self.test_session() as sess: with self.test_session() as sess:
conv_w_coords = sess.run(conv_w_coords_tf) conv_w_coords = sess.run(
conv_w_coords_tf, feed_dict={self.input_images: self.fake_images})
# Original features # Original features
self.assertAllEqual(conv_w_coords[0, :, :, :4], self.assertAllEqual(conv_w_coords[0, :, :, :4],
...@@ -261,10 +281,11 @@ class ModelTest(tf.test.TestCase): ...@@ -261,10 +281,11 @@ class ModelTest(tf.test.TestCase):
class CharsetMapperTest(tf.test.TestCase): class CharsetMapperTest(tf.test.TestCase):
def test_text_corresponds_to_ids(self): def test_text_corresponds_to_ids(self):
charset = create_fake_charset(36) charset = create_fake_charset(36)
ids = tf.constant( ids = tf.constant([[17, 14, 21, 21, 24], [32, 24, 27, 21, 13]],
[[17, 14, 21, 21, 24], [32, 24, 27, 21, 13]], dtype=tf.int64) dtype=tf.int64)
charset_mapper = model.CharsetMapper(charset) charset_mapper = model.CharsetMapper(charset)
with self.test_session() as sess: with self.test_session() as sess:
......
...@@ -111,7 +111,7 @@ class SequenceLayerBase(object): ...@@ -111,7 +111,7 @@ class SequenceLayerBase(object):
self._mparams = method_params self._mparams = method_params
self._net = net self._net = net
self._labels_one_hot = labels_one_hot self._labels_one_hot = labels_one_hot
self._batch_size = net.get_shape().dims[0].value self._batch_size = tf.shape(net)[0]
# Initialize parameters for char logits which will be computed on the fly # Initialize parameters for char logits which will be computed on the fly
# inside an LSTM decoder. # inside an LSTM decoder.
...@@ -275,7 +275,7 @@ class NetSlice(SequenceLayerBase): ...@@ -275,7 +275,7 @@ class NetSlice(SequenceLayerBase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(NetSlice, self).__init__(*args, **kwargs) super(NetSlice, self).__init__(*args, **kwargs)
self._zero_label = tf.zeros( self._zero_label = tf.zeros(
[self._batch_size, self._params.num_char_classes]) tf.stack([self._batch_size, self._params.num_char_classes]))
def get_image_feature(self, char_index): def get_image_feature(self, char_index):
"""Returns a subset of image features for a character. """Returns a subset of image features for a character.
...@@ -352,7 +352,7 @@ class Attention(SequenceLayerBase): ...@@ -352,7 +352,7 @@ class Attention(SequenceLayerBase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(Attention, self).__init__(*args, **kwargs) super(Attention, self).__init__(*args, **kwargs)
self._zero_label = tf.zeros( self._zero_label = tf.zeros(
[self._batch_size, self._params.num_char_classes]) tf.stack([self._batch_size, self._params.num_char_classes]))
def get_eval_input(self, prev, i): def get_eval_input(self, prev, i):
"""See SequenceLayerBase.get_eval_input for details.""" """See SequenceLayerBase.get_eval_input for details."""
......
...@@ -78,3 +78,20 @@ def variables_to_restore(scope=None, strip_scope=False): ...@@ -78,3 +78,20 @@ def variables_to_restore(scope=None, strip_scope=False):
return variable_map return variable_map
else: else:
return {v.op.name: v for v in slim.get_variables_to_restore()} return {v.op.name: v for v in slim.get_variables_to_restore()}
def ConvertAllInputsToTensors(func):
"""A decorator to convert all function's inputs into tensors.
Args:
func: a function to decorate.
Returns:
A decorated function.
"""
def FuncWrapper(*args):
tensors = [tf.convert_to_tensor(a) for a in args]
return func(*tensors)
return FuncWrapper
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment