Commit da341f70 authored by Alexander Gorban's avatar Alexander Gorban
Browse files

Demo script to do inference on a pre-trained model.

Changes:
- A working version
- Make the predicted_text to be a part of the model.
parent f282f6ef
"""A script to run inference on a set of image files.
NOTE #1: The Attention OCR model was trained only using FSNS train dataset and
it will work only for images which look more or less similar to french street
names. In order to apply it to images from a different distribution you need
to retrain (or at least fine-tune) it using images from that distribution.
NOTE #2: This script exists for demo purposes only. It is highly recommended
to use tools and mechanisms provided by the TensorFlow Serving system to run
inference on TensorFlow models in production:
https://www.tensorflow.org/serving/serving_basic
Usage:
python demo_inference.py --batch_size=32 \
--image_path_pattern=./datasets/data/fsns/temp/fsns_train_%02d.png
"""
import numpy as np
import PIL.Image
import tensorflow as tf
from tensorflow.python.platform import flags
import common_flags
import datasets
import model as attention_ocr
FLAGS = flags.FLAGS
common_flags.define()
# e.g. ./datasets/data/fsns/temp/fsns_train_%02d.png
flags.DEFINE_string('image_path_pattern', '',
'A file pattern with a placeholder for the image index.')
def get_dataset_image_size(dataset_name):
# Ideally this info should be exposed through the dataset interface itself.
# But currently it is not available by other means.
ds_module = getattr(datasets, dataset_name)
height, width, _ = ds_module.DEFAULT_CONFIG['image_shape']
return width, height
def load_images(file_pattern, batch_size, dataset_name):
width, height = get_dataset_image_size(dataset_name)
images_actual_data = np.ndarray(shape=(batch_size, height, width, 3),
dtype='float32')
for i in range(batch_size):
path = file_pattern % i
print("Reading %s" % path)
pil_image = PIL.Image.open(tf.gfile.GFile(path))
images_actual_data[i, ...] = np.asarray(pil_image)
return images_actual_data
def load_model(checkpoint, batch_size, dataset_name):
width, height = get_dataset_image_size(dataset_name)
dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
model = common_flags.create_model(
num_char_classes=dataset.num_char_classes,
seq_length=dataset.max_sequence_length,
num_views=dataset.num_of_views,
null_code=dataset.null_code,
charset=dataset.charset)
images_placeholder = tf.placeholder(tf.float32,
shape=[batch_size, height, width, 3])
endpoints = model.create_base(images_placeholder, labels_one_hot=None)
init_fn = model.create_init_fn_to_restore(checkpoint)
return images_placeholder, endpoints, init_fn
def main(_):
images_placeholder, endpoints, init_fn = load_model(FLAGS.checkpoint,
FLAGS.batch_size,
FLAGS.dataset_name)
images_data = load_images(FLAGS.image_path_pattern, FLAGS.batch_size,
FLAGS.dataset_name)
with tf.Session() as sess:
tf.tables_initializer().run() # required by the CharsetMapper
init_fn(sess)
predictions = sess.run(endpoints.predicted_text,
feed_dict={images_placeholder: images_data})
print("Predicted strings:")
for line in predictions:
print(line)
if __name__ == '__main__':
tf.app.run()
...@@ -34,25 +34,25 @@ import metrics ...@@ -34,25 +34,25 @@ import metrics
import sequence_layers import sequence_layers
import utils import utils
OutputEndpoints = collections.namedtuple('OutputEndpoints', [ OutputEndpoints = collections.namedtuple('OutputEndpoints', [
'chars_logit', 'chars_log_prob', 'predicted_chars', 'predicted_scores' 'chars_logit', 'chars_log_prob', 'predicted_chars', 'predicted_scores',
'predicted_text'
]) ])
# TODO(gorban): replace with tf.HParams when it is released. # TODO(gorban): replace with tf.HParams when it is released.
ModelParams = collections.namedtuple('ModelParams', [ ModelParams = collections.namedtuple('ModelParams', [
'num_char_classes', 'seq_length', 'num_views', 'null_code' 'num_char_classes', 'seq_length', 'num_views', 'null_code'
]) ])
ConvTowerParams = collections.namedtuple('ConvTowerParams', ['final_endpoint']) ConvTowerParams = collections.namedtuple('ConvTowerParams', ['final_endpoint'])
SequenceLogitsParams = collections.namedtuple('SequenceLogitsParams', [ SequenceLogitsParams = collections.namedtuple('SequenceLogitsParams', [
'use_attention', 'use_autoregression', 'num_lstm_units', 'weight_decay', 'use_attention', 'use_autoregression', 'num_lstm_units', 'weight_decay',
'lstm_state_clip_value' 'lstm_state_clip_value'
]) ])
SequenceLossParams = collections.namedtuple('SequenceLossParams', [ SequenceLossParams = collections.namedtuple('SequenceLossParams', [
'label_smoothing', 'ignore_nulls', 'average_across_timesteps' 'label_smoothing', 'ignore_nulls', 'average_across_timesteps'
]) ])
EncodeCoordinatesParams = collections.namedtuple('EncodeCoordinatesParams', [ EncodeCoordinatesParams = collections.namedtuple('EncodeCoordinatesParams', [
...@@ -125,11 +125,12 @@ class Model(object): ...@@ -125,11 +125,12 @@ class Model(object):
"""Class to create the Attention OCR Model.""" """Class to create the Attention OCR Model."""
def __init__(self, def __init__(self,
num_char_classes, num_char_classes,
seq_length, seq_length,
num_views, num_views,
null_code, null_code,
mparams=None): mparams=None,
charset=None):
"""Initialized model parameters. """Initialized model parameters.
Args: Args:
...@@ -140,6 +141,13 @@ class Model(object): ...@@ -140,6 +141,13 @@ class Model(object):
indicates end of a sequence. indicates end of a sequence.
mparams: a dictionary with hyper parameters for methods, keys - mparams: a dictionary with hyper parameters for methods, keys -
function names, values - corresponding namedtuples. function names, values - corresponding namedtuples.
charset: an optional dictionary with a mapping between character ids and
utf8 strings. If specified the OutputEndpoints.predicted_text will
utf8 encoded strings corresponding to the character ids returned by
OutputEndpoints.predicted_chars (by default the predicted_text contains
an empty vector).
NOTE: Make sure you call tf.tables_initializer().run() if the charset
specified.
""" """
super(Model, self).__init__() super(Model, self).__init__()
self._params = ModelParams( self._params = ModelParams(
...@@ -150,24 +158,25 @@ class Model(object): ...@@ -150,24 +158,25 @@ class Model(object):
self._mparams = self.default_mparams() self._mparams = self.default_mparams()
if mparams: if mparams:
self._mparams.update(mparams) self._mparams.update(mparams)
self._charset = charset
def default_mparams(self): def default_mparams(self):
return { return {
'conv_tower_fn': 'conv_tower_fn':
ConvTowerParams(final_endpoint='Mixed_5d'), ConvTowerParams(final_endpoint='Mixed_5d'),
'sequence_logit_fn': 'sequence_logit_fn':
SequenceLogitsParams( SequenceLogitsParams(
use_attention=True, use_attention=True,
use_autoregression=True, use_autoregression=True,
num_lstm_units=256, num_lstm_units=256,
weight_decay=0.00004, weight_decay=0.00004,
lstm_state_clip_value=10.0), lstm_state_clip_value=10.0),
'sequence_loss_fn': 'sequence_loss_fn':
SequenceLossParams( SequenceLossParams(
label_smoothing=0.1, label_smoothing=0.1,
ignore_nulls=True, ignore_nulls=True,
average_across_timesteps=False), average_across_timesteps=False),
'encode_coordinates_fn': EncodeCoordinatesParams(enabled=False) 'encode_coordinates_fn': EncodeCoordinatesParams(enabled=False)
} }
def set_mparam(self, function, **kwargs): def set_mparam(self, function, **kwargs):
...@@ -241,7 +250,7 @@ class Model(object): ...@@ -241,7 +250,7 @@ class Model(object):
A tensor with the same size as any input tensors. A tensor with the same size as any input tensors.
""" """
batch_size, height, width, num_features = [ batch_size, height, width, num_features = [
d.value for d in nets_list[0].get_shape().dims d.value for d in nets_list[0].get_shape().dims
] ]
xy_flat_shape = (batch_size, 1, height * width, num_features) xy_flat_shape = (batch_size, 1, height * width, num_features)
nets_for_merge = [] nets_for_merge = []
...@@ -323,10 +332,10 @@ class Model(object): ...@@ -323,10 +332,10 @@ class Model(object):
return net return net
def create_base(self, def create_base(self,
images, images,
labels_one_hot, labels_one_hot,
scope='AttentionOcr_v1', scope='AttentionOcr_v1',
reuse=None): reuse=None):
"""Creates a base part of the Model (no gradients, losses or summaries). """Creates a base part of the Model (no gradients, losses or summaries).
Args: Args:
...@@ -348,8 +357,8 @@ class Model(object): ...@@ -348,8 +357,8 @@ class Model(object):
logging.debug('Views=%d single view: %s', len(views), views[0]) logging.debug('Views=%d single view: %s', len(views), views[0])
nets = [ nets = [
self.conv_tower_fn(v, is_training, reuse=(i != 0)) self.conv_tower_fn(v, is_training, reuse=(i != 0))
for i, v in enumerate(views) for i, v in enumerate(views)
] ]
logging.debug('Conv tower: %s', nets[0]) logging.debug('Conv tower: %s', nets[0])
...@@ -363,13 +372,18 @@ class Model(object): ...@@ -363,13 +372,18 @@ class Model(object):
logging.debug('chars_logit: %s', chars_logit) logging.debug('chars_logit: %s', chars_logit)
predicted_chars, chars_log_prob, predicted_scores = ( predicted_chars, chars_log_prob, predicted_scores = (
self.char_predictions(chars_logit)) self.char_predictions(chars_logit))
if self._charset:
character_mapper = CharsetMapper(self._charset)
predicted_text = character_mapper.get_text(predicted_chars)
else:
predicted_text = tf.constant([])
return OutputEndpoints( return OutputEndpoints(
chars_logit=chars_logit, chars_logit=chars_logit,
chars_log_prob=chars_log_prob, chars_log_prob=chars_log_prob,
predicted_chars=predicted_chars, predicted_chars=predicted_chars,
predicted_scores=predicted_scores) predicted_scores=predicted_scores,
predicted_text=predicted_text)
def create_loss(self, data, endpoints): def create_loss(self, data, endpoints):
"""Creates all losses required to train the model. """Creates all losses required to train the model.
...@@ -523,7 +537,8 @@ class Model(object): ...@@ -523,7 +537,8 @@ class Model(object):
tf.summary.scalar(summary_name, tf.Print(value, [value], summary_name)) tf.summary.scalar(summary_name, tf.Print(value, [value], summary_name))
return names_to_updates.values() return names_to_updates.values()
def create_init_fn_to_restore(self, master_checkpoint, inception_checkpoint): def create_init_fn_to_restore(self, master_checkpoint,
inception_checkpoint=None):
"""Creates an init operations to restore weights from various checkpoints. """Creates an init operations to restore weights from various checkpoints.
Args: Args:
......
...@@ -73,9 +73,10 @@ class ModelTest(tf.test.TestCase): ...@@ -73,9 +73,10 @@ class ModelTest(tf.test.TestCase):
high=self.num_char_classes, high=self.num_char_classes,
size=(self.batch_size, self.seq_length)).astype('int64')) size=(self.batch_size, self.seq_length)).astype('int64'))
def create_model(self): def create_model(self, charset=None):
return model.Model( return model.Model(
self.num_char_classes, self.seq_length, num_views=4, null_code=62) self.num_char_classes, self.seq_length, num_views=4, null_code=62,
charset=charset)
def test_char_related_shapes(self): def test_char_related_shapes(self):
ocr_model = self.create_model() ocr_model = self.create_model()
...@@ -244,6 +245,21 @@ class ModelTest(tf.test.TestCase): ...@@ -244,6 +245,21 @@ class ModelTest(tf.test.TestCase):
self.assertAllEqual(conv_w_coords_tf, conv_w_coords_alt_tf) self.assertAllEqual(conv_w_coords_tf, conv_w_coords_alt_tf)
def test_predicted_text_has_correct_shape_w_charset(self):
charset = create_fake_charset(self.num_char_classes)
ocr_model = self.create_model(charset=charset)
with self.test_session() as sess:
endpoints_tf = ocr_model.create_base(
images=self.fake_images, labels_one_hot=None)
sess.run(tf.global_variables_initializer())
tf.tables_initializer().run()
endpoints = sess.run(endpoints_tf)
self.assertEqual(endpoints.predicted_text.shape, (self.batch_size,))
self.assertEqual(len(endpoints.predicted_text[0]), self.seq_length)
class CharsetMapperTest(tf.test.TestCase): class CharsetMapperTest(tf.test.TestCase):
def test_text_corresponds_to_ids(self): def test_text_corresponds_to_ids(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment