Commit 89e19ed2 authored by Alexander Gorban's avatar Alexander Gorban
Browse files

Fix demo_inference to properly normalize input.

Before the fix the demo_inference.py used batch_norm and it did the
normalization of input image implicitly. If at inference time the
batch_norm was disabled the inference produced incorrect results.
This fix does the proper input image normalization and disables the batch_norm
at inference time.
parent f893da6d
"""A script to run inference on a set of image files.
NOTE #1: The Attention OCR model was trained only using FSNS train dataset and
it will work only for images which look more or less similar to french street
names. In order to apply it to images from a different distribution you need
to retrain (or at least fine-tune) it using images from that distribution.
NOTE #1: The Attention OCR model was trained only using FSNS train dataset and
it will work only for images which look more or less similar to french street
names. In order to apply it to images from a different distribution you need
to retrain (or at least fine-tune) it using images from that distribution.
NOTE #2: This script exists for demo purposes only. It is highly recommended
to use tools and mechanisms provided by the TensorFlow Serving system to run
......@@ -20,10 +20,11 @@ import PIL.Image
import tensorflow as tf
from tensorflow.python.platform import flags
from tensorflow.python.training import monitored_session
import common_flags
import datasets
import model as attention_ocr
import data_provider
FLAGS = flags.FLAGS
common_flags.define()
......@@ -44,7 +45,7 @@ def get_dataset_image_size(dataset_name):
def load_images(file_pattern, batch_size, dataset_name):
width, height = get_dataset_image_size(dataset_name)
images_actual_data = np.ndarray(shape=(batch_size, height, width, 3),
dtype='float32')
dtype='uint8')
for i in range(batch_size):
path = file_pattern % i
print("Reading %s" % path)
......@@ -53,35 +54,40 @@ def load_images(file_pattern, batch_size, dataset_name):
return images_actual_data
def load_model(checkpoint, batch_size, dataset_name):
def create_model(batch_size, dataset_name):
width, height = get_dataset_image_size(dataset_name)
dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
model = common_flags.create_model(
num_char_classes=dataset.num_char_classes,
seq_length=dataset.max_sequence_length,
num_views=dataset.num_of_views,
null_code=dataset.null_code,
charset=dataset.charset)
images_placeholder = tf.placeholder(tf.float32,
shape=[batch_size, height, width, 3])
endpoints = model.create_base(images_placeholder, labels_one_hot=None)
init_fn = model.create_init_fn_to_restore(checkpoint)
return images_placeholder, endpoints, init_fn
num_char_classes=dataset.num_char_classes,
seq_length=dataset.max_sequence_length,
num_views=dataset.num_of_views,
null_code=dataset.null_code,
charset=dataset.charset)
raw_images = tf.placeholder(tf.uint8, shape=[batch_size, height, width, 3])
images = tf.map_fn(data_provider.preprocess_image, raw_images,
dtype=tf.float32)
endpoints = model.create_base(images, labels_one_hot=None)
return raw_images, endpoints
def run(checkpoint, batch_size, dataset_name, image_path_pattern):
images_placeholder, endpoints = create_model(batch_size,
dataset_name)
images_data = load_images(image_path_pattern, batch_size,
dataset_name)
session_creator = monitored_session.ChiefSessionCreator(
checkpoint_filename_with_path=checkpoint)
with monitored_session.MonitoredSession(
session_creator=session_creator) as sess:
predictions = sess.run(endpoints.predicted_text,
feed_dict={images_placeholder: images_data})
return predictions.tolist()
def main(_):
images_placeholder, endpoints, init_fn = load_model(FLAGS.checkpoint,
FLAGS.batch_size,
FLAGS.dataset_name)
images_data = load_images(FLAGS.image_path_pattern, FLAGS.batch_size,
FLAGS.dataset_name)
with tf.Session() as sess:
tf.tables_initializer().run() # required by the CharsetMapper
init_fn(sess)
predictions = sess.run(endpoints.predicted_text,
feed_dict={images_placeholder: images_data})
print("Predicted strings:")
for line in predictions:
for line in run(FLAGS.checkpoint, FLAGS.batch_size, FLAGS.dataset_name,
FLAGS.image_path_pattern):
print(line)
......
#!/usr/bin/python
# -*- coding: UTF-8 -*-
import demo_inference
import tensorflow as tf
from tensorflow.python.training import monitored_session
_CHECKPOINT = 'model.ckpt-399731'
_CHECKPOINT_URL = 'http://download.tensorflow.org/models/attention_ocr_2017_08_09.tar.gz'
class DemoInferenceTest(tf.test.TestCase):
def setUp(self):
super(DemoInferenceTest, self).setUp()
for suffix in ['.meta', '.index', '.data-00000-of-00001']:
filename = _CHECKPOINT + suffix
self.assertTrue(tf.gfile.Exists(filename),
msg='Missing checkpoint file %s. '
'Please download and extract it from %s' %
(filename, _CHECKPOINT_URL))
self._batch_size = 32
def test_moving_variables_properly_loaded_from_a_checkpoint(self):
batch_size = 32
dataset_name = 'fsns'
images_placeholder, endpoints = demo_inference.create_model(batch_size,
dataset_name)
image_path_pattern = 'testdata/fsns_train_%02d.png'
images_data = demo_inference.load_images(image_path_pattern, batch_size,
dataset_name)
tensor_name = 'AttentionOcr_v1/conv_tower_fn/INCE/InceptionV3/Conv2d_2a_3x3/BatchNorm/moving_mean'
moving_mean_tf = tf.get_default_graph().get_tensor_by_name(
tensor_name + ':0')
reader = tf.train.NewCheckpointReader(_CHECKPOINT)
moving_mean_expected = reader.get_tensor(tensor_name)
session_creator = monitored_session.ChiefSessionCreator(
checkpoint_filename_with_path=_CHECKPOINT)
with monitored_session.MonitoredSession(
session_creator=session_creator) as sess:
moving_mean_np = sess.run(moving_mean_tf,
feed_dict={images_placeholder: images_data})
self.assertAllEqual(moving_mean_expected, moving_mean_np)
def test_correct_results_on_test_data(self):
image_path_pattern = 'testdata/fsns_train_%02d.png'
predictions = demo_inference.run(_CHECKPOINT, self._batch_size,
'fsns',
image_path_pattern)
self.assertEqual([
'Boulevard de Lunel░░░░░░░░░░░░░░░░░░░',
'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░',
'Rue de Port Maria░░░░░░░░░░░░░░░░░░░░',
'Avenue Charles Gounod░░░░░░░░░░░░░░░░',
'Rue de l‘Aurore░░░░░░░░░░░░░░░░░░░░░░',
'Rue de Beuzeville░░░░░░░░░░░░░░░░░░░░',
'Rue d‘Orbey░░░░░░░░░░░░░░░░░░░░░░░░░░',
'Rue Victor Schoulcher░░░░░░░░░░░░░░░░',
'Rue de la Gare░░░░░░░░░░░░░░░░░░░░░░░',
'Rue des Tulipes░░░░░░░░░░░░░░░░░░░░░░',
'Rue André Maginot░░░░░░░░░░░░░░░░░░░░',
'Route de Pringy░░░░░░░░░░░░░░░░░░░░░░',
'Rue des Landelles░░░░░░░░░░░░░░░░░░░░',
'Rue des Ilettes░░░░░░░░░░░░░░░░░░░░░░',
'Avenue de Maurin░░░░░░░░░░░░░░░░░░░░░',
'Rue Théresa░░░░░░░░░░░░░░░░░░░░░░░░░░', # GT='Rue Thérésa'
'Route de la Balme░░░░░░░░░░░░░░░░░░░░',
'Rue Hélène Roederer░░░░░░░░░░░░░░░░░░',
'Rue Emile Bernard░░░░░░░░░░░░░░░░░░░░',
'Place de la Mairie░░░░░░░░░░░░░░░░░░░',
'Rue des Perrots░░░░░░░░░░░░░░░░░░░░░░',
'Rue de la Libération░░░░░░░░░░░░░░░░░',
'Impasse du Capcir░░░░░░░░░░░░░░░░░░░░',
'Avenue de la Grand Mare░░░░░░░░░░░░░░',
'Rue Pierre Brossolette░░░░░░░░░░░░░░░',
'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░',
'Rue du Docteur Mourre░░░░░░░░░░░░░░░░',
'Rue d‘Ortheuil░░░░░░░░░░░░░░░░░░░░░░░',
'Rue des Sarments░░░░░░░░░░░░░░░░░░░░░',
'Rue du Centre░░░░░░░░░░░░░░░░░░░░░░░░',
'Impasse Pierre Mourgues░░░░░░░░░░░░░░',
'Rue Marcel Dassault░░░░░░░░░░░░░░░░░░'
], predictions)
if __name__ == '__main__':
tf.test.main()
......@@ -85,7 +85,7 @@ class CharsetMapper(object):
"""
mapping_strings = tf.constant(_dict_to_array(charset, default_character))
self.table = tf.contrib.lookup.index_to_string_table_from_tensor(
mapping=mapping_strings, default_value=default_character)
mapping=mapping_strings, default_value=default_character)
def get_text(self, ids):
"""Returns a string corresponding to a sequence of character ids.
......@@ -94,7 +94,7 @@ class CharsetMapper(object):
ids: a tensor with shape [batch_size, max_sequence_length]
"""
return tf.reduce_join(
self.table.lookup(tf.to_int64(ids)), reduction_indices=1)
self.table.lookup(tf.to_int64(ids)), reduction_indices=1)
def get_softmax_loss_fn(label_smoothing):
......@@ -111,12 +111,12 @@ def get_softmax_loss_fn(label_smoothing):
def loss_fn(labels, logits):
return (tf.nn.softmax_cross_entropy_with_logits(
logits=logits, labels=labels))
logits=logits, labels=labels))
else:
def loss_fn(labels, logits):
return tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=labels)
logits=logits, labels=labels)
return loss_fn
......@@ -125,12 +125,12 @@ class Model(object):
"""Class to create the Attention OCR Model."""
def __init__(self,
num_char_classes,
seq_length,
num_views,
null_code,
mparams=None,
charset=None):
num_char_classes,
seq_length,
num_views,
null_code,
mparams=None,
charset=None):
"""Initialized model parameters.
Args:
......@@ -151,10 +151,10 @@ class Model(object):
"""
super(Model, self).__init__()
self._params = ModelParams(
num_char_classes=num_char_classes,
seq_length=seq_length,
num_views=num_views,
null_code=null_code)
num_char_classes=num_char_classes,
seq_length=seq_length,
num_views=num_views,
null_code=null_code)
self._mparams = self.default_mparams()
if mparams:
self._mparams.update(mparams)
......@@ -166,16 +166,16 @@ class Model(object):
ConvTowerParams(final_endpoint='Mixed_5d'),
'sequence_logit_fn':
SequenceLogitsParams(
use_attention=True,
use_autoregression=True,
num_lstm_units=256,
weight_decay=0.00004,
lstm_state_clip_value=10.0),
use_attention=True,
use_autoregression=True,
num_lstm_units=256,
weight_decay=0.00004,
lstm_state_clip_value=10.0),
'sequence_loss_fn':
SequenceLossParams(
label_smoothing=0.1,
ignore_nulls=True,
average_across_timesteps=False),
label_smoothing=0.1,
ignore_nulls=True,
average_across_timesteps=False),
'encode_coordinates_fn': EncodeCoordinatesParams(enabled=False)
}
......@@ -201,11 +201,11 @@ class Model(object):
with tf.variable_scope('conv_tower_fn/INCE'):
if reuse:
tf.get_variable_scope().reuse_variables()
with slim.arg_scope(
[slim.batch_norm, slim.dropout], is_training=is_training):
with slim.arg_scope(inception.inception_v3_arg_scope()):
net, _ = inception.inception_v3_base(
images, final_endpoint=mparams.final_endpoint)
with slim.arg_scope(inception.inception_v3_arg_scope()):
with slim.arg_scope([slim.batch_norm, slim.dropout],
is_training=is_training):
net, _ = inception.inception_v3_base(
images, final_endpoint=mparams.final_endpoint)
return net
def _create_lstm_inputs(self, net):
......@@ -261,7 +261,7 @@ class Model(object):
nets_for_merge.append(tf.reshape(net, xy_flat_shape))
merged_net = tf.concat(nets_for_merge, 1)
net = slim.max_pool2d(
merged_net, kernel_size=[len(nets_list), 1], stride=1)
merged_net, kernel_size=[len(nets_list), 1], stride=1)
net = tf.reshape(net, (batch_size, height, width, num_features))
return net
......@@ -303,7 +303,7 @@ class Model(object):
log_prob = utils.logits_to_log_prob(chars_logit)
ids = tf.to_int32(tf.argmax(log_prob, axis=2), name='predicted_chars')
mask = tf.cast(
slim.one_hot_encoding(ids, self._params.num_char_classes), tf.bool)
slim.one_hot_encoding(ids, self._params.num_char_classes), tf.bool)
all_scores = tf.nn.softmax(chars_logit)
selected_scores = tf.boolean_mask(all_scores, mask, name='char_scores')
scores = tf.reshape(selected_scores, shape=(-1, self._params.seq_length))
......@@ -334,10 +334,10 @@ class Model(object):
return net
def create_base(self,
images,
labels_one_hot,
scope='AttentionOcr_v1',
reuse=None):
images,
labels_one_hot,
scope='AttentionOcr_v1',
reuse=None):
"""Creates a base part of the Model (no gradients, losses or summaries).
Args:
......@@ -355,7 +355,7 @@ class Model(object):
is_training = labels_one_hot is not None
with tf.variable_scope(scope, reuse=reuse):
views = tf.split(
value=images, num_or_size_splits=self._params.num_views, axis=2)
value=images, num_or_size_splits=self._params.num_views, axis=2)
logging.debug('Views=%d single view: %s', len(views), views[0])
nets = [
......@@ -381,11 +381,11 @@ class Model(object):
else:
predicted_text = tf.constant([])
return OutputEndpoints(
chars_logit=chars_logit,
chars_log_prob=chars_log_prob,
predicted_chars=predicted_chars,
predicted_scores=predicted_scores,
predicted_text=predicted_text)
chars_logit=chars_logit,
chars_log_prob=chars_log_prob,
predicted_chars=predicted_chars,
predicted_scores=predicted_scores,
predicted_text=predicted_text)
def create_loss(self, data, endpoints):
"""Creates all losses required to train the model.
......@@ -421,7 +421,7 @@ class Model(object):
A sensor with the same shape as the input.
"""
one_hot_labels = tf.one_hot(
chars_labels, depth=self._params.num_char_classes, axis=-1)
chars_labels, depth=self._params.num_char_classes, axis=-1)
pos_weight = 1.0 - weight
neg_weight = weight / self._params.num_char_classes
return one_hot_labels * pos_weight + neg_weight
......@@ -446,7 +446,7 @@ class Model(object):
with tf.variable_scope('sequence_loss_fn/SLF'):
if mparams.label_smoothing > 0:
smoothed_one_hot_labels = self.label_smoothing_regularization(
chars_labels, mparams.label_smoothing)
chars_labels, mparams.label_smoothing)
labels_list = tf.unstack(smoothed_one_hot_labels, axis=1)
else:
# NOTE: in case of sparse softmax we are not using one-hot
......@@ -459,20 +459,20 @@ class Model(object):
else:
# Suppose that reject character is the last in the charset.
reject_char = tf.constant(
self._params.num_char_classes - 1,
shape=(batch_size, seq_length),
dtype=tf.int64)
self._params.num_char_classes - 1,
shape=(batch_size, seq_length),
dtype=tf.int64)
known_char = tf.not_equal(chars_labels, reject_char)
weights = tf.to_float(known_char)
logits_list = tf.unstack(chars_logits, axis=1)
weights_list = tf.unstack(weights, axis=1)
loss = tf.contrib.legacy_seq2seq.sequence_loss(
logits_list,
labels_list,
weights_list,
softmax_loss_function=get_softmax_loss_fn(mparams.label_smoothing),
average_across_timesteps=mparams.average_across_timesteps)
logits_list,
labels_list,
weights_list,
softmax_loss_function=get_softmax_loss_fn(mparams.label_smoothing),
average_across_timesteps=mparams.average_across_timesteps)
tf.losses.add_loss(loss)
return loss
......@@ -507,7 +507,7 @@ class Model(object):
if is_training:
tf.summary.image(
sname('image/orig'), data.images_orig, max_outputs=max_outputs)
sname('image/orig'), data.images_orig, max_outputs=max_outputs)
for var in tf.trainable_variables():
tf.summary.histogram(var.op.name, var)
return None
......@@ -522,17 +522,17 @@ class Model(object):
use_metric('CharacterAccuracy',
metrics.char_accuracy(
endpoints.predicted_chars,
data.labels,
streaming=True,
rej_char=self._params.null_code))
endpoints.predicted_chars,
data.labels,
streaming=True,
rej_char=self._params.null_code))
# Sequence accuracy computed by cutting sequence at the first null char
use_metric('SequenceAccuracy',
metrics.sequence_accuracy(
endpoints.predicted_chars,
data.labels,
streaming=True,
rej_char=self._params.null_code))
endpoints.predicted_chars,
data.labels,
streaming=True,
rej_char=self._params.null_code))
for name, value in names_to_values.iteritems():
summary_name = 'eval/' + name
......@@ -540,7 +540,7 @@ class Model(object):
return names_to_updates.values()
def create_init_fn_to_restore(self, master_checkpoint,
inception_checkpoint=None):
inception_checkpoint=None):
"""Creates an init operations to restore weights from various checkpoints.
Args:
......@@ -565,12 +565,15 @@ class Model(object):
all_assign_ops.append(assign_op)
all_feed_dict.update(feed_dict)
logging.info('variables_to_restore:\n%s' % utils.variables_to_restore().keys())
logging.info('moving_average_variables:\n%s' % [v.op.name for v in tf.moving_average_variables()])
logging.info('trainable_variables:\n%s' % [v.op.name for v in tf.trainable_variables()])
if master_checkpoint:
assign_from_checkpoint(utils.variables_to_restore(), master_checkpoint)
if inception_checkpoint:
variables = utils.variables_to_restore(
'AttentionOcr_v1/conv_tower_fn/INCE', strip_scope=True)
'AttentionOcr_v1/conv_tower_fn/INCE', strip_scope=True)
assign_from_checkpoint(variables, inception_checkpoint)
def init_assign_fn(sess):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment