Commit 51b7c2b3 authored by Martin Wicke's avatar Martin Wicke Committed by GitHub
Browse files

Merge pull request #2166 from alexgorban/master

Demo script to do inference on a trained Attention OCR model
parents 6024579b dff0f0c1
"""A script to run inference on a set of image files.
NOTE #1: The Attention OCR model was trained only using FSNS train dataset and
it will work only for images which look more or less similar to french street
names. In order to apply it to images from a different distribution you need
to retrain (or at least fine-tune) it using images from that distribution.
NOTE #2: This script exists for demo purposes only. It is highly recommended
to use tools and mechanisms provided by the TensorFlow Serving system to run
inference on TensorFlow models in production:
https://www.tensorflow.org/serving/serving_basic
Usage:
python demo_inference.py --batch_size=32 \
--image_path_pattern=./datasets/data/fsns/temp/fsns_train_%02d.png
"""
import numpy as np
import PIL.Image
import tensorflow as tf
from tensorflow.python.platform import flags
import common_flags
import datasets
import model as attention_ocr
FLAGS = flags.FLAGS
common_flags.define()
# e.g. ./datasets/data/fsns/temp/fsns_train_%02d.png
flags.DEFINE_string('image_path_pattern', '',
'A file pattern with a placeholder for the image index.')
def get_dataset_image_size(dataset_name):
# Ideally this info should be exposed through the dataset interface itself.
# But currently it is not available by other means.
ds_module = getattr(datasets, dataset_name)
height, width, _ = ds_module.DEFAULT_CONFIG['image_shape']
return width, height
def load_images(file_pattern, batch_size, dataset_name):
width, height = get_dataset_image_size(dataset_name)
images_actual_data = np.ndarray(shape=(batch_size, height, width, 3),
dtype='float32')
for i in range(batch_size):
path = file_pattern % i
print("Reading %s" % path)
pil_image = PIL.Image.open(tf.gfile.GFile(path))
images_actual_data[i, ...] = np.asarray(pil_image)
return images_actual_data
def load_model(checkpoint, batch_size, dataset_name):
width, height = get_dataset_image_size(dataset_name)
dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
model = common_flags.create_model(
num_char_classes=dataset.num_char_classes,
seq_length=dataset.max_sequence_length,
num_views=dataset.num_of_views,
null_code=dataset.null_code,
charset=dataset.charset)
images_placeholder = tf.placeholder(tf.float32,
shape=[batch_size, height, width, 3])
endpoints = model.create_base(images_placeholder, labels_one_hot=None)
init_fn = model.create_init_fn_to_restore(checkpoint)
return images_placeholder, endpoints, init_fn
def main(_):
images_placeholder, endpoints, init_fn = load_model(FLAGS.checkpoint,
FLAGS.batch_size,
FLAGS.dataset_name)
images_data = load_images(FLAGS.image_path_pattern, FLAGS.batch_size,
FLAGS.dataset_name)
with tf.Session() as sess:
tf.tables_initializer().run() # required by the CharsetMapper
init_fn(sess)
predictions = sess.run(endpoints.predicted_text,
feed_dict={images_placeholder: images_data})
print("Predicted strings:")
for line in predictions:
print(line)
if __name__ == '__main__':
tf.app.run()
......@@ -34,25 +34,25 @@ import metrics
import sequence_layers
import utils
OutputEndpoints = collections.namedtuple('OutputEndpoints', [
'chars_logit', 'chars_log_prob', 'predicted_chars', 'predicted_scores'
'chars_logit', 'chars_log_prob', 'predicted_chars', 'predicted_scores',
'predicted_text'
])
# TODO(gorban): replace with tf.HParams when it is released.
ModelParams = collections.namedtuple('ModelParams', [
'num_char_classes', 'seq_length', 'num_views', 'null_code'
'num_char_classes', 'seq_length', 'num_views', 'null_code'
])
ConvTowerParams = collections.namedtuple('ConvTowerParams', ['final_endpoint'])
SequenceLogitsParams = collections.namedtuple('SequenceLogitsParams', [
'use_attention', 'use_autoregression', 'num_lstm_units', 'weight_decay',
'lstm_state_clip_value'
'use_attention', 'use_autoregression', 'num_lstm_units', 'weight_decay',
'lstm_state_clip_value'
])
SequenceLossParams = collections.namedtuple('SequenceLossParams', [
'label_smoothing', 'ignore_nulls', 'average_across_timesteps'
'label_smoothing', 'ignore_nulls', 'average_across_timesteps'
])
EncodeCoordinatesParams = collections.namedtuple('EncodeCoordinatesParams', [
......@@ -125,11 +125,12 @@ class Model(object):
"""Class to create the Attention OCR Model."""
def __init__(self,
num_char_classes,
seq_length,
num_views,
null_code,
mparams=None):
num_char_classes,
seq_length,
num_views,
null_code,
mparams=None,
charset=None):
"""Initialized model parameters.
Args:
......@@ -140,6 +141,13 @@ class Model(object):
indicates end of a sequence.
mparams: a dictionary with hyper parameters for methods, keys -
function names, values - corresponding namedtuples.
charset: an optional dictionary with a mapping between character ids and
utf8 strings. If specified the OutputEndpoints.predicted_text will
utf8 encoded strings corresponding to the character ids returned by
OutputEndpoints.predicted_chars (by default the predicted_text contains
an empty vector).
NOTE: Make sure you call tf.tables_initializer().run() if the charset
specified.
"""
super(Model, self).__init__()
self._params = ModelParams(
......@@ -150,24 +158,25 @@ class Model(object):
self._mparams = self.default_mparams()
if mparams:
self._mparams.update(mparams)
self._charset = charset
def default_mparams(self):
return {
'conv_tower_fn':
'conv_tower_fn':
ConvTowerParams(final_endpoint='Mixed_5d'),
'sequence_logit_fn':
'sequence_logit_fn':
SequenceLogitsParams(
use_attention=True,
use_autoregression=True,
num_lstm_units=256,
weight_decay=0.00004,
lstm_state_clip_value=10.0),
'sequence_loss_fn':
'sequence_loss_fn':
SequenceLossParams(
label_smoothing=0.1,
ignore_nulls=True,
average_across_timesteps=False),
'encode_coordinates_fn': EncodeCoordinatesParams(enabled=False)
'encode_coordinates_fn': EncodeCoordinatesParams(enabled=False)
}
def set_mparam(self, function, **kwargs):
......@@ -241,7 +250,7 @@ class Model(object):
A tensor with the same size as any input tensors.
"""
batch_size, height, width, num_features = [
d.value for d in nets_list[0].get_shape().dims
d.value for d in nets_list[0].get_shape().dims
]
xy_flat_shape = (batch_size, 1, height * width, num_features)
nets_for_merge = []
......@@ -323,10 +332,10 @@ class Model(object):
return net
def create_base(self,
images,
labels_one_hot,
scope='AttentionOcr_v1',
reuse=None):
images,
labels_one_hot,
scope='AttentionOcr_v1',
reuse=None):
"""Creates a base part of the Model (no gradients, losses or summaries).
Args:
......@@ -348,8 +357,8 @@ class Model(object):
logging.debug('Views=%d single view: %s', len(views), views[0])
nets = [
self.conv_tower_fn(v, is_training, reuse=(i != 0))
for i, v in enumerate(views)
self.conv_tower_fn(v, is_training, reuse=(i != 0))
for i, v in enumerate(views)
]
logging.debug('Conv tower: %s', nets[0])
......@@ -363,13 +372,18 @@ class Model(object):
logging.debug('chars_logit: %s', chars_logit)
predicted_chars, chars_log_prob, predicted_scores = (
self.char_predictions(chars_logit))
self.char_predictions(chars_logit))
if self._charset:
character_mapper = CharsetMapper(self._charset)
predicted_text = character_mapper.get_text(predicted_chars)
else:
predicted_text = tf.constant([])
return OutputEndpoints(
chars_logit=chars_logit,
chars_log_prob=chars_log_prob,
predicted_chars=predicted_chars,
predicted_scores=predicted_scores)
predicted_scores=predicted_scores,
predicted_text=predicted_text)
def create_loss(self, data, endpoints):
"""Creates all losses required to train the model.
......@@ -523,7 +537,8 @@ class Model(object):
tf.summary.scalar(summary_name, tf.Print(value, [value], summary_name))
return names_to_updates.values()
def create_init_fn_to_restore(self, master_checkpoint, inception_checkpoint):
def create_init_fn_to_restore(self, master_checkpoint,
inception_checkpoint=None):
"""Creates an init operations to restore weights from various checkpoints.
Args:
......
......@@ -73,9 +73,10 @@ class ModelTest(tf.test.TestCase):
high=self.num_char_classes,
size=(self.batch_size, self.seq_length)).astype('int64'))
def create_model(self):
def create_model(self, charset=None):
return model.Model(
self.num_char_classes, self.seq_length, num_views=4, null_code=62)
self.num_char_classes, self.seq_length, num_views=4, null_code=62,
charset=charset)
def test_char_related_shapes(self):
ocr_model = self.create_model()
......@@ -244,6 +245,21 @@ class ModelTest(tf.test.TestCase):
self.assertAllEqual(conv_w_coords_tf, conv_w_coords_alt_tf)
def test_predicted_text_has_correct_shape_w_charset(self):
charset = create_fake_charset(self.num_char_classes)
ocr_model = self.create_model(charset=charset)
with self.test_session() as sess:
endpoints_tf = ocr_model.create_base(
images=self.fake_images, labels_one_hot=None)
sess.run(tf.global_variables_initializer())
tf.tables_initializer().run()
endpoints = sess.run(endpoints_tf)
self.assertEqual(endpoints.predicted_text.shape, (self.batch_size,))
self.assertEqual(len(endpoints.predicted_text[0]), self.seq_length)
class CharsetMapperTest(tf.test.TestCase):
def test_text_corresponds_to_ids(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment