Commit d723e734 authored by Alexander Gorban's avatar Alexander Gorban
Browse files

Fix all deprecation warnings.

1. Update README.md
2. argmax, use axis instead of deminsion
3. use tf.profiler.profile instead of model_analyzer.print_model_analysis
parent b4cf2302
...@@ -34,7 +34,7 @@ pip install --upgrade tensorflow-gpu ...@@ -34,7 +34,7 @@ pip install --upgrade tensorflow-gpu
2. At least 158GB of free disk space to download the FSNS dataset: 2. At least 158GB of free disk space to download the FSNS dataset:
``` ```
cd models/attention_ocr/python/datasets cd research/attention_ocr/python/datasets
aria2c -c -j 20 -i ../../../street/python/fsns_urls.txt aria2c -c -j 20 -i ../../../street/python/fsns_urls.txt
cd .. cd ..
``` ```
...@@ -50,7 +50,7 @@ cd .. ...@@ -50,7 +50,7 @@ cd ..
To run all unit tests: To run all unit tests:
``` ```
cd models/attention_ocr/python cd research/attention_ocr/python
python -m unittest discover -p '*_test.py' python -m unittest discover -p '*_test.py'
``` ```
......
...@@ -299,7 +299,7 @@ class Model(object): ...@@ -299,7 +299,7 @@ class Model(object):
with shape [batch_size x seq_length]. with shape [batch_size x seq_length].
""" """
log_prob = utils.logits_to_log_prob(chars_logit) log_prob = utils.logits_to_log_prob(chars_logit)
ids = tf.to_int32(tf.argmax(log_prob, dimension=2), name='predicted_chars') ids = tf.to_int32(tf.argmax(log_prob, axis=2), name='predicted_chars')
mask = tf.cast( mask = tf.cast(
slim.one_hot_encoding(ids, self._params.num_char_classes), tf.bool) slim.one_hot_encoding(ids, self._params.num_char_classes), tf.bool)
all_scores = tf.nn.softmax(chars_logit) all_scores = tf.nn.softmax(chars_logit)
......
...@@ -19,7 +19,6 @@ import numpy as np ...@@ -19,7 +19,6 @@ import numpy as np
import string import string
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib import slim from tensorflow.contrib import slim
from tensorflow.contrib.tfprof import model_analyzer
import model import model
import data_provider import data_provider
...@@ -127,9 +126,9 @@ class ModelTest(tf.test.TestCase): ...@@ -127,9 +126,9 @@ class ModelTest(tf.test.TestCase):
ocr_model = self.create_model() ocr_model = self.create_model()
ocr_model.create_base(images=self.fake_images, labels_one_hot=None) ocr_model.create_base(images=self.fake_images, labels_one_hot=None)
with self.test_session() as sess: with self.test_session() as sess:
tfprof_root = model_analyzer.print_model_analysis( tfprof_root = tf.profiler.profile(
sess.graph, sess.graph,
tfprof_options=model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS) options=tf.profiler.ProfileOptionBuilder.trainable_variables_parameter())
model_size_bytes = 4 * tfprof_root.total_parameters model_size_bytes = 4 * tfprof_root.total_parameters
self.assertLess(model_size_bytes, 1 * 2**30) self.assertLess(model_size_bytes, 1 * 2**30)
......
...@@ -216,7 +216,7 @@ class SequenceLayerBase(object): ...@@ -216,7 +216,7 @@ class SequenceLayerBase(object):
Returns: Returns:
A tensor with shape [batch_size, num_char_classes] A tensor with shape [batch_size, num_char_classes]
""" """
prediction = tf.argmax(logit, dimension=1) prediction = tf.argmax(logit, axis=1)
return slim.one_hot_encoding(prediction, self._params.num_char_classes) return slim.one_hot_encoding(prediction, self._params.num_char_classes)
def get_input(self, prev, i): def get_input(self, prev, i):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment