Unverified Commit 71349a10 authored by Martin Wicke's avatar Martin Wicke Committed by GitHub
Browse files

Merge pull request #2565 from alexgorban/master

#attention_ocr: fix deprecation warnings and update usage examples
parents 3653ef1b d906b135
...@@ -34,7 +34,7 @@ pip install --upgrade tensorflow-gpu ...@@ -34,7 +34,7 @@ pip install --upgrade tensorflow-gpu
2. At least 158GB of free disk space to download the FSNS dataset: 2. At least 158GB of free disk space to download the FSNS dataset:
``` ```
cd models/attention_ocr/python/datasets cd research/attention_ocr/python/datasets
aria2c -c -j 20 -i ../../../street/python/fsns_urls.txt aria2c -c -j 20 -i ../../../street/python/fsns_urls.txt
cd .. cd ..
``` ```
...@@ -50,7 +50,7 @@ cd .. ...@@ -50,7 +50,7 @@ cd ..
To run all unit tests: To run all unit tests:
``` ```
cd models/attention_ocr/python cd research/attention_ocr/python
python -m unittest discover -p '*_test.py' python -m unittest discover -p '*_test.py'
``` ```
......
...@@ -12,6 +12,7 @@ https://www.tensorflow.org/serving/serving_basic ...@@ -12,6 +12,7 @@ https://www.tensorflow.org/serving/serving_basic
Usage: Usage:
python demo_inference.py --batch_size=32 \ python demo_inference.py --batch_size=32 \
--checkpoint=model.ckpt-399731\
--image_path_pattern=./datasets/data/fsns/temp/fsns_train_%02d.png --image_path_pattern=./datasets/data/fsns/temp/fsns_train_%02d.png
""" """
import numpy as np import numpy as np
......
...@@ -299,7 +299,7 @@ class Model(object): ...@@ -299,7 +299,7 @@ class Model(object):
with shape [batch_size x seq_length]. with shape [batch_size x seq_length].
""" """
log_prob = utils.logits_to_log_prob(chars_logit) log_prob = utils.logits_to_log_prob(chars_logit)
ids = tf.to_int32(tf.argmax(log_prob, dimension=2), name='predicted_chars') ids = tf.to_int32(tf.argmax(log_prob, axis=2), name='predicted_chars')
mask = tf.cast( mask = tf.cast(
slim.one_hot_encoding(ids, self._params.num_char_classes), tf.bool) slim.one_hot_encoding(ids, self._params.num_char_classes), tf.bool)
all_scores = tf.nn.softmax(chars_logit) all_scores = tf.nn.softmax(chars_logit)
......
...@@ -19,7 +19,6 @@ import numpy as np ...@@ -19,7 +19,6 @@ import numpy as np
import string import string
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib import slim from tensorflow.contrib import slim
from tensorflow.contrib.tfprof import model_analyzer
import model import model
import data_provider import data_provider
...@@ -127,9 +126,9 @@ class ModelTest(tf.test.TestCase): ...@@ -127,9 +126,9 @@ class ModelTest(tf.test.TestCase):
ocr_model = self.create_model() ocr_model = self.create_model()
ocr_model.create_base(images=self.fake_images, labels_one_hot=None) ocr_model.create_base(images=self.fake_images, labels_one_hot=None)
with self.test_session() as sess: with self.test_session() as sess:
tfprof_root = model_analyzer.print_model_analysis( tfprof_root = tf.profiler.profile(
sess.graph, sess.graph,
tfprof_options=model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS) options=tf.profiler.ProfileOptionBuilder.trainable_variables_parameter())
model_size_bytes = 4 * tfprof_root.total_parameters model_size_bytes = 4 * tfprof_root.total_parameters
self.assertLess(model_size_bytes, 1 * 2**30) self.assertLess(model_size_bytes, 1 * 2**30)
......
...@@ -216,7 +216,7 @@ class SequenceLayerBase(object): ...@@ -216,7 +216,7 @@ class SequenceLayerBase(object):
Returns: Returns:
A tensor with shape [batch_size, num_char_classes] A tensor with shape [batch_size, num_char_classes]
""" """
prediction = tf.argmax(logit, dimension=1) prediction = tf.argmax(logit, axis=1)
return slim.one_hot_encoding(prediction, self._params.num_char_classes) return slim.one_hot_encoding(prediction, self._params.num_char_classes)
def get_input(self, prev, i): def get_input(self, prev, i):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment