Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
51b7c2b3
Commit
51b7c2b3
authored
Sep 06, 2017
by
Martin Wicke
Committed by
GitHub
Sep 06, 2017
Browse files
Merge pull request #2166 from alexgorban/master
Demo script to do inference on a trained Attention OCR model
parents
6024579b
dff0f0c1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
147 additions
and
28 deletions
+147
-28
attention_ocr/python/demo_inference.py
attention_ocr/python/demo_inference.py
+88
-0
attention_ocr/python/model.py
attention_ocr/python/model.py
+41
-26
attention_ocr/python/model_test.py
attention_ocr/python/model_test.py
+18
-2
No files found.
attention_ocr/python/demo_inference.py
0 → 100644
View file @
51b7c2b3
"""A script to run inference on a set of image files.
NOTE #1: The Attention OCR model was trained only using FSNS train dataset and
it will work only for images which look more or less similar to french street
names. In order to apply it to images from a different distribution you need
to retrain (or at least fine-tune) it using images from that distribution.
NOTE #2: This script exists for demo purposes only. It is highly recommended
to use tools and mechanisms provided by the TensorFlow Serving system to run
inference on TensorFlow models in production:
https://www.tensorflow.org/serving/serving_basic
Usage:
python demo_inference.py --batch_size=32
\
--image_path_pattern=./datasets/data/fsns/temp/fsns_train_%02d.png
"""
import
numpy
as
np
import
PIL.Image
import
tensorflow
as
tf
from
tensorflow.python.platform
import
flags
import
common_flags
import
datasets
import
model
as
attention_ocr
FLAGS
=
flags
.
FLAGS
common_flags
.
define
()
# e.g. ./datasets/data/fsns/temp/fsns_train_%02d.png
flags
.
DEFINE_string
(
'image_path_pattern'
,
''
,
'A file pattern with a placeholder for the image index.'
)
def
get_dataset_image_size
(
dataset_name
):
# Ideally this info should be exposed through the dataset interface itself.
# But currently it is not available by other means.
ds_module
=
getattr
(
datasets
,
dataset_name
)
height
,
width
,
_
=
ds_module
.
DEFAULT_CONFIG
[
'image_shape'
]
return
width
,
height
def
load_images
(
file_pattern
,
batch_size
,
dataset_name
):
width
,
height
=
get_dataset_image_size
(
dataset_name
)
images_actual_data
=
np
.
ndarray
(
shape
=
(
batch_size
,
height
,
width
,
3
),
dtype
=
'float32'
)
for
i
in
range
(
batch_size
):
path
=
file_pattern
%
i
print
(
"Reading %s"
%
path
)
pil_image
=
PIL
.
Image
.
open
(
tf
.
gfile
.
GFile
(
path
))
images_actual_data
[
i
,
...]
=
np
.
asarray
(
pil_image
)
return
images_actual_data
def
load_model
(
checkpoint
,
batch_size
,
dataset_name
):
width
,
height
=
get_dataset_image_size
(
dataset_name
)
dataset
=
common_flags
.
create_dataset
(
split_name
=
FLAGS
.
split_name
)
model
=
common_flags
.
create_model
(
num_char_classes
=
dataset
.
num_char_classes
,
seq_length
=
dataset
.
max_sequence_length
,
num_views
=
dataset
.
num_of_views
,
null_code
=
dataset
.
null_code
,
charset
=
dataset
.
charset
)
images_placeholder
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
[
batch_size
,
height
,
width
,
3
])
endpoints
=
model
.
create_base
(
images_placeholder
,
labels_one_hot
=
None
)
init_fn
=
model
.
create_init_fn_to_restore
(
checkpoint
)
return
images_placeholder
,
endpoints
,
init_fn
def
main
(
_
):
images_placeholder
,
endpoints
,
init_fn
=
load_model
(
FLAGS
.
checkpoint
,
FLAGS
.
batch_size
,
FLAGS
.
dataset_name
)
images_data
=
load_images
(
FLAGS
.
image_path_pattern
,
FLAGS
.
batch_size
,
FLAGS
.
dataset_name
)
with
tf
.
Session
()
as
sess
:
tf
.
tables_initializer
().
run
()
# required by the CharsetMapper
init_fn
(
sess
)
predictions
=
sess
.
run
(
endpoints
.
predicted_text
,
feed_dict
=
{
images_placeholder
:
images_data
})
print
(
"Predicted strings:"
)
for
line
in
predictions
:
print
(
line
)
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
attention_ocr/python/model.py
View file @
51b7c2b3
...
@@ -34,25 +34,25 @@ import metrics
...
@@ -34,25 +34,25 @@ import metrics
import
sequence_layers
import
sequence_layers
import
utils
import
utils
OutputEndpoints
=
collections
.
namedtuple
(
'OutputEndpoints'
,
[
OutputEndpoints
=
collections
.
namedtuple
(
'OutputEndpoints'
,
[
'chars_logit'
,
'chars_log_prob'
,
'predicted_chars'
,
'predicted_scores'
'chars_logit'
,
'chars_log_prob'
,
'predicted_chars'
,
'predicted_scores'
,
'predicted_text'
])
])
# TODO(gorban): replace with tf.HParams when it is released.
# TODO(gorban): replace with tf.HParams when it is released.
ModelParams
=
collections
.
namedtuple
(
'ModelParams'
,
[
ModelParams
=
collections
.
namedtuple
(
'ModelParams'
,
[
'num_char_classes'
,
'seq_length'
,
'num_views'
,
'null_code'
'num_char_classes'
,
'seq_length'
,
'num_views'
,
'null_code'
])
])
ConvTowerParams
=
collections
.
namedtuple
(
'ConvTowerParams'
,
[
'final_endpoint'
])
ConvTowerParams
=
collections
.
namedtuple
(
'ConvTowerParams'
,
[
'final_endpoint'
])
SequenceLogitsParams
=
collections
.
namedtuple
(
'SequenceLogitsParams'
,
[
SequenceLogitsParams
=
collections
.
namedtuple
(
'SequenceLogitsParams'
,
[
'use_attention'
,
'use_autoregression'
,
'num_lstm_units'
,
'weight_decay'
,
'use_attention'
,
'use_autoregression'
,
'num_lstm_units'
,
'weight_decay'
,
'lstm_state_clip_value'
'lstm_state_clip_value'
])
])
SequenceLossParams
=
collections
.
namedtuple
(
'SequenceLossParams'
,
[
SequenceLossParams
=
collections
.
namedtuple
(
'SequenceLossParams'
,
[
'label_smoothing'
,
'ignore_nulls'
,
'average_across_timesteps'
'label_smoothing'
,
'ignore_nulls'
,
'average_across_timesteps'
])
])
EncodeCoordinatesParams
=
collections
.
namedtuple
(
'EncodeCoordinatesParams'
,
[
EncodeCoordinatesParams
=
collections
.
namedtuple
(
'EncodeCoordinatesParams'
,
[
...
@@ -125,11 +125,12 @@ class Model(object):
...
@@ -125,11 +125,12 @@ class Model(object):
"""Class to create the Attention OCR Model."""
"""Class to create the Attention OCR Model."""
def
__init__
(
self
,
def
__init__
(
self
,
num_char_classes
,
num_char_classes
,
seq_length
,
seq_length
,
num_views
,
num_views
,
null_code
,
null_code
,
mparams
=
None
):
mparams
=
None
,
charset
=
None
):
"""Initialized model parameters.
"""Initialized model parameters.
Args:
Args:
...
@@ -140,6 +141,13 @@ class Model(object):
...
@@ -140,6 +141,13 @@ class Model(object):
indicates end of a sequence.
indicates end of a sequence.
mparams: a dictionary with hyper parameters for methods, keys -
mparams: a dictionary with hyper parameters for methods, keys -
function names, values - corresponding namedtuples.
function names, values - corresponding namedtuples.
charset: an optional dictionary with a mapping between character ids and
utf8 strings. If specified the OutputEndpoints.predicted_text will
utf8 encoded strings corresponding to the character ids returned by
OutputEndpoints.predicted_chars (by default the predicted_text contains
an empty vector).
NOTE: Make sure you call tf.tables_initializer().run() if the charset
specified.
"""
"""
super
(
Model
,
self
).
__init__
()
super
(
Model
,
self
).
__init__
()
self
.
_params
=
ModelParams
(
self
.
_params
=
ModelParams
(
...
@@ -150,24 +158,25 @@ class Model(object):
...
@@ -150,24 +158,25 @@ class Model(object):
self
.
_mparams
=
self
.
default_mparams
()
self
.
_mparams
=
self
.
default_mparams
()
if
mparams
:
if
mparams
:
self
.
_mparams
.
update
(
mparams
)
self
.
_mparams
.
update
(
mparams
)
self
.
_charset
=
charset
def
default_mparams
(
self
):
def
default_mparams
(
self
):
return
{
return
{
'conv_tower_fn'
:
'conv_tower_fn'
:
ConvTowerParams
(
final_endpoint
=
'Mixed_5d'
),
ConvTowerParams
(
final_endpoint
=
'Mixed_5d'
),
'sequence_logit_fn'
:
'sequence_logit_fn'
:
SequenceLogitsParams
(
SequenceLogitsParams
(
use_attention
=
True
,
use_attention
=
True
,
use_autoregression
=
True
,
use_autoregression
=
True
,
num_lstm_units
=
256
,
num_lstm_units
=
256
,
weight_decay
=
0.00004
,
weight_decay
=
0.00004
,
lstm_state_clip_value
=
10.0
),
lstm_state_clip_value
=
10.0
),
'sequence_loss_fn'
:
'sequence_loss_fn'
:
SequenceLossParams
(
SequenceLossParams
(
label_smoothing
=
0.1
,
label_smoothing
=
0.1
,
ignore_nulls
=
True
,
ignore_nulls
=
True
,
average_across_timesteps
=
False
),
average_across_timesteps
=
False
),
'encode_coordinates_fn'
:
EncodeCoordinatesParams
(
enabled
=
False
)
'encode_coordinates_fn'
:
EncodeCoordinatesParams
(
enabled
=
False
)
}
}
def
set_mparam
(
self
,
function
,
**
kwargs
):
def
set_mparam
(
self
,
function
,
**
kwargs
):
...
@@ -241,7 +250,7 @@ class Model(object):
...
@@ -241,7 +250,7 @@ class Model(object):
A tensor with the same size as any input tensors.
A tensor with the same size as any input tensors.
"""
"""
batch_size
,
height
,
width
,
num_features
=
[
batch_size
,
height
,
width
,
num_features
=
[
d
.
value
for
d
in
nets_list
[
0
].
get_shape
().
dims
d
.
value
for
d
in
nets_list
[
0
].
get_shape
().
dims
]
]
xy_flat_shape
=
(
batch_size
,
1
,
height
*
width
,
num_features
)
xy_flat_shape
=
(
batch_size
,
1
,
height
*
width
,
num_features
)
nets_for_merge
=
[]
nets_for_merge
=
[]
...
@@ -323,10 +332,10 @@ class Model(object):
...
@@ -323,10 +332,10 @@ class Model(object):
return
net
return
net
def
create_base
(
self
,
def
create_base
(
self
,
images
,
images
,
labels_one_hot
,
labels_one_hot
,
scope
=
'AttentionOcr_v1'
,
scope
=
'AttentionOcr_v1'
,
reuse
=
None
):
reuse
=
None
):
"""Creates a base part of the Model (no gradients, losses or summaries).
"""Creates a base part of the Model (no gradients, losses or summaries).
Args:
Args:
...
@@ -348,8 +357,8 @@ class Model(object):
...
@@ -348,8 +357,8 @@ class Model(object):
logging
.
debug
(
'Views=%d single view: %s'
,
len
(
views
),
views
[
0
])
logging
.
debug
(
'Views=%d single view: %s'
,
len
(
views
),
views
[
0
])
nets
=
[
nets
=
[
self
.
conv_tower_fn
(
v
,
is_training
,
reuse
=
(
i
!=
0
))
self
.
conv_tower_fn
(
v
,
is_training
,
reuse
=
(
i
!=
0
))
for
i
,
v
in
enumerate
(
views
)
for
i
,
v
in
enumerate
(
views
)
]
]
logging
.
debug
(
'Conv tower: %s'
,
nets
[
0
])
logging
.
debug
(
'Conv tower: %s'
,
nets
[
0
])
...
@@ -363,13 +372,18 @@ class Model(object):
...
@@ -363,13 +372,18 @@ class Model(object):
logging
.
debug
(
'chars_logit: %s'
,
chars_logit
)
logging
.
debug
(
'chars_logit: %s'
,
chars_logit
)
predicted_chars
,
chars_log_prob
,
predicted_scores
=
(
predicted_chars
,
chars_log_prob
,
predicted_scores
=
(
self
.
char_predictions
(
chars_logit
))
self
.
char_predictions
(
chars_logit
))
if
self
.
_charset
:
character_mapper
=
CharsetMapper
(
self
.
_charset
)
predicted_text
=
character_mapper
.
get_text
(
predicted_chars
)
else
:
predicted_text
=
tf
.
constant
([])
return
OutputEndpoints
(
return
OutputEndpoints
(
chars_logit
=
chars_logit
,
chars_logit
=
chars_logit
,
chars_log_prob
=
chars_log_prob
,
chars_log_prob
=
chars_log_prob
,
predicted_chars
=
predicted_chars
,
predicted_chars
=
predicted_chars
,
predicted_scores
=
predicted_scores
)
predicted_scores
=
predicted_scores
,
predicted_text
=
predicted_text
)
def
create_loss
(
self
,
data
,
endpoints
):
def
create_loss
(
self
,
data
,
endpoints
):
"""Creates all losses required to train the model.
"""Creates all losses required to train the model.
...
@@ -523,7 +537,8 @@ class Model(object):
...
@@ -523,7 +537,8 @@ class Model(object):
tf
.
summary
.
scalar
(
summary_name
,
tf
.
Print
(
value
,
[
value
],
summary_name
))
tf
.
summary
.
scalar
(
summary_name
,
tf
.
Print
(
value
,
[
value
],
summary_name
))
return
names_to_updates
.
values
()
return
names_to_updates
.
values
()
def
create_init_fn_to_restore
(
self
,
master_checkpoint
,
inception_checkpoint
):
def
create_init_fn_to_restore
(
self
,
master_checkpoint
,
inception_checkpoint
=
None
):
"""Creates an init operations to restore weights from various checkpoints.
"""Creates an init operations to restore weights from various checkpoints.
Args:
Args:
...
...
attention_ocr/python/model_test.py
View file @
51b7c2b3
...
@@ -73,9 +73,10 @@ class ModelTest(tf.test.TestCase):
...
@@ -73,9 +73,10 @@ class ModelTest(tf.test.TestCase):
high
=
self
.
num_char_classes
,
high
=
self
.
num_char_classes
,
size
=
(
self
.
batch_size
,
self
.
seq_length
)).
astype
(
'int64'
))
size
=
(
self
.
batch_size
,
self
.
seq_length
)).
astype
(
'int64'
))
def
create_model
(
self
):
def
create_model
(
self
,
charset
=
None
):
return
model
.
Model
(
return
model
.
Model
(
self
.
num_char_classes
,
self
.
seq_length
,
num_views
=
4
,
null_code
=
62
)
self
.
num_char_classes
,
self
.
seq_length
,
num_views
=
4
,
null_code
=
62
,
charset
=
charset
)
def
test_char_related_shapes
(
self
):
def
test_char_related_shapes
(
self
):
ocr_model
=
self
.
create_model
()
ocr_model
=
self
.
create_model
()
...
@@ -244,6 +245,21 @@ class ModelTest(tf.test.TestCase):
...
@@ -244,6 +245,21 @@ class ModelTest(tf.test.TestCase):
self
.
assertAllEqual
(
conv_w_coords_tf
,
conv_w_coords_alt_tf
)
self
.
assertAllEqual
(
conv_w_coords_tf
,
conv_w_coords_alt_tf
)
def
test_predicted_text_has_correct_shape_w_charset
(
self
):
charset
=
create_fake_charset
(
self
.
num_char_classes
)
ocr_model
=
self
.
create_model
(
charset
=
charset
)
with
self
.
test_session
()
as
sess
:
endpoints_tf
=
ocr_model
.
create_base
(
images
=
self
.
fake_images
,
labels_one_hot
=
None
)
sess
.
run
(
tf
.
global_variables_initializer
())
tf
.
tables_initializer
().
run
()
endpoints
=
sess
.
run
(
endpoints_tf
)
self
.
assertEqual
(
endpoints
.
predicted_text
.
shape
,
(
self
.
batch_size
,))
self
.
assertEqual
(
len
(
endpoints
.
predicted_text
[
0
]),
self
.
seq_length
)
class
CharsetMapperTest
(
tf
.
test
.
TestCase
):
class
CharsetMapperTest
(
tf
.
test
.
TestCase
):
def
test_text_corresponds_to_ids
(
self
):
def
test_text_corresponds_to_ids
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment