Commit 3eba37c7 authored by Alexander Gorban's avatar Alexander Gorban
Browse files

Open source release of Attention OCR - a model for real-world image text extraction.

parent 5188c975
## Attention-based Extraction of Structured Information from Street View Imagery
*A TensorFlow model for real-world image text extraction problems.*
This folder contains the code needed to train a new Attention OCR model on the
[FSNS dataset][FSNS] dataset to transcribe street names in France. You can
also use it to train it on your own data.
More details can be found in our paper:
["Attention-based Extraction of Structured Information from Street View
Imagery"](https://arxiv.org/abs/1704.03549)
## Contacts
Authors:
Zbigniew Wojna <zbigniewwojna@gmail.com>,
Alexander Gorban <gorban@google.com>
Pull requests:
[alexgorban](https://github.com/alexgorban)
## Requirements
1. Installed TensorFlow library ([instructions][TF]).
2. At least 158Gb of free disk space to download FSNS dataset:
```
aria2c -c -j 20 -i ../street/python/fsns_urls.txt
```
3. 16Gb of RAM or more, 32Gb is recommended.
4. The train.py works with in both modes CPU and GPU, using GPU is preferable.
The GPU mode was tested with Titan X and GTX980.
[TF]: https://www.tensorflow.org/install/
[FSNS]: https://github.com/tensorflow/models/tree/master/street
## How to use this code
To run all unit tests:
```
python -m unittest discover -p '*_test.py'
```
To train from scratch:
```
python train.py
```
To train a model using a pre-trained inception weights as initialization:
```
wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz
tar xf inception_v3_2016_08_28.tar.gz
python train.py --checkpoint_inception=inception_v3.ckpt
```
To fine tune the Attention OCR model using a checkpoint:
```
wget http://download.tensorflow.org/models/attention_ocr_2017_05_01.tar.gz
tar xf attention_ocr_2017_05_01.tar.gz
python train.py --checkpoint=model.ckpt-232572
```
## Disclaimer
This code is a modified version of the internal model we used for our paper.
Currently it reaches 82.71% full sequence accuracy after 215k steps of training.
The main difference between this version and the version used in the paper - for
the paper we used a distributed training with 50 GPU (K80) workers (asynchronous
updates), the provided checkpoint was created using this code after ~60 hours of
training on a single GPU (Titan X).
# A GPU/screen config to run all jobs for training and evaluation in parallel.
# Execute:
# source /path/to/your/virtualenv/bin/activate
# screen -R TF -c all_jobs.screenrc
screen -t train 0 python train.py --train_log_dir=workdir/train
screen -t eval_train 1 python eval.py --split_name=train --train_log_dir=workdir/train --eval_log_dir=workdir/eval_train
screen -t eval_test 2 python eval.py --split_name=test --train_log_dir=workdir/train --eval_log_dir=workdir/eval_test
screen -t tensorboard 3 tensorboard --logdir=workdir
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define flags are common for both train.py and eval.py scripts."""
import sys
from tensorflow.python.platform import flags
import logging
import datasets
import model
FLAGS = flags.FLAGS
logging.basicConfig(
level=logging.DEBUG,
stream=sys.stderr,
format='%(levelname)s '
'%(asctime)s.%(msecs)06d: '
'%(filename)s: '
'%(lineno)d '
'%(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
def define():
"""Define common flags."""
# yapf: disable
flags.DEFINE_integer('batch_size', 32,
'Batch size.')
flags.DEFINE_integer('crop_width', None,
'Width of the central crop for images.')
flags.DEFINE_integer('crop_height', None,
'Height of the central crop for images.')
flags.DEFINE_string('train_log_dir', '/tmp/attention_ocr/train',
'Directory where to write event logs.')
flags.DEFINE_string('dataset_name', 'fsns',
'Name of the dataset. Supported: fsns')
flags.DEFINE_string('split_name', 'train',
'Dataset split name to run evaluation for: test,train.')
flags.DEFINE_string('dataset_dir', None,
'Dataset root folder.')
flags.DEFINE_string('checkpoint', '',
'Path for checkpoint to restore weights from.')
flags.DEFINE_string('master',
'',
'BNS name of the TensorFlow master to use.')
# Model hyper parameters
flags.DEFINE_float('learning_rate', 0.004,
'learning rate')
flags.DEFINE_string('optimizer', 'momentum',
'the optimizer to use')
flags.DEFINE_string('momentum', 0.9,
'momentum value for the momentum optimizer if used')
flags.DEFINE_bool('use_augment_input', True,
'If True will use image augmentation')
# Method hyper parameters
# conv_tower_fn
flags.DEFINE_string('final_endpoint', 'Mixed_5d',
'Endpoint to cut inception tower')
# sequence_logit_fn
flags.DEFINE_bool('use_attention', True,
'If True will use the attention mechanism')
flags.DEFINE_bool('use_autoregression', True,
'If True will use autoregression (a feedback link)')
flags.DEFINE_integer('num_lstm_units', 256,
'number of LSTM units for sequence LSTM')
flags.DEFINE_float('weight_decay', 0.00004,
'weight decay for char prediction FC layers')
flags.DEFINE_float('lstm_state_clip_value', 10.0,
'cell state is clipped by this value prior to the cell'
' output activation')
# 'sequence_loss_fn'
flags.DEFINE_float('label_smoothing', 0.1,
'weight for label smoothing')
flags.DEFINE_bool('ignore_nulls', True,
'ignore null characters for computing the loss')
flags.DEFINE_bool('average_across_timesteps', False,
'divide the returned cost by the total label weight')
# yapf: enable
def get_crop_size():
if FLAGS.crop_width and FLAGS.crop_height:
return (FLAGS.crop_width, FLAGS.crop_height)
else:
return None
def create_dataset(split_name):
ds_module = getattr(datasets, FLAGS.dataset_name)
return ds_module.get_split(split_name, dataset_dir=FLAGS.dataset_dir)
def create_mparams():
return {
'conv_tower_fn':
model.ConvTowerParams(final_endpoint=FLAGS.final_endpoint),
'sequence_logit_fn':
model.SequenceLogitsParams(
use_attention=FLAGS.use_attention,
use_autoregression=FLAGS.use_autoregression,
num_lstm_units=FLAGS.num_lstm_units,
weight_decay=FLAGS.weight_decay,
lstm_state_clip_value=FLAGS.lstm_state_clip_value),
'sequence_loss_fn':
model.SequenceLossParams(
label_smoothing=FLAGS.label_smoothing,
ignore_nulls=FLAGS.ignore_nulls,
average_across_timesteps=FLAGS.average_across_timesteps)
}
def create_model(*args, **kwargs):
ocr_model = model.Model(mparams=create_mparams(), *args, **kwargs)
return ocr_model
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions to read, decode and pre-process input data for the Model.
"""
import collections
import functools
import tensorflow as tf
from tensorflow.contrib import slim
import inception_preprocessing
# Tuple to store input data endpoints for the Model.
# It has following fields (tensors):
# images: input images,
# shape [batch_size x H x W x 3];
# labels: ground truth label ids,
# shape=[batch_size x seq_length];
# labels_one_hot: labels in one-hot encoding,
# shape [batch_size x seq_length x num_char_classes];
InputEndpoints = collections.namedtuple(
'InputEndpoints', ['images', 'images_orig', 'labels', 'labels_one_hot'])
# A namedtuple to define a configuration for shuffled batch fetching.
# num_batching_threads: A number of parallel threads to fetch data.
# queue_capacity: a max number of elements in the batch shuffling queue.
# min_after_dequeue: a min number elements in the queue after a dequeue, used
# to ensure a level of mixing of elements.
ShuffleBatchConfig = collections.namedtuple('ShuffleBatchConfig', [
'num_batching_threads', 'queue_capacity', 'min_after_dequeue'
])
DEFAULT_SHUFFLE_CONFIG = ShuffleBatchConfig(
num_batching_threads=8, queue_capacity=3000, min_after_dequeue=1000)
def augment_image(image):
"""Augmentation the image with a random modification.
Args:
image: input Tensor image of rank 3, with the last dimension
of size 3.
Returns:
Distorted Tensor image of the same shape.
"""
with tf.variable_scope('AugmentImage'):
height = image.get_shape().dims[0].value
width = image.get_shape().dims[1].value
# Random crop cut from the street sign image, resized to the same size.
# Assures that the crop is covers at least 0.8 area of the input image.
bbox_begin, bbox_size, _ = tf.image.sample_distorted_bounding_box(
tf.shape(image),
bounding_boxes=tf.zeros([0, 0, 4]),
min_object_covered=0.8,
aspect_ratio_range=[0.8, 1.2],
area_range=[0.8, 1.0],
use_image_if_no_bounding_boxes=True)
distorted_image = tf.slice(image, bbox_begin, bbox_size)
# Randomly chooses one of the 4 interpolation methods
distorted_image = inception_preprocessing.apply_with_random_selector(
distorted_image,
lambda x, method: tf.image.resize_images(x, [height, width], method),
num_cases=4)
distorted_image.set_shape([height, width, 3])
# Color distortion
distorted_image = inception_preprocessing.apply_with_random_selector(
distorted_image,
functools.partial(
inception_preprocessing.distort_color, fast_mode=False),
num_cases=4)
distorted_image = tf.clip_by_value(distorted_image, -1.5, 1.5)
return distorted_image
def central_crop(image, crop_size):
"""Returns a central crop for the specified size of an image.
Args:
image: A tensor with shape [height, width, channels]
crop_size: A tuple (crop_width, crop_height)
Returns:
A tensor of shape [crop_height, crop_width, channels].
"""
with tf.variable_scope('CentralCrop'):
target_width, target_height = crop_size
image_height, image_width = tf.shape(image)[0], tf.shape(image)[1]
assert_op1 = tf.Assert(
tf.greater_equal(image_height, target_height),
['image_height < target_height', image_height, target_height])
assert_op2 = tf.Assert(
tf.greater_equal(image_width, target_width),
['image_width < target_width', image_width, target_width])
with tf.control_dependencies([assert_op1, assert_op2]):
offset_width = (image_width - target_width) / 2
offset_height = (image_height - target_height) / 2
return tf.image.crop_to_bounding_box(image, offset_height, offset_width,
target_height, target_width)
def preprocess_image(image, augment=False, central_crop_size=None,
num_towers=4):
"""Normalizes image to have values in a narrow range around zero.
Args:
image: a [H x W x 3] uint8 tensor.
augment: optional, if True do random image distortion.
central_crop_size: A tuple (crop_width, crop_height).
num_towers: optional, number of shots of the same image in the input image.
Returns:
A float32 tensor of shape [H x W x 3] with RGB values in the required
range.
"""
with tf.variable_scope('PreprocessImage'):
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
if augment or central_crop_size:
if num_towers == 1:
images = [image]
else:
images = tf.split(value=image, num_or_size_splits=num_towers, axis=1)
if central_crop_size:
view_crop_size = (central_crop_size[0] / num_towers,
central_crop_size[1])
images = [central_crop(img, view_crop_size) for img in images]
if augment:
images = [augment_image(img) for img in images]
image = tf.concat(images, 1)
image = tf.subtract(image, 0.5)
image = tf.multiply(image, 2.5)
return image
def get_data(dataset,
batch_size,
augment=False,
central_crop_size=None,
shuffle_config=None,
shuffle=True):
"""Wraps calls to DatasetDataProviders and shuffle_batch.
For more details about supported Dataset objects refer to datasets/fsns.py.
Args:
dataset: a slim.data.dataset.Dataset object.
batch_size: number of samples per batch.
augment: optional, if True does random image distortion.
central_crop_size: A CharLogittuple (crop_width, crop_height).
shuffle_config: A namedtuple ShuffleBatchConfig.
shuffle: if True use data shuffling.
Returns:
"""
if not shuffle_config:
shuffle_config = DEFAULT_SHUFFLE_CONFIG
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
shuffle=shuffle,
common_queue_capacity=2 * batch_size,
common_queue_min=batch_size)
image_orig, label = provider.get(['image', 'label'])
image = preprocess_image(
image_orig, augment, central_crop_size, num_towers=dataset.num_of_views)
label_one_hot = slim.one_hot_encoding(label, dataset.num_char_classes)
images, images_orig, labels, labels_one_hot = (tf.train.shuffle_batch(
[image, image_orig, label, label_one_hot],
batch_size=batch_size,
num_threads=shuffle_config.num_batching_threads,
capacity=shuffle_config.queue_capacity,
min_after_dequeue=shuffle_config.min_after_dequeue))
return InputEndpoints(
images=images,
images_orig=images_orig,
labels=labels,
labels_one_hot=labels_one_hot)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for data_provider."""
import numpy as np
import tensorflow as tf
from tensorflow.contrib.slim import queues
import datasets
import data_provider
class DataProviderTest(tf.test.TestCase):
def setUp(self):
tf.test.TestCase.setUp(self)
def test_preprocessed_image_values_are_in_range(self):
image_shape = (5, 4, 3)
fake_image = np.random.randint(low=0, high=255, size=image_shape)
image_tf = data_provider.preprocess_image(fake_image)
with self.test_session() as sess:
image_np = sess.run(image_tf)
self.assertEqual(image_np.shape, image_shape)
min_value, max_value = np.min(image_np), np.max(image_np)
self.assertTrue((-1.28 < min_value) and (min_value < 1.27))
self.assertTrue((-1.28 < max_value) and (max_value < 1.27))
def test_provided_data_has_correct_shape(self):
batch_size = 4
data = data_provider.get_data(
dataset=datasets.fsns_test.get_test_split(),
batch_size=batch_size,
augment=True,
central_crop_size=None)
with self.test_session() as sess, queues.QueueRunners(sess):
images_np, labels_np = sess.run([data.images, data.labels_one_hot])
self.assertEqual(images_np.shape, (batch_size, 150, 600, 3))
self.assertEqual(labels_np.shape, (batch_size, 37, 134))
def test_optionally_applies_central_crop(self):
batch_size = 4
data = data_provider.get_data(
dataset=datasets.fsns_test.get_test_split(),
batch_size=batch_size,
augment=True,
central_crop_size=(500, 100))
with self.test_session() as sess, queues.QueueRunners(sess):
images_np = sess.run(data.images)
self.assertEqual(images_np.shape, (batch_size, 100, 500, 3))
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import fsns
import fsns_test
__all__ = [fsns, fsns_test]
# -*- coding: utf-8 -*-
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Configuration to read FSNS dataset https://goo.gl/3Ldm8v."""
import os
import re
import tensorflow as tf
from tensorflow.contrib import slim
import logging
DEFAULT_DATASET_DIR = os.path.join(os.path.dirname(__file__), 'data/fsns')
# The dataset configuration, should be used only as a default value.
DEFAULT_CONFIG = {
'name': 'FSNS',
'splits': {
'train': {
'size': 1044868,
'pattern': 'train/train*'
},
'test': {
'size': 20404,
'pattern': 'test/test*'
},
'validation': {
'size': 16150,
'pattern': 'validation/validation*'
}
},
'charset_filename': 'charset_size=134.txt',
'image_shape': (150, 600, 3),
'num_of_views': 4,
'max_sequence_length': 37,
'null_code': 133,
'items_to_descriptions': {
'image': 'A [150 x 600 x 3] color image.',
'label': 'Characters codes.',
'text': 'A unicode string.',
'length': 'A length of the encoded text.',
'num_of_views': 'A number of different views stored within the image.'
}
}
def read_charset(filename, null_character=u'\u2591'):
"""Reads a charset definition from a tab separated text file.
charset file has to have format compatible with the FSNS dataset.
Args:
filename: a path to the charset file.
null_character: a unicode character used to replace '<null>' character. the
default value is a light shade block '░'.
Returns:
a dictionary with keys equal to character codes and values - unicode
characters.
"""
pattern = re.compile(r'(\d+)\t(.+)')
charset = {}
with tf.gfile.GFile(filename) as f:
for i, line in enumerate(f):
m = pattern.match(line)
if m is None:
logging.warning('incorrect charset file. line #%d: %s', i, line)
continue
code = int(m.group(1))
char = m.group(2).decode('utf-8')
if char == '<nul>':
char = null_character
charset[code] = char
return charset
class _NumOfViewsHandler(slim.tfexample_decoder.ItemHandler):
"""Convenience handler to determine number of views stored in an image."""
def __init__(self, width_key, original_width_key, num_of_views):
super(_NumOfViewsHandler, self).__init__([width_key, original_width_key])
self._width_key = width_key
self._original_width_key = original_width_key
self._num_of_views = num_of_views
def tensors_to_item(self, keys_to_tensors):
return tf.to_int64(
self._num_of_views * keys_to_tensors[self._original_width_key] /
keys_to_tensors[self._width_key])
def get_split(split_name, dataset_dir=None, config=None):
"""Returns a dataset tuple for FSNS dataset.
Args:
split_name: A train/test split name.
dataset_dir: The base directory of the dataset sources, by default it uses
a predefined CNS path (see DEFAULT_DATASET_DIR).
config: A dictionary with dataset configuration. If None - will use the
DEFAULT_CONFIG.
Returns:
A `Dataset` namedtuple.
Raises:
ValueError: if `split_name` is not a valid train/test split.
"""
if not dataset_dir:
dataset_dir = DEFAULT_DATASET_DIR
if not config:
config = DEFAULT_CONFIG
if split_name not in config['splits']:
raise ValueError('split name %s was not recognized.' % split_name)
logging.info('Using %s dataset split_name=%s dataset_dir=%s', config['name'],
split_name, dataset_dir)
# Ignores the 'image/height' feature.
zero = tf.zeros([1], dtype=tf.int64)
keys_to_features = {
'image/encoded':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/format':
tf.FixedLenFeature((), tf.string, default_value='png'),
'image/width':
tf.FixedLenFeature([1], tf.int64, default_value=zero),
'image/orig_width':
tf.FixedLenFeature([1], tf.int64, default_value=zero),
'image/class':
tf.FixedLenFeature([config['max_sequence_length']], tf.int64),
'image/unpadded_class':
tf.VarLenFeature(tf.int64),
'image/text':
tf.FixedLenFeature([1], tf.string, default_value=''),
}
items_to_handlers = {
'image':
slim.tfexample_decoder.Image(
shape=config['image_shape'],
image_key='image/encoded',
format_key='image/format'),
'label':
slim.tfexample_decoder.Tensor(tensor_key='image/class'),
'text':
slim.tfexample_decoder.Tensor(tensor_key='image/text'),
'num_of_views':
_NumOfViewsHandler(
width_key='image/width',
original_width_key='image/orig_width',
num_of_views=config['num_of_views'])
}
decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
charset_file = os.path.join(dataset_dir, config['charset_filename'])
charset = read_charset(charset_file)
file_pattern = os.path.join(dataset_dir,
config['splits'][split_name]['pattern'])
return slim.dataset.Dataset(
data_sources=file_pattern,
reader=tf.TFRecordReader,
decoder=decoder,
num_samples=config['splits'][split_name]['size'],
items_to_descriptions=config['items_to_descriptions'],
# additional parameters for convenience.
charset=charset,
num_char_classes=len(charset),
num_of_views=config['num_of_views'],
max_sequence_length=config['max_sequence_length'],
null_code=config['null_code'])
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for FSNS datasets module."""
import collections
import os
import tensorflow as tf
from tensorflow.contrib import slim
import fsns
import unittest_utils
FLAGS = tf.flags.FLAGS
def get_test_split():
config = fsns.DEFAULT_CONFIG.copy()
config['splits'] = {'test': {'size': 50, 'pattern': 'fsns-00000-of-00001'}}
return fsns.get_split('test', dataset_dir(), config)
def dataset_dir():
return os.path.join(os.path.dirname(__file__), 'testdata/fsns')
class FsnsTest(tf.test.TestCase):
def test_decodes_example_proto(self):
expected_label = range(37)
expected_image, encoded = unittest_utils.create_random_image(
'PNG', shape=(150, 600, 3))
serialized = unittest_utils.create_serialized_example({
'image/encoded': [encoded],
'image/format': ['PNG'],
'image/class':
expected_label,
'image/unpadded_class':
range(10),
'image/text': ['Raw text'],
'image/orig_width': [150],
'image/width': [600]
})
decoder = fsns.get_split('train', dataset_dir()).decoder
with self.test_session() as sess:
data_tuple = collections.namedtuple('DecodedData', decoder.list_items())
data = sess.run(data_tuple(*decoder.decode(serialized)))
self.assertAllEqual(expected_image, data.image)
self.assertAllEqual(expected_label, data.label)
self.assertEqual(['Raw text'], data.text)
self.assertEqual([1], data.num_of_views)
def test_label_has_shape_defined(self):
serialized = 'fake'
decoder = fsns.get_split('train', dataset_dir()).decoder
[label_tf] = decoder.decode(serialized, ['label'])
self.assertEqual(label_tf.get_shape().dims[0], 37)
def test_dataset_tuple_has_all_extra_attributes(self):
dataset = fsns.get_split('train', dataset_dir())
self.assertTrue(dataset.charset)
self.assertTrue(dataset.num_char_classes)
self.assertTrue(dataset.num_of_views)
self.assertTrue(dataset.max_sequence_length)
self.assertTrue(dataset.null_code)
def test_can_use_the_test_data(self):
batch_size = 1
dataset = get_test_split()
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
shuffle=True,
common_queue_capacity=2 * batch_size,
common_queue_min=batch_size)
image_tf, label_tf = provider.get(['image', 'label'])
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
with slim.queues.QueueRunners(sess):
image_np, label_np = sess.run([image_tf, label_tf])
self.assertEqual((150, 600, 3), image_np.shape)
self.assertEqual((37, ), label_np.shape)
if __name__ == '__main__':
tf.test.main()
0
133 <nul>
1 l
2 ’
3 é
4 t
5 e
6 i
7 n
8 s
9 x
10 g
11 u
12 o
13 1
14 8
15 7
16 0
17 -
18 .
19 p
20 a
21 r
22 è
23 d
24 c
25 V
26 v
27 b
28 m
29 )
30 C
31 z
32 S
33 y
34 ,
35 k
36 É
37 A
38 h
39 E
40 »
41 D
42 /
43 H
44 M
45 (
46 G
47 P
48 ç
2 '
49 R
50 f
51 "
52 2
53 j
54 |
55 N
56 6
57 °
58 5
59 T
60 O
61 U
62 3
63 %
64 9
65 q
66 Z
67 B
68 K
69 w
70 W
71 :
72 4
73 L
74 F
75 ]
76 ï
2 ‘
77 I
78 J
79 ä
80 î
81 ;
82 à
83 ê
84 X
85 ü
86 Y
87 ô
88 =
89 +
90 \
91 {
92 }
93 _
94 Q
95 œ
96 ñ
97 *
98 !
99 Ü
51 “
100 â
101 Ç
102 Œ
103 û
104 ?
105 $
106 ë
107 «
108 €
109 &
110 <
51 ”
111 æ
112 #
113 ®
114 Â
115 È
116 >
117 [
17 —
118 Æ
119 ù
120 Î
121 Ô
122 ÿ
123 À
124 Ê
125 @
126 Ï
127 ©
128 Ë
129 Ù
130 £
131 Ÿ
132 Û
http://download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions to make unit testing easier."""
import StringIO
import numpy as np
from PIL import Image as PILImage
import tensorflow as tf
def create_random_image(image_format, shape):
"""Creates an image with random values.
Args:
image_format: An image format (PNG or JPEG).
shape: A tuple with image shape (including channels).
Returns:
A tuple (<numpy ndarray>, <a string with encoded image>)
"""
image = np.random.randint(low=0, high=255, size=shape, dtype='uint8')
io = StringIO.StringIO()
image_pil = PILImage.fromarray(image)
image_pil.save(io, image_format, subsampling=0, quality=100)
return image, io.getvalue()
def create_serialized_example(name_to_values):
"""Creates a tf.Example proto using a dictionary.
It automatically detects type of values and define a corresponding feature.
Args:
name_to_values: A dictionary.
Returns:
tf.Example proto.
"""
example = tf.train.Example()
for name, values in name_to_values.items():
feature = example.features.feature[name]
if isinstance(values[0], str):
add = feature.bytes_list.value.extend
elif isinstance(values[0], float):
add = feature.float32_list.value.extend
elif isinstance(values[0], int):
add = feature.int64_list.value.extend
else:
raise AssertionError('Unsupported type: %s' % type(values[0]))
add(values)
return example.SerializeToString()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for unittest_utils."""
import StringIO
import numpy as np
from PIL import Image as PILImage
import tensorflow as tf
import unittest_utils
class UnittestUtilsTest(tf.test.TestCase):
def test_creates_an_image_of_specified_shape(self):
image, _ = unittest_utils.create_random_image('PNG', (10, 20, 3))
self.assertEqual(image.shape, (10, 20, 3))
def test_encoded_image_corresponds_to_numpy_array(self):
image, encoded = unittest_utils.create_random_image('PNG', (20, 10, 3))
pil_image = PILImage.open(StringIO.StringIO(encoded))
self.assertAllEqual(image, np.array(pil_image))
def test_created_example_has_correct_values(self):
example_serialized = unittest_utils.create_serialized_example({
'labels': [1, 2, 3],
'data': ['FAKE']
})
example = tf.train.Example()
example.ParseFromString(example_serialized)
self.assertProtoEquals("""
features {
feature {
key: "labels"
value { int64_list {
value: 1
value: 2
value: 3
}}
}
feature {
key: "data"
value { bytes_list {
value: "FAKE"
}}
}
}
""", example)
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Script to evaluate a trained Attention OCR model.
A simple usage example:
python eval.py
"""
import tensorflow as tf
from tensorflow.contrib import slim
from tensorflow import app
from tensorflow.python.platform import flags
import data_provider
import common_flags
FLAGS = flags.FLAGS
common_flags.define()
# yapf: disable
flags.DEFINE_integer('num_batches', 100,
'Number of batches to run eval for.')
flags.DEFINE_string('eval_log_dir', '/tmp/attention_ocr/eval',
'Directory where the evaluation results are saved to.')
flags.DEFINE_integer('eval_interval_secs', 60,
'Frequency in seconds to run evaluations.')
flags.DEFINE_integer('number_of_steps', None,
'Number of times to run evaluation.')
# yapf: enable
def main(_):
if not tf.gfile.Exists(FLAGS.eval_log_dir):
tf.gfile.MakeDirs(FLAGS.eval_log_dir)
dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
model = common_flags.create_model(dataset.num_char_classes,
dataset.max_sequence_length,
dataset.num_of_views, dataset.null_code)
data = data_provider.get_data(
dataset,
FLAGS.batch_size,
augment=False,
central_crop_size=common_flags.get_crop_size())
endpoints = model.create_base(data.images, labels_one_hot=None)
model.create_loss(data, endpoints)
eval_ops = model.create_summaries(
data, endpoints, dataset.charset, is_training=False)
slim.get_or_create_global_step()
session_config = tf.ConfigProto(device_count={"GPU": 0})
slim.evaluation.evaluation_loop(
master=FLAGS.master,
checkpoint_dir=FLAGS.train_log_dir,
logdir=FLAGS.eval_log_dir,
eval_op=eval_ops,
num_evals=FLAGS.num_batches,
eval_interval_secs=FLAGS.eval_interval_secs,
max_number_of_evaluations=FLAGS.number_of_steps,
session_config=session_config)
if __name__ == '__main__':
app.run()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides utilities to preprocess images for the Inception networks."""
# TODO(gorban): add as a dependency, when slim or tensorflow/models are pipfied
# Source:
# https://raw.githubusercontent.com/tensorflow/models/a9d0e6e8923a4/slim/preprocessing/inception_preprocessing.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.ops import control_flow_ops
def apply_with_random_selector(x, func, num_cases):
"""Computes func(x, sel), with sel sampled from [0...num_cases-1].
Args:
x: input Tensor.
func: Python function to apply.
num_cases: Python int32, number of cases to sample sel from.
Returns:
The result of func(x, sel), where func receives the value of the
selector as a python integer, but sel is sampled dynamically.
"""
sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32)
# Pass the real x only to one of the func calls.
return control_flow_ops.merge([
func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case)
for case in range(num_cases)
])[0]
def distort_color(image, color_ordering=0, fast_mode=True, scope=None):
"""Distort the color of a Tensor image.
Each color distortion is non-commutative and thus ordering of the color ops
matters. Ideally we would randomly permute the ordering of the color ops.
Rather then adding that level of complication, we select a distinct ordering
of color ops for each preprocessing thread.
Args:
image: 3-D Tensor containing single image in [0, 1].
color_ordering: Python int, a type of distortion (valid values: 0-3).
fast_mode: Avoids slower ops (random_hue and random_contrast)
scope: Optional scope for name_scope.
Returns:
3-D Tensor color-distorted image on range [0, 1]
Raises:
ValueError: if color_ordering not in [0, 3]
"""
with tf.name_scope(scope, 'distort_color', [image]):
if fast_mode:
if color_ordering == 0:
image = tf.image.random_brightness(image, max_delta=32. / 255.)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
else:
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_brightness(image, max_delta=32. / 255.)
else:
if color_ordering == 0:
image = tf.image.random_brightness(image, max_delta=32. / 255.)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_hue(image, max_delta=0.2)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
elif color_ordering == 1:
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_brightness(image, max_delta=32. / 255.)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
image = tf.image.random_hue(image, max_delta=0.2)
elif color_ordering == 2:
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
image = tf.image.random_hue(image, max_delta=0.2)
image = tf.image.random_brightness(image, max_delta=32. / 255.)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
elif color_ordering == 3:
image = tf.image.random_hue(image, max_delta=0.2)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
image = tf.image.random_brightness(image, max_delta=32. / 255.)
else:
raise ValueError('color_ordering must be in [0, 3]')
# The random_* ops do not necessarily clamp.
return tf.clip_by_value(image, 0.0, 1.0)
def distorted_bounding_box_crop(image,
bbox,
min_object_covered=0.1,
aspect_ratio_range=(0.75, 1.33),
area_range=(0.05, 1.0),
max_attempts=100,
scope=None):
"""Generates cropped_image using a one of the bboxes randomly distorted.
See `tf.image.sample_distorted_bounding_box` for more documentation.
Args:
image: 3-D Tensor of image (it will be converted to floats in [0, 1]).
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged
as [ymin, xmin, ymax, xmax]. If num_boxes is 0 then it would use the
whole image.
min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
area of the image must contain at least this fraction of any bounding box
supplied.
aspect_ratio_range: An optional list of `floats`. The cropped area of the
image must have an aspect ratio = width / height within this range.
area_range: An optional list of `floats`. The cropped area of the image
must contain a fraction of the supplied image within in this range.
max_attempts: An optional `int`. Number of attempts at generating a cropped
region of the image of the specified constraints. After `max_attempts`
failures, return the entire image.
scope: Optional scope for name_scope.
Returns:
A tuple, a 3-D Tensor cropped_image and the distorted bbox
"""
with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]):
# Each bounding box has shape [1, num_boxes, box coords] and
# the coordinates are ordered [ymin, xmin, ymax, xmax].
# A large fraction of image datasets contain a human-annotated bounding
# box delineating the region of the image containing the object of interest.
# We choose to create a new bounding box for the object which is a randomly
# distorted version of the human-annotated bounding box that obeys an
# allowed range of aspect ratios, sizes and overlap with the human-annotated
# bounding box. If no box is supplied, then we assume the bounding box is
# the entire image.
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
tf.shape(image),
bounding_boxes=bbox,
min_object_covered=min_object_covered,
aspect_ratio_range=aspect_ratio_range,
area_range=area_range,
max_attempts=max_attempts,
use_image_if_no_bounding_boxes=True)
bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box
# Crop the image to the specified bounding box.
cropped_image = tf.slice(image, bbox_begin, bbox_size)
return cropped_image, distort_bbox
def preprocess_for_train(image,
height,
width,
bbox,
fast_mode=True,
scope=None):
"""Distort one image for training a network.
Distorting images provides a useful technique for augmenting the data
set during training in order to make the network invariant to aspects
of the image that do not effect the label.
Additionally it would create image_summaries to display the different
transformations applied to the image.
Args:
image: 3-D Tensor of image. If dtype is tf.float32 then the range should be
[0, 1], otherwise it would converted to tf.float32 assuming that the range
is [0, MAX], where MAX is largest positive representable number for
int(8/16/32) data type (see `tf.image.convert_image_dtype` for details).
height: integer
width: integer
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged
as [ymin, xmin, ymax, xmax].
fast_mode: Optional boolean, if True avoids slower transformations (i.e.
bi-cubic resizing, random_hue or random_contrast).
scope: Optional scope for name_scope.
Returns:
3-D float Tensor of distorted image used for training with range [-1, 1].
"""
with tf.name_scope(scope, 'distort_image', [image, height, width, bbox]):
if bbox is None:
bbox = tf.constant(
[0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
if image.dtype != tf.float32:
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
# Each bounding box has shape [1, num_boxes, box coords] and
# the coordinates are ordered [ymin, xmin, ymax, xmax].
image_with_box = tf.image.draw_bounding_boxes(
tf.expand_dims(image, 0), bbox)
tf.summary.image('image_with_bounding_boxes', image_with_box)
distorted_image, distorted_bbox = distorted_bounding_box_crop(image, bbox)
# Restore the shape since the dynamic slice based upon the bbox_size loses
# the third dimension.
distorted_image.set_shape([None, None, 3])
image_with_distorted_box = tf.image.draw_bounding_boxes(
tf.expand_dims(image, 0), distorted_bbox)
tf.summary.image('images_with_distorted_bounding_box',
image_with_distorted_box)
# This resizing operation may distort the images because the aspect
# ratio is not respected. We select a resize method in a round robin
# fashion based on the thread number.
# Note that ResizeMethod contains 4 enumerated resizing methods.
# We select only 1 case for fast_mode bilinear.
num_resize_cases = 1 if fast_mode else 4
distorted_image = apply_with_random_selector(
distorted_image,
lambda x, method: tf.image.resize_images(x, [height, width], method=method),
num_cases=num_resize_cases)
tf.summary.image('cropped_resized_image',
tf.expand_dims(distorted_image, 0))
# Randomly flip the image horizontally.
distorted_image = tf.image.random_flip_left_right(distorted_image)
# Randomly distort the colors. There are 4 ways to do it.
distorted_image = apply_with_random_selector(
distorted_image,
lambda x, ordering: distort_color(x, ordering, fast_mode),
num_cases=4)
tf.summary.image('final_distorted_image',
tf.expand_dims(distorted_image, 0))
distorted_image = tf.subtract(distorted_image, 0.5)
distorted_image = tf.multiply(distorted_image, 2.0)
return distorted_image
def preprocess_for_eval(image,
height,
width,
central_fraction=0.875,
scope=None):
"""Prepare one image for evaluation.
If height and width are specified it would output an image with that size by
applying resize_bilinear.
If central_fraction is specified it would cropt the central fraction of the
input image.
Args:
image: 3-D Tensor of image. If dtype is tf.float32 then the range should be
[0, 1], otherwise it would converted to tf.float32 assuming that the range
is [0, MAX], where MAX is largest positive representable number for
int(8/16/32) data type (see `tf.image.convert_image_dtype` for details)
height: integer
width: integer
central_fraction: Optional Float, fraction of the image to crop.
scope: Optional scope for name_scope.
Returns:
3-D float Tensor of prepared image.
"""
with tf.name_scope(scope, 'eval_image', [image, height, width]):
if image.dtype != tf.float32:
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
# Crop the central region of the image with an area containing 87.5% of
# the original image.
if central_fraction:
image = tf.image.central_crop(image, central_fraction=central_fraction)
if height and width:
# Resize the image to the specified height and width.
image = tf.expand_dims(image, 0)
image = tf.image.resize_bilinear(
image, [height, width], align_corners=False)
image = tf.squeeze(image, [0])
image = tf.subtract(image, 0.5)
image = tf.multiply(image, 2.0)
return image
def preprocess_image(image,
height,
width,
is_training=False,
bbox=None,
fast_mode=True):
"""Pre-process one image for training or evaluation.
Args:
image: 3-D Tensor [height, width, channels] with the image.
height: integer, image expected height.
width: integer, image expected width.
is_training: Boolean. If true it would transform an image for train,
otherwise it would transform it for evaluation.
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged as
[ymin, xmin, ymax, xmax].
fast_mode: Optional boolean, if True avoids slower transformations.
Returns:
3-D float Tensor containing an appropriately scaled image
Raises:
ValueError: if user does not provide bounding box
"""
if is_training:
return preprocess_for_train(image, height, width, bbox, fast_mode)
else:
return preprocess_for_eval(image, height, width)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Quality metrics for the model."""
import tensorflow as tf
def char_accuracy(predictions, targets, rej_char, streaming=False):
"""Computes character level accuracy.
Both predictions and targets should have the same shape
[batch_size x seq_length].
Args:
predictions: predicted characters ids.
targets: ground truth character ids.
rej_char: the character id used to mark an empty element (end of sequence).
streaming: if True, uses the streaming mean from the slim.metric module.
Returns:
a update_ops for execution and value tensor whose value on evaluation
returns the total character accuracy.
"""
with tf.variable_scope('CharAccuracy'):
predictions.get_shape().assert_is_compatible_with(targets.get_shape())
targets = tf.to_int32(targets)
const_rej_char = tf.constant(rej_char, shape=targets.get_shape())
weights = tf.to_float(tf.not_equal(targets, const_rej_char))
correct_chars = tf.to_float(tf.equal(predictions, targets))
accuracy_per_example = tf.div(
tf.reduce_sum(tf.multiply(correct_chars, weights), 1),
tf.reduce_sum(weights, 1))
if streaming:
return tf.contrib.metrics.streaming_mean(accuracy_per_example)
else:
return tf.reduce_mean(accuracy_per_example)
def sequence_accuracy(predictions, targets, rej_char, streaming=False):
"""Computes sequence level accuracy.
Both input tensors should have the same shape: [batch_size x seq_length].
Args:
predictions: predicted character classes.
targets: ground truth character classes.
rej_char: the character id used to mark empty element (end of sequence).
streaming: if True, uses the streaming mean from the slim.metric module.
Returns:
a update_ops for execution and value tensor whose value on evaluation
returns the total sequence accuracy.
"""
with tf.variable_scope('SequenceAccuracy'):
predictions.get_shape().assert_is_compatible_with(targets.get_shape())
targets = tf.to_int32(targets)
const_rej_char = tf.constant(
rej_char, shape=targets.get_shape(), dtype=tf.int32)
include_mask = tf.not_equal(targets, const_rej_char)
include_predictions = tf.to_int32(
tf.where(include_mask, predictions,
tf.zeros_like(predictions) + rej_char))
correct_chars = tf.to_float(tf.equal(include_predictions, targets))
correct_chars_counts = tf.cast(
tf.reduce_sum(correct_chars, reduction_indices=[1]), dtype=tf.int32)
target_length = targets.get_shape().dims[1].value
target_chars_counts = tf.constant(
target_length, shape=correct_chars_counts.get_shape())
accuracy_per_example = tf.to_float(
tf.equal(correct_chars_counts, target_chars_counts))
if streaming:
return tf.contrib.metrics.streaming_mean(accuracy_per_example)
else:
return tf.reduce_mean(accuracy_per_example)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the metrics module."""
import contextlib
import numpy as np
import tensorflow as tf
import metrics
class AccuracyTest(tf.test.TestCase):
def setUp(self):
tf.test.TestCase.setUp(self)
self.rng = np.random.RandomState([11, 23, 50])
self.num_char_classes = 3
self.batch_size = 4
self.seq_length = 5
self.rej_char = 42
@contextlib.contextmanager
def initialized_session(self):
"""Wrapper for test session context manager with required initialization.
Yields:
A session object that should be used as a context manager.
"""
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
yield sess
def _fake_labels(self):
return self.rng.randint(
low=0,
high=self.num_char_classes,
size=(self.batch_size, self.seq_length),
dtype='int32')
def _incorrect_copy(self, values, bad_indexes):
incorrect = np.copy(values)
incorrect[bad_indexes] = values[bad_indexes] + 1
return incorrect
def test_sequence_accuracy_identical_samples(self):
labels_tf = tf.convert_to_tensor(self._fake_labels())
accuracy_tf = metrics.sequence_accuracy(labels_tf, labels_tf,
self.rej_char)
with self.initialized_session() as sess:
accuracy_np = sess.run(accuracy_tf)
self.assertAlmostEqual(accuracy_np, 1.0)
def test_sequence_accuracy_one_char_difference(self):
ground_truth_np = self._fake_labels()
ground_truth_tf = tf.convert_to_tensor(ground_truth_np)
prediction_tf = tf.convert_to_tensor(
self._incorrect_copy(ground_truth_np, bad_indexes=((0, 0))))
accuracy_tf = metrics.sequence_accuracy(prediction_tf, ground_truth_tf,
self.rej_char)
with self.initialized_session() as sess:
accuracy_np = sess.run(accuracy_tf)
# 1 of 4 sequences is incorrect.
self.assertAlmostEqual(accuracy_np, 1.0 - 1.0 / self.batch_size)
def test_char_accuracy_one_char_difference_with_padding(self):
ground_truth_np = self._fake_labels()
ground_truth_tf = tf.convert_to_tensor(ground_truth_np)
prediction_tf = tf.convert_to_tensor(
self._incorrect_copy(ground_truth_np, bad_indexes=((0, 0))))
accuracy_tf = metrics.char_accuracy(prediction_tf, ground_truth_tf,
self.rej_char)
with self.initialized_session() as sess:
accuracy_np = sess.run(accuracy_tf)
chars_count = self.seq_length * self.batch_size
self.assertAlmostEqual(accuracy_np, 1.0 - 1.0 / chars_count)
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions to build the Attention OCR model.
Usage example:
ocr_model = model.Model(num_char_classes, seq_length, num_of_views)
data = ... # create namedtuple InputEndpoints
endpoints = model.create_base(data.images, data.labels_one_hot)
# endpoints.predicted_chars is a tensor with predicted character codes.
total_loss = model.create_loss(data, endpoints)
"""
import sys
import collections
import logging
import tensorflow as tf
from tensorflow.contrib import slim
from tensorflow.contrib.slim.nets import inception
import metrics
import sequence_layers
import utils
OutputEndpoints = collections.namedtuple('OutputEndpoints', [
'chars_logit', 'chars_log_prob', 'predicted_chars', 'predicted_scores'
])
# TODO(gorban): replace with tf.HParams when it is released.
ModelParams = collections.namedtuple('ModelParams', [
'num_char_classes', 'seq_length', 'num_views', 'null_code'
])
ConvTowerParams = collections.namedtuple('ConvTowerParams', ['final_endpoint'])
SequenceLogitsParams = collections.namedtuple('SequenceLogitsParams', [
'use_attention', 'use_autoregression', 'num_lstm_units', 'weight_decay',
'lstm_state_clip_value'
])
SequenceLossParams = collections.namedtuple('SequenceLossParams', [
'label_smoothing', 'ignore_nulls', 'average_across_timesteps'
])
def _dict_to_array(id_to_char, default_character):
num_char_classes = max(id_to_char.keys()) + 1
array = [default_character] * num_char_classes
for k, v in id_to_char.iteritems():
array[k] = v
return array
class CharsetMapper(object):
"""A simple class to map tensor ids into strings.
It works only when the character set is 1:1 mapping between individual
characters and individual ids.
Make sure you call tf.tables_initializer().run() as part of the init op.
"""
def __init__(self, charset, default_character='?'):
"""Creates a lookup table.
Args:
charset: a dictionary with id-to-character mapping.
"""
mapping_strings = tf.constant(_dict_to_array(charset, default_character))
self.table = tf.contrib.lookup.index_to_string_table_from_tensor(
mapping=mapping_strings, default_value=default_character)
def get_text(self, ids):
"""Returns a string corresponding to a sequence of character ids.
Args:
ids: a tensor with shape [batch_size, max_sequence_length]
"""
return tf.reduce_join(
self.table.lookup(tf.to_int64(ids)), reduction_indices=1)
def get_softmax_loss_fn(label_smoothing):
"""Returns sparse or dense loss function depending on the label_smoothing.
Args:
label_smoothing: weight for label smoothing
Returns:
a function which takes labels and predictions as arguments and returns
a softmax loss for the selected type of labels (sparse or dense).
"""
if label_smoothing > 0:
def loss_fn(labels, logits):
return (tf.nn.softmax_cross_entropy_with_logits(
logits=logits, labels=labels))
else:
def loss_fn(labels, logits):
return tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=labels)
return loss_fn
class Model(object):
"""Class to create the Attention OCR Model."""
def __init__(self,
num_char_classes,
seq_length,
num_views,
null_code,
mparams=None):
"""Initialized model parameters.
Args:
num_char_classes: size of character set.
seq_length: number of characters in a sequence.
num_views: Number of views (conv towers) to use.
null_code: A character code corresponding to a character which
indicates end of a sequence.
mparams: a dictionary with hyper parameters for methods, keys -
function names, values - corresponding namedtuples.
"""
super(Model, self).__init__()
self._params = ModelParams(
num_char_classes=num_char_classes,
seq_length=seq_length,
num_views=num_views,
null_code=null_code)
self._mparams = self.default_mparams()
if mparams:
self._mparams.update(mparams)
def default_mparams(self):
return {
'conv_tower_fn':
ConvTowerParams(final_endpoint='Mixed_5d'),
'sequence_logit_fn':
SequenceLogitsParams(
use_attention=True,
use_autoregression=True,
num_lstm_units=256,
weight_decay=0.00004,
lstm_state_clip_value=10.0),
'sequence_loss_fn':
SequenceLossParams(
label_smoothing=0.1,
ignore_nulls=True,
average_across_timesteps=False)
}
def set_mparam(self, function, **kwargs):
self._mparams[function] = self._mparams[function]._replace(**kwargs)
def conv_tower_fn(self, images, is_training=True, reuse=None):
"""Computes convolutional features using the InceptionV3 model.
Args:
images: A tensor of shape [batch_size, height, width, channels].
is_training: whether is training or not.
reuse: whether or not the network and its variables should be reused. To
be able to reuse 'scope' must be given.
Returns:
A tensor of shape [batch_size, OH, OW, N], where OWxOH is resolution of
output feature map and N is number of output features (depends on the
network architecture).
"""
mparams = self._mparams['conv_tower_fn']
logging.debug('Using final_endpoint=%s', mparams.final_endpoint)
with tf.variable_scope('conv_tower_fn/INCE'):
if reuse:
tf.get_variable_scope().reuse_variables()
with slim.arg_scope(inception.inception_v3_arg_scope()):
net, _ = inception.inception_v3_base(
images, final_endpoint=mparams.final_endpoint)
return net
def _create_lstm_inputs(self, net):
"""Splits an input tensor into a list of tensors (features).
Args:
net: A feature map of shape [batch_size, num_features, feature_size].
Raises:
AssertionError: if num_features is less than seq_length.
Returns:
A list with seq_length tensors of shape [batch_size, feature_size]
"""
num_features = net.get_shape().dims[1].value
if num_features < self._params.seq_length:
raise AssertionError('Incorrect dimension #1 of input tensor'
' %d should be bigger than %d (shape=%s)' %
(num_features, self._params.seq_length,
net.get_shape()))
elif num_features > self._params.seq_length:
logging.warning('Ignoring some features: use %d of %d (shape=%s)',
self._params.seq_length, num_features, net.get_shape())
net = tf.slice(net, [0, 0, 0], [-1, self._params.seq_length, -1])
return tf.unstack(net, axis=1)
def sequence_logit_fn(self, net, labels_one_hot):
mparams = self._mparams['sequence_logit_fn']
# TODO(gorban): remove /alias suffixes from the scopes.
with tf.variable_scope('sequence_logit_fn/SQLR'):
layer_class = sequence_layers.get_layer_class(mparams.use_attention,
mparams.use_autoregression)
layer = layer_class(net, labels_one_hot, self._params, mparams)
return layer.create_logits()
def max_pool_views(self, nets_list):
"""Max pool across all nets in spatial dimensions.
Args:
nets_list: A list of 4D tensors with identical size.
Returns:
A tensor with the same size as any input tensors.
"""
batch_size, height, width, num_features = [
d.value for d in nets_list[0].get_shape().dims
]
xy_flat_shape = (batch_size, 1, height * width, num_features)
nets_for_merge = []
with tf.variable_scope('max_pool_views', values=nets_list):
for net in nets_list:
nets_for_merge.append(tf.reshape(net, xy_flat_shape))
merged_net = tf.concat(nets_for_merge, 1)
net = slim.max_pool2d(
merged_net, kernel_size=[len(nets_list), 1], stride=1)
net = tf.reshape(net, (batch_size, height, width, num_features))
return net
def pool_views_fn(self, nets):
"""Combines output of multiple convolutional towers into a single tensor.
It stacks towers one on top another (in height dim) in a 4x1 grid.
The order is arbitrary design choice and shouldn't matter much.
Args:
nets: list of tensors of shape=[batch_size, height, width, num_features].
Returns:
A tensor of shape [batch_size, seq_length, features_size].
"""
with tf.variable_scope('pool_views_fn/STCK'):
net = tf.concat(nets, 1)
batch_size = net.get_shape().dims[0].value
feature_size = net.get_shape().dims[3].value
return tf.reshape(net, [batch_size, -1, feature_size])
def char_predictions(self, chars_logit):
"""Returns confidence scores (softmax values) for predicted characters.
Args:
chars_logit: chars logits, a tensor with shape
[batch_size x seq_length x num_char_classes]
Returns:
A tuple (ids, log_prob, scores), where:
ids - predicted characters, a int32 tensor with shape
[batch_size x seq_length];
log_prob - a log probability of all characters, a float tensor with
shape [batch_size, seq_length, num_char_classes];
scores - corresponding confidence scores for characters, a float
tensor
with shape [batch_size x seq_length].
"""
log_prob = utils.logits_to_log_prob(chars_logit)
ids = tf.to_int32(tf.argmax(log_prob, dimension=2), name='predicted_chars')
mask = tf.cast(
slim.one_hot_encoding(ids, self._params.num_char_classes), tf.bool)
all_scores = tf.nn.softmax(chars_logit)
selected_scores = tf.boolean_mask(all_scores, mask, name='char_scores')
scores = tf.reshape(selected_scores, shape=(-1, self._params.seq_length))
return ids, log_prob, scores
def create_base(self,
images,
labels_one_hot,
scope='AttentionOcr_v1',
reuse=None):
"""Creates a base part of the Model (no gradients, losses or summaries).
Args:
images: A tensor of shape [batch_size, height, width, channels].
labels_one_hot: Optional (can be None) one-hot encoding for ground truth
labels. If provided the function will create a model for training.
scope: Optional variable_scope.
reuse: whether or not the network and its variables should be reused. To
be able to reuse 'scope' must be given.
Returns:
A named tuple OutputEndpoints.
"""
logging.debug('images: %s', images)
is_training = labels_one_hot is not None
with tf.variable_scope(scope, reuse=reuse):
views = tf.split(
value=images, num_or_size_splits=self._params.num_views, axis=2)
logging.debug('Views=%d single view: %s', len(views), views[0])
nets = [
self.conv_tower_fn(v, is_training, reuse=(i != 0))
for i, v in enumerate(views)
]
logging.debug('Conv tower: %s', nets[0])
net = self.pool_views_fn(nets)
logging.debug('Pooled views: %s', net)
chars_logit = self.sequence_logit_fn(net, labels_one_hot)
logging.debug('chars_logit: %s', chars_logit)
predicted_chars, chars_log_prob, predicted_scores = (
self.char_predictions(chars_logit))
return OutputEndpoints(
chars_logit=chars_logit,
chars_log_prob=chars_log_prob,
predicted_chars=predicted_chars,
predicted_scores=predicted_scores)
def create_loss(self, data, endpoints):
"""Creates all losses required to train the model.
Args:
data: InputEndpoints namedtuple.
endpoints: Model namedtuple.
Returns:
Total loss.
"""
# NOTE: the return value of ModelLoss is not used directly for the
# gradient computation because under the hood it calls slim.losses.AddLoss,
# which registers the loss in an internal collection and later returns it
# as part of GetTotalLoss. We need to use total loss because model may have
# multiple losses including regularization losses.
self.sequence_loss_fn(endpoints.chars_logit, data.labels)
total_loss = slim.losses.get_total_loss()
tf.summary.scalar('TotalLoss', total_loss)
return total_loss
def label_smoothing_regularization(self, chars_labels, weight=0.1):
"""Applies a label smoothing regularization.
Uses the same method as in https://arxiv.org/abs/1512.00567.
Args:
chars_labels: ground truth ids of charactes,
shape=[batch_size, seq_length];
weight: label-smoothing regularization weight.
Returns:
A sensor with the same shape as the input.
"""
one_hot_labels = tf.one_hot(
chars_labels, depth=self._params.num_char_classes, axis=-1)
pos_weight = 1.0 - weight
neg_weight = weight / self._params.num_char_classes
return one_hot_labels * pos_weight + neg_weight
def sequence_loss_fn(self, chars_logits, chars_labels):
"""Loss function for char sequence.
Depending on values of hyper parameters it applies label smoothing and can
also ignore all null chars after the first one.
Args:
chars_logits: logits for predicted characters,
shape=[batch_size, seq_length, num_char_classes];
chars_labels: ground truth ids of characters,
shape=[batch_size, seq_length];
mparams: method hyper parameters.
Returns:
A Tensor with shape [batch_size] - the log-perplexity for each sequence.
"""
mparams = self._mparams['sequence_loss_fn']
with tf.variable_scope('sequence_loss_fn/SLF'):
if mparams.label_smoothing > 0:
smoothed_one_hot_labels = self.label_smoothing_regularization(
chars_labels, mparams.label_smoothing)
labels_list = tf.unstack(smoothed_one_hot_labels, axis=1)
else:
# NOTE: in case of sparse softmax we are not using one-hot
# encoding.
labels_list = tf.unstack(chars_labels, axis=1)
batch_size, seq_length, _ = chars_logits.shape.as_list()
if mparams.ignore_nulls:
weights = tf.ones((batch_size, seq_length), dtype=tf.float32)
else:
# Suppose that reject character is the last in the charset.
reject_char = tf.constant(
self._params.num_char_classes - 1,
shape=(batch_size, seq_length),
dtype=tf.int64)
known_char = tf.not_equal(chars_labels, reject_char)
weights = tf.to_float(known_char)
logits_list = tf.unstack(chars_logits, axis=1)
weights_list = tf.unstack(weights, axis=1)
loss = tf.contrib.legacy_seq2seq.sequence_loss(
logits_list,
labels_list,
weights_list,
softmax_loss_function=get_softmax_loss_fn(mparams.label_smoothing),
average_across_timesteps=mparams.average_across_timesteps)
tf.losses.add_loss(loss)
return loss
def create_summaries(self, data, endpoints, charset, is_training):
"""Creates all summaries for the model.
Args:
data: InputEndpoints namedtuple.
endpoints: OutputEndpoints namedtuple.
charset: A dictionary with mapping between character codes and
unicode characters. Use the one provided by a dataset.charset.
is_training: If True will create summary prefixes for training job,
otherwise - for evaluation.
Returns:
A list of evaluation ops
"""
def sname(label):
prefix = 'train' if is_training else 'eval'
return '%s/%s' % (prefix, label)
max_outputs = 4
# TODO(gorban): uncomment, when tf.summary.text released.
# charset_mapper = CharsetMapper(charset)
# pr_text = charset_mapper.get_text(
# endpoints.predicted_chars[:max_outputs,:])
# tf.summary.text(sname('text/pr'), pr_text)
# gt_text = charset_mapper.get_text(data.labels[:max_outputs,:])
# tf.summary.text(sname('text/gt'), gt_text)
tf.summary.image(sname('image'), data.images, max_outputs=max_outputs)
if is_training:
tf.summary.image(
sname('image/orig'), data.images_orig, max_outputs=max_outputs)
for var in tf.trainable_variables():
tf.summary.histogram(var.op.name, var)
return None
else:
names_to_values = {}
names_to_updates = {}
def use_metric(name, value_update_tuple):
names_to_values[name] = value_update_tuple[0]
names_to_updates[name] = value_update_tuple[1]
use_metric('CharacterAccuracy',
metrics.char_accuracy(
endpoints.predicted_chars,
data.labels,
streaming=True,
rej_char=self._params.null_code))
# Sequence accuracy computed by cutting sequence at the first null char
use_metric('SequenceAccuracy',
metrics.sequence_accuracy(
endpoints.predicted_chars,
data.labels,
streaming=True,
rej_char=self._params.null_code))
for name, value in names_to_values.iteritems():
summary_name = 'eval/' + name
tf.summary.scalar(summary_name, tf.Print(value, [value], summary_name))
return names_to_updates.values()
def create_init_fn_to_restore(self, master_checkpoint, inception_checkpoint):
"""Creates an init operations to restore weights from various checkpoints.
Args:
master_checkpoint: path to a checkpoint which contains all weights for
the whole model.
inception_checkpoint: path to a checkpoint which contains weights for the
inception part only.
Returns:
a function to run initialization ops.
"""
all_assign_ops = []
all_feed_dict = {}
def assign_from_checkpoint(variables, checkpoint):
logging.info('Request to re-store %d weights from %s',
len(variables), checkpoint)
if not variables:
logging.error('Can\'t find any variables to restore.')
sys.exit(1)
assign_op, feed_dict = slim.assign_from_checkpoint(checkpoint, variables)
all_assign_ops.append(assign_op)
all_feed_dict.update(feed_dict)
if master_checkpoint:
assign_from_checkpoint(utils.variables_to_restore(), master_checkpoint)
if inception_checkpoint:
variables = utils.variables_to_restore(
'AttentionOcr_v1/conv_tower_fn/INCE', strip_scope=True)
assign_from_checkpoint(variables, inception_checkpoint)
def init_assign_fn(sess):
logging.info('Restoring checkpoint(s)')
sess.run(all_assign_ops, all_feed_dict)
return init_assign_fn
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the model."""
import numpy as np
import string
import tensorflow as tf
from tensorflow.contrib import slim
from tensorflow.contrib.tfprof import model_analyzer
import model
import data_provider
def create_fake_charset(num_char_classes):
charset = {}
for i in xrange(num_char_classes):
charset[i] = string.printable[i % len(string.printable)]
return charset
class ModelTest(tf.test.TestCase):
def setUp(self):
tf.test.TestCase.setUp(self)
self.rng = np.random.RandomState([11, 23, 50])
self.batch_size = 4
self.image_width = 600
self.image_height = 30
self.seq_length = 40
self.num_char_classes = 72
self.null_code = 62
self.num_views = 4
feature_size = 288
self.conv_tower_shape = (self.batch_size, 1, 72, feature_size)
self.features_shape = (self.batch_size, self.seq_length, feature_size)
self.chars_logit_shape = (self.batch_size, self.seq_length,
self.num_char_classes)
self.length_logit_shape = (self.batch_size, self.seq_length + 1)
self.initialize_fakes()
def initialize_fakes(self):
self.images_shape = (self.batch_size, self.image_height, self.image_width,
3)
self.fake_images = tf.constant(
self.rng.randint(low=0, high=255,
size=self.images_shape).astype('float32'),
name='input_node')
self.fake_conv_tower_np = tf.constant(
self.rng.randn(*self.conv_tower_shape).astype('float32'))
self.fake_logits = tf.constant(
self.rng.randn(*self.chars_logit_shape).astype('float32'))
self.fake_labels = tf.constant(
self.rng.randint(
low=0,
high=self.num_char_classes,
size=(self.batch_size, self.seq_length)).astype('int64'))
def create_model(self):
return model.Model(
self.num_char_classes, self.seq_length, num_views=4, null_code=62)
def test_char_related_shapes(self):
ocr_model = self.create_model()
with self.test_session() as sess:
endpoints_tf = ocr_model.create_base(
images=self.fake_images, labels_one_hot=None)
sess.run(tf.global_variables_initializer())
endpoints = sess.run(endpoints_tf)
self.assertEqual((self.batch_size, self.seq_length,
self.num_char_classes), endpoints.chars_logit.shape)
self.assertEqual((self.batch_size, self.seq_length,
self.num_char_classes), endpoints.chars_log_prob.shape)
self.assertEqual((self.batch_size, self.seq_length),
endpoints.predicted_chars.shape)
self.assertEqual((self.batch_size, self.seq_length),
endpoints.predicted_scores.shape)
def test_predicted_scores_are_within_range(self):
ocr_model = self.create_model()
_, _, scores = ocr_model.char_predictions(self.fake_logits)
with self.test_session() as sess:
scores_np = sess.run(scores)
values_in_range = (scores_np >= 0.0) & (scores_np <= 1.0)
self.assertTrue(
np.all(values_in_range),
msg=('Scores contains out of the range values %s' %
scores_np[np.logical_not(values_in_range)]))
def test_conv_tower_shape(self):
with self.test_session() as sess:
ocr_model = self.create_model()
conv_tower = ocr_model.conv_tower_fn(self.fake_images)
sess.run(tf.global_variables_initializer())
conv_tower_np = sess.run(conv_tower)
self.assertEqual(self.conv_tower_shape, conv_tower_np.shape)
def test_model_size_less_then1_gb(self):
# NOTE: Actual amount of memory occupied my TF during training will be at
# least 4X times bigger because of space need to store original weights,
# updates, gradients and variances. It also depends on the type of used
# optimizer.
ocr_model = self.create_model()
ocr_model.create_base(images=self.fake_images, labels_one_hot=None)
with self.test_session() as sess:
tfprof_root = model_analyzer.print_model_analysis(
sess.graph,
tfprof_options=model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
model_size_bytes = 4 * tfprof_root.total_parameters
self.assertLess(model_size_bytes, 1 * 2**30)
def test_create_summaries_is_runnable(self):
ocr_model = self.create_model()
data = data_provider.InputEndpoints(
images=self.fake_images,
images_orig=self.fake_images,
labels=self.fake_labels,
labels_one_hot=slim.one_hot_encoding(self.fake_labels,
self.num_char_classes))
endpoints = ocr_model.create_base(
images=self.fake_images, labels_one_hot=None)
charset = create_fake_charset(self.num_char_classes)
summaries = ocr_model.create_summaries(
data, endpoints, charset, is_training=False)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
tf.tables_initializer().run()
sess.run(summaries) # just check it is runnable
def test_sequence_loss_function_without_label_smoothing(self):
model = self.create_model()
model.set_mparam('sequence_loss_fn', label_smoothing=0)
loss = model.sequence_loss_fn(self.fake_logits, self.fake_labels)
with self.test_session() as sess:
loss_np = sess.run(loss)
# This test checks that the loss function is 'runnable'.
self.assertEqual(loss_np.shape, tuple())
class CharsetMapperTest(tf.test.TestCase):
def test_text_corresponds_to_ids(self):
charset = create_fake_charset(36)
ids = tf.constant(
[[17, 14, 21, 21, 24], [32, 24, 27, 21, 13]], dtype=tf.int64)
charset_mapper = model.CharsetMapper(charset)
with self.test_session() as sess:
tf.tables_initializer().run()
text = sess.run(charset_mapper.get_text(ids))
self.assertAllEqual(text, ['hello', 'world'])
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Various implementations of sequence layers for character prediction.
A 'sequence layer' is a part of a computation graph which is responsible of
producing a sequence of characters using extracted image features. There are
many reasonable ways to implement such layers. All of them are using RNNs.
This module provides implementations which uses 'attention' mechanism to
spatially 'pool' image features and also can use a previously predicted
character to predict the next (aka auto regression).
Usage:
Select one of available classes, e.g. Attention or use a wrapper function to
pick one based on your requirements:
layer_class = sequence_layers.get_layer_class(use_attention=True,
use_autoregression=True)
layer = layer_class(net, labels_one_hot, model_params, method_params)
char_logits = layer.create_logits()
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import abc
import logging
import numpy as np
import tensorflow as tf
from tensorflow.contrib import slim
def orthogonal_initializer(shape, dtype=tf.float32, *args, **kwargs):
"""Generates orthonormal matrices with random values.
Orthonormal initialization is important for RNNs:
http://arxiv.org/abs/1312.6120
http://smerity.com/articles/2016/orthogonal_init.html
For non-square shapes the returned matrix will be semi-orthonormal: if the
number of columns exceeds the number of rows, then the rows are orthonormal
vectors; but if the number of rows exceeds the number of columns, then the
columns are orthonormal vectors.
We use SVD decomposition to generate an orthonormal matrix with random
values. The same way as it is done in the Lasagne library for Theano. Note
that both u and v returned by the svd are orthogonal and random. We just need
to pick one with the right shape.
Args:
shape: a shape of the tensor matrix to initialize.
dtype: a dtype of the initialized tensor.
*args: not used.
**kwargs: not used.
Returns:
An initialized tensor.
"""
del args
del kwargs
flat_shape = (shape[0], np.prod(shape[1:]))
w = np.random.randn(*flat_shape)
u, _, v = np.linalg.svd(w, full_matrices=False)
w = u if u.shape == flat_shape else v
return tf.constant(w.reshape(shape), dtype=dtype)
SequenceLayerParams = collections.namedtuple('SequenceLogitsParams', [
'num_lstm_units', 'weight_decay', 'lstm_state_clip_value'
])
class SequenceLayerBase(object):
"""A base abstruct class for all sequence layers.
A child class has to define following methods:
get_train_input
get_eval_input
unroll_cell
"""
__metaclass__ = abc.ABCMeta
def __init__(self, net, labels_one_hot, model_params, method_params):
"""Stores argument in member variable for further use.
Args:
net: A tensor with shape [batch_size, num_features, feature_size] which
contains some extracted image features.
labels_one_hot: An optional (can be None) ground truth labels for the
input features. Is a tensor with shape
[batch_size, seq_length, num_char_classes]
model_params: A namedtuple with model parameters (model.ModelParams).
method_params: A SequenceLayerParams instance.
"""
self._params = model_params
self._mparams = method_params
self._net = net
self._labels_one_hot = labels_one_hot
self._batch_size = net.get_shape().dims[0].value
# Initialize parameters for char logits which will be computed on the fly
# inside an LSTM decoder.
self._char_logits = {}
regularizer = slim.l2_regularizer(self._mparams.weight_decay)
self._softmax_w = slim.model_variable(
'softmax_w',
[self._mparams.num_lstm_units, self._params.num_char_classes],
initializer=orthogonal_initializer,
regularizer=regularizer)
self._softmax_b = slim.model_variable(
'softmax_b', [self._params.num_char_classes],
initializer=tf.zeros_initializer(),
regularizer=regularizer)
@abc.abstractmethod
def get_train_input(self, prev, i):
"""Returns a sample to be used to predict a character during training.
This function is used as a loop_function for an RNN decoder.
Args:
prev: output tensor from previous step of the RNN. A tensor with shape:
[batch_size, num_char_classes].
i: index of a character in the output sequence.
Returns:
A tensor with shape [batch_size, ?] - depth depends on implementation
details.
"""
pass
@abc.abstractmethod
def get_eval_input(self, prev, i):
"""Returns a sample to be used to predict a character during inference.
This function is used as a loop_function for an RNN decoder.
Args:
prev: output tensor from previous step of the RNN. A tensor with shape:
[batch_size, num_char_classes].
i: index of a character in the output sequence.
Returns:
A tensor with shape [batch_size, ?] - depth depends on implementation
details.
"""
raise AssertionError('Not implemented')
@abc.abstractmethod
def unroll_cell(self, decoder_inputs, initial_state, loop_function, cell):
"""Unrolls an RNN cell for all inputs.
This is a placeholder to call some RNN decoder. It has a similar to
tf.seq2seq.rnn_decode interface.
Args:
decoder_inputs: A list of 2D Tensors* [batch_size x input_size]. In fact,
most of existing decoders in presence of a loop_function use only the
first element to determine batch_size and length of the list to
determine number of steps.
initial_state: 2D Tensor with shape [batch_size x cell.state_size].
loop_function: function will be applied to the i-th output in order to
generate the i+1-st input (see self.get_input).
cell: rnn_cell.RNNCell defining the cell function and size.
Returns:
A tuple of the form (outputs, state), where:
outputs: A list of character logits of the same length as
decoder_inputs of 2D Tensors with shape [batch_size x num_characters].
state: The state of each cell at the final time-step.
It is a 2D Tensor of shape [batch_size x cell.state_size].
"""
pass
def is_training(self):
"""Returns True if the layer is created for training stage."""
return self._labels_one_hot is not None
def char_logit(self, inputs, char_index):
"""Creates logits for a character if required.
Args:
inputs: A tensor with shape [batch_size, ?] (depth is implementation
dependent).
char_index: A integer index of a character in the output sequence.
Returns:
A tensor with shape [batch_size, num_char_classes]
"""
if char_index not in self._char_logits:
self._char_logits[char_index] = tf.nn.xw_plus_b(inputs, self._softmax_w,
self._softmax_b)
return self._char_logits[char_index]
def char_one_hot(self, logit):
"""Creates one hot encoding for a logit of a character.
Args:
logit: A tensor with shape [batch_size, num_char_classes].
Returns:
A tensor with shape [batch_size, num_char_classes]
"""
prediction = tf.argmax(logit, dimension=1)
return slim.one_hot_encoding(prediction, self._params.num_char_classes)
def get_input(self, prev, i):
"""A wrapper for get_train_input and get_eval_input.
Args:
prev: output tensor from previous step of the RNN. A tensor with shape:
[batch_size, num_char_classes].
i: index of a character in the output sequence.
Returns:
A tensor with shape [batch_size, ?] - depth depends on implementation
details.
"""
if self.is_training():
return self.get_train_input(prev, i)
else:
return self.get_eval_input(prev, i)
def create_logits(self):
"""Creates character sequence logits for a net specified in the constructor.
A "main" method for the sequence layer which glues together all pieces.
Returns:
A tensor with shape [batch_size, seq_length, num_char_classes].
"""
with tf.variable_scope('LSTM'):
first_label = self.get_input(prev=None, i=0)
decoder_inputs = [first_label] + [None] * (self._params.seq_length - 1)
lstm_cell = tf.contrib.rnn.LSTMCell(
self._mparams.num_lstm_units,
use_peepholes=False,
cell_clip=self._mparams.lstm_state_clip_value,
state_is_tuple=True,
initializer=orthogonal_initializer)
lstm_outputs, _ = self.unroll_cell(
decoder_inputs=decoder_inputs,
initial_state=lstm_cell.zero_state(self._batch_size, tf.float32),
loop_function=self.get_input,
cell=lstm_cell)
with tf.variable_scope('logits'):
logits_list = [
tf.expand_dims(self.char_logit(logit, i), dim=1)
for i, logit in enumerate(lstm_outputs)
]
return tf.concat(logits_list, 1)
class NetSlice(SequenceLayerBase):
"""A layer which uses a subset of image features to predict each character.
"""
def __init__(self, *args, **kwargs):
super(NetSlice, self).__init__(*args, **kwargs)
self._zero_label = tf.zeros(
[self._batch_size, self._params.num_char_classes])
def get_image_feature(self, char_index):
"""Returns a subset of image features for a character.
Args:
char_index: an index of a character.
Returns:
A tensor with shape [batch_size, ?]. The output depth depends on the
depth of input net.
"""
batch_size, features_num, _ = [d.value for d in self._net.get_shape()]
slice_len = int(features_num / self._params.seq_length)
# In case when features_num != seq_length, we just pick a subset of image
# features, this choice is arbitrary and there is no intuitive geometrical
# interpretation. If features_num is not dividable by seq_length there will
# be unused image features.
net_slice = self._net[:, char_index:char_index + slice_len, :]
feature = tf.reshape(net_slice, [batch_size, -1])
logging.debug('Image feature: %s', feature)
return feature
def get_eval_input(self, prev, i):
"""See SequenceLayerBase.get_eval_input for details."""
del prev
return self.get_image_feature(i)
def get_train_input(self, prev, i):
"""See SequenceLayerBase.get_train_input for details."""
return self.get_eval_input(prev, i)
def unroll_cell(self, decoder_inputs, initial_state, loop_function, cell):
"""See SequenceLayerBase.unroll_cell for details."""
return tf.contrib.legacy_seq2seq.rnn_decoder(
decoder_inputs=decoder_inputs,
initial_state=initial_state,
cell=cell,
loop_function=self.get_input)
class NetSliceWithAutoregression(NetSlice):
"""A layer similar to NetSlice, but it also uses auto regression.
The "auto regression" means that we use network output for previous character
as a part of input for the current character.
"""
def __init__(self, *args, **kwargs):
super(NetSliceWithAutoregression, self).__init__(*args, **kwargs)
def get_eval_input(self, prev, i):
"""See SequenceLayerBase.get_eval_input for details."""
if i == 0:
prev = self._zero_label
else:
logit = self.char_logit(prev, char_index=i - 1)
prev = self.char_one_hot(logit)
image_feature = self.get_image_feature(char_index=i)
return tf.concat([image_feature, prev], 1)
def get_train_input(self, prev, i):
"""See SequenceLayerBase.get_train_input for details."""
if i == 0:
prev = self._zero_label
else:
prev = self._labels_one_hot[:, i - 1, :]
image_feature = self.get_image_feature(i)
return tf.concat([image_feature, prev], 1)
class Attention(SequenceLayerBase):
"""A layer which uses attention mechanism to select image features."""
def __init__(self, *args, **kwargs):
super(Attention, self).__init__(*args, **kwargs)
self._zero_label = tf.zeros(
[self._batch_size, self._params.num_char_classes])
def get_eval_input(self, prev, i):
"""See SequenceLayerBase.get_eval_input for details."""
del prev, i
# The attention_decoder will fetch image features from the net, no need for
# extra inputs.
return self._zero_label
def get_train_input(self, prev, i):
"""See SequenceLayerBase.get_train_input for details."""
return self.get_eval_input(prev, i)
def unroll_cell(self, decoder_inputs, initial_state, loop_function, cell):
return tf.contrib.legacy_seq2seq.attention_decoder(
decoder_inputs=decoder_inputs,
initial_state=initial_state,
attention_states=self._net,
cell=cell,
loop_function=self.get_input)
class AttentionWithAutoregression(Attention):
"""A layer which uses both attention and auto regression."""
def __init__(self, *args, **kwargs):
super(AttentionWithAutoregression, self).__init__(*args, **kwargs)
def get_train_input(self, prev, i):
"""See SequenceLayerBase.get_train_input for details."""
if i == 0:
return self._zero_label
else:
# TODO(gorban): update to gradually introduce gt labels.
return self._labels_one_hot[:, i - 1, :]
def get_eval_input(self, prev, i):
"""See SequenceLayerBase.get_eval_input for details."""
if i == 0:
return self._zero_label
else:
logit = self.char_logit(prev, char_index=i - 1)
return self.char_one_hot(logit)
def get_layer_class(use_attention, use_autoregression):
"""A convenience function to get a layer class based on requirements.
Args:
use_attention: if True a returned class will use attention.
use_autoregression: if True a returned class will use auto regression.
Returns:
One of available sequence layers (child classes for SequenceLayerBase).
"""
if use_attention and use_autoregression:
layer_class = AttentionWithAutoregression
elif use_attention and not use_autoregression:
layer_class = Attention
elif not use_attention and not use_autoregression:
layer_class = NetSlice
elif not use_attention and use_autoregression:
layer_class = NetSliceWithAutoregression
else:
raise AssertionError('Unsupported sequence layer class')
logging.debug('Use %s as a layer class', layer_class.__name__)
return layer_class
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment