"tests/vscode:/vscode.git/clone" did not exist on "7cc369687edb071282a634758278ed43b4cd191d"
Commit 44fa1d37 authored by Alex Lee's avatar Alex Lee
Browse files

Merge remote-tracking branch 'upstream/master'

parents d3628a74 6e367f67
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trains LSTM text classification model.
Model trains with adversarial or virtual adversarial training.
Computational time:
1.8 hours to train 10000 steps without adversarial or virtual adversarial
training, on 1 layer 1024 hidden units LSTM, 256 embeddings, 400 truncated
BP, 64 minibatch and on single GPU (Pascal Titan X, cuDNNv5).
4 hours to train 10000 steps with adversarial or virtual adversarial
training, with above condition.
To initialize embedding and LSTM cell weights from a pretrained model, set
FLAGS.pretrained_model_dir to the pretrained model's checkpoint directory.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
import graphs
import train_utils
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('pretrained_model_dir', None,
'Directory path to pretrained model to restore from')
def main(_):
"""Trains LSTM classification model."""
tf.logging.set_verbosity(tf.logging.INFO)
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
model = graphs.get_model()
train_op, loss, global_step = model.classifier_training()
train_utils.run_training(
train_op,
loss,
global_step,
variables_to_restore=model.pretrained_variables,
pretrained_model_dir=FLAGS.pretrained_model_dir)
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for training adversarial text models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
# Dependency imports
import numpy as np
import tensorflow as tf
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('master', '', 'Master address.')
flags.DEFINE_integer('task', 0, 'Task id of the replica running the training.')
flags.DEFINE_integer('ps_tasks', 0, 'Number of parameter servers.')
flags.DEFINE_string('train_dir', '/tmp/text_train',
'Directory for logs and checkpoints.')
flags.DEFINE_integer('max_steps', 1000000, 'Number of batches to run.')
flags.DEFINE_boolean('log_device_placement', False,
'Whether to log device placement.')
def run_training(train_op,
loss,
global_step,
variables_to_restore=None,
pretrained_model_dir=None):
"""Sets up and runs training loop."""
tf.gfile.MakeDirs(FLAGS.train_dir)
# Create pretrain Saver
if pretrained_model_dir:
assert variables_to_restore
tf.logging.info('Will attempt restore from %s: %s', pretrained_model_dir,
variables_to_restore)
saver_for_restore = tf.train.Saver(variables_to_restore)
# Init ops
if FLAGS.sync_replicas:
local_init_op = tf.get_collection('local_init_op')[0]
ready_for_local_init_op = tf.get_collection('ready_for_local_init_op')[0]
else:
local_init_op = tf.train.Supervisor.USE_DEFAULT
ready_for_local_init_op = tf.train.Supervisor.USE_DEFAULT
is_chief = FLAGS.task == 0
sv = tf.train.Supervisor(
logdir=FLAGS.train_dir,
is_chief=is_chief,
save_summaries_secs=5 * 60,
save_model_secs=5 * 60,
local_init_op=local_init_op,
ready_for_local_init_op=ready_for_local_init_op,
global_step=global_step)
# Delay starting standard services to allow possible pretrained model restore.
with sv.managed_session(
master=FLAGS.master,
config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement),
start_standard_services=False) as sess:
# Initialization
if is_chief:
if pretrained_model_dir:
maybe_restore_pretrained_model(sess, saver_for_restore,
pretrained_model_dir)
if FLAGS.sync_replicas:
sess.run(tf.get_collection('chief_init_op')[0])
sv.start_standard_services(sess)
sv.start_queue_runners(sess)
# Training loop
global_step_val = 0
while not sv.should_stop() and global_step_val < FLAGS.max_steps:
global_step_val = train_step(sess, train_op, loss, global_step)
sv.stop()
# Final checkpoint
if is_chief:
sv.saver.save(sess, sv.save_path, global_step=global_step)
def maybe_restore_pretrained_model(sess, saver_for_restore, model_dir):
"""Restores pretrained model if there is no ckpt model."""
ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
checkpoint_exists = ckpt and ckpt.model_checkpoint_path
if checkpoint_exists:
tf.logging.info('Checkpoint exists in FLAGS.train_dir; skipping '
'pretraining restore')
return
pretrain_ckpt = tf.train.get_checkpoint_state(model_dir)
if not (pretrain_ckpt and pretrain_ckpt.model_checkpoint_path):
raise ValueError(
'Asked to restore model from %s but no checkpoint found.' % model_dir)
saver_for_restore.restore(sess, pretrain_ckpt.model_checkpoint_path)
def train_step(sess, train_op, loss, global_step):
"""Runs a single training step."""
start_time = time.time()
_, loss_val, global_step_val = sess.run([train_op, loss, global_step])
duration = time.time() - start_time
# Logging
if global_step_val % 10 == 0:
examples_per_sec = FLAGS.batch_size / duration
sec_per_batch = float(duration)
format_str = ('step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)')
tf.logging.info(format_str % (global_step_val, loss_val, examples_per_sec,
sec_per_batch))
if np.isnan(loss_val):
raise OverflowError('Loss is nan')
return global_step_val
## Attention-based Extraction of Structured Information from Street View Imagery
*A TensorFlow model for real-world image text extraction problems.*
This folder contains the code needed to train a new Attention OCR model on the
[FSNS dataset][FSNS] dataset to transcribe street names in France. You can
also use it to train it on your own data.
More details can be found in our paper:
["Attention-based Extraction of Structured Information from Street View
Imagery"](https://arxiv.org/abs/1704.03549)
## Contacts
Authors:
Zbigniew Wojna <zbigniewwojna@gmail.com>,
Alexander Gorban <gorban@google.com>
Pull requests:
[alexgorban](https://github.com/alexgorban)
## Requirements
1. Install the TensorFlow library ([instructions][TF]). For example:
```
virtualenv --system-site-packages ~/.tensorflow
source ~/.tensorflow/bin/activate
pip install --upgrade pip
pip install --upgrade tensorflow_gpu
```
2. At least 158GB of free disk space to download the FSNS dataset:
```
cd models/attention_ocr/python/datasets
aria2c -c -j 20 -i ../../../street/python/fsns_urls.txt
cd ..
```
3. 16GB of RAM or more; 32GB is recommended.
4. `train.py` works with both CPU and GPU, though using GPU is preferable. It has been tested with a Titan X and with a GTX980.
[TF]: https://www.tensorflow.org/install/
[FSNS]: https://github.com/tensorflow/models/tree/master/street
## How to use this code
To run all unit tests:
```
cd models/attention_ocr/python
python -m unittest discover -p '*_test.py'
```
To train from scratch:
```
python train.py
```
To train a model using pre-trained Inception weights as initialization:
```
wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz
tar xf inception_v3_2016_08_28.tar.gz
python train.py --checkpoint_inception=inception_v3.ckpt
```
To fine tune the Attention OCR model using a checkpoint:
```
wget http://download.tensorflow.org/models/attention_ocr_2017_05_17.tar.gz
tar xf attention_ocr_2017_05_17.tar.gz
python train.py --checkpoint=model.ckpt-399731
```
## How to use your own image data to train the model
You need to define a new dataset. There are two options:
1. Store data in the same format as the FSNS dataset and just reuse the
[python/datasets/fsns.py](https://github.com/tensorflow/models/blob/master/attention_ocr/python/datasets/fsns.py)
module. E.g., create a file datasets/newtextdataset.py:
```
import fsns
DEFAULT_DATASET_DIR = 'path/to/the/dataset'
DEFAULT_CONFIG = {
'name':
'MYDATASET',
'splits': {
'train': {
'size': 123,
'pattern': 'tfexample_train*'
},
'test': {
'size': 123,
'pattern': 'tfexample_test*'
}
},
'charset_filename':
'charset_size.txt',
'image_shape': (150, 600, 3),
'num_of_views':
4,
'max_sequence_length':
37,
'null_code':
42,
'items_to_descriptions': {
'image':
'A [150 x 600 x 3] color image.',
'label':
'Characters codes.',
'text':
'A unicode string.',
'length':
'A length of the encoded text.',
'num_of_views':
'A number of different views stored within the image.'
}
}
def get_split(split_name, dataset_dir=None, config=None):
if not dataset_dir:
dataset_dir = DEFAULT_DATASET_DIR
if not config:
config = DEFAULT_CONFIG
return fsns.get_split(split_name, dataset_dir, config)
```
You will also need to include it into the `datasets/__init__.py` and specify the
dataset name in the command line.
```
python train.py --dataset_name=newtextdataset
```
Please note that eval.py will also require the same flag.
2. Define a new dataset format. The model needs the following data to train:
- images: input images, shape [batch_size x H x W x 3];
- labels: ground truth label ids, shape=[batch_size x seq_length];
- labels_one_hot: labels in one-hot encoding, shape [batch_size x seq_length x num_char_classes];
Refer to [python/data_provider.py](https://github.com/tensorflow/models/blob/master/attention_ocr/python/data_provider.py#L33)
for more details. You can use [python/datasets/fsns.py](https://github.com/tensorflow/models/blob/master/attention_ocr/python/datasets/fsns.py)
as the example.
## How to use a pre-trained model
The inference part was not released yet, but it is pretty straightforward to
implement one in Python or C++.
The recommended way is to use the [Serving infrastructure](https://tensorflow.github.io/serving/serving_basic).
Alternatively you can:
1. define a placeholder for images (or use directly an numpy array)
2. [create a graph ](https://github.com/tensorflow/models/blob/master/attention_ocr/python/eval.py#L60)
`endpoints = model.create_base(images_placeholder, labels_one_hot=None)`
3. [load a pretrained model](https://github.com/tensorflow/models/blob/master/attention_ocr/python/model.py#L494)
4. run computations through the graph:
`predictions = sess.run(endpoints.predicted_chars, feed_dict={images_placeholder:images_actual_data})`
5. Convert character IDs (predictions) to UTF8 using the provided charset file.
## Disclaimer
This code is a modified version of the internal model we used for our paper.
Currently it reaches 83.79% full sequence accuracy after 400k steps of training.
The main difference between this version and the version used in the paper - for
the paper we used a distributed training with 50 GPU (K80) workers (asynchronous
updates), the provided checkpoint was created using this code after ~6 days of
training on a single GPU (Titan X) (it reached 81% after 24 hours of training),
the coordinate encoding is missing TODO(alexgorban@).
# A GPU/screen config to run all jobs for training and evaluation in parallel.
# Execute:
# source /path/to/your/virtualenv/bin/activate
# screen -R TF -c all_jobs.screenrc
screen -t train 0 python train.py --train_log_dir=workdir/train
screen -t eval_train 1 python eval.py --split_name=train --train_log_dir=workdir/train --eval_log_dir=workdir/eval_train
screen -t eval_test 2 python eval.py --split_name=test --train_log_dir=workdir/train --eval_log_dir=workdir/eval_test
screen -t tensorboard 3 tensorboard --logdir=workdir
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define flags are common for both train.py and eval.py scripts."""
import sys
from tensorflow.python.platform import flags
import logging
import datasets
import model
FLAGS = flags.FLAGS
logging.basicConfig(
level=logging.DEBUG,
stream=sys.stderr,
format='%(levelname)s '
'%(asctime)s.%(msecs)06d: '
'%(filename)s: '
'%(lineno)d '
'%(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
def define():
"""Define common flags."""
# yapf: disable
flags.DEFINE_integer('batch_size', 32,
'Batch size.')
flags.DEFINE_integer('crop_width', None,
'Width of the central crop for images.')
flags.DEFINE_integer('crop_height', None,
'Height of the central crop for images.')
flags.DEFINE_string('train_log_dir', '/tmp/attention_ocr/train',
'Directory where to write event logs.')
flags.DEFINE_string('dataset_name', 'fsns',
'Name of the dataset. Supported: fsns')
flags.DEFINE_string('split_name', 'train',
'Dataset split name to run evaluation for: test,train.')
flags.DEFINE_string('dataset_dir', None,
'Dataset root folder.')
flags.DEFINE_string('checkpoint', '',
'Path for checkpoint to restore weights from.')
flags.DEFINE_string('master',
'',
'BNS name of the TensorFlow master to use.')
# Model hyper parameters
flags.DEFINE_float('learning_rate', 0.004,
'learning rate')
flags.DEFINE_string('optimizer', 'momentum',
'the optimizer to use')
flags.DEFINE_string('momentum', 0.9,
'momentum value for the momentum optimizer if used')
flags.DEFINE_bool('use_augment_input', True,
'If True will use image augmentation')
# Method hyper parameters
# conv_tower_fn
flags.DEFINE_string('final_endpoint', 'Mixed_5d',
'Endpoint to cut inception tower')
# sequence_logit_fn
flags.DEFINE_bool('use_attention', True,
'If True will use the attention mechanism')
flags.DEFINE_bool('use_autoregression', True,
'If True will use autoregression (a feedback link)')
flags.DEFINE_integer('num_lstm_units', 256,
'number of LSTM units for sequence LSTM')
flags.DEFINE_float('weight_decay', 0.00004,
'weight decay for char prediction FC layers')
flags.DEFINE_float('lstm_state_clip_value', 10.0,
'cell state is clipped by this value prior to the cell'
' output activation')
# 'sequence_loss_fn'
flags.DEFINE_float('label_smoothing', 0.1,
'weight for label smoothing')
flags.DEFINE_bool('ignore_nulls', True,
'ignore null characters for computing the loss')
flags.DEFINE_bool('average_across_timesteps', False,
'divide the returned cost by the total label weight')
# yapf: enable
def get_crop_size():
if FLAGS.crop_width and FLAGS.crop_height:
return (FLAGS.crop_width, FLAGS.crop_height)
else:
return None
def create_dataset(split_name):
ds_module = getattr(datasets, FLAGS.dataset_name)
return ds_module.get_split(split_name, dataset_dir=FLAGS.dataset_dir)
def create_mparams():
return {
'conv_tower_fn':
model.ConvTowerParams(final_endpoint=FLAGS.final_endpoint),
'sequence_logit_fn':
model.SequenceLogitsParams(
use_attention=FLAGS.use_attention,
use_autoregression=FLAGS.use_autoregression,
num_lstm_units=FLAGS.num_lstm_units,
weight_decay=FLAGS.weight_decay,
lstm_state_clip_value=FLAGS.lstm_state_clip_value),
'sequence_loss_fn':
model.SequenceLossParams(
label_smoothing=FLAGS.label_smoothing,
ignore_nulls=FLAGS.ignore_nulls,
average_across_timesteps=FLAGS.average_across_timesteps)
}
def create_model(*args, **kwargs):
ocr_model = model.Model(mparams=create_mparams(), *args, **kwargs)
return ocr_model
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions to read, decode and pre-process input data for the Model.
"""
import collections
import functools
import tensorflow as tf
from tensorflow.contrib import slim
import inception_preprocessing
# Tuple to store input data endpoints for the Model.
# It has following fields (tensors):
# images: input images,
# shape [batch_size x H x W x 3];
# labels: ground truth label ids,
# shape=[batch_size x seq_length];
# labels_one_hot: labels in one-hot encoding,
# shape [batch_size x seq_length x num_char_classes];
InputEndpoints = collections.namedtuple(
'InputEndpoints', ['images', 'images_orig', 'labels', 'labels_one_hot'])
# A namedtuple to define a configuration for shuffled batch fetching.
# num_batching_threads: A number of parallel threads to fetch data.
# queue_capacity: a max number of elements in the batch shuffling queue.
# min_after_dequeue: a min number elements in the queue after a dequeue, used
# to ensure a level of mixing of elements.
ShuffleBatchConfig = collections.namedtuple('ShuffleBatchConfig', [
'num_batching_threads', 'queue_capacity', 'min_after_dequeue'
])
DEFAULT_SHUFFLE_CONFIG = ShuffleBatchConfig(
num_batching_threads=8, queue_capacity=3000, min_after_dequeue=1000)
def augment_image(image):
"""Augmentation the image with a random modification.
Args:
image: input Tensor image of rank 3, with the last dimension
of size 3.
Returns:
Distorted Tensor image of the same shape.
"""
with tf.variable_scope('AugmentImage'):
height = image.get_shape().dims[0].value
width = image.get_shape().dims[1].value
# Random crop cut from the street sign image, resized to the same size.
# Assures that the crop is covers at least 0.8 area of the input image.
bbox_begin, bbox_size, _ = tf.image.sample_distorted_bounding_box(
tf.shape(image),
bounding_boxes=tf.zeros([0, 0, 4]),
min_object_covered=0.8,
aspect_ratio_range=[0.8, 1.2],
area_range=[0.8, 1.0],
use_image_if_no_bounding_boxes=True)
distorted_image = tf.slice(image, bbox_begin, bbox_size)
# Randomly chooses one of the 4 interpolation methods
distorted_image = inception_preprocessing.apply_with_random_selector(
distorted_image,
lambda x, method: tf.image.resize_images(x, [height, width], method),
num_cases=4)
distorted_image.set_shape([height, width, 3])
# Color distortion
distorted_image = inception_preprocessing.apply_with_random_selector(
distorted_image,
functools.partial(
inception_preprocessing.distort_color, fast_mode=False),
num_cases=4)
distorted_image = tf.clip_by_value(distorted_image, -1.5, 1.5)
return distorted_image
def central_crop(image, crop_size):
"""Returns a central crop for the specified size of an image.
Args:
image: A tensor with shape [height, width, channels]
crop_size: A tuple (crop_width, crop_height)
Returns:
A tensor of shape [crop_height, crop_width, channels].
"""
with tf.variable_scope('CentralCrop'):
target_width, target_height = crop_size
image_height, image_width = tf.shape(image)[0], tf.shape(image)[1]
assert_op1 = tf.Assert(
tf.greater_equal(image_height, target_height),
['image_height < target_height', image_height, target_height])
assert_op2 = tf.Assert(
tf.greater_equal(image_width, target_width),
['image_width < target_width', image_width, target_width])
with tf.control_dependencies([assert_op1, assert_op2]):
offset_width = (image_width - target_width) / 2
offset_height = (image_height - target_height) / 2
return tf.image.crop_to_bounding_box(image, offset_height, offset_width,
target_height, target_width)
def preprocess_image(image, augment=False, central_crop_size=None,
num_towers=4):
"""Normalizes image to have values in a narrow range around zero.
Args:
image: a [H x W x 3] uint8 tensor.
augment: optional, if True do random image distortion.
central_crop_size: A tuple (crop_width, crop_height).
num_towers: optional, number of shots of the same image in the input image.
Returns:
A float32 tensor of shape [H x W x 3] with RGB values in the required
range.
"""
with tf.variable_scope('PreprocessImage'):
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
if augment or central_crop_size:
if num_towers == 1:
images = [image]
else:
images = tf.split(value=image, num_or_size_splits=num_towers, axis=1)
if central_crop_size:
view_crop_size = (central_crop_size[0] / num_towers,
central_crop_size[1])
images = [central_crop(img, view_crop_size) for img in images]
if augment:
images = [augment_image(img) for img in images]
image = tf.concat(images, 1)
image = tf.subtract(image, 0.5)
image = tf.multiply(image, 2.5)
return image
def get_data(dataset,
batch_size,
augment=False,
central_crop_size=None,
shuffle_config=None,
shuffle=True):
"""Wraps calls to DatasetDataProviders and shuffle_batch.
For more details about supported Dataset objects refer to datasets/fsns.py.
Args:
dataset: a slim.data.dataset.Dataset object.
batch_size: number of samples per batch.
augment: optional, if True does random image distortion.
central_crop_size: A CharLogittuple (crop_width, crop_height).
shuffle_config: A namedtuple ShuffleBatchConfig.
shuffle: if True use data shuffling.
Returns:
"""
if not shuffle_config:
shuffle_config = DEFAULT_SHUFFLE_CONFIG
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
shuffle=shuffle,
common_queue_capacity=2 * batch_size,
common_queue_min=batch_size)
image_orig, label = provider.get(['image', 'label'])
image = preprocess_image(
image_orig, augment, central_crop_size, num_towers=dataset.num_of_views)
label_one_hot = slim.one_hot_encoding(label, dataset.num_char_classes)
images, images_orig, labels, labels_one_hot = (tf.train.shuffle_batch(
[image, image_orig, label, label_one_hot],
batch_size=batch_size,
num_threads=shuffle_config.num_batching_threads,
capacity=shuffle_config.queue_capacity,
min_after_dequeue=shuffle_config.min_after_dequeue))
return InputEndpoints(
images=images,
images_orig=images_orig,
labels=labels,
labels_one_hot=labels_one_hot)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for data_provider."""
import numpy as np
import tensorflow as tf
from tensorflow.contrib.slim import queues
import datasets
import data_provider
class DataProviderTest(tf.test.TestCase):
def setUp(self):
tf.test.TestCase.setUp(self)
def test_preprocessed_image_values_are_in_range(self):
image_shape = (5, 4, 3)
fake_image = np.random.randint(low=0, high=255, size=image_shape)
image_tf = data_provider.preprocess_image(fake_image)
with self.test_session() as sess:
image_np = sess.run(image_tf)
self.assertEqual(image_np.shape, image_shape)
min_value, max_value = np.min(image_np), np.max(image_np)
self.assertTrue((-1.28 < min_value) and (min_value < 1.27))
self.assertTrue((-1.28 < max_value) and (max_value < 1.27))
def test_provided_data_has_correct_shape(self):
batch_size = 4
data = data_provider.get_data(
dataset=datasets.fsns_test.get_test_split(),
batch_size=batch_size,
augment=True,
central_crop_size=None)
with self.test_session() as sess, queues.QueueRunners(sess):
images_np, labels_np = sess.run([data.images, data.labels_one_hot])
self.assertEqual(images_np.shape, (batch_size, 150, 600, 3))
self.assertEqual(labels_np.shape, (batch_size, 37, 134))
def test_optionally_applies_central_crop(self):
batch_size = 4
data = data_provider.get_data(
dataset=datasets.fsns_test.get_test_split(),
batch_size=batch_size,
augment=True,
central_crop_size=(500, 100))
with self.test_session() as sess, queues.QueueRunners(sess):
images_np = sess.run(data.images)
self.assertEqual(images_np.shape, (batch_size, 100, 500, 3))
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import fsns
import fsns_test
__all__ = [fsns, fsns_test]
# -*- coding: utf-8 -*-
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Configuration to read FSNS dataset https://goo.gl/3Ldm8v."""
import os
import re
import tensorflow as tf
from tensorflow.contrib import slim
import logging
DEFAULT_DATASET_DIR = os.path.join(os.path.dirname(__file__), 'data/fsns')
# The dataset configuration, should be used only as a default value.
DEFAULT_CONFIG = {
'name': 'FSNS',
'splits': {
'train': {
'size': 1044868,
'pattern': 'train/train*'
},
'test': {
'size': 20404,
'pattern': 'test/test*'
},
'validation': {
'size': 16150,
'pattern': 'validation/validation*'
}
},
'charset_filename': 'charset_size=134.txt',
'image_shape': (150, 600, 3),
'num_of_views': 4,
'max_sequence_length': 37,
'null_code': 133,
'items_to_descriptions': {
'image': 'A [150 x 600 x 3] color image.',
'label': 'Characters codes.',
'text': 'A unicode string.',
'length': 'A length of the encoded text.',
'num_of_views': 'A number of different views stored within the image.'
}
}
def read_charset(filename, null_character=u'\u2591'):
"""Reads a charset definition from a tab separated text file.
charset file has to have format compatible with the FSNS dataset.
Args:
filename: a path to the charset file.
null_character: a unicode character used to replace '<null>' character. the
default value is a light shade block '░'.
Returns:
a dictionary with keys equal to character codes and values - unicode
characters.
"""
pattern = re.compile(r'(\d+)\t(.+)')
charset = {}
with tf.gfile.GFile(filename) as f:
for i, line in enumerate(f):
m = pattern.match(line)
if m is None:
logging.warning('incorrect charset file. line #%d: %s', i, line)
continue
code = int(m.group(1))
char = m.group(2).decode('utf-8')
if char == '<nul>':
char = null_character
charset[code] = char
return charset
class _NumOfViewsHandler(slim.tfexample_decoder.ItemHandler):
"""Convenience handler to determine number of views stored in an image."""
def __init__(self, width_key, original_width_key, num_of_views):
super(_NumOfViewsHandler, self).__init__([width_key, original_width_key])
self._width_key = width_key
self._original_width_key = original_width_key
self._num_of_views = num_of_views
def tensors_to_item(self, keys_to_tensors):
return tf.to_int64(
self._num_of_views * keys_to_tensors[self._original_width_key] /
keys_to_tensors[self._width_key])
def get_split(split_name, dataset_dir=None, config=None):
"""Returns a dataset tuple for FSNS dataset.
Args:
split_name: A train/test split name.
dataset_dir: The base directory of the dataset sources, by default it uses
a predefined CNS path (see DEFAULT_DATASET_DIR).
config: A dictionary with dataset configuration. If None - will use the
DEFAULT_CONFIG.
Returns:
A `Dataset` namedtuple.
Raises:
ValueError: if `split_name` is not a valid train/test split.
"""
if not dataset_dir:
dataset_dir = DEFAULT_DATASET_DIR
if not config:
config = DEFAULT_CONFIG
if split_name not in config['splits']:
raise ValueError('split name %s was not recognized.' % split_name)
logging.info('Using %s dataset split_name=%s dataset_dir=%s', config['name'],
split_name, dataset_dir)
# Ignores the 'image/height' feature.
zero = tf.zeros([1], dtype=tf.int64)
keys_to_features = {
'image/encoded':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/format':
tf.FixedLenFeature((), tf.string, default_value='png'),
'image/width':
tf.FixedLenFeature([1], tf.int64, default_value=zero),
'image/orig_width':
tf.FixedLenFeature([1], tf.int64, default_value=zero),
'image/class':
tf.FixedLenFeature([config['max_sequence_length']], tf.int64),
'image/unpadded_class':
tf.VarLenFeature(tf.int64),
'image/text':
tf.FixedLenFeature([1], tf.string, default_value=''),
}
items_to_handlers = {
'image':
slim.tfexample_decoder.Image(
shape=config['image_shape'],
image_key='image/encoded',
format_key='image/format'),
'label':
slim.tfexample_decoder.Tensor(tensor_key='image/class'),
'text':
slim.tfexample_decoder.Tensor(tensor_key='image/text'),
'num_of_views':
_NumOfViewsHandler(
width_key='image/width',
original_width_key='image/orig_width',
num_of_views=config['num_of_views'])
}
decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
charset_file = os.path.join(dataset_dir, config['charset_filename'])
charset = read_charset(charset_file)
file_pattern = os.path.join(dataset_dir,
config['splits'][split_name]['pattern'])
return slim.dataset.Dataset(
data_sources=file_pattern,
reader=tf.TFRecordReader,
decoder=decoder,
num_samples=config['splits'][split_name]['size'],
items_to_descriptions=config['items_to_descriptions'],
# additional parameters for convenience.
charset=charset,
num_char_classes=len(charset),
num_of_views=config['num_of_views'],
max_sequence_length=config['max_sequence_length'],
null_code=config['null_code'])
This diff is collapsed.
http://download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment