Unverified Commit 730035d6 authored by Alexander Gorban's avatar Alexander Gorban Committed by GitHub
Browse files

Update research/attention_ocr model to be Python3 compatible (#8134)

* research/attention_ocr: Minor changes to make it compatible with python 3.

* research/attention_ocr: Script to create a smaller test file.
parent 7f926353
...@@ -25,10 +25,10 @@ Pull requests: ...@@ -25,10 +25,10 @@ Pull requests:
1. Install the TensorFlow library ([instructions][TF]). For example: 1. Install the TensorFlow library ([instructions][TF]). For example:
``` ```
virtualenv --system-site-packages ~/.tensorflow python3 -m venv ~/.tensorflow
source ~/.tensorflow/bin/activate source ~/.tensorflow/bin/activate
pip install --upgrade pip pip install --upgrade pip
pip install --upgrade tensorflow-gpu pip install --upgrade tensorflow-gpu=1.15
``` ```
2. At least 158GB of free disk space to download the FSNS dataset: 2. At least 158GB of free disk space to download the FSNS dataset:
...@@ -51,7 +51,7 @@ To run all unit tests: ...@@ -51,7 +51,7 @@ To run all unit tests:
``` ```
cd research/attention_ocr/python cd research/attention_ocr/python
python -m unittest discover -p '*_test.py' find . -name "*_test.py" -printf '%P\n' | xargs python3 -m unittest
``` ```
To train from scratch: To train from scratch:
......
...@@ -109,8 +109,8 @@ def central_crop(image, crop_size): ...@@ -109,8 +109,8 @@ def central_crop(image, crop_size):
tf.greater_equal(image_width, target_width), tf.greater_equal(image_width, target_width),
['image_width < target_width', image_width, target_width]) ['image_width < target_width', image_width, target_width])
with tf.control_dependencies([assert_op1, assert_op2]): with tf.control_dependencies([assert_op1, assert_op2]):
offset_width = (image_width - target_width) / 2 offset_width = tf.cast((image_width - target_width) / 2, tf.int32)
offset_height = (image_height - target_height) / 2 offset_height = tf.cast((image_height - target_height) / 2, tf.int32)
return tf.image.crop_to_bounding_box(image, offset_height, offset_width, return tf.image.crop_to_bounding_box(image, offset_height, offset_width,
target_height, target_width) target_height, target_width)
...@@ -137,7 +137,7 @@ def preprocess_image(image, augment=False, central_crop_size=None, ...@@ -137,7 +137,7 @@ def preprocess_image(image, augment=False, central_crop_size=None,
else: else:
images = tf.split(value=image, num_or_size_splits=num_towers, axis=1) images = tf.split(value=image, num_or_size_splits=num_towers, axis=1)
if central_crop_size: if central_crop_size:
view_crop_size = (central_crop_size[0] / num_towers, view_crop_size = (int(central_crop_size[0] / num_towers),
central_crop_size[1]) central_crop_size[1])
images = [central_crop(img, view_crop_size) for img in images] images = [central_crop(img, view_crop_size) for img in images]
if augment: if augment:
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import fsns from datasets import fsns
import fsns_test from datasets import fsns_test
__all__ = [fsns, fsns_test] __all__ = [fsns, fsns_test]
...@@ -79,7 +79,7 @@ def read_charset(filename, null_character=u'\u2591'): ...@@ -79,7 +79,7 @@ def read_charset(filename, null_character=u'\u2591'):
logging.warning('incorrect charset file. line #%d: %s', i, line) logging.warning('incorrect charset file. line #%d: %s', i, line)
continue continue
code = int(m.group(1)) code = int(m.group(1))
char = m.group(2).decode('utf-8') char = m.group(2)
if char == '<nul>': if char == '<nul>':
char = null_character char = null_character
charset[code] = char charset[code] = char
......
...@@ -20,15 +20,15 @@ import os ...@@ -20,15 +20,15 @@ import os
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib import slim from tensorflow.contrib import slim
import fsns from datasets import fsns
import unittest_utils from datasets import unittest_utils
FLAGS = tf.flags.FLAGS FLAGS = tf.flags.FLAGS
def get_test_split(): def get_test_split():
config = fsns.DEFAULT_CONFIG.copy() config = fsns.DEFAULT_CONFIG.copy()
config['splits'] = {'test': {'size': 50, 'pattern': 'fsns-00000-of-00001'}} config['splits'] = {'test': {'size': 5, 'pattern': 'fsns-00000-of-00001'}}
return fsns.get_split('test', dataset_dir(), config) return fsns.get_split('test', dataset_dir(), config)
...@@ -43,12 +43,12 @@ class FsnsTest(tf.test.TestCase): ...@@ -43,12 +43,12 @@ class FsnsTest(tf.test.TestCase):
'PNG', shape=(150, 600, 3)) 'PNG', shape=(150, 600, 3))
serialized = unittest_utils.create_serialized_example({ serialized = unittest_utils.create_serialized_example({
'image/encoded': [encoded], 'image/encoded': [encoded],
'image/format': ['PNG'], 'image/format': [b'PNG'],
'image/class': 'image/class':
expected_label, expected_label,
'image/unpadded_class': 'image/unpadded_class':
range(10), range(10),
'image/text': ['Raw text'], 'image/text': [b'Raw text'],
'image/orig_width': [150], 'image/orig_width': [150],
'image/width': [600] 'image/width': [600]
}) })
...@@ -60,7 +60,7 @@ class FsnsTest(tf.test.TestCase): ...@@ -60,7 +60,7 @@ class FsnsTest(tf.test.TestCase):
self.assertAllEqual(expected_image, data.image) self.assertAllEqual(expected_image, data.image)
self.assertAllEqual(expected_label, data.label) self.assertAllEqual(expected_label, data.label)
self.assertEqual(['Raw text'], data.text) self.assertEqual([b'Raw text'], data.text)
self.assertEqual([1], data.num_of_views) self.assertEqual([1], data.num_of_views)
def test_label_has_shape_defined(self): def test_label_has_shape_defined(self):
......
import urllib.request
import tensorflow as tf
import itertools
URL = 'http://download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001'
DST_ORIG = 'fsns-00000-of-00001.orig'
DST = 'fsns-00000-of-00001'
KEEP_NUM_RECORDS = 5
print('Downloading %s ...' % URL)
urllib.request.urlretrieve(URL, DST_ORIG)
print('Writing %d records from %s to %s ...' % (KEEP_NUM_RECORDS, DST_ORIG, DST))
with tf.io.TFRecordWriter(DST) as writer:
for raw_record in itertools.islice(tf.python_io.tf_record_iterator(DST_ORIG), KEEP_NUM_RECORDS):
writer.write(raw_record)
...@@ -15,12 +15,11 @@ ...@@ -15,12 +15,11 @@
"""Functions to make unit testing easier.""" """Functions to make unit testing easier."""
import StringIO
import numpy as np import numpy as np
import io
from PIL import Image as PILImage from PIL import Image as PILImage
import tensorflow as tf import tensorflow as tf
def create_random_image(image_format, shape): def create_random_image(image_format, shape):
"""Creates an image with random values. """Creates an image with random values.
...@@ -32,10 +31,10 @@ def create_random_image(image_format, shape): ...@@ -32,10 +31,10 @@ def create_random_image(image_format, shape):
A tuple (<numpy ndarray>, <a string with encoded image>) A tuple (<numpy ndarray>, <a string with encoded image>)
""" """
image = np.random.randint(low=0, high=255, size=shape, dtype='uint8') image = np.random.randint(low=0, high=255, size=shape, dtype='uint8')
io = StringIO.StringIO() fd = io.BytesIO()
image_pil = PILImage.fromarray(image) image_pil = PILImage.fromarray(image)
image_pil.save(io, image_format, subsampling=0, quality=100) image_pil.save(fd, image_format, subsampling=0, quality=100)
return image, io.getvalue() return image, fd.getvalue()
def create_serialized_example(name_to_values): def create_serialized_example(name_to_values):
...@@ -52,7 +51,7 @@ def create_serialized_example(name_to_values): ...@@ -52,7 +51,7 @@ def create_serialized_example(name_to_values):
example = tf.train.Example() example = tf.train.Example()
for name, values in name_to_values.items(): for name, values in name_to_values.items():
feature = example.features.feature[name] feature = example.features.feature[name]
if isinstance(values[0], str): if isinstance(values[0], str) or isinstance(values[0], bytes):
add = feature.bytes_list.value.extend add = feature.bytes_list.value.extend
elif isinstance(values[0], float): elif isinstance(values[0], float):
add = feature.float32_list.value.extend add = feature.float32_list.value.extend
......
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
# ============================================================================== # ==============================================================================
"""Tests for unittest_utils.""" """Tests for unittest_utils."""
import StringIO
import numpy as np import numpy as np
import io
from PIL import Image as PILImage from PIL import Image as PILImage
import tensorflow as tf import tensorflow as tf
import unittest_utils from datasets import unittest_utils
class UnittestUtilsTest(tf.test.TestCase): class UnittestUtilsTest(tf.test.TestCase):
...@@ -30,13 +30,13 @@ class UnittestUtilsTest(tf.test.TestCase): ...@@ -30,13 +30,13 @@ class UnittestUtilsTest(tf.test.TestCase):
def test_encoded_image_corresponds_to_numpy_array(self): def test_encoded_image_corresponds_to_numpy_array(self):
image, encoded = unittest_utils.create_random_image('PNG', (20, 10, 3)) image, encoded = unittest_utils.create_random_image('PNG', (20, 10, 3))
pil_image = PILImage.open(StringIO.StringIO(encoded)) pil_image = PILImage.open(io.BytesIO(encoded))
self.assertAllEqual(image, np.array(pil_image)) self.assertAllEqual(image, np.array(pil_image))
def test_created_example_has_correct_values(self): def test_created_example_has_correct_values(self):
example_serialized = unittest_utils.create_serialized_example({ example_serialized = unittest_utils.create_serialized_example({
'labels': [1, 2, 3], 'labels': [1, 2, 3],
'data': ['FAKE'] 'data': [b'FAKE']
}) })
example = tf.train.Example() example = tf.train.Example()
example.ParseFromString(example_serialized) example.ParseFromString(example_serialized)
......
...@@ -49,7 +49,7 @@ def load_images(file_pattern, batch_size, dataset_name): ...@@ -49,7 +49,7 @@ def load_images(file_pattern, batch_size, dataset_name):
for i in range(batch_size): for i in range(batch_size):
path = file_pattern % i path = file_pattern % i
print("Reading %s" % path) print("Reading %s" % path)
pil_image = PIL.Image.open(tf.gfile.GFile(path)) pil_image = PIL.Image.open(tf.gfile.GFile(path, 'rb'))
images_actual_data[i, ...] = np.asarray(pil_image) images_actual_data[i, ...] = np.asarray(pil_image)
return images_actual_data return images_actual_data
...@@ -81,7 +81,7 @@ def run(checkpoint, batch_size, dataset_name, image_path_pattern): ...@@ -81,7 +81,7 @@ def run(checkpoint, batch_size, dataset_name, image_path_pattern):
session_creator=session_creator) as sess: session_creator=session_creator) as sess:
predictions = sess.run(endpoints.predicted_text, predictions = sess.run(endpoints.predicted_text,
feed_dict={images_placeholder: images_data}) feed_dict={images_placeholder: images_data})
return predictions.tolist() return [pr_bytes.decode('utf-8') for pr_bytes in predictions.tolist()]
def main(_): def main(_):
......
#!/usr/bin/python #!/usr/bin/python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
import os
import demo_inference import demo_inference
import tensorflow as tf import tensorflow as tf
from tensorflow.python.training import monitored_session from tensorflow.python.training import monitored_session
...@@ -18,6 +19,7 @@ class DemoInferenceTest(tf.test.TestCase): ...@@ -18,6 +19,7 @@ class DemoInferenceTest(tf.test.TestCase):
'Please download and extract it from %s' % 'Please download and extract it from %s' %
(filename, _CHECKPOINT_URL)) (filename, _CHECKPOINT_URL))
self._batch_size = 32 self._batch_size = 32
tf.flags.FLAGS.dataset_dir = os.path.join(os.path.dirname(__file__), 'datasets/testdata/fsns')
def test_moving_variables_properly_loaded_from_a_checkpoint(self): def test_moving_variables_properly_loaded_from_a_checkpoint(self):
batch_size = 32 batch_size = 32
...@@ -48,7 +50,7 @@ class DemoInferenceTest(tf.test.TestCase): ...@@ -48,7 +50,7 @@ class DemoInferenceTest(tf.test.TestCase):
'fsns', 'fsns',
image_path_pattern) image_path_pattern)
self.assertEqual([ self.assertEqual([
'Boulevard de Lunel░░░░░░░░░░░░░░░░░░░', u'Boulevard de Lunel░░░░░░░░░░░░░░░░░░░',
'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░', 'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░',
'Rue de Port Maria░░░░░░░░░░░░░░░░░░░░', 'Rue de Port Maria░░░░░░░░░░░░░░░░░░░░',
'Avenue Charles Gounod░░░░░░░░░░░░░░░░', 'Avenue Charles Gounod░░░░░░░░░░░░░░░░',
......
...@@ -37,7 +37,7 @@ class AccuracyTest(tf.test.TestCase): ...@@ -37,7 +37,7 @@ class AccuracyTest(tf.test.TestCase):
Yields: Yields:
A session object that should be used as a context manager. A session object that should be used as a context manager.
""" """
with self.test_session() as sess: with self.cached_session() as sess:
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer()) sess.run(tf.local_variables_initializer())
yield sess yield sess
......
...@@ -63,7 +63,7 @@ EncodeCoordinatesParams = collections.namedtuple('EncodeCoordinatesParams', [ ...@@ -63,7 +63,7 @@ EncodeCoordinatesParams = collections.namedtuple('EncodeCoordinatesParams', [
def _dict_to_array(id_to_char, default_character): def _dict_to_array(id_to_char, default_character):
num_char_classes = max(id_to_char.keys()) + 1 num_char_classes = max(id_to_char.keys()) + 1
array = [default_character] * num_char_classes array = [default_character] * num_char_classes
for k, v in id_to_char.iteritems(): for k, v in id_to_char.items():
array[k] = v array[k] = v
return array return array
...@@ -534,10 +534,10 @@ class Model(object): ...@@ -534,10 +534,10 @@ class Model(object):
streaming=True, streaming=True,
rej_char=self._params.null_code)) rej_char=self._params.null_code))
for name, value in names_to_values.iteritems(): for name, value in names_to_values.items():
summary_name = 'eval/' + name summary_name = 'eval/' + name
tf.summary.scalar(summary_name, tf.Print(value, [value], summary_name)) tf.summary.scalar(summary_name, tf.Print(value, [value], summary_name))
return names_to_updates.values() return list(names_to_updates.values())
def create_init_fn_to_restore(self, master_checkpoint, def create_init_fn_to_restore(self, master_checkpoint,
inception_checkpoint=None): inception_checkpoint=None):
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
"""Tests for the model.""" """Tests for the model."""
import numpy as np import numpy as np
from six.moves import xrange
import string import string
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib import slim from tensorflow.contrib import slim
...@@ -27,7 +26,7 @@ import data_provider ...@@ -27,7 +26,7 @@ import data_provider
def create_fake_charset(num_char_classes): def create_fake_charset(num_char_classes):
charset = {} charset = {}
for i in xrange(num_char_classes): for i in range(num_char_classes):
charset[i] = string.printable[i % len(string.printable)] charset[i] = string.printable[i % len(string.printable)]
return charset return charset
...@@ -179,13 +178,13 @@ class ModelTest(tf.test.TestCase): ...@@ -179,13 +178,13 @@ class ModelTest(tf.test.TestCase):
tf.reshape( tf.reshape(
tf.contrib.layers.one_hot_encoding( tf.contrib.layers.one_hot_encoding(
tf.constant([i]), num_classes=h), [h, 1]), [1, w]) tf.constant([i]), num_classes=h), [h, 1]), [1, w])
for i in xrange(h) for i in range(h)
] ]
h_loc = tf.concat([tf.expand_dims(t, 2) for t in h_loc], 2) h_loc = tf.concat([tf.expand_dims(t, 2) for t in h_loc], 2)
w_loc = [ w_loc = [
tf.tile( tf.tile(
tf.contrib.layers.one_hot_encoding(tf.constant([i]), num_classes=w), tf.contrib.layers.one_hot_encoding(tf.constant([i]), num_classes=w),
[h, 1]) for i in xrange(w) [h, 1]) for i in range(w)
] ]
w_loc = tf.concat([tf.expand_dims(t, 2) for t in w_loc], 2) w_loc = tf.concat([tf.expand_dims(t, 2) for t in w_loc], 2)
loc = tf.concat([h_loc, w_loc], 2) loc = tf.concat([h_loc, w_loc], 2)
...@@ -272,7 +271,7 @@ class CharsetMapperTest(tf.test.TestCase): ...@@ -272,7 +271,7 @@ class CharsetMapperTest(tf.test.TestCase):
tf.tables_initializer().run() tf.tables_initializer().run()
text = sess.run(charset_mapper.get_text(ids)) text = sess.run(charset_mapper.get_text(ids))
self.assertAllEqual(text, ['hello', 'world']) self.assertAllEqual(text, [b'hello', b'world'])
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment