"vscode:/vscode.git/clone" did not exist on "f61ef323e0702bc2711529b9dea8e85148137337"
Unverified Commit c0cd713f authored by Neal Wu's avatar Neal Wu Committed by GitHub
Browse files

Merge pull request #3111 from alexgorban/master

Fix demo_inference to properly normalize input.
parents b719165d f02e6013
"""A script to run inference on a set of image files. """A script to run inference on a set of image files.
NOTE #1: The Attention OCR model was trained only using FSNS train dataset and NOTE #1: The Attention OCR model was trained only using FSNS train dataset and
it will work only for images which look more or less similar to french street it will work only for images which look more or less similar to french street
names. In order to apply it to images from a different distribution you need names. In order to apply it to images from a different distribution you need
to retrain (or at least fine-tune) it using images from that distribution. to retrain (or at least fine-tune) it using images from that distribution.
NOTE #2: This script exists for demo purposes only. It is highly recommended NOTE #2: This script exists for demo purposes only. It is highly recommended
to use tools and mechanisms provided by the TensorFlow Serving system to run to use tools and mechanisms provided by the TensorFlow Serving system to run
...@@ -20,10 +20,11 @@ import PIL.Image ...@@ -20,10 +20,11 @@ import PIL.Image
import tensorflow as tf import tensorflow as tf
from tensorflow.python.platform import flags from tensorflow.python.platform import flags
from tensorflow.python.training import monitored_session
import common_flags import common_flags
import datasets import datasets
import model as attention_ocr import data_provider
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
common_flags.define() common_flags.define()
...@@ -44,7 +45,7 @@ def get_dataset_image_size(dataset_name): ...@@ -44,7 +45,7 @@ def get_dataset_image_size(dataset_name):
def load_images(file_pattern, batch_size, dataset_name): def load_images(file_pattern, batch_size, dataset_name):
width, height = get_dataset_image_size(dataset_name) width, height = get_dataset_image_size(dataset_name)
images_actual_data = np.ndarray(shape=(batch_size, height, width, 3), images_actual_data = np.ndarray(shape=(batch_size, height, width, 3),
dtype='float32') dtype='uint8')
for i in range(batch_size): for i in range(batch_size):
path = file_pattern % i path = file_pattern % i
print("Reading %s" % path) print("Reading %s" % path)
...@@ -53,34 +54,40 @@ def load_images(file_pattern, batch_size, dataset_name): ...@@ -53,34 +54,40 @@ def load_images(file_pattern, batch_size, dataset_name):
return images_actual_data return images_actual_data
def load_model(checkpoint, batch_size, dataset_name): def create_model(batch_size, dataset_name):
width, height = get_dataset_image_size(dataset_name) width, height = get_dataset_image_size(dataset_name)
dataset = common_flags.create_dataset(split_name=FLAGS.split_name) dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
model = common_flags.create_model( model = common_flags.create_model(
num_char_classes=dataset.num_char_classes, num_char_classes=dataset.num_char_classes,
seq_length=dataset.max_sequence_length, seq_length=dataset.max_sequence_length,
num_views=dataset.num_of_views, num_views=dataset.num_of_views,
null_code=dataset.null_code, null_code=dataset.null_code,
charset=dataset.charset) charset=dataset.charset)
images_placeholder = tf.placeholder(tf.float32, raw_images = tf.placeholder(tf.uint8, shape=[batch_size, height, width, 3])
shape=[batch_size, height, width, 3]) images = tf.map_fn(data_provider.preprocess_image, raw_images,
endpoints = model.create_base(images_placeholder, labels_one_hot=None) dtype=tf.float32)
init_fn = model.create_init_fn_to_restore(checkpoint) endpoints = model.create_base(images, labels_one_hot=None)
return images_placeholder, endpoints, init_fn return raw_images, endpoints
def run(checkpoint, batch_size, dataset_name, image_path_pattern):
images_placeholder, endpoints = create_model(batch_size,
dataset_name)
images_data = load_images(image_path_pattern, batch_size,
dataset_name)
session_creator = monitored_session.ChiefSessionCreator(
checkpoint_filename_with_path=checkpoint)
with monitored_session.MonitoredSession(
session_creator=session_creator) as sess:
predictions = sess.run(endpoints.predicted_text,
feed_dict={images_placeholder: images_data})
return predictions.tolist()
def main(_): def main(_):
images_placeholder, endpoints, init_fn = load_model(FLAGS.checkpoint,
FLAGS.batch_size,
FLAGS.dataset_name)
images_data = load_images(FLAGS.image_path_pattern, FLAGS.batch_size,
FLAGS.dataset_name)
with tf.Session() as sess:
tf.tables_initializer().run() # required by the CharsetMapper
init_fn(sess)
predictions = sess.run(endpoints.predicted_text,
feed_dict={images_placeholder: images_data})
print("Predicted strings:") print("Predicted strings:")
predictions = run(FLAGS.checkpoint, FLAGS.batch_size, FLAGS.dataset_name,
FLAGS.image_path_pattern)
for line in predictions: for line in predictions:
print(line) print(line)
......
#!/usr/bin/python
# -*- coding: UTF-8 -*-
import demo_inference
import tensorflow as tf
from tensorflow.python.training import monitored_session
_CHECKPOINT = 'model.ckpt-399731'
_CHECKPOINT_URL = 'http://download.tensorflow.org/models/attention_ocr_2017_08_09.tar.gz'
class DemoInferenceTest(tf.test.TestCase):
def setUp(self):
super(DemoInferenceTest, self).setUp()
for suffix in ['.meta', '.index', '.data-00000-of-00001']:
filename = _CHECKPOINT + suffix
self.assertTrue(tf.gfile.Exists(filename),
msg='Missing checkpoint file %s. '
'Please download and extract it from %s' %
(filename, _CHECKPOINT_URL))
self._batch_size = 32
def test_moving_variables_properly_loaded_from_a_checkpoint(self):
batch_size = 32
dataset_name = 'fsns'
images_placeholder, endpoints = demo_inference.create_model(batch_size,
dataset_name)
image_path_pattern = 'testdata/fsns_train_%02d.png'
images_data = demo_inference.load_images(image_path_pattern, batch_size,
dataset_name)
tensor_name = 'AttentionOcr_v1/conv_tower_fn/INCE/InceptionV3/Conv2d_2a_3x3/BatchNorm/moving_mean'
moving_mean_tf = tf.get_default_graph().get_tensor_by_name(
tensor_name + ':0')
reader = tf.train.NewCheckpointReader(_CHECKPOINT)
moving_mean_expected = reader.get_tensor(tensor_name)
session_creator = monitored_session.ChiefSessionCreator(
checkpoint_filename_with_path=_CHECKPOINT)
with monitored_session.MonitoredSession(
session_creator=session_creator) as sess:
moving_mean_np = sess.run(moving_mean_tf,
feed_dict={images_placeholder: images_data})
self.assertAllEqual(moving_mean_expected, moving_mean_np)
def test_correct_results_on_test_data(self):
image_path_pattern = 'testdata/fsns_train_%02d.png'
predictions = demo_inference.run(_CHECKPOINT, self._batch_size,
'fsns',
image_path_pattern)
self.assertEqual([
'Boulevard de Lunel░░░░░░░░░░░░░░░░░░░',
'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░',
'Rue de Port Maria░░░░░░░░░░░░░░░░░░░░',
'Avenue Charles Gounod░░░░░░░░░░░░░░░░',
'Rue de l‘Aurore░░░░░░░░░░░░░░░░░░░░░░',
'Rue de Beuzeville░░░░░░░░░░░░░░░░░░░░',
'Rue d‘Orbey░░░░░░░░░░░░░░░░░░░░░░░░░░',
'Rue Victor Schoulcher░░░░░░░░░░░░░░░░',
'Rue de la Gare░░░░░░░░░░░░░░░░░░░░░░░',
'Rue des Tulipes░░░░░░░░░░░░░░░░░░░░░░',
'Rue André Maginot░░░░░░░░░░░░░░░░░░░░',
'Route de Pringy░░░░░░░░░░░░░░░░░░░░░░',
'Rue des Landelles░░░░░░░░░░░░░░░░░░░░',
'Rue des Ilettes░░░░░░░░░░░░░░░░░░░░░░',
'Avenue de Maurin░░░░░░░░░░░░░░░░░░░░░',
'Rue Théresa░░░░░░░░░░░░░░░░░░░░░░░░░░', # GT='Rue Thérésa'
'Route de la Balme░░░░░░░░░░░░░░░░░░░░',
'Rue Hélène Roederer░░░░░░░░░░░░░░░░░░',
'Rue Emile Bernard░░░░░░░░░░░░░░░░░░░░',
'Place de la Mairie░░░░░░░░░░░░░░░░░░░',
'Rue des Perrots░░░░░░░░░░░░░░░░░░░░░░',
'Rue de la Libération░░░░░░░░░░░░░░░░░',
'Impasse du Capcir░░░░░░░░░░░░░░░░░░░░',
'Avenue de la Grand Mare░░░░░░░░░░░░░░',
'Rue Pierre Brossolette░░░░░░░░░░░░░░░',
'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░',
'Rue du Docteur Mourre░░░░░░░░░░░░░░░░',
'Rue d‘Ortheuil░░░░░░░░░░░░░░░░░░░░░░░',
'Rue des Sarments░░░░░░░░░░░░░░░░░░░░░',
'Rue du Centre░░░░░░░░░░░░░░░░░░░░░░░░',
'Impasse Pierre Mourgues░░░░░░░░░░░░░░',
'Rue Marcel Dassault░░░░░░░░░░░░░░░░░░'
], predictions)
if __name__ == '__main__':
tf.test.main()
...@@ -85,7 +85,7 @@ class CharsetMapper(object): ...@@ -85,7 +85,7 @@ class CharsetMapper(object):
""" """
mapping_strings = tf.constant(_dict_to_array(charset, default_character)) mapping_strings = tf.constant(_dict_to_array(charset, default_character))
self.table = tf.contrib.lookup.index_to_string_table_from_tensor( self.table = tf.contrib.lookup.index_to_string_table_from_tensor(
mapping=mapping_strings, default_value=default_character) mapping=mapping_strings, default_value=default_character)
def get_text(self, ids): def get_text(self, ids):
"""Returns a string corresponding to a sequence of character ids. """Returns a string corresponding to a sequence of character ids.
...@@ -94,7 +94,7 @@ class CharsetMapper(object): ...@@ -94,7 +94,7 @@ class CharsetMapper(object):
ids: a tensor with shape [batch_size, max_sequence_length] ids: a tensor with shape [batch_size, max_sequence_length]
""" """
return tf.reduce_join( return tf.reduce_join(
self.table.lookup(tf.to_int64(ids)), reduction_indices=1) self.table.lookup(tf.to_int64(ids)), reduction_indices=1)
def get_softmax_loss_fn(label_smoothing): def get_softmax_loss_fn(label_smoothing):
...@@ -111,12 +111,12 @@ def get_softmax_loss_fn(label_smoothing): ...@@ -111,12 +111,12 @@ def get_softmax_loss_fn(label_smoothing):
def loss_fn(labels, logits): def loss_fn(labels, logits):
return (tf.nn.softmax_cross_entropy_with_logits( return (tf.nn.softmax_cross_entropy_with_logits(
logits=logits, labels=labels)) logits=logits, labels=labels))
else: else:
def loss_fn(labels, logits): def loss_fn(labels, logits):
return tf.nn.sparse_softmax_cross_entropy_with_logits( return tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=labels) logits=logits, labels=labels)
return loss_fn return loss_fn
...@@ -125,12 +125,12 @@ class Model(object): ...@@ -125,12 +125,12 @@ class Model(object):
"""Class to create the Attention OCR Model.""" """Class to create the Attention OCR Model."""
def __init__(self, def __init__(self,
num_char_classes, num_char_classes,
seq_length, seq_length,
num_views, num_views,
null_code, null_code,
mparams=None, mparams=None,
charset=None): charset=None):
"""Initialized model parameters. """Initialized model parameters.
Args: Args:
...@@ -151,10 +151,10 @@ class Model(object): ...@@ -151,10 +151,10 @@ class Model(object):
""" """
super(Model, self).__init__() super(Model, self).__init__()
self._params = ModelParams( self._params = ModelParams(
num_char_classes=num_char_classes, num_char_classes=num_char_classes,
seq_length=seq_length, seq_length=seq_length,
num_views=num_views, num_views=num_views,
null_code=null_code) null_code=null_code)
self._mparams = self.default_mparams() self._mparams = self.default_mparams()
if mparams: if mparams:
self._mparams.update(mparams) self._mparams.update(mparams)
...@@ -166,16 +166,16 @@ class Model(object): ...@@ -166,16 +166,16 @@ class Model(object):
ConvTowerParams(final_endpoint='Mixed_5d'), ConvTowerParams(final_endpoint='Mixed_5d'),
'sequence_logit_fn': 'sequence_logit_fn':
SequenceLogitsParams( SequenceLogitsParams(
use_attention=True, use_attention=True,
use_autoregression=True, use_autoregression=True,
num_lstm_units=256, num_lstm_units=256,
weight_decay=0.00004, weight_decay=0.00004,
lstm_state_clip_value=10.0), lstm_state_clip_value=10.0),
'sequence_loss_fn': 'sequence_loss_fn':
SequenceLossParams( SequenceLossParams(
label_smoothing=0.1, label_smoothing=0.1,
ignore_nulls=True, ignore_nulls=True,
average_across_timesteps=False), average_across_timesteps=False),
'encode_coordinates_fn': EncodeCoordinatesParams(enabled=False) 'encode_coordinates_fn': EncodeCoordinatesParams(enabled=False)
} }
...@@ -201,11 +201,11 @@ class Model(object): ...@@ -201,11 +201,11 @@ class Model(object):
with tf.variable_scope('conv_tower_fn/INCE'): with tf.variable_scope('conv_tower_fn/INCE'):
if reuse: if reuse:
tf.get_variable_scope().reuse_variables() tf.get_variable_scope().reuse_variables()
with slim.arg_scope( with slim.arg_scope(inception.inception_v3_arg_scope()):
[slim.batch_norm, slim.dropout], is_training=is_training): with slim.arg_scope([slim.batch_norm, slim.dropout],
with slim.arg_scope(inception.inception_v3_arg_scope()): is_training=is_training):
net, _ = inception.inception_v3_base( net, _ = inception.inception_v3_base(
images, final_endpoint=mparams.final_endpoint) images, final_endpoint=mparams.final_endpoint)
return net return net
def _create_lstm_inputs(self, net): def _create_lstm_inputs(self, net):
...@@ -261,7 +261,7 @@ class Model(object): ...@@ -261,7 +261,7 @@ class Model(object):
nets_for_merge.append(tf.reshape(net, xy_flat_shape)) nets_for_merge.append(tf.reshape(net, xy_flat_shape))
merged_net = tf.concat(nets_for_merge, 1) merged_net = tf.concat(nets_for_merge, 1)
net = slim.max_pool2d( net = slim.max_pool2d(
merged_net, kernel_size=[len(nets_list), 1], stride=1) merged_net, kernel_size=[len(nets_list), 1], stride=1)
net = tf.reshape(net, (batch_size, height, width, num_features)) net = tf.reshape(net, (batch_size, height, width, num_features))
return net return net
...@@ -303,7 +303,7 @@ class Model(object): ...@@ -303,7 +303,7 @@ class Model(object):
log_prob = utils.logits_to_log_prob(chars_logit) log_prob = utils.logits_to_log_prob(chars_logit)
ids = tf.to_int32(tf.argmax(log_prob, axis=2), name='predicted_chars') ids = tf.to_int32(tf.argmax(log_prob, axis=2), name='predicted_chars')
mask = tf.cast( mask = tf.cast(
slim.one_hot_encoding(ids, self._params.num_char_classes), tf.bool) slim.one_hot_encoding(ids, self._params.num_char_classes), tf.bool)
all_scores = tf.nn.softmax(chars_logit) all_scores = tf.nn.softmax(chars_logit)
selected_scores = tf.boolean_mask(all_scores, mask, name='char_scores') selected_scores = tf.boolean_mask(all_scores, mask, name='char_scores')
scores = tf.reshape(selected_scores, shape=(-1, self._params.seq_length)) scores = tf.reshape(selected_scores, shape=(-1, self._params.seq_length))
...@@ -334,10 +334,10 @@ class Model(object): ...@@ -334,10 +334,10 @@ class Model(object):
return net return net
def create_base(self, def create_base(self,
images, images,
labels_one_hot, labels_one_hot,
scope='AttentionOcr_v1', scope='AttentionOcr_v1',
reuse=None): reuse=None):
"""Creates a base part of the Model (no gradients, losses or summaries). """Creates a base part of the Model (no gradients, losses or summaries).
Args: Args:
...@@ -355,7 +355,7 @@ class Model(object): ...@@ -355,7 +355,7 @@ class Model(object):
is_training = labels_one_hot is not None is_training = labels_one_hot is not None
with tf.variable_scope(scope, reuse=reuse): with tf.variable_scope(scope, reuse=reuse):
views = tf.split( views = tf.split(
value=images, num_or_size_splits=self._params.num_views, axis=2) value=images, num_or_size_splits=self._params.num_views, axis=2)
logging.debug('Views=%d single view: %s', len(views), views[0]) logging.debug('Views=%d single view: %s', len(views), views[0])
nets = [ nets = [
...@@ -381,11 +381,11 @@ class Model(object): ...@@ -381,11 +381,11 @@ class Model(object):
else: else:
predicted_text = tf.constant([]) predicted_text = tf.constant([])
return OutputEndpoints( return OutputEndpoints(
chars_logit=chars_logit, chars_logit=chars_logit,
chars_log_prob=chars_log_prob, chars_log_prob=chars_log_prob,
predicted_chars=predicted_chars, predicted_chars=predicted_chars,
predicted_scores=predicted_scores, predicted_scores=predicted_scores,
predicted_text=predicted_text) predicted_text=predicted_text)
def create_loss(self, data, endpoints): def create_loss(self, data, endpoints):
"""Creates all losses required to train the model. """Creates all losses required to train the model.
...@@ -421,7 +421,7 @@ class Model(object): ...@@ -421,7 +421,7 @@ class Model(object):
A sensor with the same shape as the input. A sensor with the same shape as the input.
""" """
one_hot_labels = tf.one_hot( one_hot_labels = tf.one_hot(
chars_labels, depth=self._params.num_char_classes, axis=-1) chars_labels, depth=self._params.num_char_classes, axis=-1)
pos_weight = 1.0 - weight pos_weight = 1.0 - weight
neg_weight = weight / self._params.num_char_classes neg_weight = weight / self._params.num_char_classes
return one_hot_labels * pos_weight + neg_weight return one_hot_labels * pos_weight + neg_weight
...@@ -446,7 +446,7 @@ class Model(object): ...@@ -446,7 +446,7 @@ class Model(object):
with tf.variable_scope('sequence_loss_fn/SLF'): with tf.variable_scope('sequence_loss_fn/SLF'):
if mparams.label_smoothing > 0: if mparams.label_smoothing > 0:
smoothed_one_hot_labels = self.label_smoothing_regularization( smoothed_one_hot_labels = self.label_smoothing_regularization(
chars_labels, mparams.label_smoothing) chars_labels, mparams.label_smoothing)
labels_list = tf.unstack(smoothed_one_hot_labels, axis=1) labels_list = tf.unstack(smoothed_one_hot_labels, axis=1)
else: else:
# NOTE: in case of sparse softmax we are not using one-hot # NOTE: in case of sparse softmax we are not using one-hot
...@@ -459,20 +459,20 @@ class Model(object): ...@@ -459,20 +459,20 @@ class Model(object):
else: else:
# Suppose that reject character is the last in the charset. # Suppose that reject character is the last in the charset.
reject_char = tf.constant( reject_char = tf.constant(
self._params.num_char_classes - 1, self._params.num_char_classes - 1,
shape=(batch_size, seq_length), shape=(batch_size, seq_length),
dtype=tf.int64) dtype=tf.int64)
known_char = tf.not_equal(chars_labels, reject_char) known_char = tf.not_equal(chars_labels, reject_char)
weights = tf.to_float(known_char) weights = tf.to_float(known_char)
logits_list = tf.unstack(chars_logits, axis=1) logits_list = tf.unstack(chars_logits, axis=1)
weights_list = tf.unstack(weights, axis=1) weights_list = tf.unstack(weights, axis=1)
loss = tf.contrib.legacy_seq2seq.sequence_loss( loss = tf.contrib.legacy_seq2seq.sequence_loss(
logits_list, logits_list,
labels_list, labels_list,
weights_list, weights_list,
softmax_loss_function=get_softmax_loss_fn(mparams.label_smoothing), softmax_loss_function=get_softmax_loss_fn(mparams.label_smoothing),
average_across_timesteps=mparams.average_across_timesteps) average_across_timesteps=mparams.average_across_timesteps)
tf.losses.add_loss(loss) tf.losses.add_loss(loss)
return loss return loss
...@@ -507,7 +507,7 @@ class Model(object): ...@@ -507,7 +507,7 @@ class Model(object):
if is_training: if is_training:
tf.summary.image( tf.summary.image(
sname('image/orig'), data.images_orig, max_outputs=max_outputs) sname('image/orig'), data.images_orig, max_outputs=max_outputs)
for var in tf.trainable_variables(): for var in tf.trainable_variables():
tf.summary.histogram(var.op.name, var) tf.summary.histogram(var.op.name, var)
return None return None
...@@ -522,17 +522,17 @@ class Model(object): ...@@ -522,17 +522,17 @@ class Model(object):
use_metric('CharacterAccuracy', use_metric('CharacterAccuracy',
metrics.char_accuracy( metrics.char_accuracy(
endpoints.predicted_chars, endpoints.predicted_chars,
data.labels, data.labels,
streaming=True, streaming=True,
rej_char=self._params.null_code)) rej_char=self._params.null_code))
# Sequence accuracy computed by cutting sequence at the first null char # Sequence accuracy computed by cutting sequence at the first null char
use_metric('SequenceAccuracy', use_metric('SequenceAccuracy',
metrics.sequence_accuracy( metrics.sequence_accuracy(
endpoints.predicted_chars, endpoints.predicted_chars,
data.labels, data.labels,
streaming=True, streaming=True,
rej_char=self._params.null_code)) rej_char=self._params.null_code))
for name, value in names_to_values.iteritems(): for name, value in names_to_values.iteritems():
summary_name = 'eval/' + name summary_name = 'eval/' + name
...@@ -540,7 +540,7 @@ class Model(object): ...@@ -540,7 +540,7 @@ class Model(object):
return names_to_updates.values() return names_to_updates.values()
def create_init_fn_to_restore(self, master_checkpoint, def create_init_fn_to_restore(self, master_checkpoint,
inception_checkpoint=None): inception_checkpoint=None):
"""Creates an init operations to restore weights from various checkpoints. """Creates an init operations to restore weights from various checkpoints.
Args: Args:
...@@ -565,12 +565,15 @@ class Model(object): ...@@ -565,12 +565,15 @@ class Model(object):
all_assign_ops.append(assign_op) all_assign_ops.append(assign_op)
all_feed_dict.update(feed_dict) all_feed_dict.update(feed_dict)
logging.info('variables_to_restore:\n%s' % utils.variables_to_restore().keys())
logging.info('moving_average_variables:\n%s' % [v.op.name for v in tf.moving_average_variables()])
logging.info('trainable_variables:\n%s' % [v.op.name for v in tf.trainable_variables()])
if master_checkpoint: if master_checkpoint:
assign_from_checkpoint(utils.variables_to_restore(), master_checkpoint) assign_from_checkpoint(utils.variables_to_restore(), master_checkpoint)
if inception_checkpoint: if inception_checkpoint:
variables = utils.variables_to_restore( variables = utils.variables_to_restore(
'AttentionOcr_v1/conv_tower_fn/INCE', strip_scope=True) 'AttentionOcr_v1/conv_tower_fn/INCE', strip_scope=True)
assign_from_checkpoint(variables, inception_checkpoint) assign_from_checkpoint(variables, inception_checkpoint)
def init_assign_fn(sess): def init_assign_fn(sess):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment