"src/diffusers/pipelines/sana/__init__.py" did not exist on "22c4f079b1293415de58645ed1df7a92f55635e5"
Unverified Commit 730035d6 authored by Alexander Gorban's avatar Alexander Gorban Committed by GitHub
Browse files

Update research/attention_ocr model to be Python3 compatible (#8134)

* research/attention_ocr: Minor changes to make it compatible with python 3.

* research/attention_ocr: Script to create a smaller test file.
parent 7f926353
......@@ -25,10 +25,10 @@ Pull requests:
1. Install the TensorFlow library ([instructions][TF]). For example:
```
virtualenv --system-site-packages ~/.tensorflow
python3 -m venv ~/.tensorflow
source ~/.tensorflow/bin/activate
pip install --upgrade pip
pip install --upgrade tensorflow-gpu
pip install --upgrade tensorflow-gpu=1.15
```
2. At least 158GB of free disk space to download the FSNS dataset:
......@@ -51,7 +51,7 @@ To run all unit tests:
```
cd research/attention_ocr/python
python -m unittest discover -p '*_test.py'
find . -name "*_test.py" -printf '%P\n' | xargs python3 -m unittest
```
To train from scratch:
......
......@@ -109,8 +109,8 @@ def central_crop(image, crop_size):
tf.greater_equal(image_width, target_width),
['image_width < target_width', image_width, target_width])
with tf.control_dependencies([assert_op1, assert_op2]):
offset_width = (image_width - target_width) / 2
offset_height = (image_height - target_height) / 2
offset_width = tf.cast((image_width - target_width) / 2, tf.int32)
offset_height = tf.cast((image_height - target_height) / 2, tf.int32)
return tf.image.crop_to_bounding_box(image, offset_height, offset_width,
target_height, target_width)
......@@ -137,7 +137,7 @@ def preprocess_image(image, augment=False, central_crop_size=None,
else:
images = tf.split(value=image, num_or_size_splits=num_towers, axis=1)
if central_crop_size:
view_crop_size = (central_crop_size[0] / num_towers,
view_crop_size = (int(central_crop_size[0] / num_towers),
central_crop_size[1])
images = [central_crop(img, view_crop_size) for img in images]
if augment:
......
......@@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
import fsns
import fsns_test
from datasets import fsns
from datasets import fsns_test
__all__ = [fsns, fsns_test]
......@@ -79,7 +79,7 @@ def read_charset(filename, null_character=u'\u2591'):
logging.warning('incorrect charset file. line #%d: %s', i, line)
continue
code = int(m.group(1))
char = m.group(2).decode('utf-8')
char = m.group(2)
if char == '<nul>':
char = null_character
charset[code] = char
......
......@@ -20,15 +20,15 @@ import os
import tensorflow as tf
from tensorflow.contrib import slim
import fsns
import unittest_utils
from datasets import fsns
from datasets import unittest_utils
FLAGS = tf.flags.FLAGS
def get_test_split():
config = fsns.DEFAULT_CONFIG.copy()
config['splits'] = {'test': {'size': 50, 'pattern': 'fsns-00000-of-00001'}}
config['splits'] = {'test': {'size': 5, 'pattern': 'fsns-00000-of-00001'}}
return fsns.get_split('test', dataset_dir(), config)
......@@ -43,12 +43,12 @@ class FsnsTest(tf.test.TestCase):
'PNG', shape=(150, 600, 3))
serialized = unittest_utils.create_serialized_example({
'image/encoded': [encoded],
'image/format': ['PNG'],
'image/format': [b'PNG'],
'image/class':
expected_label,
'image/unpadded_class':
range(10),
'image/text': ['Raw text'],
'image/text': [b'Raw text'],
'image/orig_width': [150],
'image/width': [600]
})
......@@ -60,7 +60,7 @@ class FsnsTest(tf.test.TestCase):
self.assertAllEqual(expected_image, data.image)
self.assertAllEqual(expected_label, data.label)
self.assertEqual(['Raw text'], data.text)
self.assertEqual([b'Raw text'], data.text)
self.assertEqual([1], data.num_of_views)
def test_label_has_shape_defined(self):
......
import urllib.request
import tensorflow as tf
import itertools
URL = 'http://download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001'
DST_ORIG = 'fsns-00000-of-00001.orig'
DST = 'fsns-00000-of-00001'
KEEP_NUM_RECORDS = 5
print('Downloading %s ...' % URL)
urllib.request.urlretrieve(URL, DST_ORIG)
print('Writing %d records from %s to %s ...' % (KEEP_NUM_RECORDS, DST_ORIG, DST))
with tf.io.TFRecordWriter(DST) as writer:
for raw_record in itertools.islice(tf.python_io.tf_record_iterator(DST_ORIG), KEEP_NUM_RECORDS):
writer.write(raw_record)
......@@ -15,12 +15,11 @@
"""Functions to make unit testing easier."""
import StringIO
import numpy as np
import io
from PIL import Image as PILImage
import tensorflow as tf
def create_random_image(image_format, shape):
"""Creates an image with random values.
......@@ -32,10 +31,10 @@ def create_random_image(image_format, shape):
A tuple (<numpy ndarray>, <a string with encoded image>)
"""
image = np.random.randint(low=0, high=255, size=shape, dtype='uint8')
io = StringIO.StringIO()
fd = io.BytesIO()
image_pil = PILImage.fromarray(image)
image_pil.save(io, image_format, subsampling=0, quality=100)
return image, io.getvalue()
image_pil.save(fd, image_format, subsampling=0, quality=100)
return image, fd.getvalue()
def create_serialized_example(name_to_values):
......@@ -52,7 +51,7 @@ def create_serialized_example(name_to_values):
example = tf.train.Example()
for name, values in name_to_values.items():
feature = example.features.feature[name]
if isinstance(values[0], str):
if isinstance(values[0], str) or isinstance(values[0], bytes):
add = feature.bytes_list.value.extend
elif isinstance(values[0], float):
add = feature.float32_list.value.extend
......
......@@ -14,13 +14,13 @@
# ==============================================================================
"""Tests for unittest_utils."""
import StringIO
import numpy as np
import io
from PIL import Image as PILImage
import tensorflow as tf
import unittest_utils
from datasets import unittest_utils
class UnittestUtilsTest(tf.test.TestCase):
......@@ -30,13 +30,13 @@ class UnittestUtilsTest(tf.test.TestCase):
def test_encoded_image_corresponds_to_numpy_array(self):
image, encoded = unittest_utils.create_random_image('PNG', (20, 10, 3))
pil_image = PILImage.open(StringIO.StringIO(encoded))
pil_image = PILImage.open(io.BytesIO(encoded))
self.assertAllEqual(image, np.array(pil_image))
def test_created_example_has_correct_values(self):
example_serialized = unittest_utils.create_serialized_example({
'labels': [1, 2, 3],
'data': ['FAKE']
'data': [b'FAKE']
})
example = tf.train.Example()
example.ParseFromString(example_serialized)
......
......@@ -49,7 +49,7 @@ def load_images(file_pattern, batch_size, dataset_name):
for i in range(batch_size):
path = file_pattern % i
print("Reading %s" % path)
pil_image = PIL.Image.open(tf.gfile.GFile(path))
pil_image = PIL.Image.open(tf.gfile.GFile(path, 'rb'))
images_actual_data[i, ...] = np.asarray(pil_image)
return images_actual_data
......@@ -81,7 +81,7 @@ def run(checkpoint, batch_size, dataset_name, image_path_pattern):
session_creator=session_creator) as sess:
predictions = sess.run(endpoints.predicted_text,
feed_dict={images_placeholder: images_data})
return predictions.tolist()
return [pr_bytes.decode('utf-8') for pr_bytes in predictions.tolist()]
def main(_):
......
#!/usr/bin/python
# -*- coding: UTF-8 -*-
import os
import demo_inference
import tensorflow as tf
from tensorflow.python.training import monitored_session
......@@ -18,6 +19,7 @@ class DemoInferenceTest(tf.test.TestCase):
'Please download and extract it from %s' %
(filename, _CHECKPOINT_URL))
self._batch_size = 32
tf.flags.FLAGS.dataset_dir = os.path.join(os.path.dirname(__file__), 'datasets/testdata/fsns')
def test_moving_variables_properly_loaded_from_a_checkpoint(self):
batch_size = 32
......@@ -48,7 +50,7 @@ class DemoInferenceTest(tf.test.TestCase):
'fsns',
image_path_pattern)
self.assertEqual([
'Boulevard de Lunel░░░░░░░░░░░░░░░░░░░',
u'Boulevard de Lunel░░░░░░░░░░░░░░░░░░░',
'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░',
'Rue de Port Maria░░░░░░░░░░░░░░░░░░░░',
'Avenue Charles Gounod░░░░░░░░░░░░░░░░',
......
......@@ -37,7 +37,7 @@ class AccuracyTest(tf.test.TestCase):
Yields:
A session object that should be used as a context manager.
"""
with self.test_session() as sess:
with self.cached_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
yield sess
......
......@@ -63,7 +63,7 @@ EncodeCoordinatesParams = collections.namedtuple('EncodeCoordinatesParams', [
def _dict_to_array(id_to_char, default_character):
num_char_classes = max(id_to_char.keys()) + 1
array = [default_character] * num_char_classes
for k, v in id_to_char.iteritems():
for k, v in id_to_char.items():
array[k] = v
return array
......@@ -534,10 +534,10 @@ class Model(object):
streaming=True,
rej_char=self._params.null_code))
for name, value in names_to_values.iteritems():
for name, value in names_to_values.items():
summary_name = 'eval/' + name
tf.summary.scalar(summary_name, tf.Print(value, [value], summary_name))
return names_to_updates.values()
return list(names_to_updates.values())
def create_init_fn_to_restore(self, master_checkpoint,
inception_checkpoint=None):
......
......@@ -16,7 +16,6 @@
"""Tests for the model."""
import numpy as np
from six.moves import xrange
import string
import tensorflow as tf
from tensorflow.contrib import slim
......@@ -27,7 +26,7 @@ import data_provider
def create_fake_charset(num_char_classes):
charset = {}
for i in xrange(num_char_classes):
for i in range(num_char_classes):
charset[i] = string.printable[i % len(string.printable)]
return charset
......@@ -179,13 +178,13 @@ class ModelTest(tf.test.TestCase):
tf.reshape(
tf.contrib.layers.one_hot_encoding(
tf.constant([i]), num_classes=h), [h, 1]), [1, w])
for i in xrange(h)
for i in range(h)
]
h_loc = tf.concat([tf.expand_dims(t, 2) for t in h_loc], 2)
w_loc = [
tf.tile(
tf.contrib.layers.one_hot_encoding(tf.constant([i]), num_classes=w),
[h, 1]) for i in xrange(w)
[h, 1]) for i in range(w)
]
w_loc = tf.concat([tf.expand_dims(t, 2) for t in w_loc], 2)
loc = tf.concat([h_loc, w_loc], 2)
......@@ -272,7 +271,7 @@ class CharsetMapperTest(tf.test.TestCase):
tf.tables_initializer().run()
text = sess.run(charset_mapper.get_text(ids))
self.assertAllEqual(text, ['hello', 'world'])
self.assertAllEqual(text, [b'hello', b'world'])
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment