Commit 89e19ed2 authored by Alexander Gorban's avatar Alexander Gorban
Browse files

Fix demo_inference to properly normalize input.

Before the fix the demo_inference.py used batch_norm and it did the
normalization of input image implicitly. If at inference time the
batch_norm was disabled the inference produced incorrect results.
This fix does the proper input image normalization and disables the batch_norm
at inference time.
parent f893da6d
......@@ -20,10 +20,11 @@ import PIL.Image
import tensorflow as tf
from tensorflow.python.platform import flags
from tensorflow.python.training import monitored_session
import common_flags
import datasets
import model as attention_ocr
import data_provider
FLAGS = flags.FLAGS
common_flags.define()
......@@ -44,7 +45,7 @@ def get_dataset_image_size(dataset_name):
def load_images(file_pattern, batch_size, dataset_name):
width, height = get_dataset_image_size(dataset_name)
images_actual_data = np.ndarray(shape=(batch_size, height, width, 3),
dtype='float32')
dtype='uint8')
for i in range(batch_size):
path = file_pattern % i
print("Reading %s" % path)
......@@ -53,7 +54,7 @@ def load_images(file_pattern, batch_size, dataset_name):
return images_actual_data
def load_model(checkpoint, batch_size, dataset_name):
def create_model(batch_size, dataset_name):
width, height = get_dataset_image_size(dataset_name)
dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
model = common_flags.create_model(
......@@ -62,26 +63,31 @@ def load_model(checkpoint, batch_size, dataset_name):
num_views=dataset.num_of_views,
null_code=dataset.null_code,
charset=dataset.charset)
images_placeholder = tf.placeholder(tf.float32,
shape=[batch_size, height, width, 3])
endpoints = model.create_base(images_placeholder, labels_one_hot=None)
init_fn = model.create_init_fn_to_restore(checkpoint)
return images_placeholder, endpoints, init_fn
raw_images = tf.placeholder(tf.uint8, shape=[batch_size, height, width, 3])
images = tf.map_fn(data_provider.preprocess_image, raw_images,
dtype=tf.float32)
endpoints = model.create_base(images, labels_one_hot=None)
return raw_images, endpoints
def run(checkpoint, batch_size, dataset_name, image_path_pattern):
images_placeholder, endpoints = create_model(batch_size,
dataset_name)
images_data = load_images(image_path_pattern, batch_size,
dataset_name)
session_creator = monitored_session.ChiefSessionCreator(
checkpoint_filename_with_path=checkpoint)
with monitored_session.MonitoredSession(
session_creator=session_creator) as sess:
predictions = sess.run(endpoints.predicted_text,
feed_dict={images_placeholder: images_data})
return predictions.tolist()
def main(_):
images_placeholder, endpoints, init_fn = load_model(FLAGS.checkpoint,
FLAGS.batch_size,
FLAGS.dataset_name)
images_data = load_images(FLAGS.image_path_pattern, FLAGS.batch_size,
FLAGS.dataset_name)
with tf.Session() as sess:
tf.tables_initializer().run() # required by the CharsetMapper
init_fn(sess)
predictions = sess.run(endpoints.predicted_text,
feed_dict={images_placeholder: images_data})
print("Predicted strings:")
for line in predictions:
for line in run(FLAGS.checkpoint, FLAGS.batch_size, FLAGS.dataset_name,
FLAGS.image_path_pattern):
print(line)
......
#!/usr/bin/python
# -*- coding: UTF-8 -*-
import demo_inference
import tensorflow as tf
from tensorflow.python.training import monitored_session
_CHECKPOINT = 'model.ckpt-399731'
_CHECKPOINT_URL = 'http://download.tensorflow.org/models/attention_ocr_2017_08_09.tar.gz'
class DemoInferenceTest(tf.test.TestCase):
def setUp(self):
super(DemoInferenceTest, self).setUp()
for suffix in ['.meta', '.index', '.data-00000-of-00001']:
filename = _CHECKPOINT + suffix
self.assertTrue(tf.gfile.Exists(filename),
msg='Missing checkpoint file %s. '
'Please download and extract it from %s' %
(filename, _CHECKPOINT_URL))
self._batch_size = 32
def test_moving_variables_properly_loaded_from_a_checkpoint(self):
batch_size = 32
dataset_name = 'fsns'
images_placeholder, endpoints = demo_inference.create_model(batch_size,
dataset_name)
image_path_pattern = 'testdata/fsns_train_%02d.png'
images_data = demo_inference.load_images(image_path_pattern, batch_size,
dataset_name)
tensor_name = 'AttentionOcr_v1/conv_tower_fn/INCE/InceptionV3/Conv2d_2a_3x3/BatchNorm/moving_mean'
moving_mean_tf = tf.get_default_graph().get_tensor_by_name(
tensor_name + ':0')
reader = tf.train.NewCheckpointReader(_CHECKPOINT)
moving_mean_expected = reader.get_tensor(tensor_name)
session_creator = monitored_session.ChiefSessionCreator(
checkpoint_filename_with_path=_CHECKPOINT)
with monitored_session.MonitoredSession(
session_creator=session_creator) as sess:
moving_mean_np = sess.run(moving_mean_tf,
feed_dict={images_placeholder: images_data})
self.assertAllEqual(moving_mean_expected, moving_mean_np)
def test_correct_results_on_test_data(self):
image_path_pattern = 'testdata/fsns_train_%02d.png'
predictions = demo_inference.run(_CHECKPOINT, self._batch_size,
'fsns',
image_path_pattern)
self.assertEqual([
'Boulevard de Lunel░░░░░░░░░░░░░░░░░░░',
'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░',
'Rue de Port Maria░░░░░░░░░░░░░░░░░░░░',
'Avenue Charles Gounod░░░░░░░░░░░░░░░░',
'Rue de l‘Aurore░░░░░░░░░░░░░░░░░░░░░░',
'Rue de Beuzeville░░░░░░░░░░░░░░░░░░░░',
'Rue d‘Orbey░░░░░░░░░░░░░░░░░░░░░░░░░░',
'Rue Victor Schoulcher░░░░░░░░░░░░░░░░',
'Rue de la Gare░░░░░░░░░░░░░░░░░░░░░░░',
'Rue des Tulipes░░░░░░░░░░░░░░░░░░░░░░',
'Rue André Maginot░░░░░░░░░░░░░░░░░░░░',
'Route de Pringy░░░░░░░░░░░░░░░░░░░░░░',
'Rue des Landelles░░░░░░░░░░░░░░░░░░░░',
'Rue des Ilettes░░░░░░░░░░░░░░░░░░░░░░',
'Avenue de Maurin░░░░░░░░░░░░░░░░░░░░░',
'Rue Théresa░░░░░░░░░░░░░░░░░░░░░░░░░░', # GT='Rue Thérésa'
'Route de la Balme░░░░░░░░░░░░░░░░░░░░',
'Rue Hélène Roederer░░░░░░░░░░░░░░░░░░',
'Rue Emile Bernard░░░░░░░░░░░░░░░░░░░░',
'Place de la Mairie░░░░░░░░░░░░░░░░░░░',
'Rue des Perrots░░░░░░░░░░░░░░░░░░░░░░',
'Rue de la Libération░░░░░░░░░░░░░░░░░',
'Impasse du Capcir░░░░░░░░░░░░░░░░░░░░',
'Avenue de la Grand Mare░░░░░░░░░░░░░░',
'Rue Pierre Brossolette░░░░░░░░░░░░░░░',
'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░',
'Rue du Docteur Mourre░░░░░░░░░░░░░░░░',
'Rue d‘Ortheuil░░░░░░░░░░░░░░░░░░░░░░░',
'Rue des Sarments░░░░░░░░░░░░░░░░░░░░░',
'Rue du Centre░░░░░░░░░░░░░░░░░░░░░░░░',
'Impasse Pierre Mourgues░░░░░░░░░░░░░░',
'Rue Marcel Dassault░░░░░░░░░░░░░░░░░░'
], predictions)
if __name__ == '__main__':
tf.test.main()
......@@ -201,9 +201,9 @@ class Model(object):
with tf.variable_scope('conv_tower_fn/INCE'):
if reuse:
tf.get_variable_scope().reuse_variables()
with slim.arg_scope(
[slim.batch_norm, slim.dropout], is_training=is_training):
with slim.arg_scope(inception.inception_v3_arg_scope()):
with slim.arg_scope([slim.batch_norm, slim.dropout],
is_training=is_training):
net, _ = inception.inception_v3_base(
images, final_endpoint=mparams.final_endpoint)
return net
......@@ -565,6 +565,9 @@ class Model(object):
all_assign_ops.append(assign_op)
all_feed_dict.update(feed_dict)
logging.info('variables_to_restore:\n%s' % utils.variables_to_restore().keys())
logging.info('moving_average_variables:\n%s' % [v.op.name for v in tf.moving_average_variables()])
logging.info('trainable_variables:\n%s' % [v.op.name for v in tf.trainable_variables()])
if master_checkpoint:
assign_from_checkpoint(utils.variables_to_restore(), master_checkpoint)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment